yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,181 @@
1
+ """Pydantic schemas for YAML graph configuration validation.
2
+
3
+ Provides structured validation for graph YAML files with clear error messages.
4
+ """
5
+
6
+ from typing import Any, Literal
7
+
8
+ from pydantic import BaseModel, Field, field_validator, model_validator
9
+
10
+ from yamlgraph.constants import ErrorHandler, NodeType
11
+
12
+
13
+ class SubgraphNodeConfig(BaseModel):
14
+ """Configuration for a subgraph node."""
15
+
16
+ type: Literal["subgraph"]
17
+ graph: str = Field(
18
+ ..., description="Path to subgraph YAML file (relative to parent)"
19
+ )
20
+ mode: Literal["invoke", "direct"] = Field(
21
+ default="invoke",
22
+ description="invoke: explicit state mapping; direct: shared schema",
23
+ )
24
+ input_mapping: dict[str, str] | Literal["auto", "*"] = Field(
25
+ default_factory=dict,
26
+ description="Map parent state fields to child input (mode=invoke only)",
27
+ )
28
+ output_mapping: dict[str, str] | Literal["auto", "*"] = Field(
29
+ default_factory=dict,
30
+ description="Map child output fields to parent state (mode=invoke only)",
31
+ )
32
+ interrupt_output_mapping: dict[str, str] | Literal["auto", "*"] = Field(
33
+ default_factory=dict,
34
+ description="Map child state to parent when subgraph interrupts (FR-006)",
35
+ )
36
+ checkpointer: str | None = Field(
37
+ default=None,
38
+ description="Override parent checkpointer",
39
+ )
40
+
41
+ model_config = {"extra": "allow"}
42
+
43
+ @model_validator(mode="after")
44
+ def validate_config(self) -> "SubgraphNodeConfig":
45
+ """Validate subgraph configuration."""
46
+ if not self.graph.endswith((".yaml", ".yml")):
47
+ raise ValueError(f"Subgraph must be a YAML file: {self.graph}")
48
+ if self.mode == "direct" and (self.input_mapping or self.output_mapping):
49
+ raise ValueError("mode=direct does not support input/output mappings")
50
+ return self
51
+
52
+
53
+ class NodeConfig(BaseModel):
54
+ """Configuration for a single graph node."""
55
+
56
+ type: str = Field(default=NodeType.LLM, description="Node type")
57
+ prompt: str | None = Field(default=None, description="Prompt template name")
58
+ state_key: str | None = Field(default=None, description="State key for output")
59
+ temperature: float | None = Field(default=None, ge=0, le=2)
60
+ provider: str | None = Field(default=None)
61
+ on_error: str | None = Field(default=None)
62
+ fallback: dict[str, Any] | None = Field(default=None)
63
+ variables: dict[str, str] = Field(default_factory=dict)
64
+ requires: list[str] = Field(default_factory=list)
65
+ routes: dict[str, str] | None = Field(default=None, description="Router routes")
66
+
67
+ # Map node fields
68
+ over: str | None = Field(default=None, description="Map over expression")
69
+ # 'as' is reserved in Python, handled specially
70
+ item_var: str | None = Field(default=None, alias="as")
71
+ node: dict[str, Any] | None = Field(default=None, description="Map sub-node")
72
+ collect: str | None = Field(default=None, description="Map collect key")
73
+
74
+ # Tool/Agent fields
75
+ tools: list[str] = Field(default_factory=list)
76
+ max_iterations: int = Field(default=10, ge=1)
77
+
78
+ model_config = {"extra": "allow", "populate_by_name": True}
79
+
80
+ @field_validator("on_error")
81
+ @classmethod
82
+ def validate_on_error(cls, v: str | None) -> str | None:
83
+ """Validate on_error is a known handler."""
84
+ if v is not None and v not in ErrorHandler.all_values():
85
+ valid = ", ".join(ErrorHandler.all_values())
86
+ raise ValueError(f"Invalid on_error '{v}'. Valid: {valid}")
87
+ return v
88
+
89
+ @model_validator(mode="after")
90
+ def validate_node_requirements(self) -> "NodeConfig":
91
+ """Validate node has required fields based on type."""
92
+ if NodeType.requires_prompt(self.type) and not self.prompt:
93
+ raise ValueError(f"Node type '{self.type}' requires 'prompt' field")
94
+
95
+ if self.type == NodeType.ROUTER and not self.routes:
96
+ raise ValueError("Router node requires 'routes' field")
97
+
98
+ if self.type == NodeType.MAP:
99
+ if not self.over:
100
+ raise ValueError("Map node requires 'over' field")
101
+ if not self.item_var:
102
+ raise ValueError("Map node requires 'as' field")
103
+ if not self.node:
104
+ raise ValueError("Map node requires 'node' field")
105
+ if not self.collect:
106
+ raise ValueError("Map node requires 'collect' field")
107
+
108
+ return self
109
+
110
+
111
+ class EdgeConfig(BaseModel):
112
+ """Configuration for a graph edge."""
113
+
114
+ from_node: str = Field(..., alias="from", description="Source node")
115
+ to: str | list[str] = Field(..., description="Target node(s)")
116
+ condition: str | None = Field(default=None, description="Condition expression")
117
+
118
+ model_config = {"populate_by_name": True}
119
+
120
+
121
+ class GraphConfigSchema(BaseModel):
122
+ """Full YAML graph configuration schema.
123
+
124
+ Use this for validating graph YAML files with Pydantic.
125
+ """
126
+
127
+ version: str = Field(default="1.0")
128
+ name: str = Field(default="unnamed")
129
+ description: str = Field(default="")
130
+ defaults: dict[str, Any] = Field(default_factory=dict)
131
+ nodes: dict[str, NodeConfig] = Field(...)
132
+ edges: list[EdgeConfig] = Field(...)
133
+ tools: dict[str, Any] = Field(default_factory=dict)
134
+ state_class: str = Field(default="")
135
+ loop_limits: dict[str, int] = Field(default_factory=dict)
136
+
137
+ model_config = {"extra": "allow"}
138
+
139
+ @model_validator(mode="after")
140
+ def validate_router_targets(self) -> "GraphConfigSchema":
141
+ """Validate router routes point to existing nodes."""
142
+ for node_name, node in self.nodes.items():
143
+ if node.type == NodeType.ROUTER and node.routes:
144
+ for route_key, target in node.routes.items():
145
+ if target not in self.nodes:
146
+ raise ValueError(
147
+ f"Router '{node_name}' route '{route_key}' "
148
+ f"targets nonexistent node '{target}'"
149
+ )
150
+ return self
151
+
152
+ @model_validator(mode="after")
153
+ def validate_edge_nodes(self) -> "GraphConfigSchema":
154
+ """Validate edge sources and targets exist."""
155
+ valid_nodes = set(self.nodes.keys()) | {"START", "END"}
156
+
157
+ for edge in self.edges:
158
+ if edge.from_node not in valid_nodes:
159
+ raise ValueError(f"Edge 'from' node '{edge.from_node}' not found")
160
+
161
+ targets = edge.to if isinstance(edge.to, list) else [edge.to]
162
+ for target in targets:
163
+ if target not in valid_nodes:
164
+ raise ValueError(f"Edge 'to' node '{target}' not found")
165
+
166
+ return self
167
+
168
+
169
+ def validate_graph_schema(config: dict[str, Any]) -> GraphConfigSchema:
170
+ """Validate a graph configuration dict using Pydantic.
171
+
172
+ Args:
173
+ config: Raw parsed YAML configuration
174
+
175
+ Returns:
176
+ Validated GraphConfigSchema
177
+
178
+ Raises:
179
+ pydantic.ValidationError: If validation fails
180
+ """
181
+ return GraphConfigSchema.model_validate(config)
@@ -0,0 +1,124 @@
1
+ """Pydantic models for structured LLM outputs.
2
+
3
+ This module contains FRAMEWORK models only - models used by the framework itself.
4
+ Demo-specific output schemas are defined inline in graph YAML files.
5
+ """
6
+
7
+ from datetime import datetime
8
+ from enum import Enum
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+ # =============================================================================
14
+ # Error Types
15
+ # =============================================================================
16
+
17
+
18
+ class ErrorType(str, Enum):
19
+ """Types of errors that can occur in the pipeline."""
20
+
21
+ LLM_ERROR = "llm_error" # LLM API errors (rate limit, timeout, etc.)
22
+ VALIDATION_ERROR = "validation_error" # Pydantic validation failures
23
+ PROMPT_ERROR = "prompt_error" # Missing prompt, template errors
24
+ STATE_ERROR = "state_error" # Missing required state data
25
+ UNKNOWN_ERROR = "unknown_error" # Catch-all
26
+
27
+
28
+ class PipelineError(BaseModel):
29
+ """Structured error information for pipeline failures."""
30
+
31
+ type: ErrorType = Field(description="Category of error")
32
+ message: str = Field(description="Human-readable error message")
33
+ node: str = Field(description="Node where error occurred")
34
+ timestamp: datetime = Field(default_factory=datetime.now)
35
+ retryable: bool = Field(
36
+ default=False, description="Whether this error can be retried"
37
+ )
38
+ details: dict[str, Any] = Field(
39
+ default_factory=dict, description="Additional error context"
40
+ )
41
+
42
+ @classmethod
43
+ def from_exception(
44
+ cls, e: Exception, node: str, error_type: ErrorType | None = None
45
+ ) -> "PipelineError":
46
+ """Create a PipelineError from an exception.
47
+
48
+ Args:
49
+ e: The exception that occurred
50
+ node: Name of the node where error occurred
51
+ error_type: Optional explicit error type
52
+
53
+ Returns:
54
+ PipelineError instance
55
+ """
56
+ # Infer error type from exception
57
+ if error_type is None:
58
+ exc_name = type(e).__name__.lower()
59
+ if "rate" in exc_name or "timeout" in exc_name or "api" in exc_name:
60
+ error_type = ErrorType.LLM_ERROR
61
+ retryable = True
62
+ elif "validation" in exc_name:
63
+ error_type = ErrorType.VALIDATION_ERROR
64
+ retryable = False
65
+ elif "file" in exc_name or "prompt" in exc_name:
66
+ error_type = ErrorType.PROMPT_ERROR
67
+ retryable = False
68
+ else:
69
+ error_type = ErrorType.UNKNOWN_ERROR
70
+ retryable = False
71
+ else:
72
+ retryable = error_type == ErrorType.LLM_ERROR
73
+
74
+ return cls(
75
+ type=error_type,
76
+ message=str(e),
77
+ node=node,
78
+ retryable=retryable,
79
+ details={"exception_type": type(e).__name__},
80
+ )
81
+
82
+
83
+ # =============================================================================
84
+ # Generic Report Model (Flexible for Any Use Case)
85
+ # =============================================================================
86
+
87
+
88
+ class GenericReport(BaseModel):
89
+ """Flexible report structure for any use case.
90
+
91
+ Use this when you don't need a custom schema - works for most
92
+ analysis and summary tasks. The LLM can populate any combination
93
+ of the optional fields as needed.
94
+
95
+ Example usage in graph YAML:
96
+ nodes:
97
+ analyze:
98
+ type: llm
99
+ prompt: my_analysis
100
+ output_model: yamlgraph.models.GenericReport
101
+
102
+ Example prompts can request specific sections:
103
+ "Analyze the repository and provide:
104
+ - A summary of findings
105
+ - Key findings as bullet points
106
+ - Recommendations for improvement"
107
+ """
108
+
109
+ title: str = Field(description="Report title")
110
+ summary: str = Field(description="Executive summary")
111
+ sections: dict[str, Any] = Field(
112
+ default_factory=dict,
113
+ description="Named sections with any content (strings, dicts, lists)",
114
+ )
115
+ findings: list[str] = Field(
116
+ default_factory=list, description="Key findings or bullet points"
117
+ )
118
+ recommendations: list[str] = Field(
119
+ default_factory=list, description="Suggested actions or areas to focus on"
120
+ )
121
+ metadata: dict[str, Any] = Field(
122
+ default_factory=dict,
123
+ description="Additional key-value data (author, version, tags, etc.)",
124
+ )
@@ -0,0 +1,236 @@
1
+ """Dynamic state class generation from graph configuration.
2
+
3
+ Builds TypedDict programmatically from YAML graph config, eliminating
4
+ the need for state_class coupling between YAML and Python.
5
+ """
6
+
7
+ import logging
8
+ from operator import add
9
+ from typing import Annotated, Any, TypedDict
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def sorted_add(existing: list, new: list) -> list:
15
+ """Reducer that adds items and sorts by _map_index if present.
16
+
17
+ Used for map node fan-in to guarantee order regardless of
18
+ parallel execution timing.
19
+
20
+ Args:
21
+ existing: Current list in state
22
+ new: New items to add
23
+
24
+ Returns:
25
+ Combined list sorted by _map_index (if items have it)
26
+ """
27
+ combined = (existing or []) + (new or [])
28
+
29
+ # Sort by _map_index if items have it
30
+ if combined and isinstance(combined[0], dict) and "_map_index" in combined[0]:
31
+ combined = sorted(combined, key=lambda x: x.get("_map_index", 0))
32
+
33
+ return combined
34
+
35
+
36
+ # =============================================================================
37
+ # Base Fields - Always included in generated state
38
+ # =============================================================================
39
+
40
+ # Infrastructure fields present in all graphs
41
+ BASE_FIELDS: dict[str, type] = {
42
+ # Core tracking
43
+ "thread_id": str,
44
+ "current_step": str,
45
+ # Error handling - singular for current error
46
+ "error": Any,
47
+ # Error handling with reducer (accumulates)
48
+ "errors": Annotated[list, add],
49
+ # Messages with reducer (accumulates)
50
+ "messages": Annotated[list, add],
51
+ # Loop tracking
52
+ "_loop_counts": dict,
53
+ "_loop_limit_reached": bool,
54
+ "_agent_iterations": int,
55
+ "_agent_limit_reached": bool,
56
+ # Timestamps
57
+ "started_at": Any,
58
+ "completed_at": Any,
59
+ }
60
+
61
+ # Common input fields used across graph types
62
+ # These are always included to support --var inputs
63
+ COMMON_INPUT_FIELDS: dict[str, type] = {
64
+ "input": str, # Agent prompt input
65
+ "topic": str, # Content generation topic
66
+ "style": str, # Writing style
67
+ "word_count": int, # Target word count
68
+ "message": str, # Router message input
69
+ }
70
+
71
+ # Type mapping for YAML state config
72
+ TYPE_MAP: dict[str, type] = {
73
+ "str": str,
74
+ "string": str,
75
+ "int": int,
76
+ "integer": int,
77
+ "float": float,
78
+ "bool": bool,
79
+ "boolean": bool,
80
+ "list": list,
81
+ "dict": dict,
82
+ "any": Any,
83
+ }
84
+
85
+
86
+ def parse_state_config(state_config: dict) -> dict[str, type]:
87
+ """Parse YAML state section into field types.
88
+
89
+ Supports simple type strings:
90
+ state:
91
+ concept: str
92
+ count: int
93
+
94
+ Args:
95
+ state_config: Dict from YAML 'state' section
96
+
97
+ Returns:
98
+ Dict of field_name -> Python type
99
+ """
100
+ fields: dict[str, type] = {}
101
+
102
+ for field_name, type_spec in state_config.items():
103
+ if isinstance(type_spec, str):
104
+ # Simple type: "str", "int", etc.
105
+ normalized = type_spec.lower()
106
+ if normalized not in TYPE_MAP:
107
+ supported = ", ".join(sorted(set(TYPE_MAP.keys())))
108
+ logger.warning(
109
+ f"Unknown type '{type_spec}' for state field '{field_name}'. "
110
+ f"Supported types: {supported}. Defaulting to Any."
111
+ )
112
+ python_type = TYPE_MAP.get(normalized, Any)
113
+ fields[field_name] = python_type
114
+ else:
115
+ # Unknown format, use Any
116
+ logger.warning(
117
+ f"Invalid type specification for state field '{field_name}': "
118
+ f"expected string, got {type(type_spec).__name__}. Defaulting to Any."
119
+ )
120
+ fields[field_name] = Any
121
+
122
+ return fields
123
+
124
+
125
+ def build_state_class(config: dict) -> type:
126
+ """Build TypedDict state class from graph configuration.
127
+
128
+ Dynamically generates a TypedDict with:
129
+ - Base infrastructure fields (errors, messages, thread_id, etc.)
130
+ - Common input fields (topic, style, input, message, etc.)
131
+ - Custom fields from YAML 'state' section
132
+ - Fields extracted from node state_key
133
+ - Special fields for agent/router node types
134
+
135
+ Args:
136
+ config: Parsed YAML graph configuration dict
137
+
138
+ Returns:
139
+ TypedDict class with total=False (all fields optional)
140
+ """
141
+ # Start with base and common fields
142
+ fields: dict[str, type] = {}
143
+ fields.update(BASE_FIELDS)
144
+ fields.update(COMMON_INPUT_FIELDS)
145
+
146
+ # Add custom state fields from YAML 'state' section
147
+ state_config = config.get("state", {})
148
+ custom_fields = parse_state_config(state_config)
149
+ fields.update(custom_fields)
150
+
151
+ # Extract fields from nodes
152
+ nodes = config.get("nodes", {})
153
+ node_fields = extract_node_fields(nodes)
154
+ fields.update(node_fields)
155
+
156
+ # Build TypedDict programmatically
157
+ return TypedDict("GraphState", fields, total=False)
158
+
159
+
160
+ def extract_node_fields(nodes: dict) -> dict[str, type]:
161
+ """Extract state fields from node configurations.
162
+
163
+ Analyzes node configs to determine required state fields:
164
+ - state_key: Where node stores its output
165
+ - type: agent → adds input, _tool_results
166
+ - type: router → adds _route
167
+
168
+ Args:
169
+ nodes: Dict of node_name -> node_config
170
+
171
+ Returns:
172
+ Dict of field_name -> type for the state
173
+ """
174
+ fields: dict[str, type] = {}
175
+
176
+ for _node_name, node_config in nodes.items():
177
+ if not isinstance(node_config, dict):
178
+ continue
179
+
180
+ # state_key → Any (accepts Pydantic models)
181
+ if state_key := node_config.get("state_key"):
182
+ fields[state_key] = Any
183
+
184
+ # Node type-specific fields
185
+ node_type = node_config.get("type", "llm")
186
+
187
+ if node_type == "agent":
188
+ fields["input"] = str
189
+ fields["_tool_results"] = list
190
+
191
+ elif node_type == "router":
192
+ fields["_route"] = str
193
+
194
+ elif node_type == "map":
195
+ # Map node collect field needs sorted reducer for ordered fan-in
196
+ if collect_key := node_config.get("collect"):
197
+ fields[collect_key] = Annotated[list, sorted_add]
198
+
199
+ return fields
200
+
201
+
202
+ def create_initial_state(
203
+ topic: str = "",
204
+ style: str = "informative",
205
+ word_count: int = 300,
206
+ thread_id: str | None = None,
207
+ **kwargs: Any,
208
+ ) -> dict[str, Any]:
209
+ """Create an initial state for a new pipeline run.
210
+
211
+ Args:
212
+ topic: The topic to generate content about
213
+ style: Writing style (default: informative)
214
+ word_count: Target word count (default: 300)
215
+ thread_id: Optional thread ID (auto-generated if not provided)
216
+ **kwargs: Additional state fields (e.g., input for agents)
217
+
218
+ Returns:
219
+ Initialized state dictionary
220
+ """
221
+ import uuid
222
+ from datetime import datetime
223
+
224
+ return {
225
+ "thread_id": thread_id or uuid.uuid4().hex[:16],
226
+ "topic": topic,
227
+ "style": style,
228
+ "word_count": word_count,
229
+ "current_step": "init",
230
+ "error": None,
231
+ "errors": [],
232
+ "messages": [],
233
+ "started_at": datetime.now(),
234
+ "completed_at": None,
235
+ **kwargs,
236
+ }