snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class BisectingKMeans(BaseTransformer):
70
64
  r"""Bisecting K-Means clustering
71
65
  For more details on this class, see [sklearn.cluster.BisectingKMeans]
@@ -343,20 +337,17 @@ class BisectingKMeans(BaseTransformer):
343
337
  self,
344
338
  dataset: DataFrame,
345
339
  inference_method: str,
346
- ) -> List[str]:
347
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
348
- return the available package that exists in the snowflake anaconda channel
340
+ ) -> None:
341
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
349
342
 
350
343
  Args:
351
344
  dataset: snowpark dataframe
352
345
  inference_method: the inference method such as predict, score...
353
-
346
+
354
347
  Raises:
355
348
  SnowflakeMLException: If the estimator is not fitted, raise error
356
349
  SnowflakeMLException: If the session is None, raise error
357
350
 
358
- Returns:
359
- A list of available package that exists in the snowflake anaconda channel
360
351
  """
361
352
  if not self._is_fitted:
362
353
  raise exceptions.SnowflakeMLException(
@@ -374,9 +365,7 @@ class BisectingKMeans(BaseTransformer):
374
365
  "Session must not specified for snowpark dataset."
375
366
  ),
376
367
  )
377
- # Validate that key package version in user workspace are supported in snowflake conda channel
378
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
379
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
368
+
380
369
 
