python-tools/xinghuo.py
2024-07-22 18:07:41 +09:00

116 lines
4.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
星火认知大模型调用
文档地址: https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html
各版本区别:
Spark 4.0 Ultra
全面对标GPT 4-Turbo
最强大的大语言模型版本文本生成、语言理解、知识问答、逻辑推理、数学能力等方面实现超越GPT 4-Turbo优化联网搜索链路提供更精准回答。
Spark3.5 Max
整体接近GPT 4-Turbo
旗舰级大语言模型,具有千亿级参数,核心能力全面升级,具备更强的数学、中文、代码和多模态能力。适用于数理计算、逻辑推理等对效果有更高要求的业务场景。
Spark Pro
支持128K长文本版本
延时更低
专业级大语言模型,具有百亿级参数,在医疗、教育和代码等场景进行了专项优化,搜索场景延时更低。适用于文本、智能问答等对性能和响应速度有更高要求的业务场景。
Spark Lite
能力全面
灵活经济
轻量级大语言模型,具有更高的响应速度,适用于低算力推理与模型精调等定制化场景,可满足企业产品快速验证的需求。
"""
from openai import OpenAI
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
from rich.prompt import Prompt
from rich.panel import Panel
from dotenv import load_dotenv
import os
try:
console = Console()
load_dotenv()
question = Prompt.ask("请输入问题?")
if not question:
console.log("问题不能为空")
exit(0)
model = Prompt.ask("""
使用哪个模型general指向Lite版本generalv2指向V2.0版本generalv3指向Pro版本generalv3.5指向Max版本4.0Ultra指向4.0 Ultra版本
""", choices=['general', 'generalv3', 'generalv3.5', '4.0Ultra'],
default="generalv3.5")
table = Table(title="消耗明细")
table.add_column("会话id", justify="center", style="cyan", no_wrap=True)
table.add_column("code", justify="center", style="cyan")
table.add_column("描述信息", justify="center", style="cyan")
table.add_column("历史消耗token量", justify="center", style="magenta")
table.add_column("回答消耗token量", justify="center", style="yellow")
table.add_column("总消耗token量", justify="center", style="red")
message = []
client = OpenAI(
base_url=os.getenv("BASE_URL"),
api_key=os.getenv("API_KEY"),
)
while question:
with console.status("[bold green]请求中...") as status:
console.log('请求履历:')
message.append({
"role": "user",
"content": question
})
console.log(message)
console.log('请求发送中...')
response = client.chat.completions.create(
model=model,
messages=message
, stream=True
)
if response.response.status_code != 200:
raise Exception('Request failed')
data = []
content = ''
total_repeat = ''
console.log('请求完成。 :smiley:')
console.log('开始接收数据...')
for chunk in response:
if chunk.code != 0:
raise Exception(chunk.message)
content += chunk.choices[0].delta.content
total_repeat += content
data.append(chunk)
while '\n' in content:
part, content = content.split('\n', 1)
console.log(Markdown(part))
if content:
console.log(Markdown(content))
res = data[-1]
table.add_row(
res.id,
str(res.code),
str(res.message),
str(res.usage.prompt_tokens),
str(res.usage.completion_tokens),
str(res.usage.total_tokens),
)
message.append({
"role": "assistant",
"content": total_repeat
})
console.log('模型回答结束。 :smiley:')
question = Prompt.ask("如有问题请继续提问...(按回车结束)")
except KeyboardInterrupt as e:
console.log("用户主动退出")
exit(0)
except Exception as e:
console.log(e,log_locals=True)
exit(0)
console.print(table)