snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import (
4
- TYPE_CHECKING,
5
- Any,
6
- Callable,
7
- Dict,
8
- Optional,
9
- Type,
10
- Union,
11
- cast,
12
- final,
13
- )
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
14
4
 
15
5
  import cloudpickle
16
6
  import numpy as np
@@ -41,7 +31,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
41
31
  HANDLER_TYPE = "lightgbm"
42
32
  HANDLER_VERSION = "2024-03-19"
43
33
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
44
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
34
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
45
35
 
46
36
  MODEL_BLOB_FILE_OR_DIR = "model.pkl"
47
37
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -215,7 +205,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
215
205
  def _create_custom_model(
216
206
  raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
217
207
  model_meta: model_meta_api.ModelMetadata,
218
- ) -> Type[custom_model.CustomModel]:
208
+ ) -> type[custom_model.CustomModel]:
219
209
  def fn_factory(
220
210
  raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
221
211
  signature: model_signature.ModelSignature,
@@ -250,7 +240,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
250
240
 
251
241
  return fn
252
242
 
253
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
243
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
254
244
  for target_method_name, sig in model_meta.signatures.items():
255
245
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
256
246
 
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import pathlib
3
3
  import tempfile
4
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
5
5
 
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
@@ -61,7 +61,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
61
61
  HANDLER_TYPE = "mlflow"
62
62
  HANDLER_VERSION = "2023-12-01"
63
63
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
64
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
64
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
65
65
 
66
66
  MODEL_BLOB_FILE_OR_DIR = "model"
67
67
  _DEFAULT_TARGET_METHOD = "predict"
@@ -204,7 +204,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
204
204
  def _create_custom_model(
205
205
  raw_model: "mlflow.pyfunc.PyFuncModel",
206
206
  model_meta: model_meta_api.ModelMetadata,
207
- ) -> Type[custom_model.CustomModel]:
207
+ ) -> type[custom_model.CustomModel]:
208
208
  def fn_factory(
209
209
  raw_model: "mlflow.pyfunc.PyFuncModel",
210
210
  signature: model_signature.ModelSignature,
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import sys
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import pandas as pd
@@ -38,7 +38,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
38
38
  HANDLER_TYPE = "pytorch"
39
39
  HANDLER_VERSION = "2025-03-01"
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
41
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
42
42
  "2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
43
43
  }
44
44
 
@@ -151,7 +151,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
151
151
  model_meta.env.include_if_absent(
152
152
  [model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
153
153
  )
154
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
154
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
155
155
 
156
156
  @classmethod
157
157
  def load_model(
@@ -188,7 +188,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
188
188
  def _create_custom_model(
189
189
  raw_model: "torch.nn.Module",
190
190
  model_meta: model_meta_api.ModelMetadata,
191
- ) -> Type[custom_model.CustomModel]:
191
+ ) -> type[custom_model.CustomModel]:
192
192
  multiple_inputs = cast(
193
193
  model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
194
194
  )["multiple_inputs"]
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import logging
3
3
  import os
4
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
5
5
 
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
- def _validate_sentence_transformers_signatures(sigs: Dict[str, model_signature.ModelSignature]) -> None:
27
+ def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
28
28
  if list(sigs.keys()) != ["encode"]:
29
29
  raise ValueError("target_methods can only be ['encode']")
30
30
 
@@ -48,7 +48,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
48
48
  HANDLER_TYPE = "sentence_transformers"
49
49
  HANDLER_VERSION = "2024-03-15"
50
50
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
51
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
51
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
52
52
 
53
53
  MODEL_BLOB_FILE_OR_DIR = "model"
54
54
  DEFAULT_TARGET_METHODS = ["encode"]
@@ -166,7 +166,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
166
166
  ],
167
167
  check_local_version=True,
168
168
  )
169
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
169
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
170
170
 
171
171
  @staticmethod
172
172
  def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
@@ -224,7 +224,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
224
224
  def _create_custom_model(
225
225
  raw_model: "sentence_transformers.SentenceTransformer",
226
226
  model_meta: model_meta_api.ModelMetadata,
227
- ) -> Type[custom_model.CustomModel]:
227
+ ) -> type[custom_model.CustomModel]:
228
228
  batch_size = cast(
229
229
  model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
230
230
  ).get("batch_size", None)
