isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. isa_model/core/model_manager.py +69 -4
  2. isa_model/inference/ai_factory.py +335 -46
  3. isa_model/inference/billing_tracker.py +406 -0
  4. isa_model/inference/providers/base_provider.py +51 -4
  5. isa_model/inference/providers/ollama_provider.py +37 -18
  6. isa_model/inference/providers/openai_provider.py +65 -36
  7. isa_model/inference/providers/replicate_provider.py +42 -30
  8. isa_model/inference/services/audio/base_stt_service.py +21 -2
  9. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  10. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  11. isa_model/inference/services/audio/openai_tts_service.py +48 -9
  12. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  13. isa_model/inference/services/base_service.py +36 -1
  14. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  15. isa_model/inference/services/llm/base_llm_service.py +88 -192
  16. isa_model/inference/services/llm/llm_adapter.py +459 -0
  17. isa_model/inference/services/llm/ollama_llm_service.py +111 -185
  18. isa_model/inference/services/llm/openai_llm_service.py +115 -360
  19. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  20. isa_model/inference/services/vision/ollama_vision_service.py +11 -3
  21. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  22. isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
  23. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
  24. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
  25. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
  26. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
@@ -3,415 +3,177 @@ import os
3
3
  import json
4
4
  from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
5
5
 
6
- # 使用官方 OpenAI 库和 dotenv
6
+ # 使用官方 OpenAI
7
7
  from openai import AsyncOpenAI
8
- from dotenv import load_dotenv
9
8
 
10
9
  from isa_model.inference.services.llm.base_llm_service import BaseLLMService
11
10
  from isa_model.inference.providers.base_provider import BaseProvider
12
-
13
- # 加载 .env.local 文件中的环境变量
14
- load_dotenv(dotenv_path='.env.local')
11
+ from isa_model.inference.billing_tracker import ServiceType
15
12
 
16
13
  logger = logging.getLogger(__name__)
17
14
 
18
15
  class OpenAILLMService(BaseLLMService):
19
- """OpenAI LLM service implementation"""
16
+ """OpenAI LLM service implementation with unified invoke interface"""
20
17
 
21
- def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-3.5-turbo"):
18
+ def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-4.1-nano"):
22
19
  super().__init__(provider, model_name)
23
20
 
24
- # provider配置初始化 AsyncOpenAI 客户端
21
+ # Get full configuration from provider (including sensitive data)
22
+ provider_config = provider.get_full_config()
23
+
24
+ # Initialize AsyncOpenAI client with provider configuration
25
25
  try:
26
- api_key = provider.config.get("api_key") or os.getenv("OPENAI_API_KEY")
27
- base_url = provider.config.get("api_base") or os.getenv("OPENAI_API_BASE")
26
+ if not provider_config.get("api_key"):
27
+ raise ValueError("OpenAI API key not found in provider configuration")
28
28
 
29
29
  self.client = AsyncOpenAI(
30
- api_key=api_key,
31
- base_url=base_url
30
+ api_key=provider_config["api_key"],
31
+ base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
32
+ organization=provider_config.get("organization")
32
33
  )
33
- except TypeError as e:
34
- logger.error("初始化 OpenAI 客户端失败。请检查您的 .env.local 文件中是否正确设置了 OPENAI_API_KEY。")
35
- raise ValueError("OPENAI_API_KEY 未设置。") from e
34
+
35
+ logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
36
+
37
+ except Exception as e:
38
+ logger.error(f"Failed to initialize OpenAI client: {e}")
39
+ raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
36
40
 
37
41
  self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
38
42
  self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
39
43
 
40
- # Tool binding attributes
41
- self._bound_tools: List[Dict[str, Any]] = []
42
- self._tool_binding_kwargs: Dict[str, Any] = {}
43
- self._tool_functions: Dict[str, Callable] = {}
44
-
45
- logger.info(f"Initialized OpenAILLMService with model {self.model_name} and endpoint {self.client.base_url}")
46
44
 
