snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LogisticRegression(BaseTransformer):
70
64
  r"""Logistic Regression (aka logit, MaxEnt) classifier
71
65
  For more details on this class, see [sklearn.linear_model.LogisticRegression]
@@ -394,20 +388,17 @@ class LogisticRegression(BaseTransformer):
394
388
  self,
395
389
  dataset: DataFrame,
396
390
  inference_method: str,
397
- ) -> List[str]:
398
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
399
- return the available package that exists in the snowflake anaconda channel
391
+ ) -> None:
392
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
400
393
 
401
394
  Args:
402
395
  dataset: snowpark dataframe
403
396
  inference_method: the inference method such as predict, score...
404
-
397
+
405
398
  Raises:
406
399
  SnowflakeMLException: If the estimator is not fitted, raise error
407
400
  SnowflakeMLException: If the session is None, raise error
408
401
 
409
- Returns:
410
- A list of available package that exists in the snowflake anaconda channel
411
402
  """
412
403
  if not self._is_fitted:
413
404
  raise exceptions.SnowflakeMLException(
@@ -425,9 +416,7 @@ class LogisticRegression(BaseTransformer):
425
416
  "Session must not specified for snowpark dataset."
426
417
  ),
427
418
  )
428
- # Validate that key package version in user workspace are supported in snowflake conda channel
429
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
430
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
419
+
431
420
 
