snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class KNNImputer(BaseTransformer):
70
64
  r"""Imputation for completing missing values using k-Nearest Neighbors
71
65
  For more details on this class, see [sklearn.impute.KNNImputer]
@@ -311,20 +305,17 @@ class KNNImputer(BaseTransformer):
311
305
  self,
312
306
  dataset: DataFrame,
313
307
  inference_method: str,
314
- ) -> List[str]:
315
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
316
- return the available package that exists in the snowflake anaconda channel
308
+ ) -> None:
309
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
317
310
 
318
311
  Args:
319
312
  dataset: snowpark dataframe
320
313
  inference_method: the inference method such as predict, score...
321
-
314
+
322
315
  Raises:
323
316
  SnowflakeMLException: If the estimator is not fitted, raise error
324
317
  SnowflakeMLException: If the session is None, raise error
325
318
 
326
- Returns:
327
- A list of available package that exists in the snowflake anaconda channel
328
319
  """
329
320
  if not self._is_fitted:
330
321
  raise exceptions.SnowflakeMLException(
@@ -342,9 +333,7 @@ class KNNImputer(BaseTransformer):
342
333
  "Session must not specified for snowpark dataset."
343
334
  ),
344
335
  )
345
- # Validate that key package version in user workspace are supported in snowflake conda channel
346
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
347
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
336
+
348
337
 
