snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class OAS(BaseTransformer):
|
70
64
|
r"""Oracle Approximating Shrinkage Estimator as proposed in [1]_
|
71
65
|
For more details on this class, see [sklearn.covariance.OAS]
|
@@ -263,20 +257,17 @@ class OAS(BaseTransformer):
|
|
263
257
|
self,
|
264
258
|
dataset: DataFrame,
|
265
259
|
inference_method: str,
|
266
|
-
) ->
|
267
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
268
|
-
return the available package that exists in the snowflake anaconda channel
|
260
|
+
) -> None:
|
261
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
269
262
|
|
270
263
|
Args:
|
271
264
|
dataset: snowpark dataframe
|
272
265
|
inference_method: the inference method such as predict, score...
|
273
|
-
|
266
|
+
|
274
267
|
Raises:
|
275
268
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
276
269
|
SnowflakeMLException: If the session is None, raise error
|
277
270
|
|
278
|
-
Returns:
|
279
|
-
A list of available package that exists in the snowflake anaconda channel
|
280
271
|
"""
|
281
272
|
if not self._is_fitted:
|
282
273
|
raise exceptions.SnowflakeMLException(
|
@@ -294,9 +285,7 @@ class OAS(BaseTransformer):
|
|
294
285
|
"Session must not specified for snowpark dataset."
|
295
286
|
),
|
296
287
|
)
|
297
|
-
|
298
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
299
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
288
|
+
|
300
289
|
|
301
290
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
302
291
|
@telemetry.send_api_usage_telemetry(
|
@@ -342,7 +331,8 @@ class OAS(BaseTransformer):
|
|
342
331
|
|
343
332
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
344
333
|
|
345
|
-
self.
|
334
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
335
|
+
self._deps = self._get_dependencies()
|
346
336
|
assert isinstance(
|
347
337
|
dataset._session, Session
|
348
338
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -425,10 +415,8 @@ class OAS(BaseTransformer):
|
|
425
415
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
426
416
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
427
417
|
|
428
|
-
self.
|
429
|
-
|
430
|
-
inference_method=inference_method,
|
431
|
-
)
|
418
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
419
|
+
self._deps = self._get_dependencies()
|
432
420
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
433
421
|
|
434
422
|
transform_kwargs = dict(
|
@@ -495,16 +483,40 @@ class OAS(BaseTransformer):
|
|
495
483
|
self._is_fitted = True
|
496
484
|
return output_result
|
497
485
|
|
486
|
+
|
487
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
488
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
489
|
+
""" Method not supported for this class.
|
498
490
|
|
499
|
-
|
500
|
-
|
501
|
-
|
491
|
+
|
492
|
+
Raises:
|
493
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
494
|
+
|
495
|
+
Args:
|
496
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
497
|
+
Snowpark or Pandas DataFrame.
|
498
|
+
output_cols_prefix: Prefix for the response columns
|
502
499
|
Returns:
|
503
500
|
Transformed dataset.
|
504
501
|
"""
|
505
|
-
self.
|
506
|
-
|
507
|
-
|
502
|
+
self._infer_input_output_cols(dataset)
|
503
|
+
super()._check_dataset_type(dataset)
|
504
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
505
|
+
estimator=self._sklearn_object,
|
506
|
+
dataset=dataset,
|
507
|
+
input_cols=self.input_cols,
|
508
|
+
label_cols=self.label_cols,
|
509
|
+
sample_weight_col=self.sample_weight_col,
|
510
|
+
autogenerated=self._autogenerated,
|
511
|
+
subproject=_SUBPROJECT,
|
512
|
+
)
|
513
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
514
|
+
drop_input_cols=self._drop_input_cols,
|
515
|
+
expected_output_cols_list=self.output_cols,
|
516
|
+
)
|
517
|
+
self._sklearn_object = fitted_estimator
|
518
|
+
self._is_fitted = True
|
519
|
+
return output_result
|
508
520
|
|
509
521
|
|
510
522
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -595,10 +607,8 @@ class OAS(BaseTransformer):
|
|
595
607
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
596
608
|
|
597
609
|
if isinstance(dataset, DataFrame):
|
598
|
-
self.
|
599
|
-
|
600
|
-
inference_method=inference_method,
|
601
|
-
)
|
610
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
611
|
+
self._deps = self._get_dependencies()
|
602
612
|
assert isinstance(
|
603
613
|
dataset._session, Session
|
604
614
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -663,10 +673,8 @@ class OAS(BaseTransformer):
|
|
663
673
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
664
674
|
|
665
675
|
if isinstance(dataset, DataFrame):
|
666
|
-
self.
|
667
|
-
|
668
|
-
inference_method=inference_method,
|
669
|
-
)
|
676
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
677
|
+
self._deps = self._get_dependencies()
|
670
678
|
assert isinstance(
|
671
679
|
dataset._session, Session
|
672
680
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -728,10 +736,8 @@ class OAS(BaseTransformer):
|
|
728
736
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
729
737
|
|
730
738
|
if isinstance(dataset, DataFrame):
|
731
|
-
self.
|
732
|
-
|
733
|
-
inference_method=inference_method,
|
734
|
-
)
|
739
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
740
|
+
self._deps = self._get_dependencies()
|
735
741
|
assert isinstance(
|
736
742
|
dataset._session, Session
|
737
743
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -797,10 +803,8 @@ class OAS(BaseTransformer):
|
|
797
803
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
798
804
|
|
799
805
|
if isinstance(dataset, DataFrame):
|
800
|
-
self.
|
801
|
-
|
802
|
-
inference_method=inference_method,
|
803
|
-
)
|
806
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
807
|
+
self._deps = self._get_dependencies()
|
804
808
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
805
809
|
transform_kwargs = dict(
|
806
810
|
session=dataset._session,
|
@@ -864,17 +868,15 @@ class OAS(BaseTransformer):
|
|
864
868
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
865
869
|
|
866
870
|
if isinstance(dataset, DataFrame):
|
867
|
-
self.
|
868
|
-
|
869
|
-
inference_method="score",
|
870
|
-
)
|
871
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
872
|
+
self._deps = self._get_dependencies()
|
871
873
|
selected_cols = self._get_active_columns()
|
872
874
|
if len(selected_cols) > 0:
|
873
875
|
dataset = dataset.select(selected_cols)
|
874
876
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
875
877
|
transform_kwargs = dict(
|
876
878
|
session=dataset._session,
|
877
|
-
dependencies=
|
879
|
+
dependencies=self._deps,
|
878
880
|
score_sproc_imports=['sklearn'],
|
879
881
|
)
|
880
882
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -939,11 +941,8 @@ class OAS(BaseTransformer):
|
|
939
941
|
|
940
942
|
if isinstance(dataset, DataFrame):
|
941
943
|
|
942
|
-
self.
|
943
|
-
|
944
|
-
inference_method=inference_method,
|
945
|
-
|
946
|
-
)
|
944
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
945
|
+
self._deps = self._get_dependencies()
|
947
946
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
948
947
|
transform_kwargs = dict(
|
949
948
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ShrunkCovariance(BaseTransformer):
|
70
64
|
r"""Covariance estimator with shrinkage
|
71
65
|
For more details on this class, see [sklearn.covariance.ShrunkCovariance]
|
@@ -269,20 +263,17 @@ class ShrunkCovariance(BaseTransformer):
|
|
269
263
|
self,
|
270
264
|
dataset: DataFrame,
|
271
265
|
inference_method: str,
|
272
|
-
) ->
|
273
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
274
|
-
return the available package that exists in the snowflake anaconda channel
|
266
|
+
) -> None:
|
267
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
275
268
|
|
276
269
|
Args:
|
277
270
|
dataset: snowpark dataframe
|
278
271
|
inference_method: the inference method such as predict, score...
|
279
|
-
|
272
|
+
|
280
273
|
Raises:
|
281
274
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
282
275
|
SnowflakeMLException: If the session is None, raise error
|
283
276
|
|
284
|
-
Returns:
|
285
|
-
A list of available package that exists in the snowflake anaconda channel
|
286
277
|
"""
|
287
278
|
if not self._is_fitted:
|
288
279
|
raise exceptions.SnowflakeMLException(
|
@@ -300,9 +291,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
300
291
|
"Session must not specified for snowpark dataset."
|
301
292
|
),
|
302
293
|
)
|
303
|
-
|
304
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
305
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
294
|
+
|
306
295
|
|
307
296
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
308
297
|
@telemetry.send_api_usage_telemetry(
|
@@ -348,7 +337,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
348
337
|
|
349
338
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
350
339
|
|
351
|
-
self.
|
340
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
341
|
+
self._deps = self._get_dependencies()
|
352
342
|
assert isinstance(
|
353
343
|
dataset._session, Session
|
354
344
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -431,10 +421,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
431
421
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
432
422
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
433
423
|
|
434
|
-
self.
|
435
|
-
|
436
|
-
inference_method=inference_method,
|
437
|
-
)
|
424
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
425
|
+
self._deps = self._get_dependencies()
|
438
426
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
439
427
|
|
440
428
|
transform_kwargs = dict(
|
@@ -501,16 +489,40 @@ class ShrunkCovariance(BaseTransformer):
|
|
501
489
|
self._is_fitted = True
|
502
490
|
return output_result
|
503
491
|
|
492
|
+
|
493
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
494
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
495
|
+
""" Method not supported for this class.
|
504
496
|
|
505
|
-
|
506
|
-
|
507
|
-
|
497
|
+
|
498
|
+
Raises:
|
499
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
500
|
+
|
501
|
+
Args:
|
502
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
503
|
+
Snowpark or Pandas DataFrame.
|
504
|
+
output_cols_prefix: Prefix for the response columns
|
508
505
|
Returns:
|
509
506
|
Transformed dataset.
|
510
507
|
"""
|
511
|
-
self.
|
512
|
-
|
513
|
-
|
508
|
+
self._infer_input_output_cols(dataset)
|
509
|
+
super()._check_dataset_type(dataset)
|
510
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
511
|
+
estimator=self._sklearn_object,
|
512
|
+
dataset=dataset,
|
513
|
+
input_cols=self.input_cols,
|
514
|
+
label_cols=self.label_cols,
|
515
|
+
sample_weight_col=self.sample_weight_col,
|
516
|
+
autogenerated=self._autogenerated,
|
517
|
+
subproject=_SUBPROJECT,
|
518
|
+
)
|
519
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
520
|
+
drop_input_cols=self._drop_input_cols,
|
521
|
+
expected_output_cols_list=self.output_cols,
|
522
|
+
)
|
523
|
+
self._sklearn_object = fitted_estimator
|
524
|
+
self._is_fitted = True
|
525
|
+
return output_result
|
514
526
|
|
515
527
|
|
516
528
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -601,10 +613,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
601
613
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
602
614
|
|
603
615
|
if isinstance(dataset, DataFrame):
|
604
|
-
self.
|
605
|
-
|
606
|
-
inference_method=inference_method,
|
607
|
-
)
|
616
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
617
|
+
self._deps = self._get_dependencies()
|
608
618
|
assert isinstance(
|
609
619
|
dataset._session, Session
|
610
620
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -669,10 +679,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
669
679
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
670
680
|
|
671
681
|
if isinstance(dataset, DataFrame):
|
672
|
-
self.
|
673
|
-
|
674
|
-
inference_method=inference_method,
|
675
|
-
)
|
682
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
683
|
+
self._deps = self._get_dependencies()
|
676
684
|
assert isinstance(
|
677
685
|
dataset._session, Session
|
678
686
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -734,10 +742,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
734
742
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
735
743
|
|
736
744
|
if isinstance(dataset, DataFrame):
|
737
|
-
self.
|
738
|
-
|
739
|
-
inference_method=inference_method,
|
740
|
-
)
|
745
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
746
|
+
self._deps = self._get_dependencies()
|
741
747
|
assert isinstance(
|
742
748
|
dataset._session, Session
|
743
749
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -803,10 +809,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
803
809
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
804
810
|
|
805
811
|
if isinstance(dataset, DataFrame):
|
806
|
-
self.
|
807
|
-
|
808
|
-
inference_method=inference_method,
|
809
|
-
)
|
812
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
813
|
+
self._deps = self._get_dependencies()
|
810
814
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
811
815
|
transform_kwargs = dict(
|
812
816
|
session=dataset._session,
|
@@ -870,17 +874,15 @@ class ShrunkCovariance(BaseTransformer):
|
|
870
874
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
871
875
|
|
872
876
|
if isinstance(dataset, DataFrame):
|
873
|
-
self.
|
874
|
-
|
875
|
-
inference_method="score",
|
876
|
-
)
|
877
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
878
|
+
self._deps = self._get_dependencies()
|
877
879
|
selected_cols = self._get_active_columns()
|
878
880
|
if len(selected_cols) > 0:
|
879
881
|
dataset = dataset.select(selected_cols)
|
880
882
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
881
883
|
transform_kwargs = dict(
|
882
884
|
session=dataset._session,
|
883
|
-
dependencies=
|
885
|
+
dependencies=self._deps,
|
884
886
|
score_sproc_imports=['sklearn'],
|
885
887
|
)
|
886
888
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -945,11 +947,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
945
947
|
|
946
948
|
if isinstance(dataset, DataFrame):
|
947
949
|
|
948
|
-
self.
|
949
|
-
|
950
|
-
inference_method=inference_method,
|
951
|
-
|
952
|
-
)
|
950
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
951
|
+
self._deps = self._get_dependencies()
|
953
952
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
954
953
|
transform_kwargs = dict(
|
955
954
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class DictionaryLearning(BaseTransformer):
|
70
64
|
r"""Dictionary learning
|
71
65
|
For more details on this class, see [sklearn.decomposition.DictionaryLearning]
|
@@ -375,20 +369,17 @@ class DictionaryLearning(BaseTransformer):
|
|
375
369
|
self,
|
376
370
|
dataset: DataFrame,
|
377
371
|
inference_method: str,
|
378
|
-
) ->
|
379
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
380
|
-
return the available package that exists in the snowflake anaconda channel
|
372
|
+
) -> None:
|
373
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
381
374
|
|
382
375
|
Args:
|
383
376
|
dataset: snowpark dataframe
|
384
377
|
inference_method: the inference method such as predict, score...
|
385
|
-
|
378
|
+
|
386
379
|
Raises:
|
387
380
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
388
381
|
SnowflakeMLException: If the session is None, raise error
|
389
382
|
|
390
|
-
Returns:
|
391
|
-
A list of available package that exists in the snowflake anaconda channel
|
392
383
|
"""
|
393
384
|
if not self._is_fitted:
|
394
385
|
raise exceptions.SnowflakeMLException(
|
@@ -406,9 +397,7 @@ class DictionaryLearning(BaseTransformer):
|
|
406
397
|
"Session must not specified for snowpark dataset."
|
407
398
|
),
|
408
399
|
)
|
409
|
-
|
410
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
411
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
400
|
+
|
412
401
|
|
413
402
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
414
403
|
@telemetry.send_api_usage_telemetry(
|
@@ -454,7 +443,8 @@ class DictionaryLearning(BaseTransformer):
|
|
454
443
|
|
455
444
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
456
445
|
|
457
|
-
self.
|
446
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
447
|
+
self._deps = self._get_dependencies()
|
458
448
|
assert isinstance(
|
459
449
|
dataset._session, Session
|
460
450
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -539,10 +529,8 @@ class DictionaryLearning(BaseTransformer):
|
|
539
529
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
540
530
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
541
531
|
|
542
|
-
self.
|
543
|
-
|
544
|
-
inference_method=inference_method,
|
545
|
-
)
|
532
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
533
|
+
self._deps = self._get_dependencies()
|
546
534
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
547
535
|
|
548
536
|
transform_kwargs = dict(
|
@@ -609,16 +597,42 @@ class DictionaryLearning(BaseTransformer):
|
|
609
597
|
self._is_fitted = True
|
610
598
|
return output_result
|
611
599
|
|
600
|
+
|
601
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
602
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
603
|
+
""" Fit the model from data in X and return the transformed data
|
604
|
+
For more details on this function, see [sklearn.decomposition.DictionaryLearning.fit_transform]
|
605
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.DictionaryLearning.html#sklearn.decomposition.DictionaryLearning.fit_transform)
|
606
|
+
|
612
607
|
|
613
|
-
|
614
|
-
|
615
|
-
|
608
|
+
Raises:
|
609
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
610
|
+
|
611
|
+
Args:
|
612
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
613
|
+
Snowpark or Pandas DataFrame.
|
614
|
+
output_cols_prefix: Prefix for the response columns
|
616
615
|
Returns:
|
617
616
|
Transformed dataset.
|
618
617
|
"""
|
619
|
-
self.
|
620
|
-
|
621
|
-
|
618
|
+
self._infer_input_output_cols(dataset)
|
619
|
+
super()._check_dataset_type(dataset)
|
620
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
621
|
+
estimator=self._sklearn_object,
|
622
|
+
dataset=dataset,
|
623
|
+
input_cols=self.input_cols,
|
624
|
+
label_cols=self.label_cols,
|
625
|
+
sample_weight_col=self.sample_weight_col,
|
626
|
+
autogenerated=self._autogenerated,
|
627
|
+
subproject=_SUBPROJECT,
|
628
|
+
)
|
629
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
630
|
+
drop_input_cols=self._drop_input_cols,
|
631
|
+
expected_output_cols_list=self.output_cols,
|
632
|
+
)
|
633
|
+
self._sklearn_object = fitted_estimator
|
634
|
+
self._is_fitted = True
|
635
|
+
return output_result
|
622
636
|
|
623
637
|
|
624
638
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -709,10 +723,8 @@ class DictionaryLearning(BaseTransformer):
|
|
709
723
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
710
724
|
|
711
725
|
if isinstance(dataset, DataFrame):
|
712
|
-
self.
|
713
|
-
|
714
|
-
inference_method=inference_method,
|
715
|
-
)
|
726
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
727
|
+
self._deps = self._get_dependencies()
|
716
728
|
assert isinstance(
|
717
729
|
dataset._session, Session
|
718
730
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -777,10 +789,8 @@ class DictionaryLearning(BaseTransformer):
|
|
777
789
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
778
790
|
|
779
791
|
if isinstance(dataset, DataFrame):
|
780
|
-
self.
|
781
|
-
|
782
|
-
inference_method=inference_method,
|
783
|
-
)
|
792
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
793
|
+
self._deps = self._get_dependencies()
|
784
794
|
assert isinstance(
|
785
795
|
dataset._session, Session
|
786
796
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -842,10 +852,8 @@ class DictionaryLearning(BaseTransformer):
|
|
842
852
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
843
853
|
|
844
854
|
if isinstance(dataset, DataFrame):
|
845
|
-
self.
|
846
|
-
|
847
|
-
inference_method=inference_method,
|
848
|
-
)
|
855
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
856
|
+
self._deps = self._get_dependencies()
|
849
857
|
assert isinstance(
|
850
858
|
dataset._session, Session
|
851
859
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -911,10 +919,8 @@ class DictionaryLearning(BaseTransformer):
|
|
911
919
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
912
920
|
|
913
921
|
if isinstance(dataset, DataFrame):
|
914
|
-
self.
|
915
|
-
|
916
|
-
inference_method=inference_method,
|
917
|
-
)
|
922
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
923
|
+
self._deps = self._get_dependencies()
|
918
924
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
919
925
|
transform_kwargs = dict(
|
920
926
|
session=dataset._session,
|
@@ -976,17 +982,15 @@ class DictionaryLearning(BaseTransformer):
|
|
976
982
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
977
983
|
|
978
984
|
if isinstance(dataset, DataFrame):
|
979
|
-
self.
|
980
|
-
|
981
|
-
inference_method="score",
|
982
|
-
)
|
985
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
986
|
+
self._deps = self._get_dependencies()
|
983
987
|
selected_cols = self._get_active_columns()
|
984
988
|
if len(selected_cols) > 0:
|
985
989
|
dataset = dataset.select(selected_cols)
|
986
990
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
987
991
|
transform_kwargs = dict(
|
988
992
|
session=dataset._session,
|
989
|
-
dependencies=
|
993
|
+
dependencies=self._deps,
|
990
994
|
score_sproc_imports=['sklearn'],
|
991
995
|
)
|
992
996
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1051,11 +1055,8 @@ class DictionaryLearning(BaseTransformer):
|
|
1051
1055
|
|
1052
1056
|
if isinstance(dataset, DataFrame):
|
1053
1057
|
|
1054
|
-
self.
|
1055
|
-
|
1056
|
-
inference_method=inference_method,
|
1057
|
-
|
1058
|
-
)
|
1058
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1059
|
+
self._deps = self._get_dependencies()
|
1059
1060
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1060
1061
|
transform_kwargs = dict(
|
1061
1062
|
session = dataset._session,
|