vellum-ai 0.9.16rc2__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (245) hide show
  1. vellum/plugins/__init__.py +0 -0
  2. vellum/plugins/pydantic.py +74 -0
  3. vellum/plugins/utils.py +19 -0
  4. vellum/plugins/vellum_mypy.py +639 -3
  5. vellum/workflows/README.md +90 -0
  6. vellum/workflows/__init__.py +5 -0
  7. vellum/workflows/constants.py +43 -0
  8. vellum/workflows/descriptors/__init__.py +0 -0
  9. vellum/workflows/descriptors/base.py +339 -0
  10. vellum/workflows/descriptors/tests/test_utils.py +83 -0
  11. vellum/workflows/descriptors/utils.py +90 -0
  12. vellum/workflows/edges/__init__.py +5 -0
  13. vellum/workflows/edges/edge.py +23 -0
  14. vellum/workflows/emitters/__init__.py +5 -0
  15. vellum/workflows/emitters/base.py +14 -0
  16. vellum/workflows/environment/__init__.py +5 -0
  17. vellum/workflows/environment/environment.py +7 -0
  18. vellum/workflows/errors/__init__.py +6 -0
  19. vellum/workflows/errors/types.py +20 -0
  20. vellum/workflows/events/__init__.py +31 -0
  21. vellum/workflows/events/node.py +125 -0
  22. vellum/workflows/events/tests/__init__.py +0 -0
  23. vellum/workflows/events/tests/test_event.py +216 -0
  24. vellum/workflows/events/types.py +52 -0
  25. vellum/workflows/events/utils.py +5 -0
  26. vellum/workflows/events/workflow.py +139 -0
  27. vellum/workflows/exceptions.py +15 -0
  28. vellum/workflows/expressions/__init__.py +0 -0
  29. vellum/workflows/expressions/accessor.py +52 -0
  30. vellum/workflows/expressions/and_.py +32 -0
  31. vellum/workflows/expressions/begins_with.py +31 -0
  32. vellum/workflows/expressions/between.py +38 -0
  33. vellum/workflows/expressions/coalesce_expression.py +41 -0
  34. vellum/workflows/expressions/contains.py +30 -0
  35. vellum/workflows/expressions/does_not_begin_with.py +31 -0
  36. vellum/workflows/expressions/does_not_contain.py +30 -0
  37. vellum/workflows/expressions/does_not_end_with.py +31 -0
  38. vellum/workflows/expressions/does_not_equal.py +25 -0
  39. vellum/workflows/expressions/ends_with.py +31 -0
  40. vellum/workflows/expressions/equals.py +25 -0
  41. vellum/workflows/expressions/greater_than.py +33 -0
  42. vellum/workflows/expressions/greater_than_or_equal_to.py +33 -0
  43. vellum/workflows/expressions/in_.py +31 -0
  44. vellum/workflows/expressions/is_blank.py +24 -0
  45. vellum/workflows/expressions/is_not_blank.py +24 -0
  46. vellum/workflows/expressions/is_not_null.py +21 -0
  47. vellum/workflows/expressions/is_not_undefined.py +22 -0
  48. vellum/workflows/expressions/is_null.py +21 -0
  49. vellum/workflows/expressions/is_undefined.py +22 -0
  50. vellum/workflows/expressions/less_than.py +33 -0
  51. vellum/workflows/expressions/less_than_or_equal_to.py +33 -0
  52. vellum/workflows/expressions/not_between.py +38 -0
  53. vellum/workflows/expressions/not_in.py +31 -0
  54. vellum/workflows/expressions/or_.py +32 -0
  55. vellum/workflows/graph/__init__.py +3 -0
  56. vellum/workflows/graph/graph.py +131 -0
  57. vellum/workflows/graph/tests/__init__.py +0 -0
  58. vellum/workflows/graph/tests/test_graph.py +437 -0
  59. vellum/workflows/inputs/__init__.py +5 -0
  60. vellum/workflows/inputs/base.py +55 -0
  61. vellum/workflows/logging.py +14 -0
  62. vellum/workflows/nodes/__init__.py +46 -0
  63. vellum/workflows/nodes/bases/__init__.py +7 -0
  64. vellum/workflows/nodes/bases/base.py +332 -0
  65. vellum/workflows/nodes/bases/base_subworkflow_node/__init__.py +5 -0
  66. vellum/workflows/nodes/bases/base_subworkflow_node/node.py +10 -0
  67. vellum/workflows/nodes/bases/tests/__init__.py +0 -0
  68. vellum/workflows/nodes/bases/tests/test_base_node.py +125 -0
  69. vellum/workflows/nodes/core/__init__.py +16 -0
  70. vellum/workflows/nodes/core/error_node/__init__.py +5 -0
  71. vellum/workflows/nodes/core/error_node/node.py +26 -0
  72. vellum/workflows/nodes/core/inline_subworkflow_node/__init__.py +5 -0
  73. vellum/workflows/nodes/core/inline_subworkflow_node/node.py +73 -0
  74. vellum/workflows/nodes/core/map_node/__init__.py +5 -0
  75. vellum/workflows/nodes/core/map_node/node.py +147 -0
  76. vellum/workflows/nodes/core/map_node/tests/__init__.py +0 -0
  77. vellum/workflows/nodes/core/map_node/tests/test_node.py +65 -0
  78. vellum/workflows/nodes/core/retry_node/__init__.py +5 -0
  79. vellum/workflows/nodes/core/retry_node/node.py +106 -0
  80. vellum/workflows/nodes/core/retry_node/tests/__init__.py +0 -0
  81. vellum/workflows/nodes/core/retry_node/tests/test_node.py +93 -0
  82. vellum/workflows/nodes/core/templating_node/__init__.py +5 -0
  83. vellum/workflows/nodes/core/templating_node/custom_filters.py +12 -0
  84. vellum/workflows/nodes/core/templating_node/exceptions.py +2 -0
  85. vellum/workflows/nodes/core/templating_node/node.py +123 -0
  86. vellum/workflows/nodes/core/templating_node/render.py +55 -0
  87. vellum/workflows/nodes/core/templating_node/tests/test_templating_node.py +21 -0
  88. vellum/workflows/nodes/core/try_node/__init__.py +5 -0
  89. vellum/workflows/nodes/core/try_node/node.py +110 -0
  90. vellum/workflows/nodes/core/try_node/tests/__init__.py +0 -0
  91. vellum/workflows/nodes/core/try_node/tests/test_node.py +82 -0
  92. vellum/workflows/nodes/displayable/__init__.py +31 -0
  93. vellum/workflows/nodes/displayable/api_node/__init__.py +5 -0
  94. vellum/workflows/nodes/displayable/api_node/node.py +44 -0
  95. vellum/workflows/nodes/displayable/bases/__init__.py +11 -0
  96. vellum/workflows/nodes/displayable/bases/api_node/__init__.py +5 -0
  97. vellum/workflows/nodes/displayable/bases/api_node/node.py +70 -0
  98. vellum/workflows/nodes/displayable/bases/base_prompt_node/__init__.py +5 -0
  99. vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +60 -0
  100. vellum/workflows/nodes/displayable/bases/inline_prompt_node/__init__.py +5 -0
  101. vellum/workflows/nodes/displayable/bases/inline_prompt_node/constants.py +13 -0
  102. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +118 -0
  103. vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +98 -0
  104. vellum/workflows/nodes/displayable/bases/search_node.py +90 -0
  105. vellum/workflows/nodes/displayable/code_execution_node/__init__.py +5 -0
  106. vellum/workflows/nodes/displayable/code_execution_node/node.py +197 -0
  107. vellum/workflows/nodes/displayable/code_execution_node/tests/__init__.py +0 -0
  108. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/__init__.py +0 -0
  109. vellum/workflows/nodes/displayable/code_execution_node/tests/fixtures/main.py +3 -0
  110. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +111 -0
  111. vellum/workflows/nodes/displayable/code_execution_node/utils.py +10 -0
  112. vellum/workflows/nodes/displayable/conditional_node/__init__.py +5 -0
  113. vellum/workflows/nodes/displayable/conditional_node/node.py +25 -0
  114. vellum/workflows/nodes/displayable/final_output_node/__init__.py +5 -0
  115. vellum/workflows/nodes/displayable/final_output_node/node.py +43 -0
  116. vellum/workflows/nodes/displayable/guardrail_node/__init__.py +5 -0
  117. vellum/workflows/nodes/displayable/guardrail_node/node.py +97 -0
  118. vellum/workflows/nodes/displayable/inline_prompt_node/__init__.py +5 -0
  119. vellum/workflows/nodes/displayable/inline_prompt_node/node.py +41 -0
  120. vellum/workflows/nodes/displayable/merge_node/__init__.py +5 -0
  121. vellum/workflows/nodes/displayable/merge_node/node.py +10 -0
  122. vellum/workflows/nodes/displayable/prompt_deployment_node/__init__.py +5 -0
  123. vellum/workflows/nodes/displayable/prompt_deployment_node/node.py +45 -0
  124. vellum/workflows/nodes/displayable/search_node/__init__.py +5 -0
  125. vellum/workflows/nodes/displayable/search_node/node.py +26 -0
  126. vellum/workflows/nodes/displayable/subworkflow_deployment_node/__init__.py +5 -0
  127. vellum/workflows/nodes/displayable/subworkflow_deployment_node/node.py +156 -0
  128. vellum/workflows/nodes/displayable/tests/__init__.py +0 -0
  129. vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +148 -0
  130. vellum/workflows/nodes/displayable/tests/test_search_node_wth_text_output.py +134 -0
  131. vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +80 -0
  132. vellum/workflows/nodes/utils.py +27 -0
  133. vellum/workflows/outputs/__init__.py +6 -0
  134. vellum/workflows/outputs/base.py +196 -0
  135. vellum/workflows/ports/__init__.py +7 -0
  136. vellum/workflows/ports/node_ports.py +75 -0
  137. vellum/workflows/ports/port.py +75 -0
  138. vellum/workflows/ports/utils.py +40 -0
  139. vellum/workflows/references/__init__.py +17 -0
  140. vellum/workflows/references/environment_variable.py +20 -0
  141. vellum/workflows/references/execution_count.py +20 -0
  142. vellum/workflows/references/external_input.py +49 -0
  143. vellum/workflows/references/input.py +7 -0
  144. vellum/workflows/references/lazy.py +55 -0
  145. vellum/workflows/references/node.py +43 -0
  146. vellum/workflows/references/output.py +78 -0
  147. vellum/workflows/references/state_value.py +23 -0
  148. vellum/workflows/references/vellum_secret.py +15 -0
  149. vellum/workflows/references/workflow_input.py +41 -0
  150. vellum/workflows/resolvers/__init__.py +5 -0
  151. vellum/workflows/resolvers/base.py +15 -0
  152. vellum/workflows/runner/__init__.py +5 -0
  153. vellum/workflows/runner/runner.py +588 -0
  154. vellum/workflows/runner/types.py +18 -0
  155. vellum/workflows/state/__init__.py +5 -0
  156. vellum/workflows/state/base.py +327 -0
  157. vellum/workflows/state/context.py +18 -0
  158. vellum/workflows/state/encoder.py +57 -0
  159. vellum/workflows/state/store.py +28 -0
  160. vellum/workflows/state/tests/__init__.py +0 -0
  161. vellum/workflows/state/tests/test_state.py +113 -0
  162. vellum/workflows/types/__init__.py +0 -0
  163. vellum/workflows/types/core.py +91 -0
  164. vellum/workflows/types/generics.py +14 -0
  165. vellum/workflows/types/stack.py +39 -0
  166. vellum/workflows/types/tests/__init__.py +0 -0
  167. vellum/workflows/types/tests/test_utils.py +76 -0
  168. vellum/workflows/types/utils.py +164 -0
  169. vellum/workflows/utils/__init__.py +0 -0
  170. vellum/workflows/utils/names.py +13 -0
  171. vellum/workflows/utils/tests/__init__.py +0 -0
  172. vellum/workflows/utils/tests/test_names.py +15 -0
  173. vellum/workflows/utils/tests/test_vellum_variables.py +25 -0
  174. vellum/workflows/utils/vellum_variables.py +81 -0
  175. vellum/workflows/vellum_client.py +18 -0
  176. vellum/workflows/workflows/__init__.py +5 -0
  177. vellum/workflows/workflows/base.py +365 -0
  178. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/METADATA +2 -1
  179. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/RECORD +245 -7
  180. vellum_cli/__init__.py +72 -0
  181. vellum_cli/aliased_group.py +103 -0
  182. vellum_cli/config.py +96 -0
  183. vellum_cli/image_push.py +112 -0
  184. vellum_cli/logger.py +36 -0
  185. vellum_cli/pull.py +73 -0
  186. vellum_cli/push.py +121 -0
  187. vellum_cli/tests/test_config.py +100 -0
  188. vellum_cli/tests/test_pull.py +152 -0
  189. vellum_ee/workflows/__init__.py +0 -0
  190. vellum_ee/workflows/display/__init__.py +0 -0
  191. vellum_ee/workflows/display/base.py +73 -0
  192. vellum_ee/workflows/display/nodes/__init__.py +4 -0
  193. vellum_ee/workflows/display/nodes/base_node_display.py +116 -0
  194. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +36 -0
  195. vellum_ee/workflows/display/nodes/get_node_display_class.py +25 -0
  196. vellum_ee/workflows/display/nodes/tests/__init__.py +0 -0
  197. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +47 -0
  198. vellum_ee/workflows/display/nodes/types.py +18 -0
  199. vellum_ee/workflows/display/nodes/utils.py +33 -0
  200. vellum_ee/workflows/display/nodes/vellum/__init__.py +32 -0
  201. vellum_ee/workflows/display/nodes/vellum/api_node.py +205 -0
  202. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +71 -0
  203. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +217 -0
  204. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +61 -0
  205. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +49 -0
  206. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +170 -0
  207. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +99 -0
  208. vellum_ee/workflows/display/nodes/vellum/map_node.py +100 -0
  209. vellum_ee/workflows/display/nodes/vellum/merge_node.py +48 -0
  210. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +68 -0
  211. vellum_ee/workflows/display/nodes/vellum/search_node.py +193 -0
  212. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +58 -0
  213. vellum_ee/workflows/display/nodes/vellum/templating_node.py +67 -0
  214. vellum_ee/workflows/display/nodes/vellum/tests/__init__.py +0 -0
  215. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +106 -0
  216. vellum_ee/workflows/display/nodes/vellum/try_node.py +38 -0
  217. vellum_ee/workflows/display/nodes/vellum/utils.py +76 -0
  218. vellum_ee/workflows/display/tests/__init__.py +0 -0
  219. vellum_ee/workflows/display/tests/workflow_serialization/__init__.py +0 -0
  220. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +426 -0
  221. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +607 -0
  222. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +1175 -0
  223. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +235 -0
  224. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +511 -0
  225. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +372 -0
  226. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +272 -0
  227. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +289 -0
  228. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +354 -0
  229. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +123 -0
  230. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_try_node_serialization.py +84 -0
  231. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +233 -0
  232. vellum_ee/workflows/display/types.py +46 -0
  233. vellum_ee/workflows/display/utils/__init__.py +0 -0
  234. vellum_ee/workflows/display/utils/tests/__init__.py +0 -0
  235. vellum_ee/workflows/display/utils/tests/test_uuids.py +16 -0
  236. vellum_ee/workflows/display/utils/uuids.py +24 -0
  237. vellum_ee/workflows/display/utils/vellum.py +121 -0
  238. vellum_ee/workflows/display/vellum.py +357 -0
  239. vellum_ee/workflows/display/workflows/__init__.py +5 -0
  240. vellum_ee/workflows/display/workflows/base_workflow_display.py +302 -0
  241. vellum_ee/workflows/display/workflows/get_vellum_workflow_display_class.py +32 -0
  242. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +386 -0
  243. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/LICENSE +0 -0
  244. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/WHEEL +0 -0
  245. {vellum_ai-0.9.16rc2.dist-info → vellum_ai-0.10.0.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,646 @@
1
- from typing import Type
1
+ import re
2
+ from typing import Callable, Dict, List, Optional, Set, Type
2
3
 
3
- from mypy.plugin import Plugin
4
+ from mypy.nodes import (
5
+ AssignmentStmt,
6
+ CallExpr,
7
+ Decorator,
8
+ MemberExpr,
9
+ NameExpr,
10
+ OverloadedFuncDef,
11
+ SymbolTableNode,
12
+ TypeAlias as MypyTypeAlias,
13
+ TypeInfo,
14
+ Var,
15
+ )
16
+ from mypy.options import Options
17
+ from mypy.plugin import AttributeContext, ClassDefContext, FunctionSigContext, MethodContext, Plugin
18
+ from mypy.types import AnyType, CallableType, FunctionLike, Instance, Type as MypyType, TypeAliasType, UnionType
19
+
20
+ TypeResolver = Callable[[str, List[MypyType]], MypyType]
21
+
22
+ DESCRIPTOR_PATHS: list[tuple[str, str, str]] = [
23
+ (
24
+ "vellum.workflows.outputs.base.BaseOutputs",
25
+ "vellum.workflows.references.output.OutputReference",
26
+ r"^[^_].*$",
27
+ ),
28
+ (
29
+ "vellum.workflows.nodes.bases.base.BaseNode.ExternalInputs",
30
+ "vellum.workflows.references.external_input.ExternalInputReference",
31
+ r"^[^_].*$",
32
+ ),
33
+ (
34
+ "vellum.workflows.nodes.bases.base.BaseNode.Execution",
35
+ "vellum.workflows.references.execution_count.ExecutionCountReference",
36
+ r"^count$",
37
+ ),
38
+ (
39
+ "vellum.workflows.state.base.BaseState",
40
+ "vellum.workflows.references.state_value.StateValueReference",
41
+ r"^[^_].*$",
42
+ ),
43
+ (
44
+ "vellum.workflows.inputs.base.BaseInputs",
45
+ "vellum.workflows.references.workflow_input.WorkflowInputReference",
46
+ r"^[^_].*$",
47
+ ),
48
+ (
49
+ "vellum.workflows.nodes.bases.base.BaseNode",
50
+ "vellum.workflows.references.node.NodeReference",
51
+ r"^[a-z].*$",
52
+ ),
53
+ ]
54
+
55
+
56
+ def _is_subclass(type_info: Optional[TypeInfo], fullname: str) -> bool:
57
+ if not type_info:
58
+ return False
59
+
60
+ if type_info.fullname == fullname:
61
+ return True
62
+
63
+ return any(base.type.fullname == fullname or _is_subclass(base.type, fullname) for base in type_info.bases)
64
+
65
+
66
+ def _get_attribute_mypy_type(type_info: TypeInfo, attribute_name: str) -> Optional[MypyType]:
67
+ type_node = type_info.names.get(attribute_name)
68
+ if type_node:
69
+ return type_node.type
70
+
71
+ bases = type_info.bases
72
+ for base in bases:
73
+ mypy_type = _get_attribute_mypy_type(base.type, attribute_name)
74
+ if mypy_type:
75
+ return mypy_type
76
+
77
+ return None
4
78
 
5
79
 
6
80
  class VellumMypyPlugin(Plugin):
7
- pass
81
+ """
82
+ This plugin is responsible for properly supporting types for all of the magic we
83
+ do with Descriptors and this library in general.
84
+ """
85
+
86
+ def __init__(self, options: Options) -> None:
87
+ """
88
+ Mypy performs its analyses in two phases: semantic analysis and type checking.
89
+ So we initialize state that we need to use across both phases here.
90
+
91
+ - `_calls_with_nested_descriptor_expressions`: A set of instances that are defined
92
+ within a BaseNode that reference a descriptor. Classes (e.g. Dataclasses) don't
93
+ support Descriptors by default, so we first need to ensure we're in the context of
94
+ a `BaseNode`, before then editing the signature to support Descriptors.
95
+
96
+ - `_nested_descriptor_expressions`: A mapping of MemberExprs that point to a
97
+ Descriptor to a callable that will resolve the type of the Descriptor. This field
98
+ is used in combination with `_calls_with_nested_descriptor_expressions` to
99
+ edit the signature of the function when a Descriptor is used as an argument.
100
+ """
101
+
102
+ self._calls_with_nested_descriptor_expressions: Set[CallExpr] = set()
103
+ self._nested_descriptor_expressions: Dict[MemberExpr, TypeResolver] = {}
104
+
105
+ super().__init__(options)
106
+
107
+ def get_class_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], MypyType]]:
108
+ """
109
+ This hook is used whenever we're accessing an attribute of a class. e.g. `MyClass.my_attribute`.
110
+
111
+ We use it to replace all special class attribute references with our descriptors.
112
+
113
+ TODO: We still need to support all other descriptors besides Outputs.
114
+ https://app.shortcut.com/vellum/story/4768
115
+ """
116
+ return self._class_attribute_hook
117
+
118
+ def _class_attribute_hook(self, ctx: AttributeContext) -> MypyType:
119
+ if not isinstance(ctx.type, CallableType):
120
+ return ctx.default_attr_type
121
+
122
+ if not isinstance(ctx.type.ret_type, Instance):
123
+ return ctx.default_attr_type
124
+
125
+ if not isinstance(ctx.context, MemberExpr):
126
+ return ctx.default_attr_type
127
+
128
+ for base_path, descriptor_path, attribute_regex in DESCRIPTOR_PATHS:
129
+ if not re.match(attribute_regex, ctx.context.name):
130
+ continue
131
+
132
+ if not _is_subclass(ctx.type.ret_type.type, base_path):
133
+ continue
134
+
135
+ symbol = ctx.type.ret_type.type.names.get(ctx.context.name)
136
+ if symbol and isinstance(symbol.node, (Decorator, OverloadedFuncDef)):
137
+ continue
138
+
139
+ info = self.lookup_fully_qualified(descriptor_path)
140
+ if info and isinstance(info.node, TypeInfo):
141
+ return Instance(info.node, [self._resolve_descriptor_type(ctx.default_attr_type)])
142
+
143
+ return ctx.default_attr_type
144
+
145
+ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
146
+ """
147
+ This hook is used whenever we're defining a class. e.g. `class MyClass(BaseNode): ...`.
148
+
149
+ We add support for any special Base Class we define in our project.
150
+ """
151
+
152
+ return self._base_class_hook
153
+
154
+ def _base_class_hook(self, ctx: ClassDefContext) -> None:
155
+ if _is_subclass(ctx.cls.info, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
156
+ self._dynamic_output_node_class_hook(ctx, "result")
157
+ elif _is_subclass(ctx.cls.info, "vellum.workflows.nodes.displayable.code_execution_node.node.CodeExecutionNode"):
158
+ self._dynamic_output_node_class_hook(ctx, "result")
159
+ elif _is_subclass(ctx.cls.info, "vellum.workflows.nodes.displayable.final_output_node.node.FinalOutputNode"):
160
+ self._dynamic_output_node_class_hook(ctx, "value")
161
+
162
+ if _is_subclass(ctx.cls.info, "vellum.workflows.nodes.bases.base.BaseNode"):
163
+ return self._base_node_class_hook(ctx)
164
+
165
+ if _is_subclass(ctx.cls.info, "vellum.workflows.workflows.base.BaseWorkflow"):
166
+ return self._base_workflow_class_hook(ctx)
167
+
168
+ if _is_subclass(ctx.cls.info, "vellum.workflows.outputs.base.BaseOutputs"):
169
+ return self._base_node_outputs_class_hook(ctx)
170
+
171
+ def _dynamic_output_node_class_hook(self, ctx: ClassDefContext, attribute_name: str) -> None:
172
+ """
173
+ We use this hook to properly annotate the Outputs class for Templating Node using the resolved type
174
+ of the TemplatingNode's class _OutputType generic.
175
+ """
176
+
177
+ templating_node_info = ctx.cls.info
178
+ templating_node_bases = ctx.cls.info.bases
179
+ if not templating_node_bases:
180
+ return
181
+ if not isinstance(templating_node_bases[0], Instance):
182
+ return
183
+
184
+ base_templating_args = templating_node_bases[0].args
185
+ base_templating_node = templating_node_bases[0].type
186
+ if not _is_subclass(base_templating_node, "vellum.workflows.nodes.core.templating_node.node.TemplatingNode"):
187
+ return
188
+
189
+ if len(base_templating_args) != 2:
190
+ return
191
+
192
+ base_templating_node_resolved_type = base_templating_args[1]
193
+ if isinstance(base_templating_node_resolved_type, AnyType):
194
+ base_templating_node_resolved_type = ctx.api.named_type("builtins.str")
195
+
196
+ base_templating_node_outputs = base_templating_node.names.get("Outputs")
197
+ if not base_templating_node_outputs:
198
+ return
199
+
200
+ current_templating_node_outputs = templating_node_info.names.get("Outputs")
201
+ if not current_templating_node_outputs:
202
+ templating_node_info.names["Outputs"] = base_templating_node_outputs.copy()
203
+ new_outputs_sym = templating_node_info.names["Outputs"].node
204
+ if isinstance(new_outputs_sym, TypeInfo):
205
+ result_sym = new_outputs_sym.names[attribute_name].node
206
+ if isinstance(result_sym, Var):
207
+ result_sym.type = base_templating_node_resolved_type
208
+
209
+ def _base_node_class_hook(self, ctx: ClassDefContext) -> None:
210
+ """
211
+ Special handling of BaseNode class definitions
212
+ """
213
+ self._redefine_class_attributes_with_descriptors(ctx)
214
+
215
+ def _base_node_outputs_class_hook(self, ctx: ClassDefContext) -> None:
216
+ """
217
+ Special handling of BaseNode.Outputs class definitions
218
+ """
219
+ self._redefine_class_attributes_with_descriptors(ctx)
220
+
221
+ def _redefine_class_attributes_with_descriptors(self, ctx: ClassDefContext) -> None:
222
+ """
223
+ Given a class definition, we want to redefine all of the class attributes to accept both
224
+ the original defined type, and the descriptor version of the type.
225
+ """
226
+
227
+ for sym in ctx.cls.info.names.values():
228
+ if not isinstance(sym.node, Var):
229
+ continue
230
+
231
+ type_ = sym.node.type
232
+ if not type_:
233
+ continue
234
+
235
+ sym.node.type = self._get_resolvable_type(
236
+ lambda fullname, types: ctx.api.named_type(fullname, types), type_
237
+ )
238
+
239
+ # Supports descriptors assigned to nested classes
240
+ type_resolver = lambda fullname, types: ctx.api.named_type(fullname, types) # noqa: E731
241
+ for assignment in ctx.cls.defs.body:
242
+ if not isinstance(assignment, AssignmentStmt):
243
+ continue
244
+
245
+ call_expr = assignment.rvalue
246
+ if not isinstance(call_expr, CallExpr):
247
+ continue
248
+
249
+ self._collect_descriptor_expressions(type_resolver, call_expr)
250
+
251
+ def _collect_descriptor_expressions(self, type_resolver: TypeResolver, call_expr: CallExpr) -> None:
252
+ for arg in call_expr.args:
253
+ if isinstance(arg, CallExpr):
254
+ self._collect_descriptor_expressions(type_resolver, arg)
255
+ continue
256
+
257
+ if not isinstance(arg, MemberExpr):
258
+ continue
259
+
260
+ is_arg_registered = False
261
+ for base_path, _, attribute_regex in DESCRIPTOR_PATHS:
262
+ if not re.match(attribute_regex, arg.name):
263
+ continue
264
+
265
+ if not isinstance(arg.expr, (NameExpr, MemberExpr)):
266
+ continue
267
+
268
+ if not isinstance(arg.expr.node, TypeInfo):
269
+ continue
270
+
271
+ if not _is_subclass(arg.expr.node, base_path):
272
+ continue
273
+
274
+ is_arg_registered = True
275
+ self._calls_with_nested_descriptor_expressions.add(call_expr)
276
+ self._nested_descriptor_expressions[arg] = type_resolver
277
+ break
278
+
279
+ if not is_arg_registered:
280
+ # Need to check node outputs that are inherited
281
+ if (
282
+ re.match(r"^[^_].*$", arg.name)
283
+ and isinstance(arg.expr, MemberExpr)
284
+ and isinstance(arg.expr.expr, NameExpr)
285
+ and isinstance(arg.expr.expr.node, TypeInfo)
286
+ and _is_subclass(arg.expr.expr.node, "vellum.workflows.nodes.bases.base.BaseNode")
287
+ and arg.expr.name == "Outputs"
288
+ ):
289
+ self._calls_with_nested_descriptor_expressions.add(call_expr)
290
+ self._nested_descriptor_expressions[arg] = type_resolver
291
+
292
+ def _base_workflow_class_hook(self, ctx: ClassDefContext) -> None:
293
+ """
294
+ Placeholder for any special type logic we want to add to the Workflow class.
295
+ """
296
+
297
+ pass
298
+
299
+ def get_function_signature_hook(self, fullname: str) -> Optional[Callable[[FunctionSigContext], FunctionLike]]:
300
+ """
301
+ This hook is used whenever we're calling a function and are type checking the signature. e.g. `f(a, b)`.
302
+
303
+ We use this to support nested objects that reference descriptors within a node. Class initialization
304
+ counts as a function call in mypy, so we want to support nested descriptors assigned to class instances
305
+ we don't control, like dataclasses.
306
+ """
307
+
308
+ return self._function_signature_hook
309
+
310
+ def _function_signature_hook(self, ctx: FunctionSigContext) -> FunctionLike:
311
+ if not isinstance(ctx.context, CallExpr):
312
+ return ctx.default_signature
313
+
314
+ if ctx.context not in self._calls_with_nested_descriptor_expressions:
315
+ return ctx.default_signature
316
+
317
+ old_arg_types = ctx.default_signature.arg_types
318
+ old_arg_names = ctx.default_signature.arg_names
319
+ old_arg_kinds = ctx.default_signature.arg_kinds
320
+ new_arg_types = []
321
+
322
+ new_arg_by_name = {
323
+ arg_name: ctx.context.args[arg_index] for arg_index, arg_name in enumerate(ctx.context.arg_names)
324
+ }
325
+
326
+ should_copy_new_signature = False
327
+
328
+ for arg_index, old_arg_type in enumerate(old_arg_types):
329
+ if arg_index >= len(old_arg_kinds) or arg_index >= len(old_arg_names):
330
+ new_arg_types.append(old_arg_type)
331
+ continue
332
+
333
+ old_arg_kind = old_arg_kinds[arg_index]
334
+ if old_arg_kind.is_named():
335
+ old_arg_name = old_arg_names[arg_index]
336
+ old_arg = new_arg_by_name.get(old_arg_name)
337
+ elif arg_index < len(ctx.context.args):
338
+ old_arg = ctx.context.args[arg_index]
339
+ else:
340
+ old_arg = None
341
+
342
+ if isinstance(old_arg, MemberExpr) and old_arg in self._nested_descriptor_expressions:
343
+ should_copy_new_signature = True
344
+
345
+ new_arg_types.append(
346
+ self._get_resolvable_type(
347
+ self._nested_descriptor_expressions[old_arg],
348
+ old_arg_type,
349
+ )
350
+ )
351
+ else:
352
+ new_arg_types.append(old_arg_type)
353
+
354
+ if not should_copy_new_signature:
355
+ return ctx.default_signature
356
+
357
+ return ctx.default_signature.copy_modified(
358
+ ret_type=ctx.default_signature.ret_type,
359
+ arg_types=new_arg_types,
360
+ )
361
+
362
+ def _get_resolvable_type(
363
+ self, get_named_type: Callable[[str, List[MypyType]], MypyType], type_: MypyType
364
+ ) -> MypyType:
365
+ if isinstance(type_, TypeAliasType) and type_.alias:
366
+ if type_.alias.fullname == "vellum.workflows.types.core.Json":
367
+ """
368
+ We want to avoid infinite recursion, so we just state that the descriptor can
369
+ just reference `Json` directly instead of each of the members.
370
+ """
371
+ return UnionType(
372
+ [
373
+ type_,
374
+ get_named_type("vellum.workflows.descriptors.base.BaseDescriptor", [type_]),
375
+ ]
376
+ )
377
+
378
+ """
379
+ Type Aliases expand to an actual type, so we want to keep drilling down.
380
+ Example: Foo = str
381
+ """
382
+ return self._get_resolvable_type(get_named_type, type_.alias.target)
383
+
384
+ if isinstance(type_, UnionType):
385
+ """
386
+ If a node attribute is referencing a union type, we want to accept a descriptor
387
+ pointing to any of the individual members.
388
+ """
389
+
390
+ return UnionType([self._get_resolvable_type(get_named_type, t) for t in type_.items])
391
+
392
+ if isinstance(type_, Instance) and type_.type.fullname == "builtins.dict":
393
+ """
394
+ If a node attribute is referencing a dict, we want to accept a descriptor pointing
395
+ to the dict itself or any of the values it maps to.
396
+ """
397
+
398
+ key_type = type_.args[0]
399
+ value_type = type_.args[1]
400
+ return get_named_type(
401
+ type_.type.fullname,
402
+ [
403
+ key_type,
404
+ self._get_resolvable_type(get_named_type, value_type),
405
+ ],
406
+ )
407
+
408
+ """
409
+ Otherwise by default, we want to accept a descriptor pointing to the type itself.
410
+ """
411
+ return UnionType(
412
+ [
413
+ type_,
414
+ get_named_type("vellum.workflows.descriptors.base.BaseDescriptor", [type_]),
415
+ ]
416
+ )
417
+
418
+ def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], MypyType]]:
419
+ return self._attribute_hook
420
+
421
+ def _attribute_hook(self, ctx: AttributeContext) -> MypyType:
422
+ if not isinstance(ctx.context, MemberExpr):
423
+ return ctx.default_attr_type
424
+ if not isinstance(ctx.context.expr, NameExpr) or ctx.context.expr.name != "self":
425
+ return ctx.default_attr_type
426
+
427
+ # TODO: ensure that `self` is a BaseNode
428
+ # https://app.shortcut.com/vellum/story/5531
429
+
430
+ return self._resolve_descriptor_type(ctx.default_attr_type)
431
+
432
+ def _resolve_descriptor_type(self, default_type: MypyType) -> MypyType:
433
+ if not isinstance(default_type, UnionType):
434
+ return default_type
435
+
436
+ non_descriptor_items = [
437
+ (
438
+ item.args[0]
439
+ if isinstance(item, Instance)
440
+ and _is_subclass(item.type, "vellum.workflows.descriptors.base.BaseDescriptor")
441
+ and len(item.args) > 0
442
+ else item
443
+ )
444
+ for item in default_type.items
445
+ ]
446
+
447
+ new_items = list(set(non_descriptor_items))
448
+
449
+ if len(new_items) == 0:
450
+ return default_type
451
+
452
+ if len(new_items) == 1:
453
+ return new_items[0]
454
+
455
+ return UnionType(
456
+ items=new_items,
457
+ )
458
+
459
+ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
460
+ """
461
+ This hook is used whenever we're calling a method from a class and are type checking the return type.
462
+ We use this to support special return types from our classes that isn't supported by default due to
463
+ the lack of nested class inheritance, among other python pitfalls.
464
+ """
465
+
466
+ if fullname.endswith(".run"):
467
+ return self._run_method_hook
468
+
469
+ if fullname.endswith(".stream"):
470
+ return self._stream_method_hook
471
+
472
+ return None
473
+
474
+ def _run_method_hook(self, ctx: MethodContext) -> MypyType:
475
+ """
476
+ We use this to target `Workflow.run()` so that the WorkflowExecutionFulfilledEvent is properly typed
477
+ using the `Outputs` class defined on the user-defined subclass of `Workflow`.
478
+ """
479
+
480
+ if not isinstance(ctx.default_return_type, TypeAliasType):
481
+ return ctx.default_return_type
482
+
483
+ alias = ctx.default_return_type.alias
484
+ if not alias:
485
+ return ctx.default_return_type
486
+
487
+ alias_target = alias.target
488
+ if not isinstance(alias_target, UnionType) or not alias_target.items:
489
+ return ctx.default_return_type
490
+
491
+ fulfilled_event = alias_target.items[0]
492
+ if not isinstance(fulfilled_event, Instance):
493
+ return ctx.default_return_type
494
+
495
+ if fulfilled_event.type.fullname != "vellum.workflows.events.workflow.WorkflowExecutionFulfilledEvent":
496
+ return ctx.default_return_type
497
+
498
+ outputs_node = self._get_outputs_node(ctx)
499
+ if not outputs_node:
500
+ return ctx.default_return_type
501
+
502
+ new_fulfilled_event = fulfilled_event.copy_modified(args=(Instance(outputs_node, []),))
503
+ return TypeAliasType(
504
+ alias=MypyTypeAlias(
505
+ target=UnionType(
506
+ items=[new_fulfilled_event] + alias_target.items[1:],
507
+ ),
508
+ fullname=alias.fullname,
509
+ line=alias.line,
510
+ column=alias.column,
511
+ ),
512
+ args=ctx.default_return_type.args,
513
+ line=ctx.default_return_type.line,
514
+ column=ctx.default_return_type.column,
515
+ )
516
+
517
+ def _stream_method_hook(self, ctx: MethodContext) -> MypyType:
518
+ """
519
+ We use this to target `Workflow.stream()` so that the WorkflowExecutionFulfilledEvent is properly typed
520
+ using the `Outputs` class defined on the user-defined subclass of `Workflow`.
521
+ """
522
+
523
+ if not isinstance(ctx.default_return_type, TypeAliasType):
524
+ return ctx.default_return_type
525
+
526
+ alias = ctx.default_return_type.alias
527
+ if not alias:
528
+ return ctx.default_return_type
529
+
530
+ alias_target = alias.target
531
+ if (
532
+ not isinstance(alias_target, Instance)
533
+ or not _is_subclass(alias_target.type, "typing.Iterator")
534
+ or not alias_target.args
535
+ ):
536
+ return ctx.default_return_type
537
+
538
+ union_alias = alias_target.args[0]
539
+ if not isinstance(union_alias, TypeAliasType) or not union_alias.alias:
540
+ return ctx.default_return_type
541
+
542
+ union_target = union_alias.alias.target
543
+ if not isinstance(union_target, UnionType) or not union_target.items:
544
+ return ctx.default_return_type
545
+
546
+ fulfilled_event_index = -1
547
+ fulfilled_event = None
548
+ for event_type_index, event_type in enumerate(union_target.items):
549
+ if not isinstance(event_type, Instance):
550
+ continue
551
+
552
+ if event_type.type.fullname != "vellum.workflows.events.workflow.WorkflowExecutionFulfilledEvent":
553
+ continue
554
+
555
+ fulfilled_event_index = event_type_index
556
+ fulfilled_event = event_type
557
+
558
+ if fulfilled_event_index == -1 or not fulfilled_event:
559
+ return ctx.default_return_type
560
+
561
+ outputs_node = self._get_outputs_node(ctx)
562
+ if not outputs_node:
563
+ return ctx.default_return_type
564
+
565
+ new_fulfilled_event = fulfilled_event.copy_modified(args=(Instance(outputs_node, []),))
566
+ return TypeAliasType(
567
+ alias=MypyTypeAlias(
568
+ target=alias_target.copy_modified(
569
+ args=[
570
+ TypeAliasType(
571
+ alias=MypyTypeAlias(
572
+ target=UnionType(
573
+ items=[
574
+ new_fulfilled_event if index == fulfilled_event_index else item
575
+ for index, item in enumerate(union_target.items)
576
+ ],
577
+ ),
578
+ fullname=union_alias.alias.fullname,
579
+ line=union_alias.alias.line,
580
+ column=union_alias.alias.column,
581
+ ),
582
+ args=union_alias.args,
583
+ line=union_alias.line,
584
+ column=union_alias.column,
585
+ )
586
+ ]
587
+ + list(alias_target.args[1:])
588
+ ),
589
+ fullname=alias.fullname,
590
+ line=alias.line,
591
+ column=alias.column,
592
+ ),
593
+ args=ctx.default_return_type.args,
594
+ line=ctx.default_return_type.line,
595
+ column=ctx.default_return_type.column,
596
+ )
597
+
598
+ def _get_outputs_node(self, ctx: MethodContext) -> Optional[TypeInfo]:
599
+ if not isinstance(ctx.context, CallExpr):
600
+ return None
601
+
602
+ if not isinstance(ctx.context.callee, MemberExpr):
603
+ return None
604
+
605
+ expr = ctx.context.callee.expr
606
+ instance = ctx.api.get_expression_type(expr)
607
+ if not isinstance(instance, Instance) or not _is_subclass(
608
+ instance.type, "vellum.workflows.workflows.base.BaseWorkflow"
609
+ ):
610
+ return None
611
+
612
+ outputs_node = instance.type.names.get("Outputs")
613
+
614
+ if (
615
+ not outputs_node
616
+ or not isinstance(outputs_node.node, TypeInfo)
617
+ or not _is_subclass(outputs_node.node, "vellum.workflows.outputs.base.BaseOutputs")
618
+ ):
619
+ return None
620
+
621
+ resolved_outputs_node = self._resolve_descriptors_in_outputs(outputs_node)
622
+
623
+ if not isinstance(resolved_outputs_node.node, TypeInfo):
624
+ return None
625
+
626
+ return resolved_outputs_node.node
627
+
628
+ def _resolve_descriptors_in_outputs(self, type_info: SymbolTableNode) -> SymbolTableNode:
629
+ new_type_info = type_info.copy()
630
+ if not isinstance(new_type_info.node, TypeInfo):
631
+ return new_type_info
632
+
633
+ for sym in new_type_info.node.names.values():
634
+ if isinstance(sym.node, Var):
635
+ descriptor_type = sym.node.type
636
+ if isinstance(descriptor_type, Instance) and _is_subclass(
637
+ descriptor_type.type, "vellum.workflows.descriptors.base.BaseDescriptor"
638
+ ):
639
+ args = descriptor_type.args
640
+ if args:
641
+ sym.node.type = args[0]
642
+
643
+ return new_type_info
8
644
 
9
645
 
10
646
  def plugin(version: str) -> Type[VellumMypyPlugin]: