snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import pathlib
|
4
4
|
import warnings
|
5
|
-
from typing import
|
5
|
+
from typing import Optional, cast
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
@@ -45,10 +45,10 @@ class ModelManifest:
|
|
45
45
|
self,
|
46
46
|
model_meta: model_meta_api.ModelMetadata,
|
47
47
|
model_rel_path: pathlib.PurePosixPath,
|
48
|
-
user_files: Optional[
|
48
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
49
49
|
options: Optional[type_hints.ModelSaveOption] = None,
|
50
|
-
data_sources: Optional[
|
51
|
-
target_platforms: Optional[
|
50
|
+
data_sources: Optional[list[data_source.DataSource]] = None,
|
51
|
+
target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
|
52
52
|
) -> None:
|
53
53
|
if options is None:
|
54
54
|
options = {}
|
@@ -78,12 +78,13 @@ class ModelManifest:
|
|
78
78
|
logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
|
79
79
|
logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
|
80
80
|
logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
|
81
|
+
logger.info(f"resource_constraint: {runtime_to_use.runtime_env.resource_constraint}")
|
81
82
|
runtime_dict = runtime_to_use.save(
|
82
83
|
self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
83
84
|
)
|
84
85
|
|
85
86
|
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
86
|
-
self.methods:
|
87
|
+
self.methods: list[model_method.ModelMethod] = []
|
87
88
|
|
88
89
|
for target_method in model_meta.signatures.keys():
|
89
90
|
method = model_method.ModelMethod(
|
@@ -100,7 +101,7 @@ class ModelManifest:
|
|
100
101
|
|
101
102
|
self.methods.append(method)
|
102
103
|
|
103
|
-
self.user_files:
|
104
|
+
self.user_files: list[model_user_file.ModelUserFile] = []
|
104
105
|
|
105
106
|
if user_files is not None:
|
106
107
|
for subdirectory, paths in user_files.items():
|
@@ -127,16 +128,19 @@ class ModelManifest:
|
|
127
128
|
if model_meta.env.artifact_repository_map:
|
128
129
|
dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
|
129
130
|
|
131
|
+
runtime = model_manifest_schema.ModelRuntimeDict(
|
132
|
+
language="PYTHON",
|
133
|
+
version=runtime_to_use.runtime_env.python_version,
|
134
|
+
imports=runtime_dict["imports"],
|
135
|
+
dependencies=dependencies,
|
136
|
+
)
|
137
|
+
|
138
|
+
if runtime_dict["resource_constraint"]:
|
139
|
+
runtime["resource_constraint"] = runtime_dict["resource_constraint"]
|
140
|
+
|
130
141
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
131
142
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
132
|
-
runtimes={
|
133
|
-
self._DEFAULT_RUNTIME_NAME: model_manifest_schema.ModelRuntimeDict(
|
134
|
-
language="PYTHON",
|
135
|
-
version=runtime_to_use.runtime_env.python_version,
|
136
|
-
imports=runtime_dict["imports"],
|
137
|
-
dependencies=dependencies,
|
138
|
-
)
|
139
|
-
},
|
143
|
+
runtimes={self._DEFAULT_RUNTIME_NAME: runtime},
|
140
144
|
methods=[
|
141
145
|
method.save(
|
142
146
|
self.workspace_path,
|
@@ -178,8 +182,8 @@ class ModelManifest:
|
|
178
182
|
return res
|
179
183
|
|
180
184
|
def _extract_lineage_info(
|
181
|
-
self, data_sources: Optional[
|
182
|
-
) ->
|
185
|
+
self, data_sources: Optional[list[data_source.DataSource]]
|
186
|
+
) -> list[model_manifest_schema.LineageSourceDict]:
|
183
187
|
result = []
|
184
188
|
if data_sources:
|
185
189
|
for source in data_sources:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# This files contains schema definition of what will be written into MANIFEST.yml
|
2
2
|
import enum
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, Literal, Optional, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired, Required
|
6
6
|
|
@@ -20,14 +20,15 @@ class ModelMethodFunctionTypes(enum.Enum):
|
|
20
20
|
class ModelRuntimeDependenciesDict(TypedDict):
|
21
21
|
conda: NotRequired[str]
|
22
22
|
pip: NotRequired[str]
|
23
|
-
artifact_repository_map: NotRequired[Optional[
|
23
|
+
artifact_repository_map: NotRequired[Optional[dict[str, str]]]
|
24
24
|
|
25
25
|
|
26
26
|
class ModelRuntimeDict(TypedDict):
|
27
27
|
language: Required[Literal["PYTHON"]]
|
28
28
|
version: Required[str]
|
29
|
-
imports: Required[
|
29
|
+
imports: Required[list[str]]
|
30
30
|
dependencies: Required[ModelRuntimeDependenciesDict]
|
31
|
+
resource_constraint: NotRequired[Optional[dict[str, str]]]
|
31
32
|
|
32
33
|
|
33
34
|
class ModelMethodSignatureField(TypedDict):
|
@@ -43,8 +44,8 @@ class ModelFunctionMethodDict(TypedDict):
|
|
43
44
|
runtime: Required[str]
|
44
45
|
type: Required[str]
|
45
46
|
handler: Required[str]
|
46
|
-
inputs: Required[
|
47
|
-
outputs: Required[Union[
|
47
|
+
inputs: Required[list[ModelMethodSignatureFieldWithName]]
|
48
|
+
outputs: Required[Union[list[ModelMethodSignatureField], list[ModelMethodSignatureFieldWithName]]]
|
48
49
|
|
49
50
|
|
50
51
|
ModelMethodDict = ModelFunctionMethodDict
|
@@ -71,12 +72,12 @@ class ModelFunctionInfo(TypedDict):
|
|
71
72
|
class ModelFunctionInfoDict(TypedDict):
|
72
73
|
name: Required[str]
|
73
74
|
target_method: Required[str]
|
74
|
-
signature: Required[
|
75
|
+
signature: Required[dict[str, Any]]
|
75
76
|
|
76
77
|
|
77
78
|
class SnowparkMLDataDict(TypedDict):
|
78
79
|
schema_version: Required[str]
|
79
|
-
functions: Required[
|
80
|
+
functions: Required[list[ModelFunctionInfoDict]]
|
80
81
|
|
81
82
|
|
82
83
|
class LineageSourceTypes(enum.Enum):
|
@@ -92,9 +93,9 @@ class LineageSourceDict(TypedDict):
|
|
92
93
|
|
93
94
|
class ModelManifestDict(TypedDict):
|
94
95
|
manifest_version: Required[str]
|
95
|
-
runtimes: Required[
|
96
|
-
methods: Required[
|
97
|
-
user_data: NotRequired[
|
98
|
-
user_files: NotRequired[
|
99
|
-
lineage_sources: NotRequired[
|
100
|
-
target_platforms: NotRequired[
|
96
|
+
runtimes: Required[dict[str, ModelRuntimeDict]]
|
97
|
+
methods: Required[list[ModelMethodDict]]
|
98
|
+
user_data: NotRequired[dict[str, Any]]
|
99
|
+
user_files: NotRequired[list[str]]
|
100
|
+
lineage_sources: NotRequired[list[LineageSourceDict]]
|
101
|
+
target_platforms: NotRequired[list[str]]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import pathlib
|
3
|
-
from typing import
|
3
|
+
from typing import Optional, TypedDict, Union
|
4
4
|
|
5
5
|
from typing_extensions import NotRequired
|
6
6
|
|
@@ -137,8 +137,8 @@ class ModelMethod:
|
|
137
137
|
)
|
138
138
|
|
139
139
|
outputs: Union[
|
140
|
-
|
141
|
-
|
140
|
+
list[model_manifest_schema.ModelMethodSignatureField],
|
141
|
+
list[model_manifest_schema.ModelMethodSignatureFieldWithName],
|
142
142
|
]
|
143
143
|
if self.function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
144
144
|
outputs = [
|
@@ -3,10 +3,11 @@ import itertools
|
|
3
3
|
import os
|
4
4
|
import pathlib
|
5
5
|
import warnings
|
6
|
-
from typing import DefaultDict,
|
6
|
+
from typing import DefaultDict, Optional
|
7
7
|
|
8
8
|
from packaging import requirements, version
|
9
9
|
|
10
|
+
from snowflake.ml import version as snowml_version
|
10
11
|
from snowflake.ml._internal import env as snowml_env, env_utils
|
11
12
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
12
13
|
|
@@ -19,9 +20,8 @@ _DEFAULT_CONDA_ENV_FILENAME = "conda.yml"
|
|
19
20
|
_DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
|
20
21
|
|
21
22
|
# The default CUDA version is chosen based on the driver availability in SPCS.
|
22
|
-
#
|
23
|
-
|
24
|
-
DEFAULT_CUDA_VERSION = "11.8"
|
23
|
+
# Make sure they are aligned with default CUDA version in inference server.
|
24
|
+
DEFAULT_CUDA_VERSION = "12.4"
|
25
25
|
|
26
26
|
|
27
27
|
class ModelEnv:
|
@@ -29,22 +29,25 @@ class ModelEnv:
|
|
29
29
|
self,
|
30
30
|
conda_env_rel_path: Optional[str] = None,
|
31
31
|
pip_requirements_rel_path: Optional[str] = None,
|
32
|
+
prefer_pip: bool = False,
|
32
33
|
) -> None:
|
33
34
|
if conda_env_rel_path is None:
|
34
35
|
conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
|
35
36
|
if pip_requirements_rel_path is None:
|
36
37
|
pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
|
38
|
+
self.prefer_pip: bool = prefer_pip
|
37
39
|
self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
|
38
40
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
|
39
|
-
self.artifact_repository_map: Optional[
|
40
|
-
self.
|
41
|
-
self.
|
41
|
+
self.artifact_repository_map: Optional[dict[str, str]] = None
|
42
|
+
self.resource_constraint: Optional[dict[str, str]] = None
|
43
|
+
self._conda_dependencies: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
|
44
|
+
self._pip_requirements: list[requirements.Requirement] = []
|
42
45
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
43
46
|
self._cuda_version: Optional[version.Version] = None
|
44
|
-
self._snowpark_ml_version: version.Version = version.parse(
|
47
|
+
self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
|
45
48
|
|
46
49
|
@property
|
47
|
-
def conda_dependencies(self) ->
|
50
|
+
def conda_dependencies(self) -> list[str]:
|
48
51
|
"""List of conda channel and dependencies from that to run the model"""
|
49
52
|
return sorted(
|
50
53
|
f"{chan}::{str(req)}" if chan else str(req)
|
@@ -55,24 +58,24 @@ class ModelEnv:
|
|
55
58
|
@conda_dependencies.setter
|
56
59
|
def conda_dependencies(
|
57
60
|
self,
|
58
|
-
conda_dependencies: Optional[
|
61
|
+
conda_dependencies: Optional[list[str]] = None,
|
59
62
|
) -> None:
|
60
63
|
self._conda_dependencies = env_utils.validate_conda_dependency_string_list(
|
61
|
-
conda_dependencies if conda_dependencies else []
|
64
|
+
conda_dependencies if conda_dependencies else [], add_local_version_specifier=True
|
62
65
|
)
|
63
66
|
|
64
67
|
@property
|
65
|
-
def pip_requirements(self) ->
|
68
|
+
def pip_requirements(self) -> list[str]:
|
66
69
|
"""List of pip Python packages requirements for running the model."""
|
67
70
|
return sorted(list(map(str, self._pip_requirements)))
|
68
71
|
|
69
72
|
@pip_requirements.setter
|
70
73
|
def pip_requirements(
|
71
74
|
self,
|
72
|
-
pip_requirements: Optional[
|
75
|
+
pip_requirements: Optional[list[str]] = None,
|
73
76
|
) -> None:
|
74
77
|
self._pip_requirements = env_utils.validate_pip_requirement_string_list(
|
75
|
-
pip_requirements if pip_requirements else []
|
78
|
+
pip_requirements if pip_requirements else [], add_local_version_specifier=True
|
76
79
|
)
|
77
80
|
|
78
81
|
@property
|
@@ -113,7 +116,11 @@ class ModelEnv:
|
|
113
116
|
if snowpark_ml_version:
|
114
117
|
self._snowpark_ml_version = version.parse(snowpark_ml_version)
|
115
118
|
|
116
|
-
def include_if_absent(
|
119
|
+
def include_if_absent(
|
120
|
+
self,
|
121
|
+
pkgs: list[ModelDependency],
|
122
|
+
check_local_version: bool = False,
|
123
|
+
) -> None:
|
117
124
|
"""Append requirements into model env if absent. Depending on the environment, requirements may be added
|
118
125
|
to either the pip requirements or conda dependencies.
|
119
126
|
|
@@ -121,8 +128,8 @@ class ModelEnv:
|
|
121
128
|
pkgs: A list of ModelDependency namedtuple to be appended.
|
122
129
|
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
123
130
|
"""
|
124
|
-
if self.pip_requirements and not self.conda_dependencies and pkgs:
|
125
|
-
pip_pkg_reqs:
|
131
|
+
if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
|
132
|
+
pip_pkg_reqs: list[str] = []
|
126
133
|
warnings.warn(
|
127
134
|
(
|
128
135
|
"Dependencies specified from pip requirements."
|
@@ -139,7 +146,7 @@ class ModelEnv:
|
|
139
146
|
else:
|
140
147
|
self._include_if_absent_conda(pkgs, check_local_version)
|
141
148
|
|
142
|
-
def _include_if_absent_conda(self, pkgs:
|
149
|
+
def _include_if_absent_conda(self, pkgs: list[ModelDependency], check_local_version: bool = False) -> None:
|
143
150
|
"""Append requirements into model env conda dependencies if absent.
|
144
151
|
|
145
152
|
Args:
|
@@ -184,7 +191,7 @@ class ModelEnv:
|
|
184
191
|
stacklevel=2,
|
185
192
|
)
|
186
193
|
|
187
|
-
def _include_if_absent_pip(self, pkgs:
|
194
|
+
def _include_if_absent_pip(self, pkgs: list[str], check_local_version: bool = False) -> None:
|
188
195
|
"""Append pip requirements into model env pip requirements if absent.
|
189
196
|
|
190
197
|
Args:
|
@@ -201,7 +208,7 @@ class ModelEnv:
|
|
201
208
|
except env_utils.DuplicateDependencyError:
|
202
209
|
pass
|
203
210
|
|
204
|
-
def remove_if_present_conda(self, conda_pkgs:
|
211
|
+
def remove_if_present_conda(self, conda_pkgs: list[str]) -> None:
|
205
212
|
"""Remove conda requirements from model env if present.
|
206
213
|
|
207
214
|
Args:
|
@@ -346,13 +353,14 @@ class ModelEnv:
|
|
346
353
|
def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
|
347
354
|
self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
|
348
355
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
|
349
|
-
self.artifact_repository_map = env_dict.get("artifact_repository_map"
|
356
|
+
self.artifact_repository_map = env_dict.get("artifact_repository_map")
|
357
|
+
self.resource_constraint = env_dict.get("resource_constraint")
|
350
358
|
|
351
359
|
self.load_from_conda_file(base_dir / self.conda_env_rel_path)
|
352
360
|
self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
|
353
361
|
|
354
362
|
self.python_version = env_dict["python_version"]
|
355
|
-
self.cuda_version = env_dict.get("cuda_version"
|
363
|
+
self.cuda_version = env_dict.get("cuda_version")
|
356
364
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
357
365
|
|
358
366
|
def save_as_dict(
|
@@ -375,7 +383,8 @@ class ModelEnv:
|
|
375
383
|
return {
|
376
384
|
"conda": self.conda_env_rel_path.as_posix(),
|
377
385
|
"pip": self.pip_requirements_rel_path.as_posix(),
|
378
|
-
"artifact_repository_map": self.artifact_repository_map
|
386
|
+
"artifact_repository_map": self.artifact_repository_map or {},
|
387
|
+
"resource_constraint": self.resource_constraint or {},
|
379
388
|
"python_version": self.python_version,
|
380
389
|
"cuda_version": self.cuda_version,
|
381
390
|
"snowpark_ml_version": self.snowpark_ml_version,
|
@@ -383,7 +392,7 @@ class ModelEnv:
|
|
383
392
|
|
384
393
|
def validate_with_local_env(
|
385
394
|
self, check_snowpark_ml_version: bool = False
|
386
|
-
) ->
|
395
|
+
) -> list[env_utils.IncorrectLocalEnvironmentError]:
|
387
396
|
errors = []
|
388
397
|
try:
|
389
398
|
env_utils.validate_py_runtime_version(str(self._python_version))
|
@@ -407,10 +416,10 @@ class ModelEnv:
|
|
407
416
|
|
408
417
|
if check_snowpark_ml_version:
|
409
418
|
# For Modeling model
|
410
|
-
if self._snowpark_ml_version.base_version !=
|
419
|
+
if self._snowpark_ml_version.base_version != snowml_version.VERSION:
|
411
420
|
errors.append(
|
412
421
|
env_utils.IncorrectLocalEnvironmentError(
|
413
|
-
f"The local installed version of Snowpark ML library is {
|
422
|
+
f"The local installed version of Snowpark ML library is {snowml_version.VERSION} "
|
414
423
|
f"which differs from required version {self.snowpark_ml_version}."
|
415
424
|
)
|
416
425
|
)
|
@@ -2,13 +2,13 @@ import functools
|
|
2
2
|
import importlib
|
3
3
|
import pkgutil
|
4
4
|
from types import ModuleType
|
5
|
-
from typing import Any, Callable,
|
5
|
+
from typing import Any, Callable, Optional, TypeVar, cast
|
6
6
|
|
7
7
|
from snowflake.ml.model import type_hints as model_types
|
8
8
|
from snowflake.ml.model._packager.model_handlers import _base
|
9
9
|
|
10
10
|
_HANDLERS_BASE = "snowflake.ml.model._packager.model_handlers"
|
11
|
-
_MODEL_HANDLER_REGISTRY:
|
11
|
+
_MODEL_HANDLER_REGISTRY: dict[str, type[_base.BaseModelHandler[model_types.SupportedModelType]]] = dict()
|
12
12
|
_IS_HANDLER_LOADED = False
|
13
13
|
|
14
14
|
|
@@ -54,7 +54,7 @@ def ensure_handlers_registration(fn: F) -> F:
|
|
54
54
|
@ensure_handlers_registration
|
55
55
|
def find_handler(
|
56
56
|
model: model_types.SupportedModelType,
|
57
|
-
) -> Optional[
|
57
|
+
) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
|
58
58
|
for handler in _MODEL_HANDLER_REGISTRY.values():
|
59
59
|
if handler.can_handle(model):
|
60
60
|
return handler
|
@@ -64,7 +64,7 @@ def find_handler(
|
|
64
64
|
@ensure_handlers_registration
|
65
65
|
def load_handler(
|
66
66
|
target_model_type: model_types.SupportedModelHandlerType,
|
67
|
-
) -> Optional[
|
67
|
+
) -> Optional[type[_base.BaseModelHandler[model_types.SupportedModelType]]]:
|
68
68
|
for model_type, handler in _MODEL_HANDLER_REGISTRY.items():
|
69
69
|
if target_model_type == model_type:
|
70
70
|
return handler
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
from abc import abstractmethod
|
3
|
-
from typing import
|
3
|
+
from typing import Generic, Optional, Protocol, final
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
from typing_extensions import TypeGuard, Unpack
|
@@ -14,7 +14,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
14
14
|
HANDLER_TYPE: model_types.SupportedModelHandlerType
|
15
15
|
HANDLER_VERSION: str
|
16
16
|
_MIN_SNOWPARK_ML_VERSION: str
|
17
|
-
_HANDLER_MIGRATOR_PLANS:
|
17
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]]
|
18
18
|
|
19
19
|
@classmethod
|
20
20
|
@abstractmethod
|
@@ -1,8 +1,9 @@
|
|
1
|
+
import importlib
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import pathlib
|
4
5
|
import warnings
|
5
|
-
from typing import Any, Callable,
|
6
|
+
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
import numpy.typing as npt
|
@@ -10,8 +11,10 @@ import pandas as pd
|
|
10
11
|
from absl import logging
|
11
12
|
|
12
13
|
import snowflake.snowpark.dataframe as sp_df
|
14
|
+
from snowflake.ml._internal import env
|
13
15
|
from snowflake.ml._internal.utils import identifier
|
14
16
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
17
|
+
from snowflake.ml.model._packager.model_env import model_env
|
15
18
|
from snowflake.ml.model._packager.model_meta import model_meta
|
16
19
|
from snowflake.ml.model._signatures import (
|
17
20
|
core,
|
@@ -231,7 +234,7 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
|
|
231
234
|
|
232
235
|
|
233
236
|
def get_explain_target_method(
|
234
|
-
model_metadata: model_meta.ModelMetadata, target_methods_list:
|
237
|
+
model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
|
235
238
|
) -> Optional[str]:
|
236
239
|
for method in model_metadata.signatures.keys():
|
237
240
|
if method in target_methods_list:
|
@@ -248,7 +251,7 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
|
|
248
251
|
config_dict = json.load(f)
|
249
252
|
|
250
253
|
# a. get repository and class_path from configs
|
251
|
-
auto_map_configs = cast(
|
254
|
+
auto_map_configs = cast(dict[str, str], config_dict.get("auto_map", {}))
|
252
255
|
for config_name, config_value in auto_map_configs.items():
|
253
256
|
repository, _, class_path = config_value.rpartition("--")
|
254
257
|
|
@@ -261,3 +264,12 @@ def save_transformers_config_with_auto_map(local_model_path: str) -> None:
|
|
261
264
|
|
262
265
|
with open(f_path, "w") as f:
|
263
266
|
json.dump(config_dict, f)
|
267
|
+
|
268
|
+
|
269
|
+
def get_default_cuda_version() -> str:
|
270
|
+
# Default to the env cuda version when running in ML runtime
|
271
|
+
if env.IN_ML_RUNTIME and importlib.util.find_spec("torch") is not None:
|
272
|
+
import torch
|
273
|
+
|
274
|
+
return torch.version.cuda or model_env.DEFAULT_CUDA_VERSION
|
275
|
+
return model_env.DEFAULT_CUDA_VERSION
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
@@ -30,7 +30,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
30
30
|
HANDLER_TYPE = "catboost"
|
31
31
|
HANDLER_VERSION = "2024-03-21"
|
32
32
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
33
|
-
_HANDLER_MIGRATOR_PLANS:
|
33
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
34
34
|
|
35
35
|
MODEL_BLOB_FILE_OR_DIR = "model.bin"
|
36
36
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
@@ -147,7 +147,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
147
147
|
if enable_explainability:
|
148
148
|
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
149
149
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
150
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
150
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
151
151
|
|
152
152
|
return None
|
153
153
|
|
@@ -202,7 +202,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
202
202
|
def _create_custom_model(
|
203
203
|
raw_model: "catboost.CatBoost",
|
204
204
|
model_meta: model_meta_api.ModelMetadata,
|
205
|
-
) ->
|
205
|
+
) -> type[custom_model.CustomModel]:
|
206
206
|
def fn_factory(
|
207
207
|
raw_model: "catboost.CatBoost",
|
208
208
|
signature: model_signature.ModelSignature,
|
@@ -235,7 +235,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
235
235
|
|
236
236
|
return fn
|
237
237
|
|
238
|
-
type_method_dict:
|
238
|
+
type_method_dict: dict[str, Any] = {"_raw_model": raw_model}
|
239
239
|
for target_method_name, sig in model_meta.signatures.items():
|
240
240
|
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
241
241
|
|
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import os
|
3
3
|
import pathlib
|
4
4
|
import sys
|
5
|
-
from typing import
|
5
|
+
from typing import Optional, cast, final
|
6
6
|
|
7
7
|
import anyio
|
8
8
|
import cloudpickle
|
@@ -28,7 +28,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
28
28
|
HANDLER_TYPE = "custom"
|
29
29
|
HANDLER_VERSION = "2023-12-01"
|
30
30
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
31
|
-
_HANDLER_MIGRATOR_PLANS:
|
31
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
32
|
|
33
33
|
@classmethod
|
34
34
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["custom_model.CustomModel"]:
|
@@ -99,7 +99,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
101
|
if handler is None:
|
102
|
-
raise TypeError(
|
102
|
+
raise TypeError(
|
103
|
+
f"Model {sub_name} in model context is not a supported model type. See "
|
104
|
+
"https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/"
|
105
|
+
"bring-your-own-model-types for more details."
|
106
|
+
)
|
103
107
|
sub_model = handler.cast_model(model_ref.model)
|
104
108
|
handler.save_model(
|
105
109
|
name=sub_name,
|
@@ -161,7 +165,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
161
165
|
name: str(pathlib.PurePath(model_blob_path) / pathlib.PurePosixPath(rel_path))
|
162
166
|
for name, rel_path in artifacts_meta.items()
|
163
167
|
}
|
164
|
-
models:
|
168
|
+
models: dict[str, model_types.SupportedModelType] = dict()
|
165
169
|
for sub_model_name, _ref in context.model_refs.items():
|
166
170
|
model_type = model_meta.models[sub_model_name].model_type
|
167
171
|
handler = model_handler.load_handler(model_type)
|
@@ -1,18 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
Callable,
|
8
|
-
Dict,
|
9
|
-
List,
|
10
|
-
Optional,
|
11
|
-
Type,
|
12
|
-
Union,
|
13
|
-
cast,
|
14
|
-
final,
|
15
|
-
)
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
16
5
|
|
17
6
|
import cloudpickle
|
18
7
|
import numpy as np
|
@@ -38,7 +27,7 @@ if TYPE_CHECKING:
|
|
38
27
|
import transformers
|
39
28
|
|
40
29
|
|
41
|
-
def get_requirements_from_task(task: str, spcs_only: bool = False) ->
|
30
|
+
def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
|
42
31
|
# Text
|
43
32
|
if task in [
|
44
33
|
"conversational",
|
@@ -84,7 +73,7 @@ class HuggingFacePipelineHandler(
|
|
84
73
|
HANDLER_TYPE = "huggingface_pipeline"
|
85
74
|
HANDLER_VERSION = "2023-12-01"
|
86
75
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
87
|
-
_HANDLER_MIGRATOR_PLANS:
|
76
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
88
77
|
|
89
78
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
90
79
|
ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
|
@@ -250,20 +239,17 @@ class HuggingFacePipelineHandler(
|
|
250
239
|
task, spcs_only=(not type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
251
240
|
)
|
252
241
|
if framework is None or framework == "pt":
|
253
|
-
# Since we set default cuda version to be 11.8, to make sure it works with GPU, we need to have a default
|
254
|
-
# Pytorch version that works with CUDA 11.8 as well. This is required for huggingface pipelines only as
|
255
|
-
# users are not required to install pytorch locally if they are using the wrapper.
|
256
242
|
pkgs_requirements.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
|
257
243
|
elif framework == "tf":
|
258
244
|
pkgs_requirements.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
|
259
245
|
model_meta.env.include_if_absent(
|
260
246
|
pkgs_requirements, check_local_version=(type_utils.LazyType("transformers.Pipeline").isinstance(model))
|
261
247
|
)
|
262
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
248
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
263
249
|
|
264
250
|
@staticmethod
|
265
|
-
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) ->
|
266
|
-
device_config:
|
251
|
+
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> dict[str, str]:
|
252
|
+
device_config: dict[str, Any] = {}
|
267
253
|
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
268
254
|
gpu_nums = 0
|
269
255
|
if cuda_visible_devices is not None:
|
@@ -369,7 +355,7 @@ class HuggingFacePipelineHandler(
|
|
369
355
|
def _create_custom_model(
|
370
356
|
raw_model: "transformers.Pipeline",
|
371
357
|
model_meta: model_meta_api.ModelMetadata,
|
372
|
-
) ->
|
358
|
+
) -> type[custom_model.CustomModel]:
|
373
359
|
def fn_factory(
|
374
360
|
raw_model: "transformers.Pipeline",
|
375
361
|
signature: model_signature.ModelSignature,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import TYPE_CHECKING, Callable,
|
2
|
+
from typing import TYPE_CHECKING, Callable, Optional, cast, final
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -32,7 +32,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
32
32
|
HANDLER_TYPE = "keras"
|
33
33
|
HANDLER_VERSION = "2025-01-01"
|
34
34
|
_MIN_SNOWPARK_ML_VERSION = "1.7.5"
|
35
|
-
_HANDLER_MIGRATOR_PLANS:
|
35
|
+
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
36
|
|
37
37
|
MODEL_BLOB_FILE_OR_DIR = "model.keras"
|
38
38
|
CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
|
@@ -146,7 +146,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
146
146
|
dependencies,
|
147
147
|
check_local_version=True,
|
148
148
|
)
|
149
|
-
model_meta.env.cuda_version = kwargs.get("cuda_version",
|
149
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
150
150
|
|
151
151
|
@classmethod
|
152
152
|
def load_model(
|
@@ -185,7 +185,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
185
185
|
def _create_custom_model(
|
186
186
|
raw_model: "keras.Model",
|
187
187
|
model_meta: model_meta_api.ModelMetadata,
|
188
|
-
) ->
|
188
|
+
) -> type[custom_model.CustomModel]:
|
189
189
|
def fn_factory(
|
190
190
|
raw_model: "keras.Model",
|
191
191
|
signature: model_signature.ModelSignature,
|