snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LinearSVC(BaseTransformer):
|
70
64
|
r"""Linear Support Vector Classification
|
71
65
|
For more details on this class, see [sklearn.svm.LinearSVC]
|
@@ -354,20 +348,17 @@ class LinearSVC(BaseTransformer):
|
|
354
348
|
self,
|
355
349
|
dataset: DataFrame,
|
356
350
|
inference_method: str,
|
357
|
-
) ->
|
358
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
359
|
-
return the available package that exists in the snowflake anaconda channel
|
351
|
+
) -> None:
|
352
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
360
353
|
|
361
354
|
Args:
|
362
355
|
dataset: snowpark dataframe
|
363
356
|
inference_method: the inference method such as predict, score...
|
364
|
-
|
357
|
+
|
365
358
|
Raises:
|
366
359
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
367
360
|
SnowflakeMLException: If the session is None, raise error
|
368
361
|
|
369
|
-
Returns:
|
370
|
-
A list of available package that exists in the snowflake anaconda channel
|
371
362
|
"""
|
372
363
|
if not self._is_fitted:
|
373
364
|
raise exceptions.SnowflakeMLException(
|
@@ -385,9 +376,7 @@ class LinearSVC(BaseTransformer):
|
|
385
376
|
"Session must not specified for snowpark dataset."
|
386
377
|
),
|
387
378
|
)
|
388
|
-
|
389
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
390
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
379
|
+
|
391
380
|
|
392
381
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
393
382
|
@telemetry.send_api_usage_telemetry(
|
@@ -435,7 +424,8 @@ class LinearSVC(BaseTransformer):
|
|
435
424
|
|
436
425
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
437
426
|
|
438
|
-
self.
|
427
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
428
|
+
self._deps = self._get_dependencies()
|
439
429
|
assert isinstance(
|
440
430
|
dataset._session, Session
|
441
431
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -518,10 +508,8 @@ class LinearSVC(BaseTransformer):
|
|
518
508
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
519
509
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
520
510
|
|
521
|
-
self.
|
522
|
-
|
523
|
-
inference_method=inference_method,
|
524
|
-
)
|
511
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
512
|
+
self._deps = self._get_dependencies()
|
525
513
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
526
514
|
|
527
515
|
transform_kwargs = dict(
|
@@ -588,16 +576,40 @@ class LinearSVC(BaseTransformer):
|
|
588
576
|
self._is_fitted = True
|
589
577
|
return output_result
|
590
578
|
|
579
|
+
|
580
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
581
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
582
|
+
""" Method not supported for this class.
|
591
583
|
|
592
|
-
|
593
|
-
|
594
|
-
|
584
|
+
|
585
|
+
Raises:
|
586
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
587
|
+
|
588
|
+
Args:
|
589
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
590
|
+
Snowpark or Pandas DataFrame.
|
591
|
+
output_cols_prefix: Prefix for the response columns
|
595
592
|
Returns:
|
596
593
|
Transformed dataset.
|
597
594
|
"""
|
598
|
-
self.
|
599
|
-
|
600
|
-
|
595
|
+
self._infer_input_output_cols(dataset)
|
596
|
+
super()._check_dataset_type(dataset)
|
597
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
598
|
+
estimator=self._sklearn_object,
|
599
|
+
dataset=dataset,
|
600
|
+
input_cols=self.input_cols,
|
601
|
+
label_cols=self.label_cols,
|
602
|
+
sample_weight_col=self.sample_weight_col,
|
603
|
+
autogenerated=self._autogenerated,
|
604
|
+
subproject=_SUBPROJECT,
|
605
|
+
)
|
606
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
607
|
+
drop_input_cols=self._drop_input_cols,
|
608
|
+
expected_output_cols_list=self.output_cols,
|
609
|
+
)
|
610
|
+
self._sklearn_object = fitted_estimator
|
611
|
+
self._is_fitted = True
|
612
|
+
return output_result
|
601
613
|
|
602
614
|
|
603
615
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -688,10 +700,8 @@ class LinearSVC(BaseTransformer):
|
|
688
700
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
689
701
|
|
690
702
|
if isinstance(dataset, DataFrame):
|
691
|
-
self.
|
692
|
-
|
693
|
-
inference_method=inference_method,
|
694
|
-
)
|
703
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
704
|
+
self._deps = self._get_dependencies()
|
695
705
|
assert isinstance(
|
696
706
|
dataset._session, Session
|
697
707
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -756,10 +766,8 @@ class LinearSVC(BaseTransformer):
|
|
756
766
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
757
767
|
|
758
768
|
if isinstance(dataset, DataFrame):
|
759
|
-
self.
|
760
|
-
|
761
|
-
inference_method=inference_method,
|
762
|
-
)
|
769
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
770
|
+
self._deps = self._get_dependencies()
|
763
771
|
assert isinstance(
|
764
772
|
dataset._session, Session
|
765
773
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -823,10 +831,8 @@ class LinearSVC(BaseTransformer):
|
|
823
831
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
824
832
|
|
825
833
|
if isinstance(dataset, DataFrame):
|
826
|
-
self.
|
827
|
-
|
828
|
-
inference_method=inference_method,
|
829
|
-
)
|
834
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
835
|
+
self._deps = self._get_dependencies()
|
830
836
|
assert isinstance(
|
831
837
|
dataset._session, Session
|
832
838
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -892,10 +898,8 @@ class LinearSVC(BaseTransformer):
|
|
892
898
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
893
899
|
|
894
900
|
if isinstance(dataset, DataFrame):
|
895
|
-
self.
|
896
|
-
|
897
|
-
inference_method=inference_method,
|
898
|
-
)
|
901
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
902
|
+
self._deps = self._get_dependencies()
|
899
903
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
900
904
|
transform_kwargs = dict(
|
901
905
|
session=dataset._session,
|
@@ -959,17 +963,15 @@ class LinearSVC(BaseTransformer):
|
|
959
963
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
960
964
|
|
961
965
|
if isinstance(dataset, DataFrame):
|
962
|
-
self.
|
963
|
-
|
964
|
-
inference_method="score",
|
965
|
-
)
|
966
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
967
|
+
self._deps = self._get_dependencies()
|
966
968
|
selected_cols = self._get_active_columns()
|
967
969
|
if len(selected_cols) > 0:
|
968
970
|
dataset = dataset.select(selected_cols)
|
969
971
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
970
972
|
transform_kwargs = dict(
|
971
973
|
session=dataset._session,
|
972
|
-
dependencies=
|
974
|
+
dependencies=self._deps,
|
973
975
|
score_sproc_imports=['sklearn'],
|
974
976
|
)
|
975
977
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1034,11 +1036,8 @@ class LinearSVC(BaseTransformer):
|
|
1034
1036
|
|
1035
1037
|
if isinstance(dataset, DataFrame):
|
1036
1038
|
|
1037
|
-
self.
|
1038
|
-
|
1039
|
-
inference_method=inference_method,
|
1040
|
-
|
1041
|
-
)
|
1039
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1040
|
+
self._deps = self._get_dependencies()
|
1042
1041
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1043
1042
|
transform_kwargs = dict(
|
1044
1043
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LinearSVR(BaseTransformer):
|
70
64
|
r"""Linear Support Vector Regression
|
71
65
|
For more details on this class, see [sklearn.svm.LinearSVR]
|
@@ -326,20 +320,17 @@ class LinearSVR(BaseTransformer):
|
|
326
320
|
self,
|
327
321
|
dataset: DataFrame,
|
328
322
|
inference_method: str,
|
329
|
-
) ->
|
330
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
331
|
-
return the available package that exists in the snowflake anaconda channel
|
323
|
+
) -> None:
|
324
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
332
325
|
|
333
326
|
Args:
|
334
327
|
dataset: snowpark dataframe
|
335
328
|
inference_method: the inference method such as predict, score...
|
336
|
-
|
329
|
+
|
337
330
|
Raises:
|
338
331
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
339
332
|
SnowflakeMLException: If the session is None, raise error
|
340
333
|
|
341
|
-
Returns:
|
342
|
-
A list of available package that exists in the snowflake anaconda channel
|
343
334
|
"""
|
344
335
|
if not self._is_fitted:
|
345
336
|
raise exceptions.SnowflakeMLException(
|
@@ -357,9 +348,7 @@ class LinearSVR(BaseTransformer):
|
|
357
348
|
"Session must not specified for snowpark dataset."
|
358
349
|
),
|
359
350
|
)
|
360
|
-
|
361
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
362
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
351
|
+
|
363
352
|
|
364
353
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
365
354
|
@telemetry.send_api_usage_telemetry(
|
@@ -407,7 +396,8 @@ class LinearSVR(BaseTransformer):
|
|
407
396
|
|
408
397
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
409
398
|
|
410
|
-
self.
|
399
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
400
|
+
self._deps = self._get_dependencies()
|
411
401
|
assert isinstance(
|
412
402
|
dataset._session, Session
|
413
403
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -490,10 +480,8 @@ class LinearSVR(BaseTransformer):
|
|
490
480
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
491
481
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
492
482
|
|
493
|
-
self.
|
494
|
-
|
495
|
-
inference_method=inference_method,
|
496
|
-
)
|
483
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
484
|
+
self._deps = self._get_dependencies()
|
497
485
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
498
486
|
|
499
487
|
transform_kwargs = dict(
|
@@ -560,16 +548,40 @@ class LinearSVR(BaseTransformer):
|
|
560
548
|
self._is_fitted = True
|
561
549
|
return output_result
|
562
550
|
|
551
|
+
|
552
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
553
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
554
|
+
""" Method not supported for this class.
|
563
555
|
|
564
|
-
|
565
|
-
|
566
|
-
|
556
|
+
|
557
|
+
Raises:
|
558
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
559
|
+
|
560
|
+
Args:
|
561
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
562
|
+
Snowpark or Pandas DataFrame.
|
563
|
+
output_cols_prefix: Prefix for the response columns
|
567
564
|
Returns:
|
568
565
|
Transformed dataset.
|
569
566
|
"""
|
570
|
-
self.
|
571
|
-
|
572
|
-
|
567
|
+
self._infer_input_output_cols(dataset)
|
568
|
+
super()._check_dataset_type(dataset)
|
569
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
570
|
+
estimator=self._sklearn_object,
|
571
|
+
dataset=dataset,
|
572
|
+
input_cols=self.input_cols,
|
573
|
+
label_cols=self.label_cols,
|
574
|
+
sample_weight_col=self.sample_weight_col,
|
575
|
+
autogenerated=self._autogenerated,
|
576
|
+
subproject=_SUBPROJECT,
|
577
|
+
)
|
578
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
579
|
+
drop_input_cols=self._drop_input_cols,
|
580
|
+
expected_output_cols_list=self.output_cols,
|
581
|
+
)
|
582
|
+
self._sklearn_object = fitted_estimator
|
583
|
+
self._is_fitted = True
|
584
|
+
return output_result
|
573
585
|
|
574
586
|
|
575
587
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -660,10 +672,8 @@ class LinearSVR(BaseTransformer):
|
|
660
672
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
661
673
|
|
662
674
|
if isinstance(dataset, DataFrame):
|
663
|
-
self.
|
664
|
-
|
665
|
-
inference_method=inference_method,
|
666
|
-
)
|
675
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
676
|
+
self._deps = self._get_dependencies()
|
667
677
|
assert isinstance(
|
668
678
|
dataset._session, Session
|
669
679
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -728,10 +738,8 @@ class LinearSVR(BaseTransformer):
|
|
728
738
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
729
739
|
|
730
740
|
if isinstance(dataset, DataFrame):
|
731
|
-
self.
|
732
|
-
|
733
|
-
inference_method=inference_method,
|
734
|
-
)
|
741
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
742
|
+
self._deps = self._get_dependencies()
|
735
743
|
assert isinstance(
|
736
744
|
dataset._session, Session
|
737
745
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -793,10 +801,8 @@ class LinearSVR(BaseTransformer):
|
|
793
801
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
794
802
|
|
795
803
|
if isinstance(dataset, DataFrame):
|
796
|
-
self.
|
797
|
-
|
798
|
-
inference_method=inference_method,
|
799
|
-
)
|
804
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
805
|
+
self._deps = self._get_dependencies()
|
800
806
|
assert isinstance(
|
801
807
|
dataset._session, Session
|
802
808
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -862,10 +868,8 @@ class LinearSVR(BaseTransformer):
|
|
862
868
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
863
869
|
|
864
870
|
if isinstance(dataset, DataFrame):
|
865
|
-
self.
|
866
|
-
|
867
|
-
inference_method=inference_method,
|
868
|
-
)
|
871
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
872
|
+
self._deps = self._get_dependencies()
|
869
873
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
870
874
|
transform_kwargs = dict(
|
871
875
|
session=dataset._session,
|
@@ -929,17 +933,15 @@ class LinearSVR(BaseTransformer):
|
|
929
933
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
930
934
|
|
931
935
|
if isinstance(dataset, DataFrame):
|
932
|
-
self.
|
933
|
-
|
934
|
-
inference_method="score",
|
935
|
-
)
|
936
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
937
|
+
self._deps = self._get_dependencies()
|
936
938
|
selected_cols = self._get_active_columns()
|
937
939
|
if len(selected_cols) > 0:
|
938
940
|
dataset = dataset.select(selected_cols)
|
939
941
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
940
942
|
transform_kwargs = dict(
|
941
943
|
session=dataset._session,
|
942
|
-
dependencies=
|
944
|
+
dependencies=self._deps,
|
943
945
|
score_sproc_imports=['sklearn'],
|
944
946
|
)
|
945
947
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1004,11 +1006,8 @@ class LinearSVR(BaseTransformer):
|
|
1004
1006
|
|
1005
1007
|
if isinstance(dataset, DataFrame):
|
1006
1008
|
|
1007
|
-
self.
|
1008
|
-
|
1009
|
-
inference_method=inference_method,
|
1010
|
-
|
1011
|
-
)
|
1009
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1010
|
+
self._deps = self._get_dependencies()
|
1012
1011
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1013
1012
|
transform_kwargs = dict(
|
1014
1013
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class NuSVC(BaseTransformer):
|
70
64
|
r"""Nu-Support Vector Classification
|
71
65
|
For more details on this class, see [sklearn.svm.NuSVC]
|
@@ -360,20 +354,17 @@ class NuSVC(BaseTransformer):
|
|
360
354
|
self,
|
361
355
|
dataset: DataFrame,
|
362
356
|
inference_method: str,
|
363
|
-
) ->
|
364
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
365
|
-
return the available package that exists in the snowflake anaconda channel
|
357
|
+
) -> None:
|
358
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
366
359
|
|
367
360
|
Args:
|
368
361
|
dataset: snowpark dataframe
|
369
362
|
inference_method: the inference method such as predict, score...
|
370
|
-
|
363
|
+
|
371
364
|
Raises:
|
372
365
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
373
366
|
SnowflakeMLException: If the session is None, raise error
|
374
367
|
|
375
|
-
Returns:
|
376
|
-
A list of available package that exists in the snowflake anaconda channel
|
377
368
|
"""
|
378
369
|
if not self._is_fitted:
|
379
370
|
raise exceptions.SnowflakeMLException(
|
@@ -391,9 +382,7 @@ class NuSVC(BaseTransformer):
|
|
391
382
|
"Session must not specified for snowpark dataset."
|
392
383
|
),
|
393
384
|
)
|
394
|
-
|
395
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
396
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
385
|
+
|
397
386
|
|
398
387
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
399
388
|
@telemetry.send_api_usage_telemetry(
|
@@ -441,7 +430,8 @@ class NuSVC(BaseTransformer):
|
|
441
430
|
|
442
431
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
443
432
|
|
444
|
-
self.
|
433
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
434
|
+
self._deps = self._get_dependencies()
|
445
435
|
assert isinstance(
|
446
436
|
dataset._session, Session
|
447
437
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -524,10 +514,8 @@ class NuSVC(BaseTransformer):
|
|
524
514
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
525
515
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
526
516
|
|
527
|
-
self.
|
528
|
-
|
529
|
-
inference_method=inference_method,
|
530
|
-
)
|
517
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
518
|
+
self._deps = self._get_dependencies()
|
531
519
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
532
520
|
|
533
521
|
transform_kwargs = dict(
|
@@ -594,16 +582,40 @@ class NuSVC(BaseTransformer):
|
|
594
582
|
self._is_fitted = True
|
595
583
|
return output_result
|
596
584
|
|
585
|
+
|
586
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
587
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
588
|
+
""" Method not supported for this class.
|
597
589
|
|
598
|
-
|
599
|
-
|
600
|
-
|
590
|
+
|
591
|
+
Raises:
|
592
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
593
|
+
|
594
|
+
Args:
|
595
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
596
|
+
Snowpark or Pandas DataFrame.
|
597
|
+
output_cols_prefix: Prefix for the response columns
|
601
598
|
Returns:
|
602
599
|
Transformed dataset.
|
603
600
|
"""
|
604
|
-
self.
|
605
|
-
|
606
|
-
|
601
|
+
self._infer_input_output_cols(dataset)
|
602
|
+
super()._check_dataset_type(dataset)
|
603
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
604
|
+
estimator=self._sklearn_object,
|
605
|
+
dataset=dataset,
|
606
|
+
input_cols=self.input_cols,
|
607
|
+
label_cols=self.label_cols,
|
608
|
+
sample_weight_col=self.sample_weight_col,
|
609
|
+
autogenerated=self._autogenerated,
|
610
|
+
subproject=_SUBPROJECT,
|
611
|
+
)
|
612
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
613
|
+
drop_input_cols=self._drop_input_cols,
|
614
|
+
expected_output_cols_list=self.output_cols,
|
615
|
+
)
|
616
|
+
self._sklearn_object = fitted_estimator
|
617
|
+
self._is_fitted = True
|
618
|
+
return output_result
|
607
619
|
|
608
620
|
|
609
621
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -696,10 +708,8 @@ class NuSVC(BaseTransformer):
|
|
696
708
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
697
709
|
|
698
710
|
if isinstance(dataset, DataFrame):
|
699
|
-
self.
|
700
|
-
|
701
|
-
inference_method=inference_method,
|
702
|
-
)
|
711
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
712
|
+
self._deps = self._get_dependencies()
|
703
713
|
assert isinstance(
|
704
714
|
dataset._session, Session
|
705
715
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -766,10 +776,8 @@ class NuSVC(BaseTransformer):
|
|
766
776
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
767
777
|
|
768
778
|
if isinstance(dataset, DataFrame):
|
769
|
-
self.
|
770
|
-
|
771
|
-
inference_method=inference_method,
|
772
|
-
)
|
779
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
780
|
+
self._deps = self._get_dependencies()
|
773
781
|
assert isinstance(
|
774
782
|
dataset._session, Session
|
775
783
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -833,10 +841,8 @@ class NuSVC(BaseTransformer):
|
|
833
841
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
834
842
|
|
835
843
|
if isinstance(dataset, DataFrame):
|
836
|
-
self.
|
837
|
-
|
838
|
-
inference_method=inference_method,
|
839
|
-
)
|
844
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
845
|
+
self._deps = self._get_dependencies()
|
840
846
|
assert isinstance(
|
841
847
|
dataset._session, Session
|
842
848
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -902,10 +908,8 @@ class NuSVC(BaseTransformer):
|
|
902
908
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
903
909
|
|
904
910
|
if isinstance(dataset, DataFrame):
|
905
|
-
self.
|
906
|
-
|
907
|
-
inference_method=inference_method,
|
908
|
-
)
|
911
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
912
|
+
self._deps = self._get_dependencies()
|
909
913
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
910
914
|
transform_kwargs = dict(
|
911
915
|
session=dataset._session,
|
@@ -969,17 +973,15 @@ class NuSVC(BaseTransformer):
|
|
969
973
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
970
974
|
|
971
975
|
if isinstance(dataset, DataFrame):
|
972
|
-
self.
|
973
|
-
|
974
|
-
inference_method="score",
|
975
|
-
)
|
976
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
977
|
+
self._deps = self._get_dependencies()
|
976
978
|
selected_cols = self._get_active_columns()
|
977
979
|
if len(selected_cols) > 0:
|
978
980
|
dataset = dataset.select(selected_cols)
|
979
981
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
980
982
|
transform_kwargs = dict(
|
981
983
|
session=dataset._session,
|
982
|
-
dependencies=
|
984
|
+
dependencies=self._deps,
|
983
985
|
score_sproc_imports=['sklearn'],
|
984
986
|
)
|
985
987
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1044,11 +1046,8 @@ class NuSVC(BaseTransformer):
|
|
1044
1046
|
|
1045
1047
|
if isinstance(dataset, DataFrame):
|
1046
1048
|
|
1047
|
-
self.
|
1048
|
-
|
1049
|
-
inference_method=inference_method,
|
1050
|
-
|
1051
|
-
)
|
1049
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1050
|
+
self._deps = self._get_dependencies()
|
1052
1051
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1053
1052
|
transform_kwargs = dict(
|
1054
1053
|
session = dataset._session,
|