@@ -1,13 +1,13 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, Union, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
8
  from typing_extensions import TypeGuard, Unpack
9
9
 
10
- from snowflake.ml._internal import type_utils
10
+ from snowflake.ml._internal import env, type_utils
11
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._packager.model_env import model_env
13
13
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
@@ -19,7 +19,6 @@ from snowflake.ml.model._packager.model_meta import (
19
19
  )
20
20
  from snowflake.ml.model._packager.model_task import model_task_utils
21
21
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
22
- from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
23
22
 
24
23
  if TYPE_CHECKING:
25
24
  import sklearn.base
@@ -49,7 +48,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
49
48
  HANDLER_TYPE = "sklearn"
50
49
  HANDLER_VERSION = "2023-12-01"
51
50
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
52
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
51
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
53
52
 
54
53
  DEFAULT_TARGET_METHODS = [
55
54
  "predict",
@@ -113,7 +112,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
113
112
  raise ValueError("Sample input data is required to enable explainability.")
114
113
 
115
114
  # If this is a pipeline and we are in the container runtime, check for distributed estimator.
116
- if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
115
+ if env.IN_ML_RUNTIME and isinstance(model, sklearn.pipeline.Pipeline):
117
116
  model = _unpack_container_runtime_pipeline(model)
118
117
 
119
118
  if not is_sub_model:
@@ -265,7 +264,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
265
264
  def _create_custom_model(
266
265
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
267
266
  model_meta: model_meta_api.ModelMetadata,
268
- ) -> Type[custom_model.CustomModel]:
267
+ ) -> type[custom_model.CustomModel]:
269
268
  def fn_factory(
270
269
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
271
270
  signature: model_signature.ModelSignature,
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
@@ -36,7 +36,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
36
36
  HANDLER_TYPE = "snowml"
37
37
  HANDLER_VERSION = "2023-12-01"
38
38
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
40
40
 
41
41
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
42
42
  EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
@@ -264,7 +264,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
264
264
  def _create_custom_model(
265
265
  raw_model: "BaseEstimator",
266
266
  model_meta: model_meta_api.ModelMetadata,
267
- ) -> Type[custom_model.CustomModel]:
267
+ ) -> type[custom_model.CustomModel]:
268
268
  def fn_factory(
269
269
  raw_model: "BaseEstimator",
270
270
  signature: model_signature.ModelSignature,
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import pandas as pd
5
5
  from packaging import version
@@ -38,7 +38,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
38
38
  HANDLER_TYPE = "tensorflow"
39
39
  HANDLER_VERSION = "2025-03-01"
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
41
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
42
42
  "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
43
43
  "2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
44
44
  }
@@ -188,7 +188,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
188
188
  dependencies,
189
189
  check_local_version=True,
190
190
  )
191
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
191
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
192
192
 
193
193
  @classmethod
194
194
  def load_model(
@@ -230,7 +230,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
230
230
  def _create_custom_model(
231
231
  raw_model: "tensorflow.Module",
232
232
  model_meta: model_meta_api.ModelMetadata,
233
- ) -> Type[custom_model.CustomModel]:
233
+ ) -> type[custom_model.CustomModel]:
234
234
  multiple_inputs = cast(
235
235
  model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
236
236
  )["multiple_inputs"]
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import pandas as pd
5
5
  from typing_extensions import TypeGuard, Unpack
@@ -36,7 +36,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
36
36
  HANDLER_TYPE = "torchscript"
37
37
  HANDLER_VERSION = "2025-03-01"
38
38
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
39
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
40
40
  "2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
41
41
  }
