snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
61
61
 
62
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
63
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
64
  class GenericUnivariateSelect(BaseTransformer):
71
65
  r"""Univariate feature selector with configurable strategy
72
66
  For more details on this class, see [sklearn.feature_selection.GenericUnivariateSelect]
@@ -270,20 +264,17 @@ class GenericUnivariateSelect(BaseTransformer):
270
264
  self,
271
265
  dataset: DataFrame,
272
266
  inference_method: str,
273
- ) -> List[str]:
274
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
275
- return the available package that exists in the snowflake anaconda channel
267
+ ) -> None:
268
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
276
269
 
277
270
  Args:
278
271
  dataset: snowpark dataframe
279
272
  inference_method: the inference method such as predict, score...
280
-
273
+
281
274
  Raises:
282
275
  SnowflakeMLException: If the estimator is not fitted, raise error
283
276
  SnowflakeMLException: If the session is None, raise error
284
277
 
285
- Returns:
286
- A list of available package that exists in the snowflake anaconda channel
287
278
  """
288
279
  if not self._is_fitted:
289
280
  raise exceptions.SnowflakeMLException(
@@ -301,9 +292,7 @@ class GenericUnivariateSelect(BaseTransformer):
301
292
  "Session must not specified for snowpark dataset."
302
293
  ),
303
294
  )
304
- # Validate that key package version in user workspace are supported in snowflake conda channel
305
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
306
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
295
+
307
296
 
