python-basekit 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
basekit/__init__.py ADDED
File without changes
@@ -0,0 +1,192 @@
1
+ from collections.abc import Callable
2
+ from typing import ClassVar, override
3
+
4
+ import logfire
5
+ from httpx import Response
6
+ from pydantic import BaseModel
7
+
8
+ from basekit.ai.schema import AIClient, Tool, Usage, ValidationError
9
+ from basekit.ai.utils import transform_schema
10
+ from basekit.utils.misc import recursive_merge
11
+
12
+
13
+ class AnthropicClient(AIClient):
14
+ cache_key_fields: ClassVar[tuple[str, ...]] = ("max_tokens",)
15
+
16
+ api_key: str
17
+ base_url: str = "https://api.anthropic.com"
18
+ max_tokens: int = 16000
19
+
20
+ def _headers(self) -> dict:
21
+ return {
22
+ "X-Api-Key": self.api_key,
23
+ "Anthropic-Version": "2023-06-01",
24
+ "Content-Type": "application/json",
25
+ }
26
+
27
+ def _message_url(self) -> str:
28
+ return f"{self.base_url}/v1/messages"
29
+
30
+ def _message_params(self, system: str, prompt: str, /) -> dict:
31
+ return {
32
+ "model": self.model.request_name,
33
+ "system": system,
34
+ "messages": [{"role": "user", "content": prompt}],
35
+ "max_tokens": self.max_tokens,
36
+ }
37
+
38
+ def _structured_message_params(self, schema: dict, /) -> dict:
39
+ return {"output_config": {"format": {"type": "json_schema", "schema": schema}}}
40
+
41
+ def _tool_use_params(self, tool: dict, /) -> dict:
42
+ return {"tools": [tool], "tool_choice": {"type": "any"}}
43
+
44
+ def _get_block(
45
+ self, data: dict, stop_reason: str, filter_: Callable[[dict], bool], /
46
+ ) -> dict:
47
+ if data["stop_reason"] != stop_reason:
48
+ raise ValueError(f"STOP REASON:{data['stop_reason']}")
49
+
50
+ blocks: list[dict] = [c for c in data["content"] if filter_(c)]
51
+ if len(blocks) != 1:
52
+ raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
53
+
54
+ return blocks[0]
55
+
56
+ def _get_usage(self, data: dict, /) -> Usage:
57
+ usage: dict = data["usage"]
58
+ cache_creation_tokens: int = usage.get("cache_creation_input_tokens", 0)
59
+ cache_read_tokens: int = usage.get("cache_read_input_tokens", 0)
60
+ return Usage(
61
+ input_tokens=usage["input_tokens"]
62
+ + cache_creation_tokens
63
+ + cache_read_tokens,
64
+ cache_creation_tokens=cache_creation_tokens,
65
+ cache_read_tokens=cache_read_tokens,
66
+ output_tokens=usage["output_tokens"],
67
+ )
68
+
69
+ @override
70
+ async def _create_message_once(
71
+ self, system: str, prompt: str, extra: dict | None
72
+ ) -> str:
73
+ with logfire.span("anthropic client | create message") as span:
74
+ span.set_attribute(key="model", value=self.model)
75
+ span.set_attribute(key="system", value=system)
76
+ span.set_attribute(key="prompt", value=prompt)
77
+ span.set_attribute(key="extra", value=extra)
78
+
79
+ async with self._limiter:
80
+ response: Response = await self._http_client.post(
81
+ self._message_url(),
82
+ headers=self._headers(),
83
+ data=recursive_merge(
84
+ self._message_params(system, prompt),
85
+ self.model.extra_params,
86
+ extra,
87
+ ),
88
+ )
89
+
90
+ try:
91
+ data: dict = response.json()
92
+ block: dict = self._get_block(
93
+ data, "end_turn", lambda x: x["type"] == "text"
94
+ )
95
+ message: str = block["text"]
96
+ usage: Usage = self._get_usage(data)
97
+
98
+ span.set_attribute(key="return", value=message)
99
+ span.set_attribute(key="usage", value=usage)
100
+ span.message = f"{span.message} -> {usage.format()} tokens"
101
+ return message
102
+ except Exception as exc:
103
+ raise ValidationError("生成消息失败") from exc
104
+
105
+ @override
106
+ async def _create_structured_message_once[T: BaseModel](
107
+ self, system: str, prompt: str, schema: type[T], extra: dict | None
108
+ ) -> T:
109
+ schema_name: str = schema.__name__
110
+ with logfire.span(
111
+ f"anthropic client | create structured message | {schema_name}"
112
+ ) as span:
113
+ schema_data: dict = transform_schema(schema.model_json_schema())
114
+
115
+ span.set_attribute(key="model", value=self.model)
116
+ span.set_attribute(key="system", value=system)
117
+ span.set_attribute(key="prompt", value=prompt)
118
+ span.set_attribute(key="schema", value=schema_data)
119
+ span.set_attribute(key="extra", value=extra)
120
+
121
+ async with self._limiter:
122
+ response: Response = await self._http_client.post(
123
+ self._message_url(),
124
+ headers=self._headers(),
125
+ data=recursive_merge(
126
+ self._message_params(system, prompt),
127
+ self._structured_message_params(schema_data),
128
+ self.model.extra_params,
129
+ extra,
130
+ ),
131
+ )
132
+
133
+ try:
134
+ data: dict = response.json()
135
+ block: dict = self._get_block(
136
+ data, "end_turn", lambda x: x["type"] == "text"
137
+ )
138
+ result: T = schema.model_validate_json(json_data=block["text"])
139
+ usage: Usage = self._get_usage(data)
140
+
141
+ span.set_attribute(key="return", value=result)
142
+ span.set_attribute(key="usage", value=usage)
143
+ span.message = f"{span.message} -> {usage.format()} tokens"
144
+ return result
145
+ except Exception as exc:
146
+ raise ValidationError("生成结构化消息失败") from exc
147
+
148
+ @override
149
+ async def _create_tool_use_once[T: BaseModel](
150
+ self, system: str, prompt: str, tool: Tool[T], extra: dict | None
151
+ ) -> T:
152
+ tool_name: str = tool.name
153
+ with logfire.span(f"anthropic client | create tool use | {tool_name}") as span:
154
+ tool_: dict = {
155
+ "name": tool.name,
156
+ "description": tool.description,
157
+ "input_schema": transform_schema(tool.input_schema.model_json_schema()),
158
+ "strict": True,
159
+ }
160
+
161
+ span.set_attribute(key="model", value=self.model)
162
+ span.set_attribute(key="system", value=system)
163
+ span.set_attribute(key="prompt", value=prompt)
164
+ span.set_attribute(key="tool", value=tool_)
165
+ span.set_attribute(key="extra", value=extra)
166
+
167
+ async with self._limiter:
168
+ response: Response = await self._http_client.post(
169
+ self._message_url(),
170
+ headers=self._headers(),
171
+ data=recursive_merge(
172
+ self._message_params(system, prompt),
173
+ self._tool_use_params(tool_),
174
+ self.model.extra_params,
175
+ extra,
176
+ ),
177
+ )
178
+
179
+ try:
180
+ data: dict = response.json()
181
+ block: dict = self._get_block(
182
+ data, "tool_use", lambda x: x["type"] == "tool_use"
183
+ )
184
+ result: T = tool.input_schema.model_validate(obj=block["input"])
185
+ usage: Usage = self._get_usage(data)
186
+
187
+ span.set_attribute(key="return", value=result)
188
+ span.set_attribute(key="usage", value=usage)
189
+ span.message = f"{span.message} -> {usage.format()} tokens"
190
+ return result
191
+ except Exception as exc:
192
+ raise ValidationError("生成工具调用失败") from exc
@@ -0,0 +1,133 @@
1
+ from collections.abc import Callable
2
+ from typing import ClassVar, override
3
+
4
+ import logfire
5
+ from httpx import Response
6
+ from pydantic import BaseModel
7
+
8
+ from basekit.ai.schema import (
9
+ AIClient,
10
+ ImageRatio,
11
+ ImageSize,
12
+ NotSupportedError,
13
+ ValidationError,
14
+ )
15
+ from basekit.utils.misc import recursive_merge
16
+
17
+
18
+ class ImageUsage(BaseModel):
19
+ width: int
20
+ height: int
21
+
22
+ def format(self) -> str:
23
+ return f"{self.width}x{self.height}"
24
+
25
+
26
+ class DashScopeClient(AIClient):
27
+ cache_key_fields: ClassVar[tuple[str, ...]] = ()
28
+
29
+ api_key: str
30
+ base_url: str = "https://dashscope.aliyuncs.com/api/v1"
31
+
32
+ def _headers(self) -> dict:
33
+ return {
34
+ "Authorization": f"Bearer {self.api_key}",
35
+ "Content-Type": "application/json",
36
+ }
37
+
38
+ def _image_url(self) -> str:
39
+ return f"{self.base_url}/services/aigc/multimodal-generation/generation"
40
+
41
+ def _image_resolution(self, ratio: ImageRatio, size: ImageSize, /) -> str:
42
+ base_resolution_map: dict[ImageRatio, tuple[int, int]] = {
43
+ "9:16": (768, 1344),
44
+ "2:3": (832, 1248),
45
+ "3:4": (864, 1184),
46
+ "4:5": (896, 1152),
47
+ "1:1": (1024, 1024),
48
+ "5:4": (1152, 896),
49
+ "4:3": (1184, 864),
50
+ "3:2": (1248, 832),
51
+ "16:9": (1344, 768),
52
+ }
53
+ resolution_ratio_map: dict[ImageSize, int] = {"1K": 1, "2K": 2, "4K": 4}
54
+
55
+ width: int = base_resolution_map[ratio][0] * resolution_ratio_map[size]
56
+ height: int = base_resolution_map[ratio][1] * resolution_ratio_map[size]
57
+ resolution: str = f"{width}*{height}"
58
+
59
+ match self.model.name:
60
+ case "qwen-image-2.0" | "qwen-image-2.0-pro":
61
+ if size not in {"1K", "2K"}:
62
+ raise NotSupportedError(
63
+ f"模型 {self.model.name} 仅支持 1K 和 2K 分辨率"
64
+ )
65
+ return resolution
66
+ case _:
67
+ raise NotSupportedError(f"模型 {self.model.name} 不合法")
68
+
69
+ def _image_params(self, prompt: str, resolution: str, /) -> dict:
70
+ return {
71
+ "model": self.model.request_name,
72
+ "input": {"messages": [{"role": "user", "content": [{"text": prompt}]}]},
73
+ "parameters": {"size": resolution},
74
+ }
75
+
76
+ def _get_block(self, data: dict, filter_: Callable[[dict], bool], /) -> dict:
77
+ choices: list = data["output"]["choices"]
78
+ if len(choices) != 1:
79
+ raise ValueError(f"LEN(CHOICES):{len(choices)}")
80
+
81
+ choice: dict = choices[0]
82
+ if choice["finish_reason"] != "stop":
83
+ raise ValueError(f"FINISH REASON:{choice['finish_reason']}")
84
+
85
+ blocks: list[dict] = [c for c in choice["message"]["content"] if filter_(c)]
86
+ if not blocks:
87
+ raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
88
+
89
+ return blocks[0]
90
+
91
+ def _get_usage(self, data: dict, /) -> ImageUsage:
92
+ usage: dict = data["usage"]
93
+ return ImageUsage(width=usage["width"], height=usage["height"])
94
+
95
+ @override
96
+ async def _create_image_once(
97
+ self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None
98
+ ) -> bytes:
99
+ with logfire.span("dashscope client | create image") as span:
100
+ resolution: str = self._image_resolution(ratio, size)
101
+
102
+ span.set_attribute(key="model", value=self.model)
103
+ span.set_attribute(key="prompt", value=prompt)
104
+ span.set_attribute(key="ratio", value=ratio)
105
+ span.set_attribute(key="size", value=size)
106
+ span.set_attribute(key="extra", value=extra)
107
+ span.set_attribute(key="resolution", value=resolution)
108
+
109
+ async with self._limiter:
110
+ response: Response = await self._http_client.post(
111
+ self._image_url(),
112
+ headers=self._headers(),
113
+ data=recursive_merge(
114
+ self._image_params(prompt, resolution),
115
+ self.model.extra_params,
116
+ extra,
117
+ ),
118
+ )
119
+
120
+ try:
121
+ data: dict = response.json()
122
+ block: dict = self._get_block(data, lambda c: "image" in c)
123
+ image_url: str = block["image"]
124
+ image: bytes = (await self._http_client.get(image_url)).content
125
+ usage: ImageUsage = self._get_usage(data)
126
+
127
+ span.set_attribute(key="image_url", value=image_url)
128
+ span.set_attribute(key="filesize", value=len(image))
129
+ span.set_attribute(key="usage", value=usage)
130
+ span.message = f"{span.message} -> {usage.format()} pixels"
131
+ return image
132
+ except Exception as exc:
133
+ raise ValidationError("生成图像失败") from exc
@@ -0,0 +1,277 @@
1
+ import base64
2
+ from collections.abc import Callable
3
+ from typing import ClassVar, override
4
+
5
+ import logfire
6
+ from httpx import Response
7
+ from pydantic import BaseModel
8
+
9
+ from basekit.ai.schema import (
10
+ AIClient,
11
+ ImageRatio,
12
+ ImageSize,
13
+ NotSupportedError,
14
+ Tool,
15
+ Usage,
16
+ ValidationError,
17
+ )
18
+ from basekit.ai.utils import transform_schema
19
+ from basekit.utils.misc import recursive_merge
20
+
21
+
22
+ class GeminiClient(AIClient):
23
+ cache_key_fields: ClassVar[tuple[str, ...]] = ()
24
+
25
+ api_key: str
26
+ base_url: str = "https://generativelanguage.googleapis.com"
27
+
28
+ def _headers(self) -> dict:
29
+ return {"X-Goog-Api-Key": self.api_key, "Content-Type": "application/json"}
30
+
31
+ def _message_url(self) -> str:
32
+ return (
33
+ f"{self.base_url}/v1beta/models/{self.model.request_name}:generateContent"
34
+ )
35
+
36
+ def _message_params(self, system: str, prompt: str, /) -> dict:
37
+ return {
38
+ "systemInstruction": {"parts": [{"text": system}]},
39
+ "contents": [{"parts": [{"text": prompt}]}],
40
+ "generationConfig": {
41
+ "responseModalities": ["TEXT"],
42
+ "thinkingConfig": {"includeThoughts": True},
43
+ },
44
+ }
45
+
46
+ def _structured_message_params(self, schema: dict, /) -> dict:
47
+ return {
48
+ "generationConfig": {
49
+ "responseMimeType": "application/json",
50
+ "responseSchema": schema,
51
+ },
52
+ }
53
+
54
+ def _tool_use_params(self, tool: dict, /) -> dict:
55
+ return {
56
+ "tools": [{"functionDeclarations": [tool]}],
57
+ "toolConfig": {"functionCallingConfig": {"mode": "ANY"}},
58
+ }
59
+
60
+ def _image_url(self) -> str:
61
+ return (
62
+ f"{self.base_url}/v1beta/models/{self.model.request_name}:generateContent"
63
+ )
64
+
65
+ def _image_config(self, ratio: ImageRatio, size: ImageSize, /) -> dict:
66
+ match self.model.name:
67
+ case "gemini-3.1-flash-image-preview":
68
+ return {"aspectRatio": ratio, "imageSize": size}
69
+ case "gemini-3-pro-image-preview":
70
+ return {"aspectRatio": ratio, "imageSize": size}
71
+ case "gemini-2.5-flash-image":
72
+ if size != "1K":
73
+ raise NotSupportedError(f"模型 {self.model.name} 仅支持 1K 分辨率")
74
+ return {"aspectRatio": ratio}
75
+ case _:
76
+ raise NotSupportedError(f"模型 {self.model.name} 不合法")
77
+
78
+ def _image_params(self, prompt: str, config: dict, /) -> dict:
79
+ return {
80
+ "contents": [{"parts": [{"text": prompt}]}],
81
+ "generationConfig": {
82
+ "responseModalities": ["IMAGE"],
83
+ "thinkingConfig": {"includeThoughts": True},
84
+ "imageConfig": config,
85
+ },
86
+ }
87
+
88
+ def _get_block(self, data: dict, filter_: Callable[[dict], bool], /) -> dict:
89
+ candidates: list = data["candidates"]
90
+ if len(candidates) != 1:
91
+ raise ValueError(f"LEN(CANDIDATES):{len(candidates)}")
92
+
93
+ candidate: dict = candidates[0]
94
+ if candidate["finishReason"] != "STOP":
95
+ raise ValueError(f"FINISH REASON:{candidate['finishReason']}")
96
+
97
+ blocks: list[dict] = [
98
+ p
99
+ for p in candidate["content"]["parts"]
100
+ if filter_(p) and "thought" not in p
101
+ ]
102
+ if len(blocks) != 1:
103
+ raise ValueError(f"LEN(BLOCKS):{len(blocks)}")
104
+
105
+ return blocks[0]
106
+
107
+ def _get_usage(self, data: dict, /) -> Usage:
108
+ metadata: dict = data["usageMetadata"]
109
+ return Usage(
110
+ input_tokens=metadata["promptTokenCount"],
111
+ cache_creation_tokens=0,
112
+ cache_read_tokens=metadata.get("cachedContentTokenCount", 0),
113
+ output_tokens=metadata.get("thoughtsTokenCount", 0)
114
+ + metadata["candidatesTokenCount"],
115
+ )
116
+
117
+ @override
118
+ async def _create_message_once(
119
+ self, system: str, prompt: str, extra: dict | None
120
+ ) -> str:
121
+ with logfire.span("gemini client | create message") as span:
122
+ span.set_attribute(key="model", value=self.model)
123
+ span.set_attribute(key="system", value=system)
124
+ span.set_attribute(key="prompt", value=prompt)
125
+ span.set_attribute(key="extra", value=extra)
126
+
127
+ async with self._limiter:
128
+ response: Response = await self._http_client.post(
129
+ self._message_url(),
130
+ headers=self._headers(),
131
+ data=recursive_merge(
132
+ self._message_params(system, prompt),
133
+ self.model.extra_params,
134
+ extra,
135
+ ),
136
+ )
137
+
138
+ try:
139
+ data: dict = response.json()
140
+ block: dict = self._get_block(data, lambda x: "text" in x)
141
+ message: str = block["text"]
142
+ usage: Usage = self._get_usage(data)
143
+
144
+ span.set_attribute(key="return", value=message)
145
+ span.set_attribute(key="usage", value=usage)
146
+ span.message = f"{span.message} -> {usage.format()} tokens"
147
+ return message
148
+ except Exception as exc:
149
+ raise ValidationError("生成消息失败") from exc
150
+
151
+ @override
152
+ async def _create_structured_message_once[T: BaseModel](
153
+ self, system: str, prompt: str, schema: type[T], extra: dict | None
154
+ ) -> T:
155
+ schema_name: str = schema.__name__
156
+ with logfire.span(
157
+ f"gemini client | create structured message | {schema_name}"
158
+ ) as span:
159
+ schema_data: dict = transform_schema(
160
+ schema.model_json_schema(), gemini=True
161
+ )
162
+
163
+ span.set_attribute(key="model", value=self.model)
164
+ span.set_attribute(key="system", value=system)
165
+ span.set_attribute(key="prompt", value=prompt)
166
+ span.set_attribute(key="schema", value=schema_data)
167
+ span.set_attribute(key="extra", value=extra)
168
+
169
+ async with self._limiter:
170
+ response: Response = await self._http_client.post(
171
+ self._message_url(),
172
+ headers=self._headers(),
173
+ data=recursive_merge(
174
+ self._message_params(system, prompt),
175
+ self._structured_message_params(schema_data),
176
+ self.model.extra_params,
177
+ extra,
178
+ ),
179
+ )
180
+
181
+ try:
182
+ data: dict = response.json()
183
+ block: dict = self._get_block(data, lambda x: "text" in x)
184
+ result: T = schema.model_validate_json(json_data=block["text"])
185
+ usage: Usage = self._get_usage(data)
186
+
187
+ span.set_attribute(key="return", value=result)
188
+ span.set_attribute(key="usage", value=usage)
189
+ span.message = f"{span.message} -> {usage.format()} tokens"
190
+ return result
191
+ except Exception as exc:
192
+ raise ValidationError("生成结构化消息失败") from exc
193
+
194
+ @override
195
+ async def _create_tool_use_once[T: BaseModel](
196
+ self, system: str, prompt: str, tool: Tool[T], extra: dict | None
197
+ ) -> T:
198
+ tool_name: str = tool.name
199
+ with logfire.span(f"gemini client | create tool use | {tool_name}") as span:
200
+ tool_: dict = {
201
+ "name": tool.name,
202
+ "description": tool.description,
203
+ "parameters": transform_schema(
204
+ tool.input_schema.model_json_schema(), gemini=True
205
+ ),
206
+ }
207
+
208
+ span.set_attribute(key="model", value=self.model)
209
+ span.set_attribute(key="system", value=system)
210
+ span.set_attribute(key="prompt", value=prompt)
211
+ span.set_attribute(key="tool", value=tool_)
212
+ span.set_attribute(key="extra", value=extra)
213
+
214
+ async with self._limiter:
215
+ response: Response = await self._http_client.post(
216
+ self._message_url(),
217
+ headers=self._headers(),
218
+ data=recursive_merge(
219
+ self._message_params(system, prompt),
220
+ self._tool_use_params(tool_),
221
+ self.model.extra_params,
222
+ extra,
223
+ ),
224
+ )
225
+
226
+ try:
227
+ data: dict = response.json()
228
+ block: dict = self._get_block(data, lambda x: "functionCall" in x)
229
+ result: T = tool.input_schema.model_validate(
230
+ obj=block["functionCall"]["args"]
231
+ )
232
+ usage: Usage = self._get_usage(data)
233
+
234
+ span.set_attribute(key="return", value=result)
235
+ span.set_attribute(key="usage", value=usage)
236
+ span.message = f"{span.message} -> {usage.format()} tokens"
237
+ return result
238
+ except Exception as exc:
239
+ raise ValidationError("生成工具调用失败") from exc
240
+
241
+ @override
242
+ async def _create_image_once(
243
+ self, prompt: str, ratio: ImageRatio, size: ImageSize, extra: dict | None
244
+ ) -> bytes:
245
+ with logfire.span("gemini client | create image") as span:
246
+ config: dict = self._image_config(ratio, size)
247
+
248
+ span.set_attribute(key="model", value=self.model)
249
+ span.set_attribute(key="prompt", value=prompt)
250
+ span.set_attribute(key="ratio", value=ratio)
251
+ span.set_attribute(key="size", value=size)
252
+ span.set_attribute(key="extra", value=extra)
253
+ span.set_attribute(key="config", value=config)
254
+
255
+ async with self._limiter:
256
+ response: Response = await self._http_client.post(
257
+ self._image_url(),
258
+ headers=self._headers(),
259
+ data=recursive_merge(
260
+ self._image_params(prompt, config),
261
+ self.model.extra_params,
262
+ extra,
263
+ ),
264
+ )
265
+
266
+ try:
267
+ data: dict = response.json()
268
+ block: dict = self._get_block(data, lambda x: "inlineData" in x)
269
+ image: bytes = base64.b64decode(s=block["inlineData"]["data"])
270
+ usage: Usage = self._get_usage(data)
271
+
272
+ span.set_attribute(key="filesize", value=len(image))
273
+ span.set_attribute(key="usage", value=usage)
274
+ span.message = f"{span.message} -> {usage.format()} tokens"
275
+ return image
276
+ except Exception as exc:
277
+ raise ValidationError("生成图像失败") from exc