vellum-ai 0.10.7__py3-none-any.whl → 0.10.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/utils.py +27 -0
- vellum/workflows/events/__init__.py +0 -2
- vellum/workflows/events/tests/test_event.py +2 -1
- vellum/workflows/events/types.py +36 -30
- vellum/workflows/events/workflow.py +14 -7
- vellum/workflows/nodes/bases/base.py +100 -38
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +1 -0
- vellum/workflows/nodes/core/templating_node/node.py +5 -0
- vellum/workflows/nodes/core/try_node/node.py +22 -4
- vellum/workflows/nodes/core/try_node/tests/test_node.py +15 -0
- vellum/workflows/nodes/displayable/api_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +1 -2
- vellum/workflows/nodes/displayable/code_execution_node/node.py +1 -2
- vellum/workflows/nodes/displayable/code_execution_node/utils.py +13 -2
- vellum/workflows/nodes/displayable/inline_prompt_node/node.py +10 -3
- vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +6 -1
- vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +1 -2
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +1 -2
- vellum/workflows/runner/runner.py +141 -32
- vellum/workflows/state/base.py +55 -21
- vellum/workflows/state/context.py +26 -3
- vellum/workflows/types/__init__.py +5 -0
- vellum/workflows/types/core.py +1 -1
- vellum/workflows/workflows/base.py +51 -17
- vellum/workflows/workflows/event_filters.py +61 -0
- {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/METADATA +1 -1
- {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/RECORD +40 -38
- vellum_cli/__init__.py +23 -4
- vellum_cli/pull.py +28 -13
- vellum_cli/tests/test_pull.py +45 -2
- vellum_ee/workflows/display/nodes/base_node_display.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/__init__.py +6 -4
- vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +17 -2
- vellum_ee/workflows/display/nodes/vellum/error_node.py +49 -0
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +203 -0
- vellum/workflows/events/utils.py +0 -5
- {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/LICENSE +0 -0
- {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/WHEEL +0 -0
- {vellum_ai-0.10.7.dist-info → vellum_ai-0.10.9.dist-info}/entry_points.txt +0 -0
@@ -67,9 +67,8 @@ class CodeExecutionNode(BaseNode[StateType], Generic[StateType, _OutputType], me
|
|
67
67
|
|
68
68
|
filepath: str - The path to the script to execute.
|
69
69
|
code_inputs: EntityInputsInterface - The inputs for the custom script.
|
70
|
-
output_type: VellumVariableType = "STRING" - The type of the output from the custom script.
|
71
70
|
runtime: CodeExecutionRuntime = "PYTHON_3_12" - The runtime to use for the custom script.
|
72
|
-
packages: Optional[Sequence[
|
71
|
+
packages: Optional[Sequence[CodeExecutionPackage]] = None - The packages to use for the custom script.
|
73
72
|
request_options: Optional[RequestOptions] = None - The request options to use for the custom script.
|
74
73
|
"""
|
75
74
|
|
@@ -2,9 +2,20 @@ import os
|
|
2
2
|
from typing import Union
|
3
3
|
|
4
4
|
|
5
|
+
def get_project_root() -> str:
|
6
|
+
current_dir = os.getcwd()
|
7
|
+
while current_dir != '/':
|
8
|
+
if ".git" in os.listdir(current_dir):
|
9
|
+
return current_dir
|
10
|
+
current_dir = os.path.dirname(current_dir)
|
11
|
+
raise FileNotFoundError("Project root not found.")
|
12
|
+
|
5
13
|
def read_file_from_path(filepath: str) -> Union[str, None]:
|
6
|
-
|
14
|
+
project_root = get_project_root()
|
15
|
+
relative_filepath = os.path.join(project_root, filepath)
|
16
|
+
|
17
|
+
if not os.path.exists(relative_filepath):
|
7
18
|
return None
|
8
19
|
|
9
|
-
with open(
|
20
|
+
with open(relative_filepath, 'r') as file:
|
10
21
|
return file.read()
|
@@ -9,16 +9,23 @@ from vellum.workflows.types.generics import StateType
|
|
9
9
|
|
10
10
|
class InlinePromptNode(BaseInlinePromptNode[StateType]):
|
11
11
|
"""
|
12
|
-
Used to execute
|
12
|
+
Used to execute a Prompt defined inline.
|
13
13
|
|
14
14
|
prompt_inputs: EntityInputsInterface - The inputs for the Prompt
|
15
15
|
ml_model: str - Either the ML Model's UUID or its name.
|
16
|
-
blocks: List[
|
16
|
+
blocks: List[PromptBlock] - The blocks that make up the Prompt
|
17
|
+
functions: Optional[List[FunctionDefinition]] - The functions to include in the Prompt
|
17
18
|
parameters: PromptParameters - The parameters for the Prompt
|
18
|
-
expand_meta: Optional[
|
19
|
+
expand_meta: Optional[AdHocExpandMeta] - Expandable execution fields to include in the response
|
20
|
+
request_options: Optional[RequestOptions] - The request options to use for the Prompt Execution
|
19
21
|
"""
|
20
22
|
|
21
23
|
class Outputs(BaseInlinePromptNode.Outputs):
|
24
|
+
"""
|
25
|
+
The outputs of the InlinePromptNode.
|
26
|
+
|
27
|
+
text: str - The result of the Prompt Execution
|
28
|
+
"""
|
22
29
|
text: str
|
23
30
|
|
24
31
|
def run(self) -> Iterator[BaseOutput]:
|
@@ -14,7 +14,7 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
|
|
14
14
|
prompt_inputs: EntityInputsInterface - The inputs for the Prompt
|
15
15
|
deployment: Union[UUID, str] - Either the Prompt Deployment's UUID or its name.
|
16
16
|
release_tag: str - The release tag to use for the Prompt Execution
|
17
|
-
external_id: Optional[str] -
|
17
|
+
external_id: Optional[str] - Optionally include a unique identifier for tracking purposes. Must be unique within a given Prompt Deployment.
|
18
18
|
expand_meta: Optional[PromptDeploymentExpandMetaRequest] - Expandable execution fields to include in the response
|
19
19
|
raw_overrides: Optional[RawPromptExecutionOverridesRequest] - The raw overrides to use for the Prompt Execution
|
20
20
|
expand_raw: Optional[Sequence[str]] - Expandable raw fields to include in the response
|
@@ -23,6 +23,11 @@ class PromptDeploymentNode(BasePromptDeploymentNode[StateType]):
|
|
23
23
|
"""
|
24
24
|
|
25
25
|
class Outputs(BasePromptDeploymentNode.Outputs):
|
26
|
+
"""
|
27
|
+
The outputs of the PromptDeploymentNode.
|
28
|
+
|
29
|
+
text: str - The result of the Prompt Execution
|
30
|
+
"""
|
26
31
|
text: str
|
27
32
|
|
28
33
|
def run(self) -> Iterator[BaseOutput]:
|
@@ -12,7 +12,6 @@ from vellum import (
|
|
12
12
|
WorkflowRequestStringInputRequest,
|
13
13
|
)
|
14
14
|
from vellum.core import RequestOptions
|
15
|
-
|
16
15
|
from vellum.workflows.constants import LATEST_RELEASE_TAG, OMIT
|
17
16
|
from vellum.workflows.errors import VellumErrorCode
|
18
17
|
from vellum.workflows.exceptions import NodeException
|
@@ -28,7 +27,7 @@ class SubworkflowDeploymentNode(BaseSubworkflowNode[StateType], Generic[StateTyp
|
|
28
27
|
subworkflow_inputs: EntityInputsInterface - The inputs for the Subworkflow
|
29
28
|
deployment: Union[UUID, str] - Either the Workflow Deployment's UUID or its name.
|
30
29
|
release_tag: str = LATEST_RELEASE_TAG - The release tag to use for the Workflow Execution
|
31
|
-
external_id: Optional[str] = OMIT -
|
30
|
+
external_id: Optional[str] = OMIT - Optionally include a unique identifier for tracking purposes. Must be unique within a given Workflow Deployment.
|
32
31
|
expand_meta: Optional[WorkflowExpandMetaRequest] = OMIT - Expandable execution fields to include in the respownse
|
33
32
|
metadata: Optional[Dict[str, Optional[Any]]] = OMIT - The metadata to use for the Workflow Execution
|
34
33
|
request_options: Optional[RequestOptions] = None - The request options to use for the Workflow Execution
|
@@ -8,7 +8,6 @@ from vellum import (
|
|
8
8
|
PromptOutput,
|
9
9
|
StringVellumValue,
|
10
10
|
)
|
11
|
-
|
12
11
|
from vellum.workflows.constants import OMIT
|
13
12
|
from vellum.workflows.inputs import BaseInputs
|
14
13
|
from vellum.workflows.nodes import PromptDeploymentNode
|
@@ -65,7 +64,7 @@ def test_text_prompt_deployment_node__basic(vellum_client):
|
|
65
64
|
assert text_output.name == "text"
|
66
65
|
assert text_output.value == "Hello, world!"
|
67
66
|
|
68
|
-
# AND we should have made the expected call to
|
67
|
+
# AND we should have made the expected call to stream the prompt execution
|
69
68
|
vellum_client.execute_prompt_stream.assert_called_once_with(
|
70
69
|
expand_meta=OMIT,
|
71
70
|
expand_raw=OMIT,
|
@@ -3,7 +3,7 @@ from copy import deepcopy
|
|
3
3
|
import logging
|
4
4
|
from queue import Empty, Queue
|
5
5
|
from threading import Event as ThreadingEvent, Thread
|
6
|
-
from uuid import UUID
|
6
|
+
from uuid import UUID
|
7
7
|
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Optional, Sequence, Set, Type, Union
|
8
8
|
|
9
9
|
from vellum.workflows.constants import UNDEF
|
@@ -28,8 +28,7 @@ from vellum.workflows.events.node import (
|
|
28
28
|
NodeExecutionRejectedBody,
|
29
29
|
NodeExecutionStreamingBody,
|
30
30
|
)
|
31
|
-
from vellum.workflows.events.types import BaseEvent
|
32
|
-
from vellum.workflows.events.utils import is_terminal_event
|
31
|
+
from vellum.workflows.events.types import BaseEvent, ParentContext, WorkflowParentContext
|
33
32
|
from vellum.workflows.events.workflow import (
|
34
33
|
WorkflowExecutionFulfilledBody,
|
35
34
|
WorkflowExecutionInitiatedBody,
|
@@ -71,9 +70,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
71
70
|
entrypoint_nodes: Optional[RunFromNodeArg] = None,
|
72
71
|
external_inputs: Optional[ExternalInputsArg] = None,
|
73
72
|
cancel_signal: Optional[ThreadingEvent] = None,
|
73
|
+
parent_context: Optional[ParentContext] = None,
|
74
74
|
):
|
75
75
|
if state and external_inputs:
|
76
|
-
raise ValueError(
|
76
|
+
raise ValueError(
|
77
|
+
"Can only run a Workflow providing one of state or external inputs, not both"
|
78
|
+
)
|
77
79
|
|
78
80
|
self.workflow = workflow
|
79
81
|
if entrypoint_nodes:
|
@@ -99,7 +101,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
99
101
|
if issubclass(ei.inputs_class.__parent_class__, BaseNode)
|
100
102
|
]
|
101
103
|
else:
|
102
|
-
normalized_inputs =
|
104
|
+
normalized_inputs = (
|
105
|
+
deepcopy(inputs) if inputs else self.workflow.get_default_inputs()
|
106
|
+
)
|
103
107
|
if state:
|
104
108
|
self._initial_state = deepcopy(state)
|
105
109
|
self._initial_state.meta.workflow_inputs = normalized_inputs
|
@@ -115,12 +119,14 @@ class WorkflowRunner(Generic[StateType]):
|
|
115
119
|
|
116
120
|
self._active_nodes_by_execution_id: Dict[UUID, BaseNode[StateType]] = {}
|
117
121
|
self._cancel_signal = cancel_signal
|
122
|
+
self._parent_context = parent_context
|
118
123
|
|
119
124
|
setattr(
|
120
125
|
self._initial_state,
|
121
126
|
"__snapshot_callback__",
|
122
127
|
lambda s: self._snapshot_state(s),
|
123
128
|
)
|
129
|
+
self.workflow.context._register_event_queue(self._workflow_event_queue)
|
124
130
|
|
125
131
|
def _snapshot_state(self, state: StateType) -> StateType:
|
126
132
|
self.workflow._store.append_state_snapshot(state)
|
@@ -143,6 +149,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
143
149
|
node_definition=node.__class__,
|
144
150
|
inputs=node._inputs,
|
145
151
|
),
|
152
|
+
parent=WorkflowParentContext(
|
153
|
+
span_id=span_id,
|
154
|
+
workflow_definition=self.workflow.__class__,
|
155
|
+
parent=self._parent_context,
|
156
|
+
type="WORKFLOW",
|
157
|
+
),
|
146
158
|
),
|
147
159
|
)
|
148
160
|
)
|
@@ -178,7 +190,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
178
190
|
instance=None,
|
179
191
|
outputs_class=node.Outputs,
|
180
192
|
)
|
181
|
-
node.state.meta.node_outputs[output_descriptor] =
|
193
|
+
node.state.meta.node_outputs[output_descriptor] = (
|
194
|
+
streaming_output_queues[output.name]
|
195
|
+
)
|
196
|
+
initiated_output: BaseOutput = BaseOutput(name=output.name)
|
197
|
+
initiated_ports = initiated_output > ports
|
182
198
|
self._work_item_event_queue.put(
|
183
199
|
WorkItemEvent(
|
184
200
|
node=node,
|
@@ -187,8 +203,13 @@ class WorkflowRunner(Generic[StateType]):
|
|
187
203
|
span_id=span_id,
|
188
204
|
body=NodeExecutionStreamingBody(
|
189
205
|
node_definition=node.__class__,
|
190
|
-
output=
|
191
|
-
invoked_ports=
|
206
|
+
output=initiated_output,
|
207
|
+
invoked_ports=initiated_ports,
|
208
|
+
),
|
209
|
+
parent=WorkflowParentContext(
|
210
|
+
span_id=span_id,
|
211
|
+
workflow_definition=self.workflow.__class__,
|
212
|
+
parent=self._parent_context,
|
192
213
|
),
|
193
214
|
),
|
194
215
|
)
|
@@ -214,6 +235,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
214
235
|
output=output,
|
215
236
|
invoked_ports=invoked_ports,
|
216
237
|
),
|
238
|
+
parent=WorkflowParentContext(
|
239
|
+
span_id=span_id,
|
240
|
+
workflow_definition=self.workflow.__class__,
|
241
|
+
parent=self._parent_context,
|
242
|
+
),
|
217
243
|
),
|
218
244
|
)
|
219
245
|
)
|
@@ -233,6 +259,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
233
259
|
output=output,
|
234
260
|
invoked_ports=invoked_ports,
|
235
261
|
),
|
262
|
+
parent=WorkflowParentContext(
|
263
|
+
span_id=span_id,
|
264
|
+
workflow_definition=self.workflow.__class__,
|
265
|
+
parent=self._parent_context,
|
266
|
+
),
|
236
267
|
),
|
237
268
|
)
|
238
269
|
)
|
@@ -246,7 +277,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
246
277
|
node.state.meta.node_outputs[descriptor] = output_value
|
247
278
|
|
248
279
|
invoked_ports = ports(outputs, node.state)
|
249
|
-
node.state.meta.node_execution_cache.fulfill_node_execution(
|
280
|
+
node.state.meta.node_execution_cache.fulfill_node_execution(
|
281
|
+
node.__class__, span_id
|
282
|
+
)
|
250
283
|
|
251
284
|
self._work_item_event_queue.put(
|
252
285
|
WorkItemEvent(
|
@@ -259,6 +292,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
259
292
|
outputs=outputs,
|
260
293
|
invoked_ports=invoked_ports,
|
261
294
|
),
|
295
|
+
parent=WorkflowParentContext(
|
296
|
+
span_id=span_id,
|
297
|
+
workflow_definition=self.workflow.__class__,
|
298
|
+
parent=self._parent_context,
|
299
|
+
),
|
262
300
|
),
|
263
301
|
)
|
264
302
|
)
|
@@ -273,11 +311,18 @@ class WorkflowRunner(Generic[StateType]):
|
|
273
311
|
node_definition=node.__class__,
|
274
312
|
error=e.error,
|
275
313
|
),
|
314
|
+
parent=WorkflowParentContext(
|
315
|
+
span_id=span_id,
|
316
|
+
workflow_definition=self.workflow.__class__,
|
317
|
+
parent=self._parent_context,
|
318
|
+
),
|
276
319
|
),
|
277
320
|
)
|
278
321
|
)
|
279
322
|
except Exception as e:
|
280
|
-
logger.exception(
|
323
|
+
logger.exception(
|
324
|
+
f"An unexpected error occurred while running node {node.__class__.__name__}"
|
325
|
+
)
|
281
326
|
|
282
327
|
self._work_item_event_queue.put(
|
283
328
|
WorkItemEvent(
|
@@ -292,13 +337,20 @@ class WorkflowRunner(Generic[StateType]):
|
|
292
337
|
code=VellumErrorCode.INTERNAL_ERROR,
|
293
338
|
),
|
294
339
|
),
|
340
|
+
parent=WorkflowParentContext(
|
341
|
+
span_id=span_id,
|
342
|
+
workflow_definition=self.workflow.__class__,
|
343
|
+
parent=self._parent_context,
|
344
|
+
),
|
295
345
|
),
|
296
346
|
)
|
297
347
|
)
|
298
348
|
|
299
349
|
logger.debug(f"Finished running node: {node.__class__.__name__}")
|
300
350
|
|
301
|
-
def _handle_invoked_ports(
|
351
|
+
def _handle_invoked_ports(
|
352
|
+
self, state: StateType, ports: Optional[Iterable[Port]]
|
353
|
+
) -> None:
|
302
354
|
if not ports:
|
303
355
|
return
|
304
356
|
|
@@ -313,7 +365,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
313
365
|
self._run_node_if_ready(next_state, edge.to_node, edge)
|
314
366
|
|
315
367
|
def _run_node_if_ready(
|
316
|
-
self,
|
368
|
+
self,
|
369
|
+
state: StateType,
|
370
|
+
node_class: Type[BaseNode],
|
371
|
+
invoked_by: Optional[Edge] = None,
|
317
372
|
) -> None:
|
318
373
|
with state.__lock__:
|
319
374
|
for descriptor in node_class.ExternalInputs:
|
@@ -325,18 +380,27 @@ class WorkflowRunner(Generic[StateType]):
|
|
325
380
|
return
|
326
381
|
|
327
382
|
all_deps = self._dependencies[node_class]
|
328
|
-
|
383
|
+
node_span_id = state.meta.node_execution_cache.queue_node_execution(
|
384
|
+
node_class, all_deps, invoked_by
|
385
|
+
)
|
386
|
+
if not node_class.Trigger.should_initiate(state, all_deps, node_span_id):
|
329
387
|
return
|
330
388
|
|
331
389
|
node = node_class(state=state, context=self.workflow.context)
|
332
|
-
|
333
|
-
|
390
|
+
state.meta.node_execution_cache.initiate_node_execution(
|
391
|
+
node_class, node_span_id
|
392
|
+
)
|
334
393
|
self._active_nodes_by_execution_id[node_span_id] = node
|
335
394
|
|
336
|
-
worker_thread = Thread(
|
395
|
+
worker_thread = Thread(
|
396
|
+
target=self._run_work_item,
|
397
|
+
kwargs={"node": node, "span_id": node_span_id},
|
398
|
+
)
|
337
399
|
worker_thread.start()
|
338
400
|
|
339
|
-
def _handle_work_item_event(
|
401
|
+
def _handle_work_item_event(
|
402
|
+
self, work_item_event: WorkItemEvent[StateType]
|
403
|
+
) -> Optional[VellumError]:
|
340
404
|
node = work_item_event.node
|
341
405
|
event = work_item_event.event
|
342
406
|
|
@@ -352,7 +416,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
352
416
|
node_output_descriptor = workflow_output_descriptor.instance
|
353
417
|
if not isinstance(node_output_descriptor, OutputReference):
|
354
418
|
continue
|
355
|
-
if
|
419
|
+
if (
|
420
|
+
node_output_descriptor.outputs_class
|
421
|
+
!= event.node_definition.Outputs
|
422
|
+
):
|
356
423
|
continue
|
357
424
|
if node_output_descriptor.name != event.output.name:
|
358
425
|
continue
|
@@ -387,9 +454,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
387
454
|
workflow_definition=self.workflow.__class__,
|
388
455
|
inputs=self._initial_state.meta.workflow_inputs,
|
389
456
|
),
|
457
|
+
parent=self._parent_context,
|
390
458
|
)
|
391
459
|
|
392
|
-
def _stream_workflow_event(
|
460
|
+
def _stream_workflow_event(
|
461
|
+
self, output: BaseOutput
|
462
|
+
) -> WorkflowExecutionStreamingEvent:
|
393
463
|
return WorkflowExecutionStreamingEvent(
|
394
464
|
trace_id=self._initial_state.meta.trace_id,
|
395
465
|
span_id=self._initial_state.meta.span_id,
|
@@ -397,9 +467,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
397
467
|
workflow_definition=self.workflow.__class__,
|
398
468
|
output=output,
|
399
469
|
),
|
470
|
+
parent=self._parent_context,
|
400
471
|
)
|
401
472
|
|
402
|
-
def _fulfill_workflow_event(
|
473
|
+
def _fulfill_workflow_event(
|
474
|
+
self, outputs: OutputsType
|
475
|
+
) -> WorkflowExecutionFulfilledEvent:
|
403
476
|
return WorkflowExecutionFulfilledEvent(
|
404
477
|
trace_id=self._initial_state.meta.trace_id,
|
405
478
|
span_id=self._initial_state.meta.span_id,
|
@@ -407,9 +480,12 @@ class WorkflowRunner(Generic[StateType]):
|
|
407
480
|
workflow_definition=self.workflow.__class__,
|
408
481
|
outputs=outputs,
|
409
482
|
),
|
483
|
+
parent=self._parent_context,
|
410
484
|
)
|
411
485
|
|
412
|
-
def _reject_workflow_event(
|
486
|
+
def _reject_workflow_event(
|
487
|
+
self, error: VellumError
|
488
|
+
) -> WorkflowExecutionRejectedEvent:
|
413
489
|
return WorkflowExecutionRejectedEvent(
|
414
490
|
trace_id=self._initial_state.meta.trace_id,
|
415
491
|
span_id=self._initial_state.meta.span_id,
|
@@ -417,6 +493,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
417
493
|
workflow_definition=self.workflow.__class__,
|
418
494
|
error=error,
|
419
495
|
),
|
496
|
+
parent=self._parent_context,
|
420
497
|
)
|
421
498
|
|
422
499
|
def _resume_workflow_event(self) -> WorkflowExecutionResumedEvent:
|
@@ -428,7 +505,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
428
505
|
),
|
429
506
|
)
|
430
507
|
|
431
|
-
def _pause_workflow_event(
|
508
|
+
def _pause_workflow_event(
|
509
|
+
self, external_inputs: Iterable[ExternalInputReference]
|
510
|
+
) -> WorkflowExecutionPausedEvent:
|
432
511
|
return WorkflowExecutionPausedEvent(
|
433
512
|
trace_id=self._initial_state.meta.trace_id,
|
434
513
|
span_id=self._initial_state.meta.span_id,
|
@@ -436,6 +515,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
436
515
|
workflow_definition=self.workflow.__class__,
|
437
516
|
external_inputs=external_inputs,
|
438
517
|
),
|
518
|
+
parent=self._parent_context,
|
439
519
|
)
|
440
520
|
|
441
521
|
def _stream(self) -> None:
|
@@ -444,7 +524,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
444
524
|
if not self._entrypoints:
|
445
525
|
self._workflow_event_queue.put(
|
446
526
|
self._reject_workflow_event(
|
447
|
-
VellumError(
|
527
|
+
VellumError(
|
528
|
+
message="No entrypoints defined",
|
529
|
+
code=VellumErrorCode.INVALID_WORKFLOW,
|
530
|
+
)
|
448
531
|
)
|
449
532
|
)
|
450
533
|
return
|
@@ -463,7 +546,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
463
546
|
logger.exception(err_message)
|
464
547
|
self._workflow_event_queue.put(
|
465
548
|
self._reject_workflow_event(
|
466
|
-
VellumError(
|
549
|
+
VellumError(
|
550
|
+
code=VellumErrorCode.INTERNAL_ERROR, message=err_message
|
551
|
+
),
|
467
552
|
)
|
468
553
|
)
|
469
554
|
return
|
@@ -519,7 +604,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
519
604
|
if isinstance(value, BaseDescriptor):
|
520
605
|
setattr(fulfilled_outputs, descriptor.name, value.resolve(final_state))
|
521
606
|
elif isinstance(descriptor.instance, BaseDescriptor):
|
522
|
-
setattr(
|
607
|
+
setattr(
|
608
|
+
fulfilled_outputs,
|
609
|
+
descriptor.name,
|
610
|
+
descriptor.instance.resolve(final_state),
|
611
|
+
)
|
523
612
|
|
524
613
|
self._workflow_event_queue.put(self._fulfill_workflow_event(fulfilled_outputs))
|
525
614
|
|
@@ -544,24 +633,41 @@ class WorkflowRunner(Generic[StateType]):
|
|
544
633
|
self._cancel_signal.wait()
|
545
634
|
self._workflow_event_queue.put(
|
546
635
|
self._reject_workflow_event(
|
547
|
-
VellumError(
|
636
|
+
VellumError(
|
637
|
+
code=VellumErrorCode.WORKFLOW_CANCELLED,
|
638
|
+
message="Workflow run cancelled",
|
639
|
+
)
|
548
640
|
)
|
549
641
|
)
|
550
642
|
|
643
|
+
def _is_terminal_event(self, event: WorkflowEvent) -> bool:
|
644
|
+
if (
|
645
|
+
event.name == "workflow.execution.fulfilled"
|
646
|
+
or event.name == "workflow.execution.rejected"
|
647
|
+
or event.name == "workflow.execution.paused"
|
648
|
+
):
|
649
|
+
return event.workflow_definition == self.workflow.__class__
|
650
|
+
return False
|
651
|
+
|
551
652
|
def stream(self) -> WorkflowEventStream:
|
552
653
|
background_thread = Thread(
|
553
|
-
target=self._run_background_thread,
|
654
|
+
target=self._run_background_thread,
|
655
|
+
name=f"{self.workflow.__class__.__name__}.background_thread",
|
554
656
|
)
|
555
657
|
background_thread.start()
|
556
658
|
|
557
659
|
if self._cancel_signal:
|
558
660
|
cancel_thread = Thread(
|
559
|
-
target=self._run_cancel_thread,
|
661
|
+
target=self._run_cancel_thread,
|
662
|
+
name=f"{self.workflow.__class__.__name__}.cancel_thread",
|
560
663
|
)
|
561
664
|
cancel_thread.start()
|
562
665
|
|
563
666
|
event: WorkflowEvent
|
564
|
-
if
|
667
|
+
if (
|
668
|
+
self._initial_state.meta.is_terminated
|
669
|
+
or self._initial_state.meta.is_terminated is None
|
670
|
+
):
|
565
671
|
event = self._initiate_workflow_event()
|
566
672
|
else:
|
567
673
|
event = self._resume_workflow_event()
|
@@ -570,7 +676,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
570
676
|
self._initial_state.meta.is_terminated = False
|
571
677
|
|
572
678
|
# The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
|
573
|
-
stream_thread = Thread(
|
679
|
+
stream_thread = Thread(
|
680
|
+
target=self._stream,
|
681
|
+
name=f"{self.workflow.__class__.__name__}.stream_thread",
|
682
|
+
)
|
574
683
|
stream_thread.start()
|
575
684
|
|
576
685
|
while stream_thread.is_alive():
|
@@ -581,7 +690,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
581
690
|
|
582
691
|
yield self._emit_event(event)
|
583
692
|
|
584
|
-
if
|
693
|
+
if self._is_terminal_event(event):
|
585
694
|
break
|
586
695
|
|
587
696
|
try:
|
@@ -590,7 +699,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
590
699
|
except Empty:
|
591
700
|
pass
|
592
701
|
|
593
|
-
if not
|
702
|
+
if not self._is_terminal_event(event):
|
594
703
|
yield self._reject_workflow_event(
|
595
704
|
VellumError(
|
596
705
|
code=VellumErrorCode.INTERNAL_ERROR,
|
vellum/workflows/state/base.py
CHANGED
@@ -5,15 +5,15 @@ from datetime import datetime
|
|
5
5
|
from queue import Queue
|
6
6
|
from threading import Lock
|
7
7
|
from uuid import UUID, uuid4
|
8
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, cast
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, cast
|
9
9
|
from typing_extensions import dataclass_transform
|
10
10
|
|
11
11
|
from pydantic import GetCoreSchemaHandler, field_serializer
|
12
12
|
from pydantic_core import core_schema
|
13
13
|
|
14
14
|
from vellum.core.pydantic_utilities import UniversalBaseModel
|
15
|
-
|
16
15
|
from vellum.workflows.constants import UNDEF
|
16
|
+
from vellum.workflows.edges.edge import Edge
|
17
17
|
from vellum.workflows.inputs.base import BaseInputs
|
18
18
|
from vellum.workflows.references import ExternalInputReference, OutputReference, StateValueReference
|
19
19
|
from vellum.workflows.types.generics import StateType
|
@@ -71,58 +71,92 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
|
|
71
71
|
|
72
72
|
|
73
73
|
class NodeExecutionCache:
|
74
|
-
|
74
|
+
_node_executions_fulfilled: Dict[Type["BaseNode"], Stack[UUID]]
|
75
75
|
_node_executions_initiated: Dict[Type["BaseNode"], Set[UUID]]
|
76
|
-
|
76
|
+
_node_executions_queued: Dict[Type["BaseNode"], List[UUID]]
|
77
|
+
_dependencies_invoked: Dict[UUID, Set[Type["BaseNode"]]]
|
77
78
|
|
78
79
|
def __init__(
|
79
80
|
self,
|
80
81
|
dependencies_invoked: Optional[Dict[str, Sequence[str]]] = None,
|
81
|
-
|
82
|
+
node_executions_fulfilled: Optional[Dict[str, Sequence[str]]] = None,
|
82
83
|
node_executions_initiated: Optional[Dict[str, Sequence[str]]] = None,
|
84
|
+
node_executions_queued: Optional[Dict[str, Sequence[str]]] = None,
|
83
85
|
) -> None:
|
84
86
|
self._dependencies_invoked = defaultdict(set)
|
85
|
-
self.
|
87
|
+
self._node_executions_fulfilled = defaultdict(Stack[UUID])
|
86
88
|
self._node_executions_initiated = defaultdict(set)
|
89
|
+
self._node_executions_queued = defaultdict(list)
|
87
90
|
|
88
|
-
for
|
89
|
-
self._dependencies_invoked[
|
91
|
+
for execution_id, dependencies in (dependencies_invoked or {}).items():
|
92
|
+
self._dependencies_invoked[UUID(execution_id)] = {get_class_by_qualname(dep) for dep in dependencies}
|
90
93
|
|
91
|
-
for node, execution_ids in (
|
94
|
+
for node, execution_ids in (node_executions_fulfilled or {}).items():
|
92
95
|
node_class = get_class_by_qualname(node)
|
93
|
-
self.
|
96
|
+
self._node_executions_fulfilled[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
94
97
|
|
95
98
|
for node, execution_ids in (node_executions_initiated or {}).items():
|
96
99
|
node_class = get_class_by_qualname(node)
|
97
100
|
self._node_executions_initiated[node_class].update({UUID(execution_id) for execution_id in execution_ids})
|
98
101
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
+
for node, execution_ids in (node_executions_queued or {}).items():
|
103
|
+
node_class = get_class_by_qualname(node)
|
104
|
+
self._node_executions_queued[node_class].extend(UUID(execution_id) for execution_id in execution_ids)
|
102
105
|
|
103
|
-
def
|
104
|
-
|
106
|
+
def _invoke_dependency(
|
107
|
+
self,
|
108
|
+
execution_id: UUID,
|
109
|
+
node: Type["BaseNode"],
|
110
|
+
dependency: Type["BaseNode"],
|
111
|
+
dependencies: Set["Type[BaseNode]"],
|
112
|
+
) -> None:
|
113
|
+
self._dependencies_invoked[execution_id].add(dependency)
|
114
|
+
if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
|
115
|
+
self._node_executions_queued[node].remove(execution_id)
|
116
|
+
|
117
|
+
def queue_node_execution(
|
118
|
+
self, node: Type["BaseNode"], dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
|
119
|
+
) -> UUID:
|
120
|
+
execution_id = uuid4()
|
121
|
+
if not invoked_by:
|
122
|
+
return execution_id
|
123
|
+
|
124
|
+
source_node = invoked_by.from_port.node_class
|
125
|
+
for queued_node_execution_id in self._node_executions_queued[node]:
|
126
|
+
if source_node not in self._dependencies_invoked[queued_node_execution_id]:
|
127
|
+
self._invoke_dependency(queued_node_execution_id, node, source_node, dependencies)
|
128
|
+
return queued_node_execution_id
|
129
|
+
|
130
|
+
self._node_executions_queued[node].append(execution_id)
|
131
|
+
self._invoke_dependency(execution_id, node, source_node, dependencies)
|
132
|
+
return execution_id
|
133
|
+
|
134
|
+
def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
|
135
|
+
return execution_id in self._node_executions_initiated[node]
|
105
136
|
|
106
137
|
def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
107
138
|
self._node_executions_initiated[node].add(execution_id)
|
108
139
|
|
109
140
|
def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
|
110
|
-
self.
|
111
|
-
self._node_execution_ids[node].push(execution_id)
|
141
|
+
self._node_executions_fulfilled[node].push(execution_id)
|
112
142
|
|
113
143
|
def get_execution_count(self, node: Type["BaseNode"]) -> int:
|
114
|
-
return self.
|
144
|
+
return self._node_executions_fulfilled[node].size()
|
115
145
|
|
116
146
|
def dump(self) -> Dict[str, Any]:
|
117
147
|
return {
|
118
148
|
"dependencies_invoked": {
|
119
|
-
|
149
|
+
str(execution_id): [str(dep) for dep in dependencies]
|
150
|
+
for execution_id, dependencies in self._dependencies_invoked.items()
|
120
151
|
},
|
121
152
|
"node_executions_initiated": {
|
122
153
|
str(node): list(execution_ids) for node, execution_ids in self._node_executions_initiated.items()
|
123
154
|
},
|
124
|
-
"
|
125
|
-
str(node): execution_ids.dump() for node, execution_ids in self.
|
155
|
+
"node_executions_fulfilled": {
|
156
|
+
str(node): execution_ids.dump() for node, execution_ids in self._node_executions_fulfilled.items()
|
157
|
+
},
|
158
|
+
"node_executions_queued": {
|
159
|
+
str(node): execution_ids for node, execution_ids in self._node_executions_queued.items()
|
126
160
|
},
|
127
161
|
}
|
128
162
|
|