api-key-manager 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api_key_manager-2.1.0.dist-info/METADATA +709 -0
- api_key_manager-2.1.0.dist-info/RECORD +73 -0
- api_key_manager-2.1.0.dist-info/WHEEL +5 -0
- api_key_manager-2.1.0.dist-info/entry_points.txt +2 -0
- api_key_manager-2.1.0.dist-info/top_level.txt +1 -0
- key_manager/__init__.py +16 -0
- key_manager/__main__.py +5 -0
- key_manager/api_models.py +358 -0
- key_manager/checker.py +51 -0
- key_manager/cli.py +270 -0
- key_manager/config.py +61 -0
- key_manager/core.py +205 -0
- key_manager/detector.py +335 -0
- key_manager/errors.py +179 -0
- key_manager/i18n.py +142 -0
- key_manager/logger.py +207 -0
- key_manager/model_capabilities.py +412 -0
- key_manager/parser.py +153 -0
- key_manager/providers/__init__.py +283 -0
- key_manager/providers/ai302.py +109 -0
- key_manager/providers/anthropic.py +109 -0
- key_manager/providers/baichuan.py +97 -0
- key_manager/providers/base.py +312 -0
- key_manager/providers/cerebras.py +109 -0
- key_manager/providers/cohere.py +90 -0
- key_manager/providers/cstcloud.py +122 -0
- key_manager/providers/dashscope.py +120 -0
- key_manager/providers/dashscope_coding.py +122 -0
- key_manager/providers/deepseek.py +166 -0
- key_manager/providers/dmxapi.py +109 -0
- key_manager/providers/doubao.py +109 -0
- key_manager/providers/fireworks.py +109 -0
- key_manager/providers/google.py +99 -0
- key_manager/providers/grok.py +109 -0
- key_manager/providers/groq.py +109 -0
- key_manager/providers/huggingface.py +54 -0
- key_manager/providers/hyperbolic.py +109 -0
- key_manager/providers/infini.py +135 -0
- key_manager/providers/infini_coding.py +124 -0
- key_manager/providers/kimi.py +121 -0
- key_manager/providers/kimi_coding.py +124 -0
- key_manager/providers/longcat.py +123 -0
- key_manager/providers/mimo.py +109 -0
- key_manager/providers/mimo_plan.py +140 -0
- key_manager/providers/minimax.py +97 -0
- key_manager/providers/minimax_plan.py +122 -0
- key_manager/providers/mistral.py +109 -0
- key_manager/providers/models_registry.py +2901 -0
- key_manager/providers/modelscope.py +134 -0
- key_manager/providers/nvidia.py +109 -0
- key_manager/providers/ocoolai.py +109 -0
- key_manager/providers/openai.py +140 -0
- key_manager/providers/openrouter.py +119 -0
- key_manager/providers/perplexity.py +109 -0
- key_manager/providers/poe.py +109 -0
- key_manager/providers/ppio.py +109 -0
- key_manager/providers/replicate.py +54 -0
- key_manager/providers/siliconflow.py +121 -0
- key_manager/providers/stepfun.py +132 -0
- key_manager/providers/tencent_hunyuan.py +122 -0
- key_manager/providers/together.py +134 -0
- key_manager/providers/yi.py +97 -0
- key_manager/providers/zai.py +109 -0
- key_manager/providers/zhipu.py +127 -0
- key_manager/providers/zhipu_coding.py +124 -0
- key_manager/proxy.py +70 -0
- key_manager/ssrf.py +68 -0
- key_manager/storage.py +134 -0
- key_manager/tester.py +137 -0
- key_manager/url_override.py +5 -0
- key_manager/validator.py +185 -0
- key_manager/web.py +1512 -0
- key_manager/webhook.py +257 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from .base import ProviderBase, CheckResult, TestResult
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaichuanProvider(ProviderBase):
|
|
7
|
+
name = "baichuan"
|
|
8
|
+
base_url = "https://api.baichuan-ai.com/v1"
|
|
9
|
+
check_endpoint = "/models"
|
|
10
|
+
check_model = "Baichuan4-Turbo"
|
|
11
|
+
|
|
12
|
+
def build_headers(self, key: str) -> dict:
|
|
13
|
+
return {"Authorization": f"Bearer {key}"}
|
|
14
|
+
|
|
15
|
+
async def get_models(self, client, key: str) -> list[str]:
|
|
16
|
+
headers = self.build_headers(key)
|
|
17
|
+
try:
|
|
18
|
+
resp = await client.get(f"{self.get_base_url()}/models", headers=headers)
|
|
19
|
+
if resp.status_code == 200:
|
|
20
|
+
data = resp.json()
|
|
21
|
+
if "data" in data:
|
|
22
|
+
return [m["id"] for m in data["data"] if "id" in m]
|
|
23
|
+
return []
|
|
24
|
+
except Exception:
|
|
25
|
+
return []
|
|
26
|
+
|
|
27
|
+
async def check(self, client, key: str) -> CheckResult:
|
|
28
|
+
headers = self.build_headers(key)
|
|
29
|
+
headers["Content-Type"] = "application/json"
|
|
30
|
+
start = time.monotonic()
|
|
31
|
+
try:
|
|
32
|
+
resp = await client.post(
|
|
33
|
+
f"{self.get_base_url()}/chat/completions",
|
|
34
|
+
headers=headers,
|
|
35
|
+
json={"model": "Baichuan4-Turbo", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 5}
|
|
36
|
+
)
|
|
37
|
+
latency = (time.monotonic() - start) * 1000
|
|
38
|
+
if resp.status_code == 200:
|
|
39
|
+
return CheckResult(True, 200, latency, None)
|
|
40
|
+
elif resp.status_code in (401, 403):
|
|
41
|
+
return CheckResult(False, resp.status_code, latency, "invalid key or forbidden")
|
|
42
|
+
elif resp.status_code == 429:
|
|
43
|
+
return CheckResult(False, 429, latency, "rate limited")
|
|
44
|
+
else:
|
|
45
|
+
try:
|
|
46
|
+
data = resp.json()
|
|
47
|
+
error_msg = data.get("error", {}).get("message", f"status {resp.status_code}")
|
|
48
|
+
except:
|
|
49
|
+
error_msg = f"status {resp.status_code}"
|
|
50
|
+
return CheckResult(False, resp.status_code, latency, error_msg)
|
|
51
|
+
except Exception as e:
|
|
52
|
+
return CheckResult(False, None, (time.monotonic() - start) * 1000, str(e))
|
|
53
|
+
|
|
54
|
+
async def test_token_limit(self, client, key: str, token_steps: list[int]) -> TestResult:
|
|
55
|
+
headers = self.build_headers(key)
|
|
56
|
+
last_success = None
|
|
57
|
+
for step in token_steps:
|
|
58
|
+
try:
|
|
59
|
+
resp = await client.post(
|
|
60
|
+
f"{self.get_base_url()}/chat/completions",
|
|
61
|
+
headers=headers,
|
|
62
|
+
json={"model": "Baichuan4-Turbo", "messages": [{"role": "user", "content": "hi"}], "max_tokens": step}
|
|
63
|
+
)
|
|
64
|
+
if resp.status_code == 200:
|
|
65
|
+
last_success = step
|
|
66
|
+
elif resp.status_code in (400, 413):
|
|
67
|
+
break
|
|
68
|
+
elif resp.status_code == 429:
|
|
69
|
+
await asyncio.sleep(1)
|
|
70
|
+
continue
|
|
71
|
+
else:
|
|
72
|
+
break
|
|
73
|
+
except Exception:
|
|
74
|
+
break
|
|
75
|
+
return TestResult(max_tokens=last_success)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
async def check_real(self, client, key: str) -> CheckResult:
|
|
79
|
+
return await self.check(client, key)
|
|
80
|
+
async def test_concurrency(self, client, key: str, concurrency_steps: list[int]) -> TestResult:
|
|
81
|
+
headers = self.build_headers(key)
|
|
82
|
+
last_success = None
|
|
83
|
+
for step in concurrency_steps:
|
|
84
|
+
tasks = [self._probe(client, headers) for _ in range(step)]
|
|
85
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
86
|
+
rate_limited = sum(1 for r in results if not isinstance(r, Exception) and not r)
|
|
87
|
+
if rate_limited / step >= 0.3:
|
|
88
|
+
break
|
|
89
|
+
last_success = step
|
|
90
|
+
return TestResult(max_concurrency=last_success)
|
|
91
|
+
|
|
92
|
+
async def _probe(self, client, headers) -> bool:
|
|
93
|
+
try:
|
|
94
|
+
resp = await client.get(f"{self.get_base_url()}{self.check_endpoint}", headers=headers)
|
|
95
|
+
return resp.status_code == 200
|
|
96
|
+
except Exception:
|
|
97
|
+
return False
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from key_manager.url_override import custom_base_url
|
|
6
|
+
|
|
7
|
+
ERROR_TYPES = {
|
|
8
|
+
"invalid_key": "invalid_key",
|
|
9
|
+
"rate_limited": "rate_limited",
|
|
10
|
+
"insufficient_balance": "insufficient_balance",
|
|
11
|
+
"quota_exceeded": "quota_exceeded",
|
|
12
|
+
"account_suspended": "account_suspended",
|
|
13
|
+
"forbidden": "forbidden",
|
|
14
|
+
"not_found": "not_found",
|
|
15
|
+
"server_error": "server_error",
|
|
16
|
+
"timeout": "timeout",
|
|
17
|
+
"connection_error": "connection_error",
|
|
18
|
+
"unknown": "unknown",
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def simplify_error(error_msg: str, status_code: int = None) -> str:
|
|
23
|
+
"""Simplify error message for better readability."""
|
|
24
|
+
if not error_msg:
|
|
25
|
+
return ""
|
|
26
|
+
|
|
27
|
+
# Common error patterns -> friendly messages
|
|
28
|
+
error_lower = error_msg.lower()
|
|
29
|
+
|
|
30
|
+
# Status code based
|
|
31
|
+
if status_code == 401 or status_code == 403:
|
|
32
|
+
return "Key 无效或无权限"
|
|
33
|
+
elif status_code == 400:
|
|
34
|
+
# 400 can mean various things, check error message
|
|
35
|
+
pass # Fall through to pattern matching
|
|
36
|
+
elif status_code == 402:
|
|
37
|
+
return "余额不足"
|
|
38
|
+
elif status_code == 429:
|
|
39
|
+
return "请求过于频繁,请稍后重试"
|
|
40
|
+
elif status_code and status_code >= 500:
|
|
41
|
+
return "服务商内部错误"
|
|
42
|
+
|
|
43
|
+
# Pattern matching for common errors
|
|
44
|
+
if "invalid" in error_lower and ("key" in error_lower or "token" in error_lower or "api" in error_lower):
|
|
45
|
+
return "Key 无效"
|
|
46
|
+
elif "authentication" in error_lower or "unauthorized" in error_lower:
|
|
47
|
+
return "认证失败"
|
|
48
|
+
elif "expired" in error_lower:
|
|
49
|
+
return "Key 已过期"
|
|
50
|
+
elif "rate limit" in error_lower or "too many" in error_lower:
|
|
51
|
+
return "请求过于频繁"
|
|
52
|
+
elif "insufficient" in error_lower or "balance" in error_lower or "overdue" in error_lower or "payment" in error_lower:
|
|
53
|
+
return "余额不足"
|
|
54
|
+
elif "suspended" in error_lower or "banned" in error_lower:
|
|
55
|
+
return "账号被封禁"
|
|
56
|
+
elif "forbidden" in error_lower or "permission" in error_lower or "access denied" in error_lower:
|
|
57
|
+
return "无权限访问"
|
|
58
|
+
elif "not found" in error_lower or "does not exist" in error_lower:
|
|
59
|
+
return "模型不存在"
|
|
60
|
+
elif "timeout" in error_lower:
|
|
61
|
+
return "请求超时"
|
|
62
|
+
elif "connection" in error_lower:
|
|
63
|
+
return "连接失败"
|
|
64
|
+
|
|
65
|
+
# If message is too long, truncate
|
|
66
|
+
if len(error_msg) > 100:
|
|
67
|
+
return error_msg[:100] + "..."
|
|
68
|
+
|
|
69
|
+
return error_msg
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class CheckResult:
|
|
74
|
+
valid: bool
|
|
75
|
+
status_code: Optional[int]
|
|
76
|
+
latency_ms: float
|
|
77
|
+
error: Optional[str]
|
|
78
|
+
rate_limit_info: Optional[dict] = None
|
|
79
|
+
error_type: Optional[str] = None
|
|
80
|
+
response_body: Optional[str] = None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class TestResult:
|
|
85
|
+
max_tokens: Optional[int] = None
|
|
86
|
+
max_concurrency: Optional[int] = None
|
|
87
|
+
rpm_limit: Optional[int] = None
|
|
88
|
+
models: Optional[list[str]] = None
|
|
89
|
+
error: Optional[str] = None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class BalanceResult:
|
|
94
|
+
supported: bool
|
|
95
|
+
balance: Optional[float] = None
|
|
96
|
+
currency: str = "USD"
|
|
97
|
+
raw: Optional[dict] = None
|
|
98
|
+
error: Optional[str] = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ProviderBase(ABC):
|
|
102
|
+
"""Abstract base for all API providers."""
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
@abstractmethod
|
|
106
|
+
def name(self) -> str:
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def base_url(self) -> str:
|
|
112
|
+
...
|
|
113
|
+
|
|
114
|
+
def get_base_url(self) -> str:
|
|
115
|
+
"""Return effective base URL, respecting per-request override."""
|
|
116
|
+
override = custom_base_url.get(None)
|
|
117
|
+
return override if override else self.base_url
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
@abstractmethod
|
|
121
|
+
def check_endpoint(self) -> str:
|
|
122
|
+
...
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def check_model(self) -> str:
|
|
126
|
+
"""Model to use for probe/check. Override in subclass if needed."""
|
|
127
|
+
return "gpt-3.5-turbo" # Default model
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def build_headers(self, key: str) -> dict:
|
|
131
|
+
...
|
|
132
|
+
|
|
133
|
+
async def check(self, client, key: str) -> CheckResult:
|
|
134
|
+
"""Check key validity using multiple models from PROVIDER_MODELS.
|
|
135
|
+
|
|
136
|
+
Strategy:
|
|
137
|
+
1. Get models from PROVIDER_MODELS (Cherry Studio sync)
|
|
138
|
+
2. Try first 5 models with /chat/completions
|
|
139
|
+
3. If any succeeds, return valid
|
|
140
|
+
4. If all fail, return last error
|
|
141
|
+
|
|
142
|
+
Providers with non-standard APIs (Anthropic, Google, etc.) should override this.
|
|
143
|
+
"""
|
|
144
|
+
import time
|
|
145
|
+
from .models_registry import PROVIDER_MODELS
|
|
146
|
+
|
|
147
|
+
headers = self.build_headers(key)
|
|
148
|
+
headers["Content-Type"] = "application/json"
|
|
149
|
+
|
|
150
|
+
# Get models for this provider from Cherry Studio sync
|
|
151
|
+
models = PROVIDER_MODELS.get(self.name, [])
|
|
152
|
+
if not models:
|
|
153
|
+
# Fallback to check_model if no models in registry
|
|
154
|
+
models = [self.check_model]
|
|
155
|
+
|
|
156
|
+
# Try first 5 models
|
|
157
|
+
test_models = models[:5]
|
|
158
|
+
last_error = ""
|
|
159
|
+
last_status = None
|
|
160
|
+
|
|
161
|
+
for model in test_models:
|
|
162
|
+
start = time.monotonic()
|
|
163
|
+
try:
|
|
164
|
+
resp = await client.post(
|
|
165
|
+
f"{self.get_base_url()}/chat/completions",
|
|
166
|
+
headers=headers,
|
|
167
|
+
json={"model": model, "messages": [{"role": "user", "content": "hi"}], "max_tokens": 5}
|
|
168
|
+
)
|
|
169
|
+
latency = (time.monotonic() - start) * 1000
|
|
170
|
+
|
|
171
|
+
if resp.status_code == 200:
|
|
172
|
+
return CheckResult(True, 200, latency, None)
|
|
173
|
+
elif resp.status_code in (401, 403):
|
|
174
|
+
return CheckResult(False, resp.status_code, latency, "invalid key or forbidden")
|
|
175
|
+
elif resp.status_code == 429:
|
|
176
|
+
return CheckResult(False, 429, latency, "rate limited")
|
|
177
|
+
else:
|
|
178
|
+
try:
|
|
179
|
+
data = resp.json()
|
|
180
|
+
last_error = data.get("error", {}).get("message", f"status {resp.status_code}")
|
|
181
|
+
except:
|
|
182
|
+
last_error = f"status {resp.status_code}"
|
|
183
|
+
last_status = resp.status_code
|
|
184
|
+
|
|
185
|
+
# Simplify error for readability
|
|
186
|
+
last_error = simplify_error(last_error, resp.status_code)
|
|
187
|
+
except Exception as e:
|
|
188
|
+
last_error = str(e)
|
|
189
|
+
last_status = None
|
|
190
|
+
|
|
191
|
+
# All models failed
|
|
192
|
+
# All models failed
|
|
193
|
+
return CheckResult(False, last_status, 0, last_error or "all models failed")
|
|
194
|
+
@abstractmethod
|
|
195
|
+
async def test_token_limit(self, client, key: str,
|
|
196
|
+
token_steps: list[int]) -> TestResult:
|
|
197
|
+
...
|
|
198
|
+
|
|
199
|
+
@abstractmethod
|
|
200
|
+
async def test_concurrency(self, client, key: str,
|
|
201
|
+
concurrency_steps: list[int]) -> TestResult:
|
|
202
|
+
...
|
|
203
|
+
|
|
204
|
+
async def probe(self, client, key: str) -> CheckResult:
|
|
205
|
+
"""Probe for provider detection.
|
|
206
|
+
|
|
207
|
+
Strategy:
|
|
208
|
+
1. Get models from PROVIDER_MODELS (Cherry Studio sync)
|
|
209
|
+
2. Try first 5 models with /chat/completions
|
|
210
|
+
3. If any succeeds, return valid
|
|
211
|
+
4. If all fail, return last error body for signature matching
|
|
212
|
+
"""
|
|
213
|
+
import time
|
|
214
|
+
from .models_registry import PROVIDER_MODELS
|
|
215
|
+
|
|
216
|
+
headers = self.build_headers(key)
|
|
217
|
+
headers["Content-Type"] = "application/json"
|
|
218
|
+
|
|
219
|
+
# Get models for this provider from Cherry Studio sync
|
|
220
|
+
models = PROVIDER_MODELS.get(self.name, [])
|
|
221
|
+
if not models:
|
|
222
|
+
# Fallback to check_model if no models in registry
|
|
223
|
+
models = [self.check_model]
|
|
224
|
+
|
|
225
|
+
# Try first 5 models
|
|
226
|
+
test_models = models[:5]
|
|
227
|
+
last_body = ""
|
|
228
|
+
last_status = None
|
|
229
|
+
|
|
230
|
+
for model in test_models:
|
|
231
|
+
start = time.monotonic()
|
|
232
|
+
try:
|
|
233
|
+
resp = await client.post(
|
|
234
|
+
f"{self.get_base_url()}/chat/completions",
|
|
235
|
+
headers=headers,
|
|
236
|
+
json={"model": model, "messages": [{"role": "user", "content": "hi"}], "max_tokens": 5}
|
|
237
|
+
)
|
|
238
|
+
latency = (time.monotonic() - start) * 1000
|
|
239
|
+
body = resp.text[:500]
|
|
240
|
+
last_body = body
|
|
241
|
+
last_status = resp.status_code
|
|
242
|
+
|
|
243
|
+
if resp.status_code == 200:
|
|
244
|
+
return CheckResult(True, 200, latency, None, response_body=body)
|
|
245
|
+
elif resp.status_code in (401, 403):
|
|
246
|
+
return CheckResult(False, resp.status_code, latency, "invalid key", response_body=body)
|
|
247
|
+
elif resp.status_code == 429:
|
|
248
|
+
return CheckResult(True, 429, latency, "rate limited", response_body=body)
|
|
249
|
+
# For other errors, try next model
|
|
250
|
+
except Exception as e:
|
|
251
|
+
last_body = str(e)
|
|
252
|
+
last_status = None
|
|
253
|
+
|
|
254
|
+
# All models failed - return last error for signature matching
|
|
255
|
+
latency = (time.monotonic() - start) * 1000 if test_models else 0
|
|
256
|
+
return CheckResult(False, last_status, latency, f"all models failed", response_body=last_body)
|
|
257
|
+
|
|
258
|
+
async def get_models(self, client, key: str) -> list[str]:
|
|
259
|
+
"""Get list of available models. Override in subclass for custom implementation."""
|
|
260
|
+
headers = self.build_headers(key)
|
|
261
|
+
try:
|
|
262
|
+
resp = await client.get(
|
|
263
|
+
f"{self.get_base_url()}{self.check_endpoint}",
|
|
264
|
+
headers=headers
|
|
265
|
+
)
|
|
266
|
+
if resp.status_code == 200:
|
|
267
|
+
data = resp.json()
|
|
268
|
+
# Try to extract model IDs from response
|
|
269
|
+
if isinstance(data, dict):
|
|
270
|
+
if "data" in data:
|
|
271
|
+
return [m.get("id", "") for m in data["data"] if m.get("id")]
|
|
272
|
+
elif "models" in data:
|
|
273
|
+
return [m.get("name", "") for m in data["models"] if m.get("name")]
|
|
274
|
+
elif isinstance(data, list):
|
|
275
|
+
return [m.get("id", m.get("name", "")) for m in data if isinstance(m, dict)]
|
|
276
|
+
return []
|
|
277
|
+
except Exception:
|
|
278
|
+
return []
|
|
279
|
+
|
|
280
|
+
async def test_concurrency_for_model(self, client, key: str, model: str, concurrency_steps: list[int]) -> TestResult:
|
|
281
|
+
"""Test concurrency for a specific model using chat completions. Override in subclass for custom implementation."""
|
|
282
|
+
headers = self.build_headers(key)
|
|
283
|
+
headers["Content-Type"] = "application/json"
|
|
284
|
+
last_success = None
|
|
285
|
+
for step in concurrency_steps:
|
|
286
|
+
tasks = [self._probe_model(client, headers, model) for _ in range(step)]
|
|
287
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
288
|
+
rate_limited = sum(1 for r in results if not isinstance(r, Exception) and not r)
|
|
289
|
+
if rate_limited / step >= 0.3:
|
|
290
|
+
break
|
|
291
|
+
last_success = step
|
|
292
|
+
return TestResult(max_concurrency=last_success)
|
|
293
|
+
|
|
294
|
+
async def _probe_model(self, client, headers, model: str) -> bool:
|
|
295
|
+
"""Probe a specific model with a minimal chat completion request. Override in subclass for custom implementation."""
|
|
296
|
+
try:
|
|
297
|
+
resp = await client.post(
|
|
298
|
+
f"{self.get_base_url()}/chat/completions",
|
|
299
|
+
headers=headers,
|
|
300
|
+
json={
|
|
301
|
+
"model": model,
|
|
302
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
303
|
+
"max_tokens": 1
|
|
304
|
+
}
|
|
305
|
+
)
|
|
306
|
+
return resp.status_code == 200
|
|
307
|
+
except Exception:
|
|
308
|
+
return False
|
|
309
|
+
|
|
310
|
+
async def get_balance(self, client, key: str) -> BalanceResult:
|
|
311
|
+
"""Get account balance. Override in subclass for custom implementation."""
|
|
312
|
+
return BalanceResult(supported=False)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from .base import ProviderBase, CheckResult, TestResult
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CerebrasProvider(ProviderBase):
|
|
7
|
+
name = "cerebras"
|
|
8
|
+
base_url = "https://api.cerebras.ai/v1"
|
|
9
|
+
check_endpoint = "/models"
|
|
10
|
+
|
|
11
|
+
def build_headers(self, key: str) -> dict:
|
|
12
|
+
return {"Authorization": f"Bearer {key}"}
|
|
13
|
+
|
|
14
|
+
async def get_models(self, client, key: str) -> list[str]:
|
|
15
|
+
headers = self.build_headers(key)
|
|
16
|
+
try:
|
|
17
|
+
resp = await client.get(
|
|
18
|
+
f"{self.get_base_url()}{self.check_endpoint}",
|
|
19
|
+
headers=headers
|
|
20
|
+
)
|
|
21
|
+
if resp.status_code == 200:
|
|
22
|
+
data = resp.json()
|
|
23
|
+
if "data" in data:
|
|
24
|
+
return [m["id"] for m in data["data"] if "id" in m]
|
|
25
|
+
return []
|
|
26
|
+
except Exception:
|
|
27
|
+
return []
|
|
28
|
+
|
|
29
|
+
async def check(self, client, key: str) -> CheckResult:
|
|
30
|
+
"""Real usage test - try to make a minimal chat completion request."""
|
|
31
|
+
headers = self.build_headers(key)
|
|
32
|
+
headers["Content-Type"] = "application/json"
|
|
33
|
+
start = time.monotonic()
|
|
34
|
+
try:
|
|
35
|
+
resp = await client.post(
|
|
36
|
+
f"{self.get_base_url()}/chat/completions",
|
|
37
|
+
headers=headers,
|
|
38
|
+
json={
|
|
39
|
+
"model": "llama3.1-8b",
|
|
40
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
41
|
+
"max_tokens": 5
|
|
42
|
+
}
|
|
43
|
+
)
|
|
44
|
+
latency = (time.monotonic() - start) * 1000
|
|
45
|
+
|
|
46
|
+
if resp.status_code == 200:
|
|
47
|
+
return CheckResult(True, 200, latency, None)
|
|
48
|
+
elif resp.status_code in (401, 403):
|
|
49
|
+
return CheckResult(False, resp.status_code, latency, "invalid key or forbidden")
|
|
50
|
+
elif resp.status_code == 429:
|
|
51
|
+
return CheckResult(False, 429, latency, "rate limited")
|
|
52
|
+
else:
|
|
53
|
+
try:
|
|
54
|
+
data = resp.json()
|
|
55
|
+
error_msg = data.get("error", {}).get("message", f"status {resp.status_code}")
|
|
56
|
+
except:
|
|
57
|
+
error_msg = f"status {resp.status_code}"
|
|
58
|
+
return CheckResult(False, resp.status_code, latency, error_msg)
|
|
59
|
+
except Exception as e:
|
|
60
|
+
return CheckResult(False, None, (time.monotonic() - start) * 1000, str(e))
|
|
61
|
+
|
|
62
|
+
async def test_token_limit(self, client, key: str, token_steps: list[int]) -> TestResult:
|
|
63
|
+
headers = self.build_headers(key)
|
|
64
|
+
last_success = None
|
|
65
|
+
for step in token_steps:
|
|
66
|
+
try:
|
|
67
|
+
resp = await client.post(
|
|
68
|
+
f"{self.get_base_url()}/chat/completions",
|
|
69
|
+
headers=headers,
|
|
70
|
+
json={
|
|
71
|
+
"model": "llama3.1-8b",
|
|
72
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
73
|
+
"max_tokens": step
|
|
74
|
+
}
|
|
75
|
+
)
|
|
76
|
+
if resp.status_code == 200:
|
|
77
|
+
last_success = step
|
|
78
|
+
elif resp.status_code in (400, 413):
|
|
79
|
+
break
|
|
80
|
+
elif resp.status_code == 429:
|
|
81
|
+
await asyncio.sleep(1)
|
|
82
|
+
continue
|
|
83
|
+
else:
|
|
84
|
+
break
|
|
85
|
+
except Exception:
|
|
86
|
+
break
|
|
87
|
+
return TestResult(max_tokens=last_success)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
async def check_real(self, client, key: str) -> CheckResult:
|
|
91
|
+
return await self.check(client, key)
|
|
92
|
+
async def test_concurrency(self, client, key: str, concurrency_steps: list[int]) -> TestResult:
|
|
93
|
+
headers = self.build_headers(key)
|
|
94
|
+
last_success = None
|
|
95
|
+
for step in concurrency_steps:
|
|
96
|
+
tasks = [self._probe(client, headers) for _ in range(step)]
|
|
97
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
98
|
+
rate_limited = sum(1 for r in results if not isinstance(r, Exception) and not r)
|
|
99
|
+
if rate_limited / step >= 0.3:
|
|
100
|
+
break
|
|
101
|
+
last_success = step
|
|
102
|
+
return TestResult(max_concurrency=last_success)
|
|
103
|
+
|
|
104
|
+
async def _probe(self, client, headers) -> bool:
|
|
105
|
+
try:
|
|
106
|
+
resp = await client.get(f"{self.get_base_url()}{self.check_endpoint}", headers=headers)
|
|
107
|
+
return resp.status_code == 200
|
|
108
|
+
except Exception:
|
|
109
|
+
return False
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from .base import ProviderBase, CheckResult, TestResult
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CohereProvider(ProviderBase):
|
|
7
|
+
name = "cohere"
|
|
8
|
+
base_url = "https://api.cohere.ai"
|
|
9
|
+
check_endpoint = "/v1/check-api-key"
|
|
10
|
+
|
|
11
|
+
def build_headers(self, key: str) -> dict:
|
|
12
|
+
return {"Authorization": f"Bearer {key}"}
|
|
13
|
+
|
|
14
|
+
async def get_models(self, client, key: str) -> list[str]:
|
|
15
|
+
headers = self.build_headers(key)
|
|
16
|
+
try:
|
|
17
|
+
resp = await client.get(
|
|
18
|
+
f"{self.get_base_url()}/v1/models",
|
|
19
|
+
headers=headers
|
|
20
|
+
)
|
|
21
|
+
if resp.status_code == 200:
|
|
22
|
+
data = resp.json()
|
|
23
|
+
if "data" in data:
|
|
24
|
+
return [m["id"] for m in data["data"] if "id" in m]
|
|
25
|
+
return []
|
|
26
|
+
except Exception:
|
|
27
|
+
return []
|
|
28
|
+
|
|
29
|
+
async def check(self, client, key: str) -> CheckResult:
|
|
30
|
+
headers = self.build_headers(key)
|
|
31
|
+
start = time.monotonic()
|
|
32
|
+
try:
|
|
33
|
+
resp = await client.post(f"{self.get_base_url()}{self.check_endpoint}", headers=headers)
|
|
34
|
+
latency = (time.monotonic() - start) * 1000
|
|
35
|
+
if resp.status_code == 200:
|
|
36
|
+
return CheckResult(True, 200, latency, None)
|
|
37
|
+
elif resp.status_code in (401, 403):
|
|
38
|
+
return CheckResult(False, resp.status_code, latency, "invalid key")
|
|
39
|
+
elif resp.status_code == 429:
|
|
40
|
+
return CheckResult(True, 429, latency, "rate limited")
|
|
41
|
+
else:
|
|
42
|
+
return CheckResult(False, resp.status_code, latency, f"status {resp.status_code}")
|
|
43
|
+
except Exception as e:
|
|
44
|
+
return CheckResult(False, None, (time.monotonic() - start) * 1000, str(e))
|
|
45
|
+
|
|
46
|
+
async def test_token_limit(self, client, key: str, token_steps: list[int]) -> TestResult:
|
|
47
|
+
headers = self.build_headers(key)
|
|
48
|
+
last_success = None
|
|
49
|
+
for step in token_steps:
|
|
50
|
+
try:
|
|
51
|
+
resp = await client.post(
|
|
52
|
+
f"{self.get_base_url()}/v1/chat",
|
|
53
|
+
headers=headers,
|
|
54
|
+
json={
|
|
55
|
+
"model": "command-r",
|
|
56
|
+
"message": "hi",
|
|
57
|
+
"max_tokens": step
|
|
58
|
+
}
|
|
59
|
+
)
|
|
60
|
+
if resp.status_code == 200:
|
|
61
|
+
last_success = step
|
|
62
|
+
elif resp.status_code in (400, 413):
|
|
63
|
+
break
|
|
64
|
+
elif resp.status_code == 429:
|
|
65
|
+
await asyncio.sleep(1)
|
|
66
|
+
continue
|
|
67
|
+
else:
|
|
68
|
+
break
|
|
69
|
+
except Exception:
|
|
70
|
+
break
|
|
71
|
+
return TestResult(max_tokens=last_success)
|
|
72
|
+
|
|
73
|
+
async def test_concurrency(self, client, key: str, concurrency_steps: list[int]) -> TestResult:
|
|
74
|
+
headers = self.build_headers(key)
|
|
75
|
+
last_success = None
|
|
76
|
+
for step in concurrency_steps:
|
|
77
|
+
tasks = [self._probe(client, headers) for _ in range(step)]
|
|
78
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
79
|
+
rate_limited = sum(1 for r in results if not isinstance(r, Exception) and not r)
|
|
80
|
+
if rate_limited / step >= 0.3:
|
|
81
|
+
break
|
|
82
|
+
last_success = step
|
|
83
|
+
return TestResult(max_concurrency=last_success)
|
|
84
|
+
|
|
85
|
+
async def _probe(self, client, headers) -> bool:
|
|
86
|
+
try:
|
|
87
|
+
resp = await client.post(f"{self.get_base_url()}{self.check_endpoint}", headers=headers)
|
|
88
|
+
return resp.status_code == 200
|
|
89
|
+
except Exception:
|
|
90
|
+
return False
|