47
45
  def _create_bound_copy(self) -> 'OpenAILLMService':
48
46
  """Create a copy of this service for tool binding"""
49
47
  bound_service = OpenAILLMService(self.provider, self.model_name)
50
48
  bound_service._bound_tools = self._bound_tools.copy()
51
- bound_service._tool_binding_kwargs = self._tool_binding_kwargs.copy()
52
- bound_service._tool_functions = self._tool_functions.copy()
53
49
  return bound_service
54
50
 
55
- def bind_tools(self, tools: List[Union[Dict[str, Any], Callable]], **kwargs) -> 'OpenAILLMService':
56
- """Bind tools to this LLM service for function calling"""
51
+ def bind_tools(self, tools: List[Any], **kwargs) -> 'OpenAILLMService':
52
+ """
53
+ Bind tools to this LLM service for function calling
54
+
55
+ Args:
56
+ tools: List of tools (functions, dicts, or LangChain tools)
57
+ **kwargs: Additional arguments for tool binding
58
+
59
+ Returns:
60
+ New LLM service instance with tools bound
61
+ """
62
+ # Create a copy of this service
57
63
  bound_service = self._create_bound_copy()
58
- bound_service._bound_tools = self._convert_tools_to_schema(tools)
59
- bound_service._tool_binding_kwargs = kwargs
60
64
 
61
- # Store the actual functions for execution
62
- for tool in tools:
63
- if callable(tool):
64
- bound_service._tool_functions[tool.__name__] = tool
65
+ # Use base class method to bind tools
66
+ bound_service._bound_tools = tools
65
67
 
66
68
  return bound_service
67
69
 
68
- async def ainvoke(self, prompt: Union[str, List[Any], Any]) -> Union[str, Any]:
69
- """Universal invocation method"""
70
- if isinstance(prompt, str):
71
- return await self.acompletion(prompt)
72
- elif isinstance(prompt, list):
73
- if not prompt:
74
- raise ValueError("Empty message list provided")
75
-
76
- # 检查是否是 LangGraph 消息对象
77
- first_msg = prompt[0]
78
- if hasattr(first_msg, 'content') and hasattr(first_msg, 'type'):
79
- # 转换 LangGraph 消息对象为标准格式
80
- converted_messages = []
81
- for msg in prompt:
82
- if hasattr(msg, 'type') and hasattr(msg, 'content'):
83
- # LangGraph 消息对象
84
- msg_dict = {"content": msg.content}
85
-
86
- # 根据消息类型设置 role
87
- if msg.type == "system":
88
- msg_dict["role"] = "system"
89
- elif msg.type == "human":
90
- msg_dict["role"] = "user"
91
- elif msg.type == "ai":
92
- msg_dict["role"] = "assistant"
93
- # 处理工具调用
94
- if hasattr(msg, 'tool_calls') and msg.tool_calls:
95
- msg_dict["tool_calls"] = [
96
- {
97
- "id": tc.get("id", f"call_{i}"),
98
- "type": "function",
99
- "function": {
100
- "name": tc["name"],
101
- "arguments": json.dumps(tc["args"])
102
- }
103
- } for i, tc in enumerate(msg.tool_calls)
104
- ]
105
- elif msg.type == "tool":
106
- msg_dict["role"] = "tool"
107
- if hasattr(msg, 'tool_call_id'):
108
- msg_dict["tool_call_id"] = msg.tool_call_id
109
- else:
110
- msg_dict["role"] = "user" # 默认为用户消息
111
-
112
- converted_messages.append(msg_dict)
113
- elif isinstance(msg, dict):
114
- # 已经是字典格式
115
- converted_messages.append(msg)
116
- else:
117
- # 处理其他类型(如字符串)
118
- converted_messages.append({"role": "user", "content": str(msg)})
119
-
120
- # 如果绑定了工具,返回 AIMessage 对象以兼容 LangGraph
121
- if self._has_bound_tools():
122
- return await self.achat_with_message_response(converted_messages)
123
- else:
124
- return await self.achat(converted_messages)
125
- elif isinstance(first_msg, dict):
126
- # 标准字典格式的消息
127
- if self._has_bound_tools():
128
- return await self.achat_with_message_response(prompt)
129
- else:
130
- return await self.achat(prompt)
131
- else:
132
- # 处理其他格式,如字符串列表
133
- converted_messages = []
134
- for msg in prompt:
135
- if isinstance(msg, str):
136
- converted_messages.append({"role": "user", "content": msg})
137
- elif isinstance(msg, dict):
138
- converted_messages.append(msg)
139
- else:
140
- converted_messages.append({"role": "user", "content": str(msg)})
141
-
142
- if self._has_bound_tools():
143
- return await self.achat_with_message_response(converted_messages)
144
- else:
145
- return await self.achat(converted_messages)
146
- else:
147
- raise ValueError("Prompt must be a string or a list of messages")
148
-
149
- async def achat(self, messages: List[Dict[str, str]]) -> str:
150
- """Chat completion method"""
70
+ async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
71
+ """Unified invoke method for all input types"""
151
72
  try:
