柚子快報(bào)邀請(qǐng)碼778899分享:人工智能 給rwkv
柚子快報(bào)邀請(qǐng)碼778899分享:人工智能 給rwkv
項(xiàng)目地址
rwkv_pytorch
服務(wù)端
import json
import uuid
import time
import torch
from src.model import RWKV_RNN
from src.sampler import sample_logits
from src.rwkv_tokenizer import RWKV_TOKENIZER
from flask import Flask, request, jsonify, Response
app = Flask(__name__)
# 初始化模型和分詞器
def init_model():
# 模型參數(shù)配置
args = {
'MODEL_NAME': 'E:/RWKV_Pytorch/weight/RWKV-x060-World-1B6-v2-20240208-ctx4096',
'vocab_size': 65536,
'device': "cpu",
'onnx_opset': '18',
}
device = args['device']
assert device in ['cpu', 'cuda', 'musa', 'npu']
if device == "musa":
import torch_musa
elif device == "npu":
import torch_npu
model = RWKV_RNN(args).to(device)
tokenizer = RWKV_TOKENIZER("asset/rwkv_vocab_v20230424.txt")
return model, tokenizer, device
def format_messages_to_prompt(messages):
formatted_prompt = ""
# 定義角色映射到期望的名稱(chēng)
role_names = {
"system": "System",
"assistant": "Assistant",
"user": "User"
}
# 遍歷消息并格式化
for message in messages:
role = role_names.get(message['role'], 'Unknown') # 獲取角色名稱(chēng),默認(rèn)為'Unknown'
content = message['content']
formatted_prompt += f"{role}: {content}\n\n" # 添加角色和內(nèi)容到提示,并添加換行符
formatted_prompt += "Assistant: "
return formatted_prompt
def generate_text_stream(prompt: str, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):
encoded_input = tokenizer.encode([prompt])
token = torch.tensor(encoded_input).long().to(device)
state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)
with torch.no_grad():
token_out, state_out = model.forward_parallel(token, state)
del token
out = token_out[:, -1]
generated_tokens = ''
completion_tokens = 0
if_max_token = True
for step in range(max_tokens):
token_sampled = sample_logits(out, temperature, top_p)
with torch.no_grad():
out, state = model.forward(token_sampled, state)
last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]
generated_tokens += last_token
completion_tokens += 1
if generated_tokens.endswith(tuple(stop)):
if_max_token = False
response = {
"object": "chat.completion.chunk",
"model": "rwkv",
"choices": [{
"delta": "",
"index": 0,
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(response)}\n\n"
else:
response = {
"object": "chat.completion.chunk",
"model": "rwkv",
"choices": [{
"delta": {"content": last_token},
"index": 0,
"finish_reason": None
}]
}
yield f"data: {json.dumps(response)}\n\n"
if if_max_token:
response = {
"object": "chat.completion.chunk",
"model": "rwkv",
"choices": [{
"delta": "",
"index": 0,
"finish_reason": "length"
}]
}
yield f"data: {json.dumps(response)}\n\n"
yield f"data:[DONE]\n\n"
def generate_text(prompt, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):
encoded_input = tokenizer.encode([prompt])
token = torch.tensor(encoded_input).long().to(device)
state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)
prompt_tokens = len(encoded_input[0])
with torch.no_grad():
token_out, state_out = model.forward_parallel(token, state)
del token
out = token_out[:, -1]
completion_tokens = 0
if_max_token = True
generated_tokens = ''
for step in range(max_tokens):
token_sampled = sample_logits(out, temperature, top_p)
with torch.no_grad():
out, state = model.forward(token_sampled, state)
# 判斷是否達(dá)到停止條件
last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]
completion_tokens += 1
print(last_token, end='')
generated_tokens += last_token
for stop_token in stop:
if generated_tokens.endswith(stop_token):
generated_tokens = generated_tokens.replace(stop_token, "") # 替換掉終止token
if_max_token = False
break
# 如果末尾含有 stop 列表中的字符串,則停止生成
if not if_max_token:
break
total_tokens = prompt_tokens + completion_tokens
usage = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
return generated_tokens, if_max_token, usage
@app.route('/events', methods=['POST'])
def sse_request():
try:
# 從查詢(xún)字符串中獲取參數(shù)
data = request.json
messages = data.get('messages', [])
stream = data.get('stream', True) == True
temperature = float(data.get('temperature', 0.5))
top_p = float(data.get('top_p', 0.9))
max_tokens = int(data.get('max_tokens', 100))
stop = data.get('stop', ['\n\nUser'])
prompt = format_messages_to_prompt(messages)
if stream:
return Response(generate_text_stream(prompt=prompt, temperature=temperature, top_p=top_p,
max_tokens=max_tokens, stop=stop),
content_type='text/event-stream')
else:
completion, if_max_token, usage = generate_text(prompt, temperature=temperature, top_p=top_p,
max_tokens=max_tokens, stop=stop)
finish_reason = "stop" if if_max_token else "length"
unique_id = str(uuid.uuid4())
current_timestamp = int(time.time())
response = {
"id": unique_id,
"object": "chat.completion",
"created": current_timestamp,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": completion,
},
"finish_reason": finish_reason
}],
"usage": usage
}
return json.dumps(response)
except Exception as e:
return json.dumps({"error": str(e)}), 500
if __name__ == '__main__':
model, tokenizer, device = init_model()
app.run(debug=False)
解釋
首先引入了需要的庫(kù),包括json用于處理JSON數(shù)據(jù),uuid用于生成唯一標(biāo)識(shí)符,time用于獲取當(dāng)前時(shí)間戳,torch用于構(gòu)建和運(yùn)行模型,F(xiàn)lask用于構(gòu)建API。定義了一個(gè)名為app的Flask應(yīng)用。init_model函數(shù)用于初始化模型和分詞器。其中,模型參數(shù)通過(guò)字典args指定。format_messages_to_prompt函數(shù)用于將消息格式化為提示字符串,以便于模型生成回復(fù)。遍歷消息列表,獲取每個(gè)消息的角色和內(nèi)容,并添加到提示字符串中。generate_text_stream函數(shù)用于以流的形式生成文本。首先將輸入的提示字符串編碼為張量,然后利用模型生成回復(fù),并利用yield關(guān)鍵字將回復(fù)以SSE(服務(wù)器發(fā)送事件)的形式返回。generate_text函數(shù)用于一次性生成完整的文本回復(fù)。與generate_text_stream函數(shù)類(lèi)似,不同的是返回的是完整的回復(fù)字符串。sse_request函數(shù)是Flask應(yīng)用的主要邏輯,用于處理POST請(qǐng)求。從請(qǐng)求的JSON數(shù)據(jù)中獲取參數(shù),并根據(jù)參數(shù)的設(shè)置調(diào)用相應(yīng)的生成函數(shù)。如果參數(shù)中設(shè)置了stream=True,則返回流式生成的回復(fù);否則返回一次性生成的回復(fù)。在__main__函數(shù)中初始化模型和分詞器,然后運(yùn)行Flask應(yīng)用。
客戶(hù)端
import json
import requests
from requests import RequestException
# 配置服務(wù)器URL
url = 'http://localhost:5000/events' # 假設(shè)您的Flask應(yīng)用運(yùn)行在本地端口5000上
# POST請(qǐng)求示例
def post_request_stream():
# 構(gòu)造請(qǐng)求數(shù)據(jù)
data = {
'messages': [
{'role': 'system', 'content': '你好!'},
{'role': 'user', 'content': '你能告訴我今天的天氣嗎?'}
],
'temperature': 0.5,
'top_p': 0.9,
'max_tokens': 100,
'stop': ['\n\nUser'],
'stream':True
}
# 使用 requests 庫(kù)來(lái)連接服務(wù)器,并傳遞參數(shù)
try:
with requests.post(url, json=data, stream=True) as r:
for line in r.iter_lines():
if line:
# 當(dāng)服務(wù)器發(fā)送消息時(shí),解碼并打印出來(lái)
decoded_line = line.decode('utf-8')
print(json.loads(decoded_line[5:])["choices"][0]["delta"], end="")
except RequestException as e:
print(f'An error occurred: {e}')
def post_request():
# 構(gòu)造請(qǐng)求數(shù)據(jù)
data = {
'messages': [
{'role': 'system', 'content': '你好!'},
{'role': 'user', 'content': '你能告訴我今天的天氣嗎?'}
],
'temperature': 0.5,
'top_p': 0.9,
'max_tokens': 100,
'stop': ['\n\nUser'],
'stream':False
}
# 使用 requests 庫(kù)來(lái)連接服務(wù)器,并傳遞參數(shù)
try:
with requests.post(url, json=data, stream=True) as r:
for line in r.iter_lines():
if line:
# 當(dāng)服務(wù)器發(fā)送消息時(shí),解碼并打印出來(lái)
decoded_line = line.decode('utf-8')
res=json.loads(decoded_line)
print(res)
except RequestException as e:
print(f'An error occurred: {e}')
if __name__ == '__main__':
# post_request()
post_request_stream()
解釋
這段代碼是一個(gè)用于向服務(wù)器發(fā)送POST請(qǐng)求的示例代碼。
首先,我們需要導(dǎo)入一些必要的庫(kù)。json庫(kù)用于處理JSON數(shù)據(jù),requests庫(kù)用于發(fā)送HTTP請(qǐng)求,RequestException用于處理請(qǐng)求異常。
接下來(lái),我們需要配置服務(wù)器的URL。在這個(gè)示例中,假設(shè)服務(wù)器運(yùn)行在本地端口5000上。
代碼中定義了兩個(gè)函數(shù)post_request_stream和post_request,分別用于發(fā)送帶有流式響應(yīng)和非流式響應(yīng)的POST請(qǐng)求。
post_request_stream函數(shù)構(gòu)造了一個(gè)包含各種參數(shù)的數(shù)據(jù)字典,并使用requests.post方法發(fā)送POST請(qǐng)求。在請(qǐng)求的參數(shù)中,stream參數(shù)被設(shè)置為T(mén)rue,表示我們希望獲得一個(gè)流式的響應(yīng)。接著,我們使用r.iter_lines()方法來(lái)迭代獲取服務(wù)器發(fā)送的消息。每收到一行消息,我們將其解碼并打印出來(lái)。
post_request函數(shù)的代碼結(jié)構(gòu)與post_request_stream函數(shù)相似,不同之處在于stream參數(shù)被設(shè)置為False,表示我們希望獲得一個(gè)非流式的響應(yīng)。
最后,在程序的主體部分,我們調(diào)用post_request_stream函數(shù)來(lái)發(fā)送流式的POST請(qǐng)求,并注釋掉了post_request函數(shù)的調(diào)用。
柚子快報(bào)邀請(qǐng)碼778899分享:人工智能 給rwkv
精彩內(nèi)容
本文內(nèi)容根據(jù)網(wǎng)絡(luò)資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點(diǎn)和立場(chǎng)。
轉(zhuǎn)載請(qǐng)注明,如有侵權(quán),聯(lián)系刪除。