snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class IsolationForest(BaseTransformer):
|
70
64
|
r"""Isolation Forest Algorithm
|
71
65
|
For more details on this class, see [sklearn.ensemble.IsolationForest]
|
@@ -324,20 +318,17 @@ class IsolationForest(BaseTransformer):
|
|
324
318
|
self,
|
325
319
|
dataset: DataFrame,
|
326
320
|
inference_method: str,
|
327
|
-
) ->
|
328
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
329
|
-
return the available package that exists in the snowflake anaconda channel
|
321
|
+
) -> None:
|
322
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
330
323
|
|
331
324
|
Args:
|
332
325
|
dataset: snowpark dataframe
|
333
326
|
inference_method: the inference method such as predict, score...
|
334
|
-
|
327
|
+
|
335
328
|
Raises:
|
336
329
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
337
330
|
SnowflakeMLException: If the session is None, raise error
|
338
331
|
|
339
|
-
Returns:
|
340
|
-
A list of available package that exists in the snowflake anaconda channel
|
341
332
|
"""
|
342
333
|
if not self._is_fitted:
|
343
334
|
raise exceptions.SnowflakeMLException(
|
@@ -355,9 +346,7 @@ class IsolationForest(BaseTransformer):
|
|
355
346
|
"Session must not specified for snowpark dataset."
|
356
347
|
),
|
357
348
|
)
|
358
|
-
|
359
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
360
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
349
|
+
|
361
350
|
|
362
351
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
363
352
|
@telemetry.send_api_usage_telemetry(
|
@@ -405,7 +394,8 @@ class IsolationForest(BaseTransformer):
|
|
405
394
|
|
406
395
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
407
396
|
|
408
|
-
self.
|
397
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
398
|
+
self._deps = self._get_dependencies()
|
409
399
|
assert isinstance(
|
410
400
|
dataset._session, Session
|
411
401
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -488,10 +478,8 @@ class IsolationForest(BaseTransformer):
|
|
488
478
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
489
479
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
490
480
|
|
491
|
-
self.
|
492
|
-
|
493
|
-
inference_method=inference_method,
|
494
|
-
)
|
481
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
482
|
+
self._deps = self._get_dependencies()
|
495
483
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
496
484
|
|
497
485
|
transform_kwargs = dict(
|
@@ -560,16 +548,40 @@ class IsolationForest(BaseTransformer):
|
|
560
548
|
self._is_fitted = True
|
561
549
|
return output_result
|
562
550
|
|
551
|
+
|
552
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
553
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
554
|
+
""" Method not supported for this class.
|
555
|
+
|
563
556
|
|
564
|
-
|
565
|
-
|
566
|
-
|
557
|
+
Raises:
|
558
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
559
|
+
|
560
|
+
Args:
|
561
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
562
|
+
Snowpark or Pandas DataFrame.
|
563
|
+
output_cols_prefix: Prefix for the response columns
|
567
564
|
Returns:
|
568
565
|
Transformed dataset.
|
569
566
|
"""
|
570
|
-
self.
|
571
|
-
|
572
|
-
|
567
|
+
self._infer_input_output_cols(dataset)
|
568
|
+
super()._check_dataset_type(dataset)
|
569
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
570
|
+
estimator=self._sklearn_object,
|
571
|
+
dataset=dataset,
|
572
|
+
input_cols=self.input_cols,
|
573
|
+
label_cols=self.label_cols,
|
574
|
+
sample_weight_col=self.sample_weight_col,
|
575
|
+
autogenerated=self._autogenerated,
|
576
|
+
subproject=_SUBPROJECT,
|
577
|
+
)
|
578
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
579
|
+
drop_input_cols=self._drop_input_cols,
|
580
|
+
expected_output_cols_list=self.output_cols,
|
581
|
+
)
|
582
|
+
self._sklearn_object = fitted_estimator
|
583
|
+
self._is_fitted = True
|
584
|
+
return output_result
|
573
585
|
|
574
586
|
|
575
587
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -660,10 +672,8 @@ class IsolationForest(BaseTransformer):
|
|
660
672
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
661
673
|
|
662
674
|
if isinstance(dataset, DataFrame):
|
663
|
-
self.
|
664
|
-
|
665
|
-
inference_method=inference_method,
|
666
|
-
)
|
675
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
676
|
+
self._deps = self._get_dependencies()
|
667
677
|
assert isinstance(
|
668
678
|
dataset._session, Session
|
669
679
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -728,10 +738,8 @@ class IsolationForest(BaseTransformer):
|
|
728
738
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
729
739
|
|
730
740
|
if isinstance(dataset, DataFrame):
|
731
|
-
self.
|
732
|
-
|
733
|
-
inference_method=inference_method,
|
734
|
-
)
|
741
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
742
|
+
self._deps = self._get_dependencies()
|
735
743
|
assert isinstance(
|
736
744
|
dataset._session, Session
|
737
745
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -795,10 +803,8 @@ class IsolationForest(BaseTransformer):
|
|
795
803
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
796
804
|
|
797
805
|
if isinstance(dataset, DataFrame):
|
798
|
-
self.
|
799
|
-
|
800
|
-
inference_method=inference_method,
|
801
|
-
)
|
806
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
807
|
+
self._deps = self._get_dependencies()
|
802
808
|
assert isinstance(
|
803
809
|
dataset._session, Session
|
804
810
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -866,10 +872,8 @@ class IsolationForest(BaseTransformer):
|
|
866
872
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
867
873
|
|
868
874
|
if isinstance(dataset, DataFrame):
|
869
|
-
self.
|
870
|
-
|
871
|
-
inference_method=inference_method,
|
872
|
-
)
|
875
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
876
|
+
self._deps = self._get_dependencies()
|
873
877
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
874
878
|
transform_kwargs = dict(
|
875
879
|
session=dataset._session,
|
@@ -931,17 +935,15 @@ class IsolationForest(BaseTransformer):
|
|
931
935
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
932
936
|
|
933
937
|
if isinstance(dataset, DataFrame):
|
934
|
-
self.
|
935
|
-
|
936
|
-
inference_method="score",
|
937
|
-
)
|
938
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
939
|
+
self._deps = self._get_dependencies()
|
938
940
|
selected_cols = self._get_active_columns()
|
939
941
|
if len(selected_cols) > 0:
|
940
942
|
dataset = dataset.select(selected_cols)
|
941
943
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
942
944
|
transform_kwargs = dict(
|
943
945
|
session=dataset._session,
|
944
|
-
dependencies=
|
946
|
+
dependencies=self._deps,
|
945
947
|
score_sproc_imports=['sklearn'],
|
946
948
|
)
|
947
949
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1006,11 +1008,8 @@ class IsolationForest(BaseTransformer):
|
|
1006
1008
|
|
1007
1009
|
if isinstance(dataset, DataFrame):
|
1008
1010
|
|
1009
|
-
self.
|
1010
|
-
|
1011
|
-
inference_method=inference_method,
|
1012
|
-
|
1013
|
-
)
|
1011
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1012
|
+
self._deps = self._get_dependencies()
|
1014
1013
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1015
1014
|
transform_kwargs = dict(
|
1016
1015
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RandomForestClassifier(BaseTransformer):
|
70
64
|
r"""A random forest classifier
|
71
65
|
For more details on this class, see [sklearn.ensemble.RandomForestClassifier]
|
@@ -436,20 +430,17 @@ class RandomForestClassifier(BaseTransformer):
|
|
436
430
|
self,
|
437
431
|
dataset: DataFrame,
|
438
432
|
inference_method: str,
|
439
|
-
) ->
|
440
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
441
|
-
return the available package that exists in the snowflake anaconda channel
|
433
|
+
) -> None:
|
434
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
442
435
|
|
443
436
|
Args:
|
444
437
|
dataset: snowpark dataframe
|
445
438
|
inference_method: the inference method such as predict, score...
|
446
|
-
|
439
|
+
|
447
440
|
Raises:
|
448
441
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
449
442
|
SnowflakeMLException: If the session is None, raise error
|
450
443
|
|
451
|
-
Returns:
|
452
|
-
A list of available package that exists in the snowflake anaconda channel
|
453
444
|
"""
|
454
445
|
if not self._is_fitted:
|
455
446
|
raise exceptions.SnowflakeMLException(
|
@@ -467,9 +458,7 @@ class RandomForestClassifier(BaseTransformer):
|
|
467
458
|
"Session must not specified for snowpark dataset."
|
468
459
|
),
|
469
460
|
)
|
470
|
-
|
471
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
472
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
461
|
+
|
473
462
|
|
474
463
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
475
464
|
@telemetry.send_api_usage_telemetry(
|
@@ -517,7 +506,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
517
506
|
|
518
507
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
519
508
|
|
520
|
-
self.
|
509
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
510
|
+
self._deps = self._get_dependencies()
|
521
511
|
assert isinstance(
|
522
512
|
dataset._session, Session
|
523
513
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -600,10 +590,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
600
590
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
601
591
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
602
592
|
|
603
|
-
self.
|
604
|
-
|
605
|
-
inference_method=inference_method,
|
606
|
-
)
|
593
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
594
|
+
self._deps = self._get_dependencies()
|
607
595
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
608
596
|
|
609
597
|
transform_kwargs = dict(
|
@@ -670,16 +658,40 @@ class RandomForestClassifier(BaseTransformer):
|
|
670
658
|
self._is_fitted = True
|
671
659
|
return output_result
|
672
660
|
|
661
|
+
|
662
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
663
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
664
|
+
""" Method not supported for this class.
|
673
665
|
|
674
|
-
|
675
|
-
|
676
|
-
|
666
|
+
|
667
|
+
Raises:
|
668
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
669
|
+
|
670
|
+
Args:
|
671
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
672
|
+
Snowpark or Pandas DataFrame.
|
673
|
+
output_cols_prefix: Prefix for the response columns
|
677
674
|
Returns:
|
678
675
|
Transformed dataset.
|
679
676
|
"""
|
680
|
-
self.
|
681
|
-
|
682
|
-
|
677
|
+
self._infer_input_output_cols(dataset)
|
678
|
+
super()._check_dataset_type(dataset)
|
679
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
680
|
+
estimator=self._sklearn_object,
|
681
|
+
dataset=dataset,
|
682
|
+
input_cols=self.input_cols,
|
683
|
+
label_cols=self.label_cols,
|
684
|
+
sample_weight_col=self.sample_weight_col,
|
685
|
+
autogenerated=self._autogenerated,
|
686
|
+
subproject=_SUBPROJECT,
|
687
|
+
)
|
688
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
689
|
+
drop_input_cols=self._drop_input_cols,
|
690
|
+
expected_output_cols_list=self.output_cols,
|
691
|
+
)
|
692
|
+
self._sklearn_object = fitted_estimator
|
693
|
+
self._is_fitted = True
|
694
|
+
return output_result
|
683
695
|
|
684
696
|
|
685
697
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -772,10 +784,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
772
784
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
773
785
|
|
774
786
|
if isinstance(dataset, DataFrame):
|
775
|
-
self.
|
776
|
-
|
777
|
-
inference_method=inference_method,
|
778
|
-
)
|
787
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
788
|
+
self._deps = self._get_dependencies()
|
779
789
|
assert isinstance(
|
780
790
|
dataset._session, Session
|
781
791
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -842,10 +852,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
842
852
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
843
853
|
|
844
854
|
if isinstance(dataset, DataFrame):
|
845
|
-
self.
|
846
|
-
|
847
|
-
inference_method=inference_method,
|
848
|
-
)
|
855
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
856
|
+
self._deps = self._get_dependencies()
|
849
857
|
assert isinstance(
|
850
858
|
dataset._session, Session
|
851
859
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -907,10 +915,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
907
915
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
908
916
|
|
909
917
|
if isinstance(dataset, DataFrame):
|
910
|
-
self.
|
911
|
-
|
912
|
-
inference_method=inference_method,
|
913
|
-
)
|
918
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
919
|
+
self._deps = self._get_dependencies()
|
914
920
|
assert isinstance(
|
915
921
|
dataset._session, Session
|
916
922
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -976,10 +982,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
976
982
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
977
983
|
|
978
984
|
if isinstance(dataset, DataFrame):
|
979
|
-
self.
|
980
|
-
|
981
|
-
inference_method=inference_method,
|
982
|
-
)
|
985
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
986
|
+
self._deps = self._get_dependencies()
|
983
987
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
984
988
|
transform_kwargs = dict(
|
985
989
|
session=dataset._session,
|
@@ -1043,17 +1047,15 @@ class RandomForestClassifier(BaseTransformer):
|
|
1043
1047
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1044
1048
|
|
1045
1049
|
if isinstance(dataset, DataFrame):
|
1046
|
-
self.
|
1047
|
-
|
1048
|
-
inference_method="score",
|
1049
|
-
)
|
1050
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1051
|
+
self._deps = self._get_dependencies()
|
1050
1052
|
selected_cols = self._get_active_columns()
|
1051
1053
|
if len(selected_cols) > 0:
|
1052
1054
|
dataset = dataset.select(selected_cols)
|
1053
1055
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1054
1056
|
transform_kwargs = dict(
|
1055
1057
|
session=dataset._session,
|
1056
|
-
dependencies=
|
1058
|
+
dependencies=self._deps,
|
1057
1059
|
score_sproc_imports=['sklearn'],
|
1058
1060
|
)
|
1059
1061
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1118,11 +1120,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
1118
1120
|
|
1119
1121
|
if isinstance(dataset, DataFrame):
|
1120
1122
|
|
1121
|
-
self.
|
1122
|
-
|
1123
|
-
inference_method=inference_method,
|
1124
|
-
|
1125
|
-
)
|
1123
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1124
|
+
self._deps = self._get_dependencies()
|
1126
1125
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1127
1126
|
transform_kwargs = dict(
|
1128
1127
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RandomForestRegressor(BaseTransformer):
|
70
64
|
r"""A random forest regressor
|
71
65
|
For more details on this class, see [sklearn.ensemble.RandomForestRegressor]
|
@@ -415,20 +409,17 @@ class RandomForestRegressor(BaseTransformer):
|
|
415
409
|
self,
|
416
410
|
dataset: DataFrame,
|
417
411
|
inference_method: str,
|
418
|
-
) ->
|
419
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
420
|
-
return the available package that exists in the snowflake anaconda channel
|
412
|
+
) -> None:
|
413
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
421
414
|
|
422
415
|
Args:
|
423
416
|
dataset: snowpark dataframe
|
424
417
|
inference_method: the inference method such as predict, score...
|
425
|
-
|
418
|
+
|
426
419
|
Raises:
|
427
420
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
428
421
|
SnowflakeMLException: If the session is None, raise error
|
429
422
|
|
430
|
-
Returns:
|
431
|
-
A list of available package that exists in the snowflake anaconda channel
|
432
423
|
"""
|
433
424
|
if not self._is_fitted:
|
434
425
|
raise exceptions.SnowflakeMLException(
|
@@ -446,9 +437,7 @@ class RandomForestRegressor(BaseTransformer):
|
|
446
437
|
"Session must not specified for snowpark dataset."
|
447
438
|
),
|
448
439
|
)
|
449
|
-
|
450
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
451
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
440
|
+
|
452
441
|
|
453
442
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
454
443
|
@telemetry.send_api_usage_telemetry(
|
@@ -496,7 +485,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
496
485
|
|
497
486
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
498
487
|
|
499
|
-
self.
|
488
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
489
|
+
self._deps = self._get_dependencies()
|
500
490
|
assert isinstance(
|
501
491
|
dataset._session, Session
|
502
492
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -579,10 +569,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
579
569
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
580
570
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
581
571
|
|
582
|
-
self.
|
583
|
-
|
584
|
-
inference_method=inference_method,
|
585
|
-
)
|
572
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
573
|
+
self._deps = self._get_dependencies()
|
586
574
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
587
575
|
|
588
576
|
transform_kwargs = dict(
|
@@ -649,16 +637,40 @@ class RandomForestRegressor(BaseTransformer):
|
|
649
637
|
self._is_fitted = True
|
650
638
|
return output_result
|
651
639
|
|
640
|
+
|
641
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
642
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
643
|
+
""" Method not supported for this class.
|
652
644
|
|
653
|
-
|
654
|
-
|
655
|
-
|
645
|
+
|
646
|
+
Raises:
|
647
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
651
|
+
Snowpark or Pandas DataFrame.
|
652
|
+
output_cols_prefix: Prefix for the response columns
|
656
653
|
Returns:
|
657
654
|
Transformed dataset.
|
658
655
|
"""
|
659
|
-
self.
|
660
|
-
|
661
|
-
|
656
|
+
self._infer_input_output_cols(dataset)
|
657
|
+
super()._check_dataset_type(dataset)
|
658
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
659
|
+
estimator=self._sklearn_object,
|
660
|
+
dataset=dataset,
|
661
|
+
input_cols=self.input_cols,
|
662
|
+
label_cols=self.label_cols,
|
663
|
+
sample_weight_col=self.sample_weight_col,
|
664
|
+
autogenerated=self._autogenerated,
|
665
|
+
subproject=_SUBPROJECT,
|
666
|
+
)
|
667
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
668
|
+
drop_input_cols=self._drop_input_cols,
|
669
|
+
expected_output_cols_list=self.output_cols,
|
670
|
+
)
|
671
|
+
self._sklearn_object = fitted_estimator
|
672
|
+
self._is_fitted = True
|
673
|
+
return output_result
|
662
674
|
|
663
675
|
|
664
676
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -749,10 +761,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
749
761
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
750
762
|
|
751
763
|
if isinstance(dataset, DataFrame):
|
752
|
-
self.
|
753
|
-
|
754
|
-
inference_method=inference_method,
|
755
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
756
766
|
assert isinstance(
|
757
767
|
dataset._session, Session
|
758
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -817,10 +827,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
817
827
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
818
828
|
|
819
829
|
if isinstance(dataset, DataFrame):
|
820
|
-
self.
|
821
|
-
|
822
|
-
inference_method=inference_method,
|
823
|
-
)
|
830
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
831
|
+
self._deps = self._get_dependencies()
|
824
832
|
assert isinstance(
|
825
833
|
dataset._session, Session
|
826
834
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -882,10 +890,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
882
890
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
883
891
|
|
884
892
|
if isinstance(dataset, DataFrame):
|
885
|
-
self.
|
886
|
-
|
887
|
-
inference_method=inference_method,
|
888
|
-
)
|
893
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
894
|
+
self._deps = self._get_dependencies()
|
889
895
|
assert isinstance(
|
890
896
|
dataset._session, Session
|
891
897
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -951,10 +957,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
951
957
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
952
958
|
|
953
959
|
if isinstance(dataset, DataFrame):
|
954
|
-
self.
|
955
|
-
|
956
|
-
inference_method=inference_method,
|
957
|
-
)
|
960
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
961
|
+
self._deps = self._get_dependencies()
|
958
962
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
959
963
|
transform_kwargs = dict(
|
960
964
|
session=dataset._session,
|
@@ -1018,17 +1022,15 @@ class RandomForestRegressor(BaseTransformer):
|
|
1018
1022
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1019
1023
|
|
1020
1024
|
if isinstance(dataset, DataFrame):
|
1021
|
-
self.
|
1022
|
-
|
1023
|
-
inference_method="score",
|
1024
|
-
)
|
1025
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1026
|
+
self._deps = self._get_dependencies()
|
1025
1027
|
selected_cols = self._get_active_columns()
|
1026
1028
|
if len(selected_cols) > 0:
|
1027
1029
|
dataset = dataset.select(selected_cols)
|
1028
1030
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1029
1031
|
transform_kwargs = dict(
|
1030
1032
|
session=dataset._session,
|
1031
|
-
dependencies=
|
1033
|
+
dependencies=self._deps,
|
1032
1034
|
score_sproc_imports=['sklearn'],
|
1033
1035
|
)
|
1034
1036
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1093,11 +1095,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
1093
1095
|
|
1094
1096
|
if isinstance(dataset, DataFrame):
|
1095
1097
|
|
1096
|
-
self.
|
1097
|
-
|
1098
|
-
inference_method=inference_method,
|
1099
|
-
|
1100
|
-
)
|
1098
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1099
|
+
self._deps = self._get_dependencies()
|
1101
1100
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1102
1101
|
transform_kwargs = dict(
|
1103
1102
|
session = dataset._session,
|