soprano-sdk 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,225 @@
1
+ from typing import Optional, Dict, Any, Tuple
2
+
3
+ import yaml
4
+ from jinja2 import Environment
5
+ from langgraph.checkpoint.memory import InMemorySaver
6
+ from langgraph.constants import START
7
+ from langgraph.graph import StateGraph
8
+ from langgraph.graph.state import CompiledStateGraph
9
+
10
+ from .constants import WorkflowKeys
11
+ from .state import create_state_model
12
+ from ..nodes.factory import NodeFactory
13
+ from ..routing.router import WorkflowRouter
14
+ from ..utils.function import FunctionRepository
15
+ from ..utils.logger import logger
16
+ from ..utils.tool import ToolRepository
17
+ from ..validation import validate_workflow
18
+
19
+
20
+ class WorkflowEngine:
21
+ def __init__(self, yaml_path: str, configs: dict):
22
+ self.yaml_path = yaml_path
23
+ self.configs = configs
24
+ logger.info(f"Loading workflow from: {yaml_path}")
25
+
26
+ try:
27
+ with open(yaml_path, 'r') as f:
28
+ self.config = yaml.safe_load(f)
29
+
30
+ logger.info("Validating workflow configuration")
31
+ validate_workflow(self.config)
32
+
33
+ self.workflow_name = self.config['name']
34
+ self.workflow_description = self.config['description']
35
+ self.workflow_version = self.config['version']
36
+ self.data_fields = self.config['data']
37
+ self.steps = self.config['steps']
38
+ self.outcomes = self.config['outcomes']
39
+ self.metadata = self.config.get('metadata', {})
40
+
41
+ self.StateType = create_state_model(self.data_fields)
42
+
43
+ self.step_map = {step['id']: step for step in self.steps}
44
+ self.outcome_map = {outcome['id']: outcome for outcome in self.outcomes}
45
+
46
+ self.function_repository = FunctionRepository()
47
+ self.tool_repository = None
48
+ if tool_config := self.config.get("tool_config"):
49
+ self.tool_repository = ToolRepository(tool_config)
50
+
51
+ self.context_store = {}
52
+ self.collect_input_fields = self._get_collect_input_fields()
53
+
54
+ logger.info(
55
+ f"Workflow loaded: {self.workflow_name} v{self.workflow_version} "
56
+ f"({len(self.steps)} steps, {len(self.outcomes)} outcomes)"
57
+ )
58
+
59
+ except Exception as e:
60
+ raise e
61
+
62
+ def get_config_value(self, key, default_value: Optional[Any]=None):
63
+ if value := self.configs.get(key) :
64
+ return value
65
+
66
+ if value := self.config.get(key) :
67
+ return value
68
+
69
+ return default_value
70
+
71
+ def _get_collect_input_fields(self) -> set:
72
+ fields = set()
73
+ for step in self.steps:
74
+ if step.get('action') == 'collect_input_with_agent' and (field := step.get('field')):
75
+ fields.add(field)
76
+ return fields
77
+
78
+ def update_context(self, context: Dict[str, Any]):
79
+ self.context_store.update(context)
80
+ logger.info(f"Context updated: {context}")
81
+
82
+ def remove_context_field(self, field_name: str):
83
+ if field_name in self.context_store:
84
+ del self.context_store[field_name]
85
+ logger.info(f"Removed context field: {field_name}")
86
+
87
+ def get_context_value(self, field_name: str):
88
+ value = self.context_store.get(field_name, None)
89
+ if value is not None:
90
+ logger.info(f"Retrieved context value for '{field_name}': {value}")
91
+ return value
92
+
93
+ def build_graph(self, checkpointer=None):
94
+ logger.info("Building workflow graph")
95
+
96
+ try:
97
+ builder = StateGraph(self.StateType)
98
+
99
+ collector_nodes = []
100
+
101
+ logger.info("Adding nodes to graph")
102
+ for step in self.steps:
103
+ step_id = step['id']
104
+ action = step['action']
105
+
106
+ if action == 'collect_input_with_agent':
107
+ collector_nodes.append(step_id)
108
+
109
+ node_fn = NodeFactory.create(step, engine_context=self)
110
+ builder.add_node(step_id, node_fn)
111
+
112
+ logger.info(f"Added node: {step_id} (action: {action})")
113
+
114
+ first_step_id = self.steps[0]['id']
115
+ builder.add_edge(START, first_step_id)
116
+ logger.info(f"Set entry point: {first_step_id}")
117
+
118
+ logger.info("Adding routing edges")
119
+ for step in self.steps:
120
+ step_id = step['id']
121
+
122
+ router = WorkflowRouter(step, self.step_map, self.outcome_map)
123
+ route_fn = router.create_route_function()
124
+ routing_map = router.get_routing_map(collector_nodes)
125
+
126
+ builder.add_conditional_edges(step_id, route_fn, routing_map)
127
+
128
+ logger.info(
129
+ f"Added routing for {step_id}: {len(routing_map)} destinations"
130
+ )
131
+
132
+ if checkpointer is None:
133
+ checkpointer = InMemorySaver()
134
+ logger.info("Using InMemorySaver for state persistence")
135
+ else:
136
+ logger.info(f"Using custom checkpointer: {type(checkpointer).__name__}")
137
+
138
+ graph = builder.compile(checkpointer=checkpointer)
139
+
140
+ logger.info("Workflow graph built successfully")
141
+ return graph
142
+
143
+ except Exception as e:
144
+ raise RuntimeError(f"Failed to build workflow graph: {e}")
145
+
146
+ def get_outcome_message(self, state: Dict[str, Any]) -> str:
147
+ outcome_id = state.get(WorkflowKeys.OUTCOME_ID)
148
+ step_id = state.get(WorkflowKeys.STEP_ID)
149
+
150
+ outcome = self.outcome_map.get(outcome_id)
151
+ if outcome and 'message' in outcome:
152
+ message = outcome['message']
153
+ template_loader = self.get_config_value("template_loader", Environment())
154
+ message = template_loader.from_string(message).render(state)
155
+ logger.info(f"Outcome message generated in step {step_id}: {message}")
156
+ return message
157
+
158
+ if error := state.get("error"):
159
+ logger.info(f"Outcome error found in step {step_id}: {error}")
160
+ return f"{error}"
161
+
162
+ if message := state.get(WorkflowKeys.MESSAGES):
163
+ logger.info(f"Outcome message found in step {step_id}: {message}")
164
+ return f"{message}"
165
+
166
+ logger.error(f"No outcome message found in step {step_id}")
167
+ return "{'error': 'Unable to complete the request'}"
168
+
169
+ def get_step_info(self, step_id: str) -> Optional[Dict[str, Any]]:
170
+ return self.step_map.get(step_id)
171
+
172
+ def get_outcome_info(self, outcome_id: str) -> Optional[Dict[str, Any]]:
173
+ return self.outcome_map.get(outcome_id)
174
+
175
+ def list_steps(self) -> list:
176
+ return [step['id'] for step in self.steps]
177
+
178
+ def list_outcomes(self) -> list:
179
+ return [outcome['id'] for outcome in self.outcomes]
180
+
181
+ def get_workflow_info(self) -> Dict[str, Any]:
182
+ return {
183
+ 'name': self.workflow_name,
184
+ 'description': self.workflow_description,
185
+ 'version': self.workflow_version,
186
+ 'steps': len(self.steps),
187
+ 'outcomes': len(self.outcomes),
188
+ 'data_fields': [f['name'] for f in self.data_fields],
189
+ 'metadata': self.metadata
190
+ }
191
+
192
+ def get_tool_policy(self) -> str:
193
+ tool_config = self.config.get('tool_config')
194
+ if not tool_config:
195
+ raise ValueError("Tool config is not provided in the YAML")
196
+ return tool_config.get('usage_policy')
197
+
198
+
199
+ def load_workflow(yaml_path: str, checkpointer=None, config=None) -> Tuple[CompiledStateGraph, WorkflowEngine]:
200
+ """
201
+ Load a workflow from YAML configuration.
202
+
203
+ This is the main entry point for using the framework.
204
+
205
+ Args:
206
+ yaml_path: Path to the workflow YAML file
207
+ checkpointer: Optional checkpointer for state persistence.
208
+ Defaults to InMemorySaver() if not provided.
209
+ Example: MongoDBSaver for production persistence.
210
+
211
+ Returns:
212
+ Tuple of (compiled_graph, engine) where:
213
+ - compiled_graph: LangGraph ready for execution
214
+ - engine: WorkflowEngine instance for introspection
215
+
216
+ Example:
217
+ ```python
218
+ graph, engine = load_workflow("workflow.yaml")
219
+ result = graph.invoke({}, config={"configurable": {"thread_id": "123"}})
220
+ message = engine.get_outcome_message(result)
221
+ ```
222
+ """
223
+ engine = WorkflowEngine(yaml_path, configs=config)
224
+ graph = engine.build_graph(checkpointer=checkpointer)
225
+ return graph, engine
@@ -0,0 +1,259 @@
1
+ import copy
2
+ import logging
3
+ import uuid
4
+ from abc import ABC, abstractmethod
5
+ from datetime import datetime
6
+ from typing import Dict, Any, List
7
+
8
+ from soprano_sdk.core.constants import WorkflowKeys, ActionType
9
+ from ..utils.logger import logger
10
+
11
+ class RollbackStrategy(ABC):
12
+ @abstractmethod
13
+ def rollback_to_node(
14
+ self,
15
+ state: Dict[str, Any],
16
+ target_node: str,
17
+ node_execution_order: List[str],
18
+ node_field_map: Dict[str, str],
19
+ workflow_steps: List[Dict[str, Any]]
20
+ ) -> Dict[str, Any]:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def should_save_snapshot(self) -> bool:
25
+ pass
26
+
27
+ @abstractmethod
28
+ def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
29
+ pass
30
+
31
+ @abstractmethod
32
+ def get_strategy_name(self) -> str:
33
+ pass
34
+
35
+
36
+ def _restore_from_snapshot(snapshot: Dict[str, Any]) -> Dict[str, Any]:
37
+ return copy.deepcopy(snapshot['state'])
38
+
39
+
40
+ def _clear_future_executions(
41
+ state: Dict[str, Any],
42
+ target_node: str,
43
+ workflow_steps: List[Dict[str, Any]]
44
+ ) -> Dict[str, Any]:
45
+ target_step_index = next(
46
+ (i for i, step in enumerate(workflow_steps) if step['id'] == target_node),
47
+ None
48
+ )
49
+
50
+ if target_step_index is None:
51
+ logger.warning(f"Target node {target_node} not found in workflow steps")
52
+ return state
53
+
54
+ future_steps = workflow_steps[target_step_index:]
55
+
56
+ logger.info(f"Future steps to clear: {[s['id'] for s in future_steps]}")
57
+
58
+ for step in future_steps:
59
+ action = step.get('action')
60
+
61
+ if action == ActionType.COLLECT_INPUT_WITH_AGENT.value:
62
+ field_name = step.get('field')
63
+ if field_name:
64
+ state[field_name] = None
65
+
66
+ conv_key = f"{field_name}_conversation"
67
+ conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
68
+ if conv_key in conversations:
69
+ del conversations[conv_key]
70
+ logger.info(f"Cleared conversation: {conv_key}")
71
+
72
+ elif action == ActionType.CALL_FUNCTION.value:
73
+ output_field = step.get('output')
74
+ if output_field:
75
+ state[output_field] = None
76
+ logger.info(f"Cleared computed field: {output_field}")
77
+
78
+ return state
79
+
80
+
81
+ class HistoryBasedRollback(RollbackStrategy):
82
+ def get_strategy_name(self) -> str:
83
+ return "history_based"
84
+
85
+ def should_save_snapshot(self) -> bool:
86
+ return True
87
+
88
+ def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
89
+ state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
90
+
91
+ snapshot = {
92
+ 'snapshot_id': str(uuid.uuid4()),
93
+ 'node_about_to_execute': node_id,
94
+ 'execution_index': execution_index,
95
+ 'timestamp': datetime.now().isoformat(),
96
+ 'state': copy.deepcopy(state),
97
+ }
98
+
99
+ state_history.append(snapshot)
100
+ state[WorkflowKeys.STATE_HISTORY] = state_history
101
+
102
+ logger.info(f"Saved snapshot #{len(state_history)-1} before executing {node_id}")
103
+
104
+ def rollback_to_node(
105
+ self,
106
+ state: Dict[str, Any],
107
+ target_node: str,
108
+ node_execution_order: List[str],
109
+ node_field_map: Dict[str, str],
110
+ workflow_steps: List[Dict[str, Any]]
111
+ ) -> Dict[str, Any]:
112
+ state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
113
+
114
+ if not state_history:
115
+ logger.warning("No state history available for rollback")
116
+ return {}
117
+
118
+ logger.info(f"Looking for snapshot before node '{target_node}'")
119
+
120
+ target_snapshot = None
121
+ target_index = None
122
+
123
+ for i, snapshot in enumerate(state_history):
124
+ if snapshot.get('node_about_to_execute') == target_node:
125
+ target_snapshot = snapshot
126
+ target_index = i
127
+ break
128
+
129
+ if target_snapshot is None:
130
+ logger.warning(f"No snapshot found before node '{target_node}'")
131
+ return {}
132
+
133
+ logger.info(f"Found snapshot at index {target_index}")
134
+ restored_state = _restore_from_snapshot(target_snapshot)
135
+
136
+ restored_state = _clear_future_executions(
137
+ restored_state,
138
+ target_node,
139
+ workflow_steps
140
+ )
141
+
142
+ restored_state[WorkflowKeys.STATE_HISTORY] = state_history[:target_index + 1]
143
+
144
+ logger.info(f"Successfully rolled back to {target_node}")
145
+
146
+ return restored_state
147
+
148
+
149
+ def _build_dependency_graph(
150
+ workflow_steps: List[Dict[str, Any]]
151
+ ) -> Dict[str, List[str]]:
152
+ graph = {}
153
+
154
+ for step in workflow_steps:
155
+ field = step.get('field') or step.get('output')
156
+
157
+ if not field:
158
+ continue
159
+
160
+ depends_on = step.get('depends_on')
161
+
162
+ if depends_on:
163
+ if isinstance(depends_on, str):
164
+ depends_on_list = [depends_on]
165
+ elif isinstance(depends_on, list):
166
+ depends_on_list = depends_on
167
+ else:
168
+ logger.warning(f"Invalid depends_on type for field '{field}': {type(depends_on)}")
169
+ depends_on_list = []
170
+
171
+ for parent_field in depends_on_list:
172
+ if parent_field not in graph:
173
+ graph[parent_field] = []
174
+ if field not in graph[parent_field]:
175
+ graph[parent_field].append(field)
176
+
177
+ if field not in graph:
178
+ graph[field] = []
179
+
180
+ return graph
181
+
182
+
183
+ def _find_all_dependents(
184
+ field: str,
185
+ dependency_graph: Dict[str, List[str]]
186
+ ) -> set:
187
+ all_dependents = set()
188
+ visited = set()
189
+
190
+ def _recurse(current_field: str):
191
+ if current_field in visited:
192
+ return
193
+ visited.add(current_field)
194
+
195
+ direct_dependents = dependency_graph.get(current_field, [])
196
+
197
+ for dependent in direct_dependents:
198
+ all_dependents.add(dependent)
199
+ _recurse(dependent)
200
+
201
+ _recurse(field)
202
+
203
+ return all_dependents
204
+
205
+
206
+ def _clear_field_conversation(state: Dict[str, Any], field: str) -> None:
207
+ conv_key = f"{field}_conversation"
208
+ conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
209
+
210
+ if conv_key in conversations:
211
+ del conversations[conv_key]
212
+ logger.info(f"Cleared conversation: {conv_key}")
213
+
214
+
215
+ class DependencyBasedRollback(RollbackStrategy):
216
+ def get_strategy_name(self) -> str:
217
+ return "dependency_based"
218
+
219
+ def should_save_snapshot(self) -> bool:
220
+ return False
221
+
222
+ def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
223
+ return None
224
+
225
+ def rollback_to_node(
226
+ self,
227
+ state: Dict[str, Any],
228
+ target_node: str,
229
+ node_execution_order: List[str],
230
+ node_field_map: Dict[str, str],
231
+ workflow_steps: List[Dict[str, Any]]
232
+ ) -> Dict[str, Any]:
233
+ target_field = node_field_map.get(target_node)
234
+
235
+ if not target_field:
236
+ logger.warning(f"No field found for target node '{target_node}'")
237
+ return state
238
+
239
+ logger.info(f"Rolling back to node '{target_node}' (field: '{target_field}')")
240
+
241
+ dependency_graph = _build_dependency_graph(workflow_steps)
242
+
243
+ logger.info(f"Dependency graph: {dependency_graph}")
244
+
245
+ dependent_fields = _find_all_dependents(target_field, dependency_graph)
246
+
247
+ logger.info(f"Fields dependent on '{target_field}': {dependent_fields}")
248
+
249
+ state[target_field] = None
250
+ _clear_field_conversation(state, target_field)
251
+
252
+ for field in dependent_fields:
253
+ state[field] = None
254
+ _clear_field_conversation(state, field)
255
+ logger.info(f"Cleared dependent field: {field}")
256
+
257
+ logger.info(f"Successfully rolled back to {target_node} using dependency graph")
258
+
259
+ return state
@@ -0,0 +1,71 @@
1
+ import types
2
+ from typing import Annotated, Optional, Dict, List, Any, Type
3
+
4
+ from typing_extensions import TypedDict
5
+
6
+ from .constants import DataType
7
+
8
+
9
+ def replace(_, right):
10
+ return right
11
+
12
+
13
+ def get_state_value(state: Dict[str, Any], key: str, default: Any = None) -> Any:
14
+ return state.get(key, default)
15
+
16
+ def set_state_value(state: Dict[str, Any], key: str, value: Any) -> None:
17
+ state[key] = value
18
+
19
+ def create_state_model(data_fields: List[dict]):
20
+ type_mapping = {
21
+ DataType.TEXT.value: Optional[str],
22
+ DataType.NUMBER.value: Optional[int],
23
+ DataType.DOUBLE.value: Optional[float],
24
+ DataType.BOOLEAN.value: Optional[bool],
25
+ DataType.LIST.value: Optional[List[Any]],
26
+ DataType.DICT.value: Optional[Dict[str, Any]],
27
+ DataType.ANY.value: Optional[Any]
28
+ }
29
+
30
+ fields = {}
31
+ for field_def in data_fields:
32
+ field_name = field_def['name']
33
+ field_type = field_def['type']
34
+ python_type = type_mapping.get(field_type, Optional[str])
35
+ fields[field_name] = Annotated[python_type, replace]
36
+
37
+ fields['_step_id'] = Annotated[Optional[str], replace]
38
+ fields['_status'] = Annotated[Optional[str], replace]
39
+ fields['_outcome_id'] = Annotated[Optional[str], replace]
40
+
41
+ fields['_messages'] = Annotated[List[str], replace]
42
+ fields['_conversations'] = Annotated[Dict[str, List[Dict[str, str]]], replace]
43
+ fields['_state_history'] = Annotated[List[Dict[str, Any]], replace]
44
+ fields['_collector_nodes'] = Annotated[Dict[str, str], replace]
45
+ fields['_attempt_counts'] = Annotated[Dict[str, int], replace]
46
+ fields['_node_execution_order'] = Annotated[List[str], replace]
47
+ fields['_node_field_map'] = Annotated[Dict[str, str], replace]
48
+ fields['_computed_fields'] = Annotated[List[str], replace]
49
+ fields['error'] = Annotated[Optional[Dict[str, str]], replace]
50
+
51
+ return types.new_class('WorkflowState', (TypedDict,), {}, lambda ns: ns.update({'__annotations__': fields}))
52
+
53
+
54
+ def initialize_state(state: Dict[str, Any]) -> Dict[str, Any]:
55
+ fields_to_initialize = {
56
+ '_state_history': [],
57
+ '_collector_nodes': {},
58
+ '_conversations': {},
59
+ '_messages': [],
60
+ '_attempt_counts': {},
61
+ '_node_execution_order': [],
62
+ '_node_field_map': {},
63
+ '_computed_fields': [],
64
+ 'error': None
65
+ }
66
+
67
+ for field_name, default_value in fields_to_initialize.items():
68
+ if not get_state_value(state, field_name):
69
+ set_state_value(state, field_name, default_value.copy() if isinstance(default_value, (list, dict)) else default_value)
70
+
71
+ return state