SimpleLLMFunc 0.2.3__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/PKG-INFO +3 -2
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/__init__.py +4 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/interface/__init__.py +4 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/interface/openai_compatible.py +100 -14
- simplellmfunc-0.2.6/SimpleLLMFunc/interface/token_bucket.py +232 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/llm_decorator/llm_chat_decorator.py +227 -412
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/llm_decorator/llm_function_decorator.py +37 -244
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/llm_decorator/utils.py +335 -1
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/logger/logger.py +30 -4
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/tool/tool.py +224 -99
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/pyproject.toml +5 -5
- simplellmfunc-0.2.3/SimpleLLMFunc/.DS_Store +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/LICENSE +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/README.md +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/config.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/interface/key_pool.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/interface/llm_interface.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/llm_decorator/__init__.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/llm_decorator/multimodal_types.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/logger/__init__.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/logger/logger_config.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/tool/__init__.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/type/__init__.py +0 -0
- {simplellmfunc-0.2.3 → simplellmfunc-0.2.6}/SimpleLLMFunc/utils.py +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: SimpleLLMFunc
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Summary: 一个轻量但完备的LLM/Agent应用开发框架,提供装饰器实现将函数DocString作为Prompt而无需函数体具体实现但能够享受函数定义和类型标注带来效率提升的开发体验。以最Code的方式,用最少的代码将LLM能力集成到任意Python项目中。
|
|
5
5
|
Author: Ni Jingzhe
|
|
6
|
-
Author-email: nijingzhe@
|
|
6
|
+
Author-email: nijingzhe@zju.edu.cn
|
|
7
7
|
Requires-Python: >=3.10
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -11,6 +11,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.13
|
|
13
13
|
Requires-Dist: httpx[socks] (>=0.28.1,<0.29.0)
|
|
14
|
+
Requires-Dist: nest-asyncio (>=1.6.0,<2.0.0)
|
|
14
15
|
Requires-Dist: openai (>=1.84.0,<2.0.0)
|
|
15
16
|
Requires-Dist: pydantic (>=2.11.5,<3.0.0)
|
|
16
17
|
Requires-Dist: pydantic-settings (>=2.9.1,<3.0.0)
|
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
from SimpleLLMFunc.interface.key_pool import APIKeyPool
|
|
2
2
|
from SimpleLLMFunc.interface.openai_compatible import OpenAICompatible
|
|
3
|
+
from SimpleLLMFunc.interface.token_bucket import TokenBucket, RateLimitManager, rate_limit_manager
|
|
3
4
|
|
|
4
5
|
__all__ = [
|
|
5
6
|
"APIKeyPool",
|
|
6
7
|
"OpenAICompatible",
|
|
8
|
+
"TokenBucket",
|
|
9
|
+
"RateLimitManager",
|
|
10
|
+
"rate_limit_manager",
|
|
7
11
|
]
|
|
@@ -7,6 +7,7 @@ from typing import Optional, Dict, Literal, Iterable, Any, AsyncGenerator
|
|
|
7
7
|
from openai import AsyncOpenAI
|
|
8
8
|
from SimpleLLMFunc.interface.llm_interface import LLM_Interface
|
|
9
9
|
from SimpleLLMFunc.interface.key_pool import APIKeyPool
|
|
10
|
+
from SimpleLLMFunc.interface.token_bucket import TokenBucket, rate_limit_manager
|
|
10
11
|
from SimpleLLMFunc.logger import (
|
|
11
12
|
app_log,
|
|
12
13
|
push_warning,
|
|
@@ -62,32 +63,40 @@ class OpenAICompatible(LLM_Interface):
|
|
|
62
63
|
{
|
|
63
64
|
"model_name": "gpt-3.5-turbo",
|
|
64
65
|
"api_keys": [key1, key2, key3],
|
|
65
|
-
"base_url": "https://api.openai.com/v1"
|
|
66
|
+
"base_url": "https://api.openai.com/v1",
|
|
66
67
|
"max_retries": 5,
|
|
67
|
-
"retry_delay": 1.0
|
|
68
|
+
"retry_delay": 1.0,
|
|
69
|
+
"rate_limit_capacity": 10,
|
|
70
|
+
"rate_limit_refill_rate": 1.0
|
|
68
71
|
},
|
|
69
72
|
{
|
|
70
73
|
"model_name": "gpt-4",
|
|
71
74
|
"api_keys": [key1, key2, key3],
|
|
72
|
-
"base_url": "https://api.openai.com/v1"
|
|
75
|
+
"base_url": "https://api.openai.com/v1",
|
|
73
76
|
"max_retries": 5,
|
|
74
|
-
"retry_delay": 1.0
|
|
77
|
+
"retry_delay": 1.0,
|
|
78
|
+
"rate_limit_capacity": 5,
|
|
79
|
+
"rate_limit_refill_rate": 0.5
|
|
75
80
|
}
|
|
76
81
|
],
|
|
77
82
|
"zhipu": [
|
|
78
83
|
{
|
|
79
84
|
"model_name": "gpt-3.5-turbo",
|
|
80
85
|
"api_keys": [key1, key2, key3],
|
|
81
|
-
"base_url": "https://open.bigmodel.cn/api/paas/v4/"
|
|
86
|
+
"base_url": "https://open.bigmodel.cn/api/paas/v4/",
|
|
82
87
|
"max_retries": 5,
|
|
83
|
-
"retry_delay": 1.0
|
|
88
|
+
"retry_delay": 1.0,
|
|
89
|
+
"rate_limit_capacity": 15,
|
|
90
|
+
"rate_limit_refill_rate": 2.0
|
|
84
91
|
},
|
|
85
92
|
{
|
|
86
93
|
"model_name": "gpt-4",
|
|
87
94
|
"api_keys": [key1, key2, key3],
|
|
88
|
-
"base_url": "https://open.bigmodel.cn/api/paas/v4/"
|
|
95
|
+
"base_url": "https://open.bigmodel.cn/api/paas/v4/",
|
|
89
96
|
"max_retries": 5,
|
|
90
|
-
"retry_delay": 1.0
|
|
97
|
+
"retry_delay": 1.0,
|
|
98
|
+
"rate_limit_capacity": 8,
|
|
99
|
+
"rate_limit_refill_rate": 1.5
|
|
91
100
|
}
|
|
92
101
|
]
|
|
93
102
|
}
|
|
@@ -145,6 +154,8 @@ class OpenAICompatible(LLM_Interface):
|
|
|
145
154
|
base_url = model_info["base_url"]
|
|
146
155
|
max_retries = model_info.get("max_retries", 5)
|
|
147
156
|
retry_delay = model_info.get("retry_delay", 1.0)
|
|
157
|
+
rate_limit_capacity = model_info.get("rate_limit_capacity", 10)
|
|
158
|
+
rate_limit_refill_rate = model_info.get("rate_limit_refill_rate", 1.0)
|
|
148
159
|
|
|
149
160
|
# 创建APIKeyPool实例
|
|
150
161
|
key_pool = APIKeyPool(api_keys, f"{provider_id}-{model_name}")
|
|
@@ -156,6 +167,8 @@ class OpenAICompatible(LLM_Interface):
|
|
|
156
167
|
base_url=base_url,
|
|
157
168
|
max_retries=max_retries,
|
|
158
169
|
retry_delay=retry_delay,
|
|
170
|
+
rate_limit_capacity=rate_limit_capacity,
|
|
171
|
+
rate_limit_refill_rate=rate_limit_refill_rate,
|
|
159
172
|
)
|
|
160
173
|
|
|
161
174
|
all_providers_dict[provider_id][model_name] = instance
|
|
@@ -190,6 +203,18 @@ class OpenAICompatible(LLM_Interface):
|
|
|
190
203
|
f"OpenAICompatible(model_name={self.model_name}, base_url={self.base_url})"
|
|
191
204
|
)
|
|
192
205
|
|
|
206
|
+
def get_rate_limit_status(self) -> Dict[str, Any]:
|
|
207
|
+
"""获取当前实例的令牌桶状态
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
包含令牌桶状态信息的字典
|
|
211
|
+
"""
|
|
212
|
+
return self.token_bucket.get_info()
|
|
213
|
+
|
|
214
|
+
def reset_rate_limit(self) -> None:
|
|
215
|
+
"""重置令牌桶(填满令牌)"""
|
|
216
|
+
self.token_bucket.reset()
|
|
217
|
+
|
|
193
218
|
def __init__(
|
|
194
219
|
self,
|
|
195
220
|
api_key_pool: APIKeyPool,
|
|
@@ -197,6 +222,8 @@ class OpenAICompatible(LLM_Interface):
|
|
|
197
222
|
base_url: str,
|
|
198
223
|
max_retries: int = 5,
|
|
199
224
|
retry_delay: float = 1.0,
|
|
225
|
+
rate_limit_capacity: int = 10,
|
|
226
|
+
rate_limit_refill_rate: float = 1.0,
|
|
200
227
|
):
|
|
201
228
|
"""初始化OpenAI兼容的LLM接口
|
|
202
229
|
|
|
@@ -206,6 +233,8 @@ class OpenAICompatible(LLM_Interface):
|
|
|
206
233
|
base_url: API基础URL,例如"https://api.openai.com/v1"或"https://open.bigmodel.cn/api/paas/v4/"
|
|
207
234
|
max_retries: 最大重试次数
|
|
208
235
|
retry_delay: 重试间隔时间(秒)
|
|
236
|
+
rate_limit_capacity: 令牌桶容量(最大令牌数)
|
|
237
|
+
rate_limit_refill_rate: 令牌补充速率(令牌数/秒)
|
|
209
238
|
"""
|
|
210
239
|
super().__init__(api_key_pool, model_name)
|
|
211
240
|
self.max_retries = max_retries
|
|
@@ -215,10 +244,49 @@ class OpenAICompatible(LLM_Interface):
|
|
|
215
244
|
self.model_name = model_name
|
|
216
245
|
|
|
217
246
|
self.key_pool = api_key_pool
|
|
247
|
+
|
|
248
|
+
# 创建令牌桶,使用provider和model作为唯一标识
|
|
249
|
+
bucket_id = f"{base_url}_{model_name}"
|
|
250
|
+
self.token_bucket = rate_limit_manager.get_or_create_bucket(
|
|
251
|
+
bucket_id=bucket_id,
|
|
252
|
+
capacity=rate_limit_capacity,
|
|
253
|
+
refill_rate=rate_limit_refill_rate
|
|
254
|
+
)
|
|
255
|
+
|
|
218
256
|
self.client = AsyncOpenAI(
|
|
219
257
|
api_key=api_key_pool.get_least_loaded_key(), base_url=self.base_url
|
|
220
258
|
)
|
|
221
259
|
|
|
260
|
+
async def _get_or_create_client(self, key: str) -> AsyncOpenAI:
|
|
261
|
+
"""获取或创建客户端,确保使用正确的API密钥"""
|
|
262
|
+
# 如果当前客户端的API密钥不匹配,或者客户端为None,创建新的客户端
|
|
263
|
+
if (not hasattr(self, '_current_key') or self._current_key != key or
|
|
264
|
+
not hasattr(self, 'client') or self.client is None):
|
|
265
|
+
|
|
266
|
+
# 关闭旧客户端
|
|
267
|
+
if hasattr(self, 'client') and self.client is not None:
|
|
268
|
+
try:
|
|
269
|
+
await self.client.close() # type: ignore
|
|
270
|
+
except Exception:
|
|
271
|
+
# 忽略关闭异常
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
# 创建新客户端
|
|
275
|
+
self.client = AsyncOpenAI(api_key=key, base_url=self.base_url)
|
|
276
|
+
self._current_key = key
|
|
277
|
+
|
|
278
|
+
return self.client
|
|
279
|
+
|
|
280
|
+
async def aclose(self):
|
|
281
|
+
"""关闭客户端连接"""
|
|
282
|
+
if hasattr(self, 'client') and self.client is not None:
|
|
283
|
+
try:
|
|
284
|
+
await self.client.close() # type: ignore
|
|
285
|
+
except Exception:
|
|
286
|
+
pass
|
|
287
|
+
finally:
|
|
288
|
+
self.client = None
|
|
289
|
+
|
|
222
290
|
async def chat(
|
|
223
291
|
self,
|
|
224
292
|
trace_id: str = get_current_trace_id(),
|
|
@@ -246,18 +314,27 @@ class OpenAICompatible(LLM_Interface):
|
|
|
246
314
|
LLM的响应内容
|
|
247
315
|
"""
|
|
248
316
|
key = self.key_pool.get_least_loaded_key()
|
|
249
|
-
|
|
317
|
+
client = await self._get_or_create_client(key)
|
|
250
318
|
|
|
251
319
|
attempt = 0
|
|
252
320
|
while attempt < self.max_retries:
|
|
253
321
|
try:
|
|
322
|
+
# 获取令牌桶令牌,设置30秒超时
|
|
323
|
+
token_acquired = await self.token_bucket.acquire(tokens_needed=1, timeout=30.0)
|
|
324
|
+
if not token_acquired:
|
|
325
|
+
push_warning(
|
|
326
|
+
f"{self.model_name} 令牌桶获取令牌超时,跳过此次请求",
|
|
327
|
+
location=get_location(),
|
|
328
|
+
)
|
|
329
|
+
raise Exception("Rate limit: 令牌桶获取令牌超时")
|
|
330
|
+
|
|
254
331
|
self.key_pool.increment_task_count(key)
|
|
255
332
|
data = json.dumps(messages, ensure_ascii=False, indent=4)
|
|
256
333
|
push_debug(
|
|
257
334
|
f"OpenAICompatible::chat: {self.model_name} request with API key: {key}, and message: {data}",
|
|
258
335
|
location=get_location(),
|
|
259
336
|
)
|
|
260
|
-
response: Dict[Any, Any] = await
|
|
337
|
+
response: Dict[Any, Any] = await client.chat.completions.create( # type: ignore
|
|
261
338
|
messages=messages, # type: ignore
|
|
262
339
|
model=self.model_name,
|
|
263
340
|
stream=stream,
|
|
@@ -295,7 +372,7 @@ class OpenAICompatible(LLM_Interface):
|
|
|
295
372
|
)
|
|
296
373
|
|
|
297
374
|
key = self.key_pool.get_least_loaded_key()
|
|
298
|
-
|
|
375
|
+
client = await self._get_or_create_client(key)
|
|
299
376
|
|
|
300
377
|
if attempt >= self.max_retries:
|
|
301
378
|
push_error(
|
|
@@ -333,18 +410,27 @@ class OpenAICompatible(LLM_Interface):
|
|
|
333
410
|
LLM的响应块
|
|
334
411
|
"""
|
|
335
412
|
key = self.key_pool.get_least_loaded_key()
|
|
336
|
-
|
|
413
|
+
client = await self._get_or_create_client(key)
|
|
337
414
|
|
|
338
415
|
attempt = 0
|
|
339
416
|
while attempt < self.max_retries:
|
|
340
417
|
try:
|
|
418
|
+
# 获取令牌桶令牌,设置30秒超时
|
|
419
|
+
token_acquired = await self.token_bucket.acquire(tokens_needed=1, timeout=30.0)
|
|
420
|
+
if not token_acquired:
|
|
421
|
+
push_warning(
|
|
422
|
+
f"{self.model_name} 流式请求令牌桶获取令牌超时,跳过此次请求",
|
|
423
|
+
location=get_location(),
|
|
424
|
+
)
|
|
425
|
+
raise Exception("Rate limit: 令牌桶获取令牌超时")
|
|
426
|
+
|
|
341
427
|
self.key_pool.increment_task_count(key)
|
|
342
428
|
data = json.dumps(messages, ensure_ascii=False, indent=4)
|
|
343
429
|
push_debug(
|
|
344
430
|
f"OpenAICompatible::chat_stream: {self.model_name} request with API key: {key}, and message: {data}",
|
|
345
431
|
location=get_location(),
|
|
346
432
|
)
|
|
347
|
-
response = await
|
|
433
|
+
response = await client.chat.completions.create( # type: ignore
|
|
348
434
|
messages=messages, # type: ignore
|
|
349
435
|
model=self.model_name,
|
|
350
436
|
stream=stream,
|
|
@@ -388,7 +474,7 @@ class OpenAICompatible(LLM_Interface):
|
|
|
388
474
|
)
|
|
389
475
|
|
|
390
476
|
key = self.key_pool.get_least_loaded_key()
|
|
391
|
-
|
|
477
|
+
client = await self._get_or_create_client(key)
|
|
392
478
|
|
|
393
479
|
if attempt >= self.max_retries:
|
|
394
480
|
push_error(
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from typing import Optional, Dict, Any
|
|
4
|
+
import threading
|
|
5
|
+
from SimpleLLMFunc.logger import push_debug, push_warning, get_location
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TokenBucket:
|
|
9
|
+
"""令牌桶算法实现,用于API请求的流量控制
|
|
10
|
+
|
|
11
|
+
令牌桶算法可以平滑突发流量,允许一定程度的突发请求,
|
|
12
|
+
同时确保长期平均速率不超过配置的限制。
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
# 类变量用于存储单例实例
|
|
16
|
+
_instances: Dict[str, 'TokenBucket'] = {}
|
|
17
|
+
_lock = threading.Lock()
|
|
18
|
+
|
|
19
|
+
def __new__(cls, bucket_id: str, capacity: int = 10, refill_rate: float = 1.0) -> 'TokenBucket':
|
|
20
|
+
"""单例模式,确保相同bucket_id只有一个实例"""
|
|
21
|
+
with cls._lock:
|
|
22
|
+
if bucket_id not in cls._instances:
|
|
23
|
+
instance = super(TokenBucket, cls).__new__(cls)
|
|
24
|
+
cls._instances[bucket_id] = instance
|
|
25
|
+
return cls._instances[bucket_id]
|
|
26
|
+
|
|
27
|
+
def __init__(self, bucket_id: str, capacity: int = 10, refill_rate: float = 1.0):
|
|
28
|
+
"""初始化令牌桶
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
bucket_id: 令牌桶唯一标识符
|
|
32
|
+
capacity: 令牌桶容量(最大令牌数)
|
|
33
|
+
refill_rate: 令牌补充速率(令牌数/秒)
|
|
34
|
+
"""
|
|
35
|
+
# 如果已经初始化,跳过初始化过程
|
|
36
|
+
if hasattr(self, 'initialized') and self.initialized:
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
self.bucket_id = bucket_id
|
|
40
|
+
self.capacity = capacity
|
|
41
|
+
self.refill_rate = refill_rate
|
|
42
|
+
self.tokens = float(capacity) # 初始时桶是满的
|
|
43
|
+
self.last_refill_time = time.time()
|
|
44
|
+
# 使用线程锁来保护所有操作,因为线程锁在异步环境中也是安全的
|
|
45
|
+
self._lock = threading.Lock()
|
|
46
|
+
self.initialized = True
|
|
47
|
+
|
|
48
|
+
push_debug(
|
|
49
|
+
f"TokenBucket {bucket_id} 初始化完成: capacity={capacity}, refill_rate={refill_rate}",
|
|
50
|
+
location=get_location()
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def _refill_tokens(self) -> None:
|
|
54
|
+
"""补充令牌到桶中"""
|
|
55
|
+
current_time = time.time()
|
|
56
|
+
time_passed = current_time - self.last_refill_time
|
|
57
|
+
|
|
58
|
+
# 计算应该补充的令牌数
|
|
59
|
+
tokens_to_add = time_passed * self.refill_rate
|
|
60
|
+
|
|
61
|
+
# 更新令牌数,不能超过容量
|
|
62
|
+
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
|
63
|
+
self.last_refill_time = current_time
|
|
64
|
+
|
|
65
|
+
push_debug(
|
|
66
|
+
f"TokenBucket {self.bucket_id} 补充令牌: 添加={tokens_to_add:.2f}, 当前={self.tokens:.2f}",
|
|
67
|
+
location=get_location()
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
async def acquire(self, tokens_needed: int = 1, timeout: Optional[float] = None) -> bool:
|
|
71
|
+
"""异步获取令牌
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
tokens_needed: 需要的令牌数量
|
|
75
|
+
timeout: 超时时间(秒),None表示无限等待
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
True表示成功获取令牌,False表示超时失败
|
|
79
|
+
"""
|
|
80
|
+
start_time = time.time()
|
|
81
|
+
|
|
82
|
+
while True:
|
|
83
|
+
# 使用线程锁保护临界区
|
|
84
|
+
with self._lock:
|
|
85
|
+
self._refill_tokens()
|
|
86
|
+
|
|
87
|
+
if self.tokens >= tokens_needed:
|
|
88
|
+
self.tokens -= tokens_needed
|
|
89
|
+
push_debug(
|
|
90
|
+
f"TokenBucket {self.bucket_id} 成功获取 {tokens_needed} 个令牌, 剩余={self.tokens:.2f}",
|
|
91
|
+
location=get_location()
|
|
92
|
+
)
|
|
93
|
+
return True
|
|
94
|
+
|
|
95
|
+
# 计算等待时间:需要多久才能补充足够的令牌
|
|
96
|
+
tokens_needed_to_wait = tokens_needed - self.tokens
|
|
97
|
+
wait_time = tokens_needed_to_wait / self.refill_rate
|
|
98
|
+
|
|
99
|
+
# 检查超时(在锁外检查)
|
|
100
|
+
if timeout is not None:
|
|
101
|
+
elapsed = time.time() - start_time
|
|
102
|
+
if elapsed >= timeout:
|
|
103
|
+
push_warning(
|
|
104
|
+
f"TokenBucket {self.bucket_id} 获取令牌超时: 需要={tokens_needed}, 可用={self.tokens:.2f}",
|
|
105
|
+
location=get_location()
|
|
106
|
+
)
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
# 最多等待100ms,避免长时间阻塞
|
|
110
|
+
wait_time = min(wait_time, 0.1)
|
|
111
|
+
|
|
112
|
+
push_debug(
|
|
113
|
+
f"TokenBucket {self.bucket_id} 等待令牌补充: 需要={tokens_needed}, 可用={self.tokens:.2f}, 等待={wait_time:.3f}s",
|
|
114
|
+
location=get_location()
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
await asyncio.sleep(wait_time)
|
|
118
|
+
|
|
119
|
+
def try_acquire(self, tokens_needed: int = 1) -> bool:
|
|
120
|
+
"""同步方式尝试获取令牌(非阻塞)
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
tokens_needed: 需要的令牌数量
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
True表示成功获取令牌,False表示令牌不足
|
|
127
|
+
"""
|
|
128
|
+
with self._lock:
|
|
129
|
+
self._refill_tokens()
|
|
130
|
+
|
|
131
|
+
if self.tokens >= tokens_needed:
|
|
132
|
+
self.tokens -= tokens_needed
|
|
133
|
+
push_debug(
|
|
134
|
+
f"TokenBucket {self.bucket_id} 同步获取 {tokens_needed} 个令牌成功, 剩余={self.tokens:.2f}",
|
|
135
|
+
location=get_location()
|
|
136
|
+
)
|
|
137
|
+
return True
|
|
138
|
+
else:
|
|
139
|
+
push_debug(
|
|
140
|
+
f"TokenBucket {self.bucket_id} 同步获取 {tokens_needed} 个令牌失败, 可用={self.tokens:.2f}",
|
|
141
|
+
location=get_location()
|
|
142
|
+
)
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
def get_available_tokens(self) -> float:
|
|
146
|
+
"""获取当前可用令牌数"""
|
|
147
|
+
with self._lock:
|
|
148
|
+
self._refill_tokens()
|
|
149
|
+
return self.tokens
|
|
150
|
+
|
|
151
|
+
def get_info(self) -> Dict[str, Any]:
|
|
152
|
+
"""获取令牌桶状态信息"""
|
|
153
|
+
with self._lock:
|
|
154
|
+
self._refill_tokens()
|
|
155
|
+
return {
|
|
156
|
+
"bucket_id": self.bucket_id,
|
|
157
|
+
"capacity": self.capacity,
|
|
158
|
+
"refill_rate": self.refill_rate,
|
|
159
|
+
"available_tokens": self.tokens,
|
|
160
|
+
"last_refill_time": self.last_refill_time
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
def reset(self) -> None:
|
|
164
|
+
"""重置令牌桶(填满令牌)"""
|
|
165
|
+
with self._lock:
|
|
166
|
+
self.tokens = float(self.capacity)
|
|
167
|
+
self.last_refill_time = time.time()
|
|
168
|
+
push_debug(
|
|
169
|
+
f"TokenBucket {self.bucket_id} 已重置,令牌数={self.tokens}",
|
|
170
|
+
location=get_location()
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def __repr__(self) -> str:
|
|
174
|
+
"""返回令牌桶的字符串表示"""
|
|
175
|
+
return (
|
|
176
|
+
f"TokenBucket(id={self.bucket_id}, capacity={self.capacity}, "
|
|
177
|
+
f"refill_rate={self.refill_rate}, tokens={self.tokens:.2f})"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class RateLimitManager:
|
|
182
|
+
"""速率限制管理器,管理多个令牌桶"""
|
|
183
|
+
|
|
184
|
+
def __init__(self):
|
|
185
|
+
self._buckets: Dict[str, TokenBucket] = {}
|
|
186
|
+
self._lock = threading.Lock()
|
|
187
|
+
|
|
188
|
+
def get_or_create_bucket(
|
|
189
|
+
self,
|
|
190
|
+
bucket_id: str,
|
|
191
|
+
capacity: int = 10,
|
|
192
|
+
refill_rate: float = 1.0
|
|
193
|
+
) -> TokenBucket:
|
|
194
|
+
"""获取或创建令牌桶
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
bucket_id: 令牌桶ID
|
|
198
|
+
capacity: 桶容量
|
|
199
|
+
refill_rate: 补充速率
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
TokenBucket实例
|
|
203
|
+
"""
|
|
204
|
+
with self._lock:
|
|
205
|
+
if bucket_id not in self._buckets:
|
|
206
|
+
self._buckets[bucket_id] = TokenBucket(bucket_id, capacity, refill_rate)
|
|
207
|
+
return self._buckets[bucket_id]
|
|
208
|
+
|
|
209
|
+
def get_bucket(self, bucket_id: str) -> Optional[TokenBucket]:
|
|
210
|
+
"""获取指定的令牌桶"""
|
|
211
|
+
return self._buckets.get(bucket_id)
|
|
212
|
+
|
|
213
|
+
def remove_bucket(self, bucket_id: str) -> bool:
|
|
214
|
+
"""移除指定的令牌桶"""
|
|
215
|
+
with self._lock:
|
|
216
|
+
if bucket_id in self._buckets:
|
|
217
|
+
del self._buckets[bucket_id]
|
|
218
|
+
return True
|
|
219
|
+
return False
|
|
220
|
+
|
|
221
|
+
def list_buckets(self) -> Dict[str, Dict[str, Any]]:
|
|
222
|
+
"""列出所有令牌桶的状态"""
|
|
223
|
+
return {bucket_id: bucket.get_info() for bucket_id, bucket in self._buckets.items()}
|
|
224
|
+
|
|
225
|
+
def reset_all(self) -> None:
|
|
226
|
+
"""重置所有令牌桶"""
|
|
227
|
+
for bucket in self._buckets.values():
|
|
228
|
+
bucket.reset()
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
# 全局速率限制管理器实例
|
|
232
|
+
rate_limit_manager = RateLimitManager()
|