349
338
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
350
339
  @telemetry.send_api_usage_telemetry(
@@ -390,7 +379,8 @@ class KNNImputer(BaseTransformer):
390
379
 
391
380
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
392
381
 
393
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
382
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
383
+ self._deps = self._get_dependencies()
394
384
  assert isinstance(
395
385
  dataset._session, Session
396
386
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -475,10 +465,8 @@ class KNNImputer(BaseTransformer):
475
465
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
476
466
  expected_dtype = convert_sp_to_sf_type(output_types[0])
477
467
 
478
- self._deps = self._batch_inference_validate_snowpark(
479
- dataset=dataset,
480
- inference_method=inference_method,
481
- )
468
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
469
+ self._deps = self._get_dependencies()
482
470
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
483
471
 
484
472
  transform_kwargs = dict(
@@ -545,16 +533,42 @@ class KNNImputer(BaseTransformer):
545
533
  self._is_fitted = True
546
534
  return output_result
547
535
 
536
+
537
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
538
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
539
+ """ Fit to data, then transform it
540
+ For more details on this function, see [sklearn.impute.KNNImputer.fit_transform]
541
+ (https://scikit-learn.org/stable/modules/generated/sklearn.impute.KNNImputer.html#sklearn.impute.KNNImputer.fit_transform)
542
+
548
543
 
549
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
550
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
551
- """
544
+ Raises:
545
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
546
+
547
+ Args:
548
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
549
+ Snowpark or Pandas DataFrame.
550
+ output_cols_prefix: Prefix for the response columns
552
551
  Returns:
553
552
  Transformed dataset.
554
553
  """
555
- self.fit(dataset)
556
- assert self._sklearn_object is not None
557
- return self._sklearn_object.embedding_
554
+ self._infer_input_output_cols(dataset)
555
+ super()._check_dataset_type(dataset)
556
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
557
+ estimator=self._sklearn_object,
558
+ dataset=dataset,
559
+ input_cols=self.input_cols,
560
+ label_cols=self.label_cols,
561
+ sample_weight_col=self.sample_weight_col,
562
+ autogenerated=self._autogenerated,
563
+ subproject=_SUBPROJECT,
564
+ )
565
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
566
+ drop_input_cols=self._drop_input_cols,
567
+ expected_output_cols_list=self.output_cols,
568
+ )
569
+ self._sklearn_object = fitted_estimator
570
+ self._is_fitted = True
571
+ return output_result
558
572
 
559
573
 
560
574
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -645,10 +659,8 @@ class KNNImputer(BaseTransformer):
645
659
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
646
660
 
647
661
  if isinstance(dataset, DataFrame):
648
- self._deps = self._batch_inference_validate_snowpark(
649
- dataset=dataset,
650
- inference_method=inference_method,
651
- )
662
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
663
+ self._deps = self._get_dependencies()
652
664
  assert isinstance(
653
665
  dataset._session, Session
654
666
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -713,10 +725,8 @@ class KNNImputer(BaseTransformer):
713
725
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
714
726
 
715
727
  if isinstance(dataset, DataFrame):
716
- self._deps = self._batch_inference_validate_snowpark(
717
- dataset=dataset,
718
- inference_method=inference_method,
719
- )
728
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
729
+ self._deps = self._get_dependencies()
720
730
  assert isinstance(
721
731
  dataset._session, Session
722
732
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -778,10 +788,8 @@ class KNNImputer(BaseTransformer):
778
788
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
779
789
 
780
790
  if isinstance(dataset, DataFrame):
781
- self._deps = self._batch_inference_validate_snowpark(
782
- dataset=dataset,
783
- inference_method=inference_method,
784
- )
791
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
792
+ self._deps = self._get_dependencies()
785
793
  assert isinstance(
786
794
  dataset._session, Session
787
795
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -847,10 +855,8 @@ class KNNImputer(BaseTransformer):
847
855
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
848
856
 
849
857
  if isinstance(dataset, DataFrame):
850
- self._deps = self._batch_inference_validate_snowpark(
851
- dataset=dataset,
852
- inference_method=inference_method,
853
- )
858
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
859
+ self._deps = self._get_dependencies()
854
860
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
855
861
  transform_kwargs = dict(
856
862
  session=dataset._session,
@@ -912,17 +918,15 @@ class KNNImputer(BaseTransformer):
912
918
  transform_kwargs: ScoreKwargsTypedDict = dict()
913
919
 
914
920
  if isinstance(dataset, DataFrame):
915
- self._deps = self._batch_inference_validate_snowpark(
916
- dataset=dataset,
917
- inference_method="score",
918
- )
921
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
922
+ self._deps = self._get_dependencies()
919
923
  selected_cols = self._get_active_columns()
920
924
  if len(selected_cols) > 0:
921
925
  dataset = dataset.select(selected_cols)
922
926
  assert isinstance(dataset._session, Session) # keep mypy happy
923
927
  transform_kwargs = dict(
924
928
  session=dataset._session,
925
- dependencies=["snowflake-snowpark-python"] + self._deps,
929
+ dependencies=self._deps,
926
930
  score_sproc_imports=['sklearn'],
927
931
  )
928
932
  elif isinstance(dataset, pd.DataFrame):
@@ -987,11 +991,8 @@ class KNNImputer(BaseTransformer):
987
991
 
988
992
  if isinstance(dataset, DataFrame):
989
993
 
990
- self._deps = self._batch_inference_validate_snowpark(
991
- dataset=dataset,
992
- inference_method=inference_method,
993
-
994
- )
994
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
995
+ self._deps = self._get_dependencies()
995
996
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
996
997
  transform_kwargs = dict(
997
998
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MissingIndicator(BaseTransformer):
70
64
  r"""Binary indicators for missing values
71
65
  For more details on this class, see [sklearn.impute.MissingIndicator]
@@ -285,20 +279,17 @@ class MissingIndicator(BaseTransformer):
285
279
  self,
286
280
  dataset: DataFrame,
287
281
  inference_method: str,
288
- ) -> List[str]:
289
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
290
- return the available package that exists in the snowflake anaconda channel
282
+ ) -> None:
283
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
291
284
 
292
285
  Args:
293
286
  dataset: snowpark dataframe
294
287
  inference_method: the inference method such as predict, score...
295
-
288
+
296
289
  Raises:
297
290
  SnowflakeMLException: If the estimator is not fitted, raise error
298
291
  SnowflakeMLException: If the session is None, raise error
299
292
 
300
- Returns:
301
- A list of available package that exists in the snowflake anaconda channel
302
293
  """
303
294
  if not self._is_fitted:
304
295
  raise exceptions.SnowflakeMLException(
@@ -316,9 +307,7 @@ class MissingIndicator(BaseTransformer):
316
307
  "Session must not specified for snowpark dataset."
317
308
  ),
318
309
  )
319
- # Validate that key package version in user workspace are supported in snowflake conda channel
320
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
321
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
310
+
322
311
 
323
312
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
324
313
  @telemetry.send_api_usage_telemetry(
@@ -364,7 +353,8 @@ class MissingIndicator(BaseTransformer):
364
353
 
365
354
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
366
355
 
367
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
356
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
357
+ self._deps = self._get_dependencies()
368
358
  assert isinstance(
369
359
  dataset._session, Session
370
360
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -449,10 +439,8 @@ class MissingIndicator(BaseTransformer):
449
439
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
450
440
  expected_dtype = convert_sp_to_sf_type(output_types[0])
451
441
 
452
- self._deps = self._batch_inference_validate_snowpark(
453
- dataset=dataset,
454
- inference_method=inference_method,
455
- )
442
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
443
+ self._deps = self._get_dependencies()
456
444
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
457
445
 
458
446
  transform_kwargs = dict(
@@ -519,16 +507,42 @@ class MissingIndicator(BaseTransformer):
519
507
  self._is_fitted = True
520
508
  return output_result
521
509
 
510
+
511
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
512
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
513
+ """ Generate missing values indicator for `X`
514
+ For more details on this function, see [sklearn.impute.MissingIndicator.fit_transform]
515
+ (https://scikit-learn.org/stable/modules/generated/sklearn.impute.MissingIndicator.html#sklearn.impute.MissingIndicator.fit_transform)
516
+
522
517
 
523
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
524
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
525
- """
518
+ Raises:
519
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
520
+
521
+ Args:
522
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
523
+ Snowpark or Pandas DataFrame.
524
+ output_cols_prefix: Prefix for the response columns
526
525
  Returns:
527
526
  Transformed dataset.
528
527
  """
529
- self.fit(dataset)
530
- assert self._sklearn_object is not None
531
- return self._sklearn_object.embedding_
528
+ self._infer_input_output_cols(dataset)
529
+ super()._check_dataset_type(dataset)
530
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
531
+ estimator=self._sklearn_object,
532
+ dataset=dataset,
533
+ input_cols=self.input_cols,
534
+ label_cols=self.label_cols,
535
+ sample_weight_col=self.sample_weight_col,
536
+ autogenerated=self._autogenerated,
537
+ subproject=_SUBPROJECT,
538
+ )
539
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
540
+ drop_input_cols=self._drop_input_cols,
541
+ expected_output_cols_list=self.output_cols,
542
+ )
543
+ self._sklearn_object = fitted_estimator
544
+ self._is_fitted = True
545
+ return output_result
532
546
 
