vellum-ai 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/core/pydantic_utilities.py +5 -0
- vellum/client/resources/workflows/client.py +8 -0
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/base.py +1 -1
- vellum/workflows/descriptors/tests/test_utils.py +3 -0
- vellum/workflows/expressions/accessor.py +8 -2
- vellum/workflows/nodes/core/map_node/node.py +49 -24
- vellum/workflows/nodes/core/map_node/tests/test_node.py +4 -4
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +3 -0
- vellum/workflows/nodes/displayable/bases/search_node.py +37 -2
- vellum/workflows/nodes/displayable/bases/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/bases/tests/test_utils.py +61 -0
- vellum/workflows/nodes/displayable/bases/types.py +42 -0
- vellum/workflows/nodes/displayable/bases/utils.py +112 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +0 -1
- vellum/workflows/nodes/displayable/search_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/search_node/tests/test_node.py +164 -0
- vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +2 -3
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +0 -1
- vellum/workflows/runner/runner.py +37 -4
- vellum/workflows/types/tests/test_utils.py +5 -2
- vellum/workflows/types/utils.py +4 -0
- vellum/workflows/workflows/base.py +14 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/RECORD +53 -42
- vellum_cli/__init__.py +24 -0
- vellum_cli/ping.py +28 -0
- vellum_cli/push.py +62 -12
- vellum_cli/tests/test_ping.py +47 -0
- vellum_cli/tests/test_push.py +76 -0
- vellum_ee/workflows/display/nodes/vellum/base_node.py +59 -11
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -0
- vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +14 -10
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +8 -1
- vellum_ee/workflows/display/tests/test_vellum_workflow_display.py +48 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +67 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +286 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_outputs_serialization.py +177 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +666 -14
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_trigger_serialization.py +7 -8
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +35 -2
- vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +29 -2
- vellum_ee/workflows/display/utils/vellum.py +4 -42
- vellum_ee/workflows/display/vellum.py +7 -36
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +5 -2
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
import enum
|
2
|
+
import json
|
3
|
+
from typing import Any, List, Union, cast
|
4
|
+
|
5
|
+
from vellum.client.types.array_vellum_value import ArrayVellumValue
|
6
|
+
from vellum.client.types.array_vellum_value_request import ArrayVellumValueRequest
|
7
|
+
from vellum.client.types.audio_vellum_value import AudioVellumValue
|
8
|
+
from vellum.client.types.audio_vellum_value_request import AudioVellumValueRequest
|
9
|
+
from vellum.client.types.chat_history_vellum_value import ChatHistoryVellumValue
|
10
|
+
from vellum.client.types.chat_history_vellum_value_request import ChatHistoryVellumValueRequest
|
11
|
+
from vellum.client.types.chat_message import ChatMessage
|
12
|
+
from vellum.client.types.error_vellum_value import ErrorVellumValue
|
13
|
+
from vellum.client.types.error_vellum_value_request import ErrorVellumValueRequest
|
14
|
+
from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
|
15
|
+
from vellum.client.types.function_call_vellum_value_request import FunctionCallVellumValueRequest
|
16
|
+
from vellum.client.types.image_vellum_value import ImageVellumValue
|
17
|
+
from vellum.client.types.image_vellum_value_request import ImageVellumValueRequest
|
18
|
+
from vellum.client.types.json_vellum_value import JsonVellumValue
|
19
|
+
from vellum.client.types.json_vellum_value_request import JsonVellumValueRequest
|
20
|
+
from vellum.client.types.number_vellum_value import NumberVellumValue
|
21
|
+
from vellum.client.types.number_vellum_value_request import NumberVellumValueRequest
|
22
|
+
from vellum.client.types.search_result import SearchResult
|
23
|
+
from vellum.client.types.search_result_request import SearchResultRequest
|
24
|
+
from vellum.client.types.search_results_vellum_value import SearchResultsVellumValue
|
25
|
+
from vellum.client.types.search_results_vellum_value_request import SearchResultsVellumValueRequest
|
26
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
27
|
+
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
28
|
+
from vellum.client.types.vellum_value import VellumValue
|
29
|
+
from vellum.client.types.vellum_value_request import VellumValueRequest
|
30
|
+
|
31
|
+
VELLUM_VALUE_REQUEST_TUPLE = (
|
32
|
+
StringVellumValueRequest,
|
33
|
+
NumberVellumValueRequest,
|
34
|
+
JsonVellumValueRequest,
|
35
|
+
ImageVellumValueRequest,
|
36
|
+
AudioVellumValueRequest,
|
37
|
+
FunctionCallVellumValueRequest,
|
38
|
+
ErrorVellumValueRequest,
|
39
|
+
ArrayVellumValueRequest,
|
40
|
+
ChatHistoryVellumValueRequest,
|
41
|
+
SearchResultsVellumValueRequest,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def primitive_to_vellum_value(value: Any) -> VellumValue:
|
46
|
+
"""Converts a python primitive to a VellumValue"""
|
47
|
+
|
48
|
+
if isinstance(value, str):
|
49
|
+
return StringVellumValue(value=value)
|
50
|
+
elif isinstance(value, enum.Enum):
|
51
|
+
return StringVellumValue(value=value.value)
|
52
|
+
elif isinstance(value, (int, float)):
|
53
|
+
return NumberVellumValue(value=value)
|
54
|
+
elif isinstance(value, list) and (
|
55
|
+
all(isinstance(message, ChatMessage) for message in value)
|
56
|
+
or all(isinstance(message, ChatMessage) for message in value)
|
57
|
+
):
|
58
|
+
chat_messages = cast(Union[List[ChatMessage], List[ChatMessage]], value)
|
59
|
+
return ChatHistoryVellumValue(value=chat_messages)
|
60
|
+
elif isinstance(value, list) and (
|
61
|
+
all(isinstance(search_result, SearchResultRequest) for search_result in value)
|
62
|
+
or all(isinstance(search_result, SearchResult) for search_result in value)
|
63
|
+
):
|
64
|
+
search_results = cast(Union[List[SearchResultRequest], List[SearchResult]], value)
|
65
|
+
return SearchResultsVellumValue(value=search_results)
|
66
|
+
elif isinstance(
|
67
|
+
value,
|
68
|
+
(
|
69
|
+
StringVellumValue,
|
70
|
+
NumberVellumValue,
|
71
|
+
JsonVellumValue,
|
72
|
+
ImageVellumValue,
|
73
|
+
AudioVellumValue,
|
74
|
+
FunctionCallVellumValue,
|
75
|
+
ErrorVellumValue,
|
76
|
+
ArrayVellumValue,
|
77
|
+
ChatHistoryVellumValue,
|
78
|
+
SearchResultsVellumValue,
|
79
|
+
),
|
80
|
+
):
|
81
|
+
return value
|
82
|
+
elif isinstance(
|
83
|
+
value,
|
84
|
+
VELLUM_VALUE_REQUEST_TUPLE,
|
85
|
+
):
|
86
|
+
# This type ignore is safe because consumers of this function won't care the difference between
|
87
|
+
# XVellumValue and XVellumValueRequest. Hopefully in the near future, neither will we
|
88
|
+
return value # type: ignore
|
89
|
+
|
90
|
+
try:
|
91
|
+
json_value = json.dumps(value)
|
92
|
+
except json.JSONDecodeError:
|
93
|
+
raise ValueError(f"Unsupported variable type: {value.__class__.__name__}")
|
94
|
+
|
95
|
+
return JsonVellumValue(value=json.loads(json_value))
|
96
|
+
|
97
|
+
|
98
|
+
def primitive_to_vellum_value_request(value: Any) -> VellumValueRequest:
|
99
|
+
vellum_value = primitive_to_vellum_value(value)
|
100
|
+
vellum_value_request_class = next(
|
101
|
+
(
|
102
|
+
vellum_value_request_class
|
103
|
+
for vellum_value_request_class in VELLUM_VALUE_REQUEST_TUPLE
|
104
|
+
if vellum_value_request_class.__name__.startswith(vellum_value.__class__.__name__)
|
105
|
+
),
|
106
|
+
None,
|
107
|
+
)
|
108
|
+
|
109
|
+
if vellum_value_request_class is None:
|
110
|
+
raise ValueError(f"Unsupported variable type: {vellum_value.__class__.__name__}")
|
111
|
+
|
112
|
+
return vellum_value_request_class.model_validate(vellum_value.model_dump())
|
@@ -76,7 +76,6 @@ def test_inline_prompt_node__function_definitions(vellum_adhoc_prompt_client):
|
|
76
76
|
class MyNode(InlinePromptNode):
|
77
77
|
ml_model = "gpt-4o"
|
78
78
|
functions = [my_function]
|
79
|
-
prompt_inputs = {}
|
80
79
|
blocks = []
|
81
80
|
|
82
81
|
# AND a known response from invoking an inline prompt
|
File without changes
|
@@ -0,0 +1,164 @@
|
|
1
|
+
from vellum import SearchResponse, SearchResult, SearchResultDocument
|
2
|
+
from vellum.client.types.search_filters_request import SearchFiltersRequest
|
3
|
+
from vellum.client.types.search_request_options_request import SearchRequestOptionsRequest
|
4
|
+
from vellum.client.types.search_result_merging_request import SearchResultMergingRequest
|
5
|
+
from vellum.client.types.search_weights_request import SearchWeightsRequest
|
6
|
+
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
7
|
+
from vellum.client.types.vellum_value_logical_condition_group_request import VellumValueLogicalConditionGroupRequest
|
8
|
+
from vellum.client.types.vellum_value_logical_condition_request import VellumValueLogicalConditionRequest
|
9
|
+
from vellum.workflows.nodes.displayable.bases.types import (
|
10
|
+
MetadataLogicalCondition,
|
11
|
+
MetadataLogicalConditionGroup,
|
12
|
+
SearchFilters,
|
13
|
+
)
|
14
|
+
from vellum.workflows.nodes.displayable.search_node.node import SearchNode
|
15
|
+
|
16
|
+
|
17
|
+
def test_run_workflow__happy_path(vellum_client):
|
18
|
+
"""Confirm that we can successfully invoke a Workflow with the new option attributes"""
|
19
|
+
|
20
|
+
# GIVEN a workflow that's set up run a Search Node
|
21
|
+
class MySearchNode(SearchNode):
|
22
|
+
query = "Search query"
|
23
|
+
document_index = "document_index"
|
24
|
+
limit = 1
|
25
|
+
weights = SearchWeightsRequest(
|
26
|
+
semantic_similarity=0.8,
|
27
|
+
keywords=0.2,
|
28
|
+
)
|
29
|
+
result_merging = SearchResultMergingRequest(
|
30
|
+
enabled=True,
|
31
|
+
)
|
32
|
+
filters = SearchFilters(
|
33
|
+
external_ids=["external_id"],
|
34
|
+
metadata=MetadataLogicalConditionGroup(
|
35
|
+
combinator="AND",
|
36
|
+
negated=False,
|
37
|
+
conditions=[
|
38
|
+
MetadataLogicalCondition(
|
39
|
+
lhs_variable="TYPE",
|
40
|
+
operator="=",
|
41
|
+
rhs_variable="COMPANY",
|
42
|
+
)
|
43
|
+
],
|
44
|
+
),
|
45
|
+
)
|
46
|
+
|
47
|
+
# AND a Search request that will return a 200 ok resposne
|
48
|
+
search_response = SearchResponse(
|
49
|
+
results=[
|
50
|
+
SearchResult(
|
51
|
+
text="Search query", score="0.0", keywords=["keywords"], document=SearchResultDocument(label="label")
|
52
|
+
)
|
53
|
+
]
|
54
|
+
)
|
55
|
+
|
56
|
+
vellum_client.search.return_value = search_response
|
57
|
+
|
58
|
+
# WHEN we run the workflow
|
59
|
+
outputs = MySearchNode().run()
|
60
|
+
|
61
|
+
# THEN the workflow should have completed successfully
|
62
|
+
assert outputs.text == "Search query"
|
63
|
+
|
64
|
+
# AND the options should be as expected
|
65
|
+
assert vellum_client.search.call_args.kwargs["options"] == SearchRequestOptionsRequest(
|
66
|
+
limit=1,
|
67
|
+
weights=SearchWeightsRequest(
|
68
|
+
semantic_similarity=0.8,
|
69
|
+
keywords=0.2,
|
70
|
+
),
|
71
|
+
result_merging=SearchResultMergingRequest(
|
72
|
+
enabled=True,
|
73
|
+
),
|
74
|
+
filters=SearchFiltersRequest(
|
75
|
+
external_ids=["external_id"],
|
76
|
+
metadata=VellumValueLogicalConditionGroupRequest(
|
77
|
+
combinator="AND",
|
78
|
+
negated=False,
|
79
|
+
conditions=[
|
80
|
+
VellumValueLogicalConditionRequest(
|
81
|
+
lhs_variable=StringVellumValueRequest(value="TYPE"),
|
82
|
+
operator="=",
|
83
|
+
rhs_variable=StringVellumValueRequest(value="COMPANY"),
|
84
|
+
)
|
85
|
+
],
|
86
|
+
),
|
87
|
+
),
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
def test_run_workflow__happy_path__options_attribute(vellum_client):
|
92
|
+
"""Confirm that we can successfully invoke a single Search node with the legacy options attribute"""
|
93
|
+
|
94
|
+
# GIVEN a workflow that's set up run a Search Node
|
95
|
+
class MySearchNode(SearchNode):
|
96
|
+
query = "Search query"
|
97
|
+
document_index = "document_index"
|
98
|
+
options = SearchRequestOptionsRequest(
|
99
|
+
limit=1,
|
100
|
+
weights=SearchWeightsRequest(
|
101
|
+
semantic_similarity=0.8,
|
102
|
+
keywords=0.2,
|
103
|
+
),
|
104
|
+
result_merging=SearchResultMergingRequest(
|
105
|
+
enabled=True,
|
106
|
+
),
|
107
|
+
filters=SearchFiltersRequest(
|
108
|
+
external_ids=["external_id"],
|
109
|
+
metadata=VellumValueLogicalConditionGroupRequest(
|
110
|
+
combinator="AND",
|
111
|
+
negated=False,
|
112
|
+
conditions=[
|
113
|
+
VellumValueLogicalConditionRequest(
|
114
|
+
lhs_variable=StringVellumValueRequest(value="TYPE"),
|
115
|
+
operator="=",
|
116
|
+
rhs_variable=StringVellumValueRequest(value="COMPANY"),
|
117
|
+
)
|
118
|
+
],
|
119
|
+
),
|
120
|
+
),
|
121
|
+
)
|
122
|
+
|
123
|
+
# AND a Search request that will return a 200 ok resposne
|
124
|
+
search_response = SearchResponse(
|
125
|
+
results=[
|
126
|
+
SearchResult(
|
127
|
+
text="Search query", score="0.0", keywords=["keywords"], document=SearchResultDocument(label="label")
|
128
|
+
)
|
129
|
+
]
|
130
|
+
)
|
131
|
+
|
132
|
+
vellum_client.search.return_value = search_response
|
133
|
+
|
134
|
+
# WHEN we run the workflow
|
135
|
+
outputs = MySearchNode().run()
|
136
|
+
|
137
|
+
# THEN the workflow should have completed successfully
|
138
|
+
assert outputs.text == "Search query"
|
139
|
+
|
140
|
+
# AND the options should be as expected
|
141
|
+
assert vellum_client.search.call_args.kwargs["options"] == SearchRequestOptionsRequest(
|
142
|
+
limit=1,
|
143
|
+
weights=SearchWeightsRequest(
|
144
|
+
semantic_similarity=0.8,
|
145
|
+
keywords=0.2,
|
146
|
+
),
|
147
|
+
result_merging=SearchResultMergingRequest(
|
148
|
+
enabled=True,
|
149
|
+
),
|
150
|
+
filters=SearchFiltersRequest(
|
151
|
+
external_ids=["external_id"],
|
152
|
+
metadata=VellumValueLogicalConditionGroupRequest(
|
153
|
+
combinator="AND",
|
154
|
+
negated=False,
|
155
|
+
conditions=[
|
156
|
+
VellumValueLogicalConditionRequest(
|
157
|
+
lhs_variable=StringVellumValueRequest(value="TYPE"),
|
158
|
+
operator="=",
|
159
|
+
rhs_variable=StringVellumValueRequest(value="COMPANY"),
|
160
|
+
)
|
161
|
+
],
|
162
|
+
),
|
163
|
+
),
|
164
|
+
)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from unittest import mock
|
1
2
|
from uuid import uuid4
|
2
3
|
from typing import Any, Iterator, List
|
3
4
|
|
@@ -33,7 +34,6 @@ def test_inline_text_prompt_node__basic(vellum_adhoc_prompt_client):
|
|
33
34
|
|
34
35
|
class MyInlinePromptNode(InlinePromptNode):
|
35
36
|
ml_model = "gpt-4o"
|
36
|
-
prompt_inputs = {}
|
37
37
|
blocks = []
|
38
38
|
|
39
39
|
# AND a known response from invoking an inline prompt
|
@@ -90,7 +90,7 @@ def test_inline_text_prompt_node__basic(vellum_adhoc_prompt_client):
|
|
90
90
|
logit_bias=None,
|
91
91
|
custom_parameters=None,
|
92
92
|
),
|
93
|
-
request_options=
|
93
|
+
request_options=mock.ANY,
|
94
94
|
)
|
95
95
|
|
96
96
|
|
@@ -107,7 +107,6 @@ def test_inline_text_prompt_node__catch_provider_error(vellum_adhoc_prompt_clien
|
|
107
107
|
@TryNode.wrap(on_error_code=WorkflowErrorCode.PROVIDER_ERROR)
|
108
108
|
class MyInlinePromptNode(InlinePromptNode):
|
109
109
|
ml_model = "gpt-4o"
|
110
|
-
prompt_inputs = {}
|
111
110
|
blocks = []
|
112
111
|
|
113
112
|
# AND a known response from invoking an inline prompt that fails
|
@@ -27,7 +27,6 @@ def test_text_prompt_deployment_node__basic(vellum_client):
|
|
27
27
|
|
28
28
|
class MyPromptDeploymentNode(PromptDeploymentNode):
|
29
29
|
deployment = "my-deployment"
|
30
|
-
prompt_inputs = {}
|
31
30
|
|
32
31
|
# AND a known response from invoking a deployed prompt
|
33
32
|
expected_outputs: List[PromptOutput] = [
|
@@ -4,7 +4,21 @@ import logging
|
|
4
4
|
from queue import Empty, Queue
|
5
5
|
from threading import Event as ThreadingEvent, Thread
|
6
6
|
from uuid import UUID
|
7
|
-
from typing import
|
7
|
+
from typing import (
|
8
|
+
TYPE_CHECKING,
|
9
|
+
Any,
|
10
|
+
Dict,
|
11
|
+
Generic,
|
12
|
+
Iterable,
|
13
|
+
Iterator,
|
14
|
+
List,
|
15
|
+
Optional,
|
16
|
+
Sequence,
|
17
|
+
Set,
|
18
|
+
Tuple,
|
19
|
+
Type,
|
20
|
+
Union,
|
21
|
+
)
|
8
22
|
|
9
23
|
from vellum.workflows.constants import UNDEF
|
10
24
|
from vellum.workflows.context import execution_context, get_parent_context
|
@@ -76,6 +90,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
76
90
|
cancel_signal: Optional[ThreadingEvent] = None,
|
77
91
|
node_output_mocks: Optional[List[BaseOutputs]] = None,
|
78
92
|
parent_context: Optional[ParentContext] = None,
|
93
|
+
max_concurrency: Optional[int] = None,
|
79
94
|
):
|
80
95
|
if state and external_inputs:
|
81
96
|
raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
|
@@ -120,6 +135,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
120
135
|
# This queue is responsible for sending events from the inner worker threads to WorkflowRunner
|
121
136
|
self._workflow_event_inner_queue: Queue[WorkflowEvent] = Queue()
|
122
137
|
|
138
|
+
self._max_concurrency = max_concurrency
|
139
|
+
self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[Edge]]] = Queue()
|
140
|
+
|
123
141
|
# This queue is responsible for sending events from WorkflowRunner to the background thread
|
124
142
|
# for user defined emitters
|
125
143
|
self._background_thread_queue: Queue[BackgroundThreadItem] = Queue()
|
@@ -350,7 +368,19 @@ class WorkflowRunner(Generic[StateType]):
|
|
350
368
|
else:
|
351
369
|
next_state = state
|
352
370
|
|
353
|
-
self.
|
371
|
+
if self._max_concurrency:
|
372
|
+
self._concurrency_queue.put((next_state, edge.to_node, edge))
|
373
|
+
else:
|
374
|
+
self._run_node_if_ready(next_state, edge.to_node, edge)
|
375
|
+
|
376
|
+
if self._max_concurrency:
|
377
|
+
num_nodes_to_run = self._max_concurrency - len(self._active_nodes_by_execution_id)
|
378
|
+
for _ in range(num_nodes_to_run):
|
379
|
+
if self._concurrency_queue.empty():
|
380
|
+
break
|
381
|
+
|
382
|
+
next_state, node_class, invoked_edge = self._concurrency_queue.get()
|
383
|
+
self._run_node_if_ready(next_state, node_class, invoked_edge)
|
354
384
|
|
355
385
|
def _run_node_if_ready(
|
356
386
|
self,
|
@@ -513,8 +543,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
513
543
|
)
|
514
544
|
for node_cls in self._entrypoints:
|
515
545
|
try:
|
516
|
-
|
517
|
-
|
546
|
+
if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
|
547
|
+
with execution_context(parent_context=current_parent):
|
548
|
+
self._run_node_if_ready(self._initial_state, node_cls)
|
549
|
+
else:
|
550
|
+
self._concurrency_queue.put((self._initial_state, node_cls, None))
|
518
551
|
except NodeException as e:
|
519
552
|
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error))
|
520
553
|
return
|
@@ -21,6 +21,7 @@ class ExampleClass:
|
|
21
21
|
zeta: ClassVar[str]
|
22
22
|
eta: List[str]
|
23
23
|
kappa: Any
|
24
|
+
mu: list[str]
|
24
25
|
|
25
26
|
|
26
27
|
T = TypeVar("T")
|
@@ -56,6 +57,7 @@ class ExampleNode(BaseNode):
|
|
56
57
|
(ExampleInheritedClass, "beta", (int,)),
|
57
58
|
(ExampleNode.Outputs, "iota", (str,)),
|
58
59
|
(ExampleClass, "kappa", (Any,)),
|
60
|
+
(ExampleClass, "mu", (list[str],)),
|
59
61
|
],
|
60
62
|
ids=[
|
61
63
|
"str",
|
@@ -71,6 +73,7 @@ class ExampleNode(BaseNode):
|
|
71
73
|
"inherited_parent_class_var",
|
72
74
|
"try_node_output",
|
73
75
|
"any",
|
76
|
+
"list_str_generic",
|
74
77
|
],
|
75
78
|
)
|
76
79
|
def test_infer_types(cls, attr_name, expected_type):
|
@@ -80,9 +83,9 @@ def test_infer_types(cls, attr_name, expected_type):
|
|
80
83
|
@pytest.mark.parametrize(
|
81
84
|
"cls, expected_attr_names",
|
82
85
|
[
|
83
|
-
(ExampleClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "kappa"}),
|
86
|
+
(ExampleClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "kappa", "mu"}),
|
84
87
|
(ExampleGenericClass, {"delta"}),
|
85
|
-
(ExampleInheritedClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "theta", "kappa"}),
|
88
|
+
(ExampleInheritedClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "theta", "kappa", "mu"}),
|
86
89
|
],
|
87
90
|
)
|
88
91
|
def test_class_attr_names(cls, expected_attr_names):
|
vellum/workflows/types/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from copy import deepcopy
|
2
2
|
from datetime import datetime
|
3
3
|
import importlib
|
4
|
+
from types import GenericAlias
|
4
5
|
from typing import (
|
5
6
|
Any,
|
6
7
|
ClassVar,
|
@@ -77,6 +78,9 @@ def infer_types(object_: Type, attr_name: str, localns: Optional[Dict[str, Any]]
|
|
77
78
|
return (type_hint,)
|
78
79
|
if isinstance(type_hint, SpecialGenericAlias):
|
79
80
|
return (type_hint,)
|
81
|
+
if isinstance(type_hint, GenericAlias):
|
82
|
+
# In future versions of python, list[str] will be a `GenericAlias`
|
83
|
+
return (cast(Type, type_hint),)
|
80
84
|
if isinstance(type_hint, TypeVar):
|
81
85
|
if type_hint in type_var_mapping:
|
82
86
|
return (type_var_mapping[type_hint],)
|
@@ -188,6 +188,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
188
188
|
external_inputs: Optional[ExternalInputsArg] = None,
|
189
189
|
cancel_signal: Optional[ThreadingEvent] = None,
|
190
190
|
node_output_mocks: Optional[List[BaseOutputs]] = None,
|
191
|
+
max_concurrency: Optional[int] = None,
|
191
192
|
) -> TerminalWorkflowEvent:
|
192
193
|
"""
|
193
194
|
Invoke a Workflow, returning the last event emitted, which should be one of:
|
@@ -215,6 +216,11 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
215
216
|
|
216
217
|
node_output_mocks: Optional[List[Outputs]] = None
|
217
218
|
A list of Outputs to mock for Nodes during Workflow Execution.
|
219
|
+
|
220
|
+
max_concurrency: Optional[int] = None
|
221
|
+
The max number of concurrent threads to run the Workflow with. If not provided, the Workflow will run
|
222
|
+
without limiting concurrency. This configuration only applies to the current Workflow and not to any
|
223
|
+
subworkflows or nodes that utilizes threads.
|
218
224
|
"""
|
219
225
|
|
220
226
|
events = WorkflowRunner(
|
@@ -226,6 +232,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
226
232
|
cancel_signal=cancel_signal,
|
227
233
|
node_output_mocks=node_output_mocks,
|
228
234
|
parent_context=self._context.parent_context,
|
235
|
+
max_concurrency=max_concurrency,
|
229
236
|
).stream()
|
230
237
|
first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
|
231
238
|
last_event = None
|
@@ -289,6 +296,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
289
296
|
external_inputs: Optional[ExternalInputsArg] = None,
|
290
297
|
cancel_signal: Optional[ThreadingEvent] = None,
|
291
298
|
node_output_mocks: Optional[List[BaseOutputs]] = None,
|
299
|
+
max_concurrency: Optional[int] = None,
|
292
300
|
) -> WorkflowEventStream:
|
293
301
|
"""
|
294
302
|
Invoke a Workflow, yielding events as they are emitted.
|
@@ -317,6 +325,11 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
317
325
|
|
318
326
|
node_output_mocks: Optional[List[Outputs]] = None
|
319
327
|
A list of Outputs to mock for Nodes during Workflow Execution.
|
328
|
+
|
329
|
+
max_concurrency: Optional[int] = None
|
330
|
+
The max number of concurrent threads to run the Workflow with. If not provided, the Workflow will run
|
331
|
+
without limiting concurrency. This configuration only applies to the current Workflow and not to any
|
332
|
+
subworkflows or nodes that utilizes threads.
|
320
333
|
"""
|
321
334
|
|
322
335
|
should_yield = event_filter or workflow_event_filter
|
@@ -329,6 +342,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
329
342
|
cancel_signal=cancel_signal,
|
330
343
|
node_output_mocks=node_output_mocks,
|
331
344
|
parent_context=self.context.parent_context,
|
345
|
+
max_concurrency=max_concurrency,
|
332
346
|
).stream():
|
333
347
|
if should_yield(self.__class__, event):
|
334
348
|
yield event
|