snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ElasticNet(BaseTransformer):
|
70
64
|
r"""Linear regression with combined L1 and L2 priors as regularizer
|
71
65
|
For more details on this class, see [sklearn.linear_model.ElasticNet]
|
@@ -329,20 +323,17 @@ class ElasticNet(BaseTransformer):
|
|
329
323
|
self,
|
330
324
|
dataset: DataFrame,
|
331
325
|
inference_method: str,
|
332
|
-
) ->
|
333
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
334
|
-
return the available package that exists in the snowflake anaconda channel
|
326
|
+
) -> None:
|
327
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
335
328
|
|
336
329
|
Args:
|
337
330
|
dataset: snowpark dataframe
|
338
331
|
inference_method: the inference method such as predict, score...
|
339
|
-
|
332
|
+
|
340
333
|
Raises:
|
341
334
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
342
335
|
SnowflakeMLException: If the session is None, raise error
|
343
336
|
|
344
|
-
Returns:
|
345
|
-
A list of available package that exists in the snowflake anaconda channel
|
346
337
|
"""
|
347
338
|
if not self._is_fitted:
|
348
339
|
raise exceptions.SnowflakeMLException(
|
@@ -360,9 +351,7 @@ class ElasticNet(BaseTransformer):
|
|
360
351
|
"Session must not specified for snowpark dataset."
|
361
352
|
),
|
362
353
|
)
|
363
|
-
|
364
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
365
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
354
|
+
|
366
355
|
|
367
356
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
368
357
|
@telemetry.send_api_usage_telemetry(
|
@@ -410,7 +399,8 @@ class ElasticNet(BaseTransformer):
|
|
410
399
|
|
411
400
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
412
401
|
|
413
|
-
self.
|
402
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
403
|
+
self._deps = self._get_dependencies()
|
414
404
|
assert isinstance(
|
415
405
|
dataset._session, Session
|
416
406
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -493,10 +483,8 @@ class ElasticNet(BaseTransformer):
|
|
493
483
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
494
484
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
495
485
|
|
496
|
-
self.
|
497
|
-
|
498
|
-
inference_method=inference_method,
|
499
|
-
)
|
486
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
487
|
+
self._deps = self._get_dependencies()
|
500
488
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
501
489
|
|
502
490
|
transform_kwargs = dict(
|
@@ -563,16 +551,40 @@ class ElasticNet(BaseTransformer):
|
|
563
551
|
self._is_fitted = True
|
564
552
|
return output_result
|
565
553
|
|
554
|
+
|
555
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
556
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
557
|
+
""" Method not supported for this class.
|
566
558
|
|
567
|
-
|
568
|
-
|
569
|
-
|
559
|
+
|
560
|
+
Raises:
|
561
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
562
|
+
|
563
|
+
Args:
|
564
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
565
|
+
Snowpark or Pandas DataFrame.
|
566
|
+
output_cols_prefix: Prefix for the response columns
|
570
567
|
Returns:
|
571
568
|
Transformed dataset.
|
572
569
|
"""
|
573
|
-
self.
|
574
|
-
|
575
|
-
|
570
|
+
self._infer_input_output_cols(dataset)
|
571
|
+
super()._check_dataset_type(dataset)
|
572
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
573
|
+
estimator=self._sklearn_object,
|
574
|
+
dataset=dataset,
|
575
|
+
input_cols=self.input_cols,
|
576
|
+
label_cols=self.label_cols,
|
577
|
+
sample_weight_col=self.sample_weight_col,
|
578
|
+
autogenerated=self._autogenerated,
|
579
|
+
subproject=_SUBPROJECT,
|
580
|
+
)
|
581
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
582
|
+
drop_input_cols=self._drop_input_cols,
|
583
|
+
expected_output_cols_list=self.output_cols,
|
584
|
+
)
|
585
|
+
self._sklearn_object = fitted_estimator
|
586
|
+
self._is_fitted = True
|
587
|
+
return output_result
|
576
588
|
|
577
589
|
|
578
590
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -663,10 +675,8 @@ class ElasticNet(BaseTransformer):
|
|
663
675
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
664
676
|
|
665
677
|
if isinstance(dataset, DataFrame):
|
666
|
-
self.
|
667
|
-
|
668
|
-
inference_method=inference_method,
|
669
|
-
)
|
678
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
679
|
+
self._deps = self._get_dependencies()
|
670
680
|
assert isinstance(
|
671
681
|
dataset._session, Session
|
672
682
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -731,10 +741,8 @@ class ElasticNet(BaseTransformer):
|
|
731
741
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
732
742
|
|
733
743
|
if isinstance(dataset, DataFrame):
|
734
|
-
self.
|
735
|
-
|
736
|
-
inference_method=inference_method,
|
737
|
-
)
|
744
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
745
|
+
self._deps = self._get_dependencies()
|
738
746
|
assert isinstance(
|
739
747
|
dataset._session, Session
|
740
748
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -796,10 +804,8 @@ class ElasticNet(BaseTransformer):
|
|
796
804
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
797
805
|
|
798
806
|
if isinstance(dataset, DataFrame):
|
799
|
-
self.
|
800
|
-
|
801
|
-
inference_method=inference_method,
|
802
|
-
)
|
807
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
808
|
+
self._deps = self._get_dependencies()
|
803
809
|
assert isinstance(
|
804
810
|
dataset._session, Session
|
805
811
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -865,10 +871,8 @@ class ElasticNet(BaseTransformer):
|
|
865
871
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
866
872
|
|
867
873
|
if isinstance(dataset, DataFrame):
|
868
|
-
self.
|
869
|
-
|
870
|
-
inference_method=inference_method,
|
871
|
-
)
|
874
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
875
|
+
self._deps = self._get_dependencies()
|
872
876
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
873
877
|
transform_kwargs = dict(
|
874
878
|
session=dataset._session,
|
@@ -932,17 +936,15 @@ class ElasticNet(BaseTransformer):
|
|
932
936
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
933
937
|
|
934
938
|
if isinstance(dataset, DataFrame):
|
935
|
-
self.
|
936
|
-
|
937
|
-
inference_method="score",
|
938
|
-
)
|
939
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
940
|
+
self._deps = self._get_dependencies()
|
939
941
|
selected_cols = self._get_active_columns()
|
940
942
|
if len(selected_cols) > 0:
|
941
943
|
dataset = dataset.select(selected_cols)
|
942
944
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
943
945
|
transform_kwargs = dict(
|
944
946
|
session=dataset._session,
|
945
|
-
dependencies=
|
947
|
+
dependencies=self._deps,
|
946
948
|
score_sproc_imports=['sklearn'],
|
947
949
|
)
|
948
950
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1007,11 +1009,8 @@ class ElasticNet(BaseTransformer):
|
|
1007
1009
|
|
1008
1010
|
if isinstance(dataset, DataFrame):
|
1009
1011
|
|
1010
|
-
self.
|
1011
|
-
|
1012
|
-
inference_method=inference_method,
|
1013
|
-
|
1014
|
-
)
|
1012
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1013
|
+
self._deps = self._get_dependencies()
|
1015
1014
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1016
1015
|
transform_kwargs = dict(
|
1017
1016
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ElasticNetCV(BaseTransformer):
|
70
64
|
r"""Elastic Net model with iterative fitting along a regularization path
|
71
65
|
For more details on this class, see [sklearn.linear_model.ElasticNetCV]
|
@@ -365,20 +359,17 @@ class ElasticNetCV(BaseTransformer):
|
|
365
359
|
self,
|
366
360
|
dataset: DataFrame,
|
367
361
|
inference_method: str,
|
368
|
-
) ->
|
369
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
370
|
-
return the available package that exists in the snowflake anaconda channel
|
362
|
+
) -> None:
|
363
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
371
364
|
|
372
365
|
Args:
|
373
366
|
dataset: snowpark dataframe
|
374
367
|
inference_method: the inference method such as predict, score...
|
375
|
-
|
368
|
+
|
376
369
|
Raises:
|
377
370
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
378
371
|
SnowflakeMLException: If the session is None, raise error
|
379
372
|
|
380
|
-
Returns:
|
381
|
-
A list of available package that exists in the snowflake anaconda channel
|
382
373
|
"""
|
383
374
|
if not self._is_fitted:
|
384
375
|
raise exceptions.SnowflakeMLException(
|
@@ -396,9 +387,7 @@ class ElasticNetCV(BaseTransformer):
|
|
396
387
|
"Session must not specified for snowpark dataset."
|
397
388
|
),
|
398
389
|
)
|
399
|
-
|
400
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
401
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
390
|
+
|
402
391
|
|
403
392
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
404
393
|
@telemetry.send_api_usage_telemetry(
|
@@ -446,7 +435,8 @@ class ElasticNetCV(BaseTransformer):
|
|
446
435
|
|
447
436
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
448
437
|
|
449
|
-
self.
|
438
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
439
|
+
self._deps = self._get_dependencies()
|
450
440
|
assert isinstance(
|
451
441
|
dataset._session, Session
|
452
442
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -529,10 +519,8 @@ class ElasticNetCV(BaseTransformer):
|
|
529
519
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
530
520
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
531
521
|
|
532
|
-
self.
|
533
|
-
|
534
|
-
inference_method=inference_method,
|
535
|
-
)
|
522
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
523
|
+
self._deps = self._get_dependencies()
|
536
524
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
537
525
|
|
538
526
|
transform_kwargs = dict(
|
@@ -599,16 +587,40 @@ class ElasticNetCV(BaseTransformer):
|
|
599
587
|
self._is_fitted = True
|
600
588
|
return output_result
|
601
589
|
|
590
|
+
|
591
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
592
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
593
|
+
""" Method not supported for this class.
|
602
594
|
|
603
|
-
|
604
|
-
|
605
|
-
|
595
|
+
|
596
|
+
Raises:
|
597
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
598
|
+
|
599
|
+
Args:
|
600
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
601
|
+
Snowpark or Pandas DataFrame.
|
602
|
+
output_cols_prefix: Prefix for the response columns
|
606
603
|
Returns:
|
607
604
|
Transformed dataset.
|
608
605
|
"""
|
609
|
-
self.
|
610
|
-
|
611
|
-
|
606
|
+
self._infer_input_output_cols(dataset)
|
607
|
+
super()._check_dataset_type(dataset)
|
608
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
609
|
+
estimator=self._sklearn_object,
|
610
|
+
dataset=dataset,
|
611
|
+
input_cols=self.input_cols,
|
612
|
+
label_cols=self.label_cols,
|
613
|
+
sample_weight_col=self.sample_weight_col,
|
614
|
+
autogenerated=self._autogenerated,
|
615
|
+
subproject=_SUBPROJECT,
|
616
|
+
)
|
617
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
618
|
+
drop_input_cols=self._drop_input_cols,
|
619
|
+
expected_output_cols_list=self.output_cols,
|
620
|
+
)
|
621
|
+
self._sklearn_object = fitted_estimator
|
622
|
+
self._is_fitted = True
|
623
|
+
return output_result
|
612
624
|
|
613
625
|
|
614
626
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -699,10 +711,8 @@ class ElasticNetCV(BaseTransformer):
|
|
699
711
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
700
712
|
|
701
713
|
if isinstance(dataset, DataFrame):
|
702
|
-
self.
|
703
|
-
|
704
|
-
inference_method=inference_method,
|
705
|
-
)
|
714
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
715
|
+
self._deps = self._get_dependencies()
|
706
716
|
assert isinstance(
|
707
717
|
dataset._session, Session
|
708
718
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -767,10 +777,8 @@ class ElasticNetCV(BaseTransformer):
|
|
767
777
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
768
778
|
|
769
779
|
if isinstance(dataset, DataFrame):
|
770
|
-
self.
|
771
|
-
|
772
|
-
inference_method=inference_method,
|
773
|
-
)
|
780
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
781
|
+
self._deps = self._get_dependencies()
|
774
782
|
assert isinstance(
|
775
783
|
dataset._session, Session
|
776
784
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -832,10 +840,8 @@ class ElasticNetCV(BaseTransformer):
|
|
832
840
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
833
841
|
|
834
842
|
if isinstance(dataset, DataFrame):
|
835
|
-
self.
|
836
|
-
|
837
|
-
inference_method=inference_method,
|
838
|
-
)
|
843
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
844
|
+
self._deps = self._get_dependencies()
|
839
845
|
assert isinstance(
|
840
846
|
dataset._session, Session
|
841
847
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -901,10 +907,8 @@ class ElasticNetCV(BaseTransformer):
|
|
901
907
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
902
908
|
|
903
909
|
if isinstance(dataset, DataFrame):
|
904
|
-
self.
|
905
|
-
|
906
|
-
inference_method=inference_method,
|
907
|
-
)
|
910
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
911
|
+
self._deps = self._get_dependencies()
|
908
912
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
909
913
|
transform_kwargs = dict(
|
910
914
|
session=dataset._session,
|
@@ -968,17 +972,15 @@ class ElasticNetCV(BaseTransformer):
|
|
968
972
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
969
973
|
|
970
974
|
if isinstance(dataset, DataFrame):
|
971
|
-
self.
|
972
|
-
|
973
|
-
inference_method="score",
|
974
|
-
)
|
975
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
976
|
+
self._deps = self._get_dependencies()
|
975
977
|
selected_cols = self._get_active_columns()
|
976
978
|
if len(selected_cols) > 0:
|
977
979
|
dataset = dataset.select(selected_cols)
|
978
980
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
979
981
|
transform_kwargs = dict(
|
980
982
|
session=dataset._session,
|
981
|
-
dependencies=
|
983
|
+
dependencies=self._deps,
|
982
984
|
score_sproc_imports=['sklearn'],
|
983
985
|
)
|
984
986
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1043,11 +1045,8 @@ class ElasticNetCV(BaseTransformer):
|
|
1043
1045
|
|
1044
1046
|
if isinstance(dataset, DataFrame):
|
1045
1047
|
|
1046
|
-
self.
|
1047
|
-
|
1048
|
-
inference_method=inference_method,
|
1049
|
-
|
1050
|
-
)
|
1048
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1049
|
+
self._deps = self._get_dependencies()
|
1051
1050
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1052
1051
|
transform_kwargs = dict(
|
1053
1052
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GammaRegressor(BaseTransformer):
|
70
64
|
r"""Generalized Linear Model with a Gamma distribution
|
71
65
|
For more details on this class, see [sklearn.linear_model.GammaRegressor]
|
@@ -310,20 +304,17 @@ class GammaRegressor(BaseTransformer):
|
|
310
304
|
self,
|
311
305
|
dataset: DataFrame,
|
312
306
|
inference_method: str,
|
313
|
-
) ->
|
314
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
315
|
-
return the available package that exists in the snowflake anaconda channel
|
307
|
+
) -> None:
|
308
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
316
309
|
|
317
310
|
Args:
|
318
311
|
dataset: snowpark dataframe
|
319
312
|
inference_method: the inference method such as predict, score...
|
320
|
-
|
313
|
+
|
321
314
|
Raises:
|
322
315
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
323
316
|
SnowflakeMLException: If the session is None, raise error
|
324
317
|
|
325
|
-
Returns:
|
326
|
-
A list of available package that exists in the snowflake anaconda channel
|
327
318
|
"""
|
328
319
|
if not self._is_fitted:
|
329
320
|
raise exceptions.SnowflakeMLException(
|
@@ -341,9 +332,7 @@ class GammaRegressor(BaseTransformer):
|
|
341
332
|
"Session must not specified for snowpark dataset."
|
342
333
|
),
|
343
334
|
)
|
344
|
-
|
345
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
346
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
335
|
+
|
347
336
|
|
348
337
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
349
338
|
@telemetry.send_api_usage_telemetry(
|
@@ -391,7 +380,8 @@ class GammaRegressor(BaseTransformer):
|
|
391
380
|
|
392
381
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
393
382
|
|
394
|
-
self.
|
383
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
384
|
+
self._deps = self._get_dependencies()
|
395
385
|
assert isinstance(
|
396
386
|
dataset._session, Session
|
397
387
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -474,10 +464,8 @@ class GammaRegressor(BaseTransformer):
|
|
474
464
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
475
465
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
476
466
|
|
477
|
-
self.
|
478
|
-
|
479
|
-
inference_method=inference_method,
|
480
|
-
)
|
467
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
468
|
+
self._deps = self._get_dependencies()
|
481
469
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
482
470
|
|
483
471
|
transform_kwargs = dict(
|
@@ -544,16 +532,40 @@ class GammaRegressor(BaseTransformer):
|
|
544
532
|
self._is_fitted = True
|
545
533
|
return output_result
|
546
534
|
|
535
|
+
|
536
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
537
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
538
|
+
""" Method not supported for this class.
|
547
539
|
|
548
|
-
|
549
|
-
|
550
|
-
|
540
|
+
|
541
|
+
Raises:
|
542
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
543
|
+
|
544
|
+
Args:
|
545
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
546
|
+
Snowpark or Pandas DataFrame.
|
547
|
+
output_cols_prefix: Prefix for the response columns
|
551
548
|
Returns:
|
552
549
|
Transformed dataset.
|
553
550
|
"""
|
554
|
-
self.
|
555
|
-
|
556
|
-
|
551
|
+
self._infer_input_output_cols(dataset)
|
552
|
+
super()._check_dataset_type(dataset)
|
553
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
554
|
+
estimator=self._sklearn_object,
|
555
|
+
dataset=dataset,
|
556
|
+
input_cols=self.input_cols,
|
557
|
+
label_cols=self.label_cols,
|
558
|
+
sample_weight_col=self.sample_weight_col,
|
559
|
+
autogenerated=self._autogenerated,
|
560
|
+
subproject=_SUBPROJECT,
|
561
|
+
)
|
562
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
563
|
+
drop_input_cols=self._drop_input_cols,
|
564
|
+
expected_output_cols_list=self.output_cols,
|
565
|
+
)
|
566
|
+
self._sklearn_object = fitted_estimator
|
567
|
+
self._is_fitted = True
|
568
|
+
return output_result
|
557
569
|
|
558
570
|
|
559
571
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -644,10 +656,8 @@ class GammaRegressor(BaseTransformer):
|
|
644
656
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
645
657
|
|
646
658
|
if isinstance(dataset, DataFrame):
|
647
|
-
self.
|
648
|
-
|
649
|
-
inference_method=inference_method,
|
650
|
-
)
|
659
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
660
|
+
self._deps = self._get_dependencies()
|
651
661
|
assert isinstance(
|
652
662
|
dataset._session, Session
|
653
663
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -712,10 +722,8 @@ class GammaRegressor(BaseTransformer):
|
|
712
722
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
713
723
|
|
714
724
|
if isinstance(dataset, DataFrame):
|
715
|
-
self.
|
716
|
-
|
717
|
-
inference_method=inference_method,
|
718
|
-
)
|
725
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
726
|
+
self._deps = self._get_dependencies()
|
719
727
|
assert isinstance(
|
720
728
|
dataset._session, Session
|
721
729
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -777,10 +785,8 @@ class GammaRegressor(BaseTransformer):
|
|
777
785
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
778
786
|
|
779
787
|
if isinstance(dataset, DataFrame):
|
780
|
-
self.
|
781
|
-
|
782
|
-
inference_method=inference_method,
|
783
|
-
)
|
788
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
789
|
+
self._deps = self._get_dependencies()
|
784
790
|
assert isinstance(
|
785
791
|
dataset._session, Session
|
786
792
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -846,10 +852,8 @@ class GammaRegressor(BaseTransformer):
|
|
846
852
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
847
853
|
|
848
854
|
if isinstance(dataset, DataFrame):
|
849
|
-
self.
|
850
|
-
|
851
|
-
inference_method=inference_method,
|
852
|
-
)
|
855
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
856
|
+
self._deps = self._get_dependencies()
|
853
857
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
854
858
|
transform_kwargs = dict(
|
855
859
|
session=dataset._session,
|
@@ -913,17 +917,15 @@ class GammaRegressor(BaseTransformer):
|
|
913
917
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
914
918
|
|
915
919
|
if isinstance(dataset, DataFrame):
|
916
|
-
self.
|
917
|
-
|
918
|
-
inference_method="score",
|
919
|
-
)
|
920
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
921
|
+
self._deps = self._get_dependencies()
|
920
922
|
selected_cols = self._get_active_columns()
|
921
923
|
if len(selected_cols) > 0:
|
922
924
|
dataset = dataset.select(selected_cols)
|
923
925
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
924
926
|
transform_kwargs = dict(
|
925
927
|
session=dataset._session,
|
926
|
-
dependencies=
|
928
|
+
dependencies=self._deps,
|
927
929
|
score_sproc_imports=['sklearn'],
|
928
930
|
)
|
929
931
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -988,11 +990,8 @@ class GammaRegressor(BaseTransformer):
|
|
988
990
|
|
989
991
|
if isinstance(dataset, DataFrame):
|
990
992
|
|
991
|
-
self.
|
992
|
-
|
993
|
-
inference_method=inference_method,
|
994
|
-
|
995
|
-
)
|
993
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
994
|
+
self._deps = self._get_dependencies()
|
996
995
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
997
996
|
transform_kwargs = dict(
|
998
997
|
session = dataset._session,
|