snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import inspect
2
2
  import os
3
3
  import pathlib
4
4
  import sys
5
- from typing import Dict, Optional, Type, cast, final
5
+ from typing import Optional, cast, final
6
6
 
7
7
  import anyio
8
8
  import cloudpickle
@@ -28,7 +28,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
28
28
  HANDLER_TYPE = "custom"
29
29
  HANDLER_VERSION = "2023-12-01"
30
30
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
31
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
31
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
32
32
 
33
33
  @classmethod
34
34
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["custom_model.CustomModel"]:
@@ -99,7 +99,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
101
  if handler is None:
102
- raise TypeError("Your input type to custom model is not currently supported")
102
+ raise TypeError(
103
+ f"Model {sub_name} in model context is not a supported model type. See "
104
+ "https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/"
105
+ "bring-your-own-model-types for more details."
106
+ )
103
107
  sub_model = handler.cast_model(model_ref.model)
104
108
  handler.save_model(
105
109
  name=sub_name,
@@ -161,7 +165,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
161
165
  name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
162
166
  for name, rel_path in artifacts_meta.items()
163
167
  }
164
- models: Dict[str, model_types.SupportedModelType] = dict()
168
+ models: dict[str, model_types.SupportedModelType] = dict()
165
169
  for sub_model_name, _ref in context.model_refs.items():
166
170
  model_type = model_meta.models[sub_model_name].model_type
167
171
  handler = model_handler.load_handler(model_type)
@@ -1,18 +1,7 @@
1
1
  import json
2
2
  import os
3
3
  import warnings
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Callable,
8
- Dict,
9
- List,
10
- Optional,
11
- Type,
12
- Union,
13
- cast,
14
- final,
15
- )
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
16
5
 
17
6
  import cloudpickle
18
7
  import numpy as np
@@ -38,7 +27,7 @@ if TYPE_CHECKING:
38
27
  import transformers
39
28
 
40
29
 
41
- def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model_env.ModelDependency]:
30
+ def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
42
31
  # Text
