isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. isa_model/__init__.py +30 -1
  2. isa_model/client.py +770 -0
  3. isa_model/core/config/__init__.py +16 -0
  4. isa_model/core/config/config_manager.py +514 -0
  5. isa_model/core/config.py +426 -0
  6. isa_model/core/models/model_billing_tracker.py +476 -0
  7. isa_model/core/models/model_manager.py +399 -0
  8. isa_model/core/models/model_repo.py +343 -0
  9. isa_model/core/pricing_manager.py +426 -0
  10. isa_model/core/services/__init__.py +19 -0
  11. isa_model/core/services/intelligent_model_selector.py +547 -0
  12. isa_model/core/types.py +291 -0
  13. isa_model/deployment/__init__.py +2 -0
  14. isa_model/deployment/cloud/__init__.py +9 -0
  15. isa_model/deployment/cloud/modal/__init__.py +10 -0
  16. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
  17. isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
  18. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
  19. isa_model/deployment/cloud/modal/register_models.py +321 -0
  20. isa_model/deployment/runtime/deployed_service.py +338 -0
  21. isa_model/deployment/services/__init__.py +9 -0
  22. isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
  23. isa_model/deployment/services/model_service.py +332 -0
  24. isa_model/deployment/services/service_monitor.py +356 -0
  25. isa_model/deployment/services/service_registry.py +527 -0
  26. isa_model/eval/__init__.py +80 -44
  27. isa_model/eval/config/__init__.py +10 -0
  28. isa_model/eval/config/evaluation_config.py +108 -0
  29. isa_model/eval/evaluators/__init__.py +18 -0
  30. isa_model/eval/evaluators/base_evaluator.py +503 -0
  31. isa_model/eval/evaluators/llm_evaluator.py +472 -0
  32. isa_model/eval/factory.py +417 -709
  33. isa_model/eval/infrastructure/__init__.py +24 -0
  34. isa_model/eval/infrastructure/experiment_tracker.py +466 -0
  35. isa_model/eval/metrics.py +191 -21
  36. isa_model/inference/ai_factory.py +187 -387
  37. isa_model/inference/providers/modal_provider.py +109 -0
  38. isa_model/inference/providers/yyds_provider.py +108 -0
  39. isa_model/inference/services/__init__.py +2 -1
  40. isa_model/inference/services/audio/base_stt_service.py +65 -1
  41. isa_model/inference/services/audio/base_tts_service.py +75 -1
  42. isa_model/inference/services/audio/openai_stt_service.py +189 -151
  43. isa_model/inference/services/audio/openai_tts_service.py +12 -10
  44. isa_model/inference/services/audio/replicate_tts_service.py +61 -56
  45. isa_model/inference/services/base_service.py +55 -55
  46. isa_model/inference/services/embedding/base_embed_service.py +65 -1
  47. isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
  48. isa_model/inference/services/embedding/openai_embed_service.py +8 -10
  49. isa_model/inference/services/helpers/stacked_config.py +148 -0
  50. isa_model/inference/services/img/__init__.py +18 -0
  51. isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
  52. isa_model/inference/services/img/flux_professional_service.py +603 -0
  53. isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
  54. isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
  55. isa_model/inference/services/llm/__init__.py +3 -3
  56. isa_model/inference/services/llm/base_llm_service.py +519 -35
  57. isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
  58. isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
  59. isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
  60. isa_model/inference/services/llm/ollama_llm_service.py +150 -15
  61. isa_model/inference/services/llm/openai_llm_service.py +134 -31
  62. isa_model/inference/services/llm/yyds_llm_service.py +255 -0
  63. isa_model/inference/services/vision/__init__.py +38 -4
  64. isa_model/inference/services/vision/base_vision_service.py +241 -96
  65. isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
  66. isa_model/inference/services/vision/doc_analysis_service.py +640 -0
  67. isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
  68. isa_model/inference/services/vision/helpers/image_utils.py +272 -3
  69. isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
  70. isa_model/inference/services/vision/openai_vision_service.py +109 -170
  71. isa_model/inference/services/vision/replicate_vision_service.py +508 -0
  72. isa_model/inference/services/vision/ui_analysis_service.py +823 -0
  73. isa_model/scripts/register_models.py +370 -0
  74. isa_model/scripts/register_models_with_embeddings.py +510 -0
  75. isa_model/serving/__init__.py +19 -0
  76. isa_model/serving/api/__init__.py +10 -0
  77. isa_model/serving/api/fastapi_server.py +89 -0
  78. isa_model/serving/api/middleware/__init__.py +9 -0
  79. isa_model/serving/api/middleware/request_logger.py +88 -0
  80. isa_model/serving/api/routes/__init__.py +5 -0
  81. isa_model/serving/api/routes/health.py +82 -0
  82. isa_model/serving/api/routes/llm.py +19 -0
  83. isa_model/serving/api/routes/ui_analysis.py +223 -0
  84. isa_model/serving/api/routes/unified.py +202 -0
  85. isa_model/serving/api/routes/vision.py +19 -0
  86. isa_model/serving/api/schemas/__init__.py +17 -0
  87. isa_model/serving/api/schemas/common.py +33 -0
  88. isa_model/serving/api/schemas/ui_analysis.py +78 -0
  89. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
  90. isa_model-0.3.6.dist-info/RECORD +147 -0
  91. isa_model/core/model_manager.py +0 -208
  92. isa_model/core/model_registry.py +0 -342
  93. isa_model/inference/billing_tracker.py +0 -406
  94. isa_model/inference/services/llm/triton_llm_service.py +0 -481
  95. isa_model/inference/services/vision/ollama_vision_service.py +0 -194
  96. isa_model-0.3.4.dist-info/RECORD +0 -91
  97. /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
  98. /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
  99. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
  100. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,21 @@ import httpx
