ai-lib-python 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_lib_python/__init__.py +1 -1
- ai_lib_python/computer_use/__init__.py +228 -0
- ai_lib_python/drivers/__init__.py +140 -0
- ai_lib_python/drivers/anthropic.py +173 -0
- ai_lib_python/drivers/gemini.py +177 -0
- ai_lib_python/drivers/openai.py +133 -0
- ai_lib_python/mcp/__init__.py +181 -0
- ai_lib_python/multimodal/__init__.py +138 -0
- ai_lib_python/protocol/v2/__init__.py +22 -0
- ai_lib_python/protocol/v2/capabilities.py +198 -0
- ai_lib_python/protocol/v2/manifest.py +256 -0
- ai_lib_python/registry/__init__.py +174 -0
- {ai_lib_python-0.6.0.dist-info → ai_lib_python-0.7.1.dist-info}/METADATA +25 -4
- {ai_lib_python-0.6.0.dist-info → ai_lib_python-0.7.1.dist-info}/RECORD +17 -6
- {ai_lib_python-0.6.0.dist-info → ai_lib_python-0.7.1.dist-info}/WHEEL +0 -0
- {ai_lib_python-0.6.0.dist-info → ai_lib_python-0.7.1.dist-info}/licenses/LICENSE-APACHE +0 -0
- {ai_lib_python-0.6.0.dist-info → ai_lib_python-0.7.1.dist-info}/licenses/LICENSE-MIT +0 -0
ai_lib_python/__init__.py
CHANGED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""Computer Use 抽象层 — 提供跨厂商的 GUI 自动化操作标准化和安全控制。
|
|
2
|
+
|
|
3
|
+
Computer Use abstraction layer for AI-Protocol. Provides:
|
|
4
|
+
- Normalized action types across providers (screen_based, tool_based)
|
|
5
|
+
- Safety policy enforcement (confirmation, sandbox, logging, domain allowlist)
|
|
6
|
+
- Provider-specific configuration extraction
|
|
7
|
+
- Action validation before execution
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Any
|
|
15
|
+
from urllib.parse import urlparse
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ─── Normalized Action Types ────────────────────────────────────────────────
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ActionType(str, Enum):
|
|
22
|
+
"""Normalized computer use action types."""
|
|
23
|
+
|
|
24
|
+
SCREENSHOT = "screenshot"
|
|
25
|
+
MOUSE_CLICK = "mouse_click"
|
|
26
|
+
MOUSE_DOUBLE_CLICK = "mouse_double_click"
|
|
27
|
+
MOUSE_DRAG = "mouse_drag"
|
|
28
|
+
SCROLL = "scroll"
|
|
29
|
+
MOUSE_MOVE = "mouse_move"
|
|
30
|
+
KEYBOARD_TYPE = "keyboard_type"
|
|
31
|
+
KEYBOARD_SHORTCUT = "keyboard_shortcut"
|
|
32
|
+
BROWSER_NAVIGATE = "browser_navigate"
|
|
33
|
+
BROWSER_CLICK_ELEMENT = "browser_click_element"
|
|
34
|
+
BROWSER_FILL_FORM = "browser_fill_form"
|
|
35
|
+
ZOOM_REGION = "zoom_region"
|
|
36
|
+
FILE_READ = "file_read"
|
|
37
|
+
FILE_WRITE = "file_write"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MouseButton(str, Enum):
|
|
41
|
+
LEFT = "left"
|
|
42
|
+
RIGHT = "right"
|
|
43
|
+
MIDDLE = "middle"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ComputerAction:
|
|
48
|
+
"""A normalized computer use action — provider-agnostic."""
|
|
49
|
+
|
|
50
|
+
action_type: ActionType
|
|
51
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
52
|
+
|
|
53
|
+
# -- convenience factories --
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def screenshot(cls, fmt: str = "png") -> ComputerAction:
|
|
57
|
+
return cls(ActionType.SCREENSHOT, {"format": fmt})
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def mouse_click(
|
|
61
|
+
cls, x: float, y: float, button: MouseButton = MouseButton.LEFT
|
|
62
|
+
) -> ComputerAction:
|
|
63
|
+
return cls(ActionType.MOUSE_CLICK, {"x": x, "y": y, "button": button.value})
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def keyboard_type(cls, text: str) -> ComputerAction:
|
|
67
|
+
return cls(ActionType.KEYBOARD_TYPE, {"text": text})
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def keyboard_shortcut(cls, keys: list[str]) -> ComputerAction:
|
|
71
|
+
return cls(ActionType.KEYBOARD_SHORTCUT, {"keys": keys})
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def browser_navigate(cls, url: str) -> ComputerAction:
|
|
75
|
+
return cls(ActionType.BROWSER_NAVIGATE, {"url": url})
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def file_read(cls, path: str) -> ComputerAction:
|
|
79
|
+
return cls(ActionType.FILE_READ, {"path": path})
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def file_write(cls, path: str, content: str) -> ComputerAction:
|
|
83
|
+
return cls(ActionType.FILE_WRITE, {"path": path, "content": content})
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ImplementationStyle(str, Enum):
|
|
87
|
+
"""Provider implementation approach."""
|
|
88
|
+
|
|
89
|
+
SCREEN_BASED = "screen_based"
|
|
90
|
+
TOOL_BASED = "tool_based"
|
|
91
|
+
HYBRID = "hybrid"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class SandboxMode(str, Enum):
|
|
95
|
+
REQUIRED = "required"
|
|
96
|
+
RECOMMENDED = "recommended"
|
|
97
|
+
OPTIONAL = "optional"
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# ─── Safety Policy ──────────────────────────────────────────────────────────
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class SafetyViolation(Exception):
|
|
104
|
+
"""Raised when a computer use action violates the safety policy."""
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class SafetyPolicy:
|
|
109
|
+
"""Safety policy for computer use actions.
|
|
110
|
+
|
|
111
|
+
Loaded from the manifest's ``computer_use.safety`` configuration.
|
|
112
|
+
All validations are enforced *before* the action is dispatched.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
confirmation_required: bool = True
|
|
116
|
+
sandbox_mode: SandboxMode = SandboxMode.RECOMMENDED
|
|
117
|
+
action_logging: bool = True
|
|
118
|
+
domain_allowlist: set[str] = field(default_factory=set)
|
|
119
|
+
sensitive_data_protection: bool = True
|
|
120
|
+
max_actions_per_turn: int = 0
|
|
121
|
+
action_timeout_ms: int = 30_000
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_config(cls, safety_dict: dict[str, Any] | None) -> SafetyPolicy:
|
|
125
|
+
"""Build a safety policy from a manifest's ``computer_use.safety`` dict."""
|
|
126
|
+
if not safety_dict:
|
|
127
|
+
return cls()
|
|
128
|
+
return cls(
|
|
129
|
+
confirmation_required=safety_dict.get("confirmation_required", True),
|
|
130
|
+
sandbox_mode=SandboxMode(safety_dict.get("sandbox_mode", "recommended")),
|
|
131
|
+
action_logging=safety_dict.get("action_logging", True),
|
|
132
|
+
domain_allowlist=set(safety_dict.get("domain_allowlist_entries", [])),
|
|
133
|
+
sensitive_data_protection=safety_dict.get("sensitive_data_protection", True),
|
|
134
|
+
max_actions_per_turn=safety_dict.get("max_actions_per_turn", 0),
|
|
135
|
+
action_timeout_ms=safety_dict.get("action_timeout_ms", 30_000),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def validate_action(
|
|
139
|
+
self,
|
|
140
|
+
action: ComputerAction,
|
|
141
|
+
actions_this_turn: int = 0,
|
|
142
|
+
) -> None:
|
|
143
|
+
"""Validate an action against this policy. Raises :class:`SafetyViolation`."""
|
|
144
|
+
if self.max_actions_per_turn > 0 and actions_this_turn >= self.max_actions_per_turn:
|
|
145
|
+
raise SafetyViolation(
|
|
146
|
+
f"Max actions per turn exceeded: limit={self.max_actions_per_turn}, "
|
|
147
|
+
f"attempted={actions_this_turn + 1}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if action.action_type == ActionType.BROWSER_NAVIGATE and self.domain_allowlist:
|
|
151
|
+
url = action.params.get("url", "")
|
|
152
|
+
domain = _extract_domain(url)
|
|
153
|
+
if domain not in self.domain_allowlist:
|
|
154
|
+
raise SafetyViolation(
|
|
155
|
+
f"Domain '{domain}' is not in the allowlist: {sorted(self.domain_allowlist)}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if self.sensitive_data_protection and action.action_type in (
|
|
159
|
+
ActionType.FILE_READ,
|
|
160
|
+
ActionType.FILE_WRITE,
|
|
161
|
+
):
|
|
162
|
+
path = action.params.get("path", "")
|
|
163
|
+
if _is_sensitive_path(path):
|
|
164
|
+
raise SafetyViolation(f"Access to sensitive path '{path}' is blocked")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# ─── Provider Configuration ─────────────────────────────────────────────────
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclass
|
|
171
|
+
class CuProviderConfig:
|
|
172
|
+
"""Provider-specific computer use configuration."""
|
|
173
|
+
|
|
174
|
+
tool_type: str = "computer_use"
|
|
175
|
+
beta_header: str | None = None
|
|
176
|
+
implementation: ImplementationStyle = ImplementationStyle.SCREEN_BASED
|
|
177
|
+
model_requirement: str | None = None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def extract_provider_config(cu_config: dict[str, Any] | None) -> CuProviderConfig | None:
|
|
181
|
+
"""Extract provider-specific CU configuration from a manifest section."""
|
|
182
|
+
if not cu_config or not cu_config.get("supported"):
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
impl_str = cu_config.get("implementation", "screen_based")
|
|
186
|
+
implementation = ImplementationStyle(impl_str)
|
|
187
|
+
|
|
188
|
+
mapping = cu_config.get("provider_mapping", {})
|
|
189
|
+
return CuProviderConfig(
|
|
190
|
+
tool_type=mapping.get("tool_type", "computer_use"),
|
|
191
|
+
beta_header=mapping.get("beta_header"),
|
|
192
|
+
implementation=implementation,
|
|
193
|
+
model_requirement=mapping.get("model_requirement"),
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# ─── Helpers ────────────────────────────────────────────────────────────────
|
|
198
|
+
|
|
199
|
+
_SENSITIVE_PATTERNS = (
|
|
200
|
+
".ssh", ".gnupg", ".aws", "credentials", "secrets",
|
|
201
|
+
".env", "password", "token", ".kube/config",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _extract_domain(url: str) -> str:
|
|
206
|
+
try:
|
|
207
|
+
parsed = urlparse(url)
|
|
208
|
+
return parsed.hostname or ""
|
|
209
|
+
except Exception:
|
|
210
|
+
return url.split("//")[-1].split("/")[0].split(":")[0]
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _is_sensitive_path(path: str) -> bool:
|
|
214
|
+
lower = path.lower()
|
|
215
|
+
return any(p in lower for p in _SENSITIVE_PATTERNS)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
__all__ = [
|
|
219
|
+
"ActionType",
|
|
220
|
+
"ComputerAction",
|
|
221
|
+
"CuProviderConfig",
|
|
222
|
+
"ImplementationStyle",
|
|
223
|
+
"MouseButton",
|
|
224
|
+
"SafetyPolicy",
|
|
225
|
+
"SafetyViolation",
|
|
226
|
+
"SandboxMode",
|
|
227
|
+
"extract_provider_config",
|
|
228
|
+
]
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Provider 驱动抽象层 — 通过 ABC 实现多厂商 API 适配的动态分发。
|
|
2
|
+
|
|
3
|
+
Provider driver abstraction layer implementing the ProviderContract specification.
|
|
4
|
+
Uses abstract base class + factory for runtime polymorphism, enabling the same
|
|
5
|
+
client code to work with OpenAI, Anthropic, Gemini, and any compatible provider.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from ai_lib_python.protocol.v2.capabilities import Capability
|
|
15
|
+
from ai_lib_python.protocol.v2.manifest import ApiStyle
|
|
16
|
+
from ai_lib_python.types.events import StreamingEvent
|
|
17
|
+
from ai_lib_python.types.message import Message
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class DriverRequest:
|
|
22
|
+
"""Unified HTTP request representation for provider communication."""
|
|
23
|
+
|
|
24
|
+
url: str = ""
|
|
25
|
+
method: str = "POST"
|
|
26
|
+
headers: dict[str, str] = field(default_factory=dict)
|
|
27
|
+
body: dict[str, Any] = field(default_factory=dict)
|
|
28
|
+
stream: bool = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class DriverResponse:
|
|
33
|
+
"""Unified chat response from provider."""
|
|
34
|
+
|
|
35
|
+
content: str | None = None
|
|
36
|
+
finish_reason: str | None = None
|
|
37
|
+
usage: UsageInfo | None = None
|
|
38
|
+
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
|
39
|
+
raw: dict[str, Any] = field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class UsageInfo:
|
|
44
|
+
"""Token usage information."""
|
|
45
|
+
|
|
46
|
+
prompt_tokens: int = 0
|
|
47
|
+
completion_tokens: int = 0
|
|
48
|
+
total_tokens: int = 0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ProviderDriver(ABC):
|
|
52
|
+
"""Core abstract class for provider-specific API adaptation.
|
|
53
|
+
|
|
54
|
+
Each provider API style (OpenAI, Anthropic, Gemini) has a concrete
|
|
55
|
+
implementation. The runtime selects the correct driver based on the
|
|
56
|
+
manifest's ``api_style`` or ``provider_contract``.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def provider_id(self) -> str:
|
|
62
|
+
"""Unique provider identifier (matches manifest ``id``)."""
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def api_style(self) -> ApiStyle:
|
|
67
|
+
"""API style this driver implements."""
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def build_request(
|
|
71
|
+
self,
|
|
72
|
+
messages: list[Message],
|
|
73
|
+
model: str,
|
|
74
|
+
*,
|
|
75
|
+
temperature: float | None = None,
|
|
76
|
+
max_tokens: int | None = None,
|
|
77
|
+
stream: bool = False,
|
|
78
|
+
extra: dict[str, Any] | None = None,
|
|
79
|
+
) -> DriverRequest:
|
|
80
|
+
"""Build a provider-specific HTTP request from unified parameters."""
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def parse_response(self, body: dict[str, Any]) -> DriverResponse:
|
|
84
|
+
"""Parse a non-streaming response into unified format."""
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def parse_stream_event(self, data: str) -> StreamingEvent | None:
|
|
88
|
+
"""Parse a single streaming event from raw SSE/NDJSON data."""
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def supported_capabilities(self) -> list[Capability]:
|
|
92
|
+
"""Get the list of capabilities this driver supports."""
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def is_stream_done(self, data: str) -> bool:
|
|
96
|
+
"""Check if the done signal has been received in streaming."""
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# ---------------------------------------------------------------------------
|
|
100
|
+
# Concrete drivers (imported lazily to avoid circular deps)
|
|
101
|
+
# ---------------------------------------------------------------------------
|
|
102
|
+
|
|
103
|
+
from ai_lib_python.drivers.anthropic import AnthropicDriver # noqa: E402
|
|
104
|
+
from ai_lib_python.drivers.gemini import GeminiDriver # noqa: E402
|
|
105
|
+
from ai_lib_python.drivers.openai import OpenAiDriver # noqa: E402
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def create_driver(
|
|
109
|
+
api_style: ApiStyle,
|
|
110
|
+
provider_id: str,
|
|
111
|
+
capabilities: list[Capability] | None = None,
|
|
112
|
+
) -> ProviderDriver:
|
|
113
|
+
"""Factory: create the appropriate driver from an API style.
|
|
114
|
+
|
|
115
|
+
``Custom`` falls back to OpenAI-compatible, which covers most
|
|
116
|
+
providers that follow the OpenAI chat completions format (DeepSeek,
|
|
117
|
+
Moonshot, Zhipu, etc.).
|
|
118
|
+
"""
|
|
119
|
+
caps = capabilities or []
|
|
120
|
+
match api_style:
|
|
121
|
+
case ApiStyle.OPENAI_COMPATIBLE | ApiStyle.CUSTOM:
|
|
122
|
+
return OpenAiDriver(provider_id=provider_id, capabilities=caps)
|
|
123
|
+
case ApiStyle.ANTHROPIC_MESSAGES:
|
|
124
|
+
return AnthropicDriver(provider_id=provider_id, capabilities=caps)
|
|
125
|
+
case ApiStyle.GEMINI_GENERATE:
|
|
126
|
+
return GeminiDriver(provider_id=provider_id, capabilities=caps)
|
|
127
|
+
case _:
|
|
128
|
+
return OpenAiDriver(provider_id=provider_id, capabilities=caps)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
__all__ = [
|
|
132
|
+
"AnthropicDriver",
|
|
133
|
+
"DriverRequest",
|
|
134
|
+
"DriverResponse",
|
|
135
|
+
"GeminiDriver",
|
|
136
|
+
"OpenAiDriver",
|
|
137
|
+
"ProviderDriver",
|
|
138
|
+
"UsageInfo",
|
|
139
|
+
"create_driver",
|
|
140
|
+
]
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Anthropic Messages API 驱动 — 实现 Anthropic 特有的请求/响应格式转换。
|
|
2
|
+
|
|
3
|
+
Anthropic Messages API driver. Key differences from OpenAI:
|
|
4
|
+
- System messages are a top-level ``system`` parameter, not part of ``messages``.
|
|
5
|
+
- Content uses typed blocks: ``[{"type": "text", "text": "..."}]``.
|
|
6
|
+
- Streaming uses ``event: content_block_delta`` with ``delta.text``.
|
|
7
|
+
- Response uses ``content[0].text`` instead of ``choices[0].message.content``.
|
|
8
|
+
- ``max_tokens`` is required, not optional.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from ai_lib_python.drivers import (
|
|
17
|
+
DriverRequest,
|
|
18
|
+
DriverResponse,
|
|
19
|
+
ProviderDriver,
|
|
20
|
+
UsageInfo,
|
|
21
|
+
)
|
|
22
|
+
from ai_lib_python.protocol.v2.capabilities import Capability
|
|
23
|
+
from ai_lib_python.protocol.v2.manifest import ApiStyle
|
|
24
|
+
from ai_lib_python.types.events import StreamingEvent
|
|
25
|
+
from ai_lib_python.types.message import Message
|
|
26
|
+
|
|
27
|
+
_DEFAULT_MAX_TOKENS = 4096
|
|
28
|
+
|
|
29
|
+
# Anthropic stop_reason → AI-Protocol normalized finish_reason
|
|
30
|
+
_STOP_REASON_MAP: dict[str, str] = {
|
|
31
|
+
"end_turn": "stop",
|
|
32
|
+
"max_tokens": "length",
|
|
33
|
+
"tool_use": "tool_calls",
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AnthropicDriver(ProviderDriver):
|
|
38
|
+
"""Anthropic Messages API driver."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
provider_id: str,
|
|
43
|
+
capabilities: list[Capability] | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
self._provider_id = provider_id
|
|
46
|
+
self._capabilities = capabilities or []
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def provider_id(self) -> str:
|
|
50
|
+
return self._provider_id
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def api_style(self) -> ApiStyle:
|
|
54
|
+
return ApiStyle.ANTHROPIC_MESSAGES
|
|
55
|
+
|
|
56
|
+
def build_request(
|
|
57
|
+
self,
|
|
58
|
+
messages: list[Message],
|
|
59
|
+
model: str,
|
|
60
|
+
*,
|
|
61
|
+
temperature: float | None = None,
|
|
62
|
+
max_tokens: int | None = None,
|
|
63
|
+
stream: bool = False,
|
|
64
|
+
extra: dict[str, Any] | None = None,
|
|
65
|
+
) -> DriverRequest:
|
|
66
|
+
system_text, msgs = self._split_system(messages)
|
|
67
|
+
|
|
68
|
+
body: dict[str, Any] = {
|
|
69
|
+
"model": model,
|
|
70
|
+
"messages": msgs,
|
|
71
|
+
"max_tokens": max_tokens or _DEFAULT_MAX_TOKENS,
|
|
72
|
+
"stream": stream,
|
|
73
|
+
}
|
|
74
|
+
if system_text:
|
|
75
|
+
body["system"] = system_text
|
|
76
|
+
if temperature is not None:
|
|
77
|
+
body["temperature"] = temperature
|
|
78
|
+
if extra:
|
|
79
|
+
body.update(extra)
|
|
80
|
+
|
|
81
|
+
headers = {"anthropic-version": "2023-06-01"}
|
|
82
|
+
return DriverRequest(body=body, stream=stream, headers=headers)
|
|
83
|
+
|
|
84
|
+
def parse_response(self, body: dict[str, Any]) -> DriverResponse:
|
|
85
|
+
# content: [{type: "text", text: "..."}]
|
|
86
|
+
content_blocks = body.get("content", [])
|
|
87
|
+
text = next(
|
|
88
|
+
(b["text"] for b in content_blocks if b.get("type") == "text"),
|
|
89
|
+
None,
|
|
90
|
+
)
|
|
91
|
+
# Normalize stop_reason
|
|
92
|
+
raw_reason = body.get("stop_reason", "")
|
|
93
|
+
finish_reason = _STOP_REASON_MAP.get(raw_reason, raw_reason) or None
|
|
94
|
+
|
|
95
|
+
usage_raw = body.get("usage")
|
|
96
|
+
usage = None
|
|
97
|
+
if usage_raw:
|
|
98
|
+
inp = usage_raw.get("input_tokens", 0)
|
|
99
|
+
out = usage_raw.get("output_tokens", 0)
|
|
100
|
+
usage = UsageInfo(prompt_tokens=inp, completion_tokens=out, total_tokens=inp + out)
|
|
101
|
+
|
|
102
|
+
tool_calls = [b for b in content_blocks if b.get("type") == "tool_use"]
|
|
103
|
+
|
|
104
|
+
return DriverResponse(
|
|
105
|
+
content=text,
|
|
106
|
+
finish_reason=finish_reason,
|
|
107
|
+
usage=usage,
|
|
108
|
+
tool_calls=tool_calls,
|
|
109
|
+
raw=body,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def parse_stream_event(self, data: str) -> StreamingEvent | None:
|
|
113
|
+
stripped = data.strip()
|
|
114
|
+
if not stripped:
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
chunk = json.loads(stripped)
|
|
118
|
+
event_type = chunk.get("type", "")
|
|
119
|
+
|
|
120
|
+
if event_type == "content_block_delta":
|
|
121
|
+
delta = chunk.get("delta", {})
|
|
122
|
+
if text := delta.get("text"):
|
|
123
|
+
seq = chunk.get("index")
|
|
124
|
+
return StreamingEvent.content_delta(text, sequence_id=seq)
|
|
125
|
+
if thinking := delta.get("thinking"):
|
|
126
|
+
return StreamingEvent.thinking_delta(thinking)
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
if event_type == "message_delta":
|
|
130
|
+
reason = chunk.get("delta", {}).get("stop_reason")
|
|
131
|
+
if reason:
|
|
132
|
+
return StreamingEvent.stream_end(_STOP_REASON_MAP.get(reason, reason))
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
if event_type == "message_stop":
|
|
136
|
+
return StreamingEvent.stream_end("stop")
|
|
137
|
+
|
|
138
|
+
if event_type == "error":
|
|
139
|
+
return StreamingEvent.stream_error(chunk.get("error"))
|
|
140
|
+
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def supported_capabilities(self) -> list[Capability]:
|
|
144
|
+
return list(self._capabilities)
|
|
145
|
+
|
|
146
|
+
def is_stream_done(self, _data: str) -> bool:
|
|
147
|
+
# Anthropic signals done via event type, not a sentinel string.
|
|
148
|
+
return False
|
|
149
|
+
|
|
150
|
+
# -- internal helpers ------------------------------------------------
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def _split_system(messages: list[Message]) -> tuple[str | None, list[dict[str, Any]]]:
|
|
154
|
+
"""Extract system text and convert remaining messages to Anthropic format."""
|
|
155
|
+
system_parts: list[str] = []
|
|
156
|
+
msgs: list[dict[str, Any]] = []
|
|
157
|
+
|
|
158
|
+
for m in messages:
|
|
159
|
+
role = m.role if isinstance(m.role, str) else m.role.value
|
|
160
|
+
if role == "system":
|
|
161
|
+
if isinstance(m.content, str):
|
|
162
|
+
system_parts.append(m.content)
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
if isinstance(m.content, str):
|
|
166
|
+
content: Any = [{"type": "text", "text": m.content}]
|
|
167
|
+
else:
|
|
168
|
+
content = [b.model_dump(by_alias=True) for b in m.content]
|
|
169
|
+
|
|
170
|
+
msgs.append({"role": role, "content": content})
|
|
171
|
+
|
|
172
|
+
system_text = "\n\n".join(system_parts) if system_parts else None
|
|
173
|
+
return system_text, msgs
|