penguiflow 2.2.1__py3-none-any.whl → 2.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of penguiflow might be problematic. Click here for more details.
- penguiflow/__init__.py +1 -1
- penguiflow/core.py +24 -1
- penguiflow/planner/__init__.py +27 -0
- penguiflow/planner/prompts.py +243 -0
- penguiflow/planner/react.py +1339 -0
- {penguiflow-2.2.1.dist-info → penguiflow-2.2.3.dist-info}/METADATA +1 -1
- {penguiflow-2.2.1.dist-info → penguiflow-2.2.3.dist-info}/RECORD +11 -8
- {penguiflow-2.2.1.dist-info → penguiflow-2.2.3.dist-info}/WHEEL +0 -0
- {penguiflow-2.2.1.dist-info → penguiflow-2.2.3.dist-info}/entry_points.txt +0 -0
- {penguiflow-2.2.1.dist-info → penguiflow-2.2.3.dist-info}/licenses/LICENSE +0 -0
- {penguiflow-2.2.1.dist-info → penguiflow-2.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1339 @@
|
|
|
1
|
+
"""JSON-only ReAct planner loop with pause/resume and summarisation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
8
|
+
import time
|
|
9
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, Literal, Protocol
|
|
12
|
+
from uuid import uuid4
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
15
|
+
|
|
16
|
+
from ..catalog import NodeSpec, build_catalog
|
|
17
|
+
from ..node import Node
|
|
18
|
+
from ..registry import ModelRegistry
|
|
19
|
+
from . import prompts
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class JSONLLMClient(Protocol):
|
|
23
|
+
async def complete(
|
|
24
|
+
self,
|
|
25
|
+
*,
|
|
26
|
+
messages: Sequence[Mapping[str, str]],
|
|
27
|
+
response_format: Mapping[str, Any] | None = None,
|
|
28
|
+
) -> str:
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ParallelCall(BaseModel):
|
|
33
|
+
node: str
|
|
34
|
+
args: dict[str, Any] = Field(default_factory=dict)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ParallelJoin(BaseModel):
|
|
38
|
+
node: str
|
|
39
|
+
args: dict[str, Any] = Field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PlannerAction(BaseModel):
|
|
43
|
+
thought: str
|
|
44
|
+
next_node: str | None = None
|
|
45
|
+
args: dict[str, Any] | None = None
|
|
46
|
+
plan: list[ParallelCall] | None = None
|
|
47
|
+
join: ParallelJoin | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
PlannerPauseReason = Literal[
|
|
51
|
+
"approval_required",
|
|
52
|
+
"await_input",
|
|
53
|
+
"external_event",
|
|
54
|
+
"constraints_conflict",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PlannerPause(BaseModel):
|
|
59
|
+
reason: PlannerPauseReason
|
|
60
|
+
payload: dict[str, Any] = Field(default_factory=dict)
|
|
61
|
+
resume_token: str
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class PlannerFinish(BaseModel):
|
|
65
|
+
reason: Literal["answer_complete", "no_path", "budget_exhausted"]
|
|
66
|
+
payload: Any = None
|
|
67
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TrajectorySummary(BaseModel):
|
|
71
|
+
goals: list[str] = Field(default_factory=list)
|
|
72
|
+
facts: dict[str, Any] = Field(default_factory=dict)
|
|
73
|
+
pending: list[str] = Field(default_factory=list)
|
|
74
|
+
last_output_digest: str | None = None
|
|
75
|
+
note: str | None = None
|
|
76
|
+
|
|
77
|
+
def compact(self) -> dict[str, Any]:
|
|
78
|
+
payload = {
|
|
79
|
+
"goals": list(self.goals),
|
|
80
|
+
"facts": dict(self.facts),
|
|
81
|
+
"pending": list(self.pending),
|
|
82
|
+
"last_output_digest": self.last_output_digest,
|
|
83
|
+
}
|
|
84
|
+
if self.note:
|
|
85
|
+
payload["note"] = self.note
|
|
86
|
+
return payload
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass(slots=True)
|
|
90
|
+
class TrajectoryStep:
|
|
91
|
+
action: PlannerAction
|
|
92
|
+
observation: Any | None = None
|
|
93
|
+
error: str | None = None
|
|
94
|
+
failure: Mapping[str, Any] | None = None
|
|
95
|
+
|
|
96
|
+
def dump(self) -> dict[str, Any]:
|
|
97
|
+
return {
|
|
98
|
+
"action": self.action.model_dump(mode="json"),
|
|
99
|
+
"observation": self._serialise_observation(),
|
|
100
|
+
"error": self.error,
|
|
101
|
+
"failure": dict(self.failure) if self.failure else None,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
def _serialise_observation(self) -> Any:
|
|
105
|
+
if isinstance(self.observation, BaseModel):
|
|
106
|
+
return self.observation.model_dump(mode="json")
|
|
107
|
+
return self.observation
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass(slots=True)
|
|
111
|
+
class Trajectory:
|
|
112
|
+
query: str
|
|
113
|
+
context_meta: Mapping[str, Any] | None = None
|
|
114
|
+
steps: list[TrajectoryStep] = field(default_factory=list)
|
|
115
|
+
summary: TrajectorySummary | None = None
|
|
116
|
+
hint_state: dict[str, Any] = field(default_factory=dict)
|
|
117
|
+
resume_user_input: str | None = None
|
|
118
|
+
|
|
119
|
+
def to_history(self) -> list[dict[str, Any]]:
|
|
120
|
+
return [step.dump() for step in self.steps]
|
|
121
|
+
|
|
122
|
+
def serialise(self) -> dict[str, Any]:
|
|
123
|
+
return {
|
|
124
|
+
"query": self.query,
|
|
125
|
+
"context_meta": dict(self.context_meta or {}),
|
|
126
|
+
"steps": self.to_history(),
|
|
127
|
+
"summary": self.summary.model_dump(mode="json")
|
|
128
|
+
if self.summary
|
|
129
|
+
else None,
|
|
130
|
+
"hint_state": dict(self.hint_state),
|
|
131
|
+
"resume_user_input": self.resume_user_input,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def from_serialised(cls, payload: Mapping[str, Any]) -> Trajectory:
|
|
136
|
+
trajectory = cls(
|
|
137
|
+
query=payload["query"],
|
|
138
|
+
context_meta=payload.get("context_meta"),
|
|
139
|
+
)
|
|
140
|
+
for step_data in payload.get("steps", []):
|
|
141
|
+
action = PlannerAction.model_validate(step_data["action"])
|
|
142
|
+
step = TrajectoryStep(
|
|
143
|
+
action=action,
|
|
144
|
+
observation=step_data.get("observation"),
|
|
145
|
+
error=step_data.get("error"),
|
|
146
|
+
failure=step_data.get("failure"),
|
|
147
|
+
)
|
|
148
|
+
trajectory.steps.append(step)
|
|
149
|
+
summary_data = payload.get("summary")
|
|
150
|
+
if summary_data:
|
|
151
|
+
trajectory.summary = TrajectorySummary.model_validate(summary_data)
|
|
152
|
+
trajectory.hint_state.update(payload.get("hint_state", {}))
|
|
153
|
+
trajectory.resume_user_input = payload.get("resume_user_input")
|
|
154
|
+
return trajectory
|
|
155
|
+
|
|
156
|
+
def compress(self) -> TrajectorySummary:
|
|
157
|
+
facts: dict[str, Any] = {}
|
|
158
|
+
pending: list[str] = []
|
|
159
|
+
last_observation = None
|
|
160
|
+
if self.steps:
|
|
161
|
+
last_step = self.steps[-1]
|
|
162
|
+
if last_step.observation is not None:
|
|
163
|
+
last_observation = last_step._serialise_observation()
|
|
164
|
+
facts["last_observation"] = last_observation
|
|
165
|
+
if last_step.error:
|
|
166
|
+
facts["last_error"] = last_step.error
|
|
167
|
+
for step in self.steps:
|
|
168
|
+
if step.error:
|
|
169
|
+
pending.append(
|
|
170
|
+
f"retry {step.action.next_node or 'finish'}"
|
|
171
|
+
)
|
|
172
|
+
digest = None
|
|
173
|
+
if last_observation is not None:
|
|
174
|
+
digest_raw = json.dumps(last_observation, ensure_ascii=False)
|
|
175
|
+
digest = digest_raw if len(digest_raw) <= 120 else f"{digest_raw[:117]}..."
|
|
176
|
+
summary = TrajectorySummary(
|
|
177
|
+
goals=[self.query],
|
|
178
|
+
facts=facts,
|
|
179
|
+
pending=pending,
|
|
180
|
+
last_output_digest=digest,
|
|
181
|
+
note="rule_based",
|
|
182
|
+
)
|
|
183
|
+
self.summary = summary
|
|
184
|
+
return summary
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass(slots=True)
|
|
188
|
+
class _PauseRecord:
|
|
189
|
+
trajectory: Trajectory
|
|
190
|
+
reason: str
|
|
191
|
+
payload: dict[str, Any]
|
|
192
|
+
constraints: dict[str, Any] | None = None
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclass(slots=True)
|
|
196
|
+
class _PlanningHints:
|
|
197
|
+
ordering_hints: tuple[str, ...]
|
|
198
|
+
parallel_groups: tuple[tuple[str, ...], ...]
|
|
199
|
+
sequential_only: set[str]
|
|
200
|
+
disallow_nodes: set[str]
|
|
201
|
+
prefer_nodes: tuple[str, ...]
|
|
202
|
+
max_parallel: int | None
|
|
203
|
+
budget_hints: dict[str, Any]
|
|
204
|
+
|
|
205
|
+
@classmethod
|
|
206
|
+
def from_mapping(cls, payload: Mapping[str, Any] | None) -> _PlanningHints:
|
|
207
|
+
if not payload:
|
|
208
|
+
return cls((), (), set(), set(), (), None, {})
|
|
209
|
+
ordering = tuple(str(item) for item in payload.get("ordering_hints", ()))
|
|
210
|
+
parallel_groups = tuple(
|
|
211
|
+
tuple(str(node) for node in group)
|
|
212
|
+
for group in payload.get("parallel_groups", ())
|
|
213
|
+
)
|
|
214
|
+
sequential = {str(item) for item in payload.get("sequential_only", ())}
|
|
215
|
+
disallow = {str(item) for item in payload.get("disallow_nodes", ())}
|
|
216
|
+
prefer = tuple(str(item) for item in payload.get("prefer_nodes", ()))
|
|
217
|
+
budget_raw = dict(payload.get("budget_hints", {}))
|
|
218
|
+
max_parallel_value = payload.get("max_parallel")
|
|
219
|
+
if not isinstance(max_parallel_value, int):
|
|
220
|
+
candidate = budget_raw.get("max_parallel")
|
|
221
|
+
max_parallel_value = candidate if isinstance(candidate, int) else None
|
|
222
|
+
return cls(
|
|
223
|
+
ordering_hints=ordering,
|
|
224
|
+
parallel_groups=parallel_groups,
|
|
225
|
+
sequential_only=sequential,
|
|
226
|
+
disallow_nodes=disallow,
|
|
227
|
+
prefer_nodes=prefer,
|
|
228
|
+
max_parallel=max_parallel_value,
|
|
229
|
+
budget_hints=budget_raw,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def to_prompt_payload(self) -> dict[str, Any]:
|
|
233
|
+
payload: dict[str, Any] = {}
|
|
234
|
+
constraints: list[str] = []
|
|
235
|
+
if self.max_parallel is not None:
|
|
236
|
+
constraints.append(f"max_parallel={self.max_parallel}")
|
|
237
|
+
if self.sequential_only:
|
|
238
|
+
constraints.append(
|
|
239
|
+
"sequential_only=" + ",".join(sorted(self.sequential_only))
|
|
240
|
+
)
|
|
241
|
+
if constraints:
|
|
242
|
+
payload["constraints"] = "; ".join(constraints)
|
|
243
|
+
if self.ordering_hints:
|
|
244
|
+
payload["preferred_order"] = list(self.ordering_hints)
|
|
245
|
+
if self.parallel_groups:
|
|
246
|
+
payload["parallel_groups"] = [list(group) for group in self.parallel_groups]
|
|
247
|
+
if self.disallow_nodes:
|
|
248
|
+
payload["disallow_nodes"] = sorted(self.disallow_nodes)
|
|
249
|
+
if self.prefer_nodes:
|
|
250
|
+
payload["preferred_nodes"] = list(self.prefer_nodes)
|
|
251
|
+
if self.budget_hints:
|
|
252
|
+
payload["budget"] = dict(self.budget_hints)
|
|
253
|
+
return payload
|
|
254
|
+
|
|
255
|
+
def empty(self) -> bool:
|
|
256
|
+
return not (
|
|
257
|
+
self.ordering_hints
|
|
258
|
+
or self.parallel_groups
|
|
259
|
+
or self.sequential_only
|
|
260
|
+
or self.disallow_nodes
|
|
261
|
+
or self.prefer_nodes
|
|
262
|
+
or self.max_parallel is not None
|
|
263
|
+
or self.budget_hints
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class _ConstraintTracker:
|
|
268
|
+
__slots__ = (
|
|
269
|
+
"_deadline_at",
|
|
270
|
+
"_hop_budget",
|
|
271
|
+
"_hops_used",
|
|
272
|
+
"_time_source",
|
|
273
|
+
"deadline_triggered",
|
|
274
|
+
"hop_exhausted",
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
*,
|
|
280
|
+
deadline_s: float | None,
|
|
281
|
+
hop_budget: int | None,
|
|
282
|
+
time_source: Callable[[], float],
|
|
283
|
+
) -> None:
|
|
284
|
+
now = time_source()
|
|
285
|
+
self._deadline_at = now + deadline_s if deadline_s is not None else None
|
|
286
|
+
self._hop_budget = hop_budget
|
|
287
|
+
self._hops_used = 0
|
|
288
|
+
self._time_source = time_source
|
|
289
|
+
self.deadline_triggered = False
|
|
290
|
+
self.hop_exhausted = hop_budget == 0 and hop_budget is not None
|
|
291
|
+
|
|
292
|
+
def check_deadline(self) -> str | None:
|
|
293
|
+
if self._deadline_at is None:
|
|
294
|
+
return None
|
|
295
|
+
if self._time_source() >= self._deadline_at:
|
|
296
|
+
self.deadline_triggered = True
|
|
297
|
+
return prompts.render_deadline_exhausted()
|
|
298
|
+
return None
|
|
299
|
+
|
|
300
|
+
def has_budget_for_next_tool(self) -> bool:
|
|
301
|
+
if self._hop_budget is None:
|
|
302
|
+
return True
|
|
303
|
+
return self._hops_used < self._hop_budget
|
|
304
|
+
|
|
305
|
+
def record_hop(self) -> None:
|
|
306
|
+
if self._hop_budget is None:
|
|
307
|
+
return
|
|
308
|
+
self._hops_used += 1
|
|
309
|
+
if self._hops_used >= self._hop_budget:
|
|
310
|
+
self.hop_exhausted = True
|
|
311
|
+
|
|
312
|
+
def snapshot(self) -> dict[str, Any]:
|
|
313
|
+
remaining: float | None = None
|
|
314
|
+
if self._deadline_at is not None:
|
|
315
|
+
remaining = max(self._deadline_at - self._time_source(), 0.0)
|
|
316
|
+
return {
|
|
317
|
+
"deadline_at": self._deadline_at,
|
|
318
|
+
"deadline_remaining_s": remaining,
|
|
319
|
+
"hop_budget": self._hop_budget,
|
|
320
|
+
"hops_used": self._hops_used,
|
|
321
|
+
"deadline_triggered": self.deadline_triggered,
|
|
322
|
+
"hop_exhausted": self.hop_exhausted,
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
@classmethod
|
|
326
|
+
def from_snapshot(
|
|
327
|
+
cls, snapshot: Mapping[str, Any], *, time_source: Callable[[], float]
|
|
328
|
+
) -> _ConstraintTracker:
|
|
329
|
+
deadline_remaining = snapshot.get("deadline_remaining_s")
|
|
330
|
+
hop_budget = snapshot.get("hop_budget")
|
|
331
|
+
tracker = cls(
|
|
332
|
+
deadline_s=deadline_remaining,
|
|
333
|
+
hop_budget=hop_budget,
|
|
334
|
+
time_source=time_source,
|
|
335
|
+
)
|
|
336
|
+
tracker._hops_used = int(snapshot.get("hops_used", 0))
|
|
337
|
+
tracker._hop_budget = hop_budget
|
|
338
|
+
if deadline_remaining is None and snapshot.get("deadline_at") is None:
|
|
339
|
+
tracker._deadline_at = None
|
|
340
|
+
elif deadline_remaining is not None:
|
|
341
|
+
tracker._deadline_at = time_source() + max(float(deadline_remaining), 0.0)
|
|
342
|
+
else:
|
|
343
|
+
tracker._deadline_at = snapshot.get("deadline_at")
|
|
344
|
+
tracker.deadline_triggered = bool(snapshot.get("deadline_triggered", False))
|
|
345
|
+
tracker.hop_exhausted = bool(snapshot.get("hop_exhausted", False))
|
|
346
|
+
if (
|
|
347
|
+
tracker._hop_budget is not None
|
|
348
|
+
and tracker._hops_used >= tracker._hop_budget
|
|
349
|
+
):
|
|
350
|
+
tracker.hop_exhausted = True
|
|
351
|
+
return tracker
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class _PlannerPauseSignal(Exception):
|
|
355
|
+
def __init__(self, pause: PlannerPause) -> None:
|
|
356
|
+
super().__init__(pause.reason)
|
|
357
|
+
self.pause = pause
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
@dataclass(slots=True)
|
|
361
|
+
class _BranchExecutionResult:
|
|
362
|
+
observation: BaseModel | None = None
|
|
363
|
+
error: str | None = None
|
|
364
|
+
failure: Mapping[str, Any] | None = None
|
|
365
|
+
pause: PlannerPause | None = None
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class _LiteLLMJSONClient:
|
|
369
|
+
def __init__(
|
|
370
|
+
self,
|
|
371
|
+
llm: str | Mapping[str, Any],
|
|
372
|
+
*,
|
|
373
|
+
temperature: float,
|
|
374
|
+
json_schema_mode: bool,
|
|
375
|
+
) -> None:
|
|
376
|
+
self._llm = llm
|
|
377
|
+
self._temperature = temperature
|
|
378
|
+
self._json_schema_mode = json_schema_mode
|
|
379
|
+
|
|
380
|
+
async def complete(
|
|
381
|
+
self,
|
|
382
|
+
*,
|
|
383
|
+
messages: Sequence[Mapping[str, str]],
|
|
384
|
+
response_format: Mapping[str, Any] | None = None,
|
|
385
|
+
) -> str:
|
|
386
|
+
try:
|
|
387
|
+
import litellm
|
|
388
|
+
except ModuleNotFoundError as exc: # pragma: no cover - import guard
|
|
389
|
+
raise RuntimeError(
|
|
390
|
+
"LiteLLM is not installed. Install penguiflow[planner] or provide "
|
|
391
|
+
"a custom llm_client."
|
|
392
|
+
) from exc
|
|
393
|
+
|
|
394
|
+
params: dict[str, Any]
|
|
395
|
+
if isinstance(self._llm, str):
|
|
396
|
+
params = {"model": self._llm}
|
|
397
|
+
else:
|
|
398
|
+
params = dict(self._llm)
|
|
399
|
+
params.setdefault("temperature", self._temperature)
|
|
400
|
+
params["messages"] = list(messages)
|
|
401
|
+
if self._json_schema_mode and response_format is not None:
|
|
402
|
+
params["response_format"] = response_format
|
|
403
|
+
|
|
404
|
+
response = await litellm.acompletion(**params)
|
|
405
|
+
choice = response["choices"][0]
|
|
406
|
+
content = choice["message"]["content"]
|
|
407
|
+
if content is None:
|
|
408
|
+
raise RuntimeError("LiteLLM returned empty content")
|
|
409
|
+
return content
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class _PlannerContext:
|
|
413
|
+
__slots__ = ("meta", "_planner", "_trajectory")
|
|
414
|
+
|
|
415
|
+
def __init__(self, planner: ReactPlanner, trajectory: Trajectory) -> None:
|
|
416
|
+
self.meta = dict(trajectory.context_meta or {})
|
|
417
|
+
self._planner = planner
|
|
418
|
+
self._trajectory = trajectory
|
|
419
|
+
|
|
420
|
+
async def pause(
|
|
421
|
+
self,
|
|
422
|
+
reason: PlannerPauseReason,
|
|
423
|
+
payload: Mapping[str, Any] | None = None,
|
|
424
|
+
) -> PlannerPause:
|
|
425
|
+
return await self._planner._pause_from_context(
|
|
426
|
+
reason,
|
|
427
|
+
dict(payload or {}),
|
|
428
|
+
self._trajectory,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
class ReactPlanner:
|
|
433
|
+
"""Minimal JSON-only ReAct loop."""
|
|
434
|
+
|
|
435
|
+
def __init__(
|
|
436
|
+
self,
|
|
437
|
+
llm: str | Mapping[str, Any] | None = None,
|
|
438
|
+
*,
|
|
439
|
+
nodes: Sequence[Node] | None = None,
|
|
440
|
+
catalog: Sequence[NodeSpec] | None = None,
|
|
441
|
+
registry: ModelRegistry | None = None,
|
|
442
|
+
llm_client: JSONLLMClient | None = None,
|
|
443
|
+
max_iters: int = 8,
|
|
444
|
+
temperature: float = 0.0,
|
|
445
|
+
json_schema_mode: bool = True,
|
|
446
|
+
system_prompt_extra: str | None = None,
|
|
447
|
+
token_budget: int | None = None,
|
|
448
|
+
pause_enabled: bool = True,
|
|
449
|
+
state_store: Any | None = None,
|
|
450
|
+
summarizer_llm: str | Mapping[str, Any] | None = None,
|
|
451
|
+
planning_hints: Mapping[str, Any] | None = None,
|
|
452
|
+
repair_attempts: int = 3,
|
|
453
|
+
deadline_s: float | None = None,
|
|
454
|
+
hop_budget: int | None = None,
|
|
455
|
+
time_source: Callable[[], float] | None = None,
|
|
456
|
+
) -> None:
|
|
457
|
+
if catalog is None:
|
|
458
|
+
if nodes is None or registry is None:
|
|
459
|
+
raise ValueError(
|
|
460
|
+
"Either catalog or (nodes and registry) must be provided"
|
|
461
|
+
)
|
|
462
|
+
catalog = build_catalog(nodes, registry)
|
|
463
|
+
|
|
464
|
+
self._specs = list(catalog)
|
|
465
|
+
self._spec_by_name = {spec.name: spec for spec in self._specs}
|
|
466
|
+
self._catalog_records = [spec.to_tool_record() for spec in self._specs]
|
|
467
|
+
self._planning_hints = _PlanningHints.from_mapping(planning_hints)
|
|
468
|
+
hints_payload = (
|
|
469
|
+
self._planning_hints.to_prompt_payload()
|
|
470
|
+
if not self._planning_hints.empty()
|
|
471
|
+
else None
|
|
472
|
+
)
|
|
473
|
+
self._system_prompt = prompts.build_system_prompt(
|
|
474
|
+
self._catalog_records,
|
|
475
|
+
extra=system_prompt_extra,
|
|
476
|
+
planning_hints=hints_payload,
|
|
477
|
+
)
|
|
478
|
+
self._max_iters = max_iters
|
|
479
|
+
self._repair_attempts = repair_attempts
|
|
480
|
+
self._json_schema_mode = json_schema_mode
|
|
481
|
+
self._token_budget = token_budget
|
|
482
|
+
self._pause_enabled = pause_enabled
|
|
483
|
+
self._state_store = state_store
|
|
484
|
+
self._pause_records: dict[str, _PauseRecord] = {}
|
|
485
|
+
self._active_trajectory: Trajectory | None = None
|
|
486
|
+
self._active_tracker: _ConstraintTracker | None = None
|
|
487
|
+
self._deadline_s = deadline_s
|
|
488
|
+
self._hop_budget = hop_budget
|
|
489
|
+
self._time_source = time_source or time.monotonic
|
|
490
|
+
self._response_format = (
|
|
491
|
+
{
|
|
492
|
+
"type": "json_schema",
|
|
493
|
+
"json_schema": {
|
|
494
|
+
"name": "planner_action",
|
|
495
|
+
"schema": PlannerAction.model_json_schema(),
|
|
496
|
+
},
|
|
497
|
+
}
|
|
498
|
+
if json_schema_mode
|
|
499
|
+
else None
|
|
500
|
+
)
|
|
501
|
+
self._summarizer_client: JSONLLMClient | None = None
|
|
502
|
+
if llm_client is not None:
|
|
503
|
+
self._client = llm_client
|
|
504
|
+
else:
|
|
505
|
+
if llm is None:
|
|
506
|
+
raise ValueError("llm or llm_client must be provided")
|
|
507
|
+
self._client = _LiteLLMJSONClient(
|
|
508
|
+
llm,
|
|
509
|
+
temperature=temperature,
|
|
510
|
+
json_schema_mode=json_schema_mode,
|
|
511
|
+
)
|
|
512
|
+
if summarizer_llm is not None:
|
|
513
|
+
self._summarizer_client = _LiteLLMJSONClient(
|
|
514
|
+
summarizer_llm,
|
|
515
|
+
temperature=temperature,
|
|
516
|
+
json_schema_mode=True,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
async def run(
|
|
520
|
+
self,
|
|
521
|
+
query: str,
|
|
522
|
+
*,
|
|
523
|
+
context_meta: Mapping[str, Any] | None = None,
|
|
524
|
+
) -> PlannerFinish | PlannerPause:
|
|
525
|
+
trajectory = Trajectory(query=query, context_meta=context_meta)
|
|
526
|
+
return await self._run_loop(trajectory, tracker=None)
|
|
527
|
+
|
|
528
|
+
async def resume(
|
|
529
|
+
self,
|
|
530
|
+
token: str,
|
|
531
|
+
user_input: str | None = None,
|
|
532
|
+
) -> PlannerFinish | PlannerPause:
|
|
533
|
+
record = await self._load_pause_record(token)
|
|
534
|
+
trajectory = record.trajectory
|
|
535
|
+
trajectory.context_meta = trajectory.context_meta or {}
|
|
536
|
+
if user_input is not None:
|
|
537
|
+
trajectory.resume_user_input = user_input
|
|
538
|
+
tracker: _ConstraintTracker | None = None
|
|
539
|
+
if record.constraints is not None:
|
|
540
|
+
tracker = _ConstraintTracker.from_snapshot(
|
|
541
|
+
record.constraints,
|
|
542
|
+
time_source=self._time_source,
|
|
543
|
+
)
|
|
544
|
+
return await self._run_loop(trajectory, tracker=tracker)
|
|
545
|
+
|
|
546
|
+
async def _run_loop(
|
|
547
|
+
self,
|
|
548
|
+
trajectory: Trajectory,
|
|
549
|
+
*,
|
|
550
|
+
tracker: _ConstraintTracker | None,
|
|
551
|
+
) -> PlannerFinish | PlannerPause:
|
|
552
|
+
last_observation: Any | None = None
|
|
553
|
+
self._active_trajectory = trajectory
|
|
554
|
+
if tracker is None:
|
|
555
|
+
tracker = _ConstraintTracker(
|
|
556
|
+
deadline_s=self._deadline_s,
|
|
557
|
+
hop_budget=self._hop_budget,
|
|
558
|
+
time_source=self._time_source,
|
|
559
|
+
)
|
|
560
|
+
self._active_tracker = tracker
|
|
561
|
+
try:
|
|
562
|
+
while len(trajectory.steps) < self._max_iters:
|
|
563
|
+
deadline_message = tracker.check_deadline()
|
|
564
|
+
if deadline_message is not None:
|
|
565
|
+
return self._finish(
|
|
566
|
+
trajectory,
|
|
567
|
+
reason="budget_exhausted",
|
|
568
|
+
payload=last_observation,
|
|
569
|
+
thought=deadline_message,
|
|
570
|
+
constraints=tracker,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
action = await self.step(trajectory)
|
|
574
|
+
|
|
575
|
+
if action.plan:
|
|
576
|
+
parallel_observation, pause = await self._execute_parallel_plan(
|
|
577
|
+
action, trajectory, tracker
|
|
578
|
+
)
|
|
579
|
+
if pause is not None:
|
|
580
|
+
return pause
|
|
581
|
+
trajectory.summary = None
|
|
582
|
+
last_observation = parallel_observation
|
|
583
|
+
trajectory.resume_user_input = None
|
|
584
|
+
continue
|
|
585
|
+
|
|
586
|
+
if action.next_node is None:
|
|
587
|
+
payload = action.args or last_observation
|
|
588
|
+
return self._finish(
|
|
589
|
+
trajectory,
|
|
590
|
+
reason="answer_complete",
|
|
591
|
+
payload=payload,
|
|
592
|
+
thought=action.thought,
|
|
593
|
+
constraints=tracker,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
constraint_error = self._check_action_constraints(
|
|
597
|
+
action, trajectory, tracker
|
|
598
|
+
)
|
|
599
|
+
if constraint_error is not None:
|
|
600
|
+
trajectory.steps.append(
|
|
601
|
+
TrajectoryStep(action=action, error=constraint_error)
|
|
602
|
+
)
|
|
603
|
+
trajectory.summary = None
|
|
604
|
+
continue
|
|
605
|
+
|
|
606
|
+
spec = self._spec_by_name.get(action.next_node)
|
|
607
|
+
if spec is None:
|
|
608
|
+
error = prompts.render_invalid_node(
|
|
609
|
+
action.next_node,
|
|
610
|
+
list(self._spec_by_name.keys()),
|
|
611
|
+
)
|
|
612
|
+
trajectory.steps.append(TrajectoryStep(action=action, error=error))
|
|
613
|
+
trajectory.summary = None
|
|
614
|
+
continue
|
|
615
|
+
|
|
616
|
+
try:
|
|
617
|
+
parsed_args = spec.args_model.model_validate(action.args or {})
|
|
618
|
+
except ValidationError as exc:
|
|
619
|
+
error = prompts.render_validation_error(
|
|
620
|
+
spec.name,
|
|
621
|
+
json.dumps(exc.errors(), ensure_ascii=False),
|
|
622
|
+
)
|
|
623
|
+
trajectory.steps.append(TrajectoryStep(action=action, error=error))
|
|
624
|
+
trajectory.summary = None
|
|
625
|
+
continue
|
|
626
|
+
|
|
627
|
+
ctx = _PlannerContext(self, trajectory)
|
|
628
|
+
try:
|
|
629
|
+
result = await spec.node.func(parsed_args, ctx)
|
|
630
|
+
except _PlannerPauseSignal as signal:
|
|
631
|
+
tracker.record_hop()
|
|
632
|
+
trajectory.steps.append(
|
|
633
|
+
TrajectoryStep(
|
|
634
|
+
action=action,
|
|
635
|
+
observation={
|
|
636
|
+
"pause": signal.pause.reason,
|
|
637
|
+
"payload": signal.pause.payload,
|
|
638
|
+
},
|
|
639
|
+
)
|
|
640
|
+
)
|
|
641
|
+
trajectory.summary = None
|
|
642
|
+
await self._record_pause(signal.pause, trajectory, tracker)
|
|
643
|
+
return signal.pause
|
|
644
|
+
except Exception as exc:
|
|
645
|
+
failure_payload = self._build_failure_payload(
|
|
646
|
+
spec, parsed_args, exc
|
|
647
|
+
)
|
|
648
|
+
error = (
|
|
649
|
+
f"tool '{spec.name}' raised {exc.__class__.__name__}: {exc}"
|
|
650
|
+
)
|
|
651
|
+
trajectory.steps.append(
|
|
652
|
+
TrajectoryStep(
|
|
653
|
+
action=action,
|
|
654
|
+
error=error,
|
|
655
|
+
failure=failure_payload,
|
|
656
|
+
)
|
|
657
|
+
)
|
|
658
|
+
tracker.record_hop()
|
|
659
|
+
trajectory.summary = None
|
|
660
|
+
last_observation = None
|
|
661
|
+
continue
|
|
662
|
+
|
|
663
|
+
try:
|
|
664
|
+
observation = spec.out_model.model_validate(result)
|
|
665
|
+
except ValidationError as exc:
|
|
666
|
+
error = prompts.render_output_validation_error(
|
|
667
|
+
spec.name,
|
|
668
|
+
json.dumps(exc.errors(), ensure_ascii=False),
|
|
669
|
+
)
|
|
670
|
+
tracker.record_hop()
|
|
671
|
+
trajectory.steps.append(TrajectoryStep(action=action, error=error))
|
|
672
|
+
trajectory.summary = None
|
|
673
|
+
last_observation = None
|
|
674
|
+
continue
|
|
675
|
+
|
|
676
|
+
trajectory.steps.append(
|
|
677
|
+
TrajectoryStep(action=action, observation=observation)
|
|
678
|
+
)
|
|
679
|
+
tracker.record_hop()
|
|
680
|
+
trajectory.summary = None
|
|
681
|
+
last_observation = observation.model_dump(mode="json")
|
|
682
|
+
self._record_hint_progress(spec.name, trajectory)
|
|
683
|
+
trajectory.resume_user_input = None
|
|
684
|
+
|
|
685
|
+
if tracker.deadline_triggered or tracker.hop_exhausted:
|
|
686
|
+
thought = (
|
|
687
|
+
prompts.render_deadline_exhausted()
|
|
688
|
+
if tracker.deadline_triggered
|
|
689
|
+
else prompts.render_hop_budget_violation(self._hop_budget or 0)
|
|
690
|
+
)
|
|
691
|
+
return self._finish(
|
|
692
|
+
trajectory,
|
|
693
|
+
reason="budget_exhausted",
|
|
694
|
+
payload=last_observation,
|
|
695
|
+
thought=thought,
|
|
696
|
+
constraints=tracker,
|
|
697
|
+
)
|
|
698
|
+
return self._finish(
|
|
699
|
+
trajectory,
|
|
700
|
+
reason="no_path",
|
|
701
|
+
payload=last_observation,
|
|
702
|
+
thought="iteration limit reached",
|
|
703
|
+
constraints=tracker,
|
|
704
|
+
)
|
|
705
|
+
finally:
|
|
706
|
+
self._active_trajectory = None
|
|
707
|
+
self._active_tracker = None
|
|
708
|
+
|
|
709
|
+
async def step(self, trajectory: Trajectory) -> PlannerAction:
|
|
710
|
+
base_messages = await self._build_messages(trajectory)
|
|
711
|
+
messages: list[dict[str, str]] = list(base_messages)
|
|
712
|
+
last_error: str | None = None
|
|
713
|
+
|
|
714
|
+
for _ in range(self._repair_attempts):
|
|
715
|
+
if last_error is not None:
|
|
716
|
+
messages = list(base_messages) + [
|
|
717
|
+
{
|
|
718
|
+
"role": "system",
|
|
719
|
+
"content": prompts.render_repair_message(last_error),
|
|
720
|
+
}
|
|
721
|
+
]
|
|
722
|
+
|
|
723
|
+
raw = await self._client.complete(
|
|
724
|
+
messages=messages,
|
|
725
|
+
response_format=self._response_format,
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
try:
|
|
729
|
+
return PlannerAction.model_validate_json(raw)
|
|
730
|
+
except ValidationError as exc:
|
|
731
|
+
last_error = json.dumps(exc.errors(), ensure_ascii=False)
|
|
732
|
+
continue
|
|
733
|
+
|
|
734
|
+
raise RuntimeError("Planner failed to produce valid JSON after repair attempts")
|
|
735
|
+
|
|
736
|
+
async def _execute_parallel_plan(
|
|
737
|
+
self,
|
|
738
|
+
action: PlannerAction,
|
|
739
|
+
trajectory: Trajectory,
|
|
740
|
+
tracker: _ConstraintTracker,
|
|
741
|
+
) -> tuple[Any | None, PlannerPause | None]:
|
|
742
|
+
if action.next_node is not None:
|
|
743
|
+
error = prompts.render_parallel_with_next_node(action.next_node)
|
|
744
|
+
trajectory.steps.append(TrajectoryStep(action=action, error=error))
|
|
745
|
+
trajectory.summary = None
|
|
746
|
+
return None, None
|
|
747
|
+
|
|
748
|
+
if not action.plan:
|
|
749
|
+
error = prompts.render_empty_parallel_plan()
|
|
750
|
+
trajectory.steps.append(TrajectoryStep(action=action, error=error))
|
|
751
|
+
trajectory.summary = None
|
|
752
|
+
return None, None
|
|
753
|
+
|
|
754
|
+
validation_errors: list[str] = []
|
|
755
|
+
entries: list[tuple[ParallelCall, NodeSpec, BaseModel]] = []
|
|
756
|
+
for plan_item in action.plan:
|
|
757
|
+
spec = self._spec_by_name.get(plan_item.node)
|
|
758
|
+
if spec is None:
|
|
759
|
+
validation_errors.append(
|
|
760
|
+
prompts.render_invalid_node(
|
|
761
|
+
plan_item.node, list(self._spec_by_name.keys())
|
|
762
|
+
)
|
|
763
|
+
)
|
|
764
|
+
continue
|
|
765
|
+
try:
|
|
766
|
+
parsed_args = spec.args_model.model_validate(plan_item.args or {})
|
|
767
|
+
except ValidationError as exc:
|
|
768
|
+
validation_errors.append(
|
|
769
|
+
prompts.render_validation_error(
|
|
770
|
+
spec.name,
|
|
771
|
+
json.dumps(exc.errors(), ensure_ascii=False),
|
|
772
|
+
)
|
|
773
|
+
)
|
|
774
|
+
continue
|
|
775
|
+
entries.append((plan_item, spec, parsed_args))
|
|
776
|
+
|
|
777
|
+
if validation_errors:
|
|
778
|
+
error = prompts.render_parallel_setup_error(validation_errors)
|
|
779
|
+
trajectory.steps.append(TrajectoryStep(action=action, error=error))
|
|
780
|
+
trajectory.summary = None
|
|
781
|
+
return None, None
|
|
782
|
+
|
|
783
|
+
ctx = _PlannerContext(self, trajectory)
|
|
784
|
+
results = await asyncio.gather(
|
|
785
|
+
*(
|
|
786
|
+
self._run_parallel_branch(spec, parsed_args, ctx)
|
|
787
|
+
for (_, spec, parsed_args) in entries
|
|
788
|
+
)
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
branch_payloads: list[dict[str, Any]] = []
|
|
792
|
+
success_payloads: list[Any] = []
|
|
793
|
+
failure_entries: list[dict[str, Any]] = []
|
|
794
|
+
pause_result: PlannerPause | None = None
|
|
795
|
+
|
|
796
|
+
for (_, spec, parsed_args), outcome in zip(
|
|
797
|
+
entries, results, strict=False
|
|
798
|
+
):
|
|
799
|
+
tracker.record_hop()
|
|
800
|
+
payload: dict[str, Any] = {
|
|
801
|
+
"node": spec.name,
|
|
802
|
+
"args": parsed_args.model_dump(mode="json"),
|
|
803
|
+
}
|
|
804
|
+
if outcome.pause is not None and pause_result is None:
|
|
805
|
+
pause_result = outcome.pause
|
|
806
|
+
payload["pause"] = {
|
|
807
|
+
"reason": outcome.pause.reason,
|
|
808
|
+
"payload": dict(outcome.pause.payload),
|
|
809
|
+
}
|
|
810
|
+
elif outcome.observation is not None:
|
|
811
|
+
obs_json = outcome.observation.model_dump(mode="json")
|
|
812
|
+
payload["observation"] = obs_json
|
|
813
|
+
success_payloads.append(obs_json)
|
|
814
|
+
self._record_hint_progress(spec.name, trajectory)
|
|
815
|
+
else:
|
|
816
|
+
error_text = outcome.error or prompts.render_parallel_unknown_failure(
|
|
817
|
+
spec.name
|
|
818
|
+
)
|
|
819
|
+
payload["error"] = error_text
|
|
820
|
+
if outcome.failure is not None:
|
|
821
|
+
payload["failure"] = dict(outcome.failure)
|
|
822
|
+
failure_entries.append(
|
|
823
|
+
{
|
|
824
|
+
"node": spec.name,
|
|
825
|
+
"error": error_text,
|
|
826
|
+
"failure": dict(outcome.failure),
|
|
827
|
+
}
|
|
828
|
+
)
|
|
829
|
+
else:
|
|
830
|
+
failure_entries.append(
|
|
831
|
+
{"node": spec.name, "error": error_text}
|
|
832
|
+
)
|
|
833
|
+
branch_payloads.append(payload)
|
|
834
|
+
|
|
835
|
+
stats = {"success": len(success_payloads), "failed": len(failure_entries)}
|
|
836
|
+
observation: dict[str, Any] = {
|
|
837
|
+
"branches": branch_payloads,
|
|
838
|
+
"stats": stats,
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
if pause_result is not None:
|
|
842
|
+
observation["join"] = {
|
|
843
|
+
"status": "skipped",
|
|
844
|
+
"reason": "pause",
|
|
845
|
+
}
|
|
846
|
+
trajectory.steps.append(
|
|
847
|
+
TrajectoryStep(action=action, observation=observation)
|
|
848
|
+
)
|
|
849
|
+
trajectory.summary = None
|
|
850
|
+
await self._record_pause(pause_result, trajectory, tracker)
|
|
851
|
+
return observation, pause_result
|
|
852
|
+
|
|
853
|
+
join_payload: dict[str, Any] | None = None
|
|
854
|
+
join_error: str | None = None
|
|
855
|
+
join_failure: Mapping[str, Any] | None = None
|
|
856
|
+
join_spec: NodeSpec | None = None
|
|
857
|
+
join_args_template: dict[str, Any] | None = None
|
|
858
|
+
|
|
859
|
+
if action.join is not None:
|
|
860
|
+
join_spec = self._spec_by_name.get(action.join.node)
|
|
861
|
+
if join_spec is None:
|
|
862
|
+
join_error = prompts.render_invalid_node(
|
|
863
|
+
action.join.node, list(self._spec_by_name.keys())
|
|
864
|
+
)
|
|
865
|
+
elif failure_entries:
|
|
866
|
+
join_payload = {
|
|
867
|
+
"node": join_spec.name,
|
|
868
|
+
"status": "skipped",
|
|
869
|
+
"reason": "branch_failures",
|
|
870
|
+
"failures": list(failure_entries),
|
|
871
|
+
}
|
|
872
|
+
else:
|
|
873
|
+
join_args_template = dict(action.join.args or {})
|
|
874
|
+
join_fields = join_spec.args_model.model_fields
|
|
875
|
+
if "expect" in join_fields and "expect" not in join_args_template:
|
|
876
|
+
join_args_template["expect"] = len(entries)
|
|
877
|
+
if "results" in join_fields and "results" not in join_args_template:
|
|
878
|
+
join_args_template["results"] = list(success_payloads)
|
|
879
|
+
if "branches" in join_fields and "branches" not in join_args_template:
|
|
880
|
+
join_args_template["branches"] = list(branch_payloads)
|
|
881
|
+
if "failures" in join_fields and "failures" not in join_args_template:
|
|
882
|
+
join_args_template["failures"] = []
|
|
883
|
+
if (
|
|
884
|
+
"success_count" in join_fields
|
|
885
|
+
and "success_count" not in join_args_template
|
|
886
|
+
):
|
|
887
|
+
join_args_template["success_count"] = len(success_payloads)
|
|
888
|
+
if (
|
|
889
|
+
"failure_count" in join_fields
|
|
890
|
+
and "failure_count" not in join_args_template
|
|
891
|
+
):
|
|
892
|
+
join_args_template["failure_count"] = len(failure_entries)
|
|
893
|
+
|
|
894
|
+
try:
|
|
895
|
+
join_args = join_spec.args_model.model_validate(join_args_template)
|
|
896
|
+
except ValidationError as exc:
|
|
897
|
+
join_error = prompts.render_validation_error(
|
|
898
|
+
join_spec.name,
|
|
899
|
+
json.dumps(exc.errors(), ensure_ascii=False),
|
|
900
|
+
)
|
|
901
|
+
else:
|
|
902
|
+
join_ctx = _PlannerContext(self, trajectory)
|
|
903
|
+
join_ctx.meta.update(
|
|
904
|
+
{
|
|
905
|
+
"parallel_results": branch_payloads,
|
|
906
|
+
"parallel_success_count": len(success_payloads),
|
|
907
|
+
"parallel_failure_count": len(failure_entries),
|
|
908
|
+
}
|
|
909
|
+
)
|
|
910
|
+
if failure_entries:
|
|
911
|
+
join_ctx.meta["parallel_failures"] = list(failure_entries)
|
|
912
|
+
join_ctx.meta["parallel_input"] = dict(join_args_template)
|
|
913
|
+
|
|
914
|
+
try:
|
|
915
|
+
join_raw = await join_spec.node.func(join_args, join_ctx)
|
|
916
|
+
except _PlannerPauseSignal as signal:
|
|
917
|
+
tracker.record_hop()
|
|
918
|
+
join_payload = {
|
|
919
|
+
"node": join_spec.name,
|
|
920
|
+
"pause": {
|
|
921
|
+
"reason": signal.pause.reason,
|
|
922
|
+
"payload": dict(signal.pause.payload),
|
|
923
|
+
},
|
|
924
|
+
}
|
|
925
|
+
observation["join"] = join_payload
|
|
926
|
+
trajectory.steps.append(
|
|
927
|
+
TrajectoryStep(action=action, observation=observation)
|
|
928
|
+
)
|
|
929
|
+
trajectory.summary = None
|
|
930
|
+
await self._record_pause(signal.pause, trajectory, tracker)
|
|
931
|
+
return observation, signal.pause
|
|
932
|
+
except Exception as exc:
|
|
933
|
+
tracker.record_hop()
|
|
934
|
+
join_error = (
|
|
935
|
+
f"tool '{join_spec.name}' raised "
|
|
936
|
+
f"{exc.__class__.__name__}: {exc}"
|
|
937
|
+
)
|
|
938
|
+
join_failure = self._build_failure_payload(
|
|
939
|
+
join_spec, join_args, exc
|
|
940
|
+
)
|
|
941
|
+
else:
|
|
942
|
+
try:
|
|
943
|
+
join_model = join_spec.out_model.model_validate(join_raw)
|
|
944
|
+
except ValidationError as exc:
|
|
945
|
+
tracker.record_hop()
|
|
946
|
+
join_error = prompts.render_output_validation_error(
|
|
947
|
+
join_spec.name,
|
|
948
|
+
json.dumps(exc.errors(), ensure_ascii=False),
|
|
949
|
+
)
|
|
950
|
+
else:
|
|
951
|
+
tracker.record_hop()
|
|
952
|
+
self._record_hint_progress(join_spec.name, trajectory)
|
|
953
|
+
join_payload = {
|
|
954
|
+
"node": join_spec.name,
|
|
955
|
+
"observation": join_model.model_dump(mode="json"),
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
if action.join is not None and "join" not in observation:
|
|
959
|
+
if join_payload is not None:
|
|
960
|
+
observation["join"] = join_payload
|
|
961
|
+
else:
|
|
962
|
+
join_name = (
|
|
963
|
+
join_spec.name
|
|
964
|
+
if join_spec is not None
|
|
965
|
+
else action.join.node
|
|
966
|
+
if action.join is not None
|
|
967
|
+
else "join"
|
|
968
|
+
)
|
|
969
|
+
join_entry: dict[str, Any] = {"node": join_name}
|
|
970
|
+
if join_error is not None:
|
|
971
|
+
join_entry["error"] = join_error
|
|
972
|
+
if join_failure is not None:
|
|
973
|
+
join_entry["failure"] = dict(join_failure)
|
|
974
|
+
if "error" in join_entry or "failure" in join_entry:
|
|
975
|
+
observation["join"] = join_entry
|
|
976
|
+
elif action.join is not None and join_spec is None:
|
|
977
|
+
observation["join"] = join_entry
|
|
978
|
+
|
|
979
|
+
trajectory.steps.append(
|
|
980
|
+
TrajectoryStep(action=action, observation=observation)
|
|
981
|
+
)
|
|
982
|
+
trajectory.summary = None
|
|
983
|
+
return observation, None
|
|
984
|
+
|
|
985
|
+
async def _run_parallel_branch(
|
|
986
|
+
self, spec: NodeSpec, args: BaseModel, ctx: _PlannerContext
|
|
987
|
+
) -> _BranchExecutionResult:
|
|
988
|
+
try:
|
|
989
|
+
raw = await spec.node.func(args, ctx)
|
|
990
|
+
except _PlannerPauseSignal as signal:
|
|
991
|
+
return _BranchExecutionResult(pause=signal.pause)
|
|
992
|
+
except Exception as exc:
|
|
993
|
+
failure_payload = self._build_failure_payload(spec, args, exc)
|
|
994
|
+
error = (
|
|
995
|
+
f"tool '{spec.name}' raised {exc.__class__.__name__}: {exc}"
|
|
996
|
+
)
|
|
997
|
+
return _BranchExecutionResult(error=error, failure=failure_payload)
|
|
998
|
+
|
|
999
|
+
try:
|
|
1000
|
+
observation = spec.out_model.model_validate(raw)
|
|
1001
|
+
except ValidationError as exc:
|
|
1002
|
+
error = prompts.render_output_validation_error(
|
|
1003
|
+
spec.name,
|
|
1004
|
+
json.dumps(exc.errors(), ensure_ascii=False),
|
|
1005
|
+
)
|
|
1006
|
+
return _BranchExecutionResult(error=error)
|
|
1007
|
+
|
|
1008
|
+
return _BranchExecutionResult(observation=observation)
|
|
1009
|
+
|
|
1010
|
+
async def _build_messages(self, trajectory: Trajectory) -> list[dict[str, str]]:
|
|
1011
|
+
messages: list[dict[str, str]] = [
|
|
1012
|
+
{"role": "system", "content": self._system_prompt},
|
|
1013
|
+
{
|
|
1014
|
+
"role": "user",
|
|
1015
|
+
"content": prompts.build_user_prompt(
|
|
1016
|
+
trajectory.query,
|
|
1017
|
+
trajectory.context_meta,
|
|
1018
|
+
),
|
|
1019
|
+
},
|
|
1020
|
+
]
|
|
1021
|
+
|
|
1022
|
+
history_messages: list[dict[str, str]] = []
|
|
1023
|
+
for step in trajectory.steps:
|
|
1024
|
+
action_payload = json.dumps(
|
|
1025
|
+
step.action.model_dump(mode="json"),
|
|
1026
|
+
ensure_ascii=False,
|
|
1027
|
+
sort_keys=True,
|
|
1028
|
+
)
|
|
1029
|
+
history_messages.append({"role": "assistant", "content": action_payload})
|
|
1030
|
+
history_messages.append(
|
|
1031
|
+
{
|
|
1032
|
+
"role": "user",
|
|
1033
|
+
"content": prompts.render_observation(
|
|
1034
|
+
observation=step._serialise_observation(),
|
|
1035
|
+
error=step.error,
|
|
1036
|
+
failure=step.failure,
|
|
1037
|
+
),
|
|
1038
|
+
}
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
if trajectory.resume_user_input:
|
|
1042
|
+
history_messages.append(
|
|
1043
|
+
{
|
|
1044
|
+
"role": "user",
|
|
1045
|
+
"content": prompts.render_resume_user_input(
|
|
1046
|
+
trajectory.resume_user_input
|
|
1047
|
+
),
|
|
1048
|
+
}
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
if self._token_budget is None:
|
|
1052
|
+
return messages + history_messages
|
|
1053
|
+
|
|
1054
|
+
candidate = messages + history_messages
|
|
1055
|
+
if self._estimate_size(candidate) <= self._token_budget:
|
|
1056
|
+
return candidate
|
|
1057
|
+
|
|
1058
|
+
summary = await self._summarise_trajectory(trajectory)
|
|
1059
|
+
summary_message = {
|
|
1060
|
+
"role": "system",
|
|
1061
|
+
"content": prompts.render_summary(summary.compact()),
|
|
1062
|
+
}
|
|
1063
|
+
condensed: list[dict[str, str]] = messages + [summary_message]
|
|
1064
|
+
if trajectory.steps:
|
|
1065
|
+
last_step = trajectory.steps[-1]
|
|
1066
|
+
condensed.append(
|
|
1067
|
+
{
|
|
1068
|
+
"role": "assistant",
|
|
1069
|
+
"content": json.dumps(
|
|
1070
|
+
last_step.action.model_dump(mode="json"),
|
|
1071
|
+
ensure_ascii=False,
|
|
1072
|
+
sort_keys=True,
|
|
1073
|
+
),
|
|
1074
|
+
}
|
|
1075
|
+
)
|
|
1076
|
+
condensed.append(
|
|
1077
|
+
{
|
|
1078
|
+
"role": "user",
|
|
1079
|
+
"content": prompts.render_observation(
|
|
1080
|
+
observation=last_step._serialise_observation(),
|
|
1081
|
+
error=last_step.error,
|
|
1082
|
+
failure=last_step.failure,
|
|
1083
|
+
),
|
|
1084
|
+
}
|
|
1085
|
+
)
|
|
1086
|
+
if trajectory.resume_user_input:
|
|
1087
|
+
condensed.append(
|
|
1088
|
+
{
|
|
1089
|
+
"role": "user",
|
|
1090
|
+
"content": prompts.render_resume_user_input(
|
|
1091
|
+
trajectory.resume_user_input
|
|
1092
|
+
),
|
|
1093
|
+
}
|
|
1094
|
+
)
|
|
1095
|
+
return condensed
|
|
1096
|
+
|
|
1097
|
+
def _estimate_size(self, messages: Sequence[Mapping[str, str]]) -> int:
|
|
1098
|
+
return sum(len(item.get("content", "")) for item in messages)
|
|
1099
|
+
|
|
1100
|
+
async def _summarise_trajectory(
|
|
1101
|
+
self, trajectory: Trajectory
|
|
1102
|
+
) -> TrajectorySummary:
|
|
1103
|
+
if trajectory.summary is not None:
|
|
1104
|
+
return trajectory.summary
|
|
1105
|
+
|
|
1106
|
+
base_summary = trajectory.compress()
|
|
1107
|
+
summary_text = prompts.render_summary(base_summary.compact())
|
|
1108
|
+
if (
|
|
1109
|
+
self._summarizer_client is not None
|
|
1110
|
+
and self._token_budget is not None
|
|
1111
|
+
and len(summary_text) > self._token_budget
|
|
1112
|
+
):
|
|
1113
|
+
messages = prompts.build_summarizer_messages(
|
|
1114
|
+
trajectory.query,
|
|
1115
|
+
trajectory.to_history(),
|
|
1116
|
+
base_summary.compact(),
|
|
1117
|
+
)
|
|
1118
|
+
response_format = {
|
|
1119
|
+
"type": "json_schema",
|
|
1120
|
+
"json_schema": {
|
|
1121
|
+
"name": "trajectory_summary",
|
|
1122
|
+
"schema": TrajectorySummary.model_json_schema(),
|
|
1123
|
+
},
|
|
1124
|
+
}
|
|
1125
|
+
try:
|
|
1126
|
+
raw = await self._summarizer_client.complete(
|
|
1127
|
+
messages=messages,
|
|
1128
|
+
response_format=response_format,
|
|
1129
|
+
)
|
|
1130
|
+
summary = TrajectorySummary.model_validate_json(raw)
|
|
1131
|
+
summary.note = summary.note or "llm"
|
|
1132
|
+
trajectory.summary = summary
|
|
1133
|
+
return summary
|
|
1134
|
+
except Exception:
|
|
1135
|
+
base_summary.note = "rule_based_fallback"
|
|
1136
|
+
trajectory.summary = base_summary
|
|
1137
|
+
return base_summary
|
|
1138
|
+
|
|
1139
|
+
def _check_action_constraints(
|
|
1140
|
+
self,
|
|
1141
|
+
action: PlannerAction,
|
|
1142
|
+
trajectory: Trajectory,
|
|
1143
|
+
tracker: _ConstraintTracker,
|
|
1144
|
+
) -> str | None:
|
|
1145
|
+
hints = self._planning_hints
|
|
1146
|
+
node_name = action.next_node
|
|
1147
|
+
if node_name and not tracker.has_budget_for_next_tool():
|
|
1148
|
+
limit = self._hop_budget if self._hop_budget is not None else 0
|
|
1149
|
+
return prompts.render_hop_budget_violation(limit)
|
|
1150
|
+
if node_name and node_name in hints.disallow_nodes:
|
|
1151
|
+
return prompts.render_disallowed_node(node_name)
|
|
1152
|
+
if hints.max_parallel is not None and action.plan:
|
|
1153
|
+
if len(action.plan) > hints.max_parallel:
|
|
1154
|
+
return prompts.render_parallel_limit(hints.max_parallel)
|
|
1155
|
+
if hints.sequential_only and action.plan:
|
|
1156
|
+
for item in action.plan:
|
|
1157
|
+
candidate = item.node
|
|
1158
|
+
if candidate in hints.sequential_only:
|
|
1159
|
+
return prompts.render_sequential_only(candidate)
|
|
1160
|
+
if hints.ordering_hints and node_name is not None:
|
|
1161
|
+
state = trajectory.hint_state.setdefault(
|
|
1162
|
+
"ordering_state",
|
|
1163
|
+
{"completed": [], "warned": False},
|
|
1164
|
+
)
|
|
1165
|
+
completed = state.setdefault("completed", [])
|
|
1166
|
+
expected_index = len(completed)
|
|
1167
|
+
if expected_index < len(hints.ordering_hints):
|
|
1168
|
+
expected_node = hints.ordering_hints[expected_index]
|
|
1169
|
+
if node_name != expected_node:
|
|
1170
|
+
if (
|
|
1171
|
+
node_name in hints.ordering_hints
|
|
1172
|
+
and not state.get("warned", False)
|
|
1173
|
+
):
|
|
1174
|
+
state["warned"] = True
|
|
1175
|
+
return prompts.render_ordering_hint_violation(
|
|
1176
|
+
hints.ordering_hints,
|
|
1177
|
+
node_name,
|
|
1178
|
+
)
|
|
1179
|
+
return None
|
|
1180
|
+
|
|
1181
|
+
def _record_hint_progress(self, node_name: str, trajectory: Trajectory) -> None:
|
|
1182
|
+
hints = self._planning_hints
|
|
1183
|
+
if not hints.ordering_hints:
|
|
1184
|
+
return
|
|
1185
|
+
state = trajectory.hint_state.setdefault(
|
|
1186
|
+
"ordering_state",
|
|
1187
|
+
{"completed": [], "warned": False},
|
|
1188
|
+
)
|
|
1189
|
+
completed = state.setdefault("completed", [])
|
|
1190
|
+
expected_index = len(completed)
|
|
1191
|
+
if (
|
|
1192
|
+
expected_index < len(hints.ordering_hints)
|
|
1193
|
+
and node_name == hints.ordering_hints[expected_index]
|
|
1194
|
+
):
|
|
1195
|
+
completed.append(node_name)
|
|
1196
|
+
state["warned"] = False
|
|
1197
|
+
|
|
1198
|
+
def _build_failure_payload(
|
|
1199
|
+
self, spec: NodeSpec, args: BaseModel, exc: Exception
|
|
1200
|
+
) -> dict[str, Any]:
|
|
1201
|
+
suggestion = getattr(exc, "suggestion", None)
|
|
1202
|
+
if suggestion is None:
|
|
1203
|
+
suggestion = getattr(exc, "remedy", None)
|
|
1204
|
+
payload: dict[str, Any] = {
|
|
1205
|
+
"node": spec.name,
|
|
1206
|
+
"args": args.model_dump(mode="json"),
|
|
1207
|
+
"error_code": exc.__class__.__name__,
|
|
1208
|
+
"message": str(exc),
|
|
1209
|
+
}
|
|
1210
|
+
if suggestion:
|
|
1211
|
+
payload["suggestion"] = str(suggestion)
|
|
1212
|
+
return payload
|
|
1213
|
+
|
|
1214
|
+
async def pause(
|
|
1215
|
+
self, reason: PlannerPauseReason, payload: Mapping[str, Any] | None = None
|
|
1216
|
+
) -> PlannerPause:
|
|
1217
|
+
if self._active_trajectory is None:
|
|
1218
|
+
raise RuntimeError("pause() requires an active planner run")
|
|
1219
|
+
try:
|
|
1220
|
+
await self._pause_from_context(
|
|
1221
|
+
reason,
|
|
1222
|
+
dict(payload or {}),
|
|
1223
|
+
self._active_trajectory,
|
|
1224
|
+
)
|
|
1225
|
+
except _PlannerPauseSignal as signal:
|
|
1226
|
+
return signal.pause
|
|
1227
|
+
raise RuntimeError("pause request did not trigger")
|
|
1228
|
+
|
|
1229
|
+
async def _pause_from_context(
|
|
1230
|
+
self,
|
|
1231
|
+
reason: PlannerPauseReason,
|
|
1232
|
+
payload: dict[str, Any],
|
|
1233
|
+
trajectory: Trajectory,
|
|
1234
|
+
) -> PlannerPause:
|
|
1235
|
+
if not self._pause_enabled:
|
|
1236
|
+
raise RuntimeError("Pause/resume is disabled for this planner")
|
|
1237
|
+
pause = PlannerPause(
|
|
1238
|
+
reason=reason,
|
|
1239
|
+
payload=dict(payload),
|
|
1240
|
+
resume_token=uuid4().hex,
|
|
1241
|
+
)
|
|
1242
|
+
await self._record_pause(pause, trajectory, self._active_tracker)
|
|
1243
|
+
raise _PlannerPauseSignal(pause)
|
|
1244
|
+
|
|
1245
|
+
async def _record_pause(
|
|
1246
|
+
self,
|
|
1247
|
+
pause: PlannerPause,
|
|
1248
|
+
trajectory: Trajectory,
|
|
1249
|
+
tracker: _ConstraintTracker | None,
|
|
1250
|
+
) -> None:
|
|
1251
|
+
snapshot = Trajectory.from_serialised(trajectory.serialise())
|
|
1252
|
+
record = _PauseRecord(
|
|
1253
|
+
trajectory=snapshot,
|
|
1254
|
+
reason=pause.reason,
|
|
1255
|
+
payload=dict(pause.payload),
|
|
1256
|
+
constraints=tracker.snapshot() if tracker is not None else None,
|
|
1257
|
+
)
|
|
1258
|
+
await self._store_pause_record(pause.resume_token, record)
|
|
1259
|
+
|
|
1260
|
+
async def _store_pause_record(self, token: str, record: _PauseRecord) -> None:
|
|
1261
|
+
self._pause_records[token] = record
|
|
1262
|
+
if self._state_store is None:
|
|
1263
|
+
return
|
|
1264
|
+
saver = getattr(self._state_store, "save_planner_state", None)
|
|
1265
|
+
if saver is None:
|
|
1266
|
+
return
|
|
1267
|
+
payload = self._serialise_pause_record(record)
|
|
1268
|
+
result = saver(token, payload)
|
|
1269
|
+
if inspect.isawaitable(result):
|
|
1270
|
+
await result
|
|
1271
|
+
|
|
1272
|
+
async def _load_pause_record(self, token: str) -> _PauseRecord:
|
|
1273
|
+
record = self._pause_records.pop(token, None)
|
|
1274
|
+
if record is not None:
|
|
1275
|
+
return record
|
|
1276
|
+
if self._state_store is not None:
|
|
1277
|
+
loader = getattr(self._state_store, "load_planner_state", None)
|
|
1278
|
+
if loader is not None:
|
|
1279
|
+
result = loader(token)
|
|
1280
|
+
if inspect.isawaitable(result):
|
|
1281
|
+
result = await result
|
|
1282
|
+
if result is None:
|
|
1283
|
+
raise KeyError(token)
|
|
1284
|
+
trajectory = Trajectory.from_serialised(result["trajectory"])
|
|
1285
|
+
payload = dict(result.get("payload", {}))
|
|
1286
|
+
reason = result.get("reason", "await_input")
|
|
1287
|
+
constraints = result.get("constraints")
|
|
1288
|
+
return _PauseRecord(
|
|
1289
|
+
trajectory=trajectory,
|
|
1290
|
+
reason=reason,
|
|
1291
|
+
payload=payload,
|
|
1292
|
+
constraints=constraints,
|
|
1293
|
+
)
|
|
1294
|
+
raise KeyError(token)
|
|
1295
|
+
|
|
1296
|
+
def _serialise_pause_record(self, record: _PauseRecord) -> dict[str, Any]:
|
|
1297
|
+
return {
|
|
1298
|
+
"trajectory": record.trajectory.serialise(),
|
|
1299
|
+
"reason": record.reason,
|
|
1300
|
+
"payload": dict(record.payload),
|
|
1301
|
+
"constraints": dict(record.constraints)
|
|
1302
|
+
if record.constraints is not None
|
|
1303
|
+
else None,
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1306
|
+
def _finish(
|
|
1307
|
+
self,
|
|
1308
|
+
trajectory: Trajectory,
|
|
1309
|
+
*,
|
|
1310
|
+
reason: Literal["answer_complete", "no_path", "budget_exhausted"],
|
|
1311
|
+
payload: Any,
|
|
1312
|
+
thought: str,
|
|
1313
|
+
constraints: _ConstraintTracker | None = None,
|
|
1314
|
+
error: str | None = None,
|
|
1315
|
+
) -> PlannerFinish:
|
|
1316
|
+
metadata = {
|
|
1317
|
+
"reason": reason,
|
|
1318
|
+
"thought": thought,
|
|
1319
|
+
"steps": trajectory.to_history(),
|
|
1320
|
+
"step_count": len(trajectory.steps),
|
|
1321
|
+
}
|
|
1322
|
+
if constraints is not None:
|
|
1323
|
+
metadata["constraints"] = constraints.snapshot()
|
|
1324
|
+
if error is not None:
|
|
1325
|
+
metadata["error"] = error
|
|
1326
|
+
return PlannerFinish(reason=reason, payload=payload, metadata=metadata)
|
|
1327
|
+
|
|
1328
|
+
|
|
1329
|
+
__all__ = [
|
|
1330
|
+
"ParallelCall",
|
|
1331
|
+
"ParallelJoin",
|
|
1332
|
+
"PlannerAction",
|
|
1333
|
+
"PlannerFinish",
|
|
1334
|
+
"PlannerPause",
|
|
1335
|
+
"ReactPlanner",
|
|
1336
|
+
"Trajectory",
|
|
1337
|
+
"TrajectoryStep",
|
|
1338
|
+
"TrajectorySummary",
|
|
1339
|
+
]
|