43
32
  if task in [
44
33
  "conversational",
@@ -84,7 +73,7 @@ class HuggingFacePipelineHandler(
84
73
  HANDLER_TYPE = "huggingface_pipeline"
85
74
  HANDLER_VERSION = "2023-12-01"
86
75
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
87
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
76
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
88
77
 
89
78
  MODEL_BLOB_FILE_OR_DIR = "model"
90
79
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
@@ -250,20 +239,17 @@ class HuggingFacePipelineHandler(
250
239
  task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
251
240
  )
252
241
  if framework is None or framework == "pt":
253
- # Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
254
- # Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
255
- # users are not required to install pytorch locally if they are using the wrapper.
256
242
  pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
257
243
  elif framework == "tf":
258
244
  pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
259
245
  model_meta.env.include_if_absent(
260
246
  pkgs_requirements, check_local_version=(type_utils.LazyType("transformers.Pipeline").isinstance(model))
261
247
  )
262
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
248
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
263
249
 
264
250
  @staticmethod
265
- def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
266
- device_config: Dict[str, Any] = {}
251
+ def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> dict[str, str]:
252
+ device_config: dict[str, Any] = {}
267
253
  cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
268
254
  gpu_nums = 0
269
255
  if cuda_visible_devices is not None:
@@ -369,7 +355,7 @@ class HuggingFacePipelineHandler(
369
355
  def _create_custom_model(
370
356
  raw_model: "transformers.Pipeline",
371
357
  model_meta: model_meta_api.ModelMetadata,
372
- ) -> Type[custom_model.CustomModel]:
358
+ ) -> type[custom_model.CustomModel]:
373
359
  def fn_factory(
374
360
  raw_model: "transformers.Pipeline",
375
361
  signature: model_signature.ModelSignature,
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -32,7 +32,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
32
32
  HANDLER_TYPE = "keras"
33
33
  HANDLER_VERSION = "2025-01-01"
34
34
  _MIN_SNOWPARK_ML_VERSION = "1.7.5"
35
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
35
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
36
36
 
37
37
  MODEL_BLOB_FILE_OR_DIR = "model.keras"
38
38
  CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
@@ -146,7 +146,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
146
146
  dependencies,
147
147
  check_local_version=True,
148
148
  )
149
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
149
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
150
150
 
151
151
  @classmethod
152
152
  def load_model(
@@ -185,7 +185,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
185
185
  def _create_custom_model(
186
186
  raw_model: "keras.Model",
187
187
  model_meta: model_meta_api.ModelMetadata,
188
- ) -> Type[custom_model.CustomModel]:
188
+ ) -> type[custom_model.CustomModel]:
189
189
  def fn_factory(
190
190
  raw_model: "keras.Model",
191
191
  signature: model_signature.ModelSignature,
@@ -1,16 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import (
4
- TYPE_CHECKING,
5
- Any,
6
- Callable,
7
- Dict,
8
- Optional,
9
- Type,
10
- Union,
11
- cast,
12
- final,
13
- )
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
14
4
 
15
5
  import cloudpickle
16
6
  import numpy as np
@@ -41,7 +31,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
41
31
  HANDLER_TYPE = "lightgbm"
42
32
  HANDLER_VERSION = "2024-03-19"
43
33
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
44
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
34
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
45
35
 
46
36
  MODEL_BLOB_FILE_OR_DIR = "model.pkl"
47
37
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -215,7 +205,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
215
205
  def _create_custom_model(
216
206
  raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
217
207
  model_meta: model_meta_api.ModelMetadata,
218
- ) -> Type[custom_model.CustomModel]:
208
+ ) -> type[custom_model.CustomModel]:
219
209
  def fn_factory(
220
210
  raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
221
211
  signature: model_signature.ModelSignature,
@@ -250,7 +240,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
250
240
 
251
241
  return fn
252
242
 
253
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
243
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
254
244
  for target_method_name, sig in model_meta.signatures.items():
255
245
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
256
246
 
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import pathlib
3
3
  import tempfile
4
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
5
5
 
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
@@ -61,7 +61,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
61
61
  HANDLER_TYPE = "mlflow"
62
62
  HANDLER_VERSION = "2023-12-01"
63
63
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
64
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
64
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
65
65
 
66
66
  MODEL_BLOB_FILE_OR_DIR = "model"
67
67
  _DEFAULT_TARGET_METHOD = "predict"
@@ -204,7 +204,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
204
204
  def _create_custom_model(
205
205
  raw_model: "mlflow.pyfunc.PyFuncModel",
206
206
  model_meta: model_meta_api.ModelMetadata,
207
- ) -> Type[custom_model.CustomModel]:
207
+ ) -> type[custom_model.CustomModel]:
208
208
  def fn_factory(
209
209
  raw_model: "mlflow.pyfunc.PyFuncModel",
210
210
  signature: model_signature.ModelSignature,
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import sys
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import pandas as pd
@@ -38,7 +38,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
38
38
  HANDLER_TYPE = "pytorch"
39
39
  HANDLER_VERSION = "2025-03-01"
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
41
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
42
42
  "2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
43
43
  }
44
44
 
@@ -151,7 +151,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
151
151
  model_meta.env.include_if_absent(
152
152
  [model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
153
153
  )
154
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
154
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
155
155
 
156
156
  @classmethod
157
157
  def load_model(
@@ -188,7 +188,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
188
188
  def _create_custom_model(
189
189
  raw_model: "torch.nn.Module",
190
190
  model_meta: model_meta_api.ModelMetadata,
191
- ) -> Type[custom_model.CustomModel]:
191
+ ) -> type[custom_model.CustomModel]:
192
192
  multiple_inputs = cast(
193
193
  model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
194
194
  )["multiple_inputs"]
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import logging
3
3
  import os
4
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
5
5
 
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
- def _validate_sentence_transformers_signatures(sigs: Dict[str, model_signature.ModelSignature]) -> None:
27
+ def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
28
28
  if list(sigs.keys()) != ["encode"]:
29
29
  raise ValueError("target_methods can only be ['encode']")
30
30
 
@@ -48,7 +48,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
48
48
  HANDLER_TYPE = "sentence_transformers"
49
49
  HANDLER_VERSION = "2024-03-15"
50
50
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
51
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
51
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
52
52
 
53
53
  MODEL_BLOB_FILE_OR_DIR = "model"
54
54
  DEFAULT_TARGET_METHODS = ["encode"]
@@ -166,7 +166,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
166
166
  ],
167
167
  check_local_version=True,
168
168
  )
169
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
169
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
170
170
 
171
171
  @staticmethod
172
172
  def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
@@ -224,7 +224,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
224
224
  def _create_custom_model(
225
225
  raw_model: "sentence_transformers.SentenceTransformer",
226
226
  model_meta: model_meta_api.ModelMetadata,
227
- ) -> Type[custom_model.CustomModel]:
227
+ ) -> type[custom_model.CustomModel]:
228
228
  batch_size = cast(
229
229
  model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
230
230
  ).get("batch_size", None)