42
42
 
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
141
141
  model_meta.env.include_if_absent(
142
142
  [model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
143
143
  )
144
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
144
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
145
145
 
146
146
  @classmethod
147
147
  def load_model(
@@ -181,7 +181,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
181
181
  def _create_custom_model(
182
182
  raw_model: "torch.jit.ScriptModule",
183
183
  model_meta: model_meta_api.ModelMetadata,
184
- ) -> Type[custom_model.CustomModel]:
184
+ ) -> type[custom_model.CustomModel]:
185
185
  def fn_factory(
186
186
  raw_model: "torch.jit.ScriptModule",
187
187
  signature: model_signature.ModelSignature,
@@ -1,17 +1,7 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
3
  import warnings
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Callable,
8
- Dict,
9
- Optional,
10
- Type,
11
- Union,
12
- cast,
13
- final,
14
- )
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
15
5
 
16
6
  import numpy as np
17
7
  import pandas as pd
@@ -44,7 +34,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
44
34
  HANDLER_TYPE = "xgboost"
45
35
  HANDLER_VERSION = "2023-12-01"
46
36
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
47
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
37
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
48
38
 
49
39
  MODEL_BLOB_FILE_OR_DIR = "model.ubj"
50
40
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -175,7 +165,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
175
165
  if enable_explainability:
176
166
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
177
167
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
178
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
168
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
179
169
 
180
170
  @classmethod
181
171
  def load_model(
@@ -227,7 +217,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
227
217
  def _create_custom_model(
228
218
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
229
219
  model_meta: model_meta_api.ModelMetadata,
230
- ) -> Type[custom_model.CustomModel]:
220
+ ) -> type[custom_model.CustomModel]:
231
221
  def fn_factory(
232
222
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
233
223
  signature: model_signature.ModelSignature,
@@ -261,7 +251,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
261
251
  return explain_fn
262
252
  return fn
263
253
 
264
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
254
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
265
255
  for target_method_name, sig in model_meta.signatures.items():
266
256
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
267
257
 
@@ -1,4 +1,4 @@
1
- from typing import Dict, cast
1
+ from typing import cast
2
2
 
3
3
  from typing_extensions import Unpack
4
4
 
@@ -25,7 +25,7 @@ class ModelBlobMeta:
25
25
  self.handler_version = kwargs["handler_version"]
26
26
  self.function_properties = kwargs.get("function_properties", {})
27
27
 
28
- self.artifacts: Dict[str, str] = {}
28
+ self.artifacts: dict[str, str] = {}
29
29
  artifacts = kwargs.get("artifacts", None)
30
30
  if artifacts:
31
31
  self.artifacts = artifacts
@@ -6,31 +6,26 @@ import zipfile
6
6
  from contextlib import contextmanager
7
7
  from datetime import datetime
8
8
  from types import ModuleType
9
- from typing import Any, Dict, Generator, List, Optional, TypedDict
9
+ from typing import Any, Generator, Optional, TypedDict
10
10
 
11
11
  import cloudpickle
12
12
  import yaml
13
13
  from packaging import requirements, version
14
14
  from typing_extensions import Required
15
15
 
16
- from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
16
+ from snowflake.ml import version as snowml_version
17
+ from snowflake.ml._internal import env_utils, file_utils
17
18
  from snowflake.ml.model import model_signature, type_hints as model_types
18
19
  from snowflake.ml.model._packager.model_env import model_env
19
- from snowflake.ml.model._packager.model_meta import (
20
- _packaging_requirements,
21
- model_blob_meta,
22
- model_meta_schema,
23
- )
20
+ from snowflake.ml.model._packager.model_meta import model_blob_meta, model_meta_schema
24
21
  from snowflake.ml.model._packager.model_meta_migrator import migrator_plans
25
22
  from snowflake.ml.model._packager.model_runtime import model_runtime
26
23
 
27
24
  MODEL_METADATA_FILE = "model.yaml"
28
25
  MODEL_CODE_DIR = "code"
29
26
 
30
- _PACKAGING_REQUIREMENTS = [
31
- str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r)))
32
- for r in _packaging_requirements.REQUIREMENTS
33
- ]
27
+ _PACKAGING_REQUIREMENTS = ["cloudpickle"]
28
+
34
29
  _SNOWFLAKE_PKG_NAME = "snowflake"
35
30
  _SNOWFLAKE_ML_PKG_NAME = f"{_SNOWFLAKE_PKG_NAME}.ml"
36
31
 