432
421
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
433
422
  @telemetry.send_api_usage_telemetry(
@@ -475,7 +464,8 @@ class LogisticRegression(BaseTransformer):
475
464
 
476
465
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
477
466
 
478
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
467
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
468
+ self._deps = self._get_dependencies()
479
469
  assert isinstance(
480
470
  dataset._session, Session
481
471
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -558,10 +548,8 @@ class LogisticRegression(BaseTransformer):
558
548
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
559
549
  expected_dtype = convert_sp_to_sf_type(output_types[0])
560
550
 
561
- self._deps = self._batch_inference_validate_snowpark(
562
- dataset=dataset,
563
- inference_method=inference_method,
564
- )
551
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
552
+ self._deps = self._get_dependencies()
565
553
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
566
554
 
567
555
  transform_kwargs = dict(
@@ -628,16 +616,40 @@ class LogisticRegression(BaseTransformer):
628
616
  self._is_fitted = True
629
617
  return output_result
630
618
 
619
+
620
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
621
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
622
+ """ Method not supported for this class.
631
623
 
632
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
633
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
634
- """
624
+
625
+ Raises:
626
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
627
+
628
+ Args:
629
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
630
+ Snowpark or Pandas DataFrame.
631
+ output_cols_prefix: Prefix for the response columns
635
632
  Returns:
636
633
  Transformed dataset.
637
634
  """
638
- self.fit(dataset)
639
- assert self._sklearn_object is not None
640
- return self._sklearn_object.embedding_
635
+ self._infer_input_output_cols(dataset)
636
+ super()._check_dataset_type(dataset)
637
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
638
+ estimator=self._sklearn_object,
639
+ dataset=dataset,
640
+ input_cols=self.input_cols,
641
+ label_cols=self.label_cols,
642
+ sample_weight_col=self.sample_weight_col,
643
+ autogenerated=self._autogenerated,
644
+ subproject=_SUBPROJECT,
645
+ )
646
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
647
+ drop_input_cols=self._drop_input_cols,
648
+ expected_output_cols_list=self.output_cols,
649
+ )
650
+ self._sklearn_object = fitted_estimator
651
+ self._is_fitted = True
652
+ return output_result
641
653
 
642
654
 
643
655
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -730,10 +742,8 @@ class LogisticRegression(BaseTransformer):
730
742
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
731
743
 
732
744
  if isinstance(dataset, DataFrame):
733
- self._deps = self._batch_inference_validate_snowpark(
734
- dataset=dataset,
735
- inference_method=inference_method,
736
- )
745
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
746
+ self._deps = self._get_dependencies()
737
747
  assert isinstance(
738
748
  dataset._session, Session
739
749
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -800,10 +810,8 @@ class LogisticRegression(BaseTransformer):
800
810
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
801
811
 
802
812
  if isinstance(dataset, DataFrame):
803
- self._deps = self._batch_inference_validate_snowpark(
804
- dataset=dataset,
805
- inference_method=inference_method,
806
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
807
815
  assert isinstance(
808
816
  dataset._session, Session
809
817
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -867,10 +875,8 @@ class LogisticRegression(BaseTransformer):
867
875
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
868
876
 
869
877
  if isinstance(dataset, DataFrame):
870
- self._deps = self._batch_inference_validate_snowpark(
871
- dataset=dataset,
872
- inference_method=inference_method,
873
- )
878
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
879
+ self._deps = self._get_dependencies()
874
880
  assert isinstance(
875
881
  dataset._session, Session
876
882
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -936,10 +942,8 @@ class LogisticRegression(BaseTransformer):
936
942
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
937
943
 
938
944
  if isinstance(dataset, DataFrame):
939
- self._deps = self._batch_inference_validate_snowpark(
940
- dataset=dataset,
941
- inference_method=inference_method,
942
- )
945
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
946
+ self._deps = self._get_dependencies()
943
947
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
944
948
  transform_kwargs = dict(
945
949
  session=dataset._session,
@@ -1003,17 +1007,15 @@ class LogisticRegression(BaseTransformer):
1003
1007
  transform_kwargs: ScoreKwargsTypedDict = dict()
1004
1008
 
1005
1009
  if isinstance(dataset, DataFrame):
1006
- self._deps = self._batch_inference_validate_snowpark(
1007
- dataset=dataset,
1008
- inference_method="score",
1009
- )
1010
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1011
+ self._deps = self._get_dependencies()
1010
1012
  selected_cols = self._get_active_columns()
1011
1013
  if len(selected_cols) > 0:
1012
1014
  dataset = dataset.select(selected_cols)
1013
1015
  assert isinstance(dataset._session, Session) # keep mypy happy
1014
1016
  transform_kwargs = dict(
1015
1017
  session=dataset._session,
1016
- dependencies=["snowflake-snowpark-python"] + self._deps,
1018
+ dependencies=self._deps,
1017
1019
  score_sproc_imports=['sklearn'],
1018
1020
  )
1019
1021
  elif isinstance(dataset, pd.DataFrame):
@@ -1078,11 +1080,8 @@ class LogisticRegression(BaseTransformer):
1078
1080
 
1079
1081
  if isinstance(dataset, DataFrame):
1080
1082
 
1081
- self._deps = self._batch_inference_validate_snowpark(
1082
- dataset=dataset,
1083
- inference_method=inference_method,
1084
-
1085
- )
1083
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1084
+ self._deps = self._get_dependencies()
1086
1085
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1087
1086
  transform_kwargs = dict(
1088
1087
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LogisticRegressionCV(BaseTransformer):
70
64
  r"""Logistic Regression CV (aka logit, MaxEnt) classifier
71
65
  For more details on this class, see [sklearn.linear_model.LogisticRegressionCV]
@@ -415,20 +409,17 @@ class LogisticRegressionCV(BaseTransformer):
415
409
  self,
416
410
  dataset: DataFrame,
417
411
  inference_method: str,
418
- ) -> List[str]:
419
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
420
- return the available package that exists in the snowflake anaconda channel
412
+ ) -> None:
413
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
421
414
 
422
415
  Args:
423
416
  dataset: snowpark dataframe
424
417
  inference_method: the inference method such as predict, score...
425
-
418
+
426
419
  Raises:
427
420
  SnowflakeMLException: If the estimator is not fitted, raise error
428
421
  SnowflakeMLException: If the session is None, raise error
429
422
 
430
- Returns:
431
- A list of available package that exists in the snowflake anaconda channel
432
423
  """
433
424
  if not self._is_fitted:
434
425
  raise exceptions.SnowflakeMLException(
@@ -446,9 +437,7 @@ class LogisticRegressionCV(BaseTransformer):
446
437
  "Session must not specified for snowpark dataset."
447
438
  ),
448
439
  )
449
- # Validate that key package version in user workspace are supported in snowflake conda channel
450
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
451
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
440
+
452
441
 
453
442
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
454
443
  @telemetry.send_api_usage_telemetry(
@@ -496,7 +485,8 @@ class LogisticRegressionCV(BaseTransformer):
496
485
 
497
486
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
498
487
 
499
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
488
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
489
+ self._deps = self._get_dependencies()
500
490
  assert isinstance(
501
491
  dataset._session, Session
502
492
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -579,10 +569,8 @@ class LogisticRegressionCV(BaseTransformer):
579
569
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
580
570
  expected_dtype = convert_sp_to_sf_type(output_types[0])
581
571
 
582
- self._deps = self._batch_inference_validate_snowpark(
583
- dataset=dataset,
584
- inference_method=inference_method,
585
- )
572
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
573
+ self._deps = self._get_dependencies()
586
574
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
587
575
 
588
576
  transform_kwargs = dict(
@@ -649,16 +637,40 @@ class LogisticRegressionCV(BaseTransformer):
649
637
  self._is_fitted = True
650
638
  return output_result
651
639
 
640
+
641
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
642
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
643
+ """ Method not supported for this class.
652
644
 
653
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
654
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
655
- """
645
+
646
+ Raises:
647
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
648
+
649
+ Args:
650
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
651
+ Snowpark or Pandas DataFrame.
652
+ output_cols_prefix: Prefix for the response columns
656
653
  Returns:
657
654
  Transformed dataset.
658
655
  """
659
- self.fit(dataset)
660
- assert self._sklearn_object is not None
661
- return self._sklearn_object.embedding_
656
+ self._infer_input_output_cols(dataset)
657
+ super()._check_dataset_type(dataset)
658
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
659
+ estimator=self._sklearn_object,
660
+ dataset=dataset,
661
+ input_cols=self.input_cols,
662
+ label_cols=self.label_cols,
663
+ sample_weight_col=self.sample_weight_col,
664
+ autogenerated=self._autogenerated,
665
+ subproject=_SUBPROJECT,
666
+ )
667
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
668
+ drop_input_cols=self._drop_input_cols,
669
+ expected_output_cols_list=self.output_cols,
670
+ )
671
+ self._sklearn_object = fitted_estimator
672
+ self._is_fitted = True
673
+ return output_result
662
674
 
