snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class PCA(BaseTransformer):
70
64
  r"""Principal component analysis (PCA)
71
65
  For more details on this class, see [sklearn.decomposition.PCA]
@@ -347,20 +341,17 @@ class PCA(BaseTransformer):
347
341
  self,
348
342
  dataset: DataFrame,
349
343
  inference_method: str,
350
- ) -> List[str]:
351
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
352
- return the available package that exists in the snowflake anaconda channel
344
+ ) -> None:
345
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
353
346
 
354
347
  Args:
355
348
  dataset: snowpark dataframe
356
349
  inference_method: the inference method such as predict, score...
357
-
350
+
358
351
  Raises:
359
352
  SnowflakeMLException: If the estimator is not fitted, raise error
360
353
  SnowflakeMLException: If the session is None, raise error
361
354
 
362
- Returns:
363
- A list of available package that exists in the snowflake anaconda channel
364
355
  """
365
356
  if not self._is_fitted:
366
357
  raise exceptions.SnowflakeMLException(
@@ -378,9 +369,7 @@ class PCA(BaseTransformer):
378
369
  "Session must not specified for snowpark dataset."
379
370
  ),
380
371
  )
381
- # Validate that key package version in user workspace are supported in snowflake conda channel
382
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
383
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
372
+
384
373
 
