vellum-ai 0.13.19__py3-none-any.whl → 0.13.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/__init__.py +2 -0
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/resources/test_suites/client.py +44 -8
- vellum/client/types/__init__.py +2 -0
- vellum/client/types/container_image_container_image_tag.py +21 -0
- vellum/client/types/container_image_read.py +2 -1
- vellum/types/container_image_container_image_tag.py +3 -0
- vellum/workflows/events/workflow.py +7 -7
- vellum/workflows/graph/graph.py +6 -0
- vellum/workflows/graph/tests/test_graph.py +24 -0
- vellum/workflows/nodes/bases/base.py +3 -2
- vellum/workflows/nodes/core/inline_subworkflow_node/node.py +52 -7
- vellum/workflows/runner/runner.py +3 -3
- vellum/workflows/types/generics.py +1 -1
- vellum/workflows/workflows/base.py +11 -11
- {vellum_ai-0.13.19.dist-info → vellum_ai-0.13.21.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.19.dist-info → vellum_ai-0.13.21.dist-info}/RECORD +41 -38
- vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +2 -3
- vellum_ee/workflows/display/nodes/vellum/utils.py +1 -1
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +9 -30
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +22 -36
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +35 -70
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_error_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +24 -57
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +4 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +4 -4
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +2 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -2
- vellum_ee/workflows/display/utils/vellum.py +7 -2
- vellum_ee/workflows/display/vellum.py +0 -2
- vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +65 -0
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +11 -15
- vellum_ee/workflows/tests/local_workflow/display/workflow.py +0 -2
- {vellum_ai-0.13.19.dist-info → vellum_ai-0.13.21.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.19.dist-info → vellum_ai-0.13.21.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.19.dist-info → vellum_ai-0.13.21.dist-info}/entry_points.txt +0 -0
vellum/__init__.py
CHANGED
@@ -64,6 +64,7 @@ from .types import (
|
|
64
64
|
ConditionCombinator,
|
65
65
|
ConditionalNodeResult,
|
66
66
|
ConditionalNodeResultData,
|
67
|
+
ContainerImageContainerImageTag,
|
67
68
|
ContainerImageRead,
|
68
69
|
CreateTestSuiteTestCaseRequest,
|
69
70
|
DeploymentHistoryItem,
|
@@ -586,6 +587,7 @@ __all__ = [
|
|
586
587
|
"ConditionCombinator",
|
587
588
|
"ConditionalNodeResult",
|
588
589
|
"ConditionalNodeResultData",
|
590
|
+
"ContainerImageContainerImageTag",
|
589
591
|
"ContainerImageRead",
|
590
592
|
"CreateTestSuiteTestCaseRequest",
|
591
593
|
"DeploymentHistoryItem",
|
@@ -18,7 +18,7 @@ class BaseClientWrapper:
|
|
18
18
|
headers: typing.Dict[str, str] = {
|
19
19
|
"X-Fern-Language": "Python",
|
20
20
|
"X-Fern-SDK-Name": "vellum-ai",
|
21
|
-
"X-Fern-SDK-Version": "0.13.
|
21
|
+
"X-Fern-SDK-Version": "0.13.21",
|
22
22
|
}
|
23
23
|
headers["X_API_KEY"] = self.api_key
|
24
24
|
return headers
|
@@ -139,21 +139,39 @@ class TestSuitesClient:
|
|
139
139
|
|
140
140
|
Examples
|
141
141
|
--------
|
142
|
-
from vellum import
|
142
|
+
from vellum import (
|
143
|
+
NamedTestCaseArrayVariableValueRequest,
|
144
|
+
NamedTestCaseStringVariableValueRequest,
|
145
|
+
StringVellumValueRequest,
|
146
|
+
Vellum,
|
147
|
+
)
|
143
148
|
|
144
149
|
client = Vellum(
|
145
150
|
api_key="YOUR_API_KEY",
|
146
151
|
)
|
147
152
|
client.test_suites.upsert_test_suite_test_case(
|
148
153
|
id_="id",
|
154
|
+
label="Test Case 1",
|
149
155
|
input_values=[
|
150
156
|
NamedTestCaseStringVariableValueRequest(
|
151
|
-
|
157
|
+
value="What are your favorite colors?",
|
158
|
+
name="var_1",
|
152
159
|
)
|
153
160
|
],
|
154
161
|
evaluation_values=[
|
155
|
-
|
156
|
-
|
162
|
+
NamedTestCaseArrayVariableValueRequest(
|
163
|
+
value=[
|
164
|
+
StringVellumValueRequest(
|
165
|
+
value="Red",
|
166
|
+
),
|
167
|
+
StringVellumValueRequest(
|
168
|
+
value="Green",
|
169
|
+
),
|
170
|
+
StringVellumValueRequest(
|
171
|
+
value="Blue",
|
172
|
+
),
|
173
|
+
],
|
174
|
+
name="var_2",
|
157
175
|
)
|
158
176
|
],
|
159
177
|
)
|
@@ -463,7 +481,12 @@ class AsyncTestSuitesClient:
|
|
463
481
|
--------
|
464
482
|
import asyncio
|
465
483
|
|
466
|
-
from vellum import
|
484
|
+
from vellum import (
|
485
|
+
AsyncVellum,
|
486
|
+
NamedTestCaseArrayVariableValueRequest,
|
487
|
+
NamedTestCaseStringVariableValueRequest,
|
488
|
+
StringVellumValueRequest,
|
489
|
+
)
|
467
490
|
|
468
491
|
client = AsyncVellum(
|
469
492
|
api_key="YOUR_API_KEY",
|
@@ -473,14 +496,27 @@ class AsyncTestSuitesClient:
|
|
473
496
|
async def main() -> None:
|
474
497
|
await client.test_suites.upsert_test_suite_test_case(
|
475
498
|
id_="id",
|
499
|
+
label="Test Case 1",
|
476
500
|
input_values=[
|
477
501
|
NamedTestCaseStringVariableValueRequest(
|
478
|
-
|
502
|
+
value="What are your favorite colors?",
|
503
|
+
name="var_1",
|
479
504
|
)
|
480
505
|
],
|
481
506
|
evaluation_values=[
|
482
|
-
|
483
|
-
|
507
|
+
NamedTestCaseArrayVariableValueRequest(
|
508
|
+
value=[
|
509
|
+
StringVellumValueRequest(
|
510
|
+
value="Red",
|
511
|
+
),
|
512
|
+
StringVellumValueRequest(
|
513
|
+
value="Green",
|
514
|
+
),
|
515
|
+
StringVellumValueRequest(
|
516
|
+
value="Blue",
|
517
|
+
),
|
518
|
+
],
|
519
|
+
name="var_2",
|
484
520
|
)
|
485
521
|
],
|
486
522
|
)
|
vellum/client/types/__init__.py
CHANGED
@@ -68,6 +68,7 @@ from .components_schemas_pdf_search_result_meta_source_request import Components
|
|
68
68
|
from .condition_combinator import ConditionCombinator
|
69
69
|
from .conditional_node_result import ConditionalNodeResult
|
70
70
|
from .conditional_node_result_data import ConditionalNodeResultData
|
71
|
+
from .container_image_container_image_tag import ContainerImageContainerImageTag
|
71
72
|
from .container_image_read import ContainerImageRead
|
72
73
|
from .create_test_suite_test_case_request import CreateTestSuiteTestCaseRequest
|
73
74
|
from .deployment_history_item import DeploymentHistoryItem
|
@@ -577,6 +578,7 @@ __all__ = [
|
|
577
578
|
"ConditionCombinator",
|
578
579
|
"ConditionalNodeResult",
|
579
580
|
"ConditionalNodeResultData",
|
581
|
+
"ContainerImageContainerImageTag",
|
580
582
|
"ContainerImageRead",
|
581
583
|
"CreateTestSuiteTestCaseRequest",
|
582
584
|
"DeploymentHistoryItem",
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import datetime as dt
|
5
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
6
|
+
import typing
|
7
|
+
import pydantic
|
8
|
+
|
9
|
+
|
10
|
+
class ContainerImageContainerImageTag(UniversalBaseModel):
|
11
|
+
name: str
|
12
|
+
modified: dt.datetime
|
13
|
+
|
14
|
+
if IS_PYDANTIC_V2:
|
15
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
16
|
+
else:
|
17
|
+
|
18
|
+
class Config:
|
19
|
+
frozen = True
|
20
|
+
smart_union = True
|
21
|
+
extra = pydantic.Extra.allow
|
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
|
|
4
4
|
from .entity_visibility import EntityVisibility
|
5
5
|
import datetime as dt
|
6
6
|
import typing
|
7
|
+
from .container_image_container_image_tag import ContainerImageContainerImageTag
|
7
8
|
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
8
9
|
import pydantic
|
9
10
|
|
@@ -16,7 +17,7 @@ class ContainerImageRead(UniversalBaseModel):
|
|
16
17
|
modified: dt.datetime
|
17
18
|
repository: str
|
18
19
|
sha: str
|
19
|
-
tags: typing.List[
|
20
|
+
tags: typing.List[ContainerImageContainerImageTag]
|
20
21
|
|
21
22
|
if IS_PYDANTIC_V2:
|
22
23
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
@@ -6,7 +6,7 @@ from vellum.core.pydantic_utilities import UniversalBaseModel
|
|
6
6
|
from vellum.workflows.errors import WorkflowError
|
7
7
|
from vellum.workflows.outputs.base import BaseOutput
|
8
8
|
from vellum.workflows.references import ExternalInputReference
|
9
|
-
from vellum.workflows.types.generics import OutputsType, StateType
|
9
|
+
from vellum.workflows.types.generics import InputsType, OutputsType, StateType
|
10
10
|
|
11
11
|
from .node import (
|
12
12
|
NodeExecutionFulfilledEvent,
|
@@ -38,20 +38,20 @@ class _BaseWorkflowEvent(BaseEvent):
|
|
38
38
|
return self.body.workflow_definition
|
39
39
|
|
40
40
|
|
41
|
-
class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[
|
42
|
-
inputs:
|
41
|
+
class WorkflowExecutionInitiatedBody(_BaseWorkflowExecutionBody, Generic[InputsType]):
|
42
|
+
inputs: InputsType
|
43
43
|
|
44
44
|
@field_serializer("inputs")
|
45
|
-
def serialize_inputs(self, inputs:
|
45
|
+
def serialize_inputs(self, inputs: InputsType, _info: Any) -> Dict[str, Any]:
|
46
46
|
return default_serializer(inputs)
|
47
47
|
|
48
48
|
|
49
|
-
class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[
|
49
|
+
class WorkflowExecutionInitiatedEvent(_BaseWorkflowEvent, Generic[InputsType]):
|
50
50
|
name: Literal["workflow.execution.initiated"] = "workflow.execution.initiated"
|
51
|
-
body: WorkflowExecutionInitiatedBody[
|
51
|
+
body: WorkflowExecutionInitiatedBody[InputsType]
|
52
52
|
|
53
53
|
@property
|
54
|
-
def inputs(self) ->
|
54
|
+
def inputs(self) -> InputsType:
|
55
55
|
return self.body.inputs
|
56
56
|
|
57
57
|
|
vellum/workflows/graph/graph.py
CHANGED
@@ -118,6 +118,12 @@ class Graph:
|
|
118
118
|
self._terminals = {other}
|
119
119
|
return self
|
120
120
|
|
121
|
+
def __rrshift__(cls, other_cls: GraphTarget) -> "Graph":
|
122
|
+
if not isinstance(other_cls, set):
|
123
|
+
other_cls = {other_cls}
|
124
|
+
|
125
|
+
return Graph.from_set(other_cls) >> cls
|
126
|
+
|
121
127
|
@property
|
122
128
|
def entrypoints(self) -> Iterator[Type["BaseNode"]]:
|
123
129
|
return iter(e.node_class for e in self._entrypoints)
|
@@ -460,3 +460,27 @@ def test_graph__node_to_port():
|
|
460
460
|
|
461
461
|
# AND two edges
|
462
462
|
assert len(list(graph.edges)) == 2
|
463
|
+
|
464
|
+
|
465
|
+
def test_graph__set_to_graph():
|
466
|
+
# GIVEN three nodes
|
467
|
+
class SourceNode(BaseNode):
|
468
|
+
pass
|
469
|
+
|
470
|
+
class MiddleNode(BaseNode):
|
471
|
+
pass
|
472
|
+
|
473
|
+
class TargetNode(BaseNode):
|
474
|
+
pass
|
475
|
+
|
476
|
+
# WHEN we create a graph from a set to a graph
|
477
|
+
graph: Graph = {SourceNode, MiddleNode} >> Graph.from_node(TargetNode)
|
478
|
+
|
479
|
+
# THEN the graph has the source node and middle node as the entrypoints
|
480
|
+
assert set(graph.entrypoints) == {SourceNode, MiddleNode}
|
481
|
+
|
482
|
+
# AND three nodes
|
483
|
+
assert len(list(graph.nodes)) == 3
|
484
|
+
|
485
|
+
# AND two edges
|
486
|
+
assert len(list(graph.edges)) == 2
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from dataclasses import field
|
1
2
|
from functools import cached_property, reduce
|
2
3
|
import inspect
|
3
4
|
from types import MappingProxyType
|
@@ -192,7 +193,7 @@ class _BaseNodeTriggerMeta(type):
|
|
192
193
|
|
193
194
|
class _BaseNodeExecutionMeta(type):
|
194
195
|
def __getattribute__(cls, name: str) -> Any:
|
195
|
-
if name
|
196
|
+
if name == "count" and issubclass(cls, BaseNode.Execution):
|
196
197
|
return ExecutionCountReference(cls.node_class)
|
197
198
|
|
198
199
|
return super().__getattribute__(name)
|
@@ -230,7 +231,7 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
|
|
230
231
|
# "Outputs" class inherits from "BaseOutputs" and do so automatically.
|
231
232
|
# https://app.shortcut.com/vellum/story/4008/auto-inherit-basenodeoutputs-in-outputs-classes
|
232
233
|
class Outputs(BaseOutputs):
|
233
|
-
_node_class:
|
234
|
+
_node_class: Type["BaseNode"] = field(init=False)
|
234
235
|
|
235
236
|
class Ports(NodePorts):
|
236
237
|
default = Port(default=True)
|
@@ -1,16 +1,17 @@
|
|
1
|
-
from typing import TYPE_CHECKING, ClassVar, Generic, Iterator, Optional, Set, Type, TypeVar, Union
|
1
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union
|
2
2
|
|
3
3
|
from vellum.workflows.constants import UNDEF
|
4
4
|
from vellum.workflows.context import execution_context, get_parent_context
|
5
5
|
from vellum.workflows.errors.types import WorkflowErrorCode
|
6
6
|
from vellum.workflows.exceptions import NodeException
|
7
7
|
from vellum.workflows.inputs.base import BaseInputs
|
8
|
-
from vellum.workflows.nodes.bases.base import BaseNode
|
8
|
+
from vellum.workflows.nodes.bases.base import BaseNode, BaseNodeMeta
|
9
9
|
from vellum.workflows.outputs.base import BaseOutput, BaseOutputs
|
10
|
+
from vellum.workflows.references import OutputReference
|
10
11
|
from vellum.workflows.state.base import BaseState
|
11
12
|
from vellum.workflows.state.context import WorkflowContext
|
12
13
|
from vellum.workflows.types.core import EntityInputsInterface
|
13
|
-
from vellum.workflows.types.generics import
|
14
|
+
from vellum.workflows.types.generics import InputsType, StateType
|
14
15
|
from vellum.workflows.workflows.event_filters import all_workflow_event_filter
|
15
16
|
|
16
17
|
if TYPE_CHECKING:
|
@@ -19,15 +20,53 @@ if TYPE_CHECKING:
|
|
19
20
|
InnerStateType = TypeVar("InnerStateType", bound=BaseState)
|
20
21
|
|
21
22
|
|
22
|
-
class
|
23
|
+
class _InlineSubworkflowNodeMeta(BaseNodeMeta):
|
24
|
+
def __new__(cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
|
25
|
+
node_class = super().__new__(cls, name, bases, dct)
|
26
|
+
|
27
|
+
subworkflow_attribute = dct.get("subworkflow")
|
28
|
+
if not subworkflow_attribute:
|
29
|
+
return node_class
|
30
|
+
|
31
|
+
if not issubclass(node_class, InlineSubworkflowNode):
|
32
|
+
raise ValueError("_InlineSubworkflowNodeMeta can only be used on subclasses of InlineSubworkflowNode")
|
33
|
+
|
34
|
+
subworkflow_outputs = getattr(subworkflow_attribute, "Outputs")
|
35
|
+
if not issubclass(subworkflow_outputs, BaseOutputs):
|
36
|
+
raise ValueError("subworkflow.Outputs must be a subclass of BaseOutputs")
|
37
|
+
|
38
|
+
outputs_class = dct.get("Outputs")
|
39
|
+
if not outputs_class:
|
40
|
+
raise ValueError("Outputs class not found in base classes")
|
41
|
+
|
42
|
+
if not issubclass(outputs_class, BaseNode.Outputs):
|
43
|
+
raise ValueError("Outputs class must be a subclass of BaseNode.Outputs")
|
44
|
+
|
45
|
+
for descriptor in subworkflow_outputs:
|
46
|
+
node_class.__annotate_outputs_class__(outputs_class, descriptor)
|
47
|
+
|
48
|
+
return node_class
|
49
|
+
|
50
|
+
def __getattribute__(cls, name: str) -> Any:
|
51
|
+
try:
|
52
|
+
return super().__getattribute__(name)
|
53
|
+
except AttributeError:
|
54
|
+
if name != "__wrapped_node__" and issubclass(cls, InlineSubworkflowNode):
|
55
|
+
return getattr(cls.__wrapped_node__, name)
|
56
|
+
raise
|
57
|
+
|
58
|
+
|
59
|
+
class InlineSubworkflowNode(
|
60
|
+
BaseNode[StateType], Generic[StateType, InputsType, InnerStateType], metaclass=_InlineSubworkflowNodeMeta
|
61
|
+
):
|
23
62
|
"""
|
24
63
|
Used to execute a Subworkflow defined inline.
|
25
64
|
|
26
|
-
subworkflow: Type["BaseWorkflow[
|
65
|
+
subworkflow: Type["BaseWorkflow[InputsType, InnerStateType]"] - The Subworkflow to execute
|
27
66
|
subworkflow_inputs: ClassVar[EntityInputsInterface] = {}
|
28
67
|
"""
|
29
68
|
|
30
|
-
subworkflow: Type["BaseWorkflow[
|
69
|
+
subworkflow: Type["BaseWorkflow[InputsType, InnerStateType]"]
|
31
70
|
subworkflow_inputs: ClassVar[Union[EntityInputsInterface, BaseInputs, Type[UNDEF]]] = UNDEF
|
32
71
|
|
33
72
|
def run(self) -> Iterator[BaseOutput]:
|
@@ -70,7 +109,7 @@ class InlineSubworkflowNode(BaseNode[StateType], Generic[StateType, WorkflowInpu
|
|
70
109
|
value=output_value,
|
71
110
|
)
|
72
111
|
|
73
|
-
def _compile_subworkflow_inputs(self) ->
|
112
|
+
def _compile_subworkflow_inputs(self) -> InputsType:
|
74
113
|
inputs_class = self.subworkflow.get_inputs_class()
|
75
114
|
if self.subworkflow_inputs is UNDEF:
|
76
115
|
inputs_dict = {}
|
@@ -85,3 +124,9 @@ class InlineSubworkflowNode(BaseNode[StateType], Generic[StateType, WorkflowInpu
|
|
85
124
|
return self.subworkflow_inputs
|
86
125
|
else:
|
87
126
|
raise ValueError(f"Invalid subworkflow inputs type: {type(self.subworkflow_inputs)}")
|
127
|
+
|
128
|
+
@classmethod
|
129
|
+
def __annotate_outputs_class__(cls, outputs_class: Type[BaseOutputs], reference: OutputReference) -> None:
|
130
|
+
# Subclasses of InlineSubworkflowNode can override this method to provider their own
|
131
|
+
# approach to annotating the outputs class based on the `subworkflow.Outputs`
|
132
|
+
setattr(outputs_class, reference.name, reference)
|
@@ -65,7 +65,7 @@ from vellum.workflows.ports.port import Port
|
|
65
65
|
from vellum.workflows.references import ExternalInputReference, OutputReference
|
66
66
|
from vellum.workflows.state.base import BaseState
|
67
67
|
from vellum.workflows.types.cycle_map import CycleMap
|
68
|
-
from vellum.workflows.types.generics import OutputsType, StateType
|
68
|
+
from vellum.workflows.types.generics import InputsType, OutputsType, StateType
|
69
69
|
|
70
70
|
if TYPE_CHECKING:
|
71
71
|
from vellum.workflows import BaseWorkflow
|
@@ -82,8 +82,8 @@ class WorkflowRunner(Generic[StateType]):
|
|
82
82
|
|
83
83
|
def __init__(
|
84
84
|
self,
|
85
|
-
workflow: "BaseWorkflow[
|
86
|
-
inputs: Optional[
|
85
|
+
workflow: "BaseWorkflow[InputsType, StateType]",
|
86
|
+
inputs: Optional[InputsType] = None,
|
87
87
|
state: Optional[StateType] = None,
|
88
88
|
entrypoint_nodes: Optional[RunFromNodeArg] = None,
|
89
89
|
external_inputs: Optional[ExternalInputsArg] = None,
|
@@ -10,5 +10,5 @@ if TYPE_CHECKING:
|
|
10
10
|
NodeType = TypeVar("NodeType", bound="BaseNode")
|
11
11
|
StateType = TypeVar("StateType", bound="BaseState")
|
12
12
|
WorkflowType = TypeVar("WorkflowType", bound="BaseWorkflow")
|
13
|
-
|
13
|
+
InputsType = TypeVar("InputsType", bound="BaseInputs")
|
14
14
|
OutputsType = TypeVar("OutputsType", bound="BaseOutputs")
|
@@ -68,7 +68,7 @@ from vellum.workflows.runner.runner import ExternalInputsArg, RunFromNodeArg
|
|
68
68
|
from vellum.workflows.state.base import BaseState, StateMeta
|
69
69
|
from vellum.workflows.state.context import WorkflowContext
|
70
70
|
from vellum.workflows.state.store import Store
|
71
|
-
from vellum.workflows.types.generics import
|
71
|
+
from vellum.workflows.types.generics import InputsType, StateType
|
72
72
|
from vellum.workflows.types.utils import get_original_base
|
73
73
|
from vellum.workflows.utils.uuids import uuid4_from_hash
|
74
74
|
from vellum.workflows.workflows.event_filters import workflow_event_filter
|
@@ -88,7 +88,7 @@ class _BaseWorkflowMeta(type):
|
|
88
88
|
GraphAttribute = Union[Type[BaseNode], Graph, Set[Type[BaseNode]], Set[Graph]]
|
89
89
|
|
90
90
|
|
91
|
-
class BaseWorkflow(Generic[
|
91
|
+
class BaseWorkflow(Generic[InputsType, StateType], metaclass=_BaseWorkflowMeta):
|
92
92
|
__id__: UUID = uuid4_from_hash(__qualname__)
|
93
93
|
graph: ClassVar[GraphAttribute]
|
94
94
|
emitters: List[BaseWorkflowEmitter]
|
@@ -99,7 +99,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
99
99
|
|
100
100
|
WorkflowEvent = Union[ # type: ignore
|
101
101
|
GenericWorkflowEvent,
|
102
|
-
WorkflowExecutionInitiatedEvent[
|
102
|
+
WorkflowExecutionInitiatedEvent[InputsType], # type: ignore[valid-type]
|
103
103
|
WorkflowExecutionFulfilledEvent[Outputs],
|
104
104
|
WorkflowExecutionSnapshottedEvent[StateType], # type: ignore[valid-type]
|
105
105
|
]
|
@@ -181,7 +181,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
181
181
|
|
182
182
|
def run(
|
183
183
|
self,
|
184
|
-
inputs: Optional[
|
184
|
+
inputs: Optional[InputsType] = None,
|
185
185
|
*,
|
186
186
|
state: Optional[StateType] = None,
|
187
187
|
entrypoint_nodes: Optional[RunFromNodeArg] = None,
|
@@ -198,7 +198,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
198
198
|
|
199
199
|
Parameters
|
200
200
|
----------
|
201
|
-
inputs: Optional[
|
201
|
+
inputs: Optional[InputsType] = None
|
202
202
|
The Inputs instance used to initiate the Workflow Execution.
|
203
203
|
|
204
204
|
state: Optional[StateType] = None
|
@@ -288,7 +288,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
288
288
|
|
289
289
|
def stream(
|
290
290
|
self,
|
291
|
-
inputs: Optional[
|
291
|
+
inputs: Optional[InputsType] = None,
|
292
292
|
*,
|
293
293
|
event_filter: Optional[Callable[[Type["BaseWorkflow"], WorkflowEvent], bool]] = None,
|
294
294
|
state: Optional[StateType] = None,
|
@@ -307,7 +307,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
307
307
|
A filter that can be used to filter events based on the Workflow Class and the event itself. If the method
|
308
308
|
returns `False`, the event will not be yielded.
|
309
309
|
|
310
|
-
inputs: Optional[
|
310
|
+
inputs: Optional[InputsType] = None
|
311
311
|
The Inputs instance used to initiate the Workflow Execution.
|
312
312
|
|
313
313
|
state: Optional[StateType] = None
|
@@ -359,7 +359,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
359
359
|
@lru_cache
|
360
360
|
def _get_parameterized_classes(
|
361
361
|
cls,
|
362
|
-
) -> Tuple[Type[
|
362
|
+
) -> Tuple[Type[InputsType], Type[StateType]]:
|
363
363
|
original_base = get_original_base(cls)
|
364
364
|
|
365
365
|
inputs_type, state_type = get_args(original_base)
|
@@ -378,17 +378,17 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
378
378
|
return (inputs_type, state_type)
|
379
379
|
|
380
380
|
@classmethod
|
381
|
-
def get_inputs_class(cls) -> Type[
|
381
|
+
def get_inputs_class(cls) -> Type[InputsType]:
|
382
382
|
return cls._get_parameterized_classes()[0]
|
383
383
|
|
384
384
|
@classmethod
|
385
385
|
def get_state_class(cls) -> Type[StateType]:
|
386
386
|
return cls._get_parameterized_classes()[1]
|
387
387
|
|
388
|
-
def get_default_inputs(self) ->
|
388
|
+
def get_default_inputs(self) -> InputsType:
|
389
389
|
return self.get_inputs_class()()
|
390
390
|
|
391
|
-
def get_default_state(self, workflow_inputs: Optional[
|
391
|
+
def get_default_state(self, workflow_inputs: Optional[InputsType] = None) -> StateType:
|
392
392
|
return self.get_state_class()(
|
393
393
|
meta=StateMeta(
|
394
394
|
parent=self._parent_state,
|