snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +6 -6
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +6 -16
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +17 -28
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +4 -5
- snowflake/ml/jobs/job.py +24 -14
- snowflake/ml/jobs/manager.py +37 -41
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +199 -26
- snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +13 -13
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +17 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +33 -31
- snowflake/ml/registry/registry.py +29 -22
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import os
|
3
3
|
import pathlib
|
4
4
|
import sys
|
5
|
-
from typing import
|
5
|
+
from typing import Optional, cast, final
|
6
6
|
|
7
7
|
import anyio
|
8
8
|
import cloudpickle
|
@@ -28,7 +28,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
28
28
|
HANDLER_TYPE = "custom"
|
29
29
|
HANDLER_VERSION = "2023-12-01"
|
30
30
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
31
|
-
_HANDLER_MIGRATOR_PLANS:
|
31
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
32
|
|
33
33
|
@classmethod
|
34
34
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["custom_model.CustomModel"]:
|
@@ -99,7 +99,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
101
|
if handler is None:
|
102
|
-
raise TypeError(
|
102
|
+
raise TypeError(
|
103
|
+
f"Model {sub_name} in model context is not a supported model type. See "
|
104
|
+
"https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/"
|
105
|
+
"bring-your-own-model-types for more details."
|
106
|
+
)
|
103
107
|
sub_model = handler.cast_model(model_ref.model)
|
104
108
|
handler.save_model(
|
105
109
|
name=sub_name,
|
@@ -161,7 +165,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
161
165
|
name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
|
162
166
|
for name, rel_path in artifacts_meta.items()
|
163
167
|
}
|
164
|
-
models:
|
168
|
+
models: dict[str, model_types.SupportedModelType] = dict()
|
165
169
|
for sub_model_name, _ref in context.model_refs.items():
|
166
170
|
model_type = model_meta.models[sub_model_name].model_type
|
167
171
|
handler = model_handler.load_handler(model_type)
|
@@ -1,18 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
List,
|
10
|
-
Optional,
|
11
|
-
Type,
|
12
|
-
Union,
|
13
|
-
cast,
|
14
|
-
final,
|
15
|
-
)
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
16
5
|
|
17
6
|
import cloudpickle
|
18
7
|
import numpy as np
|
@@ -38,7 +27,7 @@ if TYPE_CHECKING:
|
|
38
27
|
import transformers
|
39
28
|
|
40
29
|
|
41
|
-
def get_requirements_from_task(task: str, spcs_only: bool = False) ->
|
30
|
+
def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
|
42
31
|
# Text
|
43
32
|
if task in [
|
44
33
|
"conversational",
|
@@ -84,7 +73,7 @@ class HuggingFacePipelineHandler(
|
|
84
73
|
HANDLER_TYPE = "huggingface_pipeline"
|
85
74
|
HANDLER_VERSION = "2023-12-01"
|
86
75
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
87
|
-
_HANDLER_MIGRATOR_PLANS:
|
76
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
88
77
|
|
89
78
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
90
79
|
ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
|
@@ -250,20 +239,17 @@ class HuggingFacePipelineHandler(
|
|
250
239
|
task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
251
240
|
)
|
252
241
|
if framework is None or framework == "pt":
|
253
|
-
# Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
|
254
|
-
# Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
|
255
|
-
# users are not required to install pytorch locally if they are using the wrapper.
|
256
242
|
pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
|
257
243
|
elif framework == "tf":
|
258
244
|
pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
|
259
245
|
model_meta.env.include_if_absent(
|
260
246
|
pkgs_requirements, check_local_version=(type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
261
247
|
)
|
262
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
248
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
263
249
|
|
264
250
|
@staticmethod
|
265
|
-
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) ->
|
266
|
-
device_config:
|
251
|
+
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> dict[str, str]:
|
252
|
+
device_config: dict[str, Any] = {}
|
267
253
|
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
268
254
|
gpu_nums = 0
|
269
255
|
if cuda_visible_devices is not None:
|
@@ -369,7 +355,7 @@ class HuggingFacePipelineHandler(
|
|
369
355
|
def _create_custom_model(
|
370
356
|
raw_model: "transformers.Pipeline",
|
371
357
|
model_meta: model_meta_api.ModelMetadata,
|
372
|
-
) ->
|
358
|
+
) -> type[custom_model.CustomModel]:
|
373
359
|
def fn_factory(
|
374
360
|
raw_model: "transformers.Pipeline",
|
375
361
|
signature: model_signature.ModelSignature,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -32,7 +32,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
32
32
|
HANDLER_TYPE = "keras"
|
33
33
|
HANDLER_VERSION = "2025-01-01"
|
34
34
|
_MIN_SNOWPARK_ML_VERSION = "1.7.5"
|
35
|
-
_HANDLER_MIGRATOR_PLANS:
|
35
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
36
|
|
37
37
|
MODEL_BLOB_FILE_OR_DIR = "model.keras"
|
38
38
|
CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
|
@@ -146,7 +146,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
146
146
|
dependencies,
|
147
147
|
check_local_version=True,
|
148
148
|
)
|
149
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
149
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
150
150
|
|
151
151
|
@classmethod
|
152
152
|
def load_model(
|
@@ -185,7 +185,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
185
185
|
def _create_custom_model(
|
186
186
|
raw_model: "keras.Model",
|
187
187
|
model_meta: model_meta_api.ModelMetadata,
|
188
|
-
) ->
|
188
|
+
) -> type[custom_model.CustomModel]:
|
189
189
|
def fn_factory(
|
190
190
|
raw_model: "keras.Model",
|
191
191
|
signature: model_signature.ModelSignature,
|
@@ -1,16 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import
|
4
|
-
TYPE_CHECKING,
|
5
|
-
Any,
|
6
|
-
Callable,
|
7
|
-
Dict,
|
8
|
-
Optional,
|
9
|
-
Type,
|
10
|
-
Union,
|
11
|
-
cast,
|
12
|
-
final,
|
13
|
-
)
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
14
4
|
|
15
5
|
import cloudpickle
|
16
6
|
import numpy as np
|
@@ -41,7 +31,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
41
31
|
HANDLER_TYPE = "lightgbm"
|
42
32
|
HANDLER_VERSION = "2024-03-19"
|
43
33
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
44
|
-
_HANDLER_MIGRATOR_PLANS:
|
34
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
45
35
|
|
46
36
|
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
47
37
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -215,7 +205,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
215
205
|
def _create_custom_model(
|
216
206
|
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
217
207
|
model_meta: model_meta_api.ModelMetadata,
|
218
|
-
) ->
|
208
|
+
) -> type[custom_model.CustomModel]:
|
219
209
|
def fn_factory(
|
220
210
|
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
221
211
|
signature: model_signature.ModelSignature,
|
@@ -250,7 +240,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
250
240
|
|
251
241
|
return fn
|
252
242
|
|
253
|
-
type_method_dict:
|
243
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
254
244
|
for target_method_name, sig in model_meta.signatures.items():
|
255
245
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
256
246
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
|
-
from typing import TYPE_CHECKING, Callable,
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
5
5
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
@@ -61,7 +61,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
61
61
|
HANDLER_TYPE = "mlflow"
|
62
62
|
HANDLER_VERSION = "2023-12-01"
|
63
63
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
64
|
-
_HANDLER_MIGRATOR_PLANS:
|
64
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
65
65
|
|
66
66
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
67
67
|
_DEFAULT_TARGET_METHOD = "predict"
|
@@ -204,7 +204,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
204
204
|
def _create_custom_model(
|
205
205
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
206
206
|
model_meta: model_meta_api.ModelMetadata,
|
207
|
-
) ->
|
207
|
+
) -> type[custom_model.CustomModel]:
|
208
208
|
def fn_factory(
|
209
209
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
210
210
|
signature: model_signature.ModelSignature,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import sys
|
3
|
-
from typing import TYPE_CHECKING, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import pandas as pd
|
@@ -38,7 +38,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
38
38
|
HANDLER_TYPE = "pytorch"
|
39
39
|
HANDLER_VERSION = "2025-03-01"
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
-
_HANDLER_MIGRATOR_PLANS:
|
41
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
42
|
"2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
|
43
43
|
}
|
44
44
|
|
@@ -151,7 +151,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
151
151
|
model_meta.env.include_if_absent(
|
152
152
|
[model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
|
153
153
|
)
|
154
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
154
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
155
155
|
|
156
156
|
@classmethod
|
157
157
|
def load_model(
|
@@ -188,7 +188,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
188
188
|
def _create_custom_model(
|
189
189
|
raw_model: "torch.nn.Module",
|
190
190
|
model_meta: model_meta_api.ModelMetadata,
|
191
|
-
) ->
|
191
|
+
) -> type[custom_model.CustomModel]:
|
192
192
|
multiple_inputs = cast(
|
193
193
|
model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
|
194
194
|
)["multiple_inputs"]
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
-
from typing import TYPE_CHECKING, Callable,
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
5
5
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|
24
24
|
logger = logging.getLogger(__name__)
|
25
25
|
|
26
26
|
|
27
|
-
def _validate_sentence_transformers_signatures(sigs:
|
27
|
+
def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
|
28
28
|
if list(sigs.keys()) != ["encode"]:
|
29
29
|
raise ValueError("target_methods can only be ['encode']")
|
30
30
|
|
@@ -48,7 +48,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
48
48
|
HANDLER_TYPE = "sentence_transformers"
|
49
49
|
HANDLER_VERSION = "2024-03-15"
|
50
50
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
51
|
-
_HANDLER_MIGRATOR_PLANS:
|
51
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
52
52
|
|
53
53
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
54
54
|
DEFAULT_TARGET_METHODS = ["encode"]
|
@@ -166,7 +166,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
166
166
|
],
|
167
167
|
check_local_version=True,
|
168
168
|
)
|
169
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
169
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
170
170
|
|
171
171
|
@staticmethod
|
172
172
|
def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
|
@@ -224,7 +224,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
224
224
|
def _create_custom_model(
|
225
225
|
raw_model: "sentence_transformers.SentenceTransformer",
|
226
226
|
model_meta: model_meta_api.ModelMetadata,
|
227
|
-
) ->
|
227
|
+
) -> type[custom_model.CustomModel]:
|
228
228
|
batch_size = cast(
|
229
229
|
model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
|
230
230
|
).get("batch_size", None)
|
@@ -1,13 +1,13 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Union, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
8
|
from typing_extensions import TypeGuard, Unpack
|
9
9
|
|
10
|
-
from snowflake.ml._internal import type_utils
|
10
|
+
from snowflake.ml._internal import env, type_utils
|
11
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._packager.model_env import model_env
|
13
13
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
@@ -19,7 +19,6 @@ from snowflake.ml.model._packager.model_meta import (
|
|
19
19
|
)
|
20
20
|
from snowflake.ml.model._packager.model_task import model_task_utils
|
21
21
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
22
|
-
from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
|
23
22
|
|
24
23
|
if TYPE_CHECKING:
|
25
24
|
import sklearn.base
|
@@ -49,7 +48,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
49
48
|
HANDLER_TYPE = "sklearn"
|
50
49
|
HANDLER_VERSION = "2023-12-01"
|
51
50
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
52
|
-
_HANDLER_MIGRATOR_PLANS:
|
51
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
53
52
|
|
54
53
|
DEFAULT_TARGET_METHODS = [
|
55
54
|
"predict",
|
@@ -113,7 +112,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
113
112
|
raise ValueError("Sample input data is required to enable explainability.")
|
114
113
|
|
115
114
|
# If this is a pipeline and we are in the container runtime, check for distributed estimator.
|
116
|
-
if
|
115
|
+
if env.IN_ML_RUNTIME and isinstance(model, sklearn.pipeline.Pipeline):
|
117
116
|
model = _unpack_container_runtime_pipeline(model)
|
118
117
|
|
119
118
|
if not is_sub_model:
|
@@ -265,7 +264,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
265
264
|
def _create_custom_model(
|
266
265
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
267
266
|
model_meta: model_meta_api.ModelMetadata,
|
268
|
-
) ->
|
267
|
+
) -> type[custom_model.CustomModel]:
|
269
268
|
def fn_factory(
|
270
269
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
271
270
|
signature: model_signature.ModelSignature,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
@@ -36,7 +36,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
36
36
|
HANDLER_TYPE = "snowml"
|
37
37
|
HANDLER_VERSION = "2023-12-01"
|
38
38
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
39
|
-
_HANDLER_MIGRATOR_PLANS:
|
39
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
40
40
|
|
41
41
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
42
42
|
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
@@ -264,7 +264,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
264
264
|
def _create_custom_model(
|
265
265
|
raw_model: "BaseEstimator",
|
266
266
|
model_meta: model_meta_api.ModelMetadata,
|
267
|
-
) ->
|
267
|
+
) -> type[custom_model.CustomModel]:
|
268
268
|
def fn_factory(
|
269
269
|
raw_model: "BaseEstimator",
|
270
270
|
signature: model_signature.ModelSignature,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from packaging import version
|
@@ -38,7 +38,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
38
38
|
HANDLER_TYPE = "tensorflow"
|
39
39
|
HANDLER_VERSION = "2025-03-01"
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
-
_HANDLER_MIGRATOR_PLANS:
|
41
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
42
|
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
|
43
43
|
"2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
|
44
44
|
}
|
@@ -188,7 +188,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
188
188
|
dependencies,
|
189
189
|
check_local_version=True,
|
190
190
|
)
|
191
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
191
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
192
192
|
|
193
193
|
@classmethod
|
194
194
|
def load_model(
|
@@ -230,7 +230,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
230
230
|
def _create_custom_model(
|
231
231
|
raw_model: "tensorflow.Module",
|
232
232
|
model_meta: model_meta_api.ModelMetadata,
|
233
|
-
) ->
|
233
|
+
) -> type[custom_model.CustomModel]:
|
234
234
|
multiple_inputs = cast(
|
235
235
|
model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
|
236
236
|
)["multiple_inputs"]
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from typing_extensions import TypeGuard, Unpack
|
@@ -36,7 +36,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
36
36
|
HANDLER_TYPE = "torchscript"
|
37
37
|
HANDLER_VERSION = "2025-03-01"
|
38
38
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
39
|
-
_HANDLER_MIGRATOR_PLANS:
|
39
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
40
40
|
"2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
|
41
41
|
}
|
42
42
|
|
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
141
141
|
model_meta.env.include_if_absent(
|
142
142
|
[model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
|
143
143
|
)
|
144
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
144
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
145
145
|
|
146
146
|
@classmethod
|
147
147
|
def load_model(
|
@@ -181,7 +181,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
181
181
|
def _create_custom_model(
|
182
182
|
raw_model: "torch.jit.ScriptModule",
|
183
183
|
model_meta: model_meta_api.ModelMetadata,
|
184
|
-
) ->
|
184
|
+
) -> type[custom_model.CustomModel]:
|
185
185
|
def fn_factory(
|
186
186
|
raw_model: "torch.jit.ScriptModule",
|
187
187
|
signature: model_signature.ModelSignature,
|
@@ -1,17 +1,7 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
Optional,
|
10
|
-
Type,
|
11
|
-
Union,
|
12
|
-
cast,
|
13
|
-
final,
|
14
|
-
)
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
15
5
|
|
16
6
|
import numpy as np
|
17
7
|
import pandas as pd
|
@@ -44,7 +34,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
44
34
|
HANDLER_TYPE = "xgboost"
|
45
35
|
HANDLER_VERSION = "2023-12-01"
|
46
36
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
47
|
-
_HANDLER_MIGRATOR_PLANS:
|
37
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
48
38
|
|
49
39
|
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
50
40
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -175,7 +165,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
175
165
|
if enable_explainability:
|
176
166
|
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
177
167
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
178
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
168
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
179
169
|
|
180
170
|
@classmethod
|
181
171
|
def load_model(
|
@@ -227,7 +217,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
227
217
|
def _create_custom_model(
|
228
218
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
229
219
|
model_meta: model_meta_api.ModelMetadata,
|
230
|
-
) ->
|
220
|
+
) -> type[custom_model.CustomModel]:
|
231
221
|
def fn_factory(
|
232
222
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
233
223
|
signature: model_signature.ModelSignature,
|
@@ -261,7 +251,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
261
251
|
return explain_fn
|
262
252
|
return fn
|
263
253
|
|
264
|
-
type_method_dict:
|
254
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
265
255
|
for target_method_name, sig in model_meta.signatures.items():
|
266
256
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
267
257
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import cast
|
2
2
|
|
3
3
|
from typing_extensions import Unpack
|
4
4
|
|
@@ -25,7 +25,7 @@ class ModelBlobMeta:
|
|
25
25
|
self.handler_version = kwargs["handler_version"]
|
26
26
|
self.function_properties = kwargs.get("function_properties", {})
|
27
27
|
|
28
|
-
self.artifacts:
|
28
|
+
self.artifacts: dict[str, str] = {}
|
29
29
|
artifacts = kwargs.get("artifacts", None)
|
30
30
|
if artifacts:
|
31
31
|
self.artifacts = artifacts
|