116 lines
4.4 KiB
Python
116 lines
4.4 KiB
Python
"""
|
||
星火认知大模型调用
|
||
文档地址: https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html
|
||
各版本区别:
|
||
Spark 4.0 Ultra
|
||
全面对标GPT 4-Turbo
|
||
最强大的大语言模型版本,文本生成、语言理解、知识问答、逻辑推理、数学能力等方面实现超越GPT 4-Turbo,优化联网搜索链路,提供更精准回答。
|
||
|
||
Spark3.5 Max
|
||
整体接近GPT 4-Turbo
|
||
旗舰级大语言模型,具有千亿级参数,核心能力全面升级,具备更强的数学、中文、代码和多模态能力。适用于数理计算、逻辑推理等对效果有更高要求的业务场景。
|
||
Spark Pro
|
||
支持128K长文本版本
|
||
延时更低
|
||
专业级大语言模型,具有百亿级参数,在医疗、教育和代码等场景进行了专项优化,搜索场景延时更低。适用于文本、智能问答等对性能和响应速度有更高要求的业务场景。
|
||
Spark Lite
|
||
能力全面
|
||
灵活经济
|
||
轻量级大语言模型,具有更高的响应速度,适用于低算力推理与模型精调等定制化场景,可满足企业产品快速验证的需求。
|
||
"""
|
||
from openai import OpenAI
|
||
from rich.console import Console
|
||
from rich.markdown import Markdown
|
||
from rich.table import Table
|
||
from rich.prompt import Prompt
|
||
from rich.panel import Panel
|
||
from dotenv import load_dotenv
|
||
import os
|
||
|
||
try:
|
||
console = Console()
|
||
load_dotenv()
|
||
question = Prompt.ask("请输入问题?")
|
||
if not question:
|
||
console.log("问题不能为空")
|
||
exit(0)
|
||
model = Prompt.ask("""
|
||
使用哪个模型(general指向Lite版本;generalv2指向V2.0版本;generalv3指向Pro版本;generalv3.5指向Max版本;4.0Ultra指向4.0 Ultra版本)
|
||
""", choices=['general', 'generalv3', 'generalv3.5', '4.0Ultra'],
|
||
default="generalv3.5")
|
||
|
||
table = Table(title="消耗明细")
|
||
table.add_column("会话id", justify="center", style="cyan", no_wrap=True)
|
||
table.add_column("code", justify="center", style="cyan")
|
||
table.add_column("描述信息", justify="center", style="cyan")
|
||
table.add_column("历史消耗token量", justify="center", style="magenta")
|
||
table.add_column("回答消耗token量", justify="center", style="yellow")
|
||
table.add_column("总消耗token量", justify="center", style="red")
|
||
|
||
message = []
|
||
|
||
client = OpenAI(
|
||
base_url=os.getenv("BASE_URL"),
|
||
api_key=os.getenv("API_KEY"),
|
||
)
|
||
|
||
while question:
|
||
with console.status("[bold green]请求中...") as status:
|
||
console.log('请求履历:')
|
||
message.append({
|
||
"role": "user",
|
||
"content": question
|
||
})
|
||
console.log(message)
|
||
console.log('请求发送中...')
|
||
response = client.chat.completions.create(
|
||
model=model,
|
||
messages=message
|
||
, stream=True
|
||
)
|
||
|
||
if response.response.status_code != 200:
|
||
raise Exception('Request failed')
|
||
|
||
data = []
|
||
content = ''
|
||
total_repeat = ''
|
||
console.log('请求完成。 :smiley:')
|
||
console.log('开始接收数据...')
|
||
for chunk in response:
|
||
if chunk.code != 0:
|
||
raise Exception(chunk.message)
|
||
content += chunk.choices[0].delta.content
|
||
total_repeat += content
|
||
data.append(chunk)
|
||
while '\n' in content:
|
||
part, content = content.split('\n', 1)
|
||
console.log(Markdown(part))
|
||
|
||
if content:
|
||
console.log(Markdown(content))
|
||
|
||
res = data[-1]
|
||
table.add_row(
|
||
res.id,
|
||
str(res.code),
|
||
str(res.message),
|
||
str(res.usage.prompt_tokens),
|
||
str(res.usage.completion_tokens),
|
||
str(res.usage.total_tokens),
|
||
)
|
||
|
||
message.append({
|
||
"role": "assistant",
|
||
"content": total_repeat
|
||
})
|
||
console.log('模型回答结束。 :smiley:')
|
||
question = Prompt.ask("如有问题请继续提问...(按回车结束)")
|
||
except KeyboardInterrupt as e:
|
||
console.log("用户主动退出")
|
||
exit(0)
|
||
except Exception as e:
|
||
console.log(e,log_locals=True)
|
||
exit(0)
|
||
|
||
console.print(table) |