pydantic-graph-studio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,55 @@
1
+ """Pydantic Graph Studio entrypoint."""
2
+
3
+ from pydantic_graph_studio.cli import main
4
+ from pydantic_graph_studio.introspection import build_graph_model, serialize_graph
5
+ from pydantic_graph_studio.runtime import (
6
+ RunHooks,
7
+ instrument_graph_run,
8
+ iter_instrumented,
9
+ iter_run_events,
10
+ run_instrumented,
11
+ run_instrumented_sync,
12
+ )
13
+ from pydantic_graph_studio.schemas import (
14
+ EdgeTakenEvent,
15
+ ErrorEvent,
16
+ Event,
17
+ EventBase,
18
+ GraphEdge,
19
+ GraphModel,
20
+ GraphNode,
21
+ NodeEndEvent,
22
+ NodeStartEvent,
23
+ RunEndEvent,
24
+ event_schema,
25
+ export_schemas,
26
+ graph_schema,
27
+ )
28
+ from pydantic_graph_studio.server import RunRegistry, create_app
29
+
30
+ __all__ = [
31
+ "EdgeTakenEvent",
32
+ "ErrorEvent",
33
+ "Event",
34
+ "EventBase",
35
+ "GraphEdge",
36
+ "GraphModel",
37
+ "GraphNode",
38
+ "NodeEndEvent",
39
+ "NodeStartEvent",
40
+ "RunEndEvent",
41
+ "RunHooks",
42
+ "RunRegistry",
43
+ "build_graph_model",
44
+ "create_app",
45
+ "event_schema",
46
+ "export_schemas",
47
+ "graph_schema",
48
+ "instrument_graph_run",
49
+ "iter_instrumented",
50
+ "iter_run_events",
51
+ "main",
52
+ "run_instrumented",
53
+ "run_instrumented_sync",
54
+ "serialize_graph",
55
+ ]
@@ -0,0 +1,201 @@
1
+ """Command-line entrypoint for launching the studio server."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import importlib
7
+ import importlib.util
8
+ import sys
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ from pydantic_graph import Graph
14
+ from pydantic_graph.nodes import BaseNode
15
+
16
+ from pydantic_graph_studio.introspection import build_graph_model
17
+ from pydantic_graph_studio.server import create_app
18
+
19
+ BetaGraph: type[Any] | None = None
20
+ try: # pragma: no cover - optional beta support
21
+ from pydantic_graph.beta.graph import Graph as _BetaGraph
22
+ except ModuleNotFoundError: # pragma: no cover
23
+ pass
24
+ else:
25
+ BetaGraph = _BetaGraph
26
+
27
+
28
+ @dataclass(slots=True)
29
+ class GraphRef:
30
+ target: str
31
+ attribute: str
32
+
33
+
34
+ class CLIError(RuntimeError):
35
+ """Raised for user-facing CLI errors."""
36
+
37
+
38
+ def main(argv: list[str] | None = None) -> None:
39
+ """CLI entrypoint for the `pgraph` launcher."""
40
+
41
+ args = _parse_args(argv)
42
+ try:
43
+ graph = _load_graph(args.graph_ref)
44
+ start_node = _resolve_start_node(graph, args.start)
45
+ _run_server(graph, start_node, host=args.host, port=args.port)
46
+ except CLIError as exc:
47
+ print(f"error: {exc}", file=sys.stderr)
48
+ raise SystemExit(2) from exc
49
+
50
+
51
+ def _parse_args(argv: list[str] | None) -> argparse.Namespace:
52
+ parser = argparse.ArgumentParser(
53
+ prog="pgraph",
54
+ description="Launch the local Pydantic Graph Studio for a graph reference.",
55
+ )
56
+ parser.add_argument(
57
+ "graph_ref",
58
+ help="Graph reference in the form module:var or path.py:var",
59
+ )
60
+ parser.add_argument(
61
+ "--host",
62
+ default="127.0.0.1",
63
+ help="Host to bind the local server (default: 127.0.0.1)",
64
+ )
65
+ parser.add_argument(
66
+ "--port",
67
+ type=int,
68
+ default=8000,
69
+ help="Port to bind the local server (default: 8000)",
70
+ )
71
+ parser.add_argument(
72
+ "--start",
73
+ help="Explicit node id to use as the entry point",
74
+ )
75
+ return parser.parse_args(argv)
76
+
77
+
78
+ def _load_graph(graph_ref: str) -> Any:
79
+ parsed = _parse_graph_ref(graph_ref)
80
+ module = _load_module(parsed.target)
81
+ graph = _resolve_attribute(module, parsed.attribute)
82
+ if not isinstance(graph, Graph) and not _is_beta_graph(graph):
83
+ raise CLIError(
84
+ "Graph reference did not resolve to a Graph instance. "
85
+ "Ensure the reference points to a pydantic_graph.Graph object."
86
+ )
87
+ return graph
88
+
89
+
90
+ def _parse_graph_ref(graph_ref: str) -> GraphRef:
91
+ if ":" not in graph_ref:
92
+ raise CLIError("Graph reference must be in the form module:var or path.py:var")
93
+ target, attribute = graph_ref.rsplit(":", 1)
94
+ if not target or not attribute:
95
+ raise CLIError("Graph reference must include both target and attribute")
96
+ return GraphRef(target=target, attribute=attribute)
97
+
98
+
99
+ def _load_module(target: str) -> Any:
100
+ path = Path(target)
101
+ if _looks_like_path(target):
102
+ if not path.exists():
103
+ raise CLIError(f"File not found: {path}")
104
+ if path.suffix != ".py":
105
+ raise CLIError("File reference must point to a .py file")
106
+ return _load_module_from_file(path)
107
+ try:
108
+ return importlib.import_module(target)
109
+ except ModuleNotFoundError as exc:
110
+ raise CLIError(f"Module not found: {target}") from exc
111
+
112
+
113
+ def _looks_like_path(target: str) -> bool:
114
+ return "/" in target or "\\" in target or target.endswith(".py")
115
+
116
+
117
+ def _load_module_from_file(path: Path) -> Any:
118
+ module_name = f"pgraph_user_{path.stem}_{abs(hash(path))}"
119
+ spec = importlib.util.spec_from_file_location(module_name, path)
120
+ if spec is None or spec.loader is None:
121
+ raise CLIError(f"Unable to load module from file: {path}")
122
+ module = importlib.util.module_from_spec(spec)
123
+ sys.modules[module_name] = module
124
+ spec.loader.exec_module(module)
125
+ return module
126
+
127
+
128
+ def _resolve_attribute(module: Any, attribute: str) -> Any:
129
+ current: Any = module
130
+ for segment in attribute.split("."):
131
+ if not hasattr(current, segment):
132
+ raise CLIError(f"Attribute '{segment}' not found while resolving '{attribute}'")
133
+ current = getattr(current, segment)
134
+ return current
135
+
136
+
137
+ def _resolve_start_node(
138
+ graph: Any,
139
+ start_node_id: str | None,
140
+ ) -> BaseNode[Any, Any, Any] | None:
141
+ if _is_beta_graph(graph):
142
+ if start_node_id:
143
+ raise CLIError("Beta graphs use a fixed start node; --start is not supported.")
144
+ return None
145
+
146
+ node_defs = graph.node_defs
147
+ if not node_defs:
148
+ raise CLIError("Graph contains no nodes")
149
+
150
+ if start_node_id:
151
+ node_def = node_defs.get(start_node_id)
152
+ if node_def is None:
153
+ available = ", ".join(sorted(node_defs.keys()))
154
+ raise CLIError(f"Unknown start node '{start_node_id}'. Available nodes: {available}")
155
+ else:
156
+ entry_nodes = build_graph_model(graph).entry_nodes
157
+ if not entry_nodes:
158
+ raise CLIError("Unable to infer an entry node. Use --start to specify one.")
159
+ if len(entry_nodes) > 1:
160
+ entries = ", ".join(entry_nodes)
161
+ raise CLIError(f"Multiple entry nodes found: {entries}. Use --start to choose one.")
162
+ node_def = node_defs[entry_nodes[0]]
163
+
164
+ node_cls_or_instance = node_def.node
165
+ if isinstance(node_cls_or_instance, BaseNode):
166
+ return node_cls_or_instance
167
+ if isinstance(node_cls_or_instance, type) and issubclass(node_cls_or_instance, BaseNode):
168
+ try:
169
+ instance = node_cls_or_instance()
170
+ except TypeError as exc:
171
+ raise CLIError(
172
+ "Failed to instantiate the start node. "
173
+ "Ensure it can be constructed with no arguments or provide a different entry node."
174
+ ) from exc
175
+ return instance
176
+
177
+ raise CLIError("Start node did not resolve to a BaseNode instance")
178
+
179
+
180
+ def _run_server(
181
+ graph: Any,
182
+ start_node: BaseNode[Any, Any, Any] | None,
183
+ *,
184
+ host: str,
185
+ port: int,
186
+ ) -> None:
187
+ if port <= 0 or port > 65535:
188
+ raise CLIError("Port must be between 1 and 65535")
189
+
190
+ app = create_app(graph, start_node)
191
+ try:
192
+ import uvicorn
193
+ except ModuleNotFoundError as exc:
194
+ raise CLIError("uvicorn is required to run the server") from exc
195
+
196
+ print(f"Studio running at http://{host}:{port}")
197
+ uvicorn.run(app, host=host, port=port, log_level="info")
198
+
199
+
200
+ def _is_beta_graph(graph: Any) -> bool:
201
+ return BetaGraph is not None and isinstance(graph, BetaGraph)
@@ -0,0 +1,209 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable, Mapping
4
+ from typing import Any
5
+
6
+ from pydantic_graph import Graph
7
+ from pydantic_graph.nodes import NodeDef
8
+
9
+ from pydantic_graph_studio.schemas import GraphEdge, GraphModel, GraphNode
10
+
11
+ BetaGraph: type[Any] | None = None
12
+ try: # pragma: no cover - optional beta support
13
+ from pydantic_graph.beta.decision import Decision as BetaDecision
14
+ from pydantic_graph.beta.graph import Graph as _BetaGraph
15
+ from pydantic_graph.beta.join import Join as BetaJoin
16
+ from pydantic_graph.beta.node import EndNode as BetaEndNode
17
+ from pydantic_graph.beta.node import Fork as BetaFork
18
+ from pydantic_graph.beta.node import StartNode as BetaStartNode
19
+ from pydantic_graph.beta.paths import DestinationMarker as BetaDestinationMarker
20
+ from pydantic_graph.beta.paths import Path as BetaPath
21
+ from pydantic_graph.beta.step import NodeStep as BetaNodeStep
22
+ from pydantic_graph.beta.step import Step as BetaStep
23
+ except ModuleNotFoundError: # pragma: no cover
24
+ pass
25
+ else:
26
+ BetaGraph = _BetaGraph
27
+
28
+
29
+ def build_graph_model(graph: Any) -> GraphModel:
30
+ """Build a GraphModel payload from a pydantic_graph.Graph instance."""
31
+
32
+ if _is_beta_graph(graph):
33
+ return _build_beta_graph_model(graph)
34
+
35
+ node_defs = _sorted_node_defs(graph.node_defs)
36
+ nodes = _build_nodes(node_defs)
37
+ edges = _build_edges(node_defs)
38
+ entry_nodes = _infer_entry_nodes(node_defs, edges)
39
+ terminal_nodes = _infer_terminal_nodes(node_defs)
40
+ return GraphModel(
41
+ nodes=nodes,
42
+ edges=edges,
43
+ entry_nodes=entry_nodes,
44
+ terminal_nodes=terminal_nodes,
45
+ )
46
+
47
+
48
+ def serialize_graph(graph: Graph[Any, Any, Any]) -> dict[str, Any]:
49
+ """Serialize a graph into a JSON-safe dict payload."""
50
+
51
+ return build_graph_model(graph).model_dump(mode="json")
52
+
53
+
54
+ def _sorted_node_defs(node_defs: Mapping[str, NodeDef[Any, Any, Any]]) -> list[NodeDef[Any, Any, Any]]:
55
+ return [node_defs[node_id] for node_id in sorted(node_defs.keys())]
56
+
57
+
58
+ def _build_nodes(node_defs: Iterable[NodeDef[Any, Any, Any]]) -> list[GraphNode]:
59
+ return [
60
+ GraphNode(
61
+ node_id=node_def.node_id,
62
+ label=node_def.node.__name__,
63
+ )
64
+ for node_def in node_defs
65
+ ]
66
+
67
+
68
+ def _build_edges(node_defs: Iterable[NodeDef[Any, Any, Any]]) -> list[GraphEdge]:
69
+ edges: list[GraphEdge] = []
70
+ for node_def in node_defs:
71
+ for target_id in sorted(node_def.next_node_edges.keys()):
72
+ edges.append(
73
+ GraphEdge(
74
+ source_node_id=node_def.node_id,
75
+ target_node_id=target_id,
76
+ dynamic=False,
77
+ )
78
+ )
79
+ if node_def.returns_base_node:
80
+ edges.append(
81
+ GraphEdge(
82
+ source_node_id=node_def.node_id,
83
+ target_node_id=None,
84
+ dynamic=True,
85
+ )
86
+ )
87
+ return edges
88
+
89
+
90
+ def _infer_entry_nodes(
91
+ node_defs: Iterable[NodeDef[Any, Any, Any]],
92
+ edges: Iterable[GraphEdge],
93
+ ) -> list[str]:
94
+ node_ids = {node_def.node_id for node_def in node_defs}
95
+ inbound: set[str] = set()
96
+ for edge in edges:
97
+ if edge.target_node_id is not None:
98
+ inbound.add(edge.target_node_id)
99
+ entry_nodes = sorted(node_ids - inbound)
100
+ return entry_nodes
101
+
102
+
103
+ def _infer_terminal_nodes(node_defs: Iterable[NodeDef[Any, Any, Any]]) -> list[str]:
104
+ terminal_nodes = sorted(node_def.node_id for node_def in node_defs if node_def.end_edge is not None)
105
+ return terminal_nodes
106
+
107
+
108
+ def _is_beta_graph(graph: Any) -> bool:
109
+ return BetaGraph is not None and isinstance(graph, BetaGraph)
110
+
111
+
112
+ def _build_beta_graph_model(graph: Any) -> GraphModel:
113
+ nodes = _build_beta_nodes(graph)
114
+ edges = _build_beta_edges(graph.edges_by_source, graph.nodes)
115
+ entry_nodes = _infer_beta_entry_nodes(graph)
116
+ terminal_nodes = _infer_beta_terminal_nodes(graph)
117
+ return GraphModel(
118
+ nodes=nodes,
119
+ edges=edges,
120
+ entry_nodes=entry_nodes,
121
+ terminal_nodes=terminal_nodes,
122
+ )
123
+
124
+
125
+ def _build_beta_nodes(graph: Any) -> list[GraphNode]:
126
+ items = sorted(graph.nodes.items(), key=lambda item: str(item[0]))
127
+ return [
128
+ GraphNode(
129
+ node_id=str(node_id),
130
+ label=_beta_node_label(node),
131
+ )
132
+ for node_id, node in items
133
+ ]
134
+
135
+
136
+ def _beta_node_label(node: Any) -> str | None:
137
+ if BetaGraph is None:
138
+ return None
139
+ if isinstance(node, BetaStartNode):
140
+ return "Start"
141
+ if isinstance(node, BetaEndNode):
142
+ return "Done"
143
+ if isinstance(node, BetaFork):
144
+ return "Map Fork" if node.is_map else "Fork"
145
+ if isinstance(node, BetaJoin):
146
+ return "Join"
147
+ if isinstance(node, BetaDecision):
148
+ return "Decision"
149
+ if isinstance(node, BetaNodeStep):
150
+ return node.node_type.__name__
151
+ if isinstance(node, BetaStep):
152
+ return node.label or str(node.id)
153
+ raw = getattr(node, "label", None) or getattr(node, "id", None)
154
+ return str(raw) if raw is not None else None
155
+
156
+
157
+ def _build_beta_edges(
158
+ edges_by_source: Mapping[Any, list[Any]], nodes: Mapping[Any, Any] | None = None
159
+ ) -> list[GraphEdge]:
160
+ edge_pairs: set[tuple[str, str]] = set()
161
+
162
+ for source_id, paths in edges_by_source.items():
163
+ for path in paths:
164
+ for target_id in _beta_path_destinations(path):
165
+ edge_pairs.add((str(source_id), str(target_id)))
166
+
167
+ if nodes:
168
+ for node_id, node in nodes.items():
169
+ if BetaGraph is None or not isinstance(node, BetaDecision):
170
+ continue
171
+ for branch in node.branches:
172
+ for target_id in _beta_path_destinations(branch.path):
173
+ edge_pairs.add((str(node_id), str(target_id)))
174
+
175
+ edges = [
176
+ GraphEdge(
177
+ source_node_id=source_id,
178
+ target_node_id=target_id,
179
+ dynamic=False,
180
+ )
181
+ for source_id, target_id in sorted(edge_pairs)
182
+ ]
183
+ return edges
184
+
185
+
186
+ def _beta_path_destinations(path: Any) -> list[Any]:
187
+ if BetaGraph is None:
188
+ return []
189
+ if not isinstance(path, BetaPath):
190
+ return []
191
+ destinations: list[Any] = []
192
+ for item in path.items:
193
+ if isinstance(item, BetaDestinationMarker):
194
+ destinations.append(item.destination_id)
195
+ return destinations
196
+
197
+
198
+ def _infer_beta_entry_nodes(graph: Any) -> list[str]:
199
+ if BetaGraph is None:
200
+ return []
201
+ start_id = str(BetaStartNode.id)
202
+ return [start_id] if start_id in {str(node_id) for node_id in graph.nodes.keys()} else []
203
+
204
+
205
+ def _infer_beta_terminal_nodes(graph: Any) -> list[str]:
206
+ if BetaGraph is None:
207
+ return []
208
+ end_id = str(BetaEndNode.id)
209
+ return [end_id] if end_id in {str(node_id) for node_id in graph.nodes.keys()} else []