isa-model 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/inference/services/llm/openai_llm_service.py +116 -7
- {isa_model-0.2.8.dist-info → isa_model-0.3.0.dist-info}/METADATA +1 -1
- {isa_model-0.2.8.dist-info → isa_model-0.3.0.dist-info}/RECORD +5 -5
- {isa_model-0.2.8.dist-info → isa_model-0.3.0.dist-info}/WHEEL +0 -0
- {isa_model-0.2.8.dist-info → isa_model-0.3.0.dist-info}/top_level.txt +0 -0
@@ -65,17 +65,21 @@ class OpenAILLMService(BaseLLMService):
|
|
65
65
|
|
66
66
|
return bound_service
|
67
67
|
|
68
|
-
async def ainvoke(self, prompt: Union[str, List[Any], Any]) -> str:
|
68
|
+
async def ainvoke(self, prompt: Union[str, List[Any], Any]) -> Union[str, Any]:
|
69
69
|
"""Universal invocation method"""
|
70
70
|
if isinstance(prompt, str):
|
71
71
|
return await self.acompletion(prompt)
|
72
72
|
elif isinstance(prompt, list):
|
73
|
+
if not prompt:
|
74
|
+
raise ValueError("Empty message list provided")
|
75
|
+
|
73
76
|
# 检查是否是 LangGraph 消息对象
|
74
|
-
|
77
|
+
first_msg = prompt[0]
|
78
|
+
if hasattr(first_msg, 'content') and hasattr(first_msg, 'type'):
|
75
79
|
# 转换 LangGraph 消息对象为标准格式
|
76
80
|
converted_messages = []
|
77
81
|
for msg in prompt:
|
78
|
-
if hasattr(msg, 'type'):
|
82
|
+
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
79
83
|
# LangGraph 消息对象
|
80
84
|
msg_dict = {"content": msg.content}
|
81
85
|
|
@@ -106,14 +110,39 @@ class OpenAILLMService(BaseLLMService):
|
|
106
110
|
msg_dict["role"] = "user" # 默认为用户消息
|
107
111
|
|
108
112
|
converted_messages.append(msg_dict)
|
109
|
-
|
113
|
+
elif isinstance(msg, dict):
|
110
114
|
# 已经是字典格式
|
111
115
|
converted_messages.append(msg)
|
116
|
+
else:
|
117
|
+
# 处理其他类型(如字符串)
|
118
|
+
converted_messages.append({"role": "user", "content": str(msg)})
|
112
119
|
|
113
|
-
|
114
|
-
|
120
|
+
# 如果绑定了工具,返回 AIMessage 对象以兼容 LangGraph
|
121
|
+
if self._has_bound_tools():
|
122
|
+
return await self.achat_with_message_response(converted_messages)
|
123
|
+
else:
|
124
|
+
return await self.achat(converted_messages)
|
125
|
+
elif isinstance(first_msg, dict):
|
115
126
|
# 标准字典格式的消息
|
116
|
-
|
127
|
+
if self._has_bound_tools():
|
128
|
+
return await self.achat_with_message_response(prompt)
|
129
|
+
else:
|
130
|
+
return await self.achat(prompt)
|
131
|
+
else:
|
132
|
+
# 处理其他格式,如字符串列表
|
133
|
+
converted_messages = []
|
134
|
+
for msg in prompt:
|
135
|
+
if isinstance(msg, str):
|
136
|
+
converted_messages.append({"role": "user", "content": msg})
|
137
|
+
elif isinstance(msg, dict):
|
138
|
+
converted_messages.append(msg)
|
139
|
+
else:
|
140
|
+
converted_messages.append({"role": "user", "content": str(msg)})
|
141
|
+
|
142
|
+
if self._has_bound_tools():
|
143
|
+
return await self.achat_with_message_response(converted_messages)
|
144
|
+
else:
|
145
|
+
return await self.achat(converted_messages)
|
117
146
|
else:
|
118
147
|
raise ValueError("Prompt must be a string or a list of messages")
|
119
148
|
|
@@ -161,6 +190,86 @@ class OpenAILLMService(BaseLLMService):
|
|
161
190
|
logger.error(f"Error in chat completion: {e}")
|
162
191
|
raise
|
163
192
|
|
193
|
+
async def achat_with_message_response(self, messages: List[Dict[str, str]]) -> Any:
|
194
|
+
"""Chat completion method that returns message object for LangGraph compatibility"""
|
195
|
+
try:
|
196
|
+
temperature = self.config.get("temperature", 0.7)
|
197
|
+
max_tokens = self.config.get("max_tokens", 1024)
|
198
|
+
|
199
|
+
kwargs = {
|
200
|
+
"model": self.model_name,
|
201
|
+
"messages": messages,
|
202
|
+
"temperature": temperature,
|
203
|
+
"max_tokens": max_tokens
|
204
|
+
}
|
205
|
+
|
206
|
+
# Add tools if bound
|
207
|
+
if self._has_bound_tools():
|
208
|
+
kwargs["tools"] = self._get_bound_tools()
|
209
|
+
kwargs["tool_choice"] = "auto"
|
210
|
+
|
211
|
+
response = await self.client.chat.completions.create(**kwargs)
|
212
|
+
|
213
|
+
if response.usage:
|
214
|
+
self.last_token_usage = {
|
215
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
216
|
+
"completion_tokens": response.usage.completion_tokens,
|
217
|
+
"total_tokens": response.usage.total_tokens
|
218
|
+
}
|
219
|
+
|
220
|
+
# Update total usage
|
221
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
222
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
223
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
224
|
+
self.total_token_usage["requests_count"] += 1
|
225
|
+
|
226
|
+
# Create a LangGraph-compatible message object
|
227
|
+
message = response.choices[0].message
|
228
|
+
|
229
|
+
# Try to import LangGraph message classes
|
230
|
+
try:
|
231
|
+
from langchain_core.messages import AIMessage
|
232
|
+
|
233
|
+
# Create AIMessage with tool calls if present
|
234
|
+
if message.tool_calls:
|
235
|
+
tool_calls = []
|
236
|
+
for tc in message.tool_calls:
|
237
|
+
tool_calls.append({
|
238
|
+
"name": tc.function.name,
|
239
|
+
"args": json.loads(tc.function.arguments),
|
240
|
+
"id": tc.id
|
241
|
+
})
|
242
|
+
|
243
|
+
return AIMessage(
|
244
|
+
content=message.content or "",
|
245
|
+
tool_calls=tool_calls
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
return AIMessage(content=message.content or "")
|
249
|
+
|
250
|
+
except ImportError:
|
251
|
+
# Fallback: create a simple object with content attribute
|
252
|
+
class SimpleMessage:
|
253
|
+
def __init__(self, content, tool_calls=None):
|
254
|
+
self.content = content
|
255
|
+
self.tool_calls = tool_calls or []
|
256
|
+
|
257
|
+
if message.tool_calls:
|
258
|
+
tool_calls = []
|
259
|
+
for tc in message.tool_calls:
|
260
|
+
tool_calls.append({
|
261
|
+
"name": tc.function.name,
|
262
|
+
"args": json.loads(tc.function.arguments),
|
263
|
+
"id": tc.id
|
264
|
+
})
|
265
|
+
return SimpleMessage(message.content or "", tool_calls)
|
266
|
+
else:
|
267
|
+
return SimpleMessage(message.content or "")
|
268
|
+
|
269
|
+
except Exception as e:
|
270
|
+
logger.error(f"Error in chat completion with message response: {e}")
|
271
|
+
raise
|
272
|
+
|
164
273
|
async def _handle_tool_calls(self, assistant_message, original_messages: List[Dict[str, str]]) -> str:
|
165
274
|
"""Handle tool calls from the assistant"""
|
166
275
|
# Add assistant message with tool calls to conversation
|
@@ -40,7 +40,7 @@ isa_model/inference/services/embedding/openai_embed_service.py,sha256=47DEQpj8HB
|
|
40
40
|
isa_model/inference/services/llm/__init__.py,sha256=C6t9w33j3Ap4oGcJal9-htifKe0rxwws_kC3F-_B_Ps,341
|
41
41
|
isa_model/inference/services/llm/base_llm_service.py,sha256=hf4egO9_s3rOQYwyhDS6O_8ECIAltkj4Ir89PTosraE,8381
|
42
42
|
isa_model/inference/services/llm/ollama_llm_service.py,sha256=EfLdoovyrChYBlGreQukpSZt5l6DkfXwjjmPPovmm70,12934
|
43
|
-
isa_model/inference/services/llm/openai_llm_service.py,sha256=
|
43
|
+
isa_model/inference/services/llm/openai_llm_service.py,sha256=XarEWzPg3DnITxrhkVtdR1RC0puklFIAUALgC61P8LM,19279
|
44
44
|
isa_model/inference/services/llm/triton_llm_service.py,sha256=ZFo7JoZ799Nvyi8Cz1jfWOa6TUn0hDRJtBrotadMAd4,17673
|
45
45
|
isa_model/inference/services/ml/base_ml_service.py,sha256=mLBA6ENowa3KVzNqHyhWxf_Pr-cJJj84lDE4TniPzYI,2894
|
46
46
|
isa_model/inference/services/ml/sklearn_ml_service.py,sha256=Lf9JrwvI25lca7JBbjB_e66eAUtXFbwxZ3Hs13dVGkA,5512
|
@@ -80,7 +80,7 @@ isa_model/training/core/config.py,sha256=oqgKpBvtzrN6jwLIQYQ2707lH6nmjrktRiSxp9i
|
|
80
80
|
isa_model/training/core/dataset.py,sha256=XCFsnf0NUMU1dJpdvo_CAMyvXB-9_RCUEiy8TU50e20,7802
|
81
81
|
isa_model/training/core/trainer.py,sha256=h5TjqjdFr0Fsv5y4-0siy1KmOlqLfliVaUXybvuoeXU,26932
|
82
82
|
isa_model/training/core/utils.py,sha256=Nik0M2ssfNbWqP6fKO0Kfyhzr_H6Q19ioxB-qCYbn5E,8387
|
83
|
-
isa_model-0.
|
84
|
-
isa_model-0.
|
85
|
-
isa_model-0.
|
86
|
-
isa_model-0.
|
83
|
+
isa_model-0.3.0.dist-info/METADATA,sha256=vKAOkCdWjst6VFeisv1QxEHUEzxJpMOd5FO-RMG_C6M,12226
|
84
|
+
isa_model-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
85
|
+
isa_model-0.3.0.dist-info/top_level.txt,sha256=eHSy_Xb3kNkh2kK11mi1mZh0Wz91AQ5b8k2KFYO-rE8,10
|
86
|
+
isa_model-0.3.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|