385
374
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
386
375
  @telemetry.send_api_usage_telemetry(
@@ -426,7 +415,8 @@ class PCA(BaseTransformer):
426
415
 
427
416
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
428
417
 
429
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
418
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
419
+ self._deps = self._get_dependencies()
430
420
  assert isinstance(
431
421
  dataset._session, Session
432
422
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -511,10 +501,8 @@ class PCA(BaseTransformer):
511
501
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
512
502
  expected_dtype = convert_sp_to_sf_type(output_types[0])
513
503
 
514
- self._deps = self._batch_inference_validate_snowpark(
515
- dataset=dataset,
516
- inference_method=inference_method,
517
- )
504
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
505
+ self._deps = self._get_dependencies()
518
506
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
519
507
 
520
508
  transform_kwargs = dict(
@@ -581,16 +569,42 @@ class PCA(BaseTransformer):
581
569
  self._is_fitted = True
582
570
  return output_result
583
571
 
572
+
573
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
574
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
575
+ """ Fit the model with X and apply the dimensionality reduction on X
576
+ For more details on this function, see [sklearn.decomposition.PCA.fit_transform]
577
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html#sklearn.decomposition.PCA.fit_transform)
578
+
584
579
 
585
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
586
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
587
- """
580
+ Raises:
581
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
582
+
583
+ Args:
584
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
585
+ Snowpark or Pandas DataFrame.
586
+ output_cols_prefix: Prefix for the response columns
588
587
  Returns:
589
588
  Transformed dataset.
590
589
  """
591
- self.fit(dataset)
592
- assert self._sklearn_object is not None
593
- return self._sklearn_object.embedding_
590
+ self._infer_input_output_cols(dataset)
591
+ super()._check_dataset_type(dataset)
592
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
593
+ estimator=self._sklearn_object,
594
+ dataset=dataset,
595
+ input_cols=self.input_cols,
596
+ label_cols=self.label_cols,
597
+ sample_weight_col=self.sample_weight_col,
598
+ autogenerated=self._autogenerated,
599
+ subproject=_SUBPROJECT,
600
+ )
601
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
602
+ drop_input_cols=self._drop_input_cols,
603
+ expected_output_cols_list=self.output_cols,
604
+ )
605
+ self._sklearn_object = fitted_estimator
606
+ self._is_fitted = True
607
+ return output_result
594
608
 
595
609
 
596
610
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -681,10 +695,8 @@ class PCA(BaseTransformer):
681
695
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
682
696
 
683
697
  if isinstance(dataset, DataFrame):
684
- self._deps = self._batch_inference_validate_snowpark(
685
- dataset=dataset,
686
- inference_method=inference_method,
687
- )
698
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
699
+ self._deps = self._get_dependencies()
688
700
  assert isinstance(
689
701
  dataset._session, Session
690
702
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -749,10 +761,8 @@ class PCA(BaseTransformer):
749
761
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
750
762
 
751
763
  if isinstance(dataset, DataFrame):
752
- self._deps = self._batch_inference_validate_snowpark(
753
- dataset=dataset,
754
- inference_method=inference_method,
755
- )
764
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
765
+ self._deps = self._get_dependencies()
756
766
  assert isinstance(
757
767
  dataset._session, Session
758
768
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -814,10 +824,8 @@ class PCA(BaseTransformer):
814
824
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
815
825
 
816
826
  if isinstance(dataset, DataFrame):
817
- self._deps = self._batch_inference_validate_snowpark(
818
- dataset=dataset,
819
- inference_method=inference_method,
820
- )
827
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
828
+ self._deps = self._get_dependencies()
821
829
  assert isinstance(
822
830
  dataset._session, Session
823
831
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -885,10 +893,8 @@ class PCA(BaseTransformer):
885
893
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
886
894
 
887
895
  if isinstance(dataset, DataFrame):
888
- self._deps = self._batch_inference_validate_snowpark(
889
- dataset=dataset,
890
- inference_method=inference_method,
891
- )
896
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
897
+ self._deps = self._get_dependencies()
892
898
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
893
899
  transform_kwargs = dict(
894
900
  session=dataset._session,
@@ -952,17 +958,15 @@ class PCA(BaseTransformer):
952
958
  transform_kwargs: ScoreKwargsTypedDict = dict()
953
959
 
954
960
  if isinstance(dataset, DataFrame):
955
- self._deps = self._batch_inference_validate_snowpark(
956
- dataset=dataset,
957
- inference_method="score",
958
- )
961
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
962
+ self._deps = self._get_dependencies()
959
963
  selected_cols = self._get_active_columns()
960
964
  if len(selected_cols) > 0:
961
965
  dataset = dataset.select(selected_cols)
962
966
  assert isinstance(dataset._session, Session) # keep mypy happy
963
967
  transform_kwargs = dict(
964
968
  session=dataset._session,
965
- dependencies=["snowflake-snowpark-python"] + self._deps,
969
+ dependencies=self._deps,
966
970
  score_sproc_imports=['sklearn'],
967
971
  )
968
972
  elif isinstance(dataset, pd.DataFrame):
@@ -1027,11 +1031,8 @@ class PCA(BaseTransformer):
1027
1031
 
1028
1032
  if isinstance(dataset, DataFrame):
1029
1033
 
1030
- self._deps = self._batch_inference_validate_snowpark(
1031
- dataset=dataset,
1032
- inference_method=inference_method,
1033
-
1034
- )
1034
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1035
+ self._deps = self._get_dependencies()
1035
1036
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1036
1037
  transform_kwargs = dict(
1037
1038
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SparsePCA(BaseTransformer):
70
64
  r"""Sparse Principal Components Analysis (SparsePCA)
71
65
  For more details on this class, see [sklearn.decomposition.SparsePCA]
@@ -320,20 +314,17 @@ class SparsePCA(BaseTransformer):
320
314
  self,
321
315
  dataset: DataFrame,
322
316
  inference_method: str,
323
- ) -> List[str]:
324
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
325
- return the available package that exists in the snowflake anaconda channel
317
+ ) -> None:
318
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
326
319
 
327
320
  Args:
328
321
  dataset: snowpark dataframe
329
322
  inference_method: the inference method such as predict, score...
330
-
323
+
331
324
  Raises:
332
325
  SnowflakeMLException: If the estimator is not fitted, raise error
333
326
  SnowflakeMLException: If the session is None, raise error
334
327
 
335
- Returns:
336
- A list of available package that exists in the snowflake anaconda channel
337
328
  """
338
329
  if not self._is_fitted:
339
330
  raise exceptions.SnowflakeMLException(
@@ -351,9 +342,7 @@ class SparsePCA(BaseTransformer):
351
342
  "Session must not specified for snowpark dataset."
352
343
  ),
353
344
  )
354
- # Validate that key package version in user workspace are supported in snowflake conda channel
355
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
356
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
345
+
357
346
 
358
347
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
359
348
  @telemetry.send_api_usage_telemetry(
@@ -399,7 +388,8 @@ class SparsePCA(BaseTransformer):
399
388
 
400
389
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
401
390
 
402
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
391
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
392
+ self._deps = self._get_dependencies()
403
393
  assert isinstance(
404
394
  dataset._session, Session
405
395
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -484,10 +474,8 @@ class SparsePCA(BaseTransformer):
484
474
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
485
475
  expected_dtype = convert_sp_to_sf_type(output_types[0])
486
476
 
487
- self._deps = self._batch_inference_validate_snowpark(
488
- dataset=dataset,
489
- inference_method=inference_method,
490
- )
477
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
478
+ self._deps = self._get_dependencies()
491
479
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
492
480
 
493
481
  transform_kwargs = dict(
@@ -554,16 +542,42 @@ class SparsePCA(BaseTransformer):
554
542
  self._is_fitted = True
555
543
  return output_result
556
544
 
545
+
546
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
547
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
548
+ """ Fit to data, then transform it
549
+ For more details on this function, see [sklearn.decomposition.SparsePCA.fit_transform]
550
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.SparsePCA.html#sklearn.decomposition.SparsePCA.fit_transform)
551
+
557
552
 
558
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
559
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
560
- """
553
+ Raises:
554
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
555
+
556
+ Args:
557
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
558
+ Snowpark or Pandas DataFrame.
559
+ output_cols_prefix: Prefix for the response columns
561
560
  Returns:
562
561
  Transformed dataset.
563
562
  """
564
- self.fit(dataset)
565
- assert self._sklearn_object is not None
566
- return self._sklearn_object.embedding_
563
+ self._infer_input_output_cols(dataset)
564
+ super()._check_dataset_type(dataset)
565
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
566
+ estimator=self._sklearn_object,
567
+ dataset=dataset,
568
+ input_cols=self.input_cols,
569
+ label_cols=self.label_cols,
570
+ sample_weight_col=self.sample_weight_col,
571
+ autogenerated=self._autogenerated,
572
+ subproject=_SUBPROJECT,
573
+ )
574
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=self.output_cols,
577
+ )
578
+ self._sklearn_object = fitted_estimator
579
+ self._is_fitted = True
580
+ return output_result
567
581
 
