tamar-model-client 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/async_client.py +580 -540
- tamar_model_client/circuit_breaker.py +140 -0
- tamar_model_client/core/__init__.py +40 -0
- tamar_model_client/core/base_client.py +221 -0
- tamar_model_client/core/http_fallback.py +249 -0
- tamar_model_client/core/logging_setup.py +194 -0
- tamar_model_client/core/request_builder.py +221 -0
- tamar_model_client/core/response_handler.py +136 -0
- tamar_model_client/core/utils.py +171 -0
- tamar_model_client/error_handler.py +315 -0
- tamar_model_client/exceptions.py +388 -7
- tamar_model_client/json_formatter.py +36 -1
- tamar_model_client/sync_client.py +590 -486
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/METADATA +289 -61
- tamar_model_client-0.1.21.dist-info/RECORD +35 -0
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/top_level.txt +1 -0
- tests/__init__.py +1 -0
- tests/stream_hanging_analysis.py +357 -0
- tests/test_google_azure_final.py +448 -0
- tests/test_simple.py +235 -0
- tamar_model_client-0.1.19.dist-info/RECORD +0 -22
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/WHEEL +0 -0
@@ -0,0 +1,140 @@
|
|
1
|
+
"""
|
2
|
+
Circuit Breaker implementation for resilient client
|
3
|
+
|
4
|
+
This module provides a thread-safe circuit breaker pattern implementation
|
5
|
+
to handle failures gracefully and prevent cascading failures.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import time
|
9
|
+
import logging
|
10
|
+
from enum import Enum
|
11
|
+
from threading import Lock
|
12
|
+
from typing import Optional
|
13
|
+
|
14
|
+
from .core.logging_setup import get_protected_logger
|
15
|
+
|
16
|
+
logger = get_protected_logger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class CircuitState(Enum):
|
20
|
+
"""Circuit breaker states"""
|
21
|
+
CLOSED = "closed" # Normal operation
|
22
|
+
OPEN = "open" # Circuit is broken, requests fail fast
|
23
|
+
HALF_OPEN = "half_open" # Testing if service has recovered
|
24
|
+
|
25
|
+
|
26
|
+
class CircuitBreaker:
|
27
|
+
"""
|
28
|
+
Thread-safe circuit breaker implementation
|
29
|
+
|
30
|
+
The circuit breaker prevents cascading failures by failing fast when
|
31
|
+
a service is unavailable, and automatically recovers when the service
|
32
|
+
becomes available again.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
|
36
|
+
"""
|
37
|
+
Initialize the circuit breaker
|
38
|
+
|
39
|
+
Args:
|
40
|
+
failure_threshold: Number of consecutive failures before opening circuit
|
41
|
+
recovery_timeout: Seconds to wait before attempting recovery
|
42
|
+
"""
|
43
|
+
self.failure_threshold = failure_threshold
|
44
|
+
self.recovery_timeout = recovery_timeout
|
45
|
+
self.failure_count = 0
|
46
|
+
self.last_failure_time: Optional[float] = None
|
47
|
+
self.state = CircuitState.CLOSED
|
48
|
+
self._lock = Lock()
|
49
|
+
|
50
|
+
@property
|
51
|
+
def is_open(self) -> bool:
|
52
|
+
"""Check if circuit breaker is open"""
|
53
|
+
with self._lock:
|
54
|
+
if self.state == CircuitState.OPEN:
|
55
|
+
# Check if we should attempt recovery
|
56
|
+
if (self.last_failure_time and
|
57
|
+
time.time() - self.last_failure_time > self.recovery_timeout):
|
58
|
+
self.state = CircuitState.HALF_OPEN
|
59
|
+
logger.info("🔄 Circuit breaker entering HALF_OPEN state")
|
60
|
+
return False
|
61
|
+
return True
|
62
|
+
return False
|
63
|
+
|
64
|
+
def record_success(self) -> None:
|
65
|
+
"""Record a successful request"""
|
66
|
+
with self._lock:
|
67
|
+
if self.state == CircuitState.HALF_OPEN:
|
68
|
+
# Success in half-open state means service has recovered
|
69
|
+
self.state = CircuitState.CLOSED
|
70
|
+
self.failure_count = 0
|
71
|
+
logger.info("🔺 Circuit breaker recovered to CLOSED state")
|
72
|
+
elif self.state == CircuitState.CLOSED and self.failure_count > 0:
|
73
|
+
# Reset failure count on success
|
74
|
+
self.failure_count = 0
|
75
|
+
|
76
|
+
def record_failure(self, error_code=None) -> None:
|
77
|
+
"""
|
78
|
+
Record a failed request
|
79
|
+
|
80
|
+
Args:
|
81
|
+
error_code: gRPC error code for failure classification
|
82
|
+
"""
|
83
|
+
with self._lock:
|
84
|
+
# 对于某些错误类型,不计入熔断统计或权重较低
|
85
|
+
if error_code and self._should_ignore_for_circuit_breaker(error_code):
|
86
|
+
return
|
87
|
+
|
88
|
+
# ABORTED 错误权重较低,因为通常是瞬时的并发问题
|
89
|
+
import grpc
|
90
|
+
if error_code == grpc.StatusCode.ABORTED:
|
91
|
+
# ABORTED 错误只计算半个失败
|
92
|
+
self.failure_count += 0.5
|
93
|
+
else:
|
94
|
+
self.failure_count += 1
|
95
|
+
|
96
|
+
self.last_failure_time = time.time()
|
97
|
+
|
98
|
+
if self.failure_count >= self.failure_threshold:
|
99
|
+
if self.state != CircuitState.OPEN:
|
100
|
+
self.state = CircuitState.OPEN
|
101
|
+
logger.warning(
|
102
|
+
f"🔻 Circuit breaker OPENED after {self.failure_count} failures",
|
103
|
+
extra={
|
104
|
+
"failure_count": self.failure_count,
|
105
|
+
"threshold": self.failure_threshold,
|
106
|
+
"trigger_error": error_code.name if error_code else "unknown"
|
107
|
+
}
|
108
|
+
)
|
109
|
+
|
110
|
+
def _should_ignore_for_circuit_breaker(self, error_code) -> bool:
|
111
|
+
"""
|
112
|
+
判断错误是否应该被熔断器忽略
|
113
|
+
|
114
|
+
某些错误不应该触发熔断:
|
115
|
+
- 客户端主动取消的请求
|
116
|
+
- 认证相关错误(不代表服务不可用)
|
117
|
+
"""
|
118
|
+
import grpc
|
119
|
+
ignored_codes = {
|
120
|
+
grpc.StatusCode.UNAUTHENTICATED, # 认证问题,不是服务问题
|
121
|
+
grpc.StatusCode.PERMISSION_DENIED, # 权限问题,不是服务问题
|
122
|
+
grpc.StatusCode.INVALID_ARGUMENT, # 参数错误,不是服务问题
|
123
|
+
}
|
124
|
+
return error_code in ignored_codes
|
125
|
+
|
126
|
+
def should_fallback(self) -> bool:
|
127
|
+
"""Check if fallback should be used"""
|
128
|
+
return self.is_open and self.state != CircuitState.HALF_OPEN
|
129
|
+
|
130
|
+
def get_state(self) -> str:
|
131
|
+
"""Get current circuit state"""
|
132
|
+
return self.state.value
|
133
|
+
|
134
|
+
def reset(self) -> None:
|
135
|
+
"""Reset circuit breaker to initial state"""
|
136
|
+
with self._lock:
|
137
|
+
self.state = CircuitState.CLOSED
|
138
|
+
self.failure_count = 0
|
139
|
+
self.last_failure_time = None
|
140
|
+
logger.info("🔄 Circuit breaker reset to CLOSED state")
|
@@ -0,0 +1,40 @@
|
|
1
|
+
"""
|
2
|
+
Core components for Tamar Model Client
|
3
|
+
|
4
|
+
This package contains shared components used by both sync and async clients.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .utils import (
|
8
|
+
is_effective_value,
|
9
|
+
serialize_value,
|
10
|
+
remove_none_from_dict,
|
11
|
+
generate_request_id,
|
12
|
+
set_request_id,
|
13
|
+
get_request_id
|
14
|
+
)
|
15
|
+
|
16
|
+
from .logging_setup import (
|
17
|
+
setup_logger,
|
18
|
+
RequestIdFilter,
|
19
|
+
TamarLoggerAdapter,
|
20
|
+
get_protected_logger,
|
21
|
+
reset_logger_config,
|
22
|
+
MAX_MESSAGE_LENGTH
|
23
|
+
)
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
# Utils
|
27
|
+
'is_effective_value',
|
28
|
+
'serialize_value',
|
29
|
+
'remove_none_from_dict',
|
30
|
+
'generate_request_id',
|
31
|
+
'set_request_id',
|
32
|
+
'get_request_id',
|
33
|
+
# Logging
|
34
|
+
'setup_logger',
|
35
|
+
'RequestIdFilter',
|
36
|
+
'TamarLoggerAdapter',
|
37
|
+
'get_protected_logger',
|
38
|
+
'reset_logger_config',
|
39
|
+
'MAX_MESSAGE_LENGTH',
|
40
|
+
]
|
@@ -0,0 +1,221 @@
|
|
1
|
+
"""
|
2
|
+
Base client class for Tamar Model Client
|
3
|
+
|
4
|
+
This module provides the base client class with shared initialization logic
|
5
|
+
and configuration management for both sync and async clients.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import logging
|
10
|
+
from typing import Optional, Dict, Any
|
11
|
+
from abc import ABC, abstractmethod
|
12
|
+
|
13
|
+
from ..auth import JWTAuthHandler
|
14
|
+
from ..error_handler import GrpcErrorHandler, ErrorRecoveryStrategy
|
15
|
+
from .logging_setup import MAX_MESSAGE_LENGTH, get_protected_logger
|
16
|
+
|
17
|
+
|
18
|
+
class BaseClient(ABC):
|
19
|
+
"""
|
20
|
+
基础客户端抽象类
|
21
|
+
|
22
|
+
提供同步和异步客户端的共享功能:
|
23
|
+
- 配置管理
|
24
|
+
- 认证设置
|
25
|
+
- 连接选项构建
|
26
|
+
- 错误处理器初始化
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
server_address: Optional[str] = None,
|
32
|
+
jwt_secret_key: Optional[str] = None,
|
33
|
+
jwt_token: Optional[str] = None,
|
34
|
+
default_payload: Optional[dict] = None,
|
35
|
+
token_expires_in: int = 3600,
|
36
|
+
max_retries: Optional[int] = None,
|
37
|
+
retry_delay: Optional[float] = None,
|
38
|
+
logger_name: str = None,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
初始化基础客户端
|
42
|
+
|
43
|
+
Args:
|
44
|
+
server_address: gRPC 服务器地址,格式为 "host:port"
|
45
|
+
jwt_secret_key: JWT 签名密钥,用于生成认证令牌
|
46
|
+
jwt_token: 预生成的 JWT 令牌(可选)
|
47
|
+
default_payload: JWT 令牌的默认载荷
|
48
|
+
token_expires_in: JWT 令牌过期时间(秒)
|
49
|
+
max_retries: 最大重试次数(默认从环境变量读取)
|
50
|
+
retry_delay: 初始重试延迟(秒,默认从环境变量读取)
|
51
|
+
logger_name: 日志记录器名称
|
52
|
+
|
53
|
+
Raises:
|
54
|
+
ValueError: 当服务器地址未提供时
|
55
|
+
"""
|
56
|
+
# === 服务端地址配置 ===
|
57
|
+
self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
|
58
|
+
if not self.server_address:
|
59
|
+
raise ValueError("Server address must be provided via argument or environment variable.")
|
60
|
+
|
61
|
+
# 默认调用超时时间
|
62
|
+
self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
|
63
|
+
|
64
|
+
# === JWT 认证配置 ===
|
65
|
+
self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
|
66
|
+
self.jwt_handler = JWTAuthHandler(self.jwt_secret_key) if self.jwt_secret_key else None
|
67
|
+
self.jwt_token = jwt_token # 用户传入的预生成 Token(可选)
|
68
|
+
self.default_payload = default_payload
|
69
|
+
self.token_expires_in = token_expires_in
|
70
|
+
|
71
|
+
# === TLS/Authority 配置 ===
|
72
|
+
self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
|
73
|
+
self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
|
74
|
+
|
75
|
+
# === 重试配置 ===
|
76
|
+
self.max_retries = max_retries if max_retries is not None else int(
|
77
|
+
os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
|
78
|
+
self.retry_delay = retry_delay if retry_delay is not None else float(
|
79
|
+
os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
|
80
|
+
|
81
|
+
# === 日志配置 ===
|
82
|
+
self.logger = get_protected_logger(logger_name or __name__)
|
83
|
+
|
84
|
+
# === 错误处理器 ===
|
85
|
+
self.error_handler = GrpcErrorHandler(self.logger)
|
86
|
+
self.recovery_strategy = ErrorRecoveryStrategy(self)
|
87
|
+
|
88
|
+
# === 连接状态 ===
|
89
|
+
self._closed = False
|
90
|
+
|
91
|
+
# === 熔断降级配置 ===
|
92
|
+
self._init_resilient_features()
|
93
|
+
|
94
|
+
def build_channel_options(self) -> list:
|
95
|
+
"""
|
96
|
+
构建 gRPC 通道选项
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
list: gRPC 通道配置选项列表
|
100
|
+
|
101
|
+
包含的配置:
|
102
|
+
- 消息大小限制
|
103
|
+
- Keepalive 设置(30秒ping间隔,10秒超时)
|
104
|
+
- 连接生命周期管理(1小时最大连接时间)
|
105
|
+
- 性能优化选项(带宽探测、内置重试)
|
106
|
+
"""
|
107
|
+
options = [
|
108
|
+
# 消息大小限制
|
109
|
+
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
110
|
+
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
|
111
|
+
|
112
|
+
# Keepalive 核心配置
|
113
|
+
('grpc.keepalive_time_ms', 30000), # 30秒发送一次 keepalive ping
|
114
|
+
('grpc.keepalive_timeout_ms', 10000), # ping 响应超时时间 10秒
|
115
|
+
('grpc.keepalive_permit_without_calls', True), # 空闲时也发送 keepalive
|
116
|
+
('grpc.http2.max_pings_without_data', 2), # 无数据时最大 ping 次数
|
117
|
+
|
118
|
+
# 连接管理增强配置
|
119
|
+
('grpc.http2.min_time_between_pings_ms', 10000), # ping 最小间隔 10秒
|
120
|
+
('grpc.http2.max_connection_idle_ms', 300000), # 最大空闲时间 5分钟
|
121
|
+
('grpc.http2.max_connection_age_ms', 3600000), # 连接最大生存时间 1小时
|
122
|
+
('grpc.http2.max_connection_age_grace_ms', 5000), # 优雅关闭时间 5秒
|
123
|
+
|
124
|
+
# 性能相关配置
|
125
|
+
('grpc.http2.bdp_probe', 1), # 启用带宽延迟探测
|
126
|
+
('grpc.enable_retries', 1), # 启用内置重试
|
127
|
+
]
|
128
|
+
|
129
|
+
if self.default_authority:
|
130
|
+
options.append(("grpc.default_authority", self.default_authority))
|
131
|
+
|
132
|
+
return options
|
133
|
+
|
134
|
+
def _build_auth_metadata(self, request_id: str) -> list:
|
135
|
+
"""
|
136
|
+
构建认证元数据
|
137
|
+
|
138
|
+
为每个请求构建包含认证信息和请求ID的gRPC元数据。
|
139
|
+
JWT令牌会在每次请求时重新生成以确保有效性。
|
140
|
+
|
141
|
+
Args:
|
142
|
+
request_id: 当前请求的唯一标识符
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
list: gRPC元数据列表,包含请求ID和认证令牌
|
146
|
+
"""
|
147
|
+
metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
|
148
|
+
|
149
|
+
if self.jwt_handler:
|
150
|
+
self.jwt_token = self.jwt_handler.encode_token(
|
151
|
+
self.default_payload,
|
152
|
+
expires_in=self.token_expires_in
|
153
|
+
)
|
154
|
+
metadata.append(("authorization", f"Bearer {self.jwt_token}"))
|
155
|
+
|
156
|
+
return metadata
|
157
|
+
|
158
|
+
@abstractmethod
|
159
|
+
def close(self):
|
160
|
+
"""关闭客户端连接(由子类实现)"""
|
161
|
+
pass
|
162
|
+
|
163
|
+
@abstractmethod
|
164
|
+
def __enter__(self):
|
165
|
+
"""进入上下文管理器(由子类实现)"""
|
166
|
+
pass
|
167
|
+
|
168
|
+
@abstractmethod
|
169
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
170
|
+
"""退出上下文管理器(由子类实现)"""
|
171
|
+
pass
|
172
|
+
|
173
|
+
def _init_resilient_features(self):
|
174
|
+
"""初始化熔断降级特性"""
|
175
|
+
# 是否启用熔断降级
|
176
|
+
self.resilient_enabled = os.getenv('MODEL_CLIENT_RESILIENT_ENABLED', 'false').lower() == 'true'
|
177
|
+
|
178
|
+
if self.resilient_enabled:
|
179
|
+
# HTTP 降级地址
|
180
|
+
self.http_fallback_url = os.getenv('MODEL_CLIENT_HTTP_FALLBACK_URL')
|
181
|
+
|
182
|
+
if not self.http_fallback_url:
|
183
|
+
self.logger.warning("🔶 Resilient mode enabled but MODEL_CLIENT_HTTP_FALLBACK_URL not set")
|
184
|
+
self.resilient_enabled = False
|
185
|
+
return
|
186
|
+
|
187
|
+
# 初始化熔断器
|
188
|
+
from ..circuit_breaker import CircuitBreaker
|
189
|
+
self.circuit_breaker = CircuitBreaker(
|
190
|
+
failure_threshold=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_THRESHOLD', '5')),
|
191
|
+
recovery_timeout=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_TIMEOUT', '60'))
|
192
|
+
)
|
193
|
+
|
194
|
+
# HTTP 客户端(延迟初始化)
|
195
|
+
self._http_client = None
|
196
|
+
self._http_session = None # 异步客户端使用
|
197
|
+
|
198
|
+
self.logger.info(
|
199
|
+
"🛡️ Resilient mode enabled",
|
200
|
+
extra={
|
201
|
+
"http_fallback_url": self.http_fallback_url,
|
202
|
+
"circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
|
203
|
+
"circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
|
204
|
+
}
|
205
|
+
)
|
206
|
+
else:
|
207
|
+
self.circuit_breaker = None
|
208
|
+
self.http_fallback_url = None
|
209
|
+
|
210
|
+
def get_resilient_metrics(self):
|
211
|
+
"""获取熔断降级指标"""
|
212
|
+
if not self.resilient_enabled or not self.circuit_breaker:
|
213
|
+
return None
|
214
|
+
|
215
|
+
return {
|
216
|
+
"enabled": self.resilient_enabled,
|
217
|
+
"circuit_state": self.circuit_breaker.get_state(),
|
218
|
+
"failure_count": self.circuit_breaker.failure_count,
|
219
|
+
"last_failure_time": self.circuit_breaker.last_failure_time,
|
220
|
+
"http_fallback_url": self.http_fallback_url
|
221
|
+
}
|
@@ -0,0 +1,249 @@
|
|
1
|
+
"""
|
2
|
+
HTTP fallback functionality for resilient clients
|
3
|
+
|
4
|
+
This module provides mixin classes for HTTP-based fallback when gRPC
|
5
|
+
connections fail, supporting both synchronous and asynchronous clients.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
from typing import Optional, Iterator, AsyncIterator, Dict, Any
|
11
|
+
|
12
|
+
from . import generate_request_id, get_protected_logger
|
13
|
+
from ..schemas import ModelRequest, ModelResponse
|
14
|
+
|
15
|
+
logger = get_protected_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class HttpFallbackMixin:
|
19
|
+
"""HTTP fallback functionality for synchronous clients"""
|
20
|
+
|
21
|
+
def _ensure_http_client(self) -> None:
|
22
|
+
"""Ensure HTTP client is initialized"""
|
23
|
+
if not hasattr(self, '_http_client') or not self._http_client:
|
24
|
+
import requests
|
25
|
+
self._http_client = requests.Session()
|
26
|
+
|
27
|
+
# Set authentication header if available
|
28
|
+
# Note: JWT token will be set per request in headers
|
29
|
+
|
30
|
+
# Set default headers
|
31
|
+
self._http_client.headers.update({
|
32
|
+
'Content-Type': 'application/json',
|
33
|
+
'User-Agent': 'TamarModelClient/1.0'
|
34
|
+
})
|
35
|
+
|
36
|
+
def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
|
37
|
+
"""Convert ModelRequest to HTTP payload format"""
|
38
|
+
payload = {
|
39
|
+
"provider": model_request.provider.value,
|
40
|
+
"model": model_request.model,
|
41
|
+
"user_context": model_request.user_context.model_dump(),
|
42
|
+
"stream": model_request.stream
|
43
|
+
}
|
44
|
+
|
45
|
+
# Add provider-specific fields
|
46
|
+
if hasattr(model_request, 'messages') and model_request.messages:
|
47
|
+
payload['messages'] = model_request.messages
|
48
|
+
if hasattr(model_request, 'contents') and model_request.contents:
|
49
|
+
payload['contents'] = model_request.contents
|
50
|
+
|
51
|
+
# Add optional fields
|
52
|
+
if model_request.channel:
|
53
|
+
payload['channel'] = model_request.channel.value
|
54
|
+
if model_request.invoke_type:
|
55
|
+
payload['invoke_type'] = model_request.invoke_type.value
|
56
|
+
|
57
|
+
# Add extra parameters
|
58
|
+
if hasattr(model_request, 'model_extra') and model_request.model_extra:
|
59
|
+
for key, value in model_request.model_extra.items():
|
60
|
+
if key not in payload:
|
61
|
+
payload[key] = value
|
62
|
+
|
63
|
+
return payload
|
64
|
+
|
65
|
+
def _handle_http_stream(self, url: str, payload: Dict[str, Any],
|
66
|
+
timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> Iterator[ModelResponse]:
|
67
|
+
"""Handle HTTP streaming response"""
|
68
|
+
import requests
|
69
|
+
|
70
|
+
response = self._http_client.post(
|
71
|
+
url,
|
72
|
+
json=payload,
|
73
|
+
timeout=timeout or 30,
|
74
|
+
headers=headers,
|
75
|
+
stream=True
|
76
|
+
)
|
77
|
+
response.raise_for_status()
|
78
|
+
|
79
|
+
# Parse SSE stream
|
80
|
+
for line in response.iter_lines():
|
81
|
+
if line:
|
82
|
+
line_str = line.decode('utf-8')
|
83
|
+
if line_str.startswith('data: '):
|
84
|
+
data_str = line_str[6:]
|
85
|
+
if data_str == '[DONE]':
|
86
|
+
break
|
87
|
+
try:
|
88
|
+
data = json.loads(data_str)
|
89
|
+
yield ModelResponse(**data)
|
90
|
+
except json.JSONDecodeError:
|
91
|
+
logger.warning(f"Failed to parse streaming response: {data_str}")
|
92
|
+
|
93
|
+
def _invoke_http_fallback(self, model_request: ModelRequest,
|
94
|
+
timeout: Optional[float] = None,
|
95
|
+
request_id: Optional[str] = None) -> Any:
|
96
|
+
"""HTTP fallback implementation"""
|
97
|
+
self._ensure_http_client()
|
98
|
+
|
99
|
+
# Generate request ID if not provided
|
100
|
+
if not request_id:
|
101
|
+
request_id = generate_request_id()
|
102
|
+
|
103
|
+
# Log fallback usage
|
104
|
+
logger.warning(
|
105
|
+
f"🔻 Using HTTP fallback for request",
|
106
|
+
extra={
|
107
|
+
"request_id": request_id,
|
108
|
+
"provider": model_request.provider.value,
|
109
|
+
"model": model_request.model,
|
110
|
+
"fallback_url": self.http_fallback_url
|
111
|
+
}
|
112
|
+
)
|
113
|
+
|
114
|
+
# Convert to HTTP format
|
115
|
+
http_payload = self._convert_to_http_format(model_request)
|
116
|
+
|
117
|
+
# Construct URL
|
118
|
+
url = f"{self.http_fallback_url}/v1/invoke"
|
119
|
+
|
120
|
+
# Build headers with authentication
|
121
|
+
headers = {'X-Request-ID': request_id}
|
122
|
+
if hasattr(self, 'jwt_token') and self.jwt_token:
|
123
|
+
headers['Authorization'] = f'Bearer {self.jwt_token}'
|
124
|
+
|
125
|
+
if model_request.stream:
|
126
|
+
# Return streaming iterator
|
127
|
+
return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
|
128
|
+
else:
|
129
|
+
# Non-streaming request
|
130
|
+
response = self._http_client.post(
|
131
|
+
url,
|
132
|
+
json=http_payload,
|
133
|
+
timeout=timeout or 30,
|
134
|
+
headers=headers
|
135
|
+
)
|
136
|
+
response.raise_for_status()
|
137
|
+
|
138
|
+
# Parse response
|
139
|
+
data = response.json()
|
140
|
+
return ModelResponse(**data)
|
141
|
+
|
142
|
+
|
143
|
+
class AsyncHttpFallbackMixin:
|
144
|
+
"""HTTP fallback functionality for asynchronous clients"""
|
145
|
+
|
146
|
+
async def _ensure_http_client(self) -> None:
|
147
|
+
"""Ensure async HTTP client is initialized"""
|
148
|
+
if not hasattr(self, '_http_session') or not self._http_session:
|
149
|
+
import aiohttp
|
150
|
+
self._http_session = aiohttp.ClientSession(
|
151
|
+
headers={
|
152
|
+
'Content-Type': 'application/json',
|
153
|
+
'User-Agent': 'AsyncTamarModelClient/1.0'
|
154
|
+
}
|
155
|
+
)
|
156
|
+
|
157
|
+
# Note: JWT token will be set per request in headers
|
158
|
+
|
159
|
+
def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
|
160
|
+
"""Convert ModelRequest to HTTP payload format (reuse sync version)"""
|
161
|
+
# This method doesn't need to be async, so we can reuse the sync version
|
162
|
+
return HttpFallbackMixin._convert_to_http_format(self, model_request)
|
163
|
+
|
164
|
+
async def _handle_http_stream(self, url: str, payload: Dict[str, Any],
|
165
|
+
timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> AsyncIterator[ModelResponse]:
|
166
|
+
"""Handle async HTTP streaming response"""
|
167
|
+
import aiohttp
|
168
|
+
|
169
|
+
timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
|
170
|
+
|
171
|
+
async with self._http_session.post(
|
172
|
+
url,
|
173
|
+
json=payload,
|
174
|
+
timeout=timeout_obj,
|
175
|
+
headers=headers
|
176
|
+
) as response:
|
177
|
+
response.raise_for_status()
|
178
|
+
|
179
|
+
# Parse SSE stream
|
180
|
+
async for line_bytes in response.content:
|
181
|
+
if line_bytes:
|
182
|
+
line_str = line_bytes.decode('utf-8').strip()
|
183
|
+
if line_str.startswith('data: '):
|
184
|
+
data_str = line_str[6:]
|
185
|
+
if data_str == '[DONE]':
|
186
|
+
break
|
187
|
+
try:
|
188
|
+
data = json.loads(data_str)
|
189
|
+
yield ModelResponse(**data)
|
190
|
+
except json.JSONDecodeError:
|
191
|
+
logger.warning(f"Failed to parse streaming response: {data_str}")
|
192
|
+
|
193
|
+
async def _invoke_http_fallback(self, model_request: ModelRequest,
|
194
|
+
timeout: Optional[float] = None,
|
195
|
+
request_id: Optional[str] = None) -> Any:
|
196
|
+
"""Async HTTP fallback implementation"""
|
197
|
+
await self._ensure_http_client()
|
198
|
+
|
199
|
+
# Generate request ID if not provided
|
200
|
+
if not request_id:
|
201
|
+
request_id = generate_request_id()
|
202
|
+
|
203
|
+
# Log fallback usage
|
204
|
+
logger.warning(
|
205
|
+
f"🔻 Using HTTP fallback for request",
|
206
|
+
extra={
|
207
|
+
"request_id": request_id,
|
208
|
+
"provider": model_request.provider.value,
|
209
|
+
"model": model_request.model,
|
210
|
+
"fallback_url": self.http_fallback_url
|
211
|
+
}
|
212
|
+
)
|
213
|
+
|
214
|
+
# Convert to HTTP format
|
215
|
+
http_payload = self._convert_to_http_format(model_request)
|
216
|
+
|
217
|
+
# Construct URL
|
218
|
+
url = f"{self.http_fallback_url}/v1/invoke"
|
219
|
+
|
220
|
+
# Build headers with authentication
|
221
|
+
headers = {'X-Request-ID': request_id}
|
222
|
+
if hasattr(self, 'jwt_token') and self.jwt_token:
|
223
|
+
headers['Authorization'] = f'Bearer {self.jwt_token}'
|
224
|
+
|
225
|
+
if model_request.stream:
|
226
|
+
# Return async streaming iterator
|
227
|
+
return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
|
228
|
+
else:
|
229
|
+
# Non-streaming request
|
230
|
+
import aiohttp
|
231
|
+
timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
|
232
|
+
|
233
|
+
async with self._http_session.post(
|
234
|
+
url,
|
235
|
+
json=http_payload,
|
236
|
+
timeout=timeout_obj,
|
237
|
+
headers=headers
|
238
|
+
) as response:
|
239
|
+
response.raise_for_status()
|
240
|
+
|
241
|
+
# Parse response
|
242
|
+
data = await response.json()
|
243
|
+
return ModelResponse(**data)
|
244
|
+
|
245
|
+
async def _cleanup_http_session(self) -> None:
|
246
|
+
"""Clean up HTTP session"""
|
247
|
+
if hasattr(self, '_http_session') and self._http_session:
|
248
|
+
await self._http_session.close()
|
249
|
+
self._http_session = None
|