snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class Isomap(BaseTransformer):
|
70
64
|
r"""Isomap Embedding
|
71
65
|
For more details on this class, see [sklearn.manifold.Isomap]
|
@@ -339,20 +333,17 @@ class Isomap(BaseTransformer):
|
|
339
333
|
self,
|
340
334
|
dataset: DataFrame,
|
341
335
|
inference_method: str,
|
342
|
-
) ->
|
343
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
344
|
-
return the available package that exists in the snowflake anaconda channel
|
336
|
+
) -> None:
|
337
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
345
338
|
|
346
339
|
Args:
|
347
340
|
dataset: snowpark dataframe
|
348
341
|
inference_method: the inference method such as predict, score...
|
349
|
-
|
342
|
+
|
350
343
|
Raises:
|
351
344
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
352
345
|
SnowflakeMLException: If the session is None, raise error
|
353
346
|
|
354
|
-
Returns:
|
355
|
-
A list of available package that exists in the snowflake anaconda channel
|
356
347
|
"""
|
357
348
|
if not self._is_fitted:
|
358
349
|
raise exceptions.SnowflakeMLException(
|
@@ -370,9 +361,7 @@ class Isomap(BaseTransformer):
|
|
370
361
|
"Session must not specified for snowpark dataset."
|
371
362
|
),
|
372
363
|
)
|
373
|
-
|
374
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
375
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
364
|
+
|
376
365
|
|
377
366
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
378
367
|
@telemetry.send_api_usage_telemetry(
|
@@ -418,7 +407,8 @@ class Isomap(BaseTransformer):
|
|
418
407
|
|
419
408
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
420
409
|
|
421
|
-
self.
|
410
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
411
|
+
self._deps = self._get_dependencies()
|
422
412
|
assert isinstance(
|
423
413
|
dataset._session, Session
|
424
414
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -503,10 +493,8 @@ class Isomap(BaseTransformer):
|
|
503
493
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
504
494
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
505
495
|
|
506
|
-
self.
|
507
|
-
|
508
|
-
inference_method=inference_method,
|
509
|
-
)
|
496
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
497
|
+
self._deps = self._get_dependencies()
|
510
498
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
511
499
|
|
512
500
|
transform_kwargs = dict(
|
@@ -573,16 +561,42 @@ class Isomap(BaseTransformer):
|
|
573
561
|
self._is_fitted = True
|
574
562
|
return output_result
|
575
563
|
|
564
|
+
|
565
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
566
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
567
|
+
""" Fit the model from data in X and transform X
|
568
|
+
For more details on this function, see [sklearn.manifold.Isomap.fit_transform]
|
569
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.manifold.Isomap.html#sklearn.manifold.Isomap.fit_transform)
|
570
|
+
|
576
571
|
|
577
|
-
|
578
|
-
|
579
|
-
|
572
|
+
Raises:
|
573
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
574
|
+
|
575
|
+
Args:
|
576
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
577
|
+
Snowpark or Pandas DataFrame.
|
578
|
+
output_cols_prefix: Prefix for the response columns
|
580
579
|
Returns:
|
581
580
|
Transformed dataset.
|
582
581
|
"""
|
583
|
-
self.
|
584
|
-
|
585
|
-
|
582
|
+
self._infer_input_output_cols(dataset)
|
583
|
+
super()._check_dataset_type(dataset)
|
584
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
585
|
+
estimator=self._sklearn_object,
|
586
|
+
dataset=dataset,
|
587
|
+
input_cols=self.input_cols,
|
588
|
+
label_cols=self.label_cols,
|
589
|
+
sample_weight_col=self.sample_weight_col,
|
590
|
+
autogenerated=self._autogenerated,
|
591
|
+
subproject=_SUBPROJECT,
|
592
|
+
)
|
593
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
594
|
+
drop_input_cols=self._drop_input_cols,
|
595
|
+
expected_output_cols_list=self.output_cols,
|
596
|
+
)
|
597
|
+
self._sklearn_object = fitted_estimator
|
598
|
+
self._is_fitted = True
|
599
|
+
return output_result
|
586
600
|
|
587
601
|
|
588
602
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -673,10 +687,8 @@ class Isomap(BaseTransformer):
|
|
673
687
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
674
688
|
|
675
689
|
if isinstance(dataset, DataFrame):
|
676
|
-
self.
|
677
|
-
|
678
|
-
inference_method=inference_method,
|
679
|
-
)
|
690
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
691
|
+
self._deps = self._get_dependencies()
|
680
692
|
assert isinstance(
|
681
693
|
dataset._session, Session
|
682
694
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -741,10 +753,8 @@ class Isomap(BaseTransformer):
|
|
741
753
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
742
754
|
|
743
755
|
if isinstance(dataset, DataFrame):
|
744
|
-
self.
|
745
|
-
|
746
|
-
inference_method=inference_method,
|
747
|
-
)
|
756
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
757
|
+
self._deps = self._get_dependencies()
|
748
758
|
assert isinstance(
|
749
759
|
dataset._session, Session
|
750
760
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -806,10 +816,8 @@ class Isomap(BaseTransformer):
|
|
806
816
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
807
817
|
|
808
818
|
if isinstance(dataset, DataFrame):
|
809
|
-
self.
|
810
|
-
|
811
|
-
inference_method=inference_method,
|
812
|
-
)
|
819
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
820
|
+
self._deps = self._get_dependencies()
|
813
821
|
assert isinstance(
|
814
822
|
dataset._session, Session
|
815
823
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -875,10 +883,8 @@ class Isomap(BaseTransformer):
|
|
875
883
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
876
884
|
|
877
885
|
if isinstance(dataset, DataFrame):
|
878
|
-
self.
|
879
|
-
|
880
|
-
inference_method=inference_method,
|
881
|
-
)
|
886
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
887
|
+
self._deps = self._get_dependencies()
|
882
888
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
883
889
|
transform_kwargs = dict(
|
884
890
|
session=dataset._session,
|
@@ -940,17 +946,15 @@ class Isomap(BaseTransformer):
|
|
940
946
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
941
947
|
|
942
948
|
if isinstance(dataset, DataFrame):
|
943
|
-
self.
|
944
|
-
|
945
|
-
inference_method="score",
|
946
|
-
)
|
949
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
950
|
+
self._deps = self._get_dependencies()
|
947
951
|
selected_cols = self._get_active_columns()
|
948
952
|
if len(selected_cols) > 0:
|
949
953
|
dataset = dataset.select(selected_cols)
|
950
954
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
951
955
|
transform_kwargs = dict(
|
952
956
|
session=dataset._session,
|
953
|
-
dependencies=
|
957
|
+
dependencies=self._deps,
|
954
958
|
score_sproc_imports=['sklearn'],
|
955
959
|
)
|
956
960
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1015,11 +1019,8 @@ class Isomap(BaseTransformer):
|
|
1015
1019
|
|
1016
1020
|
if isinstance(dataset, DataFrame):
|
1017
1021
|
|
1018
|
-
self.
|
1019
|
-
|
1020
|
-
inference_method=inference_method,
|
1021
|
-
|
1022
|
-
)
|
1022
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1023
|
+
self._deps = self._get_dependencies()
|
1023
1024
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1024
1025
|
transform_kwargs = dict(
|
1025
1026
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MDS(BaseTransformer):
|
70
64
|
r"""Multidimensional scaling
|
71
65
|
For more details on this class, see [sklearn.manifold.MDS]
|
@@ -322,20 +316,17 @@ class MDS(BaseTransformer):
|
|
322
316
|
self,
|
323
317
|
dataset: DataFrame,
|
324
318
|
inference_method: str,
|
325
|
-
) ->
|
326
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
327
|
-
return the available package that exists in the snowflake anaconda channel
|
319
|
+
) -> None:
|
320
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
328
321
|
|
329
322
|
Args:
|
330
323
|
dataset: snowpark dataframe
|
331
324
|
inference_method: the inference method such as predict, score...
|
332
|
-
|
325
|
+
|
333
326
|
Raises:
|
334
327
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
335
328
|
SnowflakeMLException: If the session is None, raise error
|
336
329
|
|
337
|
-
Returns:
|
338
|
-
A list of available package that exists in the snowflake anaconda channel
|
339
330
|
"""
|
340
331
|
if not self._is_fitted:
|
341
332
|
raise exceptions.SnowflakeMLException(
|
@@ -353,9 +344,7 @@ class MDS(BaseTransformer):
|
|
353
344
|
"Session must not specified for snowpark dataset."
|
354
345
|
),
|
355
346
|
)
|
356
|
-
|
357
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
358
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
347
|
+
|
359
348
|
|
360
349
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
361
350
|
@telemetry.send_api_usage_telemetry(
|
@@ -401,7 +390,8 @@ class MDS(BaseTransformer):
|
|
401
390
|
|
402
391
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
403
392
|
|
404
|
-
self.
|
393
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
394
|
+
self._deps = self._get_dependencies()
|
405
395
|
assert isinstance(
|
406
396
|
dataset._session, Session
|
407
397
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -484,10 +474,8 @@ class MDS(BaseTransformer):
|
|
484
474
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
485
475
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
486
476
|
|
487
|
-
self.
|
488
|
-
|
489
|
-
inference_method=inference_method,
|
490
|
-
)
|
477
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
478
|
+
self._deps = self._get_dependencies()
|
491
479
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
492
480
|
|
493
481
|
transform_kwargs = dict(
|
@@ -554,16 +542,42 @@ class MDS(BaseTransformer):
|
|
554
542
|
self._is_fitted = True
|
555
543
|
return output_result
|
556
544
|
|
545
|
+
|
546
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
547
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
548
|
+
""" Fit the data from `X`, and returns the embedded coordinates
|
549
|
+
For more details on this function, see [sklearn.manifold.MDS.fit_transform]
|
550
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.manifold.MDS.html#sklearn.manifold.MDS.fit_transform)
|
551
|
+
|
557
552
|
|
558
|
-
|
559
|
-
|
560
|
-
|
553
|
+
Raises:
|
554
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
558
|
+
Snowpark or Pandas DataFrame.
|
559
|
+
output_cols_prefix: Prefix for the response columns
|
561
560
|
Returns:
|
562
561
|
Transformed dataset.
|
563
562
|
"""
|
564
|
-
self.
|
565
|
-
|
566
|
-
|
563
|
+
self._infer_input_output_cols(dataset)
|
564
|
+
super()._check_dataset_type(dataset)
|
565
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
566
|
+
estimator=self._sklearn_object,
|
567
|
+
dataset=dataset,
|
568
|
+
input_cols=self.input_cols,
|
569
|
+
label_cols=self.label_cols,
|
570
|
+
sample_weight_col=self.sample_weight_col,
|
571
|
+
autogenerated=self._autogenerated,
|
572
|
+
subproject=_SUBPROJECT,
|
573
|
+
)
|
574
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
575
|
+
drop_input_cols=self._drop_input_cols,
|
576
|
+
expected_output_cols_list=self.output_cols,
|
577
|
+
)
|
578
|
+
self._sklearn_object = fitted_estimator
|
579
|
+
self._is_fitted = True
|
580
|
+
return output_result
|
567
581
|
|
568
582
|
|
569
583
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -654,10 +668,8 @@ class MDS(BaseTransformer):
|
|
654
668
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
655
669
|
|
656
670
|
if isinstance(dataset, DataFrame):
|
657
|
-
self.
|
658
|
-
|
659
|
-
inference_method=inference_method,
|
660
|
-
)
|
671
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
672
|
+
self._deps = self._get_dependencies()
|
661
673
|
assert isinstance(
|
662
674
|
dataset._session, Session
|
663
675
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -722,10 +734,8 @@ class MDS(BaseTransformer):
|
|
722
734
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
723
735
|
|
724
736
|
if isinstance(dataset, DataFrame):
|
725
|
-
self.
|
726
|
-
|
727
|
-
inference_method=inference_method,
|
728
|
-
)
|
737
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
738
|
+
self._deps = self._get_dependencies()
|
729
739
|
assert isinstance(
|
730
740
|
dataset._session, Session
|
731
741
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -787,10 +797,8 @@ class MDS(BaseTransformer):
|
|
787
797
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
788
798
|
|
789
799
|
if isinstance(dataset, DataFrame):
|
790
|
-
self.
|
791
|
-
|
792
|
-
inference_method=inference_method,
|
793
|
-
)
|
800
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
801
|
+
self._deps = self._get_dependencies()
|
794
802
|
assert isinstance(
|
795
803
|
dataset._session, Session
|
796
804
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -856,10 +864,8 @@ class MDS(BaseTransformer):
|
|
856
864
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
857
865
|
|
858
866
|
if isinstance(dataset, DataFrame):
|
859
|
-
self.
|
860
|
-
|
861
|
-
inference_method=inference_method,
|
862
|
-
)
|
867
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
868
|
+
self._deps = self._get_dependencies()
|
863
869
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
864
870
|
transform_kwargs = dict(
|
865
871
|
session=dataset._session,
|
@@ -921,17 +927,15 @@ class MDS(BaseTransformer):
|
|
921
927
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
922
928
|
|
923
929
|
if isinstance(dataset, DataFrame):
|
924
|
-
self.
|
925
|
-
|
926
|
-
inference_method="score",
|
927
|
-
)
|
930
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
931
|
+
self._deps = self._get_dependencies()
|
928
932
|
selected_cols = self._get_active_columns()
|
929
933
|
if len(selected_cols) > 0:
|
930
934
|
dataset = dataset.select(selected_cols)
|
931
935
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
932
936
|
transform_kwargs = dict(
|
933
937
|
session=dataset._session,
|
934
|
-
dependencies=
|
938
|
+
dependencies=self._deps,
|
935
939
|
score_sproc_imports=['sklearn'],
|
936
940
|
)
|
937
941
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -996,11 +1000,8 @@ class MDS(BaseTransformer):
|
|
996
1000
|
|
997
1001
|
if isinstance(dataset, DataFrame):
|
998
1002
|
|
999
|
-
self.
|
1000
|
-
|
1001
|
-
inference_method=inference_method,
|
1002
|
-
|
1003
|
-
)
|
1003
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1004
|
+
self._deps = self._get_dependencies()
|
1004
1005
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1005
1006
|
transform_kwargs = dict(
|
1006
1007
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SpectralEmbedding(BaseTransformer):
|
70
64
|
r"""Spectral embedding for non-linear dimensionality reduction
|
71
65
|
For more details on this class, see [sklearn.manifold.SpectralEmbedding]
|
@@ -324,20 +318,17 @@ class SpectralEmbedding(BaseTransformer):
|
|
324
318
|
self,
|
325
319
|
dataset: DataFrame,
|
326
320
|
inference_method: str,
|
327
|
-
) ->
|
328
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
329
|
-
return the available package that exists in the snowflake anaconda channel
|
321
|
+
) -> None:
|
322
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
330
323
|
|
331
324
|
Args:
|
332
325
|
dataset: snowpark dataframe
|
333
326
|
inference_method: the inference method such as predict, score...
|
334
|
-
|
327
|
+
|
335
328
|
Raises:
|
336
329
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
337
330
|
SnowflakeMLException: If the session is None, raise error
|
338
331
|
|
339
|
-
Returns:
|
340
|
-
A list of available package that exists in the snowflake anaconda channel
|
341
332
|
"""
|
342
333
|
if not self._is_fitted:
|
343
334
|
raise exceptions.SnowflakeMLException(
|
@@ -355,9 +346,7 @@ class SpectralEmbedding(BaseTransformer):
|
|
355
346
|
"Session must not specified for snowpark dataset."
|
356
347
|
),
|
357
348
|
)
|
358
|
-
|
359
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
360
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
349
|
+
|
361
350
|
|
362
351
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
363
352
|
@telemetry.send_api_usage_telemetry(
|
@@ -403,7 +392,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
403
392
|
|
404
393
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
405
394
|
|
406
|
-
self.
|
395
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
396
|
+
self._deps = self._get_dependencies()
|
407
397
|
assert isinstance(
|
408
398
|
dataset._session, Session
|
409
399
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -486,10 +476,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
486
476
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
487
477
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
488
478
|
|
489
|
-
self.
|
490
|
-
|
491
|
-
inference_method=inference_method,
|
492
|
-
)
|
479
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
480
|
+
self._deps = self._get_dependencies()
|
493
481
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
494
482
|
|
495
483
|
transform_kwargs = dict(
|
@@ -556,16 +544,42 @@ class SpectralEmbedding(BaseTransformer):
|
|
556
544
|
self._is_fitted = True
|
557
545
|
return output_result
|
558
546
|
|
547
|
+
|
548
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
549
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
550
|
+
""" Fit the model from data in X and transform X
|
551
|
+
For more details on this function, see [sklearn.manifold.SpectralEmbedding.fit_transform]
|
552
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.manifold.SpectralEmbedding.html#sklearn.manifold.SpectralEmbedding.fit_transform)
|
553
|
+
|
559
554
|
|
560
|
-
|
561
|
-
|
562
|
-
|
555
|
+
Raises:
|
556
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
560
|
+
Snowpark or Pandas DataFrame.
|
561
|
+
output_cols_prefix: Prefix for the response columns
|
563
562
|
Returns:
|
564
563
|
Transformed dataset.
|
565
564
|
"""
|
566
|
-
self.
|
567
|
-
|
568
|
-
|
565
|
+
self._infer_input_output_cols(dataset)
|
566
|
+
super()._check_dataset_type(dataset)
|
567
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
568
|
+
estimator=self._sklearn_object,
|
569
|
+
dataset=dataset,
|
570
|
+
input_cols=self.input_cols,
|
571
|
+
label_cols=self.label_cols,
|
572
|
+
sample_weight_col=self.sample_weight_col,
|
573
|
+
autogenerated=self._autogenerated,
|
574
|
+
subproject=_SUBPROJECT,
|
575
|
+
)
|
576
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
577
|
+
drop_input_cols=self._drop_input_cols,
|
578
|
+
expected_output_cols_list=self.output_cols,
|
579
|
+
)
|
580
|
+
self._sklearn_object = fitted_estimator
|
581
|
+
self._is_fitted = True
|
582
|
+
return output_result
|
569
583
|
|
570
584
|
|
571
585
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -656,10 +670,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
656
670
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
657
671
|
|
658
672
|
if isinstance(dataset, DataFrame):
|
659
|
-
self.
|
660
|
-
|
661
|
-
inference_method=inference_method,
|
662
|
-
)
|
673
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
674
|
+
self._deps = self._get_dependencies()
|
663
675
|
assert isinstance(
|
664
676
|
dataset._session, Session
|
665
677
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -724,10 +736,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
724
736
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
725
737
|
|
726
738
|
if isinstance(dataset, DataFrame):
|
727
|
-
self.
|
728
|
-
|
729
|
-
inference_method=inference_method,
|
730
|
-
)
|
739
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
740
|
+
self._deps = self._get_dependencies()
|
731
741
|
assert isinstance(
|
732
742
|
dataset._session, Session
|
733
743
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -789,10 +799,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
789
799
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
790
800
|
|
791
801
|
if isinstance(dataset, DataFrame):
|
792
|
-
self.
|
793
|
-
|
794
|
-
inference_method=inference_method,
|
795
|
-
)
|
802
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
803
|
+
self._deps = self._get_dependencies()
|
796
804
|
assert isinstance(
|
797
805
|
dataset._session, Session
|
798
806
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -858,10 +866,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
858
866
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
859
867
|
|
860
868
|
if isinstance(dataset, DataFrame):
|
861
|
-
self.
|
862
|
-
|
863
|
-
inference_method=inference_method,
|
864
|
-
)
|
869
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
870
|
+
self._deps = self._get_dependencies()
|
865
871
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
866
872
|
transform_kwargs = dict(
|
867
873
|
session=dataset._session,
|
@@ -923,17 +929,15 @@ class SpectralEmbedding(BaseTransformer):
|
|
923
929
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
924
930
|
|
925
931
|
if isinstance(dataset, DataFrame):
|
926
|
-
self.
|
927
|
-
|
928
|
-
inference_method="score",
|
929
|
-
)
|
932
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
933
|
+
self._deps = self._get_dependencies()
|
930
934
|
selected_cols = self._get_active_columns()
|
931
935
|
if len(selected_cols) > 0:
|
932
936
|
dataset = dataset.select(selected_cols)
|
933
937
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
934
938
|
transform_kwargs = dict(
|
935
939
|
session=dataset._session,
|
936
|
-
dependencies=
|
940
|
+
dependencies=self._deps,
|
937
941
|
score_sproc_imports=['sklearn'],
|
938
942
|
)
|
939
943
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -998,11 +1002,8 @@ class SpectralEmbedding(BaseTransformer):
|
|
998
1002
|
|
999
1003
|
if isinstance(dataset, DataFrame):
|
1000
1004
|
|
1001
|
-
self.
|
1002
|
-
|
1003
|
-
inference_method=inference_method,
|
1004
|
-
|
1005
|
-
)
|
1005
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1006
|
+
self._deps = self._get_dependencies()
|
1006
1007
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1007
1008
|
transform_kwargs = dict(
|
1008
1009
|
session = dataset._session,
|