@@ -1,13 +1,13 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, Union, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
8
  from typing_extensions import TypeGuard, Unpack
9
9
 
10
- from snowflake.ml._internal import type_utils
10
+ from snowflake.ml._internal import env, type_utils
11
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._packager.model_env import model_env
13
13
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
@@ -19,7 +19,6 @@ from snowflake.ml.model._packager.model_meta import (
19
19
  )
20
20
  from snowflake.ml.model._packager.model_task import model_task_utils
21
21
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
22
- from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
23
22
 
24
23
  if TYPE_CHECKING:
25
24
  import sklearn.base
@@ -49,7 +48,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
49
48
  HANDLER_TYPE = "sklearn"
50
49
  HANDLER_VERSION = "2023-12-01"
51
50
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
52
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
51
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
53
52
 
54
53
  DEFAULT_TARGET_METHODS = [
55
54
  "predict",
@@ -113,7 +112,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
113
112
  raise ValueError("Sample input data is required to enable explainability.")
114
113
 
115
114
  # If this is a pipeline and we are in the container runtime, check for distributed estimator.
116
- if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
115
+ if env.IN_ML_RUNTIME and isinstance(model, sklearn.pipeline.Pipeline):
117
116
  model = _unpack_container_runtime_pipeline(model)
118
117
 
119
118
  if not is_sub_model:
@@ -265,7 +264,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
265
264
  def _create_custom_model(
266
265
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
267
266
  model_meta: model_meta_api.ModelMetadata,
268
- ) -> Type[custom_model.CustomModel]:
267
+ ) -> type[custom_model.CustomModel]:
269
268
  def fn_factory(
270
269
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
271
270
  signature: model_signature.ModelSignature,
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
@@ -36,7 +36,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
36
36
  HANDLER_TYPE = "snowml"
37
37
  HANDLER_VERSION = "2023-12-01"
38
38
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
40
40
 
41
41
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
42
42
  EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
@@ -264,7 +264,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
264
264
  def _create_custom_model(
265
265
  raw_model: "BaseEstimator",
266
266
  model_meta: model_meta_api.ModelMetadata,
267
- ) -> Type[custom_model.CustomModel]:
267
+ ) -> type[custom_model.CustomModel]:
268
268
  def fn_factory(
269
269
  raw_model: "BaseEstimator",
270
270
  signature: model_signature.ModelSignature,
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import pandas as pd
5
5
  from packaging import version
@@ -38,7 +38,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
38
38
  HANDLER_TYPE = "tensorflow"
39
39
  HANDLER_VERSION = "2025-03-01"
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
41
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
42
42
  "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
43
43
  "2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
44
44
  }
