snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class OPTICS(BaseTransformer):
|
70
64
|
r"""Estimate clustering structure from vector array
|
71
65
|
For more details on this class, see [sklearn.cluster.OPTICS]
|
@@ -384,20 +378,17 @@ class OPTICS(BaseTransformer):
|
|
384
378
|
self,
|
385
379
|
dataset: DataFrame,
|
386
380
|
inference_method: str,
|
387
|
-
) ->
|
388
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
389
|
-
return the available package that exists in the snowflake anaconda channel
|
381
|
+
) -> None:
|
382
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
390
383
|
|
391
384
|
Args:
|
392
385
|
dataset: snowpark dataframe
|
393
386
|
inference_method: the inference method such as predict, score...
|
394
|
-
|
387
|
+
|
395
388
|
Raises:
|
396
389
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
397
390
|
SnowflakeMLException: If the session is None, raise error
|
398
391
|
|
399
|
-
Returns:
|
400
|
-
A list of available package that exists in the snowflake anaconda channel
|
401
392
|
"""
|
402
393
|
if not self._is_fitted:
|
403
394
|
raise exceptions.SnowflakeMLException(
|
@@ -415,9 +406,7 @@ class OPTICS(BaseTransformer):
|
|
415
406
|
"Session must not specified for snowpark dataset."
|
416
407
|
),
|
417
408
|
)
|
418
|
-
|
419
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
420
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
409
|
+
|
421
410
|
|
422
411
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
423
412
|
@telemetry.send_api_usage_telemetry(
|
@@ -463,7 +452,8 @@ class OPTICS(BaseTransformer):
|
|
463
452
|
|
464
453
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
465
454
|
|
466
|
-
self.
|
455
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
456
|
+
self._deps = self._get_dependencies()
|
467
457
|
assert isinstance(
|
468
458
|
dataset._session, Session
|
469
459
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -546,10 +536,8 @@ class OPTICS(BaseTransformer):
|
|
546
536
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
547
537
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
548
538
|
|
549
|
-
self.
|
550
|
-
|
551
|
-
inference_method=inference_method,
|
552
|
-
)
|
539
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
540
|
+
self._deps = self._get_dependencies()
|
553
541
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
554
542
|
|
555
543
|
transform_kwargs = dict(
|
@@ -618,16 +606,40 @@ class OPTICS(BaseTransformer):
|
|
618
606
|
self._is_fitted = True
|
619
607
|
return output_result
|
620
608
|
|
609
|
+
|
610
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
611
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
612
|
+
""" Method not supported for this class.
|
613
|
+
|
621
614
|
|
622
|
-
|
623
|
-
|
624
|
-
|
615
|
+
Raises:
|
616
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
617
|
+
|
618
|
+
Args:
|
619
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
620
|
+
Snowpark or Pandas DataFrame.
|
621
|
+
output_cols_prefix: Prefix for the response columns
|
625
622
|
Returns:
|
626
623
|
Transformed dataset.
|
627
624
|
"""
|
628
|
-
self.
|
629
|
-
|
630
|
-
|
625
|
+
self._infer_input_output_cols(dataset)
|
626
|
+
super()._check_dataset_type(dataset)
|
627
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
628
|
+
estimator=self._sklearn_object,
|
629
|
+
dataset=dataset,
|
630
|
+
input_cols=self.input_cols,
|
631
|
+
label_cols=self.label_cols,
|
632
|
+
sample_weight_col=self.sample_weight_col,
|
633
|
+
autogenerated=self._autogenerated,
|
634
|
+
subproject=_SUBPROJECT,
|
635
|
+
)
|
636
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
637
|
+
drop_input_cols=self._drop_input_cols,
|
638
|
+
expected_output_cols_list=self.output_cols,
|
639
|
+
)
|
640
|
+
self._sklearn_object = fitted_estimator
|
641
|
+
self._is_fitted = True
|
642
|
+
return output_result
|
631
643
|
|
632
644
|
|
633
645
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -718,10 +730,8 @@ class OPTICS(BaseTransformer):
|
|
718
730
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
719
731
|
|
720
732
|
if isinstance(dataset, DataFrame):
|
721
|
-
self.
|
722
|
-
|
723
|
-
inference_method=inference_method,
|
724
|
-
)
|
733
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
734
|
+
self._deps = self._get_dependencies()
|
725
735
|
assert isinstance(
|
726
736
|
dataset._session, Session
|
727
737
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -786,10 +796,8 @@ class OPTICS(BaseTransformer):
|
|
786
796
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
787
797
|
|
788
798
|
if isinstance(dataset, DataFrame):
|
789
|
-
self.
|
790
|
-
|
791
|
-
inference_method=inference_method,
|
792
|
-
)
|
799
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
800
|
+
self._deps = self._get_dependencies()
|
793
801
|
assert isinstance(
|
794
802
|
dataset._session, Session
|
795
803
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -851,10 +859,8 @@ class OPTICS(BaseTransformer):
|
|
851
859
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
852
860
|
|
853
861
|
if isinstance(dataset, DataFrame):
|
854
|
-
self.
|
855
|
-
|
856
|
-
inference_method=inference_method,
|
857
|
-
)
|
862
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
863
|
+
self._deps = self._get_dependencies()
|
858
864
|
assert isinstance(
|
859
865
|
dataset._session, Session
|
860
866
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -920,10 +926,8 @@ class OPTICS(BaseTransformer):
|
|
920
926
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
921
927
|
|
922
928
|
if isinstance(dataset, DataFrame):
|
923
|
-
self.
|
924
|
-
|
925
|
-
inference_method=inference_method,
|
926
|
-
)
|
929
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
930
|
+
self._deps = self._get_dependencies()
|
927
931
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
928
932
|
transform_kwargs = dict(
|
929
933
|
session=dataset._session,
|
@@ -985,17 +989,15 @@ class OPTICS(BaseTransformer):
|
|
985
989
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
986
990
|
|
987
991
|
if isinstance(dataset, DataFrame):
|
988
|
-
self.
|
989
|
-
|
990
|
-
inference_method="score",
|
991
|
-
)
|
992
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
993
|
+
self._deps = self._get_dependencies()
|
992
994
|
selected_cols = self._get_active_columns()
|
993
995
|
if len(selected_cols) > 0:
|
994
996
|
dataset = dataset.select(selected_cols)
|
995
997
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
996
998
|
transform_kwargs = dict(
|
997
999
|
session=dataset._session,
|
998
|
-
dependencies=
|
1000
|
+
dependencies=self._deps,
|
999
1001
|
score_sproc_imports=['sklearn'],
|
1000
1002
|
)
|
1001
1003
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1060,11 +1062,8 @@ class OPTICS(BaseTransformer):
|
|
1060
1062
|
|
1061
1063
|
if isinstance(dataset, DataFrame):
|
1062
1064
|
|
1063
|
-
self.
|
1064
|
-
|
1065
|
-
inference_method=inference_method,
|
1066
|
-
|
1067
|
-
)
|
1065
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1066
|
+
self._deps = self._get_dependencies()
|
1068
1067
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1069
1068
|
transform_kwargs = dict(
|
1070
1069
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SpectralBiclustering(BaseTransformer):
|
70
64
|
r"""Spectral biclustering (Kluger, 2003)
|
71
65
|
For more details on this class, see [sklearn.cluster.SpectralBiclustering]
|
@@ -322,20 +316,17 @@ class SpectralBiclustering(BaseTransformer):
|
|
322
316
|
self,
|
323
317
|
dataset: DataFrame,
|
324
318
|
inference_method: str,
|
325
|
-
) ->
|
326
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
327
|
-
return the available package that exists in the snowflake anaconda channel
|
319
|
+
) -> None:
|
320
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
328
321
|
|
329
322
|
Args:
|
330
323
|
dataset: snowpark dataframe
|
331
324
|
inference_method: the inference method such as predict, score...
|
332
|
-
|
325
|
+
|
333
326
|
Raises:
|
334
327
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
335
328
|
SnowflakeMLException: If the session is None, raise error
|
336
329
|
|
337
|
-
Returns:
|
338
|
-
A list of available package that exists in the snowflake anaconda channel
|
339
330
|
"""
|
340
331
|
if not self._is_fitted:
|
341
332
|
raise exceptions.SnowflakeMLException(
|
@@ -353,9 +344,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
353
344
|
"Session must not specified for snowpark dataset."
|
354
345
|
),
|
355
346
|
)
|
356
|
-
|
357
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
358
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
347
|
+
|
359
348
|
|
360
349
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
361
350
|
@telemetry.send_api_usage_telemetry(
|
@@ -401,7 +390,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
401
390
|
|
402
391
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
403
392
|
|
404
|
-
self.
|
393
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
394
|
+
self._deps = self._get_dependencies()
|
405
395
|
assert isinstance(
|
406
396
|
dataset._session, Session
|
407
397
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -484,10 +474,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
484
474
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
485
475
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
486
476
|
|
487
|
-
self.
|
488
|
-
|
489
|
-
inference_method=inference_method,
|
490
|
-
)
|
477
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
478
|
+
self._deps = self._get_dependencies()
|
491
479
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
492
480
|
|
493
481
|
transform_kwargs = dict(
|
@@ -554,16 +542,40 @@ class SpectralBiclustering(BaseTransformer):
|
|
554
542
|
self._is_fitted = True
|
555
543
|
return output_result
|
556
544
|
|
545
|
+
|
546
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
547
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
548
|
+
""" Method not supported for this class.
|
557
549
|
|
558
|
-
|
559
|
-
|
560
|
-
|
550
|
+
|
551
|
+
Raises:
|
552
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
553
|
+
|
554
|
+
Args:
|
555
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
556
|
+
Snowpark or Pandas DataFrame.
|
557
|
+
output_cols_prefix: Prefix for the response columns
|
561
558
|
Returns:
|
562
559
|
Transformed dataset.
|
563
560
|
"""
|
564
|
-
self.
|
565
|
-
|
566
|
-
|
561
|
+
self._infer_input_output_cols(dataset)
|
562
|
+
super()._check_dataset_type(dataset)
|
563
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
564
|
+
estimator=self._sklearn_object,
|
565
|
+
dataset=dataset,
|
566
|
+
input_cols=self.input_cols,
|
567
|
+
label_cols=self.label_cols,
|
568
|
+
sample_weight_col=self.sample_weight_col,
|
569
|
+
autogenerated=self._autogenerated,
|
570
|
+
subproject=_SUBPROJECT,
|
571
|
+
)
|
572
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
573
|
+
drop_input_cols=self._drop_input_cols,
|
574
|
+
expected_output_cols_list=self.output_cols,
|
575
|
+
)
|
576
|
+
self._sklearn_object = fitted_estimator
|
577
|
+
self._is_fitted = True
|
578
|
+
return output_result
|
567
579
|
|
568
580
|
|
569
581
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -654,10 +666,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
654
666
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
655
667
|
|
656
668
|
if isinstance(dataset, DataFrame):
|
657
|
-
self.
|
658
|
-
|
659
|
-
inference_method=inference_method,
|
660
|
-
)
|
669
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
670
|
+
self._deps = self._get_dependencies()
|
661
671
|
assert isinstance(
|
662
672
|
dataset._session, Session
|
663
673
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -722,10 +732,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
722
732
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
723
733
|
|
724
734
|
if isinstance(dataset, DataFrame):
|
725
|
-
self.
|
726
|
-
|
727
|
-
inference_method=inference_method,
|
728
|
-
)
|
735
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
736
|
+
self._deps = self._get_dependencies()
|
729
737
|
assert isinstance(
|
730
738
|
dataset._session, Session
|
731
739
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -787,10 +795,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
787
795
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
788
796
|
|
789
797
|
if isinstance(dataset, DataFrame):
|
790
|
-
self.
|
791
|
-
|
792
|
-
inference_method=inference_method,
|
793
|
-
)
|
798
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
799
|
+
self._deps = self._get_dependencies()
|
794
800
|
assert isinstance(
|
795
801
|
dataset._session, Session
|
796
802
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -856,10 +862,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
856
862
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
857
863
|
|
858
864
|
if isinstance(dataset, DataFrame):
|
859
|
-
self.
|
860
|
-
|
861
|
-
inference_method=inference_method,
|
862
|
-
)
|
865
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
866
|
+
self._deps = self._get_dependencies()
|
863
867
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
864
868
|
transform_kwargs = dict(
|
865
869
|
session=dataset._session,
|
@@ -921,17 +925,15 @@ class SpectralBiclustering(BaseTransformer):
|
|
921
925
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
922
926
|
|
923
927
|
if isinstance(dataset, DataFrame):
|
924
|
-
self.
|
925
|
-
|
926
|
-
inference_method="score",
|
927
|
-
)
|
928
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
929
|
+
self._deps = self._get_dependencies()
|
928
930
|
selected_cols = self._get_active_columns()
|
929
931
|
if len(selected_cols) > 0:
|
930
932
|
dataset = dataset.select(selected_cols)
|
931
933
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
932
934
|
transform_kwargs = dict(
|
933
935
|
session=dataset._session,
|
934
|
-
dependencies=
|
936
|
+
dependencies=self._deps,
|
935
937
|
score_sproc_imports=['sklearn'],
|
936
938
|
)
|
937
939
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -996,11 +998,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
996
998
|
|
997
999
|
if isinstance(dataset, DataFrame):
|
998
1000
|
|
999
|
-
self.
|
1000
|
-
|
1001
|
-
inference_method=inference_method,
|
1002
|
-
|
1003
|
-
)
|
1001
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1002
|
+
self._deps = self._get_dependencies()
|
1004
1003
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1005
1004
|
transform_kwargs = dict(
|
1006
1005
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SpectralClustering(BaseTransformer):
|
70
64
|
r"""Apply clustering to a projection of the normalized Laplacian
|
71
65
|
For more details on this class, see [sklearn.cluster.SpectralClustering]
|
@@ -380,20 +374,17 @@ class SpectralClustering(BaseTransformer):
|
|
380
374
|
self,
|
381
375
|
dataset: DataFrame,
|
382
376
|
inference_method: str,
|
383
|
-
) ->
|
384
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
385
|
-
return the available package that exists in the snowflake anaconda channel
|
377
|
+
) -> None:
|
378
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
386
379
|
|
387
380
|
Args:
|
388
381
|
dataset: snowpark dataframe
|
389
382
|
inference_method: the inference method such as predict, score...
|
390
|
-
|
383
|
+
|
391
384
|
Raises:
|
392
385
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
393
386
|
SnowflakeMLException: If the session is None, raise error
|
394
387
|
|
395
|
-
Returns:
|
396
|
-
A list of available package that exists in the snowflake anaconda channel
|
397
388
|
"""
|
398
389
|
if not self._is_fitted:
|
399
390
|
raise exceptions.SnowflakeMLException(
|
@@ -411,9 +402,7 @@ class SpectralClustering(BaseTransformer):
|
|
411
402
|
"Session must not specified for snowpark dataset."
|
412
403
|
),
|
413
404
|
)
|
414
|
-
|
415
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
416
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
405
|
+
|
417
406
|
|
418
407
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
419
408
|
@telemetry.send_api_usage_telemetry(
|
@@ -459,7 +448,8 @@ class SpectralClustering(BaseTransformer):
|
|
459
448
|
|
460
449
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
461
450
|
|
462
|
-
self.
|
451
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
452
|
+
self._deps = self._get_dependencies()
|
463
453
|
assert isinstance(
|
464
454
|
dataset._session, Session
|
465
455
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -542,10 +532,8 @@ class SpectralClustering(BaseTransformer):
|
|
542
532
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
543
533
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
544
534
|
|
545
|
-
self.
|
546
|
-
|
547
|
-
inference_method=inference_method,
|
548
|
-
)
|
535
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
536
|
+
self._deps = self._get_dependencies()
|
549
537
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
550
538
|
|
551
539
|
transform_kwargs = dict(
|
@@ -614,16 +602,40 @@ class SpectralClustering(BaseTransformer):
|
|
614
602
|
self._is_fitted = True
|
615
603
|
return output_result
|
616
604
|
|
605
|
+
|
606
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
607
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
608
|
+
""" Method not supported for this class.
|
609
|
+
|
617
610
|
|
618
|
-
|
619
|
-
|
620
|
-
|
611
|
+
Raises:
|
612
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
613
|
+
|
614
|
+
Args:
|
615
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
616
|
+
Snowpark or Pandas DataFrame.
|
617
|
+
output_cols_prefix: Prefix for the response columns
|
621
618
|
Returns:
|
622
619
|
Transformed dataset.
|
623
620
|
"""
|
624
|
-
self.
|
625
|
-
|
626
|
-
|
621
|
+
self._infer_input_output_cols(dataset)
|
622
|
+
super()._check_dataset_type(dataset)
|
623
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
624
|
+
estimator=self._sklearn_object,
|
625
|
+
dataset=dataset,
|
626
|
+
input_cols=self.input_cols,
|
627
|
+
label_cols=self.label_cols,
|
628
|
+
sample_weight_col=self.sample_weight_col,
|
629
|
+
autogenerated=self._autogenerated,
|
630
|
+
subproject=_SUBPROJECT,
|
631
|
+
)
|
632
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
633
|
+
drop_input_cols=self._drop_input_cols,
|
634
|
+
expected_output_cols_list=self.output_cols,
|
635
|
+
)
|
636
|
+
self._sklearn_object = fitted_estimator
|
637
|
+
self._is_fitted = True
|
638
|
+
return output_result
|
627
639
|
|
628
640
|
|
629
641
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -714,10 +726,8 @@ class SpectralClustering(BaseTransformer):
|
|
714
726
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
715
727
|
|
716
728
|
if isinstance(dataset, DataFrame):
|
717
|
-
self.
|
718
|
-
|
719
|
-
inference_method=inference_method,
|
720
|
-
)
|
729
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
730
|
+
self._deps = self._get_dependencies()
|
721
731
|
assert isinstance(
|
722
732
|
dataset._session, Session
|
723
733
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -782,10 +792,8 @@ class SpectralClustering(BaseTransformer):
|
|
782
792
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
783
793
|
|
784
794
|
if isinstance(dataset, DataFrame):
|
785
|
-
self.
|
786
|
-
|
787
|
-
inference_method=inference_method,
|
788
|
-
)
|
795
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
796
|
+
self._deps = self._get_dependencies()
|
789
797
|
assert isinstance(
|
790
798
|
dataset._session, Session
|
791
799
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -847,10 +855,8 @@ class SpectralClustering(BaseTransformer):
|
|
847
855
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
848
856
|
|
849
857
|
if isinstance(dataset, DataFrame):
|
850
|
-
self.
|
851
|
-
|
852
|
-
inference_method=inference_method,
|
853
|
-
)
|
858
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
859
|
+
self._deps = self._get_dependencies()
|
854
860
|
assert isinstance(
|
855
861
|
dataset._session, Session
|
856
862
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -916,10 +922,8 @@ class SpectralClustering(BaseTransformer):
|
|
916
922
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
917
923
|
|
918
924
|
if isinstance(dataset, DataFrame):
|
919
|
-
self.
|
920
|
-
|
921
|
-
inference_method=inference_method,
|
922
|
-
)
|
925
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
926
|
+
self._deps = self._get_dependencies()
|
923
927
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
924
928
|
transform_kwargs = dict(
|
925
929
|
session=dataset._session,
|
@@ -981,17 +985,15 @@ class SpectralClustering(BaseTransformer):
|
|
981
985
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
982
986
|
|
983
987
|
if isinstance(dataset, DataFrame):
|
984
|
-
self.
|
985
|
-
|
986
|
-
inference_method="score",
|
987
|
-
)
|
988
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
989
|
+
self._deps = self._get_dependencies()
|
988
990
|
selected_cols = self._get_active_columns()
|
989
991
|
if len(selected_cols) > 0:
|
990
992
|
dataset = dataset.select(selected_cols)
|
991
993
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
992
994
|
transform_kwargs = dict(
|
993
995
|
session=dataset._session,
|
994
|
-
dependencies=
|
996
|
+
dependencies=self._deps,
|
995
997
|
score_sproc_imports=['sklearn'],
|
996
998
|
)
|
997
999
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1056,11 +1058,8 @@ class SpectralClustering(BaseTransformer):
|
|
1056
1058
|
|
1057
1059
|
if isinstance(dataset, DataFrame):
|
1058
1060
|
|
1059
|
-
self.
|
1060
|
-
|
1061
|
-
inference_method=inference_method,
|
1062
|
-
|
1063
|
-
)
|
1061
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1062
|
+
self._deps = self._get_dependencies()
|
1064
1063
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1065
1064
|
transform_kwargs = dict(
|
1066
1065
|
session = dataset._session,
|