yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Pydantic schemas for YAML graph configuration validation.
|
|
2
|
+
|
|
3
|
+
Provides structured validation for graph YAML files with clear error messages.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
9
|
+
|
|
10
|
+
from yamlgraph.constants import ErrorHandler, NodeType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SubgraphNodeConfig(BaseModel):
|
|
14
|
+
"""Configuration for a subgraph node."""
|
|
15
|
+
|
|
16
|
+
type: Literal["subgraph"]
|
|
17
|
+
graph: str = Field(
|
|
18
|
+
..., description="Path to subgraph YAML file (relative to parent)"
|
|
19
|
+
)
|
|
20
|
+
mode: Literal["invoke", "direct"] = Field(
|
|
21
|
+
default="invoke",
|
|
22
|
+
description="invoke: explicit state mapping; direct: shared schema",
|
|
23
|
+
)
|
|
24
|
+
input_mapping: dict[str, str] | Literal["auto", "*"] = Field(
|
|
25
|
+
default_factory=dict,
|
|
26
|
+
description="Map parent state fields to child input (mode=invoke only)",
|
|
27
|
+
)
|
|
28
|
+
output_mapping: dict[str, str] | Literal["auto", "*"] = Field(
|
|
29
|
+
default_factory=dict,
|
|
30
|
+
description="Map child output fields to parent state (mode=invoke only)",
|
|
31
|
+
)
|
|
32
|
+
interrupt_output_mapping: dict[str, str] | Literal["auto", "*"] = Field(
|
|
33
|
+
default_factory=dict,
|
|
34
|
+
description="Map child state to parent when subgraph interrupts (FR-006)",
|
|
35
|
+
)
|
|
36
|
+
checkpointer: str | None = Field(
|
|
37
|
+
default=None,
|
|
38
|
+
description="Override parent checkpointer",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
model_config = {"extra": "allow"}
|
|
42
|
+
|
|
43
|
+
@model_validator(mode="after")
|
|
44
|
+
def validate_config(self) -> "SubgraphNodeConfig":
|
|
45
|
+
"""Validate subgraph configuration."""
|
|
46
|
+
if not self.graph.endswith((".yaml", ".yml")):
|
|
47
|
+
raise ValueError(f"Subgraph must be a YAML file: {self.graph}")
|
|
48
|
+
if self.mode == "direct" and (self.input_mapping or self.output_mapping):
|
|
49
|
+
raise ValueError("mode=direct does not support input/output mappings")
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class NodeConfig(BaseModel):
|
|
54
|
+
"""Configuration for a single graph node."""
|
|
55
|
+
|
|
56
|
+
type: str = Field(default=NodeType.LLM, description="Node type")
|
|
57
|
+
prompt: str | None = Field(default=None, description="Prompt template name")
|
|
58
|
+
state_key: str | None = Field(default=None, description="State key for output")
|
|
59
|
+
temperature: float | None = Field(default=None, ge=0, le=2)
|
|
60
|
+
provider: str | None = Field(default=None)
|
|
61
|
+
on_error: str | None = Field(default=None)
|
|
62
|
+
fallback: dict[str, Any] | None = Field(default=None)
|
|
63
|
+
variables: dict[str, str] = Field(default_factory=dict)
|
|
64
|
+
requires: list[str] = Field(default_factory=list)
|
|
65
|
+
routes: dict[str, str] | None = Field(default=None, description="Router routes")
|
|
66
|
+
|
|
67
|
+
# Map node fields
|
|
68
|
+
over: str | None = Field(default=None, description="Map over expression")
|
|
69
|
+
# 'as' is reserved in Python, handled specially
|
|
70
|
+
item_var: str | None = Field(default=None, alias="as")
|
|
71
|
+
node: dict[str, Any] | None = Field(default=None, description="Map sub-node")
|
|
72
|
+
collect: str | None = Field(default=None, description="Map collect key")
|
|
73
|
+
|
|
74
|
+
# Tool/Agent fields
|
|
75
|
+
tools: list[str] = Field(default_factory=list)
|
|
76
|
+
max_iterations: int = Field(default=10, ge=1)
|
|
77
|
+
|
|
78
|
+
model_config = {"extra": "allow", "populate_by_name": True}
|
|
79
|
+
|
|
80
|
+
@field_validator("on_error")
|
|
81
|
+
@classmethod
|
|
82
|
+
def validate_on_error(cls, v: str | None) -> str | None:
|
|
83
|
+
"""Validate on_error is a known handler."""
|
|
84
|
+
if v is not None and v not in ErrorHandler.all_values():
|
|
85
|
+
valid = ", ".join(ErrorHandler.all_values())
|
|
86
|
+
raise ValueError(f"Invalid on_error '{v}'. Valid: {valid}")
|
|
87
|
+
return v
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="after")
|
|
90
|
+
def validate_node_requirements(self) -> "NodeConfig":
|
|
91
|
+
"""Validate node has required fields based on type."""
|
|
92
|
+
if NodeType.requires_prompt(self.type) and not self.prompt:
|
|
93
|
+
raise ValueError(f"Node type '{self.type}' requires 'prompt' field")
|
|
94
|
+
|
|
95
|
+
if self.type == NodeType.ROUTER and not self.routes:
|
|
96
|
+
raise ValueError("Router node requires 'routes' field")
|
|
97
|
+
|
|
98
|
+
if self.type == NodeType.MAP:
|
|
99
|
+
if not self.over:
|
|
100
|
+
raise ValueError("Map node requires 'over' field")
|
|
101
|
+
if not self.item_var:
|
|
102
|
+
raise ValueError("Map node requires 'as' field")
|
|
103
|
+
if not self.node:
|
|
104
|
+
raise ValueError("Map node requires 'node' field")
|
|
105
|
+
if not self.collect:
|
|
106
|
+
raise ValueError("Map node requires 'collect' field")
|
|
107
|
+
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class EdgeConfig(BaseModel):
|
|
112
|
+
"""Configuration for a graph edge."""
|
|
113
|
+
|
|
114
|
+
from_node: str = Field(..., alias="from", description="Source node")
|
|
115
|
+
to: str | list[str] = Field(..., description="Target node(s)")
|
|
116
|
+
condition: str | None = Field(default=None, description="Condition expression")
|
|
117
|
+
|
|
118
|
+
model_config = {"populate_by_name": True}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class GraphConfigSchema(BaseModel):
|
|
122
|
+
"""Full YAML graph configuration schema.
|
|
123
|
+
|
|
124
|
+
Use this for validating graph YAML files with Pydantic.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
version: str = Field(default="1.0")
|
|
128
|
+
name: str = Field(default="unnamed")
|
|
129
|
+
description: str = Field(default="")
|
|
130
|
+
defaults: dict[str, Any] = Field(default_factory=dict)
|
|
131
|
+
nodes: dict[str, NodeConfig] = Field(...)
|
|
132
|
+
edges: list[EdgeConfig] = Field(...)
|
|
133
|
+
tools: dict[str, Any] = Field(default_factory=dict)
|
|
134
|
+
state_class: str = Field(default="")
|
|
135
|
+
loop_limits: dict[str, int] = Field(default_factory=dict)
|
|
136
|
+
|
|
137
|
+
model_config = {"extra": "allow"}
|
|
138
|
+
|
|
139
|
+
@model_validator(mode="after")
|
|
140
|
+
def validate_router_targets(self) -> "GraphConfigSchema":
|
|
141
|
+
"""Validate router routes point to existing nodes."""
|
|
142
|
+
for node_name, node in self.nodes.items():
|
|
143
|
+
if node.type == NodeType.ROUTER and node.routes:
|
|
144
|
+
for route_key, target in node.routes.items():
|
|
145
|
+
if target not in self.nodes:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Router '{node_name}' route '{route_key}' "
|
|
148
|
+
f"targets nonexistent node '{target}'"
|
|
149
|
+
)
|
|
150
|
+
return self
|
|
151
|
+
|
|
152
|
+
@model_validator(mode="after")
|
|
153
|
+
def validate_edge_nodes(self) -> "GraphConfigSchema":
|
|
154
|
+
"""Validate edge sources and targets exist."""
|
|
155
|
+
valid_nodes = set(self.nodes.keys()) | {"START", "END"}
|
|
156
|
+
|
|
157
|
+
for edge in self.edges:
|
|
158
|
+
if edge.from_node not in valid_nodes:
|
|
159
|
+
raise ValueError(f"Edge 'from' node '{edge.from_node}' not found")
|
|
160
|
+
|
|
161
|
+
targets = edge.to if isinstance(edge.to, list) else [edge.to]
|
|
162
|
+
for target in targets:
|
|
163
|
+
if target not in valid_nodes:
|
|
164
|
+
raise ValueError(f"Edge 'to' node '{target}' not found")
|
|
165
|
+
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def validate_graph_schema(config: dict[str, Any]) -> GraphConfigSchema:
|
|
170
|
+
"""Validate a graph configuration dict using Pydantic.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
config: Raw parsed YAML configuration
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Validated GraphConfigSchema
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
pydantic.ValidationError: If validation fails
|
|
180
|
+
"""
|
|
181
|
+
return GraphConfigSchema.model_validate(config)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Pydantic models for structured LLM outputs.
|
|
2
|
+
|
|
3
|
+
This module contains FRAMEWORK models only - models used by the framework itself.
|
|
4
|
+
Demo-specific output schemas are defined inline in graph YAML files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
|
|
13
|
+
# =============================================================================
|
|
14
|
+
# Error Types
|
|
15
|
+
# =============================================================================
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ErrorType(str, Enum):
|
|
19
|
+
"""Types of errors that can occur in the pipeline."""
|
|
20
|
+
|
|
21
|
+
LLM_ERROR = "llm_error" # LLM API errors (rate limit, timeout, etc.)
|
|
22
|
+
VALIDATION_ERROR = "validation_error" # Pydantic validation failures
|
|
23
|
+
PROMPT_ERROR = "prompt_error" # Missing prompt, template errors
|
|
24
|
+
STATE_ERROR = "state_error" # Missing required state data
|
|
25
|
+
UNKNOWN_ERROR = "unknown_error" # Catch-all
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PipelineError(BaseModel):
|
|
29
|
+
"""Structured error information for pipeline failures."""
|
|
30
|
+
|
|
31
|
+
type: ErrorType = Field(description="Category of error")
|
|
32
|
+
message: str = Field(description="Human-readable error message")
|
|
33
|
+
node: str = Field(description="Node where error occurred")
|
|
34
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
35
|
+
retryable: bool = Field(
|
|
36
|
+
default=False, description="Whether this error can be retried"
|
|
37
|
+
)
|
|
38
|
+
details: dict[str, Any] = Field(
|
|
39
|
+
default_factory=dict, description="Additional error context"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_exception(
|
|
44
|
+
cls, e: Exception, node: str, error_type: ErrorType | None = None
|
|
45
|
+
) -> "PipelineError":
|
|
46
|
+
"""Create a PipelineError from an exception.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
e: The exception that occurred
|
|
50
|
+
node: Name of the node where error occurred
|
|
51
|
+
error_type: Optional explicit error type
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
PipelineError instance
|
|
55
|
+
"""
|
|
56
|
+
# Infer error type from exception
|
|
57
|
+
if error_type is None:
|
|
58
|
+
exc_name = type(e).__name__.lower()
|
|
59
|
+
if "rate" in exc_name or "timeout" in exc_name or "api" in exc_name:
|
|
60
|
+
error_type = ErrorType.LLM_ERROR
|
|
61
|
+
retryable = True
|
|
62
|
+
elif "validation" in exc_name:
|
|
63
|
+
error_type = ErrorType.VALIDATION_ERROR
|
|
64
|
+
retryable = False
|
|
65
|
+
elif "file" in exc_name or "prompt" in exc_name:
|
|
66
|
+
error_type = ErrorType.PROMPT_ERROR
|
|
67
|
+
retryable = False
|
|
68
|
+
else:
|
|
69
|
+
error_type = ErrorType.UNKNOWN_ERROR
|
|
70
|
+
retryable = False
|
|
71
|
+
else:
|
|
72
|
+
retryable = error_type == ErrorType.LLM_ERROR
|
|
73
|
+
|
|
74
|
+
return cls(
|
|
75
|
+
type=error_type,
|
|
76
|
+
message=str(e),
|
|
77
|
+
node=node,
|
|
78
|
+
retryable=retryable,
|
|
79
|
+
details={"exception_type": type(e).__name__},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# =============================================================================
|
|
84
|
+
# Generic Report Model (Flexible for Any Use Case)
|
|
85
|
+
# =============================================================================
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class GenericReport(BaseModel):
|
|
89
|
+
"""Flexible report structure for any use case.
|
|
90
|
+
|
|
91
|
+
Use this when you don't need a custom schema - works for most
|
|
92
|
+
analysis and summary tasks. The LLM can populate any combination
|
|
93
|
+
of the optional fields as needed.
|
|
94
|
+
|
|
95
|
+
Example usage in graph YAML:
|
|
96
|
+
nodes:
|
|
97
|
+
analyze:
|
|
98
|
+
type: llm
|
|
99
|
+
prompt: my_analysis
|
|
100
|
+
output_model: yamlgraph.models.GenericReport
|
|
101
|
+
|
|
102
|
+
Example prompts can request specific sections:
|
|
103
|
+
"Analyze the repository and provide:
|
|
104
|
+
- A summary of findings
|
|
105
|
+
- Key findings as bullet points
|
|
106
|
+
- Recommendations for improvement"
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
title: str = Field(description="Report title")
|
|
110
|
+
summary: str = Field(description="Executive summary")
|
|
111
|
+
sections: dict[str, Any] = Field(
|
|
112
|
+
default_factory=dict,
|
|
113
|
+
description="Named sections with any content (strings, dicts, lists)",
|
|
114
|
+
)
|
|
115
|
+
findings: list[str] = Field(
|
|
116
|
+
default_factory=list, description="Key findings or bullet points"
|
|
117
|
+
)
|
|
118
|
+
recommendations: list[str] = Field(
|
|
119
|
+
default_factory=list, description="Suggested actions or areas to focus on"
|
|
120
|
+
)
|
|
121
|
+
metadata: dict[str, Any] = Field(
|
|
122
|
+
default_factory=dict,
|
|
123
|
+
description="Additional key-value data (author, version, tags, etc.)",
|
|
124
|
+
)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Dynamic state class generation from graph configuration.
|
|
2
|
+
|
|
3
|
+
Builds TypedDict programmatically from YAML graph config, eliminating
|
|
4
|
+
the need for state_class coupling between YAML and Python.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from operator import add
|
|
9
|
+
from typing import Annotated, Any, TypedDict
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def sorted_add(existing: list, new: list) -> list:
|
|
15
|
+
"""Reducer that adds items and sorts by _map_index if present.
|
|
16
|
+
|
|
17
|
+
Used for map node fan-in to guarantee order regardless of
|
|
18
|
+
parallel execution timing.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
existing: Current list in state
|
|
22
|
+
new: New items to add
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Combined list sorted by _map_index (if items have it)
|
|
26
|
+
"""
|
|
27
|
+
combined = (existing or []) + (new or [])
|
|
28
|
+
|
|
29
|
+
# Sort by _map_index if items have it
|
|
30
|
+
if combined and isinstance(combined[0], dict) and "_map_index" in combined[0]:
|
|
31
|
+
combined = sorted(combined, key=lambda x: x.get("_map_index", 0))
|
|
32
|
+
|
|
33
|
+
return combined
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# =============================================================================
|
|
37
|
+
# Base Fields - Always included in generated state
|
|
38
|
+
# =============================================================================
|
|
39
|
+
|
|
40
|
+
# Infrastructure fields present in all graphs
|
|
41
|
+
BASE_FIELDS: dict[str, type] = {
|
|
42
|
+
# Core tracking
|
|
43
|
+
"thread_id": str,
|
|
44
|
+
"current_step": str,
|
|
45
|
+
# Error handling - singular for current error
|
|
46
|
+
"error": Any,
|
|
47
|
+
# Error handling with reducer (accumulates)
|
|
48
|
+
"errors": Annotated[list, add],
|
|
49
|
+
# Messages with reducer (accumulates)
|
|
50
|
+
"messages": Annotated[list, add],
|
|
51
|
+
# Loop tracking
|
|
52
|
+
"_loop_counts": dict,
|
|
53
|
+
"_loop_limit_reached": bool,
|
|
54
|
+
"_agent_iterations": int,
|
|
55
|
+
"_agent_limit_reached": bool,
|
|
56
|
+
# Timestamps
|
|
57
|
+
"started_at": Any,
|
|
58
|
+
"completed_at": Any,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
# Common input fields used across graph types
|
|
62
|
+
# These are always included to support --var inputs
|
|
63
|
+
COMMON_INPUT_FIELDS: dict[str, type] = {
|
|
64
|
+
"input": str, # Agent prompt input
|
|
65
|
+
"topic": str, # Content generation topic
|
|
66
|
+
"style": str, # Writing style
|
|
67
|
+
"word_count": int, # Target word count
|
|
68
|
+
"message": str, # Router message input
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# Type mapping for YAML state config
|
|
72
|
+
TYPE_MAP: dict[str, type] = {
|
|
73
|
+
"str": str,
|
|
74
|
+
"string": str,
|
|
75
|
+
"int": int,
|
|
76
|
+
"integer": int,
|
|
77
|
+
"float": float,
|
|
78
|
+
"bool": bool,
|
|
79
|
+
"boolean": bool,
|
|
80
|
+
"list": list,
|
|
81
|
+
"dict": dict,
|
|
82
|
+
"any": Any,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def parse_state_config(state_config: dict) -> dict[str, type]:
|
|
87
|
+
"""Parse YAML state section into field types.
|
|
88
|
+
|
|
89
|
+
Supports simple type strings:
|
|
90
|
+
state:
|
|
91
|
+
concept: str
|
|
92
|
+
count: int
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
state_config: Dict from YAML 'state' section
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Dict of field_name -> Python type
|
|
99
|
+
"""
|
|
100
|
+
fields: dict[str, type] = {}
|
|
101
|
+
|
|
102
|
+
for field_name, type_spec in state_config.items():
|
|
103
|
+
if isinstance(type_spec, str):
|
|
104
|
+
# Simple type: "str", "int", etc.
|
|
105
|
+
normalized = type_spec.lower()
|
|
106
|
+
if normalized not in TYPE_MAP:
|
|
107
|
+
supported = ", ".join(sorted(set(TYPE_MAP.keys())))
|
|
108
|
+
logger.warning(
|
|
109
|
+
f"Unknown type '{type_spec}' for state field '{field_name}'. "
|
|
110
|
+
f"Supported types: {supported}. Defaulting to Any."
|
|
111
|
+
)
|
|
112
|
+
python_type = TYPE_MAP.get(normalized, Any)
|
|
113
|
+
fields[field_name] = python_type
|
|
114
|
+
else:
|
|
115
|
+
# Unknown format, use Any
|
|
116
|
+
logger.warning(
|
|
117
|
+
f"Invalid type specification for state field '{field_name}': "
|
|
118
|
+
f"expected string, got {type(type_spec).__name__}. Defaulting to Any."
|
|
119
|
+
)
|
|
120
|
+
fields[field_name] = Any
|
|
121
|
+
|
|
122
|
+
return fields
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def build_state_class(config: dict) -> type:
|
|
126
|
+
"""Build TypedDict state class from graph configuration.
|
|
127
|
+
|
|
128
|
+
Dynamically generates a TypedDict with:
|
|
129
|
+
- Base infrastructure fields (errors, messages, thread_id, etc.)
|
|
130
|
+
- Common input fields (topic, style, input, message, etc.)
|
|
131
|
+
- Custom fields from YAML 'state' section
|
|
132
|
+
- Fields extracted from node state_key
|
|
133
|
+
- Special fields for agent/router node types
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
config: Parsed YAML graph configuration dict
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
TypedDict class with total=False (all fields optional)
|
|
140
|
+
"""
|
|
141
|
+
# Start with base and common fields
|
|
142
|
+
fields: dict[str, type] = {}
|
|
143
|
+
fields.update(BASE_FIELDS)
|
|
144
|
+
fields.update(COMMON_INPUT_FIELDS)
|
|
145
|
+
|
|
146
|
+
# Add custom state fields from YAML 'state' section
|
|
147
|
+
state_config = config.get("state", {})
|
|
148
|
+
custom_fields = parse_state_config(state_config)
|
|
149
|
+
fields.update(custom_fields)
|
|
150
|
+
|
|
151
|
+
# Extract fields from nodes
|
|
152
|
+
nodes = config.get("nodes", {})
|
|
153
|
+
node_fields = extract_node_fields(nodes)
|
|
154
|
+
fields.update(node_fields)
|
|
155
|
+
|
|
156
|
+
# Build TypedDict programmatically
|
|
157
|
+
return TypedDict("GraphState", fields, total=False)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def extract_node_fields(nodes: dict) -> dict[str, type]:
|
|
161
|
+
"""Extract state fields from node configurations.
|
|
162
|
+
|
|
163
|
+
Analyzes node configs to determine required state fields:
|
|
164
|
+
- state_key: Where node stores its output
|
|
165
|
+
- type: agent → adds input, _tool_results
|
|
166
|
+
- type: router → adds _route
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
nodes: Dict of node_name -> node_config
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Dict of field_name -> type for the state
|
|
173
|
+
"""
|
|
174
|
+
fields: dict[str, type] = {}
|
|
175
|
+
|
|
176
|
+
for _node_name, node_config in nodes.items():
|
|
177
|
+
if not isinstance(node_config, dict):
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
# state_key → Any (accepts Pydantic models)
|
|
181
|
+
if state_key := node_config.get("state_key"):
|
|
182
|
+
fields[state_key] = Any
|
|
183
|
+
|
|
184
|
+
# Node type-specific fields
|
|
185
|
+
node_type = node_config.get("type", "llm")
|
|
186
|
+
|
|
187
|
+
if node_type == "agent":
|
|
188
|
+
fields["input"] = str
|
|
189
|
+
fields["_tool_results"] = list
|
|
190
|
+
|
|
191
|
+
elif node_type == "router":
|
|
192
|
+
fields["_route"] = str
|
|
193
|
+
|
|
194
|
+
elif node_type == "map":
|
|
195
|
+
# Map node collect field needs sorted reducer for ordered fan-in
|
|
196
|
+
if collect_key := node_config.get("collect"):
|
|
197
|
+
fields[collect_key] = Annotated[list, sorted_add]
|
|
198
|
+
|
|
199
|
+
return fields
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def create_initial_state(
|
|
203
|
+
topic: str = "",
|
|
204
|
+
style: str = "informative",
|
|
205
|
+
word_count: int = 300,
|
|
206
|
+
thread_id: str | None = None,
|
|
207
|
+
**kwargs: Any,
|
|
208
|
+
) -> dict[str, Any]:
|
|
209
|
+
"""Create an initial state for a new pipeline run.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
topic: The topic to generate content about
|
|
213
|
+
style: Writing style (default: informative)
|
|
214
|
+
word_count: Target word count (default: 300)
|
|
215
|
+
thread_id: Optional thread ID (auto-generated if not provided)
|
|
216
|
+
**kwargs: Additional state fields (e.g., input for agents)
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Initialized state dictionary
|
|
220
|
+
"""
|
|
221
|
+
import uuid
|
|
222
|
+
from datetime import datetime
|
|
223
|
+
|
|
224
|
+
return {
|
|
225
|
+
"thread_id": thread_id or uuid.uuid4().hex[:16],
|
|
226
|
+
"topic": topic,
|
|
227
|
+
"style": style,
|
|
228
|
+
"word_count": word_count,
|
|
229
|
+
"current_step": "init",
|
|
230
|
+
"error": None,
|
|
231
|
+
"errors": [],
|
|
232
|
+
"messages": [],
|
|
233
|
+
"started_at": datetime.now(),
|
|
234
|
+
"completed_at": None,
|
|
235
|
+
**kwargs,
|
|
236
|
+
}
|