381
370
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
382
371
  @telemetry.send_api_usage_telemetry(
@@ -424,7 +413,8 @@ class BisectingKMeans(BaseTransformer):
424
413
 
425
414
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
426
415
 
427
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
416
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
417
+ self._deps = self._get_dependencies()
428
418
  assert isinstance(
429
419
  dataset._session, Session
430
420
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -509,10 +499,8 @@ class BisectingKMeans(BaseTransformer):
509
499
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
510
500
  expected_dtype = convert_sp_to_sf_type(output_types[0])
511
501
 
512
- self._deps = self._batch_inference_validate_snowpark(
513
- dataset=dataset,
514
- inference_method=inference_method,
515
- )
502
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
503
+ self._deps = self._get_dependencies()
516
504
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
517
505
 
518
506
  transform_kwargs = dict(
@@ -581,16 +569,42 @@ class BisectingKMeans(BaseTransformer):
581
569
  self._is_fitted = True
582
570
  return output_result
583
571
 
572
+
573
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
574
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
575
+ """ Compute clustering and transform X to cluster-distance space
576
+ For more details on this function, see [sklearn.cluster.BisectingKMeans.fit_transform]
577
+ (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.BisectingKMeans.html#sklearn.cluster.BisectingKMeans.fit_transform)
578
+
584
579
 
585
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
586
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
587
- """
580
+ Raises:
581
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
582
+
583
+ Args:
584
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
585
+ Snowpark or Pandas DataFrame.
586
+ output_cols_prefix: Prefix for the response columns
588
587
  Returns:
589
588
  Transformed dataset.
590
589
  """
591
- self.fit(dataset)
592
- assert self._sklearn_object is not None
593
- return self._sklearn_object.embedding_
590
+ self._infer_input_output_cols(dataset)
591
+ super()._check_dataset_type(dataset)
592
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
593
+ estimator=self._sklearn_object,
594
+ dataset=dataset,
595
+ input_cols=self.input_cols,
596
+ label_cols=self.label_cols,
597
+ sample_weight_col=self.sample_weight_col,
598
+ autogenerated=self._autogenerated,
599
+ subproject=_SUBPROJECT,
600
+ )
601
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
602
+ drop_input_cols=self._drop_input_cols,
603
+ expected_output_cols_list=self.output_cols,
604
+ )
605
+ self._sklearn_object = fitted_estimator
606
+ self._is_fitted = True
607
+ return output_result
594
608
 
595
609
 
596
610
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -681,10 +695,8 @@ class BisectingKMeans(BaseTransformer):
681
695
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
682
696
 
683
697
  if isinstance(dataset, DataFrame):
684
- self._deps = self._batch_inference_validate_snowpark(
685
- dataset=dataset,
686
- inference_method=inference_method,
687
- )
698
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
699
+ self._deps = self._get_dependencies()
688
700
  assert isinstance(
689
701
  dataset._session, Session
690
702
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -749,10 +761,8 @@ class BisectingKMeans(BaseTransformer):
749
761
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
750
762
 
751
763
  if isinstance(dataset, DataFrame):
752
- self._deps = self._batch_inference_validate_snowpark(
753
- dataset=dataset,
754
- inference_method=inference_method,
755
- )
764
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
765
+ self._deps = self._get_dependencies()
756
766
  assert isinstance(
757
767
  dataset._session, Session
758
768
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -814,10 +824,8 @@ class BisectingKMeans(BaseTransformer):
814
824
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
815
825
 
816
826
  if isinstance(dataset, DataFrame):
817
- self._deps = self._batch_inference_validate_snowpark(
818
- dataset=dataset,
819
- inference_method=inference_method,
820
- )
827
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
828
+ self._deps = self._get_dependencies()
821
829
  assert isinstance(
822
830
  dataset._session, Session
823
831
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -883,10 +891,8 @@ class BisectingKMeans(BaseTransformer):
883
891
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
884
892
 
885
893
  if isinstance(dataset, DataFrame):
886
- self._deps = self._batch_inference_validate_snowpark(
887
- dataset=dataset,
888
- inference_method=inference_method,
889
- )
894
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
895
+ self._deps = self._get_dependencies()
890
896
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
891
897
  transform_kwargs = dict(
892
898
  session=dataset._session,
@@ -950,17 +956,15 @@ class BisectingKMeans(BaseTransformer):
950
956
  transform_kwargs: ScoreKwargsTypedDict = dict()
951
957
 
952
958
  if isinstance(dataset, DataFrame):
953
- self._deps = self._batch_inference_validate_snowpark(
954
- dataset=dataset,
955
- inference_method="score",
956
- )
959
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
960
+ self._deps = self._get_dependencies()
957
961
  selected_cols = self._get_active_columns()
958
962
  if len(selected_cols) > 0:
959
963
  dataset = dataset.select(selected_cols)
960
964
  assert isinstance(dataset._session, Session) # keep mypy happy
961
965
  transform_kwargs = dict(
962
966
  session=dataset._session,
963
- dependencies=["snowflake-snowpark-python"] + self._deps,
967
+ dependencies=self._deps,
964
968
  score_sproc_imports=['sklearn'],
965
969
  )
966
970
  elif isinstance(dataset, pd.DataFrame):
@@ -1025,11 +1029,8 @@ class BisectingKMeans(BaseTransformer):
1025
1029
 
1026
1030
  if isinstance(dataset, DataFrame):
1027
1031
 
1028
- self._deps = self._batch_inference_validate_snowpark(
1029
- dataset=dataset,
1030
- inference_method=inference_method,
1031
-
1032
- )
1032
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1033
+ self._deps = self._get_dependencies()
1033
1034
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1034
1035
  transform_kwargs = dict(
1035
1036
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class DBSCAN(BaseTransformer):
70
64
  r"""Perform DBSCAN clustering from vector array or distance matrix
71
65
  For more details on this class, see [sklearn.cluster.DBSCAN]
@@ -311,20 +305,17 @@ class DBSCAN(BaseTransformer):
311
305
  self,
312
306
  dataset: DataFrame,
313
307
  inference_method: str,
314
- ) -> List[str]:
315
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
316
- return the available package that exists in the snowflake anaconda channel
308
+ ) -> None:
309
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
317
310
 
318
311
  Args:
319
312
  dataset: snowpark dataframe
320
313
  inference_method: the inference method such as predict, score...
321
-
314
+
322
315
  Raises:
323
316
  SnowflakeMLException: If the estimator is not fitted, raise error
324
317
  SnowflakeMLException: If the session is None, raise error
325
318
 
326
- Returns:
327
- A list of available package that exists in the snowflake anaconda channel
328
319
  """
329
320
  if not self._is_fitted:
330
321
  raise exceptions.SnowflakeMLException(
@@ -342,9 +333,7 @@ class DBSCAN(BaseTransformer):
342
333
  "Session must not specified for snowpark dataset."
343
334
  ),
344
335
  )
345
- # Validate that key package version in user workspace are supported in snowflake conda channel
346
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
347
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
336
+
348
337
 
349
338
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
350
339
  @telemetry.send_api_usage_telemetry(
@@ -390,7 +379,8 @@ class DBSCAN(BaseTransformer):
390
379
 
391
380
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
392
381
 
393
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
382
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
383
+ self._deps = self._get_dependencies()
394
384
  assert isinstance(
395
385
  dataset._session, Session
396
386
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -473,10 +463,8 @@ class DBSCAN(BaseTransformer):
473
463
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
474
464
  expected_dtype = convert_sp_to_sf_type(output_types[0])
475
465
 
476
- self._deps = self._batch_inference_validate_snowpark(
477
- dataset=dataset,
478
- inference_method=inference_method,
479
- )
466
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
467
+ self._deps = self._get_dependencies()
480
468
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
481
469
 
482
470
  transform_kwargs = dict(
@@ -545,16 +533,40 @@ class DBSCAN(BaseTransformer):
545
533
  self._is_fitted = True
546
534
  return output_result
547
535
 
536
+
537
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
538
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
539
+ """ Method not supported for this class.
540
+
548
541
 
549
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
550
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
551
- """
542
+ Raises:
543
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
544
+
545
+ Args:
546
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
547
+ Snowpark or Pandas DataFrame.
548
+ output_cols_prefix: Prefix for the response columns
552
549
  Returns:
553
550
  Transformed dataset.
554
551
  """
555
- self.fit(dataset)
556
- assert self._sklearn_object is not None
557
- return self._sklearn_object.embedding_
552
+ self._infer_input_output_cols(dataset)
553
+ super()._check_dataset_type(dataset)
554
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
555
+ estimator=self._sklearn_object,
556
+ dataset=dataset,
557
+ input_cols=self.input_cols,
558
+ label_cols=self.label_cols,
559
+ sample_weight_col=self.sample_weight_col,
560
+ autogenerated=self._autogenerated,
561
+ subproject=_SUBPROJECT,
562
+ )
563
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_list=self.output_cols,
566
+ )
567
+ self._sklearn_object = fitted_estimator
568
+ self._is_fitted = True
569
+ return output_result
558
570
 