152
- temperature = self.config.get("temperature", 0.7)
153
- max_tokens = self.config.get("max_tokens", 1024)
73
+ # Use adapter manager to prepare messages
74
+ messages = self._prepare_messages(input_data)
154
75
 
76
+ # Prepare request kwargs
155
77
  kwargs = {
156
78
  "model": self.model_name,
157
79
  "messages": messages,
158
- "temperature": temperature,
159
- "max_tokens": max_tokens
80
+ "temperature": self.config.get("temperature", 0.7),
81
+ "max_tokens": self.config.get("max_tokens", 1024)
160
82
  }
161
83
 
162
- # Add tools if bound
163
- if self._has_bound_tools():
164
- kwargs["tools"] = self._get_bound_tools()
84
+ # Add tools if bound using adapter manager
85
+ tool_schemas = await self._prepare_tools_for_request()
86
+ if tool_schemas:
87
+ kwargs["tools"] = tool_schemas
165
88
  kwargs["tool_choice"] = "auto"
166
89
 
167
- response = await self.client.chat.completions.create(**kwargs)
168
-
169
- if response.usage:
170
- self.last_token_usage = {
171
- "prompt_tokens": response.usage.prompt_tokens,
172
- "completion_tokens": response.usage.completion_tokens,
173
- "total_tokens": response.usage.total_tokens
174
- }
90
+ # Handle streaming vs non-streaming
91
+ if self.streaming:
92
+ # Streaming mode - collect all chunks
93
+ content_chunks = []
94
+ async for chunk in await self._stream_response(kwargs):
95
+ content_chunks.append(chunk)
96
+ content = "".join(content_chunks)
175
97
 
176
- # Update total usage
177
- self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
178
- self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
179
- self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
180
- self.total_token_usage["requests_count"] += 1
181
-
182
- # Handle tool calls if present
183
- message = response.choices[0].message
184
- if message.tool_calls:
185
- return await self._handle_tool_calls(message, messages)
186
-
187
- return message.content or ""
188
-
189
- except Exception as e:
190
- logger.error(f"Error in chat completion: {e}")
191
- raise
192
-
193
- async def achat_with_message_response(self, messages: List[Dict[str, str]]) -> Any:
194
- """Chat completion method that returns message object for LangGraph compatibility"""
195
- try:
196
- temperature = self.config.get("temperature", 0.7)
197
- max_tokens = self.config.get("max_tokens", 1024)
198
-
199
- kwargs = {
200
- "model": self.model_name,
201
- "messages": messages,
202
- "temperature": temperature,
203
- "max_tokens": max_tokens
204
- }
205
-
206
- # Add tools if bound
207
- if self._has_bound_tools():
208
- kwargs["tools"] = self._get_bound_tools()
209
- kwargs["tool_choice"] = "auto"
210
-
211
- response = await self.client.chat.completions.create(**kwargs)
212
-
213
- if response.usage:
214
- self.last_token_usage = {
215
- "prompt_tokens": response.usage.prompt_tokens,
216
- "completion_tokens": response.usage.completion_tokens,
217
- "total_tokens": response.usage.total_tokens
218
- }
98
+ # Create a mock usage object for tracking
99
+ class MockUsage:
100
+ def __init__(self):
101
+ self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
102
+ self.completion_tokens = len(content) // 4 # Rough estimate
103
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
219
104
 
