snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class OrthogonalMatchingPursuit(BaseTransformer):
|
70
64
|
r"""Orthogonal Matching Pursuit model (OMP)
|
71
65
|
For more details on this class, see [sklearn.linear_model.OrthogonalMatchingPursuit]
|
@@ -288,20 +282,17 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
288
282
|
self,
|
289
283
|
dataset: DataFrame,
|
290
284
|
inference_method: str,
|
291
|
-
) ->
|
292
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
293
|
-
return the available package that exists in the snowflake anaconda channel
|
285
|
+
) -> None:
|
286
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
294
287
|
|
295
288
|
Args:
|
296
289
|
dataset: snowpark dataframe
|
297
290
|
inference_method: the inference method such as predict, score...
|
298
|
-
|
291
|
+
|
299
292
|
Raises:
|
300
293
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
301
294
|
SnowflakeMLException: If the session is None, raise error
|
302
295
|
|
303
|
-
Returns:
|
304
|
-
A list of available package that exists in the snowflake anaconda channel
|
305
296
|
"""
|
306
297
|
if not self._is_fitted:
|
307
298
|
raise exceptions.SnowflakeMLException(
|
@@ -319,9 +310,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
319
310
|
"Session must not specified for snowpark dataset."
|
320
311
|
),
|
321
312
|
)
|
322
|
-
|
323
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
324
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
313
|
+
|
325
314
|
|
326
315
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
327
316
|
@telemetry.send_api_usage_telemetry(
|
@@ -369,7 +358,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
369
358
|
|
370
359
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
371
360
|
|
372
|
-
self.
|
361
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
362
|
+
self._deps = self._get_dependencies()
|
373
363
|
assert isinstance(
|
374
364
|
dataset._session, Session
|
375
365
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -452,10 +442,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
452
442
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
453
443
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
454
444
|
|
455
|
-
self.
|
456
|
-
|
457
|
-
inference_method=inference_method,
|
458
|
-
)
|
445
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
446
|
+
self._deps = self._get_dependencies()
|
459
447
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
460
448
|
|
461
449
|
transform_kwargs = dict(
|
@@ -522,16 +510,40 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
522
510
|
self._is_fitted = True
|
523
511
|
return output_result
|
524
512
|
|
513
|
+
|
514
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
515
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
516
|
+
""" Method not supported for this class.
|
525
517
|
|
526
|
-
|
527
|
-
|
528
|
-
|
518
|
+
|
519
|
+
Raises:
|
520
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
521
|
+
|
522
|
+
Args:
|
523
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
524
|
+
Snowpark or Pandas DataFrame.
|
525
|
+
output_cols_prefix: Prefix for the response columns
|
529
526
|
Returns:
|
530
527
|
Transformed dataset.
|
531
528
|
"""
|
532
|
-
self.
|
533
|
-
|
534
|
-
|
529
|
+
self._infer_input_output_cols(dataset)
|
530
|
+
super()._check_dataset_type(dataset)
|
531
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
532
|
+
estimator=self._sklearn_object,
|
533
|
+
dataset=dataset,
|
534
|
+
input_cols=self.input_cols,
|
535
|
+
label_cols=self.label_cols,
|
536
|
+
sample_weight_col=self.sample_weight_col,
|
537
|
+
autogenerated=self._autogenerated,
|
538
|
+
subproject=_SUBPROJECT,
|
539
|
+
)
|
540
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
541
|
+
drop_input_cols=self._drop_input_cols,
|
542
|
+
expected_output_cols_list=self.output_cols,
|
543
|
+
)
|
544
|
+
self._sklearn_object = fitted_estimator
|
545
|
+
self._is_fitted = True
|
546
|
+
return output_result
|
535
547
|
|
536
548
|
|
537
549
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -622,10 +634,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
622
634
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
623
635
|
|
624
636
|
if isinstance(dataset, DataFrame):
|
625
|
-
self.
|
626
|
-
|
627
|
-
inference_method=inference_method,
|
628
|
-
)
|
637
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
638
|
+
self._deps = self._get_dependencies()
|
629
639
|
assert isinstance(
|
630
640
|
dataset._session, Session
|
631
641
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -690,10 +700,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
690
700
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
691
701
|
|
692
702
|
if isinstance(dataset, DataFrame):
|
693
|
-
self.
|
694
|
-
|
695
|
-
inference_method=inference_method,
|
696
|
-
)
|
703
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
704
|
+
self._deps = self._get_dependencies()
|
697
705
|
assert isinstance(
|
698
706
|
dataset._session, Session
|
699
707
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -755,10 +763,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
755
763
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
756
764
|
|
757
765
|
if isinstance(dataset, DataFrame):
|
758
|
-
self.
|
759
|
-
|
760
|
-
inference_method=inference_method,
|
761
|
-
)
|
766
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
767
|
+
self._deps = self._get_dependencies()
|
762
768
|
assert isinstance(
|
763
769
|
dataset._session, Session
|
764
770
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -824,10 +830,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
824
830
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
825
831
|
|
826
832
|
if isinstance(dataset, DataFrame):
|
827
|
-
self.
|
828
|
-
|
829
|
-
inference_method=inference_method,
|
830
|
-
)
|
833
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
834
|
+
self._deps = self._get_dependencies()
|
831
835
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
832
836
|
transform_kwargs = dict(
|
833
837
|
session=dataset._session,
|
@@ -891,17 +895,15 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
891
895
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
892
896
|
|
893
897
|
if isinstance(dataset, DataFrame):
|
894
|
-
self.
|
895
|
-
|
896
|
-
inference_method="score",
|
897
|
-
)
|
898
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
899
|
+
self._deps = self._get_dependencies()
|
898
900
|
selected_cols = self._get_active_columns()
|
899
901
|
if len(selected_cols) > 0:
|
900
902
|
dataset = dataset.select(selected_cols)
|
901
903
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
902
904
|
transform_kwargs = dict(
|
903
905
|
session=dataset._session,
|
904
|
-
dependencies=
|
906
|
+
dependencies=self._deps,
|
905
907
|
score_sproc_imports=['sklearn'],
|
906
908
|
)
|
907
909
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -966,11 +968,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
966
968
|
|
967
969
|
if isinstance(dataset, DataFrame):
|
968
970
|
|
969
|
-
self.
|
970
|
-
|
971
|
-
inference_method=inference_method,
|
972
|
-
|
973
|
-
)
|
971
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
972
|
+
self._deps = self._get_dependencies()
|
974
973
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
975
974
|
transform_kwargs = dict(
|
976
975
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class PassiveAggressiveClassifier(BaseTransformer):
|
70
64
|
r"""Passive Aggressive Classifier
|
71
65
|
For more details on this class, see [sklearn.linear_model.PassiveAggressiveClassifier]
|
@@ -362,20 +356,17 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
362
356
|
self,
|
363
357
|
dataset: DataFrame,
|
364
358
|
inference_method: str,
|
365
|
-
) ->
|
366
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
367
|
-
return the available package that exists in the snowflake anaconda channel
|
359
|
+
) -> None:
|
360
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
368
361
|
|
369
362
|
Args:
|
370
363
|
dataset: snowpark dataframe
|
371
364
|
inference_method: the inference method such as predict, score...
|
372
|
-
|
365
|
+
|
373
366
|
Raises:
|
374
367
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
375
368
|
SnowflakeMLException: If the session is None, raise error
|
376
369
|
|
377
|
-
Returns:
|
378
|
-
A list of available package that exists in the snowflake anaconda channel
|
379
370
|
"""
|
380
371
|
if not self._is_fitted:
|
381
372
|
raise exceptions.SnowflakeMLException(
|
@@ -393,9 +384,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
393
384
|
"Session must not specified for snowpark dataset."
|
394
385
|
),
|
395
386
|
)
|
396
|
-
|
397
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
398
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
387
|
+
|
399
388
|
|
400
389
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
401
390
|
@telemetry.send_api_usage_telemetry(
|
@@ -443,7 +432,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
443
432
|
|
444
433
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
445
434
|
|
446
|
-
self.
|
435
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
436
|
+
self._deps = self._get_dependencies()
|
447
437
|
assert isinstance(
|
448
438
|
dataset._session, Session
|
449
439
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -526,10 +516,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
526
516
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
527
517
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
528
518
|
|
529
|
-
self.
|
530
|
-
|
531
|
-
inference_method=inference_method,
|
532
|
-
)
|
519
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
520
|
+
self._deps = self._get_dependencies()
|
533
521
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
534
522
|
|
535
523
|
transform_kwargs = dict(
|
@@ -596,16 +584,40 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
596
584
|
self._is_fitted = True
|
597
585
|
return output_result
|
598
586
|
|
587
|
+
|
588
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
589
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
590
|
+
""" Method not supported for this class.
|
599
591
|
|
600
|
-
|
601
|
-
|
602
|
-
|
592
|
+
|
593
|
+
Raises:
|
594
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
595
|
+
|
596
|
+
Args:
|
597
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
598
|
+
Snowpark or Pandas DataFrame.
|
599
|
+
output_cols_prefix: Prefix for the response columns
|
603
600
|
Returns:
|
604
601
|
Transformed dataset.
|
605
602
|
"""
|
606
|
-
self.
|
607
|
-
|
608
|
-
|
603
|
+
self._infer_input_output_cols(dataset)
|
604
|
+
super()._check_dataset_type(dataset)
|
605
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
606
|
+
estimator=self._sklearn_object,
|
607
|
+
dataset=dataset,
|
608
|
+
input_cols=self.input_cols,
|
609
|
+
label_cols=self.label_cols,
|
610
|
+
sample_weight_col=self.sample_weight_col,
|
611
|
+
autogenerated=self._autogenerated,
|
612
|
+
subproject=_SUBPROJECT,
|
613
|
+
)
|
614
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
615
|
+
drop_input_cols=self._drop_input_cols,
|
616
|
+
expected_output_cols_list=self.output_cols,
|
617
|
+
)
|
618
|
+
self._sklearn_object = fitted_estimator
|
619
|
+
self._is_fitted = True
|
620
|
+
return output_result
|
609
621
|
|
610
622
|
|
611
623
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -696,10 +708,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
696
708
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
697
709
|
|
698
710
|
if isinstance(dataset, DataFrame):
|
699
|
-
self.
|
700
|
-
|
701
|
-
inference_method=inference_method,
|
702
|
-
)
|
711
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
712
|
+
self._deps = self._get_dependencies()
|
703
713
|
assert isinstance(
|
704
714
|
dataset._session, Session
|
705
715
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -764,10 +774,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
764
774
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
765
775
|
|
766
776
|
if isinstance(dataset, DataFrame):
|
767
|
-
self.
|
768
|
-
|
769
|
-
inference_method=inference_method,
|
770
|
-
)
|
777
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
778
|
+
self._deps = self._get_dependencies()
|
771
779
|
assert isinstance(
|
772
780
|
dataset._session, Session
|
773
781
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -831,10 +839,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
831
839
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
832
840
|
|
833
841
|
if isinstance(dataset, DataFrame):
|
834
|
-
self.
|
835
|
-
|
836
|
-
inference_method=inference_method,
|
837
|
-
)
|
842
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
843
|
+
self._deps = self._get_dependencies()
|
838
844
|
assert isinstance(
|
839
845
|
dataset._session, Session
|
840
846
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -900,10 +906,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
900
906
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
901
907
|
|
902
908
|
if isinstance(dataset, DataFrame):
|
903
|
-
self.
|
904
|
-
|
905
|
-
inference_method=inference_method,
|
906
|
-
)
|
909
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
910
|
+
self._deps = self._get_dependencies()
|
907
911
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
908
912
|
transform_kwargs = dict(
|
909
913
|
session=dataset._session,
|
@@ -967,17 +971,15 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
967
971
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
968
972
|
|
969
973
|
if isinstance(dataset, DataFrame):
|
970
|
-
self.
|
971
|
-
|
972
|
-
inference_method="score",
|
973
|
-
)
|
974
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
975
|
+
self._deps = self._get_dependencies()
|
974
976
|
selected_cols = self._get_active_columns()
|
975
977
|
if len(selected_cols) > 0:
|
976
978
|
dataset = dataset.select(selected_cols)
|
977
979
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
978
980
|
transform_kwargs = dict(
|
979
981
|
session=dataset._session,
|
980
|
-
dependencies=
|
982
|
+
dependencies=self._deps,
|
981
983
|
score_sproc_imports=['sklearn'],
|
982
984
|
)
|
983
985
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1042,11 +1044,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
1042
1044
|
|
1043
1045
|
if isinstance(dataset, DataFrame):
|
1044
1046
|
|
1045
|
-
self.
|
1046
|
-
|
1047
|
-
inference_method=inference_method,
|
1048
|
-
|
1049
|
-
)
|
1047
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1048
|
+
self._deps = self._get_dependencies()
|
1050
1049
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1051
1050
|
transform_kwargs = dict(
|
1052
1051
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class PassiveAggressiveRegressor(BaseTransformer):
|
70
64
|
r"""Passive Aggressive Regressor
|
71
65
|
For more details on this class, see [sklearn.linear_model.PassiveAggressiveRegressor]
|
@@ -348,20 +342,17 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
348
342
|
self,
|
349
343
|
dataset: DataFrame,
|
350
344
|
inference_method: str,
|
351
|
-
) ->
|
352
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
353
|
-
return the available package that exists in the snowflake anaconda channel
|
345
|
+
) -> None:
|
346
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
354
347
|
|
355
348
|
Args:
|
356
349
|
dataset: snowpark dataframe
|
357
350
|
inference_method: the inference method such as predict, score...
|
358
|
-
|
351
|
+
|
359
352
|
Raises:
|
360
353
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
361
354
|
SnowflakeMLException: If the session is None, raise error
|
362
355
|
|
363
|
-
Returns:
|
364
|
-
A list of available package that exists in the snowflake anaconda channel
|
365
356
|
"""
|
366
357
|
if not self._is_fitted:
|
367
358
|
raise exceptions.SnowflakeMLException(
|
@@ -379,9 +370,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
379
370
|
"Session must not specified for snowpark dataset."
|
380
371
|
),
|
381
372
|
)
|
382
|
-
|
383
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
384
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
373
|
+
|
385
374
|
|
386
375
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
387
376
|
@telemetry.send_api_usage_telemetry(
|
@@ -429,7 +418,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
429
418
|
|
430
419
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
431
420
|
|
432
|
-
self.
|
421
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
422
|
+
self._deps = self._get_dependencies()
|
433
423
|
assert isinstance(
|
434
424
|
dataset._session, Session
|
435
425
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -512,10 +502,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
512
502
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
513
503
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
514
504
|
|
515
|
-
self.
|
516
|
-
|
517
|
-
inference_method=inference_method,
|
518
|
-
)
|
505
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
506
|
+
self._deps = self._get_dependencies()
|
519
507
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
520
508
|
|
521
509
|
transform_kwargs = dict(
|
@@ -582,16 +570,40 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
582
570
|
self._is_fitted = True
|
583
571
|
return output_result
|
584
572
|
|
573
|
+
|
574
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
575
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
576
|
+
""" Method not supported for this class.
|
585
577
|
|
586
|
-
|
587
|
-
|
588
|
-
|
578
|
+
|
579
|
+
Raises:
|
580
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
581
|
+
|
582
|
+
Args:
|
583
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
584
|
+
Snowpark or Pandas DataFrame.
|
585
|
+
output_cols_prefix: Prefix for the response columns
|
589
586
|
Returns:
|
590
587
|
Transformed dataset.
|
591
588
|
"""
|
592
|
-
self.
|
593
|
-
|
594
|
-
|
589
|
+
self._infer_input_output_cols(dataset)
|
590
|
+
super()._check_dataset_type(dataset)
|
591
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
592
|
+
estimator=self._sklearn_object,
|
593
|
+
dataset=dataset,
|
594
|
+
input_cols=self.input_cols,
|
595
|
+
label_cols=self.label_cols,
|
596
|
+
sample_weight_col=self.sample_weight_col,
|
597
|
+
autogenerated=self._autogenerated,
|
598
|
+
subproject=_SUBPROJECT,
|
599
|
+
)
|
600
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
601
|
+
drop_input_cols=self._drop_input_cols,
|
602
|
+
expected_output_cols_list=self.output_cols,
|
603
|
+
)
|
604
|
+
self._sklearn_object = fitted_estimator
|
605
|
+
self._is_fitted = True
|
606
|
+
return output_result
|
595
607
|
|
596
608
|
|
597
609
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -682,10 +694,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
682
694
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
683
695
|
|
684
696
|
if isinstance(dataset, DataFrame):
|
685
|
-
self.
|
686
|
-
|
687
|
-
inference_method=inference_method,
|
688
|
-
)
|
697
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
698
|
+
self._deps = self._get_dependencies()
|
689
699
|
assert isinstance(
|
690
700
|
dataset._session, Session
|
691
701
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -750,10 +760,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
750
760
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
751
761
|
|
752
762
|
if isinstance(dataset, DataFrame):
|
753
|
-
self.
|
754
|
-
|
755
|
-
inference_method=inference_method,
|
756
|
-
)
|
763
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
764
|
+
self._deps = self._get_dependencies()
|
757
765
|
assert isinstance(
|
758
766
|
dataset._session, Session
|
759
767
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -815,10 +823,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
815
823
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
816
824
|
|
817
825
|
if isinstance(dataset, DataFrame):
|
818
|
-
self.
|
819
|
-
|
820
|
-
inference_method=inference_method,
|
821
|
-
)
|
826
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
827
|
+
self._deps = self._get_dependencies()
|
822
828
|
assert isinstance(
|
823
829
|
dataset._session, Session
|
824
830
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -884,10 +890,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
884
890
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
885
891
|
|
886
892
|
if isinstance(dataset, DataFrame):
|
887
|
-
self.
|
888
|
-
|
889
|
-
inference_method=inference_method,
|
890
|
-
)
|
893
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
894
|
+
self._deps = self._get_dependencies()
|
891
895
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
892
896
|
transform_kwargs = dict(
|
893
897
|
session=dataset._session,
|
@@ -951,17 +955,15 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
951
955
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
952
956
|
|
953
957
|
if isinstance(dataset, DataFrame):
|
954
|
-
self.
|
955
|
-
|
956
|
-
inference_method="score",
|
957
|
-
)
|
958
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
959
|
+
self._deps = self._get_dependencies()
|
958
960
|
selected_cols = self._get_active_columns()
|
959
961
|
if len(selected_cols) > 0:
|
960
962
|
dataset = dataset.select(selected_cols)
|
961
963
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
962
964
|
transform_kwargs = dict(
|
963
965
|
session=dataset._session,
|
964
|
-
dependencies=
|
966
|
+
dependencies=self._deps,
|
965
967
|
score_sproc_imports=['sklearn'],
|
966
968
|
)
|
967
969
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1026,11 +1028,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
1026
1028
|
|
1027
1029
|
if isinstance(dataset, DataFrame):
|
1028
1030
|
|
1029
|
-
self.
|
1030
|
-
|
1031
|
-
inference_method=inference_method,
|
1032
|
-
|
1033
|
-
)
|
1031
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1032
|
+
self._deps = self._get_dependencies()
|
1034
1033
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1035
1034
|
transform_kwargs = dict(
|
1036
1035
|
session = dataset._session,
|