循環(huán)神經(jīng)網(wǎng)絡(luò)是如何工作的
基于圖展開和參數(shù)共享的思想,我們可以設(shè)計(jì)各種循環(huán)神經(jīng)網(wǎng)絡(luò)。
計(jì)算循環(huán)網(wǎng)絡(luò)(將 x值的輸入序列映射到輸出值 o 的對(duì)應(yīng)序列) 訓(xùn)練損失的計(jì)算圖。損失L 衡量每個(gè) o與相應(yīng)的訓(xùn)練目標(biāo) v 的距離。當(dāng)使用 softmax 輸出時(shí),我們假設(shè) o 是未歸一化的對(duì)數(shù)概率。損失 L 內(nèi)部計(jì)算,并將其與目標(biāo) y 比較。RNN輸入到隱藏的連接由權(quán)重矩陣 U參數(shù)化,隱藏到隱藏的循環(huán)連接由權(quán)重矩陣 W參數(shù)化以及隱藏到輸出的連接由權(quán)矩陣 V 參數(shù)化。(左) 使用循環(huán)連接繪制的RNN和它的損失。(右) 同一網(wǎng)絡(luò)被視為展開的計(jì)算圖,其中每個(gè)節(jié)點(diǎn)現(xiàn)在與一個(gè)特定的時(shí)間實(shí)例相關(guān)聯(lián)。
循環(huán)神經(jīng)網(wǎng)絡(luò)中一些重要的設(shè)計(jì)模式包括以下幾種:
1. 每個(gè)時(shí)間步都有輸出,并且隱藏單元之間有循環(huán)連接的循環(huán)網(wǎng)絡(luò),如上圖所示。
2. 每個(gè)時(shí)間步都產(chǎn)生一個(gè)輸出,只有當(dāng)前時(shí)刻的輸出到下個(gè)時(shí)刻的隱藏單元之間
有循環(huán)連接的循環(huán)網(wǎng)絡(luò)。
3. 隱藏單元之間存在循環(huán)連接,但讀取整個(gè)序列后產(chǎn)生單個(gè)輸出的循環(huán)網(wǎng)絡(luò)。
任何圖靈可計(jì)算的函數(shù)都可以通過這樣一個(gè)有限維的循環(huán)網(wǎng)絡(luò)計(jì)算,在這個(gè)意義上上圖的循環(huán)神經(jīng)網(wǎng)絡(luò)是萬能的。RNN經(jīng)過若干時(shí)間步后讀取輸出,這與由圖靈機(jī)所用的時(shí)間步是漸近線性的,與輸入長(zhǎng)度也是漸近線性的 (Siegelmann and Sontag, 1991; Siegelmann, 1995; Siegelmann and Sontag, 1995;Hyotyniemi, 1996)。由圖靈機(jī)計(jì)算的函數(shù)是離散的,所以這些結(jié)果都是函數(shù)的具體實(shí)現(xiàn),而不是近似。RNN作為圖靈機(jī)使用時(shí),需要一個(gè)二進(jìn)制序列作為輸入,其輸出必須離散化后提供二進(jìn)制輸出。利用單個(gè)有限大小的特定RNN計(jì)算在此設(shè)置下的所有函數(shù)是可能的(Siegelmann and Sontag (1995) 用了 886 個(gè)單元)。圖靈機(jī)的 ‘‘輸入’’ 是要計(jì)算函數(shù)的詳細(xì)說明 (specification),所以模擬此圖靈機(jī)的相同網(wǎng)絡(luò)足以應(yīng)付所有問題。用于證明的理論RNN可以通過激活和權(quán)重(由無限精度的有理數(shù)表示)來模擬無限堆棧。
現(xiàn)在我們研究上圖中RNN的前向傳播公式。這個(gè)圖沒有指定隱藏單元的激活函數(shù)。我們假設(shè)使用雙曲正切激活函數(shù)。此外,圖中沒有明確指定何種形式的輸出和損失函數(shù)。我們假定輸出是離散的,如用于預(yù)測(cè)詞或字符的RNN。一種代表離散變量的自然方式是把輸出 o作為每個(gè)離散變量可能值的非標(biāo)準(zhǔn)化對(duì)數(shù)概率。然后,我們可以應(yīng)用softmax 函數(shù)后續(xù)處理后,獲得標(biāo)準(zhǔn)化后概率的輸出向量 。RNN從特定的初始狀態(tài) h(0) 開始前向傳播。從 t = 1 到 t = τ 的每個(gè)時(shí)間步,我們應(yīng)用以下更新方程:
其中的參數(shù)的偏置向量 b和 c 連同權(quán)重矩陣 U、V 和 W,分別對(duì)應(yīng)于輸入到隱藏、隱藏到輸出和隱藏到隱藏的連接。這個(gè)循環(huán)網(wǎng)絡(luò)將一個(gè)輸入序列映射到相同長(zhǎng)度的輸出序列。與 x序列配對(duì)的 y 的總損失就是所有時(shí)間步的損失之和。例如,L(t) 為給定的的負(fù)對(duì)數(shù)似然,則
其中, 需要讀取模型輸出向量的項(xiàng)。
關(guān)于各個(gè)參數(shù)計(jì)算這個(gè)損失函數(shù)的梯度是昂貴的操作。梯度計(jì)算涉及執(zhí)行一次前向傳播(如在上圖展開圖中從左到右的傳播),接著是由右到左的反向傳播。運(yùn)行時(shí)間是 O(τ ),并且不能通過并行化來降低,因?yàn)榍跋騻鞑D是固有循序的; 每個(gè)時(shí)間步只能一前一后地計(jì)算。前向傳播中的各個(gè)狀態(tài)必須保存,直到它們反向傳播中被再次使用,因此內(nèi)存代價(jià)也是 O(τ )。應(yīng)用于展開圖且代價(jià)為 O(τ ) 的反向傳播算法稱為通過時(shí)間反向傳播 (back-propagaTIon through TIme, BPTT)。
此類RNN的唯一循環(huán)是從輸出到隱藏層的反饋連接。在每個(gè)時(shí)間步 t,輸入為,隱藏層激活為。(左) 回路原理圖。(右) 展開的計(jì)算圖。這樣的RNN沒有前面介紹的 RNN 那樣強(qiáng)大(只能表示更小的函數(shù)集合)。上圖中的RNN可以選擇將其想要的關(guān)于過去的任何信息放入隱藏表示 ? 中并且將 ? 傳播到未來。該圖中RNN被訓(xùn)練為將特定輸出值放入 o中,并且 o是允許傳播到未來的唯一信息。此處沒有從 h 前向傳播的直接連接。之前的 h僅通過產(chǎn)生的預(yù)測(cè)間接地連接到當(dāng)前。o通常缺乏過去的重要息,除非它非常高維且內(nèi)容豐富。這使得該圖中的RNN不那么強(qiáng)大,但是它更容易訓(xùn)練,因?yàn)槊總€(gè)時(shí)間步可以與其他時(shí)間步分離訓(xùn)練,允許訓(xùn)練期間更多的并行化。
關(guān)于時(shí)間展開的循環(huán)神經(jīng)網(wǎng)絡(luò),在序列結(jié)束時(shí)具有單個(gè)輸出。這樣的網(wǎng)絡(luò)可以用于概括序列并產(chǎn)生用于進(jìn)一步處理的固定大小的表示。在結(jié)束處可能存在目標(biāo)(如此處所示),或者通過更下游模塊的反向傳播來獲得輸出上的梯度。