snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
61
61
 
62
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
63
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
64
  class SelectFwe(BaseTransformer):
71
65
  r"""Filter: Select the p-values corresponding to Family-wise error rate
72
66
  For more details on this class, see [sklearn.feature_selection.SelectFwe]
@@ -266,20 +260,17 @@ class SelectFwe(BaseTransformer):
266
260
  self,
267
261
  dataset: DataFrame,
268
262
  inference_method: str,
269
- ) -> List[str]:
270
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
271
- return the available package that exists in the snowflake anaconda channel
263
+ ) -> None:
264
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
272
265
 
273
266
  Args:
274
267
  dataset: snowpark dataframe
275
268
  inference_method: the inference method such as predict, score...
276
-
269
+
277
270
  Raises:
278
271
  SnowflakeMLException: If the estimator is not fitted, raise error
279
272
  SnowflakeMLException: If the session is None, raise error
280
273
 
281
- Returns:
282
- A list of available package that exists in the snowflake anaconda channel
283
274
  """
284
275
  if not self._is_fitted:
285
276
  raise exceptions.SnowflakeMLException(
@@ -297,9 +288,7 @@ class SelectFwe(BaseTransformer):
297
288
  "Session must not specified for snowpark dataset."
298
289
  ),
299
290
  )
300
- # Validate that key package version in user workspace are supported in snowflake conda channel
301
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
302
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
+
303
292
 
