zonix 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zonix/__init__.py +44 -0
- zonix/engine.py +335 -0
- zonix/events.py +106 -0
- zonix/exceptions.py +43 -0
- zonix/hitl.py +47 -0
- zonix/memory/__init__.py +118 -0
- zonix/models/__init__.py +14 -0
- zonix/models/anthropic.py +154 -0
- zonix/models/base.py +77 -0
- zonix/models/fake.py +34 -0
- zonix/models/openai.py +222 -0
- zonix/multi/__init__.py +10 -0
- zonix/multi/team.py +96 -0
- zonix/multi/workflow.py +127 -0
- zonix/obs.py +18 -0
- zonix/py.typed +1 -0
- zonix/runtime.py +151 -0
- zonix/serialization.py +34 -0
- zonix/spec.py +171 -0
- zonix/tools.py +109 -0
- zonix/types.py +162 -0
- zonix/wire/__init__.py +3 -0
- zonix/wire/ai_sdk.py +86 -0
- zonix-0.2.1.dist-info/METADATA +314 -0
- zonix-0.2.1.dist-info/RECORD +27 -0
- zonix-0.2.1.dist-info/WHEEL +4 -0
- zonix-0.2.1.dist-info/licenses/LICENSE +21 -0
zonix/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from .events import (
|
|
2
|
+
ApprovalRequired,
|
|
3
|
+
ErrorEvent,
|
|
4
|
+
Event,
|
|
5
|
+
Finish,
|
|
6
|
+
ReasoningDelta,
|
|
7
|
+
TextDelta,
|
|
8
|
+
TextEnd,
|
|
9
|
+
TextStart,
|
|
10
|
+
ToolInputAvailable,
|
|
11
|
+
ToolInputDelta,
|
|
12
|
+
ToolInputStart,
|
|
13
|
+
ToolOutputAvailable,
|
|
14
|
+
)
|
|
15
|
+
from .spec import Agent, agent, router, team, workflow
|
|
16
|
+
from .types import Message, PendingApproval, Route, RunResult, RunState, Span, ToolCall, Usage
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"Agent",
|
|
20
|
+
"ApprovalRequired",
|
|
21
|
+
"ErrorEvent",
|
|
22
|
+
"Event",
|
|
23
|
+
"Finish",
|
|
24
|
+
"Message",
|
|
25
|
+
"PendingApproval",
|
|
26
|
+
"ReasoningDelta",
|
|
27
|
+
"Route",
|
|
28
|
+
"RunResult",
|
|
29
|
+
"RunState",
|
|
30
|
+
"Span",
|
|
31
|
+
"TextDelta",
|
|
32
|
+
"TextEnd",
|
|
33
|
+
"TextStart",
|
|
34
|
+
"ToolCall",
|
|
35
|
+
"ToolInputAvailable",
|
|
36
|
+
"ToolInputDelta",
|
|
37
|
+
"ToolInputStart",
|
|
38
|
+
"ToolOutputAvailable",
|
|
39
|
+
"Usage",
|
|
40
|
+
"agent",
|
|
41
|
+
"router",
|
|
42
|
+
"team",
|
|
43
|
+
"workflow",
|
|
44
|
+
]
|
zonix/engine.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, TypeAdapter
|
|
9
|
+
|
|
10
|
+
from .events import (
|
|
11
|
+
ApprovalRequired,
|
|
12
|
+
TextDelta,
|
|
13
|
+
TextEnd,
|
|
14
|
+
TextStart,
|
|
15
|
+
ToolInputAvailable,
|
|
16
|
+
ToolInputStart,
|
|
17
|
+
ToolOutputAvailable,
|
|
18
|
+
)
|
|
19
|
+
from .exceptions import (
|
|
20
|
+
MaxToolRoundsExceeded,
|
|
21
|
+
OutputValidationError,
|
|
22
|
+
RunPaused,
|
|
23
|
+
ToolApprovalRejected,
|
|
24
|
+
)
|
|
25
|
+
from .hitl import approval_key, pending_from_call
|
|
26
|
+
from .models import ModelRequest, ModelResponse
|
|
27
|
+
from .serialization import to_jsonable
|
|
28
|
+
from .tools import ToolContext, ToolDefinition
|
|
29
|
+
from .types import Message, RunState, ToolCall
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RunEngine:
|
|
33
|
+
def __init__(self, agent: Any) -> None:
|
|
34
|
+
self.agent = agent
|
|
35
|
+
|
|
36
|
+
async def invoke(self, task: Any, st: RunState) -> Any:
|
|
37
|
+
attempts = self.agent.retry_attempts + 1
|
|
38
|
+
last_error: Exception | None = None
|
|
39
|
+
for attempt in range(attempts):
|
|
40
|
+
try:
|
|
41
|
+
if self.agent.timeout_seconds is None:
|
|
42
|
+
return await self._invoke_once(task, st)
|
|
43
|
+
return await asyncio.wait_for(
|
|
44
|
+
self._invoke_once(task, st),
|
|
45
|
+
timeout=self.agent.timeout_seconds,
|
|
46
|
+
)
|
|
47
|
+
except self.agent.retry_on as exc:
|
|
48
|
+
last_error = exc
|
|
49
|
+
st.trace.record({"attempt": attempt + 1, "error": str(exc)})
|
|
50
|
+
if attempt + 1 >= attempts:
|
|
51
|
+
break
|
|
52
|
+
if self.agent.fallback_node is not None:
|
|
53
|
+
return await self.agent.fallback_node.invoke(task, st.scoped(self.agent.fallback_node.name))
|
|
54
|
+
if last_error is not None:
|
|
55
|
+
raise last_error
|
|
56
|
+
raise RuntimeError("Agent invocation failed without an exception.")
|
|
57
|
+
|
|
58
|
+
async def _invoke_once(self, task: Any, st: RunState) -> Any:
|
|
59
|
+
messages = await self._build_messages(task, st)
|
|
60
|
+
tools = {tool.name: tool for tool in self.agent.tools}
|
|
61
|
+
tool_rounds = 0
|
|
62
|
+
repair_rounds = 0
|
|
63
|
+
|
|
64
|
+
while tool_rounds <= self.agent.max_tool_rounds:
|
|
65
|
+
request = ModelRequest(
|
|
66
|
+
messages=messages,
|
|
67
|
+
tools=[tool.model_tool_schema() for tool in tools.values()],
|
|
68
|
+
output_schema=self._output_schema(),
|
|
69
|
+
output_name=self._output_name(),
|
|
70
|
+
metadata={"agent": self.agent.name, "role": self.agent.role},
|
|
71
|
+
)
|
|
72
|
+
response = await self._call_model(request, st)
|
|
73
|
+
st.usage += response.usage
|
|
74
|
+
|
|
75
|
+
if response.tool_calls:
|
|
76
|
+
messages.append(
|
|
77
|
+
Message(
|
|
78
|
+
role="assistant",
|
|
79
|
+
content=response.text,
|
|
80
|
+
data={
|
|
81
|
+
"tool_calls": [
|
|
82
|
+
call.model_dump(mode="json") for call in response.tool_calls
|
|
83
|
+
]
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
)
|
|
87
|
+
tool_rounds += 1
|
|
88
|
+
for call in response.tool_calls:
|
|
89
|
+
await self._run_tool(call, tools, messages, st)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
if response.text:
|
|
93
|
+
messages.append(Message(role="assistant", content=response.text))
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
output = self._validate_output(response)
|
|
97
|
+
except OutputValidationError as exc:
|
|
98
|
+
if repair_rounds >= self.agent.output_repair_attempts:
|
|
99
|
+
raise
|
|
100
|
+
repair_rounds += 1
|
|
101
|
+
messages.append(
|
|
102
|
+
Message(
|
|
103
|
+
role="user",
|
|
104
|
+
content=(
|
|
105
|
+
"The previous response did not validate. "
|
|
106
|
+
"Return a corrected JSON value only, with no Markdown or prose. "
|
|
107
|
+
f"Validation error: {exc}"
|
|
108
|
+
),
|
|
109
|
+
data={"kind": "output_repair", "attempt": repair_rounds},
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
continue
|
|
113
|
+
st.messages = messages
|
|
114
|
+
st.scratch[self.agent.name] = output
|
|
115
|
+
await self._remember(task, output, st)
|
|
116
|
+
return output
|
|
117
|
+
|
|
118
|
+
raise MaxToolRoundsExceeded(
|
|
119
|
+
f"Agent {self.agent.name!r} exceeded {self.agent.max_tool_rounds} tool rounds."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
async def _build_messages(self, task: Any, st: RunState) -> list[Message]:
|
|
123
|
+
messages: list[Message] = []
|
|
124
|
+
instructions = await self._instructions(task, st)
|
|
125
|
+
if instructions:
|
|
126
|
+
messages.append(Message(role="system", content=instructions))
|
|
127
|
+
if st.session is not None:
|
|
128
|
+
history = await st.session.recall(task, st.ctx, memory=self.agent.memory)
|
|
129
|
+
messages.extend(history)
|
|
130
|
+
messages.append(Message(role="user", content=str(task)))
|
|
131
|
+
if st.extra:
|
|
132
|
+
messages.append(Message(role="user", content=st.extra, data={"kind": "extra"}))
|
|
133
|
+
return messages
|
|
134
|
+
|
|
135
|
+
async def _instructions(self, task: Any, st: RunState) -> str:
|
|
136
|
+
parts: list[str] = []
|
|
137
|
+
if self.agent.role:
|
|
138
|
+
parts.append(f"Role: {self.agent.role}")
|
|
139
|
+
for prompt in self.agent.prompts:
|
|
140
|
+
if isinstance(prompt, str):
|
|
141
|
+
parts.append(prompt)
|
|
142
|
+
continue
|
|
143
|
+
value = self._call_prompt(prompt, task, st)
|
|
144
|
+
if inspect.isawaitable(value):
|
|
145
|
+
value = await value
|
|
146
|
+
if value:
|
|
147
|
+
parts.append(str(value))
|
|
148
|
+
schema = self._output_schema()
|
|
149
|
+
if schema is not None:
|
|
150
|
+
parts.append(
|
|
151
|
+
"Return exactly one valid JSON value that validates against this output schema. "
|
|
152
|
+
"Do not include Markdown, prose, comments, or trailing text. "
|
|
153
|
+
"Use double quotes and valid JSON escaping.\n"
|
|
154
|
+
+ json.dumps(schema, ensure_ascii=False)
|
|
155
|
+
)
|
|
156
|
+
return "\n\n".join(parts)
|
|
157
|
+
|
|
158
|
+
def _call_prompt(self, prompt: Any, task: Any, st: RunState) -> Any:
|
|
159
|
+
signature = inspect.signature(prompt)
|
|
160
|
+
params = list(signature.parameters)
|
|
161
|
+
if len(params) == 0:
|
|
162
|
+
return prompt()
|
|
163
|
+
if len(params) == 1:
|
|
164
|
+
return prompt(st.ctx)
|
|
165
|
+
return prompt(st.ctx, task)
|
|
166
|
+
|
|
167
|
+
async def _call_model(self, request: ModelRequest, st: RunState) -> ModelResponse:
|
|
168
|
+
if st.bus._emit is not None:
|
|
169
|
+
return await self.agent.model.stream_complete(request, st.bus.publish, st.path)
|
|
170
|
+
response = await self.agent.model.complete(request)
|
|
171
|
+
await self._emit_response(response, st)
|
|
172
|
+
return response
|
|
173
|
+
|
|
174
|
+
async def _emit_response(self, response: ModelResponse, st: RunState) -> None:
|
|
175
|
+
if response.text:
|
|
176
|
+
await st.bus.publish(TextStart(st.path, "text_0"))
|
|
177
|
+
await st.bus.publish(TextDelta(st.path, "text_0", response.text))
|
|
178
|
+
await st.bus.publish(TextEnd(st.path, "text_0"))
|
|
179
|
+
for call in response.tool_calls:
|
|
180
|
+
await st.bus.publish(ToolInputStart(st.path, call.call_id, call.tool))
|
|
181
|
+
await st.bus.publish(ToolInputAvailable(st.path, call.call_id, call.tool, call.input))
|
|
182
|
+
|
|
183
|
+
async def _run_tool(
|
|
184
|
+
self,
|
|
185
|
+
call: ToolCall,
|
|
186
|
+
tools: dict[str, ToolDefinition],
|
|
187
|
+
messages: list[Message],
|
|
188
|
+
st: RunState,
|
|
189
|
+
) -> None:
|
|
190
|
+
if call.tool not in tools:
|
|
191
|
+
raise KeyError(f"Agent {self.agent.name!r} has no tool named {call.tool!r}.")
|
|
192
|
+
|
|
193
|
+
tool = tools[call.tool]
|
|
194
|
+
parsed_input = tool.parse_input(call.input)
|
|
195
|
+
normalized_call = ToolCall(call_id=call.call_id, tool=call.tool, input=parsed_input)
|
|
196
|
+
approved_input = await self._approval_for(tool, normalized_call, st)
|
|
197
|
+
if approved_input is not None and isinstance(approved_input, dict):
|
|
198
|
+
parsed_input = tool.parse_input(approved_input)
|
|
199
|
+
|
|
200
|
+
ctx = ToolContext(deps=st.ctx, usage=st.usage, state=st, agent=self.agent)
|
|
201
|
+
output = await tool.invoke(ctx, parsed_input)
|
|
202
|
+
st.usage.tool_calls += 1
|
|
203
|
+
await st.bus.publish(ToolOutputAvailable(st.path, call.call_id, to_jsonable(output)))
|
|
204
|
+
messages.append(
|
|
205
|
+
Message(
|
|
206
|
+
role="tool",
|
|
207
|
+
name=tool.name,
|
|
208
|
+
tool_call_id=call.call_id,
|
|
209
|
+
content=json.dumps(to_jsonable(output), ensure_ascii=False),
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
async def _approval_for(
|
|
214
|
+
self,
|
|
215
|
+
tool: ToolDefinition,
|
|
216
|
+
call: ToolCall,
|
|
217
|
+
st: RunState,
|
|
218
|
+
) -> dict[str, Any] | bool | None:
|
|
219
|
+
if not tool.approval:
|
|
220
|
+
return None
|
|
221
|
+
key = approval_key(call.tool, call.input)
|
|
222
|
+
decision = st.approvals.get(call.call_id, st.approvals.get(key))
|
|
223
|
+
if decision is False:
|
|
224
|
+
raise ToolApprovalRejected(f"Tool call {call.call_id} was rejected.")
|
|
225
|
+
if decision is True or isinstance(decision, dict):
|
|
226
|
+
return decision
|
|
227
|
+
|
|
228
|
+
pending = pending_from_call(call)
|
|
229
|
+
await st.bus.publish(ApprovalRequired(st.path, call.call_id, call.tool, call.input))
|
|
230
|
+
snapshot = {
|
|
231
|
+
"run_id": st.run_id,
|
|
232
|
+
"path": st.path,
|
|
233
|
+
"pending": pending,
|
|
234
|
+
"messages": messages_dump(st.messages),
|
|
235
|
+
"scratch": to_jsonable(st.scratch),
|
|
236
|
+
"trace": to_jsonable(st.trace),
|
|
237
|
+
"usage": to_jsonable(st.usage),
|
|
238
|
+
}
|
|
239
|
+
raise RunPaused(pending=pending, snapshot=snapshot)
|
|
240
|
+
|
|
241
|
+
def _output_schema(self) -> dict[str, Any] | None:
|
|
242
|
+
output_type = self.agent.output_type
|
|
243
|
+
if output_type is None or output_type is Any:
|
|
244
|
+
return None
|
|
245
|
+
return TypeAdapter(output_type).json_schema()
|
|
246
|
+
|
|
247
|
+
def _output_name(self) -> str | None:
|
|
248
|
+
output_type = self.agent.output_type
|
|
249
|
+
if output_type is None:
|
|
250
|
+
return None
|
|
251
|
+
return getattr(output_type, "__name__", repr(output_type))
|
|
252
|
+
|
|
253
|
+
def _validate_output(self, response: ModelResponse) -> Any:
|
|
254
|
+
output_type = self.agent.output_type
|
|
255
|
+
raw = response.output if response.output is not None else response.text
|
|
256
|
+
if output_type is None or output_type is Any:
|
|
257
|
+
return raw
|
|
258
|
+
try:
|
|
259
|
+
if isinstance(raw, output_type):
|
|
260
|
+
return raw
|
|
261
|
+
except TypeError:
|
|
262
|
+
pass
|
|
263
|
+
data = raw
|
|
264
|
+
if isinstance(raw, str) and output_type is not str:
|
|
265
|
+
try:
|
|
266
|
+
data = _load_json_text(raw)
|
|
267
|
+
except json.JSONDecodeError as exc:
|
|
268
|
+
raise OutputValidationError(
|
|
269
|
+
f"Agent {self.agent.name!r} expected {self._output_name()} JSON, got text."
|
|
270
|
+
) from exc
|
|
271
|
+
try:
|
|
272
|
+
if isinstance(output_type, type) and issubclass(output_type, BaseModel):
|
|
273
|
+
return output_type.model_validate(data)
|
|
274
|
+
except TypeError:
|
|
275
|
+
pass
|
|
276
|
+
try:
|
|
277
|
+
return TypeAdapter(output_type).validate_python(data)
|
|
278
|
+
except Exception as exc:
|
|
279
|
+
raise OutputValidationError(
|
|
280
|
+
f"Agent {self.agent.name!r} could not validate model output as {self._output_name()}."
|
|
281
|
+
) from exc
|
|
282
|
+
|
|
283
|
+
async def _remember(self, task: Any, output: Any, st: RunState) -> None:
|
|
284
|
+
if st.session is None:
|
|
285
|
+
return
|
|
286
|
+
await st.session.remember(Message(role="user", content=str(task)))
|
|
287
|
+
await st.session.remember(
|
|
288
|
+
Message(role="assistant", content=json.dumps(to_jsonable(output), ensure_ascii=False))
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def messages_dump(messages: list[Message]) -> list[dict[str, Any]]:
|
|
293
|
+
return [message.model_dump(mode="json") for message in messages]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _load_json_text(text: str) -> Any:
|
|
297
|
+
stripped = text.strip()
|
|
298
|
+
try:
|
|
299
|
+
return json.loads(stripped)
|
|
300
|
+
except json.JSONDecodeError:
|
|
301
|
+
candidate = _extract_json_candidate(stripped)
|
|
302
|
+
if candidate is None:
|
|
303
|
+
raise
|
|
304
|
+
return json.loads(candidate)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _extract_json_candidate(text: str) -> str | None:
|
|
308
|
+
starts = [index for index in (text.find("{"), text.find("[")) if index >= 0]
|
|
309
|
+
if not starts:
|
|
310
|
+
return None
|
|
311
|
+
start = min(starts)
|
|
312
|
+
opener = text[start]
|
|
313
|
+
closer = "}" if opener == "{" else "]"
|
|
314
|
+
depth = 0
|
|
315
|
+
in_string = False
|
|
316
|
+
escaped = False
|
|
317
|
+
|
|
318
|
+
for index, char in enumerate(text[start:], start=start):
|
|
319
|
+
if in_string:
|
|
320
|
+
if escaped:
|
|
321
|
+
escaped = False
|
|
322
|
+
elif char == "\\":
|
|
323
|
+
escaped = True
|
|
324
|
+
elif char == '"':
|
|
325
|
+
in_string = False
|
|
326
|
+
continue
|
|
327
|
+
if char == '"':
|
|
328
|
+
in_string = True
|
|
329
|
+
elif char == opener:
|
|
330
|
+
depth += 1
|
|
331
|
+
elif char == closer:
|
|
332
|
+
depth -= 1
|
|
333
|
+
if depth == 0:
|
|
334
|
+
return text[start : index + 1]
|
|
335
|
+
return None
|
zonix/events.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .serialization import to_jsonable
|
|
7
|
+
from .types import Usage
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class Event:
|
|
12
|
+
path: tuple[str, ...]
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def type(self) -> str:
|
|
16
|
+
name = type(self).__name__
|
|
17
|
+
chars: list[str] = []
|
|
18
|
+
for i, char in enumerate(name):
|
|
19
|
+
if char.isupper() and i:
|
|
20
|
+
chars.append("-")
|
|
21
|
+
chars.append(char.lower())
|
|
22
|
+
return "".join(chars)
|
|
23
|
+
|
|
24
|
+
def dump(self) -> dict[str, Any]:
|
|
25
|
+
data = to_jsonable(self)
|
|
26
|
+
if isinstance(data, dict):
|
|
27
|
+
data["type"] = self.type
|
|
28
|
+
return data
|
|
29
|
+
return {"type": self.type, "value": data}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class NodeStart(Event):
|
|
34
|
+
name: str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class NodeEnd(Event):
|
|
39
|
+
name: str
|
|
40
|
+
status: str = "ok"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class TextStart(Event):
|
|
45
|
+
id: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class TextDelta(Event):
|
|
50
|
+
id: str
|
|
51
|
+
delta: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(frozen=True)
|
|
55
|
+
class TextEnd(Event):
|
|
56
|
+
id: str
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass(frozen=True)
|
|
60
|
+
class ReasoningDelta(Event):
|
|
61
|
+
id: str
|
|
62
|
+
delta: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass(frozen=True)
|
|
66
|
+
class ToolInputStart(Event):
|
|
67
|
+
call_id: str
|
|
68
|
+
tool: str
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(frozen=True)
|
|
72
|
+
class ToolInputDelta(Event):
|
|
73
|
+
call_id: str
|
|
74
|
+
delta: str
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True)
|
|
78
|
+
class ToolInputAvailable(Event):
|
|
79
|
+
call_id: str
|
|
80
|
+
tool: str
|
|
81
|
+
input: dict[str, Any]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(frozen=True)
|
|
85
|
+
class ToolOutputAvailable(Event):
|
|
86
|
+
call_id: str
|
|
87
|
+
output: Any
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass(frozen=True)
|
|
91
|
+
class ApprovalRequired(Event):
|
|
92
|
+
call_id: str
|
|
93
|
+
tool: str
|
|
94
|
+
input: dict[str, Any]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass(frozen=True)
|
|
98
|
+
class ErrorEvent(Event):
|
|
99
|
+
message: str
|
|
100
|
+
error_type: str = "error"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass(frozen=True)
|
|
104
|
+
class Finish(Event):
|
|
105
|
+
output: Any
|
|
106
|
+
usage: Usage
|
zonix/exceptions.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .types import PendingApproval
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ZonixError(Exception):
|
|
10
|
+
"""Base error for Zonix."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModelError(ZonixError):
|
|
14
|
+
"""Raised when a model adapter cannot complete a request."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ToolError(ZonixError):
|
|
18
|
+
"""Raised when a tool fails."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OutputValidationError(ZonixError):
|
|
22
|
+
"""Raised when a model response cannot be validated as the requested output."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MaxToolRoundsExceeded(ZonixError):
|
|
26
|
+
"""Raised when an agent keeps asking for tools beyond its configured limit."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MaxStepsExceeded(ZonixError):
|
|
30
|
+
"""Raised when a team router exceeds its step budget."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ToolApprovalRejected(ZonixError):
|
|
34
|
+
"""Raised when a paused tool call is resumed with approval=False."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class RunPaused(ZonixError):
|
|
39
|
+
pending: PendingApproval
|
|
40
|
+
snapshot: dict[str, Any]
|
|
41
|
+
|
|
42
|
+
def __str__(self) -> str:
|
|
43
|
+
return f"Run paused for approval: {self.pending.tool}"
|
zonix/hitl.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .serialization import to_jsonable
|
|
10
|
+
from .types import PendingApproval, ToolCall
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def approval_key(tool: str, input: dict[str, Any]) -> str:
|
|
14
|
+
payload = json.dumps(to_jsonable(input), ensure_ascii=False, sort_keys=True)
|
|
15
|
+
digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
|
16
|
+
return f"{tool}:{digest}"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def pending_from_call(call: ToolCall) -> PendingApproval:
|
|
20
|
+
return PendingApproval(
|
|
21
|
+
call_id=call.call_id,
|
|
22
|
+
tool=call.tool,
|
|
23
|
+
input=call.input,
|
|
24
|
+
approval_key=approval_key(call.tool, call.input),
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class CheckpointStore:
|
|
30
|
+
root: Path
|
|
31
|
+
|
|
32
|
+
def __post_init__(self) -> None:
|
|
33
|
+
self.root.mkdir(parents=True, exist_ok=True)
|
|
34
|
+
|
|
35
|
+
def path_for(self, run_id: str) -> Path:
|
|
36
|
+
return self.root / f"{run_id}.json"
|
|
37
|
+
|
|
38
|
+
def save(self, run_id: str, snapshot: dict[str, Any]) -> Path:
|
|
39
|
+
path = self.path_for(run_id)
|
|
40
|
+
path.write_text(
|
|
41
|
+
json.dumps(to_jsonable(snapshot), ensure_ascii=False, indent=2),
|
|
42
|
+
encoding="utf-8",
|
|
43
|
+
)
|
|
44
|
+
return path
|
|
45
|
+
|
|
46
|
+
def load(self, run_id: str) -> dict[str, Any]:
|
|
47
|
+
return json.loads(self.path_for(run_id).read_text(encoding="utf-8"))
|
zonix/memory/__init__.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import uuid
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Protocol
|
|
7
|
+
|
|
8
|
+
from zonix.types import Message
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Memory(Protocol):
|
|
12
|
+
async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _as_list(memory: Any) -> list[Memory]:
|
|
17
|
+
if memory is None:
|
|
18
|
+
return []
|
|
19
|
+
if isinstance(memory, MemoryStack):
|
|
20
|
+
return list(memory.items)
|
|
21
|
+
if isinstance(memory, list | tuple):
|
|
22
|
+
return list(memory)
|
|
23
|
+
return [memory]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class MemoryStack:
|
|
28
|
+
items: list[Memory] = field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
def __add__(self, other: Any) -> MemoryStack:
|
|
31
|
+
return MemoryStack([*self.items, *_as_list(other)])
|
|
32
|
+
|
|
33
|
+
async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
|
|
34
|
+
messages = list(history)
|
|
35
|
+
for memory in self.items:
|
|
36
|
+
messages = await memory.apply(messages, current, ctx)
|
|
37
|
+
return messages
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class Window:
|
|
42
|
+
size: int = 20
|
|
43
|
+
|
|
44
|
+
def __add__(self, other: Any) -> MemoryStack:
|
|
45
|
+
return MemoryStack([self]) + other
|
|
46
|
+
|
|
47
|
+
async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
|
|
48
|
+
return list(history[-self.size :])
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class Summarize:
|
|
53
|
+
over: int
|
|
54
|
+
keep: int = 20
|
|
55
|
+
|
|
56
|
+
def __add__(self, other: Any) -> MemoryStack:
|
|
57
|
+
return MemoryStack([self]) + other
|
|
58
|
+
|
|
59
|
+
async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
|
|
60
|
+
total_chars = sum(len(message.content or "") for message in history)
|
|
61
|
+
if total_chars <= self.over:
|
|
62
|
+
return list(history)
|
|
63
|
+
older = history[: -self.keep]
|
|
64
|
+
recent = history[-self.keep :]
|
|
65
|
+
summary = "Earlier conversation summary: " + " ".join(
|
|
66
|
+
(message.content or "")[:200] for message in older
|
|
67
|
+
)
|
|
68
|
+
return [Message(role="system", content=summary)] + recent
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class Vector:
|
|
73
|
+
store: Any
|
|
74
|
+
k: int = 6
|
|
75
|
+
|
|
76
|
+
def __add__(self, other: Any) -> MemoryStack:
|
|
77
|
+
return MemoryStack([self]) + other
|
|
78
|
+
|
|
79
|
+
async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
|
|
80
|
+
results = await self._search(current)
|
|
81
|
+
if not results:
|
|
82
|
+
return list(history)
|
|
83
|
+
body = "\n".join(str(item) for item in results)
|
|
84
|
+
return [*history, Message(role="system", content=f"Relevant long-term memory:\n{body}")]
|
|
85
|
+
|
|
86
|
+
async def _search(self, query: Any) -> list[Any]:
|
|
87
|
+
for name in ("asearch", "asimilarity_search"):
|
|
88
|
+
method = getattr(self.store, name, None)
|
|
89
|
+
if method is not None:
|
|
90
|
+
value = method(query, k=self.k)
|
|
91
|
+
if inspect.isawaitable(value):
|
|
92
|
+
value = await value
|
|
93
|
+
return list(value)
|
|
94
|
+
for name in ("search", "similarity_search"):
|
|
95
|
+
method = getattr(self.store, name, None)
|
|
96
|
+
if method is not None:
|
|
97
|
+
return list(method(query, k=self.k))
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class Session:
|
|
103
|
+
id: str = field(default_factory=lambda: f"session_{uuid.uuid4().hex}")
|
|
104
|
+
history: list[Message] = field(default_factory=list)
|
|
105
|
+
memory: Any = None
|
|
106
|
+
|
|
107
|
+
async def recall(self, current: Any, ctx: Any, memory: Any = None) -> list[Message]:
|
|
108
|
+
strategies = MemoryStack([*_as_list(self.memory), *_as_list(memory)])
|
|
109
|
+
return await strategies.apply(list(self.history), current, ctx)
|
|
110
|
+
|
|
111
|
+
async def remember(self, message: Message) -> None:
|
|
112
|
+
self.history.append(message)
|
|
113
|
+
|
|
114
|
+
def dump(self) -> dict[str, Any]:
|
|
115
|
+
return {
|
|
116
|
+
"id": self.id,
|
|
117
|
+
"history": [message.model_dump(mode="json") for message in self.history],
|
|
118
|
+
}
|