snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import
|
4
|
-
TYPE_CHECKING,
|
5
|
-
Any,
|
6
|
-
Callable,
|
7
|
-
Dict,
|
8
|
-
Optional,
|
9
|
-
Type,
|
10
|
-
Union,
|
11
|
-
cast,
|
12
|
-
final,
|
13
|
-
)
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
14
4
|
|
15
5
|
import cloudpickle
|
16
6
|
import numpy as np
|
@@ -41,7 +31,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
41
31
|
HANDLER_TYPE = "lightgbm"
|
42
32
|
HANDLER_VERSION = "2024-03-19"
|
43
33
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
44
|
-
_HANDLER_MIGRATOR_PLANS:
|
34
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
45
35
|
|
46
36
|
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
47
37
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -215,7 +205,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
215
205
|
def _create_custom_model(
|
216
206
|
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
217
207
|
model_meta: model_meta_api.ModelMetadata,
|
218
|
-
) ->
|
208
|
+
) -> type[custom_model.CustomModel]:
|
219
209
|
def fn_factory(
|
220
210
|
raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
|
221
211
|
signature: model_signature.ModelSignature,
|
@@ -250,7 +240,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
250
240
|
|
251
241
|
return fn
|
252
242
|
|
253
|
-
type_method_dict:
|
243
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
254
244
|
for target_method_name, sig in model_meta.signatures.items():
|
255
245
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
256
246
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
|
-
from typing import TYPE_CHECKING, Callable,
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
5
5
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
@@ -61,7 +61,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
61
61
|
HANDLER_TYPE = "mlflow"
|
62
62
|
HANDLER_VERSION = "2023-12-01"
|
63
63
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
64
|
-
_HANDLER_MIGRATOR_PLANS:
|
64
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
65
65
|
|
66
66
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
67
67
|
_DEFAULT_TARGET_METHOD = "predict"
|
@@ -204,7 +204,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
204
204
|
def _create_custom_model(
|
205
205
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
206
206
|
model_meta: model_meta_api.ModelMetadata,
|
207
|
-
) ->
|
207
|
+
) -> type[custom_model.CustomModel]:
|
208
208
|
def fn_factory(
|
209
209
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
210
210
|
signature: model_signature.ModelSignature,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import sys
|
3
|
-
from typing import TYPE_CHECKING, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import pandas as pd
|
@@ -38,7 +38,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
38
38
|
HANDLER_TYPE = "pytorch"
|
39
39
|
HANDLER_VERSION = "2025-03-01"
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
-
_HANDLER_MIGRATOR_PLANS:
|
41
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
42
|
"2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
|
43
43
|
}
|
44
44
|
|
@@ -151,7 +151,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
151
151
|
model_meta.env.include_if_absent(
|
152
152
|
[model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
|
153
153
|
)
|
154
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
154
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
155
155
|
|
156
156
|
@classmethod
|
157
157
|
def load_model(
|
@@ -188,7 +188,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
188
188
|
def _create_custom_model(
|
189
189
|
raw_model: "torch.nn.Module",
|
190
190
|
model_meta: model_meta_api.ModelMetadata,
|
191
|
-
) ->
|
191
|
+
) -> type[custom_model.CustomModel]:
|
192
192
|
multiple_inputs = cast(
|
193
193
|
model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
|
194
194
|
)["multiple_inputs"]
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
-
from typing import TYPE_CHECKING, Callable,
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
5
5
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|
24
24
|
logger = logging.getLogger(__name__)
|
25
25
|
|
26
26
|
|
27
|
-
def _validate_sentence_transformers_signatures(sigs:
|
27
|
+
def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
|
28
28
|
if list(sigs.keys()) != ["encode"]:
|
29
29
|
raise ValueError("target_methods can only be ['encode']")
|
30
30
|
|
@@ -48,7 +48,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
48
48
|
HANDLER_TYPE = "sentence_transformers"
|
49
49
|
HANDLER_VERSION = "2024-03-15"
|
50
50
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
51
|
-
_HANDLER_MIGRATOR_PLANS:
|
51
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
52
52
|
|
53
53
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
54
54
|
DEFAULT_TARGET_METHODS = ["encode"]
|
@@ -166,7 +166,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
166
166
|
],
|
167
167
|
check_local_version=True,
|
168
168
|
)
|
169
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
169
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
170
170
|
|
171
171
|
@staticmethod
|
172
172
|
def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
|
@@ -224,7 +224,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
224
224
|
def _create_custom_model(
|
225
225
|
raw_model: "sentence_transformers.SentenceTransformer",
|
226
226
|
model_meta: model_meta_api.ModelMetadata,
|
227
|
-
) ->
|
227
|
+
) -> type[custom_model.CustomModel]:
|
228
228
|
batch_size = cast(
|
229
229
|
model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
|
230
230
|
).get("batch_size", None)
|
@@ -1,13 +1,13 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Callable, Optional, Union, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
8
|
from typing_extensions import TypeGuard, Unpack
|
9
9
|
|
10
|
-
from snowflake.ml._internal import type_utils
|
10
|
+
from snowflake.ml._internal import env, type_utils
|
11
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._packager.model_env import model_env
|
13
13
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
@@ -19,7 +19,6 @@ from snowflake.ml.model._packager.model_meta import (
|
|
19
19
|
)
|
20
20
|
from snowflake.ml.model._packager.model_task import model_task_utils
|
21
21
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
22
|
-
from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
|
23
22
|
|
24
23
|
if TYPE_CHECKING:
|
25
24
|
import sklearn.base
|
@@ -49,7 +48,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
49
48
|
HANDLER_TYPE = "sklearn"
|
50
49
|
HANDLER_VERSION = "2023-12-01"
|
51
50
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
52
|
-
_HANDLER_MIGRATOR_PLANS:
|
51
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
53
52
|
|
54
53
|
DEFAULT_TARGET_METHODS = [
|
55
54
|
"predict",
|
@@ -113,7 +112,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
113
112
|
raise ValueError("Sample input data is required to enable explainability.")
|
114
113
|
|
115
114
|
# If this is a pipeline and we are in the container runtime, check for distributed estimator.
|
116
|
-
if
|
115
|
+
if env.IN_ML_RUNTIME and isinstance(model, sklearn.pipeline.Pipeline):
|
117
116
|
model = _unpack_container_runtime_pipeline(model)
|
118
117
|
|
119
118
|
if not is_sub_model:
|
@@ -265,7 +264,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
265
264
|
def _create_custom_model(
|
266
265
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
267
266
|
model_meta: model_meta_api.ModelMetadata,
|
268
|
-
) ->
|
267
|
+
) -> type[custom_model.CustomModel]:
|
269
268
|
def fn_factory(
|
270
269
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
271
270
|
signature: model_signature.ModelSignature,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
@@ -36,7 +36,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
36
36
|
HANDLER_TYPE = "snowml"
|
37
37
|
HANDLER_VERSION = "2023-12-01"
|
38
38
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
39
|
-
_HANDLER_MIGRATOR_PLANS:
|
39
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
40
40
|
|
41
41
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
42
42
|
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
@@ -264,7 +264,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
264
264
|
def _create_custom_model(
|
265
265
|
raw_model: "BaseEstimator",
|
266
266
|
model_meta: model_meta_api.ModelMetadata,
|
267
|
-
) ->
|
267
|
+
) -> type[custom_model.CustomModel]:
|
268
268
|
def fn_factory(
|
269
269
|
raw_model: "BaseEstimator",
|
270
270
|
signature: model_signature.ModelSignature,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from packaging import version
|
@@ -38,7 +38,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
38
38
|
HANDLER_TYPE = "tensorflow"
|
39
39
|
HANDLER_VERSION = "2025-03-01"
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
-
_HANDLER_MIGRATOR_PLANS:
|
41
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
42
|
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
|
43
43
|
"2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
|
44
44
|
}
|
@@ -188,7 +188,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
188
188
|
dependencies,
|
189
189
|
check_local_version=True,
|
190
190
|
)
|
191
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
191
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
192
192
|
|
193
193
|
@classmethod
|
194
194
|
def load_model(
|
@@ -230,7 +230,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
230
230
|
def _create_custom_model(
|
231
231
|
raw_model: "tensorflow.Module",
|
232
232
|
model_meta: model_meta_api.ModelMetadata,
|
233
|
-
) ->
|
233
|
+
) -> type[custom_model.CustomModel]:
|
234
234
|
multiple_inputs = cast(
|
235
235
|
model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
|
236
236
|
)["multiple_inputs"]
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from typing_extensions import TypeGuard, Unpack
|
@@ -36,7 +36,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
36
36
|
HANDLER_TYPE = "torchscript"
|
37
37
|
HANDLER_VERSION = "2025-03-01"
|
38
38
|
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
39
|
-
_HANDLER_MIGRATOR_PLANS:
|
39
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {
|
40
40
|
"2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
|
41
41
|
}
|
42
42
|
|
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
141
141
|
model_meta.env.include_if_absent(
|
142
142
|
[model_env.ModelDependency(requirement="pytorch", pip_name="torch")], check_local_version=True
|
143
143
|
)
|
144
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
144
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
145
145
|
|
146
146
|
@classmethod
|
147
147
|
def load_model(
|
@@ -181,7 +181,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
181
181
|
def _create_custom_model(
|
182
182
|
raw_model: "torch.jit.ScriptModule",
|
183
183
|
model_meta: model_meta_api.ModelMetadata,
|
184
|
-
) ->
|
184
|
+
) -> type[custom_model.CustomModel]:
|
185
185
|
def fn_factory(
|
186
186
|
raw_model: "torch.jit.ScriptModule",
|
187
187
|
signature: model_signature.ModelSignature,
|
@@ -1,17 +1,7 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
Optional,
|
10
|
-
Type,
|
11
|
-
Union,
|
12
|
-
cast,
|
13
|
-
final,
|
14
|
-
)
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
15
5
|
|
16
6
|
import numpy as np
|
17
7
|
import pandas as pd
|
@@ -44,7 +34,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
44
34
|
HANDLER_TYPE = "xgboost"
|
45
35
|
HANDLER_VERSION = "2023-12-01"
|
46
36
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
47
|
-
_HANDLER_MIGRATOR_PLANS:
|
37
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
48
38
|
|
49
39
|
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
50
40
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -175,7 +165,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
175
165
|
if enable_explainability:
|
176
166
|
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
177
167
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
178
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
168
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
179
169
|
|
180
170
|
@classmethod
|
181
171
|
def load_model(
|
@@ -227,7 +217,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
227
217
|
def _create_custom_model(
|
228
218
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
229
219
|
model_meta: model_meta_api.ModelMetadata,
|
230
|
-
) ->
|
220
|
+
) -> type[custom_model.CustomModel]:
|
231
221
|
def fn_factory(
|
232
222
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
233
223
|
signature: model_signature.ModelSignature,
|
@@ -261,7 +251,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
261
251
|
return explain_fn
|
262
252
|
return fn
|
263
253
|
|
264
|
-
type_method_dict:
|
254
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
265
255
|
for target_method_name, sig in model_meta.signatures.items():
|
266
256
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
267
257
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import cast
|
2
2
|
|
3
3
|
from typing_extensions import Unpack
|
4
4
|
|
@@ -25,7 +25,7 @@ class ModelBlobMeta:
|
|
25
25
|
self.handler_version = kwargs["handler_version"]
|
26
26
|
self.function_properties = kwargs.get("function_properties", {})
|
27
27
|
|
28
|
-
self.artifacts:
|
28
|
+
self.artifacts: dict[str, str] = {}
|
29
29
|
artifacts = kwargs.get("artifacts", None)
|
30
30
|
if artifacts:
|
31
31
|
self.artifacts = artifacts
|
@@ -6,31 +6,26 @@ import zipfile
|
|
6
6
|
from contextlib import contextmanager
|
7
7
|
from datetime import datetime
|
8
8
|
from types import ModuleType
|
9
|
-
from typing import Any,
|
9
|
+
from typing import Any, Generator, Optional, TypedDict
|
10
10
|
|
11
11
|
import cloudpickle
|
12
12
|
import yaml
|
13
13
|
from packaging import requirements, version
|
14
14
|
from typing_extensions import Required
|
15
15
|
|
16
|
-
from snowflake.ml
|
16
|
+
from snowflake.ml import version as snowml_version
|
17
|
+
from snowflake.ml._internal import env_utils, file_utils
|
17
18
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
18
19
|
from snowflake.ml.model._packager.model_env import model_env
|
19
|
-
from snowflake.ml.model._packager.model_meta import
|
20
|
-
_packaging_requirements,
|
21
|
-
model_blob_meta,
|
22
|
-
model_meta_schema,
|
23
|
-
)
|
20
|
+
from snowflake.ml.model._packager.model_meta import model_blob_meta, model_meta_schema
|
24
21
|
from snowflake.ml.model._packager.model_meta_migrator import migrator_plans
|
25
22
|
from snowflake.ml.model._packager.model_runtime import model_runtime
|
26
23
|
|
27
24
|
MODEL_METADATA_FILE = "model.yaml"
|
28
25
|
MODEL_CODE_DIR = "code"
|
29
26
|
|
30
|
-
_PACKAGING_REQUIREMENTS = [
|
31
|
-
|
32
|
-
for r in _packaging_requirements.REQUIREMENTS
|
33
|
-
]
|
27
|
+
_PACKAGING_REQUIREMENTS = ["cloudpickle"]
|
28
|
+
|
34
29
|
_SNOWFLAKE_PKG_NAME = "snowflake"
|
35
30
|
_SNOWFLAKE_ML_PKG_NAME = f"{_SNOWFLAKE_PKG_NAME}.ml"
|
36
31
|
|
@@ -41,14 +36,16 @@ def create_model_metadata(
|
|
41
36
|
model_dir_path: str,
|
42
37
|
name: str,
|
43
38
|
model_type: model_types.SupportedModelHandlerType,
|
44
|
-
signatures: Optional[
|
45
|
-
function_properties: Optional[
|
46
|
-
metadata: Optional[
|
47
|
-
code_paths: Optional[
|
48
|
-
ext_modules: Optional[
|
49
|
-
conda_dependencies: Optional[
|
50
|
-
pip_requirements: Optional[
|
51
|
-
artifact_repository_map: Optional[
|
39
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
40
|
+
function_properties: Optional[dict[str, dict[str, Any]]] = None,
|
41
|
+
metadata: Optional[dict[str, str]] = None,
|
42
|
+
code_paths: Optional[list[str]] = None,
|
43
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
44
|
+
conda_dependencies: Optional[list[str]] = None,
|
45
|
+
pip_requirements: Optional[list[str]] = None,
|
46
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
47
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
48
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
52
49
|
python_version: Optional[str] = None,
|
53
50
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
54
51
|
**kwargs: Any,
|
@@ -69,6 +66,8 @@ def create_model_metadata(
|
|
69
66
|
conda_dependencies: List of conda requirements for running the model. Defaults to None.
|
70
67
|
pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
|
71
68
|
artifact_repository_map: A dict mapping from package channel to artifact repository name.
|
69
|
+
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
70
|
+
target_platforms: List of target platforms to run the model.
|
72
71
|
python_version: A string of python version where model is run. Used for user override. If specified as None,
|
73
72
|
current version would be captured. Defaults to None.
|
74
73
|
task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
|
@@ -101,16 +100,19 @@ def create_model_metadata(
|
|
101
100
|
else:
|
102
101
|
raise ValueError("`snowflake.ml` is imported via a way that embedding local ML library is not supported.")
|
103
102
|
|
103
|
+
prefer_pip = target_platforms == [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
104
104
|
env = _create_env_for_model_metadata(
|
105
105
|
conda_dependencies=conda_dependencies,
|
106
106
|
pip_requirements=pip_requirements,
|
107
107
|
artifact_repository_map=artifact_repository_map,
|
108
|
+
resource_constraint=resource_constraint,
|
108
109
|
python_version=python_version,
|
109
110
|
embed_local_ml_library=embed_local_ml_library,
|
111
|
+
prefer_pip=prefer_pip,
|
110
112
|
)
|
111
113
|
|
112
114
|
if embed_local_ml_library:
|
113
|
-
env.snowpark_ml_version = f"{
|
115
|
+
env.snowpark_ml_version = f"{snowml_version.VERSION}+{file_utils.hash_directory(path_to_copy)}"
|
114
116
|
|
115
117
|
model_meta = ModelMetadata(
|
116
118
|
name=name,
|
@@ -152,20 +154,23 @@ def create_model_metadata(
|
|
152
154
|
|
153
155
|
def _create_env_for_model_metadata(
|
154
156
|
*,
|
155
|
-
conda_dependencies: Optional[
|
156
|
-
pip_requirements: Optional[
|
157
|
-
artifact_repository_map: Optional[
|
157
|
+
conda_dependencies: Optional[list[str]] = None,
|
158
|
+
pip_requirements: Optional[list[str]] = None,
|
159
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
160
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
158
161
|
python_version: Optional[str] = None,
|
159
162
|
embed_local_ml_library: bool = False,
|
163
|
+
prefer_pip: bool = False,
|
160
164
|
) -> model_env.ModelEnv:
|
161
|
-
env = model_env.ModelEnv()
|
165
|
+
env = model_env.ModelEnv(prefer_pip=prefer_pip)
|
162
166
|
|
163
167
|
# Mypy doesn't like getter and setter have different types. See python/mypy #3004
|
164
168
|
env.conda_dependencies = conda_dependencies # type: ignore[assignment]
|
165
169
|
env.pip_requirements = pip_requirements # type: ignore[assignment]
|
166
170
|
env.artifact_repository_map = artifact_repository_map
|
171
|
+
env.resource_constraint = resource_constraint
|
167
172
|
env.python_version = python_version # type: ignore[assignment]
|
168
|
-
env.snowpark_ml_version =
|
173
|
+
env.snowpark_ml_version = snowml_version.VERSION
|
169
174
|
|
170
175
|
requirements_to_add = _PACKAGING_REQUIREMENTS
|
171
176
|
|
@@ -237,20 +242,20 @@ class ModelMetadata:
|
|
237
242
|
name: str,
|
238
243
|
env: model_env.ModelEnv,
|
239
244
|
model_type: model_types.SupportedModelHandlerType,
|
240
|
-
runtimes: Optional[
|
241
|
-
signatures: Optional[
|
242
|
-
function_properties: Optional[
|
243
|
-
user_files: Optional[
|
244
|
-
metadata: Optional[
|
245
|
+
runtimes: Optional[dict[str, model_runtime.ModelRuntime]] = None,
|
246
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
247
|
+
function_properties: Optional[dict[str, dict[str, Any]]] = None,
|
248
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
249
|
+
metadata: Optional[dict[str, str]] = None,
|
245
250
|
creation_timestamp: Optional[str] = None,
|
246
251
|
min_snowpark_ml_version: Optional[str] = None,
|
247
|
-
models: Optional[
|
252
|
+
models: Optional[dict[str, model_blob_meta.ModelBlobMeta]] = None,
|
248
253
|
original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
|
249
254
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
250
255
|
explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
|
251
256
|
) -> None:
|
252
257
|
self.name = name
|
253
|
-
self.signatures:
|
258
|
+
self.signatures: dict[str, model_signature.ModelSignature] = dict()
|
254
259
|
if signatures:
|
255
260
|
self.signatures = signatures
|
256
261
|
self.function_properties = function_properties or {}
|
@@ -265,7 +270,7 @@ class ModelMetadata:
|
|
265
270
|
else model_meta_schema.MODEL_METADATA_MIN_SNOWPARK_ML_VERSION
|
266
271
|
)
|
267
272
|
|
268
|
-
self.models:
|
273
|
+
self.models: dict[str, model_blob_meta.ModelBlobMeta] = dict()
|
269
274
|
if models:
|
270
275
|
self.models = models
|
271
276
|
|
@@ -286,7 +291,7 @@ class ModelMetadata:
|
|
286
291
|
self._min_snowpark_ml_version = max(self._min_snowpark_ml_version, parsed_min_snowpark_ml_version)
|
287
292
|
|
288
293
|
@property
|
289
|
-
def runtimes(self) ->
|
294
|
+
def runtimes(self) -> dict[str, model_runtime.ModelRuntime]:
|
290
295
|
if self._runtimes and "cpu" in self._runtimes:
|
291
296
|
return self._runtimes
|
292
297
|
runtimes = {
|
@@ -353,11 +358,11 @@ class ModelMetadata:
|
|
353
358
|
|
354
359
|
loaded_meta_min_snowpark_ml_version = loaded_meta.get("min_snowpark_ml_version", None)
|
355
360
|
if not loaded_meta_min_snowpark_ml_version or (
|
356
|
-
version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(
|
361
|
+
version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(snowml_version.VERSION)
|
357
362
|
):
|
358
363
|
raise RuntimeError(
|
359
364
|
f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version}, "
|
360
|
-
f"while current version of Snowpark ML library is {
|
365
|
+
f"while current version of Snowpark ML library is {snowml_version.VERSION}."
|
361
366
|
)
|
362
367
|
return model_meta_schema.ModelMetadataDict(
|
363
368
|
creation_timestamp=loaded_meta["creation_timestamp"],
|
@@ -400,7 +405,7 @@ class ModelMetadata:
|
|
400
405
|
env = model_env.ModelEnv()
|
401
406
|
env.load_from_dict(pathlib.Path(model_dir_path), model_dict["env"])
|
402
407
|
|
403
|
-
runtimes: Optional[
|
408
|
+
runtimes: Optional[dict[str, model_runtime.ModelRuntime]]
|
404
409
|
if model_dict.get("runtimes", None):
|
405
410
|
runtimes = {
|
406
411
|
name: model_runtime.ModelRuntime.load(pathlib.Path(model_dir_path), name, env, runtime_dict)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# This files contains schema definition of what will be written into model.yml
|
2
2
|
# Changing this file should lead to a change of the schema version.
|
3
3
|
from enum import Enum
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, TypedDict, Union
|
5
5
|
|
6
6
|
from typing_extensions import NotRequired, Required
|
7
7
|
|
@@ -18,18 +18,20 @@ class FunctionProperties(Enum):
|
|
18
18
|
class ModelRuntimeDependenciesDict(TypedDict):
|
19
19
|
conda: Required[str]
|
20
20
|
pip: Required[str]
|
21
|
-
artifact_repository_map: NotRequired[Optional[
|
21
|
+
artifact_repository_map: NotRequired[Optional[dict[str, str]]]
|
22
22
|
|
23
23
|
|
24
24
|
class ModelRuntimeDict(TypedDict):
|
25
|
-
imports: Required[
|
25
|
+
imports: Required[list[str]]
|
26
26
|
dependencies: Required[ModelRuntimeDependenciesDict]
|
27
|
+
resource_constraint: NotRequired[Optional[dict[str, str]]]
|
27
28
|
|
28
29
|
|
29
30
|
class ModelEnvDict(TypedDict):
|
30
31
|
conda: Required[str]
|
31
32
|
pip: Required[str]
|
32
|
-
artifact_repository_map: NotRequired[Optional[
|
33
|
+
artifact_repository_map: NotRequired[Optional[dict[str, str]]]
|
34
|
+
resource_constraint: NotRequired[Optional[dict[str, str]]]
|
33
35
|
python_version: Required[str]
|
34
36
|
cuda_version: NotRequired[Optional[str]]
|
35
37
|
snowpark_ml_version: Required[str]
|
@@ -102,25 +104,25 @@ class ModelBlobMetadataDict(TypedDict):
|
|
102
104
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
103
105
|
path: Required[str]
|
104
106
|
handler_version: Required[str]
|
105
|
-
function_properties: NotRequired[
|
106
|
-
artifacts: NotRequired[
|
107
|
+
function_properties: NotRequired[dict[str, dict[str, Any]]]
|
108
|
+
artifacts: NotRequired[dict[str, str]]
|
107
109
|
options: NotRequired[ModelBlobOptions]
|
108
110
|
|
109
111
|
|
110
112
|
class ModelMetadataDict(TypedDict):
|
111
113
|
creation_timestamp: Required[str]
|
112
114
|
env: Required[ModelEnvDict]
|
113
|
-
runtimes: NotRequired[
|
114
|
-
metadata: NotRequired[Optional[
|
115
|
+
runtimes: NotRequired[dict[str, ModelRuntimeDict]]
|
116
|
+
metadata: NotRequired[Optional[dict[str, str]]]
|
115
117
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
116
|
-
models: Required[
|
118
|
+
models: Required[dict[str, ModelBlobMetadataDict]]
|
117
119
|
name: Required[str]
|
118
|
-
signatures: Required[
|
120
|
+
signatures: Required[dict[str, dict[str, Any]]]
|
119
121
|
version: Required[str]
|
120
122
|
min_snowpark_ml_version: Required[str]
|
121
123
|
task: Required[str]
|
122
124
|
explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
|
123
|
-
function_properties: NotRequired[
|
125
|
+
function_properties: NotRequired[dict[str, dict[str, Any]]]
|
124
126
|
|
125
127
|
|
126
128
|
class ModelExplainAlgorithm(Enum):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
from abc import abstractmethod
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, Protocol, final
|
4
4
|
|
5
5
|
from snowflake.ml._internal import migrator_utils
|
6
6
|
|
@@ -11,13 +11,13 @@ class _BaseModelMetaMigratorProtocol(Protocol):
|
|
11
11
|
|
12
12
|
@staticmethod
|
13
13
|
@abstractmethod
|
14
|
-
def upgrade(original_meta_dict:
|
14
|
+
def upgrade(original_meta_dict: dict[str, Any]) -> dict[str, Any]:
|
15
15
|
raise NotImplementedError
|
16
16
|
|
17
17
|
|
18
18
|
class BaseModelMetaMigrator(_BaseModelMetaMigratorProtocol):
|
19
19
|
@final
|
20
|
-
def try_upgrade(self, original_meta_dict:
|
20
|
+
def try_upgrade(self, original_meta_dict: dict[str, Any]) -> dict[str, Any]:
|
21
21
|
loaded_meta_version = original_meta_dict.get("version", None)
|
22
22
|
if not loaded_meta_version or str(loaded_meta_version) != self.source_version:
|
23
23
|
raise NotImplementedError(
|