snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import pathlib
4
4
  import warnings
5
- from typing import Dict, List, Optional, cast
5
+ from typing import Optional, cast
6
6
 
7
7
  import yaml
8
8
 
@@ -45,10 +45,10 @@ class ModelManifest:
45
45
  self,
46
46
  model_meta: model_meta_api.ModelMetadata,
47
47
  model_rel_path: pathlib.PurePosixPath,
48
- user_files: Optional[Dict[str, List[str]]] = None,
48
+ user_files: Optional[dict[str, list[str]]] = None,
49
49
  options: Optional[type_hints.ModelSaveOption] = None,
50
- data_sources: Optional[List[data_source.DataSource]] = None,
51
- target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
50
+ data_sources: Optional[list[data_source.DataSource]] = None,
51
+ target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
52
52
  ) -> None:
53
53
  if options is None:
54
54
  options = {}
@@ -78,12 +78,13 @@ class ModelManifest:
78
78
  logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
79
79
  logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
80
80
  logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
81
+ logger.info(f"resource_constraint: {runtime_to_use.runtime_env.resource_constraint}")
81
82
  runtime_dict = runtime_to_use.save(
82
83
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
83
84
  )
84
85
 
85
86
  self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
86
- self.methods: List[model_method.ModelMethod] = []
87
+ self.methods: list[model_method.ModelMethod] = []
87
88
 
88
89
  for target_method in model_meta.signatures.keys():
