pydantic-graph-studio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_graph_studio/__init__.py +55 -0
- pydantic_graph_studio/cli.py +201 -0
- pydantic_graph_studio/introspection.py +209 -0
- pydantic_graph_studio/runtime.py +433 -0
- pydantic_graph_studio/schemas.py +99 -0
- pydantic_graph_studio/server.py +175 -0
- pydantic_graph_studio/ui/__init__.py +1 -0
- pydantic_graph_studio/ui/_build/tailwind.input.css +11 -0
- pydantic_graph_studio/ui/assets/app.js +783 -0
- pydantic_graph_studio/ui/assets/dagre.min.js +3809 -0
- pydantic_graph_studio/ui/assets/react-dom.production.min.js +267 -0
- pydantic_graph_studio/ui/assets/react.production.min.js +31 -0
- pydantic_graph_studio/ui/assets/reactflow.css +406 -0
- pydantic_graph_studio/ui/assets/reactflow.min.js +10 -0
- pydantic_graph_studio/ui/assets/tailwind.css +1 -0
- pydantic_graph_studio/ui/assets/theme.css +145 -0
- pydantic_graph_studio/ui/index.html +19 -0
- pydantic_graph_studio-0.1.0.dist-info/METADATA +42 -0
- pydantic_graph_studio-0.1.0.dist-info/RECORD +21 -0
- pydantic_graph_studio-0.1.0.dist-info/WHEEL +4 -0
- pydantic_graph_studio-0.1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
"""Runtime instrumentation utilities for pydantic_graph execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import types
|
|
8
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
|
|
9
|
+
from contextlib import asynccontextmanager, suppress
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
from uuid import uuid4
|
|
13
|
+
|
|
14
|
+
from pydantic_graph import Graph
|
|
15
|
+
from pydantic_graph.graph import GraphRun, GraphRunResult
|
|
16
|
+
from pydantic_graph.nodes import BaseNode, End
|
|
17
|
+
|
|
18
|
+
from pydantic_graph_studio.schemas import (
|
|
19
|
+
EdgeTakenEvent,
|
|
20
|
+
ErrorEvent,
|
|
21
|
+
Event,
|
|
22
|
+
NodeEndEvent,
|
|
23
|
+
NodeStartEvent,
|
|
24
|
+
RunEndEvent,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
BetaGraph: type[Any] | None = None
|
|
28
|
+
BetaEndMarker: type[Any] = object
|
|
29
|
+
BetaJoinItem: type[Any] = object
|
|
30
|
+
try: # pragma: no cover - optional beta support
|
|
31
|
+
from pydantic_graph.beta.graph import EndMarker as _BetaEndMarker
|
|
32
|
+
from pydantic_graph.beta.graph import Graph as _BetaGraph
|
|
33
|
+
from pydantic_graph.beta.graph import JoinItem as _BetaJoinItem
|
|
34
|
+
except ModuleNotFoundError: # pragma: no cover
|
|
35
|
+
pass
|
|
36
|
+
else:
|
|
37
|
+
BetaGraph = _BetaGraph
|
|
38
|
+
BetaEndMarker = _BetaEndMarker
|
|
39
|
+
BetaJoinItem = _BetaJoinItem
|
|
40
|
+
|
|
41
|
+
HookReturn = Awaitable[None] | None
|
|
42
|
+
NodeStartHook = Callable[[GraphRun[Any, Any, Any], BaseNode[Any, Any, Any]], HookReturn]
|
|
43
|
+
NodeEndHook = Callable[
|
|
44
|
+
[GraphRun[Any, Any, Any], BaseNode[Any, Any, Any], BaseNode[Any, Any, Any] | End[Any]],
|
|
45
|
+
HookReturn,
|
|
46
|
+
]
|
|
47
|
+
EdgeTakenHook = Callable[[GraphRun[Any, Any, Any], BaseNode[Any, Any, Any], BaseNode[Any, Any, Any]], HookReturn]
|
|
48
|
+
RunEndHook = Callable[[GraphRun[Any, Any, Any], End[Any]], HookReturn]
|
|
49
|
+
ErrorHook = Callable[[GraphRun[Any, Any, Any], BaseNode[Any, Any, Any], BaseException], HookReturn]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(slots=True)
|
|
53
|
+
class RunHooks:
|
|
54
|
+
"""Callbacks for runtime instrumentation.
|
|
55
|
+
|
|
56
|
+
Callbacks may be synchronous functions or async callables. Async callables are awaited.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
on_node_start: NodeStartHook | None = None
|
|
60
|
+
on_node_end: NodeEndHook | None = None
|
|
61
|
+
on_edge_taken: EdgeTakenHook | None = None
|
|
62
|
+
on_run_end: RunEndHook | None = None
|
|
63
|
+
on_error: ErrorHook | None = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def _maybe_await(func: Callable[..., HookReturn] | None, *args: Any) -> None:
|
|
67
|
+
if func is None:
|
|
68
|
+
return
|
|
69
|
+
result = func(*args)
|
|
70
|
+
if inspect.isawaitable(result):
|
|
71
|
+
await result
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def instrument_graph_run(graph_run: GraphRun[Any, Any, Any], hooks: RunHooks) -> GraphRun[Any, Any, Any]:
|
|
75
|
+
"""Attach runtime hooks to a GraphRun instance."""
|
|
76
|
+
|
|
77
|
+
if getattr(graph_run, "_pgraph_instrumented", False):
|
|
78
|
+
setattr(graph_run, "_pgraph_run_hooks", hooks) # noqa: B010
|
|
79
|
+
return graph_run
|
|
80
|
+
|
|
81
|
+
original_next = graph_run.next
|
|
82
|
+
setattr(graph_run, "_pgraph_instrumented", True) # noqa: B010
|
|
83
|
+
setattr(graph_run, "_pgraph_run_hooks", hooks) # noqa: B010
|
|
84
|
+
setattr(graph_run, "_pgraph_original_next", original_next) # noqa: B010
|
|
85
|
+
|
|
86
|
+
async def _instrumented_next(
|
|
87
|
+
self: GraphRun[Any, Any, Any],
|
|
88
|
+
node: BaseNode[Any, Any, Any] | None = None,
|
|
89
|
+
) -> BaseNode[Any, Any, Any] | End[Any]:
|
|
90
|
+
run_hooks: RunHooks | None = getattr(self, "_pgraph_run_hooks", None)
|
|
91
|
+
if run_hooks is None:
|
|
92
|
+
return await original_next(node)
|
|
93
|
+
|
|
94
|
+
active_node = node if node is not None else self.next_node
|
|
95
|
+
if isinstance(active_node, BaseNode):
|
|
96
|
+
await _maybe_await(run_hooks.on_node_start, self, active_node)
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
result = await original_next(node)
|
|
100
|
+
except BaseException as exc:
|
|
101
|
+
if isinstance(active_node, BaseNode):
|
|
102
|
+
await _maybe_await(run_hooks.on_error, self, active_node, exc)
|
|
103
|
+
raise
|
|
104
|
+
|
|
105
|
+
if isinstance(active_node, BaseNode):
|
|
106
|
+
await _maybe_await(run_hooks.on_node_end, self, active_node, result)
|
|
107
|
+
if isinstance(result, BaseNode):
|
|
108
|
+
await _maybe_await(run_hooks.on_edge_taken, self, active_node, result)
|
|
109
|
+
elif isinstance(result, End):
|
|
110
|
+
await _maybe_await(run_hooks.on_run_end, self, result)
|
|
111
|
+
|
|
112
|
+
return result
|
|
113
|
+
|
|
114
|
+
setattr(graph_run, "next", types.MethodType(_instrumented_next, graph_run)) # noqa: B010
|
|
115
|
+
return graph_run
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@asynccontextmanager
|
|
119
|
+
async def iter_instrumented(
|
|
120
|
+
graph: Graph[Any, Any, Any],
|
|
121
|
+
start_node: BaseNode[Any, Any, Any],
|
|
122
|
+
*,
|
|
123
|
+
state: Any = None,
|
|
124
|
+
deps: Any = None,
|
|
125
|
+
persistence: Any = None,
|
|
126
|
+
hooks: RunHooks,
|
|
127
|
+
) -> AsyncIterator[GraphRun[Any, Any, Any]]:
|
|
128
|
+
"""Iterate over a graph run while emitting instrumentation callbacks."""
|
|
129
|
+
|
|
130
|
+
async with graph.iter(start_node, state=state, deps=deps, persistence=persistence, infer_name=True) as graph_run:
|
|
131
|
+
instrument_graph_run(graph_run, hooks)
|
|
132
|
+
yield graph_run
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def run_instrumented(
|
|
136
|
+
graph: Graph[Any, Any, Any],
|
|
137
|
+
start_node: BaseNode[Any, Any, Any],
|
|
138
|
+
*,
|
|
139
|
+
state: Any = None,
|
|
140
|
+
deps: Any = None,
|
|
141
|
+
persistence: Any = None,
|
|
142
|
+
hooks: RunHooks,
|
|
143
|
+
) -> GraphRunResult[Any, Any]:
|
|
144
|
+
"""Run a graph to completion with instrumentation."""
|
|
145
|
+
|
|
146
|
+
async with iter_instrumented(
|
|
147
|
+
graph,
|
|
148
|
+
start_node,
|
|
149
|
+
state=state,
|
|
150
|
+
deps=deps,
|
|
151
|
+
persistence=persistence,
|
|
152
|
+
hooks=hooks,
|
|
153
|
+
) as graph_run:
|
|
154
|
+
async for _node in graph_run:
|
|
155
|
+
pass
|
|
156
|
+
result = graph_run.result
|
|
157
|
+
assert result is not None, "GraphRun should have a result"
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
async def iter_run_events(
|
|
162
|
+
graph: Graph[Any, Any, Any],
|
|
163
|
+
start_node: BaseNode[Any, Any, Any] | None = None,
|
|
164
|
+
*,
|
|
165
|
+
state: Any = None,
|
|
166
|
+
deps: Any = None,
|
|
167
|
+
persistence: Any = None,
|
|
168
|
+
inputs: Any = None,
|
|
169
|
+
run_id: str | None = None,
|
|
170
|
+
) -> AsyncIterator[Event]:
|
|
171
|
+
"""Yield an ordered stream of runtime events for a graph run."""
|
|
172
|
+
|
|
173
|
+
if _is_beta_graph(graph):
|
|
174
|
+
async for event in _iter_run_events_beta(
|
|
175
|
+
graph,
|
|
176
|
+
state=state,
|
|
177
|
+
deps=deps,
|
|
178
|
+
inputs=inputs,
|
|
179
|
+
run_id=run_id,
|
|
180
|
+
):
|
|
181
|
+
yield event
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
if start_node is None:
|
|
185
|
+
raise ValueError("start_node is required for v1 graphs")
|
|
186
|
+
|
|
187
|
+
if run_id is None:
|
|
188
|
+
run_id = uuid4().hex
|
|
189
|
+
queue: asyncio.Queue[Event] = asyncio.Queue()
|
|
190
|
+
done = asyncio.Event()
|
|
191
|
+
|
|
192
|
+
async def emit(event: Event) -> None:
|
|
193
|
+
await queue.put(event)
|
|
194
|
+
|
|
195
|
+
async def on_node_start(
|
|
196
|
+
_run: GraphRun[Any, Any, Any],
|
|
197
|
+
node: BaseNode[Any, Any, Any],
|
|
198
|
+
) -> None:
|
|
199
|
+
await emit(
|
|
200
|
+
NodeStartEvent(
|
|
201
|
+
run_id=run_id,
|
|
202
|
+
event_type="node_start",
|
|
203
|
+
node_id=node.get_node_id(),
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
async def on_node_end(
|
|
208
|
+
_run: GraphRun[Any, Any, Any],
|
|
209
|
+
node: BaseNode[Any, Any, Any],
|
|
210
|
+
_result: BaseNode[Any, Any, Any] | End[Any],
|
|
211
|
+
) -> None:
|
|
212
|
+
await emit(
|
|
213
|
+
NodeEndEvent(
|
|
214
|
+
run_id=run_id,
|
|
215
|
+
event_type="node_end",
|
|
216
|
+
node_id=node.get_node_id(),
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
async def on_edge_taken(
|
|
221
|
+
_run: GraphRun[Any, Any, Any],
|
|
222
|
+
source: BaseNode[Any, Any, Any],
|
|
223
|
+
target: BaseNode[Any, Any, Any],
|
|
224
|
+
) -> None:
|
|
225
|
+
await emit(
|
|
226
|
+
EdgeTakenEvent(
|
|
227
|
+
run_id=run_id,
|
|
228
|
+
event_type="edge_taken",
|
|
229
|
+
source_node_id=source.get_node_id(),
|
|
230
|
+
target_node_id=target.get_node_id(),
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
async def on_run_end(
|
|
235
|
+
_run: GraphRun[Any, Any, Any],
|
|
236
|
+
_end: End[Any],
|
|
237
|
+
) -> None:
|
|
238
|
+
await emit(
|
|
239
|
+
RunEndEvent(
|
|
240
|
+
run_id=run_id,
|
|
241
|
+
event_type="run_end",
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
done.set()
|
|
245
|
+
|
|
246
|
+
async def on_error(
|
|
247
|
+
_run: GraphRun[Any, Any, Any],
|
|
248
|
+
node: BaseNode[Any, Any, Any],
|
|
249
|
+
exc: BaseException,
|
|
250
|
+
) -> None:
|
|
251
|
+
await emit(
|
|
252
|
+
ErrorEvent(
|
|
253
|
+
run_id=run_id,
|
|
254
|
+
event_type="error",
|
|
255
|
+
message=str(exc),
|
|
256
|
+
node_id=node.get_node_id(),
|
|
257
|
+
)
|
|
258
|
+
)
|
|
259
|
+
done.set()
|
|
260
|
+
|
|
261
|
+
hooks = RunHooks(
|
|
262
|
+
on_node_start=on_node_start,
|
|
263
|
+
on_node_end=on_node_end,
|
|
264
|
+
on_edge_taken=on_edge_taken,
|
|
265
|
+
on_run_end=on_run_end,
|
|
266
|
+
on_error=on_error,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
async def _run() -> None:
|
|
270
|
+
try:
|
|
271
|
+
await run_instrumented(
|
|
272
|
+
graph,
|
|
273
|
+
start_node,
|
|
274
|
+
state=state,
|
|
275
|
+
deps=deps,
|
|
276
|
+
persistence=persistence,
|
|
277
|
+
hooks=hooks,
|
|
278
|
+
)
|
|
279
|
+
except BaseException as exc:
|
|
280
|
+
if not done.is_set():
|
|
281
|
+
await emit(
|
|
282
|
+
ErrorEvent(
|
|
283
|
+
run_id=run_id,
|
|
284
|
+
event_type="error",
|
|
285
|
+
message=str(exc),
|
|
286
|
+
node_id=None,
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
finally:
|
|
290
|
+
done.set()
|
|
291
|
+
|
|
292
|
+
task = asyncio.create_task(_run())
|
|
293
|
+
try:
|
|
294
|
+
while True:
|
|
295
|
+
if done.is_set() and queue.empty():
|
|
296
|
+
break
|
|
297
|
+
event = await queue.get()
|
|
298
|
+
yield event
|
|
299
|
+
finally:
|
|
300
|
+
if not task.done():
|
|
301
|
+
task.cancel()
|
|
302
|
+
with suppress(asyncio.CancelledError):
|
|
303
|
+
await task
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _is_beta_graph(graph: Any) -> bool:
|
|
307
|
+
return BetaGraph is not None and isinstance(graph, BetaGraph)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
async def _iter_run_events_beta(
|
|
311
|
+
graph: Any,
|
|
312
|
+
*,
|
|
313
|
+
state: Any = None,
|
|
314
|
+
deps: Any = None,
|
|
315
|
+
inputs: Any = None,
|
|
316
|
+
run_id: str | None = None,
|
|
317
|
+
) -> AsyncIterator[Event]:
|
|
318
|
+
if run_id is None:
|
|
319
|
+
run_id = uuid4().hex
|
|
320
|
+
queue: asyncio.Queue[Event] = asyncio.Queue()
|
|
321
|
+
done = asyncio.Event()
|
|
322
|
+
|
|
323
|
+
async def emit(event: Event) -> None:
|
|
324
|
+
await queue.put(event)
|
|
325
|
+
|
|
326
|
+
async def _run() -> None:
|
|
327
|
+
try:
|
|
328
|
+
async with graph.iter(state=state, deps=deps, inputs=inputs, infer_name=True) as graph_run:
|
|
329
|
+
iterator = graph_run._iterator_instance
|
|
330
|
+
original_run_task = iterator._run_task
|
|
331
|
+
|
|
332
|
+
async def instrumented_run_task(task: Any) -> Any:
|
|
333
|
+
node_id = str(task.node_id)
|
|
334
|
+
await emit(NodeStartEvent(run_id=run_id, event_type="node_start", node_id=node_id))
|
|
335
|
+
try:
|
|
336
|
+
result = await original_run_task(task)
|
|
337
|
+
except BaseException as exc:
|
|
338
|
+
await emit(
|
|
339
|
+
ErrorEvent(
|
|
340
|
+
run_id=run_id,
|
|
341
|
+
event_type="error",
|
|
342
|
+
message=str(exc),
|
|
343
|
+
node_id=node_id,
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
raise
|
|
347
|
+
await emit(NodeEndEvent(run_id=run_id, event_type="node_end", node_id=node_id))
|
|
348
|
+
|
|
349
|
+
if isinstance(result, BetaEndMarker):
|
|
350
|
+
await emit(RunEndEvent(run_id=run_id, event_type="run_end"))
|
|
351
|
+
done.set()
|
|
352
|
+
elif isinstance(result, BetaJoinItem):
|
|
353
|
+
await emit(
|
|
354
|
+
EdgeTakenEvent(
|
|
355
|
+
run_id=run_id,
|
|
356
|
+
event_type="edge_taken",
|
|
357
|
+
source_node_id=node_id,
|
|
358
|
+
target_node_id=str(result.join_id),
|
|
359
|
+
)
|
|
360
|
+
)
|
|
361
|
+
elif isinstance(result, Sequence):
|
|
362
|
+
for new_task in result:
|
|
363
|
+
await emit(
|
|
364
|
+
EdgeTakenEvent(
|
|
365
|
+
run_id=run_id,
|
|
366
|
+
event_type="edge_taken",
|
|
367
|
+
source_node_id=node_id,
|
|
368
|
+
target_node_id=str(new_task.node_id),
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
return result
|
|
372
|
+
|
|
373
|
+
iterator._run_task = instrumented_run_task
|
|
374
|
+
|
|
375
|
+
async for _item in graph_run:
|
|
376
|
+
pass
|
|
377
|
+
except BaseException as exc:
|
|
378
|
+
if not done.is_set():
|
|
379
|
+
await emit(
|
|
380
|
+
ErrorEvent(
|
|
381
|
+
run_id=run_id,
|
|
382
|
+
event_type="error",
|
|
383
|
+
message=str(exc),
|
|
384
|
+
node_id=None,
|
|
385
|
+
)
|
|
386
|
+
)
|
|
387
|
+
finally:
|
|
388
|
+
done.set()
|
|
389
|
+
|
|
390
|
+
task = asyncio.create_task(_run())
|
|
391
|
+
try:
|
|
392
|
+
while True:
|
|
393
|
+
if done.is_set() and queue.empty():
|
|
394
|
+
break
|
|
395
|
+
event = await queue.get()
|
|
396
|
+
yield event
|
|
397
|
+
finally:
|
|
398
|
+
if not task.done():
|
|
399
|
+
task.cancel()
|
|
400
|
+
with suppress(asyncio.CancelledError):
|
|
401
|
+
await task
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def run_instrumented_sync(
|
|
405
|
+
graph: Graph[Any, Any, Any],
|
|
406
|
+
start_node: BaseNode[Any, Any, Any],
|
|
407
|
+
*,
|
|
408
|
+
state: Any = None,
|
|
409
|
+
deps: Any = None,
|
|
410
|
+
persistence: Any = None,
|
|
411
|
+
hooks: RunHooks,
|
|
412
|
+
) -> GraphRunResult[Any, Any]:
|
|
413
|
+
"""Synchronously run a graph with instrumentation."""
|
|
414
|
+
|
|
415
|
+
return _get_event_loop().run_until_complete(
|
|
416
|
+
run_instrumented(
|
|
417
|
+
graph,
|
|
418
|
+
start_node,
|
|
419
|
+
state=state,
|
|
420
|
+
deps=deps,
|
|
421
|
+
persistence=persistence,
|
|
422
|
+
hooks=hooks,
|
|
423
|
+
)
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _get_event_loop() -> asyncio.AbstractEventLoop:
|
|
428
|
+
try:
|
|
429
|
+
event_loop = asyncio.get_event_loop()
|
|
430
|
+
except RuntimeError:
|
|
431
|
+
event_loop = asyncio.new_event_loop()
|
|
432
|
+
asyncio.set_event_loop(event_loop)
|
|
433
|
+
return event_loop
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Any, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field, TypeAdapter
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GraphNode(BaseModel):
|
|
9
|
+
"""Represents a node in the introspected graph."""
|
|
10
|
+
|
|
11
|
+
node_id: str
|
|
12
|
+
label: str | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GraphEdge(BaseModel):
|
|
16
|
+
"""Represents a directed edge between nodes."""
|
|
17
|
+
|
|
18
|
+
source_node_id: str
|
|
19
|
+
target_node_id: str | None = None
|
|
20
|
+
dynamic: bool = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GraphModel(BaseModel):
|
|
24
|
+
"""Container for the full graph payload."""
|
|
25
|
+
|
|
26
|
+
nodes: list[GraphNode]
|
|
27
|
+
edges: list[GraphEdge]
|
|
28
|
+
entry_nodes: list[str]
|
|
29
|
+
terminal_nodes: list[str]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class EventBase(BaseModel):
|
|
33
|
+
"""Base event fields shared across all runtime events."""
|
|
34
|
+
|
|
35
|
+
run_id: str
|
|
36
|
+
event_type: str
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class NodeStartEvent(EventBase):
|
|
40
|
+
"""Emitted when a node begins execution."""
|
|
41
|
+
|
|
42
|
+
event_type: Literal["node_start"]
|
|
43
|
+
node_id: str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class NodeEndEvent(EventBase):
|
|
47
|
+
"""Emitted when a node finishes execution."""
|
|
48
|
+
|
|
49
|
+
event_type: Literal["node_end"]
|
|
50
|
+
node_id: str
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class EdgeTakenEvent(EventBase):
|
|
54
|
+
"""Emitted when an edge is traversed during execution."""
|
|
55
|
+
|
|
56
|
+
event_type: Literal["edge_taken"]
|
|
57
|
+
source_node_id: str
|
|
58
|
+
target_node_id: str | None = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class RunEndEvent(EventBase):
|
|
62
|
+
"""Emitted when a run completes successfully."""
|
|
63
|
+
|
|
64
|
+
event_type: Literal["run_end"]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ErrorEvent(EventBase):
|
|
68
|
+
"""Emitted when a run terminates due to an error."""
|
|
69
|
+
|
|
70
|
+
event_type: Literal["error"]
|
|
71
|
+
message: str
|
|
72
|
+
node_id: str | None = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
Event = Annotated[
|
|
76
|
+
NodeStartEvent | NodeEndEvent | EdgeTakenEvent | RunEndEvent | ErrorEvent,
|
|
77
|
+
Field(discriminator="event_type"),
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def graph_schema() -> dict[str, Any]:
|
|
82
|
+
"""Return the JSON Schema for the graph payload."""
|
|
83
|
+
|
|
84
|
+
return GraphModel.model_json_schema()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def event_schema() -> dict[str, Any]:
|
|
88
|
+
"""Return the JSON Schema for the event payload."""
|
|
89
|
+
|
|
90
|
+
return TypeAdapter(Event).json_schema()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def export_schemas() -> dict[str, dict[str, Any]]:
|
|
94
|
+
"""Export all JSON Schemas keyed by payload name."""
|
|
95
|
+
|
|
96
|
+
return {
|
|
97
|
+
"graph": graph_schema(),
|
|
98
|
+
"event": event_schema(),
|
|
99
|
+
}
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from importlib import resources
|
|
9
|
+
from typing import Any
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
from fastapi import FastAPI, HTTPException
|
|
13
|
+
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
|
14
|
+
from fastapi.staticfiles import StaticFiles
|
|
15
|
+
from pydantic_graph import Graph
|
|
16
|
+
from pydantic_graph.nodes import BaseNode
|
|
17
|
+
|
|
18
|
+
from pydantic_graph_studio.introspection import serialize_graph
|
|
19
|
+
from pydantic_graph_studio.runtime import iter_run_events
|
|
20
|
+
from pydantic_graph_studio.schemas import Event
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(slots=True)
|
|
24
|
+
class RunState:
|
|
25
|
+
run_id: str
|
|
26
|
+
queue: asyncio.Queue[Event]
|
|
27
|
+
done: asyncio.Event
|
|
28
|
+
task: asyncio.Task[None]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RunRegistry:
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
"""Initialize the run registry."""
|
|
34
|
+
self._runs: dict[str, RunState] = {}
|
|
35
|
+
self._lock = asyncio.Lock()
|
|
36
|
+
|
|
37
|
+
async def start_run(
|
|
38
|
+
self,
|
|
39
|
+
graph: Graph[Any, Any, Any],
|
|
40
|
+
start_node: BaseNode[Any, Any, Any] | None,
|
|
41
|
+
*,
|
|
42
|
+
state: Any = None,
|
|
43
|
+
deps: Any = None,
|
|
44
|
+
persistence: Any = None,
|
|
45
|
+
inputs: Any = None,
|
|
46
|
+
) -> str:
|
|
47
|
+
"""Start a graph run and return the run id."""
|
|
48
|
+
run_id = uuid4().hex
|
|
49
|
+
queue: asyncio.Queue[Event] = asyncio.Queue()
|
|
50
|
+
done = asyncio.Event()
|
|
51
|
+
|
|
52
|
+
async def producer() -> None:
|
|
53
|
+
try:
|
|
54
|
+
async for event in iter_run_events(
|
|
55
|
+
graph,
|
|
56
|
+
start_node,
|
|
57
|
+
state=state,
|
|
58
|
+
deps=deps,
|
|
59
|
+
persistence=persistence,
|
|
60
|
+
inputs=inputs,
|
|
61
|
+
run_id=run_id,
|
|
62
|
+
):
|
|
63
|
+
await queue.put(event)
|
|
64
|
+
finally:
|
|
65
|
+
done.set()
|
|
66
|
+
|
|
67
|
+
task = asyncio.create_task(producer())
|
|
68
|
+
async with self._lock:
|
|
69
|
+
self._runs[run_id] = RunState(run_id=run_id, queue=queue, done=done, task=task)
|
|
70
|
+
return run_id
|
|
71
|
+
|
|
72
|
+
async def get(self, run_id: str) -> RunState | None:
|
|
73
|
+
"""Fetch the run state for a run id."""
|
|
74
|
+
async with self._lock:
|
|
75
|
+
return self._runs.get(run_id)
|
|
76
|
+
|
|
77
|
+
async def remove(self, run_id: str) -> None:
|
|
78
|
+
"""Remove a run state from the registry."""
|
|
79
|
+
async with self._lock:
|
|
80
|
+
self._runs.pop(run_id, None)
|
|
81
|
+
|
|
82
|
+
async def shutdown(self) -> None:
|
|
83
|
+
"""Cancel any in-flight runs and clear the registry."""
|
|
84
|
+
async with self._lock:
|
|
85
|
+
runs = list(self._runs.values())
|
|
86
|
+
self._runs.clear()
|
|
87
|
+
for run in runs:
|
|
88
|
+
if not run.task.done():
|
|
89
|
+
run.task.cancel()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def create_app(
|
|
93
|
+
graph: Graph[Any, Any, Any],
|
|
94
|
+
start_node: BaseNode[Any, Any, Any] | None,
|
|
95
|
+
*,
|
|
96
|
+
state: Any = None,
|
|
97
|
+
deps: Any = None,
|
|
98
|
+
persistence: Any = None,
|
|
99
|
+
inputs: Any = None,
|
|
100
|
+
) -> FastAPI:
|
|
101
|
+
"""Create the FastAPI app bound to a graph and start node."""
|
|
102
|
+
ui_root = resources.files("pydantic_graph_studio.ui")
|
|
103
|
+
index_html = (ui_root / "index.html").read_text(encoding="utf-8")
|
|
104
|
+
assets_dir = ui_root / "assets"
|
|
105
|
+
|
|
106
|
+
@asynccontextmanager
|
|
107
|
+
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
108
|
+
"""Initialize and tear down shared server state."""
|
|
109
|
+
registry = RunRegistry()
|
|
110
|
+
app.state.graph = graph
|
|
111
|
+
app.state.start_node = start_node
|
|
112
|
+
app.state.state = state
|
|
113
|
+
app.state.deps = deps
|
|
114
|
+
app.state.persistence = persistence
|
|
115
|
+
app.state.inputs = inputs
|
|
116
|
+
app.state.registry = registry
|
|
117
|
+
try:
|
|
118
|
+
yield
|
|
119
|
+
finally:
|
|
120
|
+
await registry.shutdown()
|
|
121
|
+
|
|
122
|
+
app = FastAPI(lifespan=lifespan)
|
|
123
|
+
|
|
124
|
+
@app.get("/api/graph")
|
|
125
|
+
async def get_graph() -> JSONResponse:
|
|
126
|
+
"""Return the serialized graph model."""
|
|
127
|
+
payload = serialize_graph(app.state.graph)
|
|
128
|
+
return JSONResponse(payload)
|
|
129
|
+
|
|
130
|
+
@app.post("/api/run")
|
|
131
|
+
async def start_run() -> dict[str, str]:
|
|
132
|
+
"""Start a new run and return its identifier."""
|
|
133
|
+
run_id = await app.state.registry.start_run(
|
|
134
|
+
app.state.graph,
|
|
135
|
+
app.state.start_node,
|
|
136
|
+
state=app.state.state,
|
|
137
|
+
deps=app.state.deps,
|
|
138
|
+
persistence=app.state.persistence,
|
|
139
|
+
inputs=app.state.inputs,
|
|
140
|
+
)
|
|
141
|
+
return {"run_id": run_id}
|
|
142
|
+
|
|
143
|
+
@app.get("/api/events")
|
|
144
|
+
async def stream_events(run_id: str) -> StreamingResponse:
|
|
145
|
+
"""Stream events for a run as Server-Sent Events."""
|
|
146
|
+
run_state = await app.state.registry.get(run_id)
|
|
147
|
+
if run_state is None:
|
|
148
|
+
raise HTTPException(status_code=404, detail="Unknown run_id")
|
|
149
|
+
|
|
150
|
+
async def event_stream() -> AsyncIterator[bytes]:
|
|
151
|
+
"""Yield SSE-formatted event payloads."""
|
|
152
|
+
try:
|
|
153
|
+
while True:
|
|
154
|
+
if run_state.done.is_set() and run_state.queue.empty():
|
|
155
|
+
break
|
|
156
|
+
event = await run_state.queue.get()
|
|
157
|
+
payload = json.dumps(event.model_dump(mode="json"))
|
|
158
|
+
yield f"data: {payload}\n\n".encode()
|
|
159
|
+
finally:
|
|
160
|
+
await app.state.registry.remove(run_id)
|
|
161
|
+
|
|
162
|
+
headers = {
|
|
163
|
+
"Cache-Control": "no-cache",
|
|
164
|
+
"Connection": "keep-alive",
|
|
165
|
+
}
|
|
166
|
+
return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
|
|
167
|
+
|
|
168
|
+
app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")
|
|
169
|
+
|
|
170
|
+
@app.get("/")
|
|
171
|
+
async def studio_index() -> HTMLResponse:
|
|
172
|
+
"""Serve the bundled studio UI."""
|
|
173
|
+
return HTMLResponse(index_html)
|
|
174
|
+
|
|
175
|
+
return app
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Bundled UI assets for the studio frontend."""
|