snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class KernelPCA(BaseTransformer):
70
64
  r"""Kernel Principal component analysis (KPCA) [1]_
71
65
  For more details on this class, see [sklearn.decomposition.KernelPCA]
@@ -378,20 +372,17 @@ class KernelPCA(BaseTransformer):
378
372
  self,
379
373
  dataset: DataFrame,
380
374
  inference_method: str,
381
- ) -> List[str]:
382
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
383
- return the available package that exists in the snowflake anaconda channel
375
+ ) -> None:
376
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
384
377
 
385
378
  Args:
386
379
  dataset: snowpark dataframe
387
380
  inference_method: the inference method such as predict, score...
388
-
381
+
389
382
  Raises:
390
383
  SnowflakeMLException: If the estimator is not fitted, raise error
391
384
  SnowflakeMLException: If the session is None, raise error
392
385
 
393
- Returns:
394
- A list of available package that exists in the snowflake anaconda channel
395
386
  """
396
387
  if not self._is_fitted:
397
388
  raise exceptions.SnowflakeMLException(
@@ -409,9 +400,7 @@ class KernelPCA(BaseTransformer):
409
400
  "Session must not specified for snowpark dataset."
410
401
  ),
411
402
  )
412
- # Validate that key package version in user workspace are supported in snowflake conda channel
413
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
414
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
403
+
415
404
 