89
90
  method = model_method.ModelMethod(
@@ -100,7 +101,7 @@ class ModelManifest:
100
101
 
101
102
  self.methods.append(method)
102
103
 
103
- self.user_files: List[model_user_file.ModelUserFile] = []
104
+ self.user_files: list[model_user_file.ModelUserFile] = []
104
105
 
105
106
  if user_files is not None:
106
107
  for subdirectory, paths in user_files.items():
@@ -127,16 +128,19 @@ class ModelManifest:
127
128
  if model_meta.env.artifact_repository_map:
128
129
  dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
129
130
 
131
+ runtime = model_manifest_schema.ModelRuntimeDict(
132
+ language="PYTHON",
133
+ version=runtime_to_use.runtime_env.python_version,
134
+ imports=runtime_dict["imports"],
135
+ dependencies=dependencies,
136
+ )
137
+
138
+ if runtime_dict["resource_constraint"]:
139
+ runtime["resource_constraint"] = runtime_dict["resource_constraint"]
140
+
130
141
  manifest_dict = model_manifest_schema.ModelManifestDict(
131
142
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
132
- runtimes={
133
- self._DEFAULT_RUNTIME_NAME: model_manifest_schema.ModelRuntimeDict(
134
- language="PYTHON",
135
- version=runtime_to_use.runtime_env.python_version,
136
- imports=runtime_dict["imports"],
137
- dependencies=dependencies,
138
- )
139
- },
143
+ runtimes={self._DEFAULT_RUNTIME_NAME: runtime},
140
144
  methods=[
141
145
  method.save(
142
146
  self.workspace_path,
@@ -178,8 +182,8 @@ class ModelManifest:
178
182
  return res
179
183
 
180
184
  def _extract_lineage_info(
181
- self, data_sources: Optional[List[data_source.DataSource]]
182
- ) -> List[model_manifest_schema.LineageSourceDict]:
185
+ self, data_sources: Optional[list[data_source.DataSource]]
186
+ ) -> list[model_manifest_schema.LineageSourceDict]:
183
187
  result = []
184
188
  if data_sources:
185
189
  for source in data_sources:
@@ -1,6 +1,6 @@
1
1
  # This files contains schema definition of what will be written into MANIFEST.yml
2
2
  import enum
3
- from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
3
+ from typing import Any, Literal, Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired, Required
6
6
 
@@ -20,14 +20,15 @@ class ModelMethodFunctionTypes(enum.Enum):
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
21
  conda: NotRequired[str]
22
22
  pip: NotRequired[str]
23
- artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
23
+ artifact_repository_map: NotRequired[Optional[dict[str, str]]]
24
24
 
25
25
 
26
26
  class ModelRuntimeDict(TypedDict):
27
27
  language: Required[Literal["PYTHON"]]
28
28
  version: Required[str]
29
- imports: Required[List[str]]
29
+ imports: Required[list[str]]
30
30
  dependencies: Required[ModelRuntimeDependenciesDict]
31
+ resource_constraint: NotRequired[Optional[dict[str, str]]]
31
32
 
32
33
 
33
34
  class ModelMethodSignatureField(TypedDict):
@@ -43,8 +44,8 @@ class ModelFunctionMethodDict(TypedDict):
43
44
  runtime: Required[str]
44
45
  type: Required[str]
45
46
  handler: Required[str]
46
- inputs: Required[List[ModelMethodSignatureFieldWithName]]
47
- outputs: Required[Union[List[ModelMethodSignatureField], List[ModelMethodSignatureFieldWithName]]]
47
+ inputs: Required[list[ModelMethodSignatureFieldWithName]]
48
+ outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
48
49
 
49
50
 
50
51
  ModelMethodDict = ModelFunctionMethodDict
@@ -71,12 +72,12 @@ class ModelFunctionInfo(TypedDict):
71
72
  class ModelFunctionInfoDict(TypedDict):
72
73
  name: Required[str]
73
74
  target_method: Required[str]
74
- signature: Required[Dict[str, Any]]
75
+ signature: Required[dict[str, Any]]
75
76
 
76
77
 
77
78
  class SnowparkMLDataDict(TypedDict):
78
79
  schema_version: Required[str]
79
- functions: Required[List[ModelFunctionInfoDict]]
80
+ functions: Required[list[ModelFunctionInfoDict]]
80
81
 
81
82
 
82
83
  class LineageSourceTypes(enum.Enum):
@@ -92,9 +93,9 @@ class LineageSourceDict(TypedDict):
92
93
 
93
94
  class ModelManifestDict(TypedDict):
94
95
  manifest_version: Required[str]
95
- runtimes: Required[Dict[str, ModelRuntimeDict]]
96
- methods: Required[List[ModelMethodDict]]
97
- user_data: NotRequired[Dict[str, Any]]
98
- user_files: NotRequired[List[str]]
99
- lineage_sources: NotRequired[List[LineageSourceDict]]
100
- target_platforms: NotRequired[List[str]]
96
+ runtimes: Required[dict[str, ModelRuntimeDict]]
97
+ methods: Required[list[ModelMethodDict]]
98
+ user_data: NotRequired[dict[str, Any]]
99
+ user_files: NotRequired[list[str]]
100
+ lineage_sources: NotRequired[list[LineageSourceDict]]
101
+ target_platforms: NotRequired[list[str]]
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import pathlib
3
- from typing import List, Optional, TypedDict, Union
3
+ from typing import Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired
6
6
 
@@ -137,8 +137,8 @@ class ModelMethod:
137
137
  )
138
138
 
139
139
  outputs: Union[
140
- List[model_manifest_schema.ModelMethodSignatureField],
141
- List[model_manifest_schema.ModelMethodSignatureFieldWithName],
140
+ list[model_manifest_schema.ModelMethodSignatureField],
141
+ list[model_manifest_schema.ModelMethodSignatureFieldWithName],
142
142
  ]
143
143
  if self.function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
144
144
  outputs = [
@@ -3,10 +3,11 @@ import itertools
3
3
  import os
4
4
  import pathlib
5
5
  import warnings
6
- from typing import DefaultDict, Dict, List, Optional
6
+ from typing import DefaultDict, Optional
7
7
 
8
8
  from packaging import requirements, version
9
9
 
10
+ from snowflake.ml import version as snowml_version
10
11
  from snowflake.ml._internal import env as snowml_env, env_utils
11
12
  from snowflake.ml.model._packager.model_meta import model_meta_schema
12
13
 
@@ -19,9 +20,8 @@ _DEFAULT_CONDA_ENV_FILENAME = "conda.yml"
19
20
  _DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
20
21
 
21
22
  # The default CUDA version is chosen based on the driver availability in SPCS.
22
- # If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to
23
- # make sure they are compatible.
24
- DEFAULT_CUDA_VERSION = "11.8"
23
+ # Make sure they are aligned with default CUDA version in inference server.
24
+ DEFAULT_CUDA_VERSION = "12.4"
25
25
 
26
26
 
27
27
  class ModelEnv:
@@ -29,22 +29,25 @@ class ModelEnv:
29
29
  self,
30
30
  conda_env_rel_path: Optional[str] = None,
31
31
  pip_requirements_rel_path: Optional[str] = None,
32
+ prefer_pip: bool = False,
32
33
  ) -> None:
33
34
  if conda_env_rel_path is None:
34
35
  conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
35
36
  if pip_requirements_rel_path is None:
36
37
  pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
38
+ self.prefer_pip: bool = prefer_pip
37
39
  self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
38
40
  self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
39
- self.artifact_repository_map: Optional[Dict[str, str]] = None
40
- self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
41
- self._pip_requirements: List[requirements.Requirement] = []
41
+ self.artifact_repository_map: Optional[dict[str, str]] = None
42
+ self.resource_constraint: Optional[dict[str, str]] = None
43
+ self._conda_dependencies: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
44
+ self._pip_requirements: list[requirements.Requirement] = []
42
45
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
43
46
  self._cuda_version: Optional[version.Version] = None
44
- self._snowpark_ml_version: version.Version = version.parse(snowml_env.VERSION)
47
+ self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
45
48
 
46
49
  @property
47
- def conda_dependencies(self) -> List[str]:
50
+ def conda_dependencies(self) -> list[str]:
48
51
  """List of conda channel and dependencies from that to run the model"""
49
52
  return sorted(
50
53
  f"{chan}::{str(req)}" if chan else str(req)
@@ -55,24 +58,24 @@ class ModelEnv:
55
58
  @conda_dependencies.setter
56
59
  def conda_dependencies(
57
60
  self,
58
- conda_dependencies: Optional[List[str]] = None,
61
+ conda_dependencies: Optional[list[str]] = None,
59
62
  ) -> None:
60
63
  self._conda_dependencies = env_utils.validate_conda_dependency_string_list(
61
- conda_dependencies if conda_dependencies else []
64
+ conda_dependencies if conda_dependencies else [], add_local_version_specifier=True
62
65
  )
63
66
 
64
67
  @property
65
- def pip_requirements(self) -> List[str]:
68
+ def pip_requirements(self) -> list[str]:
66
69
  """List of pip Python packages requirements for running the model."""
67
70
  return sorted(list(map(str, self._pip_requirements)))
68
71
 
69
72
  @pip_requirements.setter
70
73
  def pip_requirements(
71
74
  self,
72
- pip_requirements: Optional[List[str]] = None,
75
+ pip_requirements: Optional[list[str]] = None,
73
76
  ) -> None:
74
77
  self._pip_requirements = env_utils.validate_pip_requirement_string_list(
75
- pip_requirements if pip_requirements else []
78
+ pip_requirements if pip_requirements else [], add_local_version_specifier=True
76
79
  )
77
80
 
78
81
  @property
@@ -113,7 +116,11 @@ class ModelEnv:
113
116
  if snowpark_ml_version:
114
117
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
115
118
 
116
- def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
119
+ def include_if_absent(
120
+ self,
121
+ pkgs: list[ModelDependency],
122
+ check_local_version: bool = False,
123
+ ) -> None:
117
124
  """Append requirements into model env if absent. Depending on the environment, requirements may be added
118
125
  to either the pip requirements or conda dependencies.
119
126
 
@@ -121,8 +128,8 @@ class ModelEnv:
121
128
  pkgs: A list of ModelDependency namedtuple to be appended.
122
129
  check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
123
130
  """
124
- if self.pip_requirements and not self.conda_dependencies and pkgs:
125
- pip_pkg_reqs: List[str] = []
131
+ if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
132
+ pip_pkg_reqs: list[str] = []
126
133
  warnings.warn(
127
134
  (
128
135
  "Dependencies specified from pip requirements."
@@ -139,7 +146,7 @@ class ModelEnv:
139
146
  else:
140
147
  self._include_if_absent_conda(pkgs, check_local_version)
141
148
 
142
- def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
149
+ def _include_if_absent_conda(self, pkgs: list[ModelDependency], check_local_version: bool = False) -> None:
143
150
  """Append requirements into model env conda dependencies if absent.
144
151
 
145
152
  Args:
@@ -184,7 +191,7 @@ class ModelEnv:
184
191
  stacklevel=2,
185
192
  )
186
193
 
187
- def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
194
+ def _include_if_absent_pip(self, pkgs: list[str], check_local_version: bool = False) -> None:
188
195
  """Append pip requirements into model env pip requirements if absent.
189
196
 
190
197
  Args:
@@ -201,7 +208,7 @@ class ModelEnv:
201
208
  except env_utils.DuplicateDependencyError:
202
209
  pass
203
210
 
204
- def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
211
+ def remove_if_present_conda(self, conda_pkgs: list[str]) -> None:
205
212
  """Remove conda requirements from model env if present.
206
213
 
207
214
  Args:
@@ -346,13 +353,14 @@ class ModelEnv:
346
353
  def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
347
354
  self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
348
355
  self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
349
- self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
356
+ self.artifact_repository_map = env_dict.get("artifact_repository_map")
357
+ self.resource_constraint = env_dict.get("resource_constraint")
350
358
 
351
359
  self.load_from_conda_file(base_dir / self.conda_env_rel_path)
352
360
  self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
353
361
 
354
362
  self.python_version = env_dict["python_version"]
355
- self.cuda_version = env_dict.get("cuda_version", None)
363
+ self.cuda_version = env_dict.get("cuda_version")
356
364
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
357
365
 
358
366
  def save_as_dict(
@@ -375,7 +383,8 @@ class ModelEnv:
375
383
  return {
376
384
  "conda": self.conda_env_rel_path.as_posix(),
377
385
  "pip": self.pip_requirements_rel_path.as_posix(),
378
- "artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
386
+ "artifact_repository_map": self.artifact_repository_map or {},
387
+ "resource_constraint": self.resource_constraint or {},
379
388
  "python_version": self.python_version,
380
389
  "cuda_version": self.cuda_version,
381
390
  "snowpark_ml_version": self.snowpark_ml_version,
@@ -383,7 +392,7 @@ class ModelEnv:
383
392
 
384
393
  def validate_with_local_env(
385
394
  self, check_snowpark_ml_version: bool = False
386
- ) -> List[env_utils.IncorrectLocalEnvironmentError]:
395
+ ) -> list[env_utils.IncorrectLocalEnvironmentError]:
387
396
  errors = []
388
397
  try:
389
398
  env_utils.validate_py_runtime_version(str(self._python_version))
@@ -407,10 +416,10 @@ class ModelEnv:
407
416
 
408
417
  if check_snowpark_ml_version:
409
418
  # For Modeling model
410
- if self._snowpark_ml_version.base_version != snowml_env.VERSION:
419
+ if self._snowpark_ml_version.base_version != snowml_version.VERSION:
411
420
  errors.append(
412
421
  env_utils.IncorrectLocalEnvironmentError(
413
- f"The local installed version of Snowpark ML library is {snowml_env.VERSION} "
422
+ f"The local installed version of Snowpark ML library is {snowml_version.VERSION} "
414
423
  f"which differs from required version {self.snowpark_ml_version}."
415
424
  )
416
425
  )
@@ -2,13 +2,13 @@ import functools
2
2
  import importlib
3
3
  import pkgutil
4
4
  from types import ModuleType
5
- from typing import Any, Callable, Dict, Optional, Type, TypeVar, cast
5
+ from typing import Any, Callable, Optional, TypeVar, cast
6
6
 
7
7
  from snowflake.ml.model import type_hints as model_types
8
8
  from snowflake.ml.model._packager.model_handlers import _base
9
9
 
10
10
  _HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
11
- _MODEL_HANDLER_REGISTRY: Dict[str, Type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
11
+ _MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
12
12
  _IS_HANDLER_LOADED = False
13
13
 
14
14
 
@@ -54,7 +54,7 @@ def ensure_handlers_registration(fn: F) -> F:
54
54
  @ensure_handlers_registration
55
55
  def find_handler(
56
56
  model: model_types.SupportedModelType,
57
- ) -> Optional[Type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
57
+ ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
58
58
  for handler in _MODEL_HANDLER_REGISTRY.values():
59
59
  if handler.can_handle(model):
60
60
  return handler
@@ -64,7 +64,7 @@ def find_handler(
64
64
  @ensure_handlers_registration
65
65
  def load_handler(
66
66
  target_model_type: model_types.SupportedModelHandlerType,
67
- ) -> Optional[Type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
67
+ ) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
68
68
  for model_type, handler in _MODEL_HANDLER_REGISTRY.items():
69
69
  if target_model_type == model_type:
70
70
  return handler
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  from abc import abstractmethod
3
- from typing import Dict, Generic, Optional, Protocol, Type, final
3
+ from typing import Generic, Optional, Protocol, final
4
4
 
5
5
  import pandas as pd
6
6
  from typing_extensions import TypeGuard, Unpack
@@ -14,7 +14,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
14
14
  HANDLER_TYPE: model_types.SupportedModelHandlerType
15
15
  HANDLER_VERSION: str
16
16
  _MIN_SNOWPARK_ML_VERSION: str
17
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]]
17
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]]
18
18
 
19
19
  @classmethod
20
20
  @abstractmethod
@@ -1,8 +1,9 @@
1
+ import importlib
1
2
  import json
2
3
  import os
3
4
  import pathlib
4
5
  import warnings
5
- from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
6
+ from typing import Any, Callable, Iterable, Optional, Sequence, cast
6
7
 
7
8
  import numpy as np
8
9
  import numpy.typing as npt
@@ -10,8 +11,10 @@ import pandas as pd
10
11
  from absl import logging
11
12
 
12
13
  import snowflake.snowpark.dataframe as sp_df
14
+ from snowflake.ml._internal import env
13
15
  from snowflake.ml._internal.utils import identifier
14
16
  from snowflake.ml.model import model_signature, type_hints as model_types
17
+ from snowflake.ml.model._packager.model_env import model_env
15
18
  from snowflake.ml.model._packager.model_meta import model_meta
16
19
  from snowflake.ml.model._signatures import (
17
20
  core,
@@ -231,7 +234,7 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
231
234
 
232
235
 
233
236
  def get_explain_target_method(
234
- model_metadata: model_meta.ModelMetadata, target_methods_list: List[str]
237
+ model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
235
238
  ) -> Optional[str]:
236
239
  for method in model_metadata.signatures.keys():
237
240
  if method in target_methods_list:
@@ -248,7 +251,7 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
248
251
  config_dict = json.load(f)
249
252
 
250
253
  # a. get repository and class_path from configs
251
- auto_map_configs = cast(Dict[str, str], config_dict.get("auto_map", {}))
254
+ auto_map_configs = cast(dict[str, str], config_dict.get("auto_map", {}))
252
255
  for config_name, config_value in auto_map_configs.items():
253
256
  repository, _, class_path = config_value.rpartition("--")
254
257
 
@@ -261,3 +264,12 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
261
264
 
262
265
  with open(f_path, "w") as f:
263
266
  json.dump(config_dict, f)
267
+
268
+
269
+ def get_default_cuda_version() -> str:
270
+ # Default to the env cuda version when running in ML runtime
271
+ if env.IN_ML_RUNTIME and importlib.util.find_spec("torch") is not None:
272
+ import torch
273
+
274
+ return torch.version.cuda or model_env.DEFAULT_CUDA_VERSION
275
+ return model_env.DEFAULT_CUDA_VERSION
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -30,7 +30,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
30
30
  HANDLER_TYPE = "catboost"
31
31
  HANDLER_VERSION = "2024-03-21"
32
32
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
33
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
33
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
34
34
 
35
35
  MODEL_BLOB_FILE_OR_DIR = "model.bin"
36
36
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -147,7 +147,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
147
147
  if enable_explainability:
148
148
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
149
149
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
150
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
150
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
151
151
 
152
152
  return None
153
153
 
@@ -202,7 +202,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
202
202
  def _create_custom_model(
203
203
  raw_model: "catboost.CatBoost",
204
204
  model_meta: model_meta_api.ModelMetadata,
205
- ) -> Type[custom_model.CustomModel]:
205
+ ) -> type[custom_model.CustomModel]:
206
206
  def fn_factory(
207
207
  raw_model: "catboost.CatBoost",
208
208
  signature: model_signature.ModelSignature,
@@ -235,7 +235,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
235
235
 
236
236
  return fn
237
237
 
238
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
238
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
239
239
  for target_method_name, sig in model_meta.signatures.items():
240
240
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
241
241
 
@@ -2,7 +2,7 @@ import inspect
2
2
  import os
3
3
  import pathlib
4
4
  import sys
5
- from typing import Dict, Optional, Type, cast, final
5
+ from typing import Optional, cast, final
6
6
 
7
7
  import anyio
8
8
  import cloudpickle
@@ -28,7 +28,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
28
28
  HANDLER_TYPE = "custom"
29
29
  HANDLER_VERSION = "2023-12-01"
30
30
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
31
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
31
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
32
32
 
33
33
  @classmethod
34
34
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["custom_model.CustomModel"]:
@@ -99,7 +99,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
101
  if handler is None:
102
- raise TypeError("Your input type to custom model is not currently supported")
102
+ raise TypeError(
103
+ f"Model {sub_name} in model context is not a supported model type. See "
104
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/"
105
+ "bring-your-own-model-types for more details."
106
+ )
103
107
  sub_model = handler.cast_model(model_ref.model)
104
108
  handler.save_model(
105
109
  name=sub_name,
@@ -161,7 +165,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
161
165
  name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
162
166
  for name, rel_path in artifacts_meta.items()
163
167
  }
164
- models: Dict[str, model_types.SupportedModelType] = dict()
168
+ models: dict[str, model_types.SupportedModelType] = dict()
165
169
  for sub_model_name, _ref in context.model_refs.items():
166
170
  model_type = model_meta.models[sub_model_name].model_type
167
171
  handler = model_handler.load_handler(model_type)
@@ -1,18 +1,7 @@
1
1
  import json
2
2
  import os
3
3
  import warnings
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Callable,
8
- Dict,
9
- List,
10
- Optional,
11
- Type,
12
- Union,
13
- cast,
14
- final,
15
- )
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
16
5
 
17
6
  import cloudpickle
18
7
  import numpy as np
@@ -38,7 +27,7 @@ if TYPE_CHECKING:
38
27
  import transformers
39
28
 
40
29
 
41
- def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model_env.ModelDependency]:
30
+ def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
42
31
  # Text
43
32
  if task in [
44
33
  "conversational",
@@ -84,7 +73,7 @@ class HuggingFacePipelineHandler(
84
73
  HANDLER_TYPE = "huggingface_pipeline"
85
74
  HANDLER_VERSION = "2023-12-01"
86
75
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
87
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
76
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
88
77
 
89
78
  MODEL_BLOB_FILE_OR_DIR = "model"
90
79
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
@@ -250,20 +239,17 @@ class HuggingFacePipelineHandler(
250
239
  task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
251
240
  )
252
241
  if framework is None or framework == "pt":
253
- # Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
254
- # Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
255
- # users are not required to install pytorch locally if they are using the wrapper.
256
242
  pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
257
243
  elif framework == "tf":
258
244
  pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
259
245
  model_meta.env.include_if_absent(
260
246
  pkgs_requirements, check_local_version=(type_utils.LazyType("transformers.Pipeline").isinstance(model))
261
247
  )
262
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
248
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
263
249
 
264
250
  @staticmethod
265
- def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
266
- device_config: Dict[str, Any] = {}
251
+ def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> dict[str, str]:
252
+ device_config: dict[str, Any] = {}
267
253
  cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
268
254
  gpu_nums = 0
269
255
  if cuda_visible_devices is not None:
@@ -369,7 +355,7 @@ class HuggingFacePipelineHandler(
369
355
  def _create_custom_model(
370
356
  raw_model: "transformers.Pipeline",
371
357
  model_meta: model_meta_api.ModelMetadata,
372
- ) -> Type[custom_model.CustomModel]:
358
+ ) -> type[custom_model.CustomModel]:
373
359
  def fn_factory(
374
360
  raw_model: "transformers.Pipeline",
375
361
  signature: model_signature.ModelSignature,
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -32,7 +32,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
32
32
  HANDLER_TYPE = "keras"
33
33
  HANDLER_VERSION = "2025-01-01"
34
34
  _MIN_SNOWPARK_ML_VERSION = "1.7.5"
35
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
35
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
36
36
 
37
37
  MODEL_BLOB_FILE_OR_DIR = "model.keras"
38
38
  CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
@@ -146,7 +146,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
146
146
  dependencies,
147
147
  check_local_version=True,
148
148
  )
149
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
149
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
150
150
 
151
151
  @classmethod
152
152
  def load_model(
@@ -185,7 +185,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
185
185
  def _create_custom_model(
186
186
  raw_model: "keras.Model",
187
187
  model_meta: model_meta_api.ModelMetadata,
188
- ) -> Type[custom_model.CustomModel]:
188
+ ) -> type[custom_model.CustomModel]:
189
189
  def fn_factory(
190
190
  raw_model: "keras.Model",
191
191
  signature: model_signature.ModelSignature,