sycommon-python-lib 0.1.56b5__py3-none-any.whl → 0.1.57b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sycommon/config/Config.py +24 -3
- sycommon/config/LangfuseConfig.py +15 -0
- sycommon/config/SentryConfig.py +13 -0
- sycommon/llm/embedding.py +269 -50
- sycommon/llm/get_llm.py +9 -218
- sycommon/llm/struct_token.py +192 -0
- sycommon/llm/sy_langfuse.py +103 -0
- sycommon/llm/usage_token.py +117 -0
- sycommon/logging/kafka_log.py +187 -433
- sycommon/middleware/exception.py +10 -16
- sycommon/middleware/timeout.py +2 -1
- sycommon/middleware/traceid.py +81 -76
- sycommon/notice/uvicorn_monitor.py +32 -27
- sycommon/rabbitmq/rabbitmq_client.py +247 -242
- sycommon/rabbitmq/rabbitmq_pool.py +201 -123
- sycommon/rabbitmq/rabbitmq_service.py +25 -843
- sycommon/rabbitmq/rabbitmq_service_client_manager.py +211 -0
- sycommon/rabbitmq/rabbitmq_service_connection_monitor.py +73 -0
- sycommon/rabbitmq/rabbitmq_service_consumer_manager.py +285 -0
- sycommon/rabbitmq/rabbitmq_service_core.py +117 -0
- sycommon/rabbitmq/rabbitmq_service_producer_manager.py +238 -0
- sycommon/sentry/__init__.py +0 -0
- sycommon/sentry/sy_sentry.py +35 -0
- sycommon/services.py +122 -96
- sycommon/synacos/nacos_client_base.py +121 -0
- sycommon/synacos/nacos_config_manager.py +107 -0
- sycommon/synacos/nacos_heartbeat_manager.py +144 -0
- sycommon/synacos/nacos_service.py +63 -783
- sycommon/synacos/nacos_service_discovery.py +157 -0
- sycommon/synacos/nacos_service_registration.py +270 -0
- sycommon/tools/env.py +62 -0
- sycommon/tools/merge_headers.py +20 -0
- sycommon/tools/snowflake.py +101 -153
- {sycommon_python_lib-0.1.56b5.dist-info → sycommon_python_lib-0.1.57b4.dist-info}/METADATA +10 -8
- {sycommon_python_lib-0.1.56b5.dist-info → sycommon_python_lib-0.1.57b4.dist-info}/RECORD +38 -20
- {sycommon_python_lib-0.1.56b5.dist-info → sycommon_python_lib-0.1.57b4.dist-info}/WHEEL +0 -0
- {sycommon_python_lib-0.1.56b5.dist-info → sycommon_python_lib-0.1.57b4.dist-info}/entry_points.txt +0 -0
- {sycommon_python_lib-0.1.56b5.dist-info → sycommon_python_lib-0.1.57b4.dist-info}/top_level.txt +0 -0
sycommon/config/Config.py
CHANGED
|
@@ -15,13 +15,13 @@ class Config(metaclass=SingletonMeta):
|
|
|
15
15
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|
16
16
|
self.config = yaml.safe_load(f)
|
|
17
17
|
self.MaxBytes = self.config.get('MaxBytes', 209715200)
|
|
18
|
-
self.Timeout = self.config.get('Timeout',
|
|
18
|
+
self.Timeout = self.config.get('Timeout', 600000)
|
|
19
19
|
self.MaxRetries = self.config.get('MaxRetries', 3)
|
|
20
|
-
self.OCR = self.config.get('OCR', None)
|
|
21
|
-
self.INVOICE_OCR = self.config.get('INVOICE_OCR', None)
|
|
22
20
|
self.llm_configs = []
|
|
23
21
|
self.embedding_configs = []
|
|
24
22
|
self.reranker_configs = []
|
|
23
|
+
self.sentry_configs = []
|
|
24
|
+
self.langfuse_configs = []
|
|
25
25
|
self._process_config()
|
|
26
26
|
|
|
27
27
|
def get_llm_config(self, model_name):
|
|
@@ -42,6 +42,18 @@ class Config(metaclass=SingletonMeta):
|
|
|
42
42
|
return llm
|
|
43
43
|
raise ValueError(f"No configuration found for model: {model_name}")
|
|
44
44
|
|
|
45
|
+
def get_sentry_config(self, name):
|
|
46
|
+
for sentry in self.sentry_configs:
|
|
47
|
+
if sentry.get('name') == name:
|
|
48
|
+
return sentry
|
|
49
|
+
raise ValueError(f"No configuration found for server: {name}")
|
|
50
|
+
|
|
51
|
+
def get_langfuse_config(self, name):
|
|
52
|
+
for langfuse in self.langfuse_configs:
|
|
53
|
+
if langfuse.get('name') == name:
|
|
54
|
+
return langfuse
|
|
55
|
+
raise ValueError(f"No configuration found for server: {name}")
|
|
56
|
+
|
|
45
57
|
def _process_config(self):
|
|
46
58
|
llm_config_list = self.config.get('LLMConfig', [])
|
|
47
59
|
for llm_config in llm_config_list:
|
|
@@ -71,6 +83,15 @@ class Config(metaclass=SingletonMeta):
|
|
|
71
83
|
except ValueError as e:
|
|
72
84
|
print(f"Invalid LLM configuration: {e}")
|
|
73
85
|
|
|
86
|
+
sentry_config_list = self.config.get('SentryConfig', [])
|
|
87
|
+
for sentry_config in sentry_config_list:
|
|
88
|
+
try:
|
|
89
|
+
from sycommon.config.SentryConfig import SentryConfig
|
|
90
|
+
validated_config = SentryConfig(**sentry_config)
|
|
91
|
+
self.sentry_configs.append(validated_config.model_dump())
|
|
92
|
+
except ValueError as e:
|
|
93
|
+
print(f"Invalid Sentry configuration: {e}")
|
|
94
|
+
|
|
74
95
|
def set_attr(self, share_configs: dict):
|
|
75
96
|
self.config = {**self.config, **
|
|
76
97
|
share_configs.get('llm', {}), **share_configs}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LangfuseConfig(BaseModel):
|
|
5
|
+
name: str
|
|
6
|
+
secretKey: str
|
|
7
|
+
publicKey: str
|
|
8
|
+
baseUrl: str
|
|
9
|
+
enable: bool
|
|
10
|
+
|
|
11
|
+
@classmethod
|
|
12
|
+
def from_config(cls, server_name: str):
|
|
13
|
+
from sycommon.config.Config import Config
|
|
14
|
+
langfuse_config = Config().get_langfuse_config(server_name)
|
|
15
|
+
return cls(**langfuse_config)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SentryConfig(BaseModel):
|
|
5
|
+
name: str
|
|
6
|
+
dsn: str
|
|
7
|
+
enable: bool
|
|
8
|
+
|
|
9
|
+
@classmethod
|
|
10
|
+
def from_config(cls, server_name: str):
|
|
11
|
+
from sycommon.config.Config import Config
|
|
12
|
+
sentry_config = Config().get_sentry_config(server_name)
|
|
13
|
+
return cls(**sentry_config)
|
sycommon/llm/embedding.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import json
|
|
3
2
|
import aiohttp
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
import atexit
|
|
4
|
+
from typing import Union, List, Optional, Dict
|
|
6
5
|
from sycommon.config.Config import SingletonMeta
|
|
7
6
|
from sycommon.config.EmbeddingConfig import EmbeddingConfig
|
|
8
7
|
from sycommon.config.RerankerConfig import RerankerConfig
|
|
@@ -23,20 +22,128 @@ class Embedding(metaclass=SingletonMeta):
|
|
|
23
22
|
self.reranker_base_url = RerankerConfig.from_config(
|
|
24
23
|
self.default_reranker_model).baseUrl
|
|
25
24
|
|
|
25
|
+
# [修复] 缓存配置URL,避免高并发下重复读取配置文件
|
|
26
|
+
self._embedding_url_cache: Dict[str, str] = {
|
|
27
|
+
self.default_embedding_model: self.embeddings_base_url
|
|
28
|
+
}
|
|
29
|
+
self._reranker_url_cache: Dict[str, str] = {
|
|
30
|
+
self.default_reranker_model: self.reranker_base_url
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# [修复] 缓存模型的向量维度,用于生成兜底零向量
|
|
34
|
+
self._model_dim_cache: Dict[str, int] = {}
|
|
35
|
+
|
|
26
36
|
# 并发信号量
|
|
27
37
|
self.semaphore = asyncio.Semaphore(self.max_concurrency)
|
|
38
|
+
self.default_timeout = aiohttp.ClientTimeout(total=None)
|
|
28
39
|
|
|
29
|
-
|
|
40
|
+
# 核心优化:创建全局可复用的ClientSession(连接池复用)
|
|
41
|
+
self.session = None
|
|
42
|
+
# 重试配置(可根据需要调整)
|
|
43
|
+
self.max_retry_attempts = 3 # 最大重试次数
|
|
44
|
+
self.retry_wait_base = 0.5 # 基础等待时间(秒)
|
|
45
|
+
|
|
46
|
+
# [修复] 注册退出钩子,确保程序结束时关闭连接池
|
|
47
|
+
atexit.register(self._sync_close_session)
|
|
48
|
+
|
|
49
|
+
async def init_session(self):
|
|
50
|
+
"""初始化全局ClientSession(仅创建一次)"""
|
|
51
|
+
if self.session is None or self.session.closed:
|
|
52
|
+
# 配置连接池参数,适配高并发
|
|
53
|
+
connector = aiohttp.TCPConnector(
|
|
54
|
+
limit=self.max_concurrency * 2, # 连接池最大连接数(建议是并发数的2倍)
|
|
55
|
+
limit_per_host=self.max_concurrency, # 每个域名的最大连接数
|
|
56
|
+
ttl_dns_cache=300, # DNS缓存时间
|
|
57
|
+
enable_cleanup_closed=True # 自动清理关闭的连接
|
|
58
|
+
)
|
|
59
|
+
self.session = aiohttp.ClientSession(
|
|
60
|
+
connector=connector,
|
|
61
|
+
timeout=self.default_timeout
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
async def close_session(self):
|
|
65
|
+
"""关闭全局Session(程序退出时调用)"""
|
|
66
|
+
if self.session and not self.session.closed:
|
|
67
|
+
await self.session.close()
|
|
68
|
+
|
|
69
|
+
def _sync_close_session(self):
|
|
70
|
+
"""同步关闭Session的封装,供atexit调用"""
|
|
71
|
+
# 注意:atexit在主线程运行,如果当前没有事件循环,这个操作可能会受限
|
|
72
|
+
# 但它能捕获大多数正常退出的场景。对于asyncio程序,建议显式调用cleanup
|
|
73
|
+
try:
|
|
74
|
+
loop = asyncio.get_event_loop()
|
|
75
|
+
if loop.is_running():
|
|
76
|
+
# 如果loop还在跑,创建一个任务去关闭
|
|
77
|
+
loop.create_task(self.close_session())
|
|
78
|
+
else:
|
|
79
|
+
# 如果loop已经停止,尝试运行一次
|
|
80
|
+
loop.run_until_complete(self.close_session())
|
|
81
|
+
except Exception:
|
|
82
|
+
# 静默处理清理失败,避免退出报错
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
async def _retry_request(self, func, *args, **kwargs):
|
|
86
|
+
"""
|
|
87
|
+
原生异步重试封装函数
|
|
88
|
+
Args:
|
|
89
|
+
func: 待重试的异步函数
|
|
90
|
+
*args: 函数参数
|
|
91
|
+
**kwargs: 函数关键字参数
|
|
92
|
+
Returns:
|
|
93
|
+
函数执行结果,重试失败返回None
|
|
94
|
+
"""
|
|
95
|
+
attempt = 0
|
|
96
|
+
while attempt < self.max_retry_attempts:
|
|
97
|
+
try:
|
|
98
|
+
return await func(*args, **kwargs)
|
|
99
|
+
except (aiohttp.ClientConnectionResetError, asyncio.TimeoutError, aiohttp.ClientError) as e:
|
|
100
|
+
attempt += 1
|
|
101
|
+
if attempt >= self.max_retry_attempts:
|
|
102
|
+
SYLogger.error(
|
|
103
|
+
f"Request failed after {attempt} retries: {str(e)}")
|
|
104
|
+
return None
|
|
105
|
+
# 指数退避等待:0.5s → 1s → 2s(最大不超过5s)
|
|
106
|
+
wait_time = min(self.retry_wait_base * (2 ** (attempt - 1)), 5)
|
|
107
|
+
SYLogger.warning(
|
|
108
|
+
f"Retry {func.__name__} (attempt {attempt}/{self.max_retry_attempts}): {str(e)}, wait {wait_time}s")
|
|
109
|
+
await asyncio.sleep(wait_time)
|
|
110
|
+
except Exception as e:
|
|
111
|
+
# 非重试类异常直接返回None
|
|
112
|
+
SYLogger.error(
|
|
113
|
+
f"Non-retryable error in {func.__name__}: {str(e)}")
|
|
114
|
+
return None
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
def _get_embedding_url(self, model: str) -> str:
|
|
118
|
+
"""获取Embedding URL(带缓存)"""
|
|
119
|
+
if model not in self._embedding_url_cache:
|
|
120
|
+
self._embedding_url_cache[model] = EmbeddingConfig.from_config(
|
|
121
|
+
model).baseUrl
|
|
122
|
+
return self._embedding_url_cache[model]
|
|
123
|
+
|
|
124
|
+
def _get_reranker_url(self, model: str) -> str:
|
|
125
|
+
"""获取Reranker URL(带缓存)"""
|
|
126
|
+
if model not in self._reranker_url_cache:
|
|
127
|
+
self._reranker_url_cache[model] = RerankerConfig.from_config(
|
|
128
|
+
model).baseUrl
|
|
129
|
+
return self._reranker_url_cache[model]
|
|
130
|
+
|
|
131
|
+
async def _get_embeddings_http_core(
|
|
30
132
|
self,
|
|
31
133
|
input: Union[str, List[str]],
|
|
32
134
|
encoding_format: str = None,
|
|
33
135
|
model: str = None,
|
|
136
|
+
timeout: aiohttp.ClientTimeout = None,
|
|
34
137
|
**kwargs
|
|
35
138
|
):
|
|
139
|
+
"""embedding请求核心逻辑(剥离重试,供重试封装调用)"""
|
|
140
|
+
await self.init_session() # 确保Session已初始化
|
|
36
141
|
async with self.semaphore:
|
|
37
|
-
|
|
142
|
+
request_timeout = timeout or self.default_timeout
|
|
38
143
|
target_model = model or self.default_embedding_model
|
|
39
|
-
|
|
144
|
+
|
|
145
|
+
# [修复] 使用缓存获取URL
|
|
146
|
+
target_base_url = self._get_embedding_url(target_model)
|
|
40
147
|
url = f"{target_base_url}/v1/embeddings"
|
|
41
148
|
|
|
42
149
|
request_body = {
|
|
@@ -46,16 +153,33 @@ class Embedding(metaclass=SingletonMeta):
|
|
|
46
153
|
}
|
|
47
154
|
request_body.update(kwargs)
|
|
48
155
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
156
|
+
# 复用全局Session
|
|
157
|
+
async with self.session.post(
|
|
158
|
+
url,
|
|
159
|
+
json=request_body,
|
|
160
|
+
timeout=request_timeout
|
|
161
|
+
) as response:
|
|
162
|
+
if response.status != 200:
|
|
163
|
+
error_detail = await response.text()
|
|
164
|
+
SYLogger.error(
|
|
165
|
+
f"Embedding request failed (model: {target_model}): {error_detail}")
|
|
166
|
+
return None
|
|
167
|
+
return await response.json()
|
|
57
168
|
|
|
58
|
-
async def
|
|
169
|
+
async def _get_embeddings_http_async(
|
|
170
|
+
self,
|
|
171
|
+
input: Union[str, List[str]],
|
|
172
|
+
encoding_format: str = None,
|
|
173
|
+
model: str = None,
|
|
174
|
+
timeout: aiohttp.ClientTimeout = None, ** kwargs
|
|
175
|
+
):
|
|
176
|
+
"""对外暴露的embedding请求方法(包含重试)"""
|
|
177
|
+
return await self._retry_request(
|
|
178
|
+
self._get_embeddings_http_core,
|
|
179
|
+
input, encoding_format, model, timeout, ** kwargs
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
async def _get_reranker_http_core(
|
|
59
183
|
self,
|
|
60
184
|
documents: List[str],
|
|
61
185
|
query: str,
|
|
@@ -64,12 +188,16 @@ class Embedding(metaclass=SingletonMeta):
|
|
|
64
188
|
max_chunks_per_doc: Optional[int] = None,
|
|
65
189
|
return_documents: Optional[bool] = True,
|
|
66
190
|
return_len: Optional[bool] = True,
|
|
67
|
-
**kwargs
|
|
191
|
+
timeout: aiohttp.ClientTimeout = None, ** kwargs
|
|
68
192
|
):
|
|
193
|
+
"""reranker请求核心逻辑(剥离重试,供重试封装调用)"""
|
|
194
|
+
await self.init_session() # 确保Session已初始化
|
|
69
195
|
async with self.semaphore:
|
|
70
|
-
|
|
196
|
+
request_timeout = timeout or self.default_timeout
|
|
71
197
|
target_model = model or self.default_reranker_model
|
|
72
|
-
|
|
198
|
+
|
|
199
|
+
# [修复] 使用缓存获取URL
|
|
200
|
+
target_base_url = self._get_reranker_url(target_model)
|
|
73
201
|
url = f"{target_base_url}/v1/rerank"
|
|
74
202
|
|
|
75
203
|
request_body = {
|
|
@@ -80,23 +208,45 @@ class Embedding(metaclass=SingletonMeta):
|
|
|
80
208
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
81
209
|
"return_documents": return_documents,
|
|
82
210
|
"return_len": return_len,
|
|
83
|
-
"kwargs": json.dumps(kwargs),
|
|
84
211
|
}
|
|
85
212
|
request_body.update(kwargs)
|
|
86
213
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
214
|
+
# 复用全局Session
|
|
215
|
+
async with self.session.post(
|
|
216
|
+
url,
|
|
217
|
+
json=request_body,
|
|
218
|
+
timeout=request_timeout
|
|
219
|
+
) as response:
|
|
220
|
+
if response.status != 200:
|
|
221
|
+
error_detail = await response.text()
|
|
222
|
+
SYLogger.error(
|
|
223
|
+
f"Rerank request failed (model: {target_model}): {error_detail}")
|
|
224
|
+
return None
|
|
225
|
+
return await response.json()
|
|
226
|
+
|
|
227
|
+
async def _get_reranker_http_async(
|
|
228
|
+
self,
|
|
229
|
+
documents: List[str],
|
|
230
|
+
query: str,
|
|
231
|
+
top_n: Optional[int] = None,
|
|
232
|
+
model: str = None,
|
|
233
|
+
max_chunks_per_doc: Optional[int] = None,
|
|
234
|
+
return_documents: Optional[bool] = True,
|
|
235
|
+
return_len: Optional[bool] = True,
|
|
236
|
+
timeout: aiohttp.ClientTimeout = None, ** kwargs
|
|
237
|
+
):
|
|
238
|
+
"""对外暴露的reranker请求方法(包含重试)"""
|
|
239
|
+
return await self._retry_request(
|
|
240
|
+
self._get_reranker_http_core,
|
|
241
|
+
documents, query, top_n, model, max_chunks_per_doc,
|
|
242
|
+
return_documents, return_len, timeout, **kwargs
|
|
243
|
+
)
|
|
95
244
|
|
|
96
245
|
async def get_embeddings(
|
|
97
246
|
self,
|
|
98
247
|
corpus: List[str],
|
|
99
|
-
model: str = None
|
|
248
|
+
model: str = None,
|
|
249
|
+
timeout: Optional[Union[int, float]] = None
|
|
100
250
|
):
|
|
101
251
|
"""
|
|
102
252
|
获取语料库的嵌入向量,结果顺序与输入语料库顺序一致
|
|
@@ -104,34 +254,89 @@ class Embedding(metaclass=SingletonMeta):
|
|
|
104
254
|
Args:
|
|
105
255
|
corpus: 待生成嵌入向量的文本列表
|
|
106
256
|
model: 可选,指定使用的embedding模型名称,默认使用bge-large-zh-v1.5
|
|
257
|
+
timeout: 可选,超时时间(秒):
|
|
258
|
+
- 传int/float:表示总超时时间(秒)
|
|
259
|
+
- 不传/None:使用默认永不超时配置
|
|
107
260
|
"""
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
results = await asyncio.gather(*tasks)
|
|
114
|
-
|
|
115
|
-
vectors = []
|
|
116
|
-
for result in results:
|
|
117
|
-
if result is None:
|
|
118
|
-
zero_vector = [0.0] * 1024
|
|
119
|
-
vectors.append(zero_vector)
|
|
261
|
+
request_timeout = None
|
|
262
|
+
if timeout is not None:
|
|
263
|
+
if isinstance(timeout, (int, float)):
|
|
264
|
+
request_timeout = aiohttp.ClientTimeout(total=timeout)
|
|
265
|
+
else:
|
|
120
266
|
SYLogger.warning(
|
|
121
|
-
f"
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
vectors.append(item["embedding"])
|
|
267
|
+
f"Invalid timeout type: {type(timeout)}, must be int/float, use default timeout")
|
|
268
|
+
|
|
269
|
+
actual_model = model or self.default_embedding_model
|
|
125
270
|
|
|
126
271
|
SYLogger.info(
|
|
127
|
-
f"
|
|
128
|
-
|
|
272
|
+
f"Requesting embeddings for corpus: {len(corpus)} items (model: {actual_model}, max_concurrency: {self.max_concurrency}, timeout: {timeout or 'None'})")
|
|
273
|
+
|
|
274
|
+
all_vectors = []
|
|
275
|
+
|
|
276
|
+
# [修复] 增加 Chunk 处理逻辑,防止 corpus 过大导致内存溢出或协程过多
|
|
277
|
+
# 每次最多处理 max_concurrency * 2 个请求,避免一次性创建几十万个协程
|
|
278
|
+
batch_size = self.max_concurrency * 2
|
|
279
|
+
|
|
280
|
+
for i in range(0, len(corpus), batch_size):
|
|
281
|
+
batch_texts = corpus[i: i + batch_size]
|
|
282
|
+
|
|
283
|
+
# 给每个异步任务传入模型名称和超时配置
|
|
284
|
+
tasks = [self._get_embeddings_http_async(
|
|
285
|
+
text, model=model, timeout=request_timeout) for text in batch_texts]
|
|
286
|
+
results = await asyncio.gather(*tasks)
|
|
287
|
+
|
|
288
|
+
for result in results:
|
|
289
|
+
if result is None:
|
|
290
|
+
# [修复] 尝试获取真实维度或使用配置兜底,不再硬编码 1024
|
|
291
|
+
dim = self._model_dim_cache.get(actual_model)
|
|
292
|
+
|
|
293
|
+
# 如果缓存中没有维度,尝试从配置对象获取(假设Config类有dimension属性)
|
|
294
|
+
if dim is None:
|
|
295
|
+
try:
|
|
296
|
+
config = EmbeddingConfig.from_config(actual_model)
|
|
297
|
+
if hasattr(config, 'dimension'):
|
|
298
|
+
dim = config.dimension
|
|
299
|
+
else:
|
|
300
|
+
# 最后的兜底:如果配置也没有,必须有一个默认值防止崩溃
|
|
301
|
+
# bge-large 通常是 1024
|
|
302
|
+
dim = 1024
|
|
303
|
+
SYLogger.warning(
|
|
304
|
+
f"Cannot get dimension from config for {actual_model}, use default 1024")
|
|
305
|
+
except Exception:
|
|
306
|
+
dim = 1024
|
|
307
|
+
|
|
308
|
+
zero_vector = [0.0] * dim
|
|
309
|
+
all_vectors.append(zero_vector)
|
|
310
|
+
SYLogger.warning(
|
|
311
|
+
f"Embedding request failed, append zero vector ({dim}D) for model {actual_model}")
|
|
312
|
+
continue
|
|
313
|
+
|
|
314
|
+
# 从返回结果中提取向量并更新维度缓存
|
|
315
|
+
# 正常情况下 result["data"] 是一个列表
|
|
316
|
+
try:
|
|
317
|
+
for item in result["data"]:
|
|
318
|
+
embedding = item["embedding"]
|
|
319
|
+
# [修复] 动态学习并缓存维度
|
|
320
|
+
if actual_model not in self._model_dim_cache:
|
|
321
|
+
self._model_dim_cache[actual_model] = len(
|
|
322
|
+
embedding)
|
|
323
|
+
all_vectors.append(embedding)
|
|
324
|
+
except (KeyError, TypeError) as e:
|
|
325
|
+
SYLogger.error(f"Failed to parse embedding result: {e}")
|
|
326
|
+
# 解析失败也补零
|
|
327
|
+
dim = self._model_dim_cache.get(actual_model, 1024)
|
|
328
|
+
all_vectors.append([0.0] * dim)
|
|
329
|
+
|
|
330
|
+
SYLogger.info(
|
|
331
|
+
f"Embeddings for corpus created: {len(all_vectors)} vectors (model: {actual_model})")
|
|
332
|
+
return all_vectors
|
|
129
333
|
|
|
130
334
|
async def get_reranker(
|
|
131
335
|
self,
|
|
132
336
|
top_results: List[str],
|
|
133
337
|
query: str,
|
|
134
|
-
model: str = None
|
|
338
|
+
model: str = None,
|
|
339
|
+
timeout: Optional[Union[int, float]] = None
|
|
135
340
|
):
|
|
136
341
|
"""
|
|
137
342
|
对搜索结果进行重排序
|
|
@@ -140,10 +345,24 @@ class Embedding(metaclass=SingletonMeta):
|
|
|
140
345
|
top_results: 待重排序的文本列表
|
|
141
346
|
query: 排序参考的查询语句
|
|
142
347
|
model: 可选,指定使用的reranker模型名称,默认使用bge-reranker-large
|
|
348
|
+
timeout: 可选,超时时间(秒):
|
|
349
|
+
- 传int/float:表示总超时时间(秒)
|
|
350
|
+
- 不传/None:使用默认永不超时配置
|
|
143
351
|
"""
|
|
352
|
+
request_timeout = None
|
|
353
|
+
if timeout is not None:
|
|
354
|
+
if isinstance(timeout, (int, float)):
|
|
355
|
+
request_timeout = aiohttp.ClientTimeout(total=timeout)
|
|
356
|
+
else:
|
|
357
|
+
SYLogger.warning(
|
|
358
|
+
f"Invalid timeout type: {type(timeout)}, must be int/float, use default timeout")
|
|
359
|
+
|
|
360
|
+
actual_model = model or self.default_reranker_model
|
|
144
361
|
SYLogger.info(
|
|
145
|
-
f"Requesting reranker for top_results: {top_results} (model: {
|
|
146
|
-
|
|
362
|
+
f"Requesting reranker for top_results: {top_results} (model: {actual_model}, max_concurrency: {self.max_concurrency}, timeout: {timeout or 'None'})")
|
|
363
|
+
|
|
364
|
+
data = await self._get_reranker_http_async(
|
|
365
|
+
top_results, query, model=model, timeout=request_timeout)
|
|
147
366
|
SYLogger.info(
|
|
148
|
-
f"Reranker for top_results
|
|
367
|
+
f"Reranker for top_results completed (model: {actual_model})")
|
|
149
368
|
return data
|