snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import os
|
3
3
|
import tempfile
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional
|
5
5
|
|
6
6
|
import cloudpickle as cp
|
7
7
|
import pandas as pd
|
@@ -41,13 +41,13 @@ _PROJECT = "ModelDevelopment"
|
|
41
41
|
|
42
42
|
|
43
43
|
def get_data_iterator(
|
44
|
-
file_paths:
|
44
|
+
file_paths: list[str],
|
45
45
|
batch_size: int,
|
46
|
-
input_cols:
|
47
|
-
label_cols:
|
46
|
+
input_cols: list[str],
|
47
|
+
label_cols: list[str],
|
48
48
|
sample_weight_col: Optional[str] = None,
|
49
49
|
) -> Any:
|
50
|
-
from typing import
|
50
|
+
from typing import Optional
|
51
51
|
|
52
52
|
import xgboost
|
53
53
|
|
@@ -60,10 +60,10 @@ def get_data_iterator(
|
|
60
60
|
|
61
61
|
def __init__(
|
62
62
|
self,
|
63
|
-
file_paths:
|
63
|
+
file_paths: list[str],
|
64
64
|
batch_size: int,
|
65
|
-
input_cols:
|
66
|
-
label_cols:
|
65
|
+
input_cols: list[str],
|
66
|
+
label_cols: list[str],
|
67
67
|
sample_weight_col: Optional[str] = None,
|
68
68
|
) -> None:
|
69
69
|
"""
|
@@ -151,10 +151,10 @@ def get_data_iterator(
|
|
151
151
|
|
152
152
|
def train_xgboost_model(
|
153
153
|
estimator: object,
|
154
|
-
file_paths:
|
154
|
+
file_paths: list[str],
|
155
155
|
batch_size: int,
|
156
|
-
input_cols:
|
157
|
-
label_cols:
|
156
|
+
input_cols: list[str],
|
157
|
+
label_cols: list[str],
|
158
158
|
sample_weight_col: Optional[str] = None,
|
159
159
|
) -> object:
|
160
160
|
"""
|
@@ -247,8 +247,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
247
247
|
estimator: object,
|
248
248
|
dataset: DataFrame,
|
249
249
|
session: Session,
|
250
|
-
input_cols:
|
251
|
-
label_cols: Optional[
|
250
|
+
input_cols: list[str],
|
251
|
+
label_cols: Optional[list[str]],
|
252
252
|
sample_weight_col: Optional[str],
|
253
253
|
autogenerated: bool = False,
|
254
254
|
subproject: str = "",
|
@@ -285,8 +285,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
285
285
|
self,
|
286
286
|
model_spec: ModelSpecifications,
|
287
287
|
session: Session,
|
288
|
-
statement_params:
|
289
|
-
import_file_paths:
|
288
|
+
statement_params: dict[str, str],
|
289
|
+
import_file_paths: list[str],
|
290
290
|
) -> Any:
|
291
291
|
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
292
292
|
|
@@ -308,10 +308,10 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
308
308
|
session: Session,
|
309
309
|
dataset_stage_name: str,
|
310
310
|
batch_size: int,
|
311
|
-
input_cols:
|
312
|
-
label_cols:
|
311
|
+
input_cols: list[str],
|
312
|
+
label_cols: list[str],
|
313
313
|
sample_weight_col: Optional[str],
|
314
|
-
statement_params:
|
314
|
+
statement_params: dict[str, str],
|
315
315
|
) -> str:
|
316
316
|
import os
|
317
317
|
import sys
|
@@ -365,7 +365,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
365
365
|
|
366
366
|
return fit_wrapper_sproc
|
367
367
|
|
368
|
-
def _write_training_data_to_stage(self, dataset_stage_name: str) ->
|
368
|
+
def _write_training_data_to_stage(self, dataset_stage_name: str) -> list[str]:
|
369
369
|
"""
|
370
370
|
Materializes the training to the specified stage and returns the list of stage file paths.
|
371
371
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional, Protocol, TypedDict, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
@@ -29,9 +29,9 @@ class LocalModelTransformHandlers(Protocol):
|
|
29
29
|
def batch_inference(
|
30
30
|
self,
|
31
31
|
inference_method: str,
|
32
|
-
input_cols:
|
33
|
-
expected_output_cols:
|
34
|
-
snowpark_input_cols: Optional[
|
32
|
+
input_cols: list[str],
|
33
|
+
expected_output_cols: list[str],
|
34
|
+
snowpark_input_cols: Optional[list[str]],
|
35
35
|
drop_input_cols: Optional[bool] = False,
|
36
36
|
*args: Any,
|
37
37
|
**kwargs: Any,
|
@@ -57,8 +57,8 @@ class LocalModelTransformHandlers(Protocol):
|
|
57
57
|
|
58
58
|
def score(
|
59
59
|
self,
|
60
|
-
input_cols:
|
61
|
-
label_cols:
|
60
|
+
input_cols: list[str],
|
61
|
+
label_cols: list[str],
|
62
62
|
sample_weight_col: Optional[str],
|
63
63
|
*args: Any,
|
64
64
|
**kwargs: Any,
|
@@ -105,10 +105,10 @@ class RemoteModelTransformHandlers(Protocol):
|
|
105
105
|
def batch_inference(
|
106
106
|
self,
|
107
107
|
inference_method: str,
|
108
|
-
input_cols:
|
109
|
-
expected_output_cols:
|
108
|
+
input_cols: list[str],
|
109
|
+
expected_output_cols: list[str],
|
110
110
|
session: snowpark.Session,
|
111
|
-
dependencies:
|
111
|
+
dependencies: list[str],
|
112
112
|
drop_input_cols: Optional[bool] = False,
|
113
113
|
expected_output_cols_type: Optional[str] = "",
|
114
114
|
*args: Any,
|
@@ -137,11 +137,11 @@ class RemoteModelTransformHandlers(Protocol):
|
|
137
137
|
|
138
138
|
def score(
|
139
139
|
self,
|
140
|
-
input_cols:
|
141
|
-
label_cols:
|
140
|
+
input_cols: list[str],
|
141
|
+
label_cols: list[str],
|
142
142
|
session: snowpark.Session,
|
143
|
-
dependencies:
|
144
|
-
score_sproc_imports:
|
143
|
+
dependencies: list[str],
|
144
|
+
score_sproc_imports: list[str],
|
145
145
|
sample_weight_col: Optional[str] = None,
|
146
146
|
*args: Any,
|
147
147
|
**kwargs: Any,
|
@@ -173,10 +173,10 @@ ModelTransformHandlers = Union[LocalModelTransformHandlers, RemoteModelTransform
|
|
173
173
|
class BatchInferenceKwargsTypedDict(TypedDict, total=False):
|
174
174
|
"""A typed dict specifying all possible optional keyword args accepted by batch_inference() methods."""
|
175
175
|
|
176
|
-
snowpark_input_cols: Optional[
|
176
|
+
snowpark_input_cols: Optional[list[str]]
|
177
177
|
drop_input_cols: Optional[bool]
|
178
178
|
session: snowpark.Session
|
179
|
-
dependencies:
|
179
|
+
dependencies: list[str]
|
180
180
|
expected_output_cols_type: str
|
181
181
|
n_neighbors: Optional[int]
|
182
182
|
return_distance: bool
|
@@ -186,5 +186,5 @@ class ScoreKwargsTypedDict(TypedDict, total=False):
|
|
186
186
|
"""A typed dict specifying all possible optional keyword args accepted by score() methods."""
|
187
187
|
|
188
188
|
session: snowpark.Session
|
189
|
-
dependencies:
|
190
|
-
score_sproc_imports:
|
189
|
+
dependencies: list[str]
|
190
|
+
score_sproc_imports: list[str]
|
@@ -3,7 +3,7 @@
|
|
3
3
|
import inspect
|
4
4
|
import warnings
|
5
5
|
from enum import Enum
|
6
|
-
from typing import Any, Callable,
|
6
|
+
from typing import Any, Callable, Iterable, Optional, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import sklearn
|
@@ -62,7 +62,7 @@ class BasicStatistics(str, Enum):
|
|
62
62
|
MODE = "mode"
|
63
63
|
|
64
64
|
|
65
|
-
def get_default_args(func: Callable[..., None]) ->
|
65
|
+
def get_default_args(func: Callable[..., None]) -> dict[str, Any]:
|
66
66
|
signature = inspect.signature(func)
|
67
67
|
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
68
68
|
|
@@ -72,16 +72,16 @@ def generate_value_with_prefix(prefix: str) -> str:
|
|
72
72
|
|
73
73
|
|
74
74
|
def get_filtered_valid_sklearn_args(
|
75
|
-
args:
|
76
|
-
default_sklearn_args:
|
75
|
+
args: dict[str, Any],
|
76
|
+
default_sklearn_args: dict[str, Any],
|
77
77
|
sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
|
78
78
|
sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
|
79
79
|
snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
|
80
|
-
sklearn_added_keyword_to_version_dict: Optional[
|
81
|
-
sklearn_added_kwarg_value_to_version_dict: Optional[
|
82
|
-
sklearn_deprecated_keyword_to_version_dict: Optional[
|
83
|
-
sklearn_removed_keyword_to_version_dict: Optional[
|
84
|
-
) ->
|
80
|
+
sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
81
|
+
sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
|
82
|
+
sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
83
|
+
sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
84
|
+
) -> dict[str, Any]:
|
85
85
|
"""
|
86
86
|
Get valid sklearn keyword arguments with non-default values.
|
87
87
|
|
@@ -241,7 +241,7 @@ def to_native_format(obj: Any) -> Any:
|
|
241
241
|
return obj.to_sklearn()
|
242
242
|
|
243
243
|
|
244
|
-
def table_exists(session: snowpark.Session, table_name: str, statement_params:
|
244
|
+
def table_exists(session: snowpark.Session, table_name: str, statement_params: dict[str, Any]) -> bool:
|
245
245
|
try:
|
246
246
|
session.table(table_name).limit(0).collect(statement_params=statement_params)
|
247
247
|
return True
|
@@ -2,7 +2,7 @@
|
|
2
2
|
import inspect
|
3
3
|
from abc import abstractmethod
|
4
4
|
from datetime import datetime
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Iterable, Mapping, Optional, Union, overload
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
@@ -28,9 +28,9 @@ SKLEARN_SUPERVISED_ESTIMATORS = ["regressor", "classifier"]
|
|
28
28
|
SKLEARN_SINGLE_OUTPUT_ESTIMATORS = ["DensityEstimator", "clusterer", "outlier_detector"]
|
29
29
|
|
30
30
|
|
31
|
-
def _process_cols(cols: Optional[Union[str, Iterable[str]]]) ->
|
31
|
+
def _process_cols(cols: Optional[Union[str, Iterable[str]]]) -> list[str]:
|
32
32
|
"""Convert cols to a list."""
|
33
|
-
col_list:
|
33
|
+
col_list: list[str] = []
|
34
34
|
if cols is None:
|
35
35
|
return col_list
|
36
36
|
elif type(cols) is list:
|
@@ -55,10 +55,10 @@ class Base:
|
|
55
55
|
passthrough_cols: List columns not to be used or modified by the estimator/transformers.
|
56
56
|
These columns will be passed through all the estimator/transformer operations without any modifications.
|
57
57
|
"""
|
58
|
-
self.input_cols:
|
59
|
-
self.output_cols:
|
60
|
-
self.label_cols:
|
61
|
-
self.passthrough_cols:
|
58
|
+
self.input_cols: list[str] = []
|
59
|
+
self.output_cols: list[str] = []
|
60
|
+
self.label_cols: list[str] = []
|
61
|
+
self.passthrough_cols: list[str] = []
|
62
62
|
|
63
63
|
def _create_unfitted_sklearn_object(self) -> Any:
|
64
64
|
raise NotImplementedError()
|
@@ -66,7 +66,7 @@ class Base:
|
|
66
66
|
def _create_sklearn_object(self) -> Any:
|
67
67
|
raise NotImplementedError()
|
68
68
|
|
69
|
-
def get_input_cols(self) ->
|
69
|
+
def get_input_cols(self) -> list[str]:
|
70
70
|
"""
|
71
71
|
Input columns getter.
|
72
72
|
|
@@ -88,7 +88,7 @@ class Base:
|
|
88
88
|
self.input_cols = _process_cols(input_cols)
|
89
89
|
return self
|
90
90
|
|
91
|
-
def get_output_cols(self) ->
|
91
|
+
def get_output_cols(self) -> list[str]:
|
92
92
|
"""
|
93
93
|
Output columns getter.
|
94
94
|
|
@@ -110,7 +110,7 @@ class Base:
|
|
110
110
|
self.output_cols = _process_cols(output_cols)
|
111
111
|
return self
|
112
112
|
|
113
|
-
def get_label_cols(self) ->
|
113
|
+
def get_label_cols(self) -> list[str]:
|
114
114
|
"""
|
115
115
|
Label column getter.
|
116
116
|
|
@@ -132,7 +132,7 @@ class Base:
|
|
132
132
|
self.label_cols = _process_cols(label_cols)
|
133
133
|
return self
|
134
134
|
|
135
|
-
def get_passthrough_cols(self) ->
|
135
|
+
def get_passthrough_cols(self) -> list[str]:
|
136
136
|
"""
|
137
137
|
Passthrough columns getter.
|
138
138
|
|
@@ -215,7 +215,7 @@ class Base:
|
|
215
215
|
)
|
216
216
|
|
217
217
|
@classmethod
|
218
|
-
def _get_param_names(cls) ->
|
218
|
+
def _get_param_names(cls) -> list[str]:
|
219
219
|
"""Get parameter names for the transformer"""
|
220
220
|
# fetch the constructor or the original constructor before
|
221
221
|
# deprecation wrapping if any
|
@@ -244,7 +244,7 @@ class Base:
|
|
244
244
|
# Extract and sort argument names excluding 'self'
|
245
245
|
return sorted(p.name for p in parameters)
|
246
246
|
|
247
|
-
def get_params(self, deep: bool = True) ->
|
247
|
+
def get_params(self, deep: bool = True) -> dict[str, Any]:
|
248
248
|
"""
|
249
249
|
Get the snowflake-ml parameters for this transformer.
|
250
250
|
|
@@ -255,7 +255,7 @@ class Base:
|
|
255
255
|
Returns:
|
256
256
|
Parameter names mapped to their values.
|
257
257
|
"""
|
258
|
-
out:
|
258
|
+
out: dict[str, Any] = dict()
|
259
259
|
for key in self._get_param_names():
|
260
260
|
if hasattr(self, key):
|
261
261
|
value = getattr(self, key)
|
@@ -320,11 +320,11 @@ class Base:
|
|
320
320
|
sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
|
321
321
|
sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
|
322
322
|
snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
|
323
|
-
sklearn_added_keyword_to_version_dict: Optional[
|
324
|
-
sklearn_added_kwarg_value_to_version_dict: Optional[
|
325
|
-
sklearn_deprecated_keyword_to_version_dict: Optional[
|
326
|
-
sklearn_removed_keyword_to_version_dict: Optional[
|
327
|
-
) ->
|
323
|
+
sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
324
|
+
sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
|
325
|
+
sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
326
|
+
sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
327
|
+
) -> dict[str, Any]:
|
328
328
|
"""
|
329
329
|
Get sklearn keyword arguments.
|
330
330
|
|
@@ -350,7 +350,7 @@ class Base:
|
|
350
350
|
"""
|
351
351
|
default_sklearn_args = _utils.get_default_args(default_sklearn_obj.__class__.__init__)
|
352
352
|
given_args = self.get_params()
|
353
|
-
sklearn_args:
|
353
|
+
sklearn_args: dict[str, Any] = _utils.get_filtered_valid_sklearn_args(
|
354
354
|
args=given_args,
|
355
355
|
default_sklearn_args=default_sklearn_args,
|
356
356
|
sklearn_initial_keywords=sklearn_initial_keywords,
|
@@ -368,8 +368,8 @@ class BaseEstimator(Base):
|
|
368
368
|
def __init__(
|
369
369
|
self,
|
370
370
|
*,
|
371
|
-
file_names: Optional[
|
372
|
-
custom_states: Optional[
|
371
|
+
file_names: Optional[list[str]] = None,
|
372
|
+
custom_states: Optional[list[str]] = None,
|
373
373
|
sample_weight_col: Optional[str] = None,
|
374
374
|
) -> None:
|
375
375
|
"""
|
@@ -418,7 +418,7 @@ class BaseEstimator(Base):
|
|
418
418
|
self.sample_weight_col = sample_weight_col
|
419
419
|
return self
|
420
420
|
|
421
|
-
def _get_dependencies(self) ->
|
421
|
+
def _get_dependencies(self) -> list[str]:
|
422
422
|
"""
|
423
423
|
Return the list of conda dependencies required to work with the object.
|
424
424
|
|
@@ -458,8 +458,8 @@ class BaseEstimator(Base):
|
|
458
458
|
return dataset[self.input_cols]
|
459
459
|
|
460
460
|
def _compute(
|
461
|
-
self, dataset: snowpark.DataFrame, cols:
|
462
|
-
) ->
|
461
|
+
self, dataset: snowpark.DataFrame, cols: list[str], states: list[str]
|
462
|
+
) -> dict[str, dict[str, Union[int, float, str]]]:
|
463
463
|
"""
|
464
464
|
Compute required states of the columns.
|
465
465
|
|
@@ -474,7 +474,7 @@ class BaseEstimator(Base):
|
|
474
474
|
A dict of {column_name: {state: value}} of each column.
|
475
475
|
"""
|
476
476
|
|
477
|
-
def _compute_on_partition(df: snowpark.DataFrame, cols_subset:
|
477
|
+
def _compute_on_partition(df: snowpark.DataFrame, cols_subset: list[str]) -> snowpark.DataFrame:
|
478
478
|
"""Returns a DataFrame with the desired computation on the specified column subset."""
|
479
479
|
exprs = []
|
480
480
|
sql_prefix = "SQL>>>"
|
@@ -499,7 +499,7 @@ class BaseEstimator(Base):
|
|
499
499
|
statement_params=telemetry.get_statement_params(PROJECT, SUBPROJECT, self.__class__.__name__),
|
500
500
|
)
|
501
501
|
|
502
|
-
computed_dict:
|
502
|
+
computed_dict: dict[str, dict[str, Union[int, float, str]]] = {}
|
503
503
|
for idx, val in enumerate(_results[0]):
|
504
504
|
col_name = cols[idx // len(states)]
|
505
505
|
if col_name not in computed_dict:
|
@@ -516,8 +516,8 @@ class BaseTransformer(BaseEstimator):
|
|
516
516
|
self,
|
517
517
|
*,
|
518
518
|
drop_input_cols: Optional[bool] = False,
|
519
|
-
file_names: Optional[
|
520
|
-
custom_states: Optional[
|
519
|
+
file_names: Optional[list[str]] = None,
|
520
|
+
custom_states: Optional[list[str]] = None,
|
521
521
|
sample_weight_col: Optional[str] = None,
|
522
522
|
) -> None:
|
523
523
|
"""Base class for all transformers."""
|
@@ -551,7 +551,7 @@ class BaseTransformer(BaseEstimator):
|
|
551
551
|
),
|
552
552
|
)
|
553
553
|
|
554
|
-
def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) ->
|
554
|
+
def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> list[str]:
|
555
555
|
"""
|
556
556
|
Infer input_cols from the dataset. Input column are all columns in the input dataset that are not
|
557
557
|
designated as label, passthrough, or sample weight columns.
|
@@ -569,7 +569,7 @@ class BaseTransformer(BaseEstimator):
|
|
569
569
|
]
|
570
570
|
return cols
|
571
571
|
|
572
|
-
def _infer_output_cols(self) ->
|
572
|
+
def _infer_output_cols(self) -> list[str]:
|
573
573
|
"""Infer output column names from based on the estimator.
|
574
574
|
|
575
575
|
Returns:
|
@@ -624,7 +624,7 @@ class BaseTransformer(BaseEstimator):
|
|
624
624
|
cols = self._infer_output_cols()
|
625
625
|
self.set_output_cols(output_cols=cols)
|
626
626
|
|
627
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[
|
627
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[list[str]] = None) -> list[str]:
|
628
628
|
"""Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
629
629
|
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
630
630
|
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
4
4
|
|
5
|
-
pkg_dir = os.path.dirname(
|
5
|
+
pkg_dir = os.path.dirname(__file__)
|
6
6
|
pkg_name = __name__
|
7
7
|
exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
8
8
|
for k, v in exportable_classes.items():
|
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import copy
|
3
3
|
import warnings
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Iterable, Optional, Union
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import numpy.typing as npt
|
@@ -25,7 +25,7 @@ STRATEGY_TO_STATE_DICT = {
|
|
25
25
|
"most_frequent": _utils.BasicStatistics.MODE,
|
26
26
|
}
|
27
27
|
|
28
|
-
SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP:
|
28
|
+
SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP: dict[type[T.DataType], npt.DTypeLike] = {
|
29
29
|
T.ByteType: np.dtype("int8"),
|
30
30
|
T.ShortType: np.dtype("int16"),
|
31
31
|
T.IntegerType: np.dtype("int32"),
|
@@ -164,7 +164,7 @@ class SimpleImputer(base.BaseTransformer):
|
|
164
164
|
|
165
165
|
self.fill_value = fill_value
|
166
166
|
self.missing_values = missing_values
|
167
|
-
self.statistics_:
|
167
|
+
self.statistics_: dict[str, Any] = {}
|
168
168
|
# TODO(hayu): [SNOW-752265] Support SimpleImputer keep_empty_features.
|
169
169
|
# Add back when `keep_empty_features` is supported.
|
170
170
|
# self.keep_empty_features = keep_empty_features
|
@@ -195,7 +195,7 @@ class SimpleImputer(base.BaseTransformer):
|
|
195
195
|
del self.feature_names_in_
|
196
196
|
del self._sklearn_fit_dtype
|
197
197
|
|
198
|
-
def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) ->
|
198
|
+
def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) -> dict[str, T.DataType]:
|
199
199
|
"""
|
200
200
|
Checks that the input columns are all the same datatype category(except for most_frequent strategy) and
|
201
201
|
returns the datatype.
|
@@ -211,7 +211,7 @@ class SimpleImputer(base.BaseTransformer):
|
|
211
211
|
supported.
|
212
212
|
"""
|
213
213
|
|
214
|
-
def check_type_consistency(col_types:
|
214
|
+
def check_type_consistency(col_types: dict[str, T.DataType]) -> None:
|
215
215
|
is_numeric_type = None
|
216
216
|
for col_name, col_type in col_types.items():
|
217
217
|
if is_numeric_type is None:
|
@@ -5,7 +5,7 @@ import cloudpickle
|
|
5
5
|
from snowflake.ml._internal import init_utils
|
6
6
|
from snowflake.ml._internal.utils import result
|
7
7
|
|
8
|
-
pkg_dir = os.path.dirname(
|
8
|
+
pkg_dir = os.path.dirname(__file__)
|
9
9
|
pkg_name = __name__
|
10
10
|
exportable_functions = init_utils.fetch_functions_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
11
11
|
for k, v in exportable_functions.items():
|