snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SpectralCoclustering(BaseTransformer):
|
70
64
|
r"""Spectral Co-Clustering algorithm (Dhillon, 2001)
|
71
65
|
For more details on this class, see [sklearn.cluster.SpectralCoclustering]
|
@@ -301,20 +295,17 @@ class SpectralCoclustering(BaseTransformer):
|
|
301
295
|
self,
|
302
296
|
dataset: DataFrame,
|
303
297
|
inference_method: str,
|
304
|
-
) ->
|
305
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
306
|
-
return the available package that exists in the snowflake anaconda channel
|
298
|
+
) -> None:
|
299
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
307
300
|
|
308
301
|
Args:
|
309
302
|
dataset: snowpark dataframe
|
310
303
|
inference_method: the inference method such as predict, score...
|
311
|
-
|
304
|
+
|
312
305
|
Raises:
|
313
306
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
314
307
|
SnowflakeMLException: If the session is None, raise error
|
315
308
|
|
316
|
-
Returns:
|
317
|
-
A list of available package that exists in the snowflake anaconda channel
|
318
309
|
"""
|
319
310
|
if not self._is_fitted:
|
320
311
|
raise exceptions.SnowflakeMLException(
|
@@ -332,9 +323,7 @@ class SpectralCoclustering(BaseTransformer):
|
|
332
323
|
"Session must not specified for snowpark dataset."
|
333
324
|
),
|
334
325
|
)
|
335
|
-
|
336
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
337
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
326
|
+
|
338
327
|
|
339
328
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
340
329
|
@telemetry.send_api_usage_telemetry(
|
@@ -380,7 +369,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
380
369
|
|
381
370
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
382
371
|
|
383
|
-
self.
|
372
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
373
|
+
self._deps = self._get_dependencies()
|
384
374
|
assert isinstance(
|
385
375
|
dataset._session, Session
|
386
376
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -463,10 +453,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
463
453
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
464
454
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
465
455
|
|
466
|
-
self.
|
467
|
-
|
468
|
-
inference_method=inference_method,
|
469
|
-
)
|
456
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
457
|
+
self._deps = self._get_dependencies()
|
470
458
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
471
459
|
|
472
460
|
transform_kwargs = dict(
|
@@ -533,16 +521,40 @@ class SpectralCoclustering(BaseTransformer):
|
|
533
521
|
self._is_fitted = True
|
534
522
|
return output_result
|
535
523
|
|
524
|
+
|
525
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
526
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
527
|
+
""" Method not supported for this class.
|
536
528
|
|
537
|
-
|
538
|
-
|
539
|
-
|
529
|
+
|
530
|
+
Raises:
|
531
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
535
|
+
Snowpark or Pandas DataFrame.
|
536
|
+
output_cols_prefix: Prefix for the response columns
|
540
537
|
Returns:
|
541
538
|
Transformed dataset.
|
542
539
|
"""
|
543
|
-
self.
|
544
|
-
|
545
|
-
|
540
|
+
self._infer_input_output_cols(dataset)
|
541
|
+
super()._check_dataset_type(dataset)
|
542
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
543
|
+
estimator=self._sklearn_object,
|
544
|
+
dataset=dataset,
|
545
|
+
input_cols=self.input_cols,
|
546
|
+
label_cols=self.label_cols,
|
547
|
+
sample_weight_col=self.sample_weight_col,
|
548
|
+
autogenerated=self._autogenerated,
|
549
|
+
subproject=_SUBPROJECT,
|
550
|
+
)
|
551
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
552
|
+
drop_input_cols=self._drop_input_cols,
|
553
|
+
expected_output_cols_list=self.output_cols,
|
554
|
+
)
|
555
|
+
self._sklearn_object = fitted_estimator
|
556
|
+
self._is_fitted = True
|
557
|
+
return output_result
|
546
558
|
|
547
559
|
|
548
560
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -633,10 +645,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
633
645
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
634
646
|
|
635
647
|
if isinstance(dataset, DataFrame):
|
636
|
-
self.
|
637
|
-
|
638
|
-
inference_method=inference_method,
|
639
|
-
)
|
648
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
649
|
+
self._deps = self._get_dependencies()
|
640
650
|
assert isinstance(
|
641
651
|
dataset._session, Session
|
642
652
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -701,10 +711,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
701
711
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
702
712
|
|
703
713
|
if isinstance(dataset, DataFrame):
|
704
|
-
self.
|
705
|
-
|
706
|
-
inference_method=inference_method,
|
707
|
-
)
|
714
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
715
|
+
self._deps = self._get_dependencies()
|
708
716
|
assert isinstance(
|
709
717
|
dataset._session, Session
|
710
718
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -766,10 +774,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
766
774
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
767
775
|
|
768
776
|
if isinstance(dataset, DataFrame):
|
769
|
-
self.
|
770
|
-
|
771
|
-
inference_method=inference_method,
|
772
|
-
)
|
777
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
778
|
+
self._deps = self._get_dependencies()
|
773
779
|
assert isinstance(
|
774
780
|
dataset._session, Session
|
775
781
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -835,10 +841,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
835
841
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
836
842
|
|
837
843
|
if isinstance(dataset, DataFrame):
|
838
|
-
self.
|
839
|
-
|
840
|
-
inference_method=inference_method,
|
841
|
-
)
|
844
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
845
|
+
self._deps = self._get_dependencies()
|
842
846
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
843
847
|
transform_kwargs = dict(
|
844
848
|
session=dataset._session,
|
@@ -900,17 +904,15 @@ class SpectralCoclustering(BaseTransformer):
|
|
900
904
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
901
905
|
|
902
906
|
if isinstance(dataset, DataFrame):
|
903
|
-
self.
|
904
|
-
|
905
|
-
inference_method="score",
|
906
|
-
)
|
907
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
908
|
+
self._deps = self._get_dependencies()
|
907
909
|
selected_cols = self._get_active_columns()
|
908
910
|
if len(selected_cols) > 0:
|
909
911
|
dataset = dataset.select(selected_cols)
|
910
912
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
911
913
|
transform_kwargs = dict(
|
912
914
|
session=dataset._session,
|
913
|
-
dependencies=
|
915
|
+
dependencies=self._deps,
|
914
916
|
score_sproc_imports=['sklearn'],
|
915
917
|
)
|
916
918
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -975,11 +977,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
975
977
|
|
976
978
|
if isinstance(dataset, DataFrame):
|
977
979
|
|
978
|
-
self.
|
979
|
-
|
980
|
-
inference_method=inference_method,
|
981
|
-
|
982
|
-
)
|
980
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
981
|
+
self._deps = self._get_dependencies()
|
983
982
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
984
983
|
transform_kwargs = dict(
|
985
984
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ColumnTransformer(BaseTransformer):
|
70
64
|
r"""Applies transformers to columns of an array or pandas DataFrame
|
71
65
|
For more details on this class, see [sklearn.compose.ColumnTransformer]
|
@@ -331,20 +325,17 @@ class ColumnTransformer(BaseTransformer):
|
|
331
325
|
self,
|
332
326
|
dataset: DataFrame,
|
333
327
|
inference_method: str,
|
334
|
-
) ->
|
335
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
336
|
-
return the available package that exists in the snowflake anaconda channel
|
328
|
+
) -> None:
|
329
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
337
330
|
|
338
331
|
Args:
|
339
332
|
dataset: snowpark dataframe
|
340
333
|
inference_method: the inference method such as predict, score...
|
341
|
-
|
334
|
+
|
342
335
|
Raises:
|
343
336
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
344
337
|
SnowflakeMLException: If the session is None, raise error
|
345
338
|
|
346
|
-
Returns:
|
347
|
-
A list of available package that exists in the snowflake anaconda channel
|
348
339
|
"""
|
349
340
|
if not self._is_fitted:
|
350
341
|
raise exceptions.SnowflakeMLException(
|
@@ -362,9 +353,7 @@ class ColumnTransformer(BaseTransformer):
|
|
362
353
|
"Session must not specified for snowpark dataset."
|
363
354
|
),
|
364
355
|
)
|
365
|
-
|
366
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
367
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
356
|
+
|
368
357
|
|
369
358
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
370
359
|
@telemetry.send_api_usage_telemetry(
|
@@ -410,7 +399,8 @@ class ColumnTransformer(BaseTransformer):
|
|
410
399
|
|
411
400
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
412
401
|
|
413
|
-
self.
|
402
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
403
|
+
self._deps = self._get_dependencies()
|
414
404
|
assert isinstance(
|
415
405
|
dataset._session, Session
|
416
406
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -495,10 +485,8 @@ class ColumnTransformer(BaseTransformer):
|
|
495
485
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
496
486
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
497
487
|
|
498
|
-
self.
|
499
|
-
|
500
|
-
inference_method=inference_method,
|
501
|
-
)
|
488
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
489
|
+
self._deps = self._get_dependencies()
|
502
490
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
503
491
|
|
504
492
|
transform_kwargs = dict(
|
@@ -565,16 +553,42 @@ class ColumnTransformer(BaseTransformer):
|
|
565
553
|
self._is_fitted = True
|
566
554
|
return output_result
|
567
555
|
|
556
|
+
|
557
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
558
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
559
|
+
""" Fit all transformers, transform the data and concatenate results
|
560
|
+
For more details on this function, see [sklearn.compose.ColumnTransformer.fit_transform]
|
561
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer.fit_transform)
|
562
|
+
|
568
563
|
|
569
|
-
|
570
|
-
|
571
|
-
|
564
|
+
Raises:
|
565
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
566
|
+
|
567
|
+
Args:
|
568
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
569
|
+
Snowpark or Pandas DataFrame.
|
570
|
+
output_cols_prefix: Prefix for the response columns
|
572
571
|
Returns:
|
573
572
|
Transformed dataset.
|
574
573
|
"""
|
575
|
-
self.
|
576
|
-
|
577
|
-
|
574
|
+
self._infer_input_output_cols(dataset)
|
575
|
+
super()._check_dataset_type(dataset)
|
576
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
577
|
+
estimator=self._sklearn_object,
|
578
|
+
dataset=dataset,
|
579
|
+
input_cols=self.input_cols,
|
580
|
+
label_cols=self.label_cols,
|
581
|
+
sample_weight_col=self.sample_weight_col,
|
582
|
+
autogenerated=self._autogenerated,
|
583
|
+
subproject=_SUBPROJECT,
|
584
|
+
)
|
585
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
586
|
+
drop_input_cols=self._drop_input_cols,
|
587
|
+
expected_output_cols_list=self.output_cols,
|
588
|
+
)
|
589
|
+
self._sklearn_object = fitted_estimator
|
590
|
+
self._is_fitted = True
|
591
|
+
return output_result
|
578
592
|
|
579
593
|
|
580
594
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -665,10 +679,8 @@ class ColumnTransformer(BaseTransformer):
|
|
665
679
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
666
680
|
|
667
681
|
if isinstance(dataset, DataFrame):
|
668
|
-
self.
|
669
|
-
|
670
|
-
inference_method=inference_method,
|
671
|
-
)
|
682
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
683
|
+
self._deps = self._get_dependencies()
|
672
684
|
assert isinstance(
|
673
685
|
dataset._session, Session
|
674
686
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -733,10 +745,8 @@ class ColumnTransformer(BaseTransformer):
|
|
733
745
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
734
746
|
|
735
747
|
if isinstance(dataset, DataFrame):
|
736
|
-
self.
|
737
|
-
|
738
|
-
inference_method=inference_method,
|
739
|
-
)
|
748
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
749
|
+
self._deps = self._get_dependencies()
|
740
750
|
assert isinstance(
|
741
751
|
dataset._session, Session
|
742
752
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -798,10 +808,8 @@ class ColumnTransformer(BaseTransformer):
|
|
798
808
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
799
809
|
|
800
810
|
if isinstance(dataset, DataFrame):
|
801
|
-
self.
|
802
|
-
|
803
|
-
inference_method=inference_method,
|
804
|
-
)
|
811
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
812
|
+
self._deps = self._get_dependencies()
|
805
813
|
assert isinstance(
|
806
814
|
dataset._session, Session
|
807
815
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -867,10 +875,8 @@ class ColumnTransformer(BaseTransformer):
|
|
867
875
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
868
876
|
|
869
877
|
if isinstance(dataset, DataFrame):
|
870
|
-
self.
|
871
|
-
|
872
|
-
inference_method=inference_method,
|
873
|
-
)
|
878
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
879
|
+
self._deps = self._get_dependencies()
|
874
880
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
875
881
|
transform_kwargs = dict(
|
876
882
|
session=dataset._session,
|
@@ -932,17 +938,15 @@ class ColumnTransformer(BaseTransformer):
|
|
932
938
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
933
939
|
|
934
940
|
if isinstance(dataset, DataFrame):
|
935
|
-
self.
|
936
|
-
|
937
|
-
inference_method="score",
|
938
|
-
)
|
941
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
942
|
+
self._deps = self._get_dependencies()
|
939
943
|
selected_cols = self._get_active_columns()
|
940
944
|
if len(selected_cols) > 0:
|
941
945
|
dataset = dataset.select(selected_cols)
|
942
946
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
943
947
|
transform_kwargs = dict(
|
944
948
|
session=dataset._session,
|
945
|
-
dependencies=
|
949
|
+
dependencies=self._deps,
|
946
950
|
score_sproc_imports=['sklearn'],
|
947
951
|
)
|
948
952
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1007,11 +1011,8 @@ class ColumnTransformer(BaseTransformer):
|
|
1007
1011
|
|
1008
1012
|
if isinstance(dataset, DataFrame):
|
1009
1013
|
|
1010
|
-
self.
|
1011
|
-
|
1012
|
-
inference_method=inference_method,
|
1013
|
-
|
1014
|
-
)
|
1014
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1015
|
+
self._deps = self._get_dependencies()
|
1015
1016
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1016
1017
|
transform_kwargs = dict(
|
1017
1018
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class TransformedTargetRegressor(BaseTransformer):
|
70
64
|
r"""Meta-estimator to regress on a transformed target
|
71
65
|
For more details on this class, see [sklearn.compose.TransformedTargetRegressor]
|
@@ -292,20 +286,17 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
292
286
|
self,
|
293
287
|
dataset: DataFrame,
|
294
288
|
inference_method: str,
|
295
|
-
) ->
|
296
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
297
|
-
return the available package that exists in the snowflake anaconda channel
|
289
|
+
) -> None:
|
290
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
298
291
|
|
299
292
|
Args:
|
300
293
|
dataset: snowpark dataframe
|
301
294
|
inference_method: the inference method such as predict, score...
|
302
|
-
|
295
|
+
|
303
296
|
Raises:
|
304
297
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
305
298
|
SnowflakeMLException: If the session is None, raise error
|
306
299
|
|
307
|
-
Returns:
|
308
|
-
A list of available package that exists in the snowflake anaconda channel
|
309
300
|
"""
|
310
301
|
if not self._is_fitted:
|
311
302
|
raise exceptions.SnowflakeMLException(
|
@@ -323,9 +314,7 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
323
314
|
"Session must not specified for snowpark dataset."
|
324
315
|
),
|
325
316
|
)
|
326
|
-
|
327
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
328
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
317
|
+
|
329
318
|
|
330
319
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
331
320
|
@telemetry.send_api_usage_telemetry(
|
@@ -373,7 +362,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
373
362
|
|
374
363
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
375
364
|
|
376
|
-
self.
|
365
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
366
|
+
self._deps = self._get_dependencies()
|
377
367
|
assert isinstance(
|
378
368
|
dataset._session, Session
|
379
369
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -456,10 +446,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
456
446
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
457
447
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
458
448
|
|
459
|
-
self.
|
460
|
-
|
461
|
-
inference_method=inference_method,
|
462
|
-
)
|
449
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
450
|
+
self._deps = self._get_dependencies()
|
463
451
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
464
452
|
|
465
453
|
transform_kwargs = dict(
|
@@ -526,16 +514,40 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
526
514
|
self._is_fitted = True
|
527
515
|
return output_result
|
528
516
|
|
517
|
+
|
518
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
519
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
520
|
+
""" Method not supported for this class.
|
529
521
|
|
530
|
-
|
531
|
-
|
532
|
-
|
522
|
+
|
523
|
+
Raises:
|
524
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
528
|
+
Snowpark or Pandas DataFrame.
|
529
|
+
output_cols_prefix: Prefix for the response columns
|
533
530
|
Returns:
|
534
531
|
Transformed dataset.
|
535
532
|
"""
|
536
|
-
self.
|
537
|
-
|
538
|
-
|
533
|
+
self._infer_input_output_cols(dataset)
|
534
|
+
super()._check_dataset_type(dataset)
|
535
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
536
|
+
estimator=self._sklearn_object,
|
537
|
+
dataset=dataset,
|
538
|
+
input_cols=self.input_cols,
|
539
|
+
label_cols=self.label_cols,
|
540
|
+
sample_weight_col=self.sample_weight_col,
|
541
|
+
autogenerated=self._autogenerated,
|
542
|
+
subproject=_SUBPROJECT,
|
543
|
+
)
|
544
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
545
|
+
drop_input_cols=self._drop_input_cols,
|
546
|
+
expected_output_cols_list=self.output_cols,
|
547
|
+
)
|
548
|
+
self._sklearn_object = fitted_estimator
|
549
|
+
self._is_fitted = True
|
550
|
+
return output_result
|
539
551
|
|
540
552
|
|
541
553
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -626,10 +638,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
626
638
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
627
639
|
|
628
640
|
if isinstance(dataset, DataFrame):
|
629
|
-
self.
|
630
|
-
|
631
|
-
inference_method=inference_method,
|
632
|
-
)
|
641
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
642
|
+
self._deps = self._get_dependencies()
|
633
643
|
assert isinstance(
|
634
644
|
dataset._session, Session
|
635
645
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -694,10 +704,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
694
704
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
695
705
|
|
696
706
|
if isinstance(dataset, DataFrame):
|
697
|
-
self.
|
698
|
-
|
699
|
-
inference_method=inference_method,
|
700
|
-
)
|
707
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
708
|
+
self._deps = self._get_dependencies()
|
701
709
|
assert isinstance(
|
702
710
|
dataset._session, Session
|
703
711
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -759,10 +767,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
759
767
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
760
768
|
|
761
769
|
if isinstance(dataset, DataFrame):
|
762
|
-
self.
|
763
|
-
|
764
|
-
inference_method=inference_method,
|
765
|
-
)
|
770
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
771
|
+
self._deps = self._get_dependencies()
|
766
772
|
assert isinstance(
|
767
773
|
dataset._session, Session
|
768
774
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -828,10 +834,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
828
834
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
829
835
|
|
830
836
|
if isinstance(dataset, DataFrame):
|
831
|
-
self.
|
832
|
-
|
833
|
-
inference_method=inference_method,
|
834
|
-
)
|
837
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
838
|
+
self._deps = self._get_dependencies()
|
835
839
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
836
840
|
transform_kwargs = dict(
|
837
841
|
session=dataset._session,
|
@@ -895,17 +899,15 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
895
899
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
896
900
|
|
897
901
|
if isinstance(dataset, DataFrame):
|
898
|
-
self.
|
899
|
-
|
900
|
-
inference_method="score",
|
901
|
-
)
|
902
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
903
|
+
self._deps = self._get_dependencies()
|
902
904
|
selected_cols = self._get_active_columns()
|
903
905
|
if len(selected_cols) > 0:
|
904
906
|
dataset = dataset.select(selected_cols)
|
905
907
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
906
908
|
transform_kwargs = dict(
|
907
909
|
session=dataset._session,
|
908
|
-
dependencies=
|
910
|
+
dependencies=self._deps,
|
909
911
|
score_sproc_imports=['sklearn'],
|
910
912
|
)
|
911
913
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -970,11 +972,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
970
972
|
|
971
973
|
if isinstance(dataset, DataFrame):
|
972
974
|
|
973
|
-
self.
|
974
|
-
|
975
|
-
inference_method=inference_method,
|
976
|
-
|
977
|
-
)
|
975
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
976
|
+
self._deps = self._get_dependencies()
|
978
977
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
979
978
|
transform_kwargs = dict(
|
980
979
|
session = dataset._session,
|