snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class NuSVR(BaseTransformer):
|
70
64
|
r"""Nu Support Vector Regression
|
71
65
|
For more details on this class, see [sklearn.svm.NuSVR]
|
@@ -321,20 +315,17 @@ class NuSVR(BaseTransformer):
|
|
321
315
|
self,
|
322
316
|
dataset: DataFrame,
|
323
317
|
inference_method: str,
|
324
|
-
) ->
|
325
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
326
|
-
return the available package that exists in the snowflake anaconda channel
|
318
|
+
) -> None:
|
319
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
327
320
|
|
328
321
|
Args:
|
329
322
|
dataset: snowpark dataframe
|
330
323
|
inference_method: the inference method such as predict, score...
|
331
|
-
|
324
|
+
|
332
325
|
Raises:
|
333
326
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
334
327
|
SnowflakeMLException: If the session is None, raise error
|
335
328
|
|
336
|
-
Returns:
|
337
|
-
A list of available package that exists in the snowflake anaconda channel
|
338
329
|
"""
|
339
330
|
if not self._is_fitted:
|
340
331
|
raise exceptions.SnowflakeMLException(
|
@@ -352,9 +343,7 @@ class NuSVR(BaseTransformer):
|
|
352
343
|
"Session must not specified for snowpark dataset."
|
353
344
|
),
|
354
345
|
)
|
355
|
-
|
356
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
357
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
346
|
+
|
358
347
|
|
359
348
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
360
349
|
@telemetry.send_api_usage_telemetry(
|
@@ -402,7 +391,8 @@ class NuSVR(BaseTransformer):
|
|
402
391
|
|
403
392
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
404
393
|
|
405
|
-
self.
|
394
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
395
|
+
self._deps = self._get_dependencies()
|
406
396
|
assert isinstance(
|
407
397
|
dataset._session, Session
|
408
398
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -485,10 +475,8 @@ class NuSVR(BaseTransformer):
|
|
485
475
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
486
476
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
487
477
|
|
488
|
-
self.
|
489
|
-
|
490
|
-
inference_method=inference_method,
|
491
|
-
)
|
478
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
479
|
+
self._deps = self._get_dependencies()
|
492
480
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
493
481
|
|
494
482
|
transform_kwargs = dict(
|
@@ -555,16 +543,40 @@ class NuSVR(BaseTransformer):
|
|
555
543
|
self._is_fitted = True
|
556
544
|
return output_result
|
557
545
|
|
546
|
+
|
547
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
548
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
549
|
+
""" Method not supported for this class.
|
558
550
|
|
559
|
-
|
560
|
-
|
561
|
-
|
551
|
+
|
552
|
+
Raises:
|
553
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
554
|
+
|
555
|
+
Args:
|
556
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
557
|
+
Snowpark or Pandas DataFrame.
|
558
|
+
output_cols_prefix: Prefix for the response columns
|
562
559
|
Returns:
|
563
560
|
Transformed dataset.
|
564
561
|
"""
|
565
|
-
self.
|
566
|
-
|
567
|
-
|
562
|
+
self._infer_input_output_cols(dataset)
|
563
|
+
super()._check_dataset_type(dataset)
|
564
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
565
|
+
estimator=self._sklearn_object,
|
566
|
+
dataset=dataset,
|
567
|
+
input_cols=self.input_cols,
|
568
|
+
label_cols=self.label_cols,
|
569
|
+
sample_weight_col=self.sample_weight_col,
|
570
|
+
autogenerated=self._autogenerated,
|
571
|
+
subproject=_SUBPROJECT,
|
572
|
+
)
|
573
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
574
|
+
drop_input_cols=self._drop_input_cols,
|
575
|
+
expected_output_cols_list=self.output_cols,
|
576
|
+
)
|
577
|
+
self._sklearn_object = fitted_estimator
|
578
|
+
self._is_fitted = True
|
579
|
+
return output_result
|
568
580
|
|
569
581
|
|
570
582
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -655,10 +667,8 @@ class NuSVR(BaseTransformer):
|
|
655
667
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
656
668
|
|
657
669
|
if isinstance(dataset, DataFrame):
|
658
|
-
self.
|
659
|
-
|
660
|
-
inference_method=inference_method,
|
661
|
-
)
|
670
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
671
|
+
self._deps = self._get_dependencies()
|
662
672
|
assert isinstance(
|
663
673
|
dataset._session, Session
|
664
674
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -723,10 +733,8 @@ class NuSVR(BaseTransformer):
|
|
723
733
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
724
734
|
|
725
735
|
if isinstance(dataset, DataFrame):
|
726
|
-
self.
|
727
|
-
|
728
|
-
inference_method=inference_method,
|
729
|
-
)
|
736
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
737
|
+
self._deps = self._get_dependencies()
|
730
738
|
assert isinstance(
|
731
739
|
dataset._session, Session
|
732
740
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -788,10 +796,8 @@ class NuSVR(BaseTransformer):
|
|
788
796
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
789
797
|
|
790
798
|
if isinstance(dataset, DataFrame):
|
791
|
-
self.
|
792
|
-
|
793
|
-
inference_method=inference_method,
|
794
|
-
)
|
799
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
800
|
+
self._deps = self._get_dependencies()
|
795
801
|
assert isinstance(
|
796
802
|
dataset._session, Session
|
797
803
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -857,10 +863,8 @@ class NuSVR(BaseTransformer):
|
|
857
863
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
858
864
|
|
859
865
|
if isinstance(dataset, DataFrame):
|
860
|
-
self.
|
861
|
-
|
862
|
-
inference_method=inference_method,
|
863
|
-
)
|
866
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
867
|
+
self._deps = self._get_dependencies()
|
864
868
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
865
869
|
transform_kwargs = dict(
|
866
870
|
session=dataset._session,
|
@@ -924,17 +928,15 @@ class NuSVR(BaseTransformer):
|
|
924
928
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
925
929
|
|
926
930
|
if isinstance(dataset, DataFrame):
|
927
|
-
self.
|
928
|
-
|
929
|
-
inference_method="score",
|
930
|
-
)
|
931
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
932
|
+
self._deps = self._get_dependencies()
|
931
933
|
selected_cols = self._get_active_columns()
|
932
934
|
if len(selected_cols) > 0:
|
933
935
|
dataset = dataset.select(selected_cols)
|
934
936
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
935
937
|
transform_kwargs = dict(
|
936
938
|
session=dataset._session,
|
937
|
-
dependencies=
|
939
|
+
dependencies=self._deps,
|
938
940
|
score_sproc_imports=['sklearn'],
|
939
941
|
)
|
940
942
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -999,11 +1001,8 @@ class NuSVR(BaseTransformer):
|
|
999
1001
|
|
1000
1002
|
if isinstance(dataset, DataFrame):
|
1001
1003
|
|
1002
|
-
self.
|
1003
|
-
|
1004
|
-
inference_method=inference_method,
|
1005
|
-
|
1006
|
-
)
|
1004
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1005
|
+
self._deps = self._get_dependencies()
|
1007
1006
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1008
1007
|
transform_kwargs = dict(
|
1009
1008
|
session = dataset._session,
|
snowflake/ml/modeling/svm/svc.py
CHANGED
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SVC(BaseTransformer):
|
70
64
|
r"""C-Support Vector Classification
|
71
65
|
For more details on this class, see [sklearn.svm.SVC]
|
@@ -363,20 +357,17 @@ class SVC(BaseTransformer):
|
|
363
357
|
self,
|
364
358
|
dataset: DataFrame,
|
365
359
|
inference_method: str,
|
366
|
-
) ->
|
367
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
368
|
-
return the available package that exists in the snowflake anaconda channel
|
360
|
+
) -> None:
|
361
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
369
362
|
|
370
363
|
Args:
|
371
364
|
dataset: snowpark dataframe
|
372
365
|
inference_method: the inference method such as predict, score...
|
373
|
-
|
366
|
+
|
374
367
|
Raises:
|
375
368
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
376
369
|
SnowflakeMLException: If the session is None, raise error
|
377
370
|
|
378
|
-
Returns:
|
379
|
-
A list of available package that exists in the snowflake anaconda channel
|
380
371
|
"""
|
381
372
|
if not self._is_fitted:
|
382
373
|
raise exceptions.SnowflakeMLException(
|
@@ -394,9 +385,7 @@ class SVC(BaseTransformer):
|
|
394
385
|
"Session must not specified for snowpark dataset."
|
395
386
|
),
|
396
387
|
)
|
397
|
-
|
398
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
399
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
388
|
+
|
400
389
|
|
401
390
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
402
391
|
@telemetry.send_api_usage_telemetry(
|
@@ -444,7 +433,8 @@ class SVC(BaseTransformer):
|
|
444
433
|
|
445
434
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
446
435
|
|
447
|
-
self.
|
436
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
437
|
+
self._deps = self._get_dependencies()
|
448
438
|
assert isinstance(
|
449
439
|
dataset._session, Session
|
450
440
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -527,10 +517,8 @@ class SVC(BaseTransformer):
|
|
527
517
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
528
518
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
529
519
|
|
530
|
-
self.
|
531
|
-
|
532
|
-
inference_method=inference_method,
|
533
|
-
)
|
520
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
521
|
+
self._deps = self._get_dependencies()
|
534
522
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
535
523
|
|
536
524
|
transform_kwargs = dict(
|
@@ -597,16 +585,40 @@ class SVC(BaseTransformer):
|
|
597
585
|
self._is_fitted = True
|
598
586
|
return output_result
|
599
587
|
|
588
|
+
|
589
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
590
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
591
|
+
""" Method not supported for this class.
|
600
592
|
|
601
|
-
|
602
|
-
|
603
|
-
|
593
|
+
|
594
|
+
Raises:
|
595
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
596
|
+
|
597
|
+
Args:
|
598
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
599
|
+
Snowpark or Pandas DataFrame.
|
600
|
+
output_cols_prefix: Prefix for the response columns
|
604
601
|
Returns:
|
605
602
|
Transformed dataset.
|
606
603
|
"""
|
607
|
-
self.
|
608
|
-
|
609
|
-
|
604
|
+
self._infer_input_output_cols(dataset)
|
605
|
+
super()._check_dataset_type(dataset)
|
606
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
607
|
+
estimator=self._sklearn_object,
|
608
|
+
dataset=dataset,
|
609
|
+
input_cols=self.input_cols,
|
610
|
+
label_cols=self.label_cols,
|
611
|
+
sample_weight_col=self.sample_weight_col,
|
612
|
+
autogenerated=self._autogenerated,
|
613
|
+
subproject=_SUBPROJECT,
|
614
|
+
)
|
615
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
616
|
+
drop_input_cols=self._drop_input_cols,
|
617
|
+
expected_output_cols_list=self.output_cols,
|
618
|
+
)
|
619
|
+
self._sklearn_object = fitted_estimator
|
620
|
+
self._is_fitted = True
|
621
|
+
return output_result
|
610
622
|
|
611
623
|
|
612
624
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -699,10 +711,8 @@ class SVC(BaseTransformer):
|
|
699
711
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
700
712
|
|
701
713
|
if isinstance(dataset, DataFrame):
|
702
|
-
self.
|
703
|
-
|
704
|
-
inference_method=inference_method,
|
705
|
-
)
|
714
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
715
|
+
self._deps = self._get_dependencies()
|
706
716
|
assert isinstance(
|
707
717
|
dataset._session, Session
|
708
718
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -769,10 +779,8 @@ class SVC(BaseTransformer):
|
|
769
779
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
770
780
|
|
771
781
|
if isinstance(dataset, DataFrame):
|
772
|
-
self.
|
773
|
-
|
774
|
-
inference_method=inference_method,
|
775
|
-
)
|
782
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
783
|
+
self._deps = self._get_dependencies()
|
776
784
|
assert isinstance(
|
777
785
|
dataset._session, Session
|
778
786
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -836,10 +844,8 @@ class SVC(BaseTransformer):
|
|
836
844
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
837
845
|
|
838
846
|
if isinstance(dataset, DataFrame):
|
839
|
-
self.
|
840
|
-
|
841
|
-
inference_method=inference_method,
|
842
|
-
)
|
847
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
848
|
+
self._deps = self._get_dependencies()
|
843
849
|
assert isinstance(
|
844
850
|
dataset._session, Session
|
845
851
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -905,10 +911,8 @@ class SVC(BaseTransformer):
|
|
905
911
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
906
912
|
|
907
913
|
if isinstance(dataset, DataFrame):
|
908
|
-
self.
|
909
|
-
|
910
|
-
inference_method=inference_method,
|
911
|
-
)
|
914
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
915
|
+
self._deps = self._get_dependencies()
|
912
916
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
913
917
|
transform_kwargs = dict(
|
914
918
|
session=dataset._session,
|
@@ -972,17 +976,15 @@ class SVC(BaseTransformer):
|
|
972
976
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
973
977
|
|
974
978
|
if isinstance(dataset, DataFrame):
|
975
|
-
self.
|
976
|
-
|
977
|
-
inference_method="score",
|
978
|
-
)
|
979
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
980
|
+
self._deps = self._get_dependencies()
|
979
981
|
selected_cols = self._get_active_columns()
|
980
982
|
if len(selected_cols) > 0:
|
981
983
|
dataset = dataset.select(selected_cols)
|
982
984
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
983
985
|
transform_kwargs = dict(
|
984
986
|
session=dataset._session,
|
985
|
-
dependencies=
|
987
|
+
dependencies=self._deps,
|
986
988
|
score_sproc_imports=['sklearn'],
|
987
989
|
)
|
988
990
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1047,11 +1049,8 @@ class SVC(BaseTransformer):
|
|
1047
1049
|
|
1048
1050
|
if isinstance(dataset, DataFrame):
|
1049
1051
|
|
1050
|
-
self.
|
1051
|
-
|
1052
|
-
inference_method=inference_method,
|
1053
|
-
|
1054
|
-
)
|
1052
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1053
|
+
self._deps = self._get_dependencies()
|
1055
1054
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1056
1055
|
transform_kwargs = dict(
|
1057
1056
|
session = dataset._session,
|
snowflake/ml/modeling/svm/svr.py
CHANGED
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SVR(BaseTransformer):
|
70
64
|
r"""Epsilon-Support Vector Regression
|
71
65
|
For more details on this class, see [sklearn.svm.SVR]
|
@@ -324,20 +318,17 @@ class SVR(BaseTransformer):
|
|
324
318
|
self,
|
325
319
|
dataset: DataFrame,
|
326
320
|
inference_method: str,
|
327
|
-
) ->
|
328
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
329
|
-
return the available package that exists in the snowflake anaconda channel
|
321
|
+
) -> None:
|
322
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
330
323
|
|
331
324
|
Args:
|
332
325
|
dataset: snowpark dataframe
|
333
326
|
inference_method: the inference method such as predict, score...
|
334
|
-
|
327
|
+
|
335
328
|
Raises:
|
336
329
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
337
330
|
SnowflakeMLException: If the session is None, raise error
|
338
331
|
|
339
|
-
Returns:
|
340
|
-
A list of available package that exists in the snowflake anaconda channel
|
341
332
|
"""
|
342
333
|
if not self._is_fitted:
|
343
334
|
raise exceptions.SnowflakeMLException(
|
@@ -355,9 +346,7 @@ class SVR(BaseTransformer):
|
|
355
346
|
"Session must not specified for snowpark dataset."
|
356
347
|
),
|
357
348
|
)
|
358
|
-
|
359
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
360
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
349
|
+
|
361
350
|
|
362
351
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
363
352
|
@telemetry.send_api_usage_telemetry(
|
@@ -405,7 +394,8 @@ class SVR(BaseTransformer):
|
|
405
394
|
|
406
395
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
407
396
|
|
408
|
-
self.
|
397
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
398
|
+
self._deps = self._get_dependencies()
|
409
399
|
assert isinstance(
|
410
400
|
dataset._session, Session
|
411
401
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -488,10 +478,8 @@ class SVR(BaseTransformer):
|
|
488
478
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
489
479
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
490
480
|
|
491
|
-
self.
|
492
|
-
|
493
|
-
inference_method=inference_method,
|
494
|
-
)
|
481
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
482
|
+
self._deps = self._get_dependencies()
|
495
483
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
496
484
|
|
497
485
|
transform_kwargs = dict(
|
@@ -558,16 +546,40 @@ class SVR(BaseTransformer):
|
|
558
546
|
self._is_fitted = True
|
559
547
|
return output_result
|
560
548
|
|
549
|
+
|
550
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
551
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
552
|
+
""" Method not supported for this class.
|
561
553
|
|
562
|
-
|
563
|
-
|
564
|
-
|
554
|
+
|
555
|
+
Raises:
|
556
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
560
|
+
Snowpark or Pandas DataFrame.
|
561
|
+
output_cols_prefix: Prefix for the response columns
|
565
562
|
Returns:
|
566
563
|
Transformed dataset.
|
567
564
|
"""
|
568
|
-
self.
|
569
|
-
|
570
|
-
|
565
|
+
self._infer_input_output_cols(dataset)
|
566
|
+
super()._check_dataset_type(dataset)
|
567
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
568
|
+
estimator=self._sklearn_object,
|
569
|
+
dataset=dataset,
|
570
|
+
input_cols=self.input_cols,
|
571
|
+
label_cols=self.label_cols,
|
572
|
+
sample_weight_col=self.sample_weight_col,
|
573
|
+
autogenerated=self._autogenerated,
|
574
|
+
subproject=_SUBPROJECT,
|
575
|
+
)
|
576
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
577
|
+
drop_input_cols=self._drop_input_cols,
|
578
|
+
expected_output_cols_list=self.output_cols,
|
579
|
+
)
|
580
|
+
self._sklearn_object = fitted_estimator
|
581
|
+
self._is_fitted = True
|
582
|
+
return output_result
|
571
583
|
|
572
584
|
|
573
585
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -658,10 +670,8 @@ class SVR(BaseTransformer):
|
|
658
670
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
659
671
|
|
660
672
|
if isinstance(dataset, DataFrame):
|
661
|
-
self.
|
662
|
-
|
663
|
-
inference_method=inference_method,
|
664
|
-
)
|
673
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
674
|
+
self._deps = self._get_dependencies()
|
665
675
|
assert isinstance(
|
666
676
|
dataset._session, Session
|
667
677
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -726,10 +736,8 @@ class SVR(BaseTransformer):
|
|
726
736
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
727
737
|
|
728
738
|
if isinstance(dataset, DataFrame):
|
729
|
-
self.
|
730
|
-
|
731
|
-
inference_method=inference_method,
|
732
|
-
)
|
739
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
740
|
+
self._deps = self._get_dependencies()
|
733
741
|
assert isinstance(
|
734
742
|
dataset._session, Session
|
735
743
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -791,10 +799,8 @@ class SVR(BaseTransformer):
|
|
791
799
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
792
800
|
|
793
801
|
if isinstance(dataset, DataFrame):
|
794
|
-
self.
|
795
|
-
|
796
|
-
inference_method=inference_method,
|
797
|
-
)
|
802
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
803
|
+
self._deps = self._get_dependencies()
|
798
804
|
assert isinstance(
|
799
805
|
dataset._session, Session
|
800
806
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -860,10 +866,8 @@ class SVR(BaseTransformer):
|
|
860
866
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
861
867
|
|
862
868
|
if isinstance(dataset, DataFrame):
|
863
|
-
self.
|
864
|
-
|
865
|
-
inference_method=inference_method,
|
866
|
-
)
|
869
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
870
|
+
self._deps = self._get_dependencies()
|
867
871
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
868
872
|
transform_kwargs = dict(
|
869
873
|
session=dataset._session,
|
@@ -927,17 +931,15 @@ class SVR(BaseTransformer):
|
|
927
931
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
928
932
|
|
929
933
|
if isinstance(dataset, DataFrame):
|
930
|
-
self.
|
931
|
-
|
932
|
-
inference_method="score",
|
933
|
-
)
|
934
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
935
|
+
self._deps = self._get_dependencies()
|
934
936
|
selected_cols = self._get_active_columns()
|
935
937
|
if len(selected_cols) > 0:
|
936
938
|
dataset = dataset.select(selected_cols)
|
937
939
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
938
940
|
transform_kwargs = dict(
|
939
941
|
session=dataset._session,
|
940
|
-
dependencies=
|
942
|
+
dependencies=self._deps,
|
941
943
|
score_sproc_imports=['sklearn'],
|
942
944
|
)
|
943
945
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1002,11 +1004,8 @@ class SVR(BaseTransformer):
|
|
1002
1004
|
|
1003
1005
|
if isinstance(dataset, DataFrame):
|
1004
1006
|
|
1005
|
-
self.
|
1006
|
-
|
1007
|
-
inference_method=inference_method,
|
1008
|
-
|
1009
|
-
)
|
1007
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1008
|
+
self._deps = self._get_dependencies()
|
1010
1009
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1011
1010
|
transform_kwargs = dict(
|
1012
1011
|
session = dataset._session,
|