snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class FactorAnalysis(BaseTransformer):
70
64
  r"""Factor Analysis (FA)
71
65
  For more details on this class, see [sklearn.decomposition.FactorAnalysis]
@@ -312,20 +306,17 @@ class FactorAnalysis(BaseTransformer):
312
306
  self,
313
307
  dataset: DataFrame,
314
308
  inference_method: str,
315
- ) -> List[str]:
316
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
317
- return the available package that exists in the snowflake anaconda channel
309
+ ) -> None:
310
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
318
311
 
319
312
  Args:
320
313
  dataset: snowpark dataframe
321
314
  inference_method: the inference method such as predict, score...
322
-
315
+
323
316
  Raises:
324
317
  SnowflakeMLException: If the estimator is not fitted, raise error
325
318
  SnowflakeMLException: If the session is None, raise error
326
319
 
327
- Returns:
328
- A list of available package that exists in the snowflake anaconda channel
329
320
  """
330
321
  if not self._is_fitted:
331
322
  raise exceptions.SnowflakeMLException(
@@ -343,9 +334,7 @@ class FactorAnalysis(BaseTransformer):
343
334
  "Session must not specified for snowpark dataset."
344
335
  ),
345
336
  )
346
- # Validate that key package version in user workspace are supported in snowflake conda channel
347
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
348
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
337
+
349
338
 