3
3
  import json
4
4
  from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
5
5
  from isa_model.inference.services.llm.base_llm_service import BaseLLMService
6
- from isa_model.inference.providers.base_provider import BaseProvider
7
6
 
8
7
  logger = logging.getLogger(__name__)
9
8
 
10
9
  class OllamaLLMService(BaseLLMService):
11
10
  """Ollama LLM service with unified invoke interface and proper adapter support"""
12
11
 
13
- def __init__(self, provider: 'BaseProvider', model_name: str = "llama3.2:3b-instruct-fp16"):
14
- super().__init__(provider, model_name)
12
+ def __init__(self, provider_name: str, model_name: str = "llama3.2:3b-instruct-fp16", **kwargs):
13
+ super().__init__(provider_name, model_name, **kwargs)
14
+
15
+ # Get configuration from centralized config manager
16
+ provider_config = self.get_provider_config()
15
17
 
16
18
  # Create HTTP client for Ollama API
17
- base_url = self.config.get("base_url", "http://localhost:11434")
18
- timeout = self.config.get("timeout", 60)
19
+ base_url = provider_config.get("base_url", "http://localhost:11434")
20
+ timeout = provider_config.get("timeout", 60)
19
21
 
20
22
  self.client = httpx.AsyncClient(
21
23
  base_url=base_url,
@@ -31,13 +33,14 @@ class OllamaLLMService(BaseLLMService):
31
33
  def _ensure_client(self):
32
34
  """Ensure the HTTP client is available and not closed"""
33
35
  if not hasattr(self, 'client') or not self.client or self.client.is_closed:
34
- base_url = self.config.get("base_url", "http://localhost:11434")
35
- timeout = self.config.get("timeout", 60)
36
+ provider_config = self.get_provider_config()
37
+ base_url = provider_config.get("base_url", "http://localhost:11434")
38
+ timeout = provider_config.get("timeout", 60)
36
39
  self.client = httpx.AsyncClient(base_url=base_url, timeout=timeout)
37
40
 
38
41
  def _create_bound_copy(self) -> 'OllamaLLMService':
39
42
  """Create a copy of this service for tool binding"""
40
- bound_service = OllamaLLMService(self.provider, self.model_name)
43
+ bound_service = OllamaLLMService(self.provider_name, self.model_name)
41
44
  bound_service._bound_tools = self._bound_tools.copy()
42
45
  return bound_service
43
46
 
@@ -70,14 +73,15 @@ class OllamaLLMService(BaseLLMService):
70
73
  messages = self._prepare_messages(input_data)
71
74
 
72
75
  # Prepare request parameters
76
+ provider_config = self.get_provider_config()
73
77
  payload = {
74
78
  "model": self.model_name,
75
79
  "messages": messages,
76
80
  "stream": self.streaming,
77
81
  "options": {
78
- "temperature": self.config.get("temperature", 0.7),
79
- "top_p": self.config.get("top_p", 0.9),
80
- "num_predict": self.config.get("max_tokens", 2048)
82
+ "temperature": provider_config.get("temperature", 0.7),
83
+ "top_p": provider_config.get("top_p", 0.9),
84
+ "num_predict": provider_config.get("max_tokens", 2048)
81
85
  }
82
86
  }
83
87
 
@@ -86,9 +90,15 @@ class OllamaLLMService(BaseLLMService):
86
90
  if tool_schemas:
87
91
  payload["tools"] = tool_schemas
88
92
 
89
- # Handle streaming
93
+ # Handle streaming vs non-streaming
90
94
  if self.streaming:
91
- return self._stream_response(payload)
95
+ # TRUE STREAMING MODE - collect all chunks from the stream
96
+ content_chunks = []
97
+ async for token in self.astream(input_data):
98
+ content_chunks.append(token)
99
+ content = "".join(content_chunks)
100
+
101
+ return self._format_response(content, input_data)
92
102
 
93
103
  # Regular request
94
104
  response = await self.client.post("/api/chat", json=payload)
@@ -98,6 +108,7 @@ class OllamaLLMService(BaseLLMService):
98
108
  # Update token usage if available
99
109
  if "eval_count" in result:
100
110
  self._update_token_usage(result)
111
+ await self._track_ollama_billing(result)
101
112
 
102
113
  # Handle tool calls if present - let adapter process the complete message
103
114
  message = result["message"]
@@ -190,6 +201,44 @@ class OllamaLLMService(BaseLLMService):
190
201
  # Get final response from the model
191
202
  return await self.ainvoke(messages)
192
203
 
204
+ async def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
205
+ """Track usage for streaming requests (estimated)"""
206
+ # Create a mock usage object for tracking
207
+ class MockUsage:
208
+ def __init__(self):
209
+ self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
210
+ self.completion_tokens = len(content) // 4 # Rough estimate
211
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
212
+
213
+ usage = MockUsage()
214
+ self._update_token_usage_from_mock(usage)
215
+
216
+ # Track billing
217
+ await self._track_llm_usage(
218
+ operation="chat_stream",
219
+ input_tokens=usage.prompt_tokens,
220
+ output_tokens=usage.completion_tokens,
221
+ metadata={
222
+ "model": self.model_name,
223
+ "provider": "ollama",
224
+ "streaming": True
225
+ }
226
+ )
227
+
228
+ def _update_token_usage_from_mock(self, usage):
229
+ """Update token usage statistics from mock usage object"""
230
+ self.last_token_usage = {
231
+ "prompt_tokens": usage.prompt_tokens,
232
+ "completion_tokens": usage.completion_tokens,
233
+ "total_tokens": usage.total_tokens
234
+ }
235
+
236
+ # Update total usage
237
+ self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
238
+ self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
239
+ self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
240
+ self.total_token_usage["requests_count"] += 1
241
+
193
242
  def _update_token_usage(self, result: Dict[str, Any]):
194
243
  """Update token usage statistics"""
195
244
  self.last_token_usage = {
@@ -204,6 +253,21 @@ class OllamaLLMService(BaseLLMService):
204
253
  self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
205
254
  self.total_token_usage["requests_count"] += 1
206
255
 
256
+ async def _track_ollama_billing(self, result: Dict[str, Any]):
257
+ """Track billing information for Ollama requests"""
258
+ prompt_tokens = result.get("prompt_eval_count", 0)
259
+ completion_tokens = result.get("eval_count", 0)
260
+
261
+ await self._track_llm_usage(
262
+ operation="chat",
263
+ input_tokens=prompt_tokens,
264
+ output_tokens=completion_tokens,
265
+ metadata={
266
+ "model": self.model_name,
267
+ "provider": "ollama"
268
+ }
269
+ )
270
+
207
271
  def get_token_usage(self) -> Dict[str, Any]:
208
272
  """Get total token usage statistics"""
209
273
  return self.total_token_usage
@@ -214,9 +278,10 @@ class OllamaLLMService(BaseLLMService):
214
278
 
215
279
  def get_model_info(self) -> Dict[str, Any]:
216
280
  """Get information about the current model"""
281
+ provider_config = self.get_provider_config()
217
282
  return {
218
283
  "name": self.model_name,
219
- "max_tokens": self.config.get("max_tokens", 2048),
284
+ "max_tokens": provider_config.get("max_tokens", 2048),
220
285
  "supports_streaming": True,
221
286
  "supports_functions": True,
222
287
  "provider": "ollama"
@@ -230,4 +295,74 @@ class OllamaLLMService(BaseLLMService):
230
295
  if not self.client.is_closed:
231
296
  await self.client.aclose()
232
297
  except Exception as e:
233
- logger.warning(f"Error closing Ollama client: {e}")
298
+ logger.warning(f"Error closing Ollama client: {e}")
299
+
300
+ async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
301
+ """
302
+ True streaming method that yields tokens one by one as they arrive
303
+
304
+ Args:
305
+ input_data: Can be:
306
+ - str: Simple text prompt
307
+ - list: Message history like [{"role": "user", "content": "hello"}]
308
+ - Any: LangChain message objects or other formats
309
+
310
+ Yields:
311
+ Individual tokens as they arrive from the model
312
+ """
313
+ try:
314
+ # Ensure client is available
315
+ self._ensure_client()
316
+
317
+ # Use adapter manager to prepare messages
318
+ messages = self._prepare_messages(input_data)
319
+
320
+ # Prepare request parameters for streaming
321
+ provider_config = self.get_provider_config()
322
+ payload = {
323
+ "model": self.model_name,
324
+ "messages": messages,
325
+ "stream": True, # Force streaming for astream
326
+ "options": {
327
+ "temperature": provider_config.get("temperature", 0.7),
328
+ "top_p": provider_config.get("top_p", 0.9),
329
+ "num_predict": provider_config.get("max_tokens", 2048)
330
+ }
331
+ }
332
+
333
+ # Add tools if bound using adapter manager
334
+ tool_schemas = await self._prepare_tools_for_request()
335
+ if tool_schemas:
336
+ payload["tools"] = tool_schemas
337
+
338
+ # Stream tokens one by one
339
+ content_chunks = []
340
+ try:
341
+ async with self.client.stream("POST", "/api/chat", json=payload) as response:
342
+ response.raise_for_status()
343
+ async for line in response.aiter_lines():
344
+ if line.strip():
345
+ try:
346
+ chunk = json.loads(line)
347
+ if "message" in chunk and "content" in chunk["message"]:
348
+ content = chunk["message"]["content"]
349
+ if content:
350
+ content_chunks.append(content)
351
+ yield content
352
+ except json.JSONDecodeError:
353
+ continue
354
+
355
+ # Track usage after streaming is complete (estimated)
356
+ full_content = "".join(content_chunks)
357
+ await self._track_streaming_usage(messages, full_content)
358
+
359
+ except Exception as e:
360
+ logger.error(f"Error in streaming: {e}")
361
+ raise
362
+
363
+ except httpx.RequestError as e:
364
+ logger.error(f"HTTP request error in astream: {e}")
365
+ raise
366
+ except Exception as e:
367
+ logger.error(f"Error in astream: {e}")
368
+ raise
@@ -7,19 +7,18 @@ from typing import Dict, Any, List, Union, AsyncGenerator, Optional, Callable
7
7
  from openai import AsyncOpenAI
8
8
 
9
9
  from isa_model.inference.services.llm.base_llm_service import BaseLLMService
10
- from isa_model.inference.providers.base_provider import BaseProvider
11
- from isa_model.inference.billing_tracker import ServiceType
10
+ from ....core.types import ServiceType
12
11
 
13
12
  logger = logging.getLogger(__name__)
14
13
 
15
14
  class OpenAILLMService(BaseLLMService):
16
15
  """OpenAI LLM service implementation with unified invoke interface"""
17
16
 
18
- def __init__(self, provider: 'BaseProvider', model_name: str = "gpt-4.1-nano"):
19
- super().__init__(provider, model_name)
17
+ def __init__(self, model_name: str = "gpt-4o-mini", provider_name: str = "openai", **kwargs):
18
+ super().__init__(provider_name, model_name, **kwargs)
20
19
 
21
- # Get full configuration from provider (including sensitive data)
22
- provider_config = provider.get_full_config()
20
+ # Get configuration from centralized config manager
21
+ provider_config = self.get_provider_config()
23
22
 
24
23
  # Initialize AsyncOpenAI client with provider configuration
25
24
  try:
@@ -28,7 +27,7 @@ class OpenAILLMService(BaseLLMService):
28
27
 
29
28
  self.client = AsyncOpenAI(
30
29
  api_key=provider_config["api_key"],
31
- base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
30
+ base_url=provider_config.get("api_base_url", "https://api.openai.com/v1"),
32
31
  organization=provider_config.get("organization")
33
32
  )
34
33
 
@@ -44,7 +43,7 @@ class OpenAILLMService(BaseLLMService):
44
43
 
45
44
  def _create_bound_copy(self) -> 'OpenAILLMService':
46
45
  """Create a copy of this service for tool binding"""
47
- bound_service = OpenAILLMService(self.provider, self.model_name)
46
+ bound_service = OpenAILLMService(self.model_name, self.provider_name)
48
47
  bound_service._bound_tools = self._bound_tools.copy()
49
48
  return bound_service
50
49
 
@@ -67,6 +66,58 @@ class OpenAILLMService(BaseLLMService):
67
66
 
68
67
  return bound_service
69
68
 
69
+ async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
70
+ """
71
+ True streaming method - yields tokens one by one as they arrive
72
+
73
+ Args:
74
+ input_data: Same as ainvoke
75
+
76
+ Yields:
77
+ Individual tokens as they arrive from the API
78
+ """
79
+ try:
80
+ # Use adapter manager to prepare messages
81
+ messages = self._prepare_messages(input_data)
82
+
83
+ # Prepare request kwargs
84
+ provider_config = self.get_provider_config()
85
+ kwargs = {
86
+ "model": self.model_name,
87
+ "messages": messages,
88
+ "temperature": provider_config.get("temperature", 0.7),
89
+ "max_tokens": provider_config.get("max_tokens", 1024),
90
+ "stream": True
91
+ }
92
+
93
+ # Add tools if bound using adapter manager
94
+ tool_schemas = await self._prepare_tools_for_request()
95
+ if tool_schemas:
96
+ kwargs["tools"] = tool_schemas
97
+ kwargs["tool_choice"] = "auto"
98
+
99
+ # Stream tokens one by one
100
+ content_chunks = []
101
+ try:
102
+ stream = await self.client.chat.completions.create(**kwargs)
103
+ async for chunk in stream:
104
+ content = chunk.choices[0].delta.content
105
+ if content:
106
+ content_chunks.append(content)
107
+ yield content
108
+
109
+ # Track usage after streaming is complete
110
+ full_content = "".join(content_chunks)
111
+ self._track_streaming_usage(messages, full_content)
112
+
113
+ except Exception as e:
114
+ logger.error(f"Error in streaming: {e}")
115
+ raise
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error in astream: {e}")
119
+ raise
120
+
70
121
  async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
71
122
  """Unified invoke method for all input types"""
72
123
  try:
@@ -74,11 +125,12 @@ class OpenAILLMService(BaseLLMService):
74
125
  messages = self._prepare_messages(input_data)
75
126
 
76
127
  # Prepare request kwargs
128
+ provider_config = self.get_provider_config()
77
129
  kwargs = {
78
130
  "model": self.model_name,
79
131
  "messages": messages,
80
- "temperature": self.config.get("temperature", 0.7),
81
- "max_tokens": self.config.get("max_tokens", 1024)
132
+ "temperature": provider_config.get("temperature", 0.7),
133
+ "max_tokens": provider_config.get("max_tokens", 1024)
82
134
  }
83
135
 
84
136
  # Add tools if bound using adapter manager
@@ -89,23 +141,12 @@ class OpenAILLMService(BaseLLMService):
89
141
 
90
142
  # Handle streaming vs non-streaming
91
143
  if self.streaming:
92
- # Streaming mode - collect all chunks
144
+ # TRUE STREAMING MODE - collect all chunks from the stream
93
145
  content_chunks = []
94
- async for chunk in await self._stream_response(kwargs):
95
- content_chunks.append(chunk)
146
+ async for token in self.astream(input_data):
147
+ content_chunks.append(token)
96
148
  content = "".join(content_chunks)
97
149
 
98
- # Create a mock usage object for tracking
99
- class MockUsage:
100
- def __init__(self):
101
- self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
102
- self.completion_tokens = len(content) // 4 # Rough estimate
103
- self.total_tokens = self.prompt_tokens + self.completion_tokens
104
-
105
- usage = MockUsage()
106
- self._update_token_usage(usage)
107
- self._track_billing(usage)
108
-
109
150
  return self._format_response(content, input_data)
110
151
  else:
111
152
  # Non-streaming mode
@@ -115,7 +156,7 @@ class OpenAILLMService(BaseLLMService):
115
156
  # Update usage tracking
116
157
  if response.usage:
117
158
  self._update_token_usage(response.usage)
118
- self._track_billing(response.usage)
159
+ await self._track_billing(response.usage)
119
160
 
120
161
  # Handle tool calls if present - let adapter process the complete message
121
162
  if message.tool_calls:
@@ -129,9 +170,28 @@ class OpenAILLMService(BaseLLMService):
129
170
  logger.error(f"Error in ainvoke: {e}")
130
171
  raise
131
172
 
173
+ def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
174
+ """Track usage for streaming requests (estimated)"""
175
+ # Create a mock usage object for tracking
176
+ class MockUsage:
177
+ def __init__(self):
178
+ self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
179
+ self.completion_tokens = len(content) // 4 # Rough estimate
180
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
181
+
182
+ usage = MockUsage()
183
+ self._update_token_usage(usage)
184
+ # Fire and forget async tracking
185
+ import asyncio
186
+ try:
187
+ loop = asyncio.get_event_loop()
188
+ loop.create_task(self._track_billing(usage))
189
+ except:
190
+ # If no event loop, skip tracking
191
+ pass
132
192
 
133
193
  async def _stream_response(self, kwargs: Dict[str, Any]) -> AsyncGenerator[str, None]:
134
- """Handle streaming responses"""
194
+ """Handle streaming responses - DEPRECATED: Use astream() instead"""
135
195
  kwargs["stream"] = True
136
196
 
137
197
  async def stream_generator():
@@ -162,16 +222,17 @@ class OpenAILLMService(BaseLLMService):
162
222
  self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
163
223
  self.total_token_usage["requests_count"] += 1
164
224
 
165
- def _track_billing(self, usage):
225
+ async def _track_billing(self, usage):
166
226
  """Track billing information"""
167
- self._track_usage(
227
+ provider_config = self.get_provider_config()
228
+ await self._track_usage(
168
229
  service_type=ServiceType.LLM,
169
230
  operation="chat",
170
231
  input_tokens=usage.prompt_tokens,
171
232
  output_tokens=usage.completion_tokens,
172
233
  metadata={
173
- "temperature": self.config.get("temperature", 0.7),
174
- "max_tokens": self.config.get("max_tokens", 1024)
234
+ "temperature": provider_config.get("temperature", 0.7),
235
+ "max_tokens": provider_config.get("max_tokens", 1024)
175
236
  }
176
237
  )
177
238
 
@@ -185,15 +246,57 @@ class OpenAILLMService(BaseLLMService):
185
246
 
186
247
  def get_model_info(self) -> Dict[str, Any]:
187
248
  """Get information about the current model"""
249
+ provider_config = self.get_provider_config()
188
250
  return {
189
251
  "name": self.model_name,
190
- "max_tokens": self.config.get("max_tokens", 1024),
252
+ "max_tokens": provider_config.get("max_tokens", 1024),
191
253
  "supports_streaming": True,
192
254
  "supports_functions": True,
193
255
  "provider": "openai"
194
256
  }
195
257
 
196
258
 
259
+ async def chat(
260
+ self,
261
+ input_data: Union[str, List[Dict[str, str]], Any],
262
+ max_tokens: Optional[int] = None
263
+ ) -> Dict[str, Any]:
264
+ """
265
+ Chat method that wraps ainvoke for compatibility with base class
266
+
267
+ Args:
268
+ input_data: Input messages
269
+ max_tokens: Maximum tokens to generate
270
+
271
+ Returns:
272
+ Dict containing chat response
273
+ """
274
+ try:
275
+ # Call ainvoke and get the response
276
+ response = await self.ainvoke(input_data)
277
+
278
+ # Return in expected format
279
+ return {
280
+ "text": response if isinstance(response, str) else str(response),
281
+ "success": True,
282
+ "metadata": {
283
+ "model": self.model_name,
284
+ "provider": self.provider_name,
285
+ "max_tokens": max_tokens or self.max_tokens
286
+ }
287
+ }
288
+ except Exception as e:
289
+ logger.error(f"Chat method failed: {e}")
290
+ return {
291
+ "text": "",
292
+ "success": False,
293
+ "error": str(e),
294
+ "metadata": {
295
+ "model": self.model_name,
296
+ "provider": self.provider_name
297
+ }
298
+ }
299
+
197
300
  async def close(self):
198
301
  """Close the backend client"""
199
302
  await self.client.close()