toolstream 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolstream/__init__.py +44 -0
- toolstream/_agent.py +215 -0
- toolstream/_builtin_tools.py +115 -0
- toolstream/_context.py +14 -0
- toolstream/_direct.py +292 -0
- toolstream/_invoke.py +126 -0
- toolstream/_protocol.py +87 -0
- toolstream/_schema.py +259 -0
- toolstream/_session.py +176 -0
- toolstream/_tools.py +109 -0
- toolstream/config.py +26 -0
- toolstream/events.py +63 -0
- toolstream/py.typed +0 -0
- toolstream-0.1.0.dist-info/METADATA +7 -0
- toolstream-0.1.0.dist-info/RECORD +16 -0
- toolstream-0.1.0.dist-info/WHEEL +4 -0
toolstream/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from ._agent import (
|
|
2
|
+
AgentDefinition,
|
|
3
|
+
AgentSandbox,
|
|
4
|
+
ToolRef,
|
|
5
|
+
discover_agents,
|
|
6
|
+
load_agent,
|
|
7
|
+
resolve_prompt,
|
|
8
|
+
)
|
|
9
|
+
from ._context import ToolContext
|
|
10
|
+
from ._invoke import invoke_agent, invoke_agent_sync
|
|
11
|
+
from ._session import AsyncSession, SyncSession
|
|
12
|
+
from ._tools import Tool, collect_tools, tool
|
|
13
|
+
from .config import SessionConfig
|
|
14
|
+
from .events import Error, Result, StepFinish, StepStart, Text, ToolUse
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
# Agent definitions
|
|
18
|
+
"AgentDefinition",
|
|
19
|
+
"AgentSandbox",
|
|
20
|
+
"ToolRef",
|
|
21
|
+
"load_agent",
|
|
22
|
+
"discover_agents",
|
|
23
|
+
"resolve_prompt",
|
|
24
|
+
# Agent invocation
|
|
25
|
+
"invoke_agent",
|
|
26
|
+
"invoke_agent_sync",
|
|
27
|
+
# Tools
|
|
28
|
+
"tool",
|
|
29
|
+
"Tool",
|
|
30
|
+
"collect_tools",
|
|
31
|
+
# Context
|
|
32
|
+
"ToolContext",
|
|
33
|
+
# Sessions
|
|
34
|
+
"AsyncSession",
|
|
35
|
+
"SyncSession",
|
|
36
|
+
"SessionConfig",
|
|
37
|
+
# Events
|
|
38
|
+
"StepStart",
|
|
39
|
+
"Text",
|
|
40
|
+
"ToolUse",
|
|
41
|
+
"StepFinish",
|
|
42
|
+
"Error",
|
|
43
|
+
"Result",
|
|
44
|
+
]
|
toolstream/_agent.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Agent definition loading and discovery."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib.resources
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ToolRef",
|
|
17
|
+
"AgentSandbox",
|
|
18
|
+
"AgentDefinition",
|
|
19
|
+
"load_agent",
|
|
20
|
+
"discover_agents",
|
|
21
|
+
"resolve_prompt",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class ToolRef:
|
|
27
|
+
name: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class AgentSandbox:
|
|
32
|
+
tools: list[str] | None = None
|
|
33
|
+
skip_permissions: bool = False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class AgentDefinition:
|
|
38
|
+
name: str
|
|
39
|
+
prompt_template: str
|
|
40
|
+
version: str
|
|
41
|
+
description: str = ""
|
|
42
|
+
tools: list[ToolRef] | None = None
|
|
43
|
+
sandbox: AgentSandbox | None = None
|
|
44
|
+
model: str | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _parse_agent_data(data: dict, source: str) -> AgentDefinition:
|
|
48
|
+
"""Parse a dict (from JSON) into an AgentDefinition.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
data: Parsed JSON dict.
|
|
52
|
+
source: Human-readable origin for error messages (file path, package).
|
|
53
|
+
"""
|
|
54
|
+
tools: list[ToolRef] | None = None
|
|
55
|
+
if "tools" in data:
|
|
56
|
+
tools = [ToolRef(name=entry["name"]) for entry in (data["tools"] or [])]
|
|
57
|
+
|
|
58
|
+
sandbox: AgentSandbox | None = None
|
|
59
|
+
if "sandbox" in data:
|
|
60
|
+
sandbox_data = data["sandbox"]
|
|
61
|
+
sandbox = AgentSandbox(
|
|
62
|
+
tools=sandbox_data.get("tools"),
|
|
63
|
+
skip_permissions=sandbox_data.get("skip_permissions", False),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return AgentDefinition(
|
|
67
|
+
name=data["name"],
|
|
68
|
+
prompt_template=data["prompt_template"],
|
|
69
|
+
version=data["version"],
|
|
70
|
+
description=data.get("description", ""),
|
|
71
|
+
tools=tools,
|
|
72
|
+
sandbox=sandbox,
|
|
73
|
+
model=data.get("model"),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def load_agent(path: str | Path) -> AgentDefinition:
|
|
78
|
+
"""Load an agent definition from a JSON file.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path: Path to a ``.agent.json`` file. Bare names (no path separator,
|
|
82
|
+
no ``.json`` suffix) are rejected -- use :func:`discover_agents`
|
|
83
|
+
for name-based lookup.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
FileNotFoundError: If *path* is a bare name or the file does not exist.
|
|
87
|
+
"""
|
|
88
|
+
if isinstance(path, str) and "/" not in path and os.sep not in path and not path.endswith(".json"):
|
|
89
|
+
raise FileNotFoundError(
|
|
90
|
+
f"Agent '{path}' not found (bare name resolution requires discover_agents)"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
p = Path(path)
|
|
94
|
+
data = json.loads(p.read_text(encoding="utf-8"))
|
|
95
|
+
return _parse_agent_data(data, source=str(p))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _try_load(
|
|
99
|
+
file_path: Path,
|
|
100
|
+
seen: dict[str, AgentDefinition],
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Attempt to load a single agent file, deduplicating by name."""
|
|
103
|
+
try:
|
|
104
|
+
agent = load_agent(file_path)
|
|
105
|
+
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
|
106
|
+
logger.warning("Skipping %s: %s", file_path, exc)
|
|
107
|
+
return
|
|
108
|
+
except FileNotFoundError as exc:
|
|
109
|
+
logger.warning("Skipping %s: %s", file_path, exc)
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
if agent.name in seen:
|
|
113
|
+
logger.warning(
|
|
114
|
+
"Duplicate agent '%s' in %s (already loaded); skipping",
|
|
115
|
+
agent.name,
|
|
116
|
+
file_path,
|
|
117
|
+
)
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
seen[agent.name] = agent
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def discover_agents(
|
|
124
|
+
cwd: str | Path | None = None,
|
|
125
|
+
paths: list[str | Path] | None = None,
|
|
126
|
+
packages: list[str] | None = None,
|
|
127
|
+
) -> list[AgentDefinition]:
|
|
128
|
+
"""Discover agent definitions from directories and packages.
|
|
129
|
+
|
|
130
|
+
Search order (first occurrence of a name wins):
|
|
131
|
+
|
|
132
|
+
1. ``<cwd>/.toolstream/agents/*.agent.json``
|
|
133
|
+
2. Each directory in *paths*: ``<dir>/*.agent.json``
|
|
134
|
+
3. Each Python package in *packages*: resources matching ``*.agent.json``
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Agent definitions sorted alphabetically by name.
|
|
138
|
+
"""
|
|
139
|
+
seen: dict[str, AgentDefinition] = {}
|
|
140
|
+
|
|
141
|
+
# 1. cwd-local agents
|
|
142
|
+
if cwd is not None:
|
|
143
|
+
agents_dir = Path(cwd) / ".toolstream" / "agents"
|
|
144
|
+
if agents_dir.is_dir():
|
|
145
|
+
for f in sorted(agents_dir.glob("*.agent.json")):
|
|
146
|
+
_try_load(f, seen)
|
|
147
|
+
|
|
148
|
+
# 2. Explicit directory paths
|
|
149
|
+
if paths is not None:
|
|
150
|
+
for dir_path in paths:
|
|
151
|
+
d = Path(dir_path)
|
|
152
|
+
if not d.is_dir():
|
|
153
|
+
logger.warning("Agent path %s is not a directory; skipping", d)
|
|
154
|
+
continue
|
|
155
|
+
for f in sorted(d.glob("*.agent.json")):
|
|
156
|
+
_try_load(f, seen)
|
|
157
|
+
|
|
158
|
+
# 3. Python packages
|
|
159
|
+
if packages is not None:
|
|
160
|
+
for package_name in packages:
|
|
161
|
+
try:
|
|
162
|
+
pkg_files = importlib.resources.files(package_name)
|
|
163
|
+
except ModuleNotFoundError:
|
|
164
|
+
logger.warning("Package '%s' not found; skipping", package_name)
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
for item in pkg_files.iterdir():
|
|
168
|
+
if not item.name.endswith(".agent.json"):
|
|
169
|
+
continue
|
|
170
|
+
try:
|
|
171
|
+
data = json.loads(item.read_text(encoding="utf-8"))
|
|
172
|
+
agent = _parse_agent_data(data, source=f"{package_name}/{item.name}")
|
|
173
|
+
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
|
174
|
+
logger.warning(
|
|
175
|
+
"Skipping %s/%s: %s", package_name, item.name, exc
|
|
176
|
+
)
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
if agent.name in seen:
|
|
180
|
+
logger.warning(
|
|
181
|
+
"Duplicate agent '%s' in %s/%s (already loaded); skipping",
|
|
182
|
+
agent.name,
|
|
183
|
+
package_name,
|
|
184
|
+
item.name,
|
|
185
|
+
)
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
seen[agent.name] = agent
|
|
189
|
+
|
|
190
|
+
return sorted(seen.values(), key=lambda a: a.name)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def resolve_prompt(template: str, variables: dict[str, str]) -> str:
|
|
194
|
+
"""Substitute ``{key}`` placeholders in a prompt template.
|
|
195
|
+
|
|
196
|
+
Uses :meth:`str.replace` (not :meth:`str.format`) so only the
|
|
197
|
+
supplied *variables* are touched; literal braces in the template
|
|
198
|
+
are preserved.
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ValueError: If any original placeholder has no matching variable.
|
|
202
|
+
"""
|
|
203
|
+
# Find all original placeholders before substitution.
|
|
204
|
+
original_placeholders = set(re.findall(r"\{(\w+)\}", template))
|
|
205
|
+
|
|
206
|
+
result = template
|
|
207
|
+
for key, value in variables.items():
|
|
208
|
+
result = result.replace(f"{{{key}}}", value)
|
|
209
|
+
|
|
210
|
+
# Only check placeholders that were in the original template.
|
|
211
|
+
unresolved = [p for p in sorted(original_placeholders) if p not in variables]
|
|
212
|
+
if unresolved:
|
|
213
|
+
raise ValueError(f"Unresolved template variables: {', '.join(unresolved)}")
|
|
214
|
+
|
|
215
|
+
return result
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Standalone implementations of the built-in tools for direct-mode LLM sessions.
|
|
2
|
+
|
|
3
|
+
Each function is an async def decorated with @tool that receives injectable
|
|
4
|
+
parameters (excluded from the JSON Schema sent to the LLM, passed explicitly
|
|
5
|
+
by the DirectClient's dispatch loop).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import os
|
|
12
|
+
import shlex
|
|
13
|
+
from glob import glob as _stdlib_glob
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
from ._tools import tool
|
|
17
|
+
|
|
18
|
+
_MAX_OUTPUT_CHARS = 50_000
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _resolve_path(file_path: str, cwd: str) -> Path:
|
|
22
|
+
"""Resolve *file_path*, making relative paths absolute against *cwd*."""
|
|
23
|
+
path = Path(file_path)
|
|
24
|
+
if not path.is_absolute():
|
|
25
|
+
path = Path(cwd) / path
|
|
26
|
+
return path
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@tool(inject=["cwd"])
|
|
30
|
+
async def read(file_path: str, cwd: str, offset: int = 0, limit: int = 2000) -> str:
|
|
31
|
+
"""Read a file and return its contents with line numbers."""
|
|
32
|
+
path = _resolve_path(file_path, cwd)
|
|
33
|
+
lines = path.read_text().splitlines()
|
|
34
|
+
selected = lines[offset : offset + limit]
|
|
35
|
+
return "\n".join(f"{i + offset + 1}: {line}" for i, line in enumerate(selected))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@tool(inject=["cwd"])
|
|
39
|
+
async def write(file_path: str, content: str, cwd: str) -> str:
|
|
40
|
+
"""Write content to a file, creating parent directories as needed."""
|
|
41
|
+
path = _resolve_path(file_path, cwd)
|
|
42
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
path.write_text(content)
|
|
44
|
+
return f"Wrote {len(content)} bytes to {file_path}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@tool(inject=["cwd", "env"])
|
|
48
|
+
async def bash(command: str, cwd: str, env: dict[str, str], timeout: int = 120) -> str:
|
|
49
|
+
"""Run a shell command and return stdout+stderr combined."""
|
|
50
|
+
proc = await asyncio.create_subprocess_shell(
|
|
51
|
+
command,
|
|
52
|
+
stdout=asyncio.subprocess.PIPE,
|
|
53
|
+
stderr=asyncio.subprocess.STDOUT,
|
|
54
|
+
cwd=cwd,
|
|
55
|
+
env=env,
|
|
56
|
+
)
|
|
57
|
+
try:
|
|
58
|
+
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
|
59
|
+
output = stdout.decode(errors="replace")
|
|
60
|
+
if len(output) > _MAX_OUTPUT_CHARS:
|
|
61
|
+
output = output[:_MAX_OUTPUT_CHARS] + f"\n... (truncated, {len(output)} total chars)"
|
|
62
|
+
return output
|
|
63
|
+
except asyncio.TimeoutError:
|
|
64
|
+
proc.kill()
|
|
65
|
+
await proc.wait()
|
|
66
|
+
return f"Command timed out after {timeout}s"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@tool(inject=["cwd"])
|
|
70
|
+
async def edit(file_path: str, old_string: str, new_string: str, cwd: str) -> str:
|
|
71
|
+
"""Edit a file by replacing the first occurrence of old_string with new_string."""
|
|
72
|
+
path = _resolve_path(file_path, cwd)
|
|
73
|
+
content = path.read_text()
|
|
74
|
+
if old_string not in content:
|
|
75
|
+
return f"Error: old_string not found in {file_path}"
|
|
76
|
+
content = content.replace(old_string, new_string, 1)
|
|
77
|
+
path.write_text(content)
|
|
78
|
+
return f"Edited {file_path}"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@tool(inject=["cwd"])
|
|
82
|
+
async def grep(pattern: str, path: str, cwd: str, include: str | None = None) -> str:
|
|
83
|
+
"""Search for a pattern in files using grep -rn, piped to head -50."""
|
|
84
|
+
cmd = f"grep -rn {shlex.quote(pattern)} {shlex.quote(path)}"
|
|
85
|
+
if include:
|
|
86
|
+
cmd += f" --include={shlex.quote(include)}"
|
|
87
|
+
cmd += " | head -50"
|
|
88
|
+
grep_env = {"PATH": os.environ.get("PATH", "/usr/bin:/bin")}
|
|
89
|
+
proc = await asyncio.create_subprocess_shell(
|
|
90
|
+
cmd,
|
|
91
|
+
stdout=asyncio.subprocess.PIPE,
|
|
92
|
+
stderr=asyncio.subprocess.STDOUT,
|
|
93
|
+
cwd=cwd,
|
|
94
|
+
env=grep_env,
|
|
95
|
+
)
|
|
96
|
+
try:
|
|
97
|
+
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=30)
|
|
98
|
+
output = stdout.decode(errors="replace")
|
|
99
|
+
if len(output) > _MAX_OUTPUT_CHARS:
|
|
100
|
+
output = output[:_MAX_OUTPUT_CHARS] + f"\n... (truncated, {len(output)} total chars)"
|
|
101
|
+
return output
|
|
102
|
+
except asyncio.TimeoutError:
|
|
103
|
+
proc.kill()
|
|
104
|
+
await proc.wait()
|
|
105
|
+
return "Command timed out after 30s"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@tool(name="glob", inject=["cwd"])
|
|
109
|
+
async def glob_files(pattern: str, cwd: str) -> str:
|
|
110
|
+
"""Find files matching a glob pattern. Returns up to 100 matches."""
|
|
111
|
+
if os.path.isabs(pattern):
|
|
112
|
+
matches = _stdlib_glob(pattern, recursive=True)
|
|
113
|
+
else:
|
|
114
|
+
matches = _stdlib_glob(pattern, root_dir=cwd, recursive=True)
|
|
115
|
+
return "\n".join(matches[:100])
|
toolstream/_context.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""ToolContext -- base class for tool dependency injection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ToolContext:
|
|
10
|
+
"""Base class for tool dependency injection.
|
|
11
|
+
|
|
12
|
+
Subclass with your application's injectable attributes.
|
|
13
|
+
Tools declare inject parameters that are resolved via getattr on this object.
|
|
14
|
+
"""
|
toolstream/_direct.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""Direct LLM API client via AI Gateway."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
from collections.abc import AsyncIterator
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from . import _builtin_tools
|
|
15
|
+
from ._tools import Tool, collect_tools
|
|
16
|
+
from .config import SessionConfig
|
|
17
|
+
from .events import Error, StepFinish, StepStart, Text, ToolUse
|
|
18
|
+
|
|
19
|
+
_MAX_RETRIES = 1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _strip_provider(model: str) -> str:
|
|
23
|
+
"""Strip provider prefix from model name.
|
|
24
|
+
|
|
25
|
+
'azure-cognitive-services/gpt-5.4' -> 'gpt-5.4'
|
|
26
|
+
'gpt-5.4' -> 'gpt-5.4'
|
|
27
|
+
"""
|
|
28
|
+
if "/" in model:
|
|
29
|
+
return model.rsplit("/", 1)[1]
|
|
30
|
+
return model
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _timestamp_ms() -> int:
|
|
34
|
+
return int(time.time() * 1000)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _build_tool_definitions(tools: dict[str, Tool]) -> list[dict]:
|
|
38
|
+
"""Build OpenAI-format tool definitions from Tool objects."""
|
|
39
|
+
defs = []
|
|
40
|
+
for t in tools.values():
|
|
41
|
+
defs.append({
|
|
42
|
+
"type": "function",
|
|
43
|
+
"function": {
|
|
44
|
+
"name": t.name,
|
|
45
|
+
"description": t.description,
|
|
46
|
+
"parameters": t.input_schema,
|
|
47
|
+
},
|
|
48
|
+
})
|
|
49
|
+
return defs
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DirectClient:
|
|
53
|
+
"""Direct LLM API client (via AI Gateway) with tool calling."""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
config: SessionConfig,
|
|
58
|
+
*,
|
|
59
|
+
tools: list[Tool] | None = None,
|
|
60
|
+
tool_context: object | None = None,
|
|
61
|
+
max_completion_tokens: int = 16384,
|
|
62
|
+
http_client: httpx.AsyncClient | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
if not config.api_key:
|
|
65
|
+
raise ValueError("api_key is required for direct backend")
|
|
66
|
+
if not config.base_url:
|
|
67
|
+
raise ValueError("base_url is required for direct backend")
|
|
68
|
+
|
|
69
|
+
self._config = config
|
|
70
|
+
self._base_url = config.base_url.rstrip("/")
|
|
71
|
+
self._api_key = config.api_key
|
|
72
|
+
self._model = _strip_provider(config.model)
|
|
73
|
+
self._messages: list[dict] = []
|
|
74
|
+
self._session_id = str(uuid.uuid4())
|
|
75
|
+
self._cwd = config.cwd or os.getcwd()
|
|
76
|
+
self._max_completion_tokens = max_completion_tokens
|
|
77
|
+
self._tool_context = tool_context
|
|
78
|
+
self._owns_client = http_client is None
|
|
79
|
+
|
|
80
|
+
# Builtin context for inject resolution
|
|
81
|
+
self._builtin_context: dict[str, Any] = {
|
|
82
|
+
"cwd": self._cwd,
|
|
83
|
+
"env": config.tool_env,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# HTTP client: reuse injected or create one
|
|
87
|
+
if http_client is not None:
|
|
88
|
+
self._client = http_client
|
|
89
|
+
else:
|
|
90
|
+
self._client = httpx.AsyncClient(
|
|
91
|
+
timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Build tool registry: start with built-ins, override with user tools
|
|
95
|
+
builtin_tools = collect_tools(_builtin_tools)
|
|
96
|
+
self._builtin_names: set[str] = {t.name for t in builtin_tools}
|
|
97
|
+
self._tools: dict[str, Tool] = {t.name: t for t in builtin_tools}
|
|
98
|
+
|
|
99
|
+
if tools is not None:
|
|
100
|
+
for t in tools:
|
|
101
|
+
self._builtin_names.discard(t.name)
|
|
102
|
+
self._tools[t.name] = t
|
|
103
|
+
|
|
104
|
+
# Pre-compute API tool definitions
|
|
105
|
+
self._tool_definitions = _build_tool_definitions(self._tools)
|
|
106
|
+
|
|
107
|
+
# Build system prompt
|
|
108
|
+
self._messages.append({"role": "system", "content": config.system_prompt})
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def session_id(self) -> str:
|
|
112
|
+
return self._session_id
|
|
113
|
+
|
|
114
|
+
async def send(self, message: str) -> AsyncIterator[StepStart | Text | ToolUse | StepFinish | Error]:
|
|
115
|
+
"""Send a message and yield events. Handles the tool-calling loop internally."""
|
|
116
|
+
self._messages.append({"role": "user", "content": message})
|
|
117
|
+
msg_id = str(uuid.uuid4())
|
|
118
|
+
|
|
119
|
+
yield StepStart(
|
|
120
|
+
session_id=self._session_id,
|
|
121
|
+
message_id=msg_id,
|
|
122
|
+
timestamp=_timestamp_ms(),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
total_input_tokens = 0
|
|
126
|
+
total_output_tokens = 0
|
|
127
|
+
|
|
128
|
+
while True:
|
|
129
|
+
response = await self._chat_completion(self._messages)
|
|
130
|
+
usage = response.get("usage", {})
|
|
131
|
+
total_input_tokens += usage.get("prompt_tokens", 0)
|
|
132
|
+
total_output_tokens += usage.get("completion_tokens", 0)
|
|
133
|
+
|
|
134
|
+
choice = response["choices"][0]
|
|
135
|
+
assistant_msg = choice["message"]
|
|
136
|
+
|
|
137
|
+
# Build the message dict to append to conversation
|
|
138
|
+
msg_to_append: dict = {"role": "assistant"}
|
|
139
|
+
if assistant_msg.get("content"):
|
|
140
|
+
msg_to_append["content"] = assistant_msg["content"]
|
|
141
|
+
if assistant_msg.get("tool_calls"):
|
|
142
|
+
msg_to_append["tool_calls"] = assistant_msg["tool_calls"]
|
|
143
|
+
self._messages.append(msg_to_append)
|
|
144
|
+
|
|
145
|
+
# Yield text if present
|
|
146
|
+
if assistant_msg.get("content"):
|
|
147
|
+
yield Text(
|
|
148
|
+
session_id=self._session_id,
|
|
149
|
+
message_id=msg_id,
|
|
150
|
+
text=assistant_msg["content"],
|
|
151
|
+
timestamp=_timestamp_ms(),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Check for tool calls
|
|
155
|
+
tool_calls = assistant_msg.get("tool_calls", [])
|
|
156
|
+
if not tool_calls:
|
|
157
|
+
# No tool calls -- conversation turn is done
|
|
158
|
+
yield StepFinish(
|
|
159
|
+
session_id=self._session_id,
|
|
160
|
+
message_id=msg_id,
|
|
161
|
+
reason="stop",
|
|
162
|
+
input_tokens=total_input_tokens,
|
|
163
|
+
output_tokens=total_output_tokens,
|
|
164
|
+
reasoning_tokens=usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0),
|
|
165
|
+
cache_read_tokens=usage.get("prompt_tokens_details", {}).get("cached_tokens", 0),
|
|
166
|
+
cache_write_tokens=0,
|
|
167
|
+
cost=0.0,
|
|
168
|
+
timestamp=_timestamp_ms(),
|
|
169
|
+
)
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
# Execute tool calls and add results
|
|
173
|
+
for tc in tool_calls:
|
|
174
|
+
func_name = tc["function"]["name"]
|
|
175
|
+
try:
|
|
176
|
+
func_args = json.loads(tc["function"]["arguments"])
|
|
177
|
+
except json.JSONDecodeError:
|
|
178
|
+
func_args = {}
|
|
179
|
+
|
|
180
|
+
result = await self._dispatch_tool(func_name, func_args)
|
|
181
|
+
|
|
182
|
+
yield ToolUse(
|
|
183
|
+
session_id=self._session_id,
|
|
184
|
+
message_id=msg_id,
|
|
185
|
+
tool=func_name,
|
|
186
|
+
call_id=tc["id"],
|
|
187
|
+
status="completed",
|
|
188
|
+
input=func_args,
|
|
189
|
+
output=result,
|
|
190
|
+
title=func_name,
|
|
191
|
+
timestamp=_timestamp_ms(),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
self._messages.append({
|
|
195
|
+
"role": "tool",
|
|
196
|
+
"tool_call_id": tc["id"],
|
|
197
|
+
"content": result,
|
|
198
|
+
})
|
|
199
|
+
|
|
200
|
+
# Yield StepFinish for this tool-calling round
|
|
201
|
+
yield StepFinish(
|
|
202
|
+
session_id=self._session_id,
|
|
203
|
+
message_id=msg_id,
|
|
204
|
+
reason="tool-calls",
|
|
205
|
+
input_tokens=total_input_tokens,
|
|
206
|
+
output_tokens=total_output_tokens,
|
|
207
|
+
reasoning_tokens=0,
|
|
208
|
+
cache_read_tokens=0,
|
|
209
|
+
cache_write_tokens=0,
|
|
210
|
+
cost=0.0,
|
|
211
|
+
timestamp=_timestamp_ms(),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
async def _chat_completion(self, messages: list[dict]) -> dict:
|
|
215
|
+
"""Call LLM chat completions via AI Gateway with timeout and retry."""
|
|
216
|
+
url = self._base_url
|
|
217
|
+
headers = {
|
|
218
|
+
"x-api-key": self._api_key,
|
|
219
|
+
"Content-Type": "application/json",
|
|
220
|
+
}
|
|
221
|
+
body: dict[str, Any] = {
|
|
222
|
+
"model": self._model,
|
|
223
|
+
"messages": messages,
|
|
224
|
+
"tools": self._tool_definitions,
|
|
225
|
+
"max_completion_tokens": self._max_completion_tokens,
|
|
226
|
+
}
|
|
227
|
+
if self._config.metadata:
|
|
228
|
+
body["metadata"] = self._config.metadata
|
|
229
|
+
|
|
230
|
+
last_error: Exception | None = None
|
|
231
|
+
for attempt in range(_MAX_RETRIES + 1):
|
|
232
|
+
try:
|
|
233
|
+
response = await self._client.post(url, json=body, headers=headers)
|
|
234
|
+
response.raise_for_status()
|
|
235
|
+
return response.json()
|
|
236
|
+
except httpx.ReadTimeout as e:
|
|
237
|
+
last_error = e
|
|
238
|
+
if attempt < _MAX_RETRIES:
|
|
239
|
+
continue
|
|
240
|
+
raise
|
|
241
|
+
except httpx.HTTPStatusError:
|
|
242
|
+
raise
|
|
243
|
+
|
|
244
|
+
# Should never reach here, but satisfy type checker
|
|
245
|
+
raise last_error # type: ignore[misc]
|
|
246
|
+
|
|
247
|
+
async def _dispatch_tool(self, name: str, args: dict[str, Any]) -> str:
|
|
248
|
+
"""Dispatch a tool call by name. Returns the result string."""
|
|
249
|
+
tool_obj = self._tools.get(name)
|
|
250
|
+
if tool_obj is None:
|
|
251
|
+
return f"Error: unknown tool '{name}'"
|
|
252
|
+
|
|
253
|
+
if name in self._builtin_names:
|
|
254
|
+
# Resolve inject params from builtin context
|
|
255
|
+
kwargs = {p: self._builtin_context[p] for p in tool_obj.inject}
|
|
256
|
+
try:
|
|
257
|
+
result = await tool_obj.handler(**args, **kwargs)
|
|
258
|
+
except Exception as e:
|
|
259
|
+
return f"Error: {e}"
|
|
260
|
+
else:
|
|
261
|
+
# User tools: inject context params, then call handler
|
|
262
|
+
for param_name in tool_obj.inject:
|
|
263
|
+
if self._tool_context is None:
|
|
264
|
+
raise RuntimeError(
|
|
265
|
+
f"Tool '{name}' requires tool_context "
|
|
266
|
+
f"(inject=['{param_name}']) but tool_context is None"
|
|
267
|
+
)
|
|
268
|
+
try:
|
|
269
|
+
args[param_name] = getattr(self._tool_context, param_name)
|
|
270
|
+
except AttributeError:
|
|
271
|
+
raise AttributeError(
|
|
272
|
+
f"tool_context ({type(self._tool_context).__name__}) "
|
|
273
|
+
f"has no attribute '{param_name}' "
|
|
274
|
+
f"required by tool '{name}'"
|
|
275
|
+
)
|
|
276
|
+
try:
|
|
277
|
+
result = await tool_obj.handler(**args)
|
|
278
|
+
except Exception as e:
|
|
279
|
+
return f"Error: {e}"
|
|
280
|
+
|
|
281
|
+
return result
|
|
282
|
+
|
|
283
|
+
async def close(self) -> None:
|
|
284
|
+
"""Close the HTTP client if we own it."""
|
|
285
|
+
if self._owns_client:
|
|
286
|
+
await self._client.aclose()
|
|
287
|
+
|
|
288
|
+
async def __aenter__(self) -> DirectClient:
|
|
289
|
+
return self
|
|
290
|
+
|
|
291
|
+
async def __aexit__(self, *args: Any) -> None:
|
|
292
|
+
await self.close()
|