soprano-sdk 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- soprano_sdk/__init__.py +10 -0
- soprano_sdk/agents/__init__.py +30 -0
- soprano_sdk/agents/adaptor.py +91 -0
- soprano_sdk/agents/factory.py +228 -0
- soprano_sdk/agents/structured_output.py +97 -0
- soprano_sdk/core/__init__.py +0 -0
- soprano_sdk/core/constants.py +59 -0
- soprano_sdk/core/engine.py +225 -0
- soprano_sdk/core/rollback_strategies.py +259 -0
- soprano_sdk/core/state.py +71 -0
- soprano_sdk/engine.py +381 -0
- soprano_sdk/nodes/__init__.py +0 -0
- soprano_sdk/nodes/base.py +57 -0
- soprano_sdk/nodes/call_function.py +108 -0
- soprano_sdk/nodes/collect_input.py +526 -0
- soprano_sdk/nodes/factory.py +46 -0
- soprano_sdk/routing/__init__.py +0 -0
- soprano_sdk/routing/router.py +97 -0
- soprano_sdk/tools.py +219 -0
- soprano_sdk/utils/__init__.py +0 -0
- soprano_sdk/utils/function.py +35 -0
- soprano_sdk/utils/logger.py +6 -0
- soprano_sdk/utils/template.py +27 -0
- soprano_sdk/utils/tool.py +60 -0
- soprano_sdk/utils/tracing.py +71 -0
- soprano_sdk/validation/__init__.py +13 -0
- soprano_sdk/validation/schema.py +302 -0
- soprano_sdk/validation/validator.py +173 -0
- {soprano_sdk-0.1.94.dist-info → soprano_sdk-0.1.96.dist-info}/METADATA +1 -1
- soprano_sdk-0.1.96.dist-info/RECORD +32 -0
- soprano_sdk-0.1.94.dist-info/RECORD +0 -4
- {soprano_sdk-0.1.94.dist-info → soprano_sdk-0.1.96.dist-info}/WHEEL +0 -0
- {soprano_sdk-0.1.94.dist-info → soprano_sdk-0.1.96.dist-info}/licenses/LICENSE +0 -0
soprano_sdk/engine.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Workflow Engine - Parses YAML workflow definitions and builds LangGraph execution graphs.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
import importlib
|
|
7
|
+
from typing import TypedDict, Optional, Dict
|
|
8
|
+
from langgraph.graph import StateGraph
|
|
9
|
+
from langgraph.types import interrupt
|
|
10
|
+
from langgraph.constants import START, END
|
|
11
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
12
|
+
from agno.agent import Agent
|
|
13
|
+
from agno.models.openai import OpenAIChat
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class WorkflowEngine:
|
|
17
|
+
|
|
18
|
+
def __init__(self, yaml_path: str):
|
|
19
|
+
self.yaml_path = yaml_path
|
|
20
|
+
with open(yaml_path, 'r') as f:
|
|
21
|
+
self.config = yaml.safe_load(f)
|
|
22
|
+
|
|
23
|
+
self.workflow_name = self.config['name']
|
|
24
|
+
self.workflow_description = self.config['description']
|
|
25
|
+
self.data_fields = self.config['data']
|
|
26
|
+
self.steps = self.config['steps']
|
|
27
|
+
self.outcomes = self.config['outcomes']
|
|
28
|
+
|
|
29
|
+
# Build state type dynamically
|
|
30
|
+
self.StateType = self._build_state_type()
|
|
31
|
+
|
|
32
|
+
# Track steps by ID
|
|
33
|
+
self.step_map = {step['id']: step for step in self.steps}
|
|
34
|
+
self.outcome_map = {outcome['id']: outcome for outcome in self.outcomes}
|
|
35
|
+
|
|
36
|
+
def _build_state_type(self) -> type:
|
|
37
|
+
fields = {}
|
|
38
|
+
for field in self.data_fields:
|
|
39
|
+
field_name = field['name']
|
|
40
|
+
field_type = field['type']
|
|
41
|
+
|
|
42
|
+
# Map YAML types to Python types
|
|
43
|
+
type_mapping = {
|
|
44
|
+
'text': Optional[str],
|
|
45
|
+
'number': Optional[int],
|
|
46
|
+
'boolean': Optional[bool]
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
fields[field_name] = type_mapping.get(field_type, Optional[str])
|
|
50
|
+
|
|
51
|
+
# Add internal workflow fields
|
|
52
|
+
fields['_status'] = Optional[str]
|
|
53
|
+
fields['_attempt_counts'] = Optional[Dict[str, int]]
|
|
54
|
+
fields['_outcome_id'] = Optional[str]
|
|
55
|
+
fields['_messages'] = Optional[list] # For message accumulation
|
|
56
|
+
fields['_conversations'] = Optional[Dict[str, list]] # Conversation histories per field
|
|
57
|
+
|
|
58
|
+
# Create TypedDict dynamically
|
|
59
|
+
print("Fields: ", fields)
|
|
60
|
+
return TypedDict('WorkflowState', fields)
|
|
61
|
+
|
|
62
|
+
def _create_collect_input_with_agent_node(self, step_config: dict):
|
|
63
|
+
"""Create a node that collects input using an AI agent with conversation history"""
|
|
64
|
+
step_id = step_config['id']
|
|
65
|
+
field = step_config['field']
|
|
66
|
+
max_attempts = step_config.get('max_attempts', 5)
|
|
67
|
+
agent_config = step_config['agent']
|
|
68
|
+
transitions = step_config.get('transitions', [])
|
|
69
|
+
|
|
70
|
+
def node_fn(state: dict) -> dict:
|
|
71
|
+
# Initialize conversations dict if not exists
|
|
72
|
+
if '_conversations' not in state or state['_conversations'] is None:
|
|
73
|
+
state['_conversations'] = {}
|
|
74
|
+
|
|
75
|
+
# Initialize messages list if not exists
|
|
76
|
+
if '_messages' not in state or state['_messages'] is None:
|
|
77
|
+
state['_messages'] = []
|
|
78
|
+
|
|
79
|
+
# Check if already collected
|
|
80
|
+
if state.get(field) is not None:
|
|
81
|
+
# Field pre-populated by external orchestrator
|
|
82
|
+
# Route using first transition (success path)
|
|
83
|
+
if transitions:
|
|
84
|
+
first_transition = transitions[0]
|
|
85
|
+
next_step = first_transition['next']
|
|
86
|
+
state['_status'] = f'{step_id}_{next_step}'
|
|
87
|
+
|
|
88
|
+
# If next step is an outcome, set it
|
|
89
|
+
if next_step in self.outcome_map:
|
|
90
|
+
state['_outcome_id'] = next_step
|
|
91
|
+
|
|
92
|
+
return state
|
|
93
|
+
|
|
94
|
+
# Get conversation history for this field
|
|
95
|
+
conv_key = f'{field}_conversation'
|
|
96
|
+
if conv_key not in state['_conversations']:
|
|
97
|
+
state['_conversations'][conv_key] = []
|
|
98
|
+
|
|
99
|
+
conversation = state['_conversations'][conv_key]
|
|
100
|
+
|
|
101
|
+
# Check attempt count based on user messages in conversation
|
|
102
|
+
attempt_count = len([m for m in conversation if m['role'] == 'user'])
|
|
103
|
+
|
|
104
|
+
if attempt_count >= max_attempts:
|
|
105
|
+
state['_status'] = f'{step_id}_max_attempts'
|
|
106
|
+
state['_messages'] = [f"I'm having trouble understanding your {field}. Please contact customer service for assistance."]
|
|
107
|
+
return state
|
|
108
|
+
|
|
109
|
+
# Create agent
|
|
110
|
+
agent = Agent(
|
|
111
|
+
name=agent_config.get('name', f'{field}Collector'),
|
|
112
|
+
model=OpenAIChat(id=agent_config.get('model', 'gpt-4o-mini')),
|
|
113
|
+
instructions=agent_config['instructions']
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Generate prompt from agent
|
|
117
|
+
if len(conversation) == 0:
|
|
118
|
+
# First iteration: let agent generate initial greeting
|
|
119
|
+
response = agent.run(conversation)
|
|
120
|
+
prompt = response.content
|
|
121
|
+
conversation.append({"role": "assistant", "content": prompt})
|
|
122
|
+
state['_conversations'][conv_key] = conversation
|
|
123
|
+
else:
|
|
124
|
+
# Use last assistant message as prompt
|
|
125
|
+
prompt = conversation[-1]['content']
|
|
126
|
+
|
|
127
|
+
# Interrupt for user input
|
|
128
|
+
user_input = interrupt(prompt)
|
|
129
|
+
|
|
130
|
+
# Add to conversation
|
|
131
|
+
conversation.append({"role": "user", "content": user_input})
|
|
132
|
+
|
|
133
|
+
# Get agent response
|
|
134
|
+
response = agent.run(conversation)
|
|
135
|
+
agent_response = response.content
|
|
136
|
+
conversation.append({"role": "assistant", "content": agent_response})
|
|
137
|
+
|
|
138
|
+
# Update conversation in state
|
|
139
|
+
state['_conversations'][conv_key] = conversation
|
|
140
|
+
|
|
141
|
+
# Check agent response against all transitions
|
|
142
|
+
matched = False
|
|
143
|
+
for transition in transitions:
|
|
144
|
+
pattern = transition['pattern']
|
|
145
|
+
if pattern in agent_response:
|
|
146
|
+
matched = True
|
|
147
|
+
next_step = transition['next']
|
|
148
|
+
|
|
149
|
+
# Try to extract value after pattern
|
|
150
|
+
value = agent_response.split(pattern)[1].strip()
|
|
151
|
+
|
|
152
|
+
# If there's a value, store it in the field
|
|
153
|
+
if value:
|
|
154
|
+
# Store the value (convert type if needed)
|
|
155
|
+
field_def = next((f for f in self.data_fields if f['name'] == field), None)
|
|
156
|
+
if field_def and field_def['type'] == 'number':
|
|
157
|
+
state[field] = int(value)
|
|
158
|
+
else:
|
|
159
|
+
state[field] = value
|
|
160
|
+
|
|
161
|
+
state['_messages'] = [f"✓ {field.replace('_', ' ').title()} collected: {value}"]
|
|
162
|
+
else:
|
|
163
|
+
# No value extracted (e.g., "ORDER_ID_FAILED:")
|
|
164
|
+
state['_messages'] = []
|
|
165
|
+
|
|
166
|
+
# Set status based on next step
|
|
167
|
+
state['_status'] = f'{step_id}_{next_step}'
|
|
168
|
+
|
|
169
|
+
# If next step is an outcome, set it
|
|
170
|
+
if next_step in self.outcome_map:
|
|
171
|
+
state['_outcome_id'] = next_step
|
|
172
|
+
|
|
173
|
+
break
|
|
174
|
+
|
|
175
|
+
if not matched:
|
|
176
|
+
# Continue collecting - will self-loop
|
|
177
|
+
# Don't store agent_response in _messages as it will become the interrupt prompt
|
|
178
|
+
state['_status'] = f'{step_id}_collecting'
|
|
179
|
+
state['_messages'] = []
|
|
180
|
+
|
|
181
|
+
return state
|
|
182
|
+
|
|
183
|
+
return node_fn
|
|
184
|
+
|
|
185
|
+
def _create_call_function_node(self, step_config: dict):
|
|
186
|
+
step_id = step_config['id']
|
|
187
|
+
function_path = step_config['function']
|
|
188
|
+
inputs = step_config['inputs']
|
|
189
|
+
output_field = step_config['output']
|
|
190
|
+
next_step = step_config.get('next')
|
|
191
|
+
transitions = step_config.get('transitions', [])
|
|
192
|
+
|
|
193
|
+
def node_fn(state: dict) -> dict:
|
|
194
|
+
# Import the function
|
|
195
|
+
module_name, function_name = function_path.rsplit('.', 1)
|
|
196
|
+
module = importlib.import_module(module_name)
|
|
197
|
+
func = getattr(module, function_name)
|
|
198
|
+
|
|
199
|
+
# Prepare inputs by replacing placeholders
|
|
200
|
+
kwargs = {}
|
|
201
|
+
for key, value in inputs.items():
|
|
202
|
+
if isinstance(value, str) and value.startswith('{') and value.endswith('}'):
|
|
203
|
+
# Replace placeholder with state value
|
|
204
|
+
field_name = value[1:-1]
|
|
205
|
+
kwargs[key] = state.get(field_name)
|
|
206
|
+
else:
|
|
207
|
+
kwargs[key] = value
|
|
208
|
+
|
|
209
|
+
# Call the function silently
|
|
210
|
+
result = func(**kwargs)
|
|
211
|
+
|
|
212
|
+
# Store result
|
|
213
|
+
state[output_field] = result
|
|
214
|
+
|
|
215
|
+
# Handle transition-based routing
|
|
216
|
+
if transitions:
|
|
217
|
+
# Check result against all transitions
|
|
218
|
+
for transition in transitions:
|
|
219
|
+
condition = transition['condition']
|
|
220
|
+
next_dest = transition['next']
|
|
221
|
+
|
|
222
|
+
# Match condition against result
|
|
223
|
+
if result == condition:
|
|
224
|
+
state['_status'] = f'{step_id}_{next_dest}'
|
|
225
|
+
|
|
226
|
+
# If next dest is an outcome, set it
|
|
227
|
+
if next_dest in self.outcome_map:
|
|
228
|
+
state['_outcome_id'] = next_dest
|
|
229
|
+
|
|
230
|
+
break
|
|
231
|
+
else:
|
|
232
|
+
# No transitions, use simple success routing
|
|
233
|
+
state['_status'] = f'{step_id}_success'
|
|
234
|
+
|
|
235
|
+
# If next step is an outcome, set it now
|
|
236
|
+
if next_step and next_step in self.outcome_map:
|
|
237
|
+
state['_outcome_id'] = next_step
|
|
238
|
+
|
|
239
|
+
return state
|
|
240
|
+
|
|
241
|
+
return node_fn
|
|
242
|
+
|
|
243
|
+
def _create_routing_function(self, step_config: dict):
|
|
244
|
+
step_id = step_config['id']
|
|
245
|
+
next_step = step_config.get('next')
|
|
246
|
+
transitions = step_config.get('transitions', [])
|
|
247
|
+
|
|
248
|
+
def route_fn(state: dict) -> str:
|
|
249
|
+
status = state.get('_status', '')
|
|
250
|
+
|
|
251
|
+
# Handle transition-based routing (both collect_input_with_agent and call_function)
|
|
252
|
+
if transitions:
|
|
253
|
+
# Check for self-loop (collecting) - only for collect_input_with_agent
|
|
254
|
+
if status == f'{step_id}_collecting':
|
|
255
|
+
return step_id
|
|
256
|
+
|
|
257
|
+
# Check for max_attempts - only for collect_input_with_agent
|
|
258
|
+
if status == f'{step_id}_max_attempts':
|
|
259
|
+
return END
|
|
260
|
+
|
|
261
|
+
# Extract next_step from status
|
|
262
|
+
# Status format: {step_id}_{next_step}
|
|
263
|
+
if status.startswith(f'{step_id}_'):
|
|
264
|
+
target = status[len(step_id) + 1:]
|
|
265
|
+
# Check if target is an outcome or a step
|
|
266
|
+
return END if target in self.outcome_map else target
|
|
267
|
+
|
|
268
|
+
# Handle call_function with simple next step (no transitions)
|
|
269
|
+
if status == f'{step_id}_success':
|
|
270
|
+
if next_step:
|
|
271
|
+
# outcome_id already set in node if it's an outcome
|
|
272
|
+
return END if next_step in self.outcome_map else next_step
|
|
273
|
+
else:
|
|
274
|
+
return END
|
|
275
|
+
|
|
276
|
+
return END
|
|
277
|
+
|
|
278
|
+
return route_fn
|
|
279
|
+
|
|
280
|
+
def build_graph(self, checkpointer=None):
|
|
281
|
+
"""Build the LangGraph execution graph
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
checkpointer: Optional checkpointer for state persistence.
|
|
285
|
+
Defaults to InMemorySaver() if not provided.
|
|
286
|
+
"""
|
|
287
|
+
builder = StateGraph(self.StateType)
|
|
288
|
+
|
|
289
|
+
# Create nodes for each step
|
|
290
|
+
node_functions = {}
|
|
291
|
+
for step in self.steps:
|
|
292
|
+
step_id = step['id']
|
|
293
|
+
action = step['action']
|
|
294
|
+
|
|
295
|
+
if action == 'collect_input_with_agent':
|
|
296
|
+
node_fn = self._create_collect_input_with_agent_node(step)
|
|
297
|
+
elif action == 'call_function':
|
|
298
|
+
node_fn = self._create_call_function_node(step)
|
|
299
|
+
else:
|
|
300
|
+
raise ValueError(f"Unknown action type: {action}")
|
|
301
|
+
|
|
302
|
+
node_functions[step_id] = node_fn
|
|
303
|
+
builder.add_node(step_id, node_fn)
|
|
304
|
+
|
|
305
|
+
# Set entry point (first step)
|
|
306
|
+
first_step_id = self.steps[0]['id']
|
|
307
|
+
builder.add_edge(START, first_step_id)
|
|
308
|
+
|
|
309
|
+
# Add routing edges
|
|
310
|
+
for step in self.steps:
|
|
311
|
+
step_id = step['id']
|
|
312
|
+
route_fn = self._create_routing_function(step)
|
|
313
|
+
|
|
314
|
+
# Build routing map - include self-loops and all possible destinations
|
|
315
|
+
routing_map = {}
|
|
316
|
+
|
|
317
|
+
# Add transition-based routing (works for both collect_input_with_agent and call_function)
|
|
318
|
+
transitions = step.get('transitions', [])
|
|
319
|
+
if transitions:
|
|
320
|
+
# Add self-loop for collect_input_with_agent
|
|
321
|
+
if step['action'] == 'collect_input_with_agent':
|
|
322
|
+
routing_map[step_id] = step_id
|
|
323
|
+
|
|
324
|
+
# Add all possible destinations from transitions
|
|
325
|
+
for transition in transitions:
|
|
326
|
+
next_dest = transition['next']
|
|
327
|
+
routing_map[next_dest] = next_dest if next_dest in self.step_map else END
|
|
328
|
+
|
|
329
|
+
# Add next step for call_function without transitions
|
|
330
|
+
next_step = step.get('next')
|
|
331
|
+
if next_step:
|
|
332
|
+
routing_map[next_step] = next_step if next_step in self.step_map else END
|
|
333
|
+
|
|
334
|
+
# Always include END
|
|
335
|
+
routing_map[END] = END
|
|
336
|
+
|
|
337
|
+
builder.add_conditional_edges(step_id, route_fn, routing_map)
|
|
338
|
+
|
|
339
|
+
# Compile with checkpointer (default to InMemorySaver if not provided)
|
|
340
|
+
if checkpointer is None:
|
|
341
|
+
checkpointer = InMemorySaver()
|
|
342
|
+
graph = builder.compile(checkpointer=checkpointer)
|
|
343
|
+
|
|
344
|
+
return graph
|
|
345
|
+
|
|
346
|
+
def get_outcome_message(self, state: dict) -> str:
|
|
347
|
+
"""Get the outcome message from final state"""
|
|
348
|
+
outcome_id = state.get('_outcome_id')
|
|
349
|
+
if outcome_id and outcome_id in self.outcome_map:
|
|
350
|
+
outcome = self.outcome_map[outcome_id]
|
|
351
|
+
message_template = outcome['message']
|
|
352
|
+
|
|
353
|
+
# Replace placeholders in message
|
|
354
|
+
message = message_template
|
|
355
|
+
for field in self.data_fields:
|
|
356
|
+
field_name = field['name']
|
|
357
|
+
value = state.get(field_name)
|
|
358
|
+
if value is not None:
|
|
359
|
+
message = message.replace(f'{{{field_name}}}', str(value))
|
|
360
|
+
|
|
361
|
+
return message
|
|
362
|
+
|
|
363
|
+
return "Workflow completed."
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def load_workflow(yaml_path: str, checkpointer=None):
|
|
367
|
+
"""Load a workflow from YAML configuration
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
yaml_path: Path to the workflow YAML file
|
|
371
|
+
checkpointer: Optional checkpointer for state persistence.
|
|
372
|
+
Defaults to InMemorySaver() if not provided.
|
|
373
|
+
Example: MongoDBSaver for production persistence.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
Tuple of (graph, engine) where graph is the compiled LangGraph
|
|
377
|
+
and engine is the WorkflowEngine instance
|
|
378
|
+
"""
|
|
379
|
+
engine = WorkflowEngine(yaml_path)
|
|
380
|
+
graph = engine.build_graph(checkpointer=checkpointer)
|
|
381
|
+
return graph, engine
|
|
File without changes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, Any, Callable, List
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from langgraph.errors import GraphInterrupt
|
|
6
|
+
|
|
7
|
+
from ..core.constants import WorkflowKeys
|
|
8
|
+
from ..utils.logger import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ActionStrategy(ABC):
|
|
12
|
+
def __init__(self, step_config: Dict[str, Any], engine_context: Any):
|
|
13
|
+
self.step_config = step_config
|
|
14
|
+
self.engine_context = engine_context
|
|
15
|
+
self.step_id = step_config.get('id')
|
|
16
|
+
self.action = step_config.get('action')
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def execute(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def get_node_function(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
|
|
23
|
+
def node_fn(state: Dict[str, Any]) -> Dict[str, Any]:
|
|
24
|
+
logger.info(f"Executing node: {self.step_id} (action: {self.action})")
|
|
25
|
+
try:
|
|
26
|
+
result = self.execute(state)
|
|
27
|
+
logger.info(f"Node {self.step_id} completed successfully")
|
|
28
|
+
logger.info(f"Result: {result}")
|
|
29
|
+
|
|
30
|
+
if state.get(WorkflowKeys.ERROR):
|
|
31
|
+
state[WorkflowKeys.OUTCOME_ID] = WorkflowKeys.ERROR
|
|
32
|
+
self._set_status(state, WorkflowKeys.ERROR)
|
|
33
|
+
return result
|
|
34
|
+
except GraphInterrupt:
|
|
35
|
+
raise
|
|
36
|
+
except Exception as e:
|
|
37
|
+
logger.error(f"Node {self.step_id} failed: {e}", exc_info=True)
|
|
38
|
+
self._set_status(state, WorkflowKeys.ERROR)
|
|
39
|
+
state[WorkflowKeys.ERROR] = {"error": f"Unable to complete the request: {str(e)}"}
|
|
40
|
+
state[WorkflowKeys.OUTCOME_ID] = WorkflowKeys.ERROR
|
|
41
|
+
return state
|
|
42
|
+
|
|
43
|
+
return node_fn
|
|
44
|
+
|
|
45
|
+
def _get_transitions(self) -> List[Dict[str, Any]]:
|
|
46
|
+
return self.step_config.get('transitions', [])
|
|
47
|
+
|
|
48
|
+
def _get_next_step(self) -> str:
|
|
49
|
+
return self.step_config.get('next')
|
|
50
|
+
|
|
51
|
+
def _set_status(self, state: Dict[str, Any], status_suffix: str):
|
|
52
|
+
from ..core.constants import WorkflowKeys
|
|
53
|
+
state[WorkflowKeys.STATUS] = f'{self.step_id}_{status_suffix}'
|
|
54
|
+
|
|
55
|
+
def _set_outcome(self, state: Dict[str, Any], outcome_id: str):
|
|
56
|
+
from ..core.constants import WorkflowKeys
|
|
57
|
+
state[WorkflowKeys.OUTCOME_ID] = outcome_id
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import Dict, Any
|
|
2
|
+
|
|
3
|
+
from .base import ActionStrategy
|
|
4
|
+
from ..core.state import set_state_value, get_state_value
|
|
5
|
+
from ..utils.logger import logger
|
|
6
|
+
from ..utils.template import get_nested_value
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CallFunctionStrategy(ActionStrategy):
|
|
10
|
+
def __init__(self, step_config: Dict[str, Any], engine_context: Any):
|
|
11
|
+
super().__init__(step_config, engine_context)
|
|
12
|
+
self.function_path = step_config.get('function')
|
|
13
|
+
self.output_field = step_config.get('output')
|
|
14
|
+
self.inputs = step_config.get('inputs', {})
|
|
15
|
+
self.transitions = self._get_transitions()
|
|
16
|
+
self.next_step = self._get_next_step()
|
|
17
|
+
|
|
18
|
+
if not self.function_path:
|
|
19
|
+
raise RuntimeError(f"Step '{self.step_id}' missing required 'function' property")
|
|
20
|
+
|
|
21
|
+
if not self.output_field:
|
|
22
|
+
raise RuntimeError(f"Step '{self.step_id}' missing required 'output' property")
|
|
23
|
+
|
|
24
|
+
def execute(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
25
|
+
from ..utils.tracing import trace_node_execution
|
|
26
|
+
|
|
27
|
+
with trace_node_execution(
|
|
28
|
+
node_id=self.step_id,
|
|
29
|
+
node_type="call_function",
|
|
30
|
+
function=self.function_path,
|
|
31
|
+
output_field=self.output_field
|
|
32
|
+
) as span:
|
|
33
|
+
try:
|
|
34
|
+
logger.info(f"Loading function: {self.function_path}")
|
|
35
|
+
func = self.engine_context.function_repository.load(self.function_path)
|
|
36
|
+
except Exception as e:
|
|
37
|
+
span.set_attribute("error", True)
|
|
38
|
+
span.set_attribute("error.type", "LoadError")
|
|
39
|
+
span.set_attribute("error.message", str(e))
|
|
40
|
+
raise RuntimeError(
|
|
41
|
+
f"Failed to load function '{self.function_path}' in step '{self.step_id}': {e}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
logger.info(f"Calling function: {self.function_path}")
|
|
46
|
+
result = func(state)
|
|
47
|
+
logger.info(f"Function {self.function_path} returned: {result}")
|
|
48
|
+
|
|
49
|
+
span.add_event("function.executed", {
|
|
50
|
+
"result_type": type(result).__name__,
|
|
51
|
+
"result": str(result)
|
|
52
|
+
})
|
|
53
|
+
except Exception as e:
|
|
54
|
+
span.set_attribute("error", True)
|
|
55
|
+
span.set_attribute("error.type", type(e).__name__)
|
|
56
|
+
span.set_attribute("error.message", str(e))
|
|
57
|
+
raise RuntimeError(f"Function '{self.function_path}' failed in step '{self.step_id}': {e}")
|
|
58
|
+
|
|
59
|
+
if self.output_field:
|
|
60
|
+
set_state_value(state, self.output_field, result)
|
|
61
|
+
|
|
62
|
+
from ..core.constants import WorkflowKeys
|
|
63
|
+
computed_fields = get_state_value(state, WorkflowKeys.COMPUTED_FIELDS, [])
|
|
64
|
+
if self.output_field not in computed_fields:
|
|
65
|
+
computed_fields.append(self.output_field)
|
|
66
|
+
set_state_value(state, WorkflowKeys.COMPUTED_FIELDS, computed_fields)
|
|
67
|
+
|
|
68
|
+
if self.transitions:
|
|
69
|
+
return self._handle_transition_routing(state, result)
|
|
70
|
+
|
|
71
|
+
return self._handle_simple_routing(state)
|
|
72
|
+
|
|
73
|
+
def _handle_transition_routing(
|
|
74
|
+
self,
|
|
75
|
+
state: Dict[str, Any],
|
|
76
|
+
result: Any
|
|
77
|
+
) -> Dict[str, Any]:
|
|
78
|
+
for transition in self.transitions:
|
|
79
|
+
check_value = result
|
|
80
|
+
if 'path' in transition:
|
|
81
|
+
check_value = get_nested_value(result, transition['path'])
|
|
82
|
+
|
|
83
|
+
if check_value != transition['condition']:
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
next_dest = transition['next']
|
|
87
|
+
logger.info(f"Found matching transition, transitioning to {next_dest}")
|
|
88
|
+
self._set_status(state, next_dest)
|
|
89
|
+
|
|
90
|
+
if next_dest in self.engine_context.outcome_map:
|
|
91
|
+
self._set_outcome(state, next_dest)
|
|
92
|
+
return state
|
|
93
|
+
|
|
94
|
+
logger.warning(
|
|
95
|
+
f"No matching transition for result '{result}' in step '{self.step_id}'"
|
|
96
|
+
)
|
|
97
|
+
self._set_status(state, 'failed')
|
|
98
|
+
return state
|
|
99
|
+
|
|
100
|
+
def _handle_simple_routing(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
101
|
+
self._set_status(state, 'success')
|
|
102
|
+
|
|
103
|
+
if self.next_step:
|
|
104
|
+
self._set_status(state, self.next_step)
|
|
105
|
+
|
|
106
|
+
if self.next_step in self.engine_context.outcome_map:
|
|
107
|
+
self._set_outcome(state, self.next_step)
|
|
108
|
+
return state
|