350
339
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
351
340
  @telemetry.send_api_usage_telemetry(
@@ -391,7 +380,8 @@ class FactorAnalysis(BaseTransformer):
391
380
 
392
381
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
393
382
 
394
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
383
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
384
+ self._deps = self._get_dependencies()
395
385
  assert isinstance(
396
386
  dataset._session, Session
397
387
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -476,10 +466,8 @@ class FactorAnalysis(BaseTransformer):
476
466
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
477
467
  expected_dtype = convert_sp_to_sf_type(output_types[0])
478
468
 
479
- self._deps = self._batch_inference_validate_snowpark(
480
- dataset=dataset,
481
- inference_method=inference_method,
482
- )
469
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
470
+ self._deps = self._get_dependencies()
483
471
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
484
472
 
485
473
  transform_kwargs = dict(
@@ -546,16 +534,42 @@ class FactorAnalysis(BaseTransformer):
546
534
  self._is_fitted = True
547
535
  return output_result
548
536
 
537
+
538
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
539
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
540
+ """ Fit to data, then transform it
541
+ For more details on this function, see [sklearn.decomposition.FactorAnalysis.fit_transform]
542
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FactorAnalysis.html#sklearn.decomposition.FactorAnalysis.fit_transform)
543
+
549
544
 
550
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
551
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
552
- """
545
+ Raises:
546
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
547
+
548
+ Args:
549
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
550
+ Snowpark or Pandas DataFrame.
551
+ output_cols_prefix: Prefix for the response columns
553
552
  Returns:
554
553
  Transformed dataset.
555
554
  """
556
- self.fit(dataset)
557
- assert self._sklearn_object is not None
558
- return self._sklearn_object.embedding_
555
+ self._infer_input_output_cols(dataset)
556
+ super()._check_dataset_type(dataset)
557
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
558
+ estimator=self._sklearn_object,
559
+ dataset=dataset,
560
+ input_cols=self.input_cols,
561
+ label_cols=self.label_cols,
562
+ sample_weight_col=self.sample_weight_col,
563
+ autogenerated=self._autogenerated,
564
+ subproject=_SUBPROJECT,
565
+ )
566
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_list=self.output_cols,
569
+ )
570
+ self._sklearn_object = fitted_estimator
571
+ self._is_fitted = True
572
+ return output_result
559
573
 
560
574
 
561
575
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -646,10 +660,8 @@ class FactorAnalysis(BaseTransformer):
646
660
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
647
661
 
648
662
  if isinstance(dataset, DataFrame):
649
- self._deps = self._batch_inference_validate_snowpark(
650
- dataset=dataset,
651
- inference_method=inference_method,
652
- )
663
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
664
+ self._deps = self._get_dependencies()
653
665
  assert isinstance(
654
666
  dataset._session, Session
655
667
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -714,10 +726,8 @@ class FactorAnalysis(BaseTransformer):
714
726
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
715
727
 
716
728
  if isinstance(dataset, DataFrame):
717
- self._deps = self._batch_inference_validate_snowpark(
718
- dataset=dataset,
719
- inference_method=inference_method,
720
- )
729
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
730
+ self._deps = self._get_dependencies()
721
731
  assert isinstance(
722
732
  dataset._session, Session
723
733
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -779,10 +789,8 @@ class FactorAnalysis(BaseTransformer):
779
789
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
780
790
 
781
791
  if isinstance(dataset, DataFrame):
782
- self._deps = self._batch_inference_validate_snowpark(
783
- dataset=dataset,
784
- inference_method=inference_method,
785
- )
792
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
793
+ self._deps = self._get_dependencies()
786
794
  assert isinstance(
787
795
  dataset._session, Session
788
796
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -850,10 +858,8 @@ class FactorAnalysis(BaseTransformer):
850
858
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
851
859
 
852
860
  if isinstance(dataset, DataFrame):
853
- self._deps = self._batch_inference_validate_snowpark(
854
- dataset=dataset,
855
- inference_method=inference_method,
856
- )
861
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
862
+ self._deps = self._get_dependencies()
857
863
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
858
864
  transform_kwargs = dict(
859
865
  session=dataset._session,
@@ -917,17 +923,15 @@ class FactorAnalysis(BaseTransformer):
917
923
  transform_kwargs: ScoreKwargsTypedDict = dict()
918
924
 
919
925
  if isinstance(dataset, DataFrame):
920
- self._deps = self._batch_inference_validate_snowpark(
921
- dataset=dataset,
922
- inference_method="score",
923
- )
926
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
927
+ self._deps = self._get_dependencies()
924
928
  selected_cols = self._get_active_columns()
925
929
  if len(selected_cols) > 0:
926
930
  dataset = dataset.select(selected_cols)
927
931
  assert isinstance(dataset._session, Session) # keep mypy happy
928
932
  transform_kwargs = dict(
929
933
  session=dataset._session,
930
- dependencies=["snowflake-snowpark-python"] + self._deps,
934
+ dependencies=self._deps,
931
935
  score_sproc_imports=['sklearn'],
932
936
  )
933
937
  elif isinstance(dataset, pd.DataFrame):
@@ -992,11 +996,8 @@ class FactorAnalysis(BaseTransformer):
992
996
 
993
997
  if isinstance(dataset, DataFrame):
994
998
 
995
- self._deps = self._batch_inference_validate_snowpark(
996
- dataset=dataset,
997
- inference_method=inference_method,
998
-
999
- )
999
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1000
+ self._deps = self._get_dependencies()
1000
1001
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1001
1002
  transform_kwargs = dict(
1002
1003
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class FastICA(BaseTransformer):
70
64
  r"""FastICA: a fast algorithm for Independent Component Analysis
71
65
  For more details on this class, see [sklearn.decomposition.FastICA]
@@ -330,20 +324,17 @@ class FastICA(BaseTransformer):
330
324
  self,
331
325
  dataset: DataFrame,
332
326
  inference_method: str,
333
- ) -> List[str]:
334
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
335
- return the available package that exists in the snowflake anaconda channel
327
+ ) -> None:
328
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
336
329
 
337
330
  Args:
338
331
  dataset: snowpark dataframe
339
332
  inference_method: the inference method such as predict, score...
340
-
333
+
341
334
  Raises:
342
335
  SnowflakeMLException: If the estimator is not fitted, raise error
343
336
  SnowflakeMLException: If the session is None, raise error
344
337
 
345
- Returns:
346
- A list of available package that exists in the snowflake anaconda channel
347
338
  """
348
339
  if not self._is_fitted:
349
340
  raise exceptions.SnowflakeMLException(
@@ -361,9 +352,7 @@ class FastICA(BaseTransformer):
361
352
  "Session must not specified for snowpark dataset."
362
353
  ),
363
354
  )
364
- # Validate that key package version in user workspace are supported in snowflake conda channel
365
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
366
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
355
+
367
356
 
368
357
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
369
358
  @telemetry.send_api_usage_telemetry(
@@ -409,7 +398,8 @@ class FastICA(BaseTransformer):
409
398
 
410
399
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
411
400
 
412
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
401
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
402
+ self._deps = self._get_dependencies()
413
403
  assert isinstance(
414
404
  dataset._session, Session
415
405
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -494,10 +484,8 @@ class FastICA(BaseTransformer):
494
484
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
495
485
  expected_dtype = convert_sp_to_sf_type(output_types[0])
496
486
 
497
- self._deps = self._batch_inference_validate_snowpark(
498
- dataset=dataset,
499
- inference_method=inference_method,
500
- )
487
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
488
+ self._deps = self._get_dependencies()
501
489
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
502
490
 
503
491
  transform_kwargs = dict(
@@ -564,16 +552,42 @@ class FastICA(BaseTransformer):
564
552
  self._is_fitted = True
565
553
  return output_result
566
554
 
555
+
556
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
557
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
558
+ """ Fit the model and recover the sources from X
559
+ For more details on this function, see [sklearn.decomposition.FastICA.fit_transform]
560
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FastICA.html#sklearn.decomposition.FastICA.fit_transform)
561
+
567
562
 
568
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
569
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
570
- """
563
+ Raises:
564
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
565
+
566
+ Args:
567
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
568
+ Snowpark or Pandas DataFrame.
569
+ output_cols_prefix: Prefix for the response columns
571
570
  Returns:
572
571
  Transformed dataset.
573
572
  """
574
- self.fit(dataset)
575
- assert self._sklearn_object is not None
576
- return self._sklearn_object.embedding_
573
+ self._infer_input_output_cols(dataset)
574
+ super()._check_dataset_type(dataset)
575
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
576
+ estimator=self._sklearn_object,
577
+ dataset=dataset,
578
+ input_cols=self.input_cols,
579
+ label_cols=self.label_cols,
580
+ sample_weight_col=self.sample_weight_col,
581
+ autogenerated=self._autogenerated,
582
+ subproject=_SUBPROJECT,
583
+ )
584
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
585
+ drop_input_cols=self._drop_input_cols,
586
+ expected_output_cols_list=self.output_cols,
587
+ )
588
+ self._sklearn_object = fitted_estimator
589
+ self._is_fitted = True
590
+ return output_result
577
591
 
