snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class AdaBoostRegressor(BaseTransformer):
|
70
64
|
r"""An AdaBoost regressor
|
71
65
|
For more details on this class, see [sklearn.ensemble.AdaBoostRegressor]
|
@@ -302,20 +296,17 @@ class AdaBoostRegressor(BaseTransformer):
|
|
302
296
|
self,
|
303
297
|
dataset: DataFrame,
|
304
298
|
inference_method: str,
|
305
|
-
) ->
|
306
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
307
|
-
return the available package that exists in the snowflake anaconda channel
|
299
|
+
) -> None:
|
300
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
308
301
|
|
309
302
|
Args:
|
310
303
|
dataset: snowpark dataframe
|
311
304
|
inference_method: the inference method such as predict, score...
|
312
|
-
|
305
|
+
|
313
306
|
Raises:
|
314
307
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
315
308
|
SnowflakeMLException: If the session is None, raise error
|
316
309
|
|
317
|
-
Returns:
|
318
|
-
A list of available package that exists in the snowflake anaconda channel
|
319
310
|
"""
|
320
311
|
if not self._is_fitted:
|
321
312
|
raise exceptions.SnowflakeMLException(
|
@@ -333,9 +324,7 @@ class AdaBoostRegressor(BaseTransformer):
|
|
333
324
|
"Session must not specified for snowpark dataset."
|
334
325
|
),
|
335
326
|
)
|
336
|
-
|
337
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
338
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
327
|
+
|
339
328
|
|
340
329
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
341
330
|
@telemetry.send_api_usage_telemetry(
|
@@ -383,7 +372,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
383
372
|
|
384
373
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
385
374
|
|
386
|
-
self.
|
375
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
376
|
+
self._deps = self._get_dependencies()
|
387
377
|
assert isinstance(
|
388
378
|
dataset._session, Session
|
389
379
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -466,10 +456,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
466
456
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
467
457
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
468
458
|
|
469
|
-
self.
|
470
|
-
|
471
|
-
inference_method=inference_method,
|
472
|
-
)
|
459
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
460
|
+
self._deps = self._get_dependencies()
|
473
461
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
474
462
|
|
475
463
|
transform_kwargs = dict(
|
@@ -536,16 +524,40 @@ class AdaBoostRegressor(BaseTransformer):
|
|
536
524
|
self._is_fitted = True
|
537
525
|
return output_result
|
538
526
|
|
527
|
+
|
528
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
529
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
530
|
+
""" Method not supported for this class.
|
539
531
|
|
540
|
-
|
541
|
-
|
542
|
-
|
532
|
+
|
533
|
+
Raises:
|
534
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
535
|
+
|
536
|
+
Args:
|
537
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
538
|
+
Snowpark or Pandas DataFrame.
|
539
|
+
output_cols_prefix: Prefix for the response columns
|
543
540
|
Returns:
|
544
541
|
Transformed dataset.
|
545
542
|
"""
|
546
|
-
self.
|
547
|
-
|
548
|
-
|
543
|
+
self._infer_input_output_cols(dataset)
|
544
|
+
super()._check_dataset_type(dataset)
|
545
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
546
|
+
estimator=self._sklearn_object,
|
547
|
+
dataset=dataset,
|
548
|
+
input_cols=self.input_cols,
|
549
|
+
label_cols=self.label_cols,
|
550
|
+
sample_weight_col=self.sample_weight_col,
|
551
|
+
autogenerated=self._autogenerated,
|
552
|
+
subproject=_SUBPROJECT,
|
553
|
+
)
|
554
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
555
|
+
drop_input_cols=self._drop_input_cols,
|
556
|
+
expected_output_cols_list=self.output_cols,
|
557
|
+
)
|
558
|
+
self._sklearn_object = fitted_estimator
|
559
|
+
self._is_fitted = True
|
560
|
+
return output_result
|
549
561
|
|
550
562
|
|
551
563
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -636,10 +648,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
636
648
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
637
649
|
|
638
650
|
if isinstance(dataset, DataFrame):
|
639
|
-
self.
|
640
|
-
|
641
|
-
inference_method=inference_method,
|
642
|
-
)
|
651
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
652
|
+
self._deps = self._get_dependencies()
|
643
653
|
assert isinstance(
|
644
654
|
dataset._session, Session
|
645
655
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -704,10 +714,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
704
714
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
705
715
|
|
706
716
|
if isinstance(dataset, DataFrame):
|
707
|
-
self.
|
708
|
-
|
709
|
-
inference_method=inference_method,
|
710
|
-
)
|
717
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
718
|
+
self._deps = self._get_dependencies()
|
711
719
|
assert isinstance(
|
712
720
|
dataset._session, Session
|
713
721
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -769,10 +777,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
769
777
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
770
778
|
|
771
779
|
if isinstance(dataset, DataFrame):
|
772
|
-
self.
|
773
|
-
|
774
|
-
inference_method=inference_method,
|
775
|
-
)
|
780
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
781
|
+
self._deps = self._get_dependencies()
|
776
782
|
assert isinstance(
|
777
783
|
dataset._session, Session
|
778
784
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -838,10 +844,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
838
844
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
839
845
|
|
840
846
|
if isinstance(dataset, DataFrame):
|
841
|
-
self.
|
842
|
-
|
843
|
-
inference_method=inference_method,
|
844
|
-
)
|
847
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
848
|
+
self._deps = self._get_dependencies()
|
845
849
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
846
850
|
transform_kwargs = dict(
|
847
851
|
session=dataset._session,
|
@@ -905,17 +909,15 @@ class AdaBoostRegressor(BaseTransformer):
|
|
905
909
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
906
910
|
|
907
911
|
if isinstance(dataset, DataFrame):
|
908
|
-
self.
|
909
|
-
|
910
|
-
inference_method="score",
|
911
|
-
)
|
912
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
913
|
+
self._deps = self._get_dependencies()
|
912
914
|
selected_cols = self._get_active_columns()
|
913
915
|
if len(selected_cols) > 0:
|
914
916
|
dataset = dataset.select(selected_cols)
|
915
917
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
916
918
|
transform_kwargs = dict(
|
917
919
|
session=dataset._session,
|
918
|
-
dependencies=
|
920
|
+
dependencies=self._deps,
|
919
921
|
score_sproc_imports=['sklearn'],
|
920
922
|
)
|
921
923
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -980,11 +982,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
980
982
|
|
981
983
|
if isinstance(dataset, DataFrame):
|
982
984
|
|
983
|
-
self.
|
984
|
-
|
985
|
-
inference_method=inference_method,
|
986
|
-
|
987
|
-
)
|
985
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
986
|
+
self._deps = self._get_dependencies()
|
988
987
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
989
988
|
transform_kwargs = dict(
|
990
989
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class BaggingClassifier(BaseTransformer):
|
70
64
|
r"""A Bagging classifier
|
71
65
|
For more details on this class, see [sklearn.ensemble.BaggingClassifier]
|
@@ -337,20 +331,17 @@ class BaggingClassifier(BaseTransformer):
|
|
337
331
|
self,
|
338
332
|
dataset: DataFrame,
|
339
333
|
inference_method: str,
|
340
|
-
) ->
|
341
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
342
|
-
return the available package that exists in the snowflake anaconda channel
|
334
|
+
) -> None:
|
335
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
343
336
|
|
344
337
|
Args:
|
345
338
|
dataset: snowpark dataframe
|
346
339
|
inference_method: the inference method such as predict, score...
|
347
|
-
|
340
|
+
|
348
341
|
Raises:
|
349
342
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
350
343
|
SnowflakeMLException: If the session is None, raise error
|
351
344
|
|
352
|
-
Returns:
|
353
|
-
A list of available package that exists in the snowflake anaconda channel
|
354
345
|
"""
|
355
346
|
if not self._is_fitted:
|
356
347
|
raise exceptions.SnowflakeMLException(
|
@@ -368,9 +359,7 @@ class BaggingClassifier(BaseTransformer):
|
|
368
359
|
"Session must not specified for snowpark dataset."
|
369
360
|
),
|
370
361
|
)
|
371
|
-
|
372
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
373
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
362
|
+
|
374
363
|
|
375
364
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
376
365
|
@telemetry.send_api_usage_telemetry(
|
@@ -418,7 +407,8 @@ class BaggingClassifier(BaseTransformer):
|
|
418
407
|
|
419
408
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
420
409
|
|
421
|
-
self.
|
410
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
411
|
+
self._deps = self._get_dependencies()
|
422
412
|
assert isinstance(
|
423
413
|
dataset._session, Session
|
424
414
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -501,10 +491,8 @@ class BaggingClassifier(BaseTransformer):
|
|
501
491
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
502
492
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
503
493
|
|
504
|
-
self.
|
505
|
-
|
506
|
-
inference_method=inference_method,
|
507
|
-
)
|
494
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
495
|
+
self._deps = self._get_dependencies()
|
508
496
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
509
497
|
|
510
498
|
transform_kwargs = dict(
|
@@ -571,16 +559,40 @@ class BaggingClassifier(BaseTransformer):
|
|
571
559
|
self._is_fitted = True
|
572
560
|
return output_result
|
573
561
|
|
562
|
+
|
563
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
564
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
565
|
+
""" Method not supported for this class.
|
574
566
|
|
575
|
-
|
576
|
-
|
577
|
-
|
567
|
+
|
568
|
+
Raises:
|
569
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
570
|
+
|
571
|
+
Args:
|
572
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
573
|
+
Snowpark or Pandas DataFrame.
|
574
|
+
output_cols_prefix: Prefix for the response columns
|
578
575
|
Returns:
|
579
576
|
Transformed dataset.
|
580
577
|
"""
|
581
|
-
self.
|
582
|
-
|
583
|
-
|
578
|
+
self._infer_input_output_cols(dataset)
|
579
|
+
super()._check_dataset_type(dataset)
|
580
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
581
|
+
estimator=self._sklearn_object,
|
582
|
+
dataset=dataset,
|
583
|
+
input_cols=self.input_cols,
|
584
|
+
label_cols=self.label_cols,
|
585
|
+
sample_weight_col=self.sample_weight_col,
|
586
|
+
autogenerated=self._autogenerated,
|
587
|
+
subproject=_SUBPROJECT,
|
588
|
+
)
|
589
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
590
|
+
drop_input_cols=self._drop_input_cols,
|
591
|
+
expected_output_cols_list=self.output_cols,
|
592
|
+
)
|
593
|
+
self._sklearn_object = fitted_estimator
|
594
|
+
self._is_fitted = True
|
595
|
+
return output_result
|
584
596
|
|
585
597
|
|
586
598
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -673,10 +685,8 @@ class BaggingClassifier(BaseTransformer):
|
|
673
685
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
674
686
|
|
675
687
|
if isinstance(dataset, DataFrame):
|
676
|
-
self.
|
677
|
-
|
678
|
-
inference_method=inference_method,
|
679
|
-
)
|
688
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
689
|
+
self._deps = self._get_dependencies()
|
680
690
|
assert isinstance(
|
681
691
|
dataset._session, Session
|
682
692
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -743,10 +753,8 @@ class BaggingClassifier(BaseTransformer):
|
|
743
753
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
744
754
|
|
745
755
|
if isinstance(dataset, DataFrame):
|
746
|
-
self.
|
747
|
-
|
748
|
-
inference_method=inference_method,
|
749
|
-
)
|
756
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
757
|
+
self._deps = self._get_dependencies()
|
750
758
|
assert isinstance(
|
751
759
|
dataset._session, Session
|
752
760
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -810,10 +818,8 @@ class BaggingClassifier(BaseTransformer):
|
|
810
818
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
811
819
|
|
812
820
|
if isinstance(dataset, DataFrame):
|
813
|
-
self.
|
814
|
-
|
815
|
-
inference_method=inference_method,
|
816
|
-
)
|
821
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
822
|
+
self._deps = self._get_dependencies()
|
817
823
|
assert isinstance(
|
818
824
|
dataset._session, Session
|
819
825
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -879,10 +885,8 @@ class BaggingClassifier(BaseTransformer):
|
|
879
885
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
880
886
|
|
881
887
|
if isinstance(dataset, DataFrame):
|
882
|
-
self.
|
883
|
-
|
884
|
-
inference_method=inference_method,
|
885
|
-
)
|
888
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
889
|
+
self._deps = self._get_dependencies()
|
886
890
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
887
891
|
transform_kwargs = dict(
|
888
892
|
session=dataset._session,
|
@@ -946,17 +950,15 @@ class BaggingClassifier(BaseTransformer):
|
|
946
950
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
947
951
|
|
948
952
|
if isinstance(dataset, DataFrame):
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method="score",
|
952
|
-
)
|
953
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
954
|
+
self._deps = self._get_dependencies()
|
953
955
|
selected_cols = self._get_active_columns()
|
954
956
|
if len(selected_cols) > 0:
|
955
957
|
dataset = dataset.select(selected_cols)
|
956
958
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
957
959
|
transform_kwargs = dict(
|
958
960
|
session=dataset._session,
|
959
|
-
dependencies=
|
961
|
+
dependencies=self._deps,
|
960
962
|
score_sproc_imports=['sklearn'],
|
961
963
|
)
|
962
964
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,11 +1023,8 @@ class BaggingClassifier(BaseTransformer):
|
|
1021
1023
|
|
1022
1024
|
if isinstance(dataset, DataFrame):
|
1023
1025
|
|
1024
|
-
self.
|
1025
|
-
|
1026
|
-
inference_method=inference_method,
|
1027
|
-
|
1028
|
-
)
|
1026
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1027
|
+
self._deps = self._get_dependencies()
|
1029
1028
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1030
1029
|
transform_kwargs = dict(
|
1031
1030
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class BaggingRegressor(BaseTransformer):
|
70
64
|
r"""A Bagging regressor
|
71
65
|
For more details on this class, see [sklearn.ensemble.BaggingRegressor]
|
@@ -337,20 +331,17 @@ class BaggingRegressor(BaseTransformer):
|
|
337
331
|
self,
|
338
332
|
dataset: DataFrame,
|
339
333
|
inference_method: str,
|
340
|
-
) ->
|
341
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
342
|
-
return the available package that exists in the snowflake anaconda channel
|
334
|
+
) -> None:
|
335
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
343
336
|
|
344
337
|
Args:
|
345
338
|
dataset: snowpark dataframe
|
346
339
|
inference_method: the inference method such as predict, score...
|
347
|
-
|
340
|
+
|
348
341
|
Raises:
|
349
342
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
350
343
|
SnowflakeMLException: If the session is None, raise error
|
351
344
|
|
352
|
-
Returns:
|
353
|
-
A list of available package that exists in the snowflake anaconda channel
|
354
345
|
"""
|
355
346
|
if not self._is_fitted:
|
356
347
|
raise exceptions.SnowflakeMLException(
|
@@ -368,9 +359,7 @@ class BaggingRegressor(BaseTransformer):
|
|
368
359
|
"Session must not specified for snowpark dataset."
|
369
360
|
),
|
370
361
|
)
|
371
|
-
|
372
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
373
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
362
|
+
|
374
363
|
|
375
364
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
376
365
|
@telemetry.send_api_usage_telemetry(
|
@@ -418,7 +407,8 @@ class BaggingRegressor(BaseTransformer):
|
|
418
407
|
|
419
408
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
420
409
|
|
421
|
-
self.
|
410
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
411
|
+
self._deps = self._get_dependencies()
|
422
412
|
assert isinstance(
|
423
413
|
dataset._session, Session
|
424
414
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -501,10 +491,8 @@ class BaggingRegressor(BaseTransformer):
|
|
501
491
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
502
492
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
503
493
|
|
504
|
-
self.
|
505
|
-
|
506
|
-
inference_method=inference_method,
|
507
|
-
)
|
494
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
495
|
+
self._deps = self._get_dependencies()
|
508
496
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
509
497
|
|
510
498
|
transform_kwargs = dict(
|
@@ -571,16 +559,40 @@ class BaggingRegressor(BaseTransformer):
|
|
571
559
|
self._is_fitted = True
|
572
560
|
return output_result
|
573
561
|
|
562
|
+
|
563
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
564
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
565
|
+
""" Method not supported for this class.
|
574
566
|
|
575
|
-
|
576
|
-
|
577
|
-
|
567
|
+
|
568
|
+
Raises:
|
569
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
570
|
+
|
571
|
+
Args:
|
572
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
573
|
+
Snowpark or Pandas DataFrame.
|
574
|
+
output_cols_prefix: Prefix for the response columns
|
578
575
|
Returns:
|
579
576
|
Transformed dataset.
|
580
577
|
"""
|
581
|
-
self.
|
582
|
-
|
583
|
-
|
578
|
+
self._infer_input_output_cols(dataset)
|
579
|
+
super()._check_dataset_type(dataset)
|
580
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
581
|
+
estimator=self._sklearn_object,
|
582
|
+
dataset=dataset,
|
583
|
+
input_cols=self.input_cols,
|
584
|
+
label_cols=self.label_cols,
|
585
|
+
sample_weight_col=self.sample_weight_col,
|
586
|
+
autogenerated=self._autogenerated,
|
587
|
+
subproject=_SUBPROJECT,
|
588
|
+
)
|
589
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
590
|
+
drop_input_cols=self._drop_input_cols,
|
591
|
+
expected_output_cols_list=self.output_cols,
|
592
|
+
)
|
593
|
+
self._sklearn_object = fitted_estimator
|
594
|
+
self._is_fitted = True
|
595
|
+
return output_result
|
584
596
|
|
585
597
|
|
586
598
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -671,10 +683,8 @@ class BaggingRegressor(BaseTransformer):
|
|
671
683
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
672
684
|
|
673
685
|
if isinstance(dataset, DataFrame):
|
674
|
-
self.
|
675
|
-
|
676
|
-
inference_method=inference_method,
|
677
|
-
)
|
686
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
687
|
+
self._deps = self._get_dependencies()
|
678
688
|
assert isinstance(
|
679
689
|
dataset._session, Session
|
680
690
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -739,10 +749,8 @@ class BaggingRegressor(BaseTransformer):
|
|
739
749
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
740
750
|
|
741
751
|
if isinstance(dataset, DataFrame):
|
742
|
-
self.
|
743
|
-
|
744
|
-
inference_method=inference_method,
|
745
|
-
)
|
752
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
753
|
+
self._deps = self._get_dependencies()
|
746
754
|
assert isinstance(
|
747
755
|
dataset._session, Session
|
748
756
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -804,10 +812,8 @@ class BaggingRegressor(BaseTransformer):
|
|
804
812
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
805
813
|
|
806
814
|
if isinstance(dataset, DataFrame):
|
807
|
-
self.
|
808
|
-
|
809
|
-
inference_method=inference_method,
|
810
|
-
)
|
815
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
816
|
+
self._deps = self._get_dependencies()
|
811
817
|
assert isinstance(
|
812
818
|
dataset._session, Session
|
813
819
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -873,10 +879,8 @@ class BaggingRegressor(BaseTransformer):
|
|
873
879
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
874
880
|
|
875
881
|
if isinstance(dataset, DataFrame):
|
876
|
-
self.
|
877
|
-
|
878
|
-
inference_method=inference_method,
|
879
|
-
)
|
882
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
883
|
+
self._deps = self._get_dependencies()
|
880
884
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
881
885
|
transform_kwargs = dict(
|
882
886
|
session=dataset._session,
|
@@ -940,17 +944,15 @@ class BaggingRegressor(BaseTransformer):
|
|
940
944
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
941
945
|
|
942
946
|
if isinstance(dataset, DataFrame):
|
943
|
-
self.
|
944
|
-
|
945
|
-
inference_method="score",
|
946
|
-
)
|
947
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
948
|
+
self._deps = self._get_dependencies()
|
947
949
|
selected_cols = self._get_active_columns()
|
948
950
|
if len(selected_cols) > 0:
|
949
951
|
dataset = dataset.select(selected_cols)
|
950
952
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
951
953
|
transform_kwargs = dict(
|
952
954
|
session=dataset._session,
|
953
|
-
dependencies=
|
955
|
+
dependencies=self._deps,
|
954
956
|
score_sproc_imports=['sklearn'],
|
955
957
|
)
|
956
958
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1015,11 +1017,8 @@ class BaggingRegressor(BaseTransformer):
|
|
1015
1017
|
|
1016
1018
|
if isinstance(dataset, DataFrame):
|
1017
1019
|
|
1018
|
-
self.
|
1019
|
-
|
1020
|
-
inference_method=inference_method,
|
1021
|
-
|
1022
|
-
)
|
1020
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1021
|
+
self._deps = self._get_dependencies()
|
1023
1022
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1024
1023
|
transform_kwargs = dict(
|
1025
1024
|
session = dataset._session,
|