vellum-ai 0.14.15__py3-none-any.whl → 0.14.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. vellum/__init__.py +2 -0
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/resources/document_indexes/client.py +0 -55
  4. vellum/client/types/__init__.py +2 -0
  5. vellum/client/types/document_index_read.py +0 -10
  6. vellum/client/types/release.py +21 -0
  7. vellum/client/types/workflow_release_tag_read.py +7 -1
  8. vellum/plugins/pydantic.py +14 -4
  9. vellum/prompts/blocks/compilation.py +14 -0
  10. vellum/types/release.py +3 -0
  11. vellum/workflows/nodes/bases/base.py +7 -7
  12. vellum/workflows/nodes/bases/base_adornment_node.py +2 -0
  13. vellum/workflows/nodes/core/retry_node/node.py +1 -1
  14. vellum/workflows/nodes/core/try_node/node.py +1 -1
  15. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +4 -0
  16. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +27 -1
  17. vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/__init__.py +0 -0
  18. vellum/workflows/nodes/displayable/bases/inline_prompt_node/tests/test_inline_prompt_node.py +182 -0
  19. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +4 -1
  20. vellum/workflows/nodes/experimental/openai_chat_completion_node/node.py +7 -1
  21. vellum/workflows/utils/tests/test_vellum_variables.py +7 -1
  22. vellum/workflows/utils/vellum_variables.py +4 -0
  23. vellum/workflows/vellum_client.py +9 -5
  24. {vellum_ai-0.14.15.dist-info → vellum_ai-0.14.17.dist-info}/METADATA +1 -1
  25. {vellum_ai-0.14.15.dist-info → vellum_ai-0.14.17.dist-info}/RECORD +40 -35
  26. vellum_cli/image_push.py +76 -42
  27. vellum_cli/tests/test_image_push.py +56 -0
  28. vellum_ee/workflows/display/nodes/base_node_display.py +35 -29
  29. vellum_ee/workflows/display/nodes/get_node_display_class.py +0 -9
  30. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +38 -18
  31. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +1 -0
  32. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +29 -33
  33. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +91 -106
  34. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_outputs_serialization.py +33 -38
  35. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +138 -153
  36. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_trigger_serialization.py +23 -26
  37. vellum_ee/workflows/display/workflows/tests/test_workflow_display.py +51 -7
  38. {vellum_ai-0.14.15.dist-info → vellum_ai-0.14.17.dist-info}/LICENSE +0 -0
  39. {vellum_ai-0.14.15.dist-info → vellum_ai-0.14.17.dist-info}/WHEEL +0 -0
  40. {vellum_ai-0.14.15.dist-info → vellum_ai-0.14.17.dist-info}/entry_points.txt +0 -0
vellum/__init__.py CHANGED
@@ -296,6 +296,7 @@ from .types import (
296
296
  RejectedExecuteWorkflowWorkflowResultEvent,
297
297
  RejectedPromptExecutionMeta,
298
298
  RejectedWorkflowNodeResultEvent,
299
+ Release,
299
300
  ReleaseTagSource,
300
301
  ReplaceTestSuiteTestCaseRequest,
301
302
  RichTextChildBlock,
@@ -844,6 +845,7 @@ __all__ = [
844
845
  "RejectedExecuteWorkflowWorkflowResultEvent",
845
846
  "RejectedPromptExecutionMeta",
846
847
  "RejectedWorkflowNodeResultEvent",
848
+ "Release",
847
849
  "ReleaseTagSource",
848
850
  "ReplaceTestSuiteTestCaseRequest",
849
851
  "RichTextChildBlock",
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.15",
21
+ "X-Fern-SDK-Version": "0.14.17",
22
22
  }
23
23
  headers["X_API_KEY"] = self.api_key
24
24
  return headers
@@ -10,7 +10,6 @@ from json.decoder import JSONDecodeError
10
10
  from ...core.api_error import ApiError
11
11
  from ...types.document_index_indexing_config_request import DocumentIndexIndexingConfigRequest
12
12
  from ...types.entity_status import EntityStatus
13
- from ...types.environment_enum import EnvironmentEnum
14
13
  from ...types.document_index_read import DocumentIndexRead
15
14
  from ...core.serialization import convert_and_respect_annotation_metadata
16
15
  from ...core.jsonable_encoder import jsonable_encoder
@@ -108,7 +107,6 @@ class DocumentIndexesClient:
108
107
  name: str,
109
108
  indexing_config: DocumentIndexIndexingConfigRequest,
110
109
  status: typing.Optional[EntityStatus] = OMIT,
111
- environment: typing.Optional[EnvironmentEnum] = OMIT,
112
110
  copy_documents_from_index_id: typing.Optional[str] = OMIT,
113
111
  request_options: typing.Optional[RequestOptions] = None,
114
112
  ) -> DocumentIndexRead:
