soprano-sdk 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- soprano_sdk/__init__.py +12 -0
- soprano_sdk/agents/__init__.py +30 -0
- soprano_sdk/agents/adaptor.py +90 -0
- soprano_sdk/agents/factory.py +228 -0
- soprano_sdk/agents/structured_output.py +97 -0
- soprano_sdk/authenticators/__init__.py +0 -0
- soprano_sdk/authenticators/mfa.py +205 -0
- soprano_sdk/core/__init__.py +0 -0
- soprano_sdk/core/constants.py +125 -0
- soprano_sdk/core/engine.py +315 -0
- soprano_sdk/core/rollback_strategies.py +258 -0
- soprano_sdk/core/state.py +80 -0
- soprano_sdk/engine.py +381 -0
- soprano_sdk/nodes/__init__.py +0 -0
- soprano_sdk/nodes/async_function.py +237 -0
- soprano_sdk/nodes/base.py +61 -0
- soprano_sdk/nodes/call_function.py +139 -0
- soprano_sdk/nodes/collect_input.py +573 -0
- soprano_sdk/nodes/factory.py +48 -0
- soprano_sdk/routing/__init__.py +0 -0
- soprano_sdk/routing/router.py +102 -0
- soprano_sdk/tools.py +232 -0
- soprano_sdk/utils/__init__.py +0 -0
- soprano_sdk/utils/function.py +35 -0
- soprano_sdk/utils/logger.py +6 -0
- soprano_sdk/utils/template.py +27 -0
- soprano_sdk/utils/tool.py +60 -0
- soprano_sdk/utils/tracing.py +69 -0
- soprano_sdk/validation/__init__.py +13 -0
- soprano_sdk/validation/schema.py +332 -0
- soprano_sdk/validation/validator.py +227 -0
- soprano_sdk-0.2.10.dist-info/METADATA +420 -0
- soprano_sdk-0.2.10.dist-info/RECORD +35 -0
- soprano_sdk-0.2.10.dist-info/WHEEL +4 -0
- soprano_sdk-0.2.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class WorkflowKeys:
|
|
8
|
+
STEP_ID = '_step_id'
|
|
9
|
+
STATUS = '_status'
|
|
10
|
+
OUTCOME_ID = '_outcome_id'
|
|
11
|
+
MESSAGES = '_messages'
|
|
12
|
+
CONVERSATIONS = '_conversations'
|
|
13
|
+
STATE_HISTORY = '_state_history'
|
|
14
|
+
COLLECTOR_NODES = '_collector_nodes'
|
|
15
|
+
ATTEMPT_COUNTS = '_attempt_counts'
|
|
16
|
+
NODE_EXECUTION_ORDER = '_node_execution_order'
|
|
17
|
+
NODE_FIELD_MAP = '_node_field_map'
|
|
18
|
+
COMPUTED_FIELDS = '_computed_fields'
|
|
19
|
+
ERROR = 'error'
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ActionType(Enum):
|
|
23
|
+
COLLECT_INPUT_WITH_AGENT = 'collect_input_with_agent'
|
|
24
|
+
CALL_FUNCTION = 'call_function'
|
|
25
|
+
CALL_ASYNC_FUNCTION = 'call_async_function'
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class InterruptType:
|
|
29
|
+
"""Interrupt type markers for workflow pauses"""
|
|
30
|
+
USER_INPUT = '__WORKFLOW_INTERRUPT__'
|
|
31
|
+
ASYNC = '__ASYNC_INTERRUPT__'
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DataType(Enum):
|
|
35
|
+
TEXT = 'text'
|
|
36
|
+
NUMBER = 'number'
|
|
37
|
+
DOUBLE = 'double'
|
|
38
|
+
BOOLEAN = 'boolean'
|
|
39
|
+
LIST = 'list'
|
|
40
|
+
DICT = 'dict'
|
|
41
|
+
ANY = "any"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class OutcomeType(Enum):
|
|
45
|
+
SUCCESS = 'success'
|
|
46
|
+
FAILURE = 'failure'
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class StatusPattern:
|
|
50
|
+
COLLECTING = '{step_id}_collecting'
|
|
51
|
+
MAX_ATTEMPTS = '{step_id}_max_attempts'
|
|
52
|
+
NEXT_STEP = '{step_id}_{next_step}'
|
|
53
|
+
SUCCESS = '{step_id}_success'
|
|
54
|
+
FAILED = '{step_id}_failed'
|
|
55
|
+
INTENT_CHANGE = '{step_id}_{target_node}'
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TransitionPattern:
|
|
59
|
+
CAPTURED = '{field}_CAPTURED:'
|
|
60
|
+
FAILED = '{field}_FAILED:'
|
|
61
|
+
INTENT_CHANGE = 'INTENT_CHANGE:'
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
DEFAULT_MAX_ATTEMPTS = 3
|
|
65
|
+
DEFAULT_MODEL = 'gpt-4o-mini'
|
|
66
|
+
DEFAULT_TIMEOUT = 300
|
|
67
|
+
|
|
68
|
+
MAX_ATTEMPTS_MESSAGE = "I'm having trouble understanding your {field}. Please contact customer service for assistance."
|
|
69
|
+
WORKFLOW_COMPLETE_MESSAGE = "Workflow completed."
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class MFAConfig(BaseSettings):
|
|
73
|
+
"""
|
|
74
|
+
Configuration for MFA REST API endpoints.
|
|
75
|
+
|
|
76
|
+
Values can be provided during initialization or will be automatically
|
|
77
|
+
loaded from environment variables with the same name (uppercase).
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
# Load from environment variables
|
|
81
|
+
config = MFAConfig()
|
|
82
|
+
|
|
83
|
+
# Or provide specific values
|
|
84
|
+
config = MFAConfig(
|
|
85
|
+
generate_token_base_url="https://api.example.com",
|
|
86
|
+
generate_token_path="/v1/mfa/generate"
|
|
87
|
+
)
|
|
88
|
+
"""
|
|
89
|
+
generate_token_base_url: Optional[str] = Field(
|
|
90
|
+
default=None,
|
|
91
|
+
description="Base URL for the generate token endpoint"
|
|
92
|
+
)
|
|
93
|
+
generate_token_path: Optional[str] = Field(
|
|
94
|
+
default=None,
|
|
95
|
+
description="Path for the generate token endpoint"
|
|
96
|
+
)
|
|
97
|
+
validate_token_base_url: Optional[str] = Field(
|
|
98
|
+
default=None,
|
|
99
|
+
description="Base URL for the validate token endpoint"
|
|
100
|
+
)
|
|
101
|
+
validate_token_path: Optional[str] = Field(
|
|
102
|
+
default=None,
|
|
103
|
+
description="Path for the validate token endpoint"
|
|
104
|
+
)
|
|
105
|
+
authorize_token_base_url: Optional[str] = Field(
|
|
106
|
+
default=None,
|
|
107
|
+
description="Base URL for the authorize token endpoint"
|
|
108
|
+
)
|
|
109
|
+
authorize_token_path: Optional[str] = Field(
|
|
110
|
+
default=None,
|
|
111
|
+
description="Path for the authorize token endpoint"
|
|
112
|
+
)
|
|
113
|
+
api_timeout: int = Field(
|
|
114
|
+
default=30,
|
|
115
|
+
description="API request timeout in seconds"
|
|
116
|
+
)
|
|
117
|
+
mfa_cancelled_message: str = Field(
|
|
118
|
+
default="Authentication has been cancelled.",
|
|
119
|
+
description="Message to display when user cancels MFA authentication"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
model_config = SettingsConfigDict(
|
|
123
|
+
case_sensitive=False,
|
|
124
|
+
extra='ignore'
|
|
125
|
+
)
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any, Tuple
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
from jinja2 import Environment
|
|
5
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
6
|
+
from langgraph.constants import START
|
|
7
|
+
from langgraph.graph import StateGraph
|
|
8
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
9
|
+
|
|
10
|
+
from .constants import WorkflowKeys, MFAConfig
|
|
11
|
+
from .state import create_state_model
|
|
12
|
+
from ..nodes.factory import NodeFactory
|
|
13
|
+
from ..routing.router import WorkflowRouter
|
|
14
|
+
from ..utils.function import FunctionRepository
|
|
15
|
+
from ..utils.logger import logger
|
|
16
|
+
from ..utils.tool import ToolRepository
|
|
17
|
+
from ..validation import validate_workflow
|
|
18
|
+
from soprano_sdk.authenticators.mfa import MFANodeConfig
|
|
19
|
+
|
|
20
|
+
class WorkflowEngine:
|
|
21
|
+
|
|
22
|
+
def __init__(self, yaml_path: str, configs: dict, mfa_config: Optional[MFAConfig] = None):
|
|
23
|
+
self.yaml_path = yaml_path
|
|
24
|
+
self.configs = configs or {}
|
|
25
|
+
logger.info(f"Loading workflow from: {yaml_path}")
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
with open(yaml_path, 'r') as f:
|
|
29
|
+
self.config = yaml.safe_load(f)
|
|
30
|
+
|
|
31
|
+
logger.info("Validating workflow configuration")
|
|
32
|
+
validate_workflow(self.config, mfa_config=mfa_config or MFAConfig())
|
|
33
|
+
|
|
34
|
+
self.workflow_name = self.config['name']
|
|
35
|
+
self.workflow_description = self.config['description']
|
|
36
|
+
self.workflow_version = self.config['version']
|
|
37
|
+
self.mfa_validator_steps: set[str] = set()
|
|
38
|
+
self.steps: list = self.load_steps()
|
|
39
|
+
self.step_map = {step['id']: step for step in self.steps}
|
|
40
|
+
self.mfa_config = (mfa_config or MFAConfig()) if self.mfa_validator_steps else None
|
|
41
|
+
self.data_fields = self.load_data()
|
|
42
|
+
self.outcomes = self.load_outcomes()
|
|
43
|
+
self.metadata = self.config.get('metadata', {})
|
|
44
|
+
|
|
45
|
+
self.StateType = create_state_model(self.data_fields)
|
|
46
|
+
|
|
47
|
+
self.outcome_map = {outcome['id']: outcome for outcome in self.outcomes}
|
|
48
|
+
|
|
49
|
+
self.function_repository = FunctionRepository()
|
|
50
|
+
self.tool_repository = None
|
|
51
|
+
if tool_config := self.config.get("tool_config"):
|
|
52
|
+
self.tool_repository = ToolRepository(tool_config)
|
|
53
|
+
|
|
54
|
+
self.context_store = {}
|
|
55
|
+
self.collect_input_fields = self._get_collect_input_fields()
|
|
56
|
+
|
|
57
|
+
logger.info(
|
|
58
|
+
f"Workflow loaded: {self.workflow_name} v{self.workflow_version} "
|
|
59
|
+
f"({len(self.steps)} steps, {len(self.outcomes)} outcomes)"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
except Exception as e:
|
|
63
|
+
raise e
|
|
64
|
+
|
|
65
|
+
def get_config_value(self, key, default_value: Optional[Any]=None):
|
|
66
|
+
if value := self.configs.get(key) :
|
|
67
|
+
return value
|
|
68
|
+
|
|
69
|
+
if value := self.config.get(key) :
|
|
70
|
+
return value
|
|
71
|
+
|
|
72
|
+
return default_value
|
|
73
|
+
|
|
74
|
+
def _get_collect_input_fields(self) -> set:
|
|
75
|
+
fields = set()
|
|
76
|
+
for step in self.steps:
|
|
77
|
+
if step.get('action') == 'collect_input_with_agent' and (field := step.get('field')):
|
|
78
|
+
fields.add(field)
|
|
79
|
+
return fields
|
|
80
|
+
|
|
81
|
+
def update_context(self, context: Dict[str, Any]):
|
|
82
|
+
self.context_store.update(context)
|
|
83
|
+
logger.info(f"Context updated: {context}")
|
|
84
|
+
|
|
85
|
+
def remove_context_field(self, field_name: str):
|
|
86
|
+
if field_name in self.context_store:
|
|
87
|
+
del self.context_store[field_name]
|
|
88
|
+
logger.info(f"Removed context field: {field_name}")
|
|
89
|
+
|
|
90
|
+
def get_context_value(self, field_name: str):
|
|
91
|
+
value = self.context_store.get(field_name, None)
|
|
92
|
+
if value is not None:
|
|
93
|
+
logger.info(f"Retrieved context value for '{field_name}': {value}")
|
|
94
|
+
return value
|
|
95
|
+
|
|
96
|
+
def build_graph(self, checkpointer=None):
|
|
97
|
+
logger.info("Building workflow graph")
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
builder = StateGraph(self.StateType)
|
|
101
|
+
|
|
102
|
+
collector_nodes = []
|
|
103
|
+
|
|
104
|
+
logger.info("Adding nodes to graph")
|
|
105
|
+
for step in self.steps:
|
|
106
|
+
step_id = step['id']
|
|
107
|
+
action = step['action']
|
|
108
|
+
|
|
109
|
+
if action == 'collect_input_with_agent':
|
|
110
|
+
collector_nodes.append(step_id)
|
|
111
|
+
|
|
112
|
+
node_fn = NodeFactory.create(step, engine_context=self)
|
|
113
|
+
builder.add_node(step_id, node_fn)
|
|
114
|
+
|
|
115
|
+
logger.info(f"Added node: {step_id} (action: {action})")
|
|
116
|
+
|
|
117
|
+
first_step_id = self.steps[0]['id']
|
|
118
|
+
builder.add_edge(START, first_step_id)
|
|
119
|
+
logger.info(f"Set entry point: {first_step_id}")
|
|
120
|
+
|
|
121
|
+
logger.info("Adding routing edges")
|
|
122
|
+
for step in self.steps:
|
|
123
|
+
step_id = step['id']
|
|
124
|
+
|
|
125
|
+
router = WorkflowRouter(step, self.step_map, self.outcome_map)
|
|
126
|
+
route_fn = router.create_route_function()
|
|
127
|
+
routing_map = router.get_routing_map(collector_nodes)
|
|
128
|
+
|
|
129
|
+
builder.add_conditional_edges(step_id, route_fn, routing_map)
|
|
130
|
+
|
|
131
|
+
logger.info(
|
|
132
|
+
f"Added routing for {step_id}: {len(routing_map)} destinations"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if checkpointer is None:
|
|
136
|
+
checkpointer = InMemorySaver()
|
|
137
|
+
logger.info("Using InMemorySaver for state persistence")
|
|
138
|
+
else:
|
|
139
|
+
logger.info(f"Using custom checkpointer: {type(checkpointer).__name__}")
|
|
140
|
+
|
|
141
|
+
graph = builder.compile(checkpointer=checkpointer)
|
|
142
|
+
|
|
143
|
+
logger.info("Workflow graph built successfully")
|
|
144
|
+
return graph
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
raise RuntimeError(f"Failed to build workflow graph: {e}")
|
|
148
|
+
|
|
149
|
+
def get_outcome_message(self, state: Dict[str, Any]) -> str:
|
|
150
|
+
outcome_id = state.get(WorkflowKeys.OUTCOME_ID)
|
|
151
|
+
step_id = state.get(WorkflowKeys.STEP_ID)
|
|
152
|
+
|
|
153
|
+
outcome = self.outcome_map.get(outcome_id)
|
|
154
|
+
if outcome and 'message' in outcome:
|
|
155
|
+
message = outcome['message']
|
|
156
|
+
template_loader = self.get_config_value("template_loader", Environment())
|
|
157
|
+
message = template_loader.from_string(message).render(state)
|
|
158
|
+
logger.info(f"Outcome message generated in step {step_id}: {message}")
|
|
159
|
+
return message
|
|
160
|
+
|
|
161
|
+
if error := state.get("error"):
|
|
162
|
+
logger.info(f"Outcome error found in step {step_id}: {error}")
|
|
163
|
+
return f"{error}"
|
|
164
|
+
|
|
165
|
+
if message := state.get(WorkflowKeys.MESSAGES):
|
|
166
|
+
logger.info(f"Outcome message found in step {step_id}: {message}")
|
|
167
|
+
return f"{message}"
|
|
168
|
+
|
|
169
|
+
logger.error(f"No outcome message found in step {step_id}")
|
|
170
|
+
return "{'error': 'Unable to complete the request'}"
|
|
171
|
+
|
|
172
|
+
def get_step_info(self, step_id: str) -> Optional[Dict[str, Any]]:
|
|
173
|
+
return self.step_map.get(step_id)
|
|
174
|
+
|
|
175
|
+
def get_outcome_info(self, outcome_id: str) -> Optional[Dict[str, Any]]:
|
|
176
|
+
return self.outcome_map.get(outcome_id)
|
|
177
|
+
|
|
178
|
+
def list_steps(self) -> list:
|
|
179
|
+
return [step['id'] for step in self.steps]
|
|
180
|
+
|
|
181
|
+
def list_outcomes(self) -> list:
|
|
182
|
+
return [outcome['id'] for outcome in self.outcomes]
|
|
183
|
+
|
|
184
|
+
def get_workflow_info(self) -> Dict[str, Any]:
|
|
185
|
+
return {
|
|
186
|
+
'name': self.workflow_name,
|
|
187
|
+
'description': self.workflow_description,
|
|
188
|
+
'version': self.workflow_version,
|
|
189
|
+
'steps': len(self.steps),
|
|
190
|
+
'outcomes': len(self.outcomes),
|
|
191
|
+
'data_fields': [f['name'] for f in self.data_fields],
|
|
192
|
+
'metadata': self.metadata
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
def get_tool_policy(self) -> str:
|
|
196
|
+
tool_config = self.config.get('tool_config')
|
|
197
|
+
if not tool_config:
|
|
198
|
+
raise ValueError("Tool config is not provided in the YAML")
|
|
199
|
+
return tool_config.get('usage_policy')
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def load_steps(self):
|
|
203
|
+
prepared_steps: list = []
|
|
204
|
+
mfa_redirects: Dict[str, str] = {}
|
|
205
|
+
|
|
206
|
+
for step in self.config['steps']:
|
|
207
|
+
step_id = step['id']
|
|
208
|
+
|
|
209
|
+
if mfa_config := step.get('mfa'):
|
|
210
|
+
mfa_data_collector = MFANodeConfig.get_validate_user_input(
|
|
211
|
+
next_node=step_id,
|
|
212
|
+
source_node=step_id,
|
|
213
|
+
mfa_config=mfa_config
|
|
214
|
+
)
|
|
215
|
+
mfa_start = MFANodeConfig.get_call_function_template(
|
|
216
|
+
source_node=step_id,
|
|
217
|
+
next_node=mfa_data_collector['id'],
|
|
218
|
+
mfa=mfa_config
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
prepared_steps.append(mfa_start)
|
|
222
|
+
prepared_steps.append(mfa_data_collector)
|
|
223
|
+
self.mfa_validator_steps.add(mfa_data_collector['id'])
|
|
224
|
+
|
|
225
|
+
mfa_redirects[step_id] = mfa_start['id']
|
|
226
|
+
|
|
227
|
+
del step['mfa']
|
|
228
|
+
|
|
229
|
+
prepared_steps.append(step)
|
|
230
|
+
|
|
231
|
+
for step in prepared_steps:
|
|
232
|
+
if step['id'] in self.mfa_validator_steps: # MFA Validator
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
elif 'mfa' in step: # MFA Start
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
elif step.get('transitions'):
|
|
239
|
+
for transition in step.get('transitions'):
|
|
240
|
+
next_step = transition.get('next')
|
|
241
|
+
if next_step in mfa_redirects:
|
|
242
|
+
transition['next'] = mfa_redirects[next_step]
|
|
243
|
+
|
|
244
|
+
elif step.get('next') in mfa_redirects:
|
|
245
|
+
step['next'] = mfa_redirects[step['next']]
|
|
246
|
+
|
|
247
|
+
return prepared_steps
|
|
248
|
+
|
|
249
|
+
def load_data(self):
|
|
250
|
+
data: list = self.config['data']
|
|
251
|
+
for step_id in self.mfa_validator_steps:
|
|
252
|
+
step_details = self.step_map[step_id]
|
|
253
|
+
data.append(
|
|
254
|
+
dict(
|
|
255
|
+
name=f'{step_details['field']}',
|
|
256
|
+
type='text',
|
|
257
|
+
description='Input Recieved from the user during MFA'
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
return data
|
|
261
|
+
|
|
262
|
+
def load_outcomes(self):
|
|
263
|
+
outcomes: list = self.config['outcomes']
|
|
264
|
+
|
|
265
|
+
if self.mfa_config:
|
|
266
|
+
mfa_cancelled_outcome = {
|
|
267
|
+
'id': 'mfa_cancelled',
|
|
268
|
+
'type': 'failure',
|
|
269
|
+
'message': self.mfa_config.mfa_cancelled_message
|
|
270
|
+
}
|
|
271
|
+
outcomes.append(mfa_cancelled_outcome)
|
|
272
|
+
logger.info(f"Auto-generated 'mfa_cancelled' outcome with message: {self.mfa_config.mfa_cancelled_message}")
|
|
273
|
+
|
|
274
|
+
return outcomes
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def load_workflow(yaml_path: str, checkpointer=None, config=None, mfa_config: Optional[MFAConfig] = None) -> Tuple[CompiledStateGraph, WorkflowEngine]:
|
|
278
|
+
"""
|
|
279
|
+
Load a workflow from YAML configuration.
|
|
280
|
+
|
|
281
|
+
This is the main entry point for using the framework.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
yaml_path: Path to the workflow YAML file
|
|
285
|
+
checkpointer: Optional checkpointer for state persistence.
|
|
286
|
+
Defaults to InMemorySaver() if not provided.
|
|
287
|
+
Example: MongoDBSaver for production persistence.
|
|
288
|
+
config: Optional configuration dictionary
|
|
289
|
+
mfa_config: Optional MFA configuration. If not provided, will load from environment variables.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Tuple of (compiled_graph, engine) where:
|
|
293
|
+
- compiled_graph: LangGraph ready for execution
|
|
294
|
+
- engine: WorkflowEngine instance for introspection
|
|
295
|
+
|
|
296
|
+
Example:
|
|
297
|
+
```python
|
|
298
|
+
# Load with environment variables
|
|
299
|
+
graph, engine = load_workflow("workflow.yaml")
|
|
300
|
+
|
|
301
|
+
# Or provide MFA configuration explicitly
|
|
302
|
+
from soprano_sdk.core.constants import MFAConfig
|
|
303
|
+
mfa_config = MFAConfig(
|
|
304
|
+
generate_token_base_url="https://api.example.com",
|
|
305
|
+
generate_token_path="/v1/mfa/generate"
|
|
306
|
+
)
|
|
307
|
+
graph, engine = load_workflow("workflow.yaml", mfa_config=mfa_config)
|
|
308
|
+
|
|
309
|
+
result = graph.invoke({}, config={"configurable": {"thread_id": "123"}})
|
|
310
|
+
message = engine.get_outcome_message(result)
|
|
311
|
+
```
|
|
312
|
+
"""
|
|
313
|
+
engine = WorkflowEngine(yaml_path, configs=config, mfa_config=mfa_config)
|
|
314
|
+
graph = engine.build_graph(checkpointer=checkpointer)
|
|
315
|
+
return graph, engine
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import uuid
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Dict, Any, List
|
|
6
|
+
|
|
7
|
+
from soprano_sdk.core.constants import WorkflowKeys, ActionType
|
|
8
|
+
from ..utils.logger import logger
|
|
9
|
+
|
|
10
|
+
class RollbackStrategy(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def rollback_to_node(
|
|
13
|
+
self,
|
|
14
|
+
state: Dict[str, Any],
|
|
15
|
+
target_node: str,
|
|
16
|
+
node_execution_order: List[str],
|
|
17
|
+
node_field_map: Dict[str, str],
|
|
18
|
+
workflow_steps: List[Dict[str, Any]]
|
|
19
|
+
) -> Dict[str, Any]:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def should_save_snapshot(self) -> bool:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def get_strategy_name(self) -> str:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _restore_from_snapshot(snapshot: Dict[str, Any]) -> Dict[str, Any]:
|
|
36
|
+
return copy.deepcopy(snapshot['state'])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _clear_future_executions(
|
|
40
|
+
state: Dict[str, Any],
|
|
41
|
+
target_node: str,
|
|
42
|
+
workflow_steps: List[Dict[str, Any]]
|
|
43
|
+
) -> Dict[str, Any]:
|
|
44
|
+
target_step_index = next(
|
|
45
|
+
(i for i, step in enumerate(workflow_steps) if step['id'] == target_node),
|
|
46
|
+
None
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if target_step_index is None:
|
|
50
|
+
logger.warning(f"Target node {target_node} not found in workflow steps")
|
|
51
|
+
return state
|
|
52
|
+
|
|
53
|
+
future_steps = workflow_steps[target_step_index:]
|
|
54
|
+
|
|
55
|
+
logger.info(f"Future steps to clear: {[s['id'] for s in future_steps]}")
|
|
56
|
+
|
|
57
|
+
for step in future_steps:
|
|
58
|
+
action = step.get('action')
|
|
59
|
+
|
|
60
|
+
if action == ActionType.COLLECT_INPUT_WITH_AGENT.value:
|
|
61
|
+
field_name = step.get('field')
|
|
62
|
+
if field_name:
|
|
63
|
+
state[field_name] = None
|
|
64
|
+
|
|
65
|
+
conv_key = f"{field_name}_conversation"
|
|
66
|
+
conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
|
|
67
|
+
if conv_key in conversations:
|
|
68
|
+
del conversations[conv_key]
|
|
69
|
+
logger.info(f"Cleared conversation: {conv_key}")
|
|
70
|
+
|
|
71
|
+
elif action == ActionType.CALL_FUNCTION.value:
|
|
72
|
+
output_field = step.get('output')
|
|
73
|
+
if output_field:
|
|
74
|
+
state[output_field] = None
|
|
75
|
+
logger.info(f"Cleared computed field: {output_field}")
|
|
76
|
+
|
|
77
|
+
return state
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class HistoryBasedRollback(RollbackStrategy):
|
|
81
|
+
def get_strategy_name(self) -> str:
|
|
82
|
+
return "history_based"
|
|
83
|
+
|
|
84
|
+
def should_save_snapshot(self) -> bool:
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
|
|
88
|
+
state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
|
|
89
|
+
|
|
90
|
+
snapshot = {
|
|
91
|
+
'snapshot_id': str(uuid.uuid4()),
|
|
92
|
+
'node_about_to_execute': node_id,
|
|
93
|
+
'execution_index': execution_index,
|
|
94
|
+
'timestamp': datetime.now().isoformat(),
|
|
95
|
+
'state': copy.deepcopy(state),
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
state_history.append(snapshot)
|
|
99
|
+
state[WorkflowKeys.STATE_HISTORY] = state_history
|
|
100
|
+
|
|
101
|
+
logger.info(f"Saved snapshot #{len(state_history)-1} before executing {node_id}")
|
|
102
|
+
|
|
103
|
+
def rollback_to_node(
|
|
104
|
+
self,
|
|
105
|
+
state: Dict[str, Any],
|
|
106
|
+
target_node: str,
|
|
107
|
+
node_execution_order: List[str],
|
|
108
|
+
node_field_map: Dict[str, str],
|
|
109
|
+
workflow_steps: List[Dict[str, Any]]
|
|
110
|
+
) -> Dict[str, Any]:
|
|
111
|
+
state_history = state.get(WorkflowKeys.STATE_HISTORY, [])
|
|
112
|
+
|
|
113
|
+
if not state_history:
|
|
114
|
+
logger.warning("No state history available for rollback")
|
|
115
|
+
return {}
|
|
116
|
+
|
|
117
|
+
logger.info(f"Looking for snapshot before node '{target_node}'")
|
|
118
|
+
|
|
119
|
+
target_snapshot = None
|
|
120
|
+
target_index = None
|
|
121
|
+
|
|
122
|
+
for i, snapshot in enumerate(state_history):
|
|
123
|
+
if snapshot.get('node_about_to_execute') == target_node:
|
|
124
|
+
target_snapshot = snapshot
|
|
125
|
+
target_index = i
|
|
126
|
+
break
|
|
127
|
+
|
|
128
|
+
if target_snapshot is None:
|
|
129
|
+
logger.warning(f"No snapshot found before node '{target_node}'")
|
|
130
|
+
return {}
|
|
131
|
+
|
|
132
|
+
logger.info(f"Found snapshot at index {target_index}")
|
|
133
|
+
restored_state = _restore_from_snapshot(target_snapshot)
|
|
134
|
+
|
|
135
|
+
restored_state = _clear_future_executions(
|
|
136
|
+
restored_state,
|
|
137
|
+
target_node,
|
|
138
|
+
workflow_steps
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
restored_state[WorkflowKeys.STATE_HISTORY] = state_history[:target_index + 1]
|
|
142
|
+
|
|
143
|
+
logger.info(f"Successfully rolled back to {target_node}")
|
|
144
|
+
|
|
145
|
+
return restored_state
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _build_dependency_graph(
|
|
149
|
+
workflow_steps: List[Dict[str, Any]]
|
|
150
|
+
) -> Dict[str, List[str]]:
|
|
151
|
+
graph = {}
|
|
152
|
+
|
|
153
|
+
for step in workflow_steps:
|
|
154
|
+
field = step.get('field') or step.get('output')
|
|
155
|
+
|
|
156
|
+
if not field:
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
depends_on = step.get('depends_on')
|
|
160
|
+
|
|
161
|
+
if depends_on:
|
|
162
|
+
if isinstance(depends_on, str):
|
|
163
|
+
depends_on_list = [depends_on]
|
|
164
|
+
elif isinstance(depends_on, list):
|
|
165
|
+
depends_on_list = depends_on
|
|
166
|
+
else:
|
|
167
|
+
logger.warning(f"Invalid depends_on type for field '{field}': {type(depends_on)}")
|
|
168
|
+
depends_on_list = []
|
|
169
|
+
|
|
170
|
+
for parent_field in depends_on_list:
|
|
171
|
+
if parent_field not in graph:
|
|
172
|
+
graph[parent_field] = []
|
|
173
|
+
if field not in graph[parent_field]:
|
|
174
|
+
graph[parent_field].append(field)
|
|
175
|
+
|
|
176
|
+
if field not in graph:
|
|
177
|
+
graph[field] = []
|
|
178
|
+
|
|
179
|
+
return graph
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _find_all_dependents(
|
|
183
|
+
field: str,
|
|
184
|
+
dependency_graph: Dict[str, List[str]]
|
|
185
|
+
) -> set:
|
|
186
|
+
all_dependents = set()
|
|
187
|
+
visited = set()
|
|
188
|
+
|
|
189
|
+
def _recurse(current_field: str):
|
|
190
|
+
if current_field in visited:
|
|
191
|
+
return
|
|
192
|
+
visited.add(current_field)
|
|
193
|
+
|
|
194
|
+
direct_dependents = dependency_graph.get(current_field, [])
|
|
195
|
+
|
|
196
|
+
for dependent in direct_dependents:
|
|
197
|
+
all_dependents.add(dependent)
|
|
198
|
+
_recurse(dependent)
|
|
199
|
+
|
|
200
|
+
_recurse(field)
|
|
201
|
+
|
|
202
|
+
return all_dependents
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _clear_field_conversation(state: Dict[str, Any], field: str) -> None:
|
|
206
|
+
conv_key = f"{field}_conversation"
|
|
207
|
+
conversations = state.get(WorkflowKeys.CONVERSATIONS, {})
|
|
208
|
+
|
|
209
|
+
if conv_key in conversations:
|
|
210
|
+
del conversations[conv_key]
|
|
211
|
+
logger.info(f"Cleared conversation: {conv_key}")
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class DependencyBasedRollback(RollbackStrategy):
|
|
215
|
+
def get_strategy_name(self) -> str:
|
|
216
|
+
return "dependency_based"
|
|
217
|
+
|
|
218
|
+
def should_save_snapshot(self) -> bool:
|
|
219
|
+
return False
|
|
220
|
+
|
|
221
|
+
def save_snapshot(self, state: Dict[str, Any], node_id: str, execution_index: int) -> None:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
def rollback_to_node(
|
|
225
|
+
self,
|
|
226
|
+
state: Dict[str, Any],
|
|
227
|
+
target_node: str,
|
|
228
|
+
node_execution_order: List[str],
|
|
229
|
+
node_field_map: Dict[str, str],
|
|
230
|
+
workflow_steps: List[Dict[str, Any]]
|
|
231
|
+
) -> Dict[str, Any]:
|
|
232
|
+
target_field = node_field_map.get(target_node)
|
|
233
|
+
|
|
234
|
+
if not target_field:
|
|
235
|
+
logger.warning(f"No field found for target node '{target_node}'")
|
|
236
|
+
return state
|
|
237
|
+
|
|
238
|
+
logger.info(f"Rolling back to node '{target_node}' (field: '{target_field}')")
|
|
239
|
+
|
|
240
|
+
dependency_graph = _build_dependency_graph(workflow_steps)
|
|
241
|
+
|
|
242
|
+
logger.info(f"Dependency graph: {dependency_graph}")
|
|
243
|
+
|
|
244
|
+
dependent_fields = _find_all_dependents(target_field, dependency_graph)
|
|
245
|
+
|
|
246
|
+
logger.info(f"Fields dependent on '{target_field}': {dependent_fields}")
|
|
247
|
+
|
|
248
|
+
state[target_field] = None
|
|
249
|
+
_clear_field_conversation(state, target_field)
|
|
250
|
+
|
|
251
|
+
for field in dependent_fields:
|
|
252
|
+
state[field] = None
|
|
253
|
+
_clear_field_conversation(state, field)
|
|
254
|
+
logger.info(f"Cleared dependent field: {field}")
|
|
255
|
+
|
|
256
|
+
logger.info(f"Successfully rolled back to {target_node} using dependency graph")
|
|
257
|
+
|
|
258
|
+
return state
|