soprano-sdk 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,125 @@
1
+ from enum import Enum
2
+ from typing import Optional
3
+ from pydantic import Field
4
+ from pydantic_settings import BaseSettings, SettingsConfigDict
5
+
6
+
7
+ class WorkflowKeys:
8
+ STEP_ID = '_step_id'
9
+ STATUS = '_status'
10
+ OUTCOME_ID = '_outcome_id'
11
+ MESSAGES = '_messages'
12
+ CONVERSATIONS = '_conversations'
13
+ STATE_HISTORY = '_state_history'
14
+ COLLECTOR_NODES = '_collector_nodes'
15
+ ATTEMPT_COUNTS = '_attempt_counts'
16
+ NODE_EXECUTION_ORDER = '_node_execution_order'
17
+ NODE_FIELD_MAP = '_node_field_map'
18
+ COMPUTED_FIELDS = '_computed_fields'
19
+ ERROR = 'error'
20
+
21
+
22
+ class ActionType(Enum):
23
+ COLLECT_INPUT_WITH_AGENT = 'collect_input_with_agent'
24
+ CALL_FUNCTION = 'call_function'
25
+ CALL_ASYNC_FUNCTION = 'call_async_function'
26
+
27
+
28
+ class InterruptType:
29
+ """Interrupt type markers for workflow pauses"""
30
+ USER_INPUT = '__WORKFLOW_INTERRUPT__'
31
+ ASYNC = '__ASYNC_INTERRUPT__'
32
+
33
+
34
+ class DataType(Enum):
35
+ TEXT = 'text'
36
+ NUMBER = 'number'
37
+ DOUBLE = 'double'
38
+ BOOLEAN = 'boolean'
39
+ LIST = 'list'
40
+ DICT = 'dict'
41
+ ANY = "any"
42
+
43
+
44
+ class OutcomeType(Enum):
45
+ SUCCESS = 'success'
46
+ FAILURE = 'failure'
47
+
48
+
49
+ class StatusPattern:
50
+ COLLECTING = '{step_id}_collecting'
51
+ MAX_ATTEMPTS = '{step_id}_max_attempts'
52
+ NEXT_STEP = '{step_id}_{next_step}'
53
+ SUCCESS = '{step_id}_success'
54
+ FAILED = '{step_id}_failed'
55
+ INTENT_CHANGE = '{step_id}_{target_node}'
56
+
57
+
58
+ class TransitionPattern:
59
+ CAPTURED = '{field}_CAPTURED:'
60
+ FAILED = '{field}_FAILED:'
61
+ INTENT_CHANGE = 'INTENT_CHANGE:'
62
+
63
+
64
+ DEFAULT_MAX_ATTEMPTS = 3
65
+ DEFAULT_MODEL = 'gpt-4o-mini'
66
+ DEFAULT_TIMEOUT = 300
67
+
68
+ MAX_ATTEMPTS_MESSAGE = "I'm having trouble understanding your {field}. Please contact customer service for assistance."
69
+ WORKFLOW_COMPLETE_MESSAGE = "Workflow completed."
70
+
71
+
72
+ class MFAConfig(BaseSettings):
73
+ """
74
+ Configuration for MFA REST API endpoints.
75
+
76
+ Values can be provided during initialization or will be automatically
77
+ loaded from environment variables with the same name (uppercase).
78
+
79
+ Example:
80
+ # Load from environment variables
81
+ config = MFAConfig()
82
+
83
+ # Or provide specific values
84
+ config = MFAConfig(
85
+ generate_token_base_url="https://api.example.com",
86
+ generate_token_path="/v1/mfa/generate"
87
+ )
88
+ """
89
+ generate_token_base_url: Optional[str] = Field(
90
+ default=None,
91
+ description="Base URL for the generate token endpoint"
92
+ )
93
+ generate_token_path: Optional[str] = Field(
94
+ default=None,
95
+ description="Path for the generate token endpoint"
96
+ )
97
+ validate_token_base_url: Optional[str] = Field(
98
+ default=None,
99
+ description="Base URL for the validate token endpoint"
100
+ )
101
+ validate_token_path: Optional[str] = Field(
102
+ default=None,
103
+ description="Path for the validate token endpoint"
104
+ )
105
+ authorize_token_base_url: Optional[str] = Field(
106
+ default=None,
107
+ description="Base URL for the authorize token endpoint"
108
+ )
109
+ authorize_token_path: Optional[str] = Field(
110
+ default=None,
111
+ description="Path for the authorize token endpoint"
112
+ )
113
+ api_timeout: int = Field(
114
+ default=30,
115
+ description="API request timeout in seconds"
116
+ )
117
+ mfa_cancelled_message: str = Field(
118
+ default="Authentication has been cancelled.",
119
+ description="Message to display when user cancels MFA authentication"
120
+ )
121
+
122
+ model_config = SettingsConfigDict(
123
+ case_sensitive=False,
124
+ extra='ignore'
125
+ )
@@ -0,0 +1,315 @@
1
+ from typing import Optional, Dict, Any, Tuple
2
+
3
+ import yaml
4
+ from jinja2 import Environment
5
+ from langgraph.checkpoint.memory import InMemorySaver
6
+ from langgraph.constants import START
7
+ from langgraph.graph import StateGraph
8
+ from langgraph.graph.state import CompiledStateGraph
9
+
10
+ from .constants import WorkflowKeys, MFAConfig
11
+ from .state import create_state_model
12
+ from ..nodes.factory import NodeFactory
13
+ from ..routing.router import WorkflowRouter
14
+ from ..utils.function import FunctionRepository
15
+ from ..utils.logger import logger
16
+ from ..utils.tool import ToolRepository
17
+ from ..validation import validate_workflow
18
+ from soprano_sdk.authenticators.mfa import MFANodeConfig
19
+
20
+ class WorkflowEngine:
21
+
22
+ def __init__(self, yaml_path: str, configs: dict, mfa_config: Optional[MFAConfig] = None):
23
+ self.yaml_path = yaml_path
24
+ self.configs = configs or {}
25
+ logger.info(f"Loading workflow from: {yaml_path}")
26
+
27
+ try:
28
+ with open(yaml_path, 'r') as f:
29
+ self.config = yaml.safe_load(f)
30
+
31
+ logger.info("Validating workflow configuration")
32
+ validate_workflow(self.config, mfa_config=mfa_config or MFAConfig())
33
+
34
+ self.workflow_name = self.config['name']
35
+ self.workflow_description = self.config['description']
36
+ self.workflow_version = self.config['version']
37
+ self.mfa_validator_steps: set[str] = set()
38
+ self.steps: list = self.load_steps()
39
+ self.step_map = {step['id']: step for step in self.steps}
40
+ self.mfa_config = (mfa_config or MFAConfig()) if self.mfa_validator_steps else None
41
+ self.data_fields = self.load_data()
42
+ self.outcomes = self.load_outcomes()
43
+ self.metadata = self.config.get('metadata', {})
44
+
45
+ self.StateType = create_state_model(self.data_fields)
46
+
47
+ self.outcome_map = {outcome['id']: outcome for outcome in self.outcomes}
48
+
49
+ self.function_repository = FunctionRepository()
50
+ self.tool_repository = None
51
+ if tool_config := self.config.get("tool_config"):
52
+ self.tool_repository = ToolRepository(tool_config)
53
+
54
+ self.context_store = {}
55
+ self.collect_input_fields = self._get_collect_input_fields()
56
+
57
+ logger.info(
58
+ f"Workflow loaded: {self.workflow_name} v{self.workflow_version} "
59
+ f"({len(self.steps)} steps, {len(self.outcomes)} outcomes)"
60
+ )
61
+
62
+ except Exception as e:
63
+ raise e
64
+
65
+ def get_config_value(self, key, default_value: Optional[Any]=None):
66
+ if value := self.configs.get(key) :
67
+ return value
68
+
69
+ if value := self.config.get(key) :
70
+ return value
71
+
72
+ return default_value
73
+
74
+ def _get_collect_input_fields(self) -> set:
75
+ fields = set()
76
+ for step in self.steps:
77
+ if step.get('action') == 'collect_input_with_agent' and (field := step.get('field')):
78
+ fields.add(field)
79
+ return fields
80
+
81
+ def update_context(self, context: Dict[str, Any]):
82
+ self.context_store.update(context)
83
+ logger.info(f"Context updated: {context}")
84
+
85
+ def remove_context_field(self, field_name: str):
86
+ if field_name in self.context_store:
87
+ del self.context_store[field_name]
88
+ logger.info(f"Removed context field: {field_name}")
89
+
90
+ def get_context_value(self, field_name: str):
91
+ value = self.context_store.get(field_name, None)
92
+ if value is not None:
93
+ logger.info(f"Retrieved context value for '{field_name}': {value}")
94
+ return value
95
+
96
+ def build_graph(self, checkpointer=None):
97
+ logger.info("Building workflow graph")
98
+
99
+ try:
100
+ builder = StateGraph(self.StateType)
101
+
102
+ collector_nodes = []
103
+
104
+ logger.info("Adding nodes to graph")
105
+ for step in self.steps:
106
+ step_id = step['id']
107
+ action = step['action']
108
+
109
+ if action == 'collect_input_with_agent':
110
+ collector_nodes.append(step_id)
111
+
112
+ node_fn = NodeFactory.create(step, engine_context=self)
113
+ builder.add_node(step_id, node_fn)
114
+
115
+ logger.info(f"Added node: {step_id} (action: {action})")
116
+
117
+ first_step_id = self.steps[0]['id']
118
+ builder.add_edge(START, first_step_id)
119
+ logger.info(f"Set entry point: {first_step_id}")
120
+
121
+ logger.info("Adding routing edges")
122
+ for step in self.steps:
123
+ step_id = step['id']
124
+
125
+ router = WorkflowRouter(step, self.step_map, self.outcome_map)
126
+ route_fn = router.create_route_function()
127
+ routing_map = router.get_routing_map(collector_nodes)
128
+
129
+ builder.add_conditional_edges(step_id, route_fn, routing_map)
130
+
131
+ logger.info(
132
+ f"Added routing for {step_id}: {len(routing_map)} destinations"
133
+ )
134
+
135
+ if checkpointer is None:
136
+ checkpointer = InMemorySaver()
137
+ logger.info("Using InMemorySaver for state persistence")
138
+ else:
139
+ logger.info(f"Using custom checkpointer: {type(checkpointer).__name__}")
140
+
141
+ graph = builder.compile(checkpointer=checkpointer)
142
+
143
+ logger.info("Workflow graph built successfully")
144
+ return graph
145
+
146
+ except Exception as e:
147
+ raise RuntimeError(f"Failed to build workflow graph: {e}")
148
+
149
+ def get_outcome_message(self, state: Dict[str, Any]) -> str:
150
+ outcome_id = state.get(WorkflowKeys.OUTCOME_ID)
151
+ step_id = state.get(WorkflowKeys.STEP_ID)
152
+
153
+ outcome = self.outcome_map.get(outcome_id)
154
+ if outcome and 'message' in outcome:
155
+ message = outcome['message']
156
+ template_loader = self.get_config_value("template_loader", Environment())
157
+ message = template_loader.from_string(message).render(state)
158
+ logger.info(f"Outcome message generated in step {step_id}: {message}")
159
+ return message
160
+
161
+ if error := state.get("error"):
162
+ logger.info(f"Outcome error found in step {step_id}: {error}")
163
+ return f"{error}"
164
+
165
+ if message := state.get(WorkflowKeys.MESSAGES):
166
+ logger.info(f"Outcome message found in step {step_id}: {message}")
167
+ return f"{message}"
168
+
169
+ logger.error(f"No outcome message found in step {step_id}")
170
+ return "{'error': 'Unable to complete the request'}"
171
+
172
+ def get_step_info(self, step_id: str) -> Optional[Dict[str, Any]]:
173
+ return self.step_map.get(step_id)
174
+
175
+ def get_outcome_info(self, outcome_id: str) -> Optional[Dict[str, Any]]:
176
+ return self.outcome_map.get(outcome_id)
177
+
178
+ def list_steps(self) -> list:
179
+ return [step['id'] for step in self.steps]
180
+
181
+ def list_outcomes(self) -> list:
182
+ return [outcome['id'] for outcome in self.outcomes]
183
+
184
+ def get_workflow_info(self) -> Dict[str, Any]:
185
+ return {
186
+ 'name': self.workflow_name,
187
+ 'description': self.workflow_description,
188
+ 'version': self.workflow_version,
189
+ 'steps': len(self.steps),
190
+ 'outcomes': len(self.outcomes),
191
+ 'data_fields': [f['name'] for f in self.data_fields],
192
+ 'metadata': self.metadata
193
+ }
194
+
195
+ def get_tool_policy(self) -> str:
196
+ tool_config = self.config.get('tool_config')
197
+ if not tool_config:
198
+ raise ValueError("Tool config is not provided in the YAML")
199
+ return tool_config.get('usage_policy')
200
+
201
+
202
+ def load_steps(self):
203
+ prepared_steps: list = []
204
+ mfa_redirects: Dict[str, str] = {}
205
+
206
+ for step in self.config['steps']:
207
+ step_id = step['id']
208
+
209
+ if mfa_config := step.get('mfa'):
210
+ mfa_data_collector = MFANodeConfig.get_validate_user_input(
211
+ next_node=step_id,
212
+ source_node=step_id,
213
+ mfa_config=mfa_config
214
+ )
215
+ mfa_start = MFANodeConfig.get_call_function_template(
216
+ source_node=step_id,
217
+ next_node=mfa_data_collector['id'],
218
+ mfa=mfa_config
219
+ )
220
+
221
+ prepared_steps.append(mfa_start)
222
+ prepared_steps.append(mfa_data_collector)
223
+ self.mfa_validator_steps.add(mfa_data_collector['id'])
224
+
225
+ mfa_redirects[step_id] = mfa_start['id']
226
+
227
+ del step['mfa']
228
+
229
+ prepared_steps.append(step)
230
+
231
+ for step in prepared_steps:
232
+ if step['id'] in self.mfa_validator_steps: # MFA Validator
233
+ continue
234
+
235
+ elif 'mfa' in step: # MFA Start
236
+ continue
237
+
238
+ elif step.get('transitions'):
239
+ for transition in step.get('transitions'):
240
+ next_step = transition.get('next')
241
+ if next_step in mfa_redirects:
242
+ transition['next'] = mfa_redirects[next_step]
243
+
244
+ elif step.get('next') in mfa_redirects:
245
+ step['next'] = mfa_redirects[step['next']]
246
+
247
+ return prepared_steps
248
+
249
+ def load_data(self):
250
+ data: list = self.config['data']
251
+ for step_id in self.mfa_validator_steps:
252
+ step_details = self.step_map[step_id]
253
+ data.append(
254
+ dict(
255
+ name=f'{step_details['field']}',
256
+ type='text',
257
+ description='Input Recieved from the user during MFA'
258
+ )
259
+ )
260
+ return data
261
+
262
+ def load_outcomes(self):
263
+ outcomes: list = self.config['outcomes']
264
+
265
+ if self.mfa_config:
266
+ mfa_cancelled_outcome = {
267
+ 'id': 'mfa_cancelled',
268
+ 'type': 'failure',
269
+ 'message': self.mfa_config.mfa_cancelled_message
270
+ }
271
+ outcomes.append(mfa_cancelled_outcome)
272
+ logger.info(f"Auto-generated 'mfa_cancelled' outcome with message: {self.mfa_config.mfa_cancelled_message}")
273
+
274
+ return outcomes
275
+
276
+
277
+ def load_workflow(yaml_path: str, checkpointer=None, config=None, mfa_config: Optional[MFAConfig] = None) -> Tuple[CompiledStateGraph, WorkflowEngine]:
278
+ """
279
+ Load a workflow from YAML configuration.
280
+
281
+ This is the main entry point for using the framework.
282
+
283
+ Args:
284
+ yaml_path: Path to the workflow YAML file
285
+ checkpointer: Optional checkpointer for state persistence.
286
+ Defaults to InMemorySaver() if not provided.
287
+ Example: MongoDBSaver for production persistence.
288
+ config: Optional configuration dictionary
289
+ mfa_config: Optional MFA configuration. If not provided, will load from environment variables.
290
+
291
+ Returns:
292
+ Tuple of (compiled_graph, engine) where:
293
+ - compiled_graph: LangGraph ready for execution
294
+ - engine: WorkflowEngine instance for introspection
295
+
296
+ Example:
297
+ ```python
298
+ # Load with environment variables
299
+ graph, engine = load_workflow("workflow.yaml")
300
+
301
+ # Or provide MFA configuration explicitly
302
+ from soprano_sdk.core.constants import MFAConfig
303
+ mfa_config = MFAConfig(
304
+ generate_token_base_url="https://api.example.com",
305
+ generate_token_path="/v1/mfa/generate"
306
+ )
307
+ graph, engine = load_workflow("workflow.yaml", mfa_config=mfa_config)
308
+
309
+ result = graph.invoke({}, config={"configurable": {"thread_id": "123"}})
310
+ message = engine.get_outcome_message(result)
311
+ ```
312
+ """
313
+ engine = WorkflowEngine(yaml_path, configs=config, mfa_config=mfa_config)
314
+ graph = engine.build_graph(checkpointer=checkpointer)
315
+ return graph, engine
@@ -0,0 +1,258 @@
1
+ import copy
2
+ import uuid
3
+ from abc import ABC, abstractmethod
4
+ from datetime import datetime
5
+ from typing import Dict, Any, List
6
+
7
+ from soprano_sdk.core.constants import WorkflowKeys, ActionType
8
+ from ..utils.logger import logger
9
+
10
+ class RollbackStrategy(ABC):
11
+ @abstractmethod
12
+ def rollback_to_node(
13
+ self,
14
+ state: Dict[str, Any],
15
+ target_node: str,
16
+ node_execution_order: List[str],
17
+ node_field_map: Dict[str, str],
18
+ workflow_steps: List[Dict[str, Any]]
19
+ ) -> Dict[str, Any]:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def should_save_snapshot(self) -> bool:
24
+ pass
25
+
26
+ @abstractmethod
27
+ def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
28
+ pass
29
+
30
+ @abstractmethod
31
+ def get_strategy_name(self) -> str:
32
+ pass
33
+
34
+
35
+ def _restore_from_snapshot(snapshot: Dict[str, Any]) -> Dict[str, Any]:
36
+ return copy.deepcopy(snapshot['state'])
37
+
38
+
39
+ def _clear_future_executions(
40
+ state: Dict[str, Any],
41
+ target_node: str,
42
+ workflow_steps: List[Dict[str, Any]]
43
+ ) -> Dict[str, Any]:
44
+ target_step_index = next(
45
+ (i for i, step in enumerate(workflow_steps) if step['id'] == target_node),
46
+ None
47
+ )
48
+
49
+ if target_step_index is None:
50
+ logger.warning(f"Target node {target_node} not found in workflow steps")
51
+ return state
52
+
53
+ future_steps = workflow_steps[target_step_index:]
54
+
55
+ logger.info(f"Future steps to clear: {[s['id'] for s in future_steps]}")
56
+
57
+ for step in future_steps:
58
+ action = step.get('action')
59
+
60
+ if action == ActionType.COLLECT_INPUT_WITH_AGENT.value:
61
+ field_name = step.get('field')
62
+ if field_name:
63
+ state[field_name] = None
64
+
65
+ conv_key = f"{field_name}_conversation"
66
+ conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
67
+ if conv_key in conversations:
68
+ del conversations[conv_key]
69
+ logger.info(f"Cleared conversation: {conv_key}")
70
+
71
+ elif action == ActionType.CALL_FUNCTION.value:
72
+ output_field = step.get('output')
73
+ if output_field:
74
+ state[output_field] = None
75
+ logger.info(f"Cleared computed field: {output_field}")
76
+
77
+ return state
78
+
79
+
80
+ class HistoryBasedRollback(RollbackStrategy):
81
+ def get_strategy_name(self) -> str:
82
+ return "history_based"
83
+
84
+ def should_save_snapshot(self) -> bool:
85
+ return True
86
+
87
+ def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
88
+ state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
89
+
90
+ snapshot = {
91
+ 'snapshot_id': str(uuid.uuid4()),
92
+ 'node_about_to_execute': node_id,
93
+ 'execution_index': execution_index,
94
+ 'timestamp': datetime.now().isoformat(),
95
+ 'state': copy.deepcopy(state),
96
+ }
97
+
98
+ state_history.append(snapshot)
99
+ state[WorkflowKeys.STATE_HISTORY] = state_history
100
+
101
+ logger.info(f"Saved snapshot #{len(state_history)-1} before executing {node_id}")
102
+
103
+ def rollback_to_node(
104
+ self,
105
+ state: Dict[str, Any],
106
+ target_node: str,
107
+ node_execution_order: List[str],
108
+ node_field_map: Dict[str, str],
109
+ workflow_steps: List[Dict[str, Any]]
110
+ ) -> Dict[str, Any]:
111
+ state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
112
+
113
+ if not state_history:
114
+ logger.warning("No state history available for rollback")
115
+ return {}
116
+
117
+ logger.info(f"Looking for snapshot before node '{target_node}'")
118
+
119
+ target_snapshot = None
120
+ target_index = None
121
+
122
+ for i, snapshot in enumerate(state_history):
123
+ if snapshot.get('node_about_to_execute') == target_node:
124
+ target_snapshot = snapshot
125
+ target_index = i
126
+ break
127
+
128
+ if target_snapshot is None:
129
+ logger.warning(f"No snapshot found before node '{target_node}'")
130
+ return {}
131
+
132
+ logger.info(f"Found snapshot at index {target_index}")
133
+ restored_state = _restore_from_snapshot(target_snapshot)
134
+
135
+ restored_state = _clear_future_executions(
136
+ restored_state,
137
+ target_node,
138
+ workflow_steps
139
+ )
140
+
141
+ restored_state[WorkflowKeys.STATE_HISTORY] = state_history[:target_index + 1]
142
+
143
+ logger.info(f"Successfully rolled back to {target_node}")
144
+
145
+ return restored_state
146
+
147
+
148
+ def _build_dependency_graph(
149
+ workflow_steps: List[Dict[str, Any]]
150
+ ) -> Dict[str, List[str]]:
151
+ graph = {}
152
+
153
+ for step in workflow_steps:
154
+ field = step.get('field') or step.get('output')
155
+
156
+ if not field:
157
+ continue
158
+
159
+ depends_on = step.get('depends_on')
160
+
161
+ if depends_on:
162
+ if isinstance(depends_on, str):
163
+ depends_on_list = [depends_on]
164
+ elif isinstance(depends_on, list):
165
+ depends_on_list = depends_on
166
+ else:
167
+ logger.warning(f"Invalid depends_on type for field '{field}': {type(depends_on)}")
168
+ depends_on_list = []
169
+
170
+ for parent_field in depends_on_list:
171
+ if parent_field not in graph:
172
+ graph[parent_field] = []
173
+ if field not in graph[parent_field]:
174
+ graph[parent_field].append(field)
175
+
176
+ if field not in graph:
177
+ graph[field] = []
178
+
179
+ return graph
180
+
181
+
182
+ def _find_all_dependents(
183
+ field: str,
184
+ dependency_graph: Dict[str, List[str]]
185
+ ) -> set:
186
+ all_dependents = set()
187
+ visited = set()
188
+
189
+ def _recurse(current_field: str):
190
+ if current_field in visited:
191
+ return
192
+ visited.add(current_field)
193
+
194
+ direct_dependents = dependency_graph.get(current_field, [])
195
+
196
+ for dependent in direct_dependents:
197
+ all_dependents.add(dependent)
198
+ _recurse(dependent)
199
+
200
+ _recurse(field)
201
+
202
+ return all_dependents
203
+
204
+
205
+ def _clear_field_conversation(state: Dict[str, Any], field: str) -> None:
206
+ conv_key = f"{field}_conversation"
207
+ conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
208
+
209
+ if conv_key in conversations:
210
+ del conversations[conv_key]
211
+ logger.info(f"Cleared conversation: {conv_key}")
212
+
213
+
214
+ class DependencyBasedRollback(RollbackStrategy):
215
+ def get_strategy_name(self) -> str:
216
+ return "dependency_based"
217
+
218
+ def should_save_snapshot(self) -> bool:
219
+ return False
220
+
221
+ def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
222
+ return None
223
+
224
+ def rollback_to_node(
225
+ self,
226
+ state: Dict[str, Any],
227
+ target_node: str,
228
+ node_execution_order: List[str],
229
+ node_field_map: Dict[str, str],
230
+ workflow_steps: List[Dict[str, Any]]
231
+ ) -> Dict[str, Any]:
232
+ target_field = node_field_map.get(target_node)
233
+
234
+ if not target_field:
235
+ logger.warning(f"No field found for target node '{target_node}'")
236
+ return state
237
+
238
+ logger.info(f"Rolling back to node '{target_node}' (field: '{target_field}')")
239
+
240
+ dependency_graph = _build_dependency_graph(workflow_steps)
241
+
242
+ logger.info(f"Dependency graph: {dependency_graph}")
243
+
244
+ dependent_fields = _find_all_dependents(target_field, dependency_graph)
245
+
246
+ logger.info(f"Fields dependent on '{target_field}': {dependent_fields}")
247
+
248
+ state[target_field] = None
249
+ _clear_field_conversation(state, target_field)
250
+
251
+ for field in dependent_fields:
252
+ state[field] = None
253
+ _clear_field_conversation(state, field)
254
+ logger.info(f"Cleared dependent field: {field}")
255
+
256
+ logger.info(f"Successfully rolled back to {target_node} using dependency graph")
257
+
258
+ return state