vellum-ai 1.4.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +14 -0
- vellum/client/__init__.py +3 -0
- vellum/client/core/client_wrapper.py +2 -2
- vellum/client/reference.md +160 -0
- vellum/client/resources/__init__.py +2 -0
- vellum/client/resources/integrations/__init__.py +4 -0
- vellum/client/resources/integrations/client.py +260 -0
- vellum/client/resources/integrations/raw_client.py +267 -0
- vellum/client/types/__init__.py +12 -0
- vellum/client/types/components_schemas_composio_execute_tool_request.py +5 -0
- vellum/client/types/components_schemas_composio_execute_tool_response.py +5 -0
- vellum/client/types/components_schemas_composio_tool_definition.py +5 -0
- vellum/client/types/composio_execute_tool_request.py +24 -0
- vellum/client/types/composio_execute_tool_response.py +24 -0
- vellum/client/types/composio_tool_definition.py +26 -0
- vellum/client/types/vellum_error_code_enum.py +2 -0
- vellum/client/types/vellum_sdk_error.py +1 -0
- vellum/client/types/workflow_event_error.py +1 -0
- vellum/resources/integrations/__init__.py +3 -0
- vellum/resources/integrations/client.py +3 -0
- vellum/resources/integrations/raw_client.py +3 -0
- vellum/types/components_schemas_composio_execute_tool_request.py +3 -0
- vellum/types/components_schemas_composio_execute_tool_response.py +3 -0
- vellum/types/components_schemas_composio_tool_definition.py +3 -0
- vellum/types/composio_execute_tool_request.py +3 -0
- vellum/types/composio_execute_tool_response.py +3 -0
- vellum/types/composio_tool_definition.py +3 -0
- vellum/workflows/runner/runner.py +132 -110
- vellum/workflows/utils/functions.py +6 -1
- vellum/workflows/utils/tests/test_functions.py +40 -0
- vellum/workflows/workflows/base.py +19 -5
- vellum/workflows/workflows/tests/test_base_workflow.py +54 -0
- {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.0.dist-info}/METADATA +1 -1
- {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.0.dist-info}/RECORD +39 -19
- vellum_ai-1.5.0.dist-info/entry_points.txt +4 -0
- vellum_ee/assets/node-definitions.json +483 -0
- vellum_ee/scripts/generate_node_definitions.py +89 -0
- vellum_ai-1.4.2.dist-info/entry_points.txt +0 -3
- {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.0.dist-info}/LICENSE +0 -0
- {vellum_ai-1.4.2.dist-info → vellum_ai-1.5.0.dist-info}/WHEEL +0 -0
@@ -42,6 +42,7 @@ from vellum.workflows.events import (
|
|
42
42
|
WorkflowExecutionStreamingEvent,
|
43
43
|
)
|
44
44
|
from vellum.workflows.events.node import (
|
45
|
+
NodeEvent,
|
45
46
|
NodeExecutionFulfilledBody,
|
46
47
|
NodeExecutionInitiatedBody,
|
47
48
|
NodeExecutionRejectedBody,
|
@@ -212,6 +213,10 @@ class WorkflowRunner(Generic[StateType]):
|
|
212
213
|
descriptor for descriptor in self.workflow.Outputs if isinstance(descriptor.instance, StateValueReference)
|
213
214
|
]
|
214
215
|
|
216
|
+
self._background_thread: Optional[Thread] = None
|
217
|
+
self._cancel_thread: Optional[Thread] = None
|
218
|
+
self._stream_thread: Optional[Thread] = None
|
219
|
+
|
215
220
|
def _snapshot_state(self, state: StateType, deltas: List[StateDelta]) -> StateType:
|
216
221
|
self._workflow_event_inner_queue.put(
|
217
222
|
WorkflowExecutionSnapshottedEvent(
|
@@ -259,17 +264,36 @@ class WorkflowRunner(Generic[StateType]):
|
|
259
264
|
return event
|
260
265
|
|
261
266
|
def _run_work_item(self, node: BaseNode[StateType], span_id: UUID) -> None:
|
267
|
+
for event in self.run_node(node, span_id):
|
268
|
+
self._workflow_event_inner_queue.put(event)
|
269
|
+
|
270
|
+
def run_node(
|
271
|
+
self,
|
272
|
+
node: "BaseNode[StateType]",
|
273
|
+
span_id: UUID,
|
274
|
+
) -> Generator[NodeEvent, None, None]:
|
275
|
+
"""
|
276
|
+
Execute a single node and yield workflow events.
|
277
|
+
|
278
|
+
Args:
|
279
|
+
node: The node instance to execute
|
280
|
+
span_id: Unique identifier for this node execution
|
281
|
+
|
282
|
+
Yields:
|
283
|
+
NodeExecutionEvent: Events emitted during node execution (initiated, streaming, fulfilled, rejected)
|
284
|
+
"""
|
262
285
|
execution = get_execution_context()
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
)
|
286
|
+
|
287
|
+
node_output_mocks_map = self.workflow.context.node_output_mocks_map
|
288
|
+
|
289
|
+
yield NodeExecutionInitiatedEvent(
|
290
|
+
trace_id=execution.trace_id,
|
291
|
+
span_id=span_id,
|
292
|
+
body=NodeExecutionInitiatedBody(
|
293
|
+
node_definition=node.__class__,
|
294
|
+
inputs=node._inputs,
|
295
|
+
),
|
296
|
+
parent=execution.parent_context,
|
273
297
|
)
|
274
298
|
|
275
299
|
logger.debug(f"Started running node: {node.__class__.__name__}")
|
@@ -282,7 +306,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
282
306
|
)
|
283
307
|
node_run_response: NodeRunResponse
|
284
308
|
was_mocked: Optional[bool] = None
|
285
|
-
mock_candidates =
|
309
|
+
mock_candidates = node_output_mocks_map.get(node.Outputs) or []
|
286
310
|
for mock_candidate in mock_candidates:
|
287
311
|
if mock_candidate.when_condition.resolve(node.state):
|
288
312
|
node_run_response = mock_candidate.then_outputs
|
@@ -312,8 +336,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
312
336
|
streaming_output_queues: Dict[str, Queue] = {}
|
313
337
|
outputs = node.Outputs()
|
314
338
|
|
315
|
-
def initiate_node_streaming_output(
|
316
|
-
|
339
|
+
def initiate_node_streaming_output(
|
340
|
+
output: BaseOutput,
|
341
|
+
) -> Generator[NodeExecutionStreamingEvent, None, None]:
|
317
342
|
streaming_output_queues[output.name] = Queue()
|
318
343
|
output_descriptor = OutputReference(
|
319
344
|
name=output.name,
|
@@ -325,57 +350,51 @@ class WorkflowRunner(Generic[StateType]):
|
|
325
350
|
node.state.meta.node_outputs[output_descriptor] = streaming_output_queues[output.name]
|
326
351
|
initiated_output: BaseOutput = BaseOutput(name=output.name)
|
327
352
|
initiated_ports = initiated_output > ports
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
invoked_ports=initiated_ports,
|
336
|
-
),
|
337
|
-
parent=execution.parent_context,
|
353
|
+
yield NodeExecutionStreamingEvent(
|
354
|
+
trace_id=execution.trace_id,
|
355
|
+
span_id=span_id,
|
356
|
+
body=NodeExecutionStreamingBody(
|
357
|
+
node_definition=node.__class__,
|
358
|
+
output=initiated_output,
|
359
|
+
invoked_ports=initiated_ports,
|
338
360
|
),
|
361
|
+
parent=execution.parent_context,
|
339
362
|
)
|
340
363
|
|
341
364
|
with execution_context(parent_context=updated_parent_context, trace_id=execution.trace_id):
|
342
365
|
for output in node_run_response:
|
343
366
|
invoked_ports = output > ports
|
344
367
|
if output.is_initiated:
|
345
|
-
initiate_node_streaming_output(output)
|
368
|
+
yield from initiate_node_streaming_output(output)
|
346
369
|
elif output.is_streaming:
|
347
370
|
if output.name not in streaming_output_queues:
|
348
|
-
initiate_node_streaming_output(output)
|
371
|
+
yield from initiate_node_streaming_output(output)
|
349
372
|
|
350
373
|
streaming_output_queues[output.name].put(output.delta)
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
invoked_ports=invoked_ports,
|
359
|
-
),
|
360
|
-
parent=execution.parent_context,
|
374
|
+
yield NodeExecutionStreamingEvent(
|
375
|
+
trace_id=execution.trace_id,
|
376
|
+
span_id=span_id,
|
377
|
+
body=NodeExecutionStreamingBody(
|
378
|
+
node_definition=node.__class__,
|
379
|
+
output=output,
|
380
|
+
invoked_ports=invoked_ports,
|
361
381
|
),
|
382
|
+
parent=execution.parent_context,
|
362
383
|
)
|
363
384
|
elif output.is_fulfilled:
|
364
385
|
if output.name in streaming_output_queues:
|
365
386
|
streaming_output_queues[output.name].put(undefined)
|
366
387
|
|
367
388
|
setattr(outputs, output.name, output.value)
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
parent=execution.parent_context,
|
378
|
-
)
|
389
|
+
yield NodeExecutionStreamingEvent(
|
390
|
+
trace_id=execution.trace_id,
|
391
|
+
span_id=span_id,
|
392
|
+
body=NodeExecutionStreamingBody(
|
393
|
+
node_definition=node.__class__,
|
394
|
+
output=output,
|
395
|
+
invoked_ports=invoked_ports,
|
396
|
+
),
|
397
|
+
parent=execution.parent_context,
|
379
398
|
)
|
380
399
|
|
381
400
|
node.state.meta.node_execution_cache.fulfill_node_execution(node.__class__, span_id)
|
@@ -390,66 +409,57 @@ class WorkflowRunner(Generic[StateType]):
|
|
390
409
|
node.state.meta.node_outputs[descriptor] = output_value
|
391
410
|
|
392
411
|
invoked_ports = ports(outputs, node.state)
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
parent=execution.parent_context,
|
404
|
-
)
|
412
|
+
yield NodeExecutionFulfilledEvent(
|
413
|
+
trace_id=execution.trace_id,
|
414
|
+
span_id=span_id,
|
415
|
+
body=NodeExecutionFulfilledBody(
|
416
|
+
node_definition=node.__class__,
|
417
|
+
outputs=outputs,
|
418
|
+
invoked_ports=invoked_ports,
|
419
|
+
mocked=was_mocked,
|
420
|
+
),
|
421
|
+
parent=execution.parent_context,
|
405
422
|
)
|
406
423
|
except NodeException as e:
|
407
424
|
logger.info(e)
|
408
425
|
captured_stacktrace = traceback.format_exc()
|
409
426
|
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
parent=execution.parent_context,
|
420
|
-
)
|
427
|
+
yield NodeExecutionRejectedEvent(
|
428
|
+
trace_id=execution.trace_id,
|
429
|
+
span_id=span_id,
|
430
|
+
body=NodeExecutionRejectedBody(
|
431
|
+
node_definition=node.__class__,
|
432
|
+
error=e.error,
|
433
|
+
stacktrace=captured_stacktrace,
|
434
|
+
),
|
435
|
+
parent=execution.parent_context,
|
421
436
|
)
|
422
437
|
except WorkflowInitializationException as e:
|
423
438
|
logger.info(e)
|
424
439
|
captured_stacktrace = traceback.format_exc()
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
parent=execution.parent_context,
|
435
|
-
)
|
440
|
+
yield NodeExecutionRejectedEvent(
|
441
|
+
trace_id=execution.trace_id,
|
442
|
+
span_id=span_id,
|
443
|
+
body=NodeExecutionRejectedBody(
|
444
|
+
node_definition=node.__class__,
|
445
|
+
error=e.error,
|
446
|
+
stacktrace=captured_stacktrace,
|
447
|
+
),
|
448
|
+
parent=execution.parent_context,
|
436
449
|
)
|
437
450
|
except InvalidExpressionException as e:
|
438
451
|
logger.info(e)
|
439
452
|
captured_stacktrace = traceback.format_exc()
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
parent=execution.parent_context,
|
450
|
-
)
|
453
|
+
yield NodeExecutionRejectedEvent(
|
454
|
+
trace_id=execution.trace_id,
|
455
|
+
span_id=span_id,
|
456
|
+
body=NodeExecutionRejectedBody(
|
457
|
+
node_definition=node.__class__,
|
458
|
+
error=e.error,
|
459
|
+
stacktrace=captured_stacktrace,
|
460
|
+
),
|
461
|
+
parent=execution.parent_context,
|
451
462
|
)
|
452
|
-
|
453
463
|
except Exception as e:
|
454
464
|
error_message = self._parse_error_message(e)
|
455
465
|
if error_message is None:
|
@@ -459,19 +469,17 @@ class WorkflowRunner(Generic[StateType]):
|
|
459
469
|
else:
|
460
470
|
error_code = WorkflowErrorCode.NODE_EXECUTION
|
461
471
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
code=error_code,
|
471
|
-
),
|
472
|
+
yield NodeExecutionRejectedEvent(
|
473
|
+
trace_id=execution.trace_id,
|
474
|
+
span_id=span_id,
|
475
|
+
body=NodeExecutionRejectedBody(
|
476
|
+
node_definition=node.__class__,
|
477
|
+
error=WorkflowError(
|
478
|
+
message=error_message,
|
479
|
+
code=error_code,
|
472
480
|
),
|
473
|
-
parent=execution.parent_context,
|
474
481
|
),
|
482
|
+
parent=execution.parent_context,
|
475
483
|
)
|
476
484
|
|
477
485
|
logger.debug(f"Finished running node: {node.__class__.__name__}")
|
@@ -911,20 +919,20 @@ class WorkflowRunner(Generic[StateType]):
|
|
911
919
|
return False
|
912
920
|
|
913
921
|
def _generate_events(self) -> Generator[WorkflowEvent, None, None]:
|
914
|
-
|
922
|
+
self._background_thread = Thread(
|
915
923
|
target=self._run_background_thread,
|
916
924
|
name=f"{self.workflow.__class__.__name__}.background_thread",
|
917
925
|
)
|
918
|
-
|
926
|
+
self._background_thread.start()
|
919
927
|
|
920
928
|
cancel_thread_kill_switch = ThreadingEvent()
|
921
929
|
if self._cancel_signal:
|
922
|
-
|
930
|
+
self._cancel_thread = Thread(
|
923
931
|
target=self._run_cancel_thread,
|
924
932
|
name=f"{self.workflow.__class__.__name__}.cancel_thread",
|
925
933
|
kwargs={"kill_switch": cancel_thread_kill_switch},
|
926
934
|
)
|
927
|
-
|
935
|
+
self._cancel_thread.start()
|
928
936
|
|
929
937
|
event: WorkflowEvent
|
930
938
|
if self._is_resuming:
|
@@ -935,13 +943,13 @@ class WorkflowRunner(Generic[StateType]):
|
|
935
943
|
yield self._emit_event(event)
|
936
944
|
|
937
945
|
# The extra level of indirection prevents the runner from waiting on the caller to consume the event stream
|
938
|
-
|
946
|
+
self._stream_thread = Thread(
|
939
947
|
target=self._stream,
|
940
948
|
name=f"{self.workflow.__class__.__name__}.stream_thread",
|
941
949
|
)
|
942
|
-
|
950
|
+
self._stream_thread.start()
|
943
951
|
|
944
|
-
while
|
952
|
+
while self._stream_thread.is_alive():
|
945
953
|
try:
|
946
954
|
event = self._workflow_event_outer_queue.get(timeout=0.1)
|
947
955
|
except Empty:
|
@@ -971,3 +979,17 @@ class WorkflowRunner(Generic[StateType]):
|
|
971
979
|
|
972
980
|
def stream(self) -> WorkflowEventStream:
|
973
981
|
return WorkflowEventGenerator(self._generate_events(), self._initial_state.meta.span_id)
|
982
|
+
|
983
|
+
def join(self) -> None:
|
984
|
+
"""
|
985
|
+
Wait for all background threads to complete.
|
986
|
+
This ensures all pending work is finished before the runner terminates.
|
987
|
+
"""
|
988
|
+
if self._stream_thread and self._stream_thread.is_alive():
|
989
|
+
self._stream_thread.join()
|
990
|
+
|
991
|
+
if self._background_thread and self._background_thread.is_alive():
|
992
|
+
self._background_thread.join()
|
993
|
+
|
994
|
+
if self._cancel_thread and self._cancel_thread.is_alive():
|
995
|
+
self._cancel_thread.join()
|
@@ -22,7 +22,7 @@ from vellum.workflows.utils.vellum_variables import vellum_variable_type_to_open
|
|
22
22
|
if TYPE_CHECKING:
|
23
23
|
from vellum.workflows.workflows.base import BaseWorkflow
|
24
24
|
|
25
|
-
type_map = {
|
25
|
+
type_map: dict[Any, str] = {
|
26
26
|
str: "string",
|
27
27
|
int: "integer",
|
28
28
|
float: "number",
|
@@ -32,8 +32,13 @@ type_map = {
|
|
32
32
|
None: "null",
|
33
33
|
type(None): "null",
|
34
34
|
inspect._empty: "null",
|
35
|
+
"None": "null",
|
35
36
|
}
|
36
37
|
|
38
|
+
for k, v in list(type_map.items()):
|
39
|
+
if isinstance(k, type):
|
40
|
+
type_map[k.__name__] = v
|
41
|
+
|
37
42
|
|
38
43
|
def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
|
39
44
|
if annotation is None:
|
@@ -733,3 +733,43 @@ def test_compile_function_definition__tuples(annotation, expected_schema):
|
|
733
733
|
compiled_function = compile_function_definition(my_function)
|
734
734
|
assert isinstance(compiled_function.parameters, dict)
|
735
735
|
assert compiled_function.parameters["properties"]["a"] == expected_schema
|
736
|
+
|
737
|
+
|
738
|
+
def test_compile_function_definition__string_annotations_with_future_imports():
|
739
|
+
"""Test that string annotations work with __future__ import annotations."""
|
740
|
+
# This simulates what happens when using `from __future__ import annotations`
|
741
|
+
# where type annotations become string literals at runtime
|
742
|
+
|
743
|
+
def my_function_with_string_annotations(
|
744
|
+
a: "str",
|
745
|
+
b: "int",
|
746
|
+
c: "float",
|
747
|
+
d: "bool",
|
748
|
+
e: "list",
|
749
|
+
f: "dict",
|
750
|
+
g: "None",
|
751
|
+
):
|
752
|
+
"""Function with string type annotations."""
|
753
|
+
pass
|
754
|
+
|
755
|
+
# WHEN compiling the function
|
756
|
+
compiled_function = compile_function_definition(my_function_with_string_annotations)
|
757
|
+
|
758
|
+
# THEN it should return the compiled function definition with proper types
|
759
|
+
assert compiled_function == FunctionDefinition(
|
760
|
+
name="my_function_with_string_annotations",
|
761
|
+
description="Function with string type annotations.",
|
762
|
+
parameters={
|
763
|
+
"type": "object",
|
764
|
+
"properties": {
|
765
|
+
"a": {"type": "string"},
|
766
|
+
"b": {"type": "integer"},
|
767
|
+
"c": {"type": "number"},
|
768
|
+
"d": {"type": "boolean"},
|
769
|
+
"e": {"type": "array"},
|
770
|
+
"f": {"type": "object"},
|
771
|
+
"g": {"type": "null"},
|
772
|
+
},
|
773
|
+
"required": ["a", "b", "c", "d", "e", "f", "g"],
|
774
|
+
},
|
775
|
+
)
|
@@ -31,6 +31,7 @@ from vellum.workflows.edges import Edge
|
|
31
31
|
from vellum.workflows.emitters.base import BaseWorkflowEmitter
|
32
32
|
from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
|
33
33
|
from vellum.workflows.events.node import (
|
34
|
+
NodeEvent,
|
34
35
|
NodeExecutionFulfilledBody,
|
35
36
|
NodeExecutionFulfilledEvent,
|
36
37
|
NodeExecutionInitiatedBody,
|
@@ -252,6 +253,7 @@ class BaseWorkflow(Generic[InputsType, StateType], BaseExecutable, metaclass=_Ba
|
|
252
253
|
self.resolvers = resolvers or (self.resolvers if hasattr(self, "resolvers") else [])
|
253
254
|
self._store = store or Store()
|
254
255
|
self._execution_context = self._context.execution_context
|
256
|
+
self._current_runner: Optional[WorkflowRunner] = None
|
255
257
|
|
256
258
|
# Register context with all emitters
|
257
259
|
for emitter in self.emitters:
|
@@ -412,7 +414,7 @@ class BaseWorkflow(Generic[InputsType, StateType], BaseExecutable, metaclass=_Ba
|
|
412
414
|
subworkflows or nodes that utilizes threads.
|
413
415
|
"""
|
414
416
|
|
415
|
-
|
417
|
+
runner = WorkflowRunner(
|
416
418
|
self,
|
417
419
|
inputs=inputs,
|
418
420
|
state=state,
|
@@ -423,7 +425,9 @@ class BaseWorkflow(Generic[InputsType, StateType], BaseExecutable, metaclass=_Ba
|
|
423
425
|
node_output_mocks=node_output_mocks,
|
424
426
|
max_concurrency=max_concurrency,
|
425
427
|
init_execution_context=self._execution_context,
|
426
|
-
)
|
428
|
+
)
|
429
|
+
self._current_runner = runner
|
430
|
+
events = runner.stream()
|
427
431
|
first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
|
428
432
|
last_event = None
|
429
433
|
for event in events:
|
@@ -531,7 +535,7 @@ class BaseWorkflow(Generic[InputsType, StateType], BaseExecutable, metaclass=_Ba
|
|
531
535
|
"""
|
532
536
|
|
533
537
|
should_yield = event_filter or workflow_event_filter
|
534
|
-
|
538
|
+
runner = WorkflowRunner(
|
535
539
|
self,
|
536
540
|
inputs=inputs,
|
537
541
|
state=state,
|
@@ -542,7 +546,9 @@ class BaseWorkflow(Generic[InputsType, StateType], BaseExecutable, metaclass=_Ba
|
|
542
546
|
node_output_mocks=node_output_mocks,
|
543
547
|
max_concurrency=max_concurrency,
|
544
548
|
init_execution_context=self._execution_context,
|
545
|
-
)
|
549
|
+
)
|
550
|
+
self._current_runner = runner
|
551
|
+
runner_stream = runner.stream()
|
546
552
|
|
547
553
|
def _generate_filtered_events() -> Generator[BaseWorkflow.WorkflowEvent, None, None]:
|
548
554
|
for event in runner_stream:
|
@@ -559,6 +565,11 @@ class BaseWorkflow(Generic[InputsType, StateType], BaseExecutable, metaclass=_Ba
|
|
559
565
|
# https://app.shortcut.com/vellum/story/4327
|
560
566
|
pass
|
561
567
|
|
568
|
+
def run_node(self, node: Type[BaseNode]) -> Generator[NodeEvent, None, None]:
|
569
|
+
runner = WorkflowRunner(self)
|
570
|
+
span_id = uuid4()
|
571
|
+
return runner.run_node(node=node(state=self.get_default_state(), context=self._context), span_id=span_id)
|
572
|
+
|
562
573
|
@classmethod
|
563
574
|
@lru_cache
|
564
575
|
def _get_parameterized_classes(
|
@@ -689,9 +700,12 @@ class BaseWorkflow(Generic[InputsType, StateType], BaseExecutable, metaclass=_Ba
|
|
689
700
|
|
690
701
|
def join(self) -> None:
|
691
702
|
"""
|
692
|
-
Wait for all emitters to complete their background work.
|
703
|
+
Wait for all emitters and runner to complete their background work.
|
693
704
|
This ensures all pending events are processed before the workflow terminates.
|
694
705
|
"""
|
706
|
+
if self._current_runner:
|
707
|
+
self._current_runner.join()
|
708
|
+
|
695
709
|
for emitter in self.emitters:
|
696
710
|
emitter.join()
|
697
711
|
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
from uuid import UUID, uuid4
|
4
4
|
|
5
5
|
from vellum.workflows.edges.edge import Edge
|
6
|
+
from vellum.workflows.events.node import NodeExecutionFulfilledEvent, NodeExecutionInitiatedEvent
|
6
7
|
from vellum.workflows.inputs.base import BaseInputs
|
7
8
|
from vellum.workflows.nodes.bases.base import BaseNode
|
8
9
|
from vellum.workflows.nodes.core.inline_subworkflow_node.node import InlineSubworkflowNode
|
@@ -684,3 +685,56 @@ def test_base_workflow__deserialize_state_with_invalid_workflow_definition(raw_w
|
|
684
685
|
|
685
686
|
# AND the workflow definition should be BaseWorkflow
|
686
687
|
assert state.meta.workflow_definition == BaseWorkflow
|
688
|
+
|
689
|
+
|
690
|
+
def test_base_workflow__join_calls_runner_join():
|
691
|
+
"""
|
692
|
+
Test that BaseWorkflow.join() calls runner.join() when runner exists.
|
693
|
+
"""
|
694
|
+
|
695
|
+
# GIVEN a test workflow
|
696
|
+
class TestWorkflow(BaseWorkflow[BaseInputs, BaseState]):
|
697
|
+
pass
|
698
|
+
|
699
|
+
workflow = TestWorkflow()
|
700
|
+
|
701
|
+
# WHEN we run the workflow to create a runner
|
702
|
+
workflow.run()
|
703
|
+
|
704
|
+
workflow.join()
|
705
|
+
|
706
|
+
# THEN the runner should have been joined (verified by no hanging threads)
|
707
|
+
assert workflow._current_runner is not None
|
708
|
+
|
709
|
+
|
710
|
+
def test_base_workflow__run_node_emits_correct_events():
|
711
|
+
"""Test that WorkflowRunner.run_node method emits the expected events."""
|
712
|
+
|
713
|
+
class TestInputs(BaseInputs):
|
714
|
+
pass
|
715
|
+
|
716
|
+
class TestState(BaseState):
|
717
|
+
pass
|
718
|
+
|
719
|
+
class TestNode(BaseNode[TestState]):
|
720
|
+
class Outputs(BaseNode.Outputs):
|
721
|
+
result: str
|
722
|
+
|
723
|
+
def run(self) -> "TestNode.Outputs":
|
724
|
+
return self.Outputs(result="test_output")
|
725
|
+
|
726
|
+
class TestWorkflow(BaseWorkflow[TestInputs, TestState]):
|
727
|
+
graph = TestNode
|
728
|
+
|
729
|
+
class Outputs(BaseWorkflow.Outputs):
|
730
|
+
result: str
|
731
|
+
|
732
|
+
workflow = TestWorkflow()
|
733
|
+
|
734
|
+
events = list(workflow.run_node(node=TestNode))
|
735
|
+
|
736
|
+
assert len(events) == 2
|
737
|
+
assert isinstance(events[0], NodeExecutionInitiatedEvent)
|
738
|
+
assert isinstance(events[1], NodeExecutionFulfilledEvent)
|
739
|
+
assert events[0].span_id == events[1].span_id
|
740
|
+
assert events[1].body.outputs.result == "test_output"
|