@@ -131,13 +129,6 @@ class DocumentIndexesClient:
131
129
  * `ACTIVE` - Active
132
130
  * `ARCHIVED` - Archived
133
131
 
134
- environment : typing.Optional[EnvironmentEnum]
135
- The environment this document index is used in
136
-
137
- * `DEVELOPMENT` - Development
138
- * `STAGING` - Staging
139
- * `PRODUCTION` - Production
140
-
141
132
  copy_documents_from_index_id : typing.Optional[str]
142
133
  Optionally specify the id of a document index from which you'd like to copy and re-index its documents into this newly created index
143
134
 
@@ -191,7 +182,6 @@ class DocumentIndexesClient:
191
182
  "label": label,
192
183
  "name": name,
193
184
  "status": status,
194
- "environment": environment,
195
185
  "indexing_config": convert_and_respect_annotation_metadata(
196
186
  object_=indexing_config, annotation=DocumentIndexIndexingConfigRequest, direction="write"
197
187
  ),
@@ -268,7 +258,6 @@ class DocumentIndexesClient:
268
258
  *,
269
259
  label: str,
270
260
  status: typing.Optional[EntityStatus] = OMIT,
271
- environment: typing.Optional[EnvironmentEnum] = OMIT,
272
261
  request_options: typing.Optional[RequestOptions] = None,
273
262
  ) -> DocumentIndexRead:
274
263
  """
@@ -288,13 +277,6 @@ class DocumentIndexesClient:
288
277
  * `ACTIVE` - Active
289
278
  * `ARCHIVED` - Archived
290
279
 
291
- environment : typing.Optional[EnvironmentEnum]
292
- The environment this document index is used in
293
-
294
- * `DEVELOPMENT` - Development
295
- * `STAGING` - Staging
296
- * `PRODUCTION` - Production
297
-
298
280
  request_options : typing.Optional[RequestOptions]
299
281
  Request-specific configuration.
300
282
 
@@ -322,7 +304,6 @@ class DocumentIndexesClient:
322
304
  json={
323
305
  "label": label,
324
306
  "status": status,
325
- "environment": environment,
326
307
  },
327
308
  request_options=request_options,
328
309
  omit=OMIT,
@@ -388,7 +369,6 @@ class DocumentIndexesClient:
388
369
  *,
389
370
  label: typing.Optional[str] = OMIT,
390
371
  status: typing.Optional[EntityStatus] = OMIT,
391
- environment: typing.Optional[EnvironmentEnum] = OMIT,
392
372
  request_options: typing.Optional[RequestOptions] = None,
393
373
  ) -> DocumentIndexRead:
394
374
  """
@@ -408,13 +388,6 @@ class DocumentIndexesClient:
408
388
  * `ACTIVE` - Active
409
389
  * `ARCHIVED` - Archived
410
390
 
411
- environment : typing.Optional[EnvironmentEnum]
412
- The environment this document index is used in
413
-
414
- * `DEVELOPMENT` - Development
415
- * `STAGING` - Staging
416
- * `PRODUCTION` - Production
417
-
418
391
  request_options : typing.Optional[RequestOptions]
419
392
  Request-specific configuration.
420
393
 
@@ -441,7 +414,6 @@ class DocumentIndexesClient:
441
414
  json={
442
415
  "label": label,
443
416
  "status": status,
444
- "environment": environment,
445
417
  },
