snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +11 -1
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/feature_store/feature_store.py +166 -184
- snowflake/ml/feature_store/feature_view.py +12 -24
- snowflake/ml/fileset/sfcfs.py +56 -50
- snowflake/ml/fileset/stage_fs.py +48 -13
- snowflake/ml/model/_client/model/model_version_impl.py +6 -49
- snowflake/ml/model/_client/ops/model_ops.py +78 -29
- snowflake/ml/model/_client/sql/model.py +23 -2
- snowflake/ml/model/_client/sql/model_version.py +22 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +7 -5
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -2
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
- snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
- snowflake/ml/modeling/cluster/birch.py +195 -123
- snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
- snowflake/ml/modeling/cluster/dbscan.py +195 -123
- snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
- snowflake/ml/modeling/cluster/k_means.py +195 -123
- snowflake/ml/modeling/cluster/mean_shift.py +195 -123
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
- snowflake/ml/modeling/cluster/optics.py +195 -123
- snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
- snowflake/ml/modeling/compose/column_transformer.py +195 -123
- snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
- snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
- snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
- snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
- snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
- snowflake/ml/modeling/covariance/oas.py +195 -123
- snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
- snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
- snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
- snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
- snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/pca.py +195 -123
- snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
- snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
- snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
- snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +24 -6
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
- snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
- snowflake/ml/modeling/impute/knn_imputer.py +195 -123
- snowflake/ml/modeling/impute/missing_indicator.py +195 -123
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
- snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/lars.py +195 -123
- snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
- snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/perceptron.py +195 -123
- snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ridge.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
- snowflake/ml/modeling/manifold/isomap.py +195 -123
- snowflake/ml/modeling/manifold/mds.py +195 -123
- snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
- snowflake/ml/modeling/manifold/tsne.py +195 -123
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
- snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
- snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
- snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
- snowflake/ml/modeling/svm/linear_svc.py +195 -123
- snowflake/ml/modeling/svm/linear_svr.py +195 -123
- snowflake/ml/modeling/svm/nu_svc.py +195 -123
- snowflake/ml/modeling/svm/nu_svr.py +195 -123
- snowflake/ml/modeling/svm/svc.py +195 -123
- snowflake/ml/modeling/svm/svr.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
- snowflake/ml/registry/_manager/model_manager.py +5 -1
- snowflake/ml/registry/model_registry.py +99 -26
- snowflake/ml/registry/registry.py +3 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from snowflake.ml.model.model_signature import (
|
|
20
20
|
FeatureSpec,
|
21
21
|
ModelSignature,
|
22
22
|
_infer_signature,
|
23
|
+
_rename_signature_with_snowflake_identifiers,
|
23
24
|
)
|
24
25
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
25
26
|
gather_dependencies,
|
@@ -330,7 +331,7 @@ class GridSearchCV(BaseTransformer):
|
|
330
331
|
)
|
331
332
|
self._sklearn_object = model_trainer.train()
|
332
333
|
self._is_fitted = True
|
333
|
-
self.
|
334
|
+
self._generate_model_signatures(dataset)
|
334
335
|
return self
|
335
336
|
|
336
337
|
def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
|
@@ -384,6 +385,9 @@ class GridSearchCV(BaseTransformer):
|
|
384
385
|
|
385
386
|
Returns:
|
386
387
|
Transformed dataset.
|
388
|
+
|
389
|
+
Raises:
|
390
|
+
SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
|
387
391
|
"""
|
388
392
|
super()._check_dataset_type(dataset)
|
389
393
|
|
@@ -396,9 +400,21 @@ class GridSearchCV(BaseTransformer):
|
|
396
400
|
expected_type_inferred = ""
|
397
401
|
# infer the datatype from label columns
|
398
402
|
if "predict" in self.model_signatures:
|
399
|
-
|
400
|
-
|
401
|
-
|
403
|
+
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
404
|
+
label_cols_signatures = [
|
405
|
+
row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
|
406
|
+
]
|
407
|
+
if len(label_cols_signatures) == 0:
|
408
|
+
error_str = (
|
409
|
+
f"Output columns {self.output_cols} do not match"
|
410
|
+
f"model signatures {self.model_signatures['predict'].outputs}."
|
411
|
+
)
|
412
|
+
raise exceptions.SnowflakeMLException(
|
413
|
+
error_code=error_codes.INVALID_ATTRIBUTE,
|
414
|
+
original_exception=ValueError(error_str),
|
415
|
+
)
|
416
|
+
|
417
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
402
418
|
self._deps = self._batch_inference_validate_snowpark(
|
403
419
|
dataset=dataset,
|
404
420
|
inference_method=inference_method,
|
@@ -785,12 +801,22 @@ class GridSearchCV(BaseTransformer):
|
|
785
801
|
|
786
802
|
return output_score
|
787
803
|
|
788
|
-
def
|
804
|
+
def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
|
805
|
+
"""
|
806
|
+
Get sklearn.model_selection.GridSearchCV object.
|
807
|
+
"""
|
808
|
+
assert self._sklearn_object is not None
|
809
|
+
return self._sklearn_object
|
810
|
+
|
811
|
+
def _get_dependencies(self) -> List[str]:
|
812
|
+
return self._deps
|
813
|
+
|
814
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
789
815
|
self._model_signature_dict = dict()
|
790
816
|
|
791
817
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
792
818
|
|
793
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
819
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
794
820
|
outputs: List[BaseFeatureSpec] = []
|
795
821
|
if hasattr(self, "predict"):
|
796
822
|
# keep mypy happy
|
@@ -798,18 +824,20 @@ class GridSearchCV(BaseTransformer):
|
|
798
824
|
# For classifier, the type of predict is the same as the type of label
|
799
825
|
if self._sklearn_object._estimator_type == "classifier":
|
800
826
|
# label columns is the desired type for output
|
801
|
-
outputs = list(_infer_signature(dataset[self.label_cols], "output"))
|
827
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
802
828
|
# rename the output columns
|
803
829
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
804
830
|
self._model_signature_dict["predict"] = ModelSignature(
|
805
831
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
806
832
|
)
|
833
|
+
|
807
834
|
# For regressor, the type of predict is float64
|
808
835
|
elif self._sklearn_object._estimator_type == "regressor":
|
809
836
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
810
837
|
self._model_signature_dict["predict"] = ModelSignature(
|
811
838
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
812
839
|
)
|
840
|
+
|
813
841
|
for prob_func in PROB_FUNCTIONS:
|
814
842
|
if hasattr(self, prob_func):
|
815
843
|
output_cols_prefix: str = f"{prob_func}_"
|
@@ -819,6 +847,12 @@ class GridSearchCV(BaseTransformer):
|
|
819
847
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
820
848
|
)
|
821
849
|
|
850
|
+
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
851
|
+
items = list(self._model_signature_dict.items())
|
852
|
+
for method, signature in items:
|
853
|
+
signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
|
854
|
+
self._model_signature_dict[method] = signature
|
855
|
+
|
822
856
|
@property
|
823
857
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
824
858
|
"""Returns model signature of current class.
|
@@ -827,7 +861,7 @@ class GridSearchCV(BaseTransformer):
|
|
827
861
|
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
828
862
|
|
829
863
|
Returns:
|
830
|
-
|
864
|
+
each method and its input output signature
|
831
865
|
"""
|
832
866
|
if self._model_signature_dict is None:
|
833
867
|
raise exceptions.SnowflakeMLException(
|
@@ -835,13 +869,3 @@ class GridSearchCV(BaseTransformer):
|
|
835
869
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
836
870
|
)
|
837
871
|
return self._model_signature_dict
|
838
|
-
|
839
|
-
def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
|
840
|
-
"""
|
841
|
-
Get sklearn.model_selection.GridSearchCV object.
|
842
|
-
"""
|
843
|
-
assert self._sklearn_object is not None
|
844
|
-
return self._sklearn_object
|
845
|
-
|
846
|
-
def _get_dependencies(self) -> List[str]:
|
847
|
-
return self._deps
|
@@ -17,6 +17,7 @@ from snowflake.ml.model.model_signature import (
|
|
17
17
|
FeatureSpec,
|
18
18
|
ModelSignature,
|
19
19
|
_infer_signature,
|
20
|
+
_rename_signature_with_snowflake_identifiers,
|
20
21
|
)
|
21
22
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
22
23
|
gather_dependencies,
|
@@ -343,7 +344,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
343
344
|
)
|
344
345
|
self._sklearn_object = model_trainer.train()
|
345
346
|
self._is_fitted = True
|
346
|
-
self.
|
347
|
+
self._generate_model_signatures(dataset)
|
347
348
|
return self
|
348
349
|
|
349
350
|
def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
|
@@ -383,6 +384,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
383
384
|
|
384
385
|
Returns:
|
385
386
|
Transformed dataset.
|
387
|
+
|
388
|
+
Raises:
|
389
|
+
SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
|
386
390
|
"""
|
387
391
|
super()._check_dataset_type(dataset)
|
388
392
|
|
@@ -395,9 +399,21 @@ class RandomizedSearchCV(BaseTransformer):
|
|
395
399
|
expected_type_inferred = ""
|
396
400
|
# infer the datatype from label columns
|
397
401
|
if "predict" in self.model_signatures:
|
398
|
-
|
399
|
-
|
400
|
-
|
402
|
+
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
403
|
+
label_cols_signatures = [
|
404
|
+
row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
|
405
|
+
]
|
406
|
+
if len(label_cols_signatures) == 0:
|
407
|
+
error_str = (
|
408
|
+
f"Output columns {self.output_cols} do not match"
|
409
|
+
f"model signatures {self.model_signatures['predict'].outputs}."
|
410
|
+
)
|
411
|
+
raise exceptions.SnowflakeMLException(
|
412
|
+
error_code=error_codes.INVALID_ATTRIBUTE,
|
413
|
+
original_exception=ValueError(error_str),
|
414
|
+
)
|
415
|
+
|
416
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
401
417
|
self._deps = self._batch_inference_validate_snowpark(
|
402
418
|
dataset=dataset,
|
403
419
|
inference_method=inference_method,
|
@@ -780,12 +796,22 @@ class RandomizedSearchCV(BaseTransformer):
|
|
780
796
|
|
781
797
|
return output_score
|
782
798
|
|
783
|
-
def
|
799
|
+
def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
|
800
|
+
"""
|
801
|
+
Get sklearn.model_selection.RandomizedSearchCV object.
|
802
|
+
"""
|
803
|
+
assert self._sklearn_object is not None
|
804
|
+
return self._sklearn_object
|
805
|
+
|
806
|
+
def _get_dependencies(self) -> List[str]:
|
807
|
+
return self._deps
|
808
|
+
|
809
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
784
810
|
self._model_signature_dict = dict()
|
785
811
|
|
786
812
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
787
813
|
|
788
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
814
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
789
815
|
outputs: List[BaseFeatureSpec] = []
|
790
816
|
if hasattr(self, "predict"):
|
791
817
|
# keep mypy happy
|
@@ -793,18 +819,20 @@ class RandomizedSearchCV(BaseTransformer):
|
|
793
819
|
# For classifier, the type of predict is the same as the type of label
|
794
820
|
if self._sklearn_object._estimator_type == "classifier":
|
795
821
|
# label columns is the desired type for output
|
796
|
-
outputs = list(_infer_signature(dataset[self.label_cols], "output"))
|
822
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
797
823
|
# rename the output columns
|
798
824
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
799
825
|
self._model_signature_dict["predict"] = ModelSignature(
|
800
826
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
801
827
|
)
|
828
|
+
|
802
829
|
# For regressor, the type of predict is float64
|
803
830
|
elif self._sklearn_object._estimator_type == "regressor":
|
804
831
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
805
832
|
self._model_signature_dict["predict"] = ModelSignature(
|
806
833
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
807
834
|
)
|
835
|
+
|
808
836
|
for prob_func in PROB_FUNCTIONS:
|
809
837
|
if hasattr(self, prob_func):
|
810
838
|
output_cols_prefix: str = f"{prob_func}_"
|
@@ -814,6 +842,12 @@ class RandomizedSearchCV(BaseTransformer):
|
|
814
842
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
815
843
|
)
|
816
844
|
|
845
|
+
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
846
|
+
items = list(self._model_signature_dict.items())
|
847
|
+
for method, signature in items:
|
848
|
+
signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
|
849
|
+
self._model_signature_dict[method] = signature
|
850
|
+
|
817
851
|
@property
|
818
852
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
819
853
|
"""Returns model signature of current class.
|
@@ -822,7 +856,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
822
856
|
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
823
857
|
|
824
858
|
Returns:
|
825
|
-
|
859
|
+
each method and its input output signature
|
826
860
|
"""
|
827
861
|
if self._model_signature_dict is None:
|
828
862
|
raise exceptions.SnowflakeMLException(
|
@@ -830,13 +864,3 @@ class RandomizedSearchCV(BaseTransformer):
|
|
830
864
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
831
865
|
)
|
832
866
|
return self._model_signature_dict
|
833
|
-
|
834
|
-
def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
|
835
|
-
"""
|
836
|
-
Get sklearn.model_selection.RandomizedSearchCV object.
|
837
|
-
"""
|
838
|
-
assert self._sklearn_object is not None
|
839
|
-
return self._sklearn_object
|
840
|
-
|
841
|
-
def _get_dependencies(self) -> List[str]:
|
842
|
-
return self._deps
|