304
293
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
305
294
  @telemetry.send_api_usage_telemetry(
@@ -345,7 +334,8 @@ class SelectFwe(BaseTransformer):
345
334
 
346
335
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
347
336
 
348
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
337
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
338
+ self._deps = self._get_dependencies()
349
339
  assert isinstance(
350
340
  dataset._session, Session
351
341
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -430,10 +420,8 @@ class SelectFwe(BaseTransformer):
430
420
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
431
421
  expected_dtype = convert_sp_to_sf_type(output_types[0])
432
422
 
433
- self._deps = self._batch_inference_validate_snowpark(
434
- dataset=dataset,
435
- inference_method=inference_method,
436
- )
423
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
424
+ self._deps = self._get_dependencies()
437
425
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
438
426
 
439
427
  transform_kwargs = dict(
@@ -500,16 +488,42 @@ class SelectFwe(BaseTransformer):
500
488
  self._is_fitted = True
501
489
  return output_result
502
490
 
491
+
492
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
493
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
494
+ """ Fit to data, then transform it
495
+ For more details on this function, see [sklearn.feature_selection.SelectFwe.fit_transform]
496
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFwe.html#sklearn.feature_selection.SelectFwe.fit_transform)
497
+
503
498
 
504
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
505
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
506
- """
499
+ Raises:
500
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
501
+
502
+ Args:
503
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
504
+ Snowpark or Pandas DataFrame.
505
+ output_cols_prefix: Prefix for the response columns
507
506
  Returns:
508
507
  Transformed dataset.
509
508
  """
510
- self.fit(dataset)
511
- assert self._sklearn_object is not None
512
- return self._sklearn_object.embedding_
509
+ self._infer_input_output_cols(dataset)
510
+ super()._check_dataset_type(dataset)
511
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
512
+ estimator=self._sklearn_object,
513
+ dataset=dataset,
514
+ input_cols=self.input_cols,
515
+ label_cols=self.label_cols,
516
+ sample_weight_col=self.sample_weight_col,
517
+ autogenerated=self._autogenerated,
518
+ subproject=_SUBPROJECT,
519
+ )
520
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=self.output_cols,
523
+ )
524
+ self._sklearn_object = fitted_estimator
525
+ self._is_fitted = True
526
+ return output_result
513
527
 
514
528
 
515
529
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -600,10 +614,8 @@ class SelectFwe(BaseTransformer):
600
614
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
601
615
 
602
616
  if isinstance(dataset, DataFrame):
603
- self._deps = self._batch_inference_validate_snowpark(
604
- dataset=dataset,
605
- inference_method=inference_method,
606
- )
617
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
618
+ self._deps = self._get_dependencies()
607
619
  assert isinstance(
608
620
  dataset._session, Session
609
621
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -668,10 +680,8 @@ class SelectFwe(BaseTransformer):
668
680
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
669
681
 
670
682
  if isinstance(dataset, DataFrame):
671
- self._deps = self._batch_inference_validate_snowpark(
672
- dataset=dataset,
673
- inference_method=inference_method,
674
- )
683
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
684
+ self._deps = self._get_dependencies()
675
685
  assert isinstance(
676
686
  dataset._session, Session
677
687
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -733,10 +743,8 @@ class SelectFwe(BaseTransformer):
733
743
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
734
744
 
735
745
  if isinstance(dataset, DataFrame):
736
- self._deps = self._batch_inference_validate_snowpark(
737
- dataset=dataset,
738
- inference_method=inference_method,
739
- )
746
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
747
+ self._deps = self._get_dependencies()
740
748
  assert isinstance(
741
749
  dataset._session, Session
742
750
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -802,10 +810,8 @@ class SelectFwe(BaseTransformer):
802
810
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
803
811
 
804
812
  if isinstance(dataset, DataFrame):
805
- self._deps = self._batch_inference_validate_snowpark(
806
- dataset=dataset,
807
- inference_method=inference_method,
808
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
809
815
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
810
816
  transform_kwargs = dict(
811
817
  session=dataset._session,
@@ -867,17 +873,15 @@ class SelectFwe(BaseTransformer):
867
873
  transform_kwargs: ScoreKwargsTypedDict = dict()
868
874
 
869
875
  if isinstance(dataset, DataFrame):
870
- self._deps = self._batch_inference_validate_snowpark(
871
- dataset=dataset,
872
- inference_method="score",
873
- )
876
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
877
+ self._deps = self._get_dependencies()
874
878
  selected_cols = self._get_active_columns()
875
879
  if len(selected_cols) > 0:
876
880
  dataset = dataset.select(selected_cols)
877
881
  assert isinstance(dataset._session, Session) # keep mypy happy
878
882
  transform_kwargs = dict(
879
883
  session=dataset._session,
880
- dependencies=["snowflake-snowpark-python"] + self._deps,
884
+ dependencies=self._deps,
881
885
  score_sproc_imports=['sklearn'],
882
886
  )
883
887
  elif isinstance(dataset, pd.DataFrame):
@@ -942,11 +946,8 @@ class SelectFwe(BaseTransformer):
942
946
 
943
947
  if isinstance(dataset, DataFrame):
944
948
 
945
- self._deps = self._batch_inference_validate_snowpark(
946
- dataset=dataset,
947
- inference_method=inference_method,
948
-
949
- )
949
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
950
+ self._deps = self._get_dependencies()
950
951
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
951
952
  transform_kwargs = dict(
952
953
  session = dataset._session,
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
61
61
 
62
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
63
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
64
  class SelectKBest(BaseTransformer):
71
65
  r"""Select features according to the k highest scores
72
66
  For more details on this class, see [sklearn.feature_selection.SelectKBest]
@@ -267,20 +261,17 @@ class SelectKBest(BaseTransformer):
267
261
  self,
268
262
  dataset: DataFrame,
269
263
  inference_method: str,
270
- ) -> List[str]:
271
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
272
- return the available package that exists in the snowflake anaconda channel
264
+ ) -> None:
265
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
273
266
 
274
267
  Args:
275
268
  dataset: snowpark dataframe
276
269
  inference_method: the inference method such as predict, score...
277
-
270
+
278
271
  Raises:
279
272
  SnowflakeMLException: If the estimator is not fitted, raise error
280
273
  SnowflakeMLException: If the session is None, raise error
281
274
 
282
- Returns:
283
- A list of available package that exists in the snowflake anaconda channel
284
275
  """
285
276
  if not self._is_fitted:
286
277
  raise exceptions.SnowflakeMLException(
@@ -298,9 +289,7 @@ class SelectKBest(BaseTransformer):
298
289
  "Session must not specified for snowpark dataset."
299
290
  ),
300
291
  )
301
- # Validate that key package version in user workspace are supported in snowflake conda channel
302
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
303
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
292
+
304
293
 
305
294
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
306
295
  @telemetry.send_api_usage_telemetry(
@@ -346,7 +335,8 @@ class SelectKBest(BaseTransformer):
346
335
 
347
336
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
348
337
 
349
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
338
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
339
+ self._deps = self._get_dependencies()
350
340
  assert isinstance(
351
341
  dataset._session, Session
352
342
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -431,10 +421,8 @@ class SelectKBest(BaseTransformer):
431
421
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
432
422
  expected_dtype = convert_sp_to_sf_type(output_types[0])
433
423
 
434
- self._deps = self._batch_inference_validate_snowpark(
435
- dataset=dataset,
436
- inference_method=inference_method,
437
- )
424
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
425
+ self._deps = self._get_dependencies()
438
426
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
439
427
 
440
428
  transform_kwargs = dict(
@@ -501,16 +489,42 @@ class SelectKBest(BaseTransformer):
501
489
  self._is_fitted = True
502
490
  return output_result
503
491
 
492
+
493
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
494
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
495
+ """ Fit to data, then transform it
496
+ For more details on this function, see [sklearn.feature_selection.SelectKBest.fit_transform]
497
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest.fit_transform)
498
+
504
499
 
