isa-model 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. isa_model/config/__init__.py +9 -0
  2. isa_model/config/config_manager.py +213 -0
  3. isa_model/core/model_manager.py +5 -0
  4. isa_model/core/model_registry.py +39 -6
  5. isa_model/core/storage/supabase_storage.py +344 -0
  6. isa_model/core/vision_models_init.py +116 -0
  7. isa_model/deployment/cloud/__init__.py +9 -0
  8. isa_model/deployment/cloud/modal/__init__.py +10 -0
  9. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +612 -0
  10. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +305 -0
  11. isa_model/inference/ai_factory.py +238 -14
  12. isa_model/inference/providers/modal_provider.py +109 -0
  13. isa_model/inference/providers/yyds_provider.py +108 -0
  14. isa_model/inference/services/__init__.py +2 -1
  15. isa_model/inference/services/base_service.py +0 -38
  16. isa_model/inference/services/llm/base_llm_service.py +32 -0
  17. isa_model/inference/services/llm/llm_adapter.py +73 -3
  18. isa_model/inference/services/llm/ollama_llm_service.py +104 -3
  19. isa_model/inference/services/llm/openai_llm_service.py +67 -15
  20. isa_model/inference/services/llm/yyds_llm_service.py +254 -0
  21. isa_model/inference/services/stacked/__init__.py +26 -0
  22. isa_model/inference/services/stacked/base_stacked_service.py +269 -0
  23. isa_model/inference/services/stacked/config.py +426 -0
  24. isa_model/inference/services/stacked/doc_analysis_service.py +640 -0
  25. isa_model/inference/services/stacked/flux_professional_service.py +579 -0
  26. isa_model/inference/services/stacked/ui_analysis_service.py +1319 -0
  27. isa_model/inference/services/vision/base_image_gen_service.py +0 -34
  28. isa_model/inference/services/vision/base_vision_service.py +46 -2
  29. isa_model/inference/services/vision/isA_vision_service.py +402 -0
  30. isa_model/inference/services/vision/openai_vision_service.py +151 -9
  31. isa_model/inference/services/vision/replicate_image_gen_service.py +166 -38
  32. isa_model/inference/services/vision/replicate_vision_service.py +693 -0
  33. isa_model/serving/__init__.py +19 -0
  34. isa_model/serving/api/__init__.py +10 -0
  35. isa_model/serving/api/fastapi_server.py +84 -0
  36. isa_model/serving/api/middleware/__init__.py +9 -0
  37. isa_model/serving/api/middleware/request_logger.py +88 -0
  38. isa_model/serving/api/routes/__init__.py +5 -0
  39. isa_model/serving/api/routes/health.py +82 -0
  40. isa_model/serving/api/routes/llm.py +19 -0
  41. isa_model/serving/api/routes/ui_analysis.py +223 -0
  42. isa_model/serving/api/routes/vision.py +19 -0
  43. isa_model/serving/api/schemas/__init__.py +17 -0
  44. isa_model/serving/api/schemas/common.py +33 -0
  45. isa_model/serving/api/schemas/ui_analysis.py +78 -0
  46. {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/METADATA +1 -1
  47. {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/RECORD +49 -17
  48. {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/WHEEL +0 -0
  49. {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ """
2
+ Modal Provider
3
+
4
+ Provider for ISA self-hosted Modal services
5
+ No API keys needed since we deploy our own services
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from typing import Dict, Any, Optional, List
11
+ from .base_provider import BaseProvider
12
+ from isa_model.inference.base import ModelType, Capability
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class ModalProvider(BaseProvider):
17
+ """Provider for ISA Modal services"""
18
+
19
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
20
+ super().__init__(config)
21
+ self.name = "modal"
22
+ self.base_url = "https://modal.com" # Not used directly
23
+
24
+ def _load_provider_env_vars(self):
25
+ """Load Modal-specific environment variables"""
26
+ # Modal doesn't need API keys for deployed services
27
+ # But we can load Modal token if available
28
+ modal_token = os.getenv("MODAL_TOKEN_ID") or os.getenv("MODAL_TOKEN_SECRET")
29
+ if modal_token:
30
+ self.config["modal_token"] = modal_token
31
+
32
+ # Set default config
33
+ if "timeout" not in self.config:
34
+ self.config["timeout"] = 300
35
+ if "deployment_region" not in self.config:
36
+ self.config["deployment_region"] = "us-east-1"
37
+ if "gpu_type" not in self.config:
38
+ self.config["gpu_type"] = "T4"
39
+
40
+ def get_api_key(self) -> str:
41
+ """Modal services don't need API keys for deployed apps"""
42
+ return "modal-deployed-service" # Placeholder
43
+
44
+ def get_base_url(self) -> str:
45
+ """Get base URL for Modal services"""
46
+ return self.base_url
47
+
48
+ def validate_credentials(self) -> bool:
49
+ """
50
+ Validate Modal credentials
51
+ For deployed services, we assume they're accessible
52
+ """
53
+ try:
54
+ # Check if Modal is available
55
+ import modal
56
+ return True
57
+ except ImportError:
58
+ logger.warning("Modal package not available")
59
+ return False
60
+
61
+ def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
62
+ """Get Modal provider capabilities"""
63
+ return {
64
+ ModelType.VISION: [
65
+ Capability.OBJECT_DETECTION,
66
+ Capability.IMAGE_ANALYSIS,
67
+ Capability.UI_DETECTION,
68
+ Capability.OCR,
69
+ Capability.DOCUMENT_ANALYSIS
70
+ ]
71
+ }
72
+
73
+ def get_models(self, model_type: ModelType) -> List[str]:
74
+ """Get available models for given type"""
75
+ if model_type == ModelType.VISION:
76
+ return [
77
+ "omniparser-v2.0",
78
+ "table-transformer-detection",
79
+ "table-transformer-structure-v1.1",
80
+ "paddleocr-3.0",
81
+ "yolov8"
82
+ ]
83
+ return []
84
+
85
+ def is_reasoning_model(self, model_name: str) -> bool:
86
+ """Check if the model is optimized for reasoning tasks"""
87
+ # Vision models are not reasoning models
88
+ return False
89
+
90
+ def get_default_config(self) -> Dict[str, Any]:
91
+ """Get default configuration for Modal services"""
92
+ return {
93
+ "timeout": 300, # 5 minutes
94
+ "max_retries": 3,
95
+ "deployment_region": "us-east-1",
96
+ "gpu_type": "T4"
97
+ }
98
+
99
+ def get_billing_info(self) -> Dict[str, Any]:
100
+ """Get billing information for Modal services"""
101
+ return {
102
+ "provider": "modal",
103
+ "billing_model": "compute_usage",
104
+ "cost_per_hour": {
105
+ "T4": 0.60,
106
+ "A100": 4.00
107
+ },
108
+ "note": "Costs depend on actual usage time, scales to zero when not in use"
109
+ }
@@ -0,0 +1,108 @@
1
+ from isa_model.inference.providers.base_provider import BaseProvider
2
+ from isa_model.inference.base import ModelType, Capability
3
+ from typing import Dict, List, Any
4
+ import logging
5
+ import os
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class YydsProvider(BaseProvider):
10
+ """Provider for YYDS API with proper API key management"""
11
+
12
+ def __init__(self, config=None):
13
+ """Initialize the YYDS Provider with centralized config management"""
14
+ super().__init__(config)
15
+ self.name = "yyds"
16
+
17
+ logger.info(f"Initialized YydsProvider with URL: {self.config.get('base_url', 'https://api.yyds.com/v1')}")
18
+
19
+ if not self.has_valid_credentials():
20
+ logger.warning("YYDS API key not found. Set YYDS_API_KEY environment variable or pass api_key in config.")
21
+
22
+ def _load_provider_env_vars(self):
23
+ """Load YYDS-specific environment variables"""
24
+ # Set defaults first
25
+ defaults = {
26
+ "base_url": "https://api.yyds.com/v1",
27
+ "timeout": 60,
28
+ "temperature": 0.7,
29
+ "top_p": 0.9,
30
+ "max_tokens": 1024
31
+ }
32
+
33
+ # Apply defaults only if not already set
34
+ for key, value in defaults.items():
35
+ if key not in self.config:
36
+ self.config[key] = value
37
+
38
+ # Load from environment variables (override config if present)
39
+ env_mappings = {
40
+ "api_key": "YYDS_API_KEY",
41
+ "base_url": "YYDS_API_BASE",
42
+ "organization": "YYDS_ORGANIZATION"
43
+ }
44
+
45
+ for config_key, env_var in env_mappings.items():
46
+ env_value = os.getenv(env_var)
47
+ if env_value:
48
+ self.config[config_key] = env_value
49
+
50
+ def _validate_config(self):
51
+ """Validate YYDS configuration"""
52
+ if not self.config.get("api_key"):
53
+ logger.debug("YYDS API key not set - some functionality may not work")
54
+
55
+ def get_model_pricing(self, model_name: str) -> Dict[str, float]:
56
+ """Get pricing information for a model - delegated to ModelManager"""
57
+ # Import here to avoid circular imports
58
+ from isa_model.core.model_manager import ModelManager
59
+ model_manager = ModelManager()
60
+ return model_manager.get_model_pricing("yyds", model_name)
61
+
62
+ def calculate_cost(self, model_name: str, input_tokens: int, output_tokens: int) -> float:
63
+ """Calculate cost for a request - delegated to ModelManager"""
64
+ # Import here to avoid circular imports
65
+ from isa_model.core.model_manager import ModelManager
66
+ model_manager = ModelManager()
67
+ return model_manager.calculate_cost("yyds", model_name, input_tokens, output_tokens)
68
+
69
+ def set_api_key(self, api_key: str):
70
+ """Set the API key after initialization"""
71
+ self.config["api_key"] = api_key
72
+ logger.info("YYDS API key updated")
73
+
74
+ def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
75
+ """Get provider capabilities by model type"""
76
+ return {
77
+ ModelType.LLM: [
78
+ Capability.CHAT,
79
+ Capability.COMPLETION
80
+ ]
81
+ }
82
+
83
+ def get_models(self, model_type: ModelType) -> List[str]:
84
+ """Get available models for given type"""
85
+ if model_type == ModelType.LLM:
86
+ return ["claude-sonnet-4-20250514", "claude-3-5-sonnet-20241022"]
87
+ else:
88
+ return []
89
+
90
+ def get_default_model(self, model_type: ModelType) -> str:
91
+ """Get default model for a given type"""
92
+ if model_type == ModelType.LLM:
93
+ return "claude-sonnet-4-20250514"
94
+ else:
95
+ return ""
96
+
97
+ def get_config(self) -> Dict[str, Any]:
98
+ """Get provider configuration"""
99
+ # Return a copy without sensitive information
100
+ config_copy = self.config.copy()
101
+ if "api_key" in config_copy:
102
+ config_copy["api_key"] = "***" if config_copy["api_key"] else ""
103
+ return config_copy
104
+
105
+ def is_reasoning_model(self, model_name: str) -> bool:
106
+ """Check if the model is optimized for reasoning tasks"""
107
+ reasoning_models = ["claude-sonnet-4", "claude-3-5-sonnet"]
108
+ return any(rm in model_name.lower() for rm in reasoning_models)
@@ -5,7 +5,8 @@ File: isa_model/inference/services/__init__.py
5
5
  This module contains service implementations for different AI model types.
6
6
  """
7
7
 
8
- from .base_service import BaseService, BaseLLMService, BaseEmbeddingService
8
+ from .base_service import BaseService, BaseEmbeddingService
9
+ from .llm.base_llm_service import BaseLLMService
9
10
 
10
11
  __all__ = [
11
12
  "BaseService",
@@ -52,44 +52,6 @@ class BaseService(ABC):
52
52
  yield
53
53
  return self
54
54
 
55
- class BaseLLMService(BaseService):
56
- """Base class for LLM services"""
57
-
58
- @abstractmethod
59
- async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
60
- """Universal invocation method"""
61
- pass
62
-
63
- @abstractmethod
64
- async def achat(self, messages: List[Dict[str, str]]) -> T:
65
- """Chat completion method"""
66
- pass
67
-
68
- @abstractmethod
69
- async def acompletion(self, prompt: str) -> T:
70
- """Text completion method"""
71
- pass
72
-
73
- @abstractmethod
74
- async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
75
- """Generate multiple completions"""
76
- pass
77
-
78
- @abstractmethod
79
- async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
80
- """Stream chat responses"""
81
- pass
82
-
83
- @abstractmethod
84
- def get_token_usage(self) -> Any:
85
- """Get total token usage statistics"""
86
- pass
87
-
88
- @abstractmethod
89
- def get_last_token_usage(self) -> Dict[str, int]:
90
- """Get token usage from last request"""
91
- pass
92
-
93
55
  class BaseEmbeddingService(BaseService):
94
56
  """Base class for embedding services"""
95
57
 
@@ -51,6 +51,22 @@ class BaseLLMService(BaseService):
51
51
  """使用适配器管理器执行工具调用"""
52
52
  return await self.adapter_manager.execute_tool(tool_name, arguments, self._tool_mappings)
53
53
 
54
+ @abstractmethod
55
+ async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
56
+ """
57
+ True streaming method that yields tokens one by one as they arrive
58
+
59
+ Args:
60
+ input_data: Can be:
61
+ - str: Simple text prompt
62
+ - list: Message history like [{"role": "user", "content": "hello"}]
63
+ - Any: LangChain message objects or other formats
64
+
65
+ Yields:
66
+ Individual tokens as they arrive from the model
67
+ """
68
+ pass
69
+
54
70
  @abstractmethod
55
71
  async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
56
72
  """
@@ -67,6 +83,22 @@ class BaseLLMService(BaseService):
67
83
  """
68
84
  pass
69
85
 
86
+ def stream(self, input_data: Union[str, List[Dict[str, str]], Any]):
87
+ """
88
+ Synchronous wrapper for astream - returns the async generator
89
+
90
+ Args:
91
+ input_data: Same as astream
92
+
93
+ Returns:
94
+ AsyncGenerator that yields tokens
95
+
96
+ Usage:
97
+ async for token in llm.stream("Hello"):
98
+ print(token, end="", flush=True)
99
+ """
100
+ return self.astream(input_data)
101
+
70
102
  def invoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
71
103
  """
72
104
  Synchronous wrapper for ainvoke
@@ -120,7 +120,12 @@ class LangChainMessageAdapter:
120
120
  msg_dict["role"] = "tool"
121
121
  if hasattr(msg, 'tool_call_id'):
122
122
  msg_dict["tool_call_id"] = msg.tool_call_id
123
+ elif msg.type == "function": # Legacy function message
124
+ msg_dict["role"] = "function"
125
+ if hasattr(msg, 'name'):
126
+ msg_dict["name"] = msg.name
123
127
  else:
128
+ # Unknown message type, default to user
124
129
  msg_dict["role"] = "user"
125
130
 
126
131
  converted_messages.append(msg_dict)
@@ -245,6 +250,69 @@ class LangChainToolAdapter:
245
250
  return f"Error executing LangChain tool {tool.name}: {str(e)}"
246
251
 
247
252
 
253
+ # ============= OpenAI 格式工具适配器 =============
254
+
255
+ class DictToolAdapter:
256
+ """OpenAI 格式工具字典适配器"""
257
+
258
+ def __init__(self):
259
+ self.adapter_name = "dict_tool"
260
+ self.priority = 9 # Higher priority than Python functions
261
+
262
+ def can_handle(self, tool: Any) -> bool:
263
+ """检查是否是 OpenAI 格式的工具字典"""
264
+ return (isinstance(tool, dict) and
265
+ tool.get("type") == "function" and
266
+ "function" in tool and
267
+ isinstance(tool["function"], dict) and
268
+ "name" in tool["function"])
269
+
270
+ def to_openai_schema(self, tool: Any) -> Dict[str, Any]:
271
+ """工具已经是 OpenAI 格式,直接返回"""
272
+ return tool
273
+
274
+ async def execute_tool(self, tool: Any, arguments: Dict[str, Any]) -> Any:
275
+ """执行 OpenAI 格式工具(通常需要外部执行器)"""
276
+ # 对于 OpenAI 格式的工具字典,我们无法直接执行
277
+ # 这种情况下返回一个指示,让调用方处理
278
+ tool_name = tool["function"]["name"]
279
+ return f"Error: Cannot execute dict tool {tool_name} directly. Requires external executor."
280
+
281
+
282
+ # ============= MCP 工具适配器 =============
283
+
284
+ class MCPToolAdapter:
285
+ """MCP 工具适配器 - 处理 MCP 协议的工具格式"""
286
+
287
+ def __init__(self):
288
+ self.adapter_name = "mcp_tool"
289
+ self.priority = 7 # 高优先级,在 LangChain 和 Dict 之间
290
+
291
+ def can_handle(self, tool: Any) -> bool:
292
+ """检查是否是 MCP 工具格式"""
293
+ return (isinstance(tool, dict) and
294
+ "name" in tool and
295
+ "description" in tool and
296
+ "inputSchema" in tool and
297
+ isinstance(tool["inputSchema"], dict))
298
+
299
+ def to_openai_schema(self, tool: Any) -> Dict[str, Any]:
300
+ """转换 MCP 工具为 OpenAI schema"""
301
+ return {
302
+ "type": "function",
303
+ "function": {
304
+ "name": tool["name"],
305
+ "description": tool["description"],
306
+ "parameters": tool.get("inputSchema", {"type": "object", "properties": {}})
307
+ }
308
+ }
309
+
310
+ async def execute_tool(self, tool: Any, arguments: Dict[str, Any]) -> Any:
311
+ """MCP 工具执行由外部处理,这里返回指示信息"""
312
+ tool_name = tool["name"]
313
+ return f"MCP tool {tool_name} execution should be handled externally by MCP client"
314
+
315
+
248
316
  # ============= Python 函数适配器 =============
249
317
 
250
318
  class PythonFunctionAdapter:
@@ -391,10 +459,12 @@ class AdapterManager:
391
459
  StandardMessageAdapter() # 回退适配器
392
460
  ]
393
461
 
394
- # 工具适配器
462
+ # 工具适配器(按优先级排序)
395
463
  self.tool_adapters = [
396
- LangChainToolAdapter(),
397
- PythonFunctionAdapter()
464
+ DictToolAdapter(), # 最高优先级 - OpenAI格式工具
465
+ LangChainToolAdapter(), # 中等优先级 - LangChain工具
466
+ MCPToolAdapter(), # 高优先级 - MCP工具
467
+ PythonFunctionAdapter() # 最低优先级 - Python函数
398
468
  ]
399
469
 
400
470
  def register_custom_adapter(self, adapter, adapter_type: str):
@@ -86,9 +86,15 @@ class OllamaLLMService(BaseLLMService):
86
86
  if tool_schemas:
87
87
  payload["tools"] = tool_schemas
88
88
 
89
- # Handle streaming
89
+ # Handle streaming vs non-streaming
90
90
  if self.streaming:
91
- return self._stream_response(payload)
91
+ # TRUE STREAMING MODE - collect all chunks from the stream
92
+ content_chunks = []
93
+ async for token in self.astream(input_data):
94
+ content_chunks.append(token)
95
+ content = "".join(content_chunks)
96
+
97
+ return self._format_response(content, input_data)
92
98
 
93
99
  # Regular request
94
100
  response = await self.client.post("/api/chat", json=payload)
@@ -190,6 +196,32 @@ class OllamaLLMService(BaseLLMService):
190
196
  # Get final response from the model
191
197
  return await self.ainvoke(messages)
192
198
 
199
+ def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
200
+ """Track usage for streaming requests (estimated)"""
201
+ # Create a mock usage object for tracking
202
+ class MockUsage:
203
+ def __init__(self):
204
+ self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
205
+ self.completion_tokens = len(content) // 4 # Rough estimate
206
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
207
+
208
+ usage = MockUsage()
209
+ self._update_token_usage_from_mock(usage)
210
+
211
+ def _update_token_usage_from_mock(self, usage):
212
+ """Update token usage statistics from mock usage object"""
213
+ self.last_token_usage = {
214
+ "prompt_tokens": usage.prompt_tokens,
215
+ "completion_tokens": usage.completion_tokens,
216
+ "total_tokens": usage.total_tokens
217
+ }
218
+
219
+ # Update total usage
220
+ self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
221
+ self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
222
+ self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
223
+ self.total_token_usage["requests_count"] += 1
224
+
193
225
  def _update_token_usage(self, result: Dict[str, Any]):
194
226
  """Update token usage statistics"""
195
227
  self.last_token_usage = {
@@ -230,4 +262,73 @@ class OllamaLLMService(BaseLLMService):
230
262
  if not self.client.is_closed:
231
263
  await self.client.aclose()
232
264
  except Exception as e:
233
- logger.warning(f"Error closing Ollama client: {e}")
265
+ logger.warning(f"Error closing Ollama client: {e}")
266
+
267
+ async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
268
+ """
269
+ True streaming method that yields tokens one by one as they arrive
270
+
271
+ Args:
272
+ input_data: Can be:
273
+ - str: Simple text prompt
274
+ - list: Message history like [{"role": "user", "content": "hello"}]
275
+ - Any: LangChain message objects or other formats
276
+
277
+ Yields:
278
+ Individual tokens as they arrive from the model
279
+ """
280
+ try:
281
+ # Ensure client is available
282
+ self._ensure_client()
283
+
284
+ # Use adapter manager to prepare messages
285
+ messages = self._prepare_messages(input_data)
286
+
287
+ # Prepare request parameters for streaming
288
+ payload = {
289
+ "model": self.model_name,
290
+ "messages": messages,
291
+ "stream": True, # Force streaming for astream
292
+ "options": {
293
+ "temperature": self.config.get("temperature", 0.7),
294
+ "top_p": self.config.get("top_p", 0.9),
295
+ "num_predict": self.config.get("max_tokens", 2048)
296
+ }
297
+ }
298
+
299
+ # Add tools if bound using adapter manager
300
+ tool_schemas = await self._prepare_tools_for_request()
301
+ if tool_schemas:
302
+ payload["tools"] = tool_schemas
303
+
304
+ # Stream tokens one by one
305
+ content_chunks = []
306
+ try:
307
+ async with self.client.stream("POST", "/api/chat", json=payload) as response:
308
+ response.raise_for_status()
309
+ async for line in response.aiter_lines():
310
+ if line.strip():
311
+ try:
312
+ chunk = json.loads(line)
313
+ if "message" in chunk and "content" in chunk["message"]:
314
+ content = chunk["message"]["content"]
315
+ if content:
316
+ content_chunks.append(content)
317
+ yield content
318
+ except json.JSONDecodeError:
319
+ continue
320
+
321
+ # Track usage after streaming is complete (estimated)
322
+ full_content = "".join(content_chunks)
323
+ self._track_streaming_usage(messages, full_content)
324
+
325
+ except Exception as e:
326
+ logger.error(f"Error in streaming: {e}")
327
+ raise
328
+
329
+ except httpx.RequestError as e:
330
+ logger.error(f"HTTP request error in astream: {e}")
331
+ raise
332
+ except Exception as e:
333
+ logger.error(f"Error in astream: {e}")
334
+ raise
@@ -67,6 +67,57 @@ class OpenAILLMService(BaseLLMService):
67
67
 
68
68
  return bound_service
69
69
 
70
+ async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
71
+ """
72
+ True streaming method - yields tokens one by one as they arrive
73
+
74
+ Args:
75
+ input_data: Same as ainvoke
76
+
77
+ Yields:
78
+ Individual tokens as they arrive from the API
79
+ """
80
+ try:
81
+ # Use adapter manager to prepare messages
82
+ messages = self._prepare_messages(input_data)
83
+
84
+ # Prepare request kwargs
85
+ kwargs = {
86
+ "model": self.model_name,
87
+ "messages": messages,
88
+ "temperature": self.config.get("temperature", 0.7),
89
+ "max_tokens": self.config.get("max_tokens", 1024),
90
+ "stream": True
91
+ }
92
+
93
+ # Add tools if bound using adapter manager
94
+ tool_schemas = await self._prepare_tools_for_request()
95
+ if tool_schemas:
96
+ kwargs["tools"] = tool_schemas
97
+ kwargs["tool_choice"] = "auto"
98
+
99
+ # Stream tokens one by one
100
+ content_chunks = []
101
+ try:
102
+ stream = await self.client.chat.completions.create(**kwargs)
103
+ async for chunk in stream:
104
+ content = chunk.choices[0].delta.content
105
+ if content:
106
+ content_chunks.append(content)
107
+ yield content
108
+
109
+ # Track usage after streaming is complete
110
+ full_content = "".join(content_chunks)
111
+ self._track_streaming_usage(messages, full_content)
112
+
113
+ except Exception as e:
114
+ logger.error(f"Error in streaming: {e}")
115
+ raise
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error in astream: {e}")
119
+ raise
120
+
70
121
  async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
71
122
  """Unified invoke method for all input types"""
72
123
  try:
@@ -89,23 +140,12 @@ class OpenAILLMService(BaseLLMService):
89
140
 
90
141
  # Handle streaming vs non-streaming
91
142
  if self.streaming:
92
- # Streaming mode - collect all chunks
143
+ # TRUE STREAMING MODE - collect all chunks from the stream
93
144
  content_chunks = []
94
- async for chunk in await self._stream_response(kwargs):
95
- content_chunks.append(chunk)
145
+ async for token in self.astream(input_data):
146
+ content_chunks.append(token)
96
147
  content = "".join(content_chunks)
97
148
 
98
- # Create a mock usage object for tracking
99
- class MockUsage:
100
- def __init__(self):
101
- self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
102
- self.completion_tokens = len(content) // 4 # Rough estimate
103
- self.total_tokens = self.prompt_tokens + self.completion_tokens
104
-
105
- usage = MockUsage()
106
- self._update_token_usage(usage)
107
- self._track_billing(usage)
108
-
109
149
  return self._format_response(content, input_data)
110
150
  else:
111
151
  # Non-streaming mode
@@ -129,9 +169,21 @@ class OpenAILLMService(BaseLLMService):
129
169
  logger.error(f"Error in ainvoke: {e}")
130
170
  raise
131
171
 
172
+ def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
173
+ """Track usage for streaming requests (estimated)"""
174
+ # Create a mock usage object for tracking
175
+ class MockUsage:
176
+ def __init__(self):
177
+ self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
178
+ self.completion_tokens = len(content) // 4 # Rough estimate
179
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
180
+
181
+ usage = MockUsage()
182
+ self._update_token_usage(usage)
183
+ self._track_billing(usage)
132
184
 
133
185
  async def _stream_response(self, kwargs: Dict[str, Any]) -> AsyncGenerator[str, None]:
134
- """Handle streaming responses"""
186
+ """Handle streaming responses - DEPRECATED: Use astream() instead"""
135
187
  kwargs["stream"] = True
136
188
 
137
189
  async def stream_generator():