578
592
 
579
593
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -664,10 +678,8 @@ class FastICA(BaseTransformer):
664
678
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
665
679
 
666
680
  if isinstance(dataset, DataFrame):
667
- self._deps = self._batch_inference_validate_snowpark(
668
- dataset=dataset,
669
- inference_method=inference_method,
670
- )
681
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
682
+ self._deps = self._get_dependencies()
671
683
  assert isinstance(
672
684
  dataset._session, Session
673
685
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -732,10 +744,8 @@ class FastICA(BaseTransformer):
732
744
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
733
745
 
734
746
  if isinstance(dataset, DataFrame):
735
- self._deps = self._batch_inference_validate_snowpark(
736
- dataset=dataset,
737
- inference_method=inference_method,
738
- )
747
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
748
+ self._deps = self._get_dependencies()
739
749
  assert isinstance(
740
750
  dataset._session, Session
741
751
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -797,10 +807,8 @@ class FastICA(BaseTransformer):
797
807
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
798
808
 
799
809
  if isinstance(dataset, DataFrame):
800
- self._deps = self._batch_inference_validate_snowpark(
801
- dataset=dataset,
802
- inference_method=inference_method,
803
- )
810
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
811
+ self._deps = self._get_dependencies()
804
812
  assert isinstance(
805
813
  dataset._session, Session
806
814
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -866,10 +874,8 @@ class FastICA(BaseTransformer):
866
874
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
867
875
 
868
876
  if isinstance(dataset, DataFrame):
869
- self._deps = self._batch_inference_validate_snowpark(
870
- dataset=dataset,
871
- inference_method=inference_method,
872
- )
877
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
878
+ self._deps = self._get_dependencies()
873
879
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
874
880
  transform_kwargs = dict(
875
881
  session=dataset._session,
@@ -931,17 +937,15 @@ class FastICA(BaseTransformer):
931
937
  transform_kwargs: ScoreKwargsTypedDict = dict()
932
938
 
933
939
  if isinstance(dataset, DataFrame):
934
- self._deps = self._batch_inference_validate_snowpark(
935
- dataset=dataset,
936
- inference_method="score",
937
- )
940
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
941
+ self._deps = self._get_dependencies()
938
942
  selected_cols = self._get_active_columns()
939
943
  if len(selected_cols) > 0:
940
944
  dataset = dataset.select(selected_cols)
941
945
  assert isinstance(dataset._session, Session) # keep mypy happy
942
946
  transform_kwargs = dict(
943
947
  session=dataset._session,
944
- dependencies=["snowflake-snowpark-python"] + self._deps,
948
+ dependencies=self._deps,
945
949
  score_sproc_imports=['sklearn'],
946
950
  )
947
951
  elif isinstance(dataset, pd.DataFrame):
@@ -1006,11 +1010,8 @@ class FastICA(BaseTransformer):
1006
1010
 
1007
1011
  if isinstance(dataset, DataFrame):
1008
1012
 
1009
- self._deps = self._batch_inference_validate_snowpark(
1010
- dataset=dataset,
1011
- inference_method=inference_method,
1012
-
1013
- )
1013
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1014
+ self._deps = self._get_dependencies()
1014
1015
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1015
1016
  transform_kwargs = dict(
1016
1017
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class IncrementalPCA(BaseTransformer):
70
64
  r"""Incremental principal components analysis (IPCA)
71
65
  For more details on this class, see [sklearn.decomposition.IncrementalPCA]
@@ -282,20 +276,17 @@ class IncrementalPCA(BaseTransformer):
282
276
  self,
283
277
  dataset: DataFrame,
284
278
  inference_method: str,
285
- ) -> List[str]:
286
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
287
- return the available package that exists in the snowflake anaconda channel
279
+ ) -> None:
280
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
288
281
 
289
282
  Args:
290
283
  dataset: snowpark dataframe
291
284
  inference_method: the inference method such as predict, score...
292
-
285
+
293
286
  Raises:
294
287
  SnowflakeMLException: If the estimator is not fitted, raise error
295
288
  SnowflakeMLException: If the session is None, raise error
296
289
 
297
- Returns:
298
- A list of available package that exists in the snowflake anaconda channel
299
290
  """
300
291
  if not self._is_fitted:
301
292
  raise exceptions.SnowflakeMLException(
@@ -313,9 +304,7 @@ class IncrementalPCA(BaseTransformer):
313
304
  "Session must not specified for snowpark dataset."
314
305
  ),
315
306
  )
316
- # Validate that key package version in user workspace are supported in snowflake conda channel
317
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
318
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
307
+
319
308
 
320
309
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
321
310
  @telemetry.send_api_usage_telemetry(
@@ -361,7 +350,8 @@ class IncrementalPCA(BaseTransformer):
361
350
 
362
351
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
363
352
 
364
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
353
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
+ self._deps = self._get_dependencies()
365
355
  assert isinstance(
366
356
  dataset._session, Session
367
357
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -446,10 +436,8 @@ class IncrementalPCA(BaseTransformer):
446
436
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
447
437
  expected_dtype = convert_sp_to_sf_type(output_types[0])
448
438
 
449
- self._deps = self._batch_inference_validate_snowpark(
450
- dataset=dataset,
451
- inference_method=inference_method,
452
- )
439
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
440
+ self._deps = self._get_dependencies()
453
441
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
454
442
 
455
443
  transform_kwargs = dict(
@@ -516,16 +504,42 @@ class IncrementalPCA(BaseTransformer):
516
504
  self._is_fitted = True
517
505
  return output_result
518
506
 
507
+
508
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
509
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
510
+ """ Fit to data, then transform it
511
+ For more details on this function, see [sklearn.decomposition.IncrementalPCA.fit_transform]
512
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.IncrementalPCA.html#sklearn.decomposition.IncrementalPCA.fit_transform)
513
+
519
514
 
520
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
521
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
522
- """
515
+ Raises:
516
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
517
+
518
+ Args:
519
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
520
+ Snowpark or Pandas DataFrame.
521
+ output_cols_prefix: Prefix for the response columns
523
522
  Returns:
524
523
  Transformed dataset.
525
524
  """
526
- self.fit(dataset)
527
- assert self._sklearn_object is not None
528
- return self._sklearn_object.embedding_
525
+ self._infer_input_output_cols(dataset)
526
+ super()._check_dataset_type(dataset)
527
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
528
+ estimator=self._sklearn_object,
529
+ dataset=dataset,
530
+ input_cols=self.input_cols,
531
+ label_cols=self.label_cols,
532
+ sample_weight_col=self.sample_weight_col,
533
+ autogenerated=self._autogenerated,
534
+ subproject=_SUBPROJECT,
535
+ )
536
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
537
+ drop_input_cols=self._drop_input_cols,
538
+ expected_output_cols_list=self.output_cols,
539
+ )
540
+ self._sklearn_object = fitted_estimator
541
+ self._is_fitted = True
542
+ return output_result
529
543
 
530
544
 
531
545
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -616,10 +630,8 @@ class IncrementalPCA(BaseTransformer):
616
630
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
617
631
 
618
632
  if isinstance(dataset, DataFrame):
619
- self._deps = self._batch_inference_validate_snowpark(
620
- dataset=dataset,
621
- inference_method=inference_method,
622
- )
633
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
634
+ self._deps = self._get_dependencies()
623
635
  assert isinstance(
624
636
  dataset._session, Session
625
637
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -684,10 +696,8 @@ class IncrementalPCA(BaseTransformer):
684
696
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
685
697
 
686
698
  if isinstance(dataset, DataFrame):
687
- self._deps = self._batch_inference_validate_snowpark(
688
- dataset=dataset,
689
- inference_method=inference_method,
690
- )
699
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
700
+ self._deps = self._get_dependencies()
691
701
  assert isinstance(
692
702
  dataset._session, Session
693
703
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -749,10 +759,8 @@ class IncrementalPCA(BaseTransformer):
749
759
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
750
760
 
751
761
  if isinstance(dataset, DataFrame):
752
- self._deps = self._batch_inference_validate_snowpark(
753
- dataset=dataset,
754
- inference_method=inference_method,
755
- )
762
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
763
+ self._deps = self._get_dependencies()
756
764
  assert isinstance(
757
765
  dataset._session, Session
758
766
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -818,10 +826,8 @@ class IncrementalPCA(BaseTransformer):
818
826
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
819
827
 
820
828
  if isinstance(dataset, DataFrame):
821
- self._deps = self._batch_inference_validate_snowpark(
822
- dataset=dataset,
823
- inference_method=inference_method,
824
- )
829
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
830
+ self._deps = self._get_dependencies()
825
831
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
826
832
  transform_kwargs = dict(
827
833
  session=dataset._session,
@@ -883,17 +889,15 @@ class IncrementalPCA(BaseTransformer):
883
889
  transform_kwargs: ScoreKwargsTypedDict = dict()
884
890
 
885
891
  if isinstance(dataset, DataFrame):
886
- self._deps = self._batch_inference_validate_snowpark(
887
- dataset=dataset,
888
- inference_method="score",
889
- )
892
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
893
+ self._deps = self._get_dependencies()
890
894
  selected_cols = self._get_active_columns()
891
895
  if len(selected_cols) > 0:
892
896
  dataset = dataset.select(selected_cols)
893
897
  assert isinstance(dataset._session, Session) # keep mypy happy
894
898
  transform_kwargs = dict(
895
899
  session=dataset._session,
896
- dependencies=["snowflake-snowpark-python"] + self._deps,
900
+ dependencies=self._deps,
897
901
  score_sproc_imports=['sklearn'],
898
902
  )
899
903
  elif isinstance(dataset, pd.DataFrame):
@@ -958,11 +962,8 @@ class IncrementalPCA(BaseTransformer):
958
962
 
959
963
  if isinstance(dataset, DataFrame):
960
964
 
961
- self._deps = self._batch_inference_validate_snowpark(
962
- dataset=dataset,
963
- inference_method=inference_method,
964
-
965
- )
965
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
966
+ self._deps = self._get_dependencies()
966
967
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
967
968
  transform_kwargs = dict(
968
969
  session = dataset._session,