446
418
  request_options=request_options,
447
419
  omit=OMIT,
@@ -651,7 +623,6 @@ class AsyncDocumentIndexesClient:
651
623
  name: str,
652
624
  indexing_config: DocumentIndexIndexingConfigRequest,
653
625
  status: typing.Optional[EntityStatus] = OMIT,
654
- environment: typing.Optional[EnvironmentEnum] = OMIT,
655
626
  copy_documents_from_index_id: typing.Optional[str] = OMIT,
656
627
  request_options: typing.Optional[RequestOptions] = None,
657
628
  ) -> DocumentIndexRead:
@@ -674,13 +645,6 @@ class AsyncDocumentIndexesClient:
674
645
  * `ACTIVE` - Active
675
646
  * `ARCHIVED` - Archived
676
647
 
677
- environment : typing.Optional[EnvironmentEnum]
678
- The environment this document index is used in
679
-
680
- * `DEVELOPMENT` - Development
681
- * `STAGING` - Staging
682
- * `PRODUCTION` - Production
683
-
684
648
  copy_documents_from_index_id : typing.Optional[str]
685
649
  Optionally specify the id of a document index from which you'd like to copy and re-index its documents into this newly created index
686
650
 
@@ -742,7 +706,6 @@ class AsyncDocumentIndexesClient:
742
706
  "label": label,
743
707
  "name": name,
744
708
  "status": status,
745
- "environment": environment,
746
709
  "indexing_config": convert_and_respect_annotation_metadata(
747
710
  object_=indexing_config, annotation=DocumentIndexIndexingConfigRequest, direction="write"
748
711
  ),
@@ -827,7 +790,6 @@ class AsyncDocumentIndexesClient:
827
790
  *,
828
791
  label: str,
829
792
  status: typing.Optional[EntityStatus] = OMIT,
830
- environment: typing.Optional[EnvironmentEnum] = OMIT,
831
793
  request_options: typing.Optional[RequestOptions] = None,
832
794
  ) -> DocumentIndexRead:
833
795
  """
@@ -847,13 +809,6 @@ class AsyncDocumentIndexesClient:
847
809
  * `ACTIVE` - Active
848
810
  * `ARCHIVED` - Archived
849
811
 
850
- environment : typing.Optional[EnvironmentEnum]
851
- The environment this document index is used in
852
-
853
- * `DEVELOPMENT` - Development
854
- * `STAGING` - Staging
855
- * `PRODUCTION` - Production
856
-
857
812
  request_options : typing.Optional[RequestOptions]
858
813
  Request-specific configuration.
859
814
 
@@ -889,7 +844,6 @@ class AsyncDocumentIndexesClient:
889
844
  json={
890
845
  "label": label,
891
846
  "status": status,
892
- "environment": environment,
893
847
  },
894
848
  request_options=request_options,
895
849
  omit=OMIT,
@@ -963,7 +917,6 @@ class AsyncDocumentIndexesClient:
963
917
  *,
964
918
  label: typing.Optional[str] = OMIT,
965
919
  status: typing.Optional[EntityStatus] = OMIT,
966
- environment: typing.Optional[EnvironmentEnum] = OMIT,
967
920
  request_options: typing.Optional[RequestOptions] = None,
968
921
  ) -> DocumentIndexRead:
969
922
  """
@@ -983,13 +936,6 @@ class AsyncDocumentIndexesClient:
983
936
  * `ACTIVE` - Active
984
937
  * `ARCHIVED` - Archived
985
938
 
986
- environment : typing.Optional[EnvironmentEnum]
987
- The environment this document index is used in
988
-
989
- * `DEVELOPMENT` - Development
990
- * `STAGING` - Staging
991
- * `PRODUCTION` - Production
992
-
993
939
  request_options : typing.Optional[RequestOptions]
994
940
  Request-specific configuration.
995
941
 
@@ -1024,7 +970,6 @@ class AsyncDocumentIndexesClient:
1024
970
  json={
1025
971
  "label": label,
1026
972
  "status": status,
1027
- "environment": environment,
1028
973
  },