308
297
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
309
298
  @telemetry.send_api_usage_telemetry(
@@ -349,7 +338,8 @@ class GenericUnivariateSelect(BaseTransformer):
349
338
 
350
339
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
351
340
 
352
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
341
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
342
+ self._deps = self._get_dependencies()
353
343
  assert isinstance(
354
344
  dataset._session, Session
355
345
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -434,10 +424,8 @@ class GenericUnivariateSelect(BaseTransformer):
434
424
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
435
425
  expected_dtype = convert_sp_to_sf_type(output_types[0])
436
426
 
437
- self._deps = self._batch_inference_validate_snowpark(
438
- dataset=dataset,
439
- inference_method=inference_method,
440
- )
427
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
428
+ self._deps = self._get_dependencies()
441
429
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
442
430
 
443
431
  transform_kwargs = dict(
@@ -504,16 +492,42 @@ class GenericUnivariateSelect(BaseTransformer):
504
492
  self._is_fitted = True
505
493
  return output_result
506
494
 
495
+
496
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
497
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
498
+ """ Fit to data, then transform it
499
+ For more details on this function, see [sklearn.feature_selection.GenericUnivariateSelect.fit_transform]
500
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.GenericUnivariateSelect.html#sklearn.feature_selection.GenericUnivariateSelect.fit_transform)
501
+
507
502
 
508
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
509
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
510
- """
503
+ Raises:
504
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
505
+
506
+ Args:
507
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
508
+ Snowpark or Pandas DataFrame.
509
+ output_cols_prefix: Prefix for the response columns
511
510
  Returns:
512
511
  Transformed dataset.
513
512
  """
514
- self.fit(dataset)
515
- assert self._sklearn_object is not None
516
- return self._sklearn_object.embedding_
513
+ self._infer_input_output_cols(dataset)
514
+ super()._check_dataset_type(dataset)
515
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
516
+ estimator=self._sklearn_object,
517
+ dataset=dataset,
518
+ input_cols=self.input_cols,
519
+ label_cols=self.label_cols,
520
+ sample_weight_col=self.sample_weight_col,
521
+ autogenerated=self._autogenerated,
522
+ subproject=_SUBPROJECT,
523
+ )
524
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
525
+ drop_input_cols=self._drop_input_cols,
526
+ expected_output_cols_list=self.output_cols,
527
+ )
528
+ self._sklearn_object = fitted_estimator
529
+ self._is_fitted = True
530
+ return output_result
517
531
 
518
532
 
519
533
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -604,10 +618,8 @@ class GenericUnivariateSelect(BaseTransformer):
604
618
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
605
619
 
606
620
  if isinstance(dataset, DataFrame):
607
- self._deps = self._batch_inference_validate_snowpark(
608
- dataset=dataset,
609
- inference_method=inference_method,
610
- )
621
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
622
+ self._deps = self._get_dependencies()
611
623
  assert isinstance(
612
624
  dataset._session, Session
613
625
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -672,10 +684,8 @@ class GenericUnivariateSelect(BaseTransformer):
672
684
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
673
685
 
674
686
  if isinstance(dataset, DataFrame):
675
- self._deps = self._batch_inference_validate_snowpark(
676
- dataset=dataset,
677
- inference_method=inference_method,
678
- )
687
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
688
+ self._deps = self._get_dependencies()
679
689
  assert isinstance(
680
690
  dataset._session, Session
681
691
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -737,10 +747,8 @@ class GenericUnivariateSelect(BaseTransformer):
737
747
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
738
748
 
739
749
  if isinstance(dataset, DataFrame):
740
- self._deps = self._batch_inference_validate_snowpark(
741
- dataset=dataset,
742
- inference_method=inference_method,
743
- )
750
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
751
+ self._deps = self._get_dependencies()
744
752
  assert isinstance(
745
753
  dataset._session, Session
746
754
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -806,10 +814,8 @@ class GenericUnivariateSelect(BaseTransformer):
806
814
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
807
815
 
808
816
  if isinstance(dataset, DataFrame):
809
- self._deps = self._batch_inference_validate_snowpark(
810
- dataset=dataset,
811
- inference_method=inference_method,
812
- )
817
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
818
+ self._deps = self._get_dependencies()
813
819
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
814
820
  transform_kwargs = dict(
815
821
  session=dataset._session,
@@ -871,17 +877,15 @@ class GenericUnivariateSelect(BaseTransformer):
871
877
  transform_kwargs: ScoreKwargsTypedDict = dict()
872
878
 
873
879
  if isinstance(dataset, DataFrame):
874
- self._deps = self._batch_inference_validate_snowpark(
875
- dataset=dataset,
876
- inference_method="score",
877
- )
880
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
881
+ self._deps = self._get_dependencies()
878
882
  selected_cols = self._get_active_columns()
879
883
  if len(selected_cols) > 0:
880
884
  dataset = dataset.select(selected_cols)
881
885
  assert isinstance(dataset._session, Session) # keep mypy happy
882
886
  transform_kwargs = dict(
883
887
  session=dataset._session,
884
- dependencies=["snowflake-snowpark-python"] + self._deps,
888
+ dependencies=self._deps,
885
889
  score_sproc_imports=['sklearn'],
886
890
  )
887
891
  elif isinstance(dataset, pd.DataFrame):
@@ -946,11 +950,8 @@ class GenericUnivariateSelect(BaseTransformer):
946
950
 
947
951
  if isinstance(dataset, DataFrame):
948
952
 
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method=inference_method,
952
-
953
- )
953
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
954
+ self._deps = self._get_dependencies()
954
955
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
955
956
  transform_kwargs = dict(
956
957
  session = dataset._session,
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
61
61
 
62
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
63
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
64
  class SelectFdr(BaseTransformer):
71
65
  r"""Filter: Select the p-values for an estimated false discovery rate
72
66
  For more details on this class, see [sklearn.feature_selection.SelectFdr]
@@ -266,20 +260,17 @@ class SelectFdr(BaseTransformer):
266
260
  self,
267
261
  dataset: DataFrame,
268
262
  inference_method: str,
269
- ) -> List[str]:
270
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
271
- return the available package that exists in the snowflake anaconda channel
263
+ ) -> None:
264
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
272
265
 
273
266
  Args:
274
267
  dataset: snowpark dataframe
275
268
  inference_method: the inference method such as predict, score...
276
-
269
+
277
270
  Raises:
278
271
  SnowflakeMLException: If the estimator is not fitted, raise error
279
272
  SnowflakeMLException: If the session is None, raise error
280
273
 
281
- Returns:
282
- A list of available package that exists in the snowflake anaconda channel
283
274
  """
284
275
  if not self._is_fitted:
285
276
  raise exceptions.SnowflakeMLException(
@@ -297,9 +288,7 @@ class SelectFdr(BaseTransformer):
297
288
  "Session must not specified for snowpark dataset."
298
289
  ),
299
290
  )
300
- # Validate that key package version in user workspace are supported in snowflake conda channel
301
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
302
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
+
303
292
 
304
293
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
305
294
  @telemetry.send_api_usage_telemetry(
@@ -345,7 +334,8 @@ class SelectFdr(BaseTransformer):
345
334
 
346
335
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
347
336
 
348
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
337
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
338
+ self._deps = self._get_dependencies()
349
339
  assert isinstance(
350
340
  dataset._session, Session
351
341
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -430,10 +420,8 @@ class SelectFdr(BaseTransformer):
430
420
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
431
421
  expected_dtype = convert_sp_to_sf_type(output_types[0])
432
422
 
433
- self._deps = self._batch_inference_validate_snowpark(
434
- dataset=dataset,
435
- inference_method=inference_method,
436
- )
423
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
424
+ self._deps = self._get_dependencies()
437
425
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
438
426
 
439
427
  transform_kwargs = dict(
@@ -500,16 +488,42 @@ class SelectFdr(BaseTransformer):
500
488
  self._is_fitted = True
501
489
  return output_result
502
490
 
491
+
492
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
493
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
494
+ """ Fit to data, then transform it
495
+ For more details on this function, see [sklearn.feature_selection.SelectFdr.fit_transform]
496
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFdr.html#sklearn.feature_selection.SelectFdr.fit_transform)
497
+
503
498
 
504
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
505
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
506
- """
499
+ Raises:
500
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
501
+
502
+ Args:
503
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
504
+ Snowpark or Pandas DataFrame.
505
+ output_cols_prefix: Prefix for the response columns
507
506
  Returns:
508
507
  Transformed dataset.
509
508
  """
510
- self.fit(dataset)
511
- assert self._sklearn_object is not None
512
- return self._sklearn_object.embedding_
509
+ self._infer_input_output_cols(dataset)
510
+ super()._check_dataset_type(dataset)
511
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
512
+ estimator=self._sklearn_object,
513
+ dataset=dataset,
514
+ input_cols=self.input_cols,
515
+ label_cols=self.label_cols,
516
+ sample_weight_col=self.sample_weight_col,
517
+ autogenerated=self._autogenerated,
518
+ subproject=_SUBPROJECT,
519
+ )
520
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=self.output_cols,
523
+ )
524
+ self._sklearn_object = fitted_estimator
525
+ self._is_fitted = True
526
+ return output_result
513
527
 
