snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +6 -6
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +6 -16
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +17 -28
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +4 -5
- snowflake/ml/jobs/job.py +24 -14
- snowflake/ml/jobs/manager.py +37 -41
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +199 -26
- snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +13 -13
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +17 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +33 -31
- snowflake/ml/registry/registry.py +29 -22
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import inspect
|
|
3
3
|
import os
|
4
4
|
import posixpath
|
5
5
|
import sys
|
6
|
-
from typing import Any,
|
6
|
+
from typing import Any, Optional
|
7
7
|
from uuid import uuid4
|
8
8
|
|
9
9
|
import cloudpickle as cp
|
@@ -73,10 +73,10 @@ class SnowparkTransformHandlers:
|
|
73
73
|
def batch_inference(
|
74
74
|
self,
|
75
75
|
inference_method: str,
|
76
|
-
input_cols:
|
77
|
-
expected_output_cols:
|
76
|
+
input_cols: list[str],
|
77
|
+
expected_output_cols: list[str],
|
78
78
|
session: Session,
|
79
|
-
dependencies:
|
79
|
+
dependencies: list[str],
|
80
80
|
drop_input_cols: Optional[bool] = False,
|
81
81
|
expected_output_cols_type: Optional[str] = "",
|
82
82
|
*args: Any,
|
@@ -229,11 +229,11 @@ class SnowparkTransformHandlers:
|
|
229
229
|
|
230
230
|
def score(
|
231
231
|
self,
|
232
|
-
input_cols:
|
233
|
-
label_cols:
|
232
|
+
input_cols: list[str],
|
233
|
+
label_cols: list[str],
|
234
234
|
session: Session,
|
235
|
-
dependencies:
|
236
|
-
score_sproc_imports:
|
235
|
+
dependencies: list[str],
|
236
|
+
score_sproc_imports: list[str],
|
237
237
|
sample_weight_col: Optional[str] = None,
|
238
238
|
*args: Any,
|
239
239
|
**kwargs: Any,
|
@@ -308,12 +308,12 @@ class SnowparkTransformHandlers:
|
|
308
308
|
)
|
309
309
|
def score_wrapper_sproc(
|
310
310
|
session: Session,
|
311
|
-
sql_queries:
|
311
|
+
sql_queries: list[str],
|
312
312
|
stage_score_file_name: str,
|
313
|
-
input_cols:
|
314
|
-
label_cols:
|
313
|
+
input_cols: list[str],
|
314
|
+
label_cols: list[str],
|
315
315
|
sample_weight_col: Optional[str],
|
316
|
-
score_statement_params:
|
316
|
+
score_statement_params: dict[str, str],
|
317
317
|
) -> float:
|
318
318
|
import inspect
|
319
319
|
import os
|
@@ -382,7 +382,7 @@ class SnowparkTransformHandlers:
|
|
382
382
|
|
383
383
|
return score
|
384
384
|
|
385
|
-
def _get_validated_snowpark_dependencies(self, session: Session, dependencies:
|
385
|
+
def _get_validated_snowpark_dependencies(self, session: Session, dependencies: list[str]) -> list[str]:
|
386
386
|
"""A helper function to validate dependencies and return the available packages that exists
|
387
387
|
in the snowflake anaconda channel
|
388
388
|
|
@@ -2,7 +2,7 @@ import importlib
|
|
2
2
|
import inspect
|
3
3
|
import os
|
4
4
|
import posixpath
|
5
|
-
from typing import Any, Callable,
|
5
|
+
from typing import Any, Callable, Optional, Union
|
6
6
|
|
7
7
|
import cloudpickle as cp
|
8
8
|
import pandas as pd
|
@@ -55,8 +55,8 @@ class SnowparkModelTrainer:
|
|
55
55
|
estimator: object,
|
56
56
|
dataset: DataFrame,
|
57
57
|
session: Session,
|
58
|
-
input_cols:
|
59
|
-
label_cols: Optional[
|
58
|
+
input_cols: list[str],
|
59
|
+
label_cols: Optional[list[str]],
|
60
60
|
sample_weight_col: Optional[str],
|
61
61
|
autogenerated: bool = False,
|
62
62
|
subproject: str = "",
|
@@ -84,7 +84,7 @@ class SnowparkModelTrainer:
|
|
84
84
|
self._subproject = subproject
|
85
85
|
self._class_name = estimator.__class__.__name__
|
86
86
|
|
87
|
-
def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params:
|
87
|
+
def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: dict[str, str]) -> object:
|
88
88
|
"""
|
89
89
|
Downloads the serialized model from a stage location and unpickles it.
|
90
90
|
|
@@ -112,7 +112,7 @@ class SnowparkModelTrainer:
|
|
112
112
|
def _build_fit_wrapper_sproc(
|
113
113
|
self,
|
114
114
|
model_spec: ModelSpecifications,
|
115
|
-
) -> Callable[[Any,
|
115
|
+
) -> Callable[[Any, list[str], str, list[str], list[str], Optional[str], dict[str, str]], str]:
|
116
116
|
"""
|
117
117
|
Constructs and returns a python stored procedure function to be used for training model.
|
118
118
|
|
@@ -129,12 +129,12 @@ class SnowparkModelTrainer:
|
|
129
129
|
|
130
130
|
def fit_wrapper_function(
|
131
131
|
session: Session,
|
132
|
-
sql_queries:
|
132
|
+
sql_queries: list[str],
|
133
133
|
temp_stage_name: str,
|
134
|
-
input_cols:
|
135
|
-
label_cols:
|
134
|
+
input_cols: list[str],
|
135
|
+
label_cols: list[str],
|
136
136
|
sample_weight_col: Optional[str],
|
137
|
-
statement_params:
|
137
|
+
statement_params: dict[str, str],
|
138
138
|
) -> str:
|
139
139
|
import inspect
|
140
140
|
import os
|
@@ -218,7 +218,7 @@ class SnowparkModelTrainer:
|
|
218
218
|
|
219
219
|
return fit_wrapper_function
|
220
220
|
|
221
|
-
def _get_fit_wrapper_sproc(self, statement_params:
|
221
|
+
def _get_fit_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
|
222
222
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
223
223
|
|
224
224
|
fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
@@ -243,7 +243,7 @@ class SnowparkModelTrainer:
|
|
243
243
|
def _build_fit_predict_wrapper_sproc(
|
244
244
|
self,
|
245
245
|
model_spec: ModelSpecifications,
|
246
|
-
) -> Callable[[Session,
|
246
|
+
) -> Callable[[Session, list[str], str, list[str], dict[str, str], bool, list[str], str], str]:
|
247
247
|
"""
|
248
248
|
Constructs and returns a python stored procedure function to be used for training model.
|
249
249
|
|
@@ -258,12 +258,12 @@ class SnowparkModelTrainer:
|
|
258
258
|
|
259
259
|
def fit_predict_wrapper_function(
|
260
260
|
session: Session,
|
261
|
-
sql_queries:
|
261
|
+
sql_queries: list[str],
|
262
262
|
temp_stage_name: str,
|
263
|
-
input_cols:
|
264
|
-
statement_params:
|
263
|
+
input_cols: list[str],
|
264
|
+
statement_params: dict[str, str],
|
265
265
|
drop_input_cols: bool,
|
266
|
-
expected_output_cols_list:
|
266
|
+
expected_output_cols_list: list[str],
|
267
267
|
fit_predict_result_name: str,
|
268
268
|
) -> str:
|
269
269
|
import os
|
@@ -346,14 +346,14 @@ class SnowparkModelTrainer:
|
|
346
346
|
) -> Callable[
|
347
347
|
[
|
348
348
|
Session,
|
349
|
-
|
349
|
+
list[str],
|
350
350
|
str,
|
351
|
-
|
352
|
-
Optional[
|
351
|
+
list[str],
|
352
|
+
Optional[list[str]],
|
353
353
|
Optional[str],
|
354
|
-
|
354
|
+
dict[str, str],
|
355
355
|
bool,
|
356
|
-
|
356
|
+
list[str],
|
357
357
|
str,
|
358
358
|
],
|
359
359
|
str,
|
@@ -372,14 +372,14 @@ class SnowparkModelTrainer:
|
|
372
372
|
|
373
373
|
def fit_transform_wrapper_function(
|
374
374
|
session: Session,
|
375
|
-
sql_queries:
|
375
|
+
sql_queries: list[str],
|
376
376
|
temp_stage_name: str,
|
377
|
-
input_cols:
|
378
|
-
label_cols: Optional[
|
377
|
+
input_cols: list[str],
|
378
|
+
label_cols: Optional[list[str]],
|
379
379
|
sample_weight_col: Optional[str],
|
380
|
-
statement_params:
|
380
|
+
statement_params: dict[str, str],
|
381
381
|
drop_input_cols: bool,
|
382
|
-
expected_output_cols_list:
|
382
|
+
expected_output_cols_list: list[str],
|
383
383
|
fit_transform_result_name: str,
|
384
384
|
) -> str:
|
385
385
|
import os
|
@@ -473,7 +473,7 @@ class SnowparkModelTrainer:
|
|
473
473
|
|
474
474
|
return fit_transform_wrapper_function
|
475
475
|
|
476
|
-
def _get_fit_predict_wrapper_sproc(self, statement_params:
|
476
|
+
def _get_fit_predict_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
|
477
477
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
478
478
|
|
479
479
|
fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
@@ -495,7 +495,7 @@ class SnowparkModelTrainer:
|
|
495
495
|
|
496
496
|
return fit_predict_wrapper_sproc
|
497
497
|
|
498
|
-
def _get_fit_transform_wrapper_sproc(self, statement_params:
|
498
|
+
def _get_fit_transform_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
|
499
499
|
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
500
500
|
|
501
501
|
fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
|
@@ -586,10 +586,10 @@ class SnowparkModelTrainer:
|
|
586
586
|
|
587
587
|
def train_fit_predict(
|
588
588
|
self,
|
589
|
-
expected_output_cols_list:
|
589
|
+
expected_output_cols_list: list[str],
|
590
590
|
drop_input_cols: Optional[bool] = False,
|
591
591
|
example_output_pd_df: Optional[pd.DataFrame] = None,
|
592
|
-
) ->
|
592
|
+
) -> tuple[Union[DataFrame, pd.DataFrame], object]:
|
593
593
|
"""Trains the model by pushing down the compute into Snowflake using stored procedures.
|
594
594
|
This API is different from fit itself because it would also provide the predict
|
595
595
|
output.
|
@@ -682,9 +682,9 @@ class SnowparkModelTrainer:
|
|
682
682
|
|
683
683
|
def train_fit_transform(
|
684
684
|
self,
|
685
|
-
expected_output_cols_list:
|
685
|
+
expected_output_cols_list: list[str],
|
686
686
|
drop_input_cols: Optional[bool] = False,
|
687
|
-
) ->
|
687
|
+
) -> tuple[Union[DataFrame, pd.DataFrame], object]:
|
688
688
|
"""Trains the model by pushing down the compute into Snowflake using stored procedures.
|
689
689
|
This API is different from fit itself because it would also provide the transform
|
690
690
|
output.
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import os
|
3
3
|
import tempfile
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional
|
5
5
|
|
6
6
|
import cloudpickle as cp
|
7
7
|
import pandas as pd
|
@@ -41,13 +41,13 @@ _PROJECT = "ModelDevelopment"
|
|
41
41
|
|
42
42
|
|
43
43
|
def get_data_iterator(
|
44
|
-
file_paths:
|
44
|
+
file_paths: list[str],
|
45
45
|
batch_size: int,
|
46
|
-
input_cols:
|
47
|
-
label_cols:
|
46
|
+
input_cols: list[str],
|
47
|
+
label_cols: list[str],
|
48
48
|
sample_weight_col: Optional[str] = None,
|
49
49
|
) -> Any:
|
50
|
-
from typing import
|
50
|
+
from typing import Optional
|
51
51
|
|
52
52
|
import xgboost
|
53
53
|
|
@@ -60,10 +60,10 @@ def get_data_iterator(
|
|
60
60
|
|
61
61
|
def __init__(
|
62
62
|
self,
|
63
|
-
file_paths:
|
63
|
+
file_paths: list[str],
|
64
64
|
batch_size: int,
|
65
|
-
input_cols:
|
66
|
-
label_cols:
|
65
|
+
input_cols: list[str],
|
66
|
+
label_cols: list[str],
|
67
67
|
sample_weight_col: Optional[str] = None,
|
68
68
|
) -> None:
|
69
69
|
"""
|
@@ -151,10 +151,10 @@ def get_data_iterator(
|
|
151
151
|
|
152
152
|
def train_xgboost_model(
|
153
153
|
estimator: object,
|
154
|
-
file_paths:
|
154
|
+
file_paths: list[str],
|
155
155
|
batch_size: int,
|
156
|
-
input_cols:
|
157
|
-
label_cols:
|
156
|
+
input_cols: list[str],
|
157
|
+
label_cols: list[str],
|
158
158
|
sample_weight_col: Optional[str] = None,
|
159
159
|
) -> object:
|
160
160
|
"""
|
@@ -247,8 +247,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
247
247
|
estimator: object,
|
248
248
|
dataset: DataFrame,
|
249
249
|
session: Session,
|
250
|
-
input_cols:
|
251
|
-
label_cols: Optional[
|
250
|
+
input_cols: list[str],
|
251
|
+
label_cols: Optional[list[str]],
|
252
252
|
sample_weight_col: Optional[str],
|
253
253
|
autogenerated: bool = False,
|
254
254
|
subproject: str = "",
|
@@ -285,8 +285,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
285
285
|
self,
|
286
286
|
model_spec: ModelSpecifications,
|
287
287
|
session: Session,
|
288
|
-
statement_params:
|
289
|
-
import_file_paths:
|
288
|
+
statement_params: dict[str, str],
|
289
|
+
import_file_paths: list[str],
|
290
290
|
) -> Any:
|
291
291
|
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
292
292
|
|
@@ -308,10 +308,10 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
308
308
|
session: Session,
|
309
309
|
dataset_stage_name: str,
|
310
310
|
batch_size: int,
|
311
|
-
input_cols:
|
312
|
-
label_cols:
|
311
|
+
input_cols: list[str],
|
312
|
+
label_cols: list[str],
|
313
313
|
sample_weight_col: Optional[str],
|
314
|
-
statement_params:
|
314
|
+
statement_params: dict[str, str],
|
315
315
|
) -> str:
|
316
316
|
import os
|
317
317
|
import sys
|
@@ -365,7 +365,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
|
365
365
|
|
366
366
|
return fit_wrapper_sproc
|
367
367
|
|
368
|
-
def _write_training_data_to_stage(self, dataset_stage_name: str) ->
|
368
|
+
def _write_training_data_to_stage(self, dataset_stage_name: str) -> list[str]:
|
369
369
|
"""
|
370
370
|
Materializes the training to the specified stage and returns the list of stage file paths.
|
371
371
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional, Protocol, TypedDict, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
@@ -29,9 +29,9 @@ class LocalModelTransformHandlers(Protocol):
|
|
29
29
|
def batch_inference(
|
30
30
|
self,
|
31
31
|
inference_method: str,
|
32
|
-
input_cols:
|
33
|
-
expected_output_cols:
|
34
|
-
snowpark_input_cols: Optional[
|
32
|
+
input_cols: list[str],
|
33
|
+
expected_output_cols: list[str],
|
34
|
+
snowpark_input_cols: Optional[list[str]],
|
35
35
|
drop_input_cols: Optional[bool] = False,
|
36
36
|
*args: Any,
|
37
37
|
**kwargs: Any,
|
@@ -57,8 +57,8 @@ class LocalModelTransformHandlers(Protocol):
|
|
57
57
|
|
58
58
|
def score(
|
59
59
|
self,
|
60
|
-
input_cols:
|
61
|
-
label_cols:
|
60
|
+
input_cols: list[str],
|
61
|
+
label_cols: list[str],
|
62
62
|
sample_weight_col: Optional[str],
|
63
63
|
*args: Any,
|
64
64
|
**kwargs: Any,
|
@@ -105,10 +105,10 @@ class RemoteModelTransformHandlers(Protocol):
|
|
105
105
|
def batch_inference(
|
106
106
|
self,
|
107
107
|
inference_method: str,
|
108
|
-
input_cols:
|
109
|
-
expected_output_cols:
|
108
|
+
input_cols: list[str],
|
109
|
+
expected_output_cols: list[str],
|
110
110
|
session: snowpark.Session,
|
111
|
-
dependencies:
|
111
|
+
dependencies: list[str],
|
112
112
|
drop_input_cols: Optional[bool] = False,
|
113
113
|
expected_output_cols_type: Optional[str] = "",
|
114
114
|
*args: Any,
|
@@ -137,11 +137,11 @@ class RemoteModelTransformHandlers(Protocol):
|
|
137
137
|
|
138
138
|
def score(
|
139
139
|
self,
|
140
|
-
input_cols:
|
141
|
-
label_cols:
|
140
|
+
input_cols: list[str],
|
141
|
+
label_cols: list[str],
|
142
142
|
session: snowpark.Session,
|
143
|
-
dependencies:
|
144
|
-
score_sproc_imports:
|
143
|
+
dependencies: list[str],
|
144
|
+
score_sproc_imports: list[str],
|
145
145
|
sample_weight_col: Optional[str] = None,
|
146
146
|
*args: Any,
|
147
147
|
**kwargs: Any,
|
@@ -173,10 +173,10 @@ ModelTransformHandlers = Union[LocalModelTransformHandlers, RemoteModelTransform
|
|
173
173
|
class BatchInferenceKwargsTypedDict(TypedDict, total=False):
|
174
174
|
"""A typed dict specifying all possible optional keyword args accepted by batch_inference() methods."""
|
175
175
|
|
176
|
-
snowpark_input_cols: Optional[
|
176
|
+
snowpark_input_cols: Optional[list[str]]
|
177
177
|
drop_input_cols: Optional[bool]
|
178
178
|
session: snowpark.Session
|
179
|
-
dependencies:
|
179
|
+
dependencies: list[str]
|
180
180
|
expected_output_cols_type: str
|
181
181
|
n_neighbors: Optional[int]
|
182
182
|
return_distance: bool
|
@@ -186,5 +186,5 @@ class ScoreKwargsTypedDict(TypedDict, total=False):
|
|
186
186
|
"""A typed dict specifying all possible optional keyword args accepted by score() methods."""
|
187
187
|
|
188
188
|
session: snowpark.Session
|
189
|
-
dependencies:
|
190
|
-
score_sproc_imports:
|
189
|
+
dependencies: list[str]
|
190
|
+
score_sproc_imports: list[str]
|
@@ -3,7 +3,7 @@
|
|
3
3
|
import inspect
|
4
4
|
import warnings
|
5
5
|
from enum import Enum
|
6
|
-
from typing import Any, Callable,
|
6
|
+
from typing import Any, Callable, Iterable, Optional, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import sklearn
|
@@ -62,7 +62,7 @@ class BasicStatistics(str, Enum):
|
|
62
62
|
MODE = "mode"
|
63
63
|
|
64
64
|
|
65
|
-
def get_default_args(func: Callable[..., None]) ->
|
65
|
+
def get_default_args(func: Callable[..., None]) -> dict[str, Any]:
|
66
66
|
signature = inspect.signature(func)
|
67
67
|
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
68
68
|
|
@@ -72,16 +72,16 @@ def generate_value_with_prefix(prefix: str) -> str:
|
|
72
72
|
|
73
73
|
|
74
74
|
def get_filtered_valid_sklearn_args(
|
75
|
-
args:
|
76
|
-
default_sklearn_args:
|
75
|
+
args: dict[str, Any],
|
76
|
+
default_sklearn_args: dict[str, Any],
|
77
77
|
sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
|
78
78
|
sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
|
79
79
|
snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
|
80
|
-
sklearn_added_keyword_to_version_dict: Optional[
|
81
|
-
sklearn_added_kwarg_value_to_version_dict: Optional[
|
82
|
-
sklearn_deprecated_keyword_to_version_dict: Optional[
|
83
|
-
sklearn_removed_keyword_to_version_dict: Optional[
|
84
|
-
) ->
|
80
|
+
sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
81
|
+
sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
|
82
|
+
sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
83
|
+
sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
84
|
+
) -> dict[str, Any]:
|
85
85
|
"""
|
86
86
|
Get valid sklearn keyword arguments with non-default values.
|
87
87
|
|
@@ -241,7 +241,7 @@ def to_native_format(obj: Any) -> Any:
|
|
241
241
|
return obj.to_sklearn()
|
242
242
|
|
243
243
|
|
244
|
-
def table_exists(session: snowpark.Session, table_name: str, statement_params:
|
244
|
+
def table_exists(session: snowpark.Session, table_name: str, statement_params: dict[str, Any]) -> bool:
|
245
245
|
try:
|
246
246
|
session.table(table_name).limit(0).collect(statement_params=statement_params)
|
247
247
|
return True
|
@@ -2,7 +2,7 @@
|
|
2
2
|
import inspect
|
3
3
|
from abc import abstractmethod
|
4
4
|
from datetime import datetime
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Iterable, Mapping, Optional, Union, overload
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
@@ -28,9 +28,9 @@ SKLEARN_SUPERVISED_ESTIMATORS = ["regressor", "classifier"]
|
|
28
28
|
SKLEARN_SINGLE_OUTPUT_ESTIMATORS = ["DensityEstimator", "clusterer", "outlier_detector"]
|
29
29
|
|
30
30
|
|
31
|
-
def _process_cols(cols: Optional[Union[str, Iterable[str]]]) ->
|
31
|
+
def _process_cols(cols: Optional[Union[str, Iterable[str]]]) -> list[str]:
|
32
32
|
"""Convert cols to a list."""
|
33
|
-
col_list:
|
33
|
+
col_list: list[str] = []
|
34
34
|
if cols is None:
|
35
35
|
return col_list
|
36
36
|
elif type(cols) is list:
|
@@ -55,10 +55,10 @@ class Base:
|
|
55
55
|
passthrough_cols: List columns not to be used or modified by the estimator/transformers.
|
56
56
|
These columns will be passed through all the estimator/transformer operations without any modifications.
|
57
57
|
"""
|
58
|
-
self.input_cols:
|
59
|
-
self.output_cols:
|
60
|
-
self.label_cols:
|
61
|
-
self.passthrough_cols:
|
58
|
+
self.input_cols: list[str] = []
|
59
|
+
self.output_cols: list[str] = []
|
60
|
+
self.label_cols: list[str] = []
|
61
|
+
self.passthrough_cols: list[str] = []
|
62
62
|
|
63
63
|
def _create_unfitted_sklearn_object(self) -> Any:
|
64
64
|
raise NotImplementedError()
|
@@ -66,7 +66,7 @@ class Base:
|
|
66
66
|
def _create_sklearn_object(self) -> Any:
|
67
67
|
raise NotImplementedError()
|
68
68
|
|
69
|
-
def get_input_cols(self) ->
|
69
|
+
def get_input_cols(self) -> list[str]:
|
70
70
|
"""
|
71
71
|
Input columns getter.
|
72
72
|
|
@@ -88,7 +88,7 @@ class Base:
|
|
88
88
|
self.input_cols = _process_cols(input_cols)
|
89
89
|
return self
|
90
90
|
|
91
|
-
def get_output_cols(self) ->
|
91
|
+
def get_output_cols(self) -> list[str]:
|
92
92
|
"""
|
93
93
|
Output columns getter.
|
94
94
|
|
@@ -110,7 +110,7 @@ class Base:
|
|
110
110
|
self.output_cols = _process_cols(output_cols)
|
111
111
|
return self
|
112
112
|
|
113
|
-
def get_label_cols(self) ->
|
113
|
+
def get_label_cols(self) -> list[str]:
|
114
114
|
"""
|
115
115
|
Label column getter.
|
116
116
|
|
@@ -132,7 +132,7 @@ class Base:
|
|
132
132
|
self.label_cols = _process_cols(label_cols)
|
133
133
|
return self
|
134
134
|
|
135
|
-
def get_passthrough_cols(self) ->
|
135
|
+
def get_passthrough_cols(self) -> list[str]:
|
136
136
|
"""
|
137
137
|
Passthrough columns getter.
|
138
138
|
|
@@ -215,7 +215,7 @@ class Base:
|
|
215
215
|
)
|
216
216
|
|
217
217
|
@classmethod
|
218
|
-
def _get_param_names(cls) ->
|
218
|
+
def _get_param_names(cls) -> list[str]:
|
219
219
|
"""Get parameter names for the transformer"""
|
220
220
|
# fetch the constructor or the original constructor before
|
221
221
|
# deprecation wrapping if any
|
@@ -244,7 +244,7 @@ class Base:
|
|
244
244
|
# Extract and sort argument names excluding 'self'
|
245
245
|
return sorted(p.name for p in parameters)
|
246
246
|
|
247
|
-
def get_params(self, deep: bool = True) ->
|
247
|
+
def get_params(self, deep: bool = True) -> dict[str, Any]:
|
248
248
|
"""
|
249
249
|
Get the snowflake-ml parameters for this transformer.
|
250
250
|
|
@@ -255,7 +255,7 @@ class Base:
|
|
255
255
|
Returns:
|
256
256
|
Parameter names mapped to their values.
|
257
257
|
"""
|
258
|
-
out:
|
258
|
+
out: dict[str, Any] = dict()
|
259
259
|
for key in self._get_param_names():
|
260
260
|
if hasattr(self, key):
|
261
261
|
value = getattr(self, key)
|
@@ -320,11 +320,11 @@ class Base:
|
|
320
320
|
sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
|
321
321
|
sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
|
322
322
|
snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
|
323
|
-
sklearn_added_keyword_to_version_dict: Optional[
|
324
|
-
sklearn_added_kwarg_value_to_version_dict: Optional[
|
325
|
-
sklearn_deprecated_keyword_to_version_dict: Optional[
|
326
|
-
sklearn_removed_keyword_to_version_dict: Optional[
|
327
|
-
) ->
|
323
|
+
sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
324
|
+
sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
|
325
|
+
sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
326
|
+
sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
|
327
|
+
) -> dict[str, Any]:
|
328
328
|
"""
|
329
329
|
Get sklearn keyword arguments.
|
330
330
|
|
@@ -350,7 +350,7 @@ class Base:
|
|
350
350
|
"""
|
351
351
|
default_sklearn_args = _utils.get_default_args(default_sklearn_obj.__class__.__init__)
|
352
352
|
given_args = self.get_params()
|
353
|
-
sklearn_args:
|
353
|
+
sklearn_args: dict[str, Any] = _utils.get_filtered_valid_sklearn_args(
|
354
354
|
args=given_args,
|
355
355
|
default_sklearn_args=default_sklearn_args,
|
356
356
|
sklearn_initial_keywords=sklearn_initial_keywords,
|
@@ -368,8 +368,8 @@ class BaseEstimator(Base):
|
|
368
368
|
def __init__(
|
369
369
|
self,
|
370
370
|
*,
|
371
|
-
file_names: Optional[
|
372
|
-
custom_states: Optional[
|
371
|
+
file_names: Optional[list[str]] = None,
|
372
|
+
custom_states: Optional[list[str]] = None,
|
373
373
|
sample_weight_col: Optional[str] = None,
|
374
374
|
) -> None:
|
375
375
|
"""
|
@@ -418,7 +418,7 @@ class BaseEstimator(Base):
|
|
418
418
|
self.sample_weight_col = sample_weight_col
|
419
419
|
return self
|
420
420
|
|
421
|
-
def _get_dependencies(self) ->
|
421
|
+
def _get_dependencies(self) -> list[str]:
|
422
422
|
"""
|
423
423
|
Return the list of conda dependencies required to work with the object.
|
424
424
|
|
@@ -458,8 +458,8 @@ class BaseEstimator(Base):
|
|
458
458
|
return dataset[self.input_cols]
|
459
459
|
|
460
460
|
def _compute(
|
461
|
-
self, dataset: snowpark.DataFrame, cols:
|
462
|
-
) ->
|
461
|
+
self, dataset: snowpark.DataFrame, cols: list[str], states: list[str]
|
462
|
+
) -> dict[str, dict[str, Union[int, float, str]]]:
|
463
463
|
"""
|
464
464
|
Compute required states of the columns.
|
465
465
|
|
@@ -474,7 +474,7 @@ class BaseEstimator(Base):
|
|
474
474
|
A dict of {column_name: {state: value}} of each column.
|
475
475
|
"""
|
476
476
|
|
477
|
-
def _compute_on_partition(df: snowpark.DataFrame, cols_subset:
|
477
|
+
def _compute_on_partition(df: snowpark.DataFrame, cols_subset: list[str]) -> snowpark.DataFrame:
|
478
478
|
"""Returns a DataFrame with the desired computation on the specified column subset."""
|
479
479
|
exprs = []
|
480
480
|
sql_prefix = "SQL>>>"
|
@@ -499,7 +499,7 @@ class BaseEstimator(Base):
|
|
499
499
|
statement_params=telemetry.get_statement_params(PROJECT, SUBPROJECT, self.__class__.__name__),
|
500
500
|
)
|
501
501
|
|
502
|
-
computed_dict:
|
502
|
+
computed_dict: dict[str, dict[str, Union[int, float, str]]] = {}
|
503
503
|
for idx, val in enumerate(_results[0]):
|
504
504
|
col_name = cols[idx // len(states)]
|
505
505
|
if col_name not in computed_dict:
|
@@ -516,8 +516,8 @@ class BaseTransformer(BaseEstimator):
|
|
516
516
|
self,
|
517
517
|
*,
|
518
518
|
drop_input_cols: Optional[bool] = False,
|
519
|
-
file_names: Optional[
|
520
|
-
custom_states: Optional[
|
519
|
+
file_names: Optional[list[str]] = None,
|
520
|
+
custom_states: Optional[list[str]] = None,
|
521
521
|
sample_weight_col: Optional[str] = None,
|
522
522
|
) -> None:
|
523
523
|
"""Base class for all transformers."""
|
@@ -551,7 +551,7 @@ class BaseTransformer(BaseEstimator):
|
|
551
551
|
),
|
552
552
|
)
|
553
553
|
|
554
|
-
def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) ->
|
554
|
+
def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> list[str]:
|
555
555
|
"""
|
556
556
|
Infer input_cols from the dataset. Input column are all columns in the input dataset that are not
|
557
557
|
designated as label, passthrough, or sample weight columns.
|
@@ -569,7 +569,7 @@ class BaseTransformer(BaseEstimator):
|
|
569
569
|
]
|
570
570
|
return cols
|
571
571
|
|
572
|
-
def _infer_output_cols(self) ->
|
572
|
+
def _infer_output_cols(self) -> list[str]:
|
573
573
|
"""Infer output column names from based on the estimator.
|
574
574
|
|
575
575
|
Returns:
|
@@ -624,7 +624,7 @@ class BaseTransformer(BaseEstimator):
|
|
624
624
|
cols = self._infer_output_cols()
|
625
625
|
self.set_output_cols(output_cols=cols)
|
626
626
|
|
627
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[
|
627
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[list[str]] = None) -> list[str]:
|
628
628
|
"""Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
629
629
|
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
630
630
|
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
4
4
|
|
5
|
-
pkg_dir = os.path.dirname(
|
5
|
+
pkg_dir = os.path.dirname(__file__)
|
6
6
|
pkg_name = __name__
|
7
7
|
exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
8
8
|
for k, v in exportable_classes.items():
|