python-basekit 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,384 @@
1
+ from collections.abc import Callable
2
+ from typing import ClassVar, override
3
+
4
+ import logfire
5
+ from httpx import Response
6
+ from pydantic import BaseModel
7
+
8
+ from basekit.ai.schema import AIClient, Tool, Usage, ValidationError
9
+ from basekit.ai.utils import transform_schema
10
+ from basekit.utils.misc import recursive_merge
11
+
12
+
13
+ class OpenAIClient(AIClient):
14
+ cache_key_fields: ClassVar[tuple[str, ...]] = ()
15
+
16
+ api_key: str
17
+ base_url: str = "https://api.openai.com/v1"
18
+
19
+ def _headers(self, /) -> dict:
20
+ return {
21
+ "Authorization": f"Bearer {self.api_key}",
22
+ "Content-Type": "application/json",
23
+ }
24
+
25
+ def _message_url(self, /) -> str:
26
+ return f"{self.base_url}/responses"
27
+
28
+ def _message_params(self, system: str, prompt: str, /) -> dict:
29
+ return {
30
+ "model": self.model.request_name,
31
+ "instructions": system,
32
+ "input": [{"role": "user", "content": prompt}],
33
+ }
34
+
35
+ def _structured_message_params(self, name: str, schema: dict, /) -> dict:
36
+ config: dict = {"name": name, "schema": schema, "strict": True}
37
+ return {"text": {"format": {"type": "json_schema", **config}}}
38
+
39
+ def _tool_use_params(self, tool: dict, /) -> dict:
40
+ return {"tools": [tool], "tool_choice": "required"}
41
+
42
+ def _check_status(self, data: dict, /) -> bool:
43
+ return "status" not in data or data["status"] == "completed"
44
+
45
+ def _get_block(self, data: dict, filter_: Callable[[dict], bool], /) -> dict:
46
+ if not self._check_status(data):
47
+ raise ValueError(f"STATUS:{data['status']}")
48
+
49
+ blocks: list[dict] = [
50
+ o for o in data["output"] if filter_(o) and self._check_status(o)
51
+ ]
52
+ if len(blocks) != 1:
53
+ raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
54
+
55
+ return blocks[0]
56
+
57
+ def _get_text(self, block: dict, /) -> str:
58
+ contents: list[dict] = [
59
+ c for c in block["content"] if c["type"] == "output_text"
60
+ ]
61
+ if len(contents) != 1:
62
+ raise ValueError(f"LEN(CONTENTS):{len(contents)}")
63
+
64
+ return contents[0]["text"]
65
+
66
+ def _get_usage(self, data: dict, /) -> Usage:
67
+ usage: dict = data["usage"]
68
+ return Usage(
69
+ input_tokens=usage["input_tokens"],
70
+ cache_creation_tokens=0,
71
+ cache_read_tokens=usage["input_tokens_details"]["cached_tokens"],
72
+ output_tokens=usage["output_tokens"],
73
+ )
74
+
75
+ @override
76
+ async def _create_message_once(
77
+ self, system: str, prompt: str, extra: dict | None
78
+ ) -> str:
79
+ with logfire.span("openai client | create message") as span:
80
+ span.set_attribute(key="model", value=self.model)
81
+ span.set_attribute(key="system", value=system)
82
+ span.set_attribute(key="prompt", value=prompt)
83
+ span.set_attribute(key="extra", value=extra)
84
+
85
+ async with self._limiter:
86
+ response: Response = await self._http_client.post(
87
+ self._message_url(),
88
+ headers=self._headers(),
89
+ data=recursive_merge(
90
+ self._message_params(system, prompt),
91
+ self.model.extra_params,
92
+ extra,
93
+ ),
94
+ )
95
+
96
+ try:
97
+ data: dict = response.json()
98
+ block: dict = self._get_block(data, lambda x: x["type"] == "message")
99
+ message: str = self._get_text(block)
100
+ usage: Usage = self._get_usage(data)
101
+
102
+ span.set_attribute(key="return", value=message)
103
+ span.set_attribute(key="usage", value=usage)
104
+ span.message = f"{span.message} -> {usage.format()} tokens"
105
+ return message
106
+ except Exception as exc:
107
+ raise ValidationError("生成消息失败") from exc
108
+
109
+ @override
110
+ async def _create_structured_message_once[T: BaseModel](
111
+ self, system: str, prompt: str, schema: type[T], extra: dict | None
112
+ ) -> T:
113
+ schema_name: str = schema.__name__
114
+ with logfire.span(
115
+ f"openai client | create structured message | {schema_name}"
116
+ ) as span:
117
+ schema_data: dict = transform_schema(schema.model_json_schema())
118
+
119
+ span.set_attribute(key="model", value=self.model)
120
+ span.set_attribute(key="system", value=system)
121
+ span.set_attribute(key="prompt", value=prompt)
122
+ span.set_attribute(key="schema", value=schema_data)
123
+ span.set_attribute(key="extra", value=extra)
124
+
125
+ async with self._limiter:
126
+ response: Response = await self._http_client.post(
127
+ self._message_url(),
128
+ headers=self._headers(),
129
+ data=recursive_merge(
130
+ self._message_params(system, prompt),
131
+ self._structured_message_params(schema_name, schema_data),
132
+ self.model.extra_params,
133
+ extra,
134
+ ),
135
+ )
136
+
137
+ try:
138
+ data: dict = response.json()
139
+ block: dict = self._get_block(data, lambda x: x["type"] == "message")
140
+ message: str = self._get_text(block)
141
+ result: T = schema.model_validate_json(json_data=message)
142
+ usage: Usage = self._get_usage(data)
143
+
144
+ span.set_attribute(key="return", value=result)
145
+ span.set_attribute(key="usage", value=usage)
146
+ span.message = f"{span.message} -> {usage.format()} tokens"
147
+ return result
148
+ except Exception as exc:
149
+ raise ValidationError("生成结构化消息失败") from exc
150
+
151
+ @override
152
+ async def _create_tool_use_once[T: BaseModel](
153
+ self, system: str, prompt: str, tool: Tool[T], extra: dict | None
154
+ ) -> T:
155
+ tool_name: str = tool.name
156
+ with logfire.span(f"openai client | create tool use | {tool_name}") as span:
157
+ tool_: dict = {
158
+ "type": "function",
159
+ "name": tool.name,
160
+ "description": tool.description,
161
+ "parameters": transform_schema(tool.input_schema.model_json_schema()),
162
+ "strict": True,
163
+ }
164
+
165
+ span.set_attribute(key="model", value=self.model)
166
+ span.set_attribute(key="system", value=system)
167
+ span.set_attribute(key="prompt", value=prompt)
168
+ span.set_attribute(key="tool", value=tool_)
169
+ span.set_attribute(key="extra", value=extra)
170
+
171
+ async with self._limiter:
172
+ response: Response = await self._http_client.post(
173
+ self._message_url(),
174
+ headers=self._headers(),
175
+ data=recursive_merge(
176
+ self._message_params(system, prompt),
177
+ self._tool_use_params(tool_),
178
+ self.model.extra_params,
179
+ extra,
180
+ ),
181
+ )
182
+
183
+ try:
184
+ data: dict = response.json()
185
+ block: dict = self._get_block(
186
+ data, lambda x: x["type"] == "function_call"
187
+ )
188
+ result: T = tool.input_schema.model_validate_json(
189
+ json_data=block["arguments"]
190
+ )
191
+ usage: Usage = self._get_usage(data)
192
+
193
+ span.set_attribute(key="return", value=result)
194
+ span.set_attribute(key="usage", value=usage)
195
+ span.message = f"{span.message} -> {usage.format()} tokens"
196
+ return result
197
+ except Exception as exc:
198
+ raise ValidationError("生成工具调用失败") from exc
199
+
200
+
201
+ class OpenAILegacyClient(AIClient):
202
+ cache_key_fields: ClassVar[tuple[str, ...]] = ()
203
+
204
+ api_key: str
205
+ base_url: str = "https://api.openai.com/v1"
206
+
207
+ def _headers(self, /) -> dict:
208
+ return {
209
+ "Authorization": f"Bearer {self.api_key}",
210
+ "Content-Type": "application/json",
211
+ }
212
+
213
+ def _message_url(self, /) -> str:
214
+ return f"{self.base_url}/chat/completions"
215
+
216
+ def _message_params(self, system: str, prompt: str, /) -> dict:
217
+ return {
218
+ "model": self.model.request_name,
219
+ "messages": [
220
+ {"role": "system", "content": system},
221
+ {"role": "user", "content": prompt},
222
+ ],
223
+ }
224
+
225
+ def _structured_message_params(self, name: str, schema: dict, /) -> dict:
226
+ config: dict = {"name": name, "schema": schema, "strict": True}
227
+ return {"response_format": {"type": "json_schema", "json_schema": config}}
228
+
229
+ def _tool_use_params(self, tool: dict, /) -> dict:
230
+ return {"tools": [tool], "tool_choice": "required"}
231
+
232
+ def _get_block(self, data: dict, finish_reason: str, /) -> dict:
233
+ choices: list[dict] = data["choices"]
234
+ if len(choices) != 1:
235
+ raise ValueError(f"LEN(BLOCKS):{len(choices)}")
236
+
237
+ choice: dict = choices[0]
238
+ if choice["finish_reason"] != finish_reason:
239
+ raise ValueError(f"FINISH REASON:{choice['finish_reason']}")
240
+
241
+ return choice["message"]
242
+
243
+ def _get_arguments(self, block: dict, /) -> str:
244
+ arguments: list[str] = [t["function"]["arguments"] for t in block["tool_calls"]]
245
+ if len(arguments) != 1:
246
+ raise ValueError(f"LEN(ARGUMENTS):{len(arguments)}")
247
+
248
+ return arguments[0]
249
+
250
+ def _get_usage(self, data: dict, /) -> Usage:
251
+ usage: dict = data["usage"]
252
+ return Usage(
253
+ input_tokens=usage["prompt_tokens"],
254
+ cache_creation_tokens=0,
255
+ cache_read_tokens=usage["prompt_tokens_details"]["cached_tokens"],
256
+ output_tokens=usage["completion_tokens"],
257
+ )
258
+
259
+ @override
260
+ async def _create_message_once(
261
+ self, system: str, prompt: str, extra: dict | None
262
+ ) -> str:
263
+ with logfire.span("openai client (legacy) | create message") as span:
264
+ span.set_attribute(key="model", value=self.model)
265
+ span.set_attribute(key="system", value=system)
266
+ span.set_attribute(key="prompt", value=prompt)
267
+ span.set_attribute(key="extra", value=extra)
268
+
269
+ async with self._limiter:
270
+ response: Response = await self._http_client.post(
271
+ self._message_url(),
272
+ headers=self._headers(),
273
+ data=recursive_merge(
274
+ self._message_params(system, prompt),
275
+ self.model.extra_params,
276
+ extra,
277
+ ),
278
+ )
279
+
280
+ try:
281
+ data: dict = response.json()
282
+ block: dict = self._get_block(data, "stop")
283
+ message: str = block["content"]
284
+ usage: Usage = self._get_usage(data)
285
+
286
+ span.set_attribute(key="return", value=message)
287
+ span.set_attribute(key="usage", value=usage)
288
+ span.message = f"{span.message} -> {usage.format()} tokens"
289
+ return message
290
+ except Exception as exc:
291
+ raise ValidationError("生成消息失败") from exc
292
+
293
+ @override
294
+ async def _create_structured_message_once[T: BaseModel](
295
+ self, system: str, prompt: str, schema: type[T], extra: dict | None
296
+ ) -> T:
297
+ schema_name: str = schema.__name__
298
+ with logfire.span(
299
+ f"openai client (legacy) | create structured message | {schema_name}"
300
+ ) as span:
301
+ schema_data: dict = transform_schema(schema.model_json_schema())
302
+
303
+ span.set_attribute(key="model", value=self.model)
304
+ span.set_attribute(key="system", value=system)
305
+ span.set_attribute(key="prompt", value=prompt)
306
+ span.set_attribute(key="schema", value=schema_data)
307
+ span.set_attribute(key="extra", value=extra)
308
+
309
+ async with self._limiter:
310
+ response: Response = await self._http_client.post(
311
+ self._message_url(),
312
+ headers=self._headers(),
313
+ data=recursive_merge(
314
+ self._message_params(system, prompt),
315
+ self._structured_message_params(schema_name, schema_data),
316
+ self.model.extra_params,
317
+ extra,
318
+ ),
319
+ )
320
+
321
+ try:
322
+ data: dict = response.json()
323
+ block: dict = self._get_block(data, "stop")
324
+ result: T = schema.model_validate_json(json_data=block["content"])
325
+ usage: Usage = self._get_usage(data)
326
+
327
+ span.set_attribute(key="return", value=result)
328
+ span.set_attribute(key="usage", value=usage)
329
+ span.message = f"{span.message} -> {usage.format()} tokens"
330
+ return result
331
+ except Exception as exc:
332
+ raise ValidationError("生成结构化消息失败") from exc
333
+
334
+ @override
335
+ async def _create_tool_use_once[T: BaseModel](
336
+ self, system: str, prompt: str, tool: Tool[T], extra: dict | None
337
+ ) -> T:
338
+ tool_name: str = tool.name
339
+ with logfire.span(
340
+ f"openai client (legacy) | create tool use | {tool_name}"
341
+ ) as span:
342
+ tool_: dict = {
343
+ "type": "function",
344
+ "function": {
345
+ "name": tool.name,
346
+ "description": tool.description,
347
+ "parameters": transform_schema(
348
+ tool.input_schema.model_json_schema()
349
+ ),
350
+ "strict": True,
351
+ },
352
+ }
353
+
354
+ span.set_attribute(key="model", value=self.model)
355
+ span.set_attribute(key="system", value=system)
356
+ span.set_attribute(key="prompt", value=prompt)
357
+ span.set_attribute(key="tool", value=tool_)
358
+ span.set_attribute(key="extra", value=extra)
359
+
360
+ async with self._limiter:
361
+ response: Response = await self._http_client.post(
362
+ self._message_url(),
363
+ headers=self._headers(),
364
+ data=recursive_merge(
365
+ self._message_params(system, prompt),
366
+ self._tool_use_params(tool_),
367
+ self.model.extra_params,
368
+ extra,
369
+ ),
370
+ )
371
+
372
+ try:
373
+ data: dict = response.json()
374
+ block: dict = self._get_block(data, "tool_calls")
375
+ arguments: str = self._get_arguments(block)
376
+ result: T = tool.input_schema.model_validate_json(json_data=arguments)
377
+ usage: Usage = self._get_usage(data)
378
+
379
+ span.set_attribute(key="return", value=result)
380
+ span.set_attribute(key="usage", value=usage)
381
+ span.message = f"{span.message} -> {usage.format()} tokens"
382
+ return result
383
+ except Exception as exc:
384
+ raise ValidationError("生成工具调用失败") from exc
basekit/ai/schema.py ADDED
@@ -0,0 +1,160 @@
1
+ from collections.abc import AsyncIterator, Awaitable, Callable
2
+ from contextlib import asynccontextmanager
3
+ from typing import Any, ClassVar, Literal, override
4
+
5
+ from httpx import HTTPStatusError, RequestError
6
+ from pydantic import BaseModel, ConfigDict, PrivateAttr
7
+ from tenacity import (
8
+ retry,
9
+ retry_if_exception,
10
+ stop_after_attempt,
11
+ wait_exponential_jitter,
12
+ )
13
+ from tenacity.stop import stop_base
14
+ from tenacity.wait import wait_base
15
+
16
+ from basekit.cache.schema import Cache
17
+ from basekit.http.clients.httpx import HttpxClient
18
+ from basekit.http.schema import DEFAULT_LIMIT
19
+ from basekit.limiter import AsyncLimiter
20
+
21
+ type ImageSize = Literal["1K", "2K", "4K"]
22
+ type ImageRatio = Literal[
23
+ "9:16", "2:3", "3:4", "4:5", "1:1", "5:4", "4:3", "3:2", "16:9"
24
+ ]
25
+
26
+
27
+ class NotSupportedError(ValueError):
28
+ pass
29
+
30
+
31
+ class ValidationError(ValueError):
32
+ pass
33
+
34
+
35
+ class Usage(BaseModel):
36
+ input_tokens: int
37
+ cache_creation_tokens: int
38
+ cache_read_tokens: int
39
+ output_tokens: int
40
+
41
+ def format(self) -> str:
42
+ return f"↓{self.input_tokens}(+{self.cache_creation_tokens},→{self.cache_read_tokens})/↑{self.output_tokens}" # noqa: E501
43
+
44
+
45
+ class Tool[T: BaseModel](BaseModel):
46
+ name: str
47
+ description: str
48
+ input_schema: type[T]
49
+
50
+
51
+ class Model(BaseModel):
52
+ cache_key_fields: ClassVar[tuple[str, ...]] = ("name", "extra_params")
53
+
54
+ name: str
55
+ alias: str | None = None
56
+ concurrency_limit: int | None = None
57
+ rate_limit: int | None = None
58
+ extra_params: dict | None = None
59
+
60
+ @property
61
+ def request_name(self) -> str:
62
+ return self.alias or self.name
63
+
64
+
65
+ class AIClient(BaseModel):
66
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
67
+ cache_key_fields: ClassVar[tuple[str, ...]] = ("model",)
68
+
69
+ model: Model
70
+ retry_stop: stop_base = stop_after_attempt(max_attempt_number=3)
71
+ retry_wait: wait_base = wait_exponential_jitter()
72
+ cache: Cache | None = None
73
+
74
+ _http_client: HttpxClient = PrivateAttr()
75
+ _limiter: AsyncLimiter = PrivateAttr()
76
+
77
+ @override
78
+ def model_post_init(self, context: Any) -> None:
79
+ self._http_client = HttpxClient(
80
+ timeout=600.0, limit=self.model.concurrency_limit or DEFAULT_LIMIT
81
+ )
82
+ self._limiter = AsyncLimiter(
83
+ concurrency_limit=self.model.concurrency_limit,
84
+ rate_limit=self.model.rate_limit,
85
+ rate_window=60.0,
86
+ )
87
+
88
+ @asynccontextmanager
89
+ async def lifespan(self) -> AsyncIterator[None]:
90
+ async with self._http_client.lifespan():
91
+ yield
92
+
93
+ def _should_retry(self, exc: BaseException, /) -> bool:
94
+ if isinstance(exc, RequestError):
95
+ return True
96
+ if isinstance(exc, HTTPStatusError):
97
+ status_code: int = exc.response.status_code
98
+ return status_code == 429 or status_code >= 500
99
+ if isinstance(exc, ValidationError):
100
+ return True
101
+ return False
102
+
103
+ def _wrap[T, **P](
104
+ self, func: Callable[P, Awaitable[T]], /
105
+ ) -> Callable[P, Awaitable[T]]:
106
+ wrapped_func: Callable[P, Awaitable[T]] = retry(
107
+ retry=retry_if_exception(predicate=self._should_retry),
108
+ stop=self.retry_stop,
109
+ wait=self.retry_wait,
110
+ reraise=True,
111
+ )(func)
112
+ return self.cache(wrapped_func) if self.cache else wrapped_func
113
+
114
+ async def _create_message_once(
115
+ self, system: str, prompt: str, extra: dict | None
116
+ ) -> str:
117
+ raise NotImplementedError
118
+
119
+ async def create_message(
120
+ self, system: str, prompt: str, extra: dict | None = None
121
+ ) -> str:
122
+ return await self._wrap(self._create_message_once)(
123
+ system=system, prompt=prompt, extra=extra
124
+ )
125
+
126
+ async def _create_structured_message_once[T: BaseModel](
127
+ self, system: str, prompt: str, schema: type[T], extra: dict | None
128
+ ) -> T:
129
+ raise NotImplementedError
130
+
131
+ async def create_structured_message[T: BaseModel](
132
+ self, system: str, prompt: str, schema: type[T], extra: dict | None = None
133
+ ) -> T:
134
+ return await self._wrap(self._create_structured_message_once)(
135
+ system=system, prompt=prompt, schema=schema, extra=extra
136
+ )
137
+
138
+ async def _create_tool_use_once[T: BaseModel](
139
+ self, system: str, prompt: str, tool: Tool[T], extra: dict | None
140
+ ) -> T:
141
+ raise NotImplementedError
142
+
143
+ async def create_tool_use[T: BaseModel](
144
+ self, system: str, prompt: str, tool: Tool[T], extra: dict | None = None
145
+ ) -> T:
146
+ return await self._wrap(self._create_tool_use_once)(
147
+ system=system, prompt=prompt, tool=tool, extra=extra
148
+ )
149
+
150
+ async def _create_image_once(
151
+ self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None
152
+ ) -> bytes:
153
+ raise NotImplementedError
154
+
155
+ async def create_image(
156
+ self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None = None
157
+ ) -> bytes:
158
+ return await self._wrap(self._create_image_once)(
159
+ prompt=prompt, ratio=ratio, size=size, extra=extra
160
+ )
basekit/ai/utils.py ADDED
@@ -0,0 +1,63 @@
1
+ from typing import cast
2
+
3
+ from basekit.ai.schema import NotSupportedError
4
+ from basekit.utils.misc import Any
5
+
6
+
7
+ def transform_schema(schema: Any, /, *, gemini: bool = False) -> dict:
8
+ if not isinstance(schema, dict):
9
+ raise NotSupportedError(f"仅支持字典类型模式,SCHEMA:{type(schema)}")
10
+
11
+ strict_schema: dict = {}
12
+
13
+ if "$ref" in schema:
14
+ strict_schema["$ref"] = schema["$ref"]
15
+ return strict_schema
16
+
17
+ if "$defs" in schema:
18
+ strict_schema["$defs"] = {
19
+ key: transform_schema(value, gemini=gemini)
20
+ for key, value in schema["$defs"].items()
21
+ }
22
+
23
+ if "title" in schema:
24
+ strict_schema["title"] = schema["title"]
25
+
26
+ if "description" in schema:
27
+ strict_schema["description"] = schema["description"]
28
+
29
+ if "enum" in schema:
30
+ strict_schema["enum"] = schema["enum"]
31
+
32
+ if "const" in schema:
33
+ strict_schema["const"] = schema["const"]
34
+
35
+ if "default" in schema:
36
+ strict_schema["default"] = schema["default"]
37
+
38
+ if "anyOf" in schema:
39
+ strict_schema["anyOf"] = [
40
+ transform_schema(cast(dict, item), gemini=gemini)
41
+ for item in schema["anyOf"]
42
+ ]
43
+
44
+ if "type" not in schema:
45
+ return strict_schema
46
+
47
+ strict_schema["type"] = schema["type"]
48
+
49
+ match strict_schema["type"]:
50
+ case "object":
51
+ strict_schema["properties"] = {
52
+ key: transform_schema(value, gemini=gemini)
53
+ for key, value in schema["properties"].items()
54
+ }
55
+ strict_schema["required"] = list(strict_schema["properties"].keys())
56
+ if not gemini:
57
+ strict_schema["additionalProperties"] = False
58
+ case "array":
59
+ strict_schema["items"] = transform_schema(schema["items"], gemini=gemini)
60
+ case "string" | "number" | "integer" | "boolean" | "null":
61
+ pass
62
+
63
+ return strict_schema
@@ -0,0 +1,90 @@
1
+ import hashlib
2
+ from collections.abc import AsyncIterator
3
+ from contextlib import asynccontextmanager
4
+ from pathlib import Path
5
+ from typing import override
6
+
7
+ from pydantic import PrivateAttr
8
+ from sqlalchemy import (
9
+ Column,
10
+ CursorResult,
11
+ Delete,
12
+ LargeBinary,
13
+ MetaData,
14
+ Row,
15
+ Select,
16
+ String,
17
+ Table,
18
+ delete,
19
+ select,
20
+ )
21
+ from sqlalchemy.dialects.sqlite import Insert, insert
22
+
23
+ from basekit.cache.schema import Cache
24
+ from basekit.database import DatabaseClient
25
+
26
+
27
+ class SQLiteCache(Cache):
28
+ path: Path | None = None
29
+
30
+ _client: DatabaseClient = PrivateAttr()
31
+ _metadata: MetaData = PrivateAttr()
32
+ _table: Table = PrivateAttr()
33
+
34
+ @override
35
+ def model_post_init(self, context) -> None:
36
+ self.path = self.path.expanduser().resolve() if self.path else None
37
+ path: str = self.path.as_posix() if self.path else ":memory:"
38
+ url: str = f"sqlite+aiosqlite:///{path}"
39
+ self._client = DatabaseClient(url=url)
40
+ self._metadata = MetaData()
41
+ self._table = Table(
42
+ "cache",
43
+ self._metadata,
44
+ Column("sha256", String(length=64), nullable=False, primary_key=True),
45
+ Column("key", LargeBinary(), nullable=False),
46
+ Column("value", LargeBinary(), nullable=False),
47
+ )
48
+
49
+ @override
50
+ @asynccontextmanager
51
+ async def lifespan(self) -> AsyncIterator[None]:
52
+ async with self._client.lifespan():
53
+ await self._client.create_schema(self._metadata)
54
+ yield
55
+
56
+ def _sha256(self, key: bytes, /) -> str:
57
+ return hashlib.sha256(string=key).hexdigest()
58
+
59
+ @override
60
+ async def set(self, key: bytes, value: bytes, /) -> None:
61
+ async with self._client.connection() as connection:
62
+ statement: Insert = (
63
+ insert(table=self._table)
64
+ .values(sha256=self._sha256(key), key=key, value=value)
65
+ .on_conflict_do_update(
66
+ index_elements=[self._table.c.sha256],
67
+ set_={"key": key, "value": value},
68
+ )
69
+ )
70
+ await connection.execute(statement=statement)
71
+ await connection.commit()
72
+
73
+ @override
74
+ async def get(self, key: bytes, /) -> bytes | None:
75
+ async with self._client.connection() as connection:
76
+ statement: Select = select(self._table.c.value).where(
77
+ self._table.c.sha256 == self._sha256(key)
78
+ )
79
+ result: CursorResult = await connection.execute(statement=statement)
80
+ row: Row | None = result.one_or_none()
81
+ return row.value if row else None
82
+
83
+ @override
84
+ async def delete(self, key: bytes, /) -> None:
85
+ async with self._client.connection() as connection:
86
+ statement: Delete = delete(table=self._table).where(
87
+ self._table.c.sha256 == self._sha256(key)
88
+ )
89
+ await connection.execute(statement=statement)
90
+ await connection.commit()