toolstream 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
toolstream/_invoke.py ADDED
@@ -0,0 +1,126 @@
1
+ """Agent invocation helpers.
2
+
3
+ Build a SessionConfig from an AgentDefinition and yield a ready-to-use
4
+ session (async or sync).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import contextlib
10
+ from collections.abc import AsyncIterator, Iterator
11
+ from typing import Callable
12
+
13
+ from ._agent import AgentDefinition, ToolRef, resolve_prompt
14
+ from ._schema import _generate_schema
15
+ from ._session import AsyncSession, SyncSession
16
+ from ._tools import Tool
17
+ from .config import SessionConfig
18
+
19
+ __all__ = ["invoke_agent", "invoke_agent_sync"]
20
+
21
+
22
+ def _filter_tools(
23
+ definition: AgentDefinition,
24
+ available_tools: dict[str, Callable] | None,
25
+ ) -> list[Tool] | None:
26
+ """Match ToolRefs in *definition* against *available_tools* handlers.
27
+
28
+ Returns None when the definition declares no tools or no tools dict
29
+ is provided. Raises ValueError if any declared tool name is missing
30
+ from *available_tools*.
31
+ """
32
+ if definition.tools is None or available_tools is None:
33
+ return None
34
+
35
+ missing: list[str] = []
36
+ matched: list[Tool] = []
37
+
38
+ for ref in definition.tools:
39
+ handler = available_tools.get(ref.name)
40
+ if handler is None:
41
+ missing.append(ref.name)
42
+ continue
43
+
44
+ if hasattr(handler, "_tool"):
45
+ matched.append(handler._tool)
46
+ else:
47
+ # Generate schema on the fly for plain callables.
48
+ input_schema = _generate_schema(handler, inject=set())
49
+ doc = handler.__doc__
50
+ description = doc.strip().split("\n")[0].strip() if doc else ""
51
+ matched.append(
52
+ Tool(
53
+ name=ref.name,
54
+ description=description,
55
+ input_schema=input_schema,
56
+ handler=handler,
57
+ inject=[],
58
+ )
59
+ )
60
+
61
+ if missing:
62
+ raise ValueError(
63
+ f"Missing tool handlers: {', '.join(missing)}"
64
+ )
65
+
66
+ return matched
67
+
68
+
69
+ def _build_invocation_config(
70
+ definition: AgentDefinition,
71
+ config: SessionConfig,
72
+ *,
73
+ variables: dict[str, str] | None = None,
74
+ available_tools: dict[str, Callable] | None = None,
75
+ ) -> SessionConfig:
76
+ """Create a SessionConfig tailored to *definition*."""
77
+ system_prompt = resolve_prompt(definition.prompt_template, variables or {})
78
+ tools = _filter_tools(definition, available_tools)
79
+
80
+ return SessionConfig(
81
+ model=definition.model or config.model,
82
+ api_key=config.api_key,
83
+ base_url=config.base_url,
84
+ system_prompt=system_prompt,
85
+ cwd=config.cwd,
86
+ tools=tools,
87
+ tool_context=config.tool_context,
88
+ tool_env=config.tool_env,
89
+ max_completion_tokens=config.max_completion_tokens,
90
+ sandbox=config.sandbox,
91
+ metadata=config.metadata,
92
+ )
93
+
94
+
95
+ @contextlib.asynccontextmanager
96
+ async def invoke_agent(
97
+ definition: AgentDefinition,
98
+ config: SessionConfig,
99
+ *,
100
+ variables: dict[str, str] | None = None,
101
+ available_tools: dict[str, Callable] | None = None,
102
+ ) -> AsyncIterator[AsyncSession]:
103
+ """Async context manager that yields an AsyncSession for *definition*."""
104
+ invocation_config = _build_invocation_config(
105
+ definition, config, variables=variables, available_tools=available_tools,
106
+ )
107
+ session = AsyncSession(invocation_config)
108
+ async with session:
109
+ yield session
110
+
111
+
112
+ @contextlib.contextmanager
113
+ def invoke_agent_sync(
114
+ definition: AgentDefinition,
115
+ config: SessionConfig,
116
+ *,
117
+ variables: dict[str, str] | None = None,
118
+ available_tools: dict[str, Callable] | None = None,
119
+ ) -> Iterator[SyncSession]:
120
+ """Sync context manager that yields a SyncSession for *definition*."""
121
+ invocation_config = _build_invocation_config(
122
+ definition, config, variables=variables, available_tools=available_tools,
123
+ )
124
+ session = SyncSession(invocation_config)
125
+ with session:
126
+ yield session
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+
5
+ from .events import Error, StepFinish, StepStart, Text, ToolUse
6
+
7
+ Event = StepStart | Text | ToolUse | StepFinish | Error
8
+
9
+
10
+ def parse_event(line: str) -> Event | None:
11
+ """Parse a single NDJSON line into a typed event.
12
+
13
+ Returns None for unparseable or unrecognized lines.
14
+ """
15
+ line = line.strip()
16
+ if not line:
17
+ return None
18
+
19
+ try:
20
+ data = json.loads(line)
21
+ except json.JSONDecodeError:
22
+ return None
23
+
24
+ if not isinstance(data, dict):
25
+ return None
26
+
27
+ event_type = data.get("type")
28
+ session_id = data.get("sessionID", "")
29
+ timestamp = data.get("timestamp", 0)
30
+ part = data.get("part", {})
31
+
32
+ if event_type == "step_start":
33
+ return StepStart(
34
+ session_id=session_id,
35
+ message_id=part.get("messageID", ""),
36
+ timestamp=timestamp,
37
+ )
38
+
39
+ if event_type == "text":
40
+ return Text(
41
+ session_id=session_id,
42
+ message_id=part.get("messageID", ""),
43
+ text=part.get("text", ""),
44
+ timestamp=timestamp,
45
+ )
46
+
47
+ if event_type == "tool_use":
48
+ state = part.get("state", {})
49
+ return ToolUse(
50
+ session_id=session_id,
51
+ message_id=part.get("messageID", ""),
52
+ tool=part.get("tool", ""),
53
+ call_id=part.get("callID", ""),
54
+ status=state.get("status", ""),
55
+ input=state.get("input", {}),
56
+ output=state.get("output", ""),
57
+ title=part.get("title", ""),
58
+ timestamp=timestamp,
59
+ )
60
+
61
+ if event_type == "step_finish":
62
+ tokens = part.get("tokens", {})
63
+ cache = tokens.get("cache", {})
64
+ return StepFinish(
65
+ session_id=session_id,
66
+ message_id=part.get("messageID", ""),
67
+ reason=part.get("reason", ""),
68
+ input_tokens=tokens.get("input", 0),
69
+ output_tokens=tokens.get("output", 0),
70
+ reasoning_tokens=tokens.get("reasoning", 0),
71
+ cache_read_tokens=cache.get("read", 0),
72
+ cache_write_tokens=cache.get("write", 0),
73
+ cost=part.get("cost", 0.0),
74
+ timestamp=timestamp,
75
+ )
76
+
77
+ if event_type == "error":
78
+ error = data.get("error", {})
79
+ return Error(
80
+ session_id=session_id,
81
+ name=error.get("name", ""),
82
+ message=error.get("message", ""),
83
+ data=error.get("data", {}),
84
+ timestamp=timestamp,
85
+ )
86
+
87
+ return None
toolstream/_schema.py ADDED
@@ -0,0 +1,259 @@
1
+ """Generate JSON Schema from Python function type hints.
2
+
3
+ Foundation for the tool registration system. Converts function signatures
4
+ (parameters, type annotations, docstrings) into JSON Schema objects
5
+ suitable for LLM tool definitions.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import dataclasses
11
+ import enum
12
+ import inspect
13
+ import re
14
+ import typing
15
+
16
+
17
+ def _type_to_schema(annotation: type) -> dict:
18
+ """Convert a Python type hint to a JSON Schema fragment."""
19
+
20
+ # Handle None / NoneType directly
21
+ if annotation is type(None):
22
+ return {"type": "null"}
23
+
24
+ # Handle missing annotation
25
+ if annotation is inspect.Parameter.empty:
26
+ return {"type": "object"}
27
+
28
+ origin = typing.get_origin(annotation)
29
+ args = typing.get_args(annotation)
30
+
31
+ # Optional[T] is Union[T, None]; T | None has origin types.UnionType
32
+ if origin is typing.Union or _is_union_type(annotation):
33
+ non_none = [a for a in args if a is not type(None)]
34
+ if len(non_none) == 1 and len(args) == 2:
35
+ # Optional[T] -- produce nullable schema
36
+ inner = _type_to_schema(non_none[0])
37
+ if "type" in inner:
38
+ t = inner["type"]
39
+ if isinstance(t, list):
40
+ if "null" not in t:
41
+ inner["type"] = t + ["null"]
42
+ else:
43
+ inner["type"] = [t, "null"]
44
+ elif "enum" in inner:
45
+ inner["enum"] = inner["enum"] + [None]
46
+ else:
47
+ inner["type"] = ["object", "null"]
48
+ return inner
49
+ # General Union -- not handled beyond Optional, fall through
50
+ return {"type": "object"}
51
+
52
+ # Literal["a", "b"]
53
+ if origin is typing.Literal:
54
+ values = list(args)
55
+ # Infer JSON Schema type from the literal values
56
+ value_types = {type(v) for v in values}
57
+ if value_types == {int}:
58
+ schema_type = "integer"
59
+ elif value_types == {float}:
60
+ schema_type = "number"
61
+ elif value_types == {bool}:
62
+ schema_type = "boolean"
63
+ else:
64
+ schema_type = "string"
65
+ return {"type": schema_type, "enum": values}
66
+
67
+ # Enum subclass
68
+ if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
69
+ return {"enum": [member.value for member in annotation]}
70
+
71
+ # list[T]
72
+ if origin is list:
73
+ if args:
74
+ return {"type": "array", "items": _type_to_schema(args[0])}
75
+ return {"type": "array"}
76
+
77
+ # dict[str, T]
78
+ if origin is dict:
79
+ if args and len(args) == 2:
80
+ return {
81
+ "type": "object",
82
+ "additionalProperties": _type_to_schema(args[1]),
83
+ }
84
+ return {"type": "object"}
85
+
86
+ # dataclass
87
+ if dataclasses.is_dataclass(annotation) and isinstance(annotation, type):
88
+ properties = {}
89
+ required = []
90
+ # Resolve stringified annotations for dataclass fields
91
+ try:
92
+ resolved = typing.get_type_hints(annotation)
93
+ except Exception:
94
+ resolved = {}
95
+ for field in dataclasses.fields(annotation):
96
+ field_type = resolved.get(field.name, field.type)
97
+ prop = _type_to_schema(field_type)
98
+ properties[field.name] = prop
99
+ # Fields with default or default_factory are optional
100
+ has_default = (
101
+ field.default is not dataclasses.MISSING
102
+ or field.default_factory is not dataclasses.MISSING
103
+ )
104
+ if not has_default:
105
+ required.append(field.name)
106
+ schema: dict = {"type": "object", "properties": properties}
107
+ if required:
108
+ schema["required"] = required
109
+ return schema
110
+
111
+ # TypedDict
112
+ if _is_typed_dict(annotation):
113
+ properties = {}
114
+ hints = typing.get_type_hints(annotation)
115
+ req_keys = getattr(annotation, "__required_keys__", frozenset())
116
+ for name, hint in hints.items():
117
+ properties[name] = _type_to_schema(hint)
118
+ schema = {"type": "object", "properties": properties}
119
+ req = [k for k in hints if k in req_keys]
120
+ if req:
121
+ schema["required"] = req
122
+ return schema
123
+
124
+ # Primitive types
125
+ primitives = {
126
+ str: {"type": "string"},
127
+ int: {"type": "integer"},
128
+ float: {"type": "number"},
129
+ bool: {"type": "boolean"},
130
+ dict: {"type": "object"},
131
+ list: {"type": "array"},
132
+ }
133
+ if annotation in primitives:
134
+ return dict(primitives[annotation])
135
+
136
+ return {"type": "object"}
137
+
138
+
139
+ def _is_union_type(annotation: type) -> bool:
140
+ """Check if annotation is a PEP 604 union (X | Y)."""
141
+ import types
142
+
143
+ return isinstance(annotation, types.UnionType)
144
+
145
+
146
+ def _is_typed_dict(annotation: type) -> bool:
147
+ """Check if annotation is a TypedDict subclass."""
148
+ return (
149
+ isinstance(annotation, type)
150
+ and issubclass(annotation, dict)
151
+ and hasattr(annotation, "__annotations__")
152
+ and hasattr(annotation, "__required_keys__")
153
+ )
154
+
155
+
156
+ # Regex for Google-style docstring param lines:
157
+ # Args:
158
+ # name: description text
159
+ # name (type): description text
160
+ _GOOGLE_ARGS_RE = re.compile(
161
+ r"^\s{2,}(\w+)(?:\s*\([^)]*\))?\s*:\s*(.+)", re.MULTILINE
162
+ )
163
+
164
+ # Regex for reST-style docstring param lines:
165
+ # :param name: description text
166
+ _REST_PARAM_RE = re.compile(
167
+ r"^\s*:param\s+(\w+)\s*:\s*(.+)", re.MULTILINE
168
+ )
169
+
170
+
171
+ def _parse_param_descriptions(fn: typing.Callable) -> dict[str, str]:
172
+ """Extract parameter descriptions from a function's docstring.
173
+
174
+ Handles Google-style (Args: section) and reST-style (:param:) formats.
175
+ Follows __wrapped__ for functools.wraps-decorated functions.
176
+ """
177
+ # Follow __wrapped__ chain to find the original docstring
178
+ target = fn
179
+ while hasattr(target, "__wrapped__"):
180
+ target = target.__wrapped__
181
+
182
+ doc = inspect.getdoc(target)
183
+ if not doc:
184
+ return {}
185
+
186
+ descriptions: dict[str, str] = {}
187
+
188
+ # Try Google-style: find the Args: section
189
+ args_match = re.search(r"^\s*Args?\s*:\s*$", doc, re.MULTILINE)
190
+ if args_match:
191
+ # Extract the block after "Args:" until the next section or end
192
+ after_args = doc[args_match.end() :]
193
+ # Stop at next section header (word followed by colon at start of line,
194
+ # or end of string)
195
+ section_end = re.search(r"^\S", after_args, re.MULTILINE)
196
+ args_block = after_args[: section_end.start()] if section_end else after_args
197
+ for m in _GOOGLE_ARGS_RE.finditer(args_block):
198
+ descriptions[m.group(1)] = m.group(2).strip()
199
+
200
+ # Try reST-style
201
+ for m in _REST_PARAM_RE.finditer(doc):
202
+ # Don't overwrite Google-style if both present
203
+ if m.group(1) not in descriptions:
204
+ descriptions[m.group(1)] = m.group(2).strip()
205
+
206
+ return descriptions
207
+
208
+
209
+ def _generate_schema(
210
+ fn: typing.Callable,
211
+ inject: set[str] | None = None,
212
+ ) -> dict:
213
+ """Generate a JSON Schema object from a function's signature.
214
+
215
+ Args:
216
+ fn: The function to generate a schema for.
217
+ inject: Parameter names to exclude (they will be injected at call time).
218
+
219
+ Returns:
220
+ A JSON Schema dict with "type", "properties", and "required" keys.
221
+ """
222
+ inject = inject or set()
223
+ sig = inspect.signature(fn)
224
+ descriptions = _parse_param_descriptions(fn)
225
+
226
+ # Resolve stringified annotations (from `from __future__ import annotations`)
227
+ try:
228
+ resolved_hints = typing.get_type_hints(fn)
229
+ except Exception:
230
+ resolved_hints = {}
231
+
232
+ properties: dict[str, dict] = {}
233
+ required: list[str] = []
234
+
235
+ for name, param in sig.parameters.items():
236
+ # Skip self/cls
237
+ if name in ("self", "cls"):
238
+ continue
239
+ # Skip injected parameters
240
+ if name in inject:
241
+ continue
242
+
243
+ annotation = resolved_hints.get(name, param.annotation)
244
+ prop = _type_to_schema(annotation)
245
+
246
+ # Add description if available
247
+ if name in descriptions:
248
+ prop["description"] = descriptions[name]
249
+
250
+ properties[name] = prop
251
+
252
+ # Parameters with no default are required
253
+ if param.default is inspect.Parameter.empty:
254
+ required.append(name)
255
+
256
+ schema: dict = {"type": "object", "properties": properties}
257
+ if required:
258
+ schema["required"] = required
259
+ return schema
toolstream/_session.py ADDED
@@ -0,0 +1,176 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import queue
5
+ import threading
6
+ from collections.abc import AsyncIterator, Iterator
7
+ from typing import Any
8
+
9
+ from ._direct import DirectClient
10
+ from .config import SessionConfig
11
+ from .events import Result, StepFinish
12
+
13
+
14
+ class AsyncSession:
15
+ """Async session wrapping the direct LLM API client."""
16
+
17
+ def __init__(self, config: SessionConfig):
18
+ self._config = config
19
+ self._direct = DirectClient(
20
+ config,
21
+ tools=config.tools,
22
+ tool_context=config.tool_context,
23
+ max_completion_tokens=config.max_completion_tokens,
24
+ )
25
+ self._turn_count = 0
26
+ self._total_cost = 0.0
27
+ self._total_input_tokens = 0
28
+ self._total_output_tokens = 0
29
+
30
+ async def __aenter__(self) -> AsyncSession:
31
+ return self
32
+
33
+ async def __aexit__(self, *args: Any) -> None:
34
+ await self.close()
35
+
36
+ async def send(self, message: str) -> AsyncIterator[Any]:
37
+ """Send a message and yield events. Yields a Result summary at the end."""
38
+ async for event in self._send_direct(message):
39
+ yield event
40
+
41
+ async def _send_direct(self, message: str) -> AsyncIterator[Any]:
42
+ """Send via direct API client."""
43
+ self._turn_count += 1
44
+
45
+ turn_input_tokens = 0
46
+ turn_output_tokens = 0
47
+ turn_reasoning_tokens = 0
48
+ turn_cache_read_tokens = 0
49
+ turn_cache_write_tokens = 0
50
+ turn_cost = 0.0
51
+ steps = 0
52
+
53
+ async for event in self._direct.send(message):
54
+ if isinstance(event, StepFinish):
55
+ turn_input_tokens += event.input_tokens
56
+ turn_output_tokens += event.output_tokens
57
+ turn_reasoning_tokens += event.reasoning_tokens
58
+ turn_cache_read_tokens += event.cache_read_tokens
59
+ turn_cache_write_tokens += event.cache_write_tokens
60
+ turn_cost += event.cost
61
+ steps += 1
62
+ yield event
63
+
64
+ # Update totals
65
+ self._total_input_tokens += turn_input_tokens
66
+ self._total_output_tokens += turn_output_tokens
67
+ self._total_cost += turn_cost
68
+
69
+ # Yield result summary
70
+ yield Result(
71
+ session_id=self._direct.session_id,
72
+ total_input_tokens=turn_input_tokens,
73
+ total_output_tokens=turn_output_tokens,
74
+ total_cost=turn_cost,
75
+ steps=steps,
76
+ )
77
+
78
+ async def close(self) -> None:
79
+ """Close the session and clean up."""
80
+ await self._direct.close()
81
+
82
+ @property
83
+ def session_id(self) -> str | None:
84
+ return self._direct.session_id
85
+
86
+ @property
87
+ def total_cost(self) -> float:
88
+ return self._total_cost
89
+
90
+ @property
91
+ def turn_count(self) -> int:
92
+ return self._turn_count
93
+
94
+
95
+ class SyncSession:
96
+ """Sync wrapper around AsyncSession using a dedicated event loop thread."""
97
+
98
+ def __init__(self, config: SessionConfig):
99
+ self._config = config
100
+ self._async_session: AsyncSession | None = None
101
+ self._loop: asyncio.AbstractEventLoop | None = None
102
+ self._thread: threading.Thread | None = None
103
+
104
+ def _start_loop(self) -> None:
105
+ """Run the event loop in a background thread."""
106
+ assert self._loop is not None
107
+ asyncio.set_event_loop(self._loop)
108
+ self._loop.run_forever()
109
+
110
+ def __enter__(self) -> SyncSession:
111
+ self._loop = asyncio.new_event_loop()
112
+ self._thread = threading.Thread(target=self._start_loop, daemon=True)
113
+ self._thread.start()
114
+ self._async_session = AsyncSession(self._config)
115
+ return self
116
+
117
+ def __exit__(self, *args: Any) -> None:
118
+ self.close()
119
+
120
+ def _run_coroutine(self, coro: Any) -> Any:
121
+ """Run a coroutine on the background event loop and return the result."""
122
+ assert self._loop is not None
123
+ future = asyncio.run_coroutine_threadsafe(coro, self._loop)
124
+ return future.result()
125
+
126
+ def send(self, message: str) -> Iterator[Any]:
127
+ """Send a message and yield events incrementally via a queue bridge."""
128
+ assert self._async_session is not None
129
+ assert self._loop is not None
130
+
131
+ q: queue.Queue = queue.Queue()
132
+ sentinel = object()
133
+
134
+ async def _produce() -> None:
135
+ try:
136
+ async for event in self._async_session.send(message): # type: ignore[union-attr]
137
+ q.put(event)
138
+ except Exception as e:
139
+ q.put(e)
140
+ finally:
141
+ q.put(sentinel)
142
+
143
+ asyncio.run_coroutine_threadsafe(_produce(), self._loop)
144
+
145
+ while True:
146
+ item = q.get()
147
+ if item is sentinel:
148
+ break
149
+ if isinstance(item, Exception):
150
+ raise item
151
+ yield item
152
+
153
+ def close(self) -> None:
154
+ """Close the session and shut down the event loop."""
155
+ if self._async_session is not None:
156
+ self._run_coroutine(self._async_session.close())
157
+ self._async_session = None
158
+ if self._loop is not None:
159
+ self._loop.call_soon_threadsafe(self._loop.stop)
160
+ if self._thread is not None:
161
+ self._thread.join(timeout=5.0)
162
+ self._loop.close()
163
+ self._loop = None
164
+ self._thread = None
165
+
166
+ @property
167
+ def session_id(self) -> str | None:
168
+ return self._async_session.session_id if self._async_session else None
169
+
170
+ @property
171
+ def total_cost(self) -> float:
172
+ return self._async_session.total_cost if self._async_session else 0.0
173
+
174
+ @property
175
+ def turn_count(self) -> int:
176
+ return self._async_session.turn_count if self._async_session else 0