isa-model 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/config/__init__.py +9 -0
- isa_model/config/config_manager.py +213 -0
- isa_model/core/model_manager.py +5 -0
- isa_model/core/model_registry.py +39 -6
- isa_model/core/storage/supabase_storage.py +344 -0
- isa_model/core/vision_models_init.py +116 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +612 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +305 -0
- isa_model/inference/ai_factory.py +238 -14
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/base_service.py +0 -38
- isa_model/inference/services/llm/base_llm_service.py +32 -0
- isa_model/inference/services/llm/llm_adapter.py +40 -0
- isa_model/inference/services/llm/ollama_llm_service.py +104 -3
- isa_model/inference/services/llm/openai_llm_service.py +67 -15
- isa_model/inference/services/llm/yyds_llm_service.py +254 -0
- isa_model/inference/services/stacked/__init__.py +26 -0
- isa_model/inference/services/stacked/base_stacked_service.py +269 -0
- isa_model/inference/services/stacked/config.py +426 -0
- isa_model/inference/services/stacked/doc_analysis_service.py +640 -0
- isa_model/inference/services/stacked/flux_professional_service.py +579 -0
- isa_model/inference/services/stacked/ui_analysis_service.py +1319 -0
- isa_model/inference/services/vision/base_image_gen_service.py +0 -34
- isa_model/inference/services/vision/base_vision_service.py +46 -2
- isa_model/inference/services/vision/isA_vision_service.py +402 -0
- isa_model/inference/services/vision/openai_vision_service.py +151 -9
- isa_model/inference/services/vision/replicate_image_gen_service.py +166 -38
- isa_model/inference/services/vision/replicate_vision_service.py +693 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +84 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/METADATA +1 -1
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/RECORD +49 -17
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
|
|
1
|
+
"""
|
2
|
+
Modal Provider
|
3
|
+
|
4
|
+
Provider for ISA self-hosted Modal services
|
5
|
+
No API keys needed since we deploy our own services
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import logging
|
10
|
+
from typing import Dict, Any, Optional, List
|
11
|
+
from .base_provider import BaseProvider
|
12
|
+
from isa_model.inference.base import ModelType, Capability
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
class ModalProvider(BaseProvider):
|
17
|
+
"""Provider for ISA Modal services"""
|
18
|
+
|
19
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
20
|
+
super().__init__(config)
|
21
|
+
self.name = "modal"
|
22
|
+
self.base_url = "https://modal.com" # Not used directly
|
23
|
+
|
24
|
+
def _load_provider_env_vars(self):
|
25
|
+
"""Load Modal-specific environment variables"""
|
26
|
+
# Modal doesn't need API keys for deployed services
|
27
|
+
# But we can load Modal token if available
|
28
|
+
modal_token = os.getenv("MODAL_TOKEN_ID") or os.getenv("MODAL_TOKEN_SECRET")
|
29
|
+
if modal_token:
|
30
|
+
self.config["modal_token"] = modal_token
|
31
|
+
|
32
|
+
# Set default config
|
33
|
+
if "timeout" not in self.config:
|
34
|
+
self.config["timeout"] = 300
|
35
|
+
if "deployment_region" not in self.config:
|
36
|
+
self.config["deployment_region"] = "us-east-1"
|
37
|
+
if "gpu_type" not in self.config:
|
38
|
+
self.config["gpu_type"] = "T4"
|
39
|
+
|
40
|
+
def get_api_key(self) -> str:
|
41
|
+
"""Modal services don't need API keys for deployed apps"""
|
42
|
+
return "modal-deployed-service" # Placeholder
|
43
|
+
|
44
|
+
def get_base_url(self) -> str:
|
45
|
+
"""Get base URL for Modal services"""
|
46
|
+
return self.base_url
|
47
|
+
|
48
|
+
def validate_credentials(self) -> bool:
|
49
|
+
"""
|
50
|
+
Validate Modal credentials
|
51
|
+
For deployed services, we assume they're accessible
|
52
|
+
"""
|
53
|
+
try:
|
54
|
+
# Check if Modal is available
|
55
|
+
import modal
|
56
|
+
return True
|
57
|
+
except ImportError:
|
58
|
+
logger.warning("Modal package not available")
|
59
|
+
return False
|
60
|
+
|
61
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
62
|
+
"""Get Modal provider capabilities"""
|
63
|
+
return {
|
64
|
+
ModelType.VISION: [
|
65
|
+
Capability.OBJECT_DETECTION,
|
66
|
+
Capability.IMAGE_ANALYSIS,
|
67
|
+
Capability.UI_DETECTION,
|
68
|
+
Capability.OCR,
|
69
|
+
Capability.DOCUMENT_ANALYSIS
|
70
|
+
]
|
71
|
+
}
|
72
|
+
|
73
|
+
def get_models(self, model_type: ModelType) -> List[str]:
|
74
|
+
"""Get available models for given type"""
|
75
|
+
if model_type == ModelType.VISION:
|
76
|
+
return [
|
77
|
+
"omniparser-v2.0",
|
78
|
+
"table-transformer-detection",
|
79
|
+
"table-transformer-structure-v1.1",
|
80
|
+
"paddleocr-3.0",
|
81
|
+
"yolov8"
|
82
|
+
]
|
83
|
+
return []
|
84
|
+
|
85
|
+
def is_reasoning_model(self, model_name: str) -> bool:
|
86
|
+
"""Check if the model is optimized for reasoning tasks"""
|
87
|
+
# Vision models are not reasoning models
|
88
|
+
return False
|
89
|
+
|
90
|
+
def get_default_config(self) -> Dict[str, Any]:
|
91
|
+
"""Get default configuration for Modal services"""
|
92
|
+
return {
|
93
|
+
"timeout": 300, # 5 minutes
|
94
|
+
"max_retries": 3,
|
95
|
+
"deployment_region": "us-east-1",
|
96
|
+
"gpu_type": "T4"
|
97
|
+
}
|
98
|
+
|
99
|
+
def get_billing_info(self) -> Dict[str, Any]:
|
100
|
+
"""Get billing information for Modal services"""
|
101
|
+
return {
|
102
|
+
"provider": "modal",
|
103
|
+
"billing_model": "compute_usage",
|
104
|
+
"cost_per_hour": {
|
105
|
+
"T4": 0.60,
|
106
|
+
"A100": 4.00
|
107
|
+
},
|
108
|
+
"note": "Costs depend on actual usage time, scales to zero when not in use"
|
109
|
+
}
|
@@ -0,0 +1,108 @@
|
|
1
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
2
|
+
from isa_model.inference.base import ModelType, Capability
|
3
|
+
from typing import Dict, List, Any
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
class YydsProvider(BaseProvider):
|
10
|
+
"""Provider for YYDS API with proper API key management"""
|
11
|
+
|
12
|
+
def __init__(self, config=None):
|
13
|
+
"""Initialize the YYDS Provider with centralized config management"""
|
14
|
+
super().__init__(config)
|
15
|
+
self.name = "yyds"
|
16
|
+
|
17
|
+
logger.info(f"Initialized YydsProvider with URL: {self.config.get('base_url', 'https://api.yyds.com/v1')}")
|
18
|
+
|
19
|
+
if not self.has_valid_credentials():
|
20
|
+
logger.warning("YYDS API key not found. Set YYDS_API_KEY environment variable or pass api_key in config.")
|
21
|
+
|
22
|
+
def _load_provider_env_vars(self):
|
23
|
+
"""Load YYDS-specific environment variables"""
|
24
|
+
# Set defaults first
|
25
|
+
defaults = {
|
26
|
+
"base_url": "https://api.yyds.com/v1",
|
27
|
+
"timeout": 60,
|
28
|
+
"temperature": 0.7,
|
29
|
+
"top_p": 0.9,
|
30
|
+
"max_tokens": 1024
|
31
|
+
}
|
32
|
+
|
33
|
+
# Apply defaults only if not already set
|
34
|
+
for key, value in defaults.items():
|
35
|
+
if key not in self.config:
|
36
|
+
self.config[key] = value
|
37
|
+
|
38
|
+
# Load from environment variables (override config if present)
|
39
|
+
env_mappings = {
|
40
|
+
"api_key": "YYDS_API_KEY",
|
41
|
+
"base_url": "YYDS_API_BASE",
|
42
|
+
"organization": "YYDS_ORGANIZATION"
|
43
|
+
}
|
44
|
+
|
45
|
+
for config_key, env_var in env_mappings.items():
|
46
|
+
env_value = os.getenv(env_var)
|
47
|
+
if env_value:
|
48
|
+
self.config[config_key] = env_value
|
49
|
+
|
50
|
+
def _validate_config(self):
|
51
|
+
"""Validate YYDS configuration"""
|
52
|
+
if not self.config.get("api_key"):
|
53
|
+
logger.debug("YYDS API key not set - some functionality may not work")
|
54
|
+
|
55
|
+
def get_model_pricing(self, model_name: str) -> Dict[str, float]:
|
56
|
+
"""Get pricing information for a model - delegated to ModelManager"""
|
57
|
+
# Import here to avoid circular imports
|
58
|
+
from isa_model.core.model_manager import ModelManager
|
59
|
+
model_manager = ModelManager()
|
60
|
+
return model_manager.get_model_pricing("yyds", model_name)
|
61
|
+
|
62
|
+
def calculate_cost(self, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
63
|
+
"""Calculate cost for a request - delegated to ModelManager"""
|
64
|
+
# Import here to avoid circular imports
|
65
|
+
from isa_model.core.model_manager import ModelManager
|
66
|
+
model_manager = ModelManager()
|
67
|
+
return model_manager.calculate_cost("yyds", model_name, input_tokens, output_tokens)
|
68
|
+
|
69
|
+
def set_api_key(self, api_key: str):
|
70
|
+
"""Set the API key after initialization"""
|
71
|
+
self.config["api_key"] = api_key
|
72
|
+
logger.info("YYDS API key updated")
|
73
|
+
|
74
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
75
|
+
"""Get provider capabilities by model type"""
|
76
|
+
return {
|
77
|
+
ModelType.LLM: [
|
78
|
+
Capability.CHAT,
|
79
|
+
Capability.COMPLETION
|
80
|
+
]
|
81
|
+
}
|
82
|
+
|
83
|
+
def get_models(self, model_type: ModelType) -> List[str]:
|
84
|
+
"""Get available models for given type"""
|
85
|
+
if model_type == ModelType.LLM:
|
86
|
+
return ["claude-sonnet-4-20250514", "claude-3-5-sonnet-20241022"]
|
87
|
+
else:
|
88
|
+
return []
|
89
|
+
|
90
|
+
def get_default_model(self, model_type: ModelType) -> str:
|
91
|
+
"""Get default model for a given type"""
|
92
|
+
if model_type == ModelType.LLM:
|
93
|
+
return "claude-sonnet-4-20250514"
|
94
|
+
else:
|
95
|
+
return ""
|
96
|
+
|
97
|
+
def get_config(self) -> Dict[str, Any]:
|
98
|
+
"""Get provider configuration"""
|
99
|
+
# Return a copy without sensitive information
|
100
|
+
config_copy = self.config.copy()
|
101
|
+
if "api_key" in config_copy:
|
102
|
+
config_copy["api_key"] = "***" if config_copy["api_key"] else ""
|
103
|
+
return config_copy
|
104
|
+
|
105
|
+
def is_reasoning_model(self, model_name: str) -> bool:
|
106
|
+
"""Check if the model is optimized for reasoning tasks"""
|
107
|
+
reasoning_models = ["claude-sonnet-4", "claude-3-5-sonnet"]
|
108
|
+
return any(rm in model_name.lower() for rm in reasoning_models)
|
@@ -5,7 +5,8 @@ File: isa_model/inference/services/__init__.py
|
|
5
5
|
This module contains service implementations for different AI model types.
|
6
6
|
"""
|
7
7
|
|
8
|
-
from .base_service import BaseService,
|
8
|
+
from .base_service import BaseService, BaseEmbeddingService
|
9
|
+
from .llm.base_llm_service import BaseLLMService
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"BaseService",
|
@@ -52,44 +52,6 @@ class BaseService(ABC):
|
|
52
52
|
yield
|
53
53
|
return self
|
54
54
|
|
55
|
-
class BaseLLMService(BaseService):
|
56
|
-
"""Base class for LLM services"""
|
57
|
-
|
58
|
-
@abstractmethod
|
59
|
-
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
|
60
|
-
"""Universal invocation method"""
|
61
|
-
pass
|
62
|
-
|
63
|
-
@abstractmethod
|
64
|
-
async def achat(self, messages: List[Dict[str, str]]) -> T:
|
65
|
-
"""Chat completion method"""
|
66
|
-
pass
|
67
|
-
|
68
|
-
@abstractmethod
|
69
|
-
async def acompletion(self, prompt: str) -> T:
|
70
|
-
"""Text completion method"""
|
71
|
-
pass
|
72
|
-
|
73
|
-
@abstractmethod
|
74
|
-
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
|
75
|
-
"""Generate multiple completions"""
|
76
|
-
pass
|
77
|
-
|
78
|
-
@abstractmethod
|
79
|
-
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
80
|
-
"""Stream chat responses"""
|
81
|
-
pass
|
82
|
-
|
83
|
-
@abstractmethod
|
84
|
-
def get_token_usage(self) -> Any:
|
85
|
-
"""Get total token usage statistics"""
|
86
|
-
pass
|
87
|
-
|
88
|
-
@abstractmethod
|
89
|
-
def get_last_token_usage(self) -> Dict[str, int]:
|
90
|
-
"""Get token usage from last request"""
|
91
|
-
pass
|
92
|
-
|
93
55
|
class BaseEmbeddingService(BaseService):
|
94
56
|
"""Base class for embedding services"""
|
95
57
|
|
@@ -51,6 +51,22 @@ class BaseLLMService(BaseService):
|
|
51
51
|
"""使用适配器管理器执行工具调用"""
|
52
52
|
return await self.adapter_manager.execute_tool(tool_name, arguments, self._tool_mappings)
|
53
53
|
|
54
|
+
@abstractmethod
|
55
|
+
async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
|
56
|
+
"""
|
57
|
+
True streaming method that yields tokens one by one as they arrive
|
58
|
+
|
59
|
+
Args:
|
60
|
+
input_data: Can be:
|
61
|
+
- str: Simple text prompt
|
62
|
+
- list: Message history like [{"role": "user", "content": "hello"}]
|
63
|
+
- Any: LangChain message objects or other formats
|
64
|
+
|
65
|
+
Yields:
|
66
|
+
Individual tokens as they arrive from the model
|
67
|
+
"""
|
68
|
+
pass
|
69
|
+
|
54
70
|
@abstractmethod
|
55
71
|
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
56
72
|
"""
|
@@ -67,6 +83,22 @@ class BaseLLMService(BaseService):
|
|
67
83
|
"""
|
68
84
|
pass
|
69
85
|
|
86
|
+
def stream(self, input_data: Union[str, List[Dict[str, str]], Any]):
|
87
|
+
"""
|
88
|
+
Synchronous wrapper for astream - returns the async generator
|
89
|
+
|
90
|
+
Args:
|
91
|
+
input_data: Same as astream
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
AsyncGenerator that yields tokens
|
95
|
+
|
96
|
+
Usage:
|
97
|
+
async for token in llm.stream("Hello"):
|
98
|
+
print(token, end="", flush=True)
|
99
|
+
"""
|
100
|
+
return self.astream(input_data)
|
101
|
+
|
70
102
|
def invoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
71
103
|
"""
|
72
104
|
Synchronous wrapper for ainvoke
|
@@ -120,7 +120,12 @@ class LangChainMessageAdapter:
|
|
120
120
|
msg_dict["role"] = "tool"
|
121
121
|
if hasattr(msg, 'tool_call_id'):
|
122
122
|
msg_dict["tool_call_id"] = msg.tool_call_id
|
123
|
+
elif msg.type == "function": # Legacy function message
|
124
|
+
msg_dict["role"] = "function"
|
125
|
+
if hasattr(msg, 'name'):
|
126
|
+
msg_dict["name"] = msg.name
|
123
127
|
else:
|
128
|
+
# Unknown message type, default to user
|
124
129
|
msg_dict["role"] = "user"
|
125
130
|
|
126
131
|
converted_messages.append(msg_dict)
|
@@ -274,6 +279,40 @@ class DictToolAdapter:
|
|
274
279
|
return f"Error: Cannot execute dict tool {tool_name} directly. Requires external executor."
|
275
280
|
|
276
281
|
|
282
|
+
# ============= MCP 工具适配器 =============
|
283
|
+
|
284
|
+
class MCPToolAdapter:
|
285
|
+
"""MCP 工具适配器 - 处理 MCP 协议的工具格式"""
|
286
|
+
|
287
|
+
def __init__(self):
|
288
|
+
self.adapter_name = "mcp_tool"
|
289
|
+
self.priority = 7 # 高优先级,在 LangChain 和 Dict 之间
|
290
|
+
|
291
|
+
def can_handle(self, tool: Any) -> bool:
|
292
|
+
"""检查是否是 MCP 工具格式"""
|
293
|
+
return (isinstance(tool, dict) and
|
294
|
+
"name" in tool and
|
295
|
+
"description" in tool and
|
296
|
+
"inputSchema" in tool and
|
297
|
+
isinstance(tool["inputSchema"], dict))
|
298
|
+
|
299
|
+
def to_openai_schema(self, tool: Any) -> Dict[str, Any]:
|
300
|
+
"""转换 MCP 工具为 OpenAI schema"""
|
301
|
+
return {
|
302
|
+
"type": "function",
|
303
|
+
"function": {
|
304
|
+
"name": tool["name"],
|
305
|
+
"description": tool["description"],
|
306
|
+
"parameters": tool.get("inputSchema", {"type": "object", "properties": {}})
|
307
|
+
}
|
308
|
+
}
|
309
|
+
|
310
|
+
async def execute_tool(self, tool: Any, arguments: Dict[str, Any]) -> Any:
|
311
|
+
"""MCP 工具执行由外部处理,这里返回指示信息"""
|
312
|
+
tool_name = tool["name"]
|
313
|
+
return f"MCP tool {tool_name} execution should be handled externally by MCP client"
|
314
|
+
|
315
|
+
|
277
316
|
# ============= Python 函数适配器 =============
|
278
317
|
|
279
318
|
class PythonFunctionAdapter:
|
@@ -424,6 +463,7 @@ class AdapterManager:
|
|
424
463
|
self.tool_adapters = [
|
425
464
|
DictToolAdapter(), # 最高优先级 - OpenAI格式工具
|
426
465
|
LangChainToolAdapter(), # 中等优先级 - LangChain工具
|
466
|
+
MCPToolAdapter(), # 高优先级 - MCP工具
|
427
467
|
PythonFunctionAdapter() # 最低优先级 - Python函数
|
428
468
|
]
|
429
469
|
|
@@ -86,9 +86,15 @@ class OllamaLLMService(BaseLLMService):
|
|
86
86
|
if tool_schemas:
|
87
87
|
payload["tools"] = tool_schemas
|
88
88
|
|
89
|
-
# Handle streaming
|
89
|
+
# Handle streaming vs non-streaming
|
90
90
|
if self.streaming:
|
91
|
-
|
91
|
+
# TRUE STREAMING MODE - collect all chunks from the stream
|
92
|
+
content_chunks = []
|
93
|
+
async for token in self.astream(input_data):
|
94
|
+
content_chunks.append(token)
|
95
|
+
content = "".join(content_chunks)
|
96
|
+
|
97
|
+
return self._format_response(content, input_data)
|
92
98
|
|
93
99
|
# Regular request
|
94
100
|
response = await self.client.post("/api/chat", json=payload)
|
@@ -190,6 +196,32 @@ class OllamaLLMService(BaseLLMService):
|
|
190
196
|
# Get final response from the model
|
191
197
|
return await self.ainvoke(messages)
|
192
198
|
|
199
|
+
def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
|
200
|
+
"""Track usage for streaming requests (estimated)"""
|
201
|
+
# Create a mock usage object for tracking
|
202
|
+
class MockUsage:
|
203
|
+
def __init__(self):
|
204
|
+
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
205
|
+
self.completion_tokens = len(content) // 4 # Rough estimate
|
206
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
207
|
+
|
208
|
+
usage = MockUsage()
|
209
|
+
self._update_token_usage_from_mock(usage)
|
210
|
+
|
211
|
+
def _update_token_usage_from_mock(self, usage):
|
212
|
+
"""Update token usage statistics from mock usage object"""
|
213
|
+
self.last_token_usage = {
|
214
|
+
"prompt_tokens": usage.prompt_tokens,
|
215
|
+
"completion_tokens": usage.completion_tokens,
|
216
|
+
"total_tokens": usage.total_tokens
|
217
|
+
}
|
218
|
+
|
219
|
+
# Update total usage
|
220
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
221
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
222
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
223
|
+
self.total_token_usage["requests_count"] += 1
|
224
|
+
|
193
225
|
def _update_token_usage(self, result: Dict[str, Any]):
|
194
226
|
"""Update token usage statistics"""
|
195
227
|
self.last_token_usage = {
|
@@ -230,4 +262,73 @@ class OllamaLLMService(BaseLLMService):
|
|
230
262
|
if not self.client.is_closed:
|
231
263
|
await self.client.aclose()
|
232
264
|
except Exception as e:
|
233
|
-
logger.warning(f"Error closing Ollama client: {e}")
|
265
|
+
logger.warning(f"Error closing Ollama client: {e}")
|
266
|
+
|
267
|
+
async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
|
268
|
+
"""
|
269
|
+
True streaming method that yields tokens one by one as they arrive
|
270
|
+
|
271
|
+
Args:
|
272
|
+
input_data: Can be:
|
273
|
+
- str: Simple text prompt
|
274
|
+
- list: Message history like [{"role": "user", "content": "hello"}]
|
275
|
+
- Any: LangChain message objects or other formats
|
276
|
+
|
277
|
+
Yields:
|
278
|
+
Individual tokens as they arrive from the model
|
279
|
+
"""
|
280
|
+
try:
|
281
|
+
# Ensure client is available
|
282
|
+
self._ensure_client()
|
283
|
+
|
284
|
+
# Use adapter manager to prepare messages
|
285
|
+
messages = self._prepare_messages(input_data)
|
286
|
+
|
287
|
+
# Prepare request parameters for streaming
|
288
|
+
payload = {
|
289
|
+
"model": self.model_name,
|
290
|
+
"messages": messages,
|
291
|
+
"stream": True, # Force streaming for astream
|
292
|
+
"options": {
|
293
|
+
"temperature": self.config.get("temperature", 0.7),
|
294
|
+
"top_p": self.config.get("top_p", 0.9),
|
295
|
+
"num_predict": self.config.get("max_tokens", 2048)
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
# Add tools if bound using adapter manager
|
300
|
+
tool_schemas = await self._prepare_tools_for_request()
|
301
|
+
if tool_schemas:
|
302
|
+
payload["tools"] = tool_schemas
|
303
|
+
|
304
|
+
# Stream tokens one by one
|
305
|
+
content_chunks = []
|
306
|
+
try:
|
307
|
+
async with self.client.stream("POST", "/api/chat", json=payload) as response:
|
308
|
+
response.raise_for_status()
|
309
|
+
async for line in response.aiter_lines():
|
310
|
+
if line.strip():
|
311
|
+
try:
|
312
|
+
chunk = json.loads(line)
|
313
|
+
if "message" in chunk and "content" in chunk["message"]:
|
314
|
+
content = chunk["message"]["content"]
|
315
|
+
if content:
|
316
|
+
content_chunks.append(content)
|
317
|
+
yield content
|
318
|
+
except json.JSONDecodeError:
|
319
|
+
continue
|
320
|
+
|
321
|
+
# Track usage after streaming is complete (estimated)
|
322
|
+
full_content = "".join(content_chunks)
|
323
|
+
self._track_streaming_usage(messages, full_content)
|
324
|
+
|
325
|
+
except Exception as e:
|
326
|
+
logger.error(f"Error in streaming: {e}")
|
327
|
+
raise
|
328
|
+
|
329
|
+
except httpx.RequestError as e:
|
330
|
+
logger.error(f"HTTP request error in astream: {e}")
|
331
|
+
raise
|
332
|
+
except Exception as e:
|
333
|
+
logger.error(f"Error in astream: {e}")
|
334
|
+
raise
|
@@ -67,6 +67,57 @@ class OpenAILLMService(BaseLLMService):
|
|
67
67
|
|
68
68
|
return bound_service
|
69
69
|
|
70
|
+
async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
|
71
|
+
"""
|
72
|
+
True streaming method - yields tokens one by one as they arrive
|
73
|
+
|
74
|
+
Args:
|
75
|
+
input_data: Same as ainvoke
|
76
|
+
|
77
|
+
Yields:
|
78
|
+
Individual tokens as they arrive from the API
|
79
|
+
"""
|
80
|
+
try:
|
81
|
+
# Use adapter manager to prepare messages
|
82
|
+
messages = self._prepare_messages(input_data)
|
83
|
+
|
84
|
+
# Prepare request kwargs
|
85
|
+
kwargs = {
|
86
|
+
"model": self.model_name,
|
87
|
+
"messages": messages,
|
88
|
+
"temperature": self.config.get("temperature", 0.7),
|
89
|
+
"max_tokens": self.config.get("max_tokens", 1024),
|
90
|
+
"stream": True
|
91
|
+
}
|
92
|
+
|
93
|
+
# Add tools if bound using adapter manager
|
94
|
+
tool_schemas = await self._prepare_tools_for_request()
|
95
|
+
if tool_schemas:
|
96
|
+
kwargs["tools"] = tool_schemas
|
97
|
+
kwargs["tool_choice"] = "auto"
|
98
|
+
|
99
|
+
# Stream tokens one by one
|
100
|
+
content_chunks = []
|
101
|
+
try:
|
102
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
103
|
+
async for chunk in stream:
|
104
|
+
content = chunk.choices[0].delta.content
|
105
|
+
if content:
|
106
|
+
content_chunks.append(content)
|
107
|
+
yield content
|
108
|
+
|
109
|
+
# Track usage after streaming is complete
|
110
|
+
full_content = "".join(content_chunks)
|
111
|
+
self._track_streaming_usage(messages, full_content)
|
112
|
+
|
113
|
+
except Exception as e:
|
114
|
+
logger.error(f"Error in streaming: {e}")
|
115
|
+
raise
|
116
|
+
|
117
|
+
except Exception as e:
|
118
|
+
logger.error(f"Error in astream: {e}")
|
119
|
+
raise
|
120
|
+
|
70
121
|
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
71
122
|
"""Unified invoke method for all input types"""
|
72
123
|
try:
|
@@ -89,23 +140,12 @@ class OpenAILLMService(BaseLLMService):
|
|
89
140
|
|
90
141
|
# Handle streaming vs non-streaming
|
91
142
|
if self.streaming:
|
92
|
-
#
|
143
|
+
# TRUE STREAMING MODE - collect all chunks from the stream
|
93
144
|
content_chunks = []
|
94
|
-
async for
|
95
|
-
content_chunks.append(
|
145
|
+
async for token in self.astream(input_data):
|
146
|
+
content_chunks.append(token)
|
96
147
|
content = "".join(content_chunks)
|
97
148
|
|
98
|
-
# Create a mock usage object for tracking
|
99
|
-
class MockUsage:
|
100
|
-
def __init__(self):
|
101
|
-
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
102
|
-
self.completion_tokens = len(content) // 4 # Rough estimate
|
103
|
-
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
104
|
-
|
105
|
-
usage = MockUsage()
|
106
|
-
self._update_token_usage(usage)
|
107
|
-
self._track_billing(usage)
|
108
|
-
|
109
149
|
return self._format_response(content, input_data)
|
110
150
|
else:
|
111
151
|
# Non-streaming mode
|
@@ -129,9 +169,21 @@ class OpenAILLMService(BaseLLMService):
|
|
129
169
|
logger.error(f"Error in ainvoke: {e}")
|
130
170
|
raise
|
131
171
|
|
172
|
+
def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
|
173
|
+
"""Track usage for streaming requests (estimated)"""
|
174
|
+
# Create a mock usage object for tracking
|
175
|
+
class MockUsage:
|
176
|
+
def __init__(self):
|
177
|
+
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
178
|
+
self.completion_tokens = len(content) // 4 # Rough estimate
|
179
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
180
|
+
|
181
|
+
usage = MockUsage()
|
182
|
+
self._update_token_usage(usage)
|
183
|
+
self._track_billing(usage)
|
132
184
|
|
133
185
|
async def _stream_response(self, kwargs: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
134
|
-
"""Handle streaming responses"""
|
186
|
+
"""Handle streaming responses - DEPRECATED: Use astream() instead"""
|
135
187
|
kwargs["stream"] = True
|
136
188
|
|
137
189
|
async def stream_generator():
|