220
- # Update total usage
221
- self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
222
- self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
223
- self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
224
- self.total_token_usage["requests_count"] += 1
225
-
226
- # Create a LangGraph-compatible message object
227
- message = response.choices[0].message
228
-
229
- # Try to import LangGraph message classes
230
- try:
231
- from langchain_core.messages import AIMessage
105
+ usage = MockUsage()
106
+ self._update_token_usage(usage)
107
+ self._track_billing(usage)
232
108
 
233
- # Create AIMessage with tool calls if present
234
- if message.tool_calls:
235
- tool_calls = []
236
- for tc in message.tool_calls:
237
- tool_calls.append({
238
- "name": tc.function.name,
239
- "args": json.loads(tc.function.arguments),
240
- "id": tc.id
241
- })
242
-
243
- return AIMessage(
244
- content=message.content or "",
245
- tool_calls=tool_calls
246
- )
247
- else:
248
- return AIMessage(content=message.content or "")
249
-
250
- except ImportError:
251
- # Fallback: create a simple object with content attribute
252
- class SimpleMessage:
253
- def __init__(self, content, tool_calls=None):
254
- self.content = content
255
- self.tool_calls = tool_calls or []
109
+ return self._format_response(content, input_data)
110
+ else:
111
+ # Non-streaming mode
112
+ response = await self.client.chat.completions.create(**kwargs)
113
+ message = response.choices[0].message
256
114
 
115
+ # Update usage tracking
116
+ if response.usage:
117
+ self._update_token_usage(response.usage)
118
+ self._track_billing(response.usage)
119
+
120
+ # Handle tool calls if present - let adapter process the complete message
257
121
  if message.tool_calls:
258
- tool_calls = []
259
- for tc in message.tool_calls:
260
- tool_calls.append({
261
- "name": tc.function.name,
262
- "args": json.loads(tc.function.arguments),
263
- "id": tc.id
264
- })
265
- return SimpleMessage(message.content or "", tool_calls)
266
- else:
267
- return SimpleMessage(message.content or "")
122
+ # Pass the complete message object to adapter for proper tool_calls handling
123
+ return self._format_response(message, input_data)
124
+
125
+ # Return appropriate format based on input type
126
+ return self._format_response(message.content or "", input_data)
268
127
 
269
128
  except Exception as e:
270
- logger.error(f"Error in chat completion with message response: {e}")
129
+ logger.error(f"Error in ainvoke: {e}")
271
130
  raise
272
131
 
273
- async def _handle_tool_calls(self, assistant_message, original_messages: List[Dict[str, str]]) -> str:
274
- """Handle tool calls from the assistant"""
275
- # Add assistant message with tool calls to conversation
276
- messages = original_messages + [{
277
- "role": "assistant",
278
- "content": assistant_message.content or "",
279
- "tool_calls": [
280
- {
281
- "id": tc.id,
282
- "type": tc.type,
283
- "function": {
284
- "name": tc.function.name,
285
- "arguments": tc.function.arguments
286
- }
287
- } for tc in assistant_message.tool_calls
288
- ]
289
- }]
132
+
133
+ async def _stream_response(self, kwargs: Dict[str, Any]) -> AsyncGenerator[str, None]:
134
+ """Handle streaming responses"""
135
+ kwargs["stream"] = True
290
136
 
291
- # Execute each tool call
292
- for tool_call in assistant_message.tool_calls:
293
- function_name = tool_call.function.name
294
- arguments = json.loads(tool_call.function.arguments)
295
-
137
+ async def stream_generator():
296
138
  try:
