snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import inspect
|
3
3
|
from abc import abstractmethod
|
4
|
-
from collections import defaultdict
|
5
4
|
from datetime import datetime
|
6
5
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
|
7
6
|
|
@@ -18,6 +17,7 @@ from snowflake.ml._internal.exceptions import (
|
|
18
17
|
)
|
19
18
|
from snowflake.ml._internal.lineage import lineage_utils
|
20
19
|
from snowflake.ml._internal.utils import identifier, parallelize
|
20
|
+
from snowflake.ml.data import data_source
|
21
21
|
from snowflake.ml.modeling.framework import _utils
|
22
22
|
from snowflake.snowpark import functions as F
|
23
23
|
|
@@ -246,7 +246,7 @@ class Base:
|
|
246
246
|
|
247
247
|
def get_params(self, deep: bool = True) -> Dict[str, Any]:
|
248
248
|
"""
|
249
|
-
Get parameters for this transformer.
|
249
|
+
Get the snowflake-ml parameters for this transformer.
|
250
250
|
|
251
251
|
Args:
|
252
252
|
deep: If True, will return the parameters for this transformer and
|
@@ -265,13 +265,13 @@ class Base:
|
|
265
265
|
out[key] = value
|
266
266
|
return out
|
267
267
|
|
268
|
-
def set_params(self, **params:
|
268
|
+
def set_params(self, **params: Any) -> None:
|
269
269
|
"""
|
270
270
|
Set the parameters of this transformer.
|
271
271
|
|
272
|
-
The method works on simple transformers as well as on nested
|
273
|
-
|
274
|
-
so that it's possible to update each component of a nested object.
|
272
|
+
The method works on simple transformers as well as on sklearn compatible pipelines with nested
|
273
|
+
objects, once the transformer has been fit. Nested objects have parameters of the form
|
274
|
+
``<component>__<parameter>`` so that it's possible to update each component of a nested object.
|
275
275
|
|
276
276
|
Args:
|
277
277
|
**params: Transformer parameter names mapped to their values.
|
@@ -283,12 +283,28 @@ class Base:
|
|
283
283
|
# simple optimization to gain speed (inspect is slow)
|
284
284
|
return
|
285
285
|
valid_params = self.get_params(deep=True)
|
286
|
+
valid_skl_params = {}
|
287
|
+
if hasattr(self, "_sklearn_object") and self._sklearn_object is not None:
|
288
|
+
valid_skl_params = self._sklearn_object.get_params()
|
286
289
|
|
287
|
-
nested_params: Dict[str, Any] = defaultdict(dict) # grouped by prefix
|
288
290
|
for key, value in params.items():
|
289
|
-
|
290
|
-
|
291
|
-
|
291
|
+
if valid_params.get("steps"):
|
292
|
+
# Recurse through pipeline steps
|
293
|
+
key, _, sub_key = key.partition("__")
|
294
|
+
for name, nested_object in valid_params["steps"]:
|
295
|
+
if name == key:
|
296
|
+
nested_object.set_params(**{sub_key: value})
|
297
|
+
|
298
|
+
elif key in valid_params:
|
299
|
+
setattr(self, key, value)
|
300
|
+
valid_params[key] = value
|
301
|
+
elif key in valid_skl_params:
|
302
|
+
# This dictionary would be empty if the following assert were not true, as specified above.
|
303
|
+
assert hasattr(self, "_sklearn_object") and self._sklearn_object is not None
|
304
|
+
setattr(self._sklearn_object, key, value)
|
305
|
+
valid_skl_params[key] = value
|
306
|
+
else:
|
307
|
+
local_valid_params = self._get_param_names() + list(valid_skl_params.keys())
|
292
308
|
raise exceptions.SnowflakeMLException(
|
293
309
|
error_code=error_codes.INVALID_ARGUMENT,
|
294
310
|
original_exception=ValueError(
|
@@ -298,15 +314,6 @@ class Base:
|
|
298
314
|
),
|
299
315
|
)
|
300
316
|
|
301
|
-
if delim:
|
302
|
-
nested_params[key][sub_key] = value
|
303
|
-
else:
|
304
|
-
setattr(self, key, value)
|
305
|
-
valid_params[key] = value
|
306
|
-
|
307
|
-
for key, sub_params in nested_params.items():
|
308
|
-
valid_params[key].set_params(**sub_params)
|
309
|
-
|
310
317
|
def get_sklearn_args(
|
311
318
|
self,
|
312
319
|
default_sklearn_obj: Optional[object] = None,
|
@@ -427,6 +434,8 @@ class BaseEstimator(Base):
|
|
427
434
|
def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "BaseEstimator":
|
428
435
|
"""Runs universal logics for all fit implementations."""
|
429
436
|
data_sources = lineage_utils.get_data_sources(dataset)
|
437
|
+
if not data_sources and isinstance(dataset, snowpark.DataFrame):
|
438
|
+
data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
|
430
439
|
lineage_utils.set_data_sources(self, data_sources)
|
431
440
|
return self._fit(dataset)
|
432
441
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -564,12 +561,23 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
564
561
|
autogenerated=self._autogenerated,
|
565
562
|
subproject=_SUBPROJECT,
|
566
563
|
)
|
567
|
-
|
568
|
-
|
569
|
-
expected_output_cols_list=(
|
570
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
571
|
-
),
|
564
|
+
expected_output_cols = (
|
565
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
572
566
|
)
|
567
|
+
if isinstance(dataset, DataFrame):
|
568
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
569
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
570
|
+
)
|
571
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
572
|
+
drop_input_cols=self._drop_input_cols,
|
573
|
+
expected_output_cols_list=expected_output_cols,
|
574
|
+
example_output_pd_df=example_output_pd_df,
|
575
|
+
)
|
576
|
+
else:
|
577
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
578
|
+
drop_input_cols=self._drop_input_cols,
|
579
|
+
expected_output_cols_list=expected_output_cols,
|
580
|
+
)
|
573
581
|
self._sklearn_object = fitted_estimator
|
574
582
|
self._is_fitted = True
|
575
583
|
return output_result
|
@@ -648,12 +656,41 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
648
656
|
|
649
657
|
return rv
|
650
658
|
|
651
|
-
def
|
652
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
653
|
-
) -> List[str]:
|
659
|
+
def _align_expected_output(
|
660
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
661
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
662
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
663
|
+
and output dataframe with 1 line.
|
664
|
+
If the method is fit_predict, run 2 lines of data.
|
665
|
+
"""
|
654
666
|
# in case the inferred output column names dimension is different
|
655
667
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
656
|
-
|
668
|
+
|
669
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
670
|
+
# so change the minimum of number of rows to 2
|
671
|
+
num_examples = 2
|
672
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
673
|
+
project=_PROJECT,
|
674
|
+
subproject=_SUBPROJECT,
|
675
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
676
|
+
inspect.currentframe(), GaussianProcessClassifier.__class__.__name__
|
677
|
+
),
|
678
|
+
api_calls=[Session.call],
|
679
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
680
|
+
)
|
681
|
+
if output_cols_prefix == "fit_predict_":
|
682
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
683
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
684
|
+
num_examples = self._sklearn_object.n_clusters
|
685
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
686
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
687
|
+
num_examples = self._sklearn_object.min_samples
|
688
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
689
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
690
|
+
num_examples = self._sklearn_object.n_neighbors
|
691
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
692
|
+
else:
|
693
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
657
694
|
|
658
695
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
659
696
|
# seen during the fit.
|
@@ -665,12 +702,14 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
665
702
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
666
703
|
if self.sample_weight_col:
|
667
704
|
output_df_columns_set -= set(self.sample_weight_col)
|
705
|
+
|
668
706
|
# if the dimension of inferred output column names is correct; use it
|
669
707
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
670
|
-
return expected_output_cols_list
|
708
|
+
return expected_output_cols_list, output_df_pd
|
671
709
|
# otherwise, use the sklearn estimator's output
|
672
710
|
else:
|
673
|
-
|
711
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
712
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
674
713
|
|
675
714
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
676
715
|
@telemetry.send_api_usage_telemetry(
|
@@ -718,7 +757,7 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
718
757
|
drop_input_cols=self._drop_input_cols,
|
719
758
|
expected_output_cols_type="float",
|
720
759
|
)
|
721
|
-
expected_output_cols = self.
|
760
|
+
expected_output_cols, _ = self._align_expected_output(
|
722
761
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
723
762
|
)
|
724
763
|
|
@@ -786,7 +825,7 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
786
825
|
drop_input_cols=self._drop_input_cols,
|
787
826
|
expected_output_cols_type="float",
|
788
827
|
)
|
789
|
-
expected_output_cols = self.
|
828
|
+
expected_output_cols, _ = self._align_expected_output(
|
790
829
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
791
830
|
)
|
792
831
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -849,7 +888,7 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
849
888
|
drop_input_cols=self._drop_input_cols,
|
850
889
|
expected_output_cols_type="float",
|
851
890
|
)
|
852
|
-
expected_output_cols = self.
|
891
|
+
expected_output_cols, _ = self._align_expected_output(
|
853
892
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
854
893
|
)
|
855
894
|
|
@@ -914,7 +953,7 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
914
953
|
drop_input_cols = self._drop_input_cols,
|
915
954
|
expected_output_cols_type="float",
|
916
955
|
)
|
917
|
-
expected_output_cols = self.
|
956
|
+
expected_output_cols, _ = self._align_expected_output(
|
918
957
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
919
958
|
)
|
920
959
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -555,12 +552,23 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
555
552
|
autogenerated=self._autogenerated,
|
556
553
|
subproject=_SUBPROJECT,
|
557
554
|
)
|
558
|
-
|
559
|
-
|
560
|
-
expected_output_cols_list=(
|
561
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
562
|
-
),
|
555
|
+
expected_output_cols = (
|
556
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
563
557
|
)
|
558
|
+
if isinstance(dataset, DataFrame):
|
559
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
560
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
561
|
+
)
|
562
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
563
|
+
drop_input_cols=self._drop_input_cols,
|
564
|
+
expected_output_cols_list=expected_output_cols,
|
565
|
+
example_output_pd_df=example_output_pd_df,
|
566
|
+
)
|
567
|
+
else:
|
568
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
569
|
+
drop_input_cols=self._drop_input_cols,
|
570
|
+
expected_output_cols_list=expected_output_cols,
|
571
|
+
)
|
564
572
|
self._sklearn_object = fitted_estimator
|
565
573
|
self._is_fitted = True
|
566
574
|
return output_result
|
@@ -639,12 +647,41 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
639
647
|
|
640
648
|
return rv
|
641
649
|
|
642
|
-
def
|
643
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
644
|
-
) -> List[str]:
|
650
|
+
def _align_expected_output(
|
651
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
652
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
653
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
654
|
+
and output dataframe with 1 line.
|
655
|
+
If the method is fit_predict, run 2 lines of data.
|
656
|
+
"""
|
645
657
|
# in case the inferred output column names dimension is different
|
646
658
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
647
|
-
|
659
|
+
|
660
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
661
|
+
# so change the minimum of number of rows to 2
|
662
|
+
num_examples = 2
|
663
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
664
|
+
project=_PROJECT,
|
665
|
+
subproject=_SUBPROJECT,
|
666
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
667
|
+
inspect.currentframe(), GaussianProcessRegressor.__class__.__name__
|
668
|
+
),
|
669
|
+
api_calls=[Session.call],
|
670
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
671
|
+
)
|
672
|
+
if output_cols_prefix == "fit_predict_":
|
673
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
674
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
675
|
+
num_examples = self._sklearn_object.n_clusters
|
676
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
677
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
678
|
+
num_examples = self._sklearn_object.min_samples
|
679
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
680
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
681
|
+
num_examples = self._sklearn_object.n_neighbors
|
682
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
683
|
+
else:
|
684
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
648
685
|
|
649
686
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
650
687
|
# seen during the fit.
|
@@ -656,12 +693,14 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
656
693
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
657
694
|
if self.sample_weight_col:
|
658
695
|
output_df_columns_set -= set(self.sample_weight_col)
|
696
|
+
|
659
697
|
# if the dimension of inferred output column names is correct; use it
|
660
698
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
661
|
-
return expected_output_cols_list
|
699
|
+
return expected_output_cols_list, output_df_pd
|
662
700
|
# otherwise, use the sklearn estimator's output
|
663
701
|
else:
|
664
|
-
|
702
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
703
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
665
704
|
|
666
705
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
667
706
|
@telemetry.send_api_usage_telemetry(
|
@@ -707,7 +746,7 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
707
746
|
drop_input_cols=self._drop_input_cols,
|
708
747
|
expected_output_cols_type="float",
|
709
748
|
)
|
710
|
-
expected_output_cols = self.
|
749
|
+
expected_output_cols, _ = self._align_expected_output(
|
711
750
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
712
751
|
)
|
713
752
|
|
@@ -773,7 +812,7 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
773
812
|
drop_input_cols=self._drop_input_cols,
|
774
813
|
expected_output_cols_type="float",
|
775
814
|
)
|
776
|
-
expected_output_cols = self.
|
815
|
+
expected_output_cols, _ = self._align_expected_output(
|
777
816
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
778
817
|
)
|
779
818
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -836,7 +875,7 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
836
875
|
drop_input_cols=self._drop_input_cols,
|
837
876
|
expected_output_cols_type="float",
|
838
877
|
)
|
839
|
-
expected_output_cols = self.
|
878
|
+
expected_output_cols, _ = self._align_expected_output(
|
840
879
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
841
880
|
)
|
842
881
|
|
@@ -901,7 +940,7 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
901
940
|
drop_input_cols = self._drop_input_cols,
|
902
941
|
expected_output_cols_type="float",
|
903
942
|
)
|
904
|
-
expected_output_cols = self.
|
943
|
+
expected_output_cols, _ = self._align_expected_output(
|
905
944
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
906
945
|
)
|
907
946
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
25
23
|
from snowflake.ml._internal import telemetry
|
26
24
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
27
25
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
28
|
-
from snowflake.ml._internal.utils import
|
26
|
+
from snowflake.ml._internal.utils import identifier
|
29
27
|
from snowflake.snowpark import DataFrame, Session
|
30
28
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
31
29
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
32
30
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
33
|
-
ModelTransformHandlers,
|
34
31
|
BatchInferenceKwargsTypedDict,
|
35
32
|
ScoreKwargsTypedDict
|
36
33
|
)
|
@@ -599,12 +596,23 @@ class IterativeImputer(BaseTransformer):
|
|
599
596
|
autogenerated=self._autogenerated,
|
600
597
|
subproject=_SUBPROJECT,
|
601
598
|
)
|
602
|
-
|
603
|
-
|
604
|
-
expected_output_cols_list=(
|
605
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
606
|
-
),
|
599
|
+
expected_output_cols = (
|
600
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
607
601
|
)
|
602
|
+
if isinstance(dataset, DataFrame):
|
603
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
604
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
605
|
+
)
|
606
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
607
|
+
drop_input_cols=self._drop_input_cols,
|
608
|
+
expected_output_cols_list=expected_output_cols,
|
609
|
+
example_output_pd_df=example_output_pd_df,
|
610
|
+
)
|
611
|
+
else:
|
612
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
613
|
+
drop_input_cols=self._drop_input_cols,
|
614
|
+
expected_output_cols_list=expected_output_cols,
|
615
|
+
)
|
608
616
|
self._sklearn_object = fitted_estimator
|
609
617
|
self._is_fitted = True
|
610
618
|
return output_result
|
@@ -685,12 +693,41 @@ class IterativeImputer(BaseTransformer):
|
|
685
693
|
|
686
694
|
return rv
|
687
695
|
|
688
|
-
def
|
689
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
690
|
-
) -> List[str]:
|
696
|
+
def _align_expected_output(
|
697
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
698
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
699
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
700
|
+
and output dataframe with 1 line.
|
701
|
+
If the method is fit_predict, run 2 lines of data.
|
702
|
+
"""
|
691
703
|
# in case the inferred output column names dimension is different
|
692
704
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
693
|
-
|
705
|
+
|
706
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
707
|
+
# so change the minimum of number of rows to 2
|
708
|
+
num_examples = 2
|
709
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
710
|
+
project=_PROJECT,
|
711
|
+
subproject=_SUBPROJECT,
|
712
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
713
|
+
inspect.currentframe(), IterativeImputer.__class__.__name__
|
714
|
+
),
|
715
|
+
api_calls=[Session.call],
|
716
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
717
|
+
)
|
718
|
+
if output_cols_prefix == "fit_predict_":
|
719
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
720
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
721
|
+
num_examples = self._sklearn_object.n_clusters
|
722
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
723
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
724
|
+
num_examples = self._sklearn_object.min_samples
|
725
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
726
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
727
|
+
num_examples = self._sklearn_object.n_neighbors
|
728
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
729
|
+
else:
|
730
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
694
731
|
|
695
732
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
696
733
|
# seen during the fit.
|
@@ -702,12 +739,14 @@ class IterativeImputer(BaseTransformer):
|
|
702
739
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
703
740
|
if self.sample_weight_col:
|
704
741
|
output_df_columns_set -= set(self.sample_weight_col)
|
742
|
+
|
705
743
|
# if the dimension of inferred output column names is correct; use it
|
706
744
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
707
|
-
return expected_output_cols_list
|
745
|
+
return expected_output_cols_list, output_df_pd
|
708
746
|
# otherwise, use the sklearn estimator's output
|
709
747
|
else:
|
710
|
-
|
748
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
749
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
711
750
|
|
712
751
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
713
752
|
@telemetry.send_api_usage_telemetry(
|
@@ -753,7 +792,7 @@ class IterativeImputer(BaseTransformer):
|
|
753
792
|
drop_input_cols=self._drop_input_cols,
|
754
793
|
expected_output_cols_type="float",
|
755
794
|
)
|
756
|
-
expected_output_cols = self.
|
795
|
+
expected_output_cols, _ = self._align_expected_output(
|
757
796
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
758
797
|
)
|
759
798
|
|
@@ -819,7 +858,7 @@ class IterativeImputer(BaseTransformer):
|
|
819
858
|
drop_input_cols=self._drop_input_cols,
|
820
859
|
expected_output_cols_type="float",
|
821
860
|
)
|
822
|
-
expected_output_cols = self.
|
861
|
+
expected_output_cols, _ = self._align_expected_output(
|
823
862
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
824
863
|
)
|
825
864
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -882,7 +921,7 @@ class IterativeImputer(BaseTransformer):
|
|
882
921
|
drop_input_cols=self._drop_input_cols,
|
883
922
|
expected_output_cols_type="float",
|
884
923
|
)
|
885
|
-
expected_output_cols = self.
|
924
|
+
expected_output_cols, _ = self._align_expected_output(
|
886
925
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
887
926
|
)
|
888
927
|
|
@@ -947,7 +986,7 @@ class IterativeImputer(BaseTransformer):
|
|
947
986
|
drop_input_cols = self._drop_input_cols,
|
948
987
|
expected_output_cols_type="float",
|
949
988
|
)
|
950
|
-
expected_output_cols = self.
|
989
|
+
expected_output_cols, _ = self._align_expected_output(
|
951
990
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
952
991
|
)
|
953
992
|
|