柚子快報(bào)激活碼778899分享:RUST學(xué)習(xí)筆記(Day 3)
柚子快報(bào)激活碼778899分享:RUST學(xué)習(xí)筆記(Day 3)
今天學(xué)習(xí)用Rust來(lái)實(shí)現(xiàn)開(kāi)源 LLM代表LLaMA模型。?本次使用的是karpathy/llama2.c: Inference Llama 2 in one file of pure C?的 Rust 實(shí)現(xiàn)的版本中的:danielgrittner/llama2-rs: LLaMA2 + Rust。僅涉及推理部份。
配置
struct Config {
dim: usize, // transformer dimension
hidden_dim: usize, // for ffn layers
n_layers: usize, // number of layers
n_heads: usize, // number of query heads
head_size: usize, // size of each head (dim / n_heads)
n_kv_heads: usize, // number of key/value heads
shared_weights: bool,
vocab_size: usize, // vocabulary size
seq_len: usize, // max. sequence length
}
在上述代碼中,我們定義了一個(gè)名為?Config?的結(jié)構(gòu)體(struct),用于表示某種配置信息。結(jié)構(gòu)體包含了多個(gè)字段,每個(gè)字段都有對(duì)應(yīng)的字段名和類(lèi)型注釋。
dim: usize:transformer 的維度。hidden_dim: usize:用于 ffn 層(feed-forward network,前饋神經(jīng)網(wǎng)絡(luò))的隱藏層維度。n_layers: usize:層數(shù)。n_heads: usize:查詢頭的數(shù)量。head_size: usize:每個(gè)查詢頭的大?。╠im / n_heads)。n_kv_heads: usize:鍵/值頭的數(shù)量。shared_weights: bool:指示是否使用共享權(quán)重。vocab_size: usize:詞匯表大小。seq_len: usize:最大序列長(zhǎng)度。
該結(jié)構(gòu)體定義了一個(gè)用于存儲(chǔ)具有不同配置信息的對(duì)象。通過(guò)創(chuàng)建?Config?的實(shí)例,并為每個(gè)字段提供適當(dāng)?shù)闹?,我們可以在代碼中使用配置對(duì)象來(lái)管理和傳遞相關(guān)的設(shè)置。
dim?就是上面一直說(shuō)的 Dim,hidden_dim?僅在 FFN 層,因?yàn)?FFN 層需要先擴(kuò)大再縮小。n_heads?和?n_kv_heads?是 Query 的 Head 數(shù)和 KV 的 Head 數(shù),簡(jiǎn)單起見(jiàn)可以認(rèn)為它們是相等的。
參數(shù)
struct TransformerWeights {
// Token Embedding Table
token_embedding_table: Vec
// Weights for RMSNorm
rms_att_weight: Vec
rms_ffn_weight: Vec
// Weights for matmuls in attn
wq: Vec
wk: Vec
wv: Vec
wo: Vec
// Weights for ffn
w1: Vec
w2: Vec
w3: Vec
// final RMSNorm
rms_final_weights: Vec
// freq_cis for RoPE relatively positional embeddings
freq_cis_real: Vec
freq_cis_imag: Vec
// (optional) classifier weights for the logits, on the last layer
wcls: Vec
}
上述代碼定義了一個(gè)名為?TransformerWeights?的結(jié)構(gòu)體(struct),用于存儲(chǔ)一組轉(zhuǎn)換器(transformer)的權(quán)重值。結(jié)構(gòu)體包含了多個(gè)字段,每個(gè)字段都指定了對(duì)應(yīng)的字段名和類(lèi)型。
token_embedding_table: Vec
該結(jié)構(gòu)體定義了一個(gè)用于存儲(chǔ)轉(zhuǎn)換器權(quán)重值的對(duì)象。通過(guò)創(chuàng)建?TransformerWeights?的實(shí)例,并為每個(gè)字段提供適當(dāng)?shù)闹?,我們可以在代碼中使用該對(duì)象來(lái)管理和傳遞相關(guān)的權(quán)重。
其中:freq_?開(kāi)頭的兩個(gè)參數(shù),它們是和位置編碼有關(guān)的參數(shù),也就是說(shuō),我們每次生成一個(gè) Token 時(shí),都需要傳入當(dāng)前位置的位置信息。位置編碼在 Transformer 中是比較重要的,因?yàn)?Self Attention 本質(zhì)上是無(wú)序的,而語(yǔ)言的先后順序在有些時(shí)候是很重要的
加載參數(shù)
fn byte_chunk_to_vec
where
T: Clone,
{
unsafe {
// 獲取起始位置的原始指針
let data = byte_chunk.as_ptr() as *const T;
// 從原始指針創(chuàng)建一個(gè) T 類(lèi)型的切片,注意number_elements是element的數(shù)量,而不是bytes
// 這句是 unsafe 的
let slice_data: &[T] = std::slice::from_raw_parts(data, number_elements);
// 將切片轉(zhuǎn)為 Vec,需要 T 可以 Clone
slice_data.to_vec()
}
}
fn byte_chunk_to_vec
這段代碼的目的是將一個(gè)字節(jié)切片轉(zhuǎn)換為元素類(lèi)型為?T?的向量。但由于涉及到指針操作和不安全的代碼,因此要特別小心使用。where T: Clone?約束確保元素類(lèi)型?T?可以被克隆來(lái)進(jìn)行復(fù)制操作。byte_chunk?表示原始的字節(jié)切片,number_elements?表示結(jié)果向量中元素的個(gè)數(shù)。
unsafe的用法
獲取原始指針:代碼中使用?byte_chunk.as_ptr()?方法獲取?byte_chunk?字節(jié)切片的原始指針,然后將其轉(zhuǎn)換為?T?類(lèi)型的常量指針。這個(gè)操作涉及到底層的指針操作,訪問(wèn)和操作內(nèi)存的原始指針需要使用?unsafe?代碼塊。 從原始指針創(chuàng)建切片:使用?std::slice::from_raw_parts()?方法根據(jù)原始指針和元素?cái)?shù)量創(chuàng)建了一個(gè)?T?類(lèi)型的切片。這個(gè)方法也需要使用?unsafe?代碼塊,因?yàn)樗佑|到了指針和內(nèi)存操作。
需要注意的是,使用?unsafe?關(guān)鍵字會(huì)打開(kāi) Rust 中的一些安全性限制。在使用?unsafe?代碼塊時(shí),需要確保代碼正確地處理了指針和內(nèi)存操作,以避免造成內(nèi)存安全和未定義行為。
在這段代碼中使用?unsafe?主要是為了利用底層的指針操作和內(nèi)存操作,來(lái)直接操作原始數(shù)據(jù)。這樣可以避免不必要的數(shù)據(jù)復(fù)制,并提高性能。但同時(shí),需要非常小心,確保代碼在使用指針和訪問(wèn)內(nèi)存時(shí)不會(huì)引發(fā)潛在的錯(cuò)誤。
as *const T的用法
在?let data = byte_chunk.as_ptr() as *const T?這段代碼中的?*?運(yùn)算符是指針類(lèi)型轉(zhuǎn)換中的解引用運(yùn)算符,訪問(wèn)指針?biāo)赶虻臄?shù)據(jù)。
在這段代碼中,byte_chunk.as_ptr()?返回了?byte_chunk?字節(jié)切片的原始指針(raw pointer),然后使用?as *const T?將其轉(zhuǎn)換為?*const T?類(lèi)型的常量指針,以便與類(lèi)型?T?相匹配。
需要注意的是,在這個(gè)上下文中的?*?運(yùn)算符并不是乘法運(yùn)算符,它是指針類(lèi)型轉(zhuǎn)換語(yǔ)法的一部分。解引用運(yùn)算符不能在 Safe Rust 中直接使用,因?yàn)樗且粋€(gè)不安全操作,需要使用?unsafe?關(guān)鍵字包裹。例如,let value = *ptr;?將會(huì)將?value?變量設(shè)置為指針?ptr?所指向位置的值。
加載模型
讀取原始的 bin 文件并指定對(duì)應(yīng)的參數(shù)大小
et token_embedding_table_size = config.vocab_size * config.dim;
// offset.. 表示從 offset 往后的所有元素
let token_embedding_table: Vec
這行代碼計(jì)算了?token_embedding_table?的大小,即詞嵌入表的大小。它將?config.vocab_size(詞匯表大小)乘以?config.dim(維度),并將結(jié)果賦值給變量?token_embedding_table_size。這行代碼創(chuàng)建了一個(gè)名為?token_embedding_table?的?Vec
因?yàn)橄蛄繑?shù)據(jù)通常取決于從某個(gè)字節(jié)切片轉(zhuǎn)換而來(lái),所以需要借助?byte_chunk_to_vec?函數(shù)來(lái)執(zhí)行轉(zhuǎn)換。
簡(jiǎn)單的內(nèi)容就到這里了,后面上硬菜,未完待續(xù)~~
柚子快報(bào)激活碼778899分享:RUST學(xué)習(xí)筆記(Day 3)
精彩文章
本文內(nèi)容根據(jù)網(wǎng)絡(luò)資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點(diǎn)和立場(chǎng)。
轉(zhuǎn)載請(qǐng)注明,如有侵權(quán),聯(lián)系刪除。