python-basekit 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- basekit/__init__.py +0 -0
- basekit/ai/clients/anthropic.py +192 -0
- basekit/ai/clients/dashscope.py +133 -0
- basekit/ai/clients/gemini.py +277 -0
- basekit/ai/clients/openai.py +384 -0
- basekit/ai/schema.py +160 -0
- basekit/ai/utils.py +63 -0
- basekit/cache/clients/sqlite.py +90 -0
- basekit/cache/schema.py +53 -0
- basekit/cache/utils.py +113 -0
- basekit/database.py +33 -0
- basekit/http/clients/curl_cffi.py +99 -0
- basekit/http/clients/httpx.py +100 -0
- basekit/http/schema.py +111 -0
- basekit/http/utils.py +31 -0
- basekit/limiter.py +179 -0
- basekit/py.typed +0 -0
- basekit/utils/batch.py +21 -0
- basekit/utils/console.py +26 -0
- basekit/utils/html.py +54 -0
- basekit/utils/jinja.py +20 -0
- basekit/utils/markdown.py +38 -0
- basekit/utils/mime.py +39 -0
- basekit/utils/misc.py +35 -0
- python_basekit-0.0.11.dist-info/METADATA +46 -0
- python_basekit-0.0.11.dist-info/RECORD +27 -0
- python_basekit-0.0.11.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import ClassVar, override
|
|
3
|
+
|
|
4
|
+
import logfire
|
|
5
|
+
from httpx import Response
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from basekit.ai.schema import AIClient, Tool, Usage, ValidationError
|
|
9
|
+
from basekit.ai.utils import transform_schema
|
|
10
|
+
from basekit.utils.misc import recursive_merge
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OpenAIClient(AIClient):
|
|
14
|
+
cache_key_fields: ClassVar[tuple[str, ...]] = ()
|
|
15
|
+
|
|
16
|
+
api_key: str
|
|
17
|
+
base_url: str = "https://api.openai.com/v1"
|
|
18
|
+
|
|
19
|
+
def _headers(self, /) -> dict:
|
|
20
|
+
return {
|
|
21
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
22
|
+
"Content-Type": "application/json",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
def _message_url(self, /) -> str:
|
|
26
|
+
return f"{self.base_url}/responses"
|
|
27
|
+
|
|
28
|
+
def _message_params(self, system: str, prompt: str, /) -> dict:
|
|
29
|
+
return {
|
|
30
|
+
"model": self.model.request_name,
|
|
31
|
+
"instructions": system,
|
|
32
|
+
"input": [{"role": "user", "content": prompt}],
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def _structured_message_params(self, name: str, schema: dict, /) -> dict:
|
|
36
|
+
config: dict = {"name": name, "schema": schema, "strict": True}
|
|
37
|
+
return {"text": {"format": {"type": "json_schema", **config}}}
|
|
38
|
+
|
|
39
|
+
def _tool_use_params(self, tool: dict, /) -> dict:
|
|
40
|
+
return {"tools": [tool], "tool_choice": "required"}
|
|
41
|
+
|
|
42
|
+
def _check_status(self, data: dict, /) -> bool:
|
|
43
|
+
return "status" not in data or data["status"] == "completed"
|
|
44
|
+
|
|
45
|
+
def _get_block(self, data: dict, filter_: Callable[[dict], bool], /) -> dict:
|
|
46
|
+
if not self._check_status(data):
|
|
47
|
+
raise ValueError(f"STATUS:{data['status']}")
|
|
48
|
+
|
|
49
|
+
blocks: list[dict] = [
|
|
50
|
+
o for o in data["output"] if filter_(o) and self._check_status(o)
|
|
51
|
+
]
|
|
52
|
+
if len(blocks) != 1:
|
|
53
|
+
raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
|
|
54
|
+
|
|
55
|
+
return blocks[0]
|
|
56
|
+
|
|
57
|
+
def _get_text(self, block: dict, /) -> str:
|
|
58
|
+
contents: list[dict] = [
|
|
59
|
+
c for c in block["content"] if c["type"] == "output_text"
|
|
60
|
+
]
|
|
61
|
+
if len(contents) != 1:
|
|
62
|
+
raise ValueError(f"LEN(CONTENTS):{len(contents)}")
|
|
63
|
+
|
|
64
|
+
return contents[0]["text"]
|
|
65
|
+
|
|
66
|
+
def _get_usage(self, data: dict, /) -> Usage:
|
|
67
|
+
usage: dict = data["usage"]
|
|
68
|
+
return Usage(
|
|
69
|
+
input_tokens=usage["input_tokens"],
|
|
70
|
+
cache_creation_tokens=0,
|
|
71
|
+
cache_read_tokens=usage["input_tokens_details"]["cached_tokens"],
|
|
72
|
+
output_tokens=usage["output_tokens"],
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
async def _create_message_once(
|
|
77
|
+
self, system: str, prompt: str, extra: dict | None
|
|
78
|
+
) -> str:
|
|
79
|
+
with logfire.span("openai client | create message") as span:
|
|
80
|
+
span.set_attribute(key="model", value=self.model)
|
|
81
|
+
span.set_attribute(key="system", value=system)
|
|
82
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
83
|
+
span.set_attribute(key="extra", value=extra)
|
|
84
|
+
|
|
85
|
+
async with self._limiter:
|
|
86
|
+
response: Response = await self._http_client.post(
|
|
87
|
+
self._message_url(),
|
|
88
|
+
headers=self._headers(),
|
|
89
|
+
data=recursive_merge(
|
|
90
|
+
self._message_params(system, prompt),
|
|
91
|
+
self.model.extra_params,
|
|
92
|
+
extra,
|
|
93
|
+
),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
data: dict = response.json()
|
|
98
|
+
block: dict = self._get_block(data, lambda x: x["type"] == "message")
|
|
99
|
+
message: str = self._get_text(block)
|
|
100
|
+
usage: Usage = self._get_usage(data)
|
|
101
|
+
|
|
102
|
+
span.set_attribute(key="return", value=message)
|
|
103
|
+
span.set_attribute(key="usage", value=usage)
|
|
104
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
105
|
+
return message
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
raise ValidationError("生成消息失败") from exc
|
|
108
|
+
|
|
109
|
+
@override
|
|
110
|
+
async def _create_structured_message_once[T: BaseModel](
|
|
111
|
+
self, system: str, prompt: str, schema: type[T], extra: dict | None
|
|
112
|
+
) -> T:
|
|
113
|
+
schema_name: str = schema.__name__
|
|
114
|
+
with logfire.span(
|
|
115
|
+
f"openai client | create structured message | {schema_name}"
|
|
116
|
+
) as span:
|
|
117
|
+
schema_data: dict = transform_schema(schema.model_json_schema())
|
|
118
|
+
|
|
119
|
+
span.set_attribute(key="model", value=self.model)
|
|
120
|
+
span.set_attribute(key="system", value=system)
|
|
121
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
122
|
+
span.set_attribute(key="schema", value=schema_data)
|
|
123
|
+
span.set_attribute(key="extra", value=extra)
|
|
124
|
+
|
|
125
|
+
async with self._limiter:
|
|
126
|
+
response: Response = await self._http_client.post(
|
|
127
|
+
self._message_url(),
|
|
128
|
+
headers=self._headers(),
|
|
129
|
+
data=recursive_merge(
|
|
130
|
+
self._message_params(system, prompt),
|
|
131
|
+
self._structured_message_params(schema_name, schema_data),
|
|
132
|
+
self.model.extra_params,
|
|
133
|
+
extra,
|
|
134
|
+
),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
data: dict = response.json()
|
|
139
|
+
block: dict = self._get_block(data, lambda x: x["type"] == "message")
|
|
140
|
+
message: str = self._get_text(block)
|
|
141
|
+
result: T = schema.model_validate_json(json_data=message)
|
|
142
|
+
usage: Usage = self._get_usage(data)
|
|
143
|
+
|
|
144
|
+
span.set_attribute(key="return", value=result)
|
|
145
|
+
span.set_attribute(key="usage", value=usage)
|
|
146
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
147
|
+
return result
|
|
148
|
+
except Exception as exc:
|
|
149
|
+
raise ValidationError("生成结构化消息失败") from exc
|
|
150
|
+
|
|
151
|
+
@override
|
|
152
|
+
async def _create_tool_use_once[T: BaseModel](
|
|
153
|
+
self, system: str, prompt: str, tool: Tool[T], extra: dict | None
|
|
154
|
+
) -> T:
|
|
155
|
+
tool_name: str = tool.name
|
|
156
|
+
with logfire.span(f"openai client | create tool use | {tool_name}") as span:
|
|
157
|
+
tool_: dict = {
|
|
158
|
+
"type": "function",
|
|
159
|
+
"name": tool.name,
|
|
160
|
+
"description": tool.description,
|
|
161
|
+
"parameters": transform_schema(tool.input_schema.model_json_schema()),
|
|
162
|
+
"strict": True,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
span.set_attribute(key="model", value=self.model)
|
|
166
|
+
span.set_attribute(key="system", value=system)
|
|
167
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
168
|
+
span.set_attribute(key="tool", value=tool_)
|
|
169
|
+
span.set_attribute(key="extra", value=extra)
|
|
170
|
+
|
|
171
|
+
async with self._limiter:
|
|
172
|
+
response: Response = await self._http_client.post(
|
|
173
|
+
self._message_url(),
|
|
174
|
+
headers=self._headers(),
|
|
175
|
+
data=recursive_merge(
|
|
176
|
+
self._message_params(system, prompt),
|
|
177
|
+
self._tool_use_params(tool_),
|
|
178
|
+
self.model.extra_params,
|
|
179
|
+
extra,
|
|
180
|
+
),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
data: dict = response.json()
|
|
185
|
+
block: dict = self._get_block(
|
|
186
|
+
data, lambda x: x["type"] == "function_call"
|
|
187
|
+
)
|
|
188
|
+
result: T = tool.input_schema.model_validate_json(
|
|
189
|
+
json_data=block["arguments"]
|
|
190
|
+
)
|
|
191
|
+
usage: Usage = self._get_usage(data)
|
|
192
|
+
|
|
193
|
+
span.set_attribute(key="return", value=result)
|
|
194
|
+
span.set_attribute(key="usage", value=usage)
|
|
195
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
196
|
+
return result
|
|
197
|
+
except Exception as exc:
|
|
198
|
+
raise ValidationError("生成工具调用失败") from exc
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class OpenAILegacyClient(AIClient):
|
|
202
|
+
cache_key_fields: ClassVar[tuple[str, ...]] = ()
|
|
203
|
+
|
|
204
|
+
api_key: str
|
|
205
|
+
base_url: str = "https://api.openai.com/v1"
|
|
206
|
+
|
|
207
|
+
def _headers(self, /) -> dict:
|
|
208
|
+
return {
|
|
209
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
210
|
+
"Content-Type": "application/json",
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
def _message_url(self, /) -> str:
|
|
214
|
+
return f"{self.base_url}/chat/completions"
|
|
215
|
+
|
|
216
|
+
def _message_params(self, system: str, prompt: str, /) -> dict:
|
|
217
|
+
return {
|
|
218
|
+
"model": self.model.request_name,
|
|
219
|
+
"messages": [
|
|
220
|
+
{"role": "system", "content": system},
|
|
221
|
+
{"role": "user", "content": prompt},
|
|
222
|
+
],
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
def _structured_message_params(self, name: str, schema: dict, /) -> dict:
|
|
226
|
+
config: dict = {"name": name, "schema": schema, "strict": True}
|
|
227
|
+
return {"response_format": {"type": "json_schema", "json_schema": config}}
|
|
228
|
+
|
|
229
|
+
def _tool_use_params(self, tool: dict, /) -> dict:
|
|
230
|
+
return {"tools": [tool], "tool_choice": "required"}
|
|
231
|
+
|
|
232
|
+
def _get_block(self, data: dict, finish_reason: str, /) -> dict:
|
|
233
|
+
choices: list[dict] = data["choices"]
|
|
234
|
+
if len(choices) != 1:
|
|
235
|
+
raise ValueError(f"LEN(BLOCKS):{len(choices)}")
|
|
236
|
+
|
|
237
|
+
choice: dict = choices[0]
|
|
238
|
+
if choice["finish_reason"] != finish_reason:
|
|
239
|
+
raise ValueError(f"FINISH REASON:{choice['finish_reason']}")
|
|
240
|
+
|
|
241
|
+
return choice["message"]
|
|
242
|
+
|
|
243
|
+
def _get_arguments(self, block: dict, /) -> str:
|
|
244
|
+
arguments: list[str] = [t["function"]["arguments"] for t in block["tool_calls"]]
|
|
245
|
+
if len(arguments) != 1:
|
|
246
|
+
raise ValueError(f"LEN(ARGUMENTS):{len(arguments)}")
|
|
247
|
+
|
|
248
|
+
return arguments[0]
|
|
249
|
+
|
|
250
|
+
def _get_usage(self, data: dict, /) -> Usage:
|
|
251
|
+
usage: dict = data["usage"]
|
|
252
|
+
return Usage(
|
|
253
|
+
input_tokens=usage["prompt_tokens"],
|
|
254
|
+
cache_creation_tokens=0,
|
|
255
|
+
cache_read_tokens=usage["prompt_tokens_details"]["cached_tokens"],
|
|
256
|
+
output_tokens=usage["completion_tokens"],
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
@override
|
|
260
|
+
async def _create_message_once(
|
|
261
|
+
self, system: str, prompt: str, extra: dict | None
|
|
262
|
+
) -> str:
|
|
263
|
+
with logfire.span("openai client (legacy) | create message") as span:
|
|
264
|
+
span.set_attribute(key="model", value=self.model)
|
|
265
|
+
span.set_attribute(key="system", value=system)
|
|
266
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
267
|
+
span.set_attribute(key="extra", value=extra)
|
|
268
|
+
|
|
269
|
+
async with self._limiter:
|
|
270
|
+
response: Response = await self._http_client.post(
|
|
271
|
+
self._message_url(),
|
|
272
|
+
headers=self._headers(),
|
|
273
|
+
data=recursive_merge(
|
|
274
|
+
self._message_params(system, prompt),
|
|
275
|
+
self.model.extra_params,
|
|
276
|
+
extra,
|
|
277
|
+
),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
data: dict = response.json()
|
|
282
|
+
block: dict = self._get_block(data, "stop")
|
|
283
|
+
message: str = block["content"]
|
|
284
|
+
usage: Usage = self._get_usage(data)
|
|
285
|
+
|
|
286
|
+
span.set_attribute(key="return", value=message)
|
|
287
|
+
span.set_attribute(key="usage", value=usage)
|
|
288
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
289
|
+
return message
|
|
290
|
+
except Exception as exc:
|
|
291
|
+
raise ValidationError("生成消息失败") from exc
|
|
292
|
+
|
|
293
|
+
@override
|
|
294
|
+
async def _create_structured_message_once[T: BaseModel](
|
|
295
|
+
self, system: str, prompt: str, schema: type[T], extra: dict | None
|
|
296
|
+
) -> T:
|
|
297
|
+
schema_name: str = schema.__name__
|
|
298
|
+
with logfire.span(
|
|
299
|
+
f"openai client (legacy) | create structured message | {schema_name}"
|
|
300
|
+
) as span:
|
|
301
|
+
schema_data: dict = transform_schema(schema.model_json_schema())
|
|
302
|
+
|
|
303
|
+
span.set_attribute(key="model", value=self.model)
|
|
304
|
+
span.set_attribute(key="system", value=system)
|
|
305
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
306
|
+
span.set_attribute(key="schema", value=schema_data)
|
|
307
|
+
span.set_attribute(key="extra", value=extra)
|
|
308
|
+
|
|
309
|
+
async with self._limiter:
|
|
310
|
+
response: Response = await self._http_client.post(
|
|
311
|
+
self._message_url(),
|
|
312
|
+
headers=self._headers(),
|
|
313
|
+
data=recursive_merge(
|
|
314
|
+
self._message_params(system, prompt),
|
|
315
|
+
self._structured_message_params(schema_name, schema_data),
|
|
316
|
+
self.model.extra_params,
|
|
317
|
+
extra,
|
|
318
|
+
),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
data: dict = response.json()
|
|
323
|
+
block: dict = self._get_block(data, "stop")
|
|
324
|
+
result: T = schema.model_validate_json(json_data=block["content"])
|
|
325
|
+
usage: Usage = self._get_usage(data)
|
|
326
|
+
|
|
327
|
+
span.set_attribute(key="return", value=result)
|
|
328
|
+
span.set_attribute(key="usage", value=usage)
|
|
329
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
330
|
+
return result
|
|
331
|
+
except Exception as exc:
|
|
332
|
+
raise ValidationError("生成结构化消息失败") from exc
|
|
333
|
+
|
|
334
|
+
@override
|
|
335
|
+
async def _create_tool_use_once[T: BaseModel](
|
|
336
|
+
self, system: str, prompt: str, tool: Tool[T], extra: dict | None
|
|
337
|
+
) -> T:
|
|
338
|
+
tool_name: str = tool.name
|
|
339
|
+
with logfire.span(
|
|
340
|
+
f"openai client (legacy) | create tool use | {tool_name}"
|
|
341
|
+
) as span:
|
|
342
|
+
tool_: dict = {
|
|
343
|
+
"type": "function",
|
|
344
|
+
"function": {
|
|
345
|
+
"name": tool.name,
|
|
346
|
+
"description": tool.description,
|
|
347
|
+
"parameters": transform_schema(
|
|
348
|
+
tool.input_schema.model_json_schema()
|
|
349
|
+
),
|
|
350
|
+
"strict": True,
|
|
351
|
+
},
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
span.set_attribute(key="model", value=self.model)
|
|
355
|
+
span.set_attribute(key="system", value=system)
|
|
356
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
357
|
+
span.set_attribute(key="tool", value=tool_)
|
|
358
|
+
span.set_attribute(key="extra", value=extra)
|
|
359
|
+
|
|
360
|
+
async with self._limiter:
|
|
361
|
+
response: Response = await self._http_client.post(
|
|
362
|
+
self._message_url(),
|
|
363
|
+
headers=self._headers(),
|
|
364
|
+
data=recursive_merge(
|
|
365
|
+
self._message_params(system, prompt),
|
|
366
|
+
self._tool_use_params(tool_),
|
|
367
|
+
self.model.extra_params,
|
|
368
|
+
extra,
|
|
369
|
+
),
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
try:
|
|
373
|
+
data: dict = response.json()
|
|
374
|
+
block: dict = self._get_block(data, "tool_calls")
|
|
375
|
+
arguments: str = self._get_arguments(block)
|
|
376
|
+
result: T = tool.input_schema.model_validate_json(json_data=arguments)
|
|
377
|
+
usage: Usage = self._get_usage(data)
|
|
378
|
+
|
|
379
|
+
span.set_attribute(key="return", value=result)
|
|
380
|
+
span.set_attribute(key="usage", value=usage)
|
|
381
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
382
|
+
return result
|
|
383
|
+
except Exception as exc:
|
|
384
|
+
raise ValidationError("生成工具调用失败") from exc
|
basekit/ai/schema.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from typing import Any, ClassVar, Literal, override
|
|
4
|
+
|
|
5
|
+
from httpx import HTTPStatusError, RequestError
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, PrivateAttr
|
|
7
|
+
from tenacity import (
|
|
8
|
+
retry,
|
|
9
|
+
retry_if_exception,
|
|
10
|
+
stop_after_attempt,
|
|
11
|
+
wait_exponential_jitter,
|
|
12
|
+
)
|
|
13
|
+
from tenacity.stop import stop_base
|
|
14
|
+
from tenacity.wait import wait_base
|
|
15
|
+
|
|
16
|
+
from basekit.cache.schema import Cache
|
|
17
|
+
from basekit.http.clients.httpx import HttpxClient
|
|
18
|
+
from basekit.http.schema import DEFAULT_LIMIT
|
|
19
|
+
from basekit.limiter import AsyncLimiter
|
|
20
|
+
|
|
21
|
+
type ImageSize = Literal["1K", "2K", "4K"]
|
|
22
|
+
type ImageRatio = Literal[
|
|
23
|
+
"9:16", "2:3", "3:4", "4:5", "1:1", "5:4", "4:3", "3:2", "16:9"
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class NotSupportedError(ValueError):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ValidationError(ValueError):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Usage(BaseModel):
|
|
36
|
+
input_tokens: int
|
|
37
|
+
cache_creation_tokens: int
|
|
38
|
+
cache_read_tokens: int
|
|
39
|
+
output_tokens: int
|
|
40
|
+
|
|
41
|
+
def format(self) -> str:
|
|
42
|
+
return f"↓{self.input_tokens}(+{self.cache_creation_tokens},→{self.cache_read_tokens})/↑{self.output_tokens}" # noqa: E501
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Tool[T: BaseModel](BaseModel):
|
|
46
|
+
name: str
|
|
47
|
+
description: str
|
|
48
|
+
input_schema: type[T]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Model(BaseModel):
|
|
52
|
+
cache_key_fields: ClassVar[tuple[str, ...]] = ("name", "extra_params")
|
|
53
|
+
|
|
54
|
+
name: str
|
|
55
|
+
alias: str | None = None
|
|
56
|
+
concurrency_limit: int | None = None
|
|
57
|
+
rate_limit: int | None = None
|
|
58
|
+
extra_params: dict | None = None
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def request_name(self) -> str:
|
|
62
|
+
return self.alias or self.name
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class AIClient(BaseModel):
|
|
66
|
+
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
|
67
|
+
cache_key_fields: ClassVar[tuple[str, ...]] = ("model",)
|
|
68
|
+
|
|
69
|
+
model: Model
|
|
70
|
+
retry_stop: stop_base = stop_after_attempt(max_attempt_number=3)
|
|
71
|
+
retry_wait: wait_base = wait_exponential_jitter()
|
|
72
|
+
cache: Cache | None = None
|
|
73
|
+
|
|
74
|
+
_http_client: HttpxClient = PrivateAttr()
|
|
75
|
+
_limiter: AsyncLimiter = PrivateAttr()
|
|
76
|
+
|
|
77
|
+
@override
|
|
78
|
+
def model_post_init(self, context: Any) -> None:
|
|
79
|
+
self._http_client = HttpxClient(
|
|
80
|
+
timeout=600.0, limit=self.model.concurrency_limit or DEFAULT_LIMIT
|
|
81
|
+
)
|
|
82
|
+
self._limiter = AsyncLimiter(
|
|
83
|
+
concurrency_limit=self.model.concurrency_limit,
|
|
84
|
+
rate_limit=self.model.rate_limit,
|
|
85
|
+
rate_window=60.0,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@asynccontextmanager
|
|
89
|
+
async def lifespan(self) -> AsyncIterator[None]:
|
|
90
|
+
async with self._http_client.lifespan():
|
|
91
|
+
yield
|
|
92
|
+
|
|
93
|
+
def _should_retry(self, exc: BaseException, /) -> bool:
|
|
94
|
+
if isinstance(exc, RequestError):
|
|
95
|
+
return True
|
|
96
|
+
if isinstance(exc, HTTPStatusError):
|
|
97
|
+
status_code: int = exc.response.status_code
|
|
98
|
+
return status_code == 429 or status_code >= 500
|
|
99
|
+
if isinstance(exc, ValidationError):
|
|
100
|
+
return True
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
def _wrap[T, **P](
|
|
104
|
+
self, func: Callable[P, Awaitable[T]], /
|
|
105
|
+
) -> Callable[P, Awaitable[T]]:
|
|
106
|
+
wrapped_func: Callable[P, Awaitable[T]] = retry(
|
|
107
|
+
retry=retry_if_exception(predicate=self._should_retry),
|
|
108
|
+
stop=self.retry_stop,
|
|
109
|
+
wait=self.retry_wait,
|
|
110
|
+
reraise=True,
|
|
111
|
+
)(func)
|
|
112
|
+
return self.cache(wrapped_func) if self.cache else wrapped_func
|
|
113
|
+
|
|
114
|
+
async def _create_message_once(
|
|
115
|
+
self, system: str, prompt: str, extra: dict | None
|
|
116
|
+
) -> str:
|
|
117
|
+
raise NotImplementedError
|
|
118
|
+
|
|
119
|
+
async def create_message(
|
|
120
|
+
self, system: str, prompt: str, extra: dict | None = None
|
|
121
|
+
) -> str:
|
|
122
|
+
return await self._wrap(self._create_message_once)(
|
|
123
|
+
system=system, prompt=prompt, extra=extra
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
async def _create_structured_message_once[T: BaseModel](
|
|
127
|
+
self, system: str, prompt: str, schema: type[T], extra: dict | None
|
|
128
|
+
) -> T:
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
async def create_structured_message[T: BaseModel](
|
|
132
|
+
self, system: str, prompt: str, schema: type[T], extra: dict | None = None
|
|
133
|
+
) -> T:
|
|
134
|
+
return await self._wrap(self._create_structured_message_once)(
|
|
135
|
+
system=system, prompt=prompt, schema=schema, extra=extra
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
async def _create_tool_use_once[T: BaseModel](
|
|
139
|
+
self, system: str, prompt: str, tool: Tool[T], extra: dict | None
|
|
140
|
+
) -> T:
|
|
141
|
+
raise NotImplementedError
|
|
142
|
+
|
|
143
|
+
async def create_tool_use[T: BaseModel](
|
|
144
|
+
self, system: str, prompt: str, tool: Tool[T], extra: dict | None = None
|
|
145
|
+
) -> T:
|
|
146
|
+
return await self._wrap(self._create_tool_use_once)(
|
|
147
|
+
system=system, prompt=prompt, tool=tool, extra=extra
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
async def _create_image_once(
|
|
151
|
+
self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None
|
|
152
|
+
) -> bytes:
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
|
|
155
|
+
async def create_image(
|
|
156
|
+
self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None = None
|
|
157
|
+
) -> bytes:
|
|
158
|
+
return await self._wrap(self._create_image_once)(
|
|
159
|
+
prompt=prompt, ratio=ratio, size=size, extra=extra
|
|
160
|
+
)
|
basekit/ai/utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
3
|
+
from basekit.ai.schema import NotSupportedError
|
|
4
|
+
from basekit.utils.misc import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def transform_schema(schema: Any, /, *, gemini: bool = False) -> dict:
|
|
8
|
+
if not isinstance(schema, dict):
|
|
9
|
+
raise NotSupportedError(f"仅支持字典类型模式,SCHEMA:{type(schema)}")
|
|
10
|
+
|
|
11
|
+
strict_schema: dict = {}
|
|
12
|
+
|
|
13
|
+
if "$ref" in schema:
|
|
14
|
+
strict_schema["$ref"] = schema["$ref"]
|
|
15
|
+
return strict_schema
|
|
16
|
+
|
|
17
|
+
if "$defs" in schema:
|
|
18
|
+
strict_schema["$defs"] = {
|
|
19
|
+
key: transform_schema(value, gemini=gemini)
|
|
20
|
+
for key, value in schema["$defs"].items()
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
if "title" in schema:
|
|
24
|
+
strict_schema["title"] = schema["title"]
|
|
25
|
+
|
|
26
|
+
if "description" in schema:
|
|
27
|
+
strict_schema["description"] = schema["description"]
|
|
28
|
+
|
|
29
|
+
if "enum" in schema:
|
|
30
|
+
strict_schema["enum"] = schema["enum"]
|
|
31
|
+
|
|
32
|
+
if "const" in schema:
|
|
33
|
+
strict_schema["const"] = schema["const"]
|
|
34
|
+
|
|
35
|
+
if "default" in schema:
|
|
36
|
+
strict_schema["default"] = schema["default"]
|
|
37
|
+
|
|
38
|
+
if "anyOf" in schema:
|
|
39
|
+
strict_schema["anyOf"] = [
|
|
40
|
+
transform_schema(cast(dict, item), gemini=gemini)
|
|
41
|
+
for item in schema["anyOf"]
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
if "type" not in schema:
|
|
45
|
+
return strict_schema
|
|
46
|
+
|
|
47
|
+
strict_schema["type"] = schema["type"]
|
|
48
|
+
|
|
49
|
+
match strict_schema["type"]:
|
|
50
|
+
case "object":
|
|
51
|
+
strict_schema["properties"] = {
|
|
52
|
+
key: transform_schema(value, gemini=gemini)
|
|
53
|
+
for key, value in schema["properties"].items()
|
|
54
|
+
}
|
|
55
|
+
strict_schema["required"] = list(strict_schema["properties"].keys())
|
|
56
|
+
if not gemini:
|
|
57
|
+
strict_schema["additionalProperties"] = False
|
|
58
|
+
case "array":
|
|
59
|
+
strict_schema["items"] = transform_schema(schema["items"], gemini=gemini)
|
|
60
|
+
case "string" | "number" | "integer" | "boolean" | "null":
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
return strict_schema
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
from collections.abc import AsyncIterator
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import override
|
|
6
|
+
|
|
7
|
+
from pydantic import PrivateAttr
|
|
8
|
+
from sqlalchemy import (
|
|
9
|
+
Column,
|
|
10
|
+
CursorResult,
|
|
11
|
+
Delete,
|
|
12
|
+
LargeBinary,
|
|
13
|
+
MetaData,
|
|
14
|
+
Row,
|
|
15
|
+
Select,
|
|
16
|
+
String,
|
|
17
|
+
Table,
|
|
18
|
+
delete,
|
|
19
|
+
select,
|
|
20
|
+
)
|
|
21
|
+
from sqlalchemy.dialects.sqlite import Insert, insert
|
|
22
|
+
|
|
23
|
+
from basekit.cache.schema import Cache
|
|
24
|
+
from basekit.database import DatabaseClient
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SQLiteCache(Cache):
|
|
28
|
+
path: Path | None = None
|
|
29
|
+
|
|
30
|
+
_client: DatabaseClient = PrivateAttr()
|
|
31
|
+
_metadata: MetaData = PrivateAttr()
|
|
32
|
+
_table: Table = PrivateAttr()
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def model_post_init(self, context) -> None:
|
|
36
|
+
self.path = self.path.expanduser().resolve() if self.path else None
|
|
37
|
+
path: str = self.path.as_posix() if self.path else ":memory:"
|
|
38
|
+
url: str = f"sqlite+aiosqlite:///{path}"
|
|
39
|
+
self._client = DatabaseClient(url=url)
|
|
40
|
+
self._metadata = MetaData()
|
|
41
|
+
self._table = Table(
|
|
42
|
+
"cache",
|
|
43
|
+
self._metadata,
|
|
44
|
+
Column("sha256", String(length=64), nullable=False, primary_key=True),
|
|
45
|
+
Column("key", LargeBinary(), nullable=False),
|
|
46
|
+
Column("value", LargeBinary(), nullable=False),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
@asynccontextmanager
|
|
51
|
+
async def lifespan(self) -> AsyncIterator[None]:
|
|
52
|
+
async with self._client.lifespan():
|
|
53
|
+
await self._client.create_schema(self._metadata)
|
|
54
|
+
yield
|
|
55
|
+
|
|
56
|
+
def _sha256(self, key: bytes, /) -> str:
|
|
57
|
+
return hashlib.sha256(string=key).hexdigest()
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
async def set(self, key: bytes, value: bytes, /) -> None:
|
|
61
|
+
async with self._client.connection() as connection:
|
|
62
|
+
statement: Insert = (
|
|
63
|
+
insert(table=self._table)
|
|
64
|
+
.values(sha256=self._sha256(key), key=key, value=value)
|
|
65
|
+
.on_conflict_do_update(
|
|
66
|
+
index_elements=[self._table.c.sha256],
|
|
67
|
+
set_={"key": key, "value": value},
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
await connection.execute(statement=statement)
|
|
71
|
+
await connection.commit()
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
async def get(self, key: bytes, /) -> bytes | None:
|
|
75
|
+
async with self._client.connection() as connection:
|
|
76
|
+
statement: Select = select(self._table.c.value).where(
|
|
77
|
+
self._table.c.sha256 == self._sha256(key)
|
|
78
|
+
)
|
|
79
|
+
result: CursorResult = await connection.execute(statement=statement)
|
|
80
|
+
row: Row | None = result.one_or_none()
|
|
81
|
+
return row.value if row else None
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
async def delete(self, key: bytes, /) -> None:
|
|
85
|
+
async with self._client.connection() as connection:
|
|
86
|
+
statement: Delete = delete(table=self._table).where(
|
|
87
|
+
self._table.c.sha256 == self._sha256(key)
|
|
88
|
+
)
|
|
89
|
+
await connection.execute(statement=statement)
|
|
90
|
+
await connection.commit()
|