ai-lib-python 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ai_lib_python/__init__.py CHANGED
@@ -27,7 +27,7 @@ from ai_lib_python.types.message import (
27
27
  )
28
28
  from ai_lib_python.types.tool import ToolCall, ToolDefinition
29
29
 
30
- __version__ = "0.6.0"
30
+ __version__ = "0.7.0"
31
31
 
32
32
  __all__ = [
33
33
  # Client
@@ -0,0 +1,228 @@
1
+ """Computer Use 抽象层 — 提供跨厂商的 GUI 自动化操作标准化和安全控制。
2
+
3
+ Computer Use abstraction layer for AI-Protocol. Provides:
4
+ - Normalized action types across providers (screen_based, tool_based)
5
+ - Safety policy enforcement (confirmation, sandbox, logging, domain allowlist)
6
+ - Provider-specific configuration extraction
7
+ - Action validation before execution
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from enum import Enum
14
+ from typing import Any
15
+ from urllib.parse import urlparse
16
+
17
+
18
+ # ─── Normalized Action Types ────────────────────────────────────────────────
19
+
20
+
21
+ class ActionType(str, Enum):
22
+ """Normalized computer use action types."""
23
+
24
+ SCREENSHOT = "screenshot"
25
+ MOUSE_CLICK = "mouse_click"
26
+ MOUSE_DOUBLE_CLICK = "mouse_double_click"
27
+ MOUSE_DRAG = "mouse_drag"
28
+ SCROLL = "scroll"
29
+ MOUSE_MOVE = "mouse_move"
30
+ KEYBOARD_TYPE = "keyboard_type"
31
+ KEYBOARD_SHORTCUT = "keyboard_shortcut"
32
+ BROWSER_NAVIGATE = "browser_navigate"
33
+ BROWSER_CLICK_ELEMENT = "browser_click_element"
34
+ BROWSER_FILL_FORM = "browser_fill_form"
35
+ ZOOM_REGION = "zoom_region"
36
+ FILE_READ = "file_read"
37
+ FILE_WRITE = "file_write"
38
+
39
+
40
+ class MouseButton(str, Enum):
41
+ LEFT = "left"
42
+ RIGHT = "right"
43
+ MIDDLE = "middle"
44
+
45
+
46
+ @dataclass
47
+ class ComputerAction:
48
+ """A normalized computer use action — provider-agnostic."""
49
+
50
+ action_type: ActionType
51
+ params: dict[str, Any] = field(default_factory=dict)
52
+
53
+ # -- convenience factories --
54
+
55
+ @classmethod
56
+ def screenshot(cls, fmt: str = "png") -> ComputerAction:
57
+ return cls(ActionType.SCREENSHOT, {"format": fmt})
58
+
59
+ @classmethod
60
+ def mouse_click(
61
+ cls, x: float, y: float, button: MouseButton = MouseButton.LEFT
62
+ ) -> ComputerAction:
63
+ return cls(ActionType.MOUSE_CLICK, {"x": x, "y": y, "button": button.value})
64
+
65
+ @classmethod
66
+ def keyboard_type(cls, text: str) -> ComputerAction:
67
+ return cls(ActionType.KEYBOARD_TYPE, {"text": text})
68
+
69
+ @classmethod
70
+ def keyboard_shortcut(cls, keys: list[str]) -> ComputerAction:
71
+ return cls(ActionType.KEYBOARD_SHORTCUT, {"keys": keys})
72
+
73
+ @classmethod
74
+ def browser_navigate(cls, url: str) -> ComputerAction:
75
+ return cls(ActionType.BROWSER_NAVIGATE, {"url": url})
76
+
77
+ @classmethod
78
+ def file_read(cls, path: str) -> ComputerAction:
79
+ return cls(ActionType.FILE_READ, {"path": path})
80
+
81
+ @classmethod
82
+ def file_write(cls, path: str, content: str) -> ComputerAction:
83
+ return cls(ActionType.FILE_WRITE, {"path": path, "content": content})
84
+
85
+
86
+ class ImplementationStyle(str, Enum):
87
+ """Provider implementation approach."""
88
+
89
+ SCREEN_BASED = "screen_based"
90
+ TOOL_BASED = "tool_based"
91
+ HYBRID = "hybrid"
92
+
93
+
94
+ class SandboxMode(str, Enum):
95
+ REQUIRED = "required"
96
+ RECOMMENDED = "recommended"
97
+ OPTIONAL = "optional"
98
+
99
+
100
+ # ─── Safety Policy ──────────────────────────────────────────────────────────
101
+
102
+
103
+ class SafetyViolation(Exception):
104
+ """Raised when a computer use action violates the safety policy."""
105
+
106
+
107
+ @dataclass
108
+ class SafetyPolicy:
109
+ """Safety policy for computer use actions.
110
+
111
+ Loaded from the manifest's ``computer_use.safety`` configuration.
112
+ All validations are enforced *before* the action is dispatched.
113
+ """
114
+
115
+ confirmation_required: bool = True
116
+ sandbox_mode: SandboxMode = SandboxMode.RECOMMENDED
117
+ action_logging: bool = True
118
+ domain_allowlist: set[str] = field(default_factory=set)
119
+ sensitive_data_protection: bool = True
120
+ max_actions_per_turn: int = 0
121
+ action_timeout_ms: int = 30_000
122
+
123
+ @classmethod
124
+ def from_config(cls, safety_dict: dict[str, Any] | None) -> SafetyPolicy:
125
+ """Build a safety policy from a manifest's ``computer_use.safety`` dict."""
126
+ if not safety_dict:
127
+ return cls()
128
+ return cls(
129
+ confirmation_required=safety_dict.get("confirmation_required", True),
130
+ sandbox_mode=SandboxMode(safety_dict.get("sandbox_mode", "recommended")),
131
+ action_logging=safety_dict.get("action_logging", True),
132
+ domain_allowlist=set(safety_dict.get("domain_allowlist_entries", [])),
133
+ sensitive_data_protection=safety_dict.get("sensitive_data_protection", True),
134
+ max_actions_per_turn=safety_dict.get("max_actions_per_turn", 0),
135
+ action_timeout_ms=safety_dict.get("action_timeout_ms", 30_000),
136
+ )
137
+
138
+ def validate_action(
139
+ self,
140
+ action: ComputerAction,
141
+ actions_this_turn: int = 0,
142
+ ) -> None:
143
+ """Validate an action against this policy. Raises :class:`SafetyViolation`."""
144
+ if self.max_actions_per_turn > 0 and actions_this_turn >= self.max_actions_per_turn:
145
+ raise SafetyViolation(
146
+ f"Max actions per turn exceeded: limit={self.max_actions_per_turn}, "
147
+ f"attempted={actions_this_turn + 1}"
148
+ )
149
+
150
+ if action.action_type == ActionType.BROWSER_NAVIGATE and self.domain_allowlist:
151
+ url = action.params.get("url", "")
152
+ domain = _extract_domain(url)
153
+ if domain not in self.domain_allowlist:
154
+ raise SafetyViolation(
155
+ f"Domain '{domain}' is not in the allowlist: {sorted(self.domain_allowlist)}"
156
+ )
157
+
158
+ if self.sensitive_data_protection and action.action_type in (
159
+ ActionType.FILE_READ,
160
+ ActionType.FILE_WRITE,
161
+ ):
162
+ path = action.params.get("path", "")
163
+ if _is_sensitive_path(path):
164
+ raise SafetyViolation(f"Access to sensitive path '{path}' is blocked")
165
+
166
+
167
+ # ─── Provider Configuration ─────────────────────────────────────────────────
168
+
169
+
170
+ @dataclass
171
+ class CuProviderConfig:
172
+ """Provider-specific computer use configuration."""
173
+
174
+ tool_type: str = "computer_use"
175
+ beta_header: str | None = None
176
+ implementation: ImplementationStyle = ImplementationStyle.SCREEN_BASED
177
+ model_requirement: str | None = None
178
+
179
+
180
+ def extract_provider_config(cu_config: dict[str, Any] | None) -> CuProviderConfig | None:
181
+ """Extract provider-specific CU configuration from a manifest section."""
182
+ if not cu_config or not cu_config.get("supported"):
183
+ return None
184
+
185
+ impl_str = cu_config.get("implementation", "screen_based")
186
+ implementation = ImplementationStyle(impl_str)
187
+
188
+ mapping = cu_config.get("provider_mapping", {})
189
+ return CuProviderConfig(
190
+ tool_type=mapping.get("tool_type", "computer_use"),
191
+ beta_header=mapping.get("beta_header"),
192
+ implementation=implementation,
193
+ model_requirement=mapping.get("model_requirement"),
194
+ )
195
+
196
+
197
+ # ─── Helpers ────────────────────────────────────────────────────────────────
198
+
199
+ _SENSITIVE_PATTERNS = (
200
+ ".ssh", ".gnupg", ".aws", "credentials", "secrets",
201
+ ".env", "password", "token", ".kube/config",
202
+ )
203
+
204
+
205
+ def _extract_domain(url: str) -> str:
206
+ try:
207
+ parsed = urlparse(url)
208
+ return parsed.hostname or ""
209
+ except Exception:
210
+ return url.split("//")[-1].split("/")[0].split(":")[0]
211
+
212
+
213
+ def _is_sensitive_path(path: str) -> bool:
214
+ lower = path.lower()
215
+ return any(p in lower for p in _SENSITIVE_PATTERNS)
216
+
217
+
218
+ __all__ = [
219
+ "ActionType",
220
+ "ComputerAction",
221
+ "CuProviderConfig",
222
+ "ImplementationStyle",
223
+ "MouseButton",
224
+ "SafetyPolicy",
225
+ "SafetyViolation",
226
+ "SandboxMode",
227
+ "extract_provider_config",
228
+ ]
@@ -0,0 +1,140 @@
1
+ """Provider 驱动抽象层 — 通过 ABC 实现多厂商 API 适配的动态分发。
2
+
3
+ Provider driver abstraction layer implementing the ProviderContract specification.
4
+ Uses abstract base class + factory for runtime polymorphism, enabling the same
5
+ client code to work with OpenAI, Anthropic, Gemini, and any compatible provider.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field
12
+ from typing import Any
13
+
14
+ from ai_lib_python.protocol.v2.capabilities import Capability
15
+ from ai_lib_python.protocol.v2.manifest import ApiStyle
16
+ from ai_lib_python.types.events import StreamingEvent
17
+ from ai_lib_python.types.message import Message
18
+
19
+
20
+ @dataclass
21
+ class DriverRequest:
22
+ """Unified HTTP request representation for provider communication."""
23
+
24
+ url: str = ""
25
+ method: str = "POST"
26
+ headers: dict[str, str] = field(default_factory=dict)
27
+ body: dict[str, Any] = field(default_factory=dict)
28
+ stream: bool = False
29
+
30
+
31
+ @dataclass
32
+ class DriverResponse:
33
+ """Unified chat response from provider."""
34
+
35
+ content: str | None = None
36
+ finish_reason: str | None = None
37
+ usage: UsageInfo | None = None
38
+ tool_calls: list[dict[str, Any]] = field(default_factory=list)
39
+ raw: dict[str, Any] = field(default_factory=dict)
40
+
41
+
42
+ @dataclass
43
+ class UsageInfo:
44
+ """Token usage information."""
45
+
46
+ prompt_tokens: int = 0
47
+ completion_tokens: int = 0
48
+ total_tokens: int = 0
49
+
50
+
51
+ class ProviderDriver(ABC):
52
+ """Core abstract class for provider-specific API adaptation.
53
+
54
+ Each provider API style (OpenAI, Anthropic, Gemini) has a concrete
55
+ implementation. The runtime selects the correct driver based on the
56
+ manifest's ``api_style`` or ``provider_contract``.
57
+ """
58
+
59
+ @property
60
+ @abstractmethod
61
+ def provider_id(self) -> str:
62
+ """Unique provider identifier (matches manifest ``id``)."""
63
+
64
+ @property
65
+ @abstractmethod
66
+ def api_style(self) -> ApiStyle:
67
+ """API style this driver implements."""
68
+
69
+ @abstractmethod
70
+ def build_request(
71
+ self,
72
+ messages: list[Message],
73
+ model: str,
74
+ *,
75
+ temperature: float | None = None,
76
+ max_tokens: int | None = None,
77
+ stream: bool = False,
78
+ extra: dict[str, Any] | None = None,
79
+ ) -> DriverRequest:
80
+ """Build a provider-specific HTTP request from unified parameters."""
81
+
82
+ @abstractmethod
83
+ def parse_response(self, body: dict[str, Any]) -> DriverResponse:
84
+ """Parse a non-streaming response into unified format."""
85
+
86
+ @abstractmethod
87
+ def parse_stream_event(self, data: str) -> StreamingEvent | None:
88
+ """Parse a single streaming event from raw SSE/NDJSON data."""
89
+
90
+ @abstractmethod
91
+ def supported_capabilities(self) -> list[Capability]:
92
+ """Get the list of capabilities this driver supports."""
93
+
94
+ @abstractmethod
95
+ def is_stream_done(self, data: str) -> bool:
96
+ """Check if the done signal has been received in streaming."""
97
+
98
+
99
+ # ---------------------------------------------------------------------------
100
+ # Concrete drivers (imported lazily to avoid circular deps)
101
+ # ---------------------------------------------------------------------------
102
+
103
+ from ai_lib_python.drivers.anthropic import AnthropicDriver # noqa: E402
104
+ from ai_lib_python.drivers.gemini import GeminiDriver # noqa: E402
105
+ from ai_lib_python.drivers.openai import OpenAiDriver # noqa: E402
106
+
107
+
108
+ def create_driver(
109
+ api_style: ApiStyle,
110
+ provider_id: str,
111
+ capabilities: list[Capability] | None = None,
112
+ ) -> ProviderDriver:
113
+ """Factory: create the appropriate driver from an API style.
114
+
115
+ ``Custom`` falls back to OpenAI-compatible, which covers most
116
+ providers that follow the OpenAI chat completions format (DeepSeek,
117
+ Moonshot, Zhipu, etc.).
118
+ """
119
+ caps = capabilities or []
120
+ match api_style:
121
+ case ApiStyle.OPENAI_COMPATIBLE | ApiStyle.CUSTOM:
122
+ return OpenAiDriver(provider_id=provider_id, capabilities=caps)
123
+ case ApiStyle.ANTHROPIC_MESSAGES:
124
+ return AnthropicDriver(provider_id=provider_id, capabilities=caps)
125
+ case ApiStyle.GEMINI_GENERATE:
126
+ return GeminiDriver(provider_id=provider_id, capabilities=caps)
127
+ case _:
128
+ return OpenAiDriver(provider_id=provider_id, capabilities=caps)
129
+
130
+
131
+ __all__ = [
132
+ "AnthropicDriver",
133
+ "DriverRequest",
134
+ "DriverResponse",
135
+ "GeminiDriver",
136
+ "OpenAiDriver",
137
+ "ProviderDriver",
138
+ "UsageInfo",
139
+ "create_driver",
140
+ ]
@@ -0,0 +1,173 @@
1
+ """Anthropic Messages API 驱动 — 实现 Anthropic 特有的请求/响应格式转换。
2
+
3
+ Anthropic Messages API driver. Key differences from OpenAI:
4
+ - System messages are a top-level ``system`` parameter, not part of ``messages``.
5
+ - Content uses typed blocks: ``[{"type": "text", "text": "..."}]``.
6
+ - Streaming uses ``event: content_block_delta`` with ``delta.text``.
7
+ - Response uses ``content[0].text`` instead of ``choices[0].message.content``.
8
+ - ``max_tokens`` is required, not optional.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ from typing import Any
15
+
16
+ from ai_lib_python.drivers import (
17
+ DriverRequest,
18
+ DriverResponse,
19
+ ProviderDriver,
20
+ UsageInfo,
21
+ )
22
+ from ai_lib_python.protocol.v2.capabilities import Capability
23
+ from ai_lib_python.protocol.v2.manifest import ApiStyle
24
+ from ai_lib_python.types.events import StreamingEvent
25
+ from ai_lib_python.types.message import Message
26
+
27
+ _DEFAULT_MAX_TOKENS = 4096
28
+
29
+ # Anthropic stop_reason → AI-Protocol normalized finish_reason
30
+ _STOP_REASON_MAP: dict[str, str] = {
31
+ "end_turn": "stop",
32
+ "max_tokens": "length",
33
+ "tool_use": "tool_calls",
34
+ }
35
+
36
+
37
+ class AnthropicDriver(ProviderDriver):
38
+ """Anthropic Messages API driver."""
39
+
40
+ def __init__(
41
+ self,
42
+ provider_id: str,
43
+ capabilities: list[Capability] | None = None,
44
+ ) -> None:
45
+ self._provider_id = provider_id
46
+ self._capabilities = capabilities or []
47
+
48
+ @property
49
+ def provider_id(self) -> str:
50
+ return self._provider_id
51
+
52
+ @property
53
+ def api_style(self) -> ApiStyle:
54
+ return ApiStyle.ANTHROPIC_MESSAGES
55
+
56
+ def build_request(
57
+ self,
58
+ messages: list[Message],
59
+ model: str,
60
+ *,
61
+ temperature: float | None = None,
62
+ max_tokens: int | None = None,
63
+ stream: bool = False,
64
+ extra: dict[str, Any] | None = None,
65
+ ) -> DriverRequest:
66
+ system_text, msgs = self._split_system(messages)
67
+
68
+ body: dict[str, Any] = {
69
+ "model": model,
70
+ "messages": msgs,
71
+ "max_tokens": max_tokens or _DEFAULT_MAX_TOKENS,
72
+ "stream": stream,
73
+ }
74
+ if system_text:
75
+ body["system"] = system_text
76
+ if temperature is not None:
77
+ body["temperature"] = temperature
78
+ if extra:
79
+ body.update(extra)
80
+
81
+ headers = {"anthropic-version": "2023-06-01"}
82
+ return DriverRequest(body=body, stream=stream, headers=headers)
83
+
84
+ def parse_response(self, body: dict[str, Any]) -> DriverResponse:
85
+ # content: [{type: "text", text: "..."}]
86
+ content_blocks = body.get("content", [])
87
+ text = next(
88
+ (b["text"] for b in content_blocks if b.get("type") == "text"),
89
+ None,
90
+ )
91
+ # Normalize stop_reason
92
+ raw_reason = body.get("stop_reason", "")
93
+ finish_reason = _STOP_REASON_MAP.get(raw_reason, raw_reason) or None
94
+
95
+ usage_raw = body.get("usage")
96
+ usage = None
97
+ if usage_raw:
98
+ inp = usage_raw.get("input_tokens", 0)
99
+ out = usage_raw.get("output_tokens", 0)
100
+ usage = UsageInfo(prompt_tokens=inp, completion_tokens=out, total_tokens=inp + out)
101
+
102
+ tool_calls = [b for b in content_blocks if b.get("type") == "tool_use"]
103
+
104
+ return DriverResponse(
105
+ content=text,
106
+ finish_reason=finish_reason,
107
+ usage=usage,
108
+ tool_calls=tool_calls,
109
+ raw=body,
110
+ )
111
+
112
+ def parse_stream_event(self, data: str) -> StreamingEvent | None:
113
+ stripped = data.strip()
114
+ if not stripped:
115
+ return None
116
+
117
+ chunk = json.loads(stripped)
118
+ event_type = chunk.get("type", "")
119
+
120
+ if event_type == "content_block_delta":
121
+ delta = chunk.get("delta", {})
122
+ if text := delta.get("text"):
123
+ seq = chunk.get("index")
124
+ return StreamingEvent.content_delta(text, sequence_id=seq)
125
+ if thinking := delta.get("thinking"):
126
+ return StreamingEvent.thinking_delta(thinking)
127
+ return None
128
+
129
+ if event_type == "message_delta":
130
+ reason = chunk.get("delta", {}).get("stop_reason")
131
+ if reason:
132
+ return StreamingEvent.stream_end(_STOP_REASON_MAP.get(reason, reason))
133
+ return None
134
+
135
+ if event_type == "message_stop":
136
+ return StreamingEvent.stream_end("stop")
137
+
138
+ if event_type == "error":
139
+ return StreamingEvent.stream_error(chunk.get("error"))
140
+
141
+ return None
142
+
143
+ def supported_capabilities(self) -> list[Capability]:
144
+ return list(self._capabilities)
145
+
146
+ def is_stream_done(self, _data: str) -> bool:
147
+ # Anthropic signals done via event type, not a sentinel string.
148
+ return False
149
+
150
+ # -- internal helpers ------------------------------------------------
151
+
152
+ @staticmethod
153
+ def _split_system(messages: list[Message]) -> tuple[str | None, list[dict[str, Any]]]:
154
+ """Extract system text and convert remaining messages to Anthropic format."""
155
+ system_parts: list[str] = []
156
+ msgs: list[dict[str, Any]] = []
157
+
158
+ for m in messages:
159
+ role = m.role if isinstance(m.role, str) else m.role.value
160
+ if role == "system":
161
+ if isinstance(m.content, str):
162
+ system_parts.append(m.content)
163
+ continue
164
+
165
+ if isinstance(m.content, str):
166
+ content: Any = [{"type": "text", "text": m.content}]
167
+ else:
168
+ content = [b.model_dump(by_alias=True) for b in m.content]
169
+
170
+ msgs.append({"role": role, "content": content})
171
+
172
+ system_text = "\n\n".join(system_parts) if system_parts else None
173
+ return system_text, msgs