@@ -41,14 +36,16 @@ def create_model_metadata(
41
36
  model_dir_path: str,
42
37
  name: str,
43
38
  model_type: model_types.SupportedModelHandlerType,
44
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
45
- function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
46
- metadata: Optional[Dict[str, str]] = None,
47
- code_paths: Optional[List[str]] = None,
48
- ext_modules: Optional[List[ModuleType]] = None,
49
- conda_dependencies: Optional[List[str]] = None,
50
- pip_requirements: Optional[List[str]] = None,
51
- artifact_repository_map: Optional[Dict[str, str]] = None,
39
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
40
+ function_properties: Optional[dict[str, dict[str, Any]]] = None,
41
+ metadata: Optional[dict[str, str]] = None,
42
+ code_paths: Optional[list[str]] = None,
43
+ ext_modules: Optional[list[ModuleType]] = None,
44
+ conda_dependencies: Optional[list[str]] = None,
45
+ pip_requirements: Optional[list[str]] = None,
46
+ artifact_repository_map: Optional[dict[str, str]] = None,
47
+ resource_constraint: Optional[dict[str, str]] = None,
48
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
52
49
  python_version: Optional[str] = None,
53
50
  task: model_types.Task = model_types.Task.UNKNOWN,
54
51
  **kwargs: Any,
@@ -69,6 +66,8 @@ def create_model_metadata(
69
66
  conda_dependencies: List of conda requirements for running the model. Defaults to None.
70
67
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
71
68
  artifact_repository_map: A dict mapping from package channel to artifact repository name.
69
+ resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
70
+ target_platforms: List of target platforms to run the model.
72
71
  python_version: A string of python version where model is run. Used for user override. If specified as None,
73
72
  current version would be captured. Defaults to None.
74
73
  task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
@@ -101,16 +100,19 @@ def create_model_metadata(
101
100
  else:
102
101
  raise ValueError("`snowflake.ml` is imported via a way that embedding local ML library is not supported.")
103
102
 
103
+ prefer_pip = target_platforms == [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
104
104
  env = _create_env_for_model_metadata(
105
105
  conda_dependencies=conda_dependencies,
106
106
  pip_requirements=pip_requirements,
107
107
  artifact_repository_map=artifact_repository_map,
108
+ resource_constraint=resource_constraint,
108
109
  python_version=python_version,
109
110
  embed_local_ml_library=embed_local_ml_library,
111
+ prefer_pip=prefer_pip,
110
112
  )
111
113
 
112
114
  if embed_local_ml_library:
113
- env.snowpark_ml_version = f"{snowml_env.VERSION}+{file_utils.hash_directory(path_to_copy)}"
115
+ env.snowpark_ml_version = f"{snowml_version.VERSION}+{file_utils.hash_directory(path_to_copy)}"
114
116
 
115
117
  model_meta = ModelMetadata(
116
118
  name=name,
@@ -152,20 +154,23 @@ def create_model_metadata(
152
154
 
153
155
  def _create_env_for_model_metadata(
154
156
  *,
155
- conda_dependencies: Optional[List[str]] = None,
156
- pip_requirements: Optional[List[str]] = None,
157
- artifact_repository_map: Optional[Dict[str, str]] = None,
157
+ conda_dependencies: Optional[list[str]] = None,
158
+ pip_requirements: Optional[list[str]] = None,
159
+ artifact_repository_map: Optional[dict[str, str]] = None,
160
+ resource_constraint: Optional[dict[str, str]] = None,
158
161
  python_version: Optional[str] = None,
159
162
  embed_local_ml_library: bool = False,
163
+ prefer_pip: bool = False,
160
164
  ) -> model_env.ModelEnv:
161
- env = model_env.ModelEnv()
165
+ env = model_env.ModelEnv(prefer_pip=prefer_pip)
162
166
 
163
167
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
164
168
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
165
169
  env.pip_requirements = pip_requirements # type: ignore[assignment]
166
170
  env.artifact_repository_map = artifact_repository_map
171
+ env.resource_constraint = resource_constraint
167
172
  env.python_version = python_version # type: ignore[assignment]
168
- env.snowpark_ml_version = snowml_env.VERSION
173
+ env.snowpark_ml_version = snowml_version.VERSION
169
174
 
170
175
  requirements_to_add = _PACKAGING_REQUIREMENTS
171
176
 
@@ -237,20 +242,20 @@ class ModelMetadata:
237
242
  name: str,
238
243
  env: model_env.ModelEnv,
239
244
  model_type: model_types.SupportedModelHandlerType,
240
- runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
241
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
242
- function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
243
- user_files: Optional[Dict[str, List[str]]] = None,
244
- metadata: Optional[Dict[str, str]] = None,
245
+ runtimes: Optional[dict[str, model_runtime.ModelRuntime]] = None,
246
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
247
+ function_properties: Optional[dict[str, dict[str, Any]]] = None,
248
+ user_files: Optional[dict[str, list[str]]] = None,
249
+ metadata: Optional[dict[str, str]] = None,
245
250
  creation_timestamp: Optional[str] = None,
246
251
  min_snowpark_ml_version: Optional[str] = None,
247
- models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
252
+ models: Optional[dict[str, model_blob_meta.ModelBlobMeta]] = None,
248
253
  original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
249
254
  task: model_types.Task = model_types.Task.UNKNOWN,
250
255
  explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
251
256
  ) -> None:
252
257
  self.name = name
253
- self.signatures: Dict[str, model_signature.ModelSignature] = dict()
258
+ self.signatures: dict[str, model_signature.ModelSignature] = dict()
254
259
  if signatures:
255
260
  self.signatures = signatures
256
261
  self.function_properties = function_properties or {}
@@ -265,7 +270,7 @@ class ModelMetadata:
265
270
  else model_meta_schema.MODEL_METADATA_MIN_SNOWPARK_ML_VERSION
266
271
  )
267
272
 
268
- self.models: Dict[str, model_blob_meta.ModelBlobMeta] = dict()
273
+ self.models: dict[str, model_blob_meta.ModelBlobMeta] = dict()
269
274
  if models:
270
275
  self.models = models
271
276
 
@@ -286,7 +291,7 @@ class ModelMetadata:
286
291
  self._min_snowpark_ml_version = max(self._min_snowpark_ml_version, parsed_min_snowpark_ml_version)
287
292
 
288
293
  @property
289
- def runtimes(self) -> Dict[str, model_runtime.ModelRuntime]:
294
+ def runtimes(self) -> dict[str, model_runtime.ModelRuntime]:
290
295
  if self._runtimes and "cpu" in self._runtimes:
291
296
  return self._runtimes
292
297
  runtimes = {
@@ -353,11 +358,11 @@ class ModelMetadata:
353
358
 
354
359
  loaded_meta_min_snowpark_ml_version = loaded_meta.get("min_snowpark_ml_version", None)
355
360
  if not loaded_meta_min_snowpark_ml_version or (
356
- version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(snowml_env.VERSION)
361
+ version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(snowml_version.VERSION)
357
362
  ):
358
363
  raise RuntimeError(
359
364
  f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version}, "
360
- f"while current version of Snowpark ML library is {snowml_env.VERSION}."
365
+ f"while current version of Snowpark ML library is {snowml_version.VERSION}."
361
366
  )
362
367
  return model_meta_schema.ModelMetadataDict(
363
368
  creation_timestamp=loaded_meta["creation_timestamp"],
@@ -400,7 +405,7 @@ class ModelMetadata:
400
405
  env = model_env.ModelEnv()
401
406
  env.load_from_dict(pathlib.Path(model_dir_path), model_dict["env"])
402
407
 
403
- runtimes: Optional[Dict[str, model_runtime.ModelRuntime]]
408
+ runtimes: Optional[dict[str, model_runtime.ModelRuntime]]
404
409
  if model_dict.get("runtimes", None):
405
410
  runtimes = {
406
411
  name: model_runtime.ModelRuntime.load(pathlib.Path(model_dir_path), name, env, runtime_dict)
@@ -1,7 +1,7 @@
1
1
  # This files contains schema definition of what will be written into model.yml
2
2
  # Changing this file should lead to a change of the schema version.
3
3
  from enum import Enum
4
- from typing import Any, Dict, List, Optional, TypedDict, Union
4
+ from typing import Any, Optional, TypedDict, Union
5
5
 
6
6
  from typing_extensions import NotRequired, Required
7
7
 
@@ -18,18 +18,20 @@ class FunctionProperties(Enum):
18
18
  class ModelRuntimeDependenciesDict(TypedDict):
19
19
  conda: Required[str]
20
20
  pip: Required[str]
21
- artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
21
+ artifact_repository_map: NotRequired[Optional[dict[str, str]]]
22
22
 
23
23
 
24
24
  class ModelRuntimeDict(TypedDict):
25
- imports: Required[List[str]]
25
+ imports: Required[list[str]]
26
26
  dependencies: Required[ModelRuntimeDependenciesDict]
27
+ resource_constraint: NotRequired[Optional[dict[str, str]]]
27
28
 
28
29
 
29
30
  class ModelEnvDict(TypedDict):
30
31
  conda: Required[str]
31
32
  pip: Required[str]
32
- artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
33
+ artifact_repository_map: NotRequired[Optional[dict[str, str]]]
34
+ resource_constraint: NotRequired[Optional[dict[str, str]]]
33
35
  python_version: Required[str]
34
36
  cuda_version: NotRequired[Optional[str]]
35
37
  snowpark_ml_version: Required[str]
@@ -102,25 +104,25 @@ class ModelBlobMetadataDict(TypedDict):
102
104
  model_type: Required[type_hints.SupportedModelHandlerType]
103
105
  path: Required[str]
104
106
  handler_version: Required[str]
105
- function_properties: NotRequired[Dict[str, Dict[str, Any]]]
106
- artifacts: NotRequired[Dict[str, str]]
107
+ function_properties: NotRequired[dict[str, dict[str, Any]]]
108
+ artifacts: NotRequired[dict[str, str]]
107
109
  options: NotRequired[ModelBlobOptions]
108
110
 
109
111
 
110
112
  class ModelMetadataDict(TypedDict):
111
113
  creation_timestamp: Required[str]
112
114
  env: Required[ModelEnvDict]
113
- runtimes: NotRequired[Dict[str, ModelRuntimeDict]]
114
- metadata: NotRequired[Optional[Dict[str, str]]]
115
+ runtimes: NotRequired[dict[str, ModelRuntimeDict]]
116
+ metadata: NotRequired[Optional[dict[str, str]]]
115
117
  model_type: Required[type_hints.SupportedModelHandlerType]
116
- models: Required[Dict[str, ModelBlobMetadataDict]]
118
+ models: Required[dict[str, ModelBlobMetadataDict]]
117
119
  name: Required[str]
118
- signatures: Required[Dict[str, Dict[str, Any]]]
120
+ signatures: Required[dict[str, dict[str, Any]]]
119
121
  version: Required[str]
120
122
  min_snowpark_ml_version: Required[str]
121
123
  task: Required[str]
122
124
  explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
123
- function_properties: NotRequired[Dict[str, Dict[str, Any]]]
125
+ function_properties: NotRequired[dict[str, dict[str, Any]]]
124
126
 
125
127
 
126
128
  class ModelExplainAlgorithm(Enum):
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  from abc import abstractmethod
3
- from typing import Any, Dict, Protocol, final
3
+ from typing import Any, Protocol, final
4
4
 
5
5
  from snowflake.ml._internal import migrator_utils
6
6
 
@@ -11,13 +11,13 @@ class _BaseModelMetaMigratorProtocol(Protocol):
11
11
 
12
12
  @staticmethod
13
13
  @abstractmethod
14
- def upgrade(original_meta_dict: Dict[str, Any]) -> Dict[str, Any]:
14
+ def upgrade(original_meta_dict: dict[str, Any]) -> dict[str, Any]:
15
15
  raise NotImplementedError
16
16
 
17
17
 
18
18
  class BaseModelMetaMigrator(_BaseModelMetaMigratorProtocol):
19
19
  @final
20
- def try_upgrade(self, original_meta_dict: Dict[str, Any]) -> Dict[str, Any]:
20
+ def try_upgrade(self, original_meta_dict: dict[str, Any]) -> dict[str, Any]:
21
21
  loaded_meta_version = original_meta_dict.get("version", None)
22
22
  if not loaded_meta_version or str(loaded_meta_version) != self.source_version:
23
23
  raise NotImplementedError(