snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class EllipticEnvelope(BaseTransformer):
|
70
64
|
r"""An object for detecting outliers in a Gaussian distributed dataset
|
71
65
|
For more details on this class, see [sklearn.covariance.EllipticEnvelope]
|
@@ -287,20 +281,17 @@ class EllipticEnvelope(BaseTransformer):
|
|
287
281
|
self,
|
288
282
|
dataset: DataFrame,
|
289
283
|
inference_method: str,
|
290
|
-
) ->
|
291
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
292
|
-
return the available package that exists in the snowflake anaconda channel
|
284
|
+
) -> None:
|
285
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
293
286
|
|
294
287
|
Args:
|
295
288
|
dataset: snowpark dataframe
|
296
289
|
inference_method: the inference method such as predict, score...
|
297
|
-
|
290
|
+
|
298
291
|
Raises:
|
299
292
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
300
293
|
SnowflakeMLException: If the session is None, raise error
|
301
294
|
|
302
|
-
Returns:
|
303
|
-
A list of available package that exists in the snowflake anaconda channel
|
304
295
|
"""
|
305
296
|
if not self._is_fitted:
|
306
297
|
raise exceptions.SnowflakeMLException(
|
@@ -318,9 +309,7 @@ class EllipticEnvelope(BaseTransformer):
|
|
318
309
|
"Session must not specified for snowpark dataset."
|
319
310
|
),
|
320
311
|
)
|
321
|
-
|
322
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
323
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
312
|
+
|
324
313
|
|
325
314
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
326
315
|
@telemetry.send_api_usage_telemetry(
|
@@ -368,7 +357,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
368
357
|
|
369
358
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
370
359
|
|
371
|
-
self.
|
360
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
361
|
+
self._deps = self._get_dependencies()
|
372
362
|
assert isinstance(
|
373
363
|
dataset._session, Session
|
374
364
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -451,10 +441,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
451
441
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
452
442
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
453
443
|
|
454
|
-
self.
|
455
|
-
|
456
|
-
inference_method=inference_method,
|
457
|
-
)
|
444
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
445
|
+
self._deps = self._get_dependencies()
|
458
446
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
459
447
|
|
460
448
|
transform_kwargs = dict(
|
@@ -523,16 +511,40 @@ class EllipticEnvelope(BaseTransformer):
|
|
523
511
|
self._is_fitted = True
|
524
512
|
return output_result
|
525
513
|
|
514
|
+
|
515
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
516
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
517
|
+
""" Method not supported for this class.
|
518
|
+
|
526
519
|
|
527
|
-
|
528
|
-
|
529
|
-
|
520
|
+
Raises:
|
521
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
522
|
+
|
523
|
+
Args:
|
524
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
525
|
+
Snowpark or Pandas DataFrame.
|
526
|
+
output_cols_prefix: Prefix for the response columns
|
530
527
|
Returns:
|
531
528
|
Transformed dataset.
|
532
529
|
"""
|
533
|
-
self.
|
534
|
-
|
535
|
-
|
530
|
+
self._infer_input_output_cols(dataset)
|
531
|
+
super()._check_dataset_type(dataset)
|
532
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
533
|
+
estimator=self._sklearn_object,
|
534
|
+
dataset=dataset,
|
535
|
+
input_cols=self.input_cols,
|
536
|
+
label_cols=self.label_cols,
|
537
|
+
sample_weight_col=self.sample_weight_col,
|
538
|
+
autogenerated=self._autogenerated,
|
539
|
+
subproject=_SUBPROJECT,
|
540
|
+
)
|
541
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
542
|
+
drop_input_cols=self._drop_input_cols,
|
543
|
+
expected_output_cols_list=self.output_cols,
|
544
|
+
)
|
545
|
+
self._sklearn_object = fitted_estimator
|
546
|
+
self._is_fitted = True
|
547
|
+
return output_result
|
536
548
|
|
537
549
|
|
538
550
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -623,10 +635,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
623
635
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
624
636
|
|
625
637
|
if isinstance(dataset, DataFrame):
|
626
|
-
self.
|
627
|
-
|
628
|
-
inference_method=inference_method,
|
629
|
-
)
|
638
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
639
|
+
self._deps = self._get_dependencies()
|
630
640
|
assert isinstance(
|
631
641
|
dataset._session, Session
|
632
642
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -691,10 +701,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
691
701
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
692
702
|
|
693
703
|
if isinstance(dataset, DataFrame):
|
694
|
-
self.
|
695
|
-
|
696
|
-
inference_method=inference_method,
|
697
|
-
)
|
704
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
705
|
+
self._deps = self._get_dependencies()
|
698
706
|
assert isinstance(
|
699
707
|
dataset._session, Session
|
700
708
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -758,10 +766,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
758
766
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
759
767
|
|
760
768
|
if isinstance(dataset, DataFrame):
|
761
|
-
self.
|
762
|
-
|
763
|
-
inference_method=inference_method,
|
764
|
-
)
|
769
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
770
|
+
self._deps = self._get_dependencies()
|
765
771
|
assert isinstance(
|
766
772
|
dataset._session, Session
|
767
773
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -829,10 +835,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
829
835
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
830
836
|
|
831
837
|
if isinstance(dataset, DataFrame):
|
832
|
-
self.
|
833
|
-
|
834
|
-
inference_method=inference_method,
|
835
|
-
)
|
838
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
839
|
+
self._deps = self._get_dependencies()
|
836
840
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
837
841
|
transform_kwargs = dict(
|
838
842
|
session=dataset._session,
|
@@ -896,17 +900,15 @@ class EllipticEnvelope(BaseTransformer):
|
|
896
900
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
897
901
|
|
898
902
|
if isinstance(dataset, DataFrame):
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method="score",
|
902
|
-
)
|
903
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
904
|
+
self._deps = self._get_dependencies()
|
903
905
|
selected_cols = self._get_active_columns()
|
904
906
|
if len(selected_cols) > 0:
|
905
907
|
dataset = dataset.select(selected_cols)
|
906
908
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
907
909
|
transform_kwargs = dict(
|
908
910
|
session=dataset._session,
|
909
|
-
dependencies=
|
911
|
+
dependencies=self._deps,
|
910
912
|
score_sproc_imports=['sklearn'],
|
911
913
|
)
|
912
914
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -971,11 +973,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
971
973
|
|
972
974
|
if isinstance(dataset, DataFrame):
|
973
975
|
|
974
|
-
self.
|
975
|
-
|
976
|
-
inference_method=inference_method,
|
977
|
-
|
978
|
-
)
|
976
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
977
|
+
self._deps = self._get_dependencies()
|
979
978
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
980
979
|
transform_kwargs = dict(
|
981
980
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class EmpiricalCovariance(BaseTransformer):
|
70
64
|
r"""Maximum likelihood covariance estimator
|
71
65
|
For more details on this class, see [sklearn.covariance.EmpiricalCovariance]
|
@@ -263,20 +257,17 @@ class EmpiricalCovariance(BaseTransformer):
|
|
263
257
|
self,
|
264
258
|
dataset: DataFrame,
|
265
259
|
inference_method: str,
|
266
|
-
) ->
|
267
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
268
|
-
return the available package that exists in the snowflake anaconda channel
|
260
|
+
) -> None:
|
261
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
269
262
|
|
270
263
|
Args:
|
271
264
|
dataset: snowpark dataframe
|
272
265
|
inference_method: the inference method such as predict, score...
|
273
|
-
|
266
|
+
|
274
267
|
Raises:
|
275
268
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
276
269
|
SnowflakeMLException: If the session is None, raise error
|
277
270
|
|
278
|
-
Returns:
|
279
|
-
A list of available package that exists in the snowflake anaconda channel
|
280
271
|
"""
|
281
272
|
if not self._is_fitted:
|
282
273
|
raise exceptions.SnowflakeMLException(
|
@@ -294,9 +285,7 @@ class EmpiricalCovariance(BaseTransformer):
|
|
294
285
|
"Session must not specified for snowpark dataset."
|
295
286
|
),
|
296
287
|
)
|
297
|
-
|
298
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
299
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
288
|
+
|
300
289
|
|
301
290
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
302
291
|
@telemetry.send_api_usage_telemetry(
|
@@ -342,7 +331,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
342
331
|
|
343
332
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
344
333
|
|
345
|
-
self.
|
334
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
335
|
+
self._deps = self._get_dependencies()
|
346
336
|
assert isinstance(
|
347
337
|
dataset._session, Session
|
348
338
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -425,10 +415,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
425
415
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
426
416
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
427
417
|
|
428
|
-
self.
|
429
|
-
|
430
|
-
inference_method=inference_method,
|
431
|
-
)
|
418
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
419
|
+
self._deps = self._get_dependencies()
|
432
420
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
433
421
|
|
434
422
|
transform_kwargs = dict(
|
@@ -495,16 +483,40 @@ class EmpiricalCovariance(BaseTransformer):
|
|
495
483
|
self._is_fitted = True
|
496
484
|
return output_result
|
497
485
|
|
486
|
+
|
487
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
488
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
489
|
+
""" Method not supported for this class.
|
498
490
|
|
499
|
-
|
500
|
-
|
501
|
-
|
491
|
+
|
492
|
+
Raises:
|
493
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
494
|
+
|
495
|
+
Args:
|
496
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
497
|
+
Snowpark or Pandas DataFrame.
|
498
|
+
output_cols_prefix: Prefix for the response columns
|
502
499
|
Returns:
|
503
500
|
Transformed dataset.
|
504
501
|
"""
|
505
|
-
self.
|
506
|
-
|
507
|
-
|
502
|
+
self._infer_input_output_cols(dataset)
|
503
|
+
super()._check_dataset_type(dataset)
|
504
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
505
|
+
estimator=self._sklearn_object,
|
506
|
+
dataset=dataset,
|
507
|
+
input_cols=self.input_cols,
|
508
|
+
label_cols=self.label_cols,
|
509
|
+
sample_weight_col=self.sample_weight_col,
|
510
|
+
autogenerated=self._autogenerated,
|
511
|
+
subproject=_SUBPROJECT,
|
512
|
+
)
|
513
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
514
|
+
drop_input_cols=self._drop_input_cols,
|
515
|
+
expected_output_cols_list=self.output_cols,
|
516
|
+
)
|
517
|
+
self._sklearn_object = fitted_estimator
|
518
|
+
self._is_fitted = True
|
519
|
+
return output_result
|
508
520
|
|
509
521
|
|
510
522
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -595,10 +607,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
595
607
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
596
608
|
|
597
609
|
if isinstance(dataset, DataFrame):
|
598
|
-
self.
|
599
|
-
|
600
|
-
inference_method=inference_method,
|
601
|
-
)
|
610
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
611
|
+
self._deps = self._get_dependencies()
|
602
612
|
assert isinstance(
|
603
613
|
dataset._session, Session
|
604
614
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -663,10 +673,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
663
673
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
664
674
|
|
665
675
|
if isinstance(dataset, DataFrame):
|
666
|
-
self.
|
667
|
-
|
668
|
-
inference_method=inference_method,
|
669
|
-
)
|
676
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
677
|
+
self._deps = self._get_dependencies()
|
670
678
|
assert isinstance(
|
671
679
|
dataset._session, Session
|
672
680
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -728,10 +736,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
728
736
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
729
737
|
|
730
738
|
if isinstance(dataset, DataFrame):
|
731
|
-
self.
|
732
|
-
|
733
|
-
inference_method=inference_method,
|
734
|
-
)
|
739
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
740
|
+
self._deps = self._get_dependencies()
|
735
741
|
assert isinstance(
|
736
742
|
dataset._session, Session
|
737
743
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -797,10 +803,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
797
803
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
798
804
|
|
799
805
|
if isinstance(dataset, DataFrame):
|
800
|
-
self.
|
801
|
-
|
802
|
-
inference_method=inference_method,
|
803
|
-
)
|
806
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
807
|
+
self._deps = self._get_dependencies()
|
804
808
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
805
809
|
transform_kwargs = dict(
|
806
810
|
session=dataset._session,
|
@@ -864,17 +868,15 @@ class EmpiricalCovariance(BaseTransformer):
|
|
864
868
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
865
869
|
|
866
870
|
if isinstance(dataset, DataFrame):
|
867
|
-
self.
|
868
|
-
|
869
|
-
inference_method="score",
|
870
|
-
)
|
871
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
872
|
+
self._deps = self._get_dependencies()
|
871
873
|
selected_cols = self._get_active_columns()
|
872
874
|
if len(selected_cols) > 0:
|
873
875
|
dataset = dataset.select(selected_cols)
|
874
876
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
875
877
|
transform_kwargs = dict(
|
876
878
|
session=dataset._session,
|
877
|
-
dependencies=
|
879
|
+
dependencies=self._deps,
|
878
880
|
score_sproc_imports=['sklearn'],
|
879
881
|
)
|
880
882
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -939,11 +941,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
939
941
|
|
940
942
|
if isinstance(dataset, DataFrame):
|
941
943
|
|
942
|
-
self.
|
943
|
-
|
944
|
-
inference_method=inference_method,
|
945
|
-
|
946
|
-
)
|
944
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
945
|
+
self._deps = self._get_dependencies()
|
947
946
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
948
947
|
transform_kwargs = dict(
|
949
948
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GraphicalLasso(BaseTransformer):
|
70
64
|
r"""Sparse inverse covariance estimation with an l1-penalized estimator
|
71
65
|
For more details on this class, see [sklearn.covariance.GraphicalLasso]
|
@@ -311,20 +305,17 @@ class GraphicalLasso(BaseTransformer):
|
|
311
305
|
self,
|
312
306
|
dataset: DataFrame,
|
313
307
|
inference_method: str,
|
314
|
-
) ->
|
315
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
316
|
-
return the available package that exists in the snowflake anaconda channel
|
308
|
+
) -> None:
|
309
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
317
310
|
|
318
311
|
Args:
|
319
312
|
dataset: snowpark dataframe
|
320
313
|
inference_method: the inference method such as predict, score...
|
321
|
-
|
314
|
+
|
322
315
|
Raises:
|
323
316
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
324
317
|
SnowflakeMLException: If the session is None, raise error
|
325
318
|
|
326
|
-
Returns:
|
327
|
-
A list of available package that exists in the snowflake anaconda channel
|
328
319
|
"""
|
329
320
|
if not self._is_fitted:
|
330
321
|
raise exceptions.SnowflakeMLException(
|
@@ -342,9 +333,7 @@ class GraphicalLasso(BaseTransformer):
|
|
342
333
|
"Session must not specified for snowpark dataset."
|
343
334
|
),
|
344
335
|
)
|
345
|
-
|
346
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
347
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
336
|
+
|
348
337
|
|
349
338
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
350
339
|
@telemetry.send_api_usage_telemetry(
|
@@ -390,7 +379,8 @@ class GraphicalLasso(BaseTransformer):
|
|
390
379
|
|
391
380
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
392
381
|
|
393
|
-
self.
|
382
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
383
|
+
self._deps = self._get_dependencies()
|
394
384
|
assert isinstance(
|
395
385
|
dataset._session, Session
|
396
386
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -473,10 +463,8 @@ class GraphicalLasso(BaseTransformer):
|
|
473
463
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
474
464
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
475
465
|
|
476
|
-
self.
|
477
|
-
|
478
|
-
inference_method=inference_method,
|
479
|
-
)
|
466
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
467
|
+
self._deps = self._get_dependencies()
|
480
468
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
481
469
|
|
482
470
|
transform_kwargs = dict(
|
@@ -543,16 +531,40 @@ class GraphicalLasso(BaseTransformer):
|
|
543
531
|
self._is_fitted = True
|
544
532
|
return output_result
|
545
533
|
|
534
|
+
|
535
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
536
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
537
|
+
""" Method not supported for this class.
|
546
538
|
|
547
|
-
|
548
|
-
|
549
|
-
|
539
|
+
|
540
|
+
Raises:
|
541
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
542
|
+
|
543
|
+
Args:
|
544
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
545
|
+
Snowpark or Pandas DataFrame.
|
546
|
+
output_cols_prefix: Prefix for the response columns
|
550
547
|
Returns:
|
551
548
|
Transformed dataset.
|
552
549
|
"""
|
553
|
-
self.
|
554
|
-
|
555
|
-
|
550
|
+
self._infer_input_output_cols(dataset)
|
551
|
+
super()._check_dataset_type(dataset)
|
552
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
553
|
+
estimator=self._sklearn_object,
|
554
|
+
dataset=dataset,
|
555
|
+
input_cols=self.input_cols,
|
556
|
+
label_cols=self.label_cols,
|
557
|
+
sample_weight_col=self.sample_weight_col,
|
558
|
+
autogenerated=self._autogenerated,
|
559
|
+
subproject=_SUBPROJECT,
|
560
|
+
)
|
561
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
562
|
+
drop_input_cols=self._drop_input_cols,
|
563
|
+
expected_output_cols_list=self.output_cols,
|
564
|
+
)
|
565
|
+
self._sklearn_object = fitted_estimator
|
566
|
+
self._is_fitted = True
|
567
|
+
return output_result
|
556
568
|
|
557
569
|
|
558
570
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -643,10 +655,8 @@ class GraphicalLasso(BaseTransformer):
|
|
643
655
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
644
656
|
|
645
657
|
if isinstance(dataset, DataFrame):
|
646
|
-
self.
|
647
|
-
|
648
|
-
inference_method=inference_method,
|
649
|
-
)
|
658
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
659
|
+
self._deps = self._get_dependencies()
|
650
660
|
assert isinstance(
|
651
661
|
dataset._session, Session
|
652
662
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -711,10 +721,8 @@ class GraphicalLasso(BaseTransformer):
|
|
711
721
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
712
722
|
|
713
723
|
if isinstance(dataset, DataFrame):
|
714
|
-
self.
|
715
|
-
|
716
|
-
inference_method=inference_method,
|
717
|
-
)
|
724
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
725
|
+
self._deps = self._get_dependencies()
|
718
726
|
assert isinstance(
|
719
727
|
dataset._session, Session
|
720
728
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -776,10 +784,8 @@ class GraphicalLasso(BaseTransformer):
|
|
776
784
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
777
785
|
|
778
786
|
if isinstance(dataset, DataFrame):
|
779
|
-
self.
|
780
|
-
|
781
|
-
inference_method=inference_method,
|
782
|
-
)
|
787
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
788
|
+
self._deps = self._get_dependencies()
|
783
789
|
assert isinstance(
|
784
790
|
dataset._session, Session
|
785
791
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -845,10 +851,8 @@ class GraphicalLasso(BaseTransformer):
|
|
845
851
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
846
852
|
|
847
853
|
if isinstance(dataset, DataFrame):
|
848
|
-
self.
|
849
|
-
|
850
|
-
inference_method=inference_method,
|
851
|
-
)
|
854
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
855
|
+
self._deps = self._get_dependencies()
|
852
856
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
853
857
|
transform_kwargs = dict(
|
854
858
|
session=dataset._session,
|
@@ -912,17 +916,15 @@ class GraphicalLasso(BaseTransformer):
|
|
912
916
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
913
917
|
|
914
918
|
if isinstance(dataset, DataFrame):
|
915
|
-
self.
|
916
|
-
|
917
|
-
inference_method="score",
|
918
|
-
)
|
919
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
920
|
+
self._deps = self._get_dependencies()
|
919
921
|
selected_cols = self._get_active_columns()
|
920
922
|
if len(selected_cols) > 0:
|
921
923
|
dataset = dataset.select(selected_cols)
|
922
924
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
923
925
|
transform_kwargs = dict(
|
924
926
|
session=dataset._session,
|
925
|
-
dependencies=
|
927
|
+
dependencies=self._deps,
|
926
928
|
score_sproc_imports=['sklearn'],
|
927
929
|
)
|
928
930
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -987,11 +989,8 @@ class GraphicalLasso(BaseTransformer):
|
|
987
989
|
|
988
990
|
if isinstance(dataset, DataFrame):
|
989
991
|
|
990
|
-
self.
|
991
|
-
|
992
|
-
inference_method=inference_method,
|
993
|
-
|
994
|
-
)
|
992
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
993
|
+
self._deps = self._get_dependencies()
|
995
994
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
996
995
|
transform_kwargs = dict(
|
997
996
|
session = dataset._session,
|