snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class BernoulliRBM(BaseTransformer):
|
70
64
|
r"""Bernoulli Restricted Boltzmann Machine (RBM)
|
71
65
|
For more details on this class, see [sklearn.neural_network.BernoulliRBM]
|
@@ -293,20 +287,17 @@ class BernoulliRBM(BaseTransformer):
|
|
293
287
|
self,
|
294
288
|
dataset: DataFrame,
|
295
289
|
inference_method: str,
|
296
|
-
) ->
|
297
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
298
|
-
return the available package that exists in the snowflake anaconda channel
|
290
|
+
) -> None:
|
291
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
299
292
|
|
300
293
|
Args:
|
301
294
|
dataset: snowpark dataframe
|
302
295
|
inference_method: the inference method such as predict, score...
|
303
|
-
|
296
|
+
|
304
297
|
Raises:
|
305
298
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
306
299
|
SnowflakeMLException: If the session is None, raise error
|
307
300
|
|
308
|
-
Returns:
|
309
|
-
A list of available package that exists in the snowflake anaconda channel
|
310
301
|
"""
|
311
302
|
if not self._is_fitted:
|
312
303
|
raise exceptions.SnowflakeMLException(
|
@@ -324,9 +315,7 @@ class BernoulliRBM(BaseTransformer):
|
|
324
315
|
"Session must not specified for snowpark dataset."
|
325
316
|
),
|
326
317
|
)
|
327
|
-
|
328
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
329
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
318
|
+
|
330
319
|
|
331
320
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
332
321
|
@telemetry.send_api_usage_telemetry(
|
@@ -372,7 +361,8 @@ class BernoulliRBM(BaseTransformer):
|
|
372
361
|
|
373
362
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
374
363
|
|
375
|
-
self.
|
364
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
365
|
+
self._deps = self._get_dependencies()
|
376
366
|
assert isinstance(
|
377
367
|
dataset._session, Session
|
378
368
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -457,10 +447,8 @@ class BernoulliRBM(BaseTransformer):
|
|
457
447
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
458
448
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
459
449
|
|
460
|
-
self.
|
461
|
-
|
462
|
-
inference_method=inference_method,
|
463
|
-
)
|
450
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
451
|
+
self._deps = self._get_dependencies()
|
464
452
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
465
453
|
|
466
454
|
transform_kwargs = dict(
|
@@ -527,16 +515,42 @@ class BernoulliRBM(BaseTransformer):
|
|
527
515
|
self._is_fitted = True
|
528
516
|
return output_result
|
529
517
|
|
518
|
+
|
519
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
520
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
521
|
+
""" Fit to data, then transform it
|
522
|
+
For more details on this function, see [sklearn.neural_network.BernoulliRBM.fit_transform]
|
523
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.BernoulliRBM.html#sklearn.neural_network.BernoulliRBM.fit_transform)
|
524
|
+
|
530
525
|
|
531
|
-
|
532
|
-
|
533
|
-
|
526
|
+
Raises:
|
527
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
528
|
+
|
529
|
+
Args:
|
530
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
531
|
+
Snowpark or Pandas DataFrame.
|
532
|
+
output_cols_prefix: Prefix for the response columns
|
534
533
|
Returns:
|
535
534
|
Transformed dataset.
|
536
535
|
"""
|
537
|
-
self.
|
538
|
-
|
539
|
-
|
536
|
+
self._infer_input_output_cols(dataset)
|
537
|
+
super()._check_dataset_type(dataset)
|
538
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
539
|
+
estimator=self._sklearn_object,
|
540
|
+
dataset=dataset,
|
541
|
+
input_cols=self.input_cols,
|
542
|
+
label_cols=self.label_cols,
|
543
|
+
sample_weight_col=self.sample_weight_col,
|
544
|
+
autogenerated=self._autogenerated,
|
545
|
+
subproject=_SUBPROJECT,
|
546
|
+
)
|
547
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
549
|
+
expected_output_cols_list=self.output_cols,
|
550
|
+
)
|
551
|
+
self._sklearn_object = fitted_estimator
|
552
|
+
self._is_fitted = True
|
553
|
+
return output_result
|
540
554
|
|
541
555
|
|
542
556
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -627,10 +641,8 @@ class BernoulliRBM(BaseTransformer):
|
|
627
641
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
628
642
|
|
629
643
|
if isinstance(dataset, DataFrame):
|
630
|
-
self.
|
631
|
-
|
632
|
-
inference_method=inference_method,
|
633
|
-
)
|
644
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
645
|
+
self._deps = self._get_dependencies()
|
634
646
|
assert isinstance(
|
635
647
|
dataset._session, Session
|
636
648
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -695,10 +707,8 @@ class BernoulliRBM(BaseTransformer):
|
|
695
707
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
696
708
|
|
697
709
|
if isinstance(dataset, DataFrame):
|
698
|
-
self.
|
699
|
-
|
700
|
-
inference_method=inference_method,
|
701
|
-
)
|
710
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
711
|
+
self._deps = self._get_dependencies()
|
702
712
|
assert isinstance(
|
703
713
|
dataset._session, Session
|
704
714
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -760,10 +770,8 @@ class BernoulliRBM(BaseTransformer):
|
|
760
770
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
761
771
|
|
762
772
|
if isinstance(dataset, DataFrame):
|
763
|
-
self.
|
764
|
-
|
765
|
-
inference_method=inference_method,
|
766
|
-
)
|
773
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
774
|
+
self._deps = self._get_dependencies()
|
767
775
|
assert isinstance(
|
768
776
|
dataset._session, Session
|
769
777
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -831,10 +839,8 @@ class BernoulliRBM(BaseTransformer):
|
|
831
839
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
832
840
|
|
833
841
|
if isinstance(dataset, DataFrame):
|
834
|
-
self.
|
835
|
-
|
836
|
-
inference_method=inference_method,
|
837
|
-
)
|
842
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
843
|
+
self._deps = self._get_dependencies()
|
838
844
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
839
845
|
transform_kwargs = dict(
|
840
846
|
session=dataset._session,
|
@@ -896,17 +902,15 @@ class BernoulliRBM(BaseTransformer):
|
|
896
902
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
897
903
|
|
898
904
|
if isinstance(dataset, DataFrame):
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method="score",
|
902
|
-
)
|
905
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
906
|
+
self._deps = self._get_dependencies()
|
903
907
|
selected_cols = self._get_active_columns()
|
904
908
|
if len(selected_cols) > 0:
|
905
909
|
dataset = dataset.select(selected_cols)
|
906
910
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
907
911
|
transform_kwargs = dict(
|
908
912
|
session=dataset._session,
|
909
|
-
dependencies=
|
913
|
+
dependencies=self._deps,
|
910
914
|
score_sproc_imports=['sklearn'],
|
911
915
|
)
|
912
916
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -971,11 +975,8 @@ class BernoulliRBM(BaseTransformer):
|
|
971
975
|
|
972
976
|
if isinstance(dataset, DataFrame):
|
973
977
|
|
974
|
-
self.
|
975
|
-
|
976
|
-
inference_method=inference_method,
|
977
|
-
|
978
|
-
)
|
978
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
979
|
+
self._deps = self._get_dependencies()
|
979
980
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
980
981
|
transform_kwargs = dict(
|
981
982
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MLPClassifier(BaseTransformer):
|
70
64
|
r"""Multi-layer Perceptron classifier
|
71
65
|
For more details on this class, see [sklearn.neural_network.MLPClassifier]
|
@@ -448,20 +442,17 @@ class MLPClassifier(BaseTransformer):
|
|
448
442
|
self,
|
449
443
|
dataset: DataFrame,
|
450
444
|
inference_method: str,
|
451
|
-
) ->
|
452
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
453
|
-
return the available package that exists in the snowflake anaconda channel
|
445
|
+
) -> None:
|
446
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
454
447
|
|
455
448
|
Args:
|
456
449
|
dataset: snowpark dataframe
|
457
450
|
inference_method: the inference method such as predict, score...
|
458
|
-
|
451
|
+
|
459
452
|
Raises:
|
460
453
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
461
454
|
SnowflakeMLException: If the session is None, raise error
|
462
455
|
|
463
|
-
Returns:
|
464
|
-
A list of available package that exists in the snowflake anaconda channel
|
465
456
|
"""
|
466
457
|
if not self._is_fitted:
|
467
458
|
raise exceptions.SnowflakeMLException(
|
@@ -479,9 +470,7 @@ class MLPClassifier(BaseTransformer):
|
|
479
470
|
"Session must not specified for snowpark dataset."
|
480
471
|
),
|
481
472
|
)
|
482
|
-
|
483
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
484
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
473
|
+
|
485
474
|
|
486
475
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
487
476
|
@telemetry.send_api_usage_telemetry(
|
@@ -529,7 +518,8 @@ class MLPClassifier(BaseTransformer):
|
|
529
518
|
|
530
519
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
531
520
|
|
532
|
-
self.
|
521
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
522
|
+
self._deps = self._get_dependencies()
|
533
523
|
assert isinstance(
|
534
524
|
dataset._session, Session
|
535
525
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -612,10 +602,8 @@ class MLPClassifier(BaseTransformer):
|
|
612
602
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
613
603
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
614
604
|
|
615
|
-
self.
|
616
|
-
|
617
|
-
inference_method=inference_method,
|
618
|
-
)
|
605
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
606
|
+
self._deps = self._get_dependencies()
|
619
607
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
620
608
|
|
621
609
|
transform_kwargs = dict(
|
@@ -682,16 +670,40 @@ class MLPClassifier(BaseTransformer):
|
|
682
670
|
self._is_fitted = True
|
683
671
|
return output_result
|
684
672
|
|
673
|
+
|
674
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
675
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
676
|
+
""" Method not supported for this class.
|
685
677
|
|
686
|
-
|
687
|
-
|
688
|
-
|
678
|
+
|
679
|
+
Raises:
|
680
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
681
|
+
|
682
|
+
Args:
|
683
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
684
|
+
Snowpark or Pandas DataFrame.
|
685
|
+
output_cols_prefix: Prefix for the response columns
|
689
686
|
Returns:
|
690
687
|
Transformed dataset.
|
691
688
|
"""
|
692
|
-
self.
|
693
|
-
|
694
|
-
|
689
|
+
self._infer_input_output_cols(dataset)
|
690
|
+
super()._check_dataset_type(dataset)
|
691
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
692
|
+
estimator=self._sklearn_object,
|
693
|
+
dataset=dataset,
|
694
|
+
input_cols=self.input_cols,
|
695
|
+
label_cols=self.label_cols,
|
696
|
+
sample_weight_col=self.sample_weight_col,
|
697
|
+
autogenerated=self._autogenerated,
|
698
|
+
subproject=_SUBPROJECT,
|
699
|
+
)
|
700
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
701
|
+
drop_input_cols=self._drop_input_cols,
|
702
|
+
expected_output_cols_list=self.output_cols,
|
703
|
+
)
|
704
|
+
self._sklearn_object = fitted_estimator
|
705
|
+
self._is_fitted = True
|
706
|
+
return output_result
|
695
707
|
|
696
708
|
|
697
709
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -784,10 +796,8 @@ class MLPClassifier(BaseTransformer):
|
|
784
796
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
785
797
|
|
786
798
|
if isinstance(dataset, DataFrame):
|
787
|
-
self.
|
788
|
-
|
789
|
-
inference_method=inference_method,
|
790
|
-
)
|
799
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
800
|
+
self._deps = self._get_dependencies()
|
791
801
|
assert isinstance(
|
792
802
|
dataset._session, Session
|
793
803
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -854,10 +864,8 @@ class MLPClassifier(BaseTransformer):
|
|
854
864
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
855
865
|
|
856
866
|
if isinstance(dataset, DataFrame):
|
857
|
-
self.
|
858
|
-
|
859
|
-
inference_method=inference_method,
|
860
|
-
)
|
867
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
868
|
+
self._deps = self._get_dependencies()
|
861
869
|
assert isinstance(
|
862
870
|
dataset._session, Session
|
863
871
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -919,10 +927,8 @@ class MLPClassifier(BaseTransformer):
|
|
919
927
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
920
928
|
|
921
929
|
if isinstance(dataset, DataFrame):
|
922
|
-
self.
|
923
|
-
|
924
|
-
inference_method=inference_method,
|
925
|
-
)
|
930
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
931
|
+
self._deps = self._get_dependencies()
|
926
932
|
assert isinstance(
|
927
933
|
dataset._session, Session
|
928
934
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -988,10 +994,8 @@ class MLPClassifier(BaseTransformer):
|
|
988
994
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
989
995
|
|
990
996
|
if isinstance(dataset, DataFrame):
|
991
|
-
self.
|
992
|
-
|
993
|
-
inference_method=inference_method,
|
994
|
-
)
|
997
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
998
|
+
self._deps = self._get_dependencies()
|
995
999
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
996
1000
|
transform_kwargs = dict(
|
997
1001
|
session=dataset._session,
|
@@ -1055,17 +1059,15 @@ class MLPClassifier(BaseTransformer):
|
|
1055
1059
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1056
1060
|
|
1057
1061
|
if isinstance(dataset, DataFrame):
|
1058
|
-
self.
|
1059
|
-
|
1060
|
-
inference_method="score",
|
1061
|
-
)
|
1062
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1063
|
+
self._deps = self._get_dependencies()
|
1062
1064
|
selected_cols = self._get_active_columns()
|
1063
1065
|
if len(selected_cols) > 0:
|
1064
1066
|
dataset = dataset.select(selected_cols)
|
1065
1067
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1066
1068
|
transform_kwargs = dict(
|
1067
1069
|
session=dataset._session,
|
1068
|
-
dependencies=
|
1070
|
+
dependencies=self._deps,
|
1069
1071
|
score_sproc_imports=['sklearn'],
|
1070
1072
|
)
|
1071
1073
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1130,11 +1132,8 @@ class MLPClassifier(BaseTransformer):
|
|
1130
1132
|
|
1131
1133
|
if isinstance(dataset, DataFrame):
|
1132
1134
|
|
1133
|
-
self.
|
1134
|
-
|
1135
|
-
inference_method=inference_method,
|
1136
|
-
|
1137
|
-
)
|
1135
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1136
|
+
self._deps = self._get_dependencies()
|
1138
1137
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1139
1138
|
transform_kwargs = dict(
|
1140
1139
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MLPRegressor(BaseTransformer):
|
70
64
|
r"""Multi-layer Perceptron regressor
|
71
65
|
For more details on this class, see [sklearn.neural_network.MLPRegressor]
|
@@ -444,20 +438,17 @@ class MLPRegressor(BaseTransformer):
|
|
444
438
|
self,
|
445
439
|
dataset: DataFrame,
|
446
440
|
inference_method: str,
|
447
|
-
) ->
|
448
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
449
|
-
return the available package that exists in the snowflake anaconda channel
|
441
|
+
) -> None:
|
442
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
450
443
|
|
451
444
|
Args:
|
452
445
|
dataset: snowpark dataframe
|
453
446
|
inference_method: the inference method such as predict, score...
|
454
|
-
|
447
|
+
|
455
448
|
Raises:
|
456
449
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
457
450
|
SnowflakeMLException: If the session is None, raise error
|
458
451
|
|
459
|
-
Returns:
|
460
|
-
A list of available package that exists in the snowflake anaconda channel
|
461
452
|
"""
|
462
453
|
if not self._is_fitted:
|
463
454
|
raise exceptions.SnowflakeMLException(
|
@@ -475,9 +466,7 @@ class MLPRegressor(BaseTransformer):
|
|
475
466
|
"Session must not specified for snowpark dataset."
|
476
467
|
),
|
477
468
|
)
|
478
|
-
|
479
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
480
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
469
|
+
|
481
470
|
|
482
471
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
483
472
|
@telemetry.send_api_usage_telemetry(
|
@@ -525,7 +514,8 @@ class MLPRegressor(BaseTransformer):
|
|
525
514
|
|
526
515
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
527
516
|
|
528
|
-
self.
|
517
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
518
|
+
self._deps = self._get_dependencies()
|
529
519
|
assert isinstance(
|
530
520
|
dataset._session, Session
|
531
521
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -608,10 +598,8 @@ class MLPRegressor(BaseTransformer):
|
|
608
598
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
609
599
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
610
600
|
|
611
|
-
self.
|
612
|
-
|
613
|
-
inference_method=inference_method,
|
614
|
-
)
|
601
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
602
|
+
self._deps = self._get_dependencies()
|
615
603
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
616
604
|
|
617
605
|
transform_kwargs = dict(
|
@@ -678,16 +666,40 @@ class MLPRegressor(BaseTransformer):
|
|
678
666
|
self._is_fitted = True
|
679
667
|
return output_result
|
680
668
|
|
669
|
+
|
670
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
671
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
672
|
+
""" Method not supported for this class.
|
681
673
|
|
682
|
-
|
683
|
-
|
684
|
-
|
674
|
+
|
675
|
+
Raises:
|
676
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
677
|
+
|
678
|
+
Args:
|
679
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
680
|
+
Snowpark or Pandas DataFrame.
|
681
|
+
output_cols_prefix: Prefix for the response columns
|
685
682
|
Returns:
|
686
683
|
Transformed dataset.
|
687
684
|
"""
|
688
|
-
self.
|
689
|
-
|
690
|
-
|
685
|
+
self._infer_input_output_cols(dataset)
|
686
|
+
super()._check_dataset_type(dataset)
|
687
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
688
|
+
estimator=self._sklearn_object,
|
689
|
+
dataset=dataset,
|
690
|
+
input_cols=self.input_cols,
|
691
|
+
label_cols=self.label_cols,
|
692
|
+
sample_weight_col=self.sample_weight_col,
|
693
|
+
autogenerated=self._autogenerated,
|
694
|
+
subproject=_SUBPROJECT,
|
695
|
+
)
|
696
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
697
|
+
drop_input_cols=self._drop_input_cols,
|
698
|
+
expected_output_cols_list=self.output_cols,
|
699
|
+
)
|
700
|
+
self._sklearn_object = fitted_estimator
|
701
|
+
self._is_fitted = True
|
702
|
+
return output_result
|
691
703
|
|
692
704
|
|
693
705
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -778,10 +790,8 @@ class MLPRegressor(BaseTransformer):
|
|
778
790
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
779
791
|
|
780
792
|
if isinstance(dataset, DataFrame):
|
781
|
-
self.
|
782
|
-
|
783
|
-
inference_method=inference_method,
|
784
|
-
)
|
793
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
794
|
+
self._deps = self._get_dependencies()
|
785
795
|
assert isinstance(
|
786
796
|
dataset._session, Session
|
787
797
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -846,10 +856,8 @@ class MLPRegressor(BaseTransformer):
|
|
846
856
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
847
857
|
|
848
858
|
if isinstance(dataset, DataFrame):
|
849
|
-
self.
|
850
|
-
|
851
|
-
inference_method=inference_method,
|
852
|
-
)
|
859
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
860
|
+
self._deps = self._get_dependencies()
|
853
861
|
assert isinstance(
|
854
862
|
dataset._session, Session
|
855
863
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -911,10 +919,8 @@ class MLPRegressor(BaseTransformer):
|
|
911
919
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
912
920
|
|
913
921
|
if isinstance(dataset, DataFrame):
|
914
|
-
self.
|
915
|
-
|
916
|
-
inference_method=inference_method,
|
917
|
-
)
|
922
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
923
|
+
self._deps = self._get_dependencies()
|
918
924
|
assert isinstance(
|
919
925
|
dataset._session, Session
|
920
926
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -980,10 +986,8 @@ class MLPRegressor(BaseTransformer):
|
|
980
986
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
981
987
|
|
982
988
|
if isinstance(dataset, DataFrame):
|
983
|
-
self.
|
984
|
-
|
985
|
-
inference_method=inference_method,
|
986
|
-
)
|
989
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
990
|
+
self._deps = self._get_dependencies()
|
987
991
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
988
992
|
transform_kwargs = dict(
|
989
993
|
session=dataset._session,
|
@@ -1047,17 +1051,15 @@ class MLPRegressor(BaseTransformer):
|
|
1047
1051
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1048
1052
|
|
1049
1053
|
if isinstance(dataset, DataFrame):
|
1050
|
-
self.
|
1051
|
-
|
1052
|
-
inference_method="score",
|
1053
|
-
)
|
1054
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1055
|
+
self._deps = self._get_dependencies()
|
1054
1056
|
selected_cols = self._get_active_columns()
|
1055
1057
|
if len(selected_cols) > 0:
|
1056
1058
|
dataset = dataset.select(selected_cols)
|
1057
1059
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1058
1060
|
transform_kwargs = dict(
|
1059
1061
|
session=dataset._session,
|
1060
|
-
dependencies=
|
1062
|
+
dependencies=self._deps,
|
1061
1063
|
score_sproc_imports=['sklearn'],
|
1062
1064
|
)
|
1063
1065
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1122,11 +1124,8 @@ class MLPRegressor(BaseTransformer):
|
|
1122
1124
|
|
1123
1125
|
if isinstance(dataset, DataFrame):
|
1124
1126
|
|
1125
|
-
self.
|
1126
|
-
|
1127
|
-
inference_method=inference_method,
|
1128
|
-
|
1129
|
-
)
|
1127
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1128
|
+
self._deps = self._get_dependencies()
|
1130
1129
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1131
1130
|
transform_kwargs = dict(
|
1132
1131
|
session = dataset._session,
|