snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any
|
2
2
|
|
3
3
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
4
4
|
from snowflake.ml.model._packager.model_meta_migrator import base_migrator, migrator_v1
|
5
5
|
|
6
|
-
MODEL_META_MIGRATOR_PLANS:
|
6
|
+
MODEL_META_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelMetaMigrator]] = {
|
7
7
|
"1": migrator_v1.MetaMigrator_v1,
|
8
8
|
}
|
9
9
|
|
10
10
|
|
11
|
-
def migrate_metadata(loaded_meta:
|
11
|
+
def migrate_metadata(loaded_meta: dict[str, Any]) -> dict[str, Any]:
|
12
12
|
loaded_meta_version = str(loaded_meta.get("version", None))
|
13
13
|
while loaded_meta_version != model_meta_schema.MODEL_METADATA_VERSION:
|
14
14
|
if loaded_meta_version not in MODEL_META_MIGRATOR_PLANS.keys():
|
@@ -1,8 +1,8 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any
|
2
2
|
|
3
3
|
from packaging import requirements, version
|
4
4
|
|
5
|
-
from snowflake.ml
|
5
|
+
from snowflake.ml import version as snowml_version
|
6
6
|
from snowflake.ml.model._packager.model_meta_migrator import base_migrator
|
7
7
|
|
8
8
|
|
@@ -11,7 +11,7 @@ class MetaMigrator_v1(base_migrator.BaseModelMetaMigrator):
|
|
11
11
|
target_version = "2023-12-01"
|
12
12
|
|
13
13
|
@staticmethod
|
14
|
-
def upgrade(original_meta_dict:
|
14
|
+
def upgrade(original_meta_dict: dict[str, Any]) -> dict[str, Any]:
|
15
15
|
loaded_python_version = version.parse(original_meta_dict["python_version"])
|
16
16
|
if original_meta_dict.get("local_ml_library_version", None):
|
17
17
|
loaded_lib_version = str(version.parse(original_meta_dict["local_ml_library_version"]))
|
@@ -24,7 +24,7 @@ class MetaMigrator_v1(base_migrator.BaseModelMetaMigrator):
|
|
24
24
|
None,
|
25
25
|
)
|
26
26
|
if lib_spec_str is None:
|
27
|
-
loaded_lib_version =
|
27
|
+
loaded_lib_version = snowml_version.VERSION
|
28
28
|
loaded_lib_version = list(requirements.Requirement(str(lib_spec_str)).specifier)[0].version
|
29
29
|
|
30
30
|
return dict(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
from types import ModuleType
|
3
|
-
from typing import
|
3
|
+
from typing import Optional
|
4
4
|
|
5
5
|
from absl import logging
|
6
6
|
|
@@ -38,15 +38,17 @@ class ModelPackager:
|
|
38
38
|
*,
|
39
39
|
name: str,
|
40
40
|
model: model_types.SupportedModelType,
|
41
|
-
signatures: Optional[
|
41
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
42
42
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
43
|
-
metadata: Optional[
|
44
|
-
conda_dependencies: Optional[
|
45
|
-
pip_requirements: Optional[
|
46
|
-
artifact_repository_map: Optional[
|
43
|
+
metadata: Optional[dict[str, str]] = None,
|
44
|
+
conda_dependencies: Optional[list[str]] = None,
|
45
|
+
pip_requirements: Optional[list[str]] = None,
|
46
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
47
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
48
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
47
49
|
python_version: Optional[str] = None,
|
48
|
-
ext_modules: Optional[
|
49
|
-
code_paths: Optional[
|
50
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
51
|
+
code_paths: Optional[list[str]] = None,
|
50
52
|
options: model_types.ModelSaveOption,
|
51
53
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
52
54
|
) -> model_meta.ModelMetadata:
|
@@ -75,8 +77,10 @@ class ModelPackager:
|
|
75
77
|
conda_dependencies=conda_dependencies,
|
76
78
|
pip_requirements=pip_requirements,
|
77
79
|
artifact_repository_map=artifact_repository_map,
|
80
|
+
resource_constraint=resource_constraint,
|
78
81
|
python_version=python_version,
|
79
82
|
task=task,
|
83
|
+
target_platforms=target_platforms,
|
80
84
|
**options,
|
81
85
|
) as meta:
|
82
86
|
model_blobs_path = os.path.join(self.local_dir_path, ModelPackager.MODEL_BLOBS_DIR)
|
@@ -1 +1,32 @@
|
|
1
|
-
|
1
|
+
# DO NOT EDIT!
|
2
|
+
# Generate by running 'bazel run --config=pre_build //bazel/requirements:sync_requirements'
|
3
|
+
|
4
|
+
REQUIREMENTS = [
|
5
|
+
"absl-py>=0.15,<2",
|
6
|
+
"aiohttp!=4.0.0a0, !=4.0.0a1",
|
7
|
+
"anyio>=3.5.0,<5",
|
8
|
+
"cachetools>=3.1.1,<6",
|
9
|
+
"cloudpickle>=2.0.0,<3",
|
10
|
+
"cryptography",
|
11
|
+
"fsspec>=2024.6.1,<2026",
|
12
|
+
"importlib_resources>=6.1.1, <7",
|
13
|
+
"numpy>=1.23,<2",
|
14
|
+
"packaging>=20.9,<25",
|
15
|
+
"pandas>=1.0.0,<3",
|
16
|
+
"pyarrow",
|
17
|
+
"pydantic>=2.8.2, <3",
|
18
|
+
"pyjwt>=2.0.0, <3",
|
19
|
+
"pytimeparse>=1.1.8,<2",
|
20
|
+
"pyyaml>=6.0,<7",
|
21
|
+
"requests",
|
22
|
+
"retrying>=1.3.3,<2",
|
23
|
+
"s3fs>=2024.6.1,<2026",
|
24
|
+
"scikit-learn>=1.4,<1.6",
|
25
|
+
"scipy>=1.9,<2",
|
26
|
+
"snowflake-connector-python>=3.12.0,<4",
|
27
|
+
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
28
|
+
"snowflake.core>=1.0.2,<2",
|
29
|
+
"sqlparse>=0.4,<1",
|
30
|
+
"typing-extensions>=4.1.0,<5",
|
31
|
+
"xgboost>=1.7.3,<3",
|
32
|
+
]
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import copy
|
2
2
|
import pathlib
|
3
3
|
import warnings
|
4
|
-
from typing import
|
4
|
+
from typing import Optional
|
5
5
|
|
6
6
|
from packaging import requirements
|
7
7
|
|
@@ -37,7 +37,7 @@ class ModelRuntime:
|
|
37
37
|
self,
|
38
38
|
name: str,
|
39
39
|
env: model_env.ModelEnv,
|
40
|
-
imports: Optional[
|
40
|
+
imports: Optional[list[str]] = None,
|
41
41
|
is_warehouse: bool = False,
|
42
42
|
is_gpu: bool = False,
|
43
43
|
loading_from_file: bool = False,
|
@@ -102,6 +102,7 @@ class ModelRuntime:
|
|
102
102
|
if env_dict.get("artifact_repository_map") is not None
|
103
103
|
else {},
|
104
104
|
),
|
105
|
+
resource_constraint=env_dict["resource_constraint"],
|
105
106
|
)
|
106
107
|
|
107
108
|
@staticmethod
|
@@ -116,6 +117,7 @@ class ModelRuntime:
|
|
116
117
|
env.cuda_version = meta_env.cuda_version
|
117
118
|
env.snowpark_ml_version = meta_env.snowpark_ml_version
|
118
119
|
env.artifact_repository_map = meta_env.artifact_repository_map
|
120
|
+
env.resource_constraint = meta_env.resource_constraint
|
119
121
|
|
120
122
|
conda_env_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["conda"])
|
121
123
|
pip_requirements_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["pip"])
|
@@ -6,12 +6,8 @@ from typing import (
|
|
6
6
|
TYPE_CHECKING,
|
7
7
|
Any,
|
8
8
|
Callable,
|
9
|
-
Dict,
|
10
|
-
List,
|
11
9
|
Optional,
|
12
10
|
Sequence,
|
13
|
-
Tuple,
|
14
|
-
Type,
|
15
11
|
Union,
|
16
12
|
final,
|
17
13
|
get_args,
|
@@ -48,7 +44,7 @@ PandasExtensionTypes = Union[
|
|
48
44
|
|
49
45
|
|
50
46
|
class DataType(Enum):
|
51
|
-
def __init__(self, value: str, snowpark_type:
|
47
|
+
def __init__(self, value: str, snowpark_type: type[spt.DataType], numpy_type: npt.DTypeLike) -> None:
|
52
48
|
self._value = value
|
53
49
|
self._snowpark_type = snowpark_type
|
54
50
|
self._numpy_type = numpy_type
|
@@ -159,7 +155,7 @@ class DataType(Enum):
|
|
159
155
|
else:
|
160
156
|
actual_sp_type = snowpark_type
|
161
157
|
|
162
|
-
snowpark_to_snowml_type_mapping:
|
158
|
+
snowpark_to_snowml_type_mapping: dict[type[spt.DataType], DataType] = {
|
163
159
|
i._snowpark_type: i
|
164
160
|
for i in DataType
|
165
161
|
# We by default infer as signed integer.
|
@@ -199,7 +195,7 @@ class DataType(Enum):
|
|
199
195
|
class BaseFeatureSpec(ABC):
|
200
196
|
"""Abstract Class for specification of a feature."""
|
201
197
|
|
202
|
-
def __init__(self, name: str, shape: Optional[
|
198
|
+
def __init__(self, name: str, shape: Optional[tuple[int, ...]]) -> None:
|
203
199
|
self._name = name
|
204
200
|
|
205
201
|
if shape and not isinstance(shape, tuple):
|
@@ -218,23 +214,19 @@ class BaseFeatureSpec(ABC):
|
|
218
214
|
@abstractmethod
|
219
215
|
def as_snowpark_type(self) -> spt.DataType:
|
220
216
|
"""Convert to corresponding Snowpark Type."""
|
221
|
-
pass
|
222
217
|
|
223
218
|
@abstractmethod
|
224
219
|
def as_dtype(self, force_numpy_dtype: bool = False) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
|
225
220
|
"""Convert to corresponding local Type."""
|
226
|
-
pass
|
227
221
|
|
228
222
|
@abstractmethod
|
229
|
-
def to_dict(self) ->
|
223
|
+
def to_dict(self) -> dict[str, Any]:
|
230
224
|
"""Serialization"""
|
231
|
-
pass
|
232
225
|
|
233
226
|
@classmethod
|
234
227
|
@abstractmethod
|
235
|
-
def from_dict(self, input_dict:
|
228
|
+
def from_dict(self, input_dict: dict[str, Any]) -> "BaseFeatureSpec":
|
236
229
|
"""Deserialization"""
|
237
|
-
pass
|
238
230
|
|
239
231
|
|
240
232
|
class FeatureSpec(BaseFeatureSpec):
|
@@ -244,7 +236,7 @@ class FeatureSpec(BaseFeatureSpec):
|
|
244
236
|
self,
|
245
237
|
name: str,
|
246
238
|
dtype: DataType,
|
247
|
-
shape: Optional[
|
239
|
+
shape: Optional[tuple[int, ...]] = None,
|
248
240
|
nullable: bool = True,
|
249
241
|
) -> None:
|
250
242
|
"""
|
@@ -330,19 +322,19 @@ class FeatureSpec(BaseFeatureSpec):
|
|
330
322
|
f"name={repr(self._name)}{shape_str}, nullable={repr(self._nullable)})"
|
331
323
|
)
|
332
324
|
|
333
|
-
def to_dict(self) ->
|
325
|
+
def to_dict(self) -> dict[str, Any]:
|
334
326
|
"""Serialize the feature group into a dict.
|
335
327
|
|
336
328
|
Returns:
|
337
329
|
A dict that serializes the feature group.
|
338
330
|
"""
|
339
|
-
base_dict:
|
331
|
+
base_dict: dict[str, Any] = {"type": self._dtype.name, "name": self._name, "nullable": self._nullable}
|
340
332
|
if self._shape is not None:
|
341
333
|
base_dict["shape"] = self._shape
|
342
334
|
return base_dict
|
343
335
|
|
344
336
|
@classmethod
|
345
|
-
def from_dict(cls, input_dict:
|
337
|
+
def from_dict(cls, input_dict: dict[str, Any]) -> "FeatureSpec":
|
346
338
|
"""Deserialize the feature specification from a dict.
|
347
339
|
|
348
340
|
Args:
|
@@ -391,7 +383,7 @@ class FeatureSpec(BaseFeatureSpec):
|
|
391
383
|
class FeatureGroupSpec(BaseFeatureSpec):
|
392
384
|
"""Specification of a group of features in Snowflake native model packaging."""
|
393
385
|
|
394
|
-
def __init__(self, name: str, specs:
|
386
|
+
def __init__(self, name: str, specs: list[BaseFeatureSpec], shape: Optional[tuple[int, ...]] = None) -> None:
|
395
387
|
"""Initialize a feature group.
|
396
388
|
|
397
389
|
Args:
|
@@ -458,19 +450,19 @@ class FeatureGroupSpec(BaseFeatureSpec):
|
|
458
450
|
def as_dtype(self, force_numpy_dtype: bool = False) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
|
459
451
|
return np.object_
|
460
452
|
|
461
|
-
def to_dict(self) ->
|
453
|
+
def to_dict(self) -> dict[str, Any]:
|
462
454
|
"""Serialize the feature group into a dict.
|
463
455
|
|
464
456
|
Returns:
|
465
457
|
A dict that serializes the feature group.
|
466
458
|
"""
|
467
|
-
base_dict:
|
459
|
+
base_dict: dict[str, Any] = {"name": self._name, "specs": [s.to_dict() for s in self._specs]}
|
468
460
|
if self._shape is not None:
|
469
461
|
base_dict["shape"] = self._shape
|
470
462
|
return base_dict
|
471
463
|
|
472
464
|
@classmethod
|
473
|
-
def from_dict(cls, input_dict:
|
465
|
+
def from_dict(cls, input_dict: dict[str, Any]) -> "FeatureGroupSpec":
|
474
466
|
"""Deserialize the feature group from a dict.
|
475
467
|
|
476
468
|
Args:
|
@@ -520,7 +512,7 @@ class ModelSignature:
|
|
520
512
|
else:
|
521
513
|
return False
|
522
514
|
|
523
|
-
def to_dict(self) ->
|
515
|
+
def to_dict(self) -> dict[str, Any]:
|
524
516
|
"""Generate a dict to represent the whole signature.
|
525
517
|
|
526
518
|
Returns:
|
@@ -533,7 +525,7 @@ class ModelSignature:
|
|
533
525
|
}
|
534
526
|
|
535
527
|
@classmethod
|
536
|
-
def from_dict(cls, loaded:
|
528
|
+
def from_dict(cls, loaded: dict[str, Any]) -> "ModelSignature":
|
537
529
|
"""Create a signature given the dict containing specifications of children features and feature groups.
|
538
530
|
|
539
531
|
Args:
|
@@ -545,7 +537,7 @@ class ModelSignature:
|
|
545
537
|
sig_outs = loaded["outputs"]
|
546
538
|
sig_inputs = loaded["inputs"]
|
547
539
|
|
548
|
-
deserialize_spec: Callable[[
|
540
|
+
deserialize_spec: Callable[[dict[str, Any]], BaseFeatureSpec] = lambda sig_spec: (
|
549
541
|
FeatureGroupSpec.from_dict(sig_spec) if "specs" in sig_spec else FeatureSpec.from_dict(sig_spec)
|
550
542
|
)
|
551
543
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import TYPE_CHECKING,
|
1
|
+
from typing import TYPE_CHECKING, Literal, Optional, Sequence
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -45,7 +45,7 @@ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
|
|
45
45
|
@staticmethod
|
46
46
|
def infer_signature(data: "xgboost.DMatrix", role: Literal["input", "output"]) -> Sequence[core.BaseFeatureSpec]:
|
47
47
|
feature_prefix = f"{XGBoostDMatrixHandler.FEATURE_PREFIX}_"
|
48
|
-
features:
|
48
|
+
features: list[core.BaseFeatureSpec] = []
|
49
49
|
role_prefix = (
|
50
50
|
XGBoostDMatrixHandler.INPUT_PREFIX if role == "input" else XGBoostDMatrixHandler.OUTPUT_PREFIX
|
51
51
|
) + "_"
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, Sequence
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
@@ -12,7 +12,7 @@ from snowflake.ml._internal.exceptions import (
|
|
12
12
|
from snowflake.ml.model._signatures import core
|
13
13
|
|
14
14
|
|
15
|
-
def convert_list_to_ndarray(data:
|
15
|
+
def convert_list_to_ndarray(data: list[Any]) -> npt.NDArray[Any]:
|
16
16
|
"""Create a numpy array from list or nested list. Avoid ragged list and unaligned types.
|
17
17
|
|
18
18
|
Args:
|
@@ -49,7 +49,7 @@ def convert_list_to_ndarray(data: List[Any]) -> npt.NDArray[Any]:
|
|
49
49
|
|
50
50
|
|
51
51
|
def rename_features(
|
52
|
-
features: Sequence[core.BaseFeatureSpec], feature_names: Optional[
|
52
|
+
features: Sequence[core.BaseFeatureSpec], feature_names: Optional[list[str]] = None
|
53
53
|
) -> Sequence[core.BaseFeatureSpec]:
|
54
54
|
"""It renames the feature in features provided optional feature names.
|
55
55
|
|
@@ -104,7 +104,7 @@ def rename_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureSpec
|
|
104
104
|
return data
|
105
105
|
|
106
106
|
|
107
|
-
def huggingface_pipeline_signature_auto_infer(task: str, params:
|
107
|
+
def huggingface_pipeline_signature_auto_infer(task: str, params: dict[str, Any]) -> Optional[core.ModelSignature]:
|
108
108
|
# Text
|
109
109
|
|
110
110
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
|
@@ -351,7 +351,7 @@ def series_dropna(series: pd.Series) -> pd.Series:
|
|
351
351
|
return series.dropna(inplace=False).reset_index(drop=True).convert_dtypes()
|
352
352
|
|
353
353
|
|
354
|
-
def infer_list(name: str, data:
|
354
|
+
def infer_list(name: str, data: list[Any]) -> core.BaseFeatureSpec:
|
355
355
|
"""Infer the feature specification from a list.
|
356
356
|
|
357
357
|
Args:
|
@@ -382,7 +382,7 @@ def infer_list(name: str, data: List[Any]) -> core.BaseFeatureSpec:
|
|
382
382
|
return core.FeatureSpec(name=name, dtype=arr_dtype, shape=arr.shape)
|
383
383
|
|
384
384
|
|
385
|
-
def infer_dict(name: str, data:
|
385
|
+
def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
|
386
386
|
"""Infer the feature specification from a dictionary.
|
387
387
|
|
388
388
|
Args:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Coroutine,
|
3
|
+
from typing import Any, Callable, Coroutine, Generator, Optional, Union
|
4
4
|
|
5
5
|
import anyio
|
6
6
|
import pandas as pd
|
@@ -78,7 +78,7 @@ class ModelRef:
|
|
78
78
|
return MethodRef(self, method_name)
|
79
79
|
raise AttributeError(f"Method {method_name} not found in model {self._name}.")
|
80
80
|
|
81
|
-
def __getstate__(self) ->
|
81
|
+
def __getstate__(self) -> dict[str, Any]:
|
82
82
|
state = self.__dict__.copy()
|
83
83
|
del state["_model"]
|
84
84
|
return state
|
@@ -113,8 +113,8 @@ class ModelContext:
|
|
113
113
|
def __init__(
|
114
114
|
self,
|
115
115
|
*,
|
116
|
-
artifacts: Optional[Union[
|
117
|
-
models: Optional[Union[
|
116
|
+
artifacts: Optional[Union[dict[str, str], str, model_types.SupportedModelType]] = None,
|
117
|
+
models: Optional[Union[dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
|
118
118
|
**kwargs: Optional[Union[str, model_types.SupportedModelType]],
|
119
119
|
) -> None:
|
120
120
|
"""Initialize the model context.
|
@@ -130,8 +130,8 @@ class ModelContext:
|
|
130
130
|
ValueError: Raised when the model name is duplicated.
|
131
131
|
"""
|
132
132
|
|
133
|
-
self.artifacts:
|
134
|
-
self.model_refs:
|
133
|
+
self.artifacts: dict[str, str] = dict()
|
134
|
+
self.model_refs: dict[str, ModelRef] = dict()
|
135
135
|
|
136
136
|
# In case that artifacts is a dictionary, assume the original usage,
|
137
137
|
# which is to pass in a dictionary of artifacts.
|
@@ -185,7 +185,7 @@ class ModelContext:
|
|
185
185
|
return self.model_refs[name]
|
186
186
|
|
187
187
|
def __getitem__(self, key: str) -> Union[str, ModelRef]:
|
188
|
-
combined:
|
188
|
+
combined: dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
|
189
189
|
if key not in combined:
|
190
190
|
raise KeyError(f"Key {key} not found in the kwargs, current available keys are: {combined.keys()}")
|
191
191
|
return combined[key]
|
@@ -226,7 +226,7 @@ class CustomModel:
|
|
226
226
|
else:
|
227
227
|
raise TypeError("A non-method inference API function is not supported.")
|
228
228
|
|
229
|
-
def _get_partitioned_infer_methods(self) ->
|
229
|
+
def _get_partitioned_infer_methods(self) -> list[str]:
|
230
230
|
"""Returns all methods in CLS with `partitioned_inference_api` as the outermost decorator."""
|
231
231
|
rv = []
|
232
232
|
for cls_method_str in dir(self):
|
@@ -1,18 +1,7 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
Any,
|
6
|
-
Dict,
|
7
|
-
List,
|
8
|
-
Literal,
|
9
|
-
Optional,
|
10
|
-
Sequence,
|
11
|
-
Tuple,
|
12
|
-
Type,
|
13
|
-
Union,
|
14
|
-
cast,
|
15
|
-
)
|
4
|
+
from typing import Any, Literal, Optional, Sequence, Union, cast
|
16
5
|
|
17
6
|
import numpy as np
|
18
7
|
import pandas as pd
|
@@ -30,7 +19,7 @@ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
|
30
19
|
from snowflake.ml.model import type_hints as model_types
|
31
20
|
from snowflake.ml.model._signatures import (
|
32
21
|
base_handler,
|
33
|
-
builtins_handler
|
22
|
+
builtins_handler,
|
34
23
|
core,
|
35
24
|
dmatrix_handler,
|
36
25
|
numpy_handler,
|
@@ -48,7 +37,7 @@ FeatureGroupSpec = core.FeatureGroupSpec
|
|
48
37
|
ModelSignature = core.ModelSignature
|
49
38
|
|
50
39
|
|
51
|
-
_LOCAL_DATA_HANDLERS:
|
40
|
+
_LOCAL_DATA_HANDLERS: list[type[base_handler.BaseDataHandler[Any]]] = [
|
52
41
|
pandas_handler.PandasDataFrameHandler,
|
53
42
|
numpy_handler.NumpyArrayHandler,
|
54
43
|
builtins_handler.ListOfBuiltinHandler,
|
@@ -414,7 +403,7 @@ class SnowparkIdentifierRule(enum.Enum):
|
|
414
403
|
|
415
404
|
def _get_dataframe_values_range(
|
416
405
|
df: snowflake.snowpark.DataFrame,
|
417
|
-
) ->
|
406
|
+
) -> dict[str, Union[tuple[int, int], tuple[float, float]]]:
|
418
407
|
columns = [
|
419
408
|
F.array_construct(F.min(field.name), F.max(field.name)).as_(field.name)
|
420
409
|
for field in df.schema.fields
|
@@ -429,7 +418,7 @@ def _get_dataframe_values_range(
|
|
429
418
|
original_exception=ValueError(f"Unable to get the value range of fields {df.columns}"),
|
430
419
|
)
|
431
420
|
return cast(
|
432
|
-
|
421
|
+
dict[str, Union[tuple[int, int], tuple[float, float]]],
|
433
422
|
{
|
434
423
|
sql_identifier.SqlIdentifier(k, case_sensitive=True).identifier(): (json.loads(v)[0], json.loads(v)[1])
|
435
424
|
for k, v in res[0].as_dict().items()
|
@@ -456,7 +445,7 @@ def _validate_snowpark_data(
|
|
456
445
|
- inferred: signature `a` - Snowpark DF `"a"`, use `get_inferred_name`
|
457
446
|
- normalized: signature `a` - Snowpark DF `A`, use `resolve_identifier`
|
458
447
|
"""
|
459
|
-
errors:
|
448
|
+
errors: dict[SnowparkIdentifierRule, list[Exception]] = {
|
460
449
|
SnowparkIdentifierRule.INFERRED: [],
|
461
450
|
SnowparkIdentifierRule.NORMALIZED: [],
|
462
451
|
}
|
@@ -549,7 +538,7 @@ def _validate_snowpark_type_feature(
|
|
549
538
|
field: spt.StructField,
|
550
539
|
ft_type: DataType,
|
551
540
|
ft_name: str,
|
552
|
-
value_range: Optional[Union[
|
541
|
+
value_range: Optional[Union[tuple[int, int], tuple[float, float]]],
|
553
542
|
strict: bool = False,
|
554
543
|
) -> None:
|
555
544
|
field_data_type = field.datatype
|
@@ -716,8 +705,8 @@ def _convert_and_validate_local_data(
|
|
716
705
|
def infer_signature(
|
717
706
|
input_data: model_types.SupportedLocalDataType,
|
718
707
|
output_data: model_types.SupportedLocalDataType,
|
719
|
-
input_feature_names: Optional[
|
720
|
-
output_feature_names: Optional[
|
708
|
+
input_feature_names: Optional[list[str]] = None,
|
709
|
+
output_feature_names: Optional[list[str]] = None,
|
721
710
|
input_data_limit: Optional[int] = 100,
|
722
711
|
output_data_limit: Optional[int] = 100,
|
723
712
|
) -> core.ModelSignature:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
from packaging import version
|
5
5
|
|
@@ -13,7 +13,7 @@ class HuggingFacePipelineModel:
|
|
13
13
|
revision: Optional[str] = None,
|
14
14
|
token: Optional[str] = None,
|
15
15
|
trust_remote_code: Optional[bool] = None,
|
16
|
-
model_kwargs: Optional[
|
16
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
17
17
|
**kwargs: Any,
|
18
18
|
) -> None:
|
19
19
|
"""
|
@@ -65,6 +65,7 @@ class HuggingFacePipelineModel:
|
|
65
65
|
warnings.warn(
|
66
66
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
|
67
67
|
FutureWarning,
|
68
|
+
stacklevel=2,
|
68
69
|
)
|
69
70
|
if token is not None:
|
70
71
|
raise ValueError(
|
@@ -183,7 +184,8 @@ class HuggingFacePipelineModel:
|
|
183
184
|
warnings.warn(
|
184
185
|
f"No model was supplied, defaulted to {model} and revision"
|
185
186
|
f" {revision} ({transformers.pipelines.HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
186
|
-
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
187
|
+
"Using a pipeline without specifying a model name and revision in production is not recommended.",
|
188
|
+
stacklevel=2,
|
187
189
|
)
|
188
190
|
if config is None and isinstance(model, str):
|
189
191
|
config_obj = transformers.AutoConfig.from_pretrained(
|
@@ -200,7 +202,8 @@ class HuggingFacePipelineModel:
|
|
200
202
|
if kwargs.get("device", None) is not None:
|
201
203
|
warnings.warn(
|
202
204
|
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
|
203
|
-
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
|
205
|
+
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`.",
|
206
|
+
stacklevel=2,
|
204
207
|
)
|
205
208
|
|
206
209
|
# ==== End pipeline logic from transformers ====
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
from enum import Enum
|
3
|
-
from typing import TYPE_CHECKING,
|
3
|
+
from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
|
4
4
|
|
5
5
|
import numpy.typing as npt
|
6
6
|
from typing_extensions import NotRequired
|
@@ -32,7 +32,7 @@ _SupportedBuiltins = Union[
|
|
32
32
|
bool,
|
33
33
|
str,
|
34
34
|
bytes,
|
35
|
-
|
35
|
+
dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
|
36
36
|
"_SupportedBuiltinsList",
|
37
37
|
]
|
38
38
|
_SupportedNumpyDtype = Union[
|
@@ -147,13 +147,15 @@ class BaseModelSaveOption(TypedDict):
|
|
147
147
|
embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
|
148
148
|
relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse.
|
149
149
|
It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
150
|
+
save_location: Local directory path to save the model and metadata.
|
150
151
|
"""
|
151
152
|
|
152
153
|
embed_local_ml_library: NotRequired[bool]
|
153
154
|
relax_version: NotRequired[bool]
|
154
155
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
155
|
-
method_options: NotRequired[
|
156
|
+
method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
|
156
157
|
enable_explainability: NotRequired[bool]
|
158
|
+
save_location: NotRequired[str]
|
157
159
|
|
158
160
|
|
159
161
|
class CatBoostModelSaveOptions(BaseModelSaveOption):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import numbers
|
3
3
|
import os
|
4
|
-
from typing import Any, Callable
|
4
|
+
from typing import Any, Callable
|
5
5
|
|
6
6
|
import cloudpickle as cp
|
7
7
|
import numpy as np
|
@@ -16,7 +16,7 @@ from snowflake.snowpark import Session
|
|
16
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
17
17
|
|
18
18
|
|
19
|
-
def validate_sklearn_args(args:
|
19
|
+
def validate_sklearn_args(args: dict[str, tuple[Any, Any, bool]], klass: type) -> dict[str, Any]:
|
20
20
|
"""Validate if all the keyword args are supported by current version of SKLearn/XGBoost object.
|
21
21
|
|
22
22
|
Args:
|
@@ -71,7 +71,7 @@ def transform_snowml_obj_to_sklearn_obj(obj: Any) -> Any:
|
|
71
71
|
return obj
|
72
72
|
|
73
73
|
|
74
|
-
def gather_dependencies(obj: Any) ->
|
74
|
+
def gather_dependencies(obj: Any) -> set[str]:
|
75
75
|
"""Gathers dependencies from the SnowML Estimator and Transformer objects.
|
76
76
|
|
77
77
|
Args:
|
@@ -82,7 +82,7 @@ def gather_dependencies(obj: Any) -> Set[str]:
|
|
82
82
|
"""
|
83
83
|
|
84
84
|
if isinstance(obj, list) or isinstance(obj, tuple):
|
85
|
-
deps:
|
85
|
+
deps: set[str] = set()
|
86
86
|
for elem in obj:
|
87
87
|
deps = deps | set(gather_dependencies(elem))
|
88
88
|
return deps
|
@@ -167,8 +167,8 @@ def get_module_name(model: object) -> str:
|
|
167
167
|
|
168
168
|
|
169
169
|
def handle_inference_result(
|
170
|
-
inference_res: Any, output_cols:
|
171
|
-
) ->
|
170
|
+
inference_res: Any, output_cols: list[str], inference_method: str, within_udf: bool = False
|
171
|
+
) -> tuple[npt.NDArray[Any], list[str]]:
|
172
172
|
if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
|
173
173
|
# In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of
|
174
174
|
# ndarrays. We need to concatenate them.
|
@@ -248,7 +248,7 @@ def create_temp_stage(session: Session) -> str:
|
|
248
248
|
|
249
249
|
|
250
250
|
def upload_model_to_stage(
|
251
|
-
stage_name: str, estimator: object, session: Session, statement_params:
|
251
|
+
stage_name: str, estimator: object, session: Session, statement_params: dict[str, str]
|
252
252
|
) -> str:
|
253
253
|
"""Util method to pickle and upload the model to a temp Snowflake stage.
|
254
254
|
|