snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GraphicalLassoCV(BaseTransformer):
|
70
64
|
r"""Sparse inverse covariance w/ cross-validated choice of the l1 penalty
|
71
65
|
For more details on this class, see [sklearn.covariance.GraphicalLassoCV]
|
@@ -337,20 +331,17 @@ class GraphicalLassoCV(BaseTransformer):
|
|
337
331
|
self,
|
338
332
|
dataset: DataFrame,
|
339
333
|
inference_method: str,
|
340
|
-
) ->
|
341
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
342
|
-
return the available package that exists in the snowflake anaconda channel
|
334
|
+
) -> None:
|
335
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
343
336
|
|
344
337
|
Args:
|
345
338
|
dataset: snowpark dataframe
|
346
339
|
inference_method: the inference method such as predict, score...
|
347
|
-
|
340
|
+
|
348
341
|
Raises:
|
349
342
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
350
343
|
SnowflakeMLException: If the session is None, raise error
|
351
344
|
|
352
|
-
Returns:
|
353
|
-
A list of available package that exists in the snowflake anaconda channel
|
354
345
|
"""
|
355
346
|
if not self._is_fitted:
|
356
347
|
raise exceptions.SnowflakeMLException(
|
@@ -368,9 +359,7 @@ class GraphicalLassoCV(BaseTransformer):
|
|
368
359
|
"Session must not specified for snowpark dataset."
|
369
360
|
),
|
370
361
|
)
|
371
|
-
|
372
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
373
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
362
|
+
|
374
363
|
|
375
364
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
376
365
|
@telemetry.send_api_usage_telemetry(
|
@@ -416,7 +405,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
416
405
|
|
417
406
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
418
407
|
|
419
|
-
self.
|
408
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
409
|
+
self._deps = self._get_dependencies()
|
420
410
|
assert isinstance(
|
421
411
|
dataset._session, Session
|
422
412
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -499,10 +489,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
499
489
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
500
490
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
501
491
|
|
502
|
-
self.
|
503
|
-
|
504
|
-
inference_method=inference_method,
|
505
|
-
)
|
492
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
493
|
+
self._deps = self._get_dependencies()
|
506
494
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
507
495
|
|
508
496
|
transform_kwargs = dict(
|
@@ -569,16 +557,40 @@ class GraphicalLassoCV(BaseTransformer):
|
|
569
557
|
self._is_fitted = True
|
570
558
|
return output_result
|
571
559
|
|
560
|
+
|
561
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
562
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
563
|
+
""" Method not supported for this class.
|
572
564
|
|
573
|
-
|
574
|
-
|
575
|
-
|
565
|
+
|
566
|
+
Raises:
|
567
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
568
|
+
|
569
|
+
Args:
|
570
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
571
|
+
Snowpark or Pandas DataFrame.
|
572
|
+
output_cols_prefix: Prefix for the response columns
|
576
573
|
Returns:
|
577
574
|
Transformed dataset.
|
578
575
|
"""
|
579
|
-
self.
|
580
|
-
|
581
|
-
|
576
|
+
self._infer_input_output_cols(dataset)
|
577
|
+
super()._check_dataset_type(dataset)
|
578
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
579
|
+
estimator=self._sklearn_object,
|
580
|
+
dataset=dataset,
|
581
|
+
input_cols=self.input_cols,
|
582
|
+
label_cols=self.label_cols,
|
583
|
+
sample_weight_col=self.sample_weight_col,
|
584
|
+
autogenerated=self._autogenerated,
|
585
|
+
subproject=_SUBPROJECT,
|
586
|
+
)
|
587
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
588
|
+
drop_input_cols=self._drop_input_cols,
|
589
|
+
expected_output_cols_list=self.output_cols,
|
590
|
+
)
|
591
|
+
self._sklearn_object = fitted_estimator
|
592
|
+
self._is_fitted = True
|
593
|
+
return output_result
|
582
594
|
|
583
595
|
|
584
596
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -669,10 +681,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
669
681
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
670
682
|
|
671
683
|
if isinstance(dataset, DataFrame):
|
672
|
-
self.
|
673
|
-
|
674
|
-
inference_method=inference_method,
|
675
|
-
)
|
684
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
685
|
+
self._deps = self._get_dependencies()
|
676
686
|
assert isinstance(
|
677
687
|
dataset._session, Session
|
678
688
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -737,10 +747,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
737
747
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
738
748
|
|
739
749
|
if isinstance(dataset, DataFrame):
|
740
|
-
self.
|
741
|
-
|
742
|
-
inference_method=inference_method,
|
743
|
-
)
|
750
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
751
|
+
self._deps = self._get_dependencies()
|
744
752
|
assert isinstance(
|
745
753
|
dataset._session, Session
|
746
754
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -802,10 +810,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
802
810
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
803
811
|
|
804
812
|
if isinstance(dataset, DataFrame):
|
805
|
-
self.
|
806
|
-
|
807
|
-
inference_method=inference_method,
|
808
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
809
815
|
assert isinstance(
|
810
816
|
dataset._session, Session
|
811
817
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -871,10 +877,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
871
877
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
872
878
|
|
873
879
|
if isinstance(dataset, DataFrame):
|
874
|
-
self.
|
875
|
-
|
876
|
-
inference_method=inference_method,
|
877
|
-
)
|
880
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
881
|
+
self._deps = self._get_dependencies()
|
878
882
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
879
883
|
transform_kwargs = dict(
|
880
884
|
session=dataset._session,
|
@@ -938,17 +942,15 @@ class GraphicalLassoCV(BaseTransformer):
|
|
938
942
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
939
943
|
|
940
944
|
if isinstance(dataset, DataFrame):
|
941
|
-
self.
|
942
|
-
|
943
|
-
inference_method="score",
|
944
|
-
)
|
945
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
946
|
+
self._deps = self._get_dependencies()
|
945
947
|
selected_cols = self._get_active_columns()
|
946
948
|
if len(selected_cols) > 0:
|
947
949
|
dataset = dataset.select(selected_cols)
|
948
950
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
949
951
|
transform_kwargs = dict(
|
950
952
|
session=dataset._session,
|
951
|
-
dependencies=
|
953
|
+
dependencies=self._deps,
|
952
954
|
score_sproc_imports=['sklearn'],
|
953
955
|
)
|
954
956
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1013,11 +1015,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
1013
1015
|
|
1014
1016
|
if isinstance(dataset, DataFrame):
|
1015
1017
|
|
1016
|
-
self.
|
1017
|
-
|
1018
|
-
inference_method=inference_method,
|
1019
|
-
|
1020
|
-
)
|
1018
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1019
|
+
self._deps = self._get_dependencies()
|
1021
1020
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1022
1021
|
transform_kwargs = dict(
|
1023
1022
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LedoitWolf(BaseTransformer):
|
70
64
|
r"""LedoitWolf Estimator
|
71
65
|
For more details on this class, see [sklearn.covariance.LedoitWolf]
|
@@ -270,20 +264,17 @@ class LedoitWolf(BaseTransformer):
|
|
270
264
|
self,
|
271
265
|
dataset: DataFrame,
|
272
266
|
inference_method: str,
|
273
|
-
) ->
|
274
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
275
|
-
return the available package that exists in the snowflake anaconda channel
|
267
|
+
) -> None:
|
268
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
276
269
|
|
277
270
|
Args:
|
278
271
|
dataset: snowpark dataframe
|
279
272
|
inference_method: the inference method such as predict, score...
|
280
|
-
|
273
|
+
|
281
274
|
Raises:
|
282
275
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
283
276
|
SnowflakeMLException: If the session is None, raise error
|
284
277
|
|
285
|
-
Returns:
|
286
|
-
A list of available package that exists in the snowflake anaconda channel
|
287
278
|
"""
|
288
279
|
if not self._is_fitted:
|
289
280
|
raise exceptions.SnowflakeMLException(
|
@@ -301,9 +292,7 @@ class LedoitWolf(BaseTransformer):
|
|
301
292
|
"Session must not specified for snowpark dataset."
|
302
293
|
),
|
303
294
|
)
|
304
|
-
|
305
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
306
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
295
|
+
|
307
296
|
|
308
297
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
309
298
|
@telemetry.send_api_usage_telemetry(
|
@@ -349,7 +338,8 @@ class LedoitWolf(BaseTransformer):
|
|
349
338
|
|
350
339
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
351
340
|
|
352
|
-
self.
|
341
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
342
|
+
self._deps = self._get_dependencies()
|
353
343
|
assert isinstance(
|
354
344
|
dataset._session, Session
|
355
345
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -432,10 +422,8 @@ class LedoitWolf(BaseTransformer):
|
|
432
422
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
433
423
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
434
424
|
|
435
|
-
self.
|
436
|
-
|
437
|
-
inference_method=inference_method,
|
438
|
-
)
|
425
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
426
|
+
self._deps = self._get_dependencies()
|
439
427
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
440
428
|
|
441
429
|
transform_kwargs = dict(
|
@@ -502,16 +490,40 @@ class LedoitWolf(BaseTransformer):
|
|
502
490
|
self._is_fitted = True
|
503
491
|
return output_result
|
504
492
|
|
493
|
+
|
494
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
495
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
496
|
+
""" Method not supported for this class.
|
505
497
|
|
506
|
-
|
507
|
-
|
508
|
-
|
498
|
+
|
499
|
+
Raises:
|
500
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
504
|
+
Snowpark or Pandas DataFrame.
|
505
|
+
output_cols_prefix: Prefix for the response columns
|
509
506
|
Returns:
|
510
507
|
Transformed dataset.
|
511
508
|
"""
|
512
|
-
self.
|
513
|
-
|
514
|
-
|
509
|
+
self._infer_input_output_cols(dataset)
|
510
|
+
super()._check_dataset_type(dataset)
|
511
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
512
|
+
estimator=self._sklearn_object,
|
513
|
+
dataset=dataset,
|
514
|
+
input_cols=self.input_cols,
|
515
|
+
label_cols=self.label_cols,
|
516
|
+
sample_weight_col=self.sample_weight_col,
|
517
|
+
autogenerated=self._autogenerated,
|
518
|
+
subproject=_SUBPROJECT,
|
519
|
+
)
|
520
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
521
|
+
drop_input_cols=self._drop_input_cols,
|
522
|
+
expected_output_cols_list=self.output_cols,
|
523
|
+
)
|
524
|
+
self._sklearn_object = fitted_estimator
|
525
|
+
self._is_fitted = True
|
526
|
+
return output_result
|
515
527
|
|
516
528
|
|
517
529
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -602,10 +614,8 @@ class LedoitWolf(BaseTransformer):
|
|
602
614
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
603
615
|
|
604
616
|
if isinstance(dataset, DataFrame):
|
605
|
-
self.
|
606
|
-
|
607
|
-
inference_method=inference_method,
|
608
|
-
)
|
617
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
618
|
+
self._deps = self._get_dependencies()
|
609
619
|
assert isinstance(
|
610
620
|
dataset._session, Session
|
611
621
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -670,10 +680,8 @@ class LedoitWolf(BaseTransformer):
|
|
670
680
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
671
681
|
|
672
682
|
if isinstance(dataset, DataFrame):
|
673
|
-
self.
|
674
|
-
|
675
|
-
inference_method=inference_method,
|
676
|
-
)
|
683
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
684
|
+
self._deps = self._get_dependencies()
|
677
685
|
assert isinstance(
|
678
686
|
dataset._session, Session
|
679
687
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -735,10 +743,8 @@ class LedoitWolf(BaseTransformer):
|
|
735
743
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
736
744
|
|
737
745
|
if isinstance(dataset, DataFrame):
|
738
|
-
self.
|
739
|
-
|
740
|
-
inference_method=inference_method,
|
741
|
-
)
|
746
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
747
|
+
self._deps = self._get_dependencies()
|
742
748
|
assert isinstance(
|
743
749
|
dataset._session, Session
|
744
750
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -804,10 +810,8 @@ class LedoitWolf(BaseTransformer):
|
|
804
810
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
805
811
|
|
806
812
|
if isinstance(dataset, DataFrame):
|
807
|
-
self.
|
808
|
-
|
809
|
-
inference_method=inference_method,
|
810
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
811
815
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
812
816
|
transform_kwargs = dict(
|
813
817
|
session=dataset._session,
|
@@ -871,17 +875,15 @@ class LedoitWolf(BaseTransformer):
|
|
871
875
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
872
876
|
|
873
877
|
if isinstance(dataset, DataFrame):
|
874
|
-
self.
|
875
|
-
|
876
|
-
inference_method="score",
|
877
|
-
)
|
878
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
879
|
+
self._deps = self._get_dependencies()
|
878
880
|
selected_cols = self._get_active_columns()
|
879
881
|
if len(selected_cols) > 0:
|
880
882
|
dataset = dataset.select(selected_cols)
|
881
883
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
882
884
|
transform_kwargs = dict(
|
883
885
|
session=dataset._session,
|
884
|
-
dependencies=
|
886
|
+
dependencies=self._deps,
|
885
887
|
score_sproc_imports=['sklearn'],
|
886
888
|
)
|
887
889
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -946,11 +948,8 @@ class LedoitWolf(BaseTransformer):
|
|
946
948
|
|
947
949
|
if isinstance(dataset, DataFrame):
|
948
950
|
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method=inference_method,
|
952
|
-
|
953
|
-
)
|
951
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
952
|
+
self._deps = self._get_dependencies()
|
954
953
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
955
954
|
transform_kwargs = dict(
|
956
955
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MinCovDet(BaseTransformer):
|
70
64
|
r"""Minimum Covariance Determinant (MCD): robust estimator of covariance
|
71
65
|
For more details on this class, see [sklearn.covariance.MinCovDet]
|
@@ -282,20 +276,17 @@ class MinCovDet(BaseTransformer):
|
|
282
276
|
self,
|
283
277
|
dataset: DataFrame,
|
284
278
|
inference_method: str,
|
285
|
-
) ->
|
286
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
287
|
-
return the available package that exists in the snowflake anaconda channel
|
279
|
+
) -> None:
|
280
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
288
281
|
|
289
282
|
Args:
|
290
283
|
dataset: snowpark dataframe
|
291
284
|
inference_method: the inference method such as predict, score...
|
292
|
-
|
285
|
+
|
293
286
|
Raises:
|
294
287
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
295
288
|
SnowflakeMLException: If the session is None, raise error
|
296
289
|
|
297
|
-
Returns:
|
298
|
-
A list of available package that exists in the snowflake anaconda channel
|
299
290
|
"""
|
300
291
|
if not self._is_fitted:
|
301
292
|
raise exceptions.SnowflakeMLException(
|
@@ -313,9 +304,7 @@ class MinCovDet(BaseTransformer):
|
|
313
304
|
"Session must not specified for snowpark dataset."
|
314
305
|
),
|
315
306
|
)
|
316
|
-
|
317
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
318
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
307
|
+
|
319
308
|
|
320
309
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
321
310
|
@telemetry.send_api_usage_telemetry(
|
@@ -361,7 +350,8 @@ class MinCovDet(BaseTransformer):
|
|
361
350
|
|
362
351
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
363
352
|
|
364
|
-
self.
|
353
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
354
|
+
self._deps = self._get_dependencies()
|
365
355
|
assert isinstance(
|
366
356
|
dataset._session, Session
|
367
357
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -444,10 +434,8 @@ class MinCovDet(BaseTransformer):
|
|
444
434
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
445
435
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
446
436
|
|
447
|
-
self.
|
448
|
-
|
449
|
-
inference_method=inference_method,
|
450
|
-
)
|
437
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
438
|
+
self._deps = self._get_dependencies()
|
451
439
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
452
440
|
|
453
441
|
transform_kwargs = dict(
|
@@ -514,16 +502,40 @@ class MinCovDet(BaseTransformer):
|
|
514
502
|
self._is_fitted = True
|
515
503
|
return output_result
|
516
504
|
|
505
|
+
|
506
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
507
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
508
|
+
""" Method not supported for this class.
|
517
509
|
|
518
|
-
|
519
|
-
|
520
|
-
|
510
|
+
|
511
|
+
Raises:
|
512
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
513
|
+
|
514
|
+
Args:
|
515
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
516
|
+
Snowpark or Pandas DataFrame.
|
517
|
+
output_cols_prefix: Prefix for the response columns
|
521
518
|
Returns:
|
522
519
|
Transformed dataset.
|
523
520
|
"""
|
524
|
-
self.
|
525
|
-
|
526
|
-
|
521
|
+
self._infer_input_output_cols(dataset)
|
522
|
+
super()._check_dataset_type(dataset)
|
523
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
524
|
+
estimator=self._sklearn_object,
|
525
|
+
dataset=dataset,
|
526
|
+
input_cols=self.input_cols,
|
527
|
+
label_cols=self.label_cols,
|
528
|
+
sample_weight_col=self.sample_weight_col,
|
529
|
+
autogenerated=self._autogenerated,
|
530
|
+
subproject=_SUBPROJECT,
|
531
|
+
)
|
532
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
534
|
+
expected_output_cols_list=self.output_cols,
|
535
|
+
)
|
536
|
+
self._sklearn_object = fitted_estimator
|
537
|
+
self._is_fitted = True
|
538
|
+
return output_result
|
527
539
|
|
528
540
|
|
529
541
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -614,10 +626,8 @@ class MinCovDet(BaseTransformer):
|
|
614
626
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
615
627
|
|
616
628
|
if isinstance(dataset, DataFrame):
|
617
|
-
self.
|
618
|
-
|
619
|
-
inference_method=inference_method,
|
620
|
-
)
|
629
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
630
|
+
self._deps = self._get_dependencies()
|
621
631
|
assert isinstance(
|
622
632
|
dataset._session, Session
|
623
633
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -682,10 +692,8 @@ class MinCovDet(BaseTransformer):
|
|
682
692
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
683
693
|
|
684
694
|
if isinstance(dataset, DataFrame):
|
685
|
-
self.
|
686
|
-
|
687
|
-
inference_method=inference_method,
|
688
|
-
)
|
695
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
696
|
+
self._deps = self._get_dependencies()
|
689
697
|
assert isinstance(
|
690
698
|
dataset._session, Session
|
691
699
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -747,10 +755,8 @@ class MinCovDet(BaseTransformer):
|
|
747
755
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
748
756
|
|
749
757
|
if isinstance(dataset, DataFrame):
|
750
|
-
self.
|
751
|
-
|
752
|
-
inference_method=inference_method,
|
753
|
-
)
|
758
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
759
|
+
self._deps = self._get_dependencies()
|
754
760
|
assert isinstance(
|
755
761
|
dataset._session, Session
|
756
762
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -816,10 +822,8 @@ class MinCovDet(BaseTransformer):
|
|
816
822
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
817
823
|
|
818
824
|
if isinstance(dataset, DataFrame):
|
819
|
-
self.
|
820
|
-
|
821
|
-
inference_method=inference_method,
|
822
|
-
)
|
825
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
826
|
+
self._deps = self._get_dependencies()
|
823
827
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
824
828
|
transform_kwargs = dict(
|
825
829
|
session=dataset._session,
|
@@ -883,17 +887,15 @@ class MinCovDet(BaseTransformer):
|
|
883
887
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
884
888
|
|
885
889
|
if isinstance(dataset, DataFrame):
|
886
|
-
self.
|
887
|
-
|
888
|
-
inference_method="score",
|
889
|
-
)
|
890
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
891
|
+
self._deps = self._get_dependencies()
|
890
892
|
selected_cols = self._get_active_columns()
|
891
893
|
if len(selected_cols) > 0:
|
892
894
|
dataset = dataset.select(selected_cols)
|
893
895
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
894
896
|
transform_kwargs = dict(
|
895
897
|
session=dataset._session,
|
896
|
-
dependencies=
|
898
|
+
dependencies=self._deps,
|
897
899
|
score_sproc_imports=['sklearn'],
|
898
900
|
)
|
899
901
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -958,11 +960,8 @@ class MinCovDet(BaseTransformer):
|
|
958
960
|
|
959
961
|
if isinstance(dataset, DataFrame):
|
960
962
|
|
961
|
-
self.
|
962
|
-
|
963
|
-
inference_method=inference_method,
|
964
|
-
|
965
|
-
)
|
963
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
964
|
+
self._deps = self._get_dependencies()
|
966
965
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
967
966
|
transform_kwargs = dict(
|
968
967
|
session = dataset._session,
|