snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SGDRegressor(BaseTransformer):
|
70
64
|
r"""Linear model fitted by minimizing a regularized empirical loss with SGD
|
71
65
|
For more details on this class, see [sklearn.linear_model.SGDRegressor]
|
@@ -415,20 +409,17 @@ class SGDRegressor(BaseTransformer):
|
|
415
409
|
self,
|
416
410
|
dataset: DataFrame,
|
417
411
|
inference_method: str,
|
418
|
-
) ->
|
419
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
420
|
-
return the available package that exists in the snowflake anaconda channel
|
412
|
+
) -> None:
|
413
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
421
414
|
|
422
415
|
Args:
|
423
416
|
dataset: snowpark dataframe
|
424
417
|
inference_method: the inference method such as predict, score...
|
425
|
-
|
418
|
+
|
426
419
|
Raises:
|
427
420
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
428
421
|
SnowflakeMLException: If the session is None, raise error
|
429
422
|
|
430
|
-
Returns:
|
431
|
-
A list of available package that exists in the snowflake anaconda channel
|
432
423
|
"""
|
433
424
|
if not self._is_fitted:
|
434
425
|
raise exceptions.SnowflakeMLException(
|
@@ -446,9 +437,7 @@ class SGDRegressor(BaseTransformer):
|
|
446
437
|
"Session must not specified for snowpark dataset."
|
447
438
|
),
|
448
439
|
)
|
449
|
-
|
450
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
451
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
440
|
+
|
452
441
|
|
453
442
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
454
443
|
@telemetry.send_api_usage_telemetry(
|
@@ -496,7 +485,8 @@ class SGDRegressor(BaseTransformer):
|
|
496
485
|
|
497
486
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
498
487
|
|
499
|
-
self.
|
488
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
489
|
+
self._deps = self._get_dependencies()
|
500
490
|
assert isinstance(
|
501
491
|
dataset._session, Session
|
502
492
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -579,10 +569,8 @@ class SGDRegressor(BaseTransformer):
|
|
579
569
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
580
570
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
581
571
|
|
582
|
-
self.
|
583
|
-
|
584
|
-
inference_method=inference_method,
|
585
|
-
)
|
572
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
573
|
+
self._deps = self._get_dependencies()
|
586
574
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
587
575
|
|
588
576
|
transform_kwargs = dict(
|
@@ -649,16 +637,40 @@ class SGDRegressor(BaseTransformer):
|
|
649
637
|
self._is_fitted = True
|
650
638
|
return output_result
|
651
639
|
|
640
|
+
|
641
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
642
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
643
|
+
""" Method not supported for this class.
|
652
644
|
|
653
|
-
|
654
|
-
|
655
|
-
|
645
|
+
|
646
|
+
Raises:
|
647
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
651
|
+
Snowpark or Pandas DataFrame.
|
652
|
+
output_cols_prefix: Prefix for the response columns
|
656
653
|
Returns:
|
657
654
|
Transformed dataset.
|
658
655
|
"""
|
659
|
-
self.
|
660
|
-
|
661
|
-
|
656
|
+
self._infer_input_output_cols(dataset)
|
657
|
+
super()._check_dataset_type(dataset)
|
658
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
659
|
+
estimator=self._sklearn_object,
|
660
|
+
dataset=dataset,
|
661
|
+
input_cols=self.input_cols,
|
662
|
+
label_cols=self.label_cols,
|
663
|
+
sample_weight_col=self.sample_weight_col,
|
664
|
+
autogenerated=self._autogenerated,
|
665
|
+
subproject=_SUBPROJECT,
|
666
|
+
)
|
667
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
668
|
+
drop_input_cols=self._drop_input_cols,
|
669
|
+
expected_output_cols_list=self.output_cols,
|
670
|
+
)
|
671
|
+
self._sklearn_object = fitted_estimator
|
672
|
+
self._is_fitted = True
|
673
|
+
return output_result
|
662
674
|
|
663
675
|
|
664
676
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -749,10 +761,8 @@ class SGDRegressor(BaseTransformer):
|
|
749
761
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
750
762
|
|
751
763
|
if isinstance(dataset, DataFrame):
|
752
|
-
self.
|
753
|
-
|
754
|
-
inference_method=inference_method,
|
755
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
756
766
|
assert isinstance(
|
757
767
|
dataset._session, Session
|
758
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -817,10 +827,8 @@ class SGDRegressor(BaseTransformer):
|
|
817
827
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
818
828
|
|
819
829
|
if isinstance(dataset, DataFrame):
|
820
|
-
self.
|
821
|
-
|
822
|
-
inference_method=inference_method,
|
823
|
-
)
|
830
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
831
|
+
self._deps = self._get_dependencies()
|
824
832
|
assert isinstance(
|
825
833
|
dataset._session, Session
|
826
834
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -882,10 +890,8 @@ class SGDRegressor(BaseTransformer):
|
|
882
890
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
883
891
|
|
884
892
|
if isinstance(dataset, DataFrame):
|
885
|
-
self.
|
886
|
-
|
887
|
-
inference_method=inference_method,
|
888
|
-
)
|
893
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
894
|
+
self._deps = self._get_dependencies()
|
889
895
|
assert isinstance(
|
890
896
|
dataset._session, Session
|
891
897
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -951,10 +957,8 @@ class SGDRegressor(BaseTransformer):
|
|
951
957
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
952
958
|
|
953
959
|
if isinstance(dataset, DataFrame):
|
954
|
-
self.
|
955
|
-
|
956
|
-
inference_method=inference_method,
|
957
|
-
)
|
960
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
961
|
+
self._deps = self._get_dependencies()
|
958
962
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
959
963
|
transform_kwargs = dict(
|
960
964
|
session=dataset._session,
|
@@ -1018,17 +1022,15 @@ class SGDRegressor(BaseTransformer):
|
|
1018
1022
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1019
1023
|
|
1020
1024
|
if isinstance(dataset, DataFrame):
|
1021
|
-
self.
|
1022
|
-
|
1023
|
-
inference_method="score",
|
1024
|
-
)
|
1025
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1026
|
+
self._deps = self._get_dependencies()
|
1025
1027
|
selected_cols = self._get_active_columns()
|
1026
1028
|
if len(selected_cols) > 0:
|
1027
1029
|
dataset = dataset.select(selected_cols)
|
1028
1030
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1029
1031
|
transform_kwargs = dict(
|
1030
1032
|
session=dataset._session,
|
1031
|
-
dependencies=
|
1033
|
+
dependencies=self._deps,
|
1032
1034
|
score_sproc_imports=['sklearn'],
|
1033
1035
|
)
|
1034
1036
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1093,11 +1095,8 @@ class SGDRegressor(BaseTransformer):
|
|
1093
1095
|
|
1094
1096
|
if isinstance(dataset, DataFrame):
|
1095
1097
|
|
1096
|
-
self.
|
1097
|
-
|
1098
|
-
inference_method=inference_method,
|
1099
|
-
|
1100
|
-
)
|
1098
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1099
|
+
self._deps = self._get_dependencies()
|
1101
1100
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1102
1101
|
transform_kwargs = dict(
|
1103
1102
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class TheilSenRegressor(BaseTransformer):
|
70
64
|
r"""Theil-Sen Estimator: robust multivariate regression model
|
71
65
|
For more details on this class, see [sklearn.linear_model.TheilSenRegressor]
|
@@ -317,20 +311,17 @@ class TheilSenRegressor(BaseTransformer):
|
|
317
311
|
self,
|
318
312
|
dataset: DataFrame,
|
319
313
|
inference_method: str,
|
320
|
-
) ->
|
321
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
322
|
-
return the available package that exists in the snowflake anaconda channel
|
314
|
+
) -> None:
|
315
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
323
316
|
|
324
317
|
Args:
|
325
318
|
dataset: snowpark dataframe
|
326
319
|
inference_method: the inference method such as predict, score...
|
327
|
-
|
320
|
+
|
328
321
|
Raises:
|
329
322
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
330
323
|
SnowflakeMLException: If the session is None, raise error
|
331
324
|
|
332
|
-
Returns:
|
333
|
-
A list of available package that exists in the snowflake anaconda channel
|
334
325
|
"""
|
335
326
|
if not self._is_fitted:
|
336
327
|
raise exceptions.SnowflakeMLException(
|
@@ -348,9 +339,7 @@ class TheilSenRegressor(BaseTransformer):
|
|
348
339
|
"Session must not specified for snowpark dataset."
|
349
340
|
),
|
350
341
|
)
|
351
|
-
|
352
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
353
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
342
|
+
|
354
343
|
|
355
344
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
356
345
|
@telemetry.send_api_usage_telemetry(
|
@@ -398,7 +387,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
398
387
|
|
399
388
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
400
389
|
|
401
|
-
self.
|
390
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
391
|
+
self._deps = self._get_dependencies()
|
402
392
|
assert isinstance(
|
403
393
|
dataset._session, Session
|
404
394
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -481,10 +471,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
481
471
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
482
472
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
483
473
|
|
484
|
-
self.
|
485
|
-
|
486
|
-
inference_method=inference_method,
|
487
|
-
)
|
474
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
475
|
+
self._deps = self._get_dependencies()
|
488
476
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
489
477
|
|
490
478
|
transform_kwargs = dict(
|
@@ -551,16 +539,40 @@ class TheilSenRegressor(BaseTransformer):
|
|
551
539
|
self._is_fitted = True
|
552
540
|
return output_result
|
553
541
|
|
542
|
+
|
543
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
544
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
545
|
+
""" Method not supported for this class.
|
554
546
|
|
555
|
-
|
556
|
-
|
557
|
-
|
547
|
+
|
548
|
+
Raises:
|
549
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
550
|
+
|
551
|
+
Args:
|
552
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
553
|
+
Snowpark or Pandas DataFrame.
|
554
|
+
output_cols_prefix: Prefix for the response columns
|
558
555
|
Returns:
|
559
556
|
Transformed dataset.
|
560
557
|
"""
|
561
|
-
self.
|
562
|
-
|
563
|
-
|
558
|
+
self._infer_input_output_cols(dataset)
|
559
|
+
super()._check_dataset_type(dataset)
|
560
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
561
|
+
estimator=self._sklearn_object,
|
562
|
+
dataset=dataset,
|
563
|
+
input_cols=self.input_cols,
|
564
|
+
label_cols=self.label_cols,
|
565
|
+
sample_weight_col=self.sample_weight_col,
|
566
|
+
autogenerated=self._autogenerated,
|
567
|
+
subproject=_SUBPROJECT,
|
568
|
+
)
|
569
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
570
|
+
drop_input_cols=self._drop_input_cols,
|
571
|
+
expected_output_cols_list=self.output_cols,
|
572
|
+
)
|
573
|
+
self._sklearn_object = fitted_estimator
|
574
|
+
self._is_fitted = True
|
575
|
+
return output_result
|
564
576
|
|
565
577
|
|
566
578
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -651,10 +663,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
651
663
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
652
664
|
|
653
665
|
if isinstance(dataset, DataFrame):
|
654
|
-
self.
|
655
|
-
|
656
|
-
inference_method=inference_method,
|
657
|
-
)
|
666
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
667
|
+
self._deps = self._get_dependencies()
|
658
668
|
assert isinstance(
|
659
669
|
dataset._session, Session
|
660
670
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -719,10 +729,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
719
729
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
720
730
|
|
721
731
|
if isinstance(dataset, DataFrame):
|
722
|
-
self.
|
723
|
-
|
724
|
-
inference_method=inference_method,
|
725
|
-
)
|
732
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
733
|
+
self._deps = self._get_dependencies()
|
726
734
|
assert isinstance(
|
727
735
|
dataset._session, Session
|
728
736
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -784,10 +792,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
784
792
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
785
793
|
|
786
794
|
if isinstance(dataset, DataFrame):
|
787
|
-
self.
|
788
|
-
|
789
|
-
inference_method=inference_method,
|
790
|
-
)
|
795
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
796
|
+
self._deps = self._get_dependencies()
|
791
797
|
assert isinstance(
|
792
798
|
dataset._session, Session
|
793
799
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -853,10 +859,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
853
859
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
854
860
|
|
855
861
|
if isinstance(dataset, DataFrame):
|
856
|
-
self.
|
857
|
-
|
858
|
-
inference_method=inference_method,
|
859
|
-
)
|
862
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
863
|
+
self._deps = self._get_dependencies()
|
860
864
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
861
865
|
transform_kwargs = dict(
|
862
866
|
session=dataset._session,
|
@@ -920,17 +924,15 @@ class TheilSenRegressor(BaseTransformer):
|
|
920
924
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
921
925
|
|
922
926
|
if isinstance(dataset, DataFrame):
|
923
|
-
self.
|
924
|
-
|
925
|
-
inference_method="score",
|
926
|
-
)
|
927
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
928
|
+
self._deps = self._get_dependencies()
|
927
929
|
selected_cols = self._get_active_columns()
|
928
930
|
if len(selected_cols) > 0:
|
929
931
|
dataset = dataset.select(selected_cols)
|
930
932
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
931
933
|
transform_kwargs = dict(
|
932
934
|
session=dataset._session,
|
933
|
-
dependencies=
|
935
|
+
dependencies=self._deps,
|
934
936
|
score_sproc_imports=['sklearn'],
|
935
937
|
)
|
936
938
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -995,11 +997,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
995
997
|
|
996
998
|
if isinstance(dataset, DataFrame):
|
997
999
|
|
998
|
-
self.
|
999
|
-
|
1000
|
-
inference_method=inference_method,
|
1001
|
-
|
1002
|
-
)
|
1000
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1001
|
+
self._deps = self._get_dependencies()
|
1003
1002
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1004
1003
|
transform_kwargs = dict(
|
1005
1004
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class TweedieRegressor(BaseTransformer):
|
70
64
|
r"""Generalized Linear Model with a Tweedie distribution
|
71
65
|
For more details on this class, see [sklearn.linear_model.TweedieRegressor]
|
@@ -343,20 +337,17 @@ class TweedieRegressor(BaseTransformer):
|
|
343
337
|
self,
|
344
338
|
dataset: DataFrame,
|
345
339
|
inference_method: str,
|
346
|
-
) ->
|
347
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
348
|
-
return the available package that exists in the snowflake anaconda channel
|
340
|
+
) -> None:
|
341
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
349
342
|
|
350
343
|
Args:
|
351
344
|
dataset: snowpark dataframe
|
352
345
|
inference_method: the inference method such as predict, score...
|
353
|
-
|
346
|
+
|
354
347
|
Raises:
|
355
348
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
356
349
|
SnowflakeMLException: If the session is None, raise error
|
357
350
|
|
358
|
-
Returns:
|
359
|
-
A list of available package that exists in the snowflake anaconda channel
|
360
351
|
"""
|
361
352
|
if not self._is_fitted:
|
362
353
|
raise exceptions.SnowflakeMLException(
|
@@ -374,9 +365,7 @@ class TweedieRegressor(BaseTransformer):
|
|
374
365
|
"Session must not specified for snowpark dataset."
|
375
366
|
),
|
376
367
|
)
|
377
|
-
|
378
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
379
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
368
|
+
|
380
369
|
|
381
370
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
382
371
|
@telemetry.send_api_usage_telemetry(
|
@@ -424,7 +413,8 @@ class TweedieRegressor(BaseTransformer):
|
|
424
413
|
|
425
414
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
426
415
|
|
427
|
-
self.
|
416
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
417
|
+
self._deps = self._get_dependencies()
|
428
418
|
assert isinstance(
|
429
419
|
dataset._session, Session
|
430
420
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -507,10 +497,8 @@ class TweedieRegressor(BaseTransformer):
|
|
507
497
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
508
498
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
509
499
|
|
510
|
-
self.
|
511
|
-
|
512
|
-
inference_method=inference_method,
|
513
|
-
)
|
500
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
501
|
+
self._deps = self._get_dependencies()
|
514
502
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
515
503
|
|
516
504
|
transform_kwargs = dict(
|
@@ -577,16 +565,40 @@ class TweedieRegressor(BaseTransformer):
|
|
577
565
|
self._is_fitted = True
|
578
566
|
return output_result
|
579
567
|
|
568
|
+
|
569
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
570
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
571
|
+
""" Method not supported for this class.
|
580
572
|
|
581
|
-
|
582
|
-
|
583
|
-
|
573
|
+
|
574
|
+
Raises:
|
575
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
576
|
+
|
577
|
+
Args:
|
578
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
579
|
+
Snowpark or Pandas DataFrame.
|
580
|
+
output_cols_prefix: Prefix for the response columns
|
584
581
|
Returns:
|
585
582
|
Transformed dataset.
|
586
583
|
"""
|
587
|
-
self.
|
588
|
-
|
589
|
-
|
584
|
+
self._infer_input_output_cols(dataset)
|
585
|
+
super()._check_dataset_type(dataset)
|
586
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
587
|
+
estimator=self._sklearn_object,
|
588
|
+
dataset=dataset,
|
589
|
+
input_cols=self.input_cols,
|
590
|
+
label_cols=self.label_cols,
|
591
|
+
sample_weight_col=self.sample_weight_col,
|
592
|
+
autogenerated=self._autogenerated,
|
593
|
+
subproject=_SUBPROJECT,
|
594
|
+
)
|
595
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
596
|
+
drop_input_cols=self._drop_input_cols,
|
597
|
+
expected_output_cols_list=self.output_cols,
|
598
|
+
)
|
599
|
+
self._sklearn_object = fitted_estimator
|
600
|
+
self._is_fitted = True
|
601
|
+
return output_result
|
590
602
|
|
591
603
|
|
592
604
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -677,10 +689,8 @@ class TweedieRegressor(BaseTransformer):
|
|
677
689
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
678
690
|
|
679
691
|
if isinstance(dataset, DataFrame):
|
680
|
-
self.
|
681
|
-
|
682
|
-
inference_method=inference_method,
|
683
|
-
)
|
692
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
693
|
+
self._deps = self._get_dependencies()
|
684
694
|
assert isinstance(
|
685
695
|
dataset._session, Session
|
686
696
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -745,10 +755,8 @@ class TweedieRegressor(BaseTransformer):
|
|
745
755
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
746
756
|
|
747
757
|
if isinstance(dataset, DataFrame):
|
748
|
-
self.
|
749
|
-
|
750
|
-
inference_method=inference_method,
|
751
|
-
)
|
758
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
759
|
+
self._deps = self._get_dependencies()
|
752
760
|
assert isinstance(
|
753
761
|
dataset._session, Session
|
754
762
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -810,10 +818,8 @@ class TweedieRegressor(BaseTransformer):
|
|
810
818
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
811
819
|
|
812
820
|
if isinstance(dataset, DataFrame):
|
813
|
-
self.
|
814
|
-
|
815
|
-
inference_method=inference_method,
|
816
|
-
)
|
821
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
822
|
+
self._deps = self._get_dependencies()
|
817
823
|
assert isinstance(
|
818
824
|
dataset._session, Session
|
819
825
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -879,10 +885,8 @@ class TweedieRegressor(BaseTransformer):
|
|
879
885
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
880
886
|
|
881
887
|
if isinstance(dataset, DataFrame):
|
882
|
-
self.
|
883
|
-
|
884
|
-
inference_method=inference_method,
|
885
|
-
)
|
888
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
889
|
+
self._deps = self._get_dependencies()
|
886
890
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
887
891
|
transform_kwargs = dict(
|
888
892
|
session=dataset._session,
|
@@ -946,17 +950,15 @@ class TweedieRegressor(BaseTransformer):
|
|
946
950
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
947
951
|
|
948
952
|
if isinstance(dataset, DataFrame):
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method="score",
|
952
|
-
)
|
953
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
954
|
+
self._deps = self._get_dependencies()
|
953
955
|
selected_cols = self._get_active_columns()
|
954
956
|
if len(selected_cols) > 0:
|
955
957
|
dataset = dataset.select(selected_cols)
|
956
958
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
957
959
|
transform_kwargs = dict(
|
958
960
|
session=dataset._session,
|
959
|
-
dependencies=
|
961
|
+
dependencies=self._deps,
|
960
962
|
score_sproc_imports=['sklearn'],
|
961
963
|
)
|
962
964
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,11 +1023,8 @@ class TweedieRegressor(BaseTransformer):
|
|
1021
1023
|
|
1022
1024
|
if isinstance(dataset, DataFrame):
|
1023
1025
|
|
1024
|
-
self.
|
1025
|
-
|
1026
|
-
inference_method=inference_method,
|
1027
|
-
|
1028
|
-
)
|
1026
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1027
|
+
self._deps = self._get_dependencies()
|
1029
1028
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1030
1029
|
transform_kwargs = dict(
|
1031
1030
|
session = dataset._session,
|