533
547
 
534
548
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -619,10 +633,8 @@ class MissingIndicator(BaseTransformer):
619
633
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
620
634
 
621
635
  if isinstance(dataset, DataFrame):
622
- self._deps = self._batch_inference_validate_snowpark(
623
- dataset=dataset,
624
- inference_method=inference_method,
625
- )
636
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
637
+ self._deps = self._get_dependencies()
626
638
  assert isinstance(
627
639
  dataset._session, Session
628
640
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -687,10 +699,8 @@ class MissingIndicator(BaseTransformer):
687
699
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
688
700
 
689
701
  if isinstance(dataset, DataFrame):
690
- self._deps = self._batch_inference_validate_snowpark(
691
- dataset=dataset,
692
- inference_method=inference_method,
693
- )
702
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
703
+ self._deps = self._get_dependencies()
694
704
  assert isinstance(
695
705
  dataset._session, Session
696
706
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -752,10 +762,8 @@ class MissingIndicator(BaseTransformer):
752
762
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
753
763
 
754
764
  if isinstance(dataset, DataFrame):
755
- self._deps = self._batch_inference_validate_snowpark(
756
- dataset=dataset,
757
- inference_method=inference_method,
758
- )
765
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
766
+ self._deps = self._get_dependencies()
759
767
  assert isinstance(
760
768
  dataset._session, Session
761
769
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -821,10 +829,8 @@ class MissingIndicator(BaseTransformer):
821
829
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
822
830
 
823
831
  if isinstance(dataset, DataFrame):
824
- self._deps = self._batch_inference_validate_snowpark(
825
- dataset=dataset,
826
- inference_method=inference_method,
827
- )
832
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
833
+ self._deps = self._get_dependencies()
828
834
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
829
835
  transform_kwargs = dict(
830
836
  session=dataset._session,
@@ -886,17 +892,15 @@ class MissingIndicator(BaseTransformer):
886
892
  transform_kwargs: ScoreKwargsTypedDict = dict()
887
893
 
888
894
  if isinstance(dataset, DataFrame):
889
- self._deps = self._batch_inference_validate_snowpark(
890
- dataset=dataset,
891
- inference_method="score",
892
- )
895
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
896
+ self._deps = self._get_dependencies()
893
897
  selected_cols = self._get_active_columns()
894
898
  if len(selected_cols) > 0:
895
899
  dataset = dataset.select(selected_cols)
896
900
  assert isinstance(dataset._session, Session) # keep mypy happy
897
901
  transform_kwargs = dict(
898
902
  session=dataset._session,
899
- dependencies=["snowflake-snowpark-python"] + self._deps,
903
+ dependencies=self._deps,
900
904
  score_sproc_imports=['sklearn'],
901
905
  )
902
906
  elif isinstance(dataset, pd.DataFrame):
@@ -961,11 +965,8 @@ class MissingIndicator(BaseTransformer):
961
965
 
962
966
  if isinstance(dataset, DataFrame):
963
967
 
964
- self._deps = self._batch_inference_validate_snowpark(
965
- dataset=dataset,
966
- inference_method=inference_method,
967
-
968
- )
968
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
969
+ self._deps = self._get_dependencies()
969
970
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
970
971
  transform_kwargs = dict(
971
972
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class AdditiveChi2Sampler(BaseTransformer):
70
64
  r"""Approximate feature map for additive chi2 kernel
71
65
  For more details on this class, see [sklearn.kernel_approximation.AdditiveChi2Sampler]
@@ -260,20 +254,17 @@ class AdditiveChi2Sampler(BaseTransformer):
260
254
  self,
261
255
  dataset: DataFrame,
262
256
  inference_method: str,
263
- ) -> List[str]:
264
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
265
- return the available package that exists in the snowflake anaconda channel
257
+ ) -> None:
258
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
266
259
 
267
260
  Args:
268
261
  dataset: snowpark dataframe
269
262
  inference_method: the inference method such as predict, score...
270
-
263
+
271
264
  Raises:
272
265
  SnowflakeMLException: If the estimator is not fitted, raise error
273
266
  SnowflakeMLException: If the session is None, raise error
274
267
 
275
- Returns:
276
- A list of available package that exists in the snowflake anaconda channel
277
268
  """
278
269
  if not self._is_fitted:
279
270
  raise exceptions.SnowflakeMLException(
@@ -291,9 +282,7 @@ class AdditiveChi2Sampler(BaseTransformer):
291
282
  "Session must not specified for snowpark dataset."
292
283
  ),
293
284
  )
294
- # Validate that key package version in user workspace are supported in snowflake conda channel
295
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
296
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
285
+
297
286
 
298
287
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
299
288
  @telemetry.send_api_usage_telemetry(
@@ -339,7 +328,8 @@ class AdditiveChi2Sampler(BaseTransformer):
339
328
 
340
329
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
341
330
 
342
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
331
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
332
+ self._deps = self._get_dependencies()
343
333
  assert isinstance(
344
334
  dataset._session, Session
345
335
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -424,10 +414,8 @@ class AdditiveChi2Sampler(BaseTransformer):
424
414
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
425
415
  expected_dtype = convert_sp_to_sf_type(output_types[0])
426
416
 
427
- self._deps = self._batch_inference_validate_snowpark(
428
- dataset=dataset,
429
- inference_method=inference_method,
430
- )
417
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
418
+ self._deps = self._get_dependencies()
431
419
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
432
420
 
433
421
  transform_kwargs = dict(
@@ -494,16 +482,42 @@ class AdditiveChi2Sampler(BaseTransformer):
494
482
  self._is_fitted = True
495
483
  return output_result
496
484
 
485
+
486
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
487
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
488
+ """ Fit to data, then transform it
489
+ For more details on this function, see [sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform]
490
+ (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.AdditiveChi2Sampler.html#sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform)
491
+
497
492
 
498
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
499
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
500
- """
493
+ Raises:
494
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
495
+
496
+ Args:
497
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
498
+ Snowpark or Pandas DataFrame.
499
+ output_cols_prefix: Prefix for the response columns
501
500
  Returns:
502
501
  Transformed dataset.
503
502
  """
504
- self.fit(dataset)
505
- assert self._sklearn_object is not None
506
- return self._sklearn_object.embedding_
503
+ self._infer_input_output_cols(dataset)
504
+ super()._check_dataset_type(dataset)
505
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
506
+ estimator=self._sklearn_object,
507
+ dataset=dataset,
508
+ input_cols=self.input_cols,
509
+ label_cols=self.label_cols,
510
+ sample_weight_col=self.sample_weight_col,
511
+ autogenerated=self._autogenerated,
512
+ subproject=_SUBPROJECT,
513
+ )
514
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=self.output_cols,
517
+ )
518
+ self._sklearn_object = fitted_estimator
519
+ self._is_fitted = True
520
+ return output_result
507
521
 
508
522
 
509
523
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -594,10 +608,8 @@ class AdditiveChi2Sampler(BaseTransformer):
594
608
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
595
609
 
596
610
  if isinstance(dataset, DataFrame):
597
- self._deps = self._batch_inference_validate_snowpark(
598
- dataset=dataset,
599
- inference_method=inference_method,
600
- )
611
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
612
+ self._deps = self._get_dependencies()
601
613
  assert isinstance(
602
614
  dataset._session, Session
603
615
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -662,10 +674,8 @@ class AdditiveChi2Sampler(BaseTransformer):
662
674
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
663
675
 
664
676
  if isinstance(dataset, DataFrame):
665
- self._deps = self._batch_inference_validate_snowpark(
666
- dataset=dataset,
667
- inference_method=inference_method,
668
- )
677
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
678
+ self._deps = self._get_dependencies()
669
679
  assert isinstance(
670
680
  dataset._session, Session
671
681
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -727,10 +737,8 @@ class AdditiveChi2Sampler(BaseTransformer):
727
737
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
728
738
 
729
739
  if isinstance(dataset, DataFrame):
730
- self._deps = self._batch_inference_validate_snowpark(
731
- dataset=dataset,
732
- inference_method=inference_method,
733
- )
740
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
741
+ self._deps = self._get_dependencies()
734
742
  assert isinstance(
735
743
  dataset._session, Session
736
744
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -796,10 +804,8 @@ class AdditiveChi2Sampler(BaseTransformer):
796
804
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
797
805
 
798
806
  if isinstance(dataset, DataFrame):
799
- self._deps = self._batch_inference_validate_snowpark(
800
- dataset=dataset,
801
- inference_method=inference_method,
802
- )
807
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
808
+ self._deps = self._get_dependencies()
803
809
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
804
810
  transform_kwargs = dict(
805
811
  session=dataset._session,
@@ -861,17 +867,15 @@ class AdditiveChi2Sampler(BaseTransformer):
861
867
  transform_kwargs: ScoreKwargsTypedDict = dict()
862
868
 
863
869
  if isinstance(dataset, DataFrame):
864
- self._deps = self._batch_inference_validate_snowpark(
865
- dataset=dataset,
866
- inference_method="score",
867
- )
870
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
871
+ self._deps = self._get_dependencies()
868
872
  selected_cols = self._get_active_columns()
869
873
  if len(selected_cols) > 0:
870
874
  dataset = dataset.select(selected_cols)
871
875
  assert isinstance(dataset._session, Session) # keep mypy happy
872
876
  transform_kwargs = dict(
873
877
  session=dataset._session,
874
- dependencies=["snowflake-snowpark-python"] + self._deps,
878
+ dependencies=self._deps,
875
879
  score_sproc_imports=['sklearn'],
876
880
  )
877
881
  elif isinstance(dataset, pd.DataFrame):
@@ -936,11 +940,8 @@ class AdditiveChi2Sampler(BaseTransformer):
936
940
 
937
941
  if isinstance(dataset, DataFrame):
938
942
 
939
- self._deps = self._batch_inference_validate_snowpark(
940
- dataset=dataset,
941
- inference_method=inference_method,
942
-
943
- )
943
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
944
+ self._deps = self._get_dependencies()
944
945
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
945
946
  transform_kwargs = dict(
946
947
  session = dataset._session,