isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/core/model_manager.py +69 -4
- isa_model/inference/ai_factory.py +335 -46
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +48 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/base_llm_service.py +88 -192
- isa_model/inference/services/llm/llm_adapter.py +459 -0
- isa_model/inference/services/llm/ollama_llm_service.py +111 -185
- isa_model/inference/services/llm/openai_llm_service.py +115 -360
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +11 -3
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,9 @@ from isa_model.inference.providers.base_provider import BaseProvider
|
|
8
8
|
logger = logging.getLogger(__name__)
|
9
9
|
|
10
10
|
class OllamaLLMService(BaseLLMService):
|
11
|
-
"""Ollama LLM service
|
11
|
+
"""Ollama LLM service with unified invoke interface and proper adapter support"""
|
12
12
|
|
13
|
-
def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.
|
13
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.2:3b-instruct-fp16"):
|
14
14
|
super().__init__(provider, model_name)
|
15
15
|
|
16
16
|
# Create HTTP client for Ollama API
|
@@ -25,50 +25,55 @@ class OllamaLLMService(BaseLLMService):
|
|
25
25
|
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
26
26
|
self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
|
27
27
|
|
28
|
-
# Tool binding attributes
|
29
|
-
self._bound_tools: List[Dict[str, Any]] = []
|
30
|
-
self._tool_binding_kwargs: Dict[str, Any] = {}
|
31
|
-
self._tool_functions: Dict[str, Callable] = {}
|
32
28
|
|
33
29
|
logger.info(f"Initialized OllamaLLMService with model {model_name} at {base_url}")
|
34
30
|
|
31
|
+
def _ensure_client(self):
|
32
|
+
"""Ensure the HTTP client is available and not closed"""
|
33
|
+
if not hasattr(self, 'client') or not self.client or self.client.is_closed:
|
34
|
+
base_url = self.config.get("base_url", "http://localhost:11434")
|
35
|
+
timeout = self.config.get("timeout", 60)
|
36
|
+
self.client = httpx.AsyncClient(base_url=base_url, timeout=timeout)
|
37
|
+
|
35
38
|
def _create_bound_copy(self) -> 'OllamaLLMService':
|
36
39
|
"""Create a copy of this service for tool binding"""
|
37
40
|
bound_service = OllamaLLMService(self.provider, self.model_name)
|
38
41
|
bound_service._bound_tools = self._bound_tools.copy()
|
39
|
-
bound_service._tool_binding_kwargs = self._tool_binding_kwargs.copy()
|
40
|
-
bound_service._tool_functions = self._tool_functions.copy()
|
41
42
|
return bound_service
|
42
43
|
|
43
|
-
def bind_tools(self, tools: List[
|
44
|
+
def bind_tools(self, tools: List[Any], **kwargs) -> 'OllamaLLMService':
|
44
45
|
"""Bind tools to this LLM service for function calling"""
|
45
46
|
bound_service = self._create_bound_copy()
|
46
|
-
|
47
|
-
bound_service.
|
48
|
-
|
49
|
-
# Store the actual functions for execution
|
50
|
-
for tool in tools:
|
51
|
-
if callable(tool):
|
52
|
-
bound_service._tool_functions[tool.__name__] = tool
|
47
|
+
# Use base class method to bind tools
|
48
|
+
bound_service._bound_tools = tools
|
53
49
|
|
54
50
|
return bound_service
|
55
51
|
|
56
|
-
async def ainvoke(self,
|
57
|
-
"""
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
52
|
+
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
53
|
+
"""
|
54
|
+
Universal async invocation method that handles different input types
|
55
|
+
|
56
|
+
Args:
|
57
|
+
input_data: Can be:
|
58
|
+
- str: Simple text prompt
|
59
|
+
- list: Message history like [{"role": "user", "content": "hello"}]
|
60
|
+
- Any: LangChain message objects or other formats
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Model response (string for simple cases, object for complex cases)
|
64
|
+
"""
|
67
65
|
try:
|
66
|
+
# Ensure client is available
|
67
|
+
self._ensure_client()
|
68
|
+
|
69
|
+
# Use adapter manager to prepare messages (consistent with OpenAI service)
|
70
|
+
messages = self._prepare_messages(input_data)
|
71
|
+
|
72
|
+
# Prepare request parameters
|
68
73
|
payload = {
|
69
74
|
"model": self.model_name,
|
70
75
|
"messages": messages,
|
71
|
-
"stream":
|
76
|
+
"stream": self.streaming,
|
72
77
|
"options": {
|
73
78
|
"temperature": self.config.get("temperature", 0.7),
|
74
79
|
"top_p": self.config.get("top_p", 0.9),
|
@@ -76,66 +81,96 @@ class OllamaLLMService(BaseLLMService):
|
|
76
81
|
}
|
77
82
|
}
|
78
83
|
|
79
|
-
# Add tools if bound
|
80
|
-
|
81
|
-
|
84
|
+
# Add tools if bound using adapter manager
|
85
|
+
tool_schemas = await self._prepare_tools_for_request()
|
86
|
+
if tool_schemas:
|
87
|
+
payload["tools"] = tool_schemas
|
88
|
+
|
89
|
+
# Handle streaming
|
90
|
+
if self.streaming:
|
91
|
+
return self._stream_response(payload)
|
82
92
|
|
93
|
+
# Regular request
|
83
94
|
response = await self.client.post("/api/chat", json=payload)
|
84
95
|
response.raise_for_status()
|
85
96
|
result = response.json()
|
86
97
|
|
87
98
|
# Update token usage if available
|
88
99
|
if "eval_count" in result:
|
89
|
-
self.
|
90
|
-
"prompt_tokens": result.get("prompt_eval_count", 0),
|
91
|
-
"completion_tokens": result.get("eval_count", 0),
|
92
|
-
"total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
|
93
|
-
}
|
94
|
-
|
95
|
-
# Update total usage
|
96
|
-
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
97
|
-
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
98
|
-
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
99
|
-
self.total_token_usage["requests_count"] += 1
|
100
|
+
self._update_token_usage(result)
|
100
101
|
|
101
|
-
# Handle tool calls if present
|
102
|
+
# Handle tool calls if present - let adapter process the complete message
|
102
103
|
message = result["message"]
|
103
104
|
if "tool_calls" in message and message["tool_calls"]:
|
104
|
-
|
105
|
+
# Create message object similar to OpenAI format for adapter processing
|
106
|
+
message_obj = type('OllamaMessage', (), {
|
107
|
+
'content': message.get("content", ""),
|
108
|
+
'tool_calls': message["tool_calls"]
|
109
|
+
})()
|
110
|
+
# Pass the complete message object to adapter for proper tool_calls handling
|
111
|
+
return self._format_response(message_obj, input_data)
|
105
112
|
|
106
|
-
|
113
|
+
# Return appropriate format based on input type
|
114
|
+
return self._format_response(message.get("content", ""), input_data)
|
107
115
|
|
108
116
|
except httpx.RequestError as e:
|
109
|
-
logger.error(f"HTTP request error in
|
117
|
+
logger.error(f"HTTP request error in ainvoke: {e}")
|
110
118
|
raise
|
111
119
|
except Exception as e:
|
112
120
|
logger.error(f"Error in chat completion: {e}")
|
113
121
|
raise
|
114
122
|
|
123
|
+
def _prepare_messages(self, input_data: Union[str, List[Dict[str, str]], Any]) -> List[Dict[str, str]]:
|
124
|
+
"""Use adapter manager to convert messages (consistent with OpenAI service)"""
|
125
|
+
return self.adapter_manager.convert_messages(input_data)
|
126
|
+
|
127
|
+
|
128
|
+
def _format_response(self, response: Union[str, Any], original_input: Any) -> Union[str, Any]:
|
129
|
+
"""Use adapter manager to format response (consistent with OpenAI service)"""
|
130
|
+
return self.adapter_manager.format_response(response, original_input)
|
131
|
+
|
132
|
+
|
133
|
+
async def _stream_response(self, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
134
|
+
"""Handle streaming responses"""
|
135
|
+
async def stream_generator():
|
136
|
+
try:
|
137
|
+
async with self.client.stream("POST", "/api/chat", json=payload) as response:
|
138
|
+
response.raise_for_status()
|
139
|
+
async for line in response.aiter_lines():
|
140
|
+
if line.strip():
|
141
|
+
try:
|
142
|
+
chunk = json.loads(line)
|
143
|
+
if "message" in chunk and "content" in chunk["message"]:
|
144
|
+
content = chunk["message"]["content"]
|
145
|
+
if content:
|
146
|
+
yield content
|
147
|
+
except json.JSONDecodeError:
|
148
|
+
continue
|
149
|
+
except Exception as e:
|
150
|
+
logger.error(f"Error in streaming: {e}")
|
151
|
+
raise
|
152
|
+
|
153
|
+
return stream_generator()
|
154
|
+
|
115
155
|
async def _handle_tool_calls(self, assistant_message: Dict[str, Any], original_messages: List[Dict[str, str]]) -> str:
|
116
|
-
"""Handle tool calls from the assistant"""
|
156
|
+
"""Handle tool calls from the assistant using adapter manager"""
|
117
157
|
tool_calls = assistant_message.get("tool_calls", [])
|
118
158
|
|
119
159
|
# Add assistant message with tool calls to conversation
|
120
160
|
messages = original_messages + [assistant_message]
|
121
161
|
|
122
|
-
# Execute each tool call
|
162
|
+
# Execute each tool call using adapter manager
|
123
163
|
for tool_call in tool_calls:
|
124
164
|
function_name = tool_call["function"]["name"]
|
125
|
-
arguments = tool_call["function"]["arguments"]
|
126
165
|
|
127
166
|
try:
|
128
167
|
# Parse arguments if they're a string
|
168
|
+
arguments = tool_call["function"]["arguments"]
|
129
169
|
if isinstance(arguments, str):
|
130
170
|
arguments = json.loads(arguments)
|
131
171
|
|
132
|
-
#
|
133
|
-
|
134
|
-
result = self._tool_functions[function_name](**arguments)
|
135
|
-
if hasattr(result, '__await__'): # Handle async functions
|
136
|
-
result = await result
|
137
|
-
else:
|
138
|
-
result = f"Error: Function {function_name} not found"
|
172
|
+
# Use adapter manager to execute tool
|
173
|
+
result = await self._execute_tool_call(function_name, arguments)
|
139
174
|
|
140
175
|
# Add tool result to messages
|
141
176
|
messages.append({
|
@@ -153,128 +188,21 @@ class OllamaLLMService(BaseLLMService):
|
|
153
188
|
})
|
154
189
|
|
155
190
|
# Get final response from the model
|
156
|
-
return await self.
|
191
|
+
return await self.ainvoke(messages)
|
157
192
|
|
158
|
-
|
159
|
-
"""
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
response = await self.client.post("/api/generate", json=payload)
|
173
|
-
response.raise_for_status()
|
174
|
-
result = response.json()
|
175
|
-
|
176
|
-
# Update token usage if available
|
177
|
-
if "eval_count" in result:
|
178
|
-
self.last_token_usage = {
|
179
|
-
"prompt_tokens": result.get("prompt_eval_count", 0),
|
180
|
-
"completion_tokens": result.get("eval_count", 0),
|
181
|
-
"total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
|
182
|
-
}
|
183
|
-
|
184
|
-
# Update total usage
|
185
|
-
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
186
|
-
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
187
|
-
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
188
|
-
self.total_token_usage["requests_count"] += 1
|
189
|
-
|
190
|
-
return result["response"]
|
191
|
-
|
192
|
-
except httpx.RequestError as e:
|
193
|
-
logger.error(f"HTTP request error in text completion: {e}")
|
194
|
-
raise
|
195
|
-
except Exception as e:
|
196
|
-
logger.error(f"Error in text completion: {e}")
|
197
|
-
raise
|
198
|
-
|
199
|
-
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[str]:
|
200
|
-
"""Generate multiple completions"""
|
201
|
-
results = []
|
202
|
-
for _ in range(n):
|
203
|
-
result = await self.achat(messages)
|
204
|
-
results.append(result)
|
205
|
-
return results
|
206
|
-
|
207
|
-
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
208
|
-
"""Stream chat responses"""
|
209
|
-
try:
|
210
|
-
payload = {
|
211
|
-
"model": self.model_name,
|
212
|
-
"messages": messages,
|
213
|
-
"stream": True,
|
214
|
-
"options": {
|
215
|
-
"temperature": self.config.get("temperature", 0.7),
|
216
|
-
"top_p": self.config.get("top_p", 0.9),
|
217
|
-
"num_predict": self.config.get("max_tokens", 2048)
|
218
|
-
}
|
219
|
-
}
|
220
|
-
|
221
|
-
# Add tools if bound
|
222
|
-
if self._has_bound_tools():
|
223
|
-
payload["tools"] = self._get_bound_tools()
|
224
|
-
|
225
|
-
async with self.client.stream("POST", "/api/chat", json=payload) as response:
|
226
|
-
response.raise_for_status()
|
227
|
-
async for line in response.aiter_lines():
|
228
|
-
if line.strip():
|
229
|
-
try:
|
230
|
-
chunk = json.loads(line)
|
231
|
-
if "message" in chunk and "content" in chunk["message"]:
|
232
|
-
content = chunk["message"]["content"]
|
233
|
-
if content:
|
234
|
-
yield content
|
235
|
-
except json.JSONDecodeError:
|
236
|
-
continue
|
237
|
-
|
238
|
-
except httpx.RequestError as e:
|
239
|
-
logger.error(f"HTTP request error in stream chat: {e}")
|
240
|
-
raise
|
241
|
-
except Exception as e:
|
242
|
-
logger.error(f"Error in stream chat: {e}")
|
243
|
-
raise
|
244
|
-
|
245
|
-
async def astream_completion(self, prompt: str) -> AsyncGenerator[str, None]:
|
246
|
-
"""Stream completion responses"""
|
247
|
-
try:
|
248
|
-
payload = {
|
249
|
-
"model": self.model_name,
|
250
|
-
"prompt": prompt,
|
251
|
-
"stream": True,
|
252
|
-
"options": {
|
253
|
-
"temperature": self.config.get("temperature", 0.7),
|
254
|
-
"top_p": self.config.get("top_p", 0.9),
|
255
|
-
"num_predict": self.config.get("max_tokens", 2048)
|
256
|
-
}
|
257
|
-
}
|
258
|
-
|
259
|
-
async with self.client.stream("POST", "/api/generate", json=payload) as response:
|
260
|
-
response.raise_for_status()
|
261
|
-
async for line in response.aiter_lines():
|
262
|
-
if line.strip():
|
263
|
-
try:
|
264
|
-
chunk = json.loads(line)
|
265
|
-
if "response" in chunk:
|
266
|
-
content = chunk["response"]
|
267
|
-
if content:
|
268
|
-
yield content
|
269
|
-
except json.JSONDecodeError:
|
270
|
-
continue
|
271
|
-
|
272
|
-
except httpx.RequestError as e:
|
273
|
-
logger.error(f"HTTP request error in stream completion: {e}")
|
274
|
-
raise
|
275
|
-
except Exception as e:
|
276
|
-
logger.error(f"Error in stream completion: {e}")
|
277
|
-
raise
|
193
|
+
def _update_token_usage(self, result: Dict[str, Any]):
|
194
|
+
"""Update token usage statistics"""
|
195
|
+
self.last_token_usage = {
|
196
|
+
"prompt_tokens": result.get("prompt_eval_count", 0),
|
197
|
+
"completion_tokens": result.get("eval_count", 0),
|
198
|
+
"total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
|
199
|
+
}
|
200
|
+
|
201
|
+
# Update total usage
|
202
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
203
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
204
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
205
|
+
self.total_token_usage["requests_count"] += 1
|
278
206
|
|
279
207
|
def get_token_usage(self) -> Dict[str, Any]:
|
280
208
|
"""Get total token usage statistics"""
|
@@ -294,14 +222,12 @@ class OllamaLLMService(BaseLLMService):
|
|
294
222
|
"provider": "ollama"
|
295
223
|
}
|
296
224
|
|
297
|
-
def _has_bound_tools(self) -> bool:
|
298
|
-
"""Check if this service has bound tools"""
|
299
|
-
return bool(self._bound_tools)
|
300
|
-
|
301
|
-
def _get_bound_tools(self) -> List[Dict[str, Any]]:
|
302
|
-
"""Get the bound tools schema"""
|
303
|
-
return self._bound_tools
|
304
225
|
|
305
226
|
async def close(self):
|
306
227
|
"""Close the HTTP client"""
|
307
|
-
|
228
|
+
if hasattr(self, 'client') and self.client:
|
229
|
+
try:
|
230
|
+
if not self.client.is_closed:
|
231
|
+
await self.client.aclose()
|
232
|
+
except Exception as e:
|
233
|
+
logger.warning(f"Error closing Ollama client: {e}")
|