snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from snowflake.ml.model.model_signature import (
|
|
20
20
|
FeatureSpec,
|
21
21
|
ModelSignature,
|
22
22
|
_infer_signature,
|
23
|
+
_rename_signature_with_snowflake_identifiers,
|
23
24
|
)
|
24
25
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
25
26
|
gather_dependencies,
|
@@ -330,12 +331,15 @@ class GridSearchCV(BaseTransformer):
|
|
330
331
|
)
|
331
332
|
self._sklearn_object = model_trainer.train()
|
332
333
|
self._is_fitted = True
|
333
|
-
self.
|
334
|
+
self._generate_model_signatures(dataset)
|
334
335
|
return self
|
335
336
|
|
336
|
-
def _batch_inference_validate_snowpark(
|
337
|
-
|
338
|
-
|
337
|
+
def _batch_inference_validate_snowpark(
|
338
|
+
self,
|
339
|
+
dataset: DataFrame,
|
340
|
+
inference_method: str,
|
341
|
+
) -> None:
|
342
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
339
343
|
|
340
344
|
Args:
|
341
345
|
dataset: snowpark dataframe
|
@@ -345,8 +349,6 @@ class GridSearchCV(BaseTransformer):
|
|
345
349
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
346
350
|
SnowflakeMLException: If the session is None, raise error
|
347
351
|
|
348
|
-
Returns:
|
349
|
-
A list of available package that exists in the snowflake anaconda channel
|
350
352
|
"""
|
351
353
|
if not self._is_fitted:
|
352
354
|
raise exceptions.SnowflakeMLException(
|
@@ -362,10 +364,6 @@ class GridSearchCV(BaseTransformer):
|
|
362
364
|
error_code=error_codes.NOT_FOUND,
|
363
365
|
original_exception=ValueError("Session must not specified for snowpark dataset."),
|
364
366
|
)
|
365
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
366
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
367
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
368
|
-
)
|
369
367
|
|
370
368
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
371
369
|
@telemetry.send_api_usage_telemetry(
|
@@ -384,6 +382,9 @@ class GridSearchCV(BaseTransformer):
|
|
384
382
|
|
385
383
|
Returns:
|
386
384
|
Transformed dataset.
|
385
|
+
|
386
|
+
Raises:
|
387
|
+
SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
|
387
388
|
"""
|
388
389
|
super()._check_dataset_type(dataset)
|
389
390
|
|
@@ -396,13 +397,23 @@ class GridSearchCV(BaseTransformer):
|
|
396
397
|
expected_type_inferred = ""
|
397
398
|
# infer the datatype from label columns
|
398
399
|
if "predict" in self.model_signatures:
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
400
|
+
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
401
|
+
label_cols_signatures = [
|
402
|
+
row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
|
403
|
+
]
|
404
|
+
if len(label_cols_signatures) == 0:
|
405
|
+
error_str = (
|
406
|
+
f"Output columns {self.output_cols} do not match"
|
407
|
+
f"model signatures {self.model_signatures['predict'].outputs}."
|
408
|
+
)
|
409
|
+
raise exceptions.SnowflakeMLException(
|
410
|
+
error_code=error_codes.INVALID_ATTRIBUTE,
|
411
|
+
original_exception=ValueError(error_str),
|
412
|
+
)
|
413
|
+
|
414
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
415
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
416
|
+
self._deps = self._get_dependencies()
|
406
417
|
|
407
418
|
assert isinstance(
|
408
419
|
dataset._session, Session
|
@@ -460,7 +471,8 @@ class GridSearchCV(BaseTransformer):
|
|
460
471
|
inference_method = "transform"
|
461
472
|
|
462
473
|
if isinstance(dataset, DataFrame):
|
463
|
-
self.
|
474
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
475
|
+
self._deps = self._get_dependencies()
|
464
476
|
assert isinstance(
|
465
477
|
dataset._session, Session
|
466
478
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -519,7 +531,8 @@ class GridSearchCV(BaseTransformer):
|
|
519
531
|
inference_method = "predict_proba"
|
520
532
|
|
521
533
|
if isinstance(dataset, DataFrame):
|
522
|
-
self.
|
534
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
535
|
+
self._deps = self._get_dependencies()
|
523
536
|
assert isinstance(
|
524
537
|
dataset._session, Session
|
525
538
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -579,7 +592,8 @@ class GridSearchCV(BaseTransformer):
|
|
579
592
|
inference_method = "predict_log_proba"
|
580
593
|
|
581
594
|
if isinstance(dataset, DataFrame):
|
582
|
-
self.
|
595
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
596
|
+
self._deps = self._get_dependencies()
|
583
597
|
assert isinstance(
|
584
598
|
dataset._session, Session
|
585
599
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -639,7 +653,8 @@ class GridSearchCV(BaseTransformer):
|
|
639
653
|
inference_method = "decision_function"
|
640
654
|
|
641
655
|
if isinstance(dataset, DataFrame):
|
642
|
-
self.
|
656
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
657
|
+
self._deps = self._get_dependencies()
|
643
658
|
assert isinstance(
|
644
659
|
dataset._session, Session
|
645
660
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -700,7 +715,8 @@ class GridSearchCV(BaseTransformer):
|
|
700
715
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
701
716
|
|
702
717
|
if isinstance(dataset, DataFrame):
|
703
|
-
self.
|
718
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
719
|
+
self._deps = self._get_dependencies()
|
704
720
|
assert isinstance(
|
705
721
|
dataset._session, Session
|
706
722
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -751,17 +767,15 @@ class GridSearchCV(BaseTransformer):
|
|
751
767
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
752
768
|
|
753
769
|
if isinstance(dataset, DataFrame):
|
754
|
-
self.
|
755
|
-
|
756
|
-
inference_method="score",
|
757
|
-
)
|
770
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
771
|
+
self._deps = self._get_dependencies()
|
758
772
|
selected_cols = self._get_active_columns()
|
759
773
|
if len(selected_cols) > 0:
|
760
774
|
dataset = dataset.select(selected_cols)
|
761
775
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
762
776
|
transform_kwargs = dict(
|
763
777
|
session=dataset._session,
|
764
|
-
dependencies=
|
778
|
+
dependencies=self._deps,
|
765
779
|
score_sproc_imports=["sklearn"],
|
766
780
|
)
|
767
781
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -785,12 +799,22 @@ class GridSearchCV(BaseTransformer):
|
|
785
799
|
|
786
800
|
return output_score
|
787
801
|
|
788
|
-
def
|
802
|
+
def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
|
803
|
+
"""
|
804
|
+
Get sklearn.model_selection.GridSearchCV object.
|
805
|
+
"""
|
806
|
+
assert self._sklearn_object is not None
|
807
|
+
return self._sklearn_object
|
808
|
+
|
809
|
+
def _get_dependencies(self) -> List[str]:
|
810
|
+
return self._deps
|
811
|
+
|
812
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
789
813
|
self._model_signature_dict = dict()
|
790
814
|
|
791
815
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
792
816
|
|
793
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
817
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
794
818
|
outputs: List[BaseFeatureSpec] = []
|
795
819
|
if hasattr(self, "predict"):
|
796
820
|
# keep mypy happy
|
@@ -798,18 +822,20 @@ class GridSearchCV(BaseTransformer):
|
|
798
822
|
# For classifier, the type of predict is the same as the type of label
|
799
823
|
if self._sklearn_object._estimator_type == "classifier":
|
800
824
|
# label columns is the desired type for output
|
801
|
-
outputs = list(_infer_signature(dataset[self.label_cols], "output"))
|
825
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
802
826
|
# rename the output columns
|
803
827
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
804
828
|
self._model_signature_dict["predict"] = ModelSignature(
|
805
829
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
806
830
|
)
|
831
|
+
|
807
832
|
# For regressor, the type of predict is float64
|
808
833
|
elif self._sklearn_object._estimator_type == "regressor":
|
809
834
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
810
835
|
self._model_signature_dict["predict"] = ModelSignature(
|
811
836
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
812
837
|
)
|
838
|
+
|
813
839
|
for prob_func in PROB_FUNCTIONS:
|
814
840
|
if hasattr(self, prob_func):
|
815
841
|
output_cols_prefix: str = f"{prob_func}_"
|
@@ -819,6 +845,12 @@ class GridSearchCV(BaseTransformer):
|
|
819
845
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
820
846
|
)
|
821
847
|
|
848
|
+
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
849
|
+
items = list(self._model_signature_dict.items())
|
850
|
+
for method, signature in items:
|
851
|
+
signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
|
852
|
+
self._model_signature_dict[method] = signature
|
853
|
+
|
822
854
|
@property
|
823
855
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
824
856
|
"""Returns model signature of current class.
|
@@ -827,7 +859,7 @@ class GridSearchCV(BaseTransformer):
|
|
827
859
|
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
828
860
|
|
829
861
|
Returns:
|
830
|
-
|
862
|
+
each method and its input output signature
|
831
863
|
"""
|
832
864
|
if self._model_signature_dict is None:
|
833
865
|
raise exceptions.SnowflakeMLException(
|
@@ -835,13 +867,3 @@ class GridSearchCV(BaseTransformer):
|
|
835
867
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
836
868
|
)
|
837
869
|
return self._model_signature_dict
|
838
|
-
|
839
|
-
def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
|
840
|
-
"""
|
841
|
-
Get sklearn.model_selection.GridSearchCV object.
|
842
|
-
"""
|
843
|
-
assert self._sklearn_object is not None
|
844
|
-
return self._sklearn_object
|
845
|
-
|
846
|
-
def _get_dependencies(self) -> List[str]:
|
847
|
-
return self._deps
|
@@ -17,6 +17,7 @@ from snowflake.ml.model.model_signature import (
|
|
17
17
|
FeatureSpec,
|
18
18
|
ModelSignature,
|
19
19
|
_infer_signature,
|
20
|
+
_rename_signature_with_snowflake_identifiers,
|
20
21
|
)
|
21
22
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
22
23
|
gather_dependencies,
|
@@ -343,11 +344,25 @@ class RandomizedSearchCV(BaseTransformer):
|
|
343
344
|
)
|
344
345
|
self._sklearn_object = model_trainer.train()
|
345
346
|
self._is_fitted = True
|
346
|
-
self.
|
347
|
+
self._generate_model_signatures(dataset)
|
347
348
|
return self
|
348
349
|
|
349
|
-
def _batch_inference_validate_snowpark(
|
350
|
-
|
350
|
+
def _batch_inference_validate_snowpark(
|
351
|
+
self,
|
352
|
+
dataset: DataFrame,
|
353
|
+
inference_method: str,
|
354
|
+
) -> None:
|
355
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
dataset: snowpark dataframe
|
359
|
+
inference_method: the inference method such as predict, score...
|
360
|
+
|
361
|
+
Raises:
|
362
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
363
|
+
SnowflakeMLException: If the session is None, raise error
|
364
|
+
|
365
|
+
"""
|
351
366
|
if not self._is_fitted:
|
352
367
|
raise exceptions.SnowflakeMLException(
|
353
368
|
error_code=error_codes.METHOD_NOT_ALLOWED,
|
@@ -362,10 +377,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
362
377
|
error_code=error_codes.NOT_FOUND,
|
363
378
|
original_exception=ValueError("Session must not specified for snowpark dataset."),
|
364
379
|
)
|
365
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
366
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
367
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
368
|
-
)
|
369
380
|
|
370
381
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
371
382
|
@telemetry.send_api_usage_telemetry(
|
@@ -383,6 +394,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
383
394
|
|
384
395
|
Returns:
|
385
396
|
Transformed dataset.
|
397
|
+
|
398
|
+
Raises:
|
399
|
+
SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
|
386
400
|
"""
|
387
401
|
super()._check_dataset_type(dataset)
|
388
402
|
|
@@ -395,13 +409,24 @@ class RandomizedSearchCV(BaseTransformer):
|
|
395
409
|
expected_type_inferred = ""
|
396
410
|
# infer the datatype from label columns
|
397
411
|
if "predict" in self.model_signatures:
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
412
|
+
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
413
|
+
label_cols_signatures = [
|
414
|
+
row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
|
415
|
+
]
|
416
|
+
if len(label_cols_signatures) == 0:
|
417
|
+
error_str = (
|
418
|
+
f"Output columns {self.output_cols} do not match"
|
419
|
+
f"model signatures {self.model_signatures['predict'].outputs}."
|
420
|
+
)
|
421
|
+
raise exceptions.SnowflakeMLException(
|
422
|
+
error_code=error_codes.INVALID_ATTRIBUTE,
|
423
|
+
original_exception=ValueError(error_str),
|
424
|
+
)
|
425
|
+
|
426
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
427
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
428
|
+
self._deps = self._get_dependencies()
|
429
|
+
|
405
430
|
assert isinstance(
|
406
431
|
dataset._session, Session
|
407
432
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -457,7 +482,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
457
482
|
inference_method = "transform"
|
458
483
|
|
459
484
|
if isinstance(dataset, DataFrame):
|
460
|
-
self.
|
485
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
486
|
+
self._deps = self._get_dependencies()
|
487
|
+
|
461
488
|
assert isinstance(
|
462
489
|
dataset._session, Session
|
463
490
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -515,7 +542,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
515
542
|
inference_method = "predict_proba"
|
516
543
|
|
517
544
|
if isinstance(dataset, DataFrame):
|
518
|
-
self.
|
545
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
546
|
+
self._deps = self._get_dependencies()
|
547
|
+
|
519
548
|
assert isinstance(
|
520
549
|
dataset._session, Session
|
521
550
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -575,7 +604,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
575
604
|
inference_method = "predict_log_proba"
|
576
605
|
|
577
606
|
if isinstance(dataset, DataFrame):
|
578
|
-
self.
|
607
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
608
|
+
self._deps = self._get_dependencies()
|
609
|
+
|
579
610
|
assert isinstance(
|
580
611
|
dataset._session, Session
|
581
612
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -634,7 +665,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
634
665
|
inference_method = "decision_function"
|
635
666
|
|
636
667
|
if isinstance(dataset, DataFrame):
|
637
|
-
self.
|
668
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
669
|
+
self._deps = self._get_dependencies()
|
670
|
+
|
638
671
|
assert isinstance(
|
639
672
|
dataset._session, Session
|
640
673
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -695,7 +728,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
695
728
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
696
729
|
|
697
730
|
if isinstance(dataset, DataFrame):
|
698
|
-
self.
|
731
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
732
|
+
self._deps = self._get_dependencies()
|
733
|
+
|
699
734
|
assert isinstance(
|
700
735
|
dataset._session, Session
|
701
736
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -745,10 +780,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
745
780
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
746
781
|
|
747
782
|
if isinstance(dataset, DataFrame):
|
748
|
-
self.
|
749
|
-
|
750
|
-
|
751
|
-
)
|
783
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
784
|
+
self._deps = self._get_dependencies()
|
785
|
+
|
752
786
|
selected_cols = self._get_active_columns()
|
753
787
|
if len(selected_cols) > 0:
|
754
788
|
dataset = dataset.select(selected_cols)
|
@@ -756,7 +790,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
756
790
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
757
791
|
transform_kwargs = dict(
|
758
792
|
session=dataset._session,
|
759
|
-
dependencies=
|
793
|
+
dependencies=self._deps,
|
760
794
|
score_sproc_imports=["sklearn"],
|
761
795
|
)
|
762
796
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -780,12 +814,22 @@ class RandomizedSearchCV(BaseTransformer):
|
|
780
814
|
|
781
815
|
return output_score
|
782
816
|
|
783
|
-
def
|
817
|
+
def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
|
818
|
+
"""
|
819
|
+
Get sklearn.model_selection.RandomizedSearchCV object.
|
820
|
+
"""
|
821
|
+
assert self._sklearn_object is not None
|
822
|
+
return self._sklearn_object
|
823
|
+
|
824
|
+
def _get_dependencies(self) -> List[str]:
|
825
|
+
return self._deps
|
826
|
+
|
827
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
784
828
|
self._model_signature_dict = dict()
|
785
829
|
|
786
830
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
787
831
|
|
788
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
832
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
789
833
|
outputs: List[BaseFeatureSpec] = []
|
790
834
|
if hasattr(self, "predict"):
|
791
835
|
# keep mypy happy
|
@@ -793,18 +837,20 @@ class RandomizedSearchCV(BaseTransformer):
|
|
793
837
|
# For classifier, the type of predict is the same as the type of label
|
794
838
|
if self._sklearn_object._estimator_type == "classifier":
|
795
839
|
# label columns is the desired type for output
|
796
|
-
outputs = list(_infer_signature(dataset[self.label_cols], "output"))
|
840
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
797
841
|
# rename the output columns
|
798
842
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
799
843
|
self._model_signature_dict["predict"] = ModelSignature(
|
800
844
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
801
845
|
)
|
846
|
+
|
802
847
|
# For regressor, the type of predict is float64
|
803
848
|
elif self._sklearn_object._estimator_type == "regressor":
|
804
849
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
805
850
|
self._model_signature_dict["predict"] = ModelSignature(
|
806
851
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
807
852
|
)
|
853
|
+
|
808
854
|
for prob_func in PROB_FUNCTIONS:
|
809
855
|
if hasattr(self, prob_func):
|
810
856
|
output_cols_prefix: str = f"{prob_func}_"
|
@@ -814,6 +860,12 @@ class RandomizedSearchCV(BaseTransformer):
|
|
814
860
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
815
861
|
)
|
816
862
|
|
863
|
+
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
864
|
+
items = list(self._model_signature_dict.items())
|
865
|
+
for method, signature in items:
|
866
|
+
signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
|
867
|
+
self._model_signature_dict[method] = signature
|
868
|
+
|
817
869
|
@property
|
818
870
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
819
871
|
"""Returns model signature of current class.
|
@@ -822,7 +874,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
822
874
|
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
823
875
|
|
824
876
|
Returns:
|
825
|
-
|
877
|
+
each method and its input output signature
|
826
878
|
"""
|
827
879
|
if self._model_signature_dict is None:
|
828
880
|
raise exceptions.SnowflakeMLException(
|
@@ -830,13 +882,3 @@ class RandomizedSearchCV(BaseTransformer):
|
|
830
882
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
831
883
|
)
|
832
884
|
return self._model_signature_dict
|
833
|
-
|
834
|
-
def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
|
835
|
-
"""
|
836
|
-
Get sklearn.model_selection.RandomizedSearchCV object.
|
837
|
-
"""
|
838
|
-
assert self._sklearn_object is not None
|
839
|
-
return self._sklearn_object
|
840
|
-
|
841
|
-
def _get_dependencies(self) -> List[str]:
|
842
|
-
return self._deps
|