yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Dynamic Pydantic model generation from YAML schema definitions.
|
|
2
|
+
|
|
3
|
+
This module enables defining output schemas in YAML prompt files,
|
|
4
|
+
making prompts fully self-contained with their expected output structure.
|
|
5
|
+
|
|
6
|
+
Example YAML schema:
|
|
7
|
+
schema:
|
|
8
|
+
name: MyOutputModel
|
|
9
|
+
fields:
|
|
10
|
+
title:
|
|
11
|
+
type: str
|
|
12
|
+
description: "The output title"
|
|
13
|
+
confidence:
|
|
14
|
+
type: float
|
|
15
|
+
constraints: {ge: 0.0, le: 1.0}
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import re
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import yaml
|
|
23
|
+
from pydantic import Field, create_model
|
|
24
|
+
|
|
25
|
+
# =============================================================================
|
|
26
|
+
# Type Resolution
|
|
27
|
+
# =============================================================================
|
|
28
|
+
|
|
29
|
+
# Mapping from type strings to Python types
|
|
30
|
+
TYPE_MAP: dict[str, type] = {
|
|
31
|
+
"str": str,
|
|
32
|
+
"int": int,
|
|
33
|
+
"float": float,
|
|
34
|
+
"bool": bool,
|
|
35
|
+
"Any": Any,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def resolve_type(type_str: str, field_name: str | None = None) -> type:
|
|
40
|
+
"""Resolve a type string to a Python type.
|
|
41
|
+
|
|
42
|
+
Supports:
|
|
43
|
+
- Basic types: str, int, float, bool, Any
|
|
44
|
+
- Generic types: list[str], list[int], dict[str, str]
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
type_str: Type string like "str", "list[str]", "dict[str, Any]"
|
|
48
|
+
field_name: Optional field name for better error messages
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Python type annotation
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If type string is not recognized
|
|
55
|
+
"""
|
|
56
|
+
# Check basic types first
|
|
57
|
+
if type_str in TYPE_MAP:
|
|
58
|
+
return TYPE_MAP[type_str]
|
|
59
|
+
|
|
60
|
+
# Handle list[T] pattern
|
|
61
|
+
list_match = re.match(r"list\[(\w+)\]", type_str)
|
|
62
|
+
if list_match:
|
|
63
|
+
inner_type = resolve_type(list_match.group(1), field_name)
|
|
64
|
+
return list[inner_type]
|
|
65
|
+
|
|
66
|
+
# Handle dict[K, V] pattern
|
|
67
|
+
dict_match = re.match(r"dict\[(\w+),\s*(\w+)\]", type_str)
|
|
68
|
+
if dict_match:
|
|
69
|
+
key_type = resolve_type(dict_match.group(1), field_name)
|
|
70
|
+
value_type = resolve_type(dict_match.group(2), field_name)
|
|
71
|
+
return dict[key_type, value_type]
|
|
72
|
+
|
|
73
|
+
# Provide helpful error with supported types
|
|
74
|
+
supported = ", ".join(TYPE_MAP.keys())
|
|
75
|
+
context = f" for field '{field_name}'" if field_name else ""
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Unknown type: '{type_str}'{context}. "
|
|
78
|
+
f"Supported types: {supported}, list[T], dict[K, V]"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# =============================================================================
|
|
83
|
+
# Model Building
|
|
84
|
+
# =============================================================================
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def build_pydantic_model(schema: dict) -> type:
|
|
88
|
+
"""Build a Pydantic model dynamically from a schema dict.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
schema: Schema definition with 'name' and 'fields' keys
|
|
92
|
+
Example:
|
|
93
|
+
{
|
|
94
|
+
"name": "MyOutputModel",
|
|
95
|
+
"fields": {
|
|
96
|
+
"title": {"type": "str", "description": "..."},
|
|
97
|
+
"score": {"type": "float", "constraints": {"ge": 0}},
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Dynamically created Pydantic model class
|
|
103
|
+
"""
|
|
104
|
+
model_name = schema["name"]
|
|
105
|
+
field_definitions = {}
|
|
106
|
+
|
|
107
|
+
for field_name, field_def in schema["fields"].items():
|
|
108
|
+
# Resolve the type - pass field_name for better error messages
|
|
109
|
+
field_type = resolve_type(field_def["type"], field_name)
|
|
110
|
+
|
|
111
|
+
# Handle optional fields
|
|
112
|
+
is_optional = field_def.get("optional", False)
|
|
113
|
+
if is_optional:
|
|
114
|
+
field_type = field_type | None
|
|
115
|
+
|
|
116
|
+
# Build Field kwargs
|
|
117
|
+
field_kwargs: dict[str, Any] = {}
|
|
118
|
+
|
|
119
|
+
if "description" in field_def:
|
|
120
|
+
field_kwargs["description"] = field_def["description"]
|
|
121
|
+
|
|
122
|
+
if "default" in field_def:
|
|
123
|
+
field_kwargs["default"] = field_def["default"]
|
|
124
|
+
elif is_optional:
|
|
125
|
+
field_kwargs["default"] = None
|
|
126
|
+
|
|
127
|
+
# Add constraints (ge, le, min_length, max_length, etc.)
|
|
128
|
+
if constraints := field_def.get("constraints"):
|
|
129
|
+
field_kwargs.update(constraints)
|
|
130
|
+
|
|
131
|
+
# Create field tuple: (type, Field(...))
|
|
132
|
+
if field_kwargs:
|
|
133
|
+
field_definitions[field_name] = (field_type, Field(**field_kwargs))
|
|
134
|
+
else:
|
|
135
|
+
field_definitions[field_name] = (field_type, ...)
|
|
136
|
+
|
|
137
|
+
return create_model(model_name, **field_definitions)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# =============================================================================
|
|
141
|
+
# YAML Loading
|
|
142
|
+
# =============================================================================
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def load_schema_from_yaml(yaml_path: str | Path) -> type | None:
|
|
146
|
+
"""Load a Pydantic model from a prompt YAML file's schema block.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
yaml_path: Path to the YAML prompt file
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Dynamically created Pydantic model, or None if no schema defined
|
|
153
|
+
"""
|
|
154
|
+
with open(yaml_path) as f:
|
|
155
|
+
config = yaml.safe_load(f)
|
|
156
|
+
|
|
157
|
+
if "schema" not in config:
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
return build_pydantic_model(config["schema"])
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Storage utilities for persistence and export."""
|
|
2
|
+
|
|
3
|
+
from yamlgraph.storage.database import YamlGraphDB
|
|
4
|
+
from yamlgraph.storage.export import (
|
|
5
|
+
export_state,
|
|
6
|
+
export_summary,
|
|
7
|
+
list_exports,
|
|
8
|
+
load_export,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"YamlGraphDB",
|
|
13
|
+
"export_state",
|
|
14
|
+
"export_summary",
|
|
15
|
+
"list_exports",
|
|
16
|
+
"load_export",
|
|
17
|
+
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""LangGraph native checkpointer integration.
|
|
2
|
+
|
|
3
|
+
Provides SQLite-based checkpointing for graph state persistence,
|
|
4
|
+
enabling time travel, replay, and resume from any checkpoint.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sqlite3
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
12
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
13
|
+
|
|
14
|
+
from yamlgraph.config import DATABASE_PATH
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_checkpointer(db_path: str | Path | None = None) -> SqliteSaver:
|
|
18
|
+
"""Get a SQLite checkpointer for graph compilation.
|
|
19
|
+
|
|
20
|
+
The checkpointer enables:
|
|
21
|
+
- Automatic state persistence after each node
|
|
22
|
+
- Time travel via get_state_history()
|
|
23
|
+
- Resume from any checkpoint
|
|
24
|
+
- Fault tolerance with pending writes
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
db_path: Path to SQLite database file.
|
|
28
|
+
Defaults to outputs/yamlgraph.db
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
SqliteSaver instance for use with graph.compile()
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> checkpointer = get_checkpointer()
|
|
35
|
+
>>> graph = workflow.compile(checkpointer=checkpointer)
|
|
36
|
+
>>> result = graph.invoke(input, {"configurable": {"thread_id": "abc"}})
|
|
37
|
+
"""
|
|
38
|
+
if db_path is None:
|
|
39
|
+
db_path = DATABASE_PATH
|
|
40
|
+
|
|
41
|
+
path = Path(db_path)
|
|
42
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
|
|
44
|
+
conn = sqlite3.connect(str(path), check_same_thread=False)
|
|
45
|
+
return SqliteSaver(conn)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_state_history(
|
|
49
|
+
graph: CompiledStateGraph,
|
|
50
|
+
thread_id: str,
|
|
51
|
+
) -> list[Any]:
|
|
52
|
+
"""Get checkpoint history for a thread.
|
|
53
|
+
|
|
54
|
+
Returns checkpoints in reverse chronological order (most recent first).
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
graph: Compiled graph with checkpointer
|
|
58
|
+
thread_id: Thread identifier to query
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of StateSnapshot objects, or empty list if thread doesn't exist
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
>>> history = get_state_history(graph, "my-thread")
|
|
65
|
+
>>> for snapshot in history:
|
|
66
|
+
... print(f"Step {snapshot.metadata.get('step')}: {snapshot.values}")
|
|
67
|
+
"""
|
|
68
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
69
|
+
try:
|
|
70
|
+
return list(graph.get_state_history(config))
|
|
71
|
+
except Exception:
|
|
72
|
+
return []
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""SQLite Storage - Simple persistence for pipeline state.
|
|
2
|
+
|
|
3
|
+
Provides a lightweight wrapper around SQLite for storing
|
|
4
|
+
and retrieving pipeline execution state.
|
|
5
|
+
|
|
6
|
+
Supports optional connection pooling for high-throughput scenarios.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import sqlite3
|
|
11
|
+
import threading
|
|
12
|
+
from contextlib import contextmanager
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from queue import Empty, Queue
|
|
16
|
+
from typing import Iterator
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from yamlgraph.config import DATABASE_PATH
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ConnectionPool:
|
|
24
|
+
"""Thread-safe SQLite connection pool.
|
|
25
|
+
|
|
26
|
+
Maintains a pool of reusable connections for high-throughput scenarios.
|
|
27
|
+
Connections are returned to the pool after use instead of being closed.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, db_path: Path, pool_size: int = 5):
|
|
31
|
+
"""Initialize connection pool.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
db_path: Path to SQLite database
|
|
35
|
+
pool_size: Maximum number of connections to maintain
|
|
36
|
+
"""
|
|
37
|
+
self._db_path = db_path
|
|
38
|
+
self._pool_size = pool_size
|
|
39
|
+
self._pool: Queue[sqlite3.Connection] = Queue(maxsize=pool_size)
|
|
40
|
+
self._lock = threading.Lock()
|
|
41
|
+
self._total_connections = 0
|
|
42
|
+
|
|
43
|
+
def _create_connection(self) -> sqlite3.Connection:
|
|
44
|
+
"""Create a new database connection."""
|
|
45
|
+
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
|
46
|
+
conn.row_factory = sqlite3.Row
|
|
47
|
+
return conn
|
|
48
|
+
|
|
49
|
+
@contextmanager
|
|
50
|
+
def get_connection(self) -> Iterator[sqlite3.Connection]:
|
|
51
|
+
"""Get a connection from the pool.
|
|
52
|
+
|
|
53
|
+
Creates a new connection if pool is empty and under limit.
|
|
54
|
+
|
|
55
|
+
Yields:
|
|
56
|
+
Database connection (returned to pool on exit)
|
|
57
|
+
"""
|
|
58
|
+
conn = None
|
|
59
|
+
try:
|
|
60
|
+
# Try to get from pool
|
|
61
|
+
try:
|
|
62
|
+
conn = self._pool.get_nowait()
|
|
63
|
+
except Empty:
|
|
64
|
+
# Pool empty - create new connection if under limit
|
|
65
|
+
with self._lock:
|
|
66
|
+
if self._total_connections < self._pool_size:
|
|
67
|
+
conn = self._create_connection()
|
|
68
|
+
self._total_connections += 1
|
|
69
|
+
else:
|
|
70
|
+
# At limit - block waiting for connection
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
if conn is None:
|
|
74
|
+
conn = self._pool.get() # Blocking wait
|
|
75
|
+
|
|
76
|
+
yield conn
|
|
77
|
+
|
|
78
|
+
finally:
|
|
79
|
+
# Return connection to pool
|
|
80
|
+
if conn is not None:
|
|
81
|
+
try:
|
|
82
|
+
self._pool.put_nowait(conn)
|
|
83
|
+
except Exception:
|
|
84
|
+
# Pool full, close connection
|
|
85
|
+
conn.close()
|
|
86
|
+
with self._lock:
|
|
87
|
+
self._total_connections -= 1
|
|
88
|
+
|
|
89
|
+
def close_all(self) -> None:
|
|
90
|
+
"""Close all connections in the pool."""
|
|
91
|
+
while True:
|
|
92
|
+
try:
|
|
93
|
+
conn = self._pool.get_nowait()
|
|
94
|
+
conn.close()
|
|
95
|
+
except Empty:
|
|
96
|
+
break
|
|
97
|
+
with self._lock:
|
|
98
|
+
self._total_connections = 0
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class YamlGraphDB:
|
|
102
|
+
"""SQLite wrapper for yamlgraph state persistence.
|
|
103
|
+
|
|
104
|
+
Supports two connection modes:
|
|
105
|
+
- Default: Creates new connection per operation (simple, safe)
|
|
106
|
+
- Pooled: Reuses connections from pool (high-throughput)
|
|
107
|
+
|
|
108
|
+
Example:
|
|
109
|
+
# Default mode (simple)
|
|
110
|
+
db = YamlGraphDB()
|
|
111
|
+
|
|
112
|
+
# Pooled mode (high-throughput)
|
|
113
|
+
db = YamlGraphDB(use_pool=True, pool_size=10)
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
db_path: str | Path | None = None,
|
|
119
|
+
use_pool: bool = False,
|
|
120
|
+
pool_size: int = 5,
|
|
121
|
+
):
|
|
122
|
+
"""Initialize database connection.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
db_path: Path to SQLite database file (default: outputs/yamlgraph.db)
|
|
126
|
+
use_pool: Enable connection pooling for high-throughput scenarios
|
|
127
|
+
pool_size: Maximum connections in pool (only used if use_pool=True)
|
|
128
|
+
"""
|
|
129
|
+
if db_path is None:
|
|
130
|
+
db_path = DATABASE_PATH
|
|
131
|
+
self.db_path = Path(db_path)
|
|
132
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
|
|
134
|
+
self._use_pool = use_pool
|
|
135
|
+
self._pool: ConnectionPool | None = None
|
|
136
|
+
if use_pool:
|
|
137
|
+
self._pool = ConnectionPool(self.db_path, pool_size)
|
|
138
|
+
|
|
139
|
+
self._init_db()
|
|
140
|
+
|
|
141
|
+
@contextmanager
|
|
142
|
+
def _get_connection(self) -> Iterator[sqlite3.Connection]:
|
|
143
|
+
"""Get a database connection.
|
|
144
|
+
|
|
145
|
+
Uses pool if enabled, otherwise creates new connection.
|
|
146
|
+
|
|
147
|
+
Yields:
|
|
148
|
+
Database connection
|
|
149
|
+
"""
|
|
150
|
+
if self._pool is not None:
|
|
151
|
+
with self._pool.get_connection() as conn:
|
|
152
|
+
yield conn
|
|
153
|
+
else:
|
|
154
|
+
conn = sqlite3.connect(self.db_path)
|
|
155
|
+
conn.row_factory = sqlite3.Row
|
|
156
|
+
try:
|
|
157
|
+
yield conn
|
|
158
|
+
finally:
|
|
159
|
+
conn.close()
|
|
160
|
+
|
|
161
|
+
def close(self) -> None:
|
|
162
|
+
"""Close database connections.
|
|
163
|
+
|
|
164
|
+
For pooled mode, closes all connections in pool.
|
|
165
|
+
"""
|
|
166
|
+
if self._pool is not None:
|
|
167
|
+
self._pool.close_all()
|
|
168
|
+
|
|
169
|
+
def _init_db(self):
|
|
170
|
+
"""Initialize database tables."""
|
|
171
|
+
with self._get_connection() as conn:
|
|
172
|
+
conn.execute("""
|
|
173
|
+
CREATE TABLE IF NOT EXISTS pipeline_runs (
|
|
174
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
175
|
+
thread_id TEXT NOT NULL,
|
|
176
|
+
created_at TEXT NOT NULL,
|
|
177
|
+
updated_at TEXT NOT NULL,
|
|
178
|
+
status TEXT NOT NULL DEFAULT 'running',
|
|
179
|
+
state_json TEXT NOT NULL
|
|
180
|
+
)
|
|
181
|
+
""")
|
|
182
|
+
conn.execute("""
|
|
183
|
+
CREATE INDEX IF NOT EXISTS idx_thread_id
|
|
184
|
+
ON pipeline_runs(thread_id)
|
|
185
|
+
""")
|
|
186
|
+
conn.commit()
|
|
187
|
+
|
|
188
|
+
def save_state(self, thread_id: str, state: dict, status: str = "running") -> int:
|
|
189
|
+
"""Save pipeline state.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
thread_id: Unique identifier for this run
|
|
193
|
+
state: State dictionary to persist
|
|
194
|
+
status: Current status (running, completed, failed)
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Row ID of the saved state
|
|
198
|
+
"""
|
|
199
|
+
now = datetime.now().isoformat()
|
|
200
|
+
state_json = json.dumps(self._serialize_state(state), default=str)
|
|
201
|
+
|
|
202
|
+
with self._get_connection() as conn:
|
|
203
|
+
# Check if thread exists
|
|
204
|
+
existing = conn.execute(
|
|
205
|
+
"SELECT id FROM pipeline_runs WHERE thread_id = ?", (thread_id,)
|
|
206
|
+
).fetchone()
|
|
207
|
+
|
|
208
|
+
if existing:
|
|
209
|
+
conn.execute(
|
|
210
|
+
"""UPDATE pipeline_runs
|
|
211
|
+
SET updated_at = ?, status = ?, state_json = ?
|
|
212
|
+
WHERE thread_id = ?""",
|
|
213
|
+
(now, status, state_json, thread_id),
|
|
214
|
+
)
|
|
215
|
+
conn.commit()
|
|
216
|
+
return existing["id"]
|
|
217
|
+
else:
|
|
218
|
+
cursor = conn.execute(
|
|
219
|
+
"""INSERT INTO pipeline_runs
|
|
220
|
+
(thread_id, created_at, updated_at, status, state_json)
|
|
221
|
+
VALUES (?, ?, ?, ?, ?)""",
|
|
222
|
+
(thread_id, now, now, status, state_json),
|
|
223
|
+
)
|
|
224
|
+
conn.commit()
|
|
225
|
+
return cursor.lastrowid
|
|
226
|
+
|
|
227
|
+
def load_state(self, thread_id: str) -> dict | None:
|
|
228
|
+
"""Load pipeline state by thread ID.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
thread_id: Unique identifier for the run
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
State dictionary or None if not found
|
|
235
|
+
"""
|
|
236
|
+
with self._get_connection() as conn:
|
|
237
|
+
row = conn.execute(
|
|
238
|
+
"SELECT state_json FROM pipeline_runs WHERE thread_id = ?", (thread_id,)
|
|
239
|
+
).fetchone()
|
|
240
|
+
|
|
241
|
+
if row:
|
|
242
|
+
return json.loads(row["state_json"])
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
def get_run_info(self, thread_id: str) -> dict | None:
|
|
246
|
+
"""Get run metadata without full state.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
thread_id: Unique identifier for the run
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Dictionary with id, thread_id, created_at, updated_at, status
|
|
253
|
+
"""
|
|
254
|
+
with self._get_connection() as conn:
|
|
255
|
+
row = conn.execute(
|
|
256
|
+
"""SELECT id, thread_id, created_at, updated_at, status
|
|
257
|
+
FROM pipeline_runs WHERE thread_id = ?""",
|
|
258
|
+
(thread_id,),
|
|
259
|
+
).fetchone()
|
|
260
|
+
|
|
261
|
+
if row:
|
|
262
|
+
return dict(row)
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
def list_runs(self, limit: int = 10) -> list[dict]:
|
|
266
|
+
"""List recent pipeline runs.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
limit: Maximum number of runs to return
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
List of run metadata dictionaries
|
|
273
|
+
"""
|
|
274
|
+
with self._get_connection() as conn:
|
|
275
|
+
rows = conn.execute(
|
|
276
|
+
"""SELECT id, thread_id, created_at, updated_at, status
|
|
277
|
+
FROM pipeline_runs
|
|
278
|
+
ORDER BY updated_at DESC
|
|
279
|
+
LIMIT ?""",
|
|
280
|
+
(limit,),
|
|
281
|
+
).fetchall()
|
|
282
|
+
|
|
283
|
+
return [dict(row) for row in rows]
|
|
284
|
+
|
|
285
|
+
def delete_run(self, thread_id: str) -> bool:
|
|
286
|
+
"""Delete a pipeline run.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
thread_id: Unique identifier for the run
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
True if deleted, False if not found
|
|
293
|
+
"""
|
|
294
|
+
with self._get_connection() as conn:
|
|
295
|
+
cursor = conn.execute(
|
|
296
|
+
"DELETE FROM pipeline_runs WHERE thread_id = ?", (thread_id,)
|
|
297
|
+
)
|
|
298
|
+
conn.commit()
|
|
299
|
+
return cursor.rowcount > 0
|
|
300
|
+
|
|
301
|
+
def _serialize_state(self, state: dict) -> dict:
|
|
302
|
+
"""Convert state to JSON-serializable format.
|
|
303
|
+
|
|
304
|
+
Handles Pydantic models and other complex types.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
state: State dictionary
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
JSON-serializable dictionary
|
|
311
|
+
"""
|
|
312
|
+
result = {}
|
|
313
|
+
for key, value in state.items():
|
|
314
|
+
if isinstance(value, BaseModel):
|
|
315
|
+
result[key] = value.model_dump()
|
|
316
|
+
elif hasattr(value, "__dict__"):
|
|
317
|
+
result[key] = vars(value)
|
|
318
|
+
else:
|
|
319
|
+
result[key] = value
|
|
320
|
+
return result
|