snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replac
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GaussianProcessClassifier(BaseTransformer):
|
70
64
|
r"""Gaussian process classification (GPC) based on Laplace approximation
|
71
65
|
For more details on this class, see [sklearn.gaussian_process.GaussianProcessClassifier]
|
@@ -352,20 +346,17 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
352
346
|
self,
|
353
347
|
dataset: DataFrame,
|
354
348
|
inference_method: str,
|
355
|
-
) ->
|
356
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
357
|
-
return the available package that exists in the snowflake anaconda channel
|
349
|
+
) -> None:
|
350
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
358
351
|
|
359
352
|
Args:
|
360
353
|
dataset: snowpark dataframe
|
361
354
|
inference_method: the inference method such as predict, score...
|
362
|
-
|
355
|
+
|
363
356
|
Raises:
|
364
357
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
365
358
|
SnowflakeMLException: If the session is None, raise error
|
366
359
|
|
367
|
-
Returns:
|
368
|
-
A list of available package that exists in the snowflake anaconda channel
|
369
360
|
"""
|
370
361
|
if not self._is_fitted:
|
371
362
|
raise exceptions.SnowflakeMLException(
|
@@ -383,9 +374,7 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
383
374
|
"Session must not specified for snowpark dataset."
|
384
375
|
),
|
385
376
|
)
|
386
|
-
|
387
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
388
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
377
|
+
|
389
378
|
|
390
379
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
391
380
|
@telemetry.send_api_usage_telemetry(
|
@@ -433,7 +422,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
433
422
|
|
434
423
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
435
424
|
|
436
|
-
self.
|
425
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
426
|
+
self._deps = self._get_dependencies()
|
437
427
|
assert isinstance(
|
438
428
|
dataset._session, Session
|
439
429
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -516,10 +506,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
516
506
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
517
507
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
518
508
|
|
519
|
-
self.
|
520
|
-
|
521
|
-
inference_method=inference_method,
|
522
|
-
)
|
509
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
510
|
+
self._deps = self._get_dependencies()
|
523
511
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
524
512
|
|
525
513
|
transform_kwargs = dict(
|
@@ -586,16 +574,40 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
586
574
|
self._is_fitted = True
|
587
575
|
return output_result
|
588
576
|
|
577
|
+
|
578
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
579
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
580
|
+
""" Method not supported for this class.
|
589
581
|
|
590
|
-
|
591
|
-
|
592
|
-
|
582
|
+
|
583
|
+
Raises:
|
584
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
585
|
+
|
586
|
+
Args:
|
587
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
588
|
+
Snowpark or Pandas DataFrame.
|
589
|
+
output_cols_prefix: Prefix for the response columns
|
593
590
|
Returns:
|
594
591
|
Transformed dataset.
|
595
592
|
"""
|
596
|
-
self.
|
597
|
-
|
598
|
-
|
593
|
+
self._infer_input_output_cols(dataset)
|
594
|
+
super()._check_dataset_type(dataset)
|
595
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
596
|
+
estimator=self._sklearn_object,
|
597
|
+
dataset=dataset,
|
598
|
+
input_cols=self.input_cols,
|
599
|
+
label_cols=self.label_cols,
|
600
|
+
sample_weight_col=self.sample_weight_col,
|
601
|
+
autogenerated=self._autogenerated,
|
602
|
+
subproject=_SUBPROJECT,
|
603
|
+
)
|
604
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
605
|
+
drop_input_cols=self._drop_input_cols,
|
606
|
+
expected_output_cols_list=self.output_cols,
|
607
|
+
)
|
608
|
+
self._sklearn_object = fitted_estimator
|
609
|
+
self._is_fitted = True
|
610
|
+
return output_result
|
599
611
|
|
600
612
|
|
601
613
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -688,10 +700,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
688
700
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
689
701
|
|
690
702
|
if isinstance(dataset, DataFrame):
|
691
|
-
self.
|
692
|
-
|
693
|
-
inference_method=inference_method,
|
694
|
-
)
|
703
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
704
|
+
self._deps = self._get_dependencies()
|
695
705
|
assert isinstance(
|
696
706
|
dataset._session, Session
|
697
707
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -758,10 +768,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
758
768
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
759
769
|
|
760
770
|
if isinstance(dataset, DataFrame):
|
761
|
-
self.
|
762
|
-
|
763
|
-
inference_method=inference_method,
|
764
|
-
)
|
771
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
772
|
+
self._deps = self._get_dependencies()
|
765
773
|
assert isinstance(
|
766
774
|
dataset._session, Session
|
767
775
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -823,10 +831,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
823
831
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
824
832
|
|
825
833
|
if isinstance(dataset, DataFrame):
|
826
|
-
self.
|
827
|
-
|
828
|
-
inference_method=inference_method,
|
829
|
-
)
|
834
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
835
|
+
self._deps = self._get_dependencies()
|
830
836
|
assert isinstance(
|
831
837
|
dataset._session, Session
|
832
838
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -892,10 +898,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
892
898
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
893
899
|
|
894
900
|
if isinstance(dataset, DataFrame):
|
895
|
-
self.
|
896
|
-
|
897
|
-
inference_method=inference_method,
|
898
|
-
)
|
901
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
902
|
+
self._deps = self._get_dependencies()
|
899
903
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
900
904
|
transform_kwargs = dict(
|
901
905
|
session=dataset._session,
|
@@ -959,17 +963,15 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
959
963
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
960
964
|
|
961
965
|
if isinstance(dataset, DataFrame):
|
962
|
-
self.
|
963
|
-
|
964
|
-
inference_method="score",
|
965
|
-
)
|
966
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
967
|
+
self._deps = self._get_dependencies()
|
966
968
|
selected_cols = self._get_active_columns()
|
967
969
|
if len(selected_cols) > 0:
|
968
970
|
dataset = dataset.select(selected_cols)
|
969
971
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
970
972
|
transform_kwargs = dict(
|
971
973
|
session=dataset._session,
|
972
|
-
dependencies=
|
974
|
+
dependencies=self._deps,
|
973
975
|
score_sproc_imports=['sklearn'],
|
974
976
|
)
|
975
977
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1034,11 +1036,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
1034
1036
|
|
1035
1037
|
if isinstance(dataset, DataFrame):
|
1036
1038
|
|
1037
|
-
self.
|
1038
|
-
|
1039
|
-
inference_method=inference_method,
|
1040
|
-
|
1041
|
-
)
|
1039
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1040
|
+
self._deps = self._get_dependencies()
|
1042
1041
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1043
1042
|
transform_kwargs = dict(
|
1044
1043
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replac
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GaussianProcessRegressor(BaseTransformer):
|
70
64
|
r"""Gaussian process regression (GPR)
|
71
65
|
For more details on this class, see [sklearn.gaussian_process.GaussianProcessRegressor]
|
@@ -343,20 +337,17 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
343
337
|
self,
|
344
338
|
dataset: DataFrame,
|
345
339
|
inference_method: str,
|
346
|
-
) ->
|
347
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
348
|
-
return the available package that exists in the snowflake anaconda channel
|
340
|
+
) -> None:
|
341
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
349
342
|
|
350
343
|
Args:
|
351
344
|
dataset: snowpark dataframe
|
352
345
|
inference_method: the inference method such as predict, score...
|
353
|
-
|
346
|
+
|
354
347
|
Raises:
|
355
348
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
356
349
|
SnowflakeMLException: If the session is None, raise error
|
357
350
|
|
358
|
-
Returns:
|
359
|
-
A list of available package that exists in the snowflake anaconda channel
|
360
351
|
"""
|
361
352
|
if not self._is_fitted:
|
362
353
|
raise exceptions.SnowflakeMLException(
|
@@ -374,9 +365,7 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
374
365
|
"Session must not specified for snowpark dataset."
|
375
366
|
),
|
376
367
|
)
|
377
|
-
|
378
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
379
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
368
|
+
|
380
369
|
|
381
370
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
382
371
|
@telemetry.send_api_usage_telemetry(
|
@@ -424,7 +413,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
424
413
|
|
425
414
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
426
415
|
|
427
|
-
self.
|
416
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
417
|
+
self._deps = self._get_dependencies()
|
428
418
|
assert isinstance(
|
429
419
|
dataset._session, Session
|
430
420
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -507,10 +497,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
507
497
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
508
498
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
509
499
|
|
510
|
-
self.
|
511
|
-
|
512
|
-
inference_method=inference_method,
|
513
|
-
)
|
500
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
501
|
+
self._deps = self._get_dependencies()
|
514
502
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
515
503
|
|
516
504
|
transform_kwargs = dict(
|
@@ -577,16 +565,40 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
577
565
|
self._is_fitted = True
|
578
566
|
return output_result
|
579
567
|
|
568
|
+
|
569
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
570
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
571
|
+
""" Method not supported for this class.
|
580
572
|
|
581
|
-
|
582
|
-
|
583
|
-
|
573
|
+
|
574
|
+
Raises:
|
575
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
576
|
+
|
577
|
+
Args:
|
578
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
579
|
+
Snowpark or Pandas DataFrame.
|
580
|
+
output_cols_prefix: Prefix for the response columns
|
584
581
|
Returns:
|
585
582
|
Transformed dataset.
|
586
583
|
"""
|
587
|
-
self.
|
588
|
-
|
589
|
-
|
584
|
+
self._infer_input_output_cols(dataset)
|
585
|
+
super()._check_dataset_type(dataset)
|
586
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
587
|
+
estimator=self._sklearn_object,
|
588
|
+
dataset=dataset,
|
589
|
+
input_cols=self.input_cols,
|
590
|
+
label_cols=self.label_cols,
|
591
|
+
sample_weight_col=self.sample_weight_col,
|
592
|
+
autogenerated=self._autogenerated,
|
593
|
+
subproject=_SUBPROJECT,
|
594
|
+
)
|
595
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
596
|
+
drop_input_cols=self._drop_input_cols,
|
597
|
+
expected_output_cols_list=self.output_cols,
|
598
|
+
)
|
599
|
+
self._sklearn_object = fitted_estimator
|
600
|
+
self._is_fitted = True
|
601
|
+
return output_result
|
590
602
|
|
591
603
|
|
592
604
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -677,10 +689,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
677
689
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
678
690
|
|
679
691
|
if isinstance(dataset, DataFrame):
|
680
|
-
self.
|
681
|
-
|
682
|
-
inference_method=inference_method,
|
683
|
-
)
|
692
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
693
|
+
self._deps = self._get_dependencies()
|
684
694
|
assert isinstance(
|
685
695
|
dataset._session, Session
|
686
696
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -745,10 +755,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
745
755
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
746
756
|
|
747
757
|
if isinstance(dataset, DataFrame):
|
748
|
-
self.
|
749
|
-
|
750
|
-
inference_method=inference_method,
|
751
|
-
)
|
758
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
759
|
+
self._deps = self._get_dependencies()
|
752
760
|
assert isinstance(
|
753
761
|
dataset._session, Session
|
754
762
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -810,10 +818,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
810
818
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
811
819
|
|
812
820
|
if isinstance(dataset, DataFrame):
|
813
|
-
self.
|
814
|
-
|
815
|
-
inference_method=inference_method,
|
816
|
-
)
|
821
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
822
|
+
self._deps = self._get_dependencies()
|
817
823
|
assert isinstance(
|
818
824
|
dataset._session, Session
|
819
825
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -879,10 +885,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
879
885
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
880
886
|
|
881
887
|
if isinstance(dataset, DataFrame):
|
882
|
-
self.
|
883
|
-
|
884
|
-
inference_method=inference_method,
|
885
|
-
)
|
888
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
889
|
+
self._deps = self._get_dependencies()
|
886
890
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
887
891
|
transform_kwargs = dict(
|
888
892
|
session=dataset._session,
|
@@ -946,17 +950,15 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
946
950
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
947
951
|
|
948
952
|
if isinstance(dataset, DataFrame):
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method="score",
|
952
|
-
)
|
953
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
954
|
+
self._deps = self._get_dependencies()
|
953
955
|
selected_cols = self._get_active_columns()
|
954
956
|
if len(selected_cols) > 0:
|
955
957
|
dataset = dataset.select(selected_cols)
|
956
958
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
957
959
|
transform_kwargs = dict(
|
958
960
|
session=dataset._session,
|
959
|
-
dependencies=
|
961
|
+
dependencies=self._deps,
|
960
962
|
score_sproc_imports=['sklearn'],
|
961
963
|
)
|
962
964
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,11 +1023,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
1021
1023
|
|
1022
1024
|
if isinstance(dataset, DataFrame):
|
1023
1025
|
|
1024
|
-
self.
|
1025
|
-
|
1026
|
-
inference_method=inference_method,
|
1027
|
-
|
1028
|
-
)
|
1026
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1027
|
+
self._deps = self._get_dependencies()
|
1029
1028
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1030
1029
|
transform_kwargs = dict(
|
1031
1030
|
session = dataset._session,
|
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn
|
|
61
61
|
|
62
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
63
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
64
|
class IterativeImputer(BaseTransformer):
|
71
65
|
r"""Multivariate imputer that estimates each feature from all the others
|
72
66
|
For more details on this class, see [sklearn.impute.IterativeImputer]
|
@@ -385,20 +379,17 @@ class IterativeImputer(BaseTransformer):
|
|
385
379
|
self,
|
386
380
|
dataset: DataFrame,
|
387
381
|
inference_method: str,
|
388
|
-
) ->
|
389
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
390
|
-
return the available package that exists in the snowflake anaconda channel
|
382
|
+
) -> None:
|
383
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
391
384
|
|
392
385
|
Args:
|
393
386
|
dataset: snowpark dataframe
|
394
387
|
inference_method: the inference method such as predict, score...
|
395
|
-
|
388
|
+
|
396
389
|
Raises:
|
397
390
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
398
391
|
SnowflakeMLException: If the session is None, raise error
|
399
392
|
|
400
|
-
Returns:
|
401
|
-
A list of available package that exists in the snowflake anaconda channel
|
402
393
|
"""
|
403
394
|
if not self._is_fitted:
|
404
395
|
raise exceptions.SnowflakeMLException(
|
@@ -416,9 +407,7 @@ class IterativeImputer(BaseTransformer):
|
|
416
407
|
"Session must not specified for snowpark dataset."
|
417
408
|
),
|
418
409
|
)
|
419
|
-
|
420
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
421
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
410
|
+
|
422
411
|
|
423
412
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
424
413
|
@telemetry.send_api_usage_telemetry(
|
@@ -464,7 +453,8 @@ class IterativeImputer(BaseTransformer):
|
|
464
453
|
|
465
454
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
466
455
|
|
467
|
-
self.
|
456
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
457
|
+
self._deps = self._get_dependencies()
|
468
458
|
assert isinstance(
|
469
459
|
dataset._session, Session
|
470
460
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -549,10 +539,8 @@ class IterativeImputer(BaseTransformer):
|
|
549
539
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
550
540
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
551
541
|
|
552
|
-
self.
|
553
|
-
|
554
|
-
inference_method=inference_method,
|
555
|
-
)
|
542
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
543
|
+
self._deps = self._get_dependencies()
|
556
544
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
557
545
|
|
558
546
|
transform_kwargs = dict(
|
@@ -619,16 +607,42 @@ class IterativeImputer(BaseTransformer):
|
|
619
607
|
self._is_fitted = True
|
620
608
|
return output_result
|
621
609
|
|
610
|
+
|
611
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
612
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
613
|
+
""" Fit the imputer on `X` and return the transformed `X`
|
614
|
+
For more details on this function, see [sklearn.impute.IterativeImputer.fit_transform]
|
615
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.impute.IterativeImputer.html#sklearn.impute.IterativeImputer.fit_transform)
|
616
|
+
|
622
617
|
|
623
|
-
|
624
|
-
|
625
|
-
|
618
|
+
Raises:
|
619
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
620
|
+
|
621
|
+
Args:
|
622
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
623
|
+
Snowpark or Pandas DataFrame.
|
624
|
+
output_cols_prefix: Prefix for the response columns
|
626
625
|
Returns:
|
627
626
|
Transformed dataset.
|
628
627
|
"""
|
629
|
-
self.
|
630
|
-
|
631
|
-
|
628
|
+
self._infer_input_output_cols(dataset)
|
629
|
+
super()._check_dataset_type(dataset)
|
630
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
631
|
+
estimator=self._sklearn_object,
|
632
|
+
dataset=dataset,
|
633
|
+
input_cols=self.input_cols,
|
634
|
+
label_cols=self.label_cols,
|
635
|
+
sample_weight_col=self.sample_weight_col,
|
636
|
+
autogenerated=self._autogenerated,
|
637
|
+
subproject=_SUBPROJECT,
|
638
|
+
)
|
639
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
640
|
+
drop_input_cols=self._drop_input_cols,
|
641
|
+
expected_output_cols_list=self.output_cols,
|
642
|
+
)
|
643
|
+
self._sklearn_object = fitted_estimator
|
644
|
+
self._is_fitted = True
|
645
|
+
return output_result
|
632
646
|
|
633
647
|
|
634
648
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -719,10 +733,8 @@ class IterativeImputer(BaseTransformer):
|
|
719
733
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
720
734
|
|
721
735
|
if isinstance(dataset, DataFrame):
|
722
|
-
self.
|
723
|
-
|
724
|
-
inference_method=inference_method,
|
725
|
-
)
|
736
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
737
|
+
self._deps = self._get_dependencies()
|
726
738
|
assert isinstance(
|
727
739
|
dataset._session, Session
|
728
740
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -787,10 +799,8 @@ class IterativeImputer(BaseTransformer):
|
|
787
799
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
788
800
|
|
789
801
|
if isinstance(dataset, DataFrame):
|
790
|
-
self.
|
791
|
-
|
792
|
-
inference_method=inference_method,
|
793
|
-
)
|
802
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
803
|
+
self._deps = self._get_dependencies()
|
794
804
|
assert isinstance(
|
795
805
|
dataset._session, Session
|
796
806
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -852,10 +862,8 @@ class IterativeImputer(BaseTransformer):
|
|
852
862
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
853
863
|
|
854
864
|
if isinstance(dataset, DataFrame):
|
855
|
-
self.
|
856
|
-
|
857
|
-
inference_method=inference_method,
|
858
|
-
)
|
865
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
866
|
+
self._deps = self._get_dependencies()
|
859
867
|
assert isinstance(
|
860
868
|
dataset._session, Session
|
861
869
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -921,10 +929,8 @@ class IterativeImputer(BaseTransformer):
|
|
921
929
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
922
930
|
|
923
931
|
if isinstance(dataset, DataFrame):
|
924
|
-
self.
|
925
|
-
|
926
|
-
inference_method=inference_method,
|
927
|
-
)
|
932
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
933
|
+
self._deps = self._get_dependencies()
|
928
934
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
929
935
|
transform_kwargs = dict(
|
930
936
|
session=dataset._session,
|
@@ -986,17 +992,15 @@ class IterativeImputer(BaseTransformer):
|
|
986
992
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
987
993
|
|
988
994
|
if isinstance(dataset, DataFrame):
|
989
|
-
self.
|
990
|
-
|
991
|
-
inference_method="score",
|
992
|
-
)
|
995
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
996
|
+
self._deps = self._get_dependencies()
|
993
997
|
selected_cols = self._get_active_columns()
|
994
998
|
if len(selected_cols) > 0:
|
995
999
|
dataset = dataset.select(selected_cols)
|
996
1000
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
997
1001
|
transform_kwargs = dict(
|
998
1002
|
session=dataset._session,
|
999
|
-
dependencies=
|
1003
|
+
dependencies=self._deps,
|
1000
1004
|
score_sproc_imports=['sklearn'],
|
1001
1005
|
)
|
1002
1006
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1061,11 +1065,8 @@ class IterativeImputer(BaseTransformer):
|
|
1061
1065
|
|
1062
1066
|
if isinstance(dataset, DataFrame):
|
1063
1067
|
|
1064
|
-
self.
|
1065
|
-
|
1066
|
-
inference_method=inference_method,
|
1067
|
-
|
1068
|
-
)
|
1068
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1069
|
+
self._deps = self._get_dependencies()
|
1069
1070
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1070
1071
|
transform_kwargs = dict(
|
1071
1072
|
session = dataset._session,
|