python-basekit 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- basekit/__init__.py +0 -0
- basekit/ai/clients/anthropic.py +192 -0
- basekit/ai/clients/dashscope.py +133 -0
- basekit/ai/clients/gemini.py +277 -0
- basekit/ai/clients/openai.py +384 -0
- basekit/ai/schema.py +160 -0
- basekit/ai/utils.py +63 -0
- basekit/cache/clients/sqlite.py +90 -0
- basekit/cache/schema.py +53 -0
- basekit/cache/utils.py +113 -0
- basekit/database.py +33 -0
- basekit/http/clients/curl_cffi.py +99 -0
- basekit/http/clients/httpx.py +100 -0
- basekit/http/schema.py +111 -0
- basekit/http/utils.py +31 -0
- basekit/limiter.py +179 -0
- basekit/py.typed +0 -0
- basekit/utils/batch.py +21 -0
- basekit/utils/console.py +26 -0
- basekit/utils/html.py +54 -0
- basekit/utils/jinja.py +20 -0
- basekit/utils/markdown.py +38 -0
- basekit/utils/mime.py +39 -0
- basekit/utils/misc.py +35 -0
- python_basekit-0.0.11.dist-info/METADATA +46 -0
- python_basekit-0.0.11.dist-info/RECORD +27 -0
- python_basekit-0.0.11.dist-info/WHEEL +4 -0
basekit/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import ClassVar, override
|
|
3
|
+
|
|
4
|
+
import logfire
|
|
5
|
+
from httpx import Response
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from basekit.ai.schema import AIClient, Tool, Usage, ValidationError
|
|
9
|
+
from basekit.ai.utils import transform_schema
|
|
10
|
+
from basekit.utils.misc import recursive_merge
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AnthropicClient(AIClient):
|
|
14
|
+
cache_key_fields: ClassVar[tuple[str, ...]] = ("max_tokens",)
|
|
15
|
+
|
|
16
|
+
api_key: str
|
|
17
|
+
base_url: str = "https://api.anthropic.com"
|
|
18
|
+
max_tokens: int = 16000
|
|
19
|
+
|
|
20
|
+
def _headers(self) -> dict:
|
|
21
|
+
return {
|
|
22
|
+
"X-Api-Key": self.api_key,
|
|
23
|
+
"Anthropic-Version": "2023-06-01",
|
|
24
|
+
"Content-Type": "application/json",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
def _message_url(self) -> str:
|
|
28
|
+
return f"{self.base_url}/v1/messages"
|
|
29
|
+
|
|
30
|
+
def _message_params(self, system: str, prompt: str, /) -> dict:
|
|
31
|
+
return {
|
|
32
|
+
"model": self.model.request_name,
|
|
33
|
+
"system": system,
|
|
34
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
35
|
+
"max_tokens": self.max_tokens,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def _structured_message_params(self, schema: dict, /) -> dict:
|
|
39
|
+
return {"output_config": {"format": {"type": "json_schema", "schema": schema}}}
|
|
40
|
+
|
|
41
|
+
def _tool_use_params(self, tool: dict, /) -> dict:
|
|
42
|
+
return {"tools": [tool], "tool_choice": {"type": "any"}}
|
|
43
|
+
|
|
44
|
+
def _get_block(
|
|
45
|
+
self, data: dict, stop_reason: str, filter_: Callable[[dict], bool], /
|
|
46
|
+
) -> dict:
|
|
47
|
+
if data["stop_reason"] != stop_reason:
|
|
48
|
+
raise ValueError(f"STOP REASON:{data['stop_reason']}")
|
|
49
|
+
|
|
50
|
+
blocks: list[dict] = [c for c in data["content"] if filter_(c)]
|
|
51
|
+
if len(blocks) != 1:
|
|
52
|
+
raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
|
|
53
|
+
|
|
54
|
+
return blocks[0]
|
|
55
|
+
|
|
56
|
+
def _get_usage(self, data: dict, /) -> Usage:
|
|
57
|
+
usage: dict = data["usage"]
|
|
58
|
+
cache_creation_tokens: int = usage.get("cache_creation_input_tokens", 0)
|
|
59
|
+
cache_read_tokens: int = usage.get("cache_read_input_tokens", 0)
|
|
60
|
+
return Usage(
|
|
61
|
+
input_tokens=usage["input_tokens"]
|
|
62
|
+
+ cache_creation_tokens
|
|
63
|
+
+ cache_read_tokens,
|
|
64
|
+
cache_creation_tokens=cache_creation_tokens,
|
|
65
|
+
cache_read_tokens=cache_read_tokens,
|
|
66
|
+
output_tokens=usage["output_tokens"],
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
async def _create_message_once(
|
|
71
|
+
self, system: str, prompt: str, extra: dict | None
|
|
72
|
+
) -> str:
|
|
73
|
+
with logfire.span("anthropic client | create message") as span:
|
|
74
|
+
span.set_attribute(key="model", value=self.model)
|
|
75
|
+
span.set_attribute(key="system", value=system)
|
|
76
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
77
|
+
span.set_attribute(key="extra", value=extra)
|
|
78
|
+
|
|
79
|
+
async with self._limiter:
|
|
80
|
+
response: Response = await self._http_client.post(
|
|
81
|
+
self._message_url(),
|
|
82
|
+
headers=self._headers(),
|
|
83
|
+
data=recursive_merge(
|
|
84
|
+
self._message_params(system, prompt),
|
|
85
|
+
self.model.extra_params,
|
|
86
|
+
extra,
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
data: dict = response.json()
|
|
92
|
+
block: dict = self._get_block(
|
|
93
|
+
data, "end_turn", lambda x: x["type"] == "text"
|
|
94
|
+
)
|
|
95
|
+
message: str = block["text"]
|
|
96
|
+
usage: Usage = self._get_usage(data)
|
|
97
|
+
|
|
98
|
+
span.set_attribute(key="return", value=message)
|
|
99
|
+
span.set_attribute(key="usage", value=usage)
|
|
100
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
101
|
+
return message
|
|
102
|
+
except Exception as exc:
|
|
103
|
+
raise ValidationError("生成消息失败") from exc
|
|
104
|
+
|
|
105
|
+
@override
|
|
106
|
+
async def _create_structured_message_once[T: BaseModel](
|
|
107
|
+
self, system: str, prompt: str, schema: type[T], extra: dict | None
|
|
108
|
+
) -> T:
|
|
109
|
+
schema_name: str = schema.__name__
|
|
110
|
+
with logfire.span(
|
|
111
|
+
f"anthropic client | create structured message | {schema_name}"
|
|
112
|
+
) as span:
|
|
113
|
+
schema_data: dict = transform_schema(schema.model_json_schema())
|
|
114
|
+
|
|
115
|
+
span.set_attribute(key="model", value=self.model)
|
|
116
|
+
span.set_attribute(key="system", value=system)
|
|
117
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
118
|
+
span.set_attribute(key="schema", value=schema_data)
|
|
119
|
+
span.set_attribute(key="extra", value=extra)
|
|
120
|
+
|
|
121
|
+
async with self._limiter:
|
|
122
|
+
response: Response = await self._http_client.post(
|
|
123
|
+
self._message_url(),
|
|
124
|
+
headers=self._headers(),
|
|
125
|
+
data=recursive_merge(
|
|
126
|
+
self._message_params(system, prompt),
|
|
127
|
+
self._structured_message_params(schema_data),
|
|
128
|
+
self.model.extra_params,
|
|
129
|
+
extra,
|
|
130
|
+
),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
data: dict = response.json()
|
|
135
|
+
block: dict = self._get_block(
|
|
136
|
+
data, "end_turn", lambda x: x["type"] == "text"
|
|
137
|
+
)
|
|
138
|
+
result: T = schema.model_validate_json(json_data=block["text"])
|
|
139
|
+
usage: Usage = self._get_usage(data)
|
|
140
|
+
|
|
141
|
+
span.set_attribute(key="return", value=result)
|
|
142
|
+
span.set_attribute(key="usage", value=usage)
|
|
143
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
144
|
+
return result
|
|
145
|
+
except Exception as exc:
|
|
146
|
+
raise ValidationError("生成结构化消息失败") from exc
|
|
147
|
+
|
|
148
|
+
@override
|
|
149
|
+
async def _create_tool_use_once[T: BaseModel](
|
|
150
|
+
self, system: str, prompt: str, tool: Tool[T], extra: dict | None
|
|
151
|
+
) -> T:
|
|
152
|
+
tool_name: str = tool.name
|
|
153
|
+
with logfire.span(f"anthropic client | create tool use | {tool_name}") as span:
|
|
154
|
+
tool_: dict = {
|
|
155
|
+
"name": tool.name,
|
|
156
|
+
"description": tool.description,
|
|
157
|
+
"input_schema": transform_schema(tool.input_schema.model_json_schema()),
|
|
158
|
+
"strict": True,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
span.set_attribute(key="model", value=self.model)
|
|
162
|
+
span.set_attribute(key="system", value=system)
|
|
163
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
164
|
+
span.set_attribute(key="tool", value=tool_)
|
|
165
|
+
span.set_attribute(key="extra", value=extra)
|
|
166
|
+
|
|
167
|
+
async with self._limiter:
|
|
168
|
+
response: Response = await self._http_client.post(
|
|
169
|
+
self._message_url(),
|
|
170
|
+
headers=self._headers(),
|
|
171
|
+
data=recursive_merge(
|
|
172
|
+
self._message_params(system, prompt),
|
|
173
|
+
self._tool_use_params(tool_),
|
|
174
|
+
self.model.extra_params,
|
|
175
|
+
extra,
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
data: dict = response.json()
|
|
181
|
+
block: dict = self._get_block(
|
|
182
|
+
data, "tool_use", lambda x: x["type"] == "tool_use"
|
|
183
|
+
)
|
|
184
|
+
result: T = tool.input_schema.model_validate(obj=block["input"])
|
|
185
|
+
usage: Usage = self._get_usage(data)
|
|
186
|
+
|
|
187
|
+
span.set_attribute(key="return", value=result)
|
|
188
|
+
span.set_attribute(key="usage", value=usage)
|
|
189
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
190
|
+
return result
|
|
191
|
+
except Exception as exc:
|
|
192
|
+
raise ValidationError("生成工具调用失败") from exc
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import ClassVar, override
|
|
3
|
+
|
|
4
|
+
import logfire
|
|
5
|
+
from httpx import Response
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from basekit.ai.schema import (
|
|
9
|
+
AIClient,
|
|
10
|
+
ImageRatio,
|
|
11
|
+
ImageSize,
|
|
12
|
+
NotSupportedError,
|
|
13
|
+
ValidationError,
|
|
14
|
+
)
|
|
15
|
+
from basekit.utils.misc import recursive_merge
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ImageUsage(BaseModel):
|
|
19
|
+
width: int
|
|
20
|
+
height: int
|
|
21
|
+
|
|
22
|
+
def format(self) -> str:
|
|
23
|
+
return f"{self.width}x{self.height}"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DashScopeClient(AIClient):
|
|
27
|
+
cache_key_fields: ClassVar[tuple[str, ...]] = ()
|
|
28
|
+
|
|
29
|
+
api_key: str
|
|
30
|
+
base_url: str = "https://dashscope.aliyuncs.com/api/v1"
|
|
31
|
+
|
|
32
|
+
def _headers(self) -> dict:
|
|
33
|
+
return {
|
|
34
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
35
|
+
"Content-Type": "application/json",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def _image_url(self) -> str:
|
|
39
|
+
return f"{self.base_url}/services/aigc/multimodal-generation/generation"
|
|
40
|
+
|
|
41
|
+
def _image_resolution(self, ratio: ImageRatio, size: ImageSize, /) -> str:
|
|
42
|
+
base_resolution_map: dict[ImageRatio, tuple[int, int]] = {
|
|
43
|
+
"9:16": (768, 1344),
|
|
44
|
+
"2:3": (832, 1248),
|
|
45
|
+
"3:4": (864, 1184),
|
|
46
|
+
"4:5": (896, 1152),
|
|
47
|
+
"1:1": (1024, 1024),
|
|
48
|
+
"5:4": (1152, 896),
|
|
49
|
+
"4:3": (1184, 864),
|
|
50
|
+
"3:2": (1248, 832),
|
|
51
|
+
"16:9": (1344, 768),
|
|
52
|
+
}
|
|
53
|
+
resolution_ratio_map: dict[ImageSize, int] = {"1K": 1, "2K": 2, "4K": 4}
|
|
54
|
+
|
|
55
|
+
width: int = base_resolution_map[ratio][0] * resolution_ratio_map[size]
|
|
56
|
+
height: int = base_resolution_map[ratio][1] * resolution_ratio_map[size]
|
|
57
|
+
resolution: str = f"{width}*{height}"
|
|
58
|
+
|
|
59
|
+
match self.model.name:
|
|
60
|
+
case "qwen-image-2.0" | "qwen-image-2.0-pro":
|
|
61
|
+
if size not in {"1K", "2K"}:
|
|
62
|
+
raise NotSupportedError(
|
|
63
|
+
f"模型 {self.model.name} 仅支持 1K 和 2K 分辨率"
|
|
64
|
+
)
|
|
65
|
+
return resolution
|
|
66
|
+
case _:
|
|
67
|
+
raise NotSupportedError(f"模型 {self.model.name} 不合法")
|
|
68
|
+
|
|
69
|
+
def _image_params(self, prompt: str, resolution: str, /) -> dict:
|
|
70
|
+
return {
|
|
71
|
+
"model": self.model.request_name,
|
|
72
|
+
"input": {"messages": [{"role": "user", "content": [{"text": prompt}]}]},
|
|
73
|
+
"parameters": {"size": resolution},
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def _get_block(self, data: dict, filter_: Callable[[dict], bool], /) -> dict:
|
|
77
|
+
choices: list = data["output"]["choices"]
|
|
78
|
+
if len(choices) != 1:
|
|
79
|
+
raise ValueError(f"LEN(CHOICES):{len(choices)}")
|
|
80
|
+
|
|
81
|
+
choice: dict = choices[0]
|
|
82
|
+
if choice["finish_reason"] != "stop":
|
|
83
|
+
raise ValueError(f"FINISH REASON:{choice['finish_reason']}")
|
|
84
|
+
|
|
85
|
+
blocks: list[dict] = [c for c in choice["message"]["content"] if filter_(c)]
|
|
86
|
+
if not blocks:
|
|
87
|
+
raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
|
|
88
|
+
|
|
89
|
+
return blocks[0]
|
|
90
|
+
|
|
91
|
+
def _get_usage(self, data: dict, /) -> ImageUsage:
|
|
92
|
+
usage: dict = data["usage"]
|
|
93
|
+
return ImageUsage(width=usage["width"], height=usage["height"])
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
async def _create_image_once(
|
|
97
|
+
self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None
|
|
98
|
+
) -> bytes:
|
|
99
|
+
with logfire.span("dashscope client | create image") as span:
|
|
100
|
+
resolution: str = self._image_resolution(ratio, size)
|
|
101
|
+
|
|
102
|
+
span.set_attribute(key="model", value=self.model)
|
|
103
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
104
|
+
span.set_attribute(key="ratio", value=ratio)
|
|
105
|
+
span.set_attribute(key="size", value=size)
|
|
106
|
+
span.set_attribute(key="extra", value=extra)
|
|
107
|
+
span.set_attribute(key="resolution", value=resolution)
|
|
108
|
+
|
|
109
|
+
async with self._limiter:
|
|
110
|
+
response: Response = await self._http_client.post(
|
|
111
|
+
self._image_url(),
|
|
112
|
+
headers=self._headers(),
|
|
113
|
+
data=recursive_merge(
|
|
114
|
+
self._image_params(prompt, resolution),
|
|
115
|
+
self.model.extra_params,
|
|
116
|
+
extra,
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
data: dict = response.json()
|
|
122
|
+
block: dict = self._get_block(data, lambda c: "image" in c)
|
|
123
|
+
image_url: str = block["image"]
|
|
124
|
+
image: bytes = (await self._http_client.get(image_url)).content
|
|
125
|
+
usage: ImageUsage = self._get_usage(data)
|
|
126
|
+
|
|
127
|
+
span.set_attribute(key="image_url", value=image_url)
|
|
128
|
+
span.set_attribute(key="filesize", value=len(image))
|
|
129
|
+
span.set_attribute(key="usage", value=usage)
|
|
130
|
+
span.message = f"{span.message} -> {usage.format()} pixels"
|
|
131
|
+
return image
|
|
132
|
+
except Exception as exc:
|
|
133
|
+
raise ValidationError("生成图像失败") from exc
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import ClassVar, override
|
|
4
|
+
|
|
5
|
+
import logfire
|
|
6
|
+
from httpx import Response
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from basekit.ai.schema import (
|
|
10
|
+
AIClient,
|
|
11
|
+
ImageRatio,
|
|
12
|
+
ImageSize,
|
|
13
|
+
NotSupportedError,
|
|
14
|
+
Tool,
|
|
15
|
+
Usage,
|
|
16
|
+
ValidationError,
|
|
17
|
+
)
|
|
18
|
+
from basekit.ai.utils import transform_schema
|
|
19
|
+
from basekit.utils.misc import recursive_merge
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GeminiClient(AIClient):
|
|
23
|
+
cache_key_fields: ClassVar[tuple[str, ...]] = ()
|
|
24
|
+
|
|
25
|
+
api_key: str
|
|
26
|
+
base_url: str = "https://generativelanguage.googleapis.com"
|
|
27
|
+
|
|
28
|
+
def _headers(self) -> dict:
|
|
29
|
+
return {"X-Goog-Api-Key": self.api_key, "Content-Type": "application/json"}
|
|
30
|
+
|
|
31
|
+
def _message_url(self) -> str:
|
|
32
|
+
return (
|
|
33
|
+
f"{self.base_url}/v1beta/models/{self.model.request_name}:generateContent"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def _message_params(self, system: str, prompt: str, /) -> dict:
|
|
37
|
+
return {
|
|
38
|
+
"systemInstruction": {"parts": [{"text": system}]},
|
|
39
|
+
"contents": [{"parts": [{"text": prompt}]}],
|
|
40
|
+
"generationConfig": {
|
|
41
|
+
"responseModalities": ["TEXT"],
|
|
42
|
+
"thinkingConfig": {"includeThoughts": True},
|
|
43
|
+
},
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
def _structured_message_params(self, schema: dict, /) -> dict:
|
|
47
|
+
return {
|
|
48
|
+
"generationConfig": {
|
|
49
|
+
"responseMimeType": "application/json",
|
|
50
|
+
"responseSchema": schema,
|
|
51
|
+
},
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
def _tool_use_params(self, tool: dict, /) -> dict:
|
|
55
|
+
return {
|
|
56
|
+
"tools": [{"functionDeclarations": [tool]}],
|
|
57
|
+
"toolConfig": {"functionCallingConfig": {"mode": "ANY"}},
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def _image_url(self) -> str:
|
|
61
|
+
return (
|
|
62
|
+
f"{self.base_url}/v1beta/models/{self.model.request_name}:generateContent"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _image_config(self, ratio: ImageRatio, size: ImageSize, /) -> dict:
|
|
66
|
+
match self.model.name:
|
|
67
|
+
case "gemini-3.1-flash-image-preview":
|
|
68
|
+
return {"aspectRatio": ratio, "imageSize": size}
|
|
69
|
+
case "gemini-3-pro-image-preview":
|
|
70
|
+
return {"aspectRatio": ratio, "imageSize": size}
|
|
71
|
+
case "gemini-2.5-flash-image":
|
|
72
|
+
if size != "1K":
|
|
73
|
+
raise NotSupportedError(f"模型 {self.model.name} 仅支持 1K 分辨率")
|
|
74
|
+
return {"aspectRatio": ratio}
|
|
75
|
+
case _:
|
|
76
|
+
raise NotSupportedError(f"模型 {self.model.name} 不合法")
|
|
77
|
+
|
|
78
|
+
def _image_params(self, prompt: str, config: dict, /) -> dict:
|
|
79
|
+
return {
|
|
80
|
+
"contents": [{"parts": [{"text": prompt}]}],
|
|
81
|
+
"generationConfig": {
|
|
82
|
+
"responseModalities": ["IMAGE"],
|
|
83
|
+
"thinkingConfig": {"includeThoughts": True},
|
|
84
|
+
"imageConfig": config,
|
|
85
|
+
},
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def _get_block(self, data: dict, filter_: Callable[[dict], bool], /) -> dict:
|
|
89
|
+
candidates: list = data["candidates"]
|
|
90
|
+
if len(candidates) != 1:
|
|
91
|
+
raise ValueError(f"LEN(CANDIDATES):{len(candidates)}")
|
|
92
|
+
|
|
93
|
+
candidate: dict = candidates[0]
|
|
94
|
+
if candidate["finishReason"] != "STOP":
|
|
95
|
+
raise ValueError(f"FINISH REASON:{candidate['finishReason']}")
|
|
96
|
+
|
|
97
|
+
blocks: list[dict] = [
|
|
98
|
+
p
|
|
99
|
+
for p in candidate["content"]["parts"]
|
|
100
|
+
if filter_(p) and "thought" not in p
|
|
101
|
+
]
|
|
102
|
+
if len(blocks) != 1:
|
|
103
|
+
raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
|
|
104
|
+
|
|
105
|
+
return blocks[0]
|
|
106
|
+
|
|
107
|
+
def _get_usage(self, data: dict, /) -> Usage:
|
|
108
|
+
metadata: dict = data["usageMetadata"]
|
|
109
|
+
return Usage(
|
|
110
|
+
input_tokens=metadata["promptTokenCount"],
|
|
111
|
+
cache_creation_tokens=0,
|
|
112
|
+
cache_read_tokens=metadata.get("cachedContentTokenCount", 0),
|
|
113
|
+
output_tokens=metadata.get("thoughtsTokenCount", 0)
|
|
114
|
+
+ metadata["candidatesTokenCount"],
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@override
|
|
118
|
+
async def _create_message_once(
|
|
119
|
+
self, system: str, prompt: str, extra: dict | None
|
|
120
|
+
) -> str:
|
|
121
|
+
with logfire.span("gemini client | create message") as span:
|
|
122
|
+
span.set_attribute(key="model", value=self.model)
|
|
123
|
+
span.set_attribute(key="system", value=system)
|
|
124
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
125
|
+
span.set_attribute(key="extra", value=extra)
|
|
126
|
+
|
|
127
|
+
async with self._limiter:
|
|
128
|
+
response: Response = await self._http_client.post(
|
|
129
|
+
self._message_url(),
|
|
130
|
+
headers=self._headers(),
|
|
131
|
+
data=recursive_merge(
|
|
132
|
+
self._message_params(system, prompt),
|
|
133
|
+
self.model.extra_params,
|
|
134
|
+
extra,
|
|
135
|
+
),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
data: dict = response.json()
|
|
140
|
+
block: dict = self._get_block(data, lambda x: "text" in x)
|
|
141
|
+
message: str = block["text"]
|
|
142
|
+
usage: Usage = self._get_usage(data)
|
|
143
|
+
|
|
144
|
+
span.set_attribute(key="return", value=message)
|
|
145
|
+
span.set_attribute(key="usage", value=usage)
|
|
146
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
147
|
+
return message
|
|
148
|
+
except Exception as exc:
|
|
149
|
+
raise ValidationError("生成消息失败") from exc
|
|
150
|
+
|
|
151
|
+
@override
|
|
152
|
+
async def _create_structured_message_once[T: BaseModel](
|
|
153
|
+
self, system: str, prompt: str, schema: type[T], extra: dict | None
|
|
154
|
+
) -> T:
|
|
155
|
+
schema_name: str = schema.__name__
|
|
156
|
+
with logfire.span(
|
|
157
|
+
f"gemini client | create structured message | {schema_name}"
|
|
158
|
+
) as span:
|
|
159
|
+
schema_data: dict = transform_schema(
|
|
160
|
+
schema.model_json_schema(), gemini=True
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
span.set_attribute(key="model", value=self.model)
|
|
164
|
+
span.set_attribute(key="system", value=system)
|
|
165
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
166
|
+
span.set_attribute(key="schema", value=schema_data)
|
|
167
|
+
span.set_attribute(key="extra", value=extra)
|
|
168
|
+
|
|
169
|
+
async with self._limiter:
|
|
170
|
+
response: Response = await self._http_client.post(
|
|
171
|
+
self._message_url(),
|
|
172
|
+
headers=self._headers(),
|
|
173
|
+
data=recursive_merge(
|
|
174
|
+
self._message_params(system, prompt),
|
|
175
|
+
self._structured_message_params(schema_data),
|
|
176
|
+
self.model.extra_params,
|
|
177
|
+
extra,
|
|
178
|
+
),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
data: dict = response.json()
|
|
183
|
+
block: dict = self._get_block(data, lambda x: "text" in x)
|
|
184
|
+
result: T = schema.model_validate_json(json_data=block["text"])
|
|
185
|
+
usage: Usage = self._get_usage(data)
|
|
186
|
+
|
|
187
|
+
span.set_attribute(key="return", value=result)
|
|
188
|
+
span.set_attribute(key="usage", value=usage)
|
|
189
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
190
|
+
return result
|
|
191
|
+
except Exception as exc:
|
|
192
|
+
raise ValidationError("生成结构化消息失败") from exc
|
|
193
|
+
|
|
194
|
+
@override
|
|
195
|
+
async def _create_tool_use_once[T: BaseModel](
|
|
196
|
+
self, system: str, prompt: str, tool: Tool[T], extra: dict | None
|
|
197
|
+
) -> T:
|
|
198
|
+
tool_name: str = tool.name
|
|
199
|
+
with logfire.span(f"gemini client | create tool use | {tool_name}") as span:
|
|
200
|
+
tool_: dict = {
|
|
201
|
+
"name": tool.name,
|
|
202
|
+
"description": tool.description,
|
|
203
|
+
"parameters": transform_schema(
|
|
204
|
+
tool.input_schema.model_json_schema(), gemini=True
|
|
205
|
+
),
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
span.set_attribute(key="model", value=self.model)
|
|
209
|
+
span.set_attribute(key="system", value=system)
|
|
210
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
211
|
+
span.set_attribute(key="tool", value=tool_)
|
|
212
|
+
span.set_attribute(key="extra", value=extra)
|
|
213
|
+
|
|
214
|
+
async with self._limiter:
|
|
215
|
+
response: Response = await self._http_client.post(
|
|
216
|
+
self._message_url(),
|
|
217
|
+
headers=self._headers(),
|
|
218
|
+
data=recursive_merge(
|
|
219
|
+
self._message_params(system, prompt),
|
|
220
|
+
self._tool_use_params(tool_),
|
|
221
|
+
self.model.extra_params,
|
|
222
|
+
extra,
|
|
223
|
+
),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
try:
|
|
227
|
+
data: dict = response.json()
|
|
228
|
+
block: dict = self._get_block(data, lambda x: "functionCall" in x)
|
|
229
|
+
result: T = tool.input_schema.model_validate(
|
|
230
|
+
obj=block["functionCall"]["args"]
|
|
231
|
+
)
|
|
232
|
+
usage: Usage = self._get_usage(data)
|
|
233
|
+
|
|
234
|
+
span.set_attribute(key="return", value=result)
|
|
235
|
+
span.set_attribute(key="usage", value=usage)
|
|
236
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
237
|
+
return result
|
|
238
|
+
except Exception as exc:
|
|
239
|
+
raise ValidationError("生成工具调用失败") from exc
|
|
240
|
+
|
|
241
|
+
@override
|
|
242
|
+
async def _create_image_once(
|
|
243
|
+
self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None
|
|
244
|
+
) -> bytes:
|
|
245
|
+
with logfire.span("gemini client | create image") as span:
|
|
246
|
+
config: dict = self._image_config(ratio, size)
|
|
247
|
+
|
|
248
|
+
span.set_attribute(key="model", value=self.model)
|
|
249
|
+
span.set_attribute(key="prompt", value=prompt)
|
|
250
|
+
span.set_attribute(key="ratio", value=ratio)
|
|
251
|
+
span.set_attribute(key="size", value=size)
|
|
252
|
+
span.set_attribute(key="extra", value=extra)
|
|
253
|
+
span.set_attribute(key="config", value=config)
|
|
254
|
+
|
|
255
|
+
async with self._limiter:
|
|
256
|
+
response: Response = await self._http_client.post(
|
|
257
|
+
self._image_url(),
|
|
258
|
+
headers=self._headers(),
|
|
259
|
+
data=recursive_merge(
|
|
260
|
+
self._image_params(prompt, config),
|
|
261
|
+
self.model.extra_params,
|
|
262
|
+
extra,
|
|
263
|
+
),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
data: dict = response.json()
|
|
268
|
+
block: dict = self._get_block(data, lambda x: "inlineData" in x)
|
|
269
|
+
image: bytes = base64.b64decode(s=block["inlineData"]["data"])
|
|
270
|
+
usage: Usage = self._get_usage(data)
|
|
271
|
+
|
|
272
|
+
span.set_attribute(key="filesize", value=len(image))
|
|
273
|
+
span.set_attribute(key="usage", value=usage)
|
|
274
|
+
span.message = f"{span.message} -> {usage.format()} tokens"
|
|
275
|
+
return image
|
|
276
|
+
except Exception as exc:
|
|
277
|
+
raise ValidationError("生成图像失败") from exc
|