isa-model 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/config/__init__.py +9 -0
- isa_model/config/config_manager.py +213 -0
- isa_model/core/model_manager.py +5 -0
- isa_model/core/model_registry.py +39 -6
- isa_model/core/storage/supabase_storage.py +344 -0
- isa_model/core/vision_models_init.py +116 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +612 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +305 -0
- isa_model/inference/ai_factory.py +238 -14
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/base_service.py +0 -38
- isa_model/inference/services/llm/base_llm_service.py +32 -0
- isa_model/inference/services/llm/llm_adapter.py +73 -3
- isa_model/inference/services/llm/ollama_llm_service.py +104 -3
- isa_model/inference/services/llm/openai_llm_service.py +67 -15
- isa_model/inference/services/llm/yyds_llm_service.py +254 -0
- isa_model/inference/services/stacked/__init__.py +26 -0
- isa_model/inference/services/stacked/base_stacked_service.py +269 -0
- isa_model/inference/services/stacked/config.py +426 -0
- isa_model/inference/services/stacked/doc_analysis_service.py +640 -0
- isa_model/inference/services/stacked/flux_professional_service.py +579 -0
- isa_model/inference/services/stacked/ui_analysis_service.py +1319 -0
- isa_model/inference/services/vision/base_image_gen_service.py +0 -34
- isa_model/inference/services/vision/base_vision_service.py +46 -2
- isa_model/inference/services/vision/isA_vision_service.py +402 -0
- isa_model/inference/services/vision/openai_vision_service.py +151 -9
- isa_model/inference/services/vision/replicate_image_gen_service.py +166 -38
- isa_model/inference/services/vision/replicate_vision_service.py +693 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +84 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/METADATA +1 -1
- {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/RECORD +49 -17
- {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/WHEEL +0 -0
- {isa_model-0.3.3.dist-info → isa_model-0.3.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, Any, List, Union, AsyncGenerator
|
3
|
+
|
4
|
+
# (�� OpenAI �
|
5
|
+
from openai import AsyncOpenAI
|
6
|
+
|
7
|
+
from isa_model.inference.services.llm.base_llm_service import BaseLLMService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
from isa_model.inference.billing_tracker import ServiceType
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
class YydsLLMService(BaseLLMService):
|
14
|
+
"""YYDS LLM service implementation with unified invoke interface"""
|
15
|
+
|
16
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "claude-sonnet-4-20250514"):
|
17
|
+
super().__init__(provider, model_name)
|
18
|
+
|
19
|
+
# Get full configuration from provider (including sensitive data)
|
20
|
+
provider_config = provider.get_full_config()
|
21
|
+
|
22
|
+
# Initialize AsyncOpenAI client with provider configuration
|
23
|
+
try:
|
24
|
+
if not provider_config.get("api_key"):
|
25
|
+
raise ValueError("YYDS API key not found in provider configuration")
|
26
|
+
|
27
|
+
self.client = AsyncOpenAI(
|
28
|
+
api_key=provider_config["api_key"],
|
29
|
+
base_url=provider_config.get("base_url", "https://api.yyds.com/v1"),
|
30
|
+
organization=provider_config.get("organization")
|
31
|
+
)
|
32
|
+
|
33
|
+
logger.info(f"Initialized YydsLLMService with model {self.model_name} and endpoint {self.client.base_url}")
|
34
|
+
|
35
|
+
except Exception as e:
|
36
|
+
logger.error(f"Failed to initialize YYDS client: {e}")
|
37
|
+
raise ValueError(f"Failed to initialize YYDS client. Check your API key configuration: {e}") from e
|
38
|
+
|
39
|
+
self.last_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
40
|
+
self.total_token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "requests_count": 0}
|
41
|
+
|
42
|
+
|
43
|
+
def _create_bound_copy(self) -> 'YydsLLMService':
|
44
|
+
"""Create a copy of this service for tool binding"""
|
45
|
+
bound_service = YydsLLMService(self.provider, self.model_name)
|
46
|
+
bound_service._bound_tools = self._bound_tools.copy()
|
47
|
+
return bound_service
|
48
|
+
|
49
|
+
def bind_tools(self, tools: List[Any], **kwargs) -> 'YydsLLMService':
|
50
|
+
"""
|
51
|
+
Bind tools to this LLM service for function calling
|
52
|
+
|
53
|
+
Args:
|
54
|
+
tools: List of tools (functions, dicts, or LangChain tools)
|
55
|
+
**kwargs: Additional arguments for tool binding
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
New LLM service instance with tools bound
|
59
|
+
"""
|
60
|
+
# Create a copy of this service
|
61
|
+
bound_service = self._create_bound_copy()
|
62
|
+
|
63
|
+
# Use base class method to bind tools
|
64
|
+
bound_service._bound_tools = tools
|
65
|
+
|
66
|
+
return bound_service
|
67
|
+
|
68
|
+
async def astream(self, input_data: Union[str, List[Dict[str, str]], Any]) -> AsyncGenerator[str, None]:
|
69
|
+
"""
|
70
|
+
True streaming method - yields tokens one by one as they arrive
|
71
|
+
|
72
|
+
Args:
|
73
|
+
input_data: Same as ainvoke
|
74
|
+
|
75
|
+
Yields:
|
76
|
+
Individual tokens as they arrive from the API
|
77
|
+
"""
|
78
|
+
try:
|
79
|
+
# Use adapter manager to prepare messages
|
80
|
+
messages = self._prepare_messages(input_data)
|
81
|
+
|
82
|
+
# Prepare request kwargs
|
83
|
+
kwargs = {
|
84
|
+
"model": self.model_name,
|
85
|
+
"messages": messages,
|
86
|
+
"temperature": self.config.get("temperature", 0.7),
|
87
|
+
"max_tokens": self.config.get("max_tokens", 1024),
|
88
|
+
"stream": True
|
89
|
+
}
|
90
|
+
|
91
|
+
# Add tools if bound using adapter manager
|
92
|
+
tool_schemas = await self._prepare_tools_for_request()
|
93
|
+
if tool_schemas:
|
94
|
+
kwargs["tools"] = tool_schemas
|
95
|
+
kwargs["tool_choice"] = "auto"
|
96
|
+
|
97
|
+
# Stream tokens one by one
|
98
|
+
content_chunks = []
|
99
|
+
try:
|
100
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
101
|
+
async for chunk in stream:
|
102
|
+
content = chunk.choices[0].delta.content
|
103
|
+
if content:
|
104
|
+
content_chunks.append(content)
|
105
|
+
yield content
|
106
|
+
|
107
|
+
# Track usage after streaming is complete
|
108
|
+
full_content = "".join(content_chunks)
|
109
|
+
self._track_streaming_usage(messages, full_content)
|
110
|
+
|
111
|
+
except Exception as e:
|
112
|
+
logger.error(f"Error in streaming: {e}")
|
113
|
+
raise
|
114
|
+
|
115
|
+
except Exception as e:
|
116
|
+
logger.error(f"Error in astream: {e}")
|
117
|
+
raise
|
118
|
+
|
119
|
+
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
120
|
+
"""Unified invoke method for all input types"""
|
121
|
+
try:
|
122
|
+
# Use adapter manager to prepare messages
|
123
|
+
messages = self._prepare_messages(input_data)
|
124
|
+
|
125
|
+
# Prepare request kwargs
|
126
|
+
kwargs = {
|
127
|
+
"model": self.model_name,
|
128
|
+
"messages": messages,
|
129
|
+
"temperature": self.config.get("temperature", 0.7),
|
130
|
+
"max_tokens": self.config.get("max_tokens", 1024)
|
131
|
+
}
|
132
|
+
|
133
|
+
# Add tools if bound using adapter manager
|
134
|
+
tool_schemas = await self._prepare_tools_for_request()
|
135
|
+
if tool_schemas:
|
136
|
+
kwargs["tools"] = tool_schemas
|
137
|
+
kwargs["tool_choice"] = "auto"
|
138
|
+
|
139
|
+
# Handle streaming vs non-streaming
|
140
|
+
if self.streaming:
|
141
|
+
# TRUE STREAMING MODE - collect all chunks from the stream
|
142
|
+
content_chunks = []
|
143
|
+
async for token in self.astream(input_data):
|
144
|
+
content_chunks.append(token)
|
145
|
+
content = "".join(content_chunks)
|
146
|
+
|
147
|
+
return self._format_response(content, input_data)
|
148
|
+
else:
|
149
|
+
# Non-streaming mode
|
150
|
+
response = await self.client.chat.completions.create(**kwargs)
|
151
|
+
message = response.choices[0].message
|
152
|
+
|
153
|
+
# Update usage tracking
|
154
|
+
if response.usage:
|
155
|
+
self._update_token_usage(response.usage)
|
156
|
+
self._track_billing(response.usage)
|
157
|
+
|
158
|
+
# Handle tool calls if present - let adapter process the complete message
|
159
|
+
if message.tool_calls:
|
160
|
+
# Pass the complete message object to adapter for proper tool_calls handling
|
161
|
+
return self._format_response(message, input_data)
|
162
|
+
|
163
|
+
# Return appropriate format based on input type
|
164
|
+
return self._format_response(message.content or "", input_data)
|
165
|
+
|
166
|
+
except Exception as e:
|
167
|
+
logger.error(f"Error in ainvoke: {e}")
|
168
|
+
raise
|
169
|
+
|
170
|
+
def _track_streaming_usage(self, messages: List[Dict[str, str]], content: str):
|
171
|
+
"""Track usage for streaming requests (estimated)"""
|
172
|
+
# Create a mock usage object for tracking
|
173
|
+
class MockUsage:
|
174
|
+
def __init__(self):
|
175
|
+
self.prompt_tokens = len(str(messages)) // 4 # Rough estimate
|
176
|
+
self.completion_tokens = len(content) // 4 # Rough estimate
|
177
|
+
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
178
|
+
|
179
|
+
usage = MockUsage()
|
180
|
+
self._update_token_usage(usage)
|
181
|
+
self._track_billing(usage)
|
182
|
+
|
183
|
+
async def _stream_response(self, kwargs: Dict[str, Any]) -> AsyncGenerator[str, None]:
|
184
|
+
"""Handle streaming responses - DEPRECATED: Use astream() instead"""
|
185
|
+
kwargs["stream"] = True
|
186
|
+
|
187
|
+
async def stream_generator():
|
188
|
+
try:
|
189
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
190
|
+
async for chunk in stream:
|
191
|
+
content = chunk.choices[0].delta.content
|
192
|
+
if content:
|
193
|
+
yield content
|
194
|
+
except Exception as e:
|
195
|
+
logger.error(f"Error in streaming: {e}")
|
196
|
+
raise
|
197
|
+
|
198
|
+
return stream_generator()
|
199
|
+
|
200
|
+
|
201
|
+
def _update_token_usage(self, usage):
|
202
|
+
"""Update token usage statistics"""
|
203
|
+
self.last_token_usage = {
|
204
|
+
"prompt_tokens": usage.prompt_tokens,
|
205
|
+
"completion_tokens": usage.completion_tokens,
|
206
|
+
"total_tokens": usage.total_tokens
|
207
|
+
}
|
208
|
+
|
209
|
+
# Update total usage
|
210
|
+
self.total_token_usage["prompt_tokens"] += self.last_token_usage["prompt_tokens"]
|
211
|
+
self.total_token_usage["completion_tokens"] += self.last_token_usage["completion_tokens"]
|
212
|
+
self.total_token_usage["total_tokens"] += self.last_token_usage["total_tokens"]
|
213
|
+
self.total_token_usage["requests_count"] += 1
|
214
|
+
|
215
|
+
def _track_billing(self, usage):
|
216
|
+
"""Track billing information"""
|
217
|
+
self._track_usage(
|
218
|
+
service_type=ServiceType.LLM,
|
219
|
+
operation="chat",
|
220
|
+
input_tokens=usage.prompt_tokens,
|
221
|
+
output_tokens=usage.completion_tokens,
|
222
|
+
metadata={
|
223
|
+
"temperature": self.config.get("temperature", 0.7),
|
224
|
+
"max_tokens": self.config.get("max_tokens", 1024)
|
225
|
+
}
|
226
|
+
)
|
227
|
+
|
228
|
+
def get_token_usage(self) -> Dict[str, Any]:
|
229
|
+
"""Get total token usage statistics"""
|
230
|
+
return self.total_token_usage
|
231
|
+
|
232
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
233
|
+
"""Get token usage from last request"""
|
234
|
+
return self.last_token_usage
|
235
|
+
|
236
|
+
def get_model_info(self) -> Dict[str, Any]:
|
237
|
+
"""Get information about the current model"""
|
238
|
+
return {
|
239
|
+
"name": self.model_name,
|
240
|
+
"max_tokens": self.config.get("max_tokens", 1024),
|
241
|
+
"supports_streaming": True,
|
242
|
+
"supports_functions": True,
|
243
|
+
"provider": "yyds",
|
244
|
+
"pricing": {
|
245
|
+
"input_tokens_per_1k": 0.0045,
|
246
|
+
"output_tokens_per_1k": 0.0225,
|
247
|
+
"currency": "USD"
|
248
|
+
}
|
249
|
+
}
|
250
|
+
|
251
|
+
|
252
|
+
async def close(self):
|
253
|
+
"""Close the backend client"""
|
254
|
+
await self.client.close()
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""
|
2
|
+
Stacked Services - Multi-model orchestration services
|
3
|
+
|
4
|
+
This module provides stacked services that combine multiple AI models
|
5
|
+
in sequence or parallel to solve complex tasks.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .base_stacked_service import BaseStackedService, LayerConfig, LayerType, LayerResult
|
9
|
+
from .ui_analysis_service import UIAnalysisService
|
10
|
+
from .doc_analysis_service import DocAnalysisStackedService
|
11
|
+
from .flux_professional_service import FluxProfessionalService
|
12
|
+
from .config import ConfigManager, StackedServiceConfig, WorkflowType, get_ui_analysis_config
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
'BaseStackedService',
|
16
|
+
'LayerConfig',
|
17
|
+
'LayerType',
|
18
|
+
'LayerResult',
|
19
|
+
'UIAnalysisService',
|
20
|
+
'DocAnalysisStackedService',
|
21
|
+
'FluxProfessionalService',
|
22
|
+
'ConfigManager',
|
23
|
+
'StackedServiceConfig',
|
24
|
+
'WorkflowType',
|
25
|
+
'get_ui_analysis_config'
|
26
|
+
]
|
@@ -0,0 +1,269 @@
|
|
1
|
+
"""
|
2
|
+
Base Stacked Service for orchestrating multiple AI models
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Dict, Any, List, Optional, Union, Callable
|
7
|
+
import time
|
8
|
+
import asyncio
|
9
|
+
import logging
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from enum import Enum
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
class LayerType(Enum):
|
16
|
+
"""Types of processing layers"""
|
17
|
+
INTELLIGENCE = "intelligence" # High-level understanding
|
18
|
+
DETECTION = "detection" # Element/object detection
|
19
|
+
CLASSIFICATION = "classification" # Detailed classification
|
20
|
+
VALIDATION = "validation" # Result validation
|
21
|
+
TRANSFORMATION = "transformation" # Data transformation
|
22
|
+
GENERATION = "generation" # Content generation
|
23
|
+
ENHANCEMENT = "enhancement" # Quality enhancement
|
24
|
+
CONTROL = "control" # Precise control/refinement
|
25
|
+
UPSCALING = "upscaling" # Resolution enhancement
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class LayerConfig:
|
29
|
+
"""Configuration for a processing layer"""
|
30
|
+
name: str
|
31
|
+
layer_type: LayerType
|
32
|
+
service_type: str # e.g., 'vision', 'llm'
|
33
|
+
model_name: str
|
34
|
+
parameters: Dict[str, Any]
|
35
|
+
depends_on: List[str] # Layer dependencies
|
36
|
+
timeout: float = 30.0
|
37
|
+
retry_count: int = 1
|
38
|
+
fallback_enabled: bool = True
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class LayerResult:
|
42
|
+
"""Result from a processing layer"""
|
43
|
+
layer_name: str
|
44
|
+
success: bool
|
45
|
+
data: Any
|
46
|
+
metadata: Dict[str, Any]
|
47
|
+
execution_time: float
|
48
|
+
error: Optional[str] = None
|
49
|
+
|
50
|
+
class BaseStackedService(ABC):
|
51
|
+
"""
|
52
|
+
Base class for stacked services that orchestrate multiple AI models
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self, ai_factory, service_name: str):
|
56
|
+
self.ai_factory = ai_factory
|
57
|
+
self.service_name = service_name
|
58
|
+
self.layers: List[LayerConfig] = []
|
59
|
+
self.services: Dict[str, Any] = {}
|
60
|
+
self.results: Dict[str, LayerResult] = {}
|
61
|
+
|
62
|
+
def add_layer(self, config: LayerConfig):
|
63
|
+
"""Add a processing layer to the stack"""
|
64
|
+
self.layers.append(config)
|
65
|
+
logger.info(f"Added layer {config.name} ({config.layer_type.value}) to {self.service_name}")
|
66
|
+
|
67
|
+
async def initialize_services(self):
|
68
|
+
"""Initialize all required services"""
|
69
|
+
for layer in self.layers:
|
70
|
+
service_key = f"{layer.service_type}_{layer.model_name}"
|
71
|
+
|
72
|
+
if service_key not in self.services:
|
73
|
+
if layer.service_type == 'vision':
|
74
|
+
if layer.model_name == "default":
|
75
|
+
# 使用默认vision服务
|
76
|
+
service = self.ai_factory.get_vision()
|
77
|
+
elif layer.model_name == "omniparser":
|
78
|
+
# 使用replicate omniparser
|
79
|
+
service = self.ai_factory.get_vision(model_name="omniparser", provider="replicate")
|
80
|
+
else:
|
81
|
+
# 其他指定模型
|
82
|
+
service = self.ai_factory.get_vision(model_name=layer.model_name)
|
83
|
+
elif layer.service_type == 'llm':
|
84
|
+
if layer.model_name == "default":
|
85
|
+
service = self.ai_factory.get_llm()
|
86
|
+
else:
|
87
|
+
service = self.ai_factory.get_llm(model_name=layer.model_name)
|
88
|
+
elif layer.service_type == 'image_gen':
|
89
|
+
if layer.model_name == "default":
|
90
|
+
service = self.ai_factory.get_image_gen()
|
91
|
+
else:
|
92
|
+
service = self.ai_factory.get_image_gen(model_name=layer.model_name)
|
93
|
+
else:
|
94
|
+
raise ValueError(f"Unsupported service type: {layer.service_type}")
|
95
|
+
|
96
|
+
self.services[service_key] = service
|
97
|
+
logger.info(f"Initialized {service_key} service")
|
98
|
+
|
99
|
+
async def execute_layer(self, layer: LayerConfig, context: Dict[str, Any]) -> LayerResult:
|
100
|
+
"""Execute a single layer"""
|
101
|
+
start_time = time.time()
|
102
|
+
|
103
|
+
try:
|
104
|
+
# Check dependencies
|
105
|
+
for dep in layer.depends_on:
|
106
|
+
if dep not in self.results or not self.results[dep].success:
|
107
|
+
raise ValueError(f"Dependency {dep} failed or not executed")
|
108
|
+
|
109
|
+
# Get the service
|
110
|
+
service_key = f"{layer.service_type}_{layer.model_name}"
|
111
|
+
service = self.services[service_key]
|
112
|
+
|
113
|
+
# Execute layer with timeout
|
114
|
+
data = await asyncio.wait_for(
|
115
|
+
self.execute_layer_logic(layer, service, context),
|
116
|
+
timeout=layer.timeout
|
117
|
+
)
|
118
|
+
|
119
|
+
execution_time = time.time() - start_time
|
120
|
+
|
121
|
+
result = LayerResult(
|
122
|
+
layer_name=layer.name,
|
123
|
+
success=True,
|
124
|
+
data=data,
|
125
|
+
metadata={
|
126
|
+
"layer_type": layer.layer_type.value,
|
127
|
+
"model": layer.model_name,
|
128
|
+
"parameters": layer.parameters
|
129
|
+
},
|
130
|
+
execution_time=execution_time
|
131
|
+
)
|
132
|
+
|
133
|
+
logger.info(f"Layer {layer.name} completed in {execution_time:.2f}s")
|
134
|
+
return result
|
135
|
+
|
136
|
+
except Exception as e:
|
137
|
+
execution_time = time.time() - start_time
|
138
|
+
error_msg = str(e)
|
139
|
+
|
140
|
+
logger.error(f"Layer {layer.name} failed after {execution_time:.2f}s: {error_msg}")
|
141
|
+
|
142
|
+
result = LayerResult(
|
143
|
+
layer_name=layer.name,
|
144
|
+
success=False,
|
145
|
+
data=None,
|
146
|
+
metadata={
|
147
|
+
"layer_type": layer.layer_type.value,
|
148
|
+
"model": layer.model_name,
|
149
|
+
"parameters": layer.parameters
|
150
|
+
},
|
151
|
+
execution_time=execution_time,
|
152
|
+
error=error_msg
|
153
|
+
)
|
154
|
+
|
155
|
+
# Try fallback if enabled
|
156
|
+
if layer.fallback_enabled:
|
157
|
+
fallback_result = await self.execute_fallback(layer, context, error_msg)
|
158
|
+
if fallback_result:
|
159
|
+
result.data = fallback_result
|
160
|
+
result.success = True
|
161
|
+
result.error = f"Fallback used: {error_msg}"
|
162
|
+
|
163
|
+
return result
|
164
|
+
|
165
|
+
@abstractmethod
|
166
|
+
async def execute_layer_logic(self, layer: LayerConfig, service: Any, context: Dict[str, Any]) -> Any:
|
167
|
+
"""Execute the specific logic for a layer - to be implemented by subclasses"""
|
168
|
+
pass
|
169
|
+
|
170
|
+
async def execute_fallback(self, layer: LayerConfig, context: Dict[str, Any], error: str) -> Optional[Any]:
|
171
|
+
"""Execute fallback logic for a failed layer - can be overridden by subclasses"""
|
172
|
+
return None
|
173
|
+
|
174
|
+
async def invoke(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
175
|
+
"""Invoke the entire stack of layers"""
|
176
|
+
logger.info(f"Starting {self.service_name} stack invocation")
|
177
|
+
stack_start_time = time.time()
|
178
|
+
|
179
|
+
# Initialize services if not done
|
180
|
+
if not self.services:
|
181
|
+
await self.initialize_services()
|
182
|
+
|
183
|
+
# Clear previous results
|
184
|
+
self.results.clear()
|
185
|
+
|
186
|
+
# Build execution order based on dependencies
|
187
|
+
execution_order = self._build_execution_order()
|
188
|
+
|
189
|
+
# Execute layers in order
|
190
|
+
context = {"input": input_data, "results": self.results}
|
191
|
+
|
192
|
+
for layer in execution_order:
|
193
|
+
result = await self.execute_layer(layer, context)
|
194
|
+
self.results[layer.name] = result
|
195
|
+
|
196
|
+
# Update context with result
|
197
|
+
context["results"] = self.results
|
198
|
+
|
199
|
+
# Stop if critical layer fails
|
200
|
+
if not result.success and not layer.fallback_enabled:
|
201
|
+
logger.error(f"Critical layer {layer.name} failed, stopping execution")
|
202
|
+
break
|
203
|
+
|
204
|
+
total_time = time.time() - stack_start_time
|
205
|
+
|
206
|
+
# Generate final result
|
207
|
+
final_result = {
|
208
|
+
"service": self.service_name,
|
209
|
+
"success": all(r.success for r in self.results.values()),
|
210
|
+
"total_execution_time": total_time,
|
211
|
+
"layer_results": {name: result for name, result in self.results.items()},
|
212
|
+
"final_output": self.generate_final_output(self.results)
|
213
|
+
}
|
214
|
+
|
215
|
+
logger.info(f"{self.service_name} stack invocation completed in {total_time:.2f}s")
|
216
|
+
return final_result
|
217
|
+
|
218
|
+
def _build_execution_order(self) -> List[LayerConfig]:
|
219
|
+
"""Build execution order based on dependencies"""
|
220
|
+
# Simple topological sort
|
221
|
+
ordered = []
|
222
|
+
remaining = self.layers.copy()
|
223
|
+
|
224
|
+
while remaining:
|
225
|
+
# Find layers with no unmet dependencies
|
226
|
+
ready = []
|
227
|
+
for layer in remaining:
|
228
|
+
deps_met = all(dep in [l.name for l in ordered] for dep in layer.depends_on)
|
229
|
+
if deps_met:
|
230
|
+
ready.append(layer)
|
231
|
+
|
232
|
+
if not ready:
|
233
|
+
raise ValueError("Circular dependency detected in layer configuration")
|
234
|
+
|
235
|
+
# Add ready layers to order
|
236
|
+
ordered.extend(ready)
|
237
|
+
for layer in ready:
|
238
|
+
remaining.remove(layer)
|
239
|
+
|
240
|
+
return ordered
|
241
|
+
|
242
|
+
@abstractmethod
|
243
|
+
def generate_final_output(self, results: Dict[str, LayerResult]) -> Any:
|
244
|
+
"""Generate final output from all layer results - to be implemented by subclasses"""
|
245
|
+
pass
|
246
|
+
|
247
|
+
async def close(self):
|
248
|
+
"""Close all services"""
|
249
|
+
for service in self.services.values():
|
250
|
+
if hasattr(service, 'close'):
|
251
|
+
await service.close()
|
252
|
+
self.services.clear()
|
253
|
+
logger.info(f"Closed all services for {self.service_name}")
|
254
|
+
|
255
|
+
def get_performance_metrics(self) -> Dict[str, Any]:
|
256
|
+
"""Get performance metrics for the stack"""
|
257
|
+
if not self.results:
|
258
|
+
return {}
|
259
|
+
|
260
|
+
metrics = {
|
261
|
+
"total_layers": len(self.results),
|
262
|
+
"successful_layers": sum(1 for r in self.results.values() if r.success),
|
263
|
+
"failed_layers": sum(1 for r in self.results.values() if not r.success),
|
264
|
+
"total_execution_time": sum(r.execution_time for r in self.results.values()),
|
265
|
+
"layer_times": {name: r.execution_time for name, r in self.results.items()},
|
266
|
+
"layer_success": {name: r.success for name, r in self.results.items()}
|
267
|
+
}
|
268
|
+
|
269
|
+
return metrics
|