297
- # Execute the tool
298
- if function_name in self._tool_functions:
299
- result = self._tool_functions[function_name](**arguments)
300
- if hasattr(result, '__await__'): # Handle async functions
301
- result = await result
302
- else:
303
- result = f"Error: Function {function_name} not found"
304
-
305
- # Add tool result to messages
306
- messages.append({
307
- "role": "tool",
308
- "content": str(result),
309
- "tool_call_id": tool_call.id
310
- })
311
-
139
+ stream = await self.client.chat.completions.create(**kwargs)
140
+ async for chunk in stream:
141
+ content = chunk.choices[0].delta.content
142
+ if content:
143
+ yield content
312
144
  except Exception as e:
313
- logger.error(f"Error executing tool {function_name}: {e}")
314
- messages.append({
315
- "role": "tool",
316
- "content": f"Error executing {function_name}: {str(e)}",
317
- "tool_call_id": tool_call.id
318
- })
145
+ logger.error(f"Error in streaming: {e}")
146
+ raise
319
147
 
320
- # Get final response from the model with all context
321
- try:
322
- kwargs = {
323
- "model": self.model_name,
324
- "messages": messages,
325
- "temperature": self.config.get("temperature", 0.7),
326
- "max_tokens": self.config.get("max_tokens", 1024)
327
- }
328
-
329
- response = await self.client.chat.completions.create(**kwargs)
330
- return response.choices[0].message.content or ""
331
-
332
- except Exception as e:
333
- logger.error(f"Error getting final response after tool calls: {e}")
334
- raise
335
-
336
- async def acompletion(self, prompt: str) -> str:
337
- """Text completion method (using chat API)"""
338
- messages = [{"role": "user", "content": prompt}]
339
- return await self.achat(messages)
148
+ return stream_generator()
340
149
 
341
- async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
342
- """Generate multiple completions"""
343
- try:
344
- temperature = self.config.get("temperature", 0.7)
345
- max_tokens = self.config.get("max_tokens", 1024)
346
-
347
- kwargs = {
348
- "model": self.model_name,
349
- "messages": messages,
350
- "temperature": temperature,
351
- "max_tokens": max_tokens,
352
- "n": n
353
- }
354
-
355
- # Add tools if bound
356
- if self._has_bound_tools():
357
- kwargs["tools"] = self._get_bound_tools()
358
- kwargs["tool_choice"] = "auto"
359
-
360
- response = await self.client.chat.completions.create(**kwargs)
361
-
362
- if response.usage:
363
- self.last_token_usage = {
364
- "prompt_tokens": response.usage.prompt_tokens,
365
- "completion_tokens": response.usage.completion_tokens,
366
- "total_tokens": response.usage.total_tokens
367
- }
368
-
369
- # Update total usage
370
- self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
371
- self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
372
- self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
373
- self.total_token_usage["requests_count"] += 1
374
-
375
- return [choice.message.content or "" for choice in response.choices]
376
- except Exception as e:
377
- logger.error(f"Error in generate: {e}")
378
- raise
379
150
 