1029
974
  request_options=request_options,
1030
975
  omit=OMIT,
@@ -304,6 +304,7 @@ from .rejected_execute_prompt_response import RejectedExecutePromptResponse
304
304
  from .rejected_execute_workflow_workflow_result_event import RejectedExecuteWorkflowWorkflowResultEvent
305
305
  from .rejected_prompt_execution_meta import RejectedPromptExecutionMeta
306
306
  from .rejected_workflow_node_result_event import RejectedWorkflowNodeResultEvent
307
+ from .release import Release
307
308
  from .release_tag_source import ReleaseTagSource
308
309
  from .replace_test_suite_test_case_request import ReplaceTestSuiteTestCaseRequest
309
310
  from .rich_text_child_block import RichTextChildBlock
@@ -825,6 +826,7 @@ __all__ = [
825
826
  "RejectedExecuteWorkflowWorkflowResultEvent",
826
827
  "RejectedPromptExecutionMeta",
827
828
  "RejectedWorkflowNodeResultEvent",
829
+ "Release",
828
830
  "ReleaseTagSource",
829
831
  "ReplaceTestSuiteTestCaseRequest",
830
832
  "RichTextChildBlock",
@@ -5,7 +5,6 @@ import datetime as dt
5
5
  import pydantic
6
6
  import typing
7
7
  from .entity_status import EntityStatus
8
- from .environment_enum import EnvironmentEnum
9
8
  from .document_index_indexing_config import DocumentIndexIndexingConfig
10
9
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
11
10
 
@@ -31,15 +30,6 @@ class DocumentIndexRead(UniversalBaseModel):
31
30
  * `ARCHIVED` - Archived
32
31
  """
33
32
 
34
- environment: typing.Optional[EnvironmentEnum] = pydantic.Field(default=None)
35
- """
36
- The environment this document index is used in
37
-
38
- * `DEVELOPMENT` - Development
39
- * `STAGING` - Staging
40
- * `PRODUCTION` - Production
41
- """
42
-
43
33
  indexing_config: DocumentIndexIndexingConfig
44
34
 
45
35
  if IS_PYDANTIC_V2:
@@ -0,0 +1,21 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from ..core.pydantic_utilities import UniversalBaseModel
4
+ import datetime as dt
5
+ from ..core.pydantic_utilities import IS_PYDANTIC_V2
6
+ import typing
7
+ import pydantic
8
+
9
+
10
+ class Release(UniversalBaseModel):
11
+ id: str
12
+ timestamp: dt.datetime
13
+
14
+ if IS_PYDANTIC_V2:
15
+ model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
16
+ else:
17
+
18
+ class Config:
19
+ frozen = True
20
+ smart_union = True
21
+ extra = pydantic.Extra.allow
@@ -4,6 +4,7 @@ from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import pydantic
5
5
  from .release_tag_source import ReleaseTagSource
6
6
  from .workflow_release_tag_workflow_deployment_history_item import WorkflowReleaseTagWorkflowDeploymentHistoryItem
7
+ from .release import Release
7
8
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
9
  import typing
9
10
 
@@ -24,7 +25,12 @@ class WorkflowReleaseTagRead(UniversalBaseModel):
24
25
 
25
26
  history_item: WorkflowReleaseTagWorkflowDeploymentHistoryItem = pydantic.Field()
26
27
  """
27
- The Workflow Deployment History Item that this Release Tag is associated with
28
+ Deprecated. Reference the `release` field instead.
29
+ """
30
+
31
+ release: Release = pydantic.Field()
32
+ """
33
+ The Release that this Release Tag points to.
28
34
  """
29
35
 
30
36
  if IS_PYDANTIC_V2:
@@ -1,17 +1,27 @@
1
1
  from functools import lru_cache
2
- from typing import Any, Dict, Literal, Optional, Tuple, Union
2
+ from typing import Any, Dict, Literal, NamedTuple, Optional, Tuple, Union
3
+ from typing_extensions import TypeAlias
3
4
 
4
5
  from pydantic.fields import FieldInfo
5
6
  from pydantic.plugin import (
6
7
  PydanticPluginProtocol,
7
- SchemaKind,
8
- SchemaTypePath,
9
8
  ValidateJsonHandlerProtocol,
10
9
  ValidatePythonHandlerProtocol,
11
10
  ValidateStringsHandlerProtocol,
12
11
  )
13
12
  from pydantic_core import CoreSchema
14
13
 
14
+ # Redefined manually instead of imported from pydantic to support versions < 2.5
15
+ SchemaKind: TypeAlias = Literal["BaseModel", "TypeAdapter", "dataclass", "create_model", "validate_call"]
16
+
17
+
18
+ # Redefined manually instead of imported from pydantic to support versions < 2.5
19
+ class SchemaTypePath(NamedTuple):
20
+ """Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""
21
+
22
+ module: str
23
+ name: str
24
+
15
25
 
16
26
  @lru_cache(maxsize=1)
17
27
  def import_base_descriptor():
@@ -81,7 +91,7 @@ class VellumPydanticPlugin(PydanticPluginProtocol):
81
91
  self,
82
92
  schema: CoreSchema,
83
93
  schema_type: Any,
84
- schema_type_path: SchemaTypePath,
94
+ schema_type_path: SchemaTypePath, # type: ignore
85
95
  schema_kind: SchemaKind,
86
96
  config: Any,
87
97
  plugin_settings: Dict[str, Any],
@@ -3,11 +3,13 @@ from typing import Sequence, Union, cast
3
3
 
4
4
  from vellum import (
5
5
  ChatMessage,
6
+ DocumentVellumValue,
6
7
  JsonVellumValue,
7
8
  PromptBlock,
8
9
  PromptRequestInput,
9
10
  RichTextPromptBlock,
10
11
  StringVellumValue,
12
+ VellumDocument,
11
13
  VellumVariable,
12
14
  )
13
15
  from vellum.client.types.audio_vellum_value import AudioVellumValue
@@ -159,6 +161,18 @@ def compile_prompt_blocks(
159
161
  cache_config=block.cache_config,
160
162
  )
161
163
  compiled_blocks.append(audio_block)
164
+
165
+ elif block.block_type == "DOCUMENT":
166
+ document_block = CompiledValuePromptBlock(
167
+ content=DocumentVellumValue(
168
+ value=VellumDocument(
169
+ src=block.src,
170
+ metadata=block.metadata,
171
+ )
172
+ ),
173
+ cache_config=block.cache_config,
174
+ )
175
+ compiled_blocks.append(document_block)
162
176
  else:
163
177
  raise PromptCompilationError(f"Unknown block_type: {block.block_type}")
164
178
 
@@ -0,0 +1,3 @@
1
+ # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
+
3
+ from vellum.client.types.release import *
@@ -293,14 +293,14 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
293
293
 
294
294
  if cls.merge_behavior == MergeBehavior.AWAIT_ALL:
295
295
  """
296
- A node utilizing an AWAIT_ALL merge strategy will only be considered ready for the Nth time
297
- when all of its dependencies have been executed N times.
296
+ A node utilizing an AWAIT_ALL merge strategy will only be considered ready
297
+ when all of its dependencies have invoked this node.
298
298
  """
299
- current_node_execution_count = state.meta.node_execution_cache.get_execution_count(cls.node_class)
300
- return all(
301
- state.meta.node_execution_cache.get_execution_count(dep) == current_node_execution_count + 1
302
- for dep in dependencies
303
- )
299
+ # Check if all dependencies have invoked this node
300
+ dependencies_invoked = state.meta.node_execution_cache._dependencies_invoked.get(node_span_id, set())
301
+ all_deps_invoked = all(dep in dependencies_invoked for dep in dependencies)
302
+
303
+ return all_deps_invoked
304
304
 
305
305
  raise NodeException(
306
306
  message="Invalid Trigger Node Specification",
@@ -73,3 +73,5 @@ class BaseAdornmentNode(
73
73
  # Subclasses of BaseAdornableNode can override this method to provider their own
74
74
  # approach to annotating the outputs class based on the `subworkflow.Outputs`
75
75
  setattr(outputs_class, reference.name, reference)
76
+ if cls.__wrapped_node__:
77
+ cls.__output_ids__[reference.name] = cls.__wrapped_node__.__output_ids__[reference.name]
@@ -85,7 +85,7 @@ Message: {terminal_event.error.message}""",
85
85
  @classmethod
86
86
  def wrap(
87
87
  cls,
88
- max_attempts: int,
88
+ max_attempts: int = 3,
89
89
  delay: Optional[float] = None,
90
90
  retry_on_error_code: Optional[WorkflowErrorCode] = None,
91
91
  retry_on_condition: Optional[BaseDescriptor] = None,
@@ -102,4 +102,4 @@ Message: {event.error.message}""",
102
102
  if reference.name == "error":
103
103
  raise ValueError("`error` is a reserved name for TryNode.Outputs")
104
104
 
105
- setattr(outputs_class, reference.name, reference)
105
+ super().__annotate_outputs_class__(outputs_class, reference)
@@ -28,6 +28,9 @@ class BasePromptNode(BaseNode, Generic[StateType]):
28
28
  def _get_prompt_event_stream(self) -> Union[Iterator[AdHocExecutePromptEvent], Iterator[ExecutePromptEvent]]:
29
29
  pass
30
30
 
31
+ def _validate(self) -> None:
32
+ pass
33
+
31
34
  def run(self) -> Iterator[BaseOutput]:
32
35
  outputs = yield from self._process_prompt_event_stream()
33
36
  if outputs is None:
@@ -37,6 +40,7 @@ class BasePromptNode(BaseNode, Generic[StateType]):
37
40
  )
38
41
 
39
42
  def _process_prompt_event_stream(self) -> Generator[BaseOutput, None, Optional[List[PromptOutput]]]:
43
+ self._validate()
40
44
  try:
41
45
  prompt_event_stream = self._get_prompt_event_stream()
42
46
  except ApiError as e:
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from uuid import uuid4
3
- from typing import Callable, ClassVar, Generic, Iterator, List, Optional, Tuple, Union
3
+ from typing import Callable, ClassVar, Generic, Iterator, List, Optional, Set, Tuple, Union
4
4
 
5
5
  from vellum import (
6
6
  AdHocExecutePromptEvent,
@@ -18,6 +18,7 @@ from vellum import (
18
18
  from vellum.client import RequestOptions
19
19
  from vellum.client.types.chat_message_request import ChatMessageRequest
20
20
  from vellum.client.types.prompt_settings import PromptSettings
21
+ from vellum.client.types.rich_text_child_block import RichTextChildBlock
21
22
  from vellum.workflows.constants import OMIT
22
23
  from vellum.workflows.context import get_execution_context
23
24
  from vellum.workflows.errors import WorkflowErrorCode
@@ -59,6 +60,31 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
59
60
  class Trigger(BasePromptNode.Trigger):
60
61
  merge_behavior = MergeBehavior.AWAIT_ANY
61
62
 
63
+ def _extract_required_input_variables(self, blocks: Union[List[PromptBlock], List[RichTextChildBlock]]) -> Set[str]:
64
+ required_variables = set()
65
+
66
+ for block in blocks:
67
+ if block.block_type == "VARIABLE":
68
+ required_variables.add(block.input_variable)
69
+ elif block.block_type == "CHAT_MESSAGE" and block.blocks:
70
+ required_variables.update(self._extract_required_input_variables(block.blocks))
71
+ elif block.block_type == "RICH_TEXT" and block.blocks:
72
+ required_variables.update(self._extract_required_input_variables(block.blocks))
73
+
74
+ return required_variables
75
+
76
+ def _validate(self) -> None:
77
+ required_variables = self._extract_required_input_variables(self.blocks)
78
+ provided_variables = set(self.prompt_inputs.keys() if self.prompt_inputs else set())
79
+
80
+ missing_variables = required_variables - provided_variables
81
+ if missing_variables:
82
+ missing_vars_str = ", ".join(f"'{var}'" for var in missing_variables)
83
+ raise NodeException(
84
+ message=f"Missing required input variables by VariablePromptBlock: {missing_vars_str}",
85
+ code=WorkflowErrorCode.INVALID_INPUTS,
86
+ )
87
+
62
88
  def _get_prompt_event_stream(self) -> Iterator[AdHocExecutePromptEvent]:
63
89
  input_variables, input_values = self._compile_prompt_inputs()
64
90
  current_context = get_execution_context()
@@ -0,0 +1,182 @@
1
+ import pytest
2
+ from uuid import uuid4
3
+ from typing import Any, Iterator, List
4
+
5
+ from vellum import (
6
+ ChatMessagePromptBlock,
7
+ JinjaPromptBlock,
8
+ PlainTextPromptBlock,
9
+ PromptBlock,
10
+ RichTextPromptBlock,
11
+ VariablePromptBlock,
12
+ )
13
+ from vellum.client.types.execute_prompt_event import ExecutePromptEvent
14
+ from vellum.client.types.fulfilled_execute_prompt_event import FulfilledExecutePromptEvent
15
+ from vellum.client.types.initiated_execute_prompt_event import InitiatedExecutePromptEvent
16
+ from vellum.client.types.prompt_output import PromptOutput
17
+ from vellum.client.types.prompt_request_string_input import PromptRequestStringInput
18
+ from vellum.client.types.string_vellum_value import StringVellumValue
19
+ from vellum.workflows.errors import WorkflowErrorCode
20
+ from vellum.workflows.exceptions import NodeException
21
+ from vellum.workflows.nodes.displayable.bases.inline_prompt_node import BaseInlinePromptNode
22
+
23
+
24
+ def test_validation_with_missing_variables():
25
+ """Test that validation correctly identifies missing variables."""
26
+ test_blocks: List[PromptBlock] = [
27
+ VariablePromptBlock(input_variable="required_var1"),
28
+ VariablePromptBlock(input_variable="required_var2"),
29
+ RichTextPromptBlock(
30
+ blocks=[
31
+ PlainTextPromptBlock(text="Some text"),
32
+ VariablePromptBlock(input_variable="required_var3"),
33
+ ],
34
+ ),
35
+ JinjaPromptBlock(template="Template without variables"),
36
+ ChatMessagePromptBlock(
37
+ chat_role="USER",
38
+ blocks=[
39
+ RichTextPromptBlock(
40
+ blocks=[
41
+ PlainTextPromptBlock(text="Nested text"),
42
+ VariablePromptBlock(input_variable="required_var4"),
43
+ ],
44
+ ),
45
+ ],
46
+ ),
47
+ ]
48
+
49
+ # GIVEN a BaseInlinePromptNode
50
+ class TestNode(BaseInlinePromptNode):
51
+ ml_model = "test-model"
52
+ blocks = test_blocks
53
+ prompt_inputs = {
54
+ "required_var1": "value1",
55
+ # required_var2 is missing
56
+ # required_var3 is missing
57
+ # required_var4 is missing
58
+ }
59
+
60
+ # WHEN the node is run
61
+ node = TestNode()
62
+ with pytest.raises(NodeException) as excinfo:
63
+ list(node.run())
64
+
65
+ # THEN the node raises the correct NodeException
66
+ assert excinfo.value.code == WorkflowErrorCode.INVALID_INPUTS
67
+ assert "required_var2" in str(excinfo.value)
68
+ assert "required_var3" in str(excinfo.value)
69
+ assert "required_var4" in str(excinfo.value)
70
+
71
+
72
+ def test_validation_with_all_variables_provided(vellum_adhoc_prompt_client):
73
+ """Test that validation passes when all variables are provided."""
74
+ test_blocks: List[PromptBlock] = [
75
+ VariablePromptBlock(input_variable="required_var1"),
76
+ VariablePromptBlock(input_variable="required_var2"),
77
+ RichTextPromptBlock(
78
+ blocks=[
79
+ PlainTextPromptBlock(text="Some text"),
80
+ VariablePromptBlock(input_variable="required_var3"),
81
+ ],
82
+ ),
83
+ JinjaPromptBlock(template="Template without variables"),
84
+ ChatMessagePromptBlock(
85
+ chat_role="USER",
86
+ blocks=[
87
+ RichTextPromptBlock(
88
+ blocks=[
89
+ PlainTextPromptBlock(text="Nested text"),
90
+ VariablePromptBlock(input_variable="required_var4"),
91
+ ],
92
+ ),
93
+ ],
94
+ ),
95
+ ]
96
+
97
+ # GIVEN a BaseInlinePromptNode
98
+ class TestNode(BaseInlinePromptNode):
99
+ ml_model = "test-model"
100
+ blocks = test_blocks
101
+ prompt_inputs = {
102
+ "required_var1": "value1",
103
+ "required_var2": "value2",
104
+ "required_var3": "value3",
105
+ "required_var4": "value4",
106
+ }
107
+
108
+ expected_outputs: List[PromptOutput] = [
109
+ StringVellumValue(value="Test response"),
110
+ ]
111
+
112
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
113
+ execution_id = str(uuid4())
114
+ events: List[ExecutePromptEvent] = [
115
+ InitiatedExecutePromptEvent(execution_id=execution_id),
116
+ FulfilledExecutePromptEvent(
117
+ execution_id=execution_id,
118
+ outputs=expected_outputs,
119
+ ),
120
+ ]
121
+ yield from events
122
+
123
+ vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
124
+
125
+ # WHEN the node is run
126
+ node = TestNode()
127
+ list(node.run())
128
+
129
+ # THEN the prompt is executed with the correct inputs
130
+ mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
131
+ assert mock_api.call_count == 1
132
+ assert mock_api.call_args.kwargs["input_values"] == [
133
+ PromptRequestStringInput(key="required_var1", type="STRING", value="value1"),
134
+ PromptRequestStringInput(key="required_var2", type="STRING", value="value2"),
135
+ PromptRequestStringInput(key="required_var3", type="STRING", value="value3"),
136
+ PromptRequestStringInput(key="required_var4", type="STRING", value="value4"),
137
+ ]
138
+
139
+
140
+ def test_validation_with_extra_variables(vellum_adhoc_prompt_client):
141
+ """Test that validation passes when extra variables are provided."""
142
+ test_blocks: List[PromptBlock] = [
143
+ VariablePromptBlock(input_variable="required_var"),
144
+ ]
145
+
146
+ # GIVEN a BaseInlinePromptNode
147
+ class TestNode(BaseInlinePromptNode):
148
+ ml_model = "test-model"
149
+ blocks = test_blocks
150
+ prompt_inputs = {
151
+ "required_var": "value",
152
+ "extra_var": "extra_value", # This is not required
153
+ }
154
+
155
+ expected_outputs: List[PromptOutput] = [
156
+ StringVellumValue(value="Test response"),
157
+ ]
158
+
159
+ def generate_prompt_events(*args: Any, **kwargs: Any) -> Iterator[ExecutePromptEvent]:
160
+ execution_id = str(uuid4())
161
+ events: List[ExecutePromptEvent] = [
162
+ InitiatedExecutePromptEvent(execution_id=execution_id),
163
+ FulfilledExecutePromptEvent(
164
+ execution_id=execution_id,
165
+ outputs=expected_outputs,
166
+ ),
167
+ ]
168
+ yield from events
169
+
170
+ vellum_adhoc_prompt_client.adhoc_execute_prompt_stream.side_effect = generate_prompt_events
171
+
172
+ # WHEN the node is run
173
+ node = TestNode()
174
+ list(node.run())
175
+
176
+ # THEN the prompt is executed with the correct inputs
177
+ mock_api = vellum_adhoc_prompt_client.adhoc_execute_prompt_stream
178
+ assert mock_api.call_count == 1
179
+ assert mock_api.call_args.kwargs["input_values"] == [
180
+ PromptRequestStringInput(key="required_var", type="STRING", value="value"),
181
+ PromptRequestStringInput(key="extra_var", type="STRING", value="extra_value"),
182
+ ]