505
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
506
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
507
- """
500
+ Raises:
501
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
502
+
503
+ Args:
504
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
505
+ Snowpark or Pandas DataFrame.
506
+ output_cols_prefix: Prefix for the response columns
508
507
  Returns:
509
508
  Transformed dataset.
510
509
  """
511
- self.fit(dataset)
512
- assert self._sklearn_object is not None
513
- return self._sklearn_object.embedding_
510
+ self._infer_input_output_cols(dataset)
511
+ super()._check_dataset_type(dataset)
512
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
513
+ estimator=self._sklearn_object,
514
+ dataset=dataset,
515
+ input_cols=self.input_cols,
516
+ label_cols=self.label_cols,
517
+ sample_weight_col=self.sample_weight_col,
518
+ autogenerated=self._autogenerated,
519
+ subproject=_SUBPROJECT,
520
+ )
521
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
522
+ drop_input_cols=self._drop_input_cols,
523
+ expected_output_cols_list=self.output_cols,
524
+ )
525
+ self._sklearn_object = fitted_estimator
526
+ self._is_fitted = True
527
+ return output_result
514
528
 
515
529
 
516
530
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -601,10 +615,8 @@ class SelectKBest(BaseTransformer):
601
615
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
602
616
 
603
617
  if isinstance(dataset, DataFrame):