514
528
 
515
529
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -600,10 +614,8 @@ class SelectFdr(BaseTransformer):
600
614
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
601
615
 
602
616
  if isinstance(dataset, DataFrame):
603
- self._deps = self._batch_inference_validate_snowpark(
604
- dataset=dataset,
605
- inference_method=inference_method,
606
- )
617
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
618
+ self._deps = self._get_dependencies()
607
619
  assert isinstance(
608
620
  dataset._session, Session
609
621
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -668,10 +680,8 @@ class SelectFdr(BaseTransformer):
668
680
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
669
681
 
670
682
  if isinstance(dataset, DataFrame):
671
- self._deps = self._batch_inference_validate_snowpark(
672
- dataset=dataset,
673
- inference_method=inference_method,
674
- )
683
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
684
+ self._deps = self._get_dependencies()
675
685
  assert isinstance(
676
686
  dataset._session, Session
677
687
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -733,10 +743,8 @@ class SelectFdr(BaseTransformer):
733
743
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
734
744
 
735
745
  if isinstance(dataset, DataFrame):
736
- self._deps = self._batch_inference_validate_snowpark(
737
- dataset=dataset,
738
- inference_method=inference_method,
739
- )
746
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
747
+ self._deps = self._get_dependencies()
740
748
  assert isinstance(
741
749
  dataset._session, Session
742
750
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -802,10 +810,8 @@ class SelectFdr(BaseTransformer):
802
810
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
803
811
 
804
812
  if isinstance(dataset, DataFrame):
805
- self._deps = self._batch_inference_validate_snowpark(
806
- dataset=dataset,
807
- inference_method=inference_method,
808
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
809
815
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
810
816
  transform_kwargs = dict(
811
817
  session=dataset._session,
@@ -867,17 +873,15 @@ class SelectFdr(BaseTransformer):
867
873
  transform_kwargs: ScoreKwargsTypedDict = dict()
868
874
 
869
875
  if isinstance(dataset, DataFrame):
870
- self._deps = self._batch_inference_validate_snowpark(
871
- dataset=dataset,
872
- inference_method="score",
873
- )
876
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
877
+ self._deps = self._get_dependencies()
874
878
  selected_cols = self._get_active_columns()
875
879
  if len(selected_cols) > 0:
876
880
  dataset = dataset.select(selected_cols)
877
881
  assert isinstance(dataset._session, Session) # keep mypy happy
878
882
  transform_kwargs = dict(
879
883
  session=dataset._session,
880
- dependencies=["snowflake-snowpark-python"] + self._deps,
884
+ dependencies=self._deps,
881
885
  score_sproc_imports=['sklearn'],
882
886
  )
883
887
  elif isinstance(dataset, pd.DataFrame):
@@ -942,11 +946,8 @@ class SelectFdr(BaseTransformer):
942
946
 
943
947
  if isinstance(dataset, DataFrame):
944
948
 
945
- self._deps = self._batch_inference_validate_snowpark(
946
- dataset=dataset,
947
- inference_method=inference_method,
948
-
949
- )
949
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
950
+ self._deps = self._get_dependencies()
950
951
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
951
952
  transform_kwargs = dict(
952
953
  session = dataset._session,
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
61
61
 
62
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
63
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
64
  class SelectFpr(BaseTransformer):
71
65
  r"""Filter: Select the pvalues below alpha based on a FPR test
72
66
  For more details on this class, see [sklearn.feature_selection.SelectFpr]
@@ -266,20 +260,17 @@ class SelectFpr(BaseTransformer):
266
260
  self,
267
261
  dataset: DataFrame,
268
262
  inference_method: str,
269
- ) -> List[str]:
270
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
271
- return the available package that exists in the snowflake anaconda channel
263
+ ) -> None:
264
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
272
265
 
273
266
  Args:
274
267
  dataset: snowpark dataframe
275
268
  inference_method: the inference method such as predict, score...
276
-
269
+
277
270
  Raises:
278
271
  SnowflakeMLException: If the estimator is not fitted, raise error
279
272
  SnowflakeMLException: If the session is None, raise error
280
273
 
281
- Returns:
282
- A list of available package that exists in the snowflake anaconda channel
283
274
  """
284
275
  if not self._is_fitted:
285
276
  raise exceptions.SnowflakeMLException(
@@ -297,9 +288,7 @@ class SelectFpr(BaseTransformer):
297
288
  "Session must not specified for snowpark dataset."
298
289
  ),
299
290
  )
300
- # Validate that key package version in user workspace are supported in snowflake conda channel
301
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
302
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
+
303
292
 
304
293
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
305
294
  @telemetry.send_api_usage_telemetry(
@@ -345,7 +334,8 @@ class SelectFpr(BaseTransformer):
345
334
 
346
335
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
347
336
 
348
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
337
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
338
+ self._deps = self._get_dependencies()
349
339
  assert isinstance(
350
340
  dataset._session, Session
351
341
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -430,10 +420,8 @@ class SelectFpr(BaseTransformer):
430
420
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
431
421
  expected_dtype = convert_sp_to_sf_type(output_types[0])
432
422
 
433
- self._deps = self._batch_inference_validate_snowpark(
434
- dataset=dataset,
435
- inference_method=inference_method,
436
- )
423
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
424
+ self._deps = self._get_dependencies()
437
425
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
438
426
 
439
427
  transform_kwargs = dict(
@@ -500,16 +488,42 @@ class SelectFpr(BaseTransformer):
500
488
  self._is_fitted = True
501
489
  return output_result
502
490
 
491
+
492
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
493
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
494
+ """ Fit to data, then transform it
495
+ For more details on this function, see [sklearn.feature_selection.SelectFpr.fit_transform]
496
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFpr.html#sklearn.feature_selection.SelectFpr.fit_transform)
497
+
503
498
 
504
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
505
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
506
- """
499
+ Raises:
500
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
501
+
502
+ Args:
503
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
504
+ Snowpark or Pandas DataFrame.
505
+ output_cols_prefix: Prefix for the response columns
507
506
  Returns:
508
507
  Transformed dataset.
509
508
  """
510
- self.fit(dataset)
511
- assert self._sklearn_object is not None
512
- return self._sklearn_object.embedding_
509
+ self._infer_input_output_cols(dataset)
510
+ super()._check_dataset_type(dataset)
511
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
512
+ estimator=self._sklearn_object,
513
+ dataset=dataset,
514
+ input_cols=self.input_cols,
515
+ label_cols=self.label_cols,
516
+ sample_weight_col=self.sample_weight_col,
517
+ autogenerated=self._autogenerated,
518
+ subproject=_SUBPROJECT,
519
+ )
520
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=self.output_cols,
523
+ )
524
+ self._sklearn_object = fitted_estimator
525
+ self._is_fitted = True
526
+ return output_result
513
527
 
514
528
 
515
529
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -600,10 +614,8 @@ class SelectFpr(BaseTransformer):
600
614
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
601
615
 
602
616
  if isinstance(dataset, DataFrame):
603
- self._deps = self._batch_inference_validate_snowpark(
604
- dataset=dataset,
605
- inference_method=inference_method,
606
- )
617
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
618
+ self._deps = self._get_dependencies()
607
619
  assert isinstance(
608
620
  dataset._session, Session
609
621
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -668,10 +680,8 @@ class SelectFpr(BaseTransformer):
668
680
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
669
681
 
670
682
  if isinstance(dataset, DataFrame):
671
- self._deps = self._batch_inference_validate_snowpark(
672
- dataset=dataset,
673
- inference_method=inference_method,
674
- )
683
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
684
+ self._deps = self._get_dependencies()
675
685
  assert isinstance(
676
686
  dataset._session, Session
677
687
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -733,10 +743,8 @@ class SelectFpr(BaseTransformer):
733
743
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
734
744
 
735
745
  if isinstance(dataset, DataFrame):
736
- self._deps = self._batch_inference_validate_snowpark(
737
- dataset=dataset,
738
- inference_method=inference_method,
739
- )
746
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
747
+ self._deps = self._get_dependencies()
740
748
  assert isinstance(
741
749
  dataset._session, Session
742
750
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -802,10 +810,8 @@ class SelectFpr(BaseTransformer):
802
810
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
803
811
 
804
812
  if isinstance(dataset, DataFrame):
805
- self._deps = self._batch_inference_validate_snowpark(
806
- dataset=dataset,
807
- inference_method=inference_method,
808
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
809
815
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
810
816
  transform_kwargs = dict(
811
817
  session=dataset._session,
@@ -867,17 +873,15 @@ class SelectFpr(BaseTransformer):
867
873
  transform_kwargs: ScoreKwargsTypedDict = dict()
868
874
 
869
875
  if isinstance(dataset, DataFrame):
870
- self._deps = self._batch_inference_validate_snowpark(
871
- dataset=dataset,
872
- inference_method="score",
873
- )
876
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
877
+ self._deps = self._get_dependencies()
874
878
  selected_cols = self._get_active_columns()
875
879
  if len(selected_cols) > 0:
876
880
  dataset = dataset.select(selected_cols)
877
881
  assert isinstance(dataset._session, Session) # keep mypy happy
878
882
  transform_kwargs = dict(
879
883
  session=dataset._session,
880
- dependencies=["snowflake-snowpark-python"] + self._deps,
884
+ dependencies=self._deps,
881
885
  score_sproc_imports=['sklearn'],
882
886
  )
883
887
  elif isinstance(dataset, pd.DataFrame):
@@ -942,11 +946,8 @@ class SelectFpr(BaseTransformer):
942
946
 
943
947
  if isinstance(dataset, DataFrame):
944
948
 
945
- self._deps = self._batch_inference_validate_snowpark(
946
- dataset=dataset,
947
- inference_method=inference_method,
948
-
949
- )
949
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
950
+ self._deps = self._get_dependencies()
950
951
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
951
952
  transform_kwargs = dict(
952
953
  session = dataset._session,