zonix 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zonix/__init__.py ADDED
@@ -0,0 +1,44 @@
1
+ from .events import (
2
+ ApprovalRequired,
3
+ ErrorEvent,
4
+ Event,
5
+ Finish,
6
+ ReasoningDelta,
7
+ TextDelta,
8
+ TextEnd,
9
+ TextStart,
10
+ ToolInputAvailable,
11
+ ToolInputDelta,
12
+ ToolInputStart,
13
+ ToolOutputAvailable,
14
+ )
15
+ from .spec import Agent, agent, router, team, workflow
16
+ from .types import Message, PendingApproval, Route, RunResult, RunState, Span, ToolCall, Usage
17
+
18
+ __all__ = [
19
+ "Agent",
20
+ "ApprovalRequired",
21
+ "ErrorEvent",
22
+ "Event",
23
+ "Finish",
24
+ "Message",
25
+ "PendingApproval",
26
+ "ReasoningDelta",
27
+ "Route",
28
+ "RunResult",
29
+ "RunState",
30
+ "Span",
31
+ "TextDelta",
32
+ "TextEnd",
33
+ "TextStart",
34
+ "ToolCall",
35
+ "ToolInputAvailable",
36
+ "ToolInputDelta",
37
+ "ToolInputStart",
38
+ "ToolOutputAvailable",
39
+ "Usage",
40
+ "agent",
41
+ "router",
42
+ "team",
43
+ "workflow",
44
+ ]
zonix/engine.py ADDED
@@ -0,0 +1,335 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import inspect
5
+ import json
6
+ from typing import Any
7
+
8
+ from pydantic import BaseModel, TypeAdapter
9
+
10
+ from .events import (
11
+ ApprovalRequired,
12
+ TextDelta,
13
+ TextEnd,
14
+ TextStart,
15
+ ToolInputAvailable,
16
+ ToolInputStart,
17
+ ToolOutputAvailable,
18
+ )
19
+ from .exceptions import (
20
+ MaxToolRoundsExceeded,
21
+ OutputValidationError,
22
+ RunPaused,
23
+ ToolApprovalRejected,
24
+ )
25
+ from .hitl import approval_key, pending_from_call
26
+ from .models import ModelRequest, ModelResponse
27
+ from .serialization import to_jsonable
28
+ from .tools import ToolContext, ToolDefinition
29
+ from .types import Message, RunState, ToolCall
30
+
31
+
32
+ class RunEngine:
33
+ def __init__(self, agent: Any) -> None:
34
+ self.agent = agent
35
+
36
+ async def invoke(self, task: Any, st: RunState) -> Any:
37
+ attempts = self.agent.retry_attempts + 1
38
+ last_error: Exception | None = None
39
+ for attempt in range(attempts):
40
+ try:
41
+ if self.agent.timeout_seconds is None:
42
+ return await self._invoke_once(task, st)
43
+ return await asyncio.wait_for(
44
+ self._invoke_once(task, st),
45
+ timeout=self.agent.timeout_seconds,
46
+ )
47
+ except self.agent.retry_on as exc:
48
+ last_error = exc
49
+ st.trace.record({"attempt": attempt + 1, "error": str(exc)})
50
+ if attempt + 1 >= attempts:
51
+ break
52
+ if self.agent.fallback_node is not None:
53
+ return await self.agent.fallback_node.invoke(task, st.scoped(self.agent.fallback_node.name))
54
+ if last_error is not None:
55
+ raise last_error
56
+ raise RuntimeError("Agent invocation failed without an exception.")
57
+
58
+ async def _invoke_once(self, task: Any, st: RunState) -> Any:
59
+ messages = await self._build_messages(task, st)
60
+ tools = {tool.name: tool for tool in self.agent.tools}
61
+ tool_rounds = 0
62
+ repair_rounds = 0
63
+
64
+ while tool_rounds <= self.agent.max_tool_rounds:
65
+ request = ModelRequest(
66
+ messages=messages,
67
+ tools=[tool.model_tool_schema() for tool in tools.values()],
68
+ output_schema=self._output_schema(),
69
+ output_name=self._output_name(),
70
+ metadata={"agent": self.agent.name, "role": self.agent.role},
71
+ )
72
+ response = await self._call_model(request, st)
73
+ st.usage += response.usage
74
+
75
+ if response.tool_calls:
76
+ messages.append(
77
+ Message(
78
+ role="assistant",
79
+ content=response.text,
80
+ data={
81
+ "tool_calls": [
82
+ call.model_dump(mode="json") for call in response.tool_calls
83
+ ]
84
+ },
85
+ )
86
+ )
87
+ tool_rounds += 1
88
+ for call in response.tool_calls:
89
+ await self._run_tool(call, tools, messages, st)
90
+ continue
91
+
92
+ if response.text:
93
+ messages.append(Message(role="assistant", content=response.text))
94
+
95
+ try:
96
+ output = self._validate_output(response)
97
+ except OutputValidationError as exc:
98
+ if repair_rounds >= self.agent.output_repair_attempts:
99
+ raise
100
+ repair_rounds += 1
101
+ messages.append(
102
+ Message(
103
+ role="user",
104
+ content=(
105
+ "The previous response did not validate. "
106
+ "Return a corrected JSON value only, with no Markdown or prose. "
107
+ f"Validation error: {exc}"
108
+ ),
109
+ data={"kind": "output_repair", "attempt": repair_rounds},
110
+ )
111
+ )
112
+ continue
113
+ st.messages = messages
114
+ st.scratch[self.agent.name] = output
115
+ await self._remember(task, output, st)
116
+ return output
117
+
118
+ raise MaxToolRoundsExceeded(
119
+ f"Agent {self.agent.name!r} exceeded {self.agent.max_tool_rounds} tool rounds."
120
+ )
121
+
122
+ async def _build_messages(self, task: Any, st: RunState) -> list[Message]:
123
+ messages: list[Message] = []
124
+ instructions = await self._instructions(task, st)
125
+ if instructions:
126
+ messages.append(Message(role="system", content=instructions))
127
+ if st.session is not None:
128
+ history = await st.session.recall(task, st.ctx, memory=self.agent.memory)
129
+ messages.extend(history)
130
+ messages.append(Message(role="user", content=str(task)))
131
+ if st.extra:
132
+ messages.append(Message(role="user", content=st.extra, data={"kind": "extra"}))
133
+ return messages
134
+
135
+ async def _instructions(self, task: Any, st: RunState) -> str:
136
+ parts: list[str] = []
137
+ if self.agent.role:
138
+ parts.append(f"Role: {self.agent.role}")
139
+ for prompt in self.agent.prompts:
140
+ if isinstance(prompt, str):
141
+ parts.append(prompt)
142
+ continue
143
+ value = self._call_prompt(prompt, task, st)
144
+ if inspect.isawaitable(value):
145
+ value = await value
146
+ if value:
147
+ parts.append(str(value))
148
+ schema = self._output_schema()
149
+ if schema is not None:
150
+ parts.append(
151
+ "Return exactly one valid JSON value that validates against this output schema. "
152
+ "Do not include Markdown, prose, comments, or trailing text. "
153
+ "Use double quotes and valid JSON escaping.\n"
154
+ + json.dumps(schema, ensure_ascii=False)
155
+ )
156
+ return "\n\n".join(parts)
157
+
158
+ def _call_prompt(self, prompt: Any, task: Any, st: RunState) -> Any:
159
+ signature = inspect.signature(prompt)
160
+ params = list(signature.parameters)
161
+ if len(params) == 0:
162
+ return prompt()
163
+ if len(params) == 1:
164
+ return prompt(st.ctx)
165
+ return prompt(st.ctx, task)
166
+
167
+ async def _call_model(self, request: ModelRequest, st: RunState) -> ModelResponse:
168
+ if st.bus._emit is not None:
169
+ return await self.agent.model.stream_complete(request, st.bus.publish, st.path)
170
+ response = await self.agent.model.complete(request)
171
+ await self._emit_response(response, st)
172
+ return response
173
+
174
+ async def _emit_response(self, response: ModelResponse, st: RunState) -> None:
175
+ if response.text:
176
+ await st.bus.publish(TextStart(st.path, "text_0"))
177
+ await st.bus.publish(TextDelta(st.path, "text_0", response.text))
178
+ await st.bus.publish(TextEnd(st.path, "text_0"))
179
+ for call in response.tool_calls:
180
+ await st.bus.publish(ToolInputStart(st.path, call.call_id, call.tool))
181
+ await st.bus.publish(ToolInputAvailable(st.path, call.call_id, call.tool, call.input))
182
+
183
+ async def _run_tool(
184
+ self,
185
+ call: ToolCall,
186
+ tools: dict[str, ToolDefinition],
187
+ messages: list[Message],
188
+ st: RunState,
189
+ ) -> None:
190
+ if call.tool not in tools:
191
+ raise KeyError(f"Agent {self.agent.name!r} has no tool named {call.tool!r}.")
192
+
193
+ tool = tools[call.tool]
194
+ parsed_input = tool.parse_input(call.input)
195
+ normalized_call = ToolCall(call_id=call.call_id, tool=call.tool, input=parsed_input)
196
+ approved_input = await self._approval_for(tool, normalized_call, st)
197
+ if approved_input is not None and isinstance(approved_input, dict):
198
+ parsed_input = tool.parse_input(approved_input)
199
+
200
+ ctx = ToolContext(deps=st.ctx, usage=st.usage, state=st, agent=self.agent)
201
+ output = await tool.invoke(ctx, parsed_input)
202
+ st.usage.tool_calls += 1
203
+ await st.bus.publish(ToolOutputAvailable(st.path, call.call_id, to_jsonable(output)))
204
+ messages.append(
205
+ Message(
206
+ role="tool",
207
+ name=tool.name,
208
+ tool_call_id=call.call_id,
209
+ content=json.dumps(to_jsonable(output), ensure_ascii=False),
210
+ )
211
+ )
212
+
213
+ async def _approval_for(
214
+ self,
215
+ tool: ToolDefinition,
216
+ call: ToolCall,
217
+ st: RunState,
218
+ ) -> dict[str, Any] | bool | None:
219
+ if not tool.approval:
220
+ return None
221
+ key = approval_key(call.tool, call.input)
222
+ decision = st.approvals.get(call.call_id, st.approvals.get(key))
223
+ if decision is False:
224
+ raise ToolApprovalRejected(f"Tool call {call.call_id} was rejected.")
225
+ if decision is True or isinstance(decision, dict):
226
+ return decision
227
+
228
+ pending = pending_from_call(call)
229
+ await st.bus.publish(ApprovalRequired(st.path, call.call_id, call.tool, call.input))
230
+ snapshot = {
231
+ "run_id": st.run_id,
232
+ "path": st.path,
233
+ "pending": pending,
234
+ "messages": messages_dump(st.messages),
235
+ "scratch": to_jsonable(st.scratch),
236
+ "trace": to_jsonable(st.trace),
237
+ "usage": to_jsonable(st.usage),
238
+ }
239
+ raise RunPaused(pending=pending, snapshot=snapshot)
240
+
241
+ def _output_schema(self) -> dict[str, Any] | None:
242
+ output_type = self.agent.output_type
243
+ if output_type is None or output_type is Any:
244
+ return None
245
+ return TypeAdapter(output_type).json_schema()
246
+
247
+ def _output_name(self) -> str | None:
248
+ output_type = self.agent.output_type
249
+ if output_type is None:
250
+ return None
251
+ return getattr(output_type, "__name__", repr(output_type))
252
+
253
+ def _validate_output(self, response: ModelResponse) -> Any:
254
+ output_type = self.agent.output_type
255
+ raw = response.output if response.output is not None else response.text
256
+ if output_type is None or output_type is Any:
257
+ return raw
258
+ try:
259
+ if isinstance(raw, output_type):
260
+ return raw
261
+ except TypeError:
262
+ pass
263
+ data = raw
264
+ if isinstance(raw, str) and output_type is not str:
265
+ try:
266
+ data = _load_json_text(raw)
267
+ except json.JSONDecodeError as exc:
268
+ raise OutputValidationError(
269
+ f"Agent {self.agent.name!r} expected {self._output_name()} JSON, got text."
270
+ ) from exc
271
+ try:
272
+ if isinstance(output_type, type) and issubclass(output_type, BaseModel):
273
+ return output_type.model_validate(data)
274
+ except TypeError:
275
+ pass
276
+ try:
277
+ return TypeAdapter(output_type).validate_python(data)
278
+ except Exception as exc:
279
+ raise OutputValidationError(
280
+ f"Agent {self.agent.name!r} could not validate model output as {self._output_name()}."
281
+ ) from exc
282
+
283
+ async def _remember(self, task: Any, output: Any, st: RunState) -> None:
284
+ if st.session is None:
285
+ return
286
+ await st.session.remember(Message(role="user", content=str(task)))
287
+ await st.session.remember(
288
+ Message(role="assistant", content=json.dumps(to_jsonable(output), ensure_ascii=False))
289
+ )
290
+
291
+
292
+ def messages_dump(messages: list[Message]) -> list[dict[str, Any]]:
293
+ return [message.model_dump(mode="json") for message in messages]
294
+
295
+
296
+ def _load_json_text(text: str) -> Any:
297
+ stripped = text.strip()
298
+ try:
299
+ return json.loads(stripped)
300
+ except json.JSONDecodeError:
301
+ candidate = _extract_json_candidate(stripped)
302
+ if candidate is None:
303
+ raise
304
+ return json.loads(candidate)
305
+
306
+
307
+ def _extract_json_candidate(text: str) -> str | None:
308
+ starts = [index for index in (text.find("{"), text.find("[")) if index >= 0]
309
+ if not starts:
310
+ return None
311
+ start = min(starts)
312
+ opener = text[start]
313
+ closer = "}" if opener == "{" else "]"
314
+ depth = 0
315
+ in_string = False
316
+ escaped = False
317
+
318
+ for index, char in enumerate(text[start:], start=start):
319
+ if in_string:
320
+ if escaped:
321
+ escaped = False
322
+ elif char == "\\":
323
+ escaped = True
324
+ elif char == '"':
325
+ in_string = False
326
+ continue
327
+ if char == '"':
328
+ in_string = True
329
+ elif char == opener:
330
+ depth += 1
331
+ elif char == closer:
332
+ depth -= 1
333
+ if depth == 0:
334
+ return text[start : index + 1]
335
+ return None
zonix/events.py ADDED
@@ -0,0 +1,106 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from .serialization import to_jsonable
7
+ from .types import Usage
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class Event:
12
+ path: tuple[str, ...]
13
+
14
+ @property
15
+ def type(self) -> str:
16
+ name = type(self).__name__
17
+ chars: list[str] = []
18
+ for i, char in enumerate(name):
19
+ if char.isupper() and i:
20
+ chars.append("-")
21
+ chars.append(char.lower())
22
+ return "".join(chars)
23
+
24
+ def dump(self) -> dict[str, Any]:
25
+ data = to_jsonable(self)
26
+ if isinstance(data, dict):
27
+ data["type"] = self.type
28
+ return data
29
+ return {"type": self.type, "value": data}
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class NodeStart(Event):
34
+ name: str
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class NodeEnd(Event):
39
+ name: str
40
+ status: str = "ok"
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class TextStart(Event):
45
+ id: str
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class TextDelta(Event):
50
+ id: str
51
+ delta: str
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class TextEnd(Event):
56
+ id: str
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class ReasoningDelta(Event):
61
+ id: str
62
+ delta: str
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class ToolInputStart(Event):
67
+ call_id: str
68
+ tool: str
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class ToolInputDelta(Event):
73
+ call_id: str
74
+ delta: str
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class ToolInputAvailable(Event):
79
+ call_id: str
80
+ tool: str
81
+ input: dict[str, Any]
82
+
83
+
84
+ @dataclass(frozen=True)
85
+ class ToolOutputAvailable(Event):
86
+ call_id: str
87
+ output: Any
88
+
89
+
90
+ @dataclass(frozen=True)
91
+ class ApprovalRequired(Event):
92
+ call_id: str
93
+ tool: str
94
+ input: dict[str, Any]
95
+
96
+
97
+ @dataclass(frozen=True)
98
+ class ErrorEvent(Event):
99
+ message: str
100
+ error_type: str = "error"
101
+
102
+
103
+ @dataclass(frozen=True)
104
+ class Finish(Event):
105
+ output: Any
106
+ usage: Usage
zonix/exceptions.py ADDED
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from .types import PendingApproval
7
+
8
+
9
+ class ZonixError(Exception):
10
+ """Base error for Zonix."""
11
+
12
+
13
+ class ModelError(ZonixError):
14
+ """Raised when a model adapter cannot complete a request."""
15
+
16
+
17
+ class ToolError(ZonixError):
18
+ """Raised when a tool fails."""
19
+
20
+
21
+ class OutputValidationError(ZonixError):
22
+ """Raised when a model response cannot be validated as the requested output."""
23
+
24
+
25
+ class MaxToolRoundsExceeded(ZonixError):
26
+ """Raised when an agent keeps asking for tools beyond its configured limit."""
27
+
28
+
29
+ class MaxStepsExceeded(ZonixError):
30
+ """Raised when a team router exceeds its step budget."""
31
+
32
+
33
+ class ToolApprovalRejected(ZonixError):
34
+ """Raised when a paused tool call is resumed with approval=False."""
35
+
36
+
37
+ @dataclass
38
+ class RunPaused(ZonixError):
39
+ pending: PendingApproval
40
+ snapshot: dict[str, Any]
41
+
42
+ def __str__(self) -> str:
43
+ return f"Run paused for approval: {self.pending.tool}"
zonix/hitl.py ADDED
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from .serialization import to_jsonable
10
+ from .types import PendingApproval, ToolCall
11
+
12
+
13
+ def approval_key(tool: str, input: dict[str, Any]) -> str:
14
+ payload = json.dumps(to_jsonable(input), ensure_ascii=False, sort_keys=True)
15
+ digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()
16
+ return f"{tool}:{digest}"
17
+
18
+
19
+ def pending_from_call(call: ToolCall) -> PendingApproval:
20
+ return PendingApproval(
21
+ call_id=call.call_id,
22
+ tool=call.tool,
23
+ input=call.input,
24
+ approval_key=approval_key(call.tool, call.input),
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class CheckpointStore:
30
+ root: Path
31
+
32
+ def __post_init__(self) -> None:
33
+ self.root.mkdir(parents=True, exist_ok=True)
34
+
35
+ def path_for(self, run_id: str) -> Path:
36
+ return self.root / f"{run_id}.json"
37
+
38
+ def save(self, run_id: str, snapshot: dict[str, Any]) -> Path:
39
+ path = self.path_for(run_id)
40
+ path.write_text(
41
+ json.dumps(to_jsonable(snapshot), ensure_ascii=False, indent=2),
42
+ encoding="utf-8",
43
+ )
44
+ return path
45
+
46
+ def load(self, run_id: str) -> dict[str, Any]:
47
+ return json.loads(self.path_for(run_id).read_text(encoding="utf-8"))
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import uuid
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Protocol
7
+
8
+ from zonix.types import Message
9
+
10
+
11
+ class Memory(Protocol):
12
+ async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
13
+ ...
14
+
15
+
16
+ def _as_list(memory: Any) -> list[Memory]:
17
+ if memory is None:
18
+ return []
19
+ if isinstance(memory, MemoryStack):
20
+ return list(memory.items)
21
+ if isinstance(memory, list | tuple):
22
+ return list(memory)
23
+ return [memory]
24
+
25
+
26
+ @dataclass
27
+ class MemoryStack:
28
+ items: list[Memory] = field(default_factory=list)
29
+
30
+ def __add__(self, other: Any) -> MemoryStack:
31
+ return MemoryStack([*self.items, *_as_list(other)])
32
+
33
+ async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
34
+ messages = list(history)
35
+ for memory in self.items:
36
+ messages = await memory.apply(messages, current, ctx)
37
+ return messages
38
+
39
+
40
+ @dataclass
41
+ class Window:
42
+ size: int = 20
43
+
44
+ def __add__(self, other: Any) -> MemoryStack:
45
+ return MemoryStack([self]) + other
46
+
47
+ async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
48
+ return list(history[-self.size :])
49
+
50
+
51
+ @dataclass
52
+ class Summarize:
53
+ over: int
54
+ keep: int = 20
55
+
56
+ def __add__(self, other: Any) -> MemoryStack:
57
+ return MemoryStack([self]) + other
58
+
59
+ async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
60
+ total_chars = sum(len(message.content or "") for message in history)
61
+ if total_chars <= self.over:
62
+ return list(history)
63
+ older = history[: -self.keep]
64
+ recent = history[-self.keep :]
65
+ summary = "Earlier conversation summary: " + " ".join(
66
+ (message.content or "")[:200] for message in older
67
+ )
68
+ return [Message(role="system", content=summary)] + recent
69
+
70
+
71
+ @dataclass
72
+ class Vector:
73
+ store: Any
74
+ k: int = 6
75
+
76
+ def __add__(self, other: Any) -> MemoryStack:
77
+ return MemoryStack([self]) + other
78
+
79
+ async def apply(self, history: list[Message], current: Any, ctx: Any) -> list[Message]:
80
+ results = await self._search(current)
81
+ if not results:
82
+ return list(history)
83
+ body = "\n".join(str(item) for item in results)
84
+ return [*history, Message(role="system", content=f"Relevant long-term memory:\n{body}")]
85
+
86
+ async def _search(self, query: Any) -> list[Any]:
87
+ for name in ("asearch", "asimilarity_search"):
88
+ method = getattr(self.store, name, None)
89
+ if method is not None:
90
+ value = method(query, k=self.k)
91
+ if inspect.isawaitable(value):
92
+ value = await value
93
+ return list(value)
94
+ for name in ("search", "similarity_search"):
95
+ method = getattr(self.store, name, None)
96
+ if method is not None:
97
+ return list(method(query, k=self.k))
98
+ return []
99
+
100
+
101
+ @dataclass
102
+ class Session:
103
+ id: str = field(default_factory=lambda: f"session_{uuid.uuid4().hex}")
104
+ history: list[Message] = field(default_factory=list)
105
+ memory: Any = None
106
+
107
+ async def recall(self, current: Any, ctx: Any, memory: Any = None) -> list[Message]:
108
+ strategies = MemoryStack([*_as_list(self.memory), *_as_list(memory)])
109
+ return await strategies.apply(list(self.history), current, ctx)
110
+
111
+ async def remember(self, message: Message) -> None:
112
+ self.history.append(message)
113
+
114
+ def dump(self) -> dict[str, Any]:
115
+ return {
116
+ "id": self.id,
117
+ "history": [message.model_dump(mode="json") for message in self.history],
118
+ }