soprano-sdk 0.1.93__py3-none-any.whl → 0.1.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- soprano_sdk/__init__.py +10 -0
- soprano_sdk/agents/__init__.py +30 -0
- soprano_sdk/agents/adaptor.py +91 -0
- soprano_sdk/agents/factory.py +228 -0
- soprano_sdk/agents/structured_output.py +97 -0
- soprano_sdk/core/__init__.py +0 -0
- soprano_sdk/core/constants.py +59 -0
- soprano_sdk/core/engine.py +225 -0
- soprano_sdk/core/rollback_strategies.py +259 -0
- soprano_sdk/core/state.py +71 -0
- soprano_sdk/engine.py +381 -0
- soprano_sdk/nodes/__init__.py +0 -0
- soprano_sdk/nodes/base.py +57 -0
- soprano_sdk/nodes/call_function.py +108 -0
- soprano_sdk/nodes/collect_input.py +526 -0
- soprano_sdk/nodes/factory.py +46 -0
- soprano_sdk/routing/__init__.py +0 -0
- soprano_sdk/routing/router.py +97 -0
- soprano_sdk/tools.py +219 -0
- soprano_sdk/utils/__init__.py +0 -0
- soprano_sdk/utils/data.py +1 -0
- soprano_sdk/utils/function.py +35 -0
- soprano_sdk/utils/logger.py +6 -0
- soprano_sdk/utils/template.py +27 -0
- soprano_sdk/utils/tool.py +60 -0
- soprano_sdk/utils/tracing.py +71 -0
- soprano_sdk/validation/__init__.py +13 -0
- soprano_sdk/validation/schema.py +302 -0
- soprano_sdk/validation/validator.py +173 -0
- {soprano_sdk-0.1.93.dist-info → soprano_sdk-0.1.95.dist-info}/METADATA +1 -1
- soprano_sdk-0.1.95.dist-info/RECORD +33 -0
- soprano_sdk-0.1.93.dist-info/RECORD +0 -4
- {soprano_sdk-0.1.93.dist-info → soprano_sdk-0.1.95.dist-info}/WHEEL +0 -0
- {soprano_sdk-0.1.93.dist-info → soprano_sdk-0.1.95.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any, Tuple
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
from jinja2 import Environment
|
|
5
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
6
|
+
from langgraph.constants import START
|
|
7
|
+
from langgraph.graph import StateGraph
|
|
8
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
9
|
+
|
|
10
|
+
from .constants import WorkflowKeys
|
|
11
|
+
from .state import create_state_model
|
|
12
|
+
from ..nodes.factory import NodeFactory
|
|
13
|
+
from ..routing.router import WorkflowRouter
|
|
14
|
+
from ..utils.function import FunctionRepository
|
|
15
|
+
from ..utils.logger import logger
|
|
16
|
+
from ..utils.tool import ToolRepository
|
|
17
|
+
from ..validation import validate_workflow
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class WorkflowEngine:
|
|
21
|
+
def __init__(self, yaml_path: str, configs: dict):
|
|
22
|
+
self.yaml_path = yaml_path
|
|
23
|
+
self.configs = configs
|
|
24
|
+
logger.info(f"Loading workflow from: {yaml_path}")
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
with open(yaml_path, 'r') as f:
|
|
28
|
+
self.config = yaml.safe_load(f)
|
|
29
|
+
|
|
30
|
+
logger.info("Validating workflow configuration")
|
|
31
|
+
validate_workflow(self.config)
|
|
32
|
+
|
|
33
|
+
self.workflow_name = self.config['name']
|
|
34
|
+
self.workflow_description = self.config['description']
|
|
35
|
+
self.workflow_version = self.config['version']
|
|
36
|
+
self.data_fields = self.config['data']
|
|
37
|
+
self.steps = self.config['steps']
|
|
38
|
+
self.outcomes = self.config['outcomes']
|
|
39
|
+
self.metadata = self.config.get('metadata', {})
|
|
40
|
+
|
|
41
|
+
self.StateType = create_state_model(self.data_fields)
|
|
42
|
+
|
|
43
|
+
self.step_map = {step['id']: step for step in self.steps}
|
|
44
|
+
self.outcome_map = {outcome['id']: outcome for outcome in self.outcomes}
|
|
45
|
+
|
|
46
|
+
self.function_repository = FunctionRepository()
|
|
47
|
+
self.tool_repository = None
|
|
48
|
+
if tool_config := self.config.get("tool_config"):
|
|
49
|
+
self.tool_repository = ToolRepository(tool_config)
|
|
50
|
+
|
|
51
|
+
self.context_store = {}
|
|
52
|
+
self.collect_input_fields = self._get_collect_input_fields()
|
|
53
|
+
|
|
54
|
+
logger.info(
|
|
55
|
+
f"Workflow loaded: {self.workflow_name} v{self.workflow_version} "
|
|
56
|
+
f"({len(self.steps)} steps, {len(self.outcomes)} outcomes)"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
raise e
|
|
61
|
+
|
|
62
|
+
def get_config_value(self, key, default_value: Optional[Any]=None):
|
|
63
|
+
if value := self.configs.get(key) :
|
|
64
|
+
return value
|
|
65
|
+
|
|
66
|
+
if value := self.config.get(key) :
|
|
67
|
+
return value
|
|
68
|
+
|
|
69
|
+
return default_value
|
|
70
|
+
|
|
71
|
+
def _get_collect_input_fields(self) -> set:
|
|
72
|
+
fields = set()
|
|
73
|
+
for step in self.steps:
|
|
74
|
+
if step.get('action') == 'collect_input_with_agent' and (field := step.get('field')):
|
|
75
|
+
fields.add(field)
|
|
76
|
+
return fields
|
|
77
|
+
|
|
78
|
+
def update_context(self, context: Dict[str, Any]):
|
|
79
|
+
self.context_store.update(context)
|
|
80
|
+
logger.info(f"Context updated: {context}")
|
|
81
|
+
|
|
82
|
+
def remove_context_field(self, field_name: str):
|
|
83
|
+
if field_name in self.context_store:
|
|
84
|
+
del self.context_store[field_name]
|
|
85
|
+
logger.info(f"Removed context field: {field_name}")
|
|
86
|
+
|
|
87
|
+
def get_context_value(self, field_name: str):
|
|
88
|
+
value = self.context_store.get(field_name, None)
|
|
89
|
+
if value is not None:
|
|
90
|
+
logger.info(f"Retrieved context value for '{field_name}': {value}")
|
|
91
|
+
return value
|
|
92
|
+
|
|
93
|
+
def build_graph(self, checkpointer=None):
|
|
94
|
+
logger.info("Building workflow graph")
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
builder = StateGraph(self.StateType)
|
|
98
|
+
|
|
99
|
+
collector_nodes = []
|
|
100
|
+
|
|
101
|
+
logger.info("Adding nodes to graph")
|
|
102
|
+
for step in self.steps:
|
|
103
|
+
step_id = step['id']
|
|
104
|
+
action = step['action']
|
|
105
|
+
|
|
106
|
+
if action == 'collect_input_with_agent':
|
|
107
|
+
collector_nodes.append(step_id)
|
|
108
|
+
|
|
109
|
+
node_fn = NodeFactory.create(step, engine_context=self)
|
|
110
|
+
builder.add_node(step_id, node_fn)
|
|
111
|
+
|
|
112
|
+
logger.info(f"Added node: {step_id} (action: {action})")
|
|
113
|
+
|
|
114
|
+
first_step_id = self.steps[0]['id']
|
|
115
|
+
builder.add_edge(START, first_step_id)
|
|
116
|
+
logger.info(f"Set entry point: {first_step_id}")
|
|
117
|
+
|
|
118
|
+
logger.info("Adding routing edges")
|
|
119
|
+
for step in self.steps:
|
|
120
|
+
step_id = step['id']
|
|
121
|
+
|
|
122
|
+
router = WorkflowRouter(step, self.step_map, self.outcome_map)
|
|
123
|
+
route_fn = router.create_route_function()
|
|
124
|
+
routing_map = router.get_routing_map(collector_nodes)
|
|
125
|
+
|
|
126
|
+
builder.add_conditional_edges(step_id, route_fn, routing_map)
|
|
127
|
+
|
|
128
|
+
logger.info(
|
|
129
|
+
f"Added routing for {step_id}: {len(routing_map)} destinations"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if checkpointer is None:
|
|
133
|
+
checkpointer = InMemorySaver()
|
|
134
|
+
logger.info("Using InMemorySaver for state persistence")
|
|
135
|
+
else:
|
|
136
|
+
logger.info(f"Using custom checkpointer: {type(checkpointer).__name__}")
|
|
137
|
+
|
|
138
|
+
graph = builder.compile(checkpointer=checkpointer)
|
|
139
|
+
|
|
140
|
+
logger.info("Workflow graph built successfully")
|
|
141
|
+
return graph
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
raise RuntimeError(f"Failed to build workflow graph: {e}")
|
|
145
|
+
|
|
146
|
+
def get_outcome_message(self, state: Dict[str, Any]) -> str:
|
|
147
|
+
outcome_id = state.get(WorkflowKeys.OUTCOME_ID)
|
|
148
|
+
step_id = state.get(WorkflowKeys.STEP_ID)
|
|
149
|
+
|
|
150
|
+
outcome = self.outcome_map.get(outcome_id)
|
|
151
|
+
if outcome and 'message' in outcome:
|
|
152
|
+
message = outcome['message']
|
|
153
|
+
template_loader = self.get_config_value("template_loader", Environment())
|
|
154
|
+
message = template_loader.from_string(message).render(state)
|
|
155
|
+
logger.info(f"Outcome message generated in step {step_id}: {message}")
|
|
156
|
+
return message
|
|
157
|
+
|
|
158
|
+
if error := state.get("error"):
|
|
159
|
+
logger.info(f"Outcome error found in step {step_id}: {error}")
|
|
160
|
+
return f"{error}"
|
|
161
|
+
|
|
162
|
+
if message := state.get(WorkflowKeys.MESSAGES):
|
|
163
|
+
logger.info(f"Outcome message found in step {step_id}: {message}")
|
|
164
|
+
return f"{message}"
|
|
165
|
+
|
|
166
|
+
logger.error(f"No outcome message found in step {step_id}")
|
|
167
|
+
return "{'error': 'Unable to complete the request'}"
|
|
168
|
+
|
|
169
|
+
def get_step_info(self, step_id: str) -> Optional[Dict[str, Any]]:
|
|
170
|
+
return self.step_map.get(step_id)
|
|
171
|
+
|
|
172
|
+
def get_outcome_info(self, outcome_id: str) -> Optional[Dict[str, Any]]:
|
|
173
|
+
return self.outcome_map.get(outcome_id)
|
|
174
|
+
|
|
175
|
+
def list_steps(self) -> list:
|
|
176
|
+
return [step['id'] for step in self.steps]
|
|
177
|
+
|
|
178
|
+
def list_outcomes(self) -> list:
|
|
179
|
+
return [outcome['id'] for outcome in self.outcomes]
|
|
180
|
+
|
|
181
|
+
def get_workflow_info(self) -> Dict[str, Any]:
|
|
182
|
+
return {
|
|
183
|
+
'name': self.workflow_name,
|
|
184
|
+
'description': self.workflow_description,
|
|
185
|
+
'version': self.workflow_version,
|
|
186
|
+
'steps': len(self.steps),
|
|
187
|
+
'outcomes': len(self.outcomes),
|
|
188
|
+
'data_fields': [f['name'] for f in self.data_fields],
|
|
189
|
+
'metadata': self.metadata
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
def get_tool_policy(self) -> str:
|
|
193
|
+
tool_config = self.config.get('tool_config')
|
|
194
|
+
if not tool_config:
|
|
195
|
+
raise ValueError("Tool config is not provided in the YAML")
|
|
196
|
+
return tool_config.get('usage_policy')
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def load_workflow(yaml_path: str, checkpointer=None, config=None) -> Tuple[CompiledStateGraph, WorkflowEngine]:
|
|
200
|
+
"""
|
|
201
|
+
Load a workflow from YAML configuration.
|
|
202
|
+
|
|
203
|
+
This is the main entry point for using the framework.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
yaml_path: Path to the workflow YAML file
|
|
207
|
+
checkpointer: Optional checkpointer for state persistence.
|
|
208
|
+
Defaults to InMemorySaver() if not provided.
|
|
209
|
+
Example: MongoDBSaver for production persistence.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Tuple of (compiled_graph, engine) where:
|
|
213
|
+
- compiled_graph: LangGraph ready for execution
|
|
214
|
+
- engine: WorkflowEngine instance for introspection
|
|
215
|
+
|
|
216
|
+
Example:
|
|
217
|
+
```python
|
|
218
|
+
graph, engine = load_workflow("workflow.yaml")
|
|
219
|
+
result = graph.invoke({}, config={"configurable": {"thread_id": "123"}})
|
|
220
|
+
message = engine.get_outcome_message(result)
|
|
221
|
+
```
|
|
222
|
+
"""
|
|
223
|
+
engine = WorkflowEngine(yaml_path, configs=config)
|
|
224
|
+
graph = engine.build_graph(checkpointer=checkpointer)
|
|
225
|
+
return graph, engine
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
import uuid
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Dict, Any, List
|
|
7
|
+
|
|
8
|
+
from soprano_sdk.core.constants import WorkflowKeys, ActionType
|
|
9
|
+
from ..utils.logger import logger
|
|
10
|
+
|
|
11
|
+
class RollbackStrategy(ABC):
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def rollback_to_node(
|
|
14
|
+
self,
|
|
15
|
+
state: Dict[str, Any],
|
|
16
|
+
target_node: str,
|
|
17
|
+
node_execution_order: List[str],
|
|
18
|
+
node_field_map: Dict[str, str],
|
|
19
|
+
workflow_steps: List[Dict[str, Any]]
|
|
20
|
+
) -> Dict[str, Any]:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def should_save_snapshot(self) -> bool:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def get_strategy_name(self) -> str:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _restore_from_snapshot(snapshot: Dict[str, Any]) -> Dict[str, Any]:
|
|
37
|
+
return copy.deepcopy(snapshot['state'])
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _clear_future_executions(
|
|
41
|
+
state: Dict[str, Any],
|
|
42
|
+
target_node: str,
|
|
43
|
+
workflow_steps: List[Dict[str, Any]]
|
|
44
|
+
) -> Dict[str, Any]:
|
|
45
|
+
target_step_index = next(
|
|
46
|
+
(i for i, step in enumerate(workflow_steps) if step['id'] == target_node),
|
|
47
|
+
None
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if target_step_index is None:
|
|
51
|
+
logger.warning(f"Target node {target_node} not found in workflow steps")
|
|
52
|
+
return state
|
|
53
|
+
|
|
54
|
+
future_steps = workflow_steps[target_step_index:]
|
|
55
|
+
|
|
56
|
+
logger.info(f"Future steps to clear: {[s['id'] for s in future_steps]}")
|
|
57
|
+
|
|
58
|
+
for step in future_steps:
|
|
59
|
+
action = step.get('action')
|
|
60
|
+
|
|
61
|
+
if action == ActionType.COLLECT_INPUT_WITH_AGENT.value:
|
|
62
|
+
field_name = step.get('field')
|
|
63
|
+
if field_name:
|
|
64
|
+
state[field_name] = None
|
|
65
|
+
|
|
66
|
+
conv_key = f"{field_name}_conversation"
|
|
67
|
+
conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
|
|
68
|
+
if conv_key in conversations:
|
|
69
|
+
del conversations[conv_key]
|
|
70
|
+
logger.info(f"Cleared conversation: {conv_key}")
|
|
71
|
+
|
|
72
|
+
elif action == ActionType.CALL_FUNCTION.value:
|
|
73
|
+
output_field = step.get('output')
|
|
74
|
+
if output_field:
|
|
75
|
+
state[output_field] = None
|
|
76
|
+
logger.info(f"Cleared computed field: {output_field}")
|
|
77
|
+
|
|
78
|
+
return state
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class HistoryBasedRollback(RollbackStrategy):
|
|
82
|
+
def get_strategy_name(self) -> str:
|
|
83
|
+
return "history_based"
|
|
84
|
+
|
|
85
|
+
def should_save_snapshot(self) -> bool:
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
|
|
89
|
+
state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
|
|
90
|
+
|
|
91
|
+
snapshot = {
|
|
92
|
+
'snapshot_id': str(uuid.uuid4()),
|
|
93
|
+
'node_about_to_execute': node_id,
|
|
94
|
+
'execution_index': execution_index,
|
|
95
|
+
'timestamp': datetime.now().isoformat(),
|
|
96
|
+
'state': copy.deepcopy(state),
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
state_history.append(snapshot)
|
|
100
|
+
state[WorkflowKeys.STATE_HISTORY] = state_history
|
|
101
|
+
|
|
102
|
+
logger.info(f"Saved snapshot #{len(state_history)-1} before executing {node_id}")
|
|
103
|
+
|
|
104
|
+
def rollback_to_node(
|
|
105
|
+
self,
|
|
106
|
+
state: Dict[str, Any],
|
|
107
|
+
target_node: str,
|
|
108
|
+
node_execution_order: List[str],
|
|
109
|
+
node_field_map: Dict[str, str],
|
|
110
|
+
workflow_steps: List[Dict[str, Any]]
|
|
111
|
+
) -> Dict[str, Any]:
|
|
112
|
+
state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
|
|
113
|
+
|
|
114
|
+
if not state_history:
|
|
115
|
+
logger.warning("No state history available for rollback")
|
|
116
|
+
return {}
|
|
117
|
+
|
|
118
|
+
logger.info(f"Looking for snapshot before node '{target_node}'")
|
|
119
|
+
|
|
120
|
+
target_snapshot = None
|
|
121
|
+
target_index = None
|
|
122
|
+
|
|
123
|
+
for i, snapshot in enumerate(state_history):
|
|
124
|
+
if snapshot.get('node_about_to_execute') == target_node:
|
|
125
|
+
target_snapshot = snapshot
|
|
126
|
+
target_index = i
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
if target_snapshot is None:
|
|
130
|
+
logger.warning(f"No snapshot found before node '{target_node}'")
|
|
131
|
+
return {}
|
|
132
|
+
|
|
133
|
+
logger.info(f"Found snapshot at index {target_index}")
|
|
134
|
+
restored_state = _restore_from_snapshot(target_snapshot)
|
|
135
|
+
|
|
136
|
+
restored_state = _clear_future_executions(
|
|
137
|
+
restored_state,
|
|
138
|
+
target_node,
|
|
139
|
+
workflow_steps
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
restored_state[WorkflowKeys.STATE_HISTORY] = state_history[:target_index + 1]
|
|
143
|
+
|
|
144
|
+
logger.info(f"Successfully rolled back to {target_node}")
|
|
145
|
+
|
|
146
|
+
return restored_state
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _build_dependency_graph(
|
|
150
|
+
workflow_steps: List[Dict[str, Any]]
|
|
151
|
+
) -> Dict[str, List[str]]:
|
|
152
|
+
graph = {}
|
|
153
|
+
|
|
154
|
+
for step in workflow_steps:
|
|
155
|
+
field = step.get('field') or step.get('output')
|
|
156
|
+
|
|
157
|
+
if not field:
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
depends_on = step.get('depends_on')
|
|
161
|
+
|
|
162
|
+
if depends_on:
|
|
163
|
+
if isinstance(depends_on, str):
|
|
164
|
+
depends_on_list = [depends_on]
|
|
165
|
+
elif isinstance(depends_on, list):
|
|
166
|
+
depends_on_list = depends_on
|
|
167
|
+
else:
|
|
168
|
+
logger.warning(f"Invalid depends_on type for field '{field}': {type(depends_on)}")
|
|
169
|
+
depends_on_list = []
|
|
170
|
+
|
|
171
|
+
for parent_field in depends_on_list:
|
|
172
|
+
if parent_field not in graph:
|
|
173
|
+
graph[parent_field] = []
|
|
174
|
+
if field not in graph[parent_field]:
|
|
175
|
+
graph[parent_field].append(field)
|
|
176
|
+
|
|
177
|
+
if field not in graph:
|
|
178
|
+
graph[field] = []
|
|
179
|
+
|
|
180
|
+
return graph
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _find_all_dependents(
|
|
184
|
+
field: str,
|
|
185
|
+
dependency_graph: Dict[str, List[str]]
|
|
186
|
+
) -> set:
|
|
187
|
+
all_dependents = set()
|
|
188
|
+
visited = set()
|
|
189
|
+
|
|
190
|
+
def _recurse(current_field: str):
|
|
191
|
+
if current_field in visited:
|
|
192
|
+
return
|
|
193
|
+
visited.add(current_field)
|
|
194
|
+
|
|
195
|
+
direct_dependents = dependency_graph.get(current_field, [])
|
|
196
|
+
|
|
197
|
+
for dependent in direct_dependents:
|
|
198
|
+
all_dependents.add(dependent)
|
|
199
|
+
_recurse(dependent)
|
|
200
|
+
|
|
201
|
+
_recurse(field)
|
|
202
|
+
|
|
203
|
+
return all_dependents
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _clear_field_conversation(state: Dict[str, Any], field: str) -> None:
|
|
207
|
+
conv_key = f"{field}_conversation"
|
|
208
|
+
conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
|
|
209
|
+
|
|
210
|
+
if conv_key in conversations:
|
|
211
|
+
del conversations[conv_key]
|
|
212
|
+
logger.info(f"Cleared conversation: {conv_key}")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class DependencyBasedRollback(RollbackStrategy):
|
|
216
|
+
def get_strategy_name(self) -> str:
|
|
217
|
+
return "dependency_based"
|
|
218
|
+
|
|
219
|
+
def should_save_snapshot(self) -> bool:
|
|
220
|
+
return False
|
|
221
|
+
|
|
222
|
+
def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
def rollback_to_node(
|
|
226
|
+
self,
|
|
227
|
+
state: Dict[str, Any],
|
|
228
|
+
target_node: str,
|
|
229
|
+
node_execution_order: List[str],
|
|
230
|
+
node_field_map: Dict[str, str],
|
|
231
|
+
workflow_steps: List[Dict[str, Any]]
|
|
232
|
+
) -> Dict[str, Any]:
|
|
233
|
+
target_field = node_field_map.get(target_node)
|
|
234
|
+
|
|
235
|
+
if not target_field:
|
|
236
|
+
logger.warning(f"No field found for target node '{target_node}'")
|
|
237
|
+
return state
|
|
238
|
+
|
|
239
|
+
logger.info(f"Rolling back to node '{target_node}' (field: '{target_field}')")
|
|
240
|
+
|
|
241
|
+
dependency_graph = _build_dependency_graph(workflow_steps)
|
|
242
|
+
|
|
243
|
+
logger.info(f"Dependency graph: {dependency_graph}")
|
|
244
|
+
|
|
245
|
+
dependent_fields = _find_all_dependents(target_field, dependency_graph)
|
|
246
|
+
|
|
247
|
+
logger.info(f"Fields dependent on '{target_field}': {dependent_fields}")
|
|
248
|
+
|
|
249
|
+
state[target_field] = None
|
|
250
|
+
_clear_field_conversation(state, target_field)
|
|
251
|
+
|
|
252
|
+
for field in dependent_fields:
|
|
253
|
+
state[field] = None
|
|
254
|
+
_clear_field_conversation(state, field)
|
|
255
|
+
logger.info(f"Cleared dependent field: {field}")
|
|
256
|
+
|
|
257
|
+
logger.info(f"Successfully rolled back to {target_node} using dependency graph")
|
|
258
|
+
|
|
259
|
+
return state
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import types
|
|
2
|
+
from typing import Annotated, Optional, Dict, List, Any, Type
|
|
3
|
+
|
|
4
|
+
from typing_extensions import TypedDict
|
|
5
|
+
|
|
6
|
+
from .constants import DataType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def replace(_, right):
|
|
10
|
+
return right
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_state_value(state: Dict[str, Any], key: str, default: Any = None) -> Any:
|
|
14
|
+
return state.get(key, default)
|
|
15
|
+
|
|
16
|
+
def set_state_value(state: Dict[str, Any], key: str, value: Any) -> None:
|
|
17
|
+
state[key] = value
|
|
18
|
+
|
|
19
|
+
def create_state_model(data_fields: List[dict]):
|
|
20
|
+
type_mapping = {
|
|
21
|
+
DataType.TEXT.value: Optional[str],
|
|
22
|
+
DataType.NUMBER.value: Optional[int],
|
|
23
|
+
DataType.DOUBLE.value: Optional[float],
|
|
24
|
+
DataType.BOOLEAN.value: Optional[bool],
|
|
25
|
+
DataType.LIST.value: Optional[List[Any]],
|
|
26
|
+
DataType.DICT.value: Optional[Dict[str, Any]],
|
|
27
|
+
DataType.ANY.value: Optional[Any]
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
fields = {}
|
|
31
|
+
for field_def in data_fields:
|
|
32
|
+
field_name = field_def['name']
|
|
33
|
+
field_type = field_def['type']
|
|
34
|
+
python_type = type_mapping.get(field_type, Optional[str])
|
|
35
|
+
fields[field_name] = Annotated[python_type, replace]
|
|
36
|
+
|
|
37
|
+
fields['_step_id'] = Annotated[Optional[str], replace]
|
|
38
|
+
fields['_status'] = Annotated[Optional[str], replace]
|
|
39
|
+
fields['_outcome_id'] = Annotated[Optional[str], replace]
|
|
40
|
+
|
|
41
|
+
fields['_messages'] = Annotated[List[str], replace]
|
|
42
|
+
fields['_conversations'] = Annotated[Dict[str, List[Dict[str, str]]], replace]
|
|
43
|
+
fields['_state_history'] = Annotated[List[Dict[str, Any]], replace]
|
|
44
|
+
fields['_collector_nodes'] = Annotated[Dict[str, str], replace]
|
|
45
|
+
fields['_attempt_counts'] = Annotated[Dict[str, int], replace]
|
|
46
|
+
fields['_node_execution_order'] = Annotated[List[str], replace]
|
|
47
|
+
fields['_node_field_map'] = Annotated[Dict[str, str], replace]
|
|
48
|
+
fields['_computed_fields'] = Annotated[List[str], replace]
|
|
49
|
+
fields['error'] = Annotated[Optional[Dict[str, str]], replace]
|
|
50
|
+
|
|
51
|
+
return types.new_class('WorkflowState', (TypedDict,), {}, lambda ns: ns.update({'__annotations__': fields}))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def initialize_state(state: Dict[str, Any]) -> Dict[str, Any]:
|
|
55
|
+
fields_to_initialize = {
|
|
56
|
+
'_state_history': [],
|
|
57
|
+
'_collector_nodes': {},
|
|
58
|
+
'_conversations': {},
|
|
59
|
+
'_messages': [],
|
|
60
|
+
'_attempt_counts': {},
|
|
61
|
+
'_node_execution_order': [],
|
|
62
|
+
'_node_field_map': {},
|
|
63
|
+
'_computed_fields': [],
|
|
64
|
+
'error': None
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
for field_name, default_value in fields_to_initialize.items():
|
|
68
|
+
if not get_state_value(state, field_name):
|
|
69
|
+
set_state_value(state, field_name, default_value.copy() if isinstance(default_value, (list, dict)) else default_value)
|
|
70
|
+
|
|
71
|
+
return state
|