simple-state-flow 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Marc Nealer
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: simple-state-flow
3
+ Version: 0.1.0
4
+ Summary: A graph-based state flow library using Pydantic
5
+ Author: Marc Nealer
6
+ License-Expression: MIT
7
+ Project-URL: repository, https://gitlab.com/mystuff6091446/state-flow.git
8
+ Requires-Python: >=3.10
9
+ License-File: LICENSE
10
+ Requires-Dist: pydantic>=2.12.5
11
+ Dynamic: license-file
@@ -0,0 +1,388 @@
1
+ # Python State Flow
2
+
3
+ This is a lightweight workflow orchestrator. You define a State object that is passed from task
4
+ to task, where you save any changes you make. At the end of the workflow, the State object is passed back to
5
+ you with the relevant results.
6
+
7
+ The library is based on LangGraphs StateGraph. There are a number of differences. Tasks or Nodes in my case are
8
+ classes and not functions. I have also enforced the use of Pydantic to create the State object.
9
+
10
+ This library allows you to define workflows similar to a DAG, except it allows for circular flows. Use this to split
11
+ your complex workflows into Nodes or tasks so they are straightforward to maintain and update
12
+
13
+
14
+ ## Features
15
+
16
+ - **Pydantic Integration**: Use Pydantic models for type-safe state management.
17
+ - **Synchronous & Asynchronous Support**: Define both sync and async nodes and flows.
18
+ - **Graph-Based Routing**: Easily define edges and conditional routing between nodes.
19
+ - **Type Safety**: Full type hint support for state objects.
20
+
21
+ ## Installation
22
+
23
+ This project uses `uv` for dependency management.
24
+
25
+ ```bash
26
+ uv add state-flow
27
+ ```
28
+
29
+ Or with pip:
30
+
31
+ ```bash
32
+ pip install state-flow
33
+ ```
34
+
35
+ ### Setup
36
+
37
+ ```bash
38
+ uv sync
39
+ ```
40
+
41
+ Alternatively, you can install the project in editable mode using pip:
42
+
43
+ ```bash
44
+ pip install -e .
45
+ ```
46
+
47
+ ## Core Concepts
48
+
49
+ ### 1. State
50
+
51
+ The State object is a predefined Pydantic `BaseModel`, which behaves similarly to a Python dataclass. It serves as the central data structure that is passed between nodes in your workflow, accumulating data and changes as the process executes.
52
+
53
+ We chose Pydantic for this role to leverage its robust validation and data transformation capabilities. By using Pydantic, you can define functions that automatically validate or transform fields as they are updated. This ensures that your state remains consistent and correct throughout the workflow execution.
54
+
55
+ For example, if a field is expected to be a Python `date` object but receives a string representation (e.g., from an API response or user input), you can use Pydantic's `BeforeValidator` to automatically parse that string into a proper date object before it is assigned to the state. This removes the need for manual parsing logic inside your nodes and keeps your business logic clean.
56
+
57
+ #### Example: Complex State with Validation
58
+
59
+ Here is an example of a more complex state object that demonstrates required fields, automatic type conversion using `BeforeValidator`, and fields for error handling and logging.
60
+
61
+ ```python
62
+ from datetime import datetime
63
+ from typing import Annotated, List, Optional
64
+ from pydantic import BaseModel, Field, BeforeValidator
65
+
66
+ def parse_datetime(v: str | datetime) -> datetime:
67
+ if isinstance(v, str):
68
+ return datetime.fromisoformat(v)
69
+ return v
70
+
71
+ class WorkflowState(BaseModel):
72
+ # Required fields that must be supplied when creating the state
73
+ user_id: str
74
+ request_id: str
75
+
76
+ # Fields with automatic conversion using BeforeValidator
77
+ # This will convert string inputs like "2023-10-27T10:00:00" into datetime objects
78
+ created_at: Annotated[datetime, BeforeValidator(parse_datetime)]
79
+ updated_at: Annotated[datetime, BeforeValidator(parse_datetime)] = Field(default_factory=datetime.now)
80
+
81
+ # Fields for tracking execution history and errors
82
+ logs: List[str] = Field(default_factory=list)
83
+ errors: List[str] = Field(default_factory=list)
84
+
85
+ # Optional data that might be populated during the workflow
86
+ processed_data: Optional[dict] = None
87
+
88
+ def log(self, message: str):
89
+ """Helper method to add log entries with timestamps"""
90
+ self.logs.append(f"{datetime.now().isoformat()}: {message}")
91
+
92
+ def add_error(self, error_msg: str):
93
+ """Helper method to record errors"""
94
+ self.errors.append(error_msg)
95
+
96
+ # Usage
97
+ # Even though we pass strings for dates, Pydantic converts them to datetime objects
98
+ state = WorkflowState(
99
+ user_id="user_123",
100
+ request_id="req_abc",
101
+ created_at="2023-10-27T10:00:00"
102
+ )
103
+
104
+ print(type(state.created_at)) # <class 'datetime.datetime'>
105
+ state.log("Workflow started")
106
+ ```
107
+
108
+ ### 2. Nodes
109
+
110
+ Nodes represent a single step in your workflow. Unlike some other libraries where nodes can be simple functions, in State Flow, a Node is a class.
111
+
112
+ ```python
113
+ from state_flow import Node, START, END
114
+
115
+ class MyNode(Node):
116
+ def exec(self, state: MyState):
117
+ state.value += 1
118
+ # Optional: set self.result for conditional routing
119
+ self.result = "success"
120
+ ```
121
+
122
+ To create a node, you must inherit from `Node` (for synchronous workflows) or `AsyncNode` (for asynchronous workflows) and override the `exec` method.
123
+
124
+ The `exec` method does not receive arguments. Instead, you access the current state via `self.state`. Inside this method, you can modify the state object directly.
125
+
126
+ After execution, the node returns the modified state object along with a `result` string. You don't need to return these manually; the base class handles it. However, you can set the `self.result` attribute within your `exec` method. This `result` string is crucial for **conditional edges**, as it determines which node will be executed next in the workflow.
127
+
128
+ ```python
129
+ class MyNode(Node[MyState]):
130
+ def exec(self) -> None:
131
+ # Modify the state
132
+ self.state.some_field = "new value"
133
+
134
+ # Set the result for conditional routing
135
+ if self.state.some_value > 10:
136
+ self.result = "high"
137
+ else:
138
+ self.result = "low"
139
+ ```
140
+
141
+ ### 3. Flows
142
+
143
+ Flows define the structure and execution logic of your workflow graph. To create a flow, you inherit from `StateFlow` (for synchronous workflows) or `AsyncStateFlow` (for asynchronous workflows).
144
+
145
+ The most critical part of defining a flow is overriding the `setup_graph` method (or `_setup_graph` for async flows). This is where you register your nodes and define the connections (edges) between them.
146
+
147
+ #### `add_node(name: str, node: Node)`
148
+ Registers a node in the graph.
149
+ - `name`: A unique string identifier for the node.
150
+ - `node`: An instance of your node class.
151
+
152
+ ```python
153
+ self.add_node("process_data", ProcessDataNode())
154
+ ```
155
+
156
+ #### `add_edge(from_node: str, to_node: str)`
157
+ Creates a direct connection between two nodes. When `from_node` finishes execution, `to_node` is executed next.
158
+ - `from_node`: The name of the source node.
159
+ - `to_node`: The name of the destination node.
160
+
161
+ ```python
162
+ self.add_edge("step_1", "step_2")
163
+ ```
164
+
165
+ #### `add_conditional_edges(from_node: str, path_map: dict[str, str])`
166
+ Creates dynamic routing based on the result of the `from_node`.
167
+ - `from_node`: The name of the source node.
168
+ - `path_map`: A dictionary mapping the `result` string (set in the node's `exec` method) to the name of the next node.
169
+
170
+ ```python
171
+ # If "decision_node" sets self.result = "success", go to "success_handler"
172
+ # If "decision_node" sets self.result = "failure", go to "error_handler"
173
+ self.add_conditional_edges("decision_node", {
174
+ "success": "success_handler",
175
+ "failure": "error_handler"
176
+ })
177
+ ```
178
+
179
+ #### `START` and `END`
180
+ These are special constants used to define the entry and exit points of your graph.
181
+ - `START`: Represents the beginning of the flow. You must add an edge from `START` to your first node.
182
+ - `END`: Represents the completion of the flow. Edges pointing to `END` signify that the workflow should terminate.
183
+
184
+ ```python
185
+ from src.nodes import START, END
186
+
187
+ self.add_edge(START, "first_node")
188
+ self.add_edge("last_node", END)
189
+ ```
190
+
191
+ ## Usage
192
+
193
+ ### Synchronous Flow Example
194
+
195
+ ```python
196
+ from typing import List
197
+ from pydantic import BaseModel, Field
198
+ from src.nodes import Node, START, END
199
+ from src.flows import StateFlow
200
+
201
+ # 1. Define your state
202
+ class MyState(BaseModel):
203
+ history: List[str] = Field(default_factory=list)
204
+
205
+ # 2. Define your nodes
206
+ class NodeA(Node[MyState]):
207
+ def exec(self) -> None:
208
+ self.state.history.append("A")
209
+ # You can set self.result to control conditional routing
210
+ self.result = "next"
211
+
212
+ class NodeB(Node[MyState]):
213
+ def exec(self) -> None:
214
+ self.state.history.append("B")
215
+
216
+ # 3. Define your flow
217
+ class SimpleFlow(StateFlow[MyState]):
218
+ def setup_graph(self) -> None:
219
+ self.add_node("A", NodeA())
220
+ self.add_node("B", NodeB())
221
+
222
+ self.add_edge(START, "A")
223
+ self.add_edge("A", "B")
224
+ self.add_edge("B", END)
225
+
226
+ # 4. Run the flow
227
+ flow = SimpleFlow()
228
+ final_state = flow.run(MyState())
229
+ print(final_state.history) # Output: ['A', 'B']
230
+ ```
231
+
232
+ ### Asynchronous Flow Example
233
+
234
+ ```python
235
+ import asyncio
236
+ from src.nodes import AsyncNode, START, END
237
+ from src.flows import AsyncStateFlow
238
+
239
+ class AsyncWorker(AsyncNode[MyState]):
240
+ async def exec(self) -> None:
241
+ await asyncio.sleep(0.1)
242
+ self.state.history.append("AsyncWork")
243
+
244
+ class AsyncFlow(AsyncStateFlow[MyState]):
245
+ def setup_graph(self) -> None:
246
+ self.add_node("worker", AsyncWorker())
247
+ self.add_edge(START, "worker")
248
+ self.add_edge("worker", END)
249
+
250
+ async def main():
251
+ flow = AsyncFlow()
252
+ final_state = await flow.run(MyState())
253
+ print(final_state.history)
254
+
255
+ if __name__ == "__main__":
256
+ asyncio.run(main())
257
+ ```
258
+
259
+ ### Conditional Edges
260
+
261
+ You can route to different nodes based on the `result` attribute set within a node.
262
+
263
+ ```python
264
+ class DecisionNode(Node[MyState]):
265
+ def exec(self) -> None:
266
+ if len(self.state.history) > 5:
267
+ self.result = "long"
268
+ else:
269
+ self.result = "short"
270
+
271
+ class MyFlow(StateFlow[MyState]):
272
+ def setup_graph(self) -> None:
273
+ self.add_node("decision", DecisionNode())
274
+ self.add_node("process_long", LongNode())
275
+ self.add_node("process_short", ShortNode())
276
+
277
+ self.add_edge(START, "decision")
278
+ self.add_conditional_edges("decision", {
279
+ "long": "process_long",
280
+ "short": "process_short"
281
+ })
282
+ self.add_edge("process_long", END)
283
+ self.add_edge("process_short", END)
284
+ ```
285
+
286
+ ## Complex Example: API Pagination to DataFrame
287
+
288
+ This example demonstrates a workflow that fetches data from a paginated API, accumulates the results, and finally converts them into a Pandas DataFrame.
289
+
290
+ ### Workflow Diagram
291
+
292
+ ```mermaid
293
+ graph TD
294
+ START((Start)) --> Init[Initialize DataFrame]
295
+ Init --> Fetch[Fetch Page]
296
+ Fetch --> Update[Update DataFrame]
297
+ Update --> Check{More Pages?}
298
+ Check -- Yes --> Fetch
299
+ Check -- No --> END((End))
300
+ ```
301
+
302
+ ### Code
303
+
304
+ ```python
305
+ import pandas as pd
306
+ from typing import List, Any, Optional
307
+ from pydantic import BaseModel, Field, ConfigDict
308
+ from src.nodes import Node, START, END
309
+ from src.flows import StateFlow
310
+
311
+ # 1. Define State
312
+ class ApiState(BaseModel):
313
+ # Allow arbitrary types for pandas DataFrame
314
+ model_config = ConfigDict(arbitrary_types_allowed=True)
315
+
316
+ current_page: int = 1
317
+ total_pages: int = 1
318
+ current_data: List[dict] = Field(default_factory=list)
319
+ final_dataframe: Optional[pd.DataFrame] = None
320
+ base_url: str
321
+
322
+ # 2. Define Nodes
323
+
324
+ class InitDataFrameNode(Node[ApiState]):
325
+ def exec(self) -> None:
326
+ print("Initializing empty DataFrame...")
327
+ self.state.final_dataframe = pd.DataFrame()
328
+
329
+ class FetchPageNode(Node[ApiState]):
330
+ def exec(self) -> None:
331
+ # Simulate API call
332
+ print(f"Fetching page {self.state.current_page}...")
333
+
334
+ # Mock response data
335
+ mock_response = {
336
+ "data": [{"id": i, "value": f"val_{i}"} for i in range(self.state.current_page * 10, (self.state.current_page + 1) * 10)],
337
+ "total_pages": 3
338
+ }
339
+
340
+ # Store current page data
341
+ self.state.current_data = mock_response["data"]
342
+ self.state.total_pages = mock_response["total_pages"]
343
+
344
+ class UpdateDataFrameNode(Node[ApiState]):
345
+ def exec(self) -> None:
346
+ print("Updating DataFrame with new data...")
347
+ new_df = pd.DataFrame(self.state.current_data)
348
+ if self.state.final_dataframe is None:
349
+ self.state.final_dataframe = new_df
350
+ else:
351
+ self.state.final_dataframe = pd.concat([self.state.final_dataframe, new_df], ignore_index=True)
352
+
353
+ # Determine next step
354
+ if self.state.current_page < self.state.total_pages:
355
+ self.state.current_page += 1
356
+ self.result = "next_page"
357
+ else:
358
+ self.result = "done"
359
+
360
+ # 3. Define Flow
361
+
362
+ class ApiPaginationFlow(StateFlow[ApiState]):
363
+ def setup_graph(self) -> None:
364
+ self.add_node("init_df", InitDataFrameNode())
365
+ self.add_node("fetch_page", FetchPageNode())
366
+ self.add_node("update_df", UpdateDataFrameNode())
367
+
368
+ # Start by initializing DataFrame
369
+ self.add_edge(START, "init_df")
370
+ self.add_edge("init_df", "fetch_page")
371
+ self.add_edge("fetch_page", "update_df")
372
+
373
+ # Loop back to fetch_page if there are more pages, otherwise end
374
+ self.add_conditional_edges("update_df", {
375
+ "next_page": "fetch_page",
376
+ "done": END
377
+ })
378
+
379
+ # 4. Run
380
+ if __name__ == "__main__":
381
+ initial_state = ApiState(base_url="https://api.example.com/items")
382
+ flow = ApiPaginationFlow()
383
+ result_state = flow.run(initial_state)
384
+
385
+ print("\nFinal DataFrame:")
386
+ print(result_state.final_dataframe.head())
387
+ print(f"Total rows: {len(result_state.final_dataframe)}")
388
+ ```
@@ -0,0 +1,25 @@
1
+ [project]
2
+ name = "simple-state-flow"
3
+ version = "0.1.0"
4
+ description = "A graph-based state flow library using Pydantic"
5
+ authors = [
6
+ { name = "Marc Nealer" }
7
+ ]
8
+ license = "MIT"
9
+ urls = { repository = "https://gitlab.com/mystuff6091446/state-flow.git" }
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "pydantic>=2.12.5",
13
+ ]
14
+
15
+ [build-system]
16
+ requires = ["setuptools>=61.0"]
17
+ build-backend = "setuptools.build_meta"
18
+
19
+ [tool.setuptools.packages.find]
20
+ where = ["src"]
21
+
22
+ [dependency-groups]
23
+ dev = [
24
+ "build>=1.4.0",
25
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: simple-state-flow
3
+ Version: 0.1.0
4
+ Summary: A graph-based state flow library using Pydantic
5
+ Author: Marc Nealer
6
+ License-Expression: MIT
7
+ Project-URL: repository, https://gitlab.com/mystuff6091446/state-flow.git
8
+ Requires-Python: >=3.10
9
+ License-File: LICENSE
10
+ Requires-Dist: pydantic>=2.12.5
11
+ Dynamic: license-file
@@ -0,0 +1,13 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/simple_state_flow.egg-info/PKG-INFO
5
+ src/simple_state_flow.egg-info/SOURCES.txt
6
+ src/simple_state_flow.egg-info/dependency_links.txt
7
+ src/simple_state_flow.egg-info/requires.txt
8
+ src/simple_state_flow.egg-info/top_level.txt
9
+ src/state_flow/__init__.py
10
+ src/state_flow/flows.py
11
+ src/state_flow/nodes.py
12
+ tests/test_async_flow.py
13
+ tests/test_random_flow.py
@@ -0,0 +1 @@
1
+ pydantic>=2.12.5
@@ -0,0 +1,4 @@
1
+ from state_flow.flows import StateFlow, AsyncStateFlow
2
+ from state_flow.nodes import Node, AsyncNode, START, END
3
+
4
+ __all__ = ["StateFlow", "AsyncStateFlow", "Node", "AsyncNode", "START", "END"]
@@ -0,0 +1,185 @@
1
+ import asyncio
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Callable, Dict, List, Optional, Type, Union, TypeVar, Generic, Coroutine
4
+ from pydantic import BaseModel
5
+ from state_flow.nodes import START, END, Node, AsyncNode
6
+
7
+ T = TypeVar("T", bound=BaseModel)
8
+
9
+ class StateFlow(ABC, Generic[T]):
10
+ """Standard synchronous state flow."""
11
+
12
+ def __init__(self) -> None:
13
+ self.nodes: Dict[str, Node[T]] = {}
14
+ self.edges: Dict[str, str] = {}
15
+ self.conditional_edges: Dict[str, Dict[str, str]] = {}
16
+ self.setup_graph()
17
+
18
+ @abstractmethod
19
+ def setup_graph(self) -> None:
20
+ """Override this method to add nodes and edges to the flow."""
21
+ pass
22
+
23
+ def add_node(self, name: str, node: Node[T]) -> None:
24
+ """Add a node to the flow.
25
+
26
+ Args:
27
+ name: The name of the node.
28
+ node: The Node instance to execute for this node.
29
+ """
30
+ self.nodes[name] = node
31
+
32
+ def add_edge(self, start_node: str, end_node: str) -> None:
33
+ """Add a directed edge between two nodes.
34
+
35
+ Args:
36
+ start_node: The name of the starting node.
37
+ end_node: The name of the ending node.
38
+ """
39
+ self.edges[start_node] = end_node
40
+
41
+ def add_conditional_edges(
42
+ self,
43
+ start_node: str,
44
+ mapping: Dict[str, str]
45
+ ) -> None:
46
+ """Add conditional edges from a node.
47
+
48
+ Args:
49
+ start_node: The name of the starting node.
50
+ mapping: A dictionary mapping node execution results to destination nodes.
51
+ """
52
+ self.conditional_edges[start_node] = mapping
53
+
54
+ def run(self, state: T) -> T:
55
+ """Run the flow starting from the START node.
56
+
57
+ Args:
58
+ state: The initial Pydantic state object.
59
+
60
+ Returns:
61
+ The final state object.
62
+ """
63
+ current_node = START
64
+
65
+ while current_node != END:
66
+ # Determine next node
67
+ if current_node == START:
68
+ next_node = self.edges.get(START)
69
+ if not next_node:
70
+ raise ValueError("START node must have an outgoing edge.")
71
+ current_node = next_node
72
+ continue
73
+
74
+ # Execute current node action
75
+ if current_node not in self.nodes:
76
+ raise ValueError(f"Node {current_node} not found in flow.")
77
+
78
+ state, result = self.nodes[current_node].run(state)
79
+
80
+ # Check for conditional edges first
81
+ if current_node in self.conditional_edges:
82
+ mapping = self.conditional_edges[current_node]
83
+ if result in mapping:
84
+ current_node = mapping[result]
85
+ elif current_node in self.edges:
86
+ current_node = self.edges[current_node]
87
+ else:
88
+ raise ValueError(f"Result '{result}' not found in mapping for node {current_node} and no default edge found.")
89
+ # Check for regular edges
90
+ elif current_node in self.edges:
91
+ current_node = self.edges[current_node]
92
+ else:
93
+ if current_node != END:
94
+ raise ValueError(f"Node {current_node} has no outgoing edges and is not END.")
95
+
96
+ return state
97
+
98
+ class AsyncStateFlow(ABC, Generic[T]):
99
+ """Asynchronous state flow that only accepts async nodes."""
100
+
101
+ def __init__(self) -> None:
102
+ self.nodes: Dict[str, AsyncNode[T]] = {}
103
+ self.edges: Dict[str, str] = {}
104
+ self.conditional_edges: Dict[str, Dict[str, str]] = {}
105
+ self._setup_graph()
106
+
107
+ @abstractmethod
108
+ def _setup_graph(self) -> None:
109
+ """Override this method to add nodes and edges to the flow."""
110
+ pass
111
+
112
+ def add_node(self, name: str, node: AsyncNode[T]) -> None:
113
+ """Add an async node to the flow.
114
+
115
+ Args:
116
+ name: The name of the node.
117
+ node: The AsyncNode instance to execute for this node.
118
+ """
119
+ self.nodes[name] = node
120
+
121
+ def add_edge(self, start_node: str, end_node: str) -> None:
122
+ """Add a directed edge between two nodes.
123
+
124
+ Args:
125
+ start_node: The name of the starting node.
126
+ end_node: The name of the ending node.
127
+ """
128
+ self.edges[start_node] = end_node
129
+
130
+ def add_conditional_edges(
131
+ self,
132
+ start_node: str,
133
+ mapping: Dict[str, str]
134
+ ) -> None:
135
+ """Add conditional edges from a node.
136
+
137
+ Args:
138
+ start_node: The name of the starting node.
139
+ mapping: A dictionary mapping node execution results to destination nodes.
140
+ """
141
+ self.conditional_edges[start_node] = mapping
142
+
143
+ async def run(self, state: T) -> T:
144
+ """Run the flow starting from the START node.
145
+
146
+ Args:
147
+ state: The initial Pydantic state object.
148
+
149
+ Returns:
150
+ The final state object.
151
+ """
152
+ current_node = START
153
+
154
+ while current_node != END:
155
+ # Determine next node
156
+ if current_node == START:
157
+ next_node = self.edges.get(START)
158
+ if not next_node:
159
+ raise ValueError("START node must have an outgoing edge.")
160
+ current_node = next_node
161
+ continue
162
+
163
+ # Execute current node action
164
+ if current_node not in self.nodes:
165
+ raise ValueError(f"Node {current_node} not found in flow.")
166
+
167
+ state, result = await self.nodes[current_node].run(state)
168
+
169
+ # Check for conditional edges first
170
+ if current_node in self.conditional_edges:
171
+ mapping = self.conditional_edges[current_node]
172
+ if result in mapping:
173
+ current_node = mapping[result]
174
+ elif current_node in self.edges:
175
+ current_node = self.edges[current_node]
176
+ else:
177
+ raise ValueError(f"Result '{result}' not found in mapping for node {current_node} and no default edge found.")
178
+ # Check for regular edges
179
+ elif current_node in self.edges:
180
+ current_node = self.edges[current_node]
181
+ else:
182
+ if current_node != END:
183
+ raise ValueError(f"Node {current_node} has no outgoing edges and is not END.")
184
+
185
+ return state
@@ -0,0 +1,112 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple, TypeVar, Generic
3
+ from pydantic import BaseModel
4
+
5
+ T = TypeVar("T", bound=BaseModel)
6
+
7
+ START = "__start__"
8
+ END = "__end__"
9
+
10
+ class Node(ABC, Generic[T]):
11
+ """Abstract base class for synchronous nodes."""
12
+
13
+ def __init__(self) -> None:
14
+ self.state: T = None # type: ignore
15
+ self.result: str = "done"
16
+
17
+ @abstractmethod
18
+ def exec(self, state: T) -> None:
19
+ """Abstract method that must be implemented by subclasses.
20
+
21
+ Args:
22
+ state: The current state object.
23
+
24
+ Returns:
25
+ A tuple of (updated state, next status).
26
+ """
27
+ pass
28
+
29
+ def _pre_run(self, state: T) -> None:
30
+ """Internal method to set the state before execution.
31
+
32
+ Args:
33
+ state: The current state object.
34
+
35
+ Raises:
36
+ TypeError: If state is not a Pydantic BaseModel.
37
+ """
38
+ if not isinstance(state, BaseModel):
39
+ raise TypeError("State must be a Pydantic BaseModel.")
40
+ self.state = state
41
+ self.result = "done"
42
+
43
+ def _post_run(self, state: T, result: str) -> None:
44
+ """Internal method to update the state after execution.
45
+
46
+ Args:
47
+ state: The updated state object.
48
+ result: The execution result.
49
+ """
50
+ if state is not None:
51
+ self.state = state
52
+ if result is not None:
53
+ self.result = result
54
+
55
+ def run(self, state: T) -> Tuple[T, str]:
56
+ """Internal method to execute the node.
57
+
58
+ Args:
59
+ state: The current state object.
60
+
61
+ Returns:
62
+ A tuple of (updated state, next status).
63
+ """
64
+ self._pre_run(state)
65
+ self.exec(self.state)
66
+ return self.state, self.result
67
+
68
+ class AsyncNode(ABC, Generic[T]):
69
+ """Abstract base class for asynchronous nodes."""
70
+
71
+ def __init__(self) -> None:
72
+ self.state: T = None # type: ignore
73
+ self.result: str = "done"
74
+
75
+ @abstractmethod
76
+ async def exec(self, state: T) -> None:
77
+ """Abstract method that must be implemented by subclasses.
78
+
79
+ Args:
80
+ state: The current state object.
81
+
82
+ Returns:
83
+ A tuple of (updated state, next status).
84
+ """
85
+ pass
86
+
87
+ def _pre_run(self, state: T) -> None:
88
+ """Internal method to set the state before execution.
89
+
90
+ Args:
91
+ state: The current state object.
92
+
93
+ Raises:
94
+ TypeError: If state is not a Pydantic BaseModel.
95
+ """
96
+ if not isinstance(state, BaseModel):
97
+ raise TypeError("State must be a Pydantic BaseModel.")
98
+ self.state = state
99
+ self.result = "done"
100
+
101
+ async def run(self, state: T) -> Tuple[T, str]:
102
+ """Internal method to execute the node.
103
+
104
+ Args:
105
+ state: The current state object.
106
+
107
+ Returns:
108
+ A tuple of (updated state, next status).
109
+ """
110
+ self._pre_run(state)
111
+ await self.exec(self.state)
112
+ return self.state, self.result
@@ -0,0 +1,72 @@
1
+ import asyncio
2
+ import random
3
+ from typing import List
4
+ from pydantic import BaseModel, Field
5
+ from state_flow.nodes import AsyncNode, START, END
6
+ from state_flow.flows import AsyncStateFlow
7
+
8
+ class AsyncFlowState(BaseModel):
9
+ history: List[str] = Field(default_factory=list)
10
+
11
+ class AsyncNodeA(AsyncNode):
12
+ async def exec(self, state: AsyncFlowState) -> None:
13
+ await asyncio.sleep(0.01)
14
+ state.history.append("A")
15
+
16
+ class AsyncNodeB(AsyncNode):
17
+ async def exec(self, state: AsyncFlowState) -> None:
18
+ await asyncio.sleep(0.01)
19
+ state.history.append("B")
20
+ self.result = random.choice(["to_c", "to_d"])
21
+
22
+ class AsyncNodeC(AsyncNode):
23
+ async def exec(self, state: AsyncFlowState) -> None:
24
+ await asyncio.sleep(0.01)
25
+ state.history.append("C")
26
+
27
+ class AsyncNodeD(AsyncNode):
28
+ async def exec(self, state: AsyncFlowState) -> None:
29
+ await asyncio.sleep(0.01)
30
+ state.history.append("D")
31
+
32
+ class SimpleAsyncFlow(AsyncStateFlow[AsyncFlowState]):
33
+ def _setup_graph(self) -> None:
34
+ self.add_node("A", AsyncNodeA())
35
+ self.add_node("B", AsyncNodeB())
36
+ self.add_node("C", AsyncNodeC())
37
+ self.add_node("D", AsyncNodeD())
38
+
39
+ self.add_edge(START, "A")
40
+ self.add_edge("A", "B")
41
+ self.add_conditional_edges("B", {"to_c": "C", "to_d": "D"})
42
+ self.add_edge("C", END)
43
+ self.add_edge("D", END)
44
+
45
+
46
+ async def test_async_flow():
47
+ flow = SimpleAsyncFlow()
48
+
49
+ seen_c = False
50
+ seen_d = False
51
+
52
+ for _ in range(100):
53
+ state = AsyncFlowState()
54
+ final_state = await flow.run(state)
55
+
56
+ assert final_state.history[0] == "A"
57
+ assert final_state.history[1] == "B"
58
+
59
+ last_node = final_state.history[2]
60
+ assert last_node in ["C", "D"]
61
+
62
+ if last_node == "C":
63
+ seen_c = True
64
+ if last_node == "D":
65
+ seen_d = True
66
+
67
+ assert seen_c, "Node C was never reached"
68
+ assert seen_d, "Node D was never reached"
69
+ print(f"Async flow history: {final_state.history}")
70
+
71
+ if __name__ == "__main__":
72
+ asyncio.run(test_async_flow())
@@ -0,0 +1,72 @@
1
+ import random
2
+ from typing import List
3
+ from pydantic import BaseModel, Field
4
+ from state_flow.nodes import Node, START, END
5
+ from state_flow.flows import StateFlow
6
+
7
+
8
+
9
+ class FlowState(BaseModel):
10
+ history: List[str] = Field(default_factory=list)
11
+
12
+ class NodeA(Node):
13
+ def exec(self, state: FlowState) -> None:
14
+ state.history.append("A")
15
+
16
+ class NodeB(Node):
17
+ def exec(self, state: FlowState) -> None:
18
+ state.history.append("B")
19
+ self.result = random.choice(["to_c", "to_d"])
20
+
21
+ class NodeC(Node):
22
+ def exec(self, state: FlowState) -> None:
23
+ state.history.append("C")
24
+
25
+ class NodeD(Node):
26
+ def exec(self, state: FlowState) -> None:
27
+ state.history.append("D")
28
+
29
+ class RandomFlow(StateFlow[FlowState]):
30
+ def setup_graph(self) -> None:
31
+ self.add_node("A", NodeA())
32
+ self.add_node("B", NodeB())
33
+ self.add_node("C", NodeC())
34
+ self.add_node("D", NodeD())
35
+
36
+ self.add_edge(START, "A")
37
+ self.add_edge("A", "B")
38
+ self.add_conditional_edges("B", {"to_c": "C", "to_d": "D"})
39
+ self.add_edge("C", END)
40
+ self.add_edge("D", END)
41
+
42
+ def test_random_flow():
43
+ flow = RandomFlow()
44
+
45
+ seen_c = False
46
+ seen_d = False
47
+
48
+ # Run up to 100 times to avoid infinite loop if something is wrong,
49
+ # but practically it should hit both very quickly.
50
+ for _ in range(100):
51
+ state = FlowState()
52
+ final_state = flow.run(state)
53
+
54
+ assert final_state.history[0] == "A"
55
+ assert final_state.history[1] == "B"
56
+
57
+ last_node = final_state.history[2]
58
+ print(final_state)
59
+ assert last_node in ["C", "D"]
60
+
61
+ if last_node == "C":
62
+ seen_c = True
63
+ if last_node == "D":
64
+ seen_d = True
65
+
66
+
67
+ assert seen_c, "Node C was never reached"
68
+ assert seen_d, "Node D was never reached"
69
+ print("Test passed: Both nodes C and D were reached randomly.")
70
+
71
+ if __name__ == "__main__":
72
+ test_random_flow()