soprano-sdk 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- soprano_sdk/__init__.py +10 -0
- soprano_sdk/agents/__init__.py +30 -0
- soprano_sdk/agents/adaptor.py +91 -0
- soprano_sdk/agents/factory.py +228 -0
- soprano_sdk/agents/structured_output.py +97 -0
- soprano_sdk/core/__init__.py +0 -0
- soprano_sdk/core/constants.py +59 -0
- soprano_sdk/core/engine.py +225 -0
- soprano_sdk/core/rollback_strategies.py +259 -0
- soprano_sdk/core/state.py +71 -0
- soprano_sdk/engine.py +381 -0
- soprano_sdk/nodes/__init__.py +0 -0
- soprano_sdk/nodes/base.py +57 -0
- soprano_sdk/nodes/call_function.py +108 -0
- soprano_sdk/nodes/collect_input.py +526 -0
- soprano_sdk/nodes/factory.py +46 -0
- soprano_sdk/routing/__init__.py +0 -0
- soprano_sdk/routing/router.py +97 -0
- soprano_sdk/tools.py +219 -0
- soprano_sdk/utils/__init__.py +0 -0
- soprano_sdk/utils/function.py +35 -0
- soprano_sdk/utils/logger.py +6 -0
- soprano_sdk/utils/template.py +27 -0
- soprano_sdk/utils/tool.py +60 -0
- soprano_sdk/utils/tracing.py +71 -0
- soprano_sdk/validation/__init__.py +13 -0
- soprano_sdk/validation/schema.py +302 -0
- soprano_sdk/validation/validator.py +173 -0
- {soprano_sdk-0.1.94.dist-info → soprano_sdk-0.1.96.dist-info}/METADATA +1 -1
- soprano_sdk-0.1.96.dist-info/RECORD +32 -0
- soprano_sdk-0.1.94.dist-info/RECORD +0 -4
- {soprano_sdk-0.1.94.dist-info → soprano_sdk-0.1.96.dist-info}/WHEEL +0 -0
- {soprano_sdk-0.1.94.dist-info → soprano_sdk-0.1.96.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
from jinja2 import Environment
|
|
4
|
+
from langgraph.types import interrupt
|
|
5
|
+
|
|
6
|
+
from .base import ActionStrategy
|
|
7
|
+
from ..agents.factory import AgentFactory, AgentAdapter
|
|
8
|
+
from ..agents.structured_output import create_structured_output_model, validate_field_definitions
|
|
9
|
+
from ..core.constants import (
|
|
10
|
+
WorkflowKeys,
|
|
11
|
+
DEFAULT_MAX_ATTEMPTS,
|
|
12
|
+
MAX_ATTEMPTS_MESSAGE,
|
|
13
|
+
TransitionPattern
|
|
14
|
+
)
|
|
15
|
+
from ..core.rollback_strategies import (
|
|
16
|
+
RollbackStrategy,
|
|
17
|
+
HistoryBasedRollback,
|
|
18
|
+
DependencyBasedRollback
|
|
19
|
+
)
|
|
20
|
+
from ..core.state import initialize_state
|
|
21
|
+
from ..utils.logger import logger
|
|
22
|
+
from ..utils.tracing import trace_node_execution, trace_agent_invocation, add_node_result
|
|
23
|
+
|
|
24
|
+
VALIDATION_ERROR_MESSAGE = "validation failed for the provided input, please enter valid input"
|
|
25
|
+
INVALID_INPUT_MESSAGE = "Looks like the input is invalid. Please double-check and re-enter it."
|
|
26
|
+
COLLECTION_FAILURE_MESSAGE = "I couldn't understand your response. Please try again and provide the required information."
|
|
27
|
+
|
|
28
|
+
def _wrap_instructions_with_intent_detection(
|
|
29
|
+
instructions: str,
|
|
30
|
+
collector_nodes: Dict[str, str],
|
|
31
|
+
with_structured_output: bool
|
|
32
|
+
) -> str:
|
|
33
|
+
if not collector_nodes:
|
|
34
|
+
return instructions
|
|
35
|
+
|
|
36
|
+
collector_nodes_str = "\n".join(f"{node_name}: {description}" for node_name, description in collector_nodes.items())
|
|
37
|
+
return f"""
|
|
38
|
+
{instructions}
|
|
39
|
+
|
|
40
|
+
AVAILABLE CONVERSATION INTENTS:
|
|
41
|
+
{collector_nodes_str}
|
|
42
|
+
|
|
43
|
+
Format: <node_name>: <intent_description>
|
|
44
|
+
|
|
45
|
+
CRITICAL INTENT DETECTION RULES:
|
|
46
|
+
1. ONLY check for intent changes against the EXACT node names listed above
|
|
47
|
+
2. The node name MUST appear in the list above to be valid
|
|
48
|
+
3. Do NOT infer, guess, or create new intent names
|
|
49
|
+
4. Tool calls are NOT intent changes
|
|
50
|
+
5. If no intents are listed above, NEVER trigger an intent change
|
|
51
|
+
|
|
52
|
+
Before responding, analyze if the user's query matches a DIFFERENT intent from the list above:
|
|
53
|
+
|
|
54
|
+
IF the user's query clearly matches a DIFFERENT intent that EXISTS in the list above:
|
|
55
|
+
- {"Respond ONLY with: INTENT_CHANGE: <node_name>" if not with_structured_output else "modify intent_change value <node_name>"}
|
|
56
|
+
- Use the EXACT node_name from the list above
|
|
57
|
+
- Do NOT provide any other response
|
|
58
|
+
- Do NOT answer the user's question
|
|
59
|
+
|
|
60
|
+
IF the user's query continues with the SAME intent OR does not match any intent in the list above:
|
|
61
|
+
- Proceed with your normal response
|
|
62
|
+
- Do NOT mention intent detection
|
|
63
|
+
- Answer the user's question as configured
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def _create_rollback_strategy(strategy_name: str) -> RollbackStrategy:
|
|
67
|
+
if strategy_name == "dependency_based":
|
|
68
|
+
return DependencyBasedRollback()
|
|
69
|
+
elif strategy_name == "history_based":
|
|
70
|
+
return HistoryBasedRollback()
|
|
71
|
+
else:
|
|
72
|
+
logger.warning(f"Unknown rollback strategy '{strategy_name}', using history_based")
|
|
73
|
+
return HistoryBasedRollback()
|
|
74
|
+
|
|
75
|
+
def _get_agent_response(agent: AgentAdapter, conversation: List[Dict[str, str]]) -> Any:
|
|
76
|
+
agent_response = agent.invoke(conversation)
|
|
77
|
+
|
|
78
|
+
conversation.append({"role": "assistant", "content": str(agent_response)})
|
|
79
|
+
|
|
80
|
+
return agent_response
|
|
81
|
+
|
|
82
|
+
class CollectInputStrategy(ActionStrategy):
|
|
83
|
+
def __init__(self, step_config: Dict[str, Any], engine_context: Any):
|
|
84
|
+
super().__init__(step_config, engine_context)
|
|
85
|
+
self.field = step_config.get('field')
|
|
86
|
+
self.agent_config = step_config.get('agent', {})
|
|
87
|
+
self.max_attempts = step_config.get('retry_limit') or engine_context.get_config_value("max_retry_limit", DEFAULT_MAX_ATTEMPTS)
|
|
88
|
+
self.transitions = self._get_transitions()
|
|
89
|
+
self.next_step = self.step_config.get("next", None)
|
|
90
|
+
self.is_structured_output = self.agent_config.get("structured_output", {}).get("enabled", False)
|
|
91
|
+
|
|
92
|
+
rollback_strategy_name = engine_context.get_config_value("rollback_strategy", "history_based")
|
|
93
|
+
self.rollback_strategy = _create_rollback_strategy(rollback_strategy_name)
|
|
94
|
+
logger.info(f"Using rollback strategy: {self.rollback_strategy.get_strategy_name()}")
|
|
95
|
+
|
|
96
|
+
self.validator = None
|
|
97
|
+
if validator_function_path := self.step_config.get("validator"):
|
|
98
|
+
self.validator = self.engine_context.function_repository.load(validator_function_path)
|
|
99
|
+
|
|
100
|
+
if not self.field:
|
|
101
|
+
raise RuntimeError(f"Step '{self.step_id}' missing required 'field' property")
|
|
102
|
+
|
|
103
|
+
if not self.agent_config:
|
|
104
|
+
raise RuntimeError(f"Step '{self.step_id}' missing required 'agent' configuration")
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def _conversation_key(self) -> str:
|
|
108
|
+
return f'{self.field}_conversation'
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def _formatted_field_name(self) -> str:
|
|
112
|
+
return self.field.replace('_', ' ').title()
|
|
113
|
+
|
|
114
|
+
def execute(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
115
|
+
with trace_node_execution(
|
|
116
|
+
node_id=self.step_id,
|
|
117
|
+
node_type="collect_input_with_agent",
|
|
118
|
+
output_field=self.field
|
|
119
|
+
) as span:
|
|
120
|
+
state = initialize_state(state)
|
|
121
|
+
|
|
122
|
+
self._apply_context_value(state, span)
|
|
123
|
+
|
|
124
|
+
is_self_loop = self._is_self_loop(state)
|
|
125
|
+
|
|
126
|
+
if not is_self_loop:
|
|
127
|
+
if self.rollback_strategy.should_save_snapshot():
|
|
128
|
+
self._save_snapshot_before_execution(state)
|
|
129
|
+
|
|
130
|
+
conversation = self._get_or_create_conversation(state)
|
|
131
|
+
|
|
132
|
+
if self._is_field_pre_populated(state):
|
|
133
|
+
span.add_event("field.pre_populated", {"value": str(state.get(self.field))})
|
|
134
|
+
state = self._handle_pre_populated_field(state, conversation)
|
|
135
|
+
add_node_result(span, self.field, state.get(self.field), state.get(WorkflowKeys.STATUS))
|
|
136
|
+
return self._add_node_to_execution_order(state)
|
|
137
|
+
|
|
138
|
+
if self._max_attempts_reached(state):
|
|
139
|
+
span.add_event("max_attempts.reached")
|
|
140
|
+
return self._handle_max_attempts(state)
|
|
141
|
+
|
|
142
|
+
agent = self._create_agent(state)
|
|
143
|
+
|
|
144
|
+
prompt = self._generate_prompt(agent, conversation, state)
|
|
145
|
+
|
|
146
|
+
user_input = interrupt(prompt)
|
|
147
|
+
|
|
148
|
+
conversation.append({"role": "user", "content": user_input})
|
|
149
|
+
span.add_event("user.input_received", {"input_length": len(user_input)})
|
|
150
|
+
|
|
151
|
+
with trace_agent_invocation(
|
|
152
|
+
agent_name=self.agent_config.get('name', self.field),
|
|
153
|
+
model=self.agent_config.get('model', 'default')
|
|
154
|
+
):
|
|
155
|
+
agent_response = _get_agent_response(agent, conversation)
|
|
156
|
+
|
|
157
|
+
if self.is_structured_output:
|
|
158
|
+
state = self._handle_structured_output_transition(state, conversation, agent_response)
|
|
159
|
+
add_node_result(span, self.field, state.get(self.field), state.get(WorkflowKeys.STATUS))
|
|
160
|
+
return self._add_node_to_execution_order(state)
|
|
161
|
+
|
|
162
|
+
if agent_response.startswith(TransitionPattern.INTENT_CHANGE):
|
|
163
|
+
span.add_event("intent.change_detected")
|
|
164
|
+
return self._handle_intent_change(agent_response, state)
|
|
165
|
+
|
|
166
|
+
state = self._process_transitions(state, conversation, agent_response)
|
|
167
|
+
|
|
168
|
+
self._update_conversation(state, conversation)
|
|
169
|
+
|
|
170
|
+
add_node_result(span, self.field, state.get(self.field), state.get(WorkflowKeys.STATUS))
|
|
171
|
+
|
|
172
|
+
return self._add_node_to_execution_order(state)
|
|
173
|
+
|
|
174
|
+
def _render_template_string(self, template_str: str, state: Dict[str, Any]) -> str:
|
|
175
|
+
if not template_str:
|
|
176
|
+
return ""
|
|
177
|
+
template_loader = self.engine_context.get_config_value('template_loader', Environment())
|
|
178
|
+
return template_loader.from_string(template_str).render(state)
|
|
179
|
+
|
|
180
|
+
def _apply_context_value(self, state: Dict[str, Any], span) -> None:
|
|
181
|
+
if not (context_value := self.engine_context.get_context_value(self.field)):
|
|
182
|
+
return
|
|
183
|
+
logger.info(f"Using context value for '{self.field}': {context_value}")
|
|
184
|
+
state[self.field] = context_value
|
|
185
|
+
span.add_event("context.value_used", {"field": self.field, "value": str(context_value)})
|
|
186
|
+
|
|
187
|
+
def _add_node_to_execution_order(self, state):
|
|
188
|
+
if 'collecting' in state.get('_status'):
|
|
189
|
+
return state
|
|
190
|
+
|
|
191
|
+
self._register_node_execution(state)
|
|
192
|
+
self._register_collector_node(state)
|
|
193
|
+
|
|
194
|
+
return state
|
|
195
|
+
|
|
196
|
+
def _is_self_loop(self, state: Dict[str, Any]) -> bool:
|
|
197
|
+
return state.get(WorkflowKeys.STATUS) == f'{self.step_id}_collecting'
|
|
198
|
+
|
|
199
|
+
def _save_snapshot_before_execution(self, state: Dict[str, Any]):
|
|
200
|
+
state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
|
|
201
|
+
execution_index = len(state_history)
|
|
202
|
+
self.rollback_strategy.save_snapshot(state, self.step_id, execution_index)
|
|
203
|
+
|
|
204
|
+
def _register_node_execution(self, state: Dict[str, Any]):
|
|
205
|
+
execution_order = state.get(WorkflowKeys.NODE_EXECUTION_ORDER, [])
|
|
206
|
+
if self.step_id not in execution_order:
|
|
207
|
+
execution_order.append(self.step_id)
|
|
208
|
+
state[WorkflowKeys.NODE_EXECUTION_ORDER] = execution_order
|
|
209
|
+
|
|
210
|
+
def _register_collector_node(self, state: Dict[str, Any]):
|
|
211
|
+
collector_nodes = state.get(WorkflowKeys.COLLECTOR_NODES, {})
|
|
212
|
+
description = self.agent_config.get('description', f"Collecting {self.field}")
|
|
213
|
+
collector_nodes[self.step_id] = description
|
|
214
|
+
state[WorkflowKeys.COLLECTOR_NODES] = collector_nodes
|
|
215
|
+
|
|
216
|
+
node_field_map = state.get(WorkflowKeys.NODE_FIELD_MAP, {})
|
|
217
|
+
node_field_map[self.step_id] = self.field
|
|
218
|
+
state[WorkflowKeys.NODE_FIELD_MAP] = node_field_map
|
|
219
|
+
|
|
220
|
+
def _is_field_pre_populated(self, state: Dict[str, Any]) -> bool:
|
|
221
|
+
return state.get(self.field) is not None
|
|
222
|
+
|
|
223
|
+
def _validate_collected_input(self, state) -> Tuple[bool, Optional[str]]:
|
|
224
|
+
if not self.validator:
|
|
225
|
+
return True, None
|
|
226
|
+
result = self.validator(**state)
|
|
227
|
+
if isinstance(result, tuple):
|
|
228
|
+
return result
|
|
229
|
+
return result, None
|
|
230
|
+
|
|
231
|
+
def _handle_pre_populated_field(self, state: Dict[str, Any], conversation: List) -> Dict[str, Any]:
|
|
232
|
+
logger.info(f"Field '{self.field}' is populated, skipping collection")
|
|
233
|
+
|
|
234
|
+
is_valid_input, _ = self._validate_collected_input(state)
|
|
235
|
+
if not is_valid_input:
|
|
236
|
+
self._set_status(state, "collecting")
|
|
237
|
+
return self._handle_validation_failure(state, conversation, message=f"{state[self.field]}", role="user")
|
|
238
|
+
|
|
239
|
+
if self.transitions:
|
|
240
|
+
first_transition = self.transitions[0]
|
|
241
|
+
next_step = first_transition['next']
|
|
242
|
+
self._set_status(state, next_step)
|
|
243
|
+
|
|
244
|
+
if next_step in self.engine_context.outcome_map:
|
|
245
|
+
self._set_outcome(state, next_step)
|
|
246
|
+
|
|
247
|
+
if self.next_step:
|
|
248
|
+
self._set_status(state, self.next_step)
|
|
249
|
+
|
|
250
|
+
if self.next_step in self.engine_context.outcome_map:
|
|
251
|
+
self._set_outcome(state, self.next_step)
|
|
252
|
+
|
|
253
|
+
return state
|
|
254
|
+
|
|
255
|
+
def _max_attempts_reached(self, state: Dict[str, Any]) -> bool:
|
|
256
|
+
conversation = state.get(WorkflowKeys.CONVERSATIONS, {}).get(self._conversation_key, [])
|
|
257
|
+
attempt_count = len([m for m in conversation if m['role'] == 'user'])
|
|
258
|
+
return attempt_count >= self.max_attempts
|
|
259
|
+
|
|
260
|
+
def _handle_max_attempts(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
261
|
+
logger.warning(f"Max attempts reached for field '{self.field}'")
|
|
262
|
+
self._set_status(state, 'max_attempts')
|
|
263
|
+
message = MAX_ATTEMPTS_MESSAGE.format(field=self.field)
|
|
264
|
+
state[WorkflowKeys.MESSAGES] = [message]
|
|
265
|
+
return state
|
|
266
|
+
|
|
267
|
+
def _get_or_create_conversation(self, state: Dict[str, Any]) -> List[Dict[str, str]]:
|
|
268
|
+
conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
|
|
269
|
+
|
|
270
|
+
if self._conversation_key not in conversations:
|
|
271
|
+
conversations[self._conversation_key] = []
|
|
272
|
+
state[WorkflowKeys.CONVERSATIONS] = conversations
|
|
273
|
+
|
|
274
|
+
return conversations[self._conversation_key]
|
|
275
|
+
def _get_model_config(self) -> Dict[str, Any]:
|
|
276
|
+
model_config = self.engine_context.get_config_value('model_config')
|
|
277
|
+
if not model_config:
|
|
278
|
+
raise ValueError("Model config not found in engine context")
|
|
279
|
+
|
|
280
|
+
if model_id := self.agent_config.get("model"):
|
|
281
|
+
model_config = model_config.copy()
|
|
282
|
+
model_config["model_name"] = model_id
|
|
283
|
+
|
|
284
|
+
return model_config
|
|
285
|
+
|
|
286
|
+
def _get_instructions(self, state: Dict[str, Any], collector_nodes: Dict[str, str]) -> str:
|
|
287
|
+
instructions = self.agent_config.get("instructions")
|
|
288
|
+
|
|
289
|
+
instructions = self._render_template_string(instructions, state)
|
|
290
|
+
|
|
291
|
+
if collector_nodes:
|
|
292
|
+
instructions = _wrap_instructions_with_intent_detection(instructions, collector_nodes, self.is_structured_output)
|
|
293
|
+
return instructions
|
|
294
|
+
|
|
295
|
+
def _load_agent_tools(self, state: Dict[str, Any]) -> List:
|
|
296
|
+
return [
|
|
297
|
+
self.engine_context.tool_repository.load(tool_name, state)
|
|
298
|
+
for tool_name in self.agent_config.get('tools', [])
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
def _create_structured_output_model(self, collector_nodes: Dict[str, str]) -> Any:
|
|
302
|
+
structured_output_config = self.agent_config.get('structured_output')
|
|
303
|
+
if not structured_output_config or not structured_output_config.get('enabled'):
|
|
304
|
+
return None
|
|
305
|
+
|
|
306
|
+
fields = structured_output_config.get('fields', [])
|
|
307
|
+
if not fields:
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
validate_field_definitions(fields)
|
|
311
|
+
model_name = f"{self.field.title().replace('_', '')}StructuredOutput"
|
|
312
|
+
|
|
313
|
+
return create_structured_output_model(
|
|
314
|
+
fields=fields,
|
|
315
|
+
model_name=model_name,
|
|
316
|
+
needs_intent_change=len(collector_nodes) > 0
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def _create_agent(self, state: Dict[str, Any]) -> AgentAdapter:
|
|
320
|
+
try:
|
|
321
|
+
model_config = self._get_model_config()
|
|
322
|
+
agent_tools = self._load_agent_tools(state)
|
|
323
|
+
collector_nodes = state.get(WorkflowKeys.COLLECTOR_NODES, {})
|
|
324
|
+
|
|
325
|
+
instructions = self._get_instructions(state, collector_nodes)
|
|
326
|
+
structured_output_model = self._create_structured_output_model(collector_nodes)
|
|
327
|
+
framework = self.engine_context.get_config_value('agent_framework', 'langgraph')
|
|
328
|
+
|
|
329
|
+
return AgentFactory.create_agent(
|
|
330
|
+
framework=framework,
|
|
331
|
+
name=self.agent_config.get('name', f'{self.field}Collector'),
|
|
332
|
+
model_config=model_config,
|
|
333
|
+
tools=agent_tools,
|
|
334
|
+
system_prompt=instructions,
|
|
335
|
+
structured_output_model=structured_output_model
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
except Exception as e:
|
|
339
|
+
raise RuntimeError(f"Failed to create agent for step '{self.step_id}': {e}")
|
|
340
|
+
|
|
341
|
+
def _generate_prompt(
|
|
342
|
+
self,
|
|
343
|
+
agent: AgentAdapter,
|
|
344
|
+
conversation: List[Dict[str, str]],
|
|
345
|
+
state: Dict[str, Any]
|
|
346
|
+
) -> str:
|
|
347
|
+
if len(conversation) == 0:
|
|
348
|
+
if not (prompt := self.agent_config.get('initial_message')):
|
|
349
|
+
prompt = agent.invoke([{"role": "user", "content": ""}])
|
|
350
|
+
|
|
351
|
+
prompt = self._render_template_string(prompt, state)
|
|
352
|
+
conversation.append({"role": "assistant", "content": prompt})
|
|
353
|
+
|
|
354
|
+
return prompt
|
|
355
|
+
|
|
356
|
+
return conversation[-1]['content']
|
|
357
|
+
|
|
358
|
+
def _update_conversation(self, state: Dict[str, Any], conversation: List[Dict[str, str]]):
|
|
359
|
+
state[WorkflowKeys.CONVERSATIONS][self._conversation_key] = conversation
|
|
360
|
+
|
|
361
|
+
def _handle_intent_change(self, target_node_or_response, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
362
|
+
if isinstance(target_node_or_response, str) and TransitionPattern.INTENT_CHANGE in target_node_or_response:
|
|
363
|
+
target_node = target_node_or_response.split(TransitionPattern.INTENT_CHANGE)[1].strip()
|
|
364
|
+
else:
|
|
365
|
+
target_node = target_node_or_response
|
|
366
|
+
|
|
367
|
+
logger.info(f"Intent change detected: {self.step_id} -> {target_node}")
|
|
368
|
+
|
|
369
|
+
rollback_state = self._rollback_state_to_node(state, target_node)
|
|
370
|
+
|
|
371
|
+
if rollback_state is None:
|
|
372
|
+
logger.error(f"Failed to rollback to node '{target_node}'")
|
|
373
|
+
raise RuntimeError(f"Unable to process intent change to '{target_node}'")
|
|
374
|
+
|
|
375
|
+
return rollback_state
|
|
376
|
+
|
|
377
|
+
def _rollback_state_to_node(
|
|
378
|
+
self,
|
|
379
|
+
state: Dict[str, Any],
|
|
380
|
+
target_node: str
|
|
381
|
+
) -> Dict[str, Any]:
|
|
382
|
+
node_execution_order = state.get(WorkflowKeys.NODE_EXECUTION_ORDER, [])
|
|
383
|
+
node_field_map = state.get(WorkflowKeys.NODE_FIELD_MAP, {})
|
|
384
|
+
workflow_steps = self.engine_context.steps
|
|
385
|
+
|
|
386
|
+
restored_state = self.rollback_strategy.rollback_to_node(
|
|
387
|
+
state=state,
|
|
388
|
+
target_node=target_node,
|
|
389
|
+
node_execution_order=node_execution_order,
|
|
390
|
+
node_field_map=node_field_map,
|
|
391
|
+
workflow_steps=workflow_steps
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if not restored_state:
|
|
395
|
+
logger.warning(f"Rollback strategy returned empty state for node '{target_node}'")
|
|
396
|
+
return {}
|
|
397
|
+
|
|
398
|
+
restored_state[WorkflowKeys.STATUS] = f"{self.step_id}_{target_node}"
|
|
399
|
+
|
|
400
|
+
return restored_state
|
|
401
|
+
|
|
402
|
+
def _process_transitions(
|
|
403
|
+
self,
|
|
404
|
+
state: Dict[str, Any],
|
|
405
|
+
conversation: List,
|
|
406
|
+
agent_response: str
|
|
407
|
+
) -> Dict[str, Any]:
|
|
408
|
+
matched = False
|
|
409
|
+
self._set_status(state, 'collecting')
|
|
410
|
+
|
|
411
|
+
for transition in self.transitions:
|
|
412
|
+
pattern = transition['pattern']
|
|
413
|
+
if pattern not in agent_response:
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
matched = True
|
|
417
|
+
next_step = transition['next']
|
|
418
|
+
|
|
419
|
+
logger.info(f"Matched transition: {transition}")
|
|
420
|
+
|
|
421
|
+
value = agent_response.split(pattern)[1].strip()
|
|
422
|
+
if value:
|
|
423
|
+
self._store_field_value(state, value)
|
|
424
|
+
is_valid_input, message = self._validate_collected_input(state)
|
|
425
|
+
if not is_valid_input:
|
|
426
|
+
return self._handle_validation_failure(state, conversation, message=message)
|
|
427
|
+
state[WorkflowKeys.MESSAGES] = [f"✓ {self._formatted_field_name} collected: {value}" ]
|
|
428
|
+
else:
|
|
429
|
+
state[WorkflowKeys.MESSAGES] = []
|
|
430
|
+
|
|
431
|
+
self._set_status(state, next_step)
|
|
432
|
+
|
|
433
|
+
if next_step in self.engine_context.outcome_map:
|
|
434
|
+
self._set_outcome(state, next_step)
|
|
435
|
+
|
|
436
|
+
break
|
|
437
|
+
|
|
438
|
+
if not matched:
|
|
439
|
+
logger.info(f"No transition matched for response in step '{self.step_id}', continuing collection")
|
|
440
|
+
state[WorkflowKeys.MESSAGES] = []
|
|
441
|
+
|
|
442
|
+
return state
|
|
443
|
+
|
|
444
|
+
def _handle_structured_output_transition(self, state: Dict[str, Any], conversation: List, agent_response: Any) -> Dict[str, Any]:
|
|
445
|
+
if target_node := agent_response.get("intent_change"):
|
|
446
|
+
return self._handle_intent_change(target_node, state)
|
|
447
|
+
|
|
448
|
+
self._set_status(state, "collecting")
|
|
449
|
+
|
|
450
|
+
if bot_response := agent_response.get("bot_response"):
|
|
451
|
+
conversation.append({"role": "assistant", "content": bot_response})
|
|
452
|
+
return state
|
|
453
|
+
|
|
454
|
+
self._store_field_value(state, agent_response)
|
|
455
|
+
|
|
456
|
+
is_valid_input, validation_error_message = self._validate_collected_input(state)
|
|
457
|
+
if not is_valid_input:
|
|
458
|
+
return self._handle_validation_failure(state, conversation, message=validation_error_message)
|
|
459
|
+
|
|
460
|
+
if next_node := self._find_matching_transition(agent_response):
|
|
461
|
+
return self._complete_collection(state, next_node, agent_response)
|
|
462
|
+
|
|
463
|
+
if self.next_step:
|
|
464
|
+
return self._complete_collection(state, self.next_step, agent_response)
|
|
465
|
+
|
|
466
|
+
return self._handle_collection_failure(state, conversation)
|
|
467
|
+
|
|
468
|
+
def _handle_validation_failure(self, state: Dict[str, Any], conversation: List, message: Optional[str]=VALIDATION_ERROR_MESSAGE, role="assistant") -> Dict[str, Any]:
|
|
469
|
+
self._store_field_value(state, None)
|
|
470
|
+
self.engine_context.update_context({self.field: None})
|
|
471
|
+
conversation.append({"role": role, "content": message})
|
|
472
|
+
return state
|
|
473
|
+
|
|
474
|
+
def _find_matching_transition(self, agent_response: Any) -> Optional[str]:
|
|
475
|
+
is_structured_output = isinstance(agent_response, dict)
|
|
476
|
+
|
|
477
|
+
for transition in self.transitions:
|
|
478
|
+
if is_structured_output:
|
|
479
|
+
next_node = transition.get("next")
|
|
480
|
+
match_value = transition.get("match")
|
|
481
|
+
ref_field = transition.get("ref")
|
|
482
|
+
|
|
483
|
+
if not all([next_node, match_value, ref_field]):
|
|
484
|
+
raise RuntimeError(f"Transition in step '{self.step_id}' missing required properties for structured output routing")
|
|
485
|
+
|
|
486
|
+
if field_value := agent_response.get(ref_field):
|
|
487
|
+
if field_value == match_value:
|
|
488
|
+
return next_node
|
|
489
|
+
else:
|
|
490
|
+
next_node = transition.get("next")
|
|
491
|
+
pattern = transition.get("pattern")
|
|
492
|
+
if pattern in agent_response:
|
|
493
|
+
return next_node
|
|
494
|
+
|
|
495
|
+
return None
|
|
496
|
+
|
|
497
|
+
def _complete_collection(self, state: Dict[str, Any], next_node: str, agent_response: Any) -> Dict[str, Any]:
|
|
498
|
+
self._set_status(state, next_node)
|
|
499
|
+
|
|
500
|
+
if next_node in self.engine_context.outcome_map:
|
|
501
|
+
self._set_outcome(state, next_node)
|
|
502
|
+
|
|
503
|
+
state[WorkflowKeys.MESSAGES] = [
|
|
504
|
+
f"✓ {self._formatted_field_name} collected: {str(agent_response)}"
|
|
505
|
+
]
|
|
506
|
+
|
|
507
|
+
return state
|
|
508
|
+
|
|
509
|
+
def _handle_collection_failure(self, state: Dict[str, Any], conversation: List) -> Dict[str, Any]:
|
|
510
|
+
conversation.append({"role": "assistant", "content": COLLECTION_FAILURE_MESSAGE})
|
|
511
|
+
self._store_field_value(state, None)
|
|
512
|
+
return state
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def _store_field_value(self, state: Dict[str, Any], value: Any):
|
|
516
|
+
field_def = next((f for f in self.engine_context.data_fields if f['name'] == self.field), None)
|
|
517
|
+
if not field_def:
|
|
518
|
+
return
|
|
519
|
+
|
|
520
|
+
if field_def.get('type') == 'number':
|
|
521
|
+
try:
|
|
522
|
+
state[self.field] = int(value)
|
|
523
|
+
except ValueError:
|
|
524
|
+
state[self.field] = value
|
|
525
|
+
else:
|
|
526
|
+
state[self.field] = value
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Dict, Any, Type, Callable
|
|
2
|
+
|
|
3
|
+
from .base import ActionStrategy
|
|
4
|
+
from .call_function import CallFunctionStrategy
|
|
5
|
+
from .collect_input import CollectInputStrategy
|
|
6
|
+
from ..core.constants import ActionType
|
|
7
|
+
from ..utils.logger import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NodeFactory:
|
|
11
|
+
_strategies: Dict[str, Type[ActionStrategy]] = {}
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def register(cls, action_type: str, strategy_class: Type[ActionStrategy]):
|
|
15
|
+
if not issubclass(strategy_class, ActionStrategy):
|
|
16
|
+
raise RuntimeError(f"Strategy class {strategy_class.__name__} must inherit from NodeStrategy")
|
|
17
|
+
|
|
18
|
+
logger.info(f"Registering node strategy: {action_type} -> {strategy_class.__name__}")
|
|
19
|
+
cls._strategies[action_type] = strategy_class
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def create(
|
|
23
|
+
cls,
|
|
24
|
+
step_config: Dict[str, Any],
|
|
25
|
+
engine_context: Any
|
|
26
|
+
) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
|
|
27
|
+
action = step_config.get('action')
|
|
28
|
+
|
|
29
|
+
if not action:
|
|
30
|
+
raise RuntimeError(f"Step '{step_config.get('id', 'unknown')}' is missing 'action' property")
|
|
31
|
+
|
|
32
|
+
if action not in cls._strategies:
|
|
33
|
+
raise RuntimeError(f"Unknown action type: '{action}'.")
|
|
34
|
+
|
|
35
|
+
strategy_class = cls._strategies[action]
|
|
36
|
+
strategy = strategy_class(step_config, engine_context)
|
|
37
|
+
|
|
38
|
+
return strategy.get_node_function()
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def is_registered(cls, action_type: str) -> bool:
|
|
42
|
+
return action_type in cls._strategies
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
NodeFactory.register(ActionType.COLLECT_INPUT_WITH_AGENT.value, CollectInputStrategy)
|
|
46
|
+
NodeFactory.register(ActionType.CALL_FUNCTION.value, CallFunctionStrategy)
|
|
File without changes
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Optional
|
|
2
|
+
|
|
3
|
+
from langgraph.constants import END
|
|
4
|
+
|
|
5
|
+
from ..core.constants import WorkflowKeys
|
|
6
|
+
from ..utils.logger import logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class WorkflowRouter:
|
|
10
|
+
def __init__(self, step_config: Dict[str, Any], step_map: Dict[str, Any], outcome_map: Dict[str, Any]):
|
|
11
|
+
self.step_id = step_config['id']
|
|
12
|
+
self.action = step_config['action']
|
|
13
|
+
self.step_map = step_map
|
|
14
|
+
self.outcome_map = outcome_map
|
|
15
|
+
self.transitions = step_config.get('transitions', [])
|
|
16
|
+
self.next_step = step_config.get('next')
|
|
17
|
+
|
|
18
|
+
def create_route_function(self):
|
|
19
|
+
def route_fn(state: Dict[str, Any]) -> str:
|
|
20
|
+
try:
|
|
21
|
+
return self._route(state)
|
|
22
|
+
except Exception as e:
|
|
23
|
+
logger.error(f"Routing error in step '{self.step_id}': {e}")
|
|
24
|
+
raise RuntimeError(f"Failed to route from step '{self.step_id}': {e}")
|
|
25
|
+
|
|
26
|
+
return route_fn
|
|
27
|
+
|
|
28
|
+
def _route(self, state: Dict[str, Any]) -> str:
|
|
29
|
+
status = state.get(WorkflowKeys.STATUS, '')
|
|
30
|
+
|
|
31
|
+
if status == f'{self.step_id}_collecting':
|
|
32
|
+
logger.info(f"Self-loop: {self.step_id} (collecting)")
|
|
33
|
+
return self.step_id
|
|
34
|
+
|
|
35
|
+
if status == f'{self.step_id}_error' :
|
|
36
|
+
logger.info(f"Error encountered in {self.step_id}, ending workflow")
|
|
37
|
+
return END
|
|
38
|
+
|
|
39
|
+
if status == f'{self.step_id}_max_attempts':
|
|
40
|
+
logger.info(f"Max attempts reached in {self.step_id}, ending workflow")
|
|
41
|
+
return END
|
|
42
|
+
|
|
43
|
+
if self.transitions or status.startswith(f'{self.step_id}_'):
|
|
44
|
+
next_node = self._route_with_transitions(state, status)
|
|
45
|
+
if next_node:
|
|
46
|
+
return next_node
|
|
47
|
+
|
|
48
|
+
if self.next_step:
|
|
49
|
+
is_outcome = self.next_step in self.outcome_map
|
|
50
|
+
logger.info(f"Simple routing: {self.step_id} -> {self.next_step}")
|
|
51
|
+
return END if is_outcome else self.next_step
|
|
52
|
+
|
|
53
|
+
logger.info(f"No routing match for status '{status}', ending workflow")
|
|
54
|
+
return END
|
|
55
|
+
|
|
56
|
+
def _route_with_transitions(self, state: Dict[str, Any], status: str) -> Optional[str]:
|
|
57
|
+
if not status.startswith(f'{self.step_id}_'):
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
target = status[len(self.step_id) + 1:]
|
|
61
|
+
|
|
62
|
+
if target in self.outcome_map:
|
|
63
|
+
logger.info(f"Routing to outcome: {self.step_id} -> {target}")
|
|
64
|
+
return END
|
|
65
|
+
|
|
66
|
+
if target in self.step_map:
|
|
67
|
+
logger.info(f"Routing to step: {self.step_id} -> {target}")
|
|
68
|
+
return target
|
|
69
|
+
|
|
70
|
+
logger.warning(f"Unknown routing target '{target}' from step '{self.step_id}'")
|
|
71
|
+
return END
|
|
72
|
+
|
|
73
|
+
def get_routing_map(self, collector_nodes: List[str]) -> Dict[str, str]:
|
|
74
|
+
routing_map = {}
|
|
75
|
+
|
|
76
|
+
if self.action == 'collect_input_with_agent':
|
|
77
|
+
routing_map[self.step_id] = self.step_id
|
|
78
|
+
|
|
79
|
+
for transition in self.transitions:
|
|
80
|
+
next_dest = transition['next']
|
|
81
|
+
if next_dest in self.step_map:
|
|
82
|
+
routing_map[next_dest] = next_dest
|
|
83
|
+
else:
|
|
84
|
+
routing_map[next_dest] = END
|
|
85
|
+
|
|
86
|
+
if self.next_step:
|
|
87
|
+
if self.next_step in self.step_map:
|
|
88
|
+
routing_map[self.next_step] = self.next_step
|
|
89
|
+
else:
|
|
90
|
+
routing_map[self.next_step] = END
|
|
91
|
+
|
|
92
|
+
for collector_node in collector_nodes:
|
|
93
|
+
routing_map[collector_node] = collector_node
|
|
94
|
+
|
|
95
|
+
routing_map[END] = END
|
|
96
|
+
|
|
97
|
+
return routing_map
|