soprano-sdk 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,526 @@
1
+ from typing import Dict, Any, List, Optional, Tuple
2
+
3
+ from jinja2 import Environment
4
+ from langgraph.types import interrupt
5
+
6
+ from .base import ActionStrategy
7
+ from ..agents.factory import AgentFactory, AgentAdapter
8
+ from ..agents.structured_output import create_structured_output_model, validate_field_definitions
9
+ from ..core.constants import (
10
+ WorkflowKeys,
11
+ DEFAULT_MAX_ATTEMPTS,
12
+ MAX_ATTEMPTS_MESSAGE,
13
+ TransitionPattern
14
+ )
15
+ from ..core.rollback_strategies import (
16
+ RollbackStrategy,
17
+ HistoryBasedRollback,
18
+ DependencyBasedRollback
19
+ )
20
+ from ..core.state import initialize_state
21
+ from ..utils.logger import logger
22
+ from ..utils.tracing import trace_node_execution, trace_agent_invocation, add_node_result
23
+
24
+ VALIDATION_ERROR_MESSAGE = "validation failed for the provided input, please enter valid input"
25
+ INVALID_INPUT_MESSAGE = "Looks like the input is invalid. Please double-check and re-enter it."
26
+ COLLECTION_FAILURE_MESSAGE = "I couldn't understand your response. Please try again and provide the required information."
27
+
28
+ def _wrap_instructions_with_intent_detection(
29
+ instructions: str,
30
+ collector_nodes: Dict[str, str],
31
+ with_structured_output: bool
32
+ ) -> str:
33
+ if not collector_nodes:
34
+ return instructions
35
+
36
+ collector_nodes_str = "\n".join(f"{node_name}: {description}" for node_name, description in collector_nodes.items())
37
+ return f"""
38
+ {instructions}
39
+
40
+ AVAILABLE CONVERSATION INTENTS:
41
+ {collector_nodes_str}
42
+
43
+ Format: <node_name>: <intent_description>
44
+
45
+ CRITICAL INTENT DETECTION RULES:
46
+ 1. ONLY check for intent changes against the EXACT node names listed above
47
+ 2. The node name MUST appear in the list above to be valid
48
+ 3. Do NOT infer, guess, or create new intent names
49
+ 4. Tool calls are NOT intent changes
50
+ 5. If no intents are listed above, NEVER trigger an intent change
51
+
52
+ Before responding, analyze if the user's query matches a DIFFERENT intent from the list above:
53
+
54
+ IF the user's query clearly matches a DIFFERENT intent that EXISTS in the list above:
55
+ - {"Respond ONLY with: INTENT_CHANGE: <node_name>" if not with_structured_output else "modify intent_change value <node_name>"}
56
+ - Use the EXACT node_name from the list above
57
+ - Do NOT provide any other response
58
+ - Do NOT answer the user's question
59
+
60
+ IF the user's query continues with the SAME intent OR does not match any intent in the list above:
61
+ - Proceed with your normal response
62
+ - Do NOT mention intent detection
63
+ - Answer the user's question as configured
64
+ """
65
+
66
+ def _create_rollback_strategy(strategy_name: str) -> RollbackStrategy:
67
+ if strategy_name == "dependency_based":
68
+ return DependencyBasedRollback()
69
+ elif strategy_name == "history_based":
70
+ return HistoryBasedRollback()
71
+ else:
72
+ logger.warning(f"Unknown rollback strategy '{strategy_name}', using history_based")
73
+ return HistoryBasedRollback()
74
+
75
+ def _get_agent_response(agent: AgentAdapter, conversation: List[Dict[str, str]]) -> Any:
76
+ agent_response = agent.invoke(conversation)
77
+
78
+ conversation.append({"role": "assistant", "content": str(agent_response)})
79
+
80
+ return agent_response
81
+
82
+ class CollectInputStrategy(ActionStrategy):
83
+ def __init__(self, step_config: Dict[str, Any], engine_context: Any):
84
+ super().__init__(step_config, engine_context)
85
+ self.field = step_config.get('field')
86
+ self.agent_config = step_config.get('agent', {})
87
+ self.max_attempts = step_config.get('retry_limit') or engine_context.get_config_value("max_retry_limit", DEFAULT_MAX_ATTEMPTS)
88
+ self.transitions = self._get_transitions()
89
+ self.next_step = self.step_config.get("next", None)
90
+ self.is_structured_output = self.agent_config.get("structured_output", {}).get("enabled", False)
91
+
92
+ rollback_strategy_name = engine_context.get_config_value("rollback_strategy", "history_based")
93
+ self.rollback_strategy = _create_rollback_strategy(rollback_strategy_name)
94
+ logger.info(f"Using rollback strategy: {self.rollback_strategy.get_strategy_name()}")
95
+
96
+ self.validator = None
97
+ if validator_function_path := self.step_config.get("validator"):
98
+ self.validator = self.engine_context.function_repository.load(validator_function_path)
99
+
100
+ if not self.field:
101
+ raise RuntimeError(f"Step '{self.step_id}' missing required 'field' property")
102
+
103
+ if not self.agent_config:
104
+ raise RuntimeError(f"Step '{self.step_id}' missing required 'agent' configuration")
105
+
106
+ @property
107
+ def _conversation_key(self) -> str:
108
+ return f'{self.field}_conversation'
109
+
110
+ @property
111
+ def _formatted_field_name(self) -> str:
112
+ return self.field.replace('_', ' ').title()
113
+
114
+ def execute(self, state: Dict[str, Any]) -> Dict[str, Any]:
115
+ with trace_node_execution(
116
+ node_id=self.step_id,
117
+ node_type="collect_input_with_agent",
118
+ output_field=self.field
119
+ ) as span:
120
+ state = initialize_state(state)
121
+
122
+ self._apply_context_value(state, span)
123
+
124
+ is_self_loop = self._is_self_loop(state)
125
+
126
+ if not is_self_loop:
127
+ if self.rollback_strategy.should_save_snapshot():
128
+ self._save_snapshot_before_execution(state)
129
+
130
+ conversation = self._get_or_create_conversation(state)
131
+
132
+ if self._is_field_pre_populated(state):
133
+ span.add_event("field.pre_populated", {"value": str(state.get(self.field))})
134
+ state = self._handle_pre_populated_field(state, conversation)
135
+ add_node_result(span, self.field, state.get(self.field), state.get(WorkflowKeys.STATUS))
136
+ return self._add_node_to_execution_order(state)
137
+
138
+ if self._max_attempts_reached(state):
139
+ span.add_event("max_attempts.reached")
140
+ return self._handle_max_attempts(state)
141
+
142
+ agent = self._create_agent(state)
143
+
144
+ prompt = self._generate_prompt(agent, conversation, state)
145
+
146
+ user_input = interrupt(prompt)
147
+
148
+ conversation.append({"role": "user", "content": user_input})
149
+ span.add_event("user.input_received", {"input_length": len(user_input)})
150
+
151
+ with trace_agent_invocation(
152
+ agent_name=self.agent_config.get('name', self.field),
153
+ model=self.agent_config.get('model', 'default')
154
+ ):
155
+ agent_response = _get_agent_response(agent, conversation)
156
+
157
+ if self.is_structured_output:
158
+ state = self._handle_structured_output_transition(state, conversation, agent_response)
159
+ add_node_result(span, self.field, state.get(self.field), state.get(WorkflowKeys.STATUS))
160
+ return self._add_node_to_execution_order(state)
161
+
162
+ if agent_response.startswith(TransitionPattern.INTENT_CHANGE):
163
+ span.add_event("intent.change_detected")
164
+ return self._handle_intent_change(agent_response, state)
165
+
166
+ state = self._process_transitions(state, conversation, agent_response)
167
+
168
+ self._update_conversation(state, conversation)
169
+
170
+ add_node_result(span, self.field, state.get(self.field), state.get(WorkflowKeys.STATUS))
171
+
172
+ return self._add_node_to_execution_order(state)
173
+
174
+ def _render_template_string(self, template_str: str, state: Dict[str, Any]) -> str:
175
+ if not template_str:
176
+ return ""
177
+ template_loader = self.engine_context.get_config_value('template_loader', Environment())
178
+ return template_loader.from_string(template_str).render(state)
179
+
180
+ def _apply_context_value(self, state: Dict[str, Any], span) -> None:
181
+ if not (context_value := self.engine_context.get_context_value(self.field)):
182
+ return
183
+ logger.info(f"Using context value for '{self.field}': {context_value}")
184
+ state[self.field] = context_value
185
+ span.add_event("context.value_used", {"field": self.field, "value": str(context_value)})
186
+
187
+ def _add_node_to_execution_order(self, state):
188
+ if 'collecting' in state.get('_status'):
189
+ return state
190
+
191
+ self._register_node_execution(state)
192
+ self._register_collector_node(state)
193
+
194
+ return state
195
+
196
+ def _is_self_loop(self, state: Dict[str, Any]) -> bool:
197
+ return state.get(WorkflowKeys.STATUS) == f'{self.step_id}_collecting'
198
+
199
+ def _save_snapshot_before_execution(self, state: Dict[str, Any]):
200
+ state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
201
+ execution_index = len(state_history)
202
+ self.rollback_strategy.save_snapshot(state, self.step_id, execution_index)
203
+
204
+ def _register_node_execution(self, state: Dict[str, Any]):
205
+ execution_order = state.get(WorkflowKeys.NODE_EXECUTION_ORDER, [])
206
+ if self.step_id not in execution_order:
207
+ execution_order.append(self.step_id)
208
+ state[WorkflowKeys.NODE_EXECUTION_ORDER] = execution_order
209
+
210
+ def _register_collector_node(self, state: Dict[str, Any]):
211
+ collector_nodes = state.get(WorkflowKeys.COLLECTOR_NODES, {})
212
+ description = self.agent_config.get('description', f"Collecting {self.field}")
213
+ collector_nodes[self.step_id] = description
214
+ state[WorkflowKeys.COLLECTOR_NODES] = collector_nodes
215
+
216
+ node_field_map = state.get(WorkflowKeys.NODE_FIELD_MAP, {})
217
+ node_field_map[self.step_id] = self.field
218
+ state[WorkflowKeys.NODE_FIELD_MAP] = node_field_map
219
+
220
+ def _is_field_pre_populated(self, state: Dict[str, Any]) -> bool:
221
+ return state.get(self.field) is not None
222
+
223
+ def _validate_collected_input(self, state) -> Tuple[bool, Optional[str]]:
224
+ if not self.validator:
225
+ return True, None
226
+ result = self.validator(**state)
227
+ if isinstance(result, tuple):
228
+ return result
229
+ return result, None
230
+
231
+ def _handle_pre_populated_field(self, state: Dict[str, Any], conversation: List) -> Dict[str, Any]:
232
+ logger.info(f"Field '{self.field}' is populated, skipping collection")
233
+
234
+ is_valid_input, _ = self._validate_collected_input(state)
235
+ if not is_valid_input:
236
+ self._set_status(state, "collecting")
237
+ return self._handle_validation_failure(state, conversation, message=f"{state[self.field]}", role="user")
238
+
239
+ if self.transitions:
240
+ first_transition = self.transitions[0]
241
+ next_step = first_transition['next']
242
+ self._set_status(state, next_step)
243
+
244
+ if next_step in self.engine_context.outcome_map:
245
+ self._set_outcome(state, next_step)
246
+
247
+ if self.next_step:
248
+ self._set_status(state, self.next_step)
249
+
250
+ if self.next_step in self.engine_context.outcome_map:
251
+ self._set_outcome(state, self.next_step)
252
+
253
+ return state
254
+
255
+ def _max_attempts_reached(self, state: Dict[str, Any]) -> bool:
256
+ conversation = state.get(WorkflowKeys.CONVERSATIONS, {}).get(self._conversation_key, [])
257
+ attempt_count = len([m for m in conversation if m['role'] == 'user'])
258
+ return attempt_count >= self.max_attempts
259
+
260
+ def _handle_max_attempts(self, state: Dict[str, Any]) -> Dict[str, Any]:
261
+ logger.warning(f"Max attempts reached for field '{self.field}'")
262
+ self._set_status(state, 'max_attempts')
263
+ message = MAX_ATTEMPTS_MESSAGE.format(field=self.field)
264
+ state[WorkflowKeys.MESSAGES] = [message]
265
+ return state
266
+
267
+ def _get_or_create_conversation(self, state: Dict[str, Any]) -> List[Dict[str, str]]:
268
+ conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
269
+
270
+ if self._conversation_key not in conversations:
271
+ conversations[self._conversation_key] = []
272
+ state[WorkflowKeys.CONVERSATIONS] = conversations
273
+
274
+ return conversations[self._conversation_key]
275
+ def _get_model_config(self) -> Dict[str, Any]:
276
+ model_config = self.engine_context.get_config_value('model_config')
277
+ if not model_config:
278
+ raise ValueError("Model config not found in engine context")
279
+
280
+ if model_id := self.agent_config.get("model"):
281
+ model_config = model_config.copy()
282
+ model_config["model_name"] = model_id
283
+
284
+ return model_config
285
+
286
+ def _get_instructions(self, state: Dict[str, Any], collector_nodes: Dict[str, str]) -> str:
287
+ instructions = self.agent_config.get("instructions")
288
+
289
+ instructions = self._render_template_string(instructions, state)
290
+
291
+ if collector_nodes:
292
+ instructions = _wrap_instructions_with_intent_detection(instructions, collector_nodes, self.is_structured_output)
293
+ return instructions
294
+
295
+ def _load_agent_tools(self, state: Dict[str, Any]) -> List:
296
+ return [
297
+ self.engine_context.tool_repository.load(tool_name, state)
298
+ for tool_name in self.agent_config.get('tools', [])
299
+ ]
300
+
301
+ def _create_structured_output_model(self, collector_nodes: Dict[str, str]) -> Any:
302
+ structured_output_config = self.agent_config.get('structured_output')
303
+ if not structured_output_config or not structured_output_config.get('enabled'):
304
+ return None
305
+
306
+ fields = structured_output_config.get('fields', [])
307
+ if not fields:
308
+ return None
309
+
310
+ validate_field_definitions(fields)
311
+ model_name = f"{self.field.title().replace('_', '')}StructuredOutput"
312
+
313
+ return create_structured_output_model(
314
+ fields=fields,
315
+ model_name=model_name,
316
+ needs_intent_change=len(collector_nodes) > 0
317
+ )
318
+
319
+ def _create_agent(self, state: Dict[str, Any]) -> AgentAdapter:
320
+ try:
321
+ model_config = self._get_model_config()
322
+ agent_tools = self._load_agent_tools(state)
323
+ collector_nodes = state.get(WorkflowKeys.COLLECTOR_NODES, {})
324
+
325
+ instructions = self._get_instructions(state, collector_nodes)
326
+ structured_output_model = self._create_structured_output_model(collector_nodes)
327
+ framework = self.engine_context.get_config_value('agent_framework', 'langgraph')
328
+
329
+ return AgentFactory.create_agent(
330
+ framework=framework,
331
+ name=self.agent_config.get('name', f'{self.field}Collector'),
332
+ model_config=model_config,
333
+ tools=agent_tools,
334
+ system_prompt=instructions,
335
+ structured_output_model=structured_output_model
336
+ )
337
+
338
+ except Exception as e:
339
+ raise RuntimeError(f"Failed to create agent for step '{self.step_id}': {e}")
340
+
341
+ def _generate_prompt(
342
+ self,
343
+ agent: AgentAdapter,
344
+ conversation: List[Dict[str, str]],
345
+ state: Dict[str, Any]
346
+ ) -> str:
347
+ if len(conversation) == 0:
348
+ if not (prompt := self.agent_config.get('initial_message')):
349
+ prompt = agent.invoke([{"role": "user", "content": ""}])
350
+
351
+ prompt = self._render_template_string(prompt, state)
352
+ conversation.append({"role": "assistant", "content": prompt})
353
+
354
+ return prompt
355
+
356
+ return conversation[-1]['content']
357
+
358
+ def _update_conversation(self, state: Dict[str, Any], conversation: List[Dict[str, str]]):
359
+ state[WorkflowKeys.CONVERSATIONS][self._conversation_key] = conversation
360
+
361
+ def _handle_intent_change(self, target_node_or_response, state: Dict[str, Any]) -> Dict[str, Any]:
362
+ if isinstance(target_node_or_response, str) and TransitionPattern.INTENT_CHANGE in target_node_or_response:
363
+ target_node = target_node_or_response.split(TransitionPattern.INTENT_CHANGE)[1].strip()
364
+ else:
365
+ target_node = target_node_or_response
366
+
367
+ logger.info(f"Intent change detected: {self.step_id} -> {target_node}")
368
+
369
+ rollback_state = self._rollback_state_to_node(state, target_node)
370
+
371
+ if rollback_state is None:
372
+ logger.error(f"Failed to rollback to node '{target_node}'")
373
+ raise RuntimeError(f"Unable to process intent change to '{target_node}'")
374
+
375
+ return rollback_state
376
+
377
+ def _rollback_state_to_node(
378
+ self,
379
+ state: Dict[str, Any],
380
+ target_node: str
381
+ ) -> Dict[str, Any]:
382
+ node_execution_order = state.get(WorkflowKeys.NODE_EXECUTION_ORDER, [])
383
+ node_field_map = state.get(WorkflowKeys.NODE_FIELD_MAP, {})
384
+ workflow_steps = self.engine_context.steps
385
+
386
+ restored_state = self.rollback_strategy.rollback_to_node(
387
+ state=state,
388
+ target_node=target_node,
389
+ node_execution_order=node_execution_order,
390
+ node_field_map=node_field_map,
391
+ workflow_steps=workflow_steps
392
+ )
393
+
394
+ if not restored_state:
395
+ logger.warning(f"Rollback strategy returned empty state for node '{target_node}'")
396
+ return {}
397
+
398
+ restored_state[WorkflowKeys.STATUS] = f"{self.step_id}_{target_node}"
399
+
400
+ return restored_state
401
+
402
+ def _process_transitions(
403
+ self,
404
+ state: Dict[str, Any],
405
+ conversation: List,
406
+ agent_response: str
407
+ ) -> Dict[str, Any]:
408
+ matched = False
409
+ self._set_status(state, 'collecting')
410
+
411
+ for transition in self.transitions:
412
+ pattern = transition['pattern']
413
+ if pattern not in agent_response:
414
+ continue
415
+
416
+ matched = True
417
+ next_step = transition['next']
418
+
419
+ logger.info(f"Matched transition: {transition}")
420
+
421
+ value = agent_response.split(pattern)[1].strip()
422
+ if value:
423
+ self._store_field_value(state, value)
424
+ is_valid_input, message = self._validate_collected_input(state)
425
+ if not is_valid_input:
426
+ return self._handle_validation_failure(state, conversation, message=message)
427
+ state[WorkflowKeys.MESSAGES] = [f"✓ {self._formatted_field_name} collected: {value}" ]
428
+ else:
429
+ state[WorkflowKeys.MESSAGES] = []
430
+
431
+ self._set_status(state, next_step)
432
+
433
+ if next_step in self.engine_context.outcome_map:
434
+ self._set_outcome(state, next_step)
435
+
436
+ break
437
+
438
+ if not matched:
439
+ logger.info(f"No transition matched for response in step '{self.step_id}', continuing collection")
440
+ state[WorkflowKeys.MESSAGES] = []
441
+
442
+ return state
443
+
444
+ def _handle_structured_output_transition(self, state: Dict[str, Any], conversation: List, agent_response: Any) -> Dict[str, Any]:
445
+ if target_node := agent_response.get("intent_change"):
446
+ return self._handle_intent_change(target_node, state)
447
+
448
+ self._set_status(state, "collecting")
449
+
450
+ if bot_response := agent_response.get("bot_response"):
451
+ conversation.append({"role": "assistant", "content": bot_response})
452
+ return state
453
+
454
+ self._store_field_value(state, agent_response)
455
+
456
+ is_valid_input, validation_error_message = self._validate_collected_input(state)
457
+ if not is_valid_input:
458
+ return self._handle_validation_failure(state, conversation, message=validation_error_message)
459
+
460
+ if next_node := self._find_matching_transition(agent_response):
461
+ return self._complete_collection(state, next_node, agent_response)
462
+
463
+ if self.next_step:
464
+ return self._complete_collection(state, self.next_step, agent_response)
465
+
466
+ return self._handle_collection_failure(state, conversation)
467
+
468
+ def _handle_validation_failure(self, state: Dict[str, Any], conversation: List, message: Optional[str]=VALIDATION_ERROR_MESSAGE, role="assistant") -> Dict[str, Any]:
469
+ self._store_field_value(state, None)
470
+ self.engine_context.update_context({self.field: None})
471
+ conversation.append({"role": role, "content": message})
472
+ return state
473
+
474
+ def _find_matching_transition(self, agent_response: Any) -> Optional[str]:
475
+ is_structured_output = isinstance(agent_response, dict)
476
+
477
+ for transition in self.transitions:
478
+ if is_structured_output:
479
+ next_node = transition.get("next")
480
+ match_value = transition.get("match")
481
+ ref_field = transition.get("ref")
482
+
483
+ if not all([next_node, match_value, ref_field]):
484
+ raise RuntimeError(f"Transition in step '{self.step_id}' missing required properties for structured output routing")
485
+
486
+ if field_value := agent_response.get(ref_field):
487
+ if field_value == match_value:
488
+ return next_node
489
+ else:
490
+ next_node = transition.get("next")
491
+ pattern = transition.get("pattern")
492
+ if pattern in agent_response:
493
+ return next_node
494
+
495
+ return None
496
+
497
+ def _complete_collection(self, state: Dict[str, Any], next_node: str, agent_response: Any) -> Dict[str, Any]:
498
+ self._set_status(state, next_node)
499
+
500
+ if next_node in self.engine_context.outcome_map:
501
+ self._set_outcome(state, next_node)
502
+
503
+ state[WorkflowKeys.MESSAGES] = [
504
+ f"✓ {self._formatted_field_name} collected: {str(agent_response)}"
505
+ ]
506
+
507
+ return state
508
+
509
+ def _handle_collection_failure(self, state: Dict[str, Any], conversation: List) -> Dict[str, Any]:
510
+ conversation.append({"role": "assistant", "content": COLLECTION_FAILURE_MESSAGE})
511
+ self._store_field_value(state, None)
512
+ return state
513
+
514
+
515
+ def _store_field_value(self, state: Dict[str, Any], value: Any):
516
+ field_def = next((f for f in self.engine_context.data_fields if f['name'] == self.field), None)
517
+ if not field_def:
518
+ return
519
+
520
+ if field_def.get('type') == 'number':
521
+ try:
522
+ state[self.field] = int(value)
523
+ except ValueError:
524
+ state[self.field] = value
525
+ else:
526
+ state[self.field] = value
@@ -0,0 +1,46 @@
1
+ from typing import Dict, Any, Type, Callable
2
+
3
+ from .base import ActionStrategy
4
+ from .call_function import CallFunctionStrategy
5
+ from .collect_input import CollectInputStrategy
6
+ from ..core.constants import ActionType
7
+ from ..utils.logger import logger
8
+
9
+
10
+ class NodeFactory:
11
+ _strategies: Dict[str, Type[ActionStrategy]] = {}
12
+
13
+ @classmethod
14
+ def register(cls, action_type: str, strategy_class: Type[ActionStrategy]):
15
+ if not issubclass(strategy_class, ActionStrategy):
16
+ raise RuntimeError(f"Strategy class {strategy_class.__name__} must inherit from NodeStrategy")
17
+
18
+ logger.info(f"Registering node strategy: {action_type} -> {strategy_class.__name__}")
19
+ cls._strategies[action_type] = strategy_class
20
+
21
+ @classmethod
22
+ def create(
23
+ cls,
24
+ step_config: Dict[str, Any],
25
+ engine_context: Any
26
+ ) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
27
+ action = step_config.get('action')
28
+
29
+ if not action:
30
+ raise RuntimeError(f"Step '{step_config.get('id', 'unknown')}' is missing 'action' property")
31
+
32
+ if action not in cls._strategies:
33
+ raise RuntimeError(f"Unknown action type: '{action}'.")
34
+
35
+ strategy_class = cls._strategies[action]
36
+ strategy = strategy_class(step_config, engine_context)
37
+
38
+ return strategy.get_node_function()
39
+
40
+ @classmethod
41
+ def is_registered(cls, action_type: str) -> bool:
42
+ return action_type in cls._strategies
43
+
44
+
45
+ NodeFactory.register(ActionType.COLLECT_INPUT_WITH_AGENT.value, CollectInputStrategy)
46
+ NodeFactory.register(ActionType.CALL_FUNCTION.value, CallFunctionStrategy)
File without changes
@@ -0,0 +1,97 @@
1
+ from typing import Dict, Any, List, Optional
2
+
3
+ from langgraph.constants import END
4
+
5
+ from ..core.constants import WorkflowKeys
6
+ from ..utils.logger import logger
7
+
8
+
9
+ class WorkflowRouter:
10
+ def __init__(self, step_config: Dict[str, Any], step_map: Dict[str, Any], outcome_map: Dict[str, Any]):
11
+ self.step_id = step_config['id']
12
+ self.action = step_config['action']
13
+ self.step_map = step_map
14
+ self.outcome_map = outcome_map
15
+ self.transitions = step_config.get('transitions', [])
16
+ self.next_step = step_config.get('next')
17
+
18
+ def create_route_function(self):
19
+ def route_fn(state: Dict[str, Any]) -> str:
20
+ try:
21
+ return self._route(state)
22
+ except Exception as e:
23
+ logger.error(f"Routing error in step '{self.step_id}': {e}")
24
+ raise RuntimeError(f"Failed to route from step '{self.step_id}': {e}")
25
+
26
+ return route_fn
27
+
28
+ def _route(self, state: Dict[str, Any]) -> str:
29
+ status = state.get(WorkflowKeys.STATUS, '')
30
+
31
+ if status == f'{self.step_id}_collecting':
32
+ logger.info(f"Self-loop: {self.step_id} (collecting)")
33
+ return self.step_id
34
+
35
+ if status == f'{self.step_id}_error' :
36
+ logger.info(f"Error encountered in {self.step_id}, ending workflow")
37
+ return END
38
+
39
+ if status == f'{self.step_id}_max_attempts':
40
+ logger.info(f"Max attempts reached in {self.step_id}, ending workflow")
41
+ return END
42
+
43
+ if self.transitions or status.startswith(f'{self.step_id}_'):
44
+ next_node = self._route_with_transitions(state, status)
45
+ if next_node:
46
+ return next_node
47
+
48
+ if self.next_step:
49
+ is_outcome = self.next_step in self.outcome_map
50
+ logger.info(f"Simple routing: {self.step_id} -> {self.next_step}")
51
+ return END if is_outcome else self.next_step
52
+
53
+ logger.info(f"No routing match for status '{status}', ending workflow")
54
+ return END
55
+
56
+ def _route_with_transitions(self, state: Dict[str, Any], status: str) -> Optional[str]:
57
+ if not status.startswith(f'{self.step_id}_'):
58
+ return None
59
+
60
+ target = status[len(self.step_id) + 1:]
61
+
62
+ if target in self.outcome_map:
63
+ logger.info(f"Routing to outcome: {self.step_id} -> {target}")
64
+ return END
65
+
66
+ if target in self.step_map:
67
+ logger.info(f"Routing to step: {self.step_id} -> {target}")
68
+ return target
69
+
70
+ logger.warning(f"Unknown routing target '{target}' from step '{self.step_id}'")
71
+ return END
72
+
73
+ def get_routing_map(self, collector_nodes: List[str]) -> Dict[str, str]:
74
+ routing_map = {}
75
+
76
+ if self.action == 'collect_input_with_agent':
77
+ routing_map[self.step_id] = self.step_id
78
+
79
+ for transition in self.transitions:
80
+ next_dest = transition['next']
81
+ if next_dest in self.step_map:
82
+ routing_map[next_dest] = next_dest
83
+ else:
84
+ routing_map[next_dest] = END
85
+
86
+ if self.next_step:
87
+ if self.next_step in self.step_map:
88
+ routing_map[self.next_step] = self.next_step
89
+ else:
90
+ routing_map[self.next_step] = END
91
+
92
+ for collector_node in collector_nodes:
93
+ routing_map[collector_node] = collector_node
94
+
95
+ routing_map[END] = END
96
+
97
+ return routing_map