柚子快報激活碼778899分享:LLaMA詳細解讀
柚子快報激活碼778899分享:LLaMA詳細解讀
LLaMA 是目前為止,效果最好的開源 LLM 之一。精讀 LLaMA 的論文及代碼,可以很好的了解 LLM 的內部原理。本文對 LLaMA 論文進行了介紹,同時附上了關鍵部分的代碼,并對代碼做了注釋。
摘要
LLaMA是一個系列模型,模型參數(shù)量從7B到65B。在大部分的任務上,LLaMA-13B強于GPT-3(175B)。LLaMA-65B的性能,可以和最好的LM相媲美,如Chinchilla-70B 和 PaLM-540B。
一、引言
一般而言,模型越大,效果越好。然而有文獻指出[1],當給定計算量的預算之后,最好的performance,并不是最大的模型,而是在一個小模型上用更多的數(shù)據(jù)進行訓練。針對給定的計算量預算,scaling laws可以計算如何選擇數(shù)據(jù)量的大小和模型的大小。然而這忽略了inference的預算,而這一點在模型推理時非常關鍵。當給定一個模型performance目標之后,最好的模型不是訓練最快的模型,而是推理最快的模型。盡管在這種情況下,訓練一個更大的模型成本會更低。
文獻[2]中推薦,訓練一個 10B 的模型,需要 200B 的 tokens,而本文的實驗發(fā)現(xiàn),一個7B的模型,經過 1T tokens 訓練之后,performance 仍然在增加。本文的目標在于,通過在超大規(guī)模的數(shù)據(jù)上訓練,給出一系列可能最好 performance 的 LLM。
二、預訓練數(shù)據(jù)
2.1 數(shù)據(jù)集
一共有1.4T的tokens,大部分的訓練數(shù)據(jù)都只用了一次,除了Wikipedia 和 Books 使用了大概2個epochs。
Pre-training data
2.2 tokenizer
使用byte pair encoding (BPE) 算法,使用的是Sentence-Piece的實現(xiàn)。所有數(shù)字被拆分為單獨的digit,所有未知的UTF-8 字符,回退到字節(jié)來進行分解。因此,LLaMA 可以通過byte 的方式,構造出很多不在 vocab 中的字符,從而也具有較好的多語言能力。
三、網絡結構改進
使用了基于transformer的架構,并做了如下3點改進:
3.1 Pre-normalization
為了提高訓練的穩(wěn)定性,對每個transformer層的輸入進行歸一化,而不是輸出進行歸一化。
同時,使用 RMS Norm 歸一化函數(shù)。RMS Norm 的全稱為 Root Mean Square layer normalization。與 layer Norm 相比,RMS Norm的主要區(qū)別在于去掉了減去均值的部分,計算公式為:
RMS Norm 的作者認為這種模式在簡化了Layer Norm 的計算,可以在減少約 7%~64% 的計算時間[3]。
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
3.2 SwiGLU
使用SwiGLU替代了ReLU作為激活函數(shù)。和PaLM中不同,維度采用而不是?4??。
SwiGLU 在論文[4]?中提出,相比于其他的激活函數(shù)變體,可以取得 log-perplexity 的最優(yōu)值(和 GEGLU 并列)。
GLU Variants Improve Transformer
SwiGLU 及幾種類似變體的計算公式如下:
其中,。代碼如下:
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
# config 中 hidden_act = 'silu'
# 'silu' 和 'swish' 對應的激活函數(shù)均為:SiLUActivation
# https://github.com/huggingface/transformers/blob/717dadc6f36be9f50abc66adfd918f9b0e6e3502/src/transformers/activations.py#L229
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
# 對應上述公式的 SwiGLU
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
從代碼可以看到 LlamaMLP 中一共有 3 個 Linear 層,原因就在于 SwiGLU 激活函數(shù)比類似 ReLU 的激活函數(shù),需要多一個 Linear 層進行門控。
3.3 RoPE
RoPE 的核心思想是“通過絕對位置編碼的方式實現(xiàn)相對位置編碼”,可以說是具備了絕對位置編碼的方便性,同時可以表示不同 token 之間的相對位置關系。[5]?不同于原始 Transformers 論文中,將 pos embedding 和 token embedding 進行相加,RoPE 是將位置編碼和 query (或者 key) 進行相乘。具體如下:
Rotary Position Embedding
其中,左側的矩陣????表示位置第???個位置的位置編碼,右側的向量????表示對應位置的 query 向量。兩者相乘,即可得到增加了位置信息的 query (或者 key)。由于????的稀疏性,上述矩陣乘法可以等價于:
Rotary Position Embedding 的簡化實現(xiàn)
其中 ? 是逐位對應相乘,。
RoPE的代碼實現(xiàn)如下[6]:
# 代碼增加了注釋,可以看到和原始公式的對應關系。
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
# 此處 inv_freq 對應公式中的 theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
# 此處 freqs 對應公式中的 m * theta, t 對應公式中的 m,表示位置
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# 此處和原始公式不同,theta_0 和 theta_0 不再相鄰
# 而是分在向量的前半部分和后半部分
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
# 大部分情況下,直接從這里返回
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
# 此次和原始推導中不同,正負號不是間隔的,而是分前半部分和后半部分。但對于結果沒有影響
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# 對應上圖中 RoPE 的簡化計算
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
四、高效實現(xiàn)
加速訓練:
使用了xformers庫。減少了activation checkpointing 中,重新計算 activation 的計算量。手動實現(xiàn) transformer 層的反向傳遞函數(shù),保存了計算成本高的 activations,例如線性層的輸出。通過使用 model parallelism 和 sequence parallelism 來減少顯存的使用量。盡可能地將 activations 的計算和GPU之間的通訊進行并行。
加速效果:
65B的模型,在2048個80G的A100 GPU上,可以達到380 tokens/sec/GPU的速度。訓練1.4T tokens需要21天。
五、主要結果與結論
Massive Multitask LanguageUnderstanding
LLaMA-13B 優(yōu)于 GPT-3,盡管只有1/10大小。 LLaMA-65B 是可以與 Chinchilla-70B 和 PaLM-540B 這種最佳的LLM相競爭的模型。經過微調之后,LLaMA的效果有顯著的提升。
未來打算發(fā)布在更大的語料上預訓練上的更大的模型,因為隨著數(shù)據(jù)和模型的增大,可以看到 performance 的穩(wěn)定提升。
優(yōu)化器
LLaMA使用了AdamW優(yōu)化器進行訓練,優(yōu)化器的超參數(shù)為 =0.9, =0.95
(關于AdamW這個大模型訓練的優(yōu)化器,可參考當前訓練神經網絡最快的方式:AdamW優(yōu)化算法+超級收斂 | 機器之心[6])
下表為LLaMA不同參數(shù)大小模型的具體設置:
表2: LLaMA不同參數(shù)大小模型的具體設置
參數(shù)維度(dim)head個數(shù)layer層數(shù)學習率batch sizetoken數(shù)量6.7B409632323.0e?44M1.0T13.0B512040403.0e?44M1.0T32.5B665652601.5e?44M1.4T65.2B819264801.5e?44M1.4T
訓練結果
如下圖所示,7B、13B、33B和65模型的訓練損失均呈下降趨勢,且在所有token上訓練完后,loss仍沒有收斂的趨勢。因此,在此時,增加訓練的token數(shù)量,仍然可以使模型繼續(xù)學習。
(LLaMA2就是在此結論的基礎上,使用了更多的token進行訓練)
高效部署
研究團隊做了一些優(yōu)化來提高模型的訓練速度:
因果多頭注意的有效實現(xiàn):使用因果多頭注意的有效實現(xiàn)來減少內存使用和運行時間。該實現(xiàn)可在xformers庫中獲得,其靈感來自于固定激活值顯存優(yōu)化和FlashAttention。這是通過不存儲注意力權重和不計算由于語言建模任務的因果性質而被掩蓋的key/query分數(shù)來實現(xiàn)的。 激活重計算:為了進一步提高訓練效率,通過檢查點減少了在向后傳遞過程中重新計算的激活量。更準確地說,節(jié)省了計算成本高的激活,比如線性層的輸出。這是通過手動實現(xiàn)transformer層的backward函數(shù)來實現(xiàn)的,而不是依賴于PyTorch的autograd。 模型并行和序列并行:為了從這種優(yōu)化中充分受益,需要通過使用模型和序列并行來減少模型的內存使用。此外,還盡可能地重疊激活的計算和gpu之間通過網絡的通信。
筆者NOTE:LLM的高效訓練是LLM工程實現(xiàn)的基礎,對于這部分,各位小伙伴還是需要深入地了解一下各種并行策略、因果多頭注意的有效實現(xiàn)、 激活重計算、混合精度訓練。
參考
^Training Compute-Optimal Large Language Models?https://arxiv.org/abs/2203.15556^Training Compute-Optimal Large Language Models?https://arxiv.org/abs/2203.15556^Root Mean Square Layer Normalization?https://arxiv.org/pdf/1910.07467.pdf^GLU Variants Improve Transformer?https://arxiv.org/pdf/2002.05202.pdf^Transformer升級之路:2、博采眾長的旋轉式位置編碼?Transformer升級之路:2、博采眾長的旋轉式位置編碼 - 科學空間|Scientific Spaces^transformers/src/transformers/models/llama/modeling_llama.py at main · huggingface/transformers · GitHub
柚子快報激活碼778899分享:LLaMA詳細解讀
推薦閱讀
本文內容根據(jù)網絡資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點和立場。
轉載請注明,如有侵權,聯(lián)系刪除。