snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class OutputCodeClassifier(BaseTransformer):
|
70
64
|
r"""(Error-Correcting) Output-Code multiclass strategy
|
71
65
|
For more details on this class, see [sklearn.multiclass.OutputCodeClassifier]
|
@@ -283,20 +277,17 @@ class OutputCodeClassifier(BaseTransformer):
|
|
283
277
|
self,
|
284
278
|
dataset: DataFrame,
|
285
279
|
inference_method: str,
|
286
|
-
) ->
|
287
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
288
|
-
return the available package that exists in the snowflake anaconda channel
|
280
|
+
) -> None:
|
281
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
289
282
|
|
290
283
|
Args:
|
291
284
|
dataset: snowpark dataframe
|
292
285
|
inference_method: the inference method such as predict, score...
|
293
|
-
|
286
|
+
|
294
287
|
Raises:
|
295
288
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
296
289
|
SnowflakeMLException: If the session is None, raise error
|
297
290
|
|
298
|
-
Returns:
|
299
|
-
A list of available package that exists in the snowflake anaconda channel
|
300
291
|
"""
|
301
292
|
if not self._is_fitted:
|
302
293
|
raise exceptions.SnowflakeMLException(
|
@@ -314,9 +305,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
314
305
|
"Session must not specified for snowpark dataset."
|
315
306
|
),
|
316
307
|
)
|
317
|
-
|
318
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
319
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
308
|
+
|
320
309
|
|
321
310
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
322
311
|
@telemetry.send_api_usage_telemetry(
|
@@ -364,7 +353,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
364
353
|
|
365
354
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
366
355
|
|
367
|
-
self.
|
356
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
357
|
+
self._deps = self._get_dependencies()
|
368
358
|
assert isinstance(
|
369
359
|
dataset._session, Session
|
370
360
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -447,10 +437,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
447
437
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
448
438
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
449
439
|
|
450
|
-
self.
|
451
|
-
|
452
|
-
inference_method=inference_method,
|
453
|
-
)
|
440
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
441
|
+
self._deps = self._get_dependencies()
|
454
442
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
455
443
|
|
456
444
|
transform_kwargs = dict(
|
@@ -517,16 +505,40 @@ class OutputCodeClassifier(BaseTransformer):
|
|
517
505
|
self._is_fitted = True
|
518
506
|
return output_result
|
519
507
|
|
508
|
+
|
509
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
510
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
511
|
+
""" Method not supported for this class.
|
520
512
|
|
521
|
-
|
522
|
-
|
523
|
-
|
513
|
+
|
514
|
+
Raises:
|
515
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
519
|
+
Snowpark or Pandas DataFrame.
|
520
|
+
output_cols_prefix: Prefix for the response columns
|
524
521
|
Returns:
|
525
522
|
Transformed dataset.
|
526
523
|
"""
|
527
|
-
self.
|
528
|
-
|
529
|
-
|
524
|
+
self._infer_input_output_cols(dataset)
|
525
|
+
super()._check_dataset_type(dataset)
|
526
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
527
|
+
estimator=self._sklearn_object,
|
528
|
+
dataset=dataset,
|
529
|
+
input_cols=self.input_cols,
|
530
|
+
label_cols=self.label_cols,
|
531
|
+
sample_weight_col=self.sample_weight_col,
|
532
|
+
autogenerated=self._autogenerated,
|
533
|
+
subproject=_SUBPROJECT,
|
534
|
+
)
|
535
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
536
|
+
drop_input_cols=self._drop_input_cols,
|
537
|
+
expected_output_cols_list=self.output_cols,
|
538
|
+
)
|
539
|
+
self._sklearn_object = fitted_estimator
|
540
|
+
self._is_fitted = True
|
541
|
+
return output_result
|
530
542
|
|
531
543
|
|
532
544
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -617,10 +629,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
617
629
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
618
630
|
|
619
631
|
if isinstance(dataset, DataFrame):
|
620
|
-
self.
|
621
|
-
|
622
|
-
inference_method=inference_method,
|
623
|
-
)
|
632
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
633
|
+
self._deps = self._get_dependencies()
|
624
634
|
assert isinstance(
|
625
635
|
dataset._session, Session
|
626
636
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -685,10 +695,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
685
695
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
686
696
|
|
687
697
|
if isinstance(dataset, DataFrame):
|
688
|
-
self.
|
689
|
-
|
690
|
-
inference_method=inference_method,
|
691
|
-
)
|
698
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
699
|
+
self._deps = self._get_dependencies()
|
692
700
|
assert isinstance(
|
693
701
|
dataset._session, Session
|
694
702
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -750,10 +758,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
750
758
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
751
759
|
|
752
760
|
if isinstance(dataset, DataFrame):
|
753
|
-
self.
|
754
|
-
|
755
|
-
inference_method=inference_method,
|
756
|
-
)
|
761
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
762
|
+
self._deps = self._get_dependencies()
|
757
763
|
assert isinstance(
|
758
764
|
dataset._session, Session
|
759
765
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -819,10 +825,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
819
825
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
820
826
|
|
821
827
|
if isinstance(dataset, DataFrame):
|
822
|
-
self.
|
823
|
-
|
824
|
-
inference_method=inference_method,
|
825
|
-
)
|
828
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
829
|
+
self._deps = self._get_dependencies()
|
826
830
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
827
831
|
transform_kwargs = dict(
|
828
832
|
session=dataset._session,
|
@@ -886,17 +890,15 @@ class OutputCodeClassifier(BaseTransformer):
|
|
886
890
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
887
891
|
|
888
892
|
if isinstance(dataset, DataFrame):
|
889
|
-
self.
|
890
|
-
|
891
|
-
inference_method="score",
|
892
|
-
)
|
893
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
894
|
+
self._deps = self._get_dependencies()
|
893
895
|
selected_cols = self._get_active_columns()
|
894
896
|
if len(selected_cols) > 0:
|
895
897
|
dataset = dataset.select(selected_cols)
|
896
898
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
897
899
|
transform_kwargs = dict(
|
898
900
|
session=dataset._session,
|
899
|
-
dependencies=
|
901
|
+
dependencies=self._deps,
|
900
902
|
score_sproc_imports=['sklearn'],
|
901
903
|
)
|
902
904
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -961,11 +963,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
961
963
|
|
962
964
|
if isinstance(dataset, DataFrame):
|
963
965
|
|
964
|
-
self.
|
965
|
-
|
966
|
-
inference_method=inference_method,
|
967
|
-
|
968
|
-
)
|
966
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
967
|
+
self._deps = self._get_dependencies()
|
969
968
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
970
969
|
transform_kwargs = dict(
|
971
970
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class BernoulliNB(BaseTransformer):
|
70
64
|
r"""Naive Bayes classifier for multivariate Bernoulli models
|
71
65
|
For more details on this class, see [sklearn.naive_bayes.BernoulliNB]
|
@@ -283,20 +277,17 @@ class BernoulliNB(BaseTransformer):
|
|
283
277
|
self,
|
284
278
|
dataset: DataFrame,
|
285
279
|
inference_method: str,
|
286
|
-
) ->
|
287
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
288
|
-
return the available package that exists in the snowflake anaconda channel
|
280
|
+
) -> None:
|
281
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
289
282
|
|
290
283
|
Args:
|
291
284
|
dataset: snowpark dataframe
|
292
285
|
inference_method: the inference method such as predict, score...
|
293
|
-
|
286
|
+
|
294
287
|
Raises:
|
295
288
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
296
289
|
SnowflakeMLException: If the session is None, raise error
|
297
290
|
|
298
|
-
Returns:
|
299
|
-
A list of available package that exists in the snowflake anaconda channel
|
300
291
|
"""
|
301
292
|
if not self._is_fitted:
|
302
293
|
raise exceptions.SnowflakeMLException(
|
@@ -314,9 +305,7 @@ class BernoulliNB(BaseTransformer):
|
|
314
305
|
"Session must not specified for snowpark dataset."
|
315
306
|
),
|
316
307
|
)
|
317
|
-
|
318
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
319
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
308
|
+
|
320
309
|
|
321
310
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
322
311
|
@telemetry.send_api_usage_telemetry(
|
@@ -364,7 +353,8 @@ class BernoulliNB(BaseTransformer):
|
|
364
353
|
|
365
354
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
366
355
|
|
367
|
-
self.
|
356
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
357
|
+
self._deps = self._get_dependencies()
|
368
358
|
assert isinstance(
|
369
359
|
dataset._session, Session
|
370
360
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -447,10 +437,8 @@ class BernoulliNB(BaseTransformer):
|
|
447
437
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
448
438
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
449
439
|
|
450
|
-
self.
|
451
|
-
|
452
|
-
inference_method=inference_method,
|
453
|
-
)
|
440
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
441
|
+
self._deps = self._get_dependencies()
|
454
442
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
455
443
|
|
456
444
|
transform_kwargs = dict(
|
@@ -517,16 +505,40 @@ class BernoulliNB(BaseTransformer):
|
|
517
505
|
self._is_fitted = True
|
518
506
|
return output_result
|
519
507
|
|
508
|
+
|
509
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
510
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
511
|
+
""" Method not supported for this class.
|
520
512
|
|
521
|
-
|
522
|
-
|
523
|
-
|
513
|
+
|
514
|
+
Raises:
|
515
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
519
|
+
Snowpark or Pandas DataFrame.
|
520
|
+
output_cols_prefix: Prefix for the response columns
|
524
521
|
Returns:
|
525
522
|
Transformed dataset.
|
526
523
|
"""
|
527
|
-
self.
|
528
|
-
|
529
|
-
|
524
|
+
self._infer_input_output_cols(dataset)
|
525
|
+
super()._check_dataset_type(dataset)
|
526
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
527
|
+
estimator=self._sklearn_object,
|
528
|
+
dataset=dataset,
|
529
|
+
input_cols=self.input_cols,
|
530
|
+
label_cols=self.label_cols,
|
531
|
+
sample_weight_col=self.sample_weight_col,
|
532
|
+
autogenerated=self._autogenerated,
|
533
|
+
subproject=_SUBPROJECT,
|
534
|
+
)
|
535
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
536
|
+
drop_input_cols=self._drop_input_cols,
|
537
|
+
expected_output_cols_list=self.output_cols,
|
538
|
+
)
|
539
|
+
self._sklearn_object = fitted_estimator
|
540
|
+
self._is_fitted = True
|
541
|
+
return output_result
|
530
542
|
|
531
543
|
|
532
544
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -619,10 +631,8 @@ class BernoulliNB(BaseTransformer):
|
|
619
631
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
620
632
|
|
621
633
|
if isinstance(dataset, DataFrame):
|
622
|
-
self.
|
623
|
-
|
624
|
-
inference_method=inference_method,
|
625
|
-
)
|
634
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
635
|
+
self._deps = self._get_dependencies()
|
626
636
|
assert isinstance(
|
627
637
|
dataset._session, Session
|
628
638
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -689,10 +699,8 @@ class BernoulliNB(BaseTransformer):
|
|
689
699
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
690
700
|
|
691
701
|
if isinstance(dataset, DataFrame):
|
692
|
-
self.
|
693
|
-
|
694
|
-
inference_method=inference_method,
|
695
|
-
)
|
702
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
703
|
+
self._deps = self._get_dependencies()
|
696
704
|
assert isinstance(
|
697
705
|
dataset._session, Session
|
698
706
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -754,10 +762,8 @@ class BernoulliNB(BaseTransformer):
|
|
754
762
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
755
763
|
|
756
764
|
if isinstance(dataset, DataFrame):
|
757
|
-
self.
|
758
|
-
|
759
|
-
inference_method=inference_method,
|
760
|
-
)
|
765
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
766
|
+
self._deps = self._get_dependencies()
|
761
767
|
assert isinstance(
|
762
768
|
dataset._session, Session
|
763
769
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -823,10 +829,8 @@ class BernoulliNB(BaseTransformer):
|
|
823
829
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
824
830
|
|
825
831
|
if isinstance(dataset, DataFrame):
|
826
|
-
self.
|
827
|
-
|
828
|
-
inference_method=inference_method,
|
829
|
-
)
|
832
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
833
|
+
self._deps = self._get_dependencies()
|
830
834
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
831
835
|
transform_kwargs = dict(
|
832
836
|
session=dataset._session,
|
@@ -890,17 +894,15 @@ class BernoulliNB(BaseTransformer):
|
|
890
894
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
891
895
|
|
892
896
|
if isinstance(dataset, DataFrame):
|
893
|
-
self.
|
894
|
-
|
895
|
-
inference_method="score",
|
896
|
-
)
|
897
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
898
|
+
self._deps = self._get_dependencies()
|
897
899
|
selected_cols = self._get_active_columns()
|
898
900
|
if len(selected_cols) > 0:
|
899
901
|
dataset = dataset.select(selected_cols)
|
900
902
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
901
903
|
transform_kwargs = dict(
|
902
904
|
session=dataset._session,
|
903
|
-
dependencies=
|
905
|
+
dependencies=self._deps,
|
904
906
|
score_sproc_imports=['sklearn'],
|
905
907
|
)
|
906
908
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -965,11 +967,8 @@ class BernoulliNB(BaseTransformer):
|
|
965
967
|
|
966
968
|
if isinstance(dataset, DataFrame):
|
967
969
|
|
968
|
-
self.
|
969
|
-
|
970
|
-
inference_method=inference_method,
|
971
|
-
|
972
|
-
)
|
970
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
971
|
+
self._deps = self._get_dependencies()
|
973
972
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
974
973
|
transform_kwargs = dict(
|
975
974
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class CategoricalNB(BaseTransformer):
|
70
64
|
r"""Naive Bayes classifier for categorical features
|
71
65
|
For more details on this class, see [sklearn.naive_bayes.CategoricalNB]
|
@@ -289,20 +283,17 @@ class CategoricalNB(BaseTransformer):
|
|
289
283
|
self,
|
290
284
|
dataset: DataFrame,
|
291
285
|
inference_method: str,
|
292
|
-
) ->
|
293
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
294
|
-
return the available package that exists in the snowflake anaconda channel
|
286
|
+
) -> None:
|
287
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
295
288
|
|
296
289
|
Args:
|
297
290
|
dataset: snowpark dataframe
|
298
291
|
inference_method: the inference method such as predict, score...
|
299
|
-
|
292
|
+
|
300
293
|
Raises:
|
301
294
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
302
295
|
SnowflakeMLException: If the session is None, raise error
|
303
296
|
|
304
|
-
Returns:
|
305
|
-
A list of available package that exists in the snowflake anaconda channel
|
306
297
|
"""
|
307
298
|
if not self._is_fitted:
|
308
299
|
raise exceptions.SnowflakeMLException(
|
@@ -320,9 +311,7 @@ class CategoricalNB(BaseTransformer):
|
|
320
311
|
"Session must not specified for snowpark dataset."
|
321
312
|
),
|
322
313
|
)
|
323
|
-
|
324
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
325
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
314
|
+
|
326
315
|
|
327
316
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
328
317
|
@telemetry.send_api_usage_telemetry(
|
@@ -370,7 +359,8 @@ class CategoricalNB(BaseTransformer):
|
|
370
359
|
|
371
360
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
372
361
|
|
373
|
-
self.
|
362
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
363
|
+
self._deps = self._get_dependencies()
|
374
364
|
assert isinstance(
|
375
365
|
dataset._session, Session
|
376
366
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -453,10 +443,8 @@ class CategoricalNB(BaseTransformer):
|
|
453
443
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
454
444
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
455
445
|
|
456
|
-
self.
|
457
|
-
|
458
|
-
inference_method=inference_method,
|
459
|
-
)
|
446
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
447
|
+
self._deps = self._get_dependencies()
|
460
448
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
461
449
|
|
462
450
|
transform_kwargs = dict(
|
@@ -523,16 +511,40 @@ class CategoricalNB(BaseTransformer):
|
|
523
511
|
self._is_fitted = True
|
524
512
|
return output_result
|
525
513
|
|
514
|
+
|
515
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
516
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
517
|
+
""" Method not supported for this class.
|
526
518
|
|
527
|
-
|
528
|
-
|
529
|
-
|
519
|
+
|
520
|
+
Raises:
|
521
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
522
|
+
|
523
|
+
Args:
|
524
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
525
|
+
Snowpark or Pandas DataFrame.
|
526
|
+
output_cols_prefix: Prefix for the response columns
|
530
527
|
Returns:
|
531
528
|
Transformed dataset.
|
532
529
|
"""
|
533
|
-
self.
|
534
|
-
|
535
|
-
|
530
|
+
self._infer_input_output_cols(dataset)
|
531
|
+
super()._check_dataset_type(dataset)
|
532
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
533
|
+
estimator=self._sklearn_object,
|
534
|
+
dataset=dataset,
|
535
|
+
input_cols=self.input_cols,
|
536
|
+
label_cols=self.label_cols,
|
537
|
+
sample_weight_col=self.sample_weight_col,
|
538
|
+
autogenerated=self._autogenerated,
|
539
|
+
subproject=_SUBPROJECT,
|
540
|
+
)
|
541
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
542
|
+
drop_input_cols=self._drop_input_cols,
|
543
|
+
expected_output_cols_list=self.output_cols,
|
544
|
+
)
|
545
|
+
self._sklearn_object = fitted_estimator
|
546
|
+
self._is_fitted = True
|
547
|
+
return output_result
|
536
548
|
|
537
549
|
|
538
550
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -625,10 +637,8 @@ class CategoricalNB(BaseTransformer):
|
|
625
637
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
626
638
|
|
627
639
|
if isinstance(dataset, DataFrame):
|
628
|
-
self.
|
629
|
-
|
630
|
-
inference_method=inference_method,
|
631
|
-
)
|
640
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
641
|
+
self._deps = self._get_dependencies()
|
632
642
|
assert isinstance(
|
633
643
|
dataset._session, Session
|
634
644
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -695,10 +705,8 @@ class CategoricalNB(BaseTransformer):
|
|
695
705
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
696
706
|
|
697
707
|
if isinstance(dataset, DataFrame):
|
698
|
-
self.
|
699
|
-
|
700
|
-
inference_method=inference_method,
|
701
|
-
)
|
708
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
709
|
+
self._deps = self._get_dependencies()
|
702
710
|
assert isinstance(
|
703
711
|
dataset._session, Session
|
704
712
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -760,10 +768,8 @@ class CategoricalNB(BaseTransformer):
|
|
760
768
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
761
769
|
|
762
770
|
if isinstance(dataset, DataFrame):
|
763
|
-
self.
|
764
|
-
|
765
|
-
inference_method=inference_method,
|
766
|
-
)
|
771
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
772
|
+
self._deps = self._get_dependencies()
|
767
773
|
assert isinstance(
|
768
774
|
dataset._session, Session
|
769
775
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -829,10 +835,8 @@ class CategoricalNB(BaseTransformer):
|
|
829
835
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
830
836
|
|
831
837
|
if isinstance(dataset, DataFrame):
|
832
|
-
self.
|
833
|
-
|
834
|
-
inference_method=inference_method,
|
835
|
-
)
|
838
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
839
|
+
self._deps = self._get_dependencies()
|
836
840
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
837
841
|
transform_kwargs = dict(
|
838
842
|
session=dataset._session,
|
@@ -896,17 +900,15 @@ class CategoricalNB(BaseTransformer):
|
|
896
900
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
897
901
|
|
898
902
|
if isinstance(dataset, DataFrame):
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method="score",
|
902
|
-
)
|
903
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
904
|
+
self._deps = self._get_dependencies()
|
903
905
|
selected_cols = self._get_active_columns()
|
904
906
|
if len(selected_cols) > 0:
|
905
907
|
dataset = dataset.select(selected_cols)
|
906
908
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
907
909
|
transform_kwargs = dict(
|
908
910
|
session=dataset._session,
|
909
|
-
dependencies=
|
911
|
+
dependencies=self._deps,
|
910
912
|
score_sproc_imports=['sklearn'],
|
911
913
|
)
|
912
914
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -971,11 +973,8 @@ class CategoricalNB(BaseTransformer):
|
|
971
973
|
|
972
974
|
if isinstance(dataset, DataFrame):
|
973
975
|
|
974
|
-
self.
|
975
|
-
|
976
|
-
inference_method=inference_method,
|
977
|
-
|
978
|
-
)
|
976
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
977
|
+
self._deps = self._get_dependencies()
|
979
978
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
980
979
|
transform_kwargs = dict(
|
981
980
|
session = dataset._session,
|