snowflake-cli-labs 2.7.0rc3__py3-none-any.whl → 2.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. snowflake/cli/__about__.py +1 -1
  2. snowflake/cli/api/feature_flags.py +1 -2
  3. snowflake/cli/api/project/definition.py +3 -36
  4. snowflake/cli/api/project/errors.py +16 -1
  5. snowflake/cli/api/project/schemas/entities/application_entity.py +5 -11
  6. snowflake/cli/api/project/schemas/entities/application_package_entity.py +5 -2
  7. snowflake/cli/api/project/schemas/entities/common.py +15 -22
  8. snowflake/cli/api/project/schemas/native_app/application.py +10 -2
  9. snowflake/cli/api/project/schemas/native_app/native_app.py +13 -2
  10. snowflake/cli/api/project/schemas/native_app/package.py +24 -1
  11. snowflake/cli/api/project/schemas/project_definition.py +23 -40
  12. snowflake/cli/api/project/schemas/snowpark/callable.py +1 -3
  13. snowflake/cli/api/project/schemas/updatable_model.py +148 -5
  14. snowflake/cli/api/project/util.py +55 -7
  15. snowflake/cli/api/rendering/jinja.py +1 -0
  16. snowflake/cli/api/rendering/project_templates.py +8 -7
  17. snowflake/cli/api/rendering/sql_templates.py +8 -4
  18. snowflake/cli/api/utils/definition_rendering.py +50 -11
  19. snowflake/cli/api/utils/models.py +10 -7
  20. snowflake/cli/api/utils/templating_functions.py +144 -0
  21. snowflake/cli/app/build_and_push.sh +8 -0
  22. snowflake/cli/app/snow_connector.py +14 -10
  23. snowflake/cli/plugins/init/commands.py +13 -7
  24. snowflake/cli/plugins/nativeapp/manager.py +93 -10
  25. snowflake/cli/plugins/nativeapp/project_model.py +13 -3
  26. snowflake/cli/plugins/nativeapp/run_processor.py +22 -51
  27. snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py +7 -18
  28. snowflake/cli/plugins/nativeapp/version/version_processor.py +4 -0
  29. snowflake/cli/plugins/snowpark/commands.py +6 -3
  30. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/METADATA +1 -1
  31. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/RECORD +34 -32
  32. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/WHEEL +0 -0
  33. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/entry_points.txt +0 -0
  34. {snowflake_cli_labs-2.7.0rc3.dist-info → snowflake_cli_labs-2.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,4 +14,4 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- VERSION = "2.7.0rc3"
17
+ VERSION = "2.8.0"
@@ -52,5 +52,4 @@ class FeatureFlag(FeatureFlagMixin):
52
52
  ENABLE_STREAMLIT_VERSIONED_STAGE = BooleanFlag(
53
53
  "ENABLE_STREAMLIT_VERSIONED_STAGE", False
54
54
  )
55
- # TODO: remove in 3.0
56
- ENABLE_PROJECT_DEFINITION_V2 = BooleanFlag("ENABLE_PROJECT_DEFINITION_V2", True)
55
+ ENABLE_PROJECT_DEFINITION_V2 = BooleanFlag("ENABLE_PROJECT_DEFINITION_V2", False)
@@ -21,13 +21,12 @@ import yaml
21
21
  from snowflake.cli.api.cli_global_context import cli_context
22
22
  from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB
23
23
  from snowflake.cli.api.project.schemas.project_definition import (
24
- ProjectDefinition,
25
24
  ProjectProperties,
26
25
  )
27
26
  from snowflake.cli.api.project.util import (
28
27
  append_to_identifier,
29
- clean_identifier,
30
28
  get_env_username,
29
+ sanitize_identifier,
31
30
  to_identifier,
32
31
  )
33
32
  from snowflake.cli.api.secure_path import SecurePath
@@ -70,40 +69,8 @@ def load_project(
70
69
  return render_definition_template(merged_definitions, context_overrides or {})
71
70
 
72
71
 
73
- def generate_local_override_yml(
74
- project: ProjectDefinition,
75
- ) -> ProjectDefinition:
76
- """
77
- Generates defaults for optional keys in the same YAML structure as the project
78
- schema. The returned YAML object can be saved directly to a file, if desired.
79
- A connection is made using global context to resolve current role and warehouse.
80
- """
81
- conn = cli_context.connection
82
- user = clean_identifier(get_env_username() or DEFAULT_USERNAME)
83
- role = conn.role
84
- warehouse = conn.warehouse
85
-
86
- local: dict = {}
87
- if project.native_app:
88
- name = clean_identifier(project.native_app.name)
89
- app_identifier = to_identifier(name)
90
- user_app_identifier = append_to_identifier(app_identifier, f"_{user}")
91
- package_identifier = append_to_identifier(app_identifier, f"_pkg_{user}")
92
- local["native_app"] = {
93
- "application": {
94
- "name": user_app_identifier,
95
- "role": role,
96
- "debug": True,
97
- "warehouse": warehouse,
98
- },
99
- "package": {"name": package_identifier, "role": role},
100
- }
101
-
102
- return project.update_from_dict(local)
103
-
104
-
105
72
  def default_app_package(project_name: str):
106
- user = clean_identifier(get_env_username() or DEFAULT_USERNAME)
73
+ user = sanitize_identifier(get_env_username() or DEFAULT_USERNAME).lower()
107
74
  return append_to_identifier(to_identifier(project_name), f"_pkg_{user}")
108
75
 
109
76
 
@@ -113,5 +80,5 @@ def default_role():
113
80
 
114
81
 
115
82
  def default_application(project_name: str):
116
- user = clean_identifier(get_env_username() or DEFAULT_USERNAME)
83
+ user = sanitize_identifier(get_env_username() or DEFAULT_USERNAME).lower()
117
84
  return append_to_identifier(to_identifier(project_name), f"_{user}")
@@ -29,10 +29,25 @@ class SchemaValidationError(ClickException):
29
29
  def __init__(self, error: ValidationError):
30
30
  errors = error.errors()
31
31
  message = f"During evaluation of {error.title} in project definition following errors were encountered:\n"
32
+
33
+ def calculate_location(e):
34
+ if e["loc"] is None:
35
+ return None
36
+
37
+ # show numbers as list indexes and strings as dictionary keys. Example: key1[0].key2
38
+ result = "".join(
39
+ f"[{item}]" if isinstance(item, int) else f".{item}"
40
+ for item in e["loc"]
41
+ )
42
+
43
+ # remove leading dot from the string if any:
44
+ return result[1:] if result.startswith(".") else result
45
+
32
46
  message += "\n".join(
33
47
  [
34
48
  self.message_templates.get(e["type"], self.generic_message).format(
35
- **e, location=".".join(e["loc"]) if e["loc"] is not None else None
49
+ **e,
50
+ location=calculate_location(e),
36
51
  )
37
52
  for e in errors
38
53
  ]
@@ -16,7 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from typing import Literal, Optional
18
18
 
19
- from pydantic import AliasChoices, Field
19
+ from pydantic import Field
20
20
  from snowflake.cli.api.project.schemas.entities.application_package_entity import (
21
21
  ApplicationPackageEntity,
22
22
  )
@@ -25,26 +25,20 @@ from snowflake.cli.api.project.schemas.entities.common import (
25
25
  TargetField,
26
26
  )
27
27
  from snowflake.cli.api.project.schemas.updatable_model import (
28
- UpdatableModel,
28
+ DiscriminatorField,
29
29
  )
30
30
 
31
31
 
32
32
  class ApplicationEntity(EntityBase):
33
- type: Literal["application"] # noqa: A003
33
+ type: Literal["application"] = DiscriminatorField() # noqa A003
34
34
  name: str = Field(
35
35
  title="Name of the application created when this entity is deployed"
36
36
  )
37
- from_: ApplicationFromField = Field(
38
- validation_alias=AliasChoices("from"),
37
+ from_: TargetField[ApplicationPackageEntity] = Field(
38
+ alias="from",
39
39
  title="An application package this entity should be created from",
40
40
  )
41
41
  debug: Optional[bool] = Field(
42
42
  title="Whether to enable debug mode when using a named stage to create an application object",
43
43
  default=None,
44
44
  )
45
-
46
-
47
- class ApplicationFromField(UpdatableModel):
48
- target: TargetField[ApplicationPackageEntity] = Field(
49
- title="Reference to an application package entity",
50
- )
@@ -23,11 +23,14 @@ from snowflake.cli.api.project.schemas.entities.common import (
23
23
  )
24
24
  from snowflake.cli.api.project.schemas.native_app.package import DistributionOptions
25
25
  from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping
26
- from snowflake.cli.api.project.schemas.updatable_model import IdentifierField
26
+ from snowflake.cli.api.project.schemas.updatable_model import (
27
+ DiscriminatorField,
28
+ IdentifierField,
29
+ )
27
30
 
28
31
 
29
32
  class ApplicationPackageEntity(EntityBase):
30
- type: Literal["application package"] # noqa: A003
33
+ type: Literal["application package"] = DiscriminatorField() # noqa: A003
31
34
  name: str = Field(
32
35
  title="Name of the application package created when this entity is deployed"
33
36
  )
@@ -17,10 +17,9 @@ from __future__ import annotations
17
17
  from abc import ABC
18
18
  from typing import Generic, List, Optional, TypeVar
19
19
 
20
- from pydantic import AliasChoices, Field, GetCoreSchemaHandler, ValidationInfo
21
- from pydantic_core import core_schema
20
+ from pydantic import Field
22
21
  from snowflake.cli.api.project.schemas.native_app.application import (
23
- ApplicationPostDeployHook,
22
+ PostDeployHook,
24
23
  )
25
24
  from snowflake.cli.api.project.schemas.updatable_model import (
26
25
  IdentifierField,
@@ -36,7 +35,7 @@ class MetaField(UpdatableModel):
36
35
  title="Role to use when creating the entity object",
37
36
  default=None,
38
37
  )
39
- post_deploy: Optional[List[ApplicationPostDeployHook]] = Field(
38
+ post_deploy: Optional[List[PostDeployHook]] = Field(
40
39
  title="Actions that will be executed after the application object is created/upgraded",
41
40
  default=None,
42
41
  )
@@ -45,7 +44,7 @@ class MetaField(UpdatableModel):
45
44
  class DefaultsField(UpdatableModel):
46
45
  schema_: Optional[str] = Field(
47
46
  title="Schema.",
48
- validation_alias=AliasChoices("schema"),
47
+ alias="schema",
49
48
  default=None,
50
49
  )
51
50
  stage: Optional[str] = Field(
@@ -65,21 +64,15 @@ class EntityBase(ABC, UpdatableModel):
65
64
  TargetType = TypeVar("TargetType")
66
65
 
67
66
 
68
- class TargetField(Generic[TargetType]):
69
- def __init__(self, entity_target_key: str):
70
- self.value = entity_target_key
71
-
72
- def __repr__(self):
73
- return self.value
74
-
75
- @classmethod
76
- def validate(cls, value: str, info: ValidationInfo) -> TargetField:
77
- return cls(value)
67
+ class TargetField(UpdatableModel, Generic[TargetType]):
68
+ target: str = Field(
69
+ title="Reference to a target entity",
70
+ )
78
71
 
79
- @classmethod
80
- def __get_pydantic_core_schema__(
81
- cls, source_type, handler: GetCoreSchemaHandler
82
- ) -> core_schema.CoreSchema:
83
- return core_schema.with_info_after_validator_function(
84
- cls.validate, handler(str), field_name=handler.field_name
85
- )
72
+ def get_type(self) -> type:
73
+ """
74
+ Returns the generic type of this class, indicating the entity type.
75
+ Pydantic extracts Generic annotations, and populates
76
+ them in __pydantic_generic_metadata__
77
+ """
78
+ return self.__pydantic_generic_metadata__["args"][0]
@@ -28,7 +28,7 @@ class SqlScriptHookType(UpdatableModel):
28
28
 
29
29
 
30
30
  # Currently sql_script is the only supported hook type. Change to a Union once other hook types are added
31
- ApplicationPostDeployHook = SqlScriptHookType
31
+ PostDeployHook = SqlScriptHookType
32
32
 
33
33
 
34
34
  class Application(UpdatableModel):
@@ -48,7 +48,15 @@ class Application(UpdatableModel):
48
48
  title="When set, forces debug_mode on/off for the deployed application object",
49
49
  default=None,
50
50
  )
51
- post_deploy: Optional[List[ApplicationPostDeployHook]] = Field(
51
+ post_deploy: Optional[List[PostDeployHook]] = Field(
52
52
  title="Actions that will be executed after the application object is created/upgraded",
53
53
  default=None,
54
54
  )
55
+
56
+
57
+ class ApplicationV11(Application):
58
+ # Templated defaults only supported in v1.1+
59
+ name: Optional[str] = Field(
60
+ title="Name of the application object created when you run the snow app run command",
61
+ default="<% fn.concat_ids(ctx.native_app.name, '_', fn.sanitize_id(fn.get_username('unknown_user')) | lower) %>",
62
+ )
@@ -18,8 +18,11 @@ import re
18
18
  from typing import List, Optional, Union
19
19
 
20
20
  from pydantic import Field, field_validator
21
- from snowflake.cli.api.project.schemas.native_app.application import Application
22
- from snowflake.cli.api.project.schemas.native_app.package import Package
21
+ from snowflake.cli.api.project.schemas.native_app.application import (
22
+ Application,
23
+ ApplicationV11,
24
+ )
25
+ from snowflake.cli.api.project.schemas.native_app.package import Package, PackageV11
23
26
  from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping
24
27
  from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
25
28
  from snowflake.cli.api.project.util import (
@@ -80,3 +83,11 @@ class NativeApp(UpdatableModel):
80
83
  transformed_artifacts.append(PathMapping(src=artifact))
81
84
 
82
85
  return transformed_artifacts
86
+
87
+
88
+ class NativeAppV11(NativeApp):
89
+ # templated defaults are only supported with version 1.1+
90
+ package: Optional[PackageV11] = Field(title="PackageSchema", default=PackageV11())
91
+ application: Optional[ApplicationV11] = Field(
92
+ title="Application info", default=ApplicationV11()
93
+ )
@@ -16,7 +16,8 @@ from __future__ import annotations
16
16
 
17
17
  from typing import List, Literal, Optional
18
18
 
19
- from pydantic import Field, field_validator
19
+ from pydantic import Field, field_validator, model_validator
20
+ from snowflake.cli.api.project.schemas.native_app.application import PostDeployHook
20
21
  from snowflake.cli.api.project.schemas.updatable_model import (
21
22
  IdentifierField,
22
23
  UpdatableModel,
@@ -44,6 +45,10 @@ class Package(UpdatableModel):
44
45
  title="Distribution of the application package created by the Snowflake CLI",
45
46
  default="internal",
46
47
  )
48
+ post_deploy: Optional[List[PostDeployHook]] = Field(
49
+ title="Actions that will be executed after the application package object is created/updated",
50
+ default=None,
51
+ )
47
52
 
48
53
  @field_validator("scripts")
49
54
  @classmethod
@@ -53,3 +58,21 @@ class Package(UpdatableModel):
53
58
  "package.scripts field should contain unique values. Check the list for duplicates and try again"
54
59
  )
55
60
  return input_list
61
+
62
+ @model_validator(mode="after")
63
+ @classmethod
64
+ def validate_no_scripts_and_post_deploy(cls, value: Package):
65
+ if value.scripts and value.post_deploy:
66
+ raise ValueError(
67
+ "package.scripts and package.post_deploy fields cannot be used together. "
68
+ "We recommend using package.post_deploy for all post package deploy scripts"
69
+ )
70
+ return value
71
+
72
+
73
+ class PackageV11(Package):
74
+ # Templated defaults only supported in v1.1+
75
+ name: Optional[str] = IdentifierField(
76
+ title="Name of the application package created when you run the snow app run command",
77
+ default="<% fn.concat_ids(ctx.native_app.name, '_pkg_', fn.sanitize_id(fn.get_username('unknown_user')) | lower) %>",
78
+ )
@@ -32,11 +32,13 @@ from snowflake.cli.api.project.schemas.entities.entities import (
32
32
  Entity,
33
33
  v2_entity_types_map,
34
34
  )
35
- from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp
35
+ from snowflake.cli.api.project.schemas.native_app.native_app import (
36
+ NativeApp,
37
+ NativeAppV11,
38
+ )
36
39
  from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark
37
40
  from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit
38
41
  from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
39
- from snowflake.cli.api.utils.models import ProjectEnvironment
40
42
  from snowflake.cli.api.utils.types import Context
41
43
  from typing_extensions import Annotated
42
44
 
@@ -99,22 +101,14 @@ class DefinitionV10(_ProjectDefinitionBase):
99
101
 
100
102
 
101
103
  class DefinitionV11(DefinitionV10):
102
- env: Union[Dict[str, str], ProjectEnvironment, None] = Field(
103
- title="Environment specification for this project.",
104
+ native_app: Optional[NativeAppV11] = Field(
105
+ title="Native app definitions for the project", default=None
106
+ )
107
+ env: Optional[Dict[str, Union[str, int, bool]]] = Field(
108
+ title="Default environment specification for this project.",
104
109
  default=None,
105
- validation_alias="env",
106
- union_mode="smart",
107
110
  )
108
111
 
109
- @field_validator("env")
110
- @classmethod
111
- def _convert_env(
112
- cls, env: Union[Dict, ProjectEnvironment, None]
113
- ) -> ProjectEnvironment:
114
- if isinstance(env, ProjectEnvironment):
115
- return env
116
- return ProjectEnvironment(default_env=(env or {}), override_env={})
117
-
118
112
 
119
113
  class DefinitionV20(_ProjectDefinitionBase):
120
114
  entities: Dict[str, Annotated[Entity, Field(discriminator="type")]] = Field(
@@ -147,10 +141,10 @@ class DefinitionV20(_ProjectDefinitionBase):
147
141
  for key, entity in entities.items():
148
142
  # TODO Automatically detect TargetFields to validate
149
143
  if entity.type == ApplicationEntity.get_type():
150
- if isinstance(entity.from_.target, TargetField):
151
- target_key = str(entity.from_.target)
152
- target_class = entity.from_.__class__.model_fields["target"]
153
- target_type = target_class.annotation.__args__[0]
144
+ if isinstance(entity.from_, TargetField):
145
+ target_key = entity.from_.target
146
+ target_object = entity.from_
147
+ target_type = target_object.get_type()
154
148
  cls._validate_target_field(target_key, target_type, entities)
155
149
  return entities
156
150
 
@@ -160,37 +154,26 @@ class DefinitionV20(_ProjectDefinitionBase):
160
154
  ):
161
155
  if target_key not in entities:
162
156
  raise ValueError(f"No such target: {target_key}")
163
- else:
164
- # Validate the target type
165
- actual_target_type = entities[target_key].__class__
166
- if target_type and target_type is not actual_target_type:
167
- raise ValueError(
168
- f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}"
169
- )
157
+
158
+ # Validate the target type
159
+ actual_target_type = entities[target_key].__class__
160
+ if target_type and target_type is not actual_target_type:
161
+ raise ValueError(
162
+ f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}"
163
+ )
170
164
 
171
165
  defaults: Optional[DefaultsField] = Field(
172
166
  title="Default key/value entity values that are merged recursively for each entity.",
173
167
  default=None,
174
168
  )
175
169
 
176
- env: Union[Dict[str, str], ProjectEnvironment, None] = Field(
177
- title="Environment specification for this project.",
170
+ env: Optional[Dict[str, Union[str, int, bool]]] = Field(
171
+ title="Default environment specification for this project.",
178
172
  default=None,
179
- validation_alias="env",
180
- union_mode="smart",
181
173
  )
182
174
 
183
- @field_validator("env")
184
- @classmethod
185
- def _convert_env(
186
- cls, env: Union[Dict, ProjectEnvironment, None]
187
- ) -> ProjectEnvironment:
188
- if isinstance(env, ProjectEnvironment):
189
- return env
190
- return ProjectEnvironment(default_env=(env or {}), override_env={})
191
-
192
175
 
193
- def build_project_definition(**data):
176
+ def build_project_definition(**data) -> ProjectDefinition:
194
177
  """
195
178
  Returns a ProjectDefinition instance with a version matching the provided definition_version value
196
179
  """
@@ -19,9 +19,7 @@ from typing import Dict, List, Optional, Union
19
19
  from pydantic import Field, field_validator
20
20
  from snowflake.cli.api.project.schemas.identifier_model import ObjectIdentifierModel
21
21
  from snowflake.cli.api.project.schemas.snowpark.argument import Argument
22
- from snowflake.cli.api.project.schemas.updatable_model import (
23
- UpdatableModel,
24
- )
22
+ from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
25
23
 
26
24
 
27
25
  class _CallableBase(UpdatableModel):
@@ -14,17 +14,149 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Any, Dict
17
+ from contextlib import contextmanager
18
+ from contextvars import ContextVar
19
+ from typing import Any, Dict, Iterator, Optional
18
20
 
19
- from pydantic import BaseModel, ConfigDict, Field
21
+ from pydantic import (
22
+ BaseModel,
23
+ ConfigDict,
24
+ Field,
25
+ ValidationInfo,
26
+ field_validator,
27
+ )
28
+ from pydantic.fields import FieldInfo
20
29
  from snowflake.cli.api.project.util import IDENTIFIER_NO_LENGTH
21
30
 
31
+ PROJECT_TEMPLATE_START = "<%"
32
+
33
+
34
+ def _is_templated(info: ValidationInfo, value: Any) -> bool:
35
+ return (
36
+ info.context
37
+ and info.context.get("skip_validation_on_templates", False)
38
+ and isinstance(value, str)
39
+ and PROJECT_TEMPLATE_START in value
40
+ )
41
+
42
+
43
+ _initial_context: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
44
+ "_init_context_var", default=None
45
+ )
46
+
47
+
48
+ @contextmanager
49
+ def context(value: Dict[str, Any]) -> Iterator[None]:
50
+ """
51
+ Thread safe context for Pydantic.
52
+ By using `with context()`, you ensure context changes apply
53
+ to the with block only
54
+ """
55
+ token = _initial_context.set(value)
56
+ try:
57
+ yield
58
+ finally:
59
+ _initial_context.reset(token)
60
+
22
61
 
23
62
  class UpdatableModel(BaseModel):
24
63
  model_config = ConfigDict(validate_assignment=True, extra="forbid")
25
64
 
26
- def __init__(self, *args, **kwargs):
27
- super().__init__(**kwargs)
65
+ def __init__(self, /, **data: Any) -> None:
66
+ """
67
+ Pydantic provides 2 options to pass in context:
68
+ 1) Through `model_validate()` as a second argument.
69
+ 2) Through a custom init method and the use of ContextVar
70
+
71
+ We decided not to use 1) because it silently stops working
72
+ if someone adds a pass through __init__ to any of the Pydantic models.
73
+
74
+ We decided to go with 2) as the safer approach.
75
+ Calling validate_python() in the __init__ is how we can pass context
76
+ on initialization according to Pydantic's documentation:
77
+ https://docs.pydantic.dev/latest/concepts/validators/#using-validation-context-with-basemodel-initialization
78
+ """
79
+ self.__pydantic_validator__.validate_python(
80
+ data,
81
+ self_instance=self,
82
+ context=_initial_context.get(),
83
+ )
84
+
85
+ @classmethod
86
+ def _is_entity_type_field(cls, field: Any) -> bool:
87
+ """
88
+ Checks if a field is of type `DiscriminatorField`
89
+ """
90
+ if not isinstance(field, FieldInfo) or not field.json_schema_extra:
91
+ return False
92
+
93
+ return (
94
+ "is_discriminator_field" in field.json_schema_extra
95
+ and field.json_schema_extra["is_discriminator_field"]
96
+ )
97
+
98
+ @classmethod
99
+ def __init_subclass__(cls, **kwargs):
100
+ """
101
+ This method will collect all the Pydantic annotations for the class
102
+ currently being initialized (any subclass of `UpdatableModel`).
103
+
104
+ It will add a field validator wrapper for every Pydantic field
105
+ in order to skip validation when templates are found.
106
+
107
+ It will apply this to all Pydantic fields, except for fields
108
+ marked as `DiscriminatorField`. These will be skipped because
109
+ Pydantic does not support validators for discriminator field types.
110
+ """
111
+
112
+ super().__init_subclass__(**kwargs)
113
+
114
+ field_annotations = {}
115
+ field_values = {}
116
+ # Go through the inheritance classes and collect all the annotations and
117
+ # all the values of the class attributes. We go in reverse order so that
118
+ # values in subclasses overrides values from parent classes in case of field overrides.
119
+
120
+ for class_ in reversed(cls.__mro__):
121
+ class_dict = class_.__dict__
122
+ field_annotations.update(class_dict.get("__annotations__", {}))
123
+
124
+ if "model_fields" in class_dict:
125
+ # This means the class dict has already been processed by Pydantic
126
+ # All fields should properly be populated in model_fields
127
+ field_values.update(class_dict["model_fields"])
128
+ else:
129
+ # If Pydantic did not process this class yet, get the values from class_dict directly
130
+ field_values.update(class_dict)
131
+
132
+ # Add Pydantic validation wrapper around all fields except `DiscriminatorField`s
133
+ for field_name in field_annotations:
134
+ if not cls._is_entity_type_field(field_values.get(field_name)):
135
+ cls._add_validator(field_name)
136
+
137
+ @classmethod
138
+ def _add_validator(cls, field_name: str):
139
+ """
140
+ Adds a Pydantic validator with mode=wrap for the provided `field_name`.
141
+ During validation, this will check if the field is templated (not expanded yet)
142
+ and in that case, it will skip all the remaining Pydantic validation on that field.
143
+
144
+ Since this validator is added last, it will skip all the other field validators
145
+ defined in the subclasses when templates are found.
146
+
147
+ This logic on templates only applies when context contains `skip_validation_on_templates` flag.
148
+ """
149
+
150
+ def validator_skipping_templated_str(cls, value, handler, info: ValidationInfo):
151
+ if _is_templated(info, value):
152
+ return value
153
+ return handler(value)
154
+
155
+ setattr(
156
+ cls,
157
+ f"_field_validator_with_verbose_name_to_avoid_name_conflict_{field_name}",
158
+ field_validator(field_name, mode="wrap")(validator_skipping_templated_str),
159
+ )
28
160
 
29
161
  def update_from_dict(self, update_values: Dict[str, Any]):
30
162
  """
@@ -47,5 +179,16 @@ class UpdatableModel(BaseModel):
47
179
  return self
48
180
 
49
181
 
50
- def IdentifierField(*args, **kwargs): # noqa
182
+ def DiscriminatorField(*args, **kwargs): # noqa N802
183
+ """
184
+ Use this type for discriminator fields used for differentiating
185
+ between different entity types.
186
+
187
+ When this `DiscriminatorField` is used on a pydantic attribute,
188
+ we will not allow templating on it.
189
+ """
190
+ return Field(is_discriminator_field=True, *args, **kwargs)
191
+
192
+
193
+ def IdentifierField(*args, **kwargs): # noqa N802
51
194
  return Field(max_length=254, pattern=IDENTIFIER_NO_LENGTH, *args, **kwargs)