559
571
 
560
572
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -645,10 +657,8 @@ class DBSCAN(BaseTransformer):
645
657
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
646
658
 
647
659
  if isinstance(dataset, DataFrame):
648
- self._deps = self._batch_inference_validate_snowpark(
649
- dataset=dataset,
650
- inference_method=inference_method,
651
- )
660
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
661
+ self._deps = self._get_dependencies()
652
662
  assert isinstance(
653
663
  dataset._session, Session
654
664
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -713,10 +723,8 @@ class DBSCAN(BaseTransformer):
713
723
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
714
724
 
715
725
  if isinstance(dataset, DataFrame):
716
- self._deps = self._batch_inference_validate_snowpark(
717
- dataset=dataset,
718
- inference_method=inference_method,
719
- )
726
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
727
+ self._deps = self._get_dependencies()
720
728
  assert isinstance(
721
729
  dataset._session, Session
722
730
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -778,10 +786,8 @@ class DBSCAN(BaseTransformer):
778
786
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
779
787
 
780
788
  if isinstance(dataset, DataFrame):
781
- self._deps = self._batch_inference_validate_snowpark(
782
- dataset=dataset,
783
- inference_method=inference_method,
784
- )
789
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
790
+ self._deps = self._get_dependencies()
785
791
  assert isinstance(
786
792
  dataset._session, Session
787
793
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -847,10 +853,8 @@ class DBSCAN(BaseTransformer):
847
853
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
848
854
 
849
855
  if isinstance(dataset, DataFrame):
850
- self._deps = self._batch_inference_validate_snowpark(
851
- dataset=dataset,
852
- inference_method=inference_method,
853
- )
856
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
857
+ self._deps = self._get_dependencies()
854
858
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
855
859
  transform_kwargs = dict(
856
860
  session=dataset._session,
@@ -912,17 +916,15 @@ class DBSCAN(BaseTransformer):
912
916
  transform_kwargs: ScoreKwargsTypedDict = dict()
913
917
 
914
918
  if isinstance(dataset, DataFrame):
915
- self._deps = self._batch_inference_validate_snowpark(
916
- dataset=dataset,
917
- inference_method="score",
918
- )
919
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
920
+ self._deps = self._get_dependencies()
919
921
  selected_cols = self._get_active_columns()
920
922
  if len(selected_cols) > 0:
921
923
  dataset = dataset.select(selected_cols)
922
924
  assert isinstance(dataset._session, Session) # keep mypy happy
923
925
  transform_kwargs = dict(
924
926
  session=dataset._session,
925
- dependencies=["snowflake-snowpark-python"] + self._deps,
927
+ dependencies=self._deps,
926
928
  score_sproc_imports=['sklearn'],
927
929
  )
928
930
  elif isinstance(dataset, pd.DataFrame):
@@ -987,11 +989,8 @@ class DBSCAN(BaseTransformer):
987
989
 
988
990
  if isinstance(dataset, DataFrame):
989
991
 
990
- self._deps = self._batch_inference_validate_snowpark(
991
- dataset=dataset,
992
- inference_method=inference_method,
993
-
994
- )
992
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
993
+ self._deps = self._get_dependencies()
995
994
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
996
995
  transform_kwargs = dict(
997
996
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class FeatureAgglomeration(BaseTransformer):
70
64
  r"""Agglomerate features
71
65
  For more details on this class, see [sklearn.cluster.FeatureAgglomeration]
@@ -343,20 +337,17 @@ class FeatureAgglomeration(BaseTransformer):
343
337
  self,
344
338
  dataset: DataFrame,
345
339
  inference_method: str,
346
- ) -> List[str]:
347
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
348
- return the available package that exists in the snowflake anaconda channel
340
+ ) -> None:
341
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
349
342
 
350
343
  Args:
351
344
  dataset: snowpark dataframe
352
345
  inference_method: the inference method such as predict, score...
353
-
346
+
354
347
  Raises:
355
348
  SnowflakeMLException: If the estimator is not fitted, raise error
356
349
  SnowflakeMLException: If the session is None, raise error
357
350
 
358
- Returns:
359
- A list of available package that exists in the snowflake anaconda channel
360
351
  """
361
352
  if not self._is_fitted:
362
353
  raise exceptions.SnowflakeMLException(
@@ -374,9 +365,7 @@ class FeatureAgglomeration(BaseTransformer):
374
365
  "Session must not specified for snowpark dataset."
375
366
  ),
376
367
  )
377
- # Validate that key package version in user workspace are supported in snowflake conda channel
378
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
379
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
368
+
380
369
 
381
370
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
382
371
  @telemetry.send_api_usage_telemetry(
@@ -422,7 +411,8 @@ class FeatureAgglomeration(BaseTransformer):
422
411
 
423
412
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
424
413
 
425
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
414
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
415
+ self._deps = self._get_dependencies()
426
416
  assert isinstance(
427
417
  dataset._session, Session
428
418
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -507,10 +497,8 @@ class FeatureAgglomeration(BaseTransformer):
507
497
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
508
498
  expected_dtype = convert_sp_to_sf_type(output_types[0])
509
499
 
510
- self._deps = self._batch_inference_validate_snowpark(
511
- dataset=dataset,
512
- inference_method=inference_method,
513
- )
500
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
501
+ self._deps = self._get_dependencies()
514
502
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
515
503
 
516
504
  transform_kwargs = dict(
@@ -579,16 +567,42 @@ class FeatureAgglomeration(BaseTransformer):
579
567
  self._is_fitted = True
580
568
  return output_result
581
569
 
570
+
571
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
572
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
573
+ """ Fit to data, then transform it
574
+ For more details on this function, see [sklearn.cluster.FeatureAgglomeration.fit_transform]
575
+ (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.FeatureAgglomeration.html#sklearn.cluster.FeatureAgglomeration.fit_transform)
576
+
582
577
 
583
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
584
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
585
- """
578
+ Raises:
579
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
580
+
581
+ Args:
582
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
583
+ Snowpark or Pandas DataFrame.
584
+ output_cols_prefix: Prefix for the response columns
586
585
  Returns:
587
586
  Transformed dataset.
588
587
  """
589
- self.fit(dataset)
590
- assert self._sklearn_object is not None
591
- return self._sklearn_object.embedding_
588
+ self._infer_input_output_cols(dataset)
589
+ super()._check_dataset_type(dataset)
590
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
591
+ estimator=self._sklearn_object,
592
+ dataset=dataset,
593
+ input_cols=self.input_cols,
594
+ label_cols=self.label_cols,
595
+ sample_weight_col=self.sample_weight_col,
596
+ autogenerated=self._autogenerated,
597
+ subproject=_SUBPROJECT,
598
+ )
599
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
600
+ drop_input_cols=self._drop_input_cols,
601
+ expected_output_cols_list=self.output_cols,
602
+ )
603
+ self._sklearn_object = fitted_estimator
604
+ self._is_fitted = True
605
+ return output_result
592
606
 
593
607
 
594
608
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -679,10 +693,8 @@ class FeatureAgglomeration(BaseTransformer):
679
693
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
680
694
 
681
695
  if isinstance(dataset, DataFrame):
682
- self._deps = self._batch_inference_validate_snowpark(
683
- dataset=dataset,
684
- inference_method=inference_method,
685
- )
696
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
697
+ self._deps = self._get_dependencies()
686
698
  assert isinstance(
687
699
  dataset._session, Session
688
700
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -747,10 +759,8 @@ class FeatureAgglomeration(BaseTransformer):
747
759
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
748
760
 
749
761
  if isinstance(dataset, DataFrame):
750
- self._deps = self._batch_inference_validate_snowpark(
751
- dataset=dataset,
752
- inference_method=inference_method,
753
- )
762
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
763
+ self._deps = self._get_dependencies()
754
764
  assert isinstance(
755
765
  dataset._session, Session
756
766
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -812,10 +822,8 @@ class FeatureAgglomeration(BaseTransformer):
812
822
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
813
823
 
814
824
  if isinstance(dataset, DataFrame):
815
- self._deps = self._batch_inference_validate_snowpark(
816
- dataset=dataset,
817
- inference_method=inference_method,
818
- )
825
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
826
+ self._deps = self._get_dependencies()
819
827
  assert isinstance(
820
828
  dataset._session, Session
821
829
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -881,10 +889,8 @@ class FeatureAgglomeration(BaseTransformer):
881
889
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
882
890
 
883
891
  if isinstance(dataset, DataFrame):
884
- self._deps = self._batch_inference_validate_snowpark(
885
- dataset=dataset,
886
- inference_method=inference_method,
887
- )
892
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
893
+ self._deps = self._get_dependencies()
888
894
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
889
895
  transform_kwargs = dict(
890
896
  session=dataset._session,
@@ -946,17 +952,15 @@ class FeatureAgglomeration(BaseTransformer):
946
952
  transform_kwargs: ScoreKwargsTypedDict = dict()
947
953
 
948
954
  if isinstance(dataset, DataFrame):
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method="score",
952
- )
955
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
956
+ self._deps = self._get_dependencies()
953
957
  selected_cols = self._get_active_columns()
954
958
  if len(selected_cols) > 0:
955
959
  dataset = dataset.select(selected_cols)
956
960
  assert isinstance(dataset._session, Session) # keep mypy happy
957
961
  transform_kwargs = dict(
958
962
  session=dataset._session,
959
- dependencies=["snowflake-snowpark-python"] + self._deps,
963
+ dependencies=self._deps,
960
964
  score_sproc_imports=['sklearn'],
961
965
  )
962
966
  elif isinstance(dataset, pd.DataFrame):
@@ -1021,11 +1025,8 @@ class FeatureAgglomeration(BaseTransformer):
1021
1025
 
1022
1026
  if isinstance(dataset, DataFrame):
1023
1027
 
1024
- self._deps = self._batch_inference_validate_snowpark(
1025
- dataset=dataset,
1026
- inference_method=inference_method,
1027
-
1028
- )
1028
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1029
+ self._deps = self._get_dependencies()
1029
1030
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1030
1031
  transform_kwargs = dict(
1031
1032
  session = dataset._session,