柚子快報邀請碼778899分享:Llama網(wǎng)絡結構介紹
柚子快報邀請碼778899分享:Llama網(wǎng)絡結構介紹
LLaMA現(xiàn)在已經是開源社區(qū)里炙手可熱的模型了,但是原文中僅僅介紹了其和標準Transformer的差別,并沒有一個全局的模型介紹。因此打算寫篇文章,爭取讓讀者不參考任何其他資料把LLaMA的模型搞懂。
結構
如圖所示為LLaMA的示意圖,由Attention和MLP層堆疊而成 LLaMA模型主要由Attention和MLP層堆疊而成,具有以下特點: 1、前置的RMSNorm:RMSNorm是一種歸一化技術,用于穩(wěn)定模型的訓練過程,提高模型的收斂速度。 2、Q、K上的RoPE旋轉式位置編碼:位置編碼用于捕捉序列中的位置信息,RoPE旋轉式位置編碼能夠有效地處理長序列,提高模型的性能。 3、Causal mask:該機制保證每個位置只能看到前面的tokens,確保了模型的自回歸性質。 4、使用了Group Query Attention:通過使用分組查詢注意力(GQA),LLaMA能夠在保持性能的同時,降低模型的計算復雜度,提高推理速度。 5、MLP表達式:down(up(x) * SILU(gate(x))),其中down, up, gate都是線性層 LLaMA各個不同大小的結構設置如下表所示。其中最大的65B的LLaMA用了2048張80GB的A100,batch size為4百萬,訓練一次需要21天。
Group Query Attention(V2 only)
自回歸模型生成回答時,需要前面生成的KV緩存起來,來加速計算。多頭注意力機制(MHA)需要的緩存量很大,Multi-Query Attention指出多個頭之間可以共享KV對。Group Query Attention沒有像MQA一樣極端,將query分組,組內共享KV,效果接近MHA,速度上與MQA可比較。p.s. 這個技術falcon已經用上了,當時falcon說自己用的是multi query attention,因為當group=1時,GQA和MQA是等價的。falcon支持設置不同的G。
RMSNorm
這是在BERT、GPT等模型中廣泛使用的LayerNorm: RMSNorm(root mean square)發(fā)現(xiàn)LayerNorm的中心偏移沒什么用(減去均值等操作)。將其去掉之后,效果幾乎不變,但是速度提升了40%。最終公式為: 注意除了沒有減均值,加偏置以外,分母上求的RMS而不是方差。
LLaMA在 Attention Layer和MLP的輸入上使用了RMSNorm,相比在輸出上使用,訓練會更加穩(wěn)定。
SwiGLU
LLaMA沒有使用ReLU,而是使用了SwiGLU,有時也被稱為SiLU。公式為: ,效果類似平滑版的ReLU:
RoPE
LLaMA使用了Rotary Position Embedding。對于Q的第m個位置向量q,通過以下方法注入位置編碼:
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
theta = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_mbeddings)
freqs = torch.einsum("i,j->ij", t, theta)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, seq_len=None):
return self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...]
# 在LlamaAttention通過以下命令調用:
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
以下代碼將q沿著最后一個維度劈成兩半,將后一半乘-1,然后連接在第一半之前,就得到了上式第三項。
# 在接下來的apply_rotary_pos_emb函數(shù)里調用
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
最后通過以下代碼得到結合了位置編碼的Q,K(K和Q使用同樣的方式進行位置編碼)。
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
q_embed = (q * cos[position_ids]) + (rotate_half(q) * sin[position_ids])
k_embed = (k * cos[position_ids]) + (rotate_half(k) * sin[position_ids])
return q_embed, k_embed
# 在LlamaAttention中通過以下命令調用:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
絕對位置編碼的優(yōu)點是計算速度快等,缺點是拓展長度比較麻煩,且絕對位置并沒有什么實際意義。而相對位置編碼對學習token之間的關系很有意義,比如距離的很遠的兩個token之間的關聯(lián)大概率很小,使用相對位置編碼往往能夠獲得更好的效果。此外拓展長度也更容易,因為不論context size多長,只需關注最長距離以內的輸入即可。相對位置編碼的缺點是沒有絕對位置編碼計算速度快。
當我們計算Attention時,RoPE可以變成相對位置編碼。 從上面這個公式可以看出,q和k的attention依賴相對距離m-n。因此RoPE為q、k注入的絕對位置編碼,計算得到的attention,卻變成了相對位置編碼。妙的很,我這里為了不參考其他文章就很容易搞懂LLaMA的結構,簡化了很多東西,推薦大家看一看RoPE原作者蘇劍林的博客了解更多信息。
本文只關注LLaMA缺失的模型結構方面的介紹,對于文章的翻譯可以參考其他的文章, 例如:靳偉,LLaMA大模型是如何煉成的, 其他參考文章:https://zhuanlan.zhihu.com/p/636784644 原文:https://arxiv.org/pdf/2302.13971.pdf。 文中參考的代碼是huggingface的transformers庫實現(xiàn)的版本,并不是Meta官方的代碼。 備注說明:受筆者水平限制,如果哪里講的不對,或者不夠清晰易懂,歡迎在評論區(qū)與我交流。
柚子快報邀請碼778899分享:Llama網(wǎng)絡結構介紹
參考閱讀
本文內容根據(jù)網(wǎng)絡資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點和立場。
轉載請注明,如有侵權,聯(lián)系刪除。