snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class Perceptron(BaseTransformer):
|
70
64
|
r"""Linear perceptron classifier
|
71
65
|
For more details on this class, see [sklearn.linear_model.Perceptron]
|
@@ -361,20 +355,17 @@ class Perceptron(BaseTransformer):
|
|
361
355
|
self,
|
362
356
|
dataset: DataFrame,
|
363
357
|
inference_method: str,
|
364
|
-
) ->
|
365
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
366
|
-
return the available package that exists in the snowflake anaconda channel
|
358
|
+
) -> None:
|
359
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
367
360
|
|
368
361
|
Args:
|
369
362
|
dataset: snowpark dataframe
|
370
363
|
inference_method: the inference method such as predict, score...
|
371
|
-
|
364
|
+
|
372
365
|
Raises:
|
373
366
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
374
367
|
SnowflakeMLException: If the session is None, raise error
|
375
368
|
|
376
|
-
Returns:
|
377
|
-
A list of available package that exists in the snowflake anaconda channel
|
378
369
|
"""
|
379
370
|
if not self._is_fitted:
|
380
371
|
raise exceptions.SnowflakeMLException(
|
@@ -392,9 +383,7 @@ class Perceptron(BaseTransformer):
|
|
392
383
|
"Session must not specified for snowpark dataset."
|
393
384
|
),
|
394
385
|
)
|
395
|
-
|
396
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
397
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
386
|
+
|
398
387
|
|
399
388
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
400
389
|
@telemetry.send_api_usage_telemetry(
|
@@ -442,7 +431,8 @@ class Perceptron(BaseTransformer):
|
|
442
431
|
|
443
432
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
444
433
|
|
445
|
-
self.
|
434
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
435
|
+
self._deps = self._get_dependencies()
|
446
436
|
assert isinstance(
|
447
437
|
dataset._session, Session
|
448
438
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -525,10 +515,8 @@ class Perceptron(BaseTransformer):
|
|
525
515
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
526
516
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
527
517
|
|
528
|
-
self.
|
529
|
-
|
530
|
-
inference_method=inference_method,
|
531
|
-
)
|
518
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
519
|
+
self._deps = self._get_dependencies()
|
532
520
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
533
521
|
|
534
522
|
transform_kwargs = dict(
|
@@ -595,16 +583,40 @@ class Perceptron(BaseTransformer):
|
|
595
583
|
self._is_fitted = True
|
596
584
|
return output_result
|
597
585
|
|
586
|
+
|
587
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
588
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
589
|
+
""" Method not supported for this class.
|
598
590
|
|
599
|
-
|
600
|
-
|
601
|
-
|
591
|
+
|
592
|
+
Raises:
|
593
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
594
|
+
|
595
|
+
Args:
|
596
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
597
|
+
Snowpark or Pandas DataFrame.
|
598
|
+
output_cols_prefix: Prefix for the response columns
|
602
599
|
Returns:
|
603
600
|
Transformed dataset.
|
604
601
|
"""
|
605
|
-
self.
|
606
|
-
|
607
|
-
|
602
|
+
self._infer_input_output_cols(dataset)
|
603
|
+
super()._check_dataset_type(dataset)
|
604
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
605
|
+
estimator=self._sklearn_object,
|
606
|
+
dataset=dataset,
|
607
|
+
input_cols=self.input_cols,
|
608
|
+
label_cols=self.label_cols,
|
609
|
+
sample_weight_col=self.sample_weight_col,
|
610
|
+
autogenerated=self._autogenerated,
|
611
|
+
subproject=_SUBPROJECT,
|
612
|
+
)
|
613
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
614
|
+
drop_input_cols=self._drop_input_cols,
|
615
|
+
expected_output_cols_list=self.output_cols,
|
616
|
+
)
|
617
|
+
self._sklearn_object = fitted_estimator
|
618
|
+
self._is_fitted = True
|
619
|
+
return output_result
|
608
620
|
|
609
621
|
|
610
622
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -695,10 +707,8 @@ class Perceptron(BaseTransformer):
|
|
695
707
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
696
708
|
|
697
709
|
if isinstance(dataset, DataFrame):
|
698
|
-
self.
|
699
|
-
|
700
|
-
inference_method=inference_method,
|
701
|
-
)
|
710
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
711
|
+
self._deps = self._get_dependencies()
|
702
712
|
assert isinstance(
|
703
713
|
dataset._session, Session
|
704
714
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -763,10 +773,8 @@ class Perceptron(BaseTransformer):
|
|
763
773
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
764
774
|
|
765
775
|
if isinstance(dataset, DataFrame):
|
766
|
-
self.
|
767
|
-
|
768
|
-
inference_method=inference_method,
|
769
|
-
)
|
776
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
777
|
+
self._deps = self._get_dependencies()
|
770
778
|
assert isinstance(
|
771
779
|
dataset._session, Session
|
772
780
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -830,10 +838,8 @@ class Perceptron(BaseTransformer):
|
|
830
838
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
831
839
|
|
832
840
|
if isinstance(dataset, DataFrame):
|
833
|
-
self.
|
834
|
-
|
835
|
-
inference_method=inference_method,
|
836
|
-
)
|
841
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
842
|
+
self._deps = self._get_dependencies()
|
837
843
|
assert isinstance(
|
838
844
|
dataset._session, Session
|
839
845
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -899,10 +905,8 @@ class Perceptron(BaseTransformer):
|
|
899
905
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
900
906
|
|
901
907
|
if isinstance(dataset, DataFrame):
|
902
|
-
self.
|
903
|
-
|
904
|
-
inference_method=inference_method,
|
905
|
-
)
|
908
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
909
|
+
self._deps = self._get_dependencies()
|
906
910
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
907
911
|
transform_kwargs = dict(
|
908
912
|
session=dataset._session,
|
@@ -966,17 +970,15 @@ class Perceptron(BaseTransformer):
|
|
966
970
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
967
971
|
|
968
972
|
if isinstance(dataset, DataFrame):
|
969
|
-
self.
|
970
|
-
|
971
|
-
inference_method="score",
|
972
|
-
)
|
973
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
974
|
+
self._deps = self._get_dependencies()
|
973
975
|
selected_cols = self._get_active_columns()
|
974
976
|
if len(selected_cols) > 0:
|
975
977
|
dataset = dataset.select(selected_cols)
|
976
978
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
977
979
|
transform_kwargs = dict(
|
978
980
|
session=dataset._session,
|
979
|
-
dependencies=
|
981
|
+
dependencies=self._deps,
|
980
982
|
score_sproc_imports=['sklearn'],
|
981
983
|
)
|
982
984
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1041,11 +1043,8 @@ class Perceptron(BaseTransformer):
|
|
1041
1043
|
|
1042
1044
|
if isinstance(dataset, DataFrame):
|
1043
1045
|
|
1044
|
-
self.
|
1045
|
-
|
1046
|
-
inference_method=inference_method,
|
1047
|
-
|
1048
|
-
)
|
1046
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1047
|
+
self._deps = self._get_dependencies()
|
1049
1048
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1050
1049
|
transform_kwargs = dict(
|
1051
1050
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class PoissonRegressor(BaseTransformer):
|
70
64
|
r"""Generalized Linear Model with a Poisson distribution
|
71
65
|
For more details on this class, see [sklearn.linear_model.PoissonRegressor]
|
@@ -310,20 +304,17 @@ class PoissonRegressor(BaseTransformer):
|
|
310
304
|
self,
|
311
305
|
dataset: DataFrame,
|
312
306
|
inference_method: str,
|
313
|
-
) ->
|
314
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
315
|
-
return the available package that exists in the snowflake anaconda channel
|
307
|
+
) -> None:
|
308
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
316
309
|
|
317
310
|
Args:
|
318
311
|
dataset: snowpark dataframe
|
319
312
|
inference_method: the inference method such as predict, score...
|
320
|
-
|
313
|
+
|
321
314
|
Raises:
|
322
315
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
323
316
|
SnowflakeMLException: If the session is None, raise error
|
324
317
|
|
325
|
-
Returns:
|
326
|
-
A list of available package that exists in the snowflake anaconda channel
|
327
318
|
"""
|
328
319
|
if not self._is_fitted:
|
329
320
|
raise exceptions.SnowflakeMLException(
|
@@ -341,9 +332,7 @@ class PoissonRegressor(BaseTransformer):
|
|
341
332
|
"Session must not specified for snowpark dataset."
|
342
333
|
),
|
343
334
|
)
|
344
|
-
|
345
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
346
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
335
|
+
|
347
336
|
|
348
337
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
349
338
|
@telemetry.send_api_usage_telemetry(
|
@@ -391,7 +380,8 @@ class PoissonRegressor(BaseTransformer):
|
|
391
380
|
|
392
381
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
393
382
|
|
394
|
-
self.
|
383
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
384
|
+
self._deps = self._get_dependencies()
|
395
385
|
assert isinstance(
|
396
386
|
dataset._session, Session
|
397
387
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -474,10 +464,8 @@ class PoissonRegressor(BaseTransformer):
|
|
474
464
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
475
465
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
476
466
|
|
477
|
-
self.
|
478
|
-
|
479
|
-
inference_method=inference_method,
|
480
|
-
)
|
467
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
468
|
+
self._deps = self._get_dependencies()
|
481
469
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
482
470
|
|
483
471
|
transform_kwargs = dict(
|
@@ -544,16 +532,40 @@ class PoissonRegressor(BaseTransformer):
|
|
544
532
|
self._is_fitted = True
|
545
533
|
return output_result
|
546
534
|
|
535
|
+
|
536
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
537
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
538
|
+
""" Method not supported for this class.
|
547
539
|
|
548
|
-
|
549
|
-
|
550
|
-
|
540
|
+
|
541
|
+
Raises:
|
542
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
543
|
+
|
544
|
+
Args:
|
545
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
546
|
+
Snowpark or Pandas DataFrame.
|
547
|
+
output_cols_prefix: Prefix for the response columns
|
551
548
|
Returns:
|
552
549
|
Transformed dataset.
|
553
550
|
"""
|
554
|
-
self.
|
555
|
-
|
556
|
-
|
551
|
+
self._infer_input_output_cols(dataset)
|
552
|
+
super()._check_dataset_type(dataset)
|
553
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
554
|
+
estimator=self._sklearn_object,
|
555
|
+
dataset=dataset,
|
556
|
+
input_cols=self.input_cols,
|
557
|
+
label_cols=self.label_cols,
|
558
|
+
sample_weight_col=self.sample_weight_col,
|
559
|
+
autogenerated=self._autogenerated,
|
560
|
+
subproject=_SUBPROJECT,
|
561
|
+
)
|
562
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
563
|
+
drop_input_cols=self._drop_input_cols,
|
564
|
+
expected_output_cols_list=self.output_cols,
|
565
|
+
)
|
566
|
+
self._sklearn_object = fitted_estimator
|
567
|
+
self._is_fitted = True
|
568
|
+
return output_result
|
557
569
|
|
558
570
|
|
559
571
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -644,10 +656,8 @@ class PoissonRegressor(BaseTransformer):
|
|
644
656
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
645
657
|
|
646
658
|
if isinstance(dataset, DataFrame):
|
647
|
-
self.
|
648
|
-
|
649
|
-
inference_method=inference_method,
|
650
|
-
)
|
659
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
660
|
+
self._deps = self._get_dependencies()
|
651
661
|
assert isinstance(
|
652
662
|
dataset._session, Session
|
653
663
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -712,10 +722,8 @@ class PoissonRegressor(BaseTransformer):
|
|
712
722
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
713
723
|
|
714
724
|
if isinstance(dataset, DataFrame):
|
715
|
-
self.
|
716
|
-
|
717
|
-
inference_method=inference_method,
|
718
|
-
)
|
725
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
726
|
+
self._deps = self._get_dependencies()
|
719
727
|
assert isinstance(
|
720
728
|
dataset._session, Session
|
721
729
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -777,10 +785,8 @@ class PoissonRegressor(BaseTransformer):
|
|
777
785
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
778
786
|
|
779
787
|
if isinstance(dataset, DataFrame):
|
780
|
-
self.
|
781
|
-
|
782
|
-
inference_method=inference_method,
|
783
|
-
)
|
788
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
789
|
+
self._deps = self._get_dependencies()
|
784
790
|
assert isinstance(
|
785
791
|
dataset._session, Session
|
786
792
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -846,10 +852,8 @@ class PoissonRegressor(BaseTransformer):
|
|
846
852
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
847
853
|
|
848
854
|
if isinstance(dataset, DataFrame):
|
849
|
-
self.
|
850
|
-
|
851
|
-
inference_method=inference_method,
|
852
|
-
)
|
855
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
856
|
+
self._deps = self._get_dependencies()
|
853
857
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
854
858
|
transform_kwargs = dict(
|
855
859
|
session=dataset._session,
|
@@ -913,17 +917,15 @@ class PoissonRegressor(BaseTransformer):
|
|
913
917
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
914
918
|
|
915
919
|
if isinstance(dataset, DataFrame):
|
916
|
-
self.
|
917
|
-
|
918
|
-
inference_method="score",
|
919
|
-
)
|
920
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
921
|
+
self._deps = self._get_dependencies()
|
920
922
|
selected_cols = self._get_active_columns()
|
921
923
|
if len(selected_cols) > 0:
|
922
924
|
dataset = dataset.select(selected_cols)
|
923
925
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
924
926
|
transform_kwargs = dict(
|
925
927
|
session=dataset._session,
|
926
|
-
dependencies=
|
928
|
+
dependencies=self._deps,
|
927
929
|
score_sproc_imports=['sklearn'],
|
928
930
|
)
|
929
931
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -988,11 +990,8 @@ class PoissonRegressor(BaseTransformer):
|
|
988
990
|
|
989
991
|
if isinstance(dataset, DataFrame):
|
990
992
|
|
991
|
-
self.
|
992
|
-
|
993
|
-
inference_method=inference_method,
|
994
|
-
|
995
|
-
)
|
993
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
994
|
+
self._deps = self._get_dependencies()
|
996
995
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
997
996
|
transform_kwargs = dict(
|
998
997
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RANSACRegressor(BaseTransformer):
|
70
64
|
r"""RANSAC (RANdom SAmple Consensus) algorithm
|
71
65
|
For more details on this class, see [sklearn.linear_model.RANSACRegressor]
|
@@ -366,20 +360,17 @@ class RANSACRegressor(BaseTransformer):
|
|
366
360
|
self,
|
367
361
|
dataset: DataFrame,
|
368
362
|
inference_method: str,
|
369
|
-
) ->
|
370
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
371
|
-
return the available package that exists in the snowflake anaconda channel
|
363
|
+
) -> None:
|
364
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
372
365
|
|
373
366
|
Args:
|
374
367
|
dataset: snowpark dataframe
|
375
368
|
inference_method: the inference method such as predict, score...
|
376
|
-
|
369
|
+
|
377
370
|
Raises:
|
378
371
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
379
372
|
SnowflakeMLException: If the session is None, raise error
|
380
373
|
|
381
|
-
Returns:
|
382
|
-
A list of available package that exists in the snowflake anaconda channel
|
383
374
|
"""
|
384
375
|
if not self._is_fitted:
|
385
376
|
raise exceptions.SnowflakeMLException(
|
@@ -397,9 +388,7 @@ class RANSACRegressor(BaseTransformer):
|
|
397
388
|
"Session must not specified for snowpark dataset."
|
398
389
|
),
|
399
390
|
)
|
400
|
-
|
401
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
402
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
391
|
+
|
403
392
|
|
404
393
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
405
394
|
@telemetry.send_api_usage_telemetry(
|
@@ -447,7 +436,8 @@ class RANSACRegressor(BaseTransformer):
|
|
447
436
|
|
448
437
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
449
438
|
|
450
|
-
self.
|
439
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
440
|
+
self._deps = self._get_dependencies()
|
451
441
|
assert isinstance(
|
452
442
|
dataset._session, Session
|
453
443
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -530,10 +520,8 @@ class RANSACRegressor(BaseTransformer):
|
|
530
520
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
531
521
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
532
522
|
|
533
|
-
self.
|
534
|
-
|
535
|
-
inference_method=inference_method,
|
536
|
-
)
|
523
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
524
|
+
self._deps = self._get_dependencies()
|
537
525
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
538
526
|
|
539
527
|
transform_kwargs = dict(
|
@@ -600,16 +588,40 @@ class RANSACRegressor(BaseTransformer):
|
|
600
588
|
self._is_fitted = True
|
601
589
|
return output_result
|
602
590
|
|
591
|
+
|
592
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
593
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
594
|
+
""" Method not supported for this class.
|
603
595
|
|
604
|
-
|
605
|
-
|
606
|
-
|
596
|
+
|
597
|
+
Raises:
|
598
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
599
|
+
|
600
|
+
Args:
|
601
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
602
|
+
Snowpark or Pandas DataFrame.
|
603
|
+
output_cols_prefix: Prefix for the response columns
|
607
604
|
Returns:
|
608
605
|
Transformed dataset.
|
609
606
|
"""
|
610
|
-
self.
|
611
|
-
|
612
|
-
|
607
|
+
self._infer_input_output_cols(dataset)
|
608
|
+
super()._check_dataset_type(dataset)
|
609
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
610
|
+
estimator=self._sklearn_object,
|
611
|
+
dataset=dataset,
|
612
|
+
input_cols=self.input_cols,
|
613
|
+
label_cols=self.label_cols,
|
614
|
+
sample_weight_col=self.sample_weight_col,
|
615
|
+
autogenerated=self._autogenerated,
|
616
|
+
subproject=_SUBPROJECT,
|
617
|
+
)
|
618
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
619
|
+
drop_input_cols=self._drop_input_cols,
|
620
|
+
expected_output_cols_list=self.output_cols,
|
621
|
+
)
|
622
|
+
self._sklearn_object = fitted_estimator
|
623
|
+
self._is_fitted = True
|
624
|
+
return output_result
|
613
625
|
|
614
626
|
|
615
627
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -700,10 +712,8 @@ class RANSACRegressor(BaseTransformer):
|
|
700
712
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
701
713
|
|
702
714
|
if isinstance(dataset, DataFrame):
|
703
|
-
self.
|
704
|
-
|
705
|
-
inference_method=inference_method,
|
706
|
-
)
|
715
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
716
|
+
self._deps = self._get_dependencies()
|
707
717
|
assert isinstance(
|
708
718
|
dataset._session, Session
|
709
719
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -768,10 +778,8 @@ class RANSACRegressor(BaseTransformer):
|
|
768
778
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
769
779
|
|
770
780
|
if isinstance(dataset, DataFrame):
|
771
|
-
self.
|
772
|
-
|
773
|
-
inference_method=inference_method,
|
774
|
-
)
|
781
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
782
|
+
self._deps = self._get_dependencies()
|
775
783
|
assert isinstance(
|
776
784
|
dataset._session, Session
|
777
785
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -833,10 +841,8 @@ class RANSACRegressor(BaseTransformer):
|
|
833
841
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
834
842
|
|
835
843
|
if isinstance(dataset, DataFrame):
|
836
|
-
self.
|
837
|
-
|
838
|
-
inference_method=inference_method,
|
839
|
-
)
|
844
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
845
|
+
self._deps = self._get_dependencies()
|
840
846
|
assert isinstance(
|
841
847
|
dataset._session, Session
|
842
848
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -902,10 +908,8 @@ class RANSACRegressor(BaseTransformer):
|
|
902
908
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
903
909
|
|
904
910
|
if isinstance(dataset, DataFrame):
|
905
|
-
self.
|
906
|
-
|
907
|
-
inference_method=inference_method,
|
908
|
-
)
|
911
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
912
|
+
self._deps = self._get_dependencies()
|
909
913
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
910
914
|
transform_kwargs = dict(
|
911
915
|
session=dataset._session,
|
@@ -969,17 +973,15 @@ class RANSACRegressor(BaseTransformer):
|
|
969
973
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
970
974
|
|
971
975
|
if isinstance(dataset, DataFrame):
|
972
|
-
self.
|
973
|
-
|
974
|
-
inference_method="score",
|
975
|
-
)
|
976
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
977
|
+
self._deps = self._get_dependencies()
|
976
978
|
selected_cols = self._get_active_columns()
|
977
979
|
if len(selected_cols) > 0:
|
978
980
|
dataset = dataset.select(selected_cols)
|
979
981
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
980
982
|
transform_kwargs = dict(
|
981
983
|
session=dataset._session,
|
982
|
-
dependencies=
|
984
|
+
dependencies=self._deps,
|
983
985
|
score_sproc_imports=['sklearn'],
|
984
986
|
)
|
985
987
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1044,11 +1046,8 @@ class RANSACRegressor(BaseTransformer):
|
|
1044
1046
|
|
1045
1047
|
if isinstance(dataset, DataFrame):
|
1046
1048
|
|
1047
|
-
self.
|
1048
|
-
|
1049
|
-
inference_method=inference_method,
|
1050
|
-
|
1051
|
-
)
|
1049
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1050
|
+
self._deps = self._get_dependencies()
|
1052
1051
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1053
1052
|
transform_kwargs = dict(
|
1054
1053
|
session = dataset._session,
|