selectools 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,245 @@
1
+ """
2
+ Provider implementations for Anthropic, Gemini, and a local fallback.
3
+
4
+ These adapters validate configuration and can be mocked or monkeypatched by
5
+ callers. They surface clear errors when SDKs or API keys are missing.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ from typing import Iterable, List
12
+
13
+ from ..env import load_default_env
14
+ from ..types import Message, Role
15
+ from .base import Provider, ProviderError
16
+
17
+
18
+ class AnthropicProvider(Provider):
19
+ """Anthropic Messages API adapter."""
20
+
21
+ name = "anthropic"
22
+ supports_streaming = True
23
+
24
+ def __init__(
25
+ self,
26
+ api_key: str | None = None,
27
+ default_model: str = "claude-3-5-sonnet-20240620",
28
+ base_url: str | None = None,
29
+ ):
30
+ load_default_env()
31
+ self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
32
+ if not self.api_key:
33
+ raise ProviderError("ANTHROPIC_API_KEY is not set. Set it in env or pass api_key.")
34
+
35
+ try:
36
+ from anthropic import Anthropic
37
+ except ImportError as exc: # noqa: BLE001
38
+ raise ProviderError("anthropic package not installed. Install with `pip install anthropic`.") from exc
39
+
40
+ self._client = Anthropic(api_key=self.api_key, base_url=base_url)
41
+ self.default_model = default_model
42
+
43
+ def complete(
44
+ self,
45
+ *,
46
+ model: str,
47
+ system_prompt: str,
48
+ messages: List[Message],
49
+ temperature: float = 0.0,
50
+ max_tokens: int = 1000,
51
+ timeout: float | None = None,
52
+ ) -> str:
53
+ payload = self._format_messages(messages)
54
+ request_args = {
55
+ "model": model or self.default_model,
56
+ "system": system_prompt,
57
+ "messages": payload,
58
+ "temperature": temperature,
59
+ "max_tokens": max_tokens,
60
+ }
61
+ if timeout is not None:
62
+ request_args["timeout"] = timeout
63
+ try:
64
+ response = self._client.messages.create(**request_args)
65
+ except Exception as exc: # noqa: BLE001
66
+ raise ProviderError(f"Anthropic completion failed: {exc}") from exc
67
+
68
+ text_chunks = [block.text for block in response.content if hasattr(block, "text")]
69
+ return "".join(text_chunks)
70
+
71
+ def stream(
72
+ self,
73
+ *,
74
+ model: str,
75
+ system_prompt: str,
76
+ messages: List[Message],
77
+ temperature: float = 0.0,
78
+ max_tokens: int = 1000,
79
+ timeout: float | None = None,
80
+ ) -> Iterable[str]:
81
+ payload = self._format_messages(messages)
82
+ request_args = {
83
+ "model": model or self.default_model,
84
+ "system": system_prompt,
85
+ "messages": payload,
86
+ "temperature": temperature,
87
+ "max_tokens": max_tokens,
88
+ "stream": True,
89
+ }
90
+ if timeout is not None:
91
+ request_args["timeout"] = timeout
92
+ try:
93
+ stream = self._client.messages.create(**request_args)
94
+ except Exception as exc: # noqa: BLE001
95
+ raise ProviderError(f"Anthropic streaming failed: {exc}") from exc
96
+
97
+ for event in stream:
98
+ if getattr(event, "type", None) == "content_block_delta":
99
+ delta = getattr(event, "delta", None)
100
+ text = getattr(delta, "text", None) if delta else None
101
+ if text:
102
+ yield text
103
+
104
+ def _format_messages(self, messages: List[Message]):
105
+ formatted = []
106
+ for message in messages:
107
+ role = message.role.value
108
+ if role == Role.TOOL.value:
109
+ role = Role.ASSISTANT.value
110
+ formatted.append({"role": role, "content": [{"type": "text", "text": message.content}]})
111
+ return formatted
112
+
113
+
114
+ class GeminiProvider(Provider):
115
+ """Google Gemini adapter using google-generativeai SDK."""
116
+
117
+ name = "gemini"
118
+ supports_streaming = True
119
+
120
+ def __init__(self, api_key: str | None = None, default_model: str = "gemini-1.5-flash"):
121
+ load_default_env()
122
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY")
123
+ if not self.api_key:
124
+ raise ProviderError("GEMINI_API_KEY is not set. Set it in env or pass api_key.")
125
+
126
+ try:
127
+ import google.generativeai as genai
128
+ except ImportError as exc: # noqa: BLE001
129
+ raise ProviderError(
130
+ "google-generativeai package not installed. Install with `pip install google-generativeai`."
131
+ ) from exc
132
+
133
+ genai.configure(api_key=self.api_key)
134
+ self._genai = genai
135
+ self.default_model = default_model
136
+
137
+ def complete(
138
+ self,
139
+ *,
140
+ model: str,
141
+ system_prompt: str,
142
+ messages: List[Message],
143
+ temperature: float = 0.0,
144
+ max_tokens: int = 1000,
145
+ timeout: float | None = None,
146
+ ) -> str:
147
+ model_obj = self._genai.GenerativeModel(model or self.default_model)
148
+ prompt_parts = self._build_prompt(system_prompt, messages)
149
+ request_options = {"timeout": timeout} if timeout is not None else None
150
+ try:
151
+ response = model_obj.generate_content(
152
+ prompt_parts,
153
+ temperature=temperature,
154
+ max_output_tokens=max_tokens,
155
+ request_options=request_options,
156
+ )
157
+ except Exception as exc: # noqa: BLE001
158
+ raise ProviderError(f"Gemini completion failed: {exc}") from exc
159
+
160
+ return response.text or ""
161
+
162
+ def stream(
163
+ self,
164
+ *,
165
+ model: str,
166
+ system_prompt: str,
167
+ messages: List[Message],
168
+ temperature: float = 0.0,
169
+ max_tokens: int = 1000,
170
+ timeout: float | None = None,
171
+ ) -> Iterable[str]:
172
+ model_obj = self._genai.GenerativeModel(model or self.default_model)
173
+ prompt_parts = self._build_prompt(system_prompt, messages)
174
+ request_options = {"timeout": timeout} if timeout is not None else None
175
+ try:
176
+ stream = model_obj.generate_content(
177
+ prompt_parts,
178
+ temperature=temperature,
179
+ max_output_tokens=max_tokens,
180
+ request_options=request_options,
181
+ stream=True,
182
+ )
183
+ except Exception as exc: # noqa: BLE001
184
+ raise ProviderError(f"Gemini streaming failed: {exc}") from exc
185
+
186
+ for chunk in stream:
187
+ if getattr(chunk, "text", None):
188
+ yield chunk.text
189
+
190
+ def _build_prompt(self, system_prompt: str, messages: List[Message]):
191
+ conversation = [system_prompt]
192
+ for message in messages:
193
+ prefix = message.role.value.capitalize()
194
+ conversation.append(f"{prefix}: {message.content}")
195
+ return "\n".join(conversation)
196
+
197
+
198
+ class LocalProvider(Provider):
199
+ """
200
+ Local fallback provider.
201
+
202
+ This does not call a model. It echoes the latest user content and is useful
203
+ for offline/manual testing or as a safe default.
204
+ """
205
+
206
+ name = "local"
207
+ supports_streaming = True
208
+
209
+ def complete(
210
+ self,
211
+ *,
212
+ model: str,
213
+ system_prompt: str,
214
+ messages: List[Message],
215
+ temperature: float = 0.0,
216
+ max_tokens: int = 1000,
217
+ timeout: float | None = None,
218
+ ) -> str:
219
+ last_user = next((m for m in reversed(messages) if m.role == Role.USER), None)
220
+ user_text = last_user.content if last_user else ""
221
+ return f"[local provider: {model}] {user_text or 'No user message provided.'}"
222
+
223
+ def stream(
224
+ self,
225
+ *,
226
+ model: str,
227
+ system_prompt: str,
228
+ messages: List[Message],
229
+ temperature: float = 0.0,
230
+ max_tokens: int = 1000,
231
+ timeout: float | None = None,
232
+ ):
233
+ text = self.complete(
234
+ model=model,
235
+ system_prompt=system_prompt,
236
+ messages=messages,
237
+ temperature=temperature,
238
+ max_tokens=max_tokens,
239
+ timeout=timeout,
240
+ )
241
+ for token in text.split():
242
+ yield token + " "
243
+
244
+
245
+ __all__ = ["AnthropicProvider", "GeminiProvider", "LocalProvider"]
toolcalling/tools.py ADDED
@@ -0,0 +1,233 @@
1
+ """
2
+ Tool metadata, schemas, and runtime validation.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ import inspect
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+
11
+
12
+ JsonSchema = Dict[str, Any]
13
+ ParameterValue = Union[str, int, float, bool, dict, list]
14
+ ParamMetadata = Dict[str, Any]
15
+
16
+
17
+ def _python_type_to_json(param_type: type) -> str:
18
+ """Map a Python type to a JSON schema type string."""
19
+ type_map = {
20
+ str: "string",
21
+ int: "integer",
22
+ float: "number",
23
+ bool: "boolean",
24
+ list: "array",
25
+ dict: "object",
26
+ }
27
+ return type_map.get(param_type, "string")
28
+
29
+
30
+ @dataclass
31
+ class ToolParameter:
32
+ """Schema definition for a single tool parameter."""
33
+
34
+ name: str
35
+ param_type: type
36
+ description: str
37
+ required: bool = True
38
+ enum: Optional[List[str]] = None
39
+
40
+ def to_schema(self) -> JsonSchema:
41
+ """Return a JSON-schema compatible definition."""
42
+ schema: JsonSchema = {
43
+ "type": _python_type_to_json(self.param_type),
44
+ "description": self.description,
45
+ }
46
+ if self.enum:
47
+ schema["enum"] = self.enum
48
+ return schema
49
+
50
+
51
+ class Tool:
52
+ """
53
+ Encapsulates a callable tool with validation and schema generation.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ name: str,
59
+ description: str,
60
+ parameters: List[ToolParameter],
61
+ function: Callable[..., str],
62
+ *,
63
+ injected_kwargs: Optional[Dict[str, Any]] = None,
64
+ config_injector: Optional[Callable[[], Dict[str, Any]]] = None,
65
+ ):
66
+ self.name = name
67
+ self.description = description
68
+ self.parameters = parameters
69
+ self.function = function
70
+ self.injected_kwargs = injected_kwargs or {}
71
+ self.config_injector = config_injector
72
+
73
+ def schema(self) -> JsonSchema:
74
+ """Return a JSON-schema style dict describing this tool."""
75
+ properties = {param.name: param.to_schema() for param in self.parameters}
76
+ required = [param.name for param in self.parameters if param.required]
77
+
78
+ return {
79
+ "name": self.name,
80
+ "description": self.description,
81
+ "parameters": {
82
+ "type": "object",
83
+ "properties": properties,
84
+ "required": required,
85
+ },
86
+ }
87
+
88
+ def _validate_single(self, param: ToolParameter, value: ParameterValue) -> Optional[str]:
89
+ """Validate a single parameter, returning an error message if invalid."""
90
+ if value is None:
91
+ return f"Parameter '{param.name}' is None"
92
+
93
+ if param.param_type is float:
94
+ if not isinstance(value, (float, int)):
95
+ return f"Parameter '{param.name}' must be a number"
96
+ return None
97
+
98
+ if not isinstance(value, param.param_type):
99
+ return f"Parameter '{param.name}' must be of type {param.param_type.__name__}, got {type(value).__name__}"
100
+ return None
101
+
102
+ def validate(self, params: Dict[str, ParameterValue]) -> Tuple[bool, Optional[str]]:
103
+ """
104
+ Validate a parameter dictionary against this tool's schema.
105
+
106
+ Returns (is_valid, error_message)
107
+ """
108
+ for param in self.parameters:
109
+ if param.required and param.name not in params:
110
+ return False, f"Missing required parameter: {param.name}"
111
+ if param.name not in params:
112
+ continue
113
+ error = self._validate_single(param, params[param.name])
114
+ if error:
115
+ return False, error
116
+ return True, None
117
+
118
+ def execute(self, params: Dict[str, ParameterValue]) -> str:
119
+ """Validate parameters then execute the underlying callable."""
120
+ is_valid, error = self.validate(params)
121
+ if not is_valid:
122
+ raise ValueError(f"Invalid parameters for tool '{self.name}': {error}")
123
+
124
+ call_args: Dict[str, Any] = dict(params)
125
+ call_args.update(self.injected_kwargs)
126
+ if self.config_injector:
127
+ call_args.update(self.config_injector() or {})
128
+
129
+ return self.function(**call_args)
130
+
131
+
132
+ def _infer_parameters_from_callable(
133
+ func: Callable[..., Any],
134
+ param_metadata: Optional[Dict[str, ParamMetadata]] = None,
135
+ ) -> List[ToolParameter]:
136
+ """Create ToolParameter objects from a callable signature and annotations."""
137
+ param_metadata = param_metadata or {}
138
+ signature = inspect.signature(func)
139
+ parameters: List[ToolParameter] = []
140
+ for name, param in signature.parameters.items():
141
+ if name.startswith("_"):
142
+ continue
143
+ annotation = param.annotation if param.annotation is not inspect._empty else str
144
+ meta = param_metadata.get(name, {})
145
+ description = meta.get("description", "")
146
+ enum = meta.get("enum")
147
+ required = param.default is inspect._empty
148
+ parameters.append(
149
+ ToolParameter(
150
+ name=name,
151
+ param_type=annotation if isinstance(annotation, type) else str,
152
+ description=description or f"Parameter '{name}'",
153
+ required=required,
154
+ enum=enum,
155
+ )
156
+ )
157
+ return parameters
158
+
159
+
160
+ def tool(
161
+ *,
162
+ name: Optional[str] = None,
163
+ description: Optional[str] = None,
164
+ param_metadata: Optional[Dict[str, ParamMetadata]] = None,
165
+ injected_kwargs: Optional[Dict[str, Any]] = None,
166
+ config_injector: Optional[Callable[[], Dict[str, Any]]] = None,
167
+ ):
168
+ """
169
+ Decorator to register a function as a Tool with schema inference.
170
+
171
+ Example:
172
+ @tool(name="search")
173
+ def search(query: str, count: int = 3) -> str:
174
+ ...
175
+ """
176
+
177
+ def wrapper(func: Callable[..., str]) -> Tool:
178
+ params = _infer_parameters_from_callable(func, param_metadata=param_metadata)
179
+ tool_obj = Tool(
180
+ name=name or func.__name__,
181
+ description=description or (func.__doc__ or "").strip() or f"Tool {func.__name__}",
182
+ parameters=params,
183
+ function=func,
184
+ injected_kwargs=injected_kwargs,
185
+ config_injector=config_injector,
186
+ )
187
+ return tool_obj
188
+
189
+ return wrapper
190
+
191
+
192
+ class ToolRegistry:
193
+ """Simple registry for reusable tool instances."""
194
+
195
+ def __init__(self):
196
+ self._tools: Dict[str, Tool] = {}
197
+
198
+ def register(self, tool_obj: Tool) -> Tool:
199
+ self._tools[tool_obj.name] = tool_obj
200
+ return tool_obj
201
+
202
+ def tool(
203
+ self,
204
+ *,
205
+ name: Optional[str] = None,
206
+ description: Optional[str] = None,
207
+ param_metadata: Optional[Dict[str, ParamMetadata]] = None,
208
+ injected_kwargs: Optional[Dict[str, Any]] = None,
209
+ config_injector: Optional[Callable[[], Dict[str, Any]]] = None,
210
+ ):
211
+ """Decorator variant that also registers the tool in this registry."""
212
+
213
+ def decorator(func: Callable[..., str]) -> Tool:
214
+ tool_obj = tool(
215
+ name=name,
216
+ description=description,
217
+ param_metadata=param_metadata,
218
+ injected_kwargs=injected_kwargs,
219
+ config_injector=config_injector,
220
+ )(func)
221
+ self.register(tool_obj)
222
+ return tool_obj
223
+
224
+ return decorator
225
+
226
+ def get(self, name: str) -> Optional[Tool]:
227
+ return self._tools.get(name)
228
+
229
+ def all(self) -> List[Tool]:
230
+ return list(self._tools.values())
231
+
232
+
233
+ __all__ = ["Tool", "ToolParameter", "ToolRegistry", "tool"]
toolcalling/types.py ADDED
@@ -0,0 +1,76 @@
1
+ """
2
+ Core message and role types for the tool-calling library.
3
+
4
+ These primitives are provider-agnostic and are reused across adapters,
5
+ the agent loop, and tests.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum
12
+ from pathlib import Path
13
+ from typing import Any, Dict, Optional
14
+ import base64
15
+
16
+
17
+ class Role(str, Enum):
18
+ """Conversation role."""
19
+
20
+ USER = "user"
21
+ ASSISTANT = "assistant"
22
+ SYSTEM = "system"
23
+ TOOL = "tool"
24
+
25
+
26
+ def _encode_image(image_path: str) -> str:
27
+ """Load and base64-encode an image from disk."""
28
+ path = Path(image_path).expanduser().resolve()
29
+ if not path.exists():
30
+ raise FileNotFoundError(f"Image file not found: {path}")
31
+
32
+ data = path.read_bytes()
33
+ return base64.b64encode(data).decode("utf-8")
34
+
35
+
36
+ @dataclass
37
+ class Message:
38
+ """
39
+ Conversation message with optional inline image payload and tool metadata.
40
+
41
+ The `image_base64` field is populated automatically when `image_path` is
42
+ provided so adapters can forward vision content to providers without
43
+ re-encoding.
44
+ """
45
+
46
+ role: Role
47
+ content: str
48
+ image_path: Optional[str] = None
49
+ tool_name: Optional[str] = None
50
+ tool_result: Optional[str] = None
51
+ image_base64: Optional[str] = field(init=False, default=None)
52
+
53
+ def __post_init__(self) -> None:
54
+ if self.image_path:
55
+ self.image_base64 = _encode_image(self.image_path)
56
+
57
+ def to_dict(self) -> Dict[str, Any]:
58
+ """Return a plain-JSON-safe representation for logging or debugging."""
59
+ return {
60
+ "role": self.role.value,
61
+ "content": self.content,
62
+ "image_base64": self.image_base64,
63
+ "tool_name": self.tool_name,
64
+ "tool_result": self.tool_result,
65
+ }
66
+
67
+
68
+ @dataclass
69
+ class ToolCall:
70
+ """Structured representation of a parsed tool call."""
71
+
72
+ tool_name: str
73
+ parameters: Dict[str, Any]
74
+
75
+
76
+ __all__ = ["Role", "Message", "ToolCall"]