isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +770 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/models/model_repo.py +343 -0
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +187 -387
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -55
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
- isa_model/inference/services/img/flux_professional_service.py +603 -0
- isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +519 -35
- isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +150 -15
- isa_model/inference/services/llm/openai_llm_service.py +134 -31
- isa_model/inference/services/llm/yyds_llm_service.py +255 -0
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +241 -96
- isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
- isa_model/inference/services/vision/doc_analysis_service.py +640 -0
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +109 -170
- isa_model/inference/services/vision/replicate_vision_service.py +508 -0
- isa_model/inference/services/vision/ui_analysis_service.py +823 -0
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +89 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/unified.py +202 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
- isa_model-0.3.6.dist-info/RECORD +147 -0
- isa_model/core/model_manager.py +0 -208
- isa_model/core/model_registry.py +0 -342
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- isa_model-0.3.4.dist-info/RECORD +0 -91
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,21 @@ import httpx
|
|
3
3
|
import json
|
4
4
|
from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
|
5
5
|
from isa_model.inference.services.llm.base_llm_service import BaseLLMService
|
6
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
7
6
|
|
8
7
|
logger = logging.getLogger(__name__)
|
9
8
|
|
10
9
|
class OllamaLLMService(BaseLLMService):
|
11
10
|
"""Ollama LLM service with unified invoke interface and proper adapter support"""
|
12
11
|
|
13
|
-
def __init__(self,
|
14
|
-
super().__init__(
|
12
|
+
def __init__(self, provider_name: str, model_name: str = "llama3.2:3b-instruct-fp16", **kwargs):
|
13
|
+
super().__init__(provider_name, model_name, **kwargs)
|
14
|
+
|
15
|
+
# Get configuration from centralized config manager
|
16
|
+
provider_config = self.get_provider_config()
|
15
17
|
|
16
18
|
# Create HTTP client for Ollama API
|
17
|
-
base_url =
|
18
|
-
timeout =
|
19
|
+
base_url = provider_config.get("base_url", "http://localhost:11434")
|
20
|
+
timeout = provider_config.get("timeout", 60)
|
19
21
|
|
20
22
|
self.client = httpx.AsyncClient(
|
21
23
|
base_url=base_url,
|
@@ -31,13 +33,14 @@ class OllamaLLMService(BaseLLMService):
|
|
31
33
|
def _ensure_client(self):
|
32
34
|
"""Ensure the HTTP client is available and not closed"""
|
33
35
|
if not hasattr(self, 'client') or not self.client or self.client.is_closed:
|
34
|
-
|
35
|
-
|
36
|
+
provider_config = self.get_provider_config()
|
37
|
+
base_url = provider_config.get("base_url", "http://localhost:11434")
|
38
|
+
timeout = provider_config.get("timeout", 60)
|
36
39
|
self.client = httpx.AsyncClient(base_url=base_url, timeout=timeout)
|
37
40
|
|
38
41
|
def _create_bound_copy(self) -> 'OllamaLLMService':
|
39
42
|
"""Create a copy of this service for tool binding"""
|
40
|
-
bound_service = OllamaLLMService(self.
|
43
|
+
bound_service = OllamaLLMService(self.provider_name, self.model_name)
|
41
44
|
bound_service._bound_tools = self._bound_tools.copy()
|
42
45
|
return bound_service
|
43
46
|
|
@@ -70,14 +73,15 @@ class OllamaLLMService(BaseLLMService):
|
|
70
73
|
messages = self._prepare_messages(input_data)
|
71
74
|
|
72
75
|
# Prepare request parameters
|
76
|
+
provider_config = self.get_provider_config()
|
73
77
|
payload = {
|
74
78
|
"model": self.model_name,
|
75
79
|
"messages": messages,
|
76
80
|
"stream": self.streaming,
|
77
81
|
"options": {
|
78
|
-
"temperature":
|
79
|
-
"top_p":
|
80
|
-
"num_predict":
|
82
|
+
"temperature": provider_config.get("temperature", 0.7),
|
83
|
+
"top_p": provider_config.get("top_p", 0.9),
|
84
|
+
"num_predict": provider_config.get("max_tokens", 2048)
|
81
85
|
}
|
82
86
|
}
|
83
87
|
|
@@ -86,9 +90,15 @@ class OllamaLLMService(BaseLLMService):
|
|
86
90
|
if tool_schemas:
|
87
91
|
payload["tools"] = tool_schemas
|
88
92
|
|
89
|
-
# Handle streaming
|
93
|
+
# Handle streaming vs non-streaming
|
90
94
|
if self.streaming:
|
91
|
-
|
95
|
+
# TRUE STREAMING MODE - collect all chunks from the stream
|
96
|
+
content_chunks = []
|
97
|
+
async for token in self.astream(input_data):
|
98
|
+
content_chunks.append(token)
|
99
|
+
content = "".join(content_chunks)
|
100
|
+
|
101
|
+
return self._format_response(content, input_data)
|
92
102
|
|
93
103
|
# Regular request
|
94
104
|
response = await self.client.post("/api/chat", json=payload)
|
@@ -98,6 +108,7 @@ class OllamaLLMService(BaseLLMService):
|
|
98
108
|
# Update token usage if available
|
99
109
|
if "eval_count" in result:
|
100
110
|
self._update_token_usage(result)
|
111
|
+
await self._track_ollama_billing(result)
|
101
112
|
|
102
113
|
# Handle tool calls if present - let adapter process the complete message
|
103
114
|
message = result["message"]
|
@@ -190,6 +201,44 @@ class OllamaLLMService(BaseLLMService):
|
|
190
201
|
# Get final response from the model
|
191
202
|
return await self.ainvoke(messages)
|
192
203
|
|
204
|
+
async def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
|
205
|
+
"""Track usage for streaming requests (estimated)"""
|
206
|
+
# Create a mock usage object for tracking
|
207
|
+
class MockUsage:
|
208
|
+
def __init__(self):
|
209
|
+
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
210
|
+
self.completion_tokens = len(content) // 4 # Rough estimate
|
211
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
212
|
+
|
213
|
+
usage = MockUsage()
|
214
|
+
self._update_token_usage_from_mock(usage)
|
215
|
+
|
216
|
+
# Track billing
|
217
|
+
await self._track_llm_usage(
|
218
|
+
operation="chat_stream",
|
219
|
+
input_tokens=usage.prompt_tokens,
|
220
|
+
output_tokens=usage.completion_tokens,
|
221
|
+
metadata={
|
222
|
+
"model": self.model_name,
|
223
|
+
"provider": "ollama",
|
224
|
+
"streaming": True
|
225
|
+
}
|
226
|
+
)
|
227
|
+
|
228
|
+
def _update_token_usage_from_mock(self, usage):
|
229
|
+
"""Update token usage statistics from mock usage object"""
|
230
|
+
self.last_token_usage = {
|
231
|
+
"prompt_tokens": usage.prompt_tokens,
|
232
|
+
"completion_tokens": usage.completion_tokens,
|
233
|
+
"total_tokens": usage.total_tokens
|
234
|
+
}
|
235
|
+
|
236
|
+
# Update total usage
|
237
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
238
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
239
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
240
|
+
self.total_token_usage["requests_count"] += 1
|
241
|
+
|
193
242
|
def _update_token_usage(self, result: Dict[str, Any]):
|
194
243
|
"""Update token usage statistics"""
|
195
244
|
self.last_token_usage = {
|
@@ -204,6 +253,21 @@ class OllamaLLMService(BaseLLMService):
|
|
204
253
|
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
205
254
|
self.total_token_usage["requests_count"] += 1
|
206
255
|
|
256
|
+
async def _track_ollama_billing(self, result: Dict[str, Any]):
|
257
|
+
"""Track billing information for Ollama requests"""
|
258
|
+
prompt_tokens = result.get("prompt_eval_count", 0)
|
259
|
+
completion_tokens = result.get("eval_count", 0)
|
260
|
+
|
261
|
+
await self._track_llm_usage(
|
262
|
+
operation="chat",
|
263
|
+
input_tokens=prompt_tokens,
|
264
|
+
output_tokens=completion_tokens,
|
265
|
+
metadata={
|
266
|
+
"model": self.model_name,
|
267
|
+
"provider": "ollama"
|
268
|
+
}
|
269
|
+
)
|
270
|
+
|
207
271
|
def get_token_usage(self) -> Dict[str, Any]:
|
208
272
|
"""Get total token usage statistics"""
|
209
273
|
return self.total_token_usage
|
@@ -214,9 +278,10 @@ class OllamaLLMService(BaseLLMService):
|
|
214
278
|
|
215
279
|
def get_model_info(self) -> Dict[str, Any]:
|
216
280
|
"""Get information about the current model"""
|
281
|
+
provider_config = self.get_provider_config()
|
217
282
|
return {
|
218
283
|
"name": self.model_name,
|
219
|
-
"max_tokens":
|
284
|
+
"max_tokens": provider_config.get("max_tokens", 2048),
|
220
285
|
"supports_streaming": True,
|
221
286
|
"supports_functions": True,
|
222
287
|
"provider": "ollama"
|
@@ -230,4 +295,74 @@ class OllamaLLMService(BaseLLMService):
|
|
230
295
|
if not self.client.is_closed:
|
231
296
|
await self.client.aclose()
|
232
297
|
except Exception as e:
|
233
|
-
logger.warning(f"Error closing Ollama client: {e}")
|
298
|
+
logger.warning(f"Error closing Ollama client: {e}")
|
299
|
+
|
300
|
+
async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
|
301
|
+
"""
|
302
|
+
True streaming method that yields tokens one by one as they arrive
|
303
|
+
|
304
|
+
Args:
|
305
|
+
input_data: Can be:
|
306
|
+
- str: Simple text prompt
|
307
|
+
- list: Message history like [{"role": "user", "content": "hello"}]
|
308
|
+
- Any: LangChain message objects or other formats
|
309
|
+
|
310
|
+
Yields:
|
311
|
+
Individual tokens as they arrive from the model
|
312
|
+
"""
|
313
|
+
try:
|
314
|
+
# Ensure client is available
|
315
|
+
self._ensure_client()
|
316
|
+
|
317
|
+
# Use adapter manager to prepare messages
|
318
|
+
messages = self._prepare_messages(input_data)
|
319
|
+
|
320
|
+
# Prepare request parameters for streaming
|
321
|
+
provider_config = self.get_provider_config()
|
322
|
+
payload = {
|
323
|
+
"model": self.model_name,
|
324
|
+
"messages": messages,
|
325
|
+
"stream": True, # Force streaming for astream
|
326
|
+
"options": {
|
327
|
+
"temperature": provider_config.get("temperature", 0.7),
|
328
|
+
"top_p": provider_config.get("top_p", 0.9),
|
329
|
+
"num_predict": provider_config.get("max_tokens", 2048)
|
330
|
+
}
|
331
|
+
}
|
332
|
+
|
333
|
+
# Add tools if bound using adapter manager
|
334
|
+
tool_schemas = await self._prepare_tools_for_request()
|
335
|
+
if tool_schemas:
|
336
|
+
payload["tools"] = tool_schemas
|
337
|
+
|
338
|
+
# Stream tokens one by one
|
339
|
+
content_chunks = []
|
340
|
+
try:
|
341
|
+
async with self.client.stream("POST", "/api/chat", json=payload) as response:
|
342
|
+
response.raise_for_status()
|
343
|
+
async for line in response.aiter_lines():
|
344
|
+
if line.strip():
|
345
|
+
try:
|
346
|
+
chunk = json.loads(line)
|
347
|
+
if "message" in chunk and "content" in chunk["message"]:
|
348
|
+
content = chunk["message"]["content"]
|
349
|
+
if content:
|
350
|
+
content_chunks.append(content)
|
351
|
+
yield content
|
352
|
+
except json.JSONDecodeError:
|
353
|
+
continue
|
354
|
+
|
355
|
+
# Track usage after streaming is complete (estimated)
|
356
|
+
full_content = "".join(content_chunks)
|
357
|
+
await self._track_streaming_usage(messages, full_content)
|
358
|
+
|
359
|
+
except Exception as e:
|
360
|
+
logger.error(f"Error in streaming: {e}")
|
361
|
+
raise
|
362
|
+
|
363
|
+
except httpx.RequestError as e:
|
364
|
+
logger.error(f"HTTP request error in astream: {e}")
|
365
|
+
raise
|
366
|
+
except Exception as e:
|
367
|
+
logger.error(f"Error in astream: {e}")
|
368
|
+
raise
|
@@ -7,19 +7,18 @@ from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
|
|
7
7
|
from openai import AsyncOpenAI
|
8
8
|
|
9
9
|
from isa_model.inference.services.llm.base_llm_service import BaseLLMService
|
10
|
-
from
|
11
|
-
from isa_model.inference.billing_tracker import ServiceType
|
10
|
+
from ....core.types import ServiceType
|
12
11
|
|
13
12
|
logger = logging.getLogger(__name__)
|
14
13
|
|
15
14
|
class OpenAILLMService(BaseLLMService):
|
16
15
|
"""OpenAI LLM service implementation with unified invoke interface"""
|
17
16
|
|
18
|
-
def __init__(self,
|
19
|
-
super().__init__(
|
17
|
+
def __init__(self, model_name: str = "gpt-4o-mini", provider_name: str = "openai", **kwargs):
|
18
|
+
super().__init__(provider_name, model_name, **kwargs)
|
20
19
|
|
21
|
-
# Get
|
22
|
-
provider_config =
|
20
|
+
# Get configuration from centralized config manager
|
21
|
+
provider_config = self.get_provider_config()
|
23
22
|
|
24
23
|
# Initialize AsyncOpenAI client with provider configuration
|
25
24
|
try:
|
@@ -28,7 +27,7 @@ class OpenAILLMService(BaseLLMService):
|
|
28
27
|
|
29
28
|
self.client = AsyncOpenAI(
|
30
29
|
api_key=provider_config["api_key"],
|
31
|
-
base_url=provider_config.get("
|
30
|
+
base_url=provider_config.get("api_base_url", "https://api.openai.com/v1"),
|
32
31
|
organization=provider_config.get("organization")
|
33
32
|
)
|
34
33
|
|
@@ -44,7 +43,7 @@ class OpenAILLMService(BaseLLMService):
|
|
44
43
|
|
45
44
|
def _create_bound_copy(self) -> 'OpenAILLMService':
|
46
45
|
"""Create a copy of this service for tool binding"""
|
47
|
-
bound_service = OpenAILLMService(self.
|
46
|
+
bound_service = OpenAILLMService(self.model_name, self.provider_name)
|
48
47
|
bound_service._bound_tools = self._bound_tools.copy()
|
49
48
|
return bound_service
|
50
49
|
|
@@ -67,6 +66,58 @@ class OpenAILLMService(BaseLLMService):
|
|
67
66
|
|
68
67
|
return bound_service
|
69
68
|
|
69
|
+
async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
|
70
|
+
"""
|
71
|
+
True streaming method - yields tokens one by one as they arrive
|
72
|
+
|
73
|
+
Args:
|
74
|
+
input_data: Same as ainvoke
|
75
|
+
|
76
|
+
Yields:
|
77
|
+
Individual tokens as they arrive from the API
|
78
|
+
"""
|
79
|
+
try:
|
80
|
+
# Use adapter manager to prepare messages
|
81
|
+
messages = self._prepare_messages(input_data)
|
82
|
+
|
83
|
+
# Prepare request kwargs
|
84
|
+
provider_config = self.get_provider_config()
|
85
|
+
kwargs = {
|
86
|
+
"model": self.model_name,
|
87
|
+
"messages": messages,
|
88
|
+
"temperature": provider_config.get("temperature", 0.7),
|
89
|
+
"max_tokens": provider_config.get("max_tokens", 1024),
|
90
|
+
"stream": True
|
91
|
+
}
|
92
|
+
|
93
|
+
# Add tools if bound using adapter manager
|
94
|
+
tool_schemas = await self._prepare_tools_for_request()
|
95
|
+
if tool_schemas:
|
96
|
+
kwargs["tools"] = tool_schemas
|
97
|
+
kwargs["tool_choice"] = "auto"
|
98
|
+
|
99
|
+
# Stream tokens one by one
|
100
|
+
content_chunks = []
|
101
|
+
try:
|
102
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
103
|
+
async for chunk in stream:
|
104
|
+
content = chunk.choices[0].delta.content
|
105
|
+
if content:
|
106
|
+
content_chunks.append(content)
|
107
|
+
yield content
|
108
|
+
|
109
|
+
# Track usage after streaming is complete
|
110
|
+
full_content = "".join(content_chunks)
|
111
|
+
self._track_streaming_usage(messages, full_content)
|
112
|
+
|
113
|
+
except Exception as e:
|
114
|
+
logger.error(f"Error in streaming: {e}")
|
115
|
+
raise
|
116
|
+
|
117
|
+
except Exception as e:
|
118
|
+
logger.error(f"Error in astream: {e}")
|
119
|
+
raise
|
120
|
+
|
70
121
|
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
71
122
|
"""Unified invoke method for all input types"""
|
72
123
|
try:
|
@@ -74,11 +125,12 @@ class OpenAILLMService(BaseLLMService):
|
|
74
125
|
messages = self._prepare_messages(input_data)
|
75
126
|
|
76
127
|
# Prepare request kwargs
|
128
|
+
provider_config = self.get_provider_config()
|
77
129
|
kwargs = {
|
78
130
|
"model": self.model_name,
|
79
131
|
"messages": messages,
|
80
|
-
"temperature":
|
81
|
-
"max_tokens":
|
132
|
+
"temperature": provider_config.get("temperature", 0.7),
|
133
|
+
"max_tokens": provider_config.get("max_tokens", 1024)
|
82
134
|
}
|
83
135
|
|
84
136
|
# Add tools if bound using adapter manager
|
@@ -89,23 +141,12 @@ class OpenAILLMService(BaseLLMService):
|
|
89
141
|
|
90
142
|
# Handle streaming vs non-streaming
|
91
143
|
if self.streaming:
|
92
|
-
#
|
144
|
+
# TRUE STREAMING MODE - collect all chunks from the stream
|
93
145
|
content_chunks = []
|
94
|
-
async for
|
95
|
-
content_chunks.append(
|
146
|
+
async for token in self.astream(input_data):
|
147
|
+
content_chunks.append(token)
|
96
148
|
content = "".join(content_chunks)
|
97
149
|
|
98
|
-
# Create a mock usage object for tracking
|
99
|
-
class MockUsage:
|
100
|
-
def __init__(self):
|
101
|
-
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
102
|
-
self.completion_tokens = len(content) // 4 # Rough estimate
|
103
|
-
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
104
|
-
|
105
|
-
usage = MockUsage()
|
106
|
-
self._update_token_usage(usage)
|
107
|
-
self._track_billing(usage)
|
108
|
-
|
109
150
|
return self._format_response(content, input_data)
|
110
151
|
else:
|
111
152
|
# Non-streaming mode
|
@@ -115,7 +156,7 @@ class OpenAILLMService(BaseLLMService):
|
|
115
156
|
# Update usage tracking
|
116
157
|
if response.usage:
|
117
158
|
self._update_token_usage(response.usage)
|
118
|
-
self._track_billing(response.usage)
|
159
|
+
await self._track_billing(response.usage)
|
119
160
|
|
120
161
|
# Handle tool calls if present - let adapter process the complete message
|
121
162
|
if message.tool_calls:
|
@@ -129,9 +170,28 @@ class OpenAILLMService(BaseLLMService):
|
|
129
170
|
logger.error(f"Error in ainvoke: {e}")
|
130
171
|
raise
|
131
172
|
|
173
|
+
def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
|
174
|
+
"""Track usage for streaming requests (estimated)"""
|
175
|
+
# Create a mock usage object for tracking
|
176
|
+
class MockUsage:
|
177
|
+
def __init__(self):
|
178
|
+
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
179
|
+
self.completion_tokens = len(content) // 4 # Rough estimate
|
180
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
181
|
+
|
182
|
+
usage = MockUsage()
|
183
|
+
self._update_token_usage(usage)
|
184
|
+
# Fire and forget async tracking
|
185
|
+
import asyncio
|
186
|
+
try:
|
187
|
+
loop = asyncio.get_event_loop()
|
188
|
+
loop.create_task(self._track_billing(usage))
|
189
|
+
except:
|
190
|
+
# If no event loop, skip tracking
|
191
|
+
pass
|
132
192
|
|
133
193
|
async def _stream_response(self, kwargs: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
134
|
-
"""Handle streaming responses"""
|
194
|
+
"""Handle streaming responses - DEPRECATED: Use astream() instead"""
|
135
195
|
kwargs["stream"] = True
|
136
196
|
|
137
197
|
async def stream_generator():
|
@@ -162,16 +222,17 @@ class OpenAILLMService(BaseLLMService):
|
|
162
222
|
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
163
223
|
self.total_token_usage["requests_count"] += 1
|
164
224
|
|
165
|
-
def _track_billing(self, usage):
|
225
|
+
async def _track_billing(self, usage):
|
166
226
|
"""Track billing information"""
|
167
|
-
self.
|
227
|
+
provider_config = self.get_provider_config()
|
228
|
+
await self._track_usage(
|
168
229
|
service_type=ServiceType.LLM,
|
169
230
|
operation="chat",
|
170
231
|
input_tokens=usage.prompt_tokens,
|
171
232
|
output_tokens=usage.completion_tokens,
|
172
233
|
metadata={
|
173
|
-
"temperature":
|
174
|
-
"max_tokens":
|
234
|
+
"temperature": provider_config.get("temperature", 0.7),
|
235
|
+
"max_tokens": provider_config.get("max_tokens", 1024)
|
175
236
|
}
|
176
237
|
)
|
177
238
|
|
@@ -185,15 +246,57 @@ class OpenAILLMService(BaseLLMService):
|
|
185
246
|
|
186
247
|
def get_model_info(self) -> Dict[str, Any]:
|
187
248
|
"""Get information about the current model"""
|
249
|
+
provider_config = self.get_provider_config()
|
188
250
|
return {
|
189
251
|
"name": self.model_name,
|
190
|
-
"max_tokens":
|
252
|
+
"max_tokens": provider_config.get("max_tokens", 1024),
|
191
253
|
"supports_streaming": True,
|
192
254
|
"supports_functions": True,
|
193
255
|
"provider": "openai"
|
194
256
|
}
|
195
257
|
|
196
258
|
|
259
|
+
async def chat(
|
260
|
+
self,
|
261
|
+
input_data: Union[str, List[Dict[str, str]], Any],
|
262
|
+
max_tokens: Optional[int] = None
|
263
|
+
) -> Dict[str, Any]:
|
264
|
+
"""
|
265
|
+
Chat method that wraps ainvoke for compatibility with base class
|
266
|
+
|
267
|
+
Args:
|
268
|
+
input_data: Input messages
|
269
|
+
max_tokens: Maximum tokens to generate
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
Dict containing chat response
|
273
|
+
"""
|
274
|
+
try:
|
275
|
+
# Call ainvoke and get the response
|
276
|
+
response = await self.ainvoke(input_data)
|
277
|
+
|
278
|
+
# Return in expected format
|
279
|
+
return {
|
280
|
+
"text": response if isinstance(response, str) else str(response),
|
281
|
+
"success": True,
|
282
|
+
"metadata": {
|
283
|
+
"model": self.model_name,
|
284
|
+
"provider": self.provider_name,
|
285
|
+
"max_tokens": max_tokens or self.max_tokens
|
286
|
+
}
|
287
|
+
}
|
288
|
+
except Exception as e:
|
289
|
+
logger.error(f"Chat method failed: {e}")
|
290
|
+
return {
|
291
|
+
"text": "",
|
292
|
+
"success": False,
|
293
|
+
"error": str(e),
|
294
|
+
"metadata": {
|
295
|
+
"model": self.model_name,
|
296
|
+
"provider": self.provider_name
|
297
|
+
}
|
298
|
+
}
|
299
|
+
|
197
300
|
async def close(self):
|
198
301
|
"""Close the backend client"""
|
199
302
|
await self.client.close()
|