simple-state-flow 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simple_state_flow-0.1.0/LICENSE +21 -0
- simple_state_flow-0.1.0/PKG-INFO +11 -0
- simple_state_flow-0.1.0/README.md +388 -0
- simple_state_flow-0.1.0/pyproject.toml +25 -0
- simple_state_flow-0.1.0/setup.cfg +4 -0
- simple_state_flow-0.1.0/src/simple_state_flow.egg-info/PKG-INFO +11 -0
- simple_state_flow-0.1.0/src/simple_state_flow.egg-info/SOURCES.txt +13 -0
- simple_state_flow-0.1.0/src/simple_state_flow.egg-info/dependency_links.txt +1 -0
- simple_state_flow-0.1.0/src/simple_state_flow.egg-info/requires.txt +1 -0
- simple_state_flow-0.1.0/src/simple_state_flow.egg-info/top_level.txt +1 -0
- simple_state_flow-0.1.0/src/state_flow/__init__.py +4 -0
- simple_state_flow-0.1.0/src/state_flow/flows.py +185 -0
- simple_state_flow-0.1.0/src/state_flow/nodes.py +112 -0
- simple_state_flow-0.1.0/tests/test_async_flow.py +72 -0
- simple_state_flow-0.1.0/tests/test_random_flow.py +72 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Marc Nealer
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: simple-state-flow
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A graph-based state flow library using Pydantic
|
|
5
|
+
Author: Marc Nealer
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: repository, https://gitlab.com/mystuff6091446/state-flow.git
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: pydantic>=2.12.5
|
|
11
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
# Python State Flow
|
|
2
|
+
|
|
3
|
+
This is a lightweight workflow orchestrator. You define a State object that is passed from task
|
|
4
|
+
to task, where you save any changes you make. At the end of the workflow, the State object is passed back to
|
|
5
|
+
you with the relevant results.
|
|
6
|
+
|
|
7
|
+
The library is based on LangGraphs StateGraph. There are a number of differences. Tasks or Nodes in my case are
|
|
8
|
+
classes and not functions. I have also enforced the use of Pydantic to create the State object.
|
|
9
|
+
|
|
10
|
+
This library allows you to define workflows similar to a DAG, except it allows for circular flows. Use this to split
|
|
11
|
+
your complex workflows into Nodes or tasks so they are straightforward to maintain and update
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
## Features
|
|
15
|
+
|
|
16
|
+
- **Pydantic Integration**: Use Pydantic models for type-safe state management.
|
|
17
|
+
- **Synchronous & Asynchronous Support**: Define both sync and async nodes and flows.
|
|
18
|
+
- **Graph-Based Routing**: Easily define edges and conditional routing between nodes.
|
|
19
|
+
- **Type Safety**: Full type hint support for state objects.
|
|
20
|
+
|
|
21
|
+
## Installation
|
|
22
|
+
|
|
23
|
+
This project uses `uv` for dependency management.
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
uv add state-flow
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
Or with pip:
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
pip install state-flow
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
### Setup
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
uv sync
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Alternatively, you can install the project in editable mode using pip:
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install -e .
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
## Core Concepts
|
|
48
|
+
|
|
49
|
+
### 1. State
|
|
50
|
+
|
|
51
|
+
The State object is a predefined Pydantic `BaseModel`, which behaves similarly to a Python dataclass. It serves as the central data structure that is passed between nodes in your workflow, accumulating data and changes as the process executes.
|
|
52
|
+
|
|
53
|
+
We chose Pydantic for this role to leverage its robust validation and data transformation capabilities. By using Pydantic, you can define functions that automatically validate or transform fields as they are updated. This ensures that your state remains consistent and correct throughout the workflow execution.
|
|
54
|
+
|
|
55
|
+
For example, if a field is expected to be a Python `date` object but receives a string representation (e.g., from an API response or user input), you can use Pydantic's `BeforeValidator` to automatically parse that string into a proper date object before it is assigned to the state. This removes the need for manual parsing logic inside your nodes and keeps your business logic clean.
|
|
56
|
+
|
|
57
|
+
#### Example: Complex State with Validation
|
|
58
|
+
|
|
59
|
+
Here is an example of a more complex state object that demonstrates required fields, automatic type conversion using `BeforeValidator`, and fields for error handling and logging.
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
from datetime import datetime
|
|
63
|
+
from typing import Annotated, List, Optional
|
|
64
|
+
from pydantic import BaseModel, Field, BeforeValidator
|
|
65
|
+
|
|
66
|
+
def parse_datetime(v: str | datetime) -> datetime:
|
|
67
|
+
if isinstance(v, str):
|
|
68
|
+
return datetime.fromisoformat(v)
|
|
69
|
+
return v
|
|
70
|
+
|
|
71
|
+
class WorkflowState(BaseModel):
|
|
72
|
+
# Required fields that must be supplied when creating the state
|
|
73
|
+
user_id: str
|
|
74
|
+
request_id: str
|
|
75
|
+
|
|
76
|
+
# Fields with automatic conversion using BeforeValidator
|
|
77
|
+
# This will convert string inputs like "2023-10-27T10:00:00" into datetime objects
|
|
78
|
+
created_at: Annotated[datetime, BeforeValidator(parse_datetime)]
|
|
79
|
+
updated_at: Annotated[datetime, BeforeValidator(parse_datetime)] = Field(default_factory=datetime.now)
|
|
80
|
+
|
|
81
|
+
# Fields for tracking execution history and errors
|
|
82
|
+
logs: List[str] = Field(default_factory=list)
|
|
83
|
+
errors: List[str] = Field(default_factory=list)
|
|
84
|
+
|
|
85
|
+
# Optional data that might be populated during the workflow
|
|
86
|
+
processed_data: Optional[dict] = None
|
|
87
|
+
|
|
88
|
+
def log(self, message: str):
|
|
89
|
+
"""Helper method to add log entries with timestamps"""
|
|
90
|
+
self.logs.append(f"{datetime.now().isoformat()}: {message}")
|
|
91
|
+
|
|
92
|
+
def add_error(self, error_msg: str):
|
|
93
|
+
"""Helper method to record errors"""
|
|
94
|
+
self.errors.append(error_msg)
|
|
95
|
+
|
|
96
|
+
# Usage
|
|
97
|
+
# Even though we pass strings for dates, Pydantic converts them to datetime objects
|
|
98
|
+
state = WorkflowState(
|
|
99
|
+
user_id="user_123",
|
|
100
|
+
request_id="req_abc",
|
|
101
|
+
created_at="2023-10-27T10:00:00"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
print(type(state.created_at)) # <class 'datetime.datetime'>
|
|
105
|
+
state.log("Workflow started")
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### 2. Nodes
|
|
109
|
+
|
|
110
|
+
Nodes represent a single step in your workflow. Unlike some other libraries where nodes can be simple functions, in State Flow, a Node is a class.
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
from state_flow import Node, START, END
|
|
114
|
+
|
|
115
|
+
class MyNode(Node):
|
|
116
|
+
def exec(self, state: MyState):
|
|
117
|
+
state.value += 1
|
|
118
|
+
# Optional: set self.result for conditional routing
|
|
119
|
+
self.result = "success"
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
To create a node, you must inherit from `Node` (for synchronous workflows) or `AsyncNode` (for asynchronous workflows) and override the `exec` method.
|
|
123
|
+
|
|
124
|
+
The `exec` method does not receive arguments. Instead, you access the current state via `self.state`. Inside this method, you can modify the state object directly.
|
|
125
|
+
|
|
126
|
+
After execution, the node returns the modified state object along with a `result` string. You don't need to return these manually; the base class handles it. However, you can set the `self.result` attribute within your `exec` method. This `result` string is crucial for **conditional edges**, as it determines which node will be executed next in the workflow.
|
|
127
|
+
|
|
128
|
+
```python
|
|
129
|
+
class MyNode(Node[MyState]):
|
|
130
|
+
def exec(self) -> None:
|
|
131
|
+
# Modify the state
|
|
132
|
+
self.state.some_field = "new value"
|
|
133
|
+
|
|
134
|
+
# Set the result for conditional routing
|
|
135
|
+
if self.state.some_value > 10:
|
|
136
|
+
self.result = "high"
|
|
137
|
+
else:
|
|
138
|
+
self.result = "low"
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
### 3. Flows
|
|
142
|
+
|
|
143
|
+
Flows define the structure and execution logic of your workflow graph. To create a flow, you inherit from `StateFlow` (for synchronous workflows) or `AsyncStateFlow` (for asynchronous workflows).
|
|
144
|
+
|
|
145
|
+
The most critical part of defining a flow is overriding the `setup_graph` method (or `_setup_graph` for async flows). This is where you register your nodes and define the connections (edges) between them.
|
|
146
|
+
|
|
147
|
+
#### `add_node(name: str, node: Node)`
|
|
148
|
+
Registers a node in the graph.
|
|
149
|
+
- `name`: A unique string identifier for the node.
|
|
150
|
+
- `node`: An instance of your node class.
|
|
151
|
+
|
|
152
|
+
```python
|
|
153
|
+
self.add_node("process_data", ProcessDataNode())
|
|
154
|
+
```
|
|
155
|
+
|
|
156
|
+
#### `add_edge(from_node: str, to_node: str)`
|
|
157
|
+
Creates a direct connection between two nodes. When `from_node` finishes execution, `to_node` is executed next.
|
|
158
|
+
- `from_node`: The name of the source node.
|
|
159
|
+
- `to_node`: The name of the destination node.
|
|
160
|
+
|
|
161
|
+
```python
|
|
162
|
+
self.add_edge("step_1", "step_2")
|
|
163
|
+
```
|
|
164
|
+
|
|
165
|
+
#### `add_conditional_edges(from_node: str, path_map: dict[str, str])`
|
|
166
|
+
Creates dynamic routing based on the result of the `from_node`.
|
|
167
|
+
- `from_node`: The name of the source node.
|
|
168
|
+
- `path_map`: A dictionary mapping the `result` string (set in the node's `exec` method) to the name of the next node.
|
|
169
|
+
|
|
170
|
+
```python
|
|
171
|
+
# If "decision_node" sets self.result = "success", go to "success_handler"
|
|
172
|
+
# If "decision_node" sets self.result = "failure", go to "error_handler"
|
|
173
|
+
self.add_conditional_edges("decision_node", {
|
|
174
|
+
"success": "success_handler",
|
|
175
|
+
"failure": "error_handler"
|
|
176
|
+
})
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
#### `START` and `END`
|
|
180
|
+
These are special constants used to define the entry and exit points of your graph.
|
|
181
|
+
- `START`: Represents the beginning of the flow. You must add an edge from `START` to your first node.
|
|
182
|
+
- `END`: Represents the completion of the flow. Edges pointing to `END` signify that the workflow should terminate.
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
from src.nodes import START, END
|
|
186
|
+
|
|
187
|
+
self.add_edge(START, "first_node")
|
|
188
|
+
self.add_edge("last_node", END)
|
|
189
|
+
```
|
|
190
|
+
|
|
191
|
+
## Usage
|
|
192
|
+
|
|
193
|
+
### Synchronous Flow Example
|
|
194
|
+
|
|
195
|
+
```python
|
|
196
|
+
from typing import List
|
|
197
|
+
from pydantic import BaseModel, Field
|
|
198
|
+
from src.nodes import Node, START, END
|
|
199
|
+
from src.flows import StateFlow
|
|
200
|
+
|
|
201
|
+
# 1. Define your state
|
|
202
|
+
class MyState(BaseModel):
|
|
203
|
+
history: List[str] = Field(default_factory=list)
|
|
204
|
+
|
|
205
|
+
# 2. Define your nodes
|
|
206
|
+
class NodeA(Node[MyState]):
|
|
207
|
+
def exec(self) -> None:
|
|
208
|
+
self.state.history.append("A")
|
|
209
|
+
# You can set self.result to control conditional routing
|
|
210
|
+
self.result = "next"
|
|
211
|
+
|
|
212
|
+
class NodeB(Node[MyState]):
|
|
213
|
+
def exec(self) -> None:
|
|
214
|
+
self.state.history.append("B")
|
|
215
|
+
|
|
216
|
+
# 3. Define your flow
|
|
217
|
+
class SimpleFlow(StateFlow[MyState]):
|
|
218
|
+
def setup_graph(self) -> None:
|
|
219
|
+
self.add_node("A", NodeA())
|
|
220
|
+
self.add_node("B", NodeB())
|
|
221
|
+
|
|
222
|
+
self.add_edge(START, "A")
|
|
223
|
+
self.add_edge("A", "B")
|
|
224
|
+
self.add_edge("B", END)
|
|
225
|
+
|
|
226
|
+
# 4. Run the flow
|
|
227
|
+
flow = SimpleFlow()
|
|
228
|
+
final_state = flow.run(MyState())
|
|
229
|
+
print(final_state.history) # Output: ['A', 'B']
|
|
230
|
+
```
|
|
231
|
+
|
|
232
|
+
### Asynchronous Flow Example
|
|
233
|
+
|
|
234
|
+
```python
|
|
235
|
+
import asyncio
|
|
236
|
+
from src.nodes import AsyncNode, START, END
|
|
237
|
+
from src.flows import AsyncStateFlow
|
|
238
|
+
|
|
239
|
+
class AsyncWorker(AsyncNode[MyState]):
|
|
240
|
+
async def exec(self) -> None:
|
|
241
|
+
await asyncio.sleep(0.1)
|
|
242
|
+
self.state.history.append("AsyncWork")
|
|
243
|
+
|
|
244
|
+
class AsyncFlow(AsyncStateFlow[MyState]):
|
|
245
|
+
def setup_graph(self) -> None:
|
|
246
|
+
self.add_node("worker", AsyncWorker())
|
|
247
|
+
self.add_edge(START, "worker")
|
|
248
|
+
self.add_edge("worker", END)
|
|
249
|
+
|
|
250
|
+
async def main():
|
|
251
|
+
flow = AsyncFlow()
|
|
252
|
+
final_state = await flow.run(MyState())
|
|
253
|
+
print(final_state.history)
|
|
254
|
+
|
|
255
|
+
if __name__ == "__main__":
|
|
256
|
+
asyncio.run(main())
|
|
257
|
+
```
|
|
258
|
+
|
|
259
|
+
### Conditional Edges
|
|
260
|
+
|
|
261
|
+
You can route to different nodes based on the `result` attribute set within a node.
|
|
262
|
+
|
|
263
|
+
```python
|
|
264
|
+
class DecisionNode(Node[MyState]):
|
|
265
|
+
def exec(self) -> None:
|
|
266
|
+
if len(self.state.history) > 5:
|
|
267
|
+
self.result = "long"
|
|
268
|
+
else:
|
|
269
|
+
self.result = "short"
|
|
270
|
+
|
|
271
|
+
class MyFlow(StateFlow[MyState]):
|
|
272
|
+
def setup_graph(self) -> None:
|
|
273
|
+
self.add_node("decision", DecisionNode())
|
|
274
|
+
self.add_node("process_long", LongNode())
|
|
275
|
+
self.add_node("process_short", ShortNode())
|
|
276
|
+
|
|
277
|
+
self.add_edge(START, "decision")
|
|
278
|
+
self.add_conditional_edges("decision", {
|
|
279
|
+
"long": "process_long",
|
|
280
|
+
"short": "process_short"
|
|
281
|
+
})
|
|
282
|
+
self.add_edge("process_long", END)
|
|
283
|
+
self.add_edge("process_short", END)
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
## Complex Example: API Pagination to DataFrame
|
|
287
|
+
|
|
288
|
+
This example demonstrates a workflow that fetches data from a paginated API, accumulates the results, and finally converts them into a Pandas DataFrame.
|
|
289
|
+
|
|
290
|
+
### Workflow Diagram
|
|
291
|
+
|
|
292
|
+
```mermaid
|
|
293
|
+
graph TD
|
|
294
|
+
START((Start)) --> Init[Initialize DataFrame]
|
|
295
|
+
Init --> Fetch[Fetch Page]
|
|
296
|
+
Fetch --> Update[Update DataFrame]
|
|
297
|
+
Update --> Check{More Pages?}
|
|
298
|
+
Check -- Yes --> Fetch
|
|
299
|
+
Check -- No --> END((End))
|
|
300
|
+
```
|
|
301
|
+
|
|
302
|
+
### Code
|
|
303
|
+
|
|
304
|
+
```python
|
|
305
|
+
import pandas as pd
|
|
306
|
+
from typing import List, Any, Optional
|
|
307
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
308
|
+
from src.nodes import Node, START, END
|
|
309
|
+
from src.flows import StateFlow
|
|
310
|
+
|
|
311
|
+
# 1. Define State
|
|
312
|
+
class ApiState(BaseModel):
|
|
313
|
+
# Allow arbitrary types for pandas DataFrame
|
|
314
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
315
|
+
|
|
316
|
+
current_page: int = 1
|
|
317
|
+
total_pages: int = 1
|
|
318
|
+
current_data: List[dict] = Field(default_factory=list)
|
|
319
|
+
final_dataframe: Optional[pd.DataFrame] = None
|
|
320
|
+
base_url: str
|
|
321
|
+
|
|
322
|
+
# 2. Define Nodes
|
|
323
|
+
|
|
324
|
+
class InitDataFrameNode(Node[ApiState]):
|
|
325
|
+
def exec(self) -> None:
|
|
326
|
+
print("Initializing empty DataFrame...")
|
|
327
|
+
self.state.final_dataframe = pd.DataFrame()
|
|
328
|
+
|
|
329
|
+
class FetchPageNode(Node[ApiState]):
|
|
330
|
+
def exec(self) -> None:
|
|
331
|
+
# Simulate API call
|
|
332
|
+
print(f"Fetching page {self.state.current_page}...")
|
|
333
|
+
|
|
334
|
+
# Mock response data
|
|
335
|
+
mock_response = {
|
|
336
|
+
"data": [{"id": i, "value": f"val_{i}"} for i in range(self.state.current_page * 10, (self.state.current_page + 1) * 10)],
|
|
337
|
+
"total_pages": 3
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
# Store current page data
|
|
341
|
+
self.state.current_data = mock_response["data"]
|
|
342
|
+
self.state.total_pages = mock_response["total_pages"]
|
|
343
|
+
|
|
344
|
+
class UpdateDataFrameNode(Node[ApiState]):
|
|
345
|
+
def exec(self) -> None:
|
|
346
|
+
print("Updating DataFrame with new data...")
|
|
347
|
+
new_df = pd.DataFrame(self.state.current_data)
|
|
348
|
+
if self.state.final_dataframe is None:
|
|
349
|
+
self.state.final_dataframe = new_df
|
|
350
|
+
else:
|
|
351
|
+
self.state.final_dataframe = pd.concat([self.state.final_dataframe, new_df], ignore_index=True)
|
|
352
|
+
|
|
353
|
+
# Determine next step
|
|
354
|
+
if self.state.current_page < self.state.total_pages:
|
|
355
|
+
self.state.current_page += 1
|
|
356
|
+
self.result = "next_page"
|
|
357
|
+
else:
|
|
358
|
+
self.result = "done"
|
|
359
|
+
|
|
360
|
+
# 3. Define Flow
|
|
361
|
+
|
|
362
|
+
class ApiPaginationFlow(StateFlow[ApiState]):
|
|
363
|
+
def setup_graph(self) -> None:
|
|
364
|
+
self.add_node("init_df", InitDataFrameNode())
|
|
365
|
+
self.add_node("fetch_page", FetchPageNode())
|
|
366
|
+
self.add_node("update_df", UpdateDataFrameNode())
|
|
367
|
+
|
|
368
|
+
# Start by initializing DataFrame
|
|
369
|
+
self.add_edge(START, "init_df")
|
|
370
|
+
self.add_edge("init_df", "fetch_page")
|
|
371
|
+
self.add_edge("fetch_page", "update_df")
|
|
372
|
+
|
|
373
|
+
# Loop back to fetch_page if there are more pages, otherwise end
|
|
374
|
+
self.add_conditional_edges("update_df", {
|
|
375
|
+
"next_page": "fetch_page",
|
|
376
|
+
"done": END
|
|
377
|
+
})
|
|
378
|
+
|
|
379
|
+
# 4. Run
|
|
380
|
+
if __name__ == "__main__":
|
|
381
|
+
initial_state = ApiState(base_url="https://api.example.com/items")
|
|
382
|
+
flow = ApiPaginationFlow()
|
|
383
|
+
result_state = flow.run(initial_state)
|
|
384
|
+
|
|
385
|
+
print("\nFinal DataFrame:")
|
|
386
|
+
print(result_state.final_dataframe.head())
|
|
387
|
+
print(f"Total rows: {len(result_state.final_dataframe)}")
|
|
388
|
+
```
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "simple-state-flow"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "A graph-based state flow library using Pydantic"
|
|
5
|
+
authors = [
|
|
6
|
+
{ name = "Marc Nealer" }
|
|
7
|
+
]
|
|
8
|
+
license = "MIT"
|
|
9
|
+
urls = { repository = "https://gitlab.com/mystuff6091446/state-flow.git" }
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"pydantic>=2.12.5",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
[build-system]
|
|
16
|
+
requires = ["setuptools>=61.0"]
|
|
17
|
+
build-backend = "setuptools.build_meta"
|
|
18
|
+
|
|
19
|
+
[tool.setuptools.packages.find]
|
|
20
|
+
where = ["src"]
|
|
21
|
+
|
|
22
|
+
[dependency-groups]
|
|
23
|
+
dev = [
|
|
24
|
+
"build>=1.4.0",
|
|
25
|
+
]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: simple-state-flow
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A graph-based state flow library using Pydantic
|
|
5
|
+
Author: Marc Nealer
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: repository, https://gitlab.com/mystuff6091446/state-flow.git
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: pydantic>=2.12.5
|
|
11
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/simple_state_flow.egg-info/PKG-INFO
|
|
5
|
+
src/simple_state_flow.egg-info/SOURCES.txt
|
|
6
|
+
src/simple_state_flow.egg-info/dependency_links.txt
|
|
7
|
+
src/simple_state_flow.egg-info/requires.txt
|
|
8
|
+
src/simple_state_flow.egg-info/top_level.txt
|
|
9
|
+
src/state_flow/__init__.py
|
|
10
|
+
src/state_flow/flows.py
|
|
11
|
+
src/state_flow/nodes.py
|
|
12
|
+
tests/test_async_flow.py
|
|
13
|
+
tests/test_random_flow.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
pydantic>=2.12.5
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
state_flow
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Type, Union, TypeVar, Generic, Coroutine
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
from state_flow.nodes import START, END, Node, AsyncNode
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T", bound=BaseModel)
|
|
8
|
+
|
|
9
|
+
class StateFlow(ABC, Generic[T]):
|
|
10
|
+
"""Standard synchronous state flow."""
|
|
11
|
+
|
|
12
|
+
def __init__(self) -> None:
|
|
13
|
+
self.nodes: Dict[str, Node[T]] = {}
|
|
14
|
+
self.edges: Dict[str, str] = {}
|
|
15
|
+
self.conditional_edges: Dict[str, Dict[str, str]] = {}
|
|
16
|
+
self.setup_graph()
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def setup_graph(self) -> None:
|
|
20
|
+
"""Override this method to add nodes and edges to the flow."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def add_node(self, name: str, node: Node[T]) -> None:
|
|
24
|
+
"""Add a node to the flow.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
name: The name of the node.
|
|
28
|
+
node: The Node instance to execute for this node.
|
|
29
|
+
"""
|
|
30
|
+
self.nodes[name] = node
|
|
31
|
+
|
|
32
|
+
def add_edge(self, start_node: str, end_node: str) -> None:
|
|
33
|
+
"""Add a directed edge between two nodes.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
start_node: The name of the starting node.
|
|
37
|
+
end_node: The name of the ending node.
|
|
38
|
+
"""
|
|
39
|
+
self.edges[start_node] = end_node
|
|
40
|
+
|
|
41
|
+
def add_conditional_edges(
|
|
42
|
+
self,
|
|
43
|
+
start_node: str,
|
|
44
|
+
mapping: Dict[str, str]
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Add conditional edges from a node.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
start_node: The name of the starting node.
|
|
50
|
+
mapping: A dictionary mapping node execution results to destination nodes.
|
|
51
|
+
"""
|
|
52
|
+
self.conditional_edges[start_node] = mapping
|
|
53
|
+
|
|
54
|
+
def run(self, state: T) -> T:
|
|
55
|
+
"""Run the flow starting from the START node.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
state: The initial Pydantic state object.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The final state object.
|
|
62
|
+
"""
|
|
63
|
+
current_node = START
|
|
64
|
+
|
|
65
|
+
while current_node != END:
|
|
66
|
+
# Determine next node
|
|
67
|
+
if current_node == START:
|
|
68
|
+
next_node = self.edges.get(START)
|
|
69
|
+
if not next_node:
|
|
70
|
+
raise ValueError("START node must have an outgoing edge.")
|
|
71
|
+
current_node = next_node
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
# Execute current node action
|
|
75
|
+
if current_node not in self.nodes:
|
|
76
|
+
raise ValueError(f"Node {current_node} not found in flow.")
|
|
77
|
+
|
|
78
|
+
state, result = self.nodes[current_node].run(state)
|
|
79
|
+
|
|
80
|
+
# Check for conditional edges first
|
|
81
|
+
if current_node in self.conditional_edges:
|
|
82
|
+
mapping = self.conditional_edges[current_node]
|
|
83
|
+
if result in mapping:
|
|
84
|
+
current_node = mapping[result]
|
|
85
|
+
elif current_node in self.edges:
|
|
86
|
+
current_node = self.edges[current_node]
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError(f"Result '{result}' not found in mapping for node {current_node} and no default edge found.")
|
|
89
|
+
# Check for regular edges
|
|
90
|
+
elif current_node in self.edges:
|
|
91
|
+
current_node = self.edges[current_node]
|
|
92
|
+
else:
|
|
93
|
+
if current_node != END:
|
|
94
|
+
raise ValueError(f"Node {current_node} has no outgoing edges and is not END.")
|
|
95
|
+
|
|
96
|
+
return state
|
|
97
|
+
|
|
98
|
+
class AsyncStateFlow(ABC, Generic[T]):
|
|
99
|
+
"""Asynchronous state flow that only accepts async nodes."""
|
|
100
|
+
|
|
101
|
+
def __init__(self) -> None:
|
|
102
|
+
self.nodes: Dict[str, AsyncNode[T]] = {}
|
|
103
|
+
self.edges: Dict[str, str] = {}
|
|
104
|
+
self.conditional_edges: Dict[str, Dict[str, str]] = {}
|
|
105
|
+
self._setup_graph()
|
|
106
|
+
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def _setup_graph(self) -> None:
|
|
109
|
+
"""Override this method to add nodes and edges to the flow."""
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
def add_node(self, name: str, node: AsyncNode[T]) -> None:
|
|
113
|
+
"""Add an async node to the flow.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
name: The name of the node.
|
|
117
|
+
node: The AsyncNode instance to execute for this node.
|
|
118
|
+
"""
|
|
119
|
+
self.nodes[name] = node
|
|
120
|
+
|
|
121
|
+
def add_edge(self, start_node: str, end_node: str) -> None:
|
|
122
|
+
"""Add a directed edge between two nodes.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
start_node: The name of the starting node.
|
|
126
|
+
end_node: The name of the ending node.
|
|
127
|
+
"""
|
|
128
|
+
self.edges[start_node] = end_node
|
|
129
|
+
|
|
130
|
+
def add_conditional_edges(
|
|
131
|
+
self,
|
|
132
|
+
start_node: str,
|
|
133
|
+
mapping: Dict[str, str]
|
|
134
|
+
) -> None:
|
|
135
|
+
"""Add conditional edges from a node.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
start_node: The name of the starting node.
|
|
139
|
+
mapping: A dictionary mapping node execution results to destination nodes.
|
|
140
|
+
"""
|
|
141
|
+
self.conditional_edges[start_node] = mapping
|
|
142
|
+
|
|
143
|
+
async def run(self, state: T) -> T:
|
|
144
|
+
"""Run the flow starting from the START node.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
state: The initial Pydantic state object.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
The final state object.
|
|
151
|
+
"""
|
|
152
|
+
current_node = START
|
|
153
|
+
|
|
154
|
+
while current_node != END:
|
|
155
|
+
# Determine next node
|
|
156
|
+
if current_node == START:
|
|
157
|
+
next_node = self.edges.get(START)
|
|
158
|
+
if not next_node:
|
|
159
|
+
raise ValueError("START node must have an outgoing edge.")
|
|
160
|
+
current_node = next_node
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
# Execute current node action
|
|
164
|
+
if current_node not in self.nodes:
|
|
165
|
+
raise ValueError(f"Node {current_node} not found in flow.")
|
|
166
|
+
|
|
167
|
+
state, result = await self.nodes[current_node].run(state)
|
|
168
|
+
|
|
169
|
+
# Check for conditional edges first
|
|
170
|
+
if current_node in self.conditional_edges:
|
|
171
|
+
mapping = self.conditional_edges[current_node]
|
|
172
|
+
if result in mapping:
|
|
173
|
+
current_node = mapping[result]
|
|
174
|
+
elif current_node in self.edges:
|
|
175
|
+
current_node = self.edges[current_node]
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Result '{result}' not found in mapping for node {current_node} and no default edge found.")
|
|
178
|
+
# Check for regular edges
|
|
179
|
+
elif current_node in self.edges:
|
|
180
|
+
current_node = self.edges[current_node]
|
|
181
|
+
else:
|
|
182
|
+
if current_node != END:
|
|
183
|
+
raise ValueError(f"Node {current_node} has no outgoing edges and is not END.")
|
|
184
|
+
|
|
185
|
+
return state
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Tuple, TypeVar, Generic
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T", bound=BaseModel)
|
|
6
|
+
|
|
7
|
+
START = "__start__"
|
|
8
|
+
END = "__end__"
|
|
9
|
+
|
|
10
|
+
class Node(ABC, Generic[T]):
|
|
11
|
+
"""Abstract base class for synchronous nodes."""
|
|
12
|
+
|
|
13
|
+
def __init__(self) -> None:
|
|
14
|
+
self.state: T = None # type: ignore
|
|
15
|
+
self.result: str = "done"
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def exec(self, state: T) -> None:
|
|
19
|
+
"""Abstract method that must be implemented by subclasses.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
state: The current state object.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
A tuple of (updated state, next status).
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def _pre_run(self, state: T) -> None:
|
|
30
|
+
"""Internal method to set the state before execution.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
state: The current state object.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
TypeError: If state is not a Pydantic BaseModel.
|
|
37
|
+
"""
|
|
38
|
+
if not isinstance(state, BaseModel):
|
|
39
|
+
raise TypeError("State must be a Pydantic BaseModel.")
|
|
40
|
+
self.state = state
|
|
41
|
+
self.result = "done"
|
|
42
|
+
|
|
43
|
+
def _post_run(self, state: T, result: str) -> None:
|
|
44
|
+
"""Internal method to update the state after execution.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
state: The updated state object.
|
|
48
|
+
result: The execution result.
|
|
49
|
+
"""
|
|
50
|
+
if state is not None:
|
|
51
|
+
self.state = state
|
|
52
|
+
if result is not None:
|
|
53
|
+
self.result = result
|
|
54
|
+
|
|
55
|
+
def run(self, state: T) -> Tuple[T, str]:
|
|
56
|
+
"""Internal method to execute the node.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
state: The current state object.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
A tuple of (updated state, next status).
|
|
63
|
+
"""
|
|
64
|
+
self._pre_run(state)
|
|
65
|
+
self.exec(self.state)
|
|
66
|
+
return self.state, self.result
|
|
67
|
+
|
|
68
|
+
class AsyncNode(ABC, Generic[T]):
|
|
69
|
+
"""Abstract base class for asynchronous nodes."""
|
|
70
|
+
|
|
71
|
+
def __init__(self) -> None:
|
|
72
|
+
self.state: T = None # type: ignore
|
|
73
|
+
self.result: str = "done"
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
async def exec(self, state: T) -> None:
|
|
77
|
+
"""Abstract method that must be implemented by subclasses.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
state: The current state object.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
A tuple of (updated state, next status).
|
|
84
|
+
"""
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
def _pre_run(self, state: T) -> None:
|
|
88
|
+
"""Internal method to set the state before execution.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
state: The current state object.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
TypeError: If state is not a Pydantic BaseModel.
|
|
95
|
+
"""
|
|
96
|
+
if not isinstance(state, BaseModel):
|
|
97
|
+
raise TypeError("State must be a Pydantic BaseModel.")
|
|
98
|
+
self.state = state
|
|
99
|
+
self.result = "done"
|
|
100
|
+
|
|
101
|
+
async def run(self, state: T) -> Tuple[T, str]:
|
|
102
|
+
"""Internal method to execute the node.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
state: The current state object.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A tuple of (updated state, next status).
|
|
109
|
+
"""
|
|
110
|
+
self._pre_run(state)
|
|
111
|
+
await self.exec(self.state)
|
|
112
|
+
return self.state, self.result
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import random
|
|
3
|
+
from typing import List
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
from state_flow.nodes import AsyncNode, START, END
|
|
6
|
+
from state_flow.flows import AsyncStateFlow
|
|
7
|
+
|
|
8
|
+
class AsyncFlowState(BaseModel):
|
|
9
|
+
history: List[str] = Field(default_factory=list)
|
|
10
|
+
|
|
11
|
+
class AsyncNodeA(AsyncNode):
|
|
12
|
+
async def exec(self, state: AsyncFlowState) -> None:
|
|
13
|
+
await asyncio.sleep(0.01)
|
|
14
|
+
state.history.append("A")
|
|
15
|
+
|
|
16
|
+
class AsyncNodeB(AsyncNode):
|
|
17
|
+
async def exec(self, state: AsyncFlowState) -> None:
|
|
18
|
+
await asyncio.sleep(0.01)
|
|
19
|
+
state.history.append("B")
|
|
20
|
+
self.result = random.choice(["to_c", "to_d"])
|
|
21
|
+
|
|
22
|
+
class AsyncNodeC(AsyncNode):
|
|
23
|
+
async def exec(self, state: AsyncFlowState) -> None:
|
|
24
|
+
await asyncio.sleep(0.01)
|
|
25
|
+
state.history.append("C")
|
|
26
|
+
|
|
27
|
+
class AsyncNodeD(AsyncNode):
|
|
28
|
+
async def exec(self, state: AsyncFlowState) -> None:
|
|
29
|
+
await asyncio.sleep(0.01)
|
|
30
|
+
state.history.append("D")
|
|
31
|
+
|
|
32
|
+
class SimpleAsyncFlow(AsyncStateFlow[AsyncFlowState]):
|
|
33
|
+
def _setup_graph(self) -> None:
|
|
34
|
+
self.add_node("A", AsyncNodeA())
|
|
35
|
+
self.add_node("B", AsyncNodeB())
|
|
36
|
+
self.add_node("C", AsyncNodeC())
|
|
37
|
+
self.add_node("D", AsyncNodeD())
|
|
38
|
+
|
|
39
|
+
self.add_edge(START, "A")
|
|
40
|
+
self.add_edge("A", "B")
|
|
41
|
+
self.add_conditional_edges("B", {"to_c": "C", "to_d": "D"})
|
|
42
|
+
self.add_edge("C", END)
|
|
43
|
+
self.add_edge("D", END)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def test_async_flow():
|
|
47
|
+
flow = SimpleAsyncFlow()
|
|
48
|
+
|
|
49
|
+
seen_c = False
|
|
50
|
+
seen_d = False
|
|
51
|
+
|
|
52
|
+
for _ in range(100):
|
|
53
|
+
state = AsyncFlowState()
|
|
54
|
+
final_state = await flow.run(state)
|
|
55
|
+
|
|
56
|
+
assert final_state.history[0] == "A"
|
|
57
|
+
assert final_state.history[1] == "B"
|
|
58
|
+
|
|
59
|
+
last_node = final_state.history[2]
|
|
60
|
+
assert last_node in ["C", "D"]
|
|
61
|
+
|
|
62
|
+
if last_node == "C":
|
|
63
|
+
seen_c = True
|
|
64
|
+
if last_node == "D":
|
|
65
|
+
seen_d = True
|
|
66
|
+
|
|
67
|
+
assert seen_c, "Node C was never reached"
|
|
68
|
+
assert seen_d, "Node D was never reached"
|
|
69
|
+
print(f"Async flow history: {final_state.history}")
|
|
70
|
+
|
|
71
|
+
if __name__ == "__main__":
|
|
72
|
+
asyncio.run(test_async_flow())
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import List
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
from state_flow.nodes import Node, START, END
|
|
5
|
+
from state_flow.flows import StateFlow
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FlowState(BaseModel):
|
|
10
|
+
history: List[str] = Field(default_factory=list)
|
|
11
|
+
|
|
12
|
+
class NodeA(Node):
|
|
13
|
+
def exec(self, state: FlowState) -> None:
|
|
14
|
+
state.history.append("A")
|
|
15
|
+
|
|
16
|
+
class NodeB(Node):
|
|
17
|
+
def exec(self, state: FlowState) -> None:
|
|
18
|
+
state.history.append("B")
|
|
19
|
+
self.result = random.choice(["to_c", "to_d"])
|
|
20
|
+
|
|
21
|
+
class NodeC(Node):
|
|
22
|
+
def exec(self, state: FlowState) -> None:
|
|
23
|
+
state.history.append("C")
|
|
24
|
+
|
|
25
|
+
class NodeD(Node):
|
|
26
|
+
def exec(self, state: FlowState) -> None:
|
|
27
|
+
state.history.append("D")
|
|
28
|
+
|
|
29
|
+
class RandomFlow(StateFlow[FlowState]):
|
|
30
|
+
def setup_graph(self) -> None:
|
|
31
|
+
self.add_node("A", NodeA())
|
|
32
|
+
self.add_node("B", NodeB())
|
|
33
|
+
self.add_node("C", NodeC())
|
|
34
|
+
self.add_node("D", NodeD())
|
|
35
|
+
|
|
36
|
+
self.add_edge(START, "A")
|
|
37
|
+
self.add_edge("A", "B")
|
|
38
|
+
self.add_conditional_edges("B", {"to_c": "C", "to_d": "D"})
|
|
39
|
+
self.add_edge("C", END)
|
|
40
|
+
self.add_edge("D", END)
|
|
41
|
+
|
|
42
|
+
def test_random_flow():
|
|
43
|
+
flow = RandomFlow()
|
|
44
|
+
|
|
45
|
+
seen_c = False
|
|
46
|
+
seen_d = False
|
|
47
|
+
|
|
48
|
+
# Run up to 100 times to avoid infinite loop if something is wrong,
|
|
49
|
+
# but practically it should hit both very quickly.
|
|
50
|
+
for _ in range(100):
|
|
51
|
+
state = FlowState()
|
|
52
|
+
final_state = flow.run(state)
|
|
53
|
+
|
|
54
|
+
assert final_state.history[0] == "A"
|
|
55
|
+
assert final_state.history[1] == "B"
|
|
56
|
+
|
|
57
|
+
last_node = final_state.history[2]
|
|
58
|
+
print(final_state)
|
|
59
|
+
assert last_node in ["C", "D"]
|
|
60
|
+
|
|
61
|
+
if last_node == "C":
|
|
62
|
+
seen_c = True
|
|
63
|
+
if last_node == "D":
|
|
64
|
+
seen_d = True
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
assert seen_c, "Node C was never reached"
|
|
68
|
+
assert seen_d, "Node D was never reached"
|
|
69
|
+
print("Test passed: Both nodes C and D were reached randomly.")
|
|
70
|
+
|
|
71
|
+
if __name__ == "__main__":
|
|
72
|
+
test_random_flow()
|