snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SkewedChi2Sampler(BaseTransformer):
|
70
64
|
r"""Approximate feature map for "skewed chi-squared" kernel
|
71
65
|
For more details on this class, see [sklearn.kernel_approximation.SkewedChi2Sampler]
|
@@ -269,20 +263,17 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
269
263
|
self,
|
270
264
|
dataset: DataFrame,
|
271
265
|
inference_method: str,
|
272
|
-
) ->
|
273
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
274
|
-
return the available package that exists in the snowflake anaconda channel
|
266
|
+
) -> None:
|
267
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
275
268
|
|
276
269
|
Args:
|
277
270
|
dataset: snowpark dataframe
|
278
271
|
inference_method: the inference method such as predict, score...
|
279
|
-
|
272
|
+
|
280
273
|
Raises:
|
281
274
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
282
275
|
SnowflakeMLException: If the session is None, raise error
|
283
276
|
|
284
|
-
Returns:
|
285
|
-
A list of available package that exists in the snowflake anaconda channel
|
286
277
|
"""
|
287
278
|
if not self._is_fitted:
|
288
279
|
raise exceptions.SnowflakeMLException(
|
@@ -300,9 +291,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
300
291
|
"Session must not specified for snowpark dataset."
|
301
292
|
),
|
302
293
|
)
|
303
|
-
|
304
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
305
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
294
|
+
|
306
295
|
|
307
296
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
308
297
|
@telemetry.send_api_usage_telemetry(
|
@@ -348,7 +337,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
348
337
|
|
349
338
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
350
339
|
|
351
|
-
self.
|
340
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
341
|
+
self._deps = self._get_dependencies()
|
352
342
|
assert isinstance(
|
353
343
|
dataset._session, Session
|
354
344
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -433,10 +423,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
433
423
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
434
424
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
435
425
|
|
436
|
-
self.
|
437
|
-
|
438
|
-
inference_method=inference_method,
|
439
|
-
)
|
426
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
427
|
+
self._deps = self._get_dependencies()
|
440
428
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
441
429
|
|
442
430
|
transform_kwargs = dict(
|
@@ -503,16 +491,42 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
503
491
|
self._is_fitted = True
|
504
492
|
return output_result
|
505
493
|
|
494
|
+
|
495
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
496
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
497
|
+
""" Fit to data, then transform it
|
498
|
+
For more details on this function, see [sklearn.kernel_approximation.SkewedChi2Sampler.fit_transform]
|
499
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.SkewedChi2Sampler.html#sklearn.kernel_approximation.SkewedChi2Sampler.fit_transform)
|
500
|
+
|
506
501
|
|
507
|
-
|
508
|
-
|
509
|
-
|
502
|
+
Raises:
|
503
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
504
|
+
|
505
|
+
Args:
|
506
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
507
|
+
Snowpark or Pandas DataFrame.
|
508
|
+
output_cols_prefix: Prefix for the response columns
|
510
509
|
Returns:
|
511
510
|
Transformed dataset.
|
512
511
|
"""
|
513
|
-
self.
|
514
|
-
|
515
|
-
|
512
|
+
self._infer_input_output_cols(dataset)
|
513
|
+
super()._check_dataset_type(dataset)
|
514
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
515
|
+
estimator=self._sklearn_object,
|
516
|
+
dataset=dataset,
|
517
|
+
input_cols=self.input_cols,
|
518
|
+
label_cols=self.label_cols,
|
519
|
+
sample_weight_col=self.sample_weight_col,
|
520
|
+
autogenerated=self._autogenerated,
|
521
|
+
subproject=_SUBPROJECT,
|
522
|
+
)
|
523
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
524
|
+
drop_input_cols=self._drop_input_cols,
|
525
|
+
expected_output_cols_list=self.output_cols,
|
526
|
+
)
|
527
|
+
self._sklearn_object = fitted_estimator
|
528
|
+
self._is_fitted = True
|
529
|
+
return output_result
|
516
530
|
|
517
531
|
|
518
532
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -603,10 +617,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
603
617
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
604
618
|
|
605
619
|
if isinstance(dataset, DataFrame):
|
606
|
-
self.
|
607
|
-
|
608
|
-
inference_method=inference_method,
|
609
|
-
)
|
620
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
621
|
+
self._deps = self._get_dependencies()
|
610
622
|
assert isinstance(
|
611
623
|
dataset._session, Session
|
612
624
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -671,10 +683,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
671
683
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
672
684
|
|
673
685
|
if isinstance(dataset, DataFrame):
|
674
|
-
self.
|
675
|
-
|
676
|
-
inference_method=inference_method,
|
677
|
-
)
|
686
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
687
|
+
self._deps = self._get_dependencies()
|
678
688
|
assert isinstance(
|
679
689
|
dataset._session, Session
|
680
690
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -736,10 +746,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
736
746
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
737
747
|
|
738
748
|
if isinstance(dataset, DataFrame):
|
739
|
-
self.
|
740
|
-
|
741
|
-
inference_method=inference_method,
|
742
|
-
)
|
749
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
750
|
+
self._deps = self._get_dependencies()
|
743
751
|
assert isinstance(
|
744
752
|
dataset._session, Session
|
745
753
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -805,10 +813,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
805
813
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
806
814
|
|
807
815
|
if isinstance(dataset, DataFrame):
|
808
|
-
self.
|
809
|
-
|
810
|
-
inference_method=inference_method,
|
811
|
-
)
|
816
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
817
|
+
self._deps = self._get_dependencies()
|
812
818
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
813
819
|
transform_kwargs = dict(
|
814
820
|
session=dataset._session,
|
@@ -870,17 +876,15 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
870
876
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
871
877
|
|
872
878
|
if isinstance(dataset, DataFrame):
|
873
|
-
self.
|
874
|
-
|
875
|
-
inference_method="score",
|
876
|
-
)
|
879
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
880
|
+
self._deps = self._get_dependencies()
|
877
881
|
selected_cols = self._get_active_columns()
|
878
882
|
if len(selected_cols) > 0:
|
879
883
|
dataset = dataset.select(selected_cols)
|
880
884
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
881
885
|
transform_kwargs = dict(
|
882
886
|
session=dataset._session,
|
883
|
-
dependencies=
|
887
|
+
dependencies=self._deps,
|
884
888
|
score_sproc_imports=['sklearn'],
|
885
889
|
)
|
886
890
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -945,11 +949,8 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
945
949
|
|
946
950
|
if isinstance(dataset, DataFrame):
|
947
951
|
|
948
|
-
self.
|
949
|
-
|
950
|
-
inference_method=inference_method,
|
951
|
-
|
952
|
-
)
|
952
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
953
|
+
self._deps = self._get_dependencies()
|
953
954
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
954
955
|
transform_kwargs = dict(
|
955
956
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_ridge".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class KernelRidge(BaseTransformer):
|
70
64
|
r"""Kernel ridge regression
|
71
65
|
For more details on this class, see [sklearn.kernel_ridge.KernelRidge]
|
@@ -305,20 +299,17 @@ class KernelRidge(BaseTransformer):
|
|
305
299
|
self,
|
306
300
|
dataset: DataFrame,
|
307
301
|
inference_method: str,
|
308
|
-
) ->
|
309
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
310
|
-
return the available package that exists in the snowflake anaconda channel
|
302
|
+
) -> None:
|
303
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
311
304
|
|
312
305
|
Args:
|
313
306
|
dataset: snowpark dataframe
|
314
307
|
inference_method: the inference method such as predict, score...
|
315
|
-
|
308
|
+
|
316
309
|
Raises:
|
317
310
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
318
311
|
SnowflakeMLException: If the session is None, raise error
|
319
312
|
|
320
|
-
Returns:
|
321
|
-
A list of available package that exists in the snowflake anaconda channel
|
322
313
|
"""
|
323
314
|
if not self._is_fitted:
|
324
315
|
raise exceptions.SnowflakeMLException(
|
@@ -336,9 +327,7 @@ class KernelRidge(BaseTransformer):
|
|
336
327
|
"Session must not specified for snowpark dataset."
|
337
328
|
),
|
338
329
|
)
|
339
|
-
|
340
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
341
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
330
|
+
|
342
331
|
|
343
332
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
344
333
|
@telemetry.send_api_usage_telemetry(
|
@@ -386,7 +375,8 @@ class KernelRidge(BaseTransformer):
|
|
386
375
|
|
387
376
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
388
377
|
|
389
|
-
self.
|
378
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
379
|
+
self._deps = self._get_dependencies()
|
390
380
|
assert isinstance(
|
391
381
|
dataset._session, Session
|
392
382
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -469,10 +459,8 @@ class KernelRidge(BaseTransformer):
|
|
469
459
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
470
460
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
471
461
|
|
472
|
-
self.
|
473
|
-
|
474
|
-
inference_method=inference_method,
|
475
|
-
)
|
462
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
463
|
+
self._deps = self._get_dependencies()
|
476
464
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
477
465
|
|
478
466
|
transform_kwargs = dict(
|
@@ -539,16 +527,40 @@ class KernelRidge(BaseTransformer):
|
|
539
527
|
self._is_fitted = True
|
540
528
|
return output_result
|
541
529
|
|
530
|
+
|
531
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
532
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
533
|
+
""" Method not supported for this class.
|
542
534
|
|
543
|
-
|
544
|
-
|
545
|
-
|
535
|
+
|
536
|
+
Raises:
|
537
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
541
|
+
Snowpark or Pandas DataFrame.
|
542
|
+
output_cols_prefix: Prefix for the response columns
|
546
543
|
Returns:
|
547
544
|
Transformed dataset.
|
548
545
|
"""
|
549
|
-
self.
|
550
|
-
|
551
|
-
|
546
|
+
self._infer_input_output_cols(dataset)
|
547
|
+
super()._check_dataset_type(dataset)
|
548
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
549
|
+
estimator=self._sklearn_object,
|
550
|
+
dataset=dataset,
|
551
|
+
input_cols=self.input_cols,
|
552
|
+
label_cols=self.label_cols,
|
553
|
+
sample_weight_col=self.sample_weight_col,
|
554
|
+
autogenerated=self._autogenerated,
|
555
|
+
subproject=_SUBPROJECT,
|
556
|
+
)
|
557
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
558
|
+
drop_input_cols=self._drop_input_cols,
|
559
|
+
expected_output_cols_list=self.output_cols,
|
560
|
+
)
|
561
|
+
self._sklearn_object = fitted_estimator
|
562
|
+
self._is_fitted = True
|
563
|
+
return output_result
|
552
564
|
|
553
565
|
|
554
566
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -639,10 +651,8 @@ class KernelRidge(BaseTransformer):
|
|
639
651
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
640
652
|
|
641
653
|
if isinstance(dataset, DataFrame):
|
642
|
-
self.
|
643
|
-
|
644
|
-
inference_method=inference_method,
|
645
|
-
)
|
654
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
655
|
+
self._deps = self._get_dependencies()
|
646
656
|
assert isinstance(
|
647
657
|
dataset._session, Session
|
648
658
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -707,10 +717,8 @@ class KernelRidge(BaseTransformer):
|
|
707
717
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
708
718
|
|
709
719
|
if isinstance(dataset, DataFrame):
|
710
|
-
self.
|
711
|
-
|
712
|
-
inference_method=inference_method,
|
713
|
-
)
|
720
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
721
|
+
self._deps = self._get_dependencies()
|
714
722
|
assert isinstance(
|
715
723
|
dataset._session, Session
|
716
724
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -772,10 +780,8 @@ class KernelRidge(BaseTransformer):
|
|
772
780
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
773
781
|
|
774
782
|
if isinstance(dataset, DataFrame):
|
775
|
-
self.
|
776
|
-
|
777
|
-
inference_method=inference_method,
|
778
|
-
)
|
783
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
784
|
+
self._deps = self._get_dependencies()
|
779
785
|
assert isinstance(
|
780
786
|
dataset._session, Session
|
781
787
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -841,10 +847,8 @@ class KernelRidge(BaseTransformer):
|
|
841
847
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
842
848
|
|
843
849
|
if isinstance(dataset, DataFrame):
|
844
|
-
self.
|
845
|
-
|
846
|
-
inference_method=inference_method,
|
847
|
-
)
|
850
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
851
|
+
self._deps = self._get_dependencies()
|
848
852
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
849
853
|
transform_kwargs = dict(
|
850
854
|
session=dataset._session,
|
@@ -908,17 +912,15 @@ class KernelRidge(BaseTransformer):
|
|
908
912
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
909
913
|
|
910
914
|
if isinstance(dataset, DataFrame):
|
911
|
-
self.
|
912
|
-
|
913
|
-
inference_method="score",
|
914
|
-
)
|
915
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
916
|
+
self._deps = self._get_dependencies()
|
915
917
|
selected_cols = self._get_active_columns()
|
916
918
|
if len(selected_cols) > 0:
|
917
919
|
dataset = dataset.select(selected_cols)
|
918
920
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
919
921
|
transform_kwargs = dict(
|
920
922
|
session=dataset._session,
|
921
|
-
dependencies=
|
923
|
+
dependencies=self._deps,
|
922
924
|
score_sproc_imports=['sklearn'],
|
923
925
|
)
|
924
926
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -983,11 +985,8 @@ class KernelRidge(BaseTransformer):
|
|
983
985
|
|
984
986
|
if isinstance(dataset, DataFrame):
|
985
987
|
|
986
|
-
self.
|
987
|
-
|
988
|
-
inference_method=inference_method,
|
989
|
-
|
990
|
-
)
|
988
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
989
|
+
self._deps = self._get_dependencies()
|
991
990
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
992
991
|
transform_kwargs = dict(
|
993
992
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", ""
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LGBMClassifier(BaseTransformer):
|
70
64
|
r"""LightGBM classifier
|
71
65
|
For more details on this class, see [lightgbm.LGBMClassifier]
|
@@ -294,20 +288,17 @@ class LGBMClassifier(BaseTransformer):
|
|
294
288
|
self,
|
295
289
|
dataset: DataFrame,
|
296
290
|
inference_method: str,
|
297
|
-
) ->
|
298
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
299
|
-
return the available package that exists in the snowflake anaconda channel
|
291
|
+
) -> None:
|
292
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
300
293
|
|
301
294
|
Args:
|
302
295
|
dataset: snowpark dataframe
|
303
296
|
inference_method: the inference method such as predict, score...
|
304
|
-
|
297
|
+
|
305
298
|
Raises:
|
306
299
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
307
300
|
SnowflakeMLException: If the session is None, raise error
|
308
301
|
|
309
|
-
Returns:
|
310
|
-
A list of available package that exists in the snowflake anaconda channel
|
311
302
|
"""
|
312
303
|
if not self._is_fitted:
|
313
304
|
raise exceptions.SnowflakeMLException(
|
@@ -325,9 +316,7 @@ class LGBMClassifier(BaseTransformer):
|
|
325
316
|
"Session must not specified for snowpark dataset."
|
326
317
|
),
|
327
318
|
)
|
328
|
-
|
329
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
330
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
319
|
+
|
331
320
|
|
332
321
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
333
322
|
@telemetry.send_api_usage_telemetry(
|
@@ -375,7 +364,8 @@ class LGBMClassifier(BaseTransformer):
|
|
375
364
|
|
376
365
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
377
366
|
|
378
|
-
self.
|
367
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
368
|
+
self._deps = self._get_dependencies()
|
379
369
|
assert isinstance(
|
380
370
|
dataset._session, Session
|
381
371
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -458,10 +448,8 @@ class LGBMClassifier(BaseTransformer):
|
|
458
448
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
459
449
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
460
450
|
|
461
|
-
self.
|
462
|
-
|
463
|
-
inference_method=inference_method,
|
464
|
-
)
|
451
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
452
|
+
self._deps = self._get_dependencies()
|
465
453
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
466
454
|
|
467
455
|
transform_kwargs = dict(
|
@@ -528,16 +516,40 @@ class LGBMClassifier(BaseTransformer):
|
|
528
516
|
self._is_fitted = True
|
529
517
|
return output_result
|
530
518
|
|
519
|
+
|
520
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
521
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
522
|
+
""" Method not supported for this class.
|
531
523
|
|
532
|
-
|
533
|
-
|
534
|
-
|
524
|
+
|
525
|
+
Raises:
|
526
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
527
|
+
|
528
|
+
Args:
|
529
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
530
|
+
Snowpark or Pandas DataFrame.
|
531
|
+
output_cols_prefix: Prefix for the response columns
|
535
532
|
Returns:
|
536
533
|
Transformed dataset.
|
537
534
|
"""
|
538
|
-
self.
|
539
|
-
|
540
|
-
|
535
|
+
self._infer_input_output_cols(dataset)
|
536
|
+
super()._check_dataset_type(dataset)
|
537
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
538
|
+
estimator=self._sklearn_object,
|
539
|
+
dataset=dataset,
|
540
|
+
input_cols=self.input_cols,
|
541
|
+
label_cols=self.label_cols,
|
542
|
+
sample_weight_col=self.sample_weight_col,
|
543
|
+
autogenerated=self._autogenerated,
|
544
|
+
subproject=_SUBPROJECT,
|
545
|
+
)
|
546
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
547
|
+
drop_input_cols=self._drop_input_cols,
|
548
|
+
expected_output_cols_list=self.output_cols,
|
549
|
+
)
|
550
|
+
self._sklearn_object = fitted_estimator
|
551
|
+
self._is_fitted = True
|
552
|
+
return output_result
|
541
553
|
|
542
554
|
|
543
555
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -630,10 +642,8 @@ class LGBMClassifier(BaseTransformer):
|
|
630
642
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
631
643
|
|
632
644
|
if isinstance(dataset, DataFrame):
|
633
|
-
self.
|
634
|
-
|
635
|
-
inference_method=inference_method,
|
636
|
-
)
|
645
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
646
|
+
self._deps = self._get_dependencies()
|
637
647
|
assert isinstance(
|
638
648
|
dataset._session, Session
|
639
649
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -700,10 +710,8 @@ class LGBMClassifier(BaseTransformer):
|
|
700
710
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
701
711
|
|
702
712
|
if isinstance(dataset, DataFrame):
|
703
|
-
self.
|
704
|
-
|
705
|
-
inference_method=inference_method,
|
706
|
-
)
|
713
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
714
|
+
self._deps = self._get_dependencies()
|
707
715
|
assert isinstance(
|
708
716
|
dataset._session, Session
|
709
717
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -765,10 +773,8 @@ class LGBMClassifier(BaseTransformer):
|
|
765
773
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
766
774
|
|
767
775
|
if isinstance(dataset, DataFrame):
|
768
|
-
self.
|
769
|
-
|
770
|
-
inference_method=inference_method,
|
771
|
-
)
|
776
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
777
|
+
self._deps = self._get_dependencies()
|
772
778
|
assert isinstance(
|
773
779
|
dataset._session, Session
|
774
780
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -834,10 +840,8 @@ class LGBMClassifier(BaseTransformer):
|
|
834
840
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
835
841
|
|
836
842
|
if isinstance(dataset, DataFrame):
|
837
|
-
self.
|
838
|
-
|
839
|
-
inference_method=inference_method,
|
840
|
-
)
|
843
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
844
|
+
self._deps = self._get_dependencies()
|
841
845
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
842
846
|
transform_kwargs = dict(
|
843
847
|
session=dataset._session,
|
@@ -901,17 +905,15 @@ class LGBMClassifier(BaseTransformer):
|
|
901
905
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
902
906
|
|
903
907
|
if isinstance(dataset, DataFrame):
|
904
|
-
self.
|
905
|
-
|
906
|
-
inference_method="score",
|
907
|
-
)
|
908
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
909
|
+
self._deps = self._get_dependencies()
|
908
910
|
selected_cols = self._get_active_columns()
|
909
911
|
if len(selected_cols) > 0:
|
910
912
|
dataset = dataset.select(selected_cols)
|
911
913
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
912
914
|
transform_kwargs = dict(
|
913
915
|
session=dataset._session,
|
914
|
-
dependencies=
|
916
|
+
dependencies=self._deps,
|
915
917
|
score_sproc_imports=['lightgbm', 'sklearn'],
|
916
918
|
)
|
917
919
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -976,11 +978,8 @@ class LGBMClassifier(BaseTransformer):
|
|
976
978
|
|
977
979
|
if isinstance(dataset, DataFrame):
|
978
980
|
|
979
|
-
self.
|
980
|
-
|
981
|
-
inference_method=inference_method,
|
982
|
-
|
983
|
-
)
|
981
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
982
|
+
self._deps = self._get_dependencies()
|
984
983
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
985
984
|
transform_kwargs = dict(
|
986
985
|
session = dataset._session,
|