isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/core/model_manager.py +69 -4
- isa_model/inference/ai_factory.py +335 -46
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +48 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/base_llm_service.py +88 -192
- isa_model/inference/services/llm/llm_adapter.py +459 -0
- isa_model/inference/services/llm/ollama_llm_service.py +111 -185
- isa_model/inference/services/llm/openai_llm_service.py +115 -360
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +11 -3
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
@@ -3,415 +3,177 @@ import os
|
|
3
3
|
import json
|
4
4
|
from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
|
5
5
|
|
6
|
-
# 使用官方 OpenAI
|
6
|
+
# 使用官方 OpenAI 库
|
7
7
|
from openai import AsyncOpenAI
|
8
|
-
from dotenv import load_dotenv
|
9
8
|
|
10
9
|
from isa_model.inference.services.llm.base_llm_service import BaseLLMService
|
11
10
|
from isa_model.inference.providers.base_provider import BaseProvider
|
12
|
-
|
13
|
-
# 加载 .env.local 文件中的环境变量
|
14
|
-
load_dotenv(dotenv_path='.env.local')
|
11
|
+
from isa_model.inference.billing_tracker import ServiceType
|
15
12
|
|
16
13
|
logger = logging.getLogger(__name__)
|
17
14
|
|
18
15
|
class OpenAILLMService(BaseLLMService):
|
19
|
-
"""OpenAI LLM service implementation"""
|
16
|
+
"""OpenAI LLM service implementation with unified invoke interface"""
|
20
17
|
|
21
|
-
def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-
|
18
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-4.1-nano"):
|
22
19
|
super().__init__(provider, model_name)
|
23
20
|
|
24
|
-
#
|
21
|
+
# Get full configuration from provider (including sensitive data)
|
22
|
+
provider_config = provider.get_full_config()
|
23
|
+
|
24
|
+
# Initialize AsyncOpenAI client with provider configuration
|
25
25
|
try:
|
26
|
-
|
27
|
-
|
26
|
+
if not provider_config.get("api_key"):
|
27
|
+
raise ValueError("OpenAI API key not found in provider configuration")
|
28
28
|
|
29
29
|
self.client = AsyncOpenAI(
|
30
|
-
api_key=api_key,
|
31
|
-
base_url=base_url
|
30
|
+
api_key=provider_config["api_key"],
|
31
|
+
base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
|
32
|
+
organization=provider_config.get("organization")
|
32
33
|
)
|
33
|
-
|
34
|
-
logger.
|
35
|
-
|
34
|
+
|
35
|
+
logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
|
36
|
+
|
37
|
+
except Exception as e:
|
38
|
+
logger.error(f"Failed to initialize OpenAI client: {e}")
|
39
|
+
raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
|
36
40
|
|
37
41
|
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
38
42
|
self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
|
39
43
|
|
40
|
-
# Tool binding attributes
|
41
|
-
self._bound_tools: List[Dict[str, Any]] = []
|
42
|
-
self._tool_binding_kwargs: Dict[str, Any] = {}
|
43
|
-
self._tool_functions: Dict[str, Callable] = {}
|
44
|
-
|
45
|
-
logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
|
46
44
|
|
47
45
|
def _create_bound_copy(self) -> 'OpenAILLMService':
|
48
46
|
"""Create a copy of this service for tool binding"""
|
49
47
|
bound_service = OpenAILLMService(self.provider, self.model_name)
|
50
48
|
bound_service._bound_tools = self._bound_tools.copy()
|
51
|
-
bound_service._tool_binding_kwargs = self._tool_binding_kwargs.copy()
|
52
|
-
bound_service._tool_functions = self._tool_functions.copy()
|
53
49
|
return bound_service
|
54
50
|
|
55
|
-
def bind_tools(self, tools: List[
|
56
|
-
"""
|
51
|
+
def bind_tools(self, tools: List[Any], **kwargs) -> 'OpenAILLMService':
|
52
|
+
"""
|
53
|
+
Bind tools to this LLM service for function calling
|
54
|
+
|
55
|
+
Args:
|
56
|
+
tools: List of tools (functions, dicts, or LangChain tools)
|
57
|
+
**kwargs: Additional arguments for tool binding
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
New LLM service instance with tools bound
|
61
|
+
"""
|
62
|
+
# Create a copy of this service
|
57
63
|
bound_service = self._create_bound_copy()
|
58
|
-
bound_service._bound_tools = self._convert_tools_to_schema(tools)
|
59
|
-
bound_service._tool_binding_kwargs = kwargs
|
60
64
|
|
61
|
-
#
|
62
|
-
|
63
|
-
if callable(tool):
|
64
|
-
bound_service._tool_functions[tool.__name__] = tool
|
65
|
+
# Use base class method to bind tools
|
66
|
+
bound_service._bound_tools = tools
|
65
67
|
|
66
68
|
return bound_service
|
67
69
|
|
68
|
-
async def ainvoke(self,
|
69
|
-
"""
|
70
|
-
if isinstance(prompt, str):
|
71
|
-
return await self.acompletion(prompt)
|
72
|
-
elif isinstance(prompt, list):
|
73
|
-
if not prompt:
|
74
|
-
raise ValueError("Empty message list provided")
|
75
|
-
|
76
|
-
# 检查是否是 LangGraph 消息对象
|
77
|
-
first_msg = prompt[0]
|
78
|
-
if hasattr(first_msg, 'content') and hasattr(first_msg, 'type'):
|
79
|
-
# 转换 LangGraph 消息对象为标准格式
|
80
|
-
converted_messages = []
|
81
|
-
for msg in prompt:
|
82
|
-
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
83
|
-
# LangGraph 消息对象
|
84
|
-
msg_dict = {"content": msg.content}
|
85
|
-
|
86
|
-
# 根据消息类型设置 role
|
87
|
-
if msg.type == "system":
|
88
|
-
msg_dict["role"] = "system"
|
89
|
-
elif msg.type == "human":
|
90
|
-
msg_dict["role"] = "user"
|
91
|
-
elif msg.type == "ai":
|
92
|
-
msg_dict["role"] = "assistant"
|
93
|
-
# 处理工具调用
|
94
|
-
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
95
|
-
msg_dict["tool_calls"] = [
|
96
|
-
{
|
97
|
-
"id": tc.get("id", f"call_{i}"),
|
98
|
-
"type": "function",
|
99
|
-
"function": {
|
100
|
-
"name": tc["name"],
|
101
|
-
"arguments": json.dumps(tc["args"])
|
102
|
-
}
|
103
|
-
} for i, tc in enumerate(msg.tool_calls)
|
104
|
-
]
|
105
|
-
elif msg.type == "tool":
|
106
|
-
msg_dict["role"] = "tool"
|
107
|
-
if hasattr(msg, 'tool_call_id'):
|
108
|
-
msg_dict["tool_call_id"] = msg.tool_call_id
|
109
|
-
else:
|
110
|
-
msg_dict["role"] = "user" # 默认为用户消息
|
111
|
-
|
112
|
-
converted_messages.append(msg_dict)
|
113
|
-
elif isinstance(msg, dict):
|
114
|
-
# 已经是字典格式
|
115
|
-
converted_messages.append(msg)
|
116
|
-
else:
|
117
|
-
# 处理其他类型(如字符串)
|
118
|
-
converted_messages.append({"role": "user", "content": str(msg)})
|
119
|
-
|
120
|
-
# 如果绑定了工具,返回 AIMessage 对象以兼容 LangGraph
|
121
|
-
if self._has_bound_tools():
|
122
|
-
return await self.achat_with_message_response(converted_messages)
|
123
|
-
else:
|
124
|
-
return await self.achat(converted_messages)
|
125
|
-
elif isinstance(first_msg, dict):
|
126
|
-
# 标准字典格式的消息
|
127
|
-
if self._has_bound_tools():
|
128
|
-
return await self.achat_with_message_response(prompt)
|
129
|
-
else:
|
130
|
-
return await self.achat(prompt)
|
131
|
-
else:
|
132
|
-
# 处理其他格式,如字符串列表
|
133
|
-
converted_messages = []
|
134
|
-
for msg in prompt:
|
135
|
-
if isinstance(msg, str):
|
136
|
-
converted_messages.append({"role": "user", "content": msg})
|
137
|
-
elif isinstance(msg, dict):
|
138
|
-
converted_messages.append(msg)
|
139
|
-
else:
|
140
|
-
converted_messages.append({"role": "user", "content": str(msg)})
|
141
|
-
|
142
|
-
if self._has_bound_tools():
|
143
|
-
return await self.achat_with_message_response(converted_messages)
|
144
|
-
else:
|
145
|
-
return await self.achat(converted_messages)
|
146
|
-
else:
|
147
|
-
raise ValueError("Prompt must be a string or a list of messages")
|
148
|
-
|
149
|
-
async def achat(self, messages: List[Dict[str, str]]) -> str:
|
150
|
-
"""Chat completion method"""
|
70
|
+
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
71
|
+
"""Unified invoke method for all input types"""
|
151
72
|
try:
|
152
|
-
|
153
|
-
|
73
|
+
# Use adapter manager to prepare messages
|
74
|
+
messages = self._prepare_messages(input_data)
|
154
75
|
|
76
|
+
# Prepare request kwargs
|
155
77
|
kwargs = {
|
156
78
|
"model": self.model_name,
|
157
79
|
"messages": messages,
|
158
|
-
"temperature": temperature,
|
159
|
-
"max_tokens": max_tokens
|
80
|
+
"temperature": self.config.get("temperature", 0.7),
|
81
|
+
"max_tokens": self.config.get("max_tokens", 1024)
|
160
82
|
}
|
161
83
|
|
162
|
-
# Add tools if bound
|
163
|
-
|
164
|
-
|
84
|
+
# Add tools if bound using adapter manager
|
85
|
+
tool_schemas = await self._prepare_tools_for_request()
|
86
|
+
if tool_schemas:
|
87
|
+
kwargs["tools"] = tool_schemas
|
165
88
|
kwargs["tool_choice"] = "auto"
|
166
89
|
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
}
|
90
|
+
# Handle streaming vs non-streaming
|
91
|
+
if self.streaming:
|
92
|
+
# Streaming mode - collect all chunks
|
93
|
+
content_chunks = []
|
94
|
+
async for chunk in await self._stream_response(kwargs):
|
95
|
+
content_chunks.append(chunk)
|
96
|
+
content = "".join(content_chunks)
|
175
97
|
|
176
|
-
#
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
# Handle tool calls if present
|
183
|
-
message = response.choices[0].message
|
184
|
-
if message.tool_calls:
|
185
|
-
return await self._handle_tool_calls(message, messages)
|
186
|
-
|
187
|
-
return message.content or ""
|
188
|
-
|
189
|
-
except Exception as e:
|
190
|
-
logger.error(f"Error in chat completion: {e}")
|
191
|
-
raise
|
192
|
-
|
193
|
-
async def achat_with_message_response(self, messages: List[Dict[str, str]]) -> Any:
|
194
|
-
"""Chat completion method that returns message object for LangGraph compatibility"""
|
195
|
-
try:
|
196
|
-
temperature = self.config.get("temperature", 0.7)
|
197
|
-
max_tokens = self.config.get("max_tokens", 1024)
|
198
|
-
|
199
|
-
kwargs = {
|
200
|
-
"model": self.model_name,
|
201
|
-
"messages": messages,
|
202
|
-
"temperature": temperature,
|
203
|
-
"max_tokens": max_tokens
|
204
|
-
}
|
205
|
-
|
206
|
-
# Add tools if bound
|
207
|
-
if self._has_bound_tools():
|
208
|
-
kwargs["tools"] = self._get_bound_tools()
|
209
|
-
kwargs["tool_choice"] = "auto"
|
210
|
-
|
211
|
-
response = await self.client.chat.completions.create(**kwargs)
|
212
|
-
|
213
|
-
if response.usage:
|
214
|
-
self.last_token_usage = {
|
215
|
-
"prompt_tokens": response.usage.prompt_tokens,
|
216
|
-
"completion_tokens": response.usage.completion_tokens,
|
217
|
-
"total_tokens": response.usage.total_tokens
|
218
|
-
}
|
98
|
+
# Create a mock usage object for tracking
|
99
|
+
class MockUsage:
|
100
|
+
def __init__(self):
|
101
|
+
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
102
|
+
self.completion_tokens = len(content) // 4 # Rough estimate
|
103
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
219
104
|
|
220
|
-
|
221
|
-
self.
|
222
|
-
self.
|
223
|
-
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
224
|
-
self.total_token_usage["requests_count"] += 1
|
225
|
-
|
226
|
-
# Create a LangGraph-compatible message object
|
227
|
-
message = response.choices[0].message
|
228
|
-
|
229
|
-
# Try to import LangGraph message classes
|
230
|
-
try:
|
231
|
-
from langchain_core.messages import AIMessage
|
105
|
+
usage = MockUsage()
|
106
|
+
self._update_token_usage(usage)
|
107
|
+
self._track_billing(usage)
|
232
108
|
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
"name": tc.function.name,
|
239
|
-
"args": json.loads(tc.function.arguments),
|
240
|
-
"id": tc.id
|
241
|
-
})
|
242
|
-
|
243
|
-
return AIMessage(
|
244
|
-
content=message.content or "",
|
245
|
-
tool_calls=tool_calls
|
246
|
-
)
|
247
|
-
else:
|
248
|
-
return AIMessage(content=message.content or "")
|
249
|
-
|
250
|
-
except ImportError:
|
251
|
-
# Fallback: create a simple object with content attribute
|
252
|
-
class SimpleMessage:
|
253
|
-
def __init__(self, content, tool_calls=None):
|
254
|
-
self.content = content
|
255
|
-
self.tool_calls = tool_calls or []
|
109
|
+
return self._format_response(content, input_data)
|
110
|
+
else:
|
111
|
+
# Non-streaming mode
|
112
|
+
response = await self.client.chat.completions.create(**kwargs)
|
113
|
+
message = response.choices[0].message
|
256
114
|
|
115
|
+
# Update usage tracking
|
116
|
+
if response.usage:
|
117
|
+
self._update_token_usage(response.usage)
|
118
|
+
self._track_billing(response.usage)
|
119
|
+
|
120
|
+
# Handle tool calls if present - let adapter process the complete message
|
257
121
|
if message.tool_calls:
|
258
|
-
tool_calls
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
"id": tc.id
|
264
|
-
})
|
265
|
-
return SimpleMessage(message.content or "", tool_calls)
|
266
|
-
else:
|
267
|
-
return SimpleMessage(message.content or "")
|
122
|
+
# Pass the complete message object to adapter for proper tool_calls handling
|
123
|
+
return self._format_response(message, input_data)
|
124
|
+
|
125
|
+
# Return appropriate format based on input type
|
126
|
+
return self._format_response(message.content or "", input_data)
|
268
127
|
|
269
128
|
except Exception as e:
|
270
|
-
logger.error(f"Error in
|
129
|
+
logger.error(f"Error in ainvoke: {e}")
|
271
130
|
raise
|
272
131
|
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
"role": "assistant",
|
278
|
-
"content": assistant_message.content or "",
|
279
|
-
"tool_calls": [
|
280
|
-
{
|
281
|
-
"id": tc.id,
|
282
|
-
"type": tc.type,
|
283
|
-
"function": {
|
284
|
-
"name": tc.function.name,
|
285
|
-
"arguments": tc.function.arguments
|
286
|
-
}
|
287
|
-
} for tc in assistant_message.tool_calls
|
288
|
-
]
|
289
|
-
}]
|
132
|
+
|
133
|
+
async def _stream_response(self, kwargs: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
134
|
+
"""Handle streaming responses"""
|
135
|
+
kwargs["stream"] = True
|
290
136
|
|
291
|
-
|
292
|
-
for tool_call in assistant_message.tool_calls:
|
293
|
-
function_name = tool_call.function.name
|
294
|
-
arguments = json.loads(tool_call.function.arguments)
|
295
|
-
|
137
|
+
async def stream_generator():
|
296
138
|
try:
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
if
|
301
|
-
|
302
|
-
else:
|
303
|
-
result = f"Error: Function {function_name} not found"
|
304
|
-
|
305
|
-
# Add tool result to messages
|
306
|
-
messages.append({
|
307
|
-
"role": "tool",
|
308
|
-
"content": str(result),
|
309
|
-
"tool_call_id": tool_call.id
|
310
|
-
})
|
311
|
-
|
139
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
140
|
+
async for chunk in stream:
|
141
|
+
content = chunk.choices[0].delta.content
|
142
|
+
if content:
|
143
|
+
yield content
|
312
144
|
except Exception as e:
|
313
|
-
logger.error(f"Error
|
314
|
-
|
315
|
-
"role": "tool",
|
316
|
-
"content": f"Error executing {function_name}: {str(e)}",
|
317
|
-
"tool_call_id": tool_call.id
|
318
|
-
})
|
145
|
+
logger.error(f"Error in streaming: {e}")
|
146
|
+
raise
|
319
147
|
|
320
|
-
|
321
|
-
try:
|
322
|
-
kwargs = {
|
323
|
-
"model": self.model_name,
|
324
|
-
"messages": messages,
|
325
|
-
"temperature": self.config.get("temperature", 0.7),
|
326
|
-
"max_tokens": self.config.get("max_tokens", 1024)
|
327
|
-
}
|
328
|
-
|
329
|
-
response = await self.client.chat.completions.create(**kwargs)
|
330
|
-
return response.choices[0].message.content or ""
|
331
|
-
|
332
|
-
except Exception as e:
|
333
|
-
logger.error(f"Error getting final response after tool calls: {e}")
|
334
|
-
raise
|
335
|
-
|
336
|
-
async def acompletion(self, prompt: str) -> str:
|
337
|
-
"""Text completion method (using chat API)"""
|
338
|
-
messages = [{"role": "user", "content": prompt}]
|
339
|
-
return await self.achat(messages)
|
148
|
+
return stream_generator()
|
340
149
|
|
341
|
-
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
342
|
-
"""Generate multiple completions"""
|
343
|
-
try:
|
344
|
-
temperature = self.config.get("temperature", 0.7)
|
345
|
-
max_tokens = self.config.get("max_tokens", 1024)
|
346
|
-
|
347
|
-
kwargs = {
|
348
|
-
"model": self.model_name,
|
349
|
-
"messages": messages,
|
350
|
-
"temperature": temperature,
|
351
|
-
"max_tokens": max_tokens,
|
352
|
-
"n": n
|
353
|
-
}
|
354
|
-
|
355
|
-
# Add tools if bound
|
356
|
-
if self._has_bound_tools():
|
357
|
-
kwargs["tools"] = self._get_bound_tools()
|
358
|
-
kwargs["tool_choice"] = "auto"
|
359
|
-
|
360
|
-
response = await self.client.chat.completions.create(**kwargs)
|
361
|
-
|
362
|
-
if response.usage:
|
363
|
-
self.last_token_usage = {
|
364
|
-
"prompt_tokens": response.usage.prompt_tokens,
|
365
|
-
"completion_tokens": response.usage.completion_tokens,
|
366
|
-
"total_tokens": response.usage.total_tokens
|
367
|
-
}
|
368
|
-
|
369
|
-
# Update total usage
|
370
|
-
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
371
|
-
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
372
|
-
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
373
|
-
self.total_token_usage["requests_count"] += 1
|
374
|
-
|
375
|
-
return [choice.message.content or "" for choice in response.choices]
|
376
|
-
except Exception as e:
|
377
|
-
logger.error(f"Error in generate: {e}")
|
378
|
-
raise
|
379
150
|
|
380
|
-
|
381
|
-
"""
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
151
|
+
def _update_token_usage(self, usage):
|
152
|
+
"""Update token usage statistics"""
|
153
|
+
self.last_token_usage = {
|
154
|
+
"prompt_tokens": usage.prompt_tokens,
|
155
|
+
"completion_tokens": usage.completion_tokens,
|
156
|
+
"total_tokens": usage.total_tokens
|
157
|
+
}
|
158
|
+
|
159
|
+
# Update total usage
|
160
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
161
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
162
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
163
|
+
self.total_token_usage["requests_count"] += 1
|
164
|
+
|
165
|
+
def _track_billing(self, usage):
|
166
|
+
"""Track billing information"""
|
167
|
+
self._track_usage(
|
168
|
+
service_type=ServiceType.LLM,
|
169
|
+
operation="chat",
|
170
|
+
input_tokens=usage.prompt_tokens,
|
171
|
+
output_tokens=usage.completion_tokens,
|
172
|
+
metadata={
|
173
|
+
"temperature": self.config.get("temperature", 0.7),
|
174
|
+
"max_tokens": self.config.get("max_tokens", 1024)
|
392
175
|
}
|
393
|
-
|
394
|
-
# Add tools if bound
|
395
|
-
if self._has_bound_tools():
|
396
|
-
kwargs["tools"] = self._get_bound_tools()
|
397
|
-
kwargs["tool_choice"] = "auto"
|
398
|
-
|
399
|
-
stream = await self.client.chat.completions.create(**kwargs)
|
400
|
-
|
401
|
-
async for chunk in stream:
|
402
|
-
content = chunk.choices[0].delta.content
|
403
|
-
if content:
|
404
|
-
yield content
|
405
|
-
|
406
|
-
except Exception as e:
|
407
|
-
logger.error(f"Error in stream chat: {e}")
|
408
|
-
raise
|
409
|
-
|
410
|
-
async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
|
411
|
-
"""Stream completion responses"""
|
412
|
-
messages = [{"role": "user", "content": prompt}]
|
413
|
-
async for chunk in self.astream_chat(messages):
|
414
|
-
yield chunk
|
176
|
+
)
|
415
177
|
|
416
178
|
def get_token_usage(self) -> Dict[str, Any]:
|
417
179
|
"""Get total token usage statistics"""
|
@@ -431,13 +193,6 @@ class OpenAILLMService(BaseLLMService):
|
|
431
193
|
"provider": "openai"
|
432
194
|
}
|
433
195
|
|
434
|
-
def _has_bound_tools(self) -> bool:
|
435
|
-
"""Check if this service has bound tools"""
|
436
|
-
return bool(self._bound_tools)
|
437
|
-
|
438
|
-
def _get_bound_tools(self) -> List[Dict[str, Any]]:
|
439
|
-
"""Get the bound tools schema"""
|
440
|
-
return self._bound_tools
|
441
196
|
|
442
197
|
async def close(self):
|
443
198
|
"""Close the backend client"""
|
@@ -2,9 +2,10 @@ from io import BytesIO
|
|
2
2
|
from PIL import Image
|
3
3
|
from typing import Union
|
4
4
|
import base64
|
5
|
-
from app.config.config_manager import config_manager
|
5
|
+
# from app.config.config_manager import config_manager # Commented out to fix import
|
6
|
+
import logging
|
6
7
|
|
7
|
-
logger =
|
8
|
+
logger = logging.getLogger(__name__)
|
8
9
|
|
9
10
|
def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> bytes:
|
10
11
|
"""压缩图片以减小大小
|
@@ -30,7 +31,7 @@ def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> b
|
|
30
31
|
# 计算新尺寸,保持宽高比
|
31
32
|
ratio = max_size / max(img.size)
|
32
33
|
if ratio < 1:
|
33
|
-
new_size =
|
34
|
+
new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
|
34
35
|
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
35
36
|
|
36
37
|
# 保存压缩后的图片
|
@@ -7,6 +7,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential
|
|
7
7
|
from isa_model.inference.services.vision.base_vision_service import BaseVisionService
|
8
8
|
from isa_model.inference.providers.base_provider import BaseProvider
|
9
9
|
import logging
|
10
|
+
import requests
|
10
11
|
|
11
12
|
logger = logging.getLogger(__name__)
|
12
13
|
|
@@ -19,10 +20,17 @@ class OllamaVisionService(BaseVisionService):
|
|
19
20
|
self.temperature = self.config.get('temperature', 0.7)
|
20
21
|
|
21
22
|
def _get_image_data(self, image: Union[str, BinaryIO]) -> bytes:
|
22
|
-
"""
|
23
|
+
"""获取图像数据,支持本地文件和URL"""
|
23
24
|
if isinstance(image, str):
|
24
|
-
|
25
|
-
|
25
|
+
# Check if it's a URL
|
26
|
+
if image.startswith(('http://', 'https://')):
|
27
|
+
response = requests.get(image)
|
28
|
+
response.raise_for_status()
|
29
|
+
return response.content
|
30
|
+
else:
|
31
|
+
# Local file path
|
32
|
+
with open(image, 'rb') as f:
|
33
|
+
return f.read()
|
26
34
|
else:
|
27
35
|
return image.read()
|
28
36
|
|