柚子快報(bào)激活碼778899分享:論文閱讀 《RWKV》論文筆記
柚子快報(bào)激活碼778899分享:論文閱讀 《RWKV》論文筆記
原文出處
[2305.13048] RWKV: Reinventing RNNs for the Transformer Era (arxiv.org)
原文筆記
What
RWKV(RawKuv):Reinventing RNNs for the Transformer Era
本文貢獻(xiàn)如下:
提出了 RWKV 網(wǎng)絡(luò)架構(gòu),結(jié)合了RNNS 和Transformer 的優(yōu)點(diǎn),同時(shí)緩解了它們已知的限制
我們提出了一種新的線性注意力機(jī)制
展示了 RWKV 在處理涉及大規(guī)模模型和長(zhǎng)距離依賴關(guān)系的任務(wù)時(shí)的性能、效率和擴(kuò)展能力
RWKV的突出賣點(diǎn):
O(1)推理復(fù)雜度真的非常香
單 token 推理時(shí)間恒定,總推理時(shí)間隨序列長(zhǎng)度線性增加內(nèi)存占用恒定,不隨序列長(zhǎng)度增加推理時(shí)間和內(nèi)存占用隨模型尺寸線性增長(zhǎng)
優(yōu)勢(shì)是數(shù)量級(jí)級(jí)別的,這意味著:
大模型的硬件限制和部署成本將大幅降低,CPU及非 NV 加速卡均可部署服務(wù)器上部署大模型的成本將大幅降低,普通臺(tái)式機(jī)和筆記本將能在本地部署大模型手機(jī)端部署也成為可能
RWKV 將推動(dòng)大模型進(jìn)行一次架構(gòu)遷移!
Why
Transformer具有出色的序列建模能力,一次處理一整句話,或一整段話,可以并行訓(xùn)練,但是同樣面臨著計(jì)算復(fù)雜度高,內(nèi)存占用大,計(jì)算成本高的難題
傳統(tǒng)的Transformer的總推理時(shí)間隨序列長(zhǎng)度二次增加(在序列特別長(zhǎng)的情況下有可能三次增加)
自注意力機(jī)制的二次復(fù)雜度使其成為涉及長(zhǎng)序列和受限資源的任務(wù)的計(jì)算和內(nèi)存密集型。這刺激了增強(qiáng) Transformer 可擴(kuò)展性的研究,有時(shí)犧牲了一些其有效性
循環(huán)神經(jīng)網(wǎng)絡(luò) (RNN) 在內(nèi)存和計(jì)算需求方面表現(xiàn)出線性縮放,內(nèi)存占用小,計(jì)算量小,(因?yàn)樗看沃惶幚硪徊康臄?shù)據(jù))但由于并行化和可擴(kuò)展性的限制,難以與 Transformer 匹配相同的性能。
RNN在訓(xùn)練長(zhǎng)序列時(shí)容易出現(xiàn)梯度消失問題
RNN 在訓(xùn)練過程中對(duì)前一步結(jié)果依賴,無(wú)法在時(shí)間維度上進(jìn)行并行化,限制了其可擴(kuò)展性(無(wú)法獲得很大的rnn模型)
RWKV 背后的動(dòng)機(jī)是平衡計(jì)算效率和神經(jīng)網(wǎng)絡(luò)的表達(dá)能力。它提供了一種處理具有數(shù)十億個(gè)參數(shù)的大規(guī)模模型的解決方案,以降低計(jì)算成本表現(xiàn)出具有競(jìng)爭(zhēng)力的性能。實(shí)驗(yàn)表明,RWKV 解決了 AI 中的縮放和部署挑戰(zhàn),特別是對(duì)于順序數(shù)據(jù)處理,指向更可持續(xù)和高效 AI 模型。
Challenge
Idea
model
原文翻譯
Abstract
Transformers 徹底改變了幾乎所有自然語(yǔ)言處理 (NLP) 任務(wù),但受到內(nèi)存和計(jì)算復(fù)雜性的影響,這些復(fù)雜性隨序列長(zhǎng)度呈二次方擴(kuò)展。相比之下,循環(huán)神經(jīng)網(wǎng)絡(luò) (RNN) 在內(nèi)存和計(jì)算需求方面表現(xiàn)出線性縮放,但由于并行化和可擴(kuò)展性的限制,難以與 Transformer 匹配相同的性能。我們提出了一種新穎的模型架構(gòu),即感知加權(quán)鍵值 (RWKV),它將變壓器的高效并行訓(xùn)練與 RNN 的有效推理相結(jié)合。
我們的方法利用了線性注意力機(jī)制,并允許我們將模型制定為 Transformer 或 RNN,從而在訓(xùn)練期間并行化計(jì)算并在推理過程中保持恒定的計(jì)算和內(nèi)存復(fù)雜性。到目前為止,我們將我們的模型擴(kuò)展到多達(dá) 14 億個(gè)參數(shù),是迄今為止訓(xùn)練的最大密集 RNN,發(fā)現(xiàn) RWKV 的性能與類似大小的 Transformer 相當(dāng),這表明未來(lái)的工作可以利用這種架構(gòu)來(lái)創(chuàng)建更有效的模型。這項(xiàng)工作為協(xié)調(diào)序列處理任務(wù)中計(jì)算效率和模型性能之間的權(quán)衡邁出了重要的一步。
Introduction
深度學(xué)習(xí)極大地推動(dòng)了人工智能,影響了一系列科學(xué)和工業(yè)用途。這些通常涉及復(fù)雜的順序數(shù)據(jù)處理任務(wù)比如自然語(yǔ)言理解任務(wù),會(huì)話AI,時(shí)間序列分析,和間接順序格式,如圖像和圖表(Brown等人,2020;Ismail Fawaz等人,2019;Wu等人,2020;Albalak等人,2022)。這些技術(shù)中占主導(dǎo)地位包括 RNN 和 Transformers (Vaswani et al., 2017),每種都有特定的優(yōu)點(diǎn)和缺點(diǎn)。RNN 需要更少的內(nèi)存,特別是對(duì)于處理長(zhǎng)序列。然而,它們?cè)谟?xùn)練過程中在時(shí)間維度上存在梯度消失問題和非并行性,限制了它們的可擴(kuò)展性(Hochreiter,1998;Le 和 Zuidema,2016)。
Transformers 已經(jīng)成為一種強(qiáng)大的替代方案,擅長(zhǎng)管理局部和遠(yuǎn)程依賴項(xiàng)并支持并行訓(xùn)練(Tay 等人,2022 年)。諸如GPT-3 (Brown et al., 2020)、ChatGPT (OpenAI, 2022;Koco?n et al., 2023),LLAMA (Touvron et al., 2023) 和 Chinchilla (Hoffmann et al., 2022) 展示了 Transformer 在 NLP 中的潛力。然而,自注意力機(jī)制的二次復(fù)雜度使其成為涉及長(zhǎng)序列和受限資源的任務(wù)的計(jì)算和內(nèi)存密集型。這刺激了增強(qiáng) Transformer 可擴(kuò)展性的研究,有時(shí)犧牲了一些其有效性(Wang 等人,2020;Zaheer 等人,2020;Dao 等人,2022a)。
為了應(yīng)對(duì)這些挑戰(zhàn),我們引入了感知加權(quán)鍵值 (RWKV) 模型,結(jié)合了 RNN 和 Transformer 的優(yōu)勢(shì),同時(shí)規(guī)避了關(guān)鍵缺陷。RWKV 通過高效的線性縮放緩解了與 Transformer (Katharopoulos et al., 2020) 相關(guān)的內(nèi)存瓶頸和二次縮放,同時(shí)保持 Transformer 的表達(dá)能力,例如并行訓(xùn)練和魯棒可擴(kuò)展性。RWKV 用線性注意力的變體重新制定注意力機(jī)制,用更有效的通道定向注意力替換傳統(tǒng)的點(diǎn)積令牌交互。這種實(shí)現(xiàn),沒有近似,提供了最低的計(jì)算和內(nèi)存復(fù)雜性;見表 1。
RWKV 背后的動(dòng)機(jī)是平衡計(jì)算效率和神經(jīng)網(wǎng)絡(luò)的表達(dá)能力。它提供了一種處理具有數(shù)十億個(gè)參數(shù)的大規(guī)模模型的解決方案,以降低計(jì)算成本表現(xiàn)出具有競(jìng)爭(zhēng)力的性能。實(shí)驗(yàn)表明,RWKV 解決了 AI 中的縮放和部署挑戰(zhàn),特別是對(duì)于順序數(shù)據(jù)處理,指向更可持續(xù)和高效 AI 模型。我們?cè)诒疚闹械呢暙I(xiàn)如下:
RWKV 的引入,一種新穎的架構(gòu),結(jié)合了 RNN 和 Transformer 優(yōu)勢(shì),同時(shí)減輕了它們的局限性。詳細(xì)的實(shí)驗(yàn),展示了 RWKV 在大規(guī)模模型的基準(zhǔn)數(shù)據(jù)集上的性能和效率。預(yù)訓(xùn)練模型的釋放,從 1690 萬(wàn)個(gè)參數(shù)到 14 億個(gè)參數(shù),在 Pile 上訓(xùn)練(Gao 等人,2020;Biderman 等人,2022)。
2 Background
在這里,我們簡(jiǎn)要回顧了 RNN 和 Transformer 的基本原理。
2.1 Recurrent Neural Networks (RNNs)
LSTM (Hochreiter and Schmidhuber, 1997) 和 GRU (Chung et al., 2014) 等流行的 RNN 架構(gòu)的原理可以概括為以下公式(如 LSTM 所示,其他架構(gòu)可以類似地推理):
盡管RNN可以分解為兩個(gè)線性塊(W和U)和一個(gè)特定于RNN的塊(1)-(6),如Bradbury等人所述。(2017),依賴于先前時(shí)間步長(zhǎng)的數(shù)據(jù)依賴禁止并行化這些典型的RNN。
2.2 Transformers and AFT
由Vaswani等人(2017)介紹,Transformers是一類神經(jīng)網(wǎng)絡(luò),已經(jīng)成為幾個(gè)NLP任務(wù)的主要架構(gòu)。Transformer 不是像 RNN 那樣逐步操作序列,而是依靠注意力機(jī)制來(lái)捕獲所有輸入和輸出tokens之間的關(guān)系:
其中為方便起見,省略了多頭和比例因子 1√dkis。核心 QK? 乘法是一個(gè)在序列中的每個(gè)令牌之間成對(duì)注意力分?jǐn)?shù)的集合,可以分解為向量操作:
AFT (Zhai et al., 2021),表述為
其中 {wt,i} ∈ RT ×T 是學(xué)習(xí)的成對(duì)位置偏差,每個(gè) wt,i 是一個(gè)標(biāo)量。
受 AFT 的啟發(fā),RWKV 采用類似的方法。但是,為簡(jiǎn)單起見,它修改了交互權(quán)重,使其可以轉(zhuǎn)化為 RNN。RWKV 中的每個(gè) wt,i 是一個(gè)通道時(shí)間衰減向量乘以相對(duì)位置并從當(dāng)前時(shí)間向后跟蹤,因?yàn)樗p:
其中 w ∈ (R≥0)^d,d 是通道數(shù)。我們要求 w 是非負(fù)的,以確保 e^wt,i ≤ 1 并且每通道權(quán)重在時(shí)間上向后衰減。
距離當(dāng)前token越遠(yuǎn)的token它就會(huì)衰減的越多,越近的token它就會(huì)衰減的越少,但實(shí)際情況比這個(gè)還要復(fù)雜一點(diǎn),后邊有個(gè)圖來(lái)可視化這一部分(channel的信息衰減))
3 RWKV
RWKV模型架構(gòu)由四個(gè)基本元素構(gòu)成,這四個(gè)基本元素本質(zhì)上都是時(shí)間混合的和通道混合的:
R:Receptance向量充當(dāng)過去信息的接收器(作為過去信息的接受程度的接受向量)
W:Weight表示位置權(quán)重衰減向量,即模型中的可訓(xùn)練參數(shù)(可訓(xùn)練的模型參數(shù))
K:鍵向量,類似于傳統(tǒng)注意力機(jī)制中的K。(用每一個(gè)token自身的一個(gè)值來(lái)對(duì)位置向量進(jìn)行調(diào)制)??
V:值向量,類似于傳統(tǒng)注意力機(jī)制中的V。
這些核心元素在每個(gè)時(shí)間步乘法交互,如圖 2 所示。
3.1?Architecture
RWKV 模型由堆疊的殘差塊組成。每個(gè)塊由一個(gè)時(shí)間混合和一個(gè)通道混合子塊組成,實(shí)現(xiàn)循環(huán)結(jié)構(gòu)以利用過去的信息。
該模型使用了獨(dú)特的類似注意力的分?jǐn)?shù)更新過程,其中包括一個(gè)隨時(shí)間變化的 softmax 操作,以提高數(shù)值穩(wěn)定性和減輕消失梯度(對(duì)于嚴(yán)格的證明,請(qǐng)參見附錄 H)。它確保梯度沿著最相關(guān)的路徑傳播。此外,架構(gòu)中包含的層歸一化 (Ba et al., 2016) 有助于穩(wěn)定梯度,有效地解決梯度消失和爆炸的問題。這些設(shè)計(jì)元素不僅增強(qiáng)了深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練動(dòng)態(tài),而且促進(jìn)了多層的堆疊,通過捕獲不同抽象級(jí)別的復(fù)雜模式,從而比傳統(tǒng)的 RNN 模型具有更好的性能(另見附錄 I)。
3.1.1 Token Shift
在該架構(gòu)中,計(jì)算中涉及的所有線性投影向量(R, K, V,通道混合中的R ', K ')都是通過當(dāng)前時(shí)間步輸入和前一個(gè)時(shí)間步輸入之間的線性插值產(chǎn)生的,促進(jìn)令牌移位。
時(shí)間混合計(jì)算的向量是塊當(dāng)前輸入和先前輸入的線性組合的線性投影:
通道混合輸入也是如此:
使用 PyTorch (Paszke et al., 2019) 庫(kù) asnn 在每個(gè)塊的時(shí)間維度上實(shí)現(xiàn)令牌移位作為一個(gè)簡(jiǎn)單的偏移量。ZeroPad2d((0,0,1,-1))。
(在模型參數(shù)較小的時(shí)候與Transofrmer的效果還是有一定差距的)
3.1.2 WKV Operator
我們模型中的 W KV 算子的計(jì)算與 Attention Free Transformer (AFT) 中使用的方法并行(Zhai 等人,2021 年)。然而,與 W 是一個(gè)成對(duì)矩陣的 AFT 不同,我們的模型將 W 視為由相對(duì)位置修改的通道向量。在我們的模型中,這種循環(huán)行為由 W KV 向量的時(shí)間相關(guān)更新定義,形式化如下等式:
為了規(guī)避 W 的任何潛在退化,我們引入了一個(gè)單獨(dú)關(guān)注當(dāng)前標(biāo)記的向量 U。有關(guān)這方面的更多信息可以在附錄 I 中找到。
3.1.3 Output Gating
使用 sigmoid 在時(shí)間混合和通道混合塊中實(shí)現(xiàn)輸出門控,接受度,σ(r)。W KV 算子后輸出向量 ot 由下式給出:
3.2 Transformer-like Training
RWKV 可以使用一種稱為時(shí)間并行模式的技術(shù)有效地并行化,讓人想起 Transformer。在單個(gè)層中處理一批序列的時(shí)間復(fù)雜度為 O(BT d2),主要由矩陣乘法 Wλ 組成,其中 λ ∈ {r, k, v, o}(假設(shè) B 序列、Tmaximum 標(biāo)記和 d 個(gè)通道)。相比之下,更新注意力分?jǐn)?shù)wkvt涉及串行掃描(更多細(xì)節(jié)見附錄D),復(fù)雜度為O(BT d)。矩陣乘法可以類似于 Wλ 并行化,其中傳統(tǒng) Transformer 中的 λ ∈ {Q, K, V, O}。逐元素 W KV 計(jì)算依賴于時(shí)間,但可以很容易地沿其他兩個(gè)維度并行化 (Lei et al., 2018)3。
3.3 RNN-like Inference
循環(huán)網(wǎng)絡(luò)通常利用狀態(tài) t 的輸出作為狀態(tài) t + 1 的輸入。在語(yǔ)言模型的自回歸解碼推理中也可以觀察到這種用法,其中每個(gè)令牌必須在傳遞到下一步之前計(jì)算。RWKV 利用了這種類似 RNN 的結(jié)構(gòu),稱為時(shí)間順序模式。在這種情況下,RWKV 可以方便地在推理過程中遞歸制定用于解碼,如附錄 D 所示。
下略
參考文獻(xiàn)
RWKV:在Transformer時(shí)代重塑RNN_嗶哩嗶哩_bilibili
RWKV-6論文解讀_嗶哩嗶哩_bilibili
柚子快報(bào)激活碼778899分享:論文閱讀 《RWKV》論文筆記
精彩鏈接
本文內(nèi)容根據(jù)網(wǎng)絡(luò)資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點(diǎn)和立場(chǎng)。
轉(zhuǎn)載請(qǐng)注明,如有侵權(quán),聯(lián)系刪除。