tamar-model-client 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,140 @@
1
+ """
2
+ Circuit Breaker implementation for resilient client
3
+
4
+ This module provides a thread-safe circuit breaker pattern implementation
5
+ to handle failures gracefully and prevent cascading failures.
6
+ """
7
+
8
+ import time
9
+ import logging
10
+ from enum import Enum
11
+ from threading import Lock
12
+ from typing import Optional
13
+
14
+ from .core.logging_setup import get_protected_logger
15
+
16
+ logger = get_protected_logger(__name__)
17
+
18
+
19
+ class CircuitState(Enum):
20
+ """Circuit breaker states"""
21
+ CLOSED = "closed" # Normal operation
22
+ OPEN = "open" # Circuit is broken, requests fail fast
23
+ HALF_OPEN = "half_open" # Testing if service has recovered
24
+
25
+
26
+ class CircuitBreaker:
27
+ """
28
+ Thread-safe circuit breaker implementation
29
+
30
+ The circuit breaker prevents cascading failures by failing fast when
31
+ a service is unavailable, and automatically recovers when the service
32
+ becomes available again.
33
+ """
34
+
35
+ def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
36
+ """
37
+ Initialize the circuit breaker
38
+
39
+ Args:
40
+ failure_threshold: Number of consecutive failures before opening circuit
41
+ recovery_timeout: Seconds to wait before attempting recovery
42
+ """
43
+ self.failure_threshold = failure_threshold
44
+ self.recovery_timeout = recovery_timeout
45
+ self.failure_count = 0
46
+ self.last_failure_time: Optional[float] = None
47
+ self.state = CircuitState.CLOSED
48
+ self._lock = Lock()
49
+
50
+ @property
51
+ def is_open(self) -> bool:
52
+ """Check if circuit breaker is open"""
53
+ with self._lock:
54
+ if self.state == CircuitState.OPEN:
55
+ # Check if we should attempt recovery
56
+ if (self.last_failure_time and
57
+ time.time() - self.last_failure_time > self.recovery_timeout):
58
+ self.state = CircuitState.HALF_OPEN
59
+ logger.info("🔄 Circuit breaker entering HALF_OPEN state")
60
+ return False
61
+ return True
62
+ return False
63
+
64
+ def record_success(self) -> None:
65
+ """Record a successful request"""
66
+ with self._lock:
67
+ if self.state == CircuitState.HALF_OPEN:
68
+ # Success in half-open state means service has recovered
69
+ self.state = CircuitState.CLOSED
70
+ self.failure_count = 0
71
+ logger.info("🔺 Circuit breaker recovered to CLOSED state")
72
+ elif self.state == CircuitState.CLOSED and self.failure_count > 0:
73
+ # Reset failure count on success
74
+ self.failure_count = 0
75
+
76
+ def record_failure(self, error_code=None) -> None:
77
+ """
78
+ Record a failed request
79
+
80
+ Args:
81
+ error_code: gRPC error code for failure classification
82
+ """
83
+ with self._lock:
84
+ # 对于某些错误类型,不计入熔断统计或权重较低
85
+ if error_code and self._should_ignore_for_circuit_breaker(error_code):
86
+ return
87
+
88
+ # ABORTED 错误权重较低,因为通常是瞬时的并发问题
89
+ import grpc
90
+ if error_code == grpc.StatusCode.ABORTED:
91
+ # ABORTED 错误只计算半个失败
92
+ self.failure_count += 0.5
93
+ else:
94
+ self.failure_count += 1
95
+
96
+ self.last_failure_time = time.time()
97
+
98
+ if self.failure_count >= self.failure_threshold:
99
+ if self.state != CircuitState.OPEN:
100
+ self.state = CircuitState.OPEN
101
+ logger.warning(
102
+ f"🔻 Circuit breaker OPENED after {self.failure_count} failures",
103
+ extra={
104
+ "failure_count": self.failure_count,
105
+ "threshold": self.failure_threshold,
106
+ "trigger_error": error_code.name if error_code else "unknown"
107
+ }
108
+ )
109
+
110
+ def _should_ignore_for_circuit_breaker(self, error_code) -> bool:
111
+ """
112
+ 判断错误是否应该被熔断器忽略
113
+
114
+ 某些错误不应该触发熔断:
115
+ - 客户端主动取消的请求
116
+ - 认证相关错误(不代表服务不可用)
117
+ """
118
+ import grpc
119
+ ignored_codes = {
120
+ grpc.StatusCode.UNAUTHENTICATED, # 认证问题,不是服务问题
121
+ grpc.StatusCode.PERMISSION_DENIED, # 权限问题,不是服务问题
122
+ grpc.StatusCode.INVALID_ARGUMENT, # 参数错误,不是服务问题
123
+ }
124
+ return error_code in ignored_codes
125
+
126
+ def should_fallback(self) -> bool:
127
+ """Check if fallback should be used"""
128
+ return self.is_open and self.state != CircuitState.HALF_OPEN
129
+
130
+ def get_state(self) -> str:
131
+ """Get current circuit state"""
132
+ return self.state.value
133
+
134
+ def reset(self) -> None:
135
+ """Reset circuit breaker to initial state"""
136
+ with self._lock:
137
+ self.state = CircuitState.CLOSED
138
+ self.failure_count = 0
139
+ self.last_failure_time = None
140
+ logger.info("🔄 Circuit breaker reset to CLOSED state")
@@ -0,0 +1,40 @@
1
+ """
2
+ Core components for Tamar Model Client
3
+
4
+ This package contains shared components used by both sync and async clients.
5
+ """
6
+
7
+ from .utils import (
8
+ is_effective_value,
9
+ serialize_value,
10
+ remove_none_from_dict,
11
+ generate_request_id,
12
+ set_request_id,
13
+ get_request_id
14
+ )
15
+
16
+ from .logging_setup import (
17
+ setup_logger,
18
+ RequestIdFilter,
19
+ TamarLoggerAdapter,
20
+ get_protected_logger,
21
+ reset_logger_config,
22
+ MAX_MESSAGE_LENGTH
23
+ )
24
+
25
+ __all__ = [
26
+ # Utils
27
+ 'is_effective_value',
28
+ 'serialize_value',
29
+ 'remove_none_from_dict',
30
+ 'generate_request_id',
31
+ 'set_request_id',
32
+ 'get_request_id',
33
+ # Logging
34
+ 'setup_logger',
35
+ 'RequestIdFilter',
36
+ 'TamarLoggerAdapter',
37
+ 'get_protected_logger',
38
+ 'reset_logger_config',
39
+ 'MAX_MESSAGE_LENGTH',
40
+ ]
@@ -0,0 +1,221 @@
1
+ """
2
+ Base client class for Tamar Model Client
3
+
4
+ This module provides the base client class with shared initialization logic
5
+ and configuration management for both sync and async clients.
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from typing import Optional, Dict, Any
11
+ from abc import ABC, abstractmethod
12
+
13
+ from ..auth import JWTAuthHandler
14
+ from ..error_handler import GrpcErrorHandler, ErrorRecoveryStrategy
15
+ from .logging_setup import MAX_MESSAGE_LENGTH, get_protected_logger
16
+
17
+
18
+ class BaseClient(ABC):
19
+ """
20
+ 基础客户端抽象类
21
+
22
+ 提供同步和异步客户端的共享功能:
23
+ - 配置管理
24
+ - 认证设置
25
+ - 连接选项构建
26
+ - 错误处理器初始化
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ server_address: Optional[str] = None,
32
+ jwt_secret_key: Optional[str] = None,
33
+ jwt_token: Optional[str] = None,
34
+ default_payload: Optional[dict] = None,
35
+ token_expires_in: int = 3600,
36
+ max_retries: Optional[int] = None,
37
+ retry_delay: Optional[float] = None,
38
+ logger_name: str = None,
39
+ ):
40
+ """
41
+ 初始化基础客户端
42
+
43
+ Args:
44
+ server_address: gRPC 服务器地址,格式为 "host:port"
45
+ jwt_secret_key: JWT 签名密钥,用于生成认证令牌
46
+ jwt_token: 预生成的 JWT 令牌(可选)
47
+ default_payload: JWT 令牌的默认载荷
48
+ token_expires_in: JWT 令牌过期时间(秒)
49
+ max_retries: 最大重试次数(默认从环境变量读取)
50
+ retry_delay: 初始重试延迟(秒,默认从环境变量读取)
51
+ logger_name: 日志记录器名称
52
+
53
+ Raises:
54
+ ValueError: 当服务器地址未提供时
55
+ """
56
+ # === 服务端地址配置 ===
57
+ self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
58
+ if not self.server_address:
59
+ raise ValueError("Server address must be provided via argument or environment variable.")
60
+
61
+ # 默认调用超时时间
62
+ self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
63
+
64
+ # === JWT 认证配置 ===
65
+ self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
66
+ self.jwt_handler = JWTAuthHandler(self.jwt_secret_key) if self.jwt_secret_key else None
67
+ self.jwt_token = jwt_token # 用户传入的预生成 Token(可选)
68
+ self.default_payload = default_payload
69
+ self.token_expires_in = token_expires_in
70
+
71
+ # === TLS/Authority 配置 ===
72
+ self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
73
+ self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
74
+
75
+ # === 重试配置 ===
76
+ self.max_retries = max_retries if max_retries is not None else int(
77
+ os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
78
+ self.retry_delay = retry_delay if retry_delay is not None else float(
79
+ os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
80
+
81
+ # === 日志配置 ===
82
+ self.logger = get_protected_logger(logger_name or __name__)
83
+
84
+ # === 错误处理器 ===
85
+ self.error_handler = GrpcErrorHandler(self.logger)
86
+ self.recovery_strategy = ErrorRecoveryStrategy(self)
87
+
88
+ # === 连接状态 ===
89
+ self._closed = False
90
+
91
+ # === 熔断降级配置 ===
92
+ self._init_resilient_features()
93
+
94
+ def build_channel_options(self) -> list:
95
+ """
96
+ 构建 gRPC 通道选项
97
+
98
+ Returns:
99
+ list: gRPC 通道配置选项列表
100
+
101
+ 包含的配置:
102
+ - 消息大小限制
103
+ - Keepalive 设置(30秒ping间隔,10秒超时)
104
+ - 连接生命周期管理(1小时最大连接时间)
105
+ - 性能优化选项(带宽探测、内置重试)
106
+ """
107
+ options = [
108
+ # 消息大小限制
109
+ ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
110
+ ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
111
+
112
+ # Keepalive 核心配置
113
+ ('grpc.keepalive_time_ms', 30000), # 30秒发送一次 keepalive ping
114
+ ('grpc.keepalive_timeout_ms', 10000), # ping 响应超时时间 10秒
115
+ ('grpc.keepalive_permit_without_calls', True), # 空闲时也发送 keepalive
116
+ ('grpc.http2.max_pings_without_data', 2), # 无数据时最大 ping 次数
117
+
118
+ # 连接管理增强配置
119
+ ('grpc.http2.min_time_between_pings_ms', 10000), # ping 最小间隔 10秒
120
+ ('grpc.http2.max_connection_idle_ms', 300000), # 最大空闲时间 5分钟
121
+ ('grpc.http2.max_connection_age_ms', 3600000), # 连接最大生存时间 1小时
122
+ ('grpc.http2.max_connection_age_grace_ms', 5000), # 优雅关闭时间 5秒
123
+
124
+ # 性能相关配置
125
+ ('grpc.http2.bdp_probe', 1), # 启用带宽延迟探测
126
+ ('grpc.enable_retries', 1), # 启用内置重试
127
+ ]
128
+
129
+ if self.default_authority:
130
+ options.append(("grpc.default_authority", self.default_authority))
131
+
132
+ return options
133
+
134
+ def _build_auth_metadata(self, request_id: str) -> list:
135
+ """
136
+ 构建认证元数据
137
+
138
+ 为每个请求构建包含认证信息和请求ID的gRPC元数据。
139
+ JWT令牌会在每次请求时重新生成以确保有效性。
140
+
141
+ Args:
142
+ request_id: 当前请求的唯一标识符
143
+
144
+ Returns:
145
+ list: gRPC元数据列表,包含请求ID和认证令牌
146
+ """
147
+ metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
148
+
149
+ if self.jwt_handler:
150
+ self.jwt_token = self.jwt_handler.encode_token(
151
+ self.default_payload,
152
+ expires_in=self.token_expires_in
153
+ )
154
+ metadata.append(("authorization", f"Bearer {self.jwt_token}"))
155
+
156
+ return metadata
157
+
158
+ @abstractmethod
159
+ def close(self):
160
+ """关闭客户端连接(由子类实现)"""
161
+ pass
162
+
163
+ @abstractmethod
164
+ def __enter__(self):
165
+ """进入上下文管理器(由子类实现)"""
166
+ pass
167
+
168
+ @abstractmethod
169
+ def __exit__(self, exc_type, exc_val, exc_tb):
170
+ """退出上下文管理器(由子类实现)"""
171
+ pass
172
+
173
+ def _init_resilient_features(self):
174
+ """初始化熔断降级特性"""
175
+ # 是否启用熔断降级
176
+ self.resilient_enabled = os.getenv('MODEL_CLIENT_RESILIENT_ENABLED', 'false').lower() == 'true'
177
+
178
+ if self.resilient_enabled:
179
+ # HTTP 降级地址
180
+ self.http_fallback_url = os.getenv('MODEL_CLIENT_HTTP_FALLBACK_URL')
181
+
182
+ if not self.http_fallback_url:
183
+ self.logger.warning("🔶 Resilient mode enabled but MODEL_CLIENT_HTTP_FALLBACK_URL not set")
184
+ self.resilient_enabled = False
185
+ return
186
+
187
+ # 初始化熔断器
188
+ from ..circuit_breaker import CircuitBreaker
189
+ self.circuit_breaker = CircuitBreaker(
190
+ failure_threshold=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_THRESHOLD', '5')),
191
+ recovery_timeout=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_TIMEOUT', '60'))
192
+ )
193
+
194
+ # HTTP 客户端(延迟初始化)
195
+ self._http_client = None
196
+ self._http_session = None # 异步客户端使用
197
+
198
+ self.logger.info(
199
+ "🛡️ Resilient mode enabled",
200
+ extra={
201
+ "http_fallback_url": self.http_fallback_url,
202
+ "circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
203
+ "circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
204
+ }
205
+ )
206
+ else:
207
+ self.circuit_breaker = None
208
+ self.http_fallback_url = None
209
+
210
+ def get_resilient_metrics(self):
211
+ """获取熔断降级指标"""
212
+ if not self.resilient_enabled or not self.circuit_breaker:
213
+ return None
214
+
215
+ return {
216
+ "enabled": self.resilient_enabled,
217
+ "circuit_state": self.circuit_breaker.get_state(),
218
+ "failure_count": self.circuit_breaker.failure_count,
219
+ "last_failure_time": self.circuit_breaker.last_failure_time,
220
+ "http_fallback_url": self.http_fallback_url
221
+ }
@@ -0,0 +1,249 @@
1
+ """
2
+ HTTP fallback functionality for resilient clients
3
+
4
+ This module provides mixin classes for HTTP-based fallback when gRPC
5
+ connections fail, supporting both synchronous and asynchronous clients.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from typing import Optional, Iterator, AsyncIterator, Dict, Any
11
+
12
+ from . import generate_request_id, get_protected_logger
13
+ from ..schemas import ModelRequest, ModelResponse
14
+
15
+ logger = get_protected_logger(__name__)
16
+
17
+
18
+ class HttpFallbackMixin:
19
+ """HTTP fallback functionality for synchronous clients"""
20
+
21
+ def _ensure_http_client(self) -> None:
22
+ """Ensure HTTP client is initialized"""
23
+ if not hasattr(self, '_http_client') or not self._http_client:
24
+ import requests
25
+ self._http_client = requests.Session()
26
+
27
+ # Set authentication header if available
28
+ # Note: JWT token will be set per request in headers
29
+
30
+ # Set default headers
31
+ self._http_client.headers.update({
32
+ 'Content-Type': 'application/json',
33
+ 'User-Agent': 'TamarModelClient/1.0'
34
+ })
35
+
36
+ def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
37
+ """Convert ModelRequest to HTTP payload format"""
38
+ payload = {
39
+ "provider": model_request.provider.value,
40
+ "model": model_request.model,
41
+ "user_context": model_request.user_context.model_dump(),
42
+ "stream": model_request.stream
43
+ }
44
+
45
+ # Add provider-specific fields
46
+ if hasattr(model_request, 'messages') and model_request.messages:
47
+ payload['messages'] = model_request.messages
48
+ if hasattr(model_request, 'contents') and model_request.contents:
49
+ payload['contents'] = model_request.contents
50
+
51
+ # Add optional fields
52
+ if model_request.channel:
53
+ payload['channel'] = model_request.channel.value
54
+ if model_request.invoke_type:
55
+ payload['invoke_type'] = model_request.invoke_type.value
56
+
57
+ # Add extra parameters
58
+ if hasattr(model_request, 'model_extra') and model_request.model_extra:
59
+ for key, value in model_request.model_extra.items():
60
+ if key not in payload:
61
+ payload[key] = value
62
+
63
+ return payload
64
+
65
+ def _handle_http_stream(self, url: str, payload: Dict[str, Any],
66
+ timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> Iterator[ModelResponse]:
67
+ """Handle HTTP streaming response"""
68
+ import requests
69
+
70
+ response = self._http_client.post(
71
+ url,
72
+ json=payload,
73
+ timeout=timeout or 30,
74
+ headers=headers,
75
+ stream=True
76
+ )
77
+ response.raise_for_status()
78
+
79
+ # Parse SSE stream
80
+ for line in response.iter_lines():
81
+ if line:
82
+ line_str = line.decode('utf-8')
83
+ if line_str.startswith('data: '):
84
+ data_str = line_str[6:]
85
+ if data_str == '[DONE]':
86
+ break
87
+ try:
88
+ data = json.loads(data_str)
89
+ yield ModelResponse(**data)
90
+ except json.JSONDecodeError:
91
+ logger.warning(f"Failed to parse streaming response: {data_str}")
92
+
93
+ def _invoke_http_fallback(self, model_request: ModelRequest,
94
+ timeout: Optional[float] = None,
95
+ request_id: Optional[str] = None) -> Any:
96
+ """HTTP fallback implementation"""
97
+ self._ensure_http_client()
98
+
99
+ # Generate request ID if not provided
100
+ if not request_id:
101
+ request_id = generate_request_id()
102
+
103
+ # Log fallback usage
104
+ logger.warning(
105
+ f"🔻 Using HTTP fallback for request",
106
+ extra={
107
+ "request_id": request_id,
108
+ "provider": model_request.provider.value,
109
+ "model": model_request.model,
110
+ "fallback_url": self.http_fallback_url
111
+ }
112
+ )
113
+
114
+ # Convert to HTTP format
115
+ http_payload = self._convert_to_http_format(model_request)
116
+
117
+ # Construct URL
118
+ url = f"{self.http_fallback_url}/v1/invoke"
119
+
120
+ # Build headers with authentication
121
+ headers = {'X-Request-ID': request_id}
122
+ if hasattr(self, 'jwt_token') and self.jwt_token:
123
+ headers['Authorization'] = f'Bearer {self.jwt_token}'
124
+
125
+ if model_request.stream:
126
+ # Return streaming iterator
127
+ return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
128
+ else:
129
+ # Non-streaming request
130
+ response = self._http_client.post(
131
+ url,
132
+ json=http_payload,
133
+ timeout=timeout or 30,
134
+ headers=headers
135
+ )
136
+ response.raise_for_status()
137
+
138
+ # Parse response
139
+ data = response.json()
140
+ return ModelResponse(**data)
141
+
142
+
143
+ class AsyncHttpFallbackMixin:
144
+ """HTTP fallback functionality for asynchronous clients"""
145
+
146
+ async def _ensure_http_client(self) -> None:
147
+ """Ensure async HTTP client is initialized"""
148
+ if not hasattr(self, '_http_session') or not self._http_session:
149
+ import aiohttp
150
+ self._http_session = aiohttp.ClientSession(
151
+ headers={
152
+ 'Content-Type': 'application/json',
153
+ 'User-Agent': 'AsyncTamarModelClient/1.0'
154
+ }
155
+ )
156
+
157
+ # Note: JWT token will be set per request in headers
158
+
159
+ def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
160
+ """Convert ModelRequest to HTTP payload format (reuse sync version)"""
161
+ # This method doesn't need to be async, so we can reuse the sync version
162
+ return HttpFallbackMixin._convert_to_http_format(self, model_request)
163
+
164
+ async def _handle_http_stream(self, url: str, payload: Dict[str, Any],
165
+ timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> AsyncIterator[ModelResponse]:
166
+ """Handle async HTTP streaming response"""
167
+ import aiohttp
168
+
169
+ timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
170
+
171
+ async with self._http_session.post(
172
+ url,
173
+ json=payload,
174
+ timeout=timeout_obj,
175
+ headers=headers
176
+ ) as response:
177
+ response.raise_for_status()
178
+
179
+ # Parse SSE stream
180
+ async for line_bytes in response.content:
181
+ if line_bytes:
182
+ line_str = line_bytes.decode('utf-8').strip()
183
+ if line_str.startswith('data: '):
184
+ data_str = line_str[6:]
185
+ if data_str == '[DONE]':
186
+ break
187
+ try:
188
+ data = json.loads(data_str)
189
+ yield ModelResponse(**data)
190
+ except json.JSONDecodeError:
191
+ logger.warning(f"Failed to parse streaming response: {data_str}")
192
+
193
+ async def _invoke_http_fallback(self, model_request: ModelRequest,
194
+ timeout: Optional[float] = None,
195
+ request_id: Optional[str] = None) -> Any:
196
+ """Async HTTP fallback implementation"""
197
+ await self._ensure_http_client()
198
+
199
+ # Generate request ID if not provided
200
+ if not request_id:
201
+ request_id = generate_request_id()
202
+
203
+ # Log fallback usage
204
+ logger.warning(
205
+ f"🔻 Using HTTP fallback for request",
206
+ extra={
207
+ "request_id": request_id,
208
+ "provider": model_request.provider.value,
209
+ "model": model_request.model,
210
+ "fallback_url": self.http_fallback_url
211
+ }
212
+ )
213
+
214
+ # Convert to HTTP format
215
+ http_payload = self._convert_to_http_format(model_request)
216
+
217
+ # Construct URL
218
+ url = f"{self.http_fallback_url}/v1/invoke"
219
+
220
+ # Build headers with authentication
221
+ headers = {'X-Request-ID': request_id}
222
+ if hasattr(self, 'jwt_token') and self.jwt_token:
223
+ headers['Authorization'] = f'Bearer {self.jwt_token}'
224
+
225
+ if model_request.stream:
226
+ # Return async streaming iterator
227
+ return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
228
+ else:
229
+ # Non-streaming request
230
+ import aiohttp
231
+ timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
232
+
233
+ async with self._http_session.post(
234
+ url,
235
+ json=http_payload,
236
+ timeout=timeout_obj,
237
+ headers=headers
238
+ ) as response:
239
+ response.raise_for_status()
240
+
241
+ # Parse response
242
+ data = await response.json()
243
+ return ModelResponse(**data)
244
+
245
+ async def _cleanup_http_session(self) -> None:
246
+ """Clean up HTTP session"""
247
+ if hasattr(self, '_http_session') and self._http_session:
248
+ await self._http_session.close()
249
+ self._http_session = None