soprano-sdk 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
soprano_sdk/engine.py ADDED
@@ -0,0 +1,381 @@
1
+ """
2
+ Workflow Engine - Parses YAML workflow definitions and builds LangGraph execution graphs.
3
+ """
4
+
5
+ import yaml
6
+ import importlib
7
+ from typing import TypedDict, Optional, Dict
8
+ from langgraph.graph import StateGraph
9
+ from langgraph.types import interrupt
10
+ from langgraph.constants import START, END
11
+ from langgraph.checkpoint.memory import InMemorySaver
12
+ from agno.agent import Agent
13
+ from agno.models.openai import OpenAIChat
14
+
15
+
16
+ class WorkflowEngine:
17
+
18
+ def __init__(self, yaml_path: str):
19
+ self.yaml_path = yaml_path
20
+ with open(yaml_path, 'r') as f:
21
+ self.config = yaml.safe_load(f)
22
+
23
+ self.workflow_name = self.config['name']
24
+ self.workflow_description = self.config['description']
25
+ self.data_fields = self.config['data']
26
+ self.steps = self.config['steps']
27
+ self.outcomes = self.config['outcomes']
28
+
29
+ # Build state type dynamically
30
+ self.StateType = self._build_state_type()
31
+
32
+ # Track steps by ID
33
+ self.step_map = {step['id']: step for step in self.steps}
34
+ self.outcome_map = {outcome['id']: outcome for outcome in self.outcomes}
35
+
36
+ def _build_state_type(self) -> type:
37
+ fields = {}
38
+ for field in self.data_fields:
39
+ field_name = field['name']
40
+ field_type = field['type']
41
+
42
+ # Map YAML types to Python types
43
+ type_mapping = {
44
+ 'text': Optional[str],
45
+ 'number': Optional[int],
46
+ 'boolean': Optional[bool]
47
+ }
48
+
49
+ fields[field_name] = type_mapping.get(field_type, Optional[str])
50
+
51
+ # Add internal workflow fields
52
+ fields['_status'] = Optional[str]
53
+ fields['_attempt_counts'] = Optional[Dict[str, int]]
54
+ fields['_outcome_id'] = Optional[str]
55
+ fields['_messages'] = Optional[list] # For message accumulation
56
+ fields['_conversations'] = Optional[Dict[str, list]] # Conversation histories per field
57
+
58
+ # Create TypedDict dynamically
59
+ print("Fields: ", fields)
60
+ return TypedDict('WorkflowState', fields)
61
+
62
+ def _create_collect_input_with_agent_node(self, step_config: dict):
63
+ """Create a node that collects input using an AI agent with conversation history"""
64
+ step_id = step_config['id']
65
+ field = step_config['field']
66
+ max_attempts = step_config.get('max_attempts', 5)
67
+ agent_config = step_config['agent']
68
+ transitions = step_config.get('transitions', [])
69
+
70
+ def node_fn(state: dict) -> dict:
71
+ # Initialize conversations dict if not exists
72
+ if '_conversations' not in state or state['_conversations'] is None:
73
+ state['_conversations'] = {}
74
+
75
+ # Initialize messages list if not exists
76
+ if '_messages' not in state or state['_messages'] is None:
77
+ state['_messages'] = []
78
+
79
+ # Check if already collected
80
+ if state.get(field) is not None:
81
+ # Field pre-populated by external orchestrator
82
+ # Route using first transition (success path)
83
+ if transitions:
84
+ first_transition = transitions[0]
85
+ next_step = first_transition['next']
86
+ state['_status'] = f'{step_id}_{next_step}'
87
+
88
+ # If next step is an outcome, set it
89
+ if next_step in self.outcome_map:
90
+ state['_outcome_id'] = next_step
91
+
92
+ return state
93
+
94
+ # Get conversation history for this field
95
+ conv_key = f'{field}_conversation'
96
+ if conv_key not in state['_conversations']:
97
+ state['_conversations'][conv_key] = []
98
+
99
+ conversation = state['_conversations'][conv_key]
100
+
101
+ # Check attempt count based on user messages in conversation
102
+ attempt_count = len([m for m in conversation if m['role'] == 'user'])
103
+
104
+ if attempt_count >= max_attempts:
105
+ state['_status'] = f'{step_id}_max_attempts'
106
+ state['_messages'] = [f"I'm having trouble understanding your {field}. Please contact customer service for assistance."]
107
+ return state
108
+
109
+ # Create agent
110
+ agent = Agent(
111
+ name=agent_config.get('name', f'{field}Collector'),
112
+ model=OpenAIChat(id=agent_config.get('model', 'gpt-4o-mini')),
113
+ instructions=agent_config['instructions']
114
+ )
115
+
116
+ # Generate prompt from agent
117
+ if len(conversation) == 0:
118
+ # First iteration: let agent generate initial greeting
119
+ response = agent.run(conversation)
120
+ prompt = response.content
121
+ conversation.append({"role": "assistant", "content": prompt})
122
+ state['_conversations'][conv_key] = conversation
123
+ else:
124
+ # Use last assistant message as prompt
125
+ prompt = conversation[-1]['content']
126
+
127
+ # Interrupt for user input
128
+ user_input = interrupt(prompt)
129
+
130
+ # Add to conversation
131
+ conversation.append({"role": "user", "content": user_input})
132
+
133
+ # Get agent response
134
+ response = agent.run(conversation)
135
+ agent_response = response.content
136
+ conversation.append({"role": "assistant", "content": agent_response})
137
+
138
+ # Update conversation in state
139
+ state['_conversations'][conv_key] = conversation
140
+
141
+ # Check agent response against all transitions
142
+ matched = False
143
+ for transition in transitions:
144
+ pattern = transition['pattern']
145
+ if pattern in agent_response:
146
+ matched = True
147
+ next_step = transition['next']
148
+
149
+ # Try to extract value after pattern
150
+ value = agent_response.split(pattern)[1].strip()
151
+
152
+ # If there's a value, store it in the field
153
+ if value:
154
+ # Store the value (convert type if needed)
155
+ field_def = next((f for f in self.data_fields if f['name'] == field), None)
156
+ if field_def and field_def['type'] == 'number':
157
+ state[field] = int(value)
158
+ else:
159
+ state[field] = value
160
+
161
+ state['_messages'] = [f"✓ {field.replace('_', ' ').title()} collected: {value}"]
162
+ else:
163
+ # No value extracted (e.g., "ORDER_ID_FAILED:")
164
+ state['_messages'] = []
165
+
166
+ # Set status based on next step
167
+ state['_status'] = f'{step_id}_{next_step}'
168
+
169
+ # If next step is an outcome, set it
170
+ if next_step in self.outcome_map:
171
+ state['_outcome_id'] = next_step
172
+
173
+ break
174
+
175
+ if not matched:
176
+ # Continue collecting - will self-loop
177
+ # Don't store agent_response in _messages as it will become the interrupt prompt
178
+ state['_status'] = f'{step_id}_collecting'
179
+ state['_messages'] = []
180
+
181
+ return state
182
+
183
+ return node_fn
184
+
185
+ def _create_call_function_node(self, step_config: dict):
186
+ step_id = step_config['id']
187
+ function_path = step_config['function']
188
+ inputs = step_config['inputs']
189
+ output_field = step_config['output']
190
+ next_step = step_config.get('next')
191
+ transitions = step_config.get('transitions', [])
192
+
193
+ def node_fn(state: dict) -> dict:
194
+ # Import the function
195
+ module_name, function_name = function_path.rsplit('.', 1)
196
+ module = importlib.import_module(module_name)
197
+ func = getattr(module, function_name)
198
+
199
+ # Prepare inputs by replacing placeholders
200
+ kwargs = {}
201
+ for key, value in inputs.items():
202
+ if isinstance(value, str) and value.startswith('{') and value.endswith('}'):
203
+ # Replace placeholder with state value
204
+ field_name = value[1:-1]
205
+ kwargs[key] = state.get(field_name)
206
+ else:
207
+ kwargs[key] = value
208
+
209
+ # Call the function silently
210
+ result = func(**kwargs)
211
+
212
+ # Store result
213
+ state[output_field] = result
214
+
215
+ # Handle transition-based routing
216
+ if transitions:
217
+ # Check result against all transitions
218
+ for transition in transitions:
219
+ condition = transition['condition']
220
+ next_dest = transition['next']
221
+
222
+ # Match condition against result
223
+ if result == condition:
224
+ state['_status'] = f'{step_id}_{next_dest}'
225
+
226
+ # If next dest is an outcome, set it
227
+ if next_dest in self.outcome_map:
228
+ state['_outcome_id'] = next_dest
229
+
230
+ break
231
+ else:
232
+ # No transitions, use simple success routing
233
+ state['_status'] = f'{step_id}_success'
234
+
235
+ # If next step is an outcome, set it now
236
+ if next_step and next_step in self.outcome_map:
237
+ state['_outcome_id'] = next_step
238
+
239
+ return state
240
+
241
+ return node_fn
242
+
243
+ def _create_routing_function(self, step_config: dict):
244
+ step_id = step_config['id']
245
+ next_step = step_config.get('next')
246
+ transitions = step_config.get('transitions', [])
247
+
248
+ def route_fn(state: dict) -> str:
249
+ status = state.get('_status', '')
250
+
251
+ # Handle transition-based routing (both collect_input_with_agent and call_function)
252
+ if transitions:
253
+ # Check for self-loop (collecting) - only for collect_input_with_agent
254
+ if status == f'{step_id}_collecting':
255
+ return step_id
256
+
257
+ # Check for max_attempts - only for collect_input_with_agent
258
+ if status == f'{step_id}_max_attempts':
259
+ return END
260
+
261
+ # Extract next_step from status
262
+ # Status format: {step_id}_{next_step}
263
+ if status.startswith(f'{step_id}_'):
264
+ target = status[len(step_id) + 1:]
265
+ # Check if target is an outcome or a step
266
+ return END if target in self.outcome_map else target
267
+
268
+ # Handle call_function with simple next step (no transitions)
269
+ if status == f'{step_id}_success':
270
+ if next_step:
271
+ # outcome_id already set in node if it's an outcome
272
+ return END if next_step in self.outcome_map else next_step
273
+ else:
274
+ return END
275
+
276
+ return END
277
+
278
+ return route_fn
279
+
280
+ def build_graph(self, checkpointer=None):
281
+ """Build the LangGraph execution graph
282
+
283
+ Args:
284
+ checkpointer: Optional checkpointer for state persistence.
285
+ Defaults to InMemorySaver() if not provided.
286
+ """
287
+ builder = StateGraph(self.StateType)
288
+
289
+ # Create nodes for each step
290
+ node_functions = {}
291
+ for step in self.steps:
292
+ step_id = step['id']
293
+ action = step['action']
294
+
295
+ if action == 'collect_input_with_agent':
296
+ node_fn = self._create_collect_input_with_agent_node(step)
297
+ elif action == 'call_function':
298
+ node_fn = self._create_call_function_node(step)
299
+ else:
300
+ raise ValueError(f"Unknown action type: {action}")
301
+
302
+ node_functions[step_id] = node_fn
303
+ builder.add_node(step_id, node_fn)
304
+
305
+ # Set entry point (first step)
306
+ first_step_id = self.steps[0]['id']
307
+ builder.add_edge(START, first_step_id)
308
+
309
+ # Add routing edges
310
+ for step in self.steps:
311
+ step_id = step['id']
312
+ route_fn = self._create_routing_function(step)
313
+
314
+ # Build routing map - include self-loops and all possible destinations
315
+ routing_map = {}
316
+
317
+ # Add transition-based routing (works for both collect_input_with_agent and call_function)
318
+ transitions = step.get('transitions', [])
319
+ if transitions:
320
+ # Add self-loop for collect_input_with_agent
321
+ if step['action'] == 'collect_input_with_agent':
322
+ routing_map[step_id] = step_id
323
+
324
+ # Add all possible destinations from transitions
325
+ for transition in transitions:
326
+ next_dest = transition['next']
327
+ routing_map[next_dest] = next_dest if next_dest in self.step_map else END
328
+
329
+ # Add next step for call_function without transitions
330
+ next_step = step.get('next')
331
+ if next_step:
332
+ routing_map[next_step] = next_step if next_step in self.step_map else END
333
+
334
+ # Always include END
335
+ routing_map[END] = END
336
+
337
+ builder.add_conditional_edges(step_id, route_fn, routing_map)
338
+
339
+ # Compile with checkpointer (default to InMemorySaver if not provided)
340
+ if checkpointer is None:
341
+ checkpointer = InMemorySaver()
342
+ graph = builder.compile(checkpointer=checkpointer)
343
+
344
+ return graph
345
+
346
+ def get_outcome_message(self, state: dict) -> str:
347
+ """Get the outcome message from final state"""
348
+ outcome_id = state.get('_outcome_id')
349
+ if outcome_id and outcome_id in self.outcome_map:
350
+ outcome = self.outcome_map[outcome_id]
351
+ message_template = outcome['message']
352
+
353
+ # Replace placeholders in message
354
+ message = message_template
355
+ for field in self.data_fields:
356
+ field_name = field['name']
357
+ value = state.get(field_name)
358
+ if value is not None:
359
+ message = message.replace(f'{{{field_name}}}', str(value))
360
+
361
+ return message
362
+
363
+ return "Workflow completed."
364
+
365
+
366
+ def load_workflow(yaml_path: str, checkpointer=None):
367
+ """Load a workflow from YAML configuration
368
+
369
+ Args:
370
+ yaml_path: Path to the workflow YAML file
371
+ checkpointer: Optional checkpointer for state persistence.
372
+ Defaults to InMemorySaver() if not provided.
373
+ Example: MongoDBSaver for production persistence.
374
+
375
+ Returns:
376
+ Tuple of (graph, engine) where graph is the compiled LangGraph
377
+ and engine is the WorkflowEngine instance
378
+ """
379
+ engine = WorkflowEngine(yaml_path)
380
+ graph = engine.build_graph(checkpointer=checkpointer)
381
+ return graph, engine
File without changes
@@ -0,0 +1,57 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, Callable, List
3
+ import logging
4
+
5
+ from langgraph.errors import GraphInterrupt
6
+
7
+ from ..core.constants import WorkflowKeys
8
+ from ..utils.logger import logger
9
+
10
+
11
+ class ActionStrategy(ABC):
12
+ def __init__(self, step_config: Dict[str, Any], engine_context: Any):
13
+ self.step_config = step_config
14
+ self.engine_context = engine_context
15
+ self.step_id = step_config.get('id')
16
+ self.action = step_config.get('action')
17
+
18
+ @abstractmethod
19
+ def execute(self, state: Dict[str, Any]) -> Dict[str, Any]:
20
+ pass
21
+
22
+ def get_node_function(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
23
+ def node_fn(state: Dict[str, Any]) -> Dict[str, Any]:
24
+ logger.info(f"Executing node: {self.step_id} (action: {self.action})")
25
+ try:
26
+ result = self.execute(state)
27
+ logger.info(f"Node {self.step_id} completed successfully")
28
+ logger.info(f"Result: {result}")
29
+
30
+ if state.get(WorkflowKeys.ERROR):
31
+ state[WorkflowKeys.OUTCOME_ID] = WorkflowKeys.ERROR
32
+ self._set_status(state, WorkflowKeys.ERROR)
33
+ return result
34
+ except GraphInterrupt:
35
+ raise
36
+ except Exception as e:
37
+ logger.error(f"Node {self.step_id} failed: {e}", exc_info=True)
38
+ self._set_status(state, WorkflowKeys.ERROR)
39
+ state[WorkflowKeys.ERROR] = {"error": f"Unable to complete the request: {str(e)}"}
40
+ state[WorkflowKeys.OUTCOME_ID] = WorkflowKeys.ERROR
41
+ return state
42
+
43
+ return node_fn
44
+
45
+ def _get_transitions(self) -> List[Dict[str, Any]]:
46
+ return self.step_config.get('transitions', [])
47
+
48
+ def _get_next_step(self) -> str:
49
+ return self.step_config.get('next')
50
+
51
+ def _set_status(self, state: Dict[str, Any], status_suffix: str):
52
+ from ..core.constants import WorkflowKeys
53
+ state[WorkflowKeys.STATUS] = f'{self.step_id}_{status_suffix}'
54
+
55
+ def _set_outcome(self, state: Dict[str, Any], outcome_id: str):
56
+ from ..core.constants import WorkflowKeys
57
+ state[WorkflowKeys.OUTCOME_ID] = outcome_id
@@ -0,0 +1,108 @@
1
+ from typing import Dict, Any
2
+
3
+ from .base import ActionStrategy
4
+ from ..core.state import set_state_value, get_state_value
5
+ from ..utils.logger import logger
6
+ from ..utils.template import get_nested_value
7
+
8
+
9
+ class CallFunctionStrategy(ActionStrategy):
10
+ def __init__(self, step_config: Dict[str, Any], engine_context: Any):
11
+ super().__init__(step_config, engine_context)
12
+ self.function_path = step_config.get('function')
13
+ self.output_field = step_config.get('output')
14
+ self.inputs = step_config.get('inputs', {})
15
+ self.transitions = self._get_transitions()
16
+ self.next_step = self._get_next_step()
17
+
18
+ if not self.function_path:
19
+ raise RuntimeError(f"Step '{self.step_id}' missing required 'function' property")
20
+
21
+ if not self.output_field:
22
+ raise RuntimeError(f"Step '{self.step_id}' missing required 'output' property")
23
+
24
+ def execute(self, state: Dict[str, Any]) -> Dict[str, Any]:
25
+ from ..utils.tracing import trace_node_execution
26
+
27
+ with trace_node_execution(
28
+ node_id=self.step_id,
29
+ node_type="call_function",
30
+ function=self.function_path,
31
+ output_field=self.output_field
32
+ ) as span:
33
+ try:
34
+ logger.info(f"Loading function: {self.function_path}")
35
+ func = self.engine_context.function_repository.load(self.function_path)
36
+ except Exception as e:
37
+ span.set_attribute("error", True)
38
+ span.set_attribute("error.type", "LoadError")
39
+ span.set_attribute("error.message", str(e))
40
+ raise RuntimeError(
41
+ f"Failed to load function '{self.function_path}' in step '{self.step_id}': {e}"
42
+ )
43
+
44
+ try:
45
+ logger.info(f"Calling function: {self.function_path}")
46
+ result = func(state)
47
+ logger.info(f"Function {self.function_path} returned: {result}")
48
+
49
+ span.add_event("function.executed", {
50
+ "result_type": type(result).__name__,
51
+ "result": str(result)
52
+ })
53
+ except Exception as e:
54
+ span.set_attribute("error", True)
55
+ span.set_attribute("error.type", type(e).__name__)
56
+ span.set_attribute("error.message", str(e))
57
+ raise RuntimeError(f"Function '{self.function_path}' failed in step '{self.step_id}': {e}")
58
+
59
+ if self.output_field:
60
+ set_state_value(state, self.output_field, result)
61
+
62
+ from ..core.constants import WorkflowKeys
63
+ computed_fields = get_state_value(state, WorkflowKeys.COMPUTED_FIELDS, [])
64
+ if self.output_field not in computed_fields:
65
+ computed_fields.append(self.output_field)
66
+ set_state_value(state, WorkflowKeys.COMPUTED_FIELDS, computed_fields)
67
+
68
+ if self.transitions:
69
+ return self._handle_transition_routing(state, result)
70
+
71
+ return self._handle_simple_routing(state)
72
+
73
+ def _handle_transition_routing(
74
+ self,
75
+ state: Dict[str, Any],
76
+ result: Any
77
+ ) -> Dict[str, Any]:
78
+ for transition in self.transitions:
79
+ check_value = result
80
+ if 'path' in transition:
81
+ check_value = get_nested_value(result, transition['path'])
82
+
83
+ if check_value != transition['condition']:
84
+ continue
85
+
86
+ next_dest = transition['next']
87
+ logger.info(f"Found matching transition, transitioning to {next_dest}")
88
+ self._set_status(state, next_dest)
89
+
90
+ if next_dest in self.engine_context.outcome_map:
91
+ self._set_outcome(state, next_dest)
92
+ return state
93
+
94
+ logger.warning(
95
+ f"No matching transition for result '{result}' in step '{self.step_id}'"
96
+ )
97
+ self._set_status(state, 'failed')
98
+ return state
99
+
100
+ def _handle_simple_routing(self, state: Dict[str, Any]) -> Dict[str, Any]:
101
+ self._set_status(state, 'success')
102
+
103
+ if self.next_step:
104
+ self._set_status(state, self.next_step)
105
+
106
+ if self.next_step in self.engine_context.outcome_map:
107
+ self._set_outcome(state, self.next_step)
108
+ return state