380
- async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
381
- """Stream chat responses"""
382
- try:
383
- temperature = self.config.get("temperature", 0.7)
384
- max_tokens = self.config.get("max_tokens", 1024)
385
-
386
- kwargs = {
387
- "model": self.model_name,
388
- "messages": messages,
389
- "temperature": temperature,
390
- "max_tokens": max_tokens,
391
- "stream": True
151
+ def _update_token_usage(self, usage):
152
+ """Update token usage statistics"""
153
+ self.last_token_usage = {
154
+ "prompt_tokens": usage.prompt_tokens,
155
+ "completion_tokens": usage.completion_tokens,
156
+ "total_tokens": usage.total_tokens
157
+ }
158
+
159
+ # Update total usage
160
+ self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
161
+ self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
162
+ self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
163
+ self.total_token_usage["requests_count"] += 1
164
+
165
+ def _track_billing(self, usage):
166
+ """Track billing information"""
167
+ self._track_usage(
168
+ service_type=ServiceType.LLM,
169
+ operation="chat",
170
+ input_tokens=usage.prompt_tokens,
171
+ output_tokens=usage.completion_tokens,
172
+ metadata={
173
+ "temperature": self.config.get("temperature", 0.7),
174
+ "max_tokens": self.config.get("max_tokens", 1024)
392
175
  }
393
-
394
- # Add tools if bound
395
- if self._has_bound_tools():
396
- kwargs["tools"] = self._get_bound_tools()
397
- kwargs["tool_choice"] = "auto"
398
-
399
- stream = await self.client.chat.completions.create(**kwargs)
400
-
401
- async for chunk in stream:
402
- content = chunk.choices[0].delta.content
403
- if content:
404
- yield content
405
-
406
- except Exception as e:
407
- logger.error(f"Error in stream chat: {e}")
408
- raise
409
-
410
- async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
411
- """Stream completion responses"""
412
- messages = [{"role": "user", "content": prompt}]
413
- async for chunk in self.astream_chat(messages):
414
- yield chunk
176
+ )
415
177
 
416
178
  def get_token_usage(self) -> Dict[str, Any]:
417
179
  """Get total token usage statistics"""
@@ -431,13 +193,6 @@ class OpenAILLMService(BaseLLMService):
431
193
  "provider": "openai"
432
194
  }
433
195
 
434
- def _has_bound_tools(self) -> bool:
435
- """Check if this service has bound tools"""
436
- return bool(self._bound_tools)
437
-
438
- def _get_bound_tools(self) -> List[Dict[str, Any]]:
439
- """Get the bound tools schema"""
440
- return self._bound_tools
441
196
 
442
197
  async def close(self):
443
198
  """Close the backend client"""
@@ -2,9 +2,10 @@ from io import BytesIO
2
2
  from PIL import Image
3
3
  from typing import Union
4
4
  import base64
5
- from app.config.config_manager import config_manager
5
+ # from app.config.config_manager import config_manager # Commented out to fix import
6
+ import logging
6
7
 
7
- logger = config_manager.get_logger(__name__)
8
+ logger = logging.getLogger(__name__)
8
9
 
9
10
  def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> bytes:
10
11
  """压缩图片以减小大小
@@ -30,7 +31,7 @@ def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> b
30
31
  # 计算新尺寸,保持宽高比
31
32
  ratio = max_size / max(img.size)
32
33
  if ratio < 1:
33
- new_size = tuple(int(dim * ratio) for dim in img.size)
34
+ new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
34
35
  img = img.resize(new_size, Image.Resampling.LANCZOS)
35
36
 
36
37
  # 保存压缩后的图片
@@ -7,6 +7,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential
7
7
  from isa_model.inference.services.vision.base_vision_service import BaseVisionService
8
8
  from isa_model.inference.providers.base_provider import BaseProvider
9
9
  import logging
10
+ import requests
10
11
 
11
12
  logger = logging.getLogger(__name__)
12
13
 
@@ -19,10 +20,17 @@ class OllamaVisionService(BaseVisionService):
19
20
  self.temperature = self.config.get('temperature', 0.7)
20
21
 
21
22
  def _get_image_data(self, image: Union[str, BinaryIO]) -> bytes:
22
- """获取图像数据"""
23
+ """获取图像数据,支持本地文件和URL"""
23
24
  if isinstance(image, str):
24
- with open(image, 'rb') as f:
25
- return f.read()
25
+ # Check if it's a URL
26
+ if image.startswith(('http://', 'https://')):
27
+ response = requests.get(image)
28
+ response.raise_for_status()
29
+ return response.content
30
+ else:
31
+ # Local file path
32
+ with open(image, 'rb') as f:
33
+ return f.read()
26
34
  else:
27
35
  return image.read()
28
36