663
675
 
664
676
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -751,10 +763,8 @@ class LogisticRegressionCV(BaseTransformer):
751
763
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
752
764
 
753
765
  if isinstance(dataset, DataFrame):
754
- self._deps = self._batch_inference_validate_snowpark(
755
- dataset=dataset,
756
- inference_method=inference_method,
757
- )
766
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
767
+ self._deps = self._get_dependencies()
758
768
  assert isinstance(
759
769
  dataset._session, Session
760
770
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -821,10 +831,8 @@ class LogisticRegressionCV(BaseTransformer):
821
831
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
822
832
 
823
833
  if isinstance(dataset, DataFrame):
824
- self._deps = self._batch_inference_validate_snowpark(
825
- dataset=dataset,
826
- inference_method=inference_method,
827
- )
834
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
835
+ self._deps = self._get_dependencies()
828
836
  assert isinstance(
829
837
  dataset._session, Session
830
838
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -888,10 +896,8 @@ class LogisticRegressionCV(BaseTransformer):
888
896
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
889
897
 
890
898
  if isinstance(dataset, DataFrame):
891
- self._deps = self._batch_inference_validate_snowpark(
892
- dataset=dataset,
893
- inference_method=inference_method,
894
- )
899
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
900
+ self._deps = self._get_dependencies()
895
901
  assert isinstance(
896
902
  dataset._session, Session
897
903
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -957,10 +963,8 @@ class LogisticRegressionCV(BaseTransformer):
957
963
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
958
964
 
959
965
  if isinstance(dataset, DataFrame):
960
- self._deps = self._batch_inference_validate_snowpark(
961
- dataset=dataset,
962
- inference_method=inference_method,
963
- )
966
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
967
+ self._deps = self._get_dependencies()
964
968
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
965
969
  transform_kwargs = dict(
966
970
  session=dataset._session,
@@ -1024,17 +1028,15 @@ class LogisticRegressionCV(BaseTransformer):
1024
1028
  transform_kwargs: ScoreKwargsTypedDict = dict()
1025
1029
 
1026
1030
  if isinstance(dataset, DataFrame):
1027
- self._deps = self._batch_inference_validate_snowpark(
1028
- dataset=dataset,
1029
- inference_method="score",
1030
- )
1031
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1032
+ self._deps = self._get_dependencies()
1031
1033
  selected_cols = self._get_active_columns()
1032
1034
  if len(selected_cols) > 0:
1033
1035
  dataset = dataset.select(selected_cols)
1034
1036
  assert isinstance(dataset._session, Session) # keep mypy happy
1035
1037
  transform_kwargs = dict(
1036
1038
  session=dataset._session,
1037
- dependencies=["snowflake-snowpark-python"] + self._deps,
1039
+ dependencies=self._deps,
1038
1040
  score_sproc_imports=['sklearn'],
1039
1041
  )
1040
1042
  elif isinstance(dataset, pd.DataFrame):
@@ -1099,11 +1101,8 @@ class LogisticRegressionCV(BaseTransformer):
1099
1101
 
1100
1102
  if isinstance(dataset, DataFrame):
1101
1103
 
1102
- self._deps = self._batch_inference_validate_snowpark(
1103
- dataset=dataset,
1104
- inference_method=inference_method,
1105
-
1106
- )
1104
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1105
+ self._deps = self._get_dependencies()
1107
1106
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1108
1107
  transform_kwargs = dict(
1109
1108
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MultiTaskElasticNet(BaseTransformer):
70
64
  r"""Multi-task ElasticNet model trained with L1/L2 mixed-norm as regularizer
71
65
  For more details on this class, see [sklearn.linear_model.MultiTaskElasticNet]
@@ -313,20 +307,17 @@ class MultiTaskElasticNet(BaseTransformer):
313
307
  self,
314
308
  dataset: DataFrame,
315
309
  inference_method: str,
316
- ) -> List[str]:
317
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
318
- return the available package that exists in the snowflake anaconda channel
310
+ ) -> None:
311
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
319
312
 
320
313
  Args:
321
314
  dataset: snowpark dataframe
322
315
  inference_method: the inference method such as predict, score...
323
-
316
+
324
317
  Raises:
325
318
  SnowflakeMLException: If the estimator is not fitted, raise error
326
319
  SnowflakeMLException: If the session is None, raise error
327
320
 
328
- Returns:
329
- A list of available package that exists in the snowflake anaconda channel
330
321
  """
331
322
  if not self._is_fitted:
332
323
  raise exceptions.SnowflakeMLException(
@@ -344,9 +335,7 @@ class MultiTaskElasticNet(BaseTransformer):
344
335
  "Session must not specified for snowpark dataset."
345
336
  ),
346
337
  )
347
- # Validate that key package version in user workspace are supported in snowflake conda channel
348
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
349
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
338
+
350
339
 
351
340
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
352
341
  @telemetry.send_api_usage_telemetry(
@@ -394,7 +383,8 @@ class MultiTaskElasticNet(BaseTransformer):
394
383
 
395
384
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
396
385
 
397
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
386
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
387
+ self._deps = self._get_dependencies()
398
388
  assert isinstance(
399
389
  dataset._session, Session
400
390
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -477,10 +467,8 @@ class MultiTaskElasticNet(BaseTransformer):
477
467
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
478
468
  expected_dtype = convert_sp_to_sf_type(output_types[0])
479
469
 
480
- self._deps = self._batch_inference_validate_snowpark(
481
- dataset=dataset,
482
- inference_method=inference_method,
483
- )
470
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
471
+ self._deps = self._get_dependencies()
484
472
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
485
473
 
486
474
  transform_kwargs = dict(
@@ -547,16 +535,40 @@ class MultiTaskElasticNet(BaseTransformer):
547
535
  self._is_fitted = True
548
536
  return output_result
549
537
 
538
+
539
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
540
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
541
+ """ Method not supported for this class.
550
542
 
551
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
552
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
553
- """
543
+
544
+ Raises:
545
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
546
+
547
+ Args:
548
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
549
+ Snowpark or Pandas DataFrame.
550
+ output_cols_prefix: Prefix for the response columns
554
551
  Returns:
555
552
  Transformed dataset.
556
553
  """
557
- self.fit(dataset)
558
- assert self._sklearn_object is not None
559
- return self._sklearn_object.embedding_
554
+ self._infer_input_output_cols(dataset)
555
+ super()._check_dataset_type(dataset)
556
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
557
+ estimator=self._sklearn_object,
558
+ dataset=dataset,
559
+ input_cols=self.input_cols,
560
+ label_cols=self.label_cols,
561
+ sample_weight_col=self.sample_weight_col,
562
+ autogenerated=self._autogenerated,
563
+ subproject=_SUBPROJECT,
564
+ )
565
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
566
+ drop_input_cols=self._drop_input_cols,
567
+ expected_output_cols_list=self.output_cols,
568
+ )
569
+ self._sklearn_object = fitted_estimator
570
+ self._is_fitted = True
571
+ return output_result
560
572
 
561
573
 
562
574
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -647,10 +659,8 @@ class MultiTaskElasticNet(BaseTransformer):
647
659
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
648
660
 
649
661
  if isinstance(dataset, DataFrame):
650
- self._deps = self._batch_inference_validate_snowpark(
651
- dataset=dataset,
652
- inference_method=inference_method,
653
- )
662
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
663
+ self._deps = self._get_dependencies()
654
664
  assert isinstance(
655
665
  dataset._session, Session
656
666
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -715,10 +725,8 @@ class MultiTaskElasticNet(BaseTransformer):
715
725
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
716
726
 
717
727
  if isinstance(dataset, DataFrame):
718
- self._deps = self._batch_inference_validate_snowpark(
719
- dataset=dataset,
720
- inference_method=inference_method,
721
- )
728
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
729
+ self._deps = self._get_dependencies()
722
730
  assert isinstance(
723
731
  dataset._session, Session
724
732
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -780,10 +788,8 @@ class MultiTaskElasticNet(BaseTransformer):
780
788
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
781
789
 
782
790
  if isinstance(dataset, DataFrame):
783
- self._deps = self._batch_inference_validate_snowpark(
784
- dataset=dataset,
785
- inference_method=inference_method,
786
- )
791
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
792
+ self._deps = self._get_dependencies()
787
793
  assert isinstance(
788
794
  dataset._session, Session
789
795
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -849,10 +855,8 @@ class MultiTaskElasticNet(BaseTransformer):
849
855
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
850
856
 
851
857
  if isinstance(dataset, DataFrame):
852
- self._deps = self._batch_inference_validate_snowpark(
853
- dataset=dataset,
854
- inference_method=inference_method,
855
- )
858
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
859
+ self._deps = self._get_dependencies()
856
860
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
857
861
  transform_kwargs = dict(
858
862
  session=dataset._session,
@@ -916,17 +920,15 @@ class MultiTaskElasticNet(BaseTransformer):
916
920
  transform_kwargs: ScoreKwargsTypedDict = dict()
917
921
 
918
922
  if isinstance(dataset, DataFrame):
919
- self._deps = self._batch_inference_validate_snowpark(
920
- dataset=dataset,
921
- inference_method="score",
922
- )
923
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
924
+ self._deps = self._get_dependencies()
923
925
  selected_cols = self._get_active_columns()
924
926
  if len(selected_cols) > 0:
925
927
  dataset = dataset.select(selected_cols)
926
928
  assert isinstance(dataset._session, Session) # keep mypy happy
927
929
  transform_kwargs = dict(
928
930
  session=dataset._session,
929
- dependencies=["snowflake-snowpark-python"] + self._deps,
931
+ dependencies=self._deps,
930
932
  score_sproc_imports=['sklearn'],
931
933
  )
932
934
  elif isinstance(dataset, pd.DataFrame):
@@ -991,11 +993,8 @@ class MultiTaskElasticNet(BaseTransformer):
991
993
 
992
994
  if isinstance(dataset, DataFrame):
993
995
 
994
- self._deps = self._batch_inference_validate_snowpark(
995
- dataset=dataset,
996
- inference_method=inference_method,
997
-
998
- )
996
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
997
+ self._deps = self._get_dependencies()
999
998
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1000
999
  transform_kwargs = dict(
1001
1000
  session = dataset._session,