568
582
 
569
583
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -654,10 +668,8 @@ class SparsePCA(BaseTransformer):
654
668
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
655
669
 
656
670
  if isinstance(dataset, DataFrame):
657
- self._deps = self._batch_inference_validate_snowpark(
658
- dataset=dataset,
659
- inference_method=inference_method,
660
- )
671
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
672
+ self._deps = self._get_dependencies()
661
673
  assert isinstance(
662
674
  dataset._session, Session
663
675
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -722,10 +734,8 @@ class SparsePCA(BaseTransformer):
722
734
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
723
735
 
724
736
  if isinstance(dataset, DataFrame):
725
- self._deps = self._batch_inference_validate_snowpark(
726
- dataset=dataset,
727
- inference_method=inference_method,
728
- )
737
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
738
+ self._deps = self._get_dependencies()
729
739
  assert isinstance(
730
740
  dataset._session, Session
731
741
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -787,10 +797,8 @@ class SparsePCA(BaseTransformer):
787
797
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
788
798
 
789
799
  if isinstance(dataset, DataFrame):
790
- self._deps = self._batch_inference_validate_snowpark(
791
- dataset=dataset,
792
- inference_method=inference_method,
793
- )
800
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
801
+ self._deps = self._get_dependencies()
794
802
  assert isinstance(
795
803
  dataset._session, Session
796
804
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -856,10 +864,8 @@ class SparsePCA(BaseTransformer):
856
864
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
857
865
 
858
866
  if isinstance(dataset, DataFrame):
859
- self._deps = self._batch_inference_validate_snowpark(
860
- dataset=dataset,
861
- inference_method=inference_method,
862
- )
867
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
868
+ self._deps = self._get_dependencies()
863
869
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
864
870
  transform_kwargs = dict(
865
871
  session=dataset._session,
@@ -921,17 +927,15 @@ class SparsePCA(BaseTransformer):
921
927
  transform_kwargs: ScoreKwargsTypedDict = dict()
922
928
 
923
929
  if isinstance(dataset, DataFrame):
924
- self._deps = self._batch_inference_validate_snowpark(
925
- dataset=dataset,
926
- inference_method="score",
927
- )
930
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
931
+ self._deps = self._get_dependencies()
928
932
  selected_cols = self._get_active_columns()
929
933
  if len(selected_cols) > 0:
930
934
  dataset = dataset.select(selected_cols)
931
935
  assert isinstance(dataset._session, Session) # keep mypy happy
932
936
  transform_kwargs = dict(
933
937
  session=dataset._session,
934
- dependencies=["snowflake-snowpark-python"] + self._deps,
938
+ dependencies=self._deps,
935
939
  score_sproc_imports=['sklearn'],
936
940
  )
937
941
  elif isinstance(dataset, pd.DataFrame):
@@ -996,11 +1000,8 @@ class SparsePCA(BaseTransformer):
996
1000
 
997
1001
  if isinstance(dataset, DataFrame):
998
1002
 
999
- self._deps = self._batch_inference_validate_snowpark(
1000
- dataset=dataset,
1001
- inference_method=inference_method,
1002
-
1003
- )
1003
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1004
+ self._deps = self._get_dependencies()
1004
1005
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1005
1006
  transform_kwargs = dict(
1006
1007
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class TruncatedSVD(BaseTransformer):
70
64
  r"""Dimensionality reduction using truncated SVD (aka LSA)
71
65
  For more details on this class, see [sklearn.decomposition.TruncatedSVD]
@@ -301,20 +295,17 @@ class TruncatedSVD(BaseTransformer):
301
295
  self,
302
296
  dataset: DataFrame,
303
297
  inference_method: str,
304
- ) -> List[str]:
305
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
306
- return the available package that exists in the snowflake anaconda channel
298
+ ) -> None:
299
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
307
300
 
308
301
  Args:
309
302
  dataset: snowpark dataframe
310
303
  inference_method: the inference method such as predict, score...
311
-
304
+
312
305
  Raises:
313
306
  SnowflakeMLException: If the estimator is not fitted, raise error
314
307
  SnowflakeMLException: If the session is None, raise error
315
308
 
316
- Returns:
317
- A list of available package that exists in the snowflake anaconda channel
318
309
  """
319
310
  if not self._is_fitted:
320
311
  raise exceptions.SnowflakeMLException(
@@ -332,9 +323,7 @@ class TruncatedSVD(BaseTransformer):
332
323
  "Session must not specified for snowpark dataset."
333
324
  ),
334
325
  )
335
- # Validate that key package version in user workspace are supported in snowflake conda channel
336
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
337
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
326
+
338
327
 
339
328
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
340
329
  @telemetry.send_api_usage_telemetry(
@@ -380,7 +369,8 @@ class TruncatedSVD(BaseTransformer):
380
369
 
381
370
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
382
371
 
383
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
372
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
373
+ self._deps = self._get_dependencies()
384
374
  assert isinstance(
385
375
  dataset._session, Session
386
376
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -465,10 +455,8 @@ class TruncatedSVD(BaseTransformer):
465
455
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
466
456
  expected_dtype = convert_sp_to_sf_type(output_types[0])
467
457
 
468
- self._deps = self._batch_inference_validate_snowpark(
469
- dataset=dataset,
470
- inference_method=inference_method,
471
- )
458
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
459
+ self._deps = self._get_dependencies()
472
460
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
473
461
 
474
462
  transform_kwargs = dict(
@@ -535,16 +523,42 @@ class TruncatedSVD(BaseTransformer):
535
523
  self._is_fitted = True
536
524
  return output_result
537
525
 
526
+
527
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
528
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
529
+ """ Fit model to X and perform dimensionality reduction on X
530
+ For more details on this function, see [sklearn.decomposition.TruncatedSVD.fit_transform]
531
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html#sklearn.decomposition.TruncatedSVD.fit_transform)
532
+
538
533
 
539
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
540
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
541
- """
534
+ Raises:
535
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
536
+
537
+ Args:
538
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
539
+ Snowpark or Pandas DataFrame.
540
+ output_cols_prefix: Prefix for the response columns
542
541
  Returns:
543
542
  Transformed dataset.
544
543
  """
545
- self.fit(dataset)
546
- assert self._sklearn_object is not None
547
- return self._sklearn_object.embedding_
544
+ self._infer_input_output_cols(dataset)
545
+ super()._check_dataset_type(dataset)
546
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
547
+ estimator=self._sklearn_object,
548
+ dataset=dataset,
549
+ input_cols=self.input_cols,
550
+ label_cols=self.label_cols,
551
+ sample_weight_col=self.sample_weight_col,
552
+ autogenerated=self._autogenerated,
553
+ subproject=_SUBPROJECT,
554
+ )
555
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=self.output_cols,
558
+ )
559
+ self._sklearn_object = fitted_estimator
560
+ self._is_fitted = True
561
+ return output_result
548
562
 
549
563
 
550
564
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -635,10 +649,8 @@ class TruncatedSVD(BaseTransformer):
635
649
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
636
650
 
637
651
  if isinstance(dataset, DataFrame):
638
- self._deps = self._batch_inference_validate_snowpark(
639
- dataset=dataset,
640
- inference_method=inference_method,
641
- )
652
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
653
+ self._deps = self._get_dependencies()
642
654
  assert isinstance(
643
655
  dataset._session, Session
644
656
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -703,10 +715,8 @@ class TruncatedSVD(BaseTransformer):
703
715
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
704
716
 
705
717
  if isinstance(dataset, DataFrame):
706
- self._deps = self._batch_inference_validate_snowpark(
707
- dataset=dataset,
708
- inference_method=inference_method,
709
- )
718
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
719
+ self._deps = self._get_dependencies()
710
720
  assert isinstance(
711
721
  dataset._session, Session
712
722
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -768,10 +778,8 @@ class TruncatedSVD(BaseTransformer):
768
778
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
769
779
 
770
780
  if isinstance(dataset, DataFrame):
771
- self._deps = self._batch_inference_validate_snowpark(
772
- dataset=dataset,
773
- inference_method=inference_method,
774
- )
781
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
782
+ self._deps = self._get_dependencies()
775
783
  assert isinstance(
776
784
  dataset._session, Session
777
785
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -837,10 +845,8 @@ class TruncatedSVD(BaseTransformer):
837
845
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
838
846
 
839
847
  if isinstance(dataset, DataFrame):
840
- self._deps = self._batch_inference_validate_snowpark(
841
- dataset=dataset,
842
- inference_method=inference_method,
843
- )
848
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
849
+ self._deps = self._get_dependencies()
844
850
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
845
851
  transform_kwargs = dict(
846
852
  session=dataset._session,
@@ -902,17 +908,15 @@ class TruncatedSVD(BaseTransformer):
902
908
  transform_kwargs: ScoreKwargsTypedDict = dict()
903
909
 
904
910
  if isinstance(dataset, DataFrame):
905
- self._deps = self._batch_inference_validate_snowpark(
906
- dataset=dataset,
907
- inference_method="score",
908
- )
911
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
912
+ self._deps = self._get_dependencies()
909
913
  selected_cols = self._get_active_columns()
910
914
  if len(selected_cols) > 0:
911
915
  dataset = dataset.select(selected_cols)
912
916
  assert isinstance(dataset._session, Session) # keep mypy happy
913
917
  transform_kwargs = dict(
914
918
  session=dataset._session,
915
- dependencies=["snowflake-snowpark-python"] + self._deps,
919
+ dependencies=self._deps,
916
920
  score_sproc_imports=['sklearn'],
917
921
  )
918
922
  elif isinstance(dataset, pd.DataFrame):
@@ -977,11 +981,8 @@ class TruncatedSVD(BaseTransformer):
977
981
 
978
982
  if isinstance(dataset, DataFrame):
979
983
 
980
- self._deps = self._batch_inference_validate_snowpark(
981
- dataset=dataset,
982
- inference_method=inference_method,
983
-
984
- )
984
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
985
+ self._deps = self._get_dependencies()
985
986
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
986
987
  transform_kwargs = dict(
987
988
  session = dataset._session,