416
405
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
417
406
  @telemetry.send_api_usage_telemetry(
@@ -457,7 +446,8 @@ class KernelPCA(BaseTransformer):
457
446
 
458
447
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
459
448
 
460
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
449
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
450
+ self._deps = self._get_dependencies()
461
451
  assert isinstance(
462
452
  dataset._session, Session
463
453
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -542,10 +532,8 @@ class KernelPCA(BaseTransformer):
542
532
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
543
533
  expected_dtype = convert_sp_to_sf_type(output_types[0])
544
534
 
545
- self._deps = self._batch_inference_validate_snowpark(
546
- dataset=dataset,
547
- inference_method=inference_method,
548
- )
535
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
536
+ self._deps = self._get_dependencies()
549
537
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
550
538
 
551
539
  transform_kwargs = dict(
@@ -612,16 +600,42 @@ class KernelPCA(BaseTransformer):
612
600
  self._is_fitted = True
613
601
  return output_result
614
602
 
603
+
604
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
605
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
606
+ """ Fit the model from data in X and transform X
607
+ For more details on this function, see [sklearn.decomposition.KernelPCA.fit_transform]
608
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.KernelPCA.html#sklearn.decomposition.KernelPCA.fit_transform)
609
+
615
610
 
616
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
617
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
618
- """
611
+ Raises:
612
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
613
+
614
+ Args:
615
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
616
+ Snowpark or Pandas DataFrame.
617
+ output_cols_prefix: Prefix for the response columns
619
618
  Returns:
620
619
  Transformed dataset.
621
620
  """
622
- self.fit(dataset)
623
- assert self._sklearn_object is not None
624
- return self._sklearn_object.embedding_
621
+ self._infer_input_output_cols(dataset)
622
+ super()._check_dataset_type(dataset)
623
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
624
+ estimator=self._sklearn_object,
625
+ dataset=dataset,
626
+ input_cols=self.input_cols,
627
+ label_cols=self.label_cols,
628
+ sample_weight_col=self.sample_weight_col,
629
+ autogenerated=self._autogenerated,
630
+ subproject=_SUBPROJECT,
631
+ )
632
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
633
+ drop_input_cols=self._drop_input_cols,
634
+ expected_output_cols_list=self.output_cols,
635
+ )
636
+ self._sklearn_object = fitted_estimator
637
+ self._is_fitted = True
638
+ return output_result
625
639
 
626
640
 
627
641
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -712,10 +726,8 @@ class KernelPCA(BaseTransformer):
712
726
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
713
727
 
714
728
  if isinstance(dataset, DataFrame):
715
- self._deps = self._batch_inference_validate_snowpark(
716
- dataset=dataset,
717
- inference_method=inference_method,
718
- )
729
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
730
+ self._deps = self._get_dependencies()
719
731
  assert isinstance(
720
732
  dataset._session, Session
721
733
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -780,10 +792,8 @@ class KernelPCA(BaseTransformer):
780
792
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
781
793
 
782
794
  if isinstance(dataset, DataFrame):
783
- self._deps = self._batch_inference_validate_snowpark(
784
- dataset=dataset,
785
- inference_method=inference_method,
786
- )
795
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
796
+ self._deps = self._get_dependencies()
787
797
  assert isinstance(
788
798
  dataset._session, Session
789
799
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -845,10 +855,8 @@ class KernelPCA(BaseTransformer):
845
855
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
846
856
 
847
857
  if isinstance(dataset, DataFrame):
848
- self._deps = self._batch_inference_validate_snowpark(
849
- dataset=dataset,
850
- inference_method=inference_method,
851
- )
858
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
859
+ self._deps = self._get_dependencies()
852
860
  assert isinstance(
853
861
  dataset._session, Session
854
862
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -914,10 +922,8 @@ class KernelPCA(BaseTransformer):
914
922
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
915
923
 
916
924
  if isinstance(dataset, DataFrame):
917
- self._deps = self._batch_inference_validate_snowpark(
918
- dataset=dataset,
919
- inference_method=inference_method,
920
- )
925
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
926
+ self._deps = self._get_dependencies()
921
927
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
922
928
  transform_kwargs = dict(
923
929
  session=dataset._session,
@@ -979,17 +985,15 @@ class KernelPCA(BaseTransformer):
979
985
  transform_kwargs: ScoreKwargsTypedDict = dict()
980
986
 
981
987
  if isinstance(dataset, DataFrame):
982
- self._deps = self._batch_inference_validate_snowpark(
983
- dataset=dataset,
984
- inference_method="score",
985
- )
988
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
989
+ self._deps = self._get_dependencies()
986
990
  selected_cols = self._get_active_columns()
987
991
  if len(selected_cols) > 0:
988
992
  dataset = dataset.select(selected_cols)
989
993
  assert isinstance(dataset._session, Session) # keep mypy happy
990
994
  transform_kwargs = dict(
991
995
  session=dataset._session,
992
- dependencies=["snowflake-snowpark-python"] + self._deps,
996
+ dependencies=self._deps,
993
997
  score_sproc_imports=['sklearn'],
994
998
  )
995
999
  elif isinstance(dataset, pd.DataFrame):
@@ -1054,11 +1058,8 @@ class KernelPCA(BaseTransformer):
1054
1058
 
1055
1059
  if isinstance(dataset, DataFrame):
1056
1060
 
1057
- self._deps = self._batch_inference_validate_snowpark(
1058
- dataset=dataset,
1059
- inference_method=inference_method,
1060
-
1061
- )
1061
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1062
+ self._deps = self._get_dependencies()
1062
1063
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1063
1064
  transform_kwargs = dict(
1064
1065
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MiniBatchDictionaryLearning(BaseTransformer):
70
64
  r"""Mini-batch dictionary learning
71
65
  For more details on this class, see [sklearn.decomposition.MiniBatchDictionaryLearning]
@@ -400,20 +394,17 @@ class MiniBatchDictionaryLearning(BaseTransformer):
400
394
  self,
401
395
  dataset: DataFrame,
402
396
  inference_method: str,
403
- ) -> List[str]:
404
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
405
- return the available package that exists in the snowflake anaconda channel
397
+ ) -> None:
398
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
406
399
 
407
400
  Args:
408
401
  dataset: snowpark dataframe
409
402
  inference_method: the inference method such as predict, score...
410
-
403
+
411
404
  Raises:
412
405
  SnowflakeMLException: If the estimator is not fitted, raise error
413
406
  SnowflakeMLException: If the session is None, raise error
414
407
 
415
- Returns:
416
- A list of available package that exists in the snowflake anaconda channel
417
408
  """
418
409
  if not self._is_fitted:
419
410
  raise exceptions.SnowflakeMLException(
@@ -431,9 +422,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
431
422
  "Session must not specified for snowpark dataset."
432
423
  ),
433
424
  )
434
- # Validate that key package version in user workspace are supported in snowflake conda channel
435
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
436
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
425
+
437
426
 
438
427
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
439
428
  @telemetry.send_api_usage_telemetry(
@@ -479,7 +468,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
479
468
 
480
469
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
481
470
 
482
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
471
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
472
+ self._deps = self._get_dependencies()
483
473
  assert isinstance(
484
474
  dataset._session, Session
485
475
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -564,10 +554,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
564
554
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
565
555
  expected_dtype = convert_sp_to_sf_type(output_types[0])
566
556
 
567
- self._deps = self._batch_inference_validate_snowpark(
568
- dataset=dataset,
569
- inference_method=inference_method,
570
- )
557
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
558
+ self._deps = self._get_dependencies()
571
559
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
572
560
 
573
561
  transform_kwargs = dict(
@@ -634,16 +622,42 @@ class MiniBatchDictionaryLearning(BaseTransformer):
634
622
  self._is_fitted = True
635
623
  return output_result
636
624
 
625
+
626
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
627
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
628
+ """ Fit to data, then transform it
629
+ For more details on this function, see [sklearn.decomposition.MiniBatchDictionaryLearning.fit_transform]
630
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.MiniBatchDictionaryLearning.html#sklearn.decomposition.MiniBatchDictionaryLearning.fit_transform)
631
+
637
632
 
638
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
639
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
640
- """
633
+ Raises:
634
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
635
+
636
+ Args:
637
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
638
+ Snowpark or Pandas DataFrame.
639
+ output_cols_prefix: Prefix for the response columns
641
640
  Returns:
642
641
  Transformed dataset.
643
642
  """
644
- self.fit(dataset)
645
- assert self._sklearn_object is not None
646
- return self._sklearn_object.embedding_
643
+ self._infer_input_output_cols(dataset)
644
+ super()._check_dataset_type(dataset)
645
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
646
+ estimator=self._sklearn_object,
647
+ dataset=dataset,
648
+ input_cols=self.input_cols,
649
+ label_cols=self.label_cols,
650
+ sample_weight_col=self.sample_weight_col,
651
+ autogenerated=self._autogenerated,
652
+ subproject=_SUBPROJECT,
653
+ )
654
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
655
+ drop_input_cols=self._drop_input_cols,
656
+ expected_output_cols_list=self.output_cols,
657
+ )
658
+ self._sklearn_object = fitted_estimator
659
+ self._is_fitted = True
660
+ return output_result
647
661
 
