Files
AgentCoord/backend/AgentCoord/LLMAPI/LLMAPI.py

331 lines
11 KiB
Python
Raw Normal View History

2024-04-07 15:04:00 +08:00
import asyncio
import openai
import yaml
from termcolor import colored
import os
# load config (apikey, apibase, model)
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
2025-12-07 17:18:10 +08:00
yaml_data = {}
2024-04-07 15:04:00 +08:00
try:
with open(yaml_file, "r", encoding="utf-8") as file:
yaml_data = yaml.safe_load(file)
except Exception:
2025-12-07 17:18:10 +08:00
yaml_data = {}
2024-04-07 15:04:00 +08:00
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") or yaml_data.get(
"OPENAI_API_BASE", "https://api.openai.com"
)
openai.api_base = OPENAI_API_BASE
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or yaml_data.get(
"OPENAI_API_KEY", ""
)
openai.api_key = OPENAI_API_KEY
MODEL: str = os.getenv("OPENAI_API_MODEL") or yaml_data.get(
"OPENAI_API_MODEL", "gpt-4-turbo-preview"
)
FAST_DESIGN_MODE: bool = os.getenv("FAST_DESIGN_MODE")
if FAST_DESIGN_MODE is None:
FAST_DESIGN_MODE = yaml_data.get("FAST_DESIGN_MODE", False)
else:
FAST_DESIGN_MODE = FAST_DESIGN_MODE.lower() in ["true", "1", "yes"]
GROQ_API_KEY = os.getenv("GROQ_API_KEY") or yaml_data.get("GROQ_API_KEY", "")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") or yaml_data.get(
"MISTRAL_API_KEY", ""
)
# for LLM completion
def LLM_Completion(
messages: list[dict], stream: bool = True, useGroq: bool = True
) -> str:
2025-12-07 17:18:10 +08:00
# 增强消息验证:确保所有消息的 role 和 content 非空且不是空白字符串
if not messages or len(messages) == 0:
raise ValueError("Messages list is empty")
# print(f"[DEBUG] LLM_Completion received {len(messages)} messages", flush=True)
for i, msg in enumerate(messages):
if not isinstance(msg, dict):
raise ValueError(f"Message at index {i} is not a dictionary")
if not msg.get("role") or str(msg.get("role")).strip() == "":
raise ValueError(f"Message at index {i} has empty 'role'")
if not msg.get("content") or str(msg.get("content")).strip() == "":
raise ValueError(f"Message at index {i} has empty 'content'")
# 额外验证确保content不会因为格式化问题变成空
content = str(msg.get("content")).strip()
if len(content) < 10: # 设置最小长度阈值
print(f"[WARNING] Message at index {i} has very short content: '{content}'", flush=True)
# 修改1
if not GROQ_API_KEY:
useGroq = False
elif not useGroq or not FAST_DESIGN_MODE:
2024-04-07 15:04:00 +08:00
force_gpt4 = True
useGroq = False
else:
force_gpt4 = False
useGroq = True
if stream:
try:
loop = asyncio.get_event_loop()
except RuntimeError as ex:
if "There is no current event loop in thread" in str(ex):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if useGroq:
if force_gpt4:
return loop.run_until_complete(
_achat_completion_json(messages=messages)
)
else:
return loop.run_until_complete(
_achat_completion_stream_groq(messages=messages)
)
else:
return loop.run_until_complete(
_achat_completion_stream(messages=messages)
)
# return asyncio.run(_achat_completion_stream(messages = messages))
else:
return _chat_completion(messages=messages)
async def _achat_completion_stream_groq(messages: list[dict]) -> str:
2024-04-07 17:21:15 +08:00
from groq import AsyncGroq
2024-04-07 15:04:00 +08:00
client = AsyncGroq(api_key=GROQ_API_KEY)
max_attempts = 5
for attempt in range(max_attempts):
print("Attempt to use Groq (Fase Design Mode):")
try:
stream = await client.chat.completions.create(
messages=messages,
# model='gemma-7b-it',
model="mixtral-8x7b-32768",
# model='llama2-70b-4096',
temperature=0.3,
response_format={"type": "json_object"},
stream=False,
)
break
except Exception:
if attempt < max_attempts - 1: # i is zero indexed
continue
else:
raise "failed"
full_reply_content = stream.choices[0].message.content
print(colored(full_reply_content, "blue", "on_white"), end="")
print()
return full_reply_content
async def _achat_completion_stream_mixtral(messages: list[dict]) -> str:
2024-04-07 17:21:15 +08:00
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
2024-04-07 15:04:00 +08:00
client = MistralClient(api_key=MISTRAL_API_KEY)
# client=AsyncGroq(api_key=GROQ_API_KEY)
max_attempts = 5
for attempt in range(max_attempts):
try:
messages[len(messages) - 1]["role"] = "user"
stream = client.chat(
messages=[
ChatMessage(
role=message["role"], content=message["content"]
)
for message in messages
],
# model = "mistral-small-latest",
model="open-mixtral-8x7b",
# response_format={"type": "json_object"},
)
break # If the operation is successful, break the loop
except Exception:
if attempt < max_attempts - 1: # i is zero indexed
continue
else:
raise "failed"
full_reply_content = stream.choices[0].message.content
print(colored(full_reply_content, "blue", "on_white"), end="")
print()
return full_reply_content
async def _achat_completion_stream_gpt35(messages: list[dict]) -> str:
openai.api_key = OPENAI_API_KEY
openai.api_base = OPENAI_API_BASE
2025-12-07 17:18:10 +08:00
kwargs = {
"messages": messages,
"max_tokens": 4096,
"n": 1,
"stop": None,
"temperature": 0.3,
"timeout": 3,
"model": "gpt-3.5-turbo-16k",
"stream": True,
}
# print("[DEBUG] about to call acreate with kwargs:", type(kwargs), kwargs)
assert kwargs is not None, "kwargs is None right before acreate!"
assert isinstance(kwargs, dict), "kwargs must be dict!"
response = await openai.ChatCompletion.acreate(**kwargs)
2024-04-07 15:04:00 +08:00
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get(
"delta", {}
) # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(
colored(chunk_message["content"], "blue", "on_white"),
end="",
)
print()
full_reply_content = "".join(
[m.get("content", "") for m in collected_messages]
)
return full_reply_content
async def _achat_completion_json(messages: list[dict]) -> str:
openai.api_key = OPENAI_API_KEY
openai.api_base = OPENAI_API_BASE
max_attempts = 5
for attempt in range(max_attempts):
try:
stream = await openai.ChatCompletion.acreate(
messages=messages,
max_tokens=4096,
n=1,
stop=None,
temperature=0.3,
timeout=3,
model=MODEL,
response_format={"type": "json_object"},
)
break
except Exception:
if attempt < max_attempts - 1: # i is zero indexed
continue
else:
raise "failed"
full_reply_content = stream.choices[0].message.content
print(colored(full_reply_content, "blue", "on_white"), end="")
print()
return full_reply_content
async def _achat_completion_stream(messages: list[dict]) -> str:
2025-12-07 17:18:10 +08:00
# print(">>>> _achat_completion_stream 被调用", flush=True)
# print(">>>> messages 实参 =", messages, flush=True)
# print(">>>> messages 类型 =", type(messages), flush=True)
2024-04-07 15:04:00 +08:00
openai.api_key = OPENAI_API_KEY
openai.api_base = OPENAI_API_BASE
response = await openai.ChatCompletion.acreate(
**_cons_kwargs(messages), stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get(
"delta", {}
) # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(
colored(chunk_message["content"], "blue", "on_white"),
end="",
)
print()
full_reply_content = "".join(
[m.get("content", "") for m in collected_messages]
)
return full_reply_content
def _chat_completion(messages: list[dict]) -> str:
rsp = openai.ChatCompletion.create(**_cons_kwargs(messages))
content = rsp["choices"][0]["message"]["content"]
return content
def _cons_kwargs(messages: list[dict]) -> dict:
kwargs = {
"messages": messages,
"max_tokens": 4096,
"temperature": 0.5,
}
2025-12-07 17:18:10 +08:00
print("[DEBUG] kwargs =", kwargs)
assert isinstance(kwargs, dict), f"_cons_kwargs returned {type(kwargs)}, must be dict"
# 添加调试信息
print(f'[DEBUG] _cons_kwargs messages: {messages}', flush=True)
# 检查并修复消息中的null值
for i, msg in enumerate(messages):
# 确保msg是字典
if not isinstance(msg, dict):
print(f"[ERROR] Message {i} is not a dictionary: {msg}", flush=True)
messages[i] = {"role": "user", "content": str(msg) if msg is not None else ""}
continue
# 确保role和content存在且不为None
if "role" not in msg or msg["role"] is None:
print(f"[ERROR] Message {i} missing role, setting to 'user'", flush=True)
msg["role"] = "user"
else:
msg["role"] = str(msg["role"]).strip()
if "content" not in msg or msg["content"] is None:
print(f"[ERROR] Message {i} missing content, setting to empty string", flush=True)
msg["content"] = ""
else:
msg["content"] = str(msg["content"]).strip()
# 根据不同的API提供商调整参数
if "deepseek" in MODEL.lower():
# DeepSeek API特殊处理
print("[DEBUG] DeepSeek API detected, adjusting parameters", flush=True)
kwargs.pop("n", None) # 移除n参数DeepSeek可能不支持
if "timeout" in kwargs:
kwargs.pop("timeout", None)
# DeepSeek可能不支持stop参数
kwargs.pop("stop", None)
else:
# OpenAI兼容的API
kwargs["n"] = 1
kwargs["stop"] = None
kwargs["timeout"] = 3
kwargs["model"] = MODEL
# 确保messages列表中的每个元素都有有效的role和content
kwargs["messages"] = [msg for msg in messages if msg["role"] and msg["content"]]
print(f"[DEBUG] Final kwargs for API call: {kwargs.keys()}", flush=True)
return kwargs