toolstream 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolstream/__init__.py +44 -0
- toolstream/_agent.py +215 -0
- toolstream/_builtin_tools.py +115 -0
- toolstream/_context.py +14 -0
- toolstream/_direct.py +292 -0
- toolstream/_invoke.py +126 -0
- toolstream/_protocol.py +87 -0
- toolstream/_schema.py +259 -0
- toolstream/_session.py +176 -0
- toolstream/_tools.py +109 -0
- toolstream/config.py +26 -0
- toolstream/events.py +63 -0
- toolstream/py.typed +0 -0
- toolstream-0.1.0.dist-info/METADATA +7 -0
- toolstream-0.1.0.dist-info/RECORD +16 -0
- toolstream-0.1.0.dist-info/WHEEL +4 -0
toolstream/_invoke.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Agent invocation helpers.
|
|
2
|
+
|
|
3
|
+
Build a SessionConfig from an AgentDefinition and yield a ready-to-use
|
|
4
|
+
session (async or sync).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import contextlib
|
|
10
|
+
from collections.abc import AsyncIterator, Iterator
|
|
11
|
+
from typing import Callable
|
|
12
|
+
|
|
13
|
+
from ._agent import AgentDefinition, ToolRef, resolve_prompt
|
|
14
|
+
from ._schema import _generate_schema
|
|
15
|
+
from ._session import AsyncSession, SyncSession
|
|
16
|
+
from ._tools import Tool
|
|
17
|
+
from .config import SessionConfig
|
|
18
|
+
|
|
19
|
+
__all__ = ["invoke_agent", "invoke_agent_sync"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _filter_tools(
|
|
23
|
+
definition: AgentDefinition,
|
|
24
|
+
available_tools: dict[str, Callable] | None,
|
|
25
|
+
) -> list[Tool] | None:
|
|
26
|
+
"""Match ToolRefs in *definition* against *available_tools* handlers.
|
|
27
|
+
|
|
28
|
+
Returns None when the definition declares no tools or no tools dict
|
|
29
|
+
is provided. Raises ValueError if any declared tool name is missing
|
|
30
|
+
from *available_tools*.
|
|
31
|
+
"""
|
|
32
|
+
if definition.tools is None or available_tools is None:
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
missing: list[str] = []
|
|
36
|
+
matched: list[Tool] = []
|
|
37
|
+
|
|
38
|
+
for ref in definition.tools:
|
|
39
|
+
handler = available_tools.get(ref.name)
|
|
40
|
+
if handler is None:
|
|
41
|
+
missing.append(ref.name)
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
if hasattr(handler, "_tool"):
|
|
45
|
+
matched.append(handler._tool)
|
|
46
|
+
else:
|
|
47
|
+
# Generate schema on the fly for plain callables.
|
|
48
|
+
input_schema = _generate_schema(handler, inject=set())
|
|
49
|
+
doc = handler.__doc__
|
|
50
|
+
description = doc.strip().split("\n")[0].strip() if doc else ""
|
|
51
|
+
matched.append(
|
|
52
|
+
Tool(
|
|
53
|
+
name=ref.name,
|
|
54
|
+
description=description,
|
|
55
|
+
input_schema=input_schema,
|
|
56
|
+
handler=handler,
|
|
57
|
+
inject=[],
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if missing:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Missing tool handlers: {', '.join(missing)}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return matched
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _build_invocation_config(
|
|
70
|
+
definition: AgentDefinition,
|
|
71
|
+
config: SessionConfig,
|
|
72
|
+
*,
|
|
73
|
+
variables: dict[str, str] | None = None,
|
|
74
|
+
available_tools: dict[str, Callable] | None = None,
|
|
75
|
+
) -> SessionConfig:
|
|
76
|
+
"""Create a SessionConfig tailored to *definition*."""
|
|
77
|
+
system_prompt = resolve_prompt(definition.prompt_template, variables or {})
|
|
78
|
+
tools = _filter_tools(definition, available_tools)
|
|
79
|
+
|
|
80
|
+
return SessionConfig(
|
|
81
|
+
model=definition.model or config.model,
|
|
82
|
+
api_key=config.api_key,
|
|
83
|
+
base_url=config.base_url,
|
|
84
|
+
system_prompt=system_prompt,
|
|
85
|
+
cwd=config.cwd,
|
|
86
|
+
tools=tools,
|
|
87
|
+
tool_context=config.tool_context,
|
|
88
|
+
tool_env=config.tool_env,
|
|
89
|
+
max_completion_tokens=config.max_completion_tokens,
|
|
90
|
+
sandbox=config.sandbox,
|
|
91
|
+
metadata=config.metadata,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@contextlib.asynccontextmanager
|
|
96
|
+
async def invoke_agent(
|
|
97
|
+
definition: AgentDefinition,
|
|
98
|
+
config: SessionConfig,
|
|
99
|
+
*,
|
|
100
|
+
variables: dict[str, str] | None = None,
|
|
101
|
+
available_tools: dict[str, Callable] | None = None,
|
|
102
|
+
) -> AsyncIterator[AsyncSession]:
|
|
103
|
+
"""Async context manager that yields an AsyncSession for *definition*."""
|
|
104
|
+
invocation_config = _build_invocation_config(
|
|
105
|
+
definition, config, variables=variables, available_tools=available_tools,
|
|
106
|
+
)
|
|
107
|
+
session = AsyncSession(invocation_config)
|
|
108
|
+
async with session:
|
|
109
|
+
yield session
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@contextlib.contextmanager
|
|
113
|
+
def invoke_agent_sync(
|
|
114
|
+
definition: AgentDefinition,
|
|
115
|
+
config: SessionConfig,
|
|
116
|
+
*,
|
|
117
|
+
variables: dict[str, str] | None = None,
|
|
118
|
+
available_tools: dict[str, Callable] | None = None,
|
|
119
|
+
) -> Iterator[SyncSession]:
|
|
120
|
+
"""Sync context manager that yields a SyncSession for *definition*."""
|
|
121
|
+
invocation_config = _build_invocation_config(
|
|
122
|
+
definition, config, variables=variables, available_tools=available_tools,
|
|
123
|
+
)
|
|
124
|
+
session = SyncSession(invocation_config)
|
|
125
|
+
with session:
|
|
126
|
+
yield session
|
toolstream/_protocol.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from .events import Error, StepFinish, StepStart, Text, ToolUse
|
|
6
|
+
|
|
7
|
+
Event = StepStart | Text | ToolUse | StepFinish | Error
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def parse_event(line: str) -> Event | None:
|
|
11
|
+
"""Parse a single NDJSON line into a typed event.
|
|
12
|
+
|
|
13
|
+
Returns None for unparseable or unrecognized lines.
|
|
14
|
+
"""
|
|
15
|
+
line = line.strip()
|
|
16
|
+
if not line:
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
data = json.loads(line)
|
|
21
|
+
except json.JSONDecodeError:
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
if not isinstance(data, dict):
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
event_type = data.get("type")
|
|
28
|
+
session_id = data.get("sessionID", "")
|
|
29
|
+
timestamp = data.get("timestamp", 0)
|
|
30
|
+
part = data.get("part", {})
|
|
31
|
+
|
|
32
|
+
if event_type == "step_start":
|
|
33
|
+
return StepStart(
|
|
34
|
+
session_id=session_id,
|
|
35
|
+
message_id=part.get("messageID", ""),
|
|
36
|
+
timestamp=timestamp,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if event_type == "text":
|
|
40
|
+
return Text(
|
|
41
|
+
session_id=session_id,
|
|
42
|
+
message_id=part.get("messageID", ""),
|
|
43
|
+
text=part.get("text", ""),
|
|
44
|
+
timestamp=timestamp,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if event_type == "tool_use":
|
|
48
|
+
state = part.get("state", {})
|
|
49
|
+
return ToolUse(
|
|
50
|
+
session_id=session_id,
|
|
51
|
+
message_id=part.get("messageID", ""),
|
|
52
|
+
tool=part.get("tool", ""),
|
|
53
|
+
call_id=part.get("callID", ""),
|
|
54
|
+
status=state.get("status", ""),
|
|
55
|
+
input=state.get("input", {}),
|
|
56
|
+
output=state.get("output", ""),
|
|
57
|
+
title=part.get("title", ""),
|
|
58
|
+
timestamp=timestamp,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if event_type == "step_finish":
|
|
62
|
+
tokens = part.get("tokens", {})
|
|
63
|
+
cache = tokens.get("cache", {})
|
|
64
|
+
return StepFinish(
|
|
65
|
+
session_id=session_id,
|
|
66
|
+
message_id=part.get("messageID", ""),
|
|
67
|
+
reason=part.get("reason", ""),
|
|
68
|
+
input_tokens=tokens.get("input", 0),
|
|
69
|
+
output_tokens=tokens.get("output", 0),
|
|
70
|
+
reasoning_tokens=tokens.get("reasoning", 0),
|
|
71
|
+
cache_read_tokens=cache.get("read", 0),
|
|
72
|
+
cache_write_tokens=cache.get("write", 0),
|
|
73
|
+
cost=part.get("cost", 0.0),
|
|
74
|
+
timestamp=timestamp,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if event_type == "error":
|
|
78
|
+
error = data.get("error", {})
|
|
79
|
+
return Error(
|
|
80
|
+
session_id=session_id,
|
|
81
|
+
name=error.get("name", ""),
|
|
82
|
+
message=error.get("message", ""),
|
|
83
|
+
data=error.get("data", {}),
|
|
84
|
+
timestamp=timestamp,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return None
|
toolstream/_schema.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Generate JSON Schema from Python function type hints.
|
|
2
|
+
|
|
3
|
+
Foundation for the tool registration system. Converts function signatures
|
|
4
|
+
(parameters, type annotations, docstrings) into JSON Schema objects
|
|
5
|
+
suitable for LLM tool definitions.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
import enum
|
|
12
|
+
import inspect
|
|
13
|
+
import re
|
|
14
|
+
import typing
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _type_to_schema(annotation: type) -> dict:
|
|
18
|
+
"""Convert a Python type hint to a JSON Schema fragment."""
|
|
19
|
+
|
|
20
|
+
# Handle None / NoneType directly
|
|
21
|
+
if annotation is type(None):
|
|
22
|
+
return {"type": "null"}
|
|
23
|
+
|
|
24
|
+
# Handle missing annotation
|
|
25
|
+
if annotation is inspect.Parameter.empty:
|
|
26
|
+
return {"type": "object"}
|
|
27
|
+
|
|
28
|
+
origin = typing.get_origin(annotation)
|
|
29
|
+
args = typing.get_args(annotation)
|
|
30
|
+
|
|
31
|
+
# Optional[T] is Union[T, None]; T | None has origin types.UnionType
|
|
32
|
+
if origin is typing.Union or _is_union_type(annotation):
|
|
33
|
+
non_none = [a for a in args if a is not type(None)]
|
|
34
|
+
if len(non_none) == 1 and len(args) == 2:
|
|
35
|
+
# Optional[T] -- produce nullable schema
|
|
36
|
+
inner = _type_to_schema(non_none[0])
|
|
37
|
+
if "type" in inner:
|
|
38
|
+
t = inner["type"]
|
|
39
|
+
if isinstance(t, list):
|
|
40
|
+
if "null" not in t:
|
|
41
|
+
inner["type"] = t + ["null"]
|
|
42
|
+
else:
|
|
43
|
+
inner["type"] = [t, "null"]
|
|
44
|
+
elif "enum" in inner:
|
|
45
|
+
inner["enum"] = inner["enum"] + [None]
|
|
46
|
+
else:
|
|
47
|
+
inner["type"] = ["object", "null"]
|
|
48
|
+
return inner
|
|
49
|
+
# General Union -- not handled beyond Optional, fall through
|
|
50
|
+
return {"type": "object"}
|
|
51
|
+
|
|
52
|
+
# Literal["a", "b"]
|
|
53
|
+
if origin is typing.Literal:
|
|
54
|
+
values = list(args)
|
|
55
|
+
# Infer JSON Schema type from the literal values
|
|
56
|
+
value_types = {type(v) for v in values}
|
|
57
|
+
if value_types == {int}:
|
|
58
|
+
schema_type = "integer"
|
|
59
|
+
elif value_types == {float}:
|
|
60
|
+
schema_type = "number"
|
|
61
|
+
elif value_types == {bool}:
|
|
62
|
+
schema_type = "boolean"
|
|
63
|
+
else:
|
|
64
|
+
schema_type = "string"
|
|
65
|
+
return {"type": schema_type, "enum": values}
|
|
66
|
+
|
|
67
|
+
# Enum subclass
|
|
68
|
+
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
|
69
|
+
return {"enum": [member.value for member in annotation]}
|
|
70
|
+
|
|
71
|
+
# list[T]
|
|
72
|
+
if origin is list:
|
|
73
|
+
if args:
|
|
74
|
+
return {"type": "array", "items": _type_to_schema(args[0])}
|
|
75
|
+
return {"type": "array"}
|
|
76
|
+
|
|
77
|
+
# dict[str, T]
|
|
78
|
+
if origin is dict:
|
|
79
|
+
if args and len(args) == 2:
|
|
80
|
+
return {
|
|
81
|
+
"type": "object",
|
|
82
|
+
"additionalProperties": _type_to_schema(args[1]),
|
|
83
|
+
}
|
|
84
|
+
return {"type": "object"}
|
|
85
|
+
|
|
86
|
+
# dataclass
|
|
87
|
+
if dataclasses.is_dataclass(annotation) and isinstance(annotation, type):
|
|
88
|
+
properties = {}
|
|
89
|
+
required = []
|
|
90
|
+
# Resolve stringified annotations for dataclass fields
|
|
91
|
+
try:
|
|
92
|
+
resolved = typing.get_type_hints(annotation)
|
|
93
|
+
except Exception:
|
|
94
|
+
resolved = {}
|
|
95
|
+
for field in dataclasses.fields(annotation):
|
|
96
|
+
field_type = resolved.get(field.name, field.type)
|
|
97
|
+
prop = _type_to_schema(field_type)
|
|
98
|
+
properties[field.name] = prop
|
|
99
|
+
# Fields with default or default_factory are optional
|
|
100
|
+
has_default = (
|
|
101
|
+
field.default is not dataclasses.MISSING
|
|
102
|
+
or field.default_factory is not dataclasses.MISSING
|
|
103
|
+
)
|
|
104
|
+
if not has_default:
|
|
105
|
+
required.append(field.name)
|
|
106
|
+
schema: dict = {"type": "object", "properties": properties}
|
|
107
|
+
if required:
|
|
108
|
+
schema["required"] = required
|
|
109
|
+
return schema
|
|
110
|
+
|
|
111
|
+
# TypedDict
|
|
112
|
+
if _is_typed_dict(annotation):
|
|
113
|
+
properties = {}
|
|
114
|
+
hints = typing.get_type_hints(annotation)
|
|
115
|
+
req_keys = getattr(annotation, "__required_keys__", frozenset())
|
|
116
|
+
for name, hint in hints.items():
|
|
117
|
+
properties[name] = _type_to_schema(hint)
|
|
118
|
+
schema = {"type": "object", "properties": properties}
|
|
119
|
+
req = [k for k in hints if k in req_keys]
|
|
120
|
+
if req:
|
|
121
|
+
schema["required"] = req
|
|
122
|
+
return schema
|
|
123
|
+
|
|
124
|
+
# Primitive types
|
|
125
|
+
primitives = {
|
|
126
|
+
str: {"type": "string"},
|
|
127
|
+
int: {"type": "integer"},
|
|
128
|
+
float: {"type": "number"},
|
|
129
|
+
bool: {"type": "boolean"},
|
|
130
|
+
dict: {"type": "object"},
|
|
131
|
+
list: {"type": "array"},
|
|
132
|
+
}
|
|
133
|
+
if annotation in primitives:
|
|
134
|
+
return dict(primitives[annotation])
|
|
135
|
+
|
|
136
|
+
return {"type": "object"}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _is_union_type(annotation: type) -> bool:
|
|
140
|
+
"""Check if annotation is a PEP 604 union (X | Y)."""
|
|
141
|
+
import types
|
|
142
|
+
|
|
143
|
+
return isinstance(annotation, types.UnionType)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _is_typed_dict(annotation: type) -> bool:
|
|
147
|
+
"""Check if annotation is a TypedDict subclass."""
|
|
148
|
+
return (
|
|
149
|
+
isinstance(annotation, type)
|
|
150
|
+
and issubclass(annotation, dict)
|
|
151
|
+
and hasattr(annotation, "__annotations__")
|
|
152
|
+
and hasattr(annotation, "__required_keys__")
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# Regex for Google-style docstring param lines:
|
|
157
|
+
# Args:
|
|
158
|
+
# name: description text
|
|
159
|
+
# name (type): description text
|
|
160
|
+
_GOOGLE_ARGS_RE = re.compile(
|
|
161
|
+
r"^\s{2,}(\w+)(?:\s*\([^)]*\))?\s*:\s*(.+)", re.MULTILINE
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Regex for reST-style docstring param lines:
|
|
165
|
+
# :param name: description text
|
|
166
|
+
_REST_PARAM_RE = re.compile(
|
|
167
|
+
r"^\s*:param\s+(\w+)\s*:\s*(.+)", re.MULTILINE
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _parse_param_descriptions(fn: typing.Callable) -> dict[str, str]:
|
|
172
|
+
"""Extract parameter descriptions from a function's docstring.
|
|
173
|
+
|
|
174
|
+
Handles Google-style (Args: section) and reST-style (:param:) formats.
|
|
175
|
+
Follows __wrapped__ for functools.wraps-decorated functions.
|
|
176
|
+
"""
|
|
177
|
+
# Follow __wrapped__ chain to find the original docstring
|
|
178
|
+
target = fn
|
|
179
|
+
while hasattr(target, "__wrapped__"):
|
|
180
|
+
target = target.__wrapped__
|
|
181
|
+
|
|
182
|
+
doc = inspect.getdoc(target)
|
|
183
|
+
if not doc:
|
|
184
|
+
return {}
|
|
185
|
+
|
|
186
|
+
descriptions: dict[str, str] = {}
|
|
187
|
+
|
|
188
|
+
# Try Google-style: find the Args: section
|
|
189
|
+
args_match = re.search(r"^\s*Args?\s*:\s*$", doc, re.MULTILINE)
|
|
190
|
+
if args_match:
|
|
191
|
+
# Extract the block after "Args:" until the next section or end
|
|
192
|
+
after_args = doc[args_match.end() :]
|
|
193
|
+
# Stop at next section header (word followed by colon at start of line,
|
|
194
|
+
# or end of string)
|
|
195
|
+
section_end = re.search(r"^\S", after_args, re.MULTILINE)
|
|
196
|
+
args_block = after_args[: section_end.start()] if section_end else after_args
|
|
197
|
+
for m in _GOOGLE_ARGS_RE.finditer(args_block):
|
|
198
|
+
descriptions[m.group(1)] = m.group(2).strip()
|
|
199
|
+
|
|
200
|
+
# Try reST-style
|
|
201
|
+
for m in _REST_PARAM_RE.finditer(doc):
|
|
202
|
+
# Don't overwrite Google-style if both present
|
|
203
|
+
if m.group(1) not in descriptions:
|
|
204
|
+
descriptions[m.group(1)] = m.group(2).strip()
|
|
205
|
+
|
|
206
|
+
return descriptions
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _generate_schema(
|
|
210
|
+
fn: typing.Callable,
|
|
211
|
+
inject: set[str] | None = None,
|
|
212
|
+
) -> dict:
|
|
213
|
+
"""Generate a JSON Schema object from a function's signature.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
fn: The function to generate a schema for.
|
|
217
|
+
inject: Parameter names to exclude (they will be injected at call time).
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
A JSON Schema dict with "type", "properties", and "required" keys.
|
|
221
|
+
"""
|
|
222
|
+
inject = inject or set()
|
|
223
|
+
sig = inspect.signature(fn)
|
|
224
|
+
descriptions = _parse_param_descriptions(fn)
|
|
225
|
+
|
|
226
|
+
# Resolve stringified annotations (from `from __future__ import annotations`)
|
|
227
|
+
try:
|
|
228
|
+
resolved_hints = typing.get_type_hints(fn)
|
|
229
|
+
except Exception:
|
|
230
|
+
resolved_hints = {}
|
|
231
|
+
|
|
232
|
+
properties: dict[str, dict] = {}
|
|
233
|
+
required: list[str] = []
|
|
234
|
+
|
|
235
|
+
for name, param in sig.parameters.items():
|
|
236
|
+
# Skip self/cls
|
|
237
|
+
if name in ("self", "cls"):
|
|
238
|
+
continue
|
|
239
|
+
# Skip injected parameters
|
|
240
|
+
if name in inject:
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
annotation = resolved_hints.get(name, param.annotation)
|
|
244
|
+
prop = _type_to_schema(annotation)
|
|
245
|
+
|
|
246
|
+
# Add description if available
|
|
247
|
+
if name in descriptions:
|
|
248
|
+
prop["description"] = descriptions[name]
|
|
249
|
+
|
|
250
|
+
properties[name] = prop
|
|
251
|
+
|
|
252
|
+
# Parameters with no default are required
|
|
253
|
+
if param.default is inspect.Parameter.empty:
|
|
254
|
+
required.append(name)
|
|
255
|
+
|
|
256
|
+
schema: dict = {"type": "object", "properties": properties}
|
|
257
|
+
if required:
|
|
258
|
+
schema["required"] = required
|
|
259
|
+
return schema
|
toolstream/_session.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import queue
|
|
5
|
+
import threading
|
|
6
|
+
from collections.abc import AsyncIterator, Iterator
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from ._direct import DirectClient
|
|
10
|
+
from .config import SessionConfig
|
|
11
|
+
from .events import Result, StepFinish
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AsyncSession:
|
|
15
|
+
"""Async session wrapping the direct LLM API client."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, config: SessionConfig):
|
|
18
|
+
self._config = config
|
|
19
|
+
self._direct = DirectClient(
|
|
20
|
+
config,
|
|
21
|
+
tools=config.tools,
|
|
22
|
+
tool_context=config.tool_context,
|
|
23
|
+
max_completion_tokens=config.max_completion_tokens,
|
|
24
|
+
)
|
|
25
|
+
self._turn_count = 0
|
|
26
|
+
self._total_cost = 0.0
|
|
27
|
+
self._total_input_tokens = 0
|
|
28
|
+
self._total_output_tokens = 0
|
|
29
|
+
|
|
30
|
+
async def __aenter__(self) -> AsyncSession:
|
|
31
|
+
return self
|
|
32
|
+
|
|
33
|
+
async def __aexit__(self, *args: Any) -> None:
|
|
34
|
+
await self.close()
|
|
35
|
+
|
|
36
|
+
async def send(self, message: str) -> AsyncIterator[Any]:
|
|
37
|
+
"""Send a message and yield events. Yields a Result summary at the end."""
|
|
38
|
+
async for event in self._send_direct(message):
|
|
39
|
+
yield event
|
|
40
|
+
|
|
41
|
+
async def _send_direct(self, message: str) -> AsyncIterator[Any]:
|
|
42
|
+
"""Send via direct API client."""
|
|
43
|
+
self._turn_count += 1
|
|
44
|
+
|
|
45
|
+
turn_input_tokens = 0
|
|
46
|
+
turn_output_tokens = 0
|
|
47
|
+
turn_reasoning_tokens = 0
|
|
48
|
+
turn_cache_read_tokens = 0
|
|
49
|
+
turn_cache_write_tokens = 0
|
|
50
|
+
turn_cost = 0.0
|
|
51
|
+
steps = 0
|
|
52
|
+
|
|
53
|
+
async for event in self._direct.send(message):
|
|
54
|
+
if isinstance(event, StepFinish):
|
|
55
|
+
turn_input_tokens += event.input_tokens
|
|
56
|
+
turn_output_tokens += event.output_tokens
|
|
57
|
+
turn_reasoning_tokens += event.reasoning_tokens
|
|
58
|
+
turn_cache_read_tokens += event.cache_read_tokens
|
|
59
|
+
turn_cache_write_tokens += event.cache_write_tokens
|
|
60
|
+
turn_cost += event.cost
|
|
61
|
+
steps += 1
|
|
62
|
+
yield event
|
|
63
|
+
|
|
64
|
+
# Update totals
|
|
65
|
+
self._total_input_tokens += turn_input_tokens
|
|
66
|
+
self._total_output_tokens += turn_output_tokens
|
|
67
|
+
self._total_cost += turn_cost
|
|
68
|
+
|
|
69
|
+
# Yield result summary
|
|
70
|
+
yield Result(
|
|
71
|
+
session_id=self._direct.session_id,
|
|
72
|
+
total_input_tokens=turn_input_tokens,
|
|
73
|
+
total_output_tokens=turn_output_tokens,
|
|
74
|
+
total_cost=turn_cost,
|
|
75
|
+
steps=steps,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
async def close(self) -> None:
|
|
79
|
+
"""Close the session and clean up."""
|
|
80
|
+
await self._direct.close()
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def session_id(self) -> str | None:
|
|
84
|
+
return self._direct.session_id
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def total_cost(self) -> float:
|
|
88
|
+
return self._total_cost
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def turn_count(self) -> int:
|
|
92
|
+
return self._turn_count
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SyncSession:
|
|
96
|
+
"""Sync wrapper around AsyncSession using a dedicated event loop thread."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, config: SessionConfig):
|
|
99
|
+
self._config = config
|
|
100
|
+
self._async_session: AsyncSession | None = None
|
|
101
|
+
self._loop: asyncio.AbstractEventLoop | None = None
|
|
102
|
+
self._thread: threading.Thread | None = None
|
|
103
|
+
|
|
104
|
+
def _start_loop(self) -> None:
|
|
105
|
+
"""Run the event loop in a background thread."""
|
|
106
|
+
assert self._loop is not None
|
|
107
|
+
asyncio.set_event_loop(self._loop)
|
|
108
|
+
self._loop.run_forever()
|
|
109
|
+
|
|
110
|
+
def __enter__(self) -> SyncSession:
|
|
111
|
+
self._loop = asyncio.new_event_loop()
|
|
112
|
+
self._thread = threading.Thread(target=self._start_loop, daemon=True)
|
|
113
|
+
self._thread.start()
|
|
114
|
+
self._async_session = AsyncSession(self._config)
|
|
115
|
+
return self
|
|
116
|
+
|
|
117
|
+
def __exit__(self, *args: Any) -> None:
|
|
118
|
+
self.close()
|
|
119
|
+
|
|
120
|
+
def _run_coroutine(self, coro: Any) -> Any:
|
|
121
|
+
"""Run a coroutine on the background event loop and return the result."""
|
|
122
|
+
assert self._loop is not None
|
|
123
|
+
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
124
|
+
return future.result()
|
|
125
|
+
|
|
126
|
+
def send(self, message: str) -> Iterator[Any]:
|
|
127
|
+
"""Send a message and yield events incrementally via a queue bridge."""
|
|
128
|
+
assert self._async_session is not None
|
|
129
|
+
assert self._loop is not None
|
|
130
|
+
|
|
131
|
+
q: queue.Queue = queue.Queue()
|
|
132
|
+
sentinel = object()
|
|
133
|
+
|
|
134
|
+
async def _produce() -> None:
|
|
135
|
+
try:
|
|
136
|
+
async for event in self._async_session.send(message): # type: ignore[union-attr]
|
|
137
|
+
q.put(event)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
q.put(e)
|
|
140
|
+
finally:
|
|
141
|
+
q.put(sentinel)
|
|
142
|
+
|
|
143
|
+
asyncio.run_coroutine_threadsafe(_produce(), self._loop)
|
|
144
|
+
|
|
145
|
+
while True:
|
|
146
|
+
item = q.get()
|
|
147
|
+
if item is sentinel:
|
|
148
|
+
break
|
|
149
|
+
if isinstance(item, Exception):
|
|
150
|
+
raise item
|
|
151
|
+
yield item
|
|
152
|
+
|
|
153
|
+
def close(self) -> None:
|
|
154
|
+
"""Close the session and shut down the event loop."""
|
|
155
|
+
if self._async_session is not None:
|
|
156
|
+
self._run_coroutine(self._async_session.close())
|
|
157
|
+
self._async_session = None
|
|
158
|
+
if self._loop is not None:
|
|
159
|
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
160
|
+
if self._thread is not None:
|
|
161
|
+
self._thread.join(timeout=5.0)
|
|
162
|
+
self._loop.close()
|
|
163
|
+
self._loop = None
|
|
164
|
+
self._thread = None
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def session_id(self) -> str | None:
|
|
168
|
+
return self._async_session.session_id if self._async_session else None
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def total_cost(self) -> float:
|
|
172
|
+
return self._async_session.total_cost if self._async_session else 0.0
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def turn_count(self) -> int:
|
|
176
|
+
return self._async_session.turn_count if self._async_session else 0
|