snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -832,6 +832,18 @@ class OneHotEncoder(base.BaseTransformer):
|
|
832
832
|
|
833
833
|
# columns: COLUMN_NAME, CATEGORY, COUNT, FITTED_CATEGORY, ENCODING, N_FEATURES_OUT, ENCODED_VALUE, OUTPUT_CATs
|
834
834
|
assert dataset._session is not None
|
835
|
+
|
836
|
+
def convert_to_string_excluding_nan(item: Any) -> Union[None, str]:
|
837
|
+
if pd.isna(item):
|
838
|
+
return None # or np.nan if you prefer to keep as NaN
|
839
|
+
else:
|
840
|
+
return str(item)
|
841
|
+
|
842
|
+
# In case of fitting with pandas dataframe and transforming with snowpark dataframe
|
843
|
+
# state_pandas cannot recognize the datatype of _CATEGORY and _FITTED_CATEGORY column
|
844
|
+
# Therefore, apply the convert_to_string_excluding_nan function to _CATEGORY and _FITTED_CATEGORY
|
845
|
+
state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].applymap(convert_to_string_excluding_nan)
|
846
|
+
state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].applymap(convert_to_string_excluding_nan)
|
835
847
|
state_df = dataset._session.create_dataframe(state_pandas)
|
836
848
|
|
837
849
|
transformed_dataset = dataset
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.preprocessing".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class PolynomialFeatures(BaseTransformer):
|
70
64
|
r"""Generate polynomial and interaction features
|
71
65
|
For more details on this class, see [sklearn.preprocessing.PolynomialFeatures]
|
@@ -283,20 +277,17 @@ class PolynomialFeatures(BaseTransformer):
|
|
283
277
|
self,
|
284
278
|
dataset: DataFrame,
|
285
279
|
inference_method: str,
|
286
|
-
) ->
|
287
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
288
|
-
return the available package that exists in the snowflake anaconda channel
|
280
|
+
) -> None:
|
281
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
289
282
|
|
290
283
|
Args:
|
291
284
|
dataset: snowpark dataframe
|
292
285
|
inference_method: the inference method such as predict, score...
|
293
|
-
|
286
|
+
|
294
287
|
Raises:
|
295
288
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
296
289
|
SnowflakeMLException: If the session is None, raise error
|
297
290
|
|
298
|
-
Returns:
|
299
|
-
A list of available package that exists in the snowflake anaconda channel
|
300
291
|
"""
|
301
292
|
if not self._is_fitted:
|
302
293
|
raise exceptions.SnowflakeMLException(
|
@@ -314,9 +305,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
314
305
|
"Session must not specified for snowpark dataset."
|
315
306
|
),
|
316
307
|
)
|
317
|
-
|
318
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
319
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
308
|
+
|
320
309
|
|
321
310
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
322
311
|
@telemetry.send_api_usage_telemetry(
|
@@ -362,7 +351,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
362
351
|
|
363
352
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
364
353
|
|
365
|
-
self.
|
354
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
355
|
+
self._deps = self._get_dependencies()
|
366
356
|
assert isinstance(
|
367
357
|
dataset._session, Session
|
368
358
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -447,10 +437,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
447
437
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
448
438
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
449
439
|
|
450
|
-
self.
|
451
|
-
|
452
|
-
inference_method=inference_method,
|
453
|
-
)
|
440
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
441
|
+
self._deps = self._get_dependencies()
|
454
442
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
455
443
|
|
456
444
|
transform_kwargs = dict(
|
@@ -517,16 +505,42 @@ class PolynomialFeatures(BaseTransformer):
|
|
517
505
|
self._is_fitted = True
|
518
506
|
return output_result
|
519
507
|
|
508
|
+
|
509
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
510
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
511
|
+
""" Fit to data, then transform it
|
512
|
+
For more details on this function, see [sklearn.preprocessing.PolynomialFeatures.fit_transform]
|
513
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.PolynomialFeatures.html#sklearn.preprocessing.PolynomialFeatures.fit_transform)
|
514
|
+
|
520
515
|
|
521
|
-
|
522
|
-
|
523
|
-
|
516
|
+
Raises:
|
517
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
518
|
+
|
519
|
+
Args:
|
520
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
521
|
+
Snowpark or Pandas DataFrame.
|
522
|
+
output_cols_prefix: Prefix for the response columns
|
524
523
|
Returns:
|
525
524
|
Transformed dataset.
|
526
525
|
"""
|
527
|
-
self.
|
528
|
-
|
529
|
-
|
526
|
+
self._infer_input_output_cols(dataset)
|
527
|
+
super()._check_dataset_type(dataset)
|
528
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
529
|
+
estimator=self._sklearn_object,
|
530
|
+
dataset=dataset,
|
531
|
+
input_cols=self.input_cols,
|
532
|
+
label_cols=self.label_cols,
|
533
|
+
sample_weight_col=self.sample_weight_col,
|
534
|
+
autogenerated=self._autogenerated,
|
535
|
+
subproject=_SUBPROJECT,
|
536
|
+
)
|
537
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
538
|
+
drop_input_cols=self._drop_input_cols,
|
539
|
+
expected_output_cols_list=self.output_cols,
|
540
|
+
)
|
541
|
+
self._sklearn_object = fitted_estimator
|
542
|
+
self._is_fitted = True
|
543
|
+
return output_result
|
530
544
|
|
531
545
|
|
532
546
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -617,10 +631,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
617
631
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
618
632
|
|
619
633
|
if isinstance(dataset, DataFrame):
|
620
|
-
self.
|
621
|
-
|
622
|
-
inference_method=inference_method,
|
623
|
-
)
|
634
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
635
|
+
self._deps = self._get_dependencies()
|
624
636
|
assert isinstance(
|
625
637
|
dataset._session, Session
|
626
638
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -685,10 +697,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
685
697
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
686
698
|
|
687
699
|
if isinstance(dataset, DataFrame):
|
688
|
-
self.
|
689
|
-
|
690
|
-
inference_method=inference_method,
|
691
|
-
)
|
700
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
701
|
+
self._deps = self._get_dependencies()
|
692
702
|
assert isinstance(
|
693
703
|
dataset._session, Session
|
694
704
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -750,10 +760,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
750
760
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
751
761
|
|
752
762
|
if isinstance(dataset, DataFrame):
|
753
|
-
self.
|
754
|
-
|
755
|
-
inference_method=inference_method,
|
756
|
-
)
|
763
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
764
|
+
self._deps = self._get_dependencies()
|
757
765
|
assert isinstance(
|
758
766
|
dataset._session, Session
|
759
767
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -819,10 +827,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
819
827
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
820
828
|
|
821
829
|
if isinstance(dataset, DataFrame):
|
822
|
-
self.
|
823
|
-
|
824
|
-
inference_method=inference_method,
|
825
|
-
)
|
830
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
831
|
+
self._deps = self._get_dependencies()
|
826
832
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
827
833
|
transform_kwargs = dict(
|
828
834
|
session=dataset._session,
|
@@ -884,17 +890,15 @@ class PolynomialFeatures(BaseTransformer):
|
|
884
890
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
885
891
|
|
886
892
|
if isinstance(dataset, DataFrame):
|
887
|
-
self.
|
888
|
-
|
889
|
-
inference_method="score",
|
890
|
-
)
|
893
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
894
|
+
self._deps = self._get_dependencies()
|
891
895
|
selected_cols = self._get_active_columns()
|
892
896
|
if len(selected_cols) > 0:
|
893
897
|
dataset = dataset.select(selected_cols)
|
894
898
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
895
899
|
transform_kwargs = dict(
|
896
900
|
session=dataset._session,
|
897
|
-
dependencies=
|
901
|
+
dependencies=self._deps,
|
898
902
|
score_sproc_imports=['sklearn'],
|
899
903
|
)
|
900
904
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -959,11 +963,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
959
963
|
|
960
964
|
if isinstance(dataset, DataFrame):
|
961
965
|
|
962
|
-
self.
|
963
|
-
|
964
|
-
inference_method=inference_method,
|
965
|
-
|
966
|
-
)
|
966
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
967
|
+
self._deps = self._get_dependencies()
|
967
968
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
968
969
|
transform_kwargs = dict(
|
969
970
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LabelPropagation(BaseTransformer):
|
70
64
|
r"""Label Propagation classifier
|
71
65
|
For more details on this class, see [sklearn.semi_supervised.LabelPropagation]
|
@@ -289,20 +283,17 @@ class LabelPropagation(BaseTransformer):
|
|
289
283
|
self,
|
290
284
|
dataset: DataFrame,
|
291
285
|
inference_method: str,
|
292
|
-
) ->
|
293
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
294
|
-
return the available package that exists in the snowflake anaconda channel
|
286
|
+
) -> None:
|
287
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
295
288
|
|
296
289
|
Args:
|
297
290
|
dataset: snowpark dataframe
|
298
291
|
inference_method: the inference method such as predict, score...
|
299
|
-
|
292
|
+
|
300
293
|
Raises:
|
301
294
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
302
295
|
SnowflakeMLException: If the session is None, raise error
|
303
296
|
|
304
|
-
Returns:
|
305
|
-
A list of available package that exists in the snowflake anaconda channel
|
306
297
|
"""
|
307
298
|
if not self._is_fitted:
|
308
299
|
raise exceptions.SnowflakeMLException(
|
@@ -320,9 +311,7 @@ class LabelPropagation(BaseTransformer):
|
|
320
311
|
"Session must not specified for snowpark dataset."
|
321
312
|
),
|
322
313
|
)
|
323
|
-
|
324
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
325
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
314
|
+
|
326
315
|
|
327
316
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
328
317
|
@telemetry.send_api_usage_telemetry(
|
@@ -370,7 +359,8 @@ class LabelPropagation(BaseTransformer):
|
|
370
359
|
|
371
360
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
372
361
|
|
373
|
-
self.
|
362
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
363
|
+
self._deps = self._get_dependencies()
|
374
364
|
assert isinstance(
|
375
365
|
dataset._session, Session
|
376
366
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -453,10 +443,8 @@ class LabelPropagation(BaseTransformer):
|
|
453
443
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
454
444
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
455
445
|
|
456
|
-
self.
|
457
|
-
|
458
|
-
inference_method=inference_method,
|
459
|
-
)
|
446
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
447
|
+
self._deps = self._get_dependencies()
|
460
448
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
461
449
|
|
462
450
|
transform_kwargs = dict(
|
@@ -523,16 +511,40 @@ class LabelPropagation(BaseTransformer):
|
|
523
511
|
self._is_fitted = True
|
524
512
|
return output_result
|
525
513
|
|
514
|
+
|
515
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
516
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
517
|
+
""" Method not supported for this class.
|
526
518
|
|
527
|
-
|
528
|
-
|
529
|
-
|
519
|
+
|
520
|
+
Raises:
|
521
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
522
|
+
|
523
|
+
Args:
|
524
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
525
|
+
Snowpark or Pandas DataFrame.
|
526
|
+
output_cols_prefix: Prefix for the response columns
|
530
527
|
Returns:
|
531
528
|
Transformed dataset.
|
532
529
|
"""
|
533
|
-
self.
|
534
|
-
|
535
|
-
|
530
|
+
self._infer_input_output_cols(dataset)
|
531
|
+
super()._check_dataset_type(dataset)
|
532
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
533
|
+
estimator=self._sklearn_object,
|
534
|
+
dataset=dataset,
|
535
|
+
input_cols=self.input_cols,
|
536
|
+
label_cols=self.label_cols,
|
537
|
+
sample_weight_col=self.sample_weight_col,
|
538
|
+
autogenerated=self._autogenerated,
|
539
|
+
subproject=_SUBPROJECT,
|
540
|
+
)
|
541
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
542
|
+
drop_input_cols=self._drop_input_cols,
|
543
|
+
expected_output_cols_list=self.output_cols,
|
544
|
+
)
|
545
|
+
self._sklearn_object = fitted_estimator
|
546
|
+
self._is_fitted = True
|
547
|
+
return output_result
|
536
548
|
|
537
549
|
|
538
550
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -625,10 +637,8 @@ class LabelPropagation(BaseTransformer):
|
|
625
637
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
626
638
|
|
627
639
|
if isinstance(dataset, DataFrame):
|
628
|
-
self.
|
629
|
-
|
630
|
-
inference_method=inference_method,
|
631
|
-
)
|
640
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
641
|
+
self._deps = self._get_dependencies()
|
632
642
|
assert isinstance(
|
633
643
|
dataset._session, Session
|
634
644
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -695,10 +705,8 @@ class LabelPropagation(BaseTransformer):
|
|
695
705
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
696
706
|
|
697
707
|
if isinstance(dataset, DataFrame):
|
698
|
-
self.
|
699
|
-
|
700
|
-
inference_method=inference_method,
|
701
|
-
)
|
708
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
709
|
+
self._deps = self._get_dependencies()
|
702
710
|
assert isinstance(
|
703
711
|
dataset._session, Session
|
704
712
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -760,10 +768,8 @@ class LabelPropagation(BaseTransformer):
|
|
760
768
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
761
769
|
|
762
770
|
if isinstance(dataset, DataFrame):
|
763
|
-
self.
|
764
|
-
|
765
|
-
inference_method=inference_method,
|
766
|
-
)
|
771
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
772
|
+
self._deps = self._get_dependencies()
|
767
773
|
assert isinstance(
|
768
774
|
dataset._session, Session
|
769
775
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -829,10 +835,8 @@ class LabelPropagation(BaseTransformer):
|
|
829
835
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
830
836
|
|
831
837
|
if isinstance(dataset, DataFrame):
|
832
|
-
self.
|
833
|
-
|
834
|
-
inference_method=inference_method,
|
835
|
-
)
|
838
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
839
|
+
self._deps = self._get_dependencies()
|
836
840
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
837
841
|
transform_kwargs = dict(
|
838
842
|
session=dataset._session,
|
@@ -896,17 +900,15 @@ class LabelPropagation(BaseTransformer):
|
|
896
900
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
897
901
|
|
898
902
|
if isinstance(dataset, DataFrame):
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method="score",
|
902
|
-
)
|
903
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
904
|
+
self._deps = self._get_dependencies()
|
903
905
|
selected_cols = self._get_active_columns()
|
904
906
|
if len(selected_cols) > 0:
|
905
907
|
dataset = dataset.select(selected_cols)
|
906
908
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
907
909
|
transform_kwargs = dict(
|
908
910
|
session=dataset._session,
|
909
|
-
dependencies=
|
911
|
+
dependencies=self._deps,
|
910
912
|
score_sproc_imports=['sklearn'],
|
911
913
|
)
|
912
914
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -971,11 +973,8 @@ class LabelPropagation(BaseTransformer):
|
|
971
973
|
|
972
974
|
if isinstance(dataset, DataFrame):
|
973
975
|
|
974
|
-
self.
|
975
|
-
|
976
|
-
inference_method=inference_method,
|
977
|
-
|
978
|
-
)
|
976
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
977
|
+
self._deps = self._get_dependencies()
|
979
978
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
980
979
|
transform_kwargs = dict(
|
981
980
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LabelSpreading(BaseTransformer):
|
70
64
|
r"""LabelSpreading model for semi-supervised learning
|
71
65
|
For more details on this class, see [sklearn.semi_supervised.LabelSpreading]
|
@@ -298,20 +292,17 @@ class LabelSpreading(BaseTransformer):
|
|
298
292
|
self,
|
299
293
|
dataset: DataFrame,
|
300
294
|
inference_method: str,
|
301
|
-
) ->
|
302
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
303
|
-
return the available package that exists in the snowflake anaconda channel
|
295
|
+
) -> None:
|
296
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
304
297
|
|
305
298
|
Args:
|
306
299
|
dataset: snowpark dataframe
|
307
300
|
inference_method: the inference method such as predict, score...
|
308
|
-
|
301
|
+
|
309
302
|
Raises:
|
310
303
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
311
304
|
SnowflakeMLException: If the session is None, raise error
|
312
305
|
|
313
|
-
Returns:
|
314
|
-
A list of available package that exists in the snowflake anaconda channel
|
315
306
|
"""
|
316
307
|
if not self._is_fitted:
|
317
308
|
raise exceptions.SnowflakeMLException(
|
@@ -329,9 +320,7 @@ class LabelSpreading(BaseTransformer):
|
|
329
320
|
"Session must not specified for snowpark dataset."
|
330
321
|
),
|
331
322
|
)
|
332
|
-
|
333
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
334
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
323
|
+
|
335
324
|
|
336
325
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
337
326
|
@telemetry.send_api_usage_telemetry(
|
@@ -379,7 +368,8 @@ class LabelSpreading(BaseTransformer):
|
|
379
368
|
|
380
369
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
381
370
|
|
382
|
-
self.
|
371
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
372
|
+
self._deps = self._get_dependencies()
|
383
373
|
assert isinstance(
|
384
374
|
dataset._session, Session
|
385
375
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -462,10 +452,8 @@ class LabelSpreading(BaseTransformer):
|
|
462
452
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
463
453
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
464
454
|
|
465
|
-
self.
|
466
|
-
|
467
|
-
inference_method=inference_method,
|
468
|
-
)
|
455
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
456
|
+
self._deps = self._get_dependencies()
|
469
457
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
470
458
|
|
471
459
|
transform_kwargs = dict(
|
@@ -532,16 +520,40 @@ class LabelSpreading(BaseTransformer):
|
|
532
520
|
self._is_fitted = True
|
533
521
|
return output_result
|
534
522
|
|
523
|
+
|
524
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
525
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
526
|
+
""" Method not supported for this class.
|
535
527
|
|
536
|
-
|
537
|
-
|
538
|
-
|
528
|
+
|
529
|
+
Raises:
|
530
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
531
|
+
|
532
|
+
Args:
|
533
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
534
|
+
Snowpark or Pandas DataFrame.
|
535
|
+
output_cols_prefix: Prefix for the response columns
|
539
536
|
Returns:
|
540
537
|
Transformed dataset.
|
541
538
|
"""
|
542
|
-
self.
|
543
|
-
|
544
|
-
|
539
|
+
self._infer_input_output_cols(dataset)
|
540
|
+
super()._check_dataset_type(dataset)
|
541
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
542
|
+
estimator=self._sklearn_object,
|
543
|
+
dataset=dataset,
|
544
|
+
input_cols=self.input_cols,
|
545
|
+
label_cols=self.label_cols,
|
546
|
+
sample_weight_col=self.sample_weight_col,
|
547
|
+
autogenerated=self._autogenerated,
|
548
|
+
subproject=_SUBPROJECT,
|
549
|
+
)
|
550
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
551
|
+
drop_input_cols=self._drop_input_cols,
|
552
|
+
expected_output_cols_list=self.output_cols,
|
553
|
+
)
|
554
|
+
self._sklearn_object = fitted_estimator
|
555
|
+
self._is_fitted = True
|
556
|
+
return output_result
|
545
557
|
|
546
558
|
|
547
559
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -634,10 +646,8 @@ class LabelSpreading(BaseTransformer):
|
|
634
646
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
635
647
|
|
636
648
|
if isinstance(dataset, DataFrame):
|
637
|
-
self.
|
638
|
-
|
639
|
-
inference_method=inference_method,
|
640
|
-
)
|
649
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
650
|
+
self._deps = self._get_dependencies()
|
641
651
|
assert isinstance(
|
642
652
|
dataset._session, Session
|
643
653
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -704,10 +714,8 @@ class LabelSpreading(BaseTransformer):
|
|
704
714
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
705
715
|
|
706
716
|
if isinstance(dataset, DataFrame):
|
707
|
-
self.
|
708
|
-
|
709
|
-
inference_method=inference_method,
|
710
|
-
)
|
717
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
718
|
+
self._deps = self._get_dependencies()
|
711
719
|
assert isinstance(
|
712
720
|
dataset._session, Session
|
713
721
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -769,10 +777,8 @@ class LabelSpreading(BaseTransformer):
|
|
769
777
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
770
778
|
|
771
779
|
if isinstance(dataset, DataFrame):
|
772
|
-
self.
|
773
|
-
|
774
|
-
inference_method=inference_method,
|
775
|
-
)
|
780
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
781
|
+
self._deps = self._get_dependencies()
|
776
782
|
assert isinstance(
|
777
783
|
dataset._session, Session
|
778
784
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -838,10 +844,8 @@ class LabelSpreading(BaseTransformer):
|
|
838
844
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
839
845
|
|
840
846
|
if isinstance(dataset, DataFrame):
|
841
|
-
self.
|
842
|
-
|
843
|
-
inference_method=inference_method,
|
844
|
-
)
|
847
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
848
|
+
self._deps = self._get_dependencies()
|
845
849
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
846
850
|
transform_kwargs = dict(
|
847
851
|
session=dataset._session,
|
@@ -905,17 +909,15 @@ class LabelSpreading(BaseTransformer):
|
|
905
909
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
906
910
|
|
907
911
|
if isinstance(dataset, DataFrame):
|
908
|
-
self.
|
909
|
-
|
910
|
-
inference_method="score",
|
911
|
-
)
|
912
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
913
|
+
self._deps = self._get_dependencies()
|
912
914
|
selected_cols = self._get_active_columns()
|
913
915
|
if len(selected_cols) > 0:
|
914
916
|
dataset = dataset.select(selected_cols)
|
915
917
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
916
918
|
transform_kwargs = dict(
|
917
919
|
session=dataset._session,
|
918
|
-
dependencies=
|
920
|
+
dependencies=self._deps,
|
919
921
|
score_sproc_imports=['sklearn'],
|
920
922
|
)
|
921
923
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -980,11 +982,8 @@ class LabelSpreading(BaseTransformer):
|
|
980
982
|
|
981
983
|
if isinstance(dataset, DataFrame):
|
982
984
|
|
983
|
-
self.
|
984
|
-
|
985
|
-
inference_method=inference_method,
|
986
|
-
|
987
|
-
)
|
985
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
986
|
+
self._deps = self._get_dependencies()
|
988
987
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
989
988
|
transform_kwargs = dict(
|
990
989
|
session = dataset._session,
|