snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RidgeCV(BaseTransformer):
|
70
64
|
r"""Ridge regression with built-in cross-validation
|
71
65
|
For more details on this class, see [sklearn.linear_model.RidgeCV]
|
@@ -330,20 +324,17 @@ class RidgeCV(BaseTransformer):
|
|
330
324
|
self,
|
331
325
|
dataset: DataFrame,
|
332
326
|
inference_method: str,
|
333
|
-
) ->
|
334
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
335
|
-
return the available package that exists in the snowflake anaconda channel
|
327
|
+
) -> None:
|
328
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
336
329
|
|
337
330
|
Args:
|
338
331
|
dataset: snowpark dataframe
|
339
332
|
inference_method: the inference method such as predict, score...
|
340
|
-
|
333
|
+
|
341
334
|
Raises:
|
342
335
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
343
336
|
SnowflakeMLException: If the session is None, raise error
|
344
337
|
|
345
|
-
Returns:
|
346
|
-
A list of available package that exists in the snowflake anaconda channel
|
347
338
|
"""
|
348
339
|
if not self._is_fitted:
|
349
340
|
raise exceptions.SnowflakeMLException(
|
@@ -361,9 +352,7 @@ class RidgeCV(BaseTransformer):
|
|
361
352
|
"Session must not specified for snowpark dataset."
|
362
353
|
),
|
363
354
|
)
|
364
|
-
|
365
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
366
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
355
|
+
|
367
356
|
|
368
357
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
369
358
|
@telemetry.send_api_usage_telemetry(
|
@@ -411,7 +400,8 @@ class RidgeCV(BaseTransformer):
|
|
411
400
|
|
412
401
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
413
402
|
|
414
|
-
self.
|
403
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
404
|
+
self._deps = self._get_dependencies()
|
415
405
|
assert isinstance(
|
416
406
|
dataset._session, Session
|
417
407
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -494,10 +484,8 @@ class RidgeCV(BaseTransformer):
|
|
494
484
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
495
485
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
496
486
|
|
497
|
-
self.
|
498
|
-
|
499
|
-
inference_method=inference_method,
|
500
|
-
)
|
487
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
488
|
+
self._deps = self._get_dependencies()
|
501
489
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
502
490
|
|
503
491
|
transform_kwargs = dict(
|
@@ -564,16 +552,40 @@ class RidgeCV(BaseTransformer):
|
|
564
552
|
self._is_fitted = True
|
565
553
|
return output_result
|
566
554
|
|
555
|
+
|
556
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
557
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
558
|
+
""" Method not supported for this class.
|
567
559
|
|
568
|
-
|
569
|
-
|
570
|
-
|
560
|
+
|
561
|
+
Raises:
|
562
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
563
|
+
|
564
|
+
Args:
|
565
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
566
|
+
Snowpark or Pandas DataFrame.
|
567
|
+
output_cols_prefix: Prefix for the response columns
|
571
568
|
Returns:
|
572
569
|
Transformed dataset.
|
573
570
|
"""
|
574
|
-
self.
|
575
|
-
|
576
|
-
|
571
|
+
self._infer_input_output_cols(dataset)
|
572
|
+
super()._check_dataset_type(dataset)
|
573
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
574
|
+
estimator=self._sklearn_object,
|
575
|
+
dataset=dataset,
|
576
|
+
input_cols=self.input_cols,
|
577
|
+
label_cols=self.label_cols,
|
578
|
+
sample_weight_col=self.sample_weight_col,
|
579
|
+
autogenerated=self._autogenerated,
|
580
|
+
subproject=_SUBPROJECT,
|
581
|
+
)
|
582
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
583
|
+
drop_input_cols=self._drop_input_cols,
|
584
|
+
expected_output_cols_list=self.output_cols,
|
585
|
+
)
|
586
|
+
self._sklearn_object = fitted_estimator
|
587
|
+
self._is_fitted = True
|
588
|
+
return output_result
|
577
589
|
|
578
590
|
|
579
591
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -664,10 +676,8 @@ class RidgeCV(BaseTransformer):
|
|
664
676
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
665
677
|
|
666
678
|
if isinstance(dataset, DataFrame):
|
667
|
-
self.
|
668
|
-
|
669
|
-
inference_method=inference_method,
|
670
|
-
)
|
679
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
680
|
+
self._deps = self._get_dependencies()
|
671
681
|
assert isinstance(
|
672
682
|
dataset._session, Session
|
673
683
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -732,10 +742,8 @@ class RidgeCV(BaseTransformer):
|
|
732
742
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
733
743
|
|
734
744
|
if isinstance(dataset, DataFrame):
|
735
|
-
self.
|
736
|
-
|
737
|
-
inference_method=inference_method,
|
738
|
-
)
|
745
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
746
|
+
self._deps = self._get_dependencies()
|
739
747
|
assert isinstance(
|
740
748
|
dataset._session, Session
|
741
749
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -797,10 +805,8 @@ class RidgeCV(BaseTransformer):
|
|
797
805
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
798
806
|
|
799
807
|
if isinstance(dataset, DataFrame):
|
800
|
-
self.
|
801
|
-
|
802
|
-
inference_method=inference_method,
|
803
|
-
)
|
808
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
809
|
+
self._deps = self._get_dependencies()
|
804
810
|
assert isinstance(
|
805
811
|
dataset._session, Session
|
806
812
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -866,10 +872,8 @@ class RidgeCV(BaseTransformer):
|
|
866
872
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
867
873
|
|
868
874
|
if isinstance(dataset, DataFrame):
|
869
|
-
self.
|
870
|
-
|
871
|
-
inference_method=inference_method,
|
872
|
-
)
|
875
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
876
|
+
self._deps = self._get_dependencies()
|
873
877
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
874
878
|
transform_kwargs = dict(
|
875
879
|
session=dataset._session,
|
@@ -933,17 +937,15 @@ class RidgeCV(BaseTransformer):
|
|
933
937
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
934
938
|
|
935
939
|
if isinstance(dataset, DataFrame):
|
936
|
-
self.
|
937
|
-
|
938
|
-
inference_method="score",
|
939
|
-
)
|
940
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
941
|
+
self._deps = self._get_dependencies()
|
940
942
|
selected_cols = self._get_active_columns()
|
941
943
|
if len(selected_cols) > 0:
|
942
944
|
dataset = dataset.select(selected_cols)
|
943
945
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
944
946
|
transform_kwargs = dict(
|
945
947
|
session=dataset._session,
|
946
|
-
dependencies=
|
948
|
+
dependencies=self._deps,
|
947
949
|
score_sproc_imports=['sklearn'],
|
948
950
|
)
|
949
951
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1008,11 +1010,8 @@ class RidgeCV(BaseTransformer):
|
|
1008
1010
|
|
1009
1011
|
if isinstance(dataset, DataFrame):
|
1010
1012
|
|
1011
|
-
self.
|
1012
|
-
|
1013
|
-
inference_method=inference_method,
|
1014
|
-
|
1015
|
-
)
|
1013
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1014
|
+
self._deps = self._get_dependencies()
|
1016
1015
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1017
1016
|
transform_kwargs = dict(
|
1018
1017
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SGDClassifier(BaseTransformer):
|
70
64
|
r"""Linear classifiers (SVM, logistic regression, etc
|
71
65
|
For more details on this class, see [sklearn.linear_model.SGDClassifier]
|
@@ -449,20 +443,17 @@ class SGDClassifier(BaseTransformer):
|
|
449
443
|
self,
|
450
444
|
dataset: DataFrame,
|
451
445
|
inference_method: str,
|
452
|
-
) ->
|
453
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
454
|
-
return the available package that exists in the snowflake anaconda channel
|
446
|
+
) -> None:
|
447
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
455
448
|
|
456
449
|
Args:
|
457
450
|
dataset: snowpark dataframe
|
458
451
|
inference_method: the inference method such as predict, score...
|
459
|
-
|
452
|
+
|
460
453
|
Raises:
|
461
454
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
462
455
|
SnowflakeMLException: If the session is None, raise error
|
463
456
|
|
464
|
-
Returns:
|
465
|
-
A list of available package that exists in the snowflake anaconda channel
|
466
457
|
"""
|
467
458
|
if not self._is_fitted:
|
468
459
|
raise exceptions.SnowflakeMLException(
|
@@ -480,9 +471,7 @@ class SGDClassifier(BaseTransformer):
|
|
480
471
|
"Session must not specified for snowpark dataset."
|
481
472
|
),
|
482
473
|
)
|
483
|
-
|
484
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
485
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
474
|
+
|
486
475
|
|
487
476
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
488
477
|
@telemetry.send_api_usage_telemetry(
|
@@ -530,7 +519,8 @@ class SGDClassifier(BaseTransformer):
|
|
530
519
|
|
531
520
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
532
521
|
|
533
|
-
self.
|
522
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
523
|
+
self._deps = self._get_dependencies()
|
534
524
|
assert isinstance(
|
535
525
|
dataset._session, Session
|
536
526
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -613,10 +603,8 @@ class SGDClassifier(BaseTransformer):
|
|
613
603
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
614
604
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
615
605
|
|
616
|
-
self.
|
617
|
-
|
618
|
-
inference_method=inference_method,
|
619
|
-
)
|
606
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
607
|
+
self._deps = self._get_dependencies()
|
620
608
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
621
609
|
|
622
610
|
transform_kwargs = dict(
|
@@ -683,16 +671,40 @@ class SGDClassifier(BaseTransformer):
|
|
683
671
|
self._is_fitted = True
|
684
672
|
return output_result
|
685
673
|
|
674
|
+
|
675
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
676
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
677
|
+
""" Method not supported for this class.
|
686
678
|
|
687
|
-
|
688
|
-
|
689
|
-
|
679
|
+
|
680
|
+
Raises:
|
681
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
682
|
+
|
683
|
+
Args:
|
684
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
685
|
+
Snowpark or Pandas DataFrame.
|
686
|
+
output_cols_prefix: Prefix for the response columns
|
690
687
|
Returns:
|
691
688
|
Transformed dataset.
|
692
689
|
"""
|
693
|
-
self.
|
694
|
-
|
695
|
-
|
690
|
+
self._infer_input_output_cols(dataset)
|
691
|
+
super()._check_dataset_type(dataset)
|
692
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
693
|
+
estimator=self._sklearn_object,
|
694
|
+
dataset=dataset,
|
695
|
+
input_cols=self.input_cols,
|
696
|
+
label_cols=self.label_cols,
|
697
|
+
sample_weight_col=self.sample_weight_col,
|
698
|
+
autogenerated=self._autogenerated,
|
699
|
+
subproject=_SUBPROJECT,
|
700
|
+
)
|
701
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
702
|
+
drop_input_cols=self._drop_input_cols,
|
703
|
+
expected_output_cols_list=self.output_cols,
|
704
|
+
)
|
705
|
+
self._sklearn_object = fitted_estimator
|
706
|
+
self._is_fitted = True
|
707
|
+
return output_result
|
696
708
|
|
697
709
|
|
698
710
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -785,10 +797,8 @@ class SGDClassifier(BaseTransformer):
|
|
785
797
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
786
798
|
|
787
799
|
if isinstance(dataset, DataFrame):
|
788
|
-
self.
|
789
|
-
|
790
|
-
inference_method=inference_method,
|
791
|
-
)
|
800
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
801
|
+
self._deps = self._get_dependencies()
|
792
802
|
assert isinstance(
|
793
803
|
dataset._session, Session
|
794
804
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -855,10 +865,8 @@ class SGDClassifier(BaseTransformer):
|
|
855
865
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
856
866
|
|
857
867
|
if isinstance(dataset, DataFrame):
|
858
|
-
self.
|
859
|
-
|
860
|
-
inference_method=inference_method,
|
861
|
-
)
|
868
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
869
|
+
self._deps = self._get_dependencies()
|
862
870
|
assert isinstance(
|
863
871
|
dataset._session, Session
|
864
872
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -922,10 +930,8 @@ class SGDClassifier(BaseTransformer):
|
|
922
930
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
923
931
|
|
924
932
|
if isinstance(dataset, DataFrame):
|
925
|
-
self.
|
926
|
-
|
927
|
-
inference_method=inference_method,
|
928
|
-
)
|
933
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
934
|
+
self._deps = self._get_dependencies()
|
929
935
|
assert isinstance(
|
930
936
|
dataset._session, Session
|
931
937
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -991,10 +997,8 @@ class SGDClassifier(BaseTransformer):
|
|
991
997
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
992
998
|
|
993
999
|
if isinstance(dataset, DataFrame):
|
994
|
-
self.
|
995
|
-
|
996
|
-
inference_method=inference_method,
|
997
|
-
)
|
1000
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1001
|
+
self._deps = self._get_dependencies()
|
998
1002
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
999
1003
|
transform_kwargs = dict(
|
1000
1004
|
session=dataset._session,
|
@@ -1058,17 +1062,15 @@ class SGDClassifier(BaseTransformer):
|
|
1058
1062
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1059
1063
|
|
1060
1064
|
if isinstance(dataset, DataFrame):
|
1061
|
-
self.
|
1062
|
-
|
1063
|
-
inference_method="score",
|
1064
|
-
)
|
1065
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1066
|
+
self._deps = self._get_dependencies()
|
1065
1067
|
selected_cols = self._get_active_columns()
|
1066
1068
|
if len(selected_cols) > 0:
|
1067
1069
|
dataset = dataset.select(selected_cols)
|
1068
1070
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1069
1071
|
transform_kwargs = dict(
|
1070
1072
|
session=dataset._session,
|
1071
|
-
dependencies=
|
1073
|
+
dependencies=self._deps,
|
1072
1074
|
score_sproc_imports=['sklearn'],
|
1073
1075
|
)
|
1074
1076
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1133,11 +1135,8 @@ class SGDClassifier(BaseTransformer):
|
|
1133
1135
|
|
1134
1136
|
if isinstance(dataset, DataFrame):
|
1135
1137
|
|
1136
|
-
self.
|
1137
|
-
|
1138
|
-
inference_method=inference_method,
|
1139
|
-
|
1140
|
-
)
|
1138
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1139
|
+
self._deps = self._get_dependencies()
|
1141
1140
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1142
1141
|
transform_kwargs = dict(
|
1143
1142
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SGDOneClassSVM(BaseTransformer):
|
70
64
|
r"""Solves linear One-Class SVM using Stochastic Gradient Descent
|
71
65
|
For more details on this class, see [sklearn.linear_model.SGDOneClassSVM]
|
@@ -347,20 +341,17 @@ class SGDOneClassSVM(BaseTransformer):
|
|
347
341
|
self,
|
348
342
|
dataset: DataFrame,
|
349
343
|
inference_method: str,
|
350
|
-
) ->
|
351
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
352
|
-
return the available package that exists in the snowflake anaconda channel
|
344
|
+
) -> None:
|
345
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
353
346
|
|
354
347
|
Args:
|
355
348
|
dataset: snowpark dataframe
|
356
349
|
inference_method: the inference method such as predict, score...
|
357
|
-
|
350
|
+
|
358
351
|
Raises:
|
359
352
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
360
353
|
SnowflakeMLException: If the session is None, raise error
|
361
354
|
|
362
|
-
Returns:
|
363
|
-
A list of available package that exists in the snowflake anaconda channel
|
364
355
|
"""
|
365
356
|
if not self._is_fitted:
|
366
357
|
raise exceptions.SnowflakeMLException(
|
@@ -378,9 +369,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
378
369
|
"Session must not specified for snowpark dataset."
|
379
370
|
),
|
380
371
|
)
|
381
|
-
|
382
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
383
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
372
|
+
|
384
373
|
|
385
374
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
386
375
|
@telemetry.send_api_usage_telemetry(
|
@@ -428,7 +417,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
428
417
|
|
429
418
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
430
419
|
|
431
|
-
self.
|
420
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
421
|
+
self._deps = self._get_dependencies()
|
432
422
|
assert isinstance(
|
433
423
|
dataset._session, Session
|
434
424
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -511,10 +501,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
511
501
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
512
502
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
513
503
|
|
514
|
-
self.
|
515
|
-
|
516
|
-
inference_method=inference_method,
|
517
|
-
)
|
504
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
505
|
+
self._deps = self._get_dependencies()
|
518
506
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
519
507
|
|
520
508
|
transform_kwargs = dict(
|
@@ -583,16 +571,40 @@ class SGDOneClassSVM(BaseTransformer):
|
|
583
571
|
self._is_fitted = True
|
584
572
|
return output_result
|
585
573
|
|
574
|
+
|
575
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
576
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
577
|
+
""" Method not supported for this class.
|
578
|
+
|
586
579
|
|
587
|
-
|
588
|
-
|
589
|
-
|
580
|
+
Raises:
|
581
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
582
|
+
|
583
|
+
Args:
|
584
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
585
|
+
Snowpark or Pandas DataFrame.
|
586
|
+
output_cols_prefix: Prefix for the response columns
|
590
587
|
Returns:
|
591
588
|
Transformed dataset.
|
592
589
|
"""
|
593
|
-
self.
|
594
|
-
|
595
|
-
|
590
|
+
self._infer_input_output_cols(dataset)
|
591
|
+
super()._check_dataset_type(dataset)
|
592
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
593
|
+
estimator=self._sklearn_object,
|
594
|
+
dataset=dataset,
|
595
|
+
input_cols=self.input_cols,
|
596
|
+
label_cols=self.label_cols,
|
597
|
+
sample_weight_col=self.sample_weight_col,
|
598
|
+
autogenerated=self._autogenerated,
|
599
|
+
subproject=_SUBPROJECT,
|
600
|
+
)
|
601
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
602
|
+
drop_input_cols=self._drop_input_cols,
|
603
|
+
expected_output_cols_list=self.output_cols,
|
604
|
+
)
|
605
|
+
self._sklearn_object = fitted_estimator
|
606
|
+
self._is_fitted = True
|
607
|
+
return output_result
|
596
608
|
|
597
609
|
|
598
610
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -683,10 +695,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
683
695
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
684
696
|
|
685
697
|
if isinstance(dataset, DataFrame):
|
686
|
-
self.
|
687
|
-
|
688
|
-
inference_method=inference_method,
|
689
|
-
)
|
698
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
699
|
+
self._deps = self._get_dependencies()
|
690
700
|
assert isinstance(
|
691
701
|
dataset._session, Session
|
692
702
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -751,10 +761,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
751
761
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
752
762
|
|
753
763
|
if isinstance(dataset, DataFrame):
|
754
|
-
self.
|
755
|
-
|
756
|
-
inference_method=inference_method,
|
757
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
758
766
|
assert isinstance(
|
759
767
|
dataset._session, Session
|
760
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -818,10 +826,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
818
826
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
819
827
|
|
820
828
|
if isinstance(dataset, DataFrame):
|
821
|
-
self.
|
822
|
-
|
823
|
-
inference_method=inference_method,
|
824
|
-
)
|
829
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
830
|
+
self._deps = self._get_dependencies()
|
825
831
|
assert isinstance(
|
826
832
|
dataset._session, Session
|
827
833
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -889,10 +895,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
889
895
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
890
896
|
|
891
897
|
if isinstance(dataset, DataFrame):
|
892
|
-
self.
|
893
|
-
|
894
|
-
inference_method=inference_method,
|
895
|
-
)
|
898
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
899
|
+
self._deps = self._get_dependencies()
|
896
900
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
897
901
|
transform_kwargs = dict(
|
898
902
|
session=dataset._session,
|
@@ -954,17 +958,15 @@ class SGDOneClassSVM(BaseTransformer):
|
|
954
958
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
955
959
|
|
956
960
|
if isinstance(dataset, DataFrame):
|
957
|
-
self.
|
958
|
-
|
959
|
-
inference_method="score",
|
960
|
-
)
|
961
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
962
|
+
self._deps = self._get_dependencies()
|
961
963
|
selected_cols = self._get_active_columns()
|
962
964
|
if len(selected_cols) > 0:
|
963
965
|
dataset = dataset.select(selected_cols)
|
964
966
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
965
967
|
transform_kwargs = dict(
|
966
968
|
session=dataset._session,
|
967
|
-
dependencies=
|
969
|
+
dependencies=self._deps,
|
968
970
|
score_sproc_imports=['sklearn'],
|
969
971
|
)
|
970
972
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1029,11 +1031,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
1029
1031
|
|
1030
1032
|
if isinstance(dataset, DataFrame):
|
1031
1033
|
|
1032
|
-
self.
|
1033
|
-
|
1034
|
-
inference_method=inference_method,
|
1035
|
-
|
1036
|
-
)
|
1034
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1035
|
+
self._deps = self._get_dependencies()
|
1037
1036
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1038
1037
|
transform_kwargs = dict(
|
1039
1038
|
session = dataset._session,
|