648
662
 
649
663
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -734,10 +748,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
734
748
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
735
749
 
736
750
  if isinstance(dataset, DataFrame):
737
- self._deps = self._batch_inference_validate_snowpark(
738
- dataset=dataset,
739
- inference_method=inference_method,
740
- )
751
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
752
+ self._deps = self._get_dependencies()
741
753
  assert isinstance(
742
754
  dataset._session, Session
743
755
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -802,10 +814,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
802
814
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
803
815
 
804
816
  if isinstance(dataset, DataFrame):
805
- self._deps = self._batch_inference_validate_snowpark(
806
- dataset=dataset,
807
- inference_method=inference_method,
808
- )
817
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
818
+ self._deps = self._get_dependencies()
809
819
  assert isinstance(
810
820
  dataset._session, Session
811
821
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -867,10 +877,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
867
877
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
868
878
 
869
879
  if isinstance(dataset, DataFrame):
870
- self._deps = self._batch_inference_validate_snowpark(
871
- dataset=dataset,
872
- inference_method=inference_method,
873
- )
880
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
881
+ self._deps = self._get_dependencies()
874
882
  assert isinstance(
875
883
  dataset._session, Session
876
884
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -936,10 +944,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
936
944
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
937
945
 
938
946
  if isinstance(dataset, DataFrame):
939
- self._deps = self._batch_inference_validate_snowpark(
940
- dataset=dataset,
941
- inference_method=inference_method,
942
- )
947
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
948
+ self._deps = self._get_dependencies()
943
949
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
944
950
  transform_kwargs = dict(
945
951
  session=dataset._session,
@@ -1001,17 +1007,15 @@ class MiniBatchDictionaryLearning(BaseTransformer):
1001
1007
  transform_kwargs: ScoreKwargsTypedDict = dict()
1002
1008
 
1003
1009
  if isinstance(dataset, DataFrame):
1004
- self._deps = self._batch_inference_validate_snowpark(
1005
- dataset=dataset,
1006
- inference_method="score",
1007
- )
1010
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1011
+ self._deps = self._get_dependencies()
1008
1012
  selected_cols = self._get_active_columns()
1009
1013
  if len(selected_cols) > 0:
1010
1014
  dataset = dataset.select(selected_cols)
1011
1015
  assert isinstance(dataset._session, Session) # keep mypy happy
1012
1016
  transform_kwargs = dict(
1013
1017
  session=dataset._session,
1014
- dependencies=["snowflake-snowpark-python"] + self._deps,
1018
+ dependencies=self._deps,
1015
1019
  score_sproc_imports=['sklearn'],
1016
1020
  )
1017
1021
  elif isinstance(dataset, pd.DataFrame):
@@ -1076,11 +1080,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
1076
1080
 
1077
1081
  if isinstance(dataset, DataFrame):
1078
1082
 
1079
- self._deps = self._batch_inference_validate_snowpark(
1080
- dataset=dataset,
1081
- inference_method=inference_method,
1082
-
1083
- )
1083
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1084
+ self._deps = self._get_dependencies()
1084
1085
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1085
1086
  transform_kwargs = dict(
1086
1087
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MiniBatchSparsePCA(BaseTransformer):
70
64
  r"""Mini-batch Sparse Principal Components Analysis
71
65
  For more details on this class, see [sklearn.decomposition.MiniBatchSparsePCA]
@@ -345,20 +339,17 @@ class MiniBatchSparsePCA(BaseTransformer):
345
339
  self,
346
340
  dataset: DataFrame,
347
341
  inference_method: str,
348
- ) -> List[str]:
349
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
350
- return the available package that exists in the snowflake anaconda channel
342
+ ) -> None:
343
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
351
344
 
352
345
  Args:
353
346
  dataset: snowpark dataframe
354
347
  inference_method: the inference method such as predict, score...
355
-
348
+
356
349
  Raises:
357
350
  SnowflakeMLException: If the estimator is not fitted, raise error
358
351
  SnowflakeMLException: If the session is None, raise error
359
352
 
360
- Returns:
361
- A list of available package that exists in the snowflake anaconda channel
362
353
  """
363
354
  if not self._is_fitted:
364
355
  raise exceptions.SnowflakeMLException(
@@ -376,9 +367,7 @@ class MiniBatchSparsePCA(BaseTransformer):
376
367
  "Session must not specified for snowpark dataset."
377
368
  ),
378
369
  )
379
- # Validate that key package version in user workspace are supported in snowflake conda channel
380
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
381
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
370
+
382
371
 
383
372
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
384
373
  @telemetry.send_api_usage_telemetry(
@@ -424,7 +413,8 @@ class MiniBatchSparsePCA(BaseTransformer):
424
413
 
425
414
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
426
415
 
427
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
416
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
417
+ self._deps = self._get_dependencies()
428
418
  assert isinstance(
429
419
  dataset._session, Session
430
420
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -509,10 +499,8 @@ class MiniBatchSparsePCA(BaseTransformer):
509
499
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
510
500
  expected_dtype = convert_sp_to_sf_type(output_types[0])
511
501
 
512
- self._deps = self._batch_inference_validate_snowpark(
513
- dataset=dataset,
514
- inference_method=inference_method,
515
- )
502
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
503
+ self._deps = self._get_dependencies()
516
504
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
517
505
 
518
506
  transform_kwargs = dict(
@@ -579,16 +567,42 @@ class MiniBatchSparsePCA(BaseTransformer):
579
567
  self._is_fitted = True
580
568
  return output_result
581
569
 
570
+
571
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
572
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
573
+ """ Fit to data, then transform it
574
+ For more details on this function, see [sklearn.decomposition.MiniBatchSparsePCA.fit_transform]
575
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.MiniBatchSparsePCA.html#sklearn.decomposition.MiniBatchSparsePCA.fit_transform)
576
+
582
577
 
583
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
584
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
585
- """
578
+ Raises:
579
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
580
+
581
+ Args:
582
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
583
+ Snowpark or Pandas DataFrame.
584
+ output_cols_prefix: Prefix for the response columns
586
585
  Returns:
587
586
  Transformed dataset.
588
587
  """
589
- self.fit(dataset)
590
- assert self._sklearn_object is not None
591
- return self._sklearn_object.embedding_
588
+ self._infer_input_output_cols(dataset)
589
+ super()._check_dataset_type(dataset)
590
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
591
+ estimator=self._sklearn_object,
592
+ dataset=dataset,
593
+ input_cols=self.input_cols,
594
+ label_cols=self.label_cols,
595
+ sample_weight_col=self.sample_weight_col,
596
+ autogenerated=self._autogenerated,
597
+ subproject=_SUBPROJECT,
598
+ )
599
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
600
+ drop_input_cols=self._drop_input_cols,
601
+ expected_output_cols_list=self.output_cols,
602
+ )
603
+ self._sklearn_object = fitted_estimator
604
+ self._is_fitted = True
605
+ return output_result
592
606
 
593
607
 
594
608
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -679,10 +693,8 @@ class MiniBatchSparsePCA(BaseTransformer):
679
693
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
680
694
 
681
695
  if isinstance(dataset, DataFrame):
682
- self._deps = self._batch_inference_validate_snowpark(
683
- dataset=dataset,
684
- inference_method=inference_method,
685
- )
696
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
697
+ self._deps = self._get_dependencies()
686
698
  assert isinstance(
687
699
  dataset._session, Session
688
700
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -747,10 +759,8 @@ class MiniBatchSparsePCA(BaseTransformer):
747
759
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
748
760
 
749
761
  if isinstance(dataset, DataFrame):
750
- self._deps = self._batch_inference_validate_snowpark(
751
- dataset=dataset,
752
- inference_method=inference_method,
753
- )
762
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
763
+ self._deps = self._get_dependencies()
754
764
  assert isinstance(
755
765
  dataset._session, Session
756
766
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -812,10 +822,8 @@ class MiniBatchSparsePCA(BaseTransformer):
812
822
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
813
823
 
814
824
  if isinstance(dataset, DataFrame):
815
- self._deps = self._batch_inference_validate_snowpark(
816
- dataset=dataset,
817
- inference_method=inference_method,
818
- )
825
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
826
+ self._deps = self._get_dependencies()
819
827
  assert isinstance(
820
828
  dataset._session, Session
821
829
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -881,10 +889,8 @@ class MiniBatchSparsePCA(BaseTransformer):
881
889
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
882
890
 
883
891
  if isinstance(dataset, DataFrame):
884
- self._deps = self._batch_inference_validate_snowpark(
885
- dataset=dataset,
886
- inference_method=inference_method,
887
- )
892
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
893
+ self._deps = self._get_dependencies()
888
894
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
889
895
  transform_kwargs = dict(
890
896
  session=dataset._session,
@@ -946,17 +952,15 @@ class MiniBatchSparsePCA(BaseTransformer):
946
952
  transform_kwargs: ScoreKwargsTypedDict = dict()
947
953
 
948
954
  if isinstance(dataset, DataFrame):
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method="score",
952
- )
955
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
956
+ self._deps = self._get_dependencies()
953
957
  selected_cols = self._get_active_columns()
954
958
  if len(selected_cols) > 0:
955
959
  dataset = dataset.select(selected_cols)
956
960
  assert isinstance(dataset._session, Session) # keep mypy happy
957
961
  transform_kwargs = dict(
958
962
  session=dataset._session,
959
- dependencies=["snowflake-snowpark-python"] + self._deps,
963
+ dependencies=self._deps,
960
964
  score_sproc_imports=['sklearn'],
961
965
  )
962
966
  elif isinstance(dataset, pd.DataFrame):
@@ -1021,11 +1025,8 @@ class MiniBatchSparsePCA(BaseTransformer):
1021
1025
 
1022
1026
  if isinstance(dataset, DataFrame):
1023
1027
 
1024
- self._deps = self._batch_inference_validate_snowpark(
1025
- dataset=dataset,
1026
- inference_method=inference_method,
1027
-
1028
- )
1028
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1029
+ self._deps = self._get_dependencies()
1029
1030
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1030
1031
  transform_kwargs = dict(
1031
1032
  session = dataset._session,