柚子快報激活碼778899分享:RUST學(xué)習(xí)筆記(Day 3)
柚子快報激活碼778899分享:RUST學(xué)習(xí)筆記(Day 3)
今天學(xué)習(xí)用Rust來實現(xiàn)開源 LLM代表LLaMA模型。?本次使用的是karpathy/llama2.c: Inference Llama 2 in one file of pure C?的 Rust 實現(xiàn)的版本中的:danielgrittner/llama2-rs: LLaMA2 + Rust。僅涉及推理部份。
配置
struct Config {
dim: usize, // transformer dimension
hidden_dim: usize, // for ffn layers
n_layers: usize, // number of layers
n_heads: usize, // number of query heads
head_size: usize, // size of each head (dim / n_heads)
n_kv_heads: usize, // number of key/value heads
shared_weights: bool,
vocab_size: usize, // vocabulary size
seq_len: usize, // max. sequence length
}
在上述代碼中,我們定義了一個名為?Config?的結(jié)構(gòu)體(struct),用于表示某種配置信息。結(jié)構(gòu)體包含了多個字段,每個字段都有對應(yīng)的字段名和類型注釋。
dim: usize:transformer 的維度。hidden_dim: usize:用于 ffn 層(feed-forward network,前饋神經(jīng)網(wǎng)絡(luò))的隱藏層維度。n_layers: usize:層數(shù)。n_heads: usize:查詢頭的數(shù)量。head_size: usize:每個查詢頭的大小(dim / n_heads)。n_kv_heads: usize:鍵/值頭的數(shù)量。shared_weights: bool:指示是否使用共享權(quán)重。vocab_size: usize:詞匯表大小。seq_len: usize:最大序列長度。
該結(jié)構(gòu)體定義了一個用于存儲具有不同配置信息的對象。通過創(chuàng)建?Config?的實例,并為每個字段提供適當(dāng)?shù)闹担覀兛梢栽诖a中使用配置對象來管理和傳遞相關(guān)的設(shè)置。
dim?就是上面一直說的 Dim,hidden_dim?僅在 FFN 層,因為 FFN 層需要先擴大再縮小。n_heads?和?n_kv_heads?是 Query 的 Head 數(shù)和 KV 的 Head 數(shù),簡單起見可以認(rèn)為它們是相等的。
參數(shù)
struct TransformerWeights {
// Token Embedding Table
token_embedding_table: Vec
// Weights for RMSNorm
rms_att_weight: Vec
rms_ffn_weight: Vec
// Weights for matmuls in attn
wq: Vec
wk: Vec
wv: Vec
wo: Vec
// Weights for ffn
w1: Vec
w2: Vec
w3: Vec
// final RMSNorm
rms_final_weights: Vec
// freq_cis for RoPE relatively positional embeddings
freq_cis_real: Vec
freq_cis_imag: Vec
// (optional) classifier weights for the logits, on the last layer
wcls: Vec
}
上述代碼定義了一個名為?TransformerWeights?的結(jié)構(gòu)體(struct),用于存儲一組轉(zhuǎn)換器(transformer)的權(quán)重值。結(jié)構(gòu)體包含了多個字段,每個字段都指定了對應(yīng)的字段名和類型。
token_embedding_table: Vec
該結(jié)構(gòu)體定義了一個用于存儲轉(zhuǎn)換器權(quán)重值的對象。通過創(chuàng)建?TransformerWeights?的實例,并為每個字段提供適當(dāng)?shù)闹?,我們可以在代碼中使用該對象來管理和傳遞相關(guān)的權(quán)重。
其中:freq_?開頭的兩個參數(shù),它們是和位置編碼有關(guān)的參數(shù),也就是說,我們每次生成一個 Token 時,都需要傳入當(dāng)前位置的位置信息。位置編碼在 Transformer 中是比較重要的,因為 Self Attention 本質(zhì)上是無序的,而語言的先后順序在有些時候是很重要的
加載參數(shù)
fn byte_chunk_to_vec
where
T: Clone,
{
unsafe {
// 獲取起始位置的原始指針
let data = byte_chunk.as_ptr() as *const T;
// 從原始指針創(chuàng)建一個 T 類型的切片,注意number_elements是element的數(shù)量,而不是bytes
// 這句是 unsafe 的
let slice_data: &[T] = std::slice::from_raw_parts(data, number_elements);
// 將切片轉(zhuǎn)為 Vec,需要 T 可以 Clone
slice_data.to_vec()
}
}
fn byte_chunk_to_vec
這段代碼的目的是將一個字節(jié)切片轉(zhuǎn)換為元素類型為?T?的向量。但由于涉及到指針操作和不安全的代碼,因此要特別小心使用。where T: Clone?約束確保元素類型?T?可以被克隆來進行復(fù)制操作。byte_chunk?表示原始的字節(jié)切片,number_elements?表示結(jié)果向量中元素的個數(shù)。
unsafe的用法
獲取原始指針:代碼中使用?byte_chunk.as_ptr()?方法獲取?byte_chunk?字節(jié)切片的原始指針,然后將其轉(zhuǎn)換為?T?類型的常量指針。這個操作涉及到底層的指針操作,訪問和操作內(nèi)存的原始指針需要使用?unsafe?代碼塊。 從原始指針創(chuàng)建切片:使用?std::slice::from_raw_parts()?方法根據(jù)原始指針和元素數(shù)量創(chuàng)建了一個?T?類型的切片。這個方法也需要使用?unsafe?代碼塊,因為它接觸到了指針和內(nèi)存操作。
需要注意的是,使用?unsafe?關(guān)鍵字會打開 Rust 中的一些安全性限制。在使用?unsafe?代碼塊時,需要確保代碼正確地處理了指針和內(nèi)存操作,以避免造成內(nèi)存安全和未定義行為。
在這段代碼中使用?unsafe?主要是為了利用底層的指針操作和內(nèi)存操作,來直接操作原始數(shù)據(jù)。這樣可以避免不必要的數(shù)據(jù)復(fù)制,并提高性能。但同時,需要非常小心,確保代碼在使用指針和訪問內(nèi)存時不會引發(fā)潛在的錯誤。
as *const T的用法
在?let data = byte_chunk.as_ptr() as *const T?這段代碼中的?*?運算符是指針類型轉(zhuǎn)換中的解引用運算符,訪問指針?biāo)赶虻臄?shù)據(jù)。
在這段代碼中,byte_chunk.as_ptr()?返回了?byte_chunk?字節(jié)切片的原始指針(raw pointer),然后使用?as *const T?將其轉(zhuǎn)換為?*const T?類型的常量指針,以便與類型?T?相匹配。
需要注意的是,在這個上下文中的?*?運算符并不是乘法運算符,它是指針類型轉(zhuǎn)換語法的一部分。解引用運算符不能在 Safe Rust 中直接使用,因為它是一個不安全操作,需要使用?unsafe?關(guān)鍵字包裹。例如,let value = *ptr;?將會將?value?變量設(shè)置為指針?ptr?所指向位置的值。
加載模型
讀取原始的 bin 文件并指定對應(yīng)的參數(shù)大小
et token_embedding_table_size = config.vocab_size * config.dim;
// offset.. 表示從 offset 往后的所有元素
let token_embedding_table: Vec
這行代碼計算了?token_embedding_table?的大小,即詞嵌入表的大小。它將?config.vocab_size(詞匯表大小)乘以?config.dim(維度),并將結(jié)果賦值給變量?token_embedding_table_size。這行代碼創(chuàng)建了一個名為?token_embedding_table?的?Vec
因為向量數(shù)據(jù)通常取決于從某個字節(jié)切片轉(zhuǎn)換而來,所以需要借助?byte_chunk_to_vec?函數(shù)來執(zhí)行轉(zhuǎn)換。
簡單的內(nèi)容就到這里了,后面上硬菜,未完待續(xù)~~
柚子快報激活碼778899分享:RUST學(xué)習(xí)筆記(Day 3)
精彩文章
本文內(nèi)容根據(jù)網(wǎng)絡(luò)資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點和立場。
轉(zhuǎn)載請注明,如有侵權(quán),聯(lián)系刪除。