@@ -188,7 +188,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
188
188
  dependencies,
189
189
  check_local_version=True,
190
190
  )
191
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
191
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
192
192
 
193
193
  @classmethod
194
194
  def load_model(
@@ -230,7 +230,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
230
230
  def _create_custom_model(
231
231
  raw_model: "tensorflow.Module",
232
232
  model_meta: model_meta_api.ModelMetadata,
233
- ) -> Type[custom_model.CustomModel]:
233
+ ) -> type[custom_model.CustomModel]:
234
234
  multiple_inputs = cast(
235
235
  model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
236
236
  )["multiple_inputs"]
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
2
+ from typing import TYPE_CHECKING, Callable, Optional, cast, final
3
3
 
4
4
  import pandas as pd
5
5
  from typing_extensions import TypeGuard, Unpack
@@ -36,7 +36,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
36
36
  HANDLER_TYPE = "torchscript"
37
37
  HANDLER_VERSION = "2025-03-01"
38
38
  _MIN_SNOWPARK_ML_VERSION = "1.8.0"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
39
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
40
40
  "2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
41
41
  }
42
42
 
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
141
141
  model_meta.env.include_if_absent(
142
142
  [model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
143
143
  )
144
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
144
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
145
145
 
146
146
  @classmethod
147
147
  def load_model(
@@ -181,7 +181,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
181
181
  def _create_custom_model(
182
182
  raw_model: "torch.jit.ScriptModule",
183
183
  model_meta: model_meta_api.ModelMetadata,
184
- ) -> Type[custom_model.CustomModel]:
184
+ ) -> type[custom_model.CustomModel]:
185
185
  def fn_factory(
186
186
  raw_model: "torch.jit.ScriptModule",
187
187
  signature: model_signature.ModelSignature,
@@ -1,17 +1,7 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
3
  import warnings
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Callable,
8
- Dict,
9
- Optional,
10
- Type,
11
- Union,
12
- cast,
13
- final,
14
- )
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
15
5
 
16
6
  import numpy as np
17
7
  import pandas as pd
@@ -44,7 +34,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
44
34
  HANDLER_TYPE = "xgboost"
45
35
  HANDLER_VERSION = "2023-12-01"
46
36
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
47
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
37
+ _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
48
38
 
49
39
  MODEL_BLOB_FILE_OR_DIR = "model.ubj"
50
40
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
@@ -175,7 +165,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
175
165
  if enable_explainability:
176
166
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
177
167
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
178
- model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
168
+ model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
179
169
 
180
170
  @classmethod
181
171
  def load_model(
@@ -227,7 +217,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
227
217
  def _create_custom_model(
228
218
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
229
219
  model_meta: model_meta_api.ModelMetadata,
230
- ) -> Type[custom_model.CustomModel]:
220
+ ) -> type[custom_model.CustomModel]:
231
221
  def fn_factory(
232
222
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
233
223
  signature: model_signature.ModelSignature,
@@ -261,7 +251,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
261
251
  return explain_fn
262
252
  return fn
263
253
 
264
- type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
254
+ type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
265
255
  for target_method_name, sig in model_meta.signatures.items():
266
256
  type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
267
257
 
@@ -1,4 +1,4 @@
1
- from typing import Dict, cast
1
+ from typing import cast
2
2
 
3
3
  from typing_extensions import Unpack
4
4
 
@@ -25,7 +25,7 @@ class ModelBlobMeta:
25
25
  self.handler_version = kwargs["handler_version"]
26
26
  self.function_properties = kwargs.get("function_properties", {})
27
27
 
28
- self.artifacts: Dict[str, str] = {}
28
+ self.artifacts: dict[str, str] = {}
29
29
  artifacts = kwargs.get("artifacts", None)
30
30
  if artifacts:
31
31
  self.artifacts = artifacts