snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ComplementNB(BaseTransformer):
|
70
64
|
r"""The Complement Naive Bayes classifier described in Rennie et al
|
71
65
|
For more details on this class, see [sklearn.naive_bayes.ComplementNB]
|
@@ -283,20 +277,17 @@ class ComplementNB(BaseTransformer):
|
|
283
277
|
self,
|
284
278
|
dataset: DataFrame,
|
285
279
|
inference_method: str,
|
286
|
-
) ->
|
287
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
288
|
-
return the available package that exists in the snowflake anaconda channel
|
280
|
+
) -> None:
|
281
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
289
282
|
|
290
283
|
Args:
|
291
284
|
dataset: snowpark dataframe
|
292
285
|
inference_method: the inference method such as predict, score...
|
293
|
-
|
286
|
+
|
294
287
|
Raises:
|
295
288
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
296
289
|
SnowflakeMLException: If the session is None, raise error
|
297
290
|
|
298
|
-
Returns:
|
299
|
-
A list of available package that exists in the snowflake anaconda channel
|
300
291
|
"""
|
301
292
|
if not self._is_fitted:
|
302
293
|
raise exceptions.SnowflakeMLException(
|
@@ -314,9 +305,7 @@ class ComplementNB(BaseTransformer):
|
|
314
305
|
"Session must not specified for snowpark dataset."
|
315
306
|
),
|
316
307
|
)
|
317
|
-
|
318
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
319
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
308
|
+
|
320
309
|
|
321
310
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
322
311
|
@telemetry.send_api_usage_telemetry(
|
@@ -364,7 +353,8 @@ class ComplementNB(BaseTransformer):
|
|
364
353
|
|
365
354
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
366
355
|
|
367
|
-
self.
|
356
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
357
|
+
self._deps = self._get_dependencies()
|
368
358
|
assert isinstance(
|
369
359
|
dataset._session, Session
|
370
360
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -447,10 +437,8 @@ class ComplementNB(BaseTransformer):
|
|
447
437
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
448
438
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
449
439
|
|
450
|
-
self.
|
451
|
-
|
452
|
-
inference_method=inference_method,
|
453
|
-
)
|
440
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
441
|
+
self._deps = self._get_dependencies()
|
454
442
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
455
443
|
|
456
444
|
transform_kwargs = dict(
|
@@ -517,16 +505,40 @@ class ComplementNB(BaseTransformer):
|
|
517
505
|
self._is_fitted = True
|
518
506
|
return output_result
|
519
507
|
|
508
|
+
|
509
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
510
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
511
|
+
""" Method not supported for this class.
|
520
512
|
|
521
|
-
|
522
|
-
|
523
|
-
|
513
|
+
|
514
|
+
Raises:
|
515
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
519
|
+
Snowpark or Pandas DataFrame.
|
520
|
+
output_cols_prefix: Prefix for the response columns
|
524
521
|
Returns:
|
525
522
|
Transformed dataset.
|
526
523
|
"""
|
527
|
-
self.
|
528
|
-
|
529
|
-
|
524
|
+
self._infer_input_output_cols(dataset)
|
525
|
+
super()._check_dataset_type(dataset)
|
526
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
527
|
+
estimator=self._sklearn_object,
|
528
|
+
dataset=dataset,
|
529
|
+
input_cols=self.input_cols,
|
530
|
+
label_cols=self.label_cols,
|
531
|
+
sample_weight_col=self.sample_weight_col,
|
532
|
+
autogenerated=self._autogenerated,
|
533
|
+
subproject=_SUBPROJECT,
|
534
|
+
)
|
535
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
536
|
+
drop_input_cols=self._drop_input_cols,
|
537
|
+
expected_output_cols_list=self.output_cols,
|
538
|
+
)
|
539
|
+
self._sklearn_object = fitted_estimator
|
540
|
+
self._is_fitted = True
|
541
|
+
return output_result
|
530
542
|
|
531
543
|
|
532
544
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -619,10 +631,8 @@ class ComplementNB(BaseTransformer):
|
|
619
631
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
620
632
|
|
621
633
|
if isinstance(dataset, DataFrame):
|
622
|
-
self.
|
623
|
-
|
624
|
-
inference_method=inference_method,
|
625
|
-
)
|
634
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
635
|
+
self._deps = self._get_dependencies()
|
626
636
|
assert isinstance(
|
627
637
|
dataset._session, Session
|
628
638
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -689,10 +699,8 @@ class ComplementNB(BaseTransformer):
|
|
689
699
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
690
700
|
|
691
701
|
if isinstance(dataset, DataFrame):
|
692
|
-
self.
|
693
|
-
|
694
|
-
inference_method=inference_method,
|
695
|
-
)
|
702
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
703
|
+
self._deps = self._get_dependencies()
|
696
704
|
assert isinstance(
|
697
705
|
dataset._session, Session
|
698
706
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -754,10 +762,8 @@ class ComplementNB(BaseTransformer):
|
|
754
762
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
755
763
|
|
756
764
|
if isinstance(dataset, DataFrame):
|
757
|
-
self.
|
758
|
-
|
759
|
-
inference_method=inference_method,
|
760
|
-
)
|
765
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
766
|
+
self._deps = self._get_dependencies()
|
761
767
|
assert isinstance(
|
762
768
|
dataset._session, Session
|
763
769
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -823,10 +829,8 @@ class ComplementNB(BaseTransformer):
|
|
823
829
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
824
830
|
|
825
831
|
if isinstance(dataset, DataFrame):
|
826
|
-
self.
|
827
|
-
|
828
|
-
inference_method=inference_method,
|
829
|
-
)
|
832
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
833
|
+
self._deps = self._get_dependencies()
|
830
834
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
831
835
|
transform_kwargs = dict(
|
832
836
|
session=dataset._session,
|
@@ -890,17 +894,15 @@ class ComplementNB(BaseTransformer):
|
|
890
894
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
891
895
|
|
892
896
|
if isinstance(dataset, DataFrame):
|
893
|
-
self.
|
894
|
-
|
895
|
-
inference_method="score",
|
896
|
-
)
|
897
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
898
|
+
self._deps = self._get_dependencies()
|
897
899
|
selected_cols = self._get_active_columns()
|
898
900
|
if len(selected_cols) > 0:
|
899
901
|
dataset = dataset.select(selected_cols)
|
900
902
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
901
903
|
transform_kwargs = dict(
|
902
904
|
session=dataset._session,
|
903
|
-
dependencies=
|
905
|
+
dependencies=self._deps,
|
904
906
|
score_sproc_imports=['sklearn'],
|
905
907
|
)
|
906
908
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -965,11 +967,8 @@ class ComplementNB(BaseTransformer):
|
|
965
967
|
|
966
968
|
if isinstance(dataset, DataFrame):
|
967
969
|
|
968
|
-
self.
|
969
|
-
|
970
|
-
inference_method=inference_method,
|
971
|
-
|
972
|
-
)
|
970
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
971
|
+
self._deps = self._get_dependencies()
|
973
972
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
974
973
|
transform_kwargs = dict(
|
975
974
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GaussianNB(BaseTransformer):
|
70
64
|
r"""Gaussian Naive Bayes (GaussianNB)
|
71
65
|
For more details on this class, see [sklearn.naive_bayes.GaussianNB]
|
@@ -264,20 +258,17 @@ class GaussianNB(BaseTransformer):
|
|
264
258
|
self,
|
265
259
|
dataset: DataFrame,
|
266
260
|
inference_method: str,
|
267
|
-
) ->
|
268
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
269
|
-
return the available package that exists in the snowflake anaconda channel
|
261
|
+
) -> None:
|
262
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
270
263
|
|
271
264
|
Args:
|
272
265
|
dataset: snowpark dataframe
|
273
266
|
inference_method: the inference method such as predict, score...
|
274
|
-
|
267
|
+
|
275
268
|
Raises:
|
276
269
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
277
270
|
SnowflakeMLException: If the session is None, raise error
|
278
271
|
|
279
|
-
Returns:
|
280
|
-
A list of available package that exists in the snowflake anaconda channel
|
281
272
|
"""
|
282
273
|
if not self._is_fitted:
|
283
274
|
raise exceptions.SnowflakeMLException(
|
@@ -295,9 +286,7 @@ class GaussianNB(BaseTransformer):
|
|
295
286
|
"Session must not specified for snowpark dataset."
|
296
287
|
),
|
297
288
|
)
|
298
|
-
|
299
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
300
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
289
|
+
|
301
290
|
|
302
291
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
303
292
|
@telemetry.send_api_usage_telemetry(
|
@@ -345,7 +334,8 @@ class GaussianNB(BaseTransformer):
|
|
345
334
|
|
346
335
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
347
336
|
|
348
|
-
self.
|
337
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
338
|
+
self._deps = self._get_dependencies()
|
349
339
|
assert isinstance(
|
350
340
|
dataset._session, Session
|
351
341
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -428,10 +418,8 @@ class GaussianNB(BaseTransformer):
|
|
428
418
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
429
419
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
430
420
|
|
431
|
-
self.
|
432
|
-
|
433
|
-
inference_method=inference_method,
|
434
|
-
)
|
421
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
422
|
+
self._deps = self._get_dependencies()
|
435
423
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
436
424
|
|
437
425
|
transform_kwargs = dict(
|
@@ -498,16 +486,40 @@ class GaussianNB(BaseTransformer):
|
|
498
486
|
self._is_fitted = True
|
499
487
|
return output_result
|
500
488
|
|
489
|
+
|
490
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
491
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
492
|
+
""" Method not supported for this class.
|
501
493
|
|
502
|
-
|
503
|
-
|
504
|
-
|
494
|
+
|
495
|
+
Raises:
|
496
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
497
|
+
|
498
|
+
Args:
|
499
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
500
|
+
Snowpark or Pandas DataFrame.
|
501
|
+
output_cols_prefix: Prefix for the response columns
|
505
502
|
Returns:
|
506
503
|
Transformed dataset.
|
507
504
|
"""
|
508
|
-
self.
|
509
|
-
|
510
|
-
|
505
|
+
self._infer_input_output_cols(dataset)
|
506
|
+
super()._check_dataset_type(dataset)
|
507
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
508
|
+
estimator=self._sklearn_object,
|
509
|
+
dataset=dataset,
|
510
|
+
input_cols=self.input_cols,
|
511
|
+
label_cols=self.label_cols,
|
512
|
+
sample_weight_col=self.sample_weight_col,
|
513
|
+
autogenerated=self._autogenerated,
|
514
|
+
subproject=_SUBPROJECT,
|
515
|
+
)
|
516
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
517
|
+
drop_input_cols=self._drop_input_cols,
|
518
|
+
expected_output_cols_list=self.output_cols,
|
519
|
+
)
|
520
|
+
self._sklearn_object = fitted_estimator
|
521
|
+
self._is_fitted = True
|
522
|
+
return output_result
|
511
523
|
|
512
524
|
|
513
525
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -600,10 +612,8 @@ class GaussianNB(BaseTransformer):
|
|
600
612
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
601
613
|
|
602
614
|
if isinstance(dataset, DataFrame):
|
603
|
-
self.
|
604
|
-
|
605
|
-
inference_method=inference_method,
|
606
|
-
)
|
615
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
616
|
+
self._deps = self._get_dependencies()
|
607
617
|
assert isinstance(
|
608
618
|
dataset._session, Session
|
609
619
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -670,10 +680,8 @@ class GaussianNB(BaseTransformer):
|
|
670
680
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
671
681
|
|
672
682
|
if isinstance(dataset, DataFrame):
|
673
|
-
self.
|
674
|
-
|
675
|
-
inference_method=inference_method,
|
676
|
-
)
|
683
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
684
|
+
self._deps = self._get_dependencies()
|
677
685
|
assert isinstance(
|
678
686
|
dataset._session, Session
|
679
687
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -735,10 +743,8 @@ class GaussianNB(BaseTransformer):
|
|
735
743
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
736
744
|
|
737
745
|
if isinstance(dataset, DataFrame):
|
738
|
-
self.
|
739
|
-
|
740
|
-
inference_method=inference_method,
|
741
|
-
)
|
746
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
747
|
+
self._deps = self._get_dependencies()
|
742
748
|
assert isinstance(
|
743
749
|
dataset._session, Session
|
744
750
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -804,10 +810,8 @@ class GaussianNB(BaseTransformer):
|
|
804
810
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
805
811
|
|
806
812
|
if isinstance(dataset, DataFrame):
|
807
|
-
self.
|
808
|
-
|
809
|
-
inference_method=inference_method,
|
810
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
811
815
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
812
816
|
transform_kwargs = dict(
|
813
817
|
session=dataset._session,
|
@@ -871,17 +875,15 @@ class GaussianNB(BaseTransformer):
|
|
871
875
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
872
876
|
|
873
877
|
if isinstance(dataset, DataFrame):
|
874
|
-
self.
|
875
|
-
|
876
|
-
inference_method="score",
|
877
|
-
)
|
878
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
879
|
+
self._deps = self._get_dependencies()
|
878
880
|
selected_cols = self._get_active_columns()
|
879
881
|
if len(selected_cols) > 0:
|
880
882
|
dataset = dataset.select(selected_cols)
|
881
883
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
882
884
|
transform_kwargs = dict(
|
883
885
|
session=dataset._session,
|
884
|
-
dependencies=
|
886
|
+
dependencies=self._deps,
|
885
887
|
score_sproc_imports=['sklearn'],
|
886
888
|
)
|
887
889
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -946,11 +948,8 @@ class GaussianNB(BaseTransformer):
|
|
946
948
|
|
947
949
|
if isinstance(dataset, DataFrame):
|
948
950
|
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method=inference_method,
|
952
|
-
|
953
|
-
)
|
951
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
952
|
+
self._deps = self._get_dependencies()
|
954
953
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
955
954
|
transform_kwargs = dict(
|
956
955
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MultinomialNB(BaseTransformer):
|
70
64
|
r"""Naive Bayes classifier for multinomial models
|
71
65
|
For more details on this class, see [sklearn.naive_bayes.MultinomialNB]
|
@@ -277,20 +271,17 @@ class MultinomialNB(BaseTransformer):
|
|
277
271
|
self,
|
278
272
|
dataset: DataFrame,
|
279
273
|
inference_method: str,
|
280
|
-
) ->
|
281
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
282
|
-
return the available package that exists in the snowflake anaconda channel
|
274
|
+
) -> None:
|
275
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
283
276
|
|
284
277
|
Args:
|
285
278
|
dataset: snowpark dataframe
|
286
279
|
inference_method: the inference method such as predict, score...
|
287
|
-
|
280
|
+
|
288
281
|
Raises:
|
289
282
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
290
283
|
SnowflakeMLException: If the session is None, raise error
|
291
284
|
|
292
|
-
Returns:
|
293
|
-
A list of available package that exists in the snowflake anaconda channel
|
294
285
|
"""
|
295
286
|
if not self._is_fitted:
|
296
287
|
raise exceptions.SnowflakeMLException(
|
@@ -308,9 +299,7 @@ class MultinomialNB(BaseTransformer):
|
|
308
299
|
"Session must not specified for snowpark dataset."
|
309
300
|
),
|
310
301
|
)
|
311
|
-
|
312
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
313
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
302
|
+
|
314
303
|
|
315
304
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
316
305
|
@telemetry.send_api_usage_telemetry(
|
@@ -358,7 +347,8 @@ class MultinomialNB(BaseTransformer):
|
|
358
347
|
|
359
348
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
360
349
|
|
361
|
-
self.
|
350
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
351
|
+
self._deps = self._get_dependencies()
|
362
352
|
assert isinstance(
|
363
353
|
dataset._session, Session
|
364
354
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -441,10 +431,8 @@ class MultinomialNB(BaseTransformer):
|
|
441
431
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
442
432
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
443
433
|
|
444
|
-
self.
|
445
|
-
|
446
|
-
inference_method=inference_method,
|
447
|
-
)
|
434
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
435
|
+
self._deps = self._get_dependencies()
|
448
436
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
449
437
|
|
450
438
|
transform_kwargs = dict(
|
@@ -511,16 +499,40 @@ class MultinomialNB(BaseTransformer):
|
|
511
499
|
self._is_fitted = True
|
512
500
|
return output_result
|
513
501
|
|
502
|
+
|
503
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
504
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
505
|
+
""" Method not supported for this class.
|
514
506
|
|
515
|
-
|
516
|
-
|
517
|
-
|
507
|
+
|
508
|
+
Raises:
|
509
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
510
|
+
|
511
|
+
Args:
|
512
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
513
|
+
Snowpark or Pandas DataFrame.
|
514
|
+
output_cols_prefix: Prefix for the response columns
|
518
515
|
Returns:
|
519
516
|
Transformed dataset.
|
520
517
|
"""
|
521
|
-
self.
|
522
|
-
|
523
|
-
|
518
|
+
self._infer_input_output_cols(dataset)
|
519
|
+
super()._check_dataset_type(dataset)
|
520
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
521
|
+
estimator=self._sklearn_object,
|
522
|
+
dataset=dataset,
|
523
|
+
input_cols=self.input_cols,
|
524
|
+
label_cols=self.label_cols,
|
525
|
+
sample_weight_col=self.sample_weight_col,
|
526
|
+
autogenerated=self._autogenerated,
|
527
|
+
subproject=_SUBPROJECT,
|
528
|
+
)
|
529
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
530
|
+
drop_input_cols=self._drop_input_cols,
|
531
|
+
expected_output_cols_list=self.output_cols,
|
532
|
+
)
|
533
|
+
self._sklearn_object = fitted_estimator
|
534
|
+
self._is_fitted = True
|
535
|
+
return output_result
|
524
536
|
|
525
537
|
|
526
538
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -613,10 +625,8 @@ class MultinomialNB(BaseTransformer):
|
|
613
625
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
614
626
|
|
615
627
|
if isinstance(dataset, DataFrame):
|
616
|
-
self.
|
617
|
-
|
618
|
-
inference_method=inference_method,
|
619
|
-
)
|
628
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
629
|
+
self._deps = self._get_dependencies()
|
620
630
|
assert isinstance(
|
621
631
|
dataset._session, Session
|
622
632
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -683,10 +693,8 @@ class MultinomialNB(BaseTransformer):
|
|
683
693
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
684
694
|
|
685
695
|
if isinstance(dataset, DataFrame):
|
686
|
-
self.
|
687
|
-
|
688
|
-
inference_method=inference_method,
|
689
|
-
)
|
696
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
697
|
+
self._deps = self._get_dependencies()
|
690
698
|
assert isinstance(
|
691
699
|
dataset._session, Session
|
692
700
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -748,10 +756,8 @@ class MultinomialNB(BaseTransformer):
|
|
748
756
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
749
757
|
|
750
758
|
if isinstance(dataset, DataFrame):
|
751
|
-
self.
|
752
|
-
|
753
|
-
inference_method=inference_method,
|
754
|
-
)
|
759
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
760
|
+
self._deps = self._get_dependencies()
|
755
761
|
assert isinstance(
|
756
762
|
dataset._session, Session
|
757
763
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -817,10 +823,8 @@ class MultinomialNB(BaseTransformer):
|
|
817
823
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
818
824
|
|
819
825
|
if isinstance(dataset, DataFrame):
|
820
|
-
self.
|
821
|
-
|
822
|
-
inference_method=inference_method,
|
823
|
-
)
|
826
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
827
|
+
self._deps = self._get_dependencies()
|
824
828
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
825
829
|
transform_kwargs = dict(
|
826
830
|
session=dataset._session,
|
@@ -884,17 +888,15 @@ class MultinomialNB(BaseTransformer):
|
|
884
888
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
885
889
|
|
886
890
|
if isinstance(dataset, DataFrame):
|
887
|
-
self.
|
888
|
-
|
889
|
-
inference_method="score",
|
890
|
-
)
|
891
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
892
|
+
self._deps = self._get_dependencies()
|
891
893
|
selected_cols = self._get_active_columns()
|
892
894
|
if len(selected_cols) > 0:
|
893
895
|
dataset = dataset.select(selected_cols)
|
894
896
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
895
897
|
transform_kwargs = dict(
|
896
898
|
session=dataset._session,
|
897
|
-
dependencies=
|
899
|
+
dependencies=self._deps,
|
898
900
|
score_sproc_imports=['sklearn'],
|
899
901
|
)
|
900
902
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -959,11 +961,8 @@ class MultinomialNB(BaseTransformer):
|
|
959
961
|
|
960
962
|
if isinstance(dataset, DataFrame):
|
961
963
|
|
962
|
-
self.
|
963
|
-
|
964
|
-
inference_method=inference_method,
|
965
|
-
|
966
|
-
)
|
964
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
965
|
+
self._deps = self._get_dependencies()
|
967
966
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
968
967
|
transform_kwargs = dict(
|
969
968
|
session = dataset._session,
|