604
- self._deps = self._batch_inference_validate_snowpark(
605
- dataset=dataset,
606
- inference_method=inference_method,
607
- )
618
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
619
+ self._deps = self._get_dependencies()
608
620
  assert isinstance(
609
621
  dataset._session, Session
610
622
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -669,10 +681,8 @@ class SelectKBest(BaseTransformer):
669
681
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
670
682
 
671
683
  if isinstance(dataset, DataFrame):
672
- self._deps = self._batch_inference_validate_snowpark(
673
- dataset=dataset,
674
- inference_method=inference_method,
675
- )
684
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
685
+ self._deps = self._get_dependencies()
676
686
  assert isinstance(
677
687
  dataset._session, Session
678
688
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -734,10 +744,8 @@ class SelectKBest(BaseTransformer):
734
744
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
735
745
 
736
746
  if isinstance(dataset, DataFrame):
737
- self._deps = self._batch_inference_validate_snowpark(
738
- dataset=dataset,
739
- inference_method=inference_method,
740
- )
747
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
748
+ self._deps = self._get_dependencies()
741
749
  assert isinstance(
742
750
  dataset._session, Session
743
751
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -803,10 +811,8 @@ class SelectKBest(BaseTransformer):
803
811
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
804
812
 
805
813
  if isinstance(dataset, DataFrame):
806
- self._deps = self._batch_inference_validate_snowpark(
807
- dataset=dataset,
808
- inference_method=inference_method,
809
- )
814
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
815
+ self._deps = self._get_dependencies()
810
816
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
811
817
  transform_kwargs = dict(
812
818
  session=dataset._session,
@@ -868,17 +874,15 @@ class SelectKBest(BaseTransformer):
868
874
  transform_kwargs: ScoreKwargsTypedDict = dict()
869
875
 
870
876
  if isinstance(dataset, DataFrame):
871
- self._deps = self._batch_inference_validate_snowpark(
872
- dataset=dataset,
873
- inference_method="score",
874
- )
877
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
878
+ self._deps = self._get_dependencies()
875
879
  selected_cols = self._get_active_columns()
876
880
  if len(selected_cols) > 0:
877
881
  dataset = dataset.select(selected_cols)
878
882
  assert isinstance(dataset._session, Session) # keep mypy happy
879
883
  transform_kwargs = dict(
880
884
  session=dataset._session,
881
- dependencies=["snowflake-snowpark-python"] + self._deps,
885
+ dependencies=self._deps,
882
886
  score_sproc_imports=['sklearn'],
883
887
  )
884
888
  elif isinstance(dataset, pd.DataFrame):
@@ -943,11 +947,8 @@ class SelectKBest(BaseTransformer):
943
947
 
944
948
  if isinstance(dataset, DataFrame):
945
949
 
946
- self._deps = self._batch_inference_validate_snowpark(
947
- dataset=dataset,
948
- inference_method=inference_method,
949
-
950
- )
950
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
951
+ self._deps = self._get_dependencies()
951
952
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
952
953
  transform_kwargs = dict(
953
954
  session = dataset._session,
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
61
61
 
62
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
63
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
64
  class SelectPercentile(BaseTransformer):
71
65
  r"""Select features according to a percentile of the highest scores
72
66
  For more details on this class, see [sklearn.feature_selection.SelectPercentile]
@@ -266,20 +260,17 @@ class SelectPercentile(BaseTransformer):
266
260
  self,
267
261
  dataset: DataFrame,
268
262
  inference_method: str,
269
- ) -> List[str]:
270
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
271
- return the available package that exists in the snowflake anaconda channel
263
+ ) -> None:
264
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
272
265
 
273
266
  Args:
274
267
  dataset: snowpark dataframe
275
268
  inference_method: the inference method such as predict, score...
276
-
269
+
277
270
  Raises:
278
271
  SnowflakeMLException: If the estimator is not fitted, raise error
279
272
  SnowflakeMLException: If the session is None, raise error
280
273
 
281
- Returns:
282
- A list of available package that exists in the snowflake anaconda channel
283
274
  """
284
275
  if not self._is_fitted:
285
276
  raise exceptions.SnowflakeMLException(
@@ -297,9 +288,7 @@ class SelectPercentile(BaseTransformer):
297
288
  "Session must not specified for snowpark dataset."
298
289
  ),
299
290
  )
300
- # Validate that key package version in user workspace are supported in snowflake conda channel
301
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
302
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
+
303
292
 
304
293
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
305
294
  @telemetry.send_api_usage_telemetry(
@@ -345,7 +334,8 @@ class SelectPercentile(BaseTransformer):
345
334
 
346
335
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
347
336
 
348
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
337
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
338
+ self._deps = self._get_dependencies()
349
339
  assert isinstance(
350
340
  dataset._session, Session
351
341
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -430,10 +420,8 @@ class SelectPercentile(BaseTransformer):
430
420
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
431
421
  expected_dtype = convert_sp_to_sf_type(output_types[0])
432
422
 
433
- self._deps = self._batch_inference_validate_snowpark(
434
- dataset=dataset,
435
- inference_method=inference_method,
436
- )
423
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
424
+ self._deps = self._get_dependencies()
437
425
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
438
426
 
439
427
  transform_kwargs = dict(
@@ -500,16 +488,42 @@ class SelectPercentile(BaseTransformer):
500
488
  self._is_fitted = True
501
489
  return output_result
502
490
 
491
+
492
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
493
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
494
+ """ Fit to data, then transform it
495
+ For more details on this function, see [sklearn.feature_selection.SelectPercentile.fit_transform]
496
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectPercentile.html#sklearn.feature_selection.SelectPercentile.fit_transform)
497
+
503
498
 
504
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
505
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
506
- """
499
+ Raises:
500
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
501
+
502
+ Args:
503
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
504
+ Snowpark or Pandas DataFrame.
505
+ output_cols_prefix: Prefix for the response columns
507
506
  Returns:
508
507
  Transformed dataset.
509
508
  """
510
- self.fit(dataset)
511
- assert self._sklearn_object is not None
512
- return self._sklearn_object.embedding_
509
+ self._infer_input_output_cols(dataset)
510
+ super()._check_dataset_type(dataset)
511
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
512
+ estimator=self._sklearn_object,
513
+ dataset=dataset,
514
+ input_cols=self.input_cols,
515
+ label_cols=self.label_cols,
516
+ sample_weight_col=self.sample_weight_col,
517
+ autogenerated=self._autogenerated,
518
+ subproject=_SUBPROJECT,
519
+ )
520
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=self.output_cols,
523
+ )
524
+ self._sklearn_object = fitted_estimator
525
+ self._is_fitted = True
526
+ return output_result
513
527
 
514
528
 
515
529
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -600,10 +614,8 @@ class SelectPercentile(BaseTransformer):
600
614
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
601
615
 
602
616
  if isinstance(dataset, DataFrame):
603
- self._deps = self._batch_inference_validate_snowpark(
604
- dataset=dataset,
605
- inference_method=inference_method,
606
- )
617
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
618
+ self._deps = self._get_dependencies()
607
619
  assert isinstance(
608
620
  dataset._session, Session
609
621
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -668,10 +680,8 @@ class SelectPercentile(BaseTransformer):
668
680
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
669
681
 
670
682
  if isinstance(dataset, DataFrame):
671
- self._deps = self._batch_inference_validate_snowpark(
672
- dataset=dataset,
673
- inference_method=inference_method,
674
- )
683
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
684
+ self._deps = self._get_dependencies()
675
685
  assert isinstance(
676
686
  dataset._session, Session
677
687
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -733,10 +743,8 @@ class SelectPercentile(BaseTransformer):
733
743
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
734
744
 
735
745
  if isinstance(dataset, DataFrame):
736
- self._deps = self._batch_inference_validate_snowpark(
737
- dataset=dataset,
738
- inference_method=inference_method,
739
- )
746
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
747
+ self._deps = self._get_dependencies()
740
748
  assert isinstance(
741
749
  dataset._session, Session
742
750
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -802,10 +810,8 @@ class SelectPercentile(BaseTransformer):
802
810
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
803
811
 
804
812
  if isinstance(dataset, DataFrame):
805
- self._deps = self._batch_inference_validate_snowpark(
806
- dataset=dataset,
807
- inference_method=inference_method,
808
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
809
815
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
810
816
  transform_kwargs = dict(
811
817
  session=dataset._session,
@@ -867,17 +873,15 @@ class SelectPercentile(BaseTransformer):
867
873
  transform_kwargs: ScoreKwargsTypedDict = dict()
868
874
 
869
875
  if isinstance(dataset, DataFrame):
870
- self._deps = self._batch_inference_validate_snowpark(
871
- dataset=dataset,
872
- inference_method="score",
873
- )
876
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
877
+ self._deps = self._get_dependencies()
874
878
  selected_cols = self._get_active_columns()
875
879
  if len(selected_cols) > 0:
876
880
  dataset = dataset.select(selected_cols)
877
881
  assert isinstance(dataset._session, Session) # keep mypy happy
878
882
  transform_kwargs = dict(
879
883
  session=dataset._session,
880
- dependencies=["snowflake-snowpark-python"] + self._deps,
884
+ dependencies=self._deps,
881
885
  score_sproc_imports=['sklearn'],
882
886
  )
883
887
  elif isinstance(dataset, pd.DataFrame):
@@ -942,11 +946,8 @@ class SelectPercentile(BaseTransformer):
942
946
 
943
947
  if isinstance(dataset, DataFrame):
944
948
 
945
- self._deps = self._batch_inference_validate_snowpark(
946
- dataset=dataset,
947
- inference_method=inference_method,
948
-
949
- )
949
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
950
+ self._deps = self._get_dependencies()
950
951
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
951
952
  transform_kwargs = dict(
952
953
  session = dataset._session,