snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class KMeans(BaseTransformer):
70
64
  r"""K-Means clustering
71
65
  For more details on this class, see [sklearn.cluster.KMeans]
@@ -338,20 +332,17 @@ class KMeans(BaseTransformer):
338
332
  self,
339
333
  dataset: DataFrame,
340
334
  inference_method: str,
341
- ) -> List[str]:
342
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
343
- return the available package that exists in the snowflake anaconda channel
335
+ ) -> None:
336
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
344
337
 
345
338
  Args:
346
339
  dataset: snowpark dataframe
347
340
  inference_method: the inference method such as predict, score...
348
-
341
+
349
342
  Raises:
350
343
  SnowflakeMLException: If the estimator is not fitted, raise error
351
344
  SnowflakeMLException: If the session is None, raise error
352
345
 
353
- Returns:
354
- A list of available package that exists in the snowflake anaconda channel
355
346
  """
356
347
  if not self._is_fitted:
357
348
  raise exceptions.SnowflakeMLException(
@@ -369,9 +360,7 @@ class KMeans(BaseTransformer):
369
360
  "Session must not specified for snowpark dataset."
370
361
  ),
371
362
  )
372
- # Validate that key package version in user workspace are supported in snowflake conda channel
373
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
374
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
363
+
375
364
 
376
365
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
377
366
  @telemetry.send_api_usage_telemetry(
@@ -419,7 +408,8 @@ class KMeans(BaseTransformer):
419
408
 
420
409
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
421
410
 
422
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
411
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
412
+ self._deps = self._get_dependencies()
423
413
  assert isinstance(
424
414
  dataset._session, Session
425
415
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -504,10 +494,8 @@ class KMeans(BaseTransformer):
504
494
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
505
495
  expected_dtype = convert_sp_to_sf_type(output_types[0])
506
496
 
507
- self._deps = self._batch_inference_validate_snowpark(
508
- dataset=dataset,
509
- inference_method=inference_method,
510
- )
497
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
498
+ self._deps = self._get_dependencies()
511
499
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
512
500
 
513
501
  transform_kwargs = dict(
@@ -576,16 +564,42 @@ class KMeans(BaseTransformer):
576
564
  self._is_fitted = True
577
565
  return output_result
578
566
 
567
+
568
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
569
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
570
+ """ Compute clustering and transform X to cluster-distance space
571
+ For more details on this function, see [sklearn.cluster.KMeans.fit_transform]
572
+ (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_transform)
573
+
579
574
 
580
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
581
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
582
- """
575
+ Raises:
576
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
577
+
578
+ Args:
579
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
580
+ Snowpark or Pandas DataFrame.
581
+ output_cols_prefix: Prefix for the response columns
583
582
  Returns:
584
583
  Transformed dataset.
585
584
  """
586
- self.fit(dataset)
587
- assert self._sklearn_object is not None
588
- return self._sklearn_object.embedding_
585
+ self._infer_input_output_cols(dataset)
586
+ super()._check_dataset_type(dataset)
587
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
588
+ estimator=self._sklearn_object,
589
+ dataset=dataset,
590
+ input_cols=self.input_cols,
591
+ label_cols=self.label_cols,
592
+ sample_weight_col=self.sample_weight_col,
593
+ autogenerated=self._autogenerated,
594
+ subproject=_SUBPROJECT,
595
+ )
596
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
597
+ drop_input_cols=self._drop_input_cols,
598
+ expected_output_cols_list=self.output_cols,
599
+ )
600
+ self._sklearn_object = fitted_estimator
601
+ self._is_fitted = True
602
+ return output_result
589
603
 
590
604
 
591
605
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -676,10 +690,8 @@ class KMeans(BaseTransformer):
676
690
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
677
691
 
678
692
  if isinstance(dataset, DataFrame):
679
- self._deps = self._batch_inference_validate_snowpark(
680
- dataset=dataset,
681
- inference_method=inference_method,
682
- )
693
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
694
+ self._deps = self._get_dependencies()
683
695
  assert isinstance(
684
696
  dataset._session, Session
685
697
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -744,10 +756,8 @@ class KMeans(BaseTransformer):
744
756
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
745
757
 
746
758
  if isinstance(dataset, DataFrame):
747
- self._deps = self._batch_inference_validate_snowpark(
748
- dataset=dataset,
749
- inference_method=inference_method,
750
- )
759
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
760
+ self._deps = self._get_dependencies()
751
761
  assert isinstance(
752
762
  dataset._session, Session
753
763
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -809,10 +819,8 @@ class KMeans(BaseTransformer):
809
819
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
810
820
 
811
821
  if isinstance(dataset, DataFrame):
812
- self._deps = self._batch_inference_validate_snowpark(
813
- dataset=dataset,
814
- inference_method=inference_method,
815
- )
822
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
823
+ self._deps = self._get_dependencies()
816
824
  assert isinstance(
817
825
  dataset._session, Session
818
826
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -878,10 +886,8 @@ class KMeans(BaseTransformer):
878
886
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
879
887
 
880
888
  if isinstance(dataset, DataFrame):
881
- self._deps = self._batch_inference_validate_snowpark(
882
- dataset=dataset,
883
- inference_method=inference_method,
884
- )
889
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
890
+ self._deps = self._get_dependencies()
885
891
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
886
892
  transform_kwargs = dict(
887
893
  session=dataset._session,
@@ -945,17 +951,15 @@ class KMeans(BaseTransformer):
945
951
  transform_kwargs: ScoreKwargsTypedDict = dict()
946
952
 
947
953
  if isinstance(dataset, DataFrame):
948
- self._deps = self._batch_inference_validate_snowpark(
949
- dataset=dataset,
950
- inference_method="score",
951
- )
954
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
955
+ self._deps = self._get_dependencies()
952
956
  selected_cols = self._get_active_columns()
953
957
  if len(selected_cols) > 0:
954
958
  dataset = dataset.select(selected_cols)
955
959
  assert isinstance(dataset._session, Session) # keep mypy happy
956
960
  transform_kwargs = dict(
957
961
  session=dataset._session,
958
- dependencies=["snowflake-snowpark-python"] + self._deps,
962
+ dependencies=self._deps,
959
963
  score_sproc_imports=['sklearn'],
960
964
  )
961
965
  elif isinstance(dataset, pd.DataFrame):
@@ -1020,11 +1024,8 @@ class KMeans(BaseTransformer):
1020
1024
 
1021
1025
  if isinstance(dataset, DataFrame):
1022
1026
 
1023
- self._deps = self._batch_inference_validate_snowpark(
1024
- dataset=dataset,
1025
- inference_method=inference_method,
1026
-
1027
- )
1027
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1028
+ self._deps = self._get_dependencies()
1028
1029
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1029
1030
  transform_kwargs = dict(
1030
1031
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MeanShift(BaseTransformer):
70
64
  r"""Mean shift clustering using a flat kernel
71
65
  For more details on this class, see [sklearn.cluster.MeanShift]
@@ -314,20 +308,17 @@ class MeanShift(BaseTransformer):
314
308
  self,
315
309
  dataset: DataFrame,
316
310
  inference_method: str,
317
- ) -> List[str]:
318
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
319
- return the available package that exists in the snowflake anaconda channel
311
+ ) -> None:
312
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
320
313
 
321
314
  Args:
322
315
  dataset: snowpark dataframe
323
316
  inference_method: the inference method such as predict, score...
324
-
317
+
325
318
  Raises:
326
319
  SnowflakeMLException: If the estimator is not fitted, raise error
327
320
  SnowflakeMLException: If the session is None, raise error
328
321
 
329
- Returns:
330
- A list of available package that exists in the snowflake anaconda channel
331
322
  """
332
323
  if not self._is_fitted:
333
324
  raise exceptions.SnowflakeMLException(
@@ -345,9 +336,7 @@ class MeanShift(BaseTransformer):
345
336
  "Session must not specified for snowpark dataset."
346
337
  ),
347
338
  )
348
- # Validate that key package version in user workspace are supported in snowflake conda channel
349
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
350
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
339
+
351
340
 
352
341
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
353
342
  @telemetry.send_api_usage_telemetry(
@@ -395,7 +384,8 @@ class MeanShift(BaseTransformer):
395
384
 
396
385
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
397
386
 
398
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
387
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
388
+ self._deps = self._get_dependencies()
399
389
  assert isinstance(
400
390
  dataset._session, Session
401
391
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -478,10 +468,8 @@ class MeanShift(BaseTransformer):
478
468
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
479
469
  expected_dtype = convert_sp_to_sf_type(output_types[0])
480
470
 
481
- self._deps = self._batch_inference_validate_snowpark(
482
- dataset=dataset,
483
- inference_method=inference_method,
484
- )
471
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
472
+ self._deps = self._get_dependencies()
485
473
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
486
474
 
487
475
  transform_kwargs = dict(
@@ -550,16 +538,40 @@ class MeanShift(BaseTransformer):
550
538
  self._is_fitted = True
551
539
  return output_result
552
540
 
541
+
542
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
543
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
544
+ """ Method not supported for this class.
545
+
553
546
 
554
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
555
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
556
- """
547
+ Raises:
548
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
549
+
550
+ Args:
551
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
552
+ Snowpark or Pandas DataFrame.
553
+ output_cols_prefix: Prefix for the response columns
557
554
  Returns:
558
555
  Transformed dataset.
559
556
  """
560
- self.fit(dataset)
561
- assert self._sklearn_object is not None
562
- return self._sklearn_object.embedding_
557
+ self._infer_input_output_cols(dataset)
558
+ super()._check_dataset_type(dataset)
559
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
560
+ estimator=self._sklearn_object,
561
+ dataset=dataset,
562
+ input_cols=self.input_cols,
563
+ label_cols=self.label_cols,
564
+ sample_weight_col=self.sample_weight_col,
565
+ autogenerated=self._autogenerated,
566
+ subproject=_SUBPROJECT,
567
+ )
568
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=self.output_cols,
571
+ )
572
+ self._sklearn_object = fitted_estimator
573
+ self._is_fitted = True
574
+ return output_result
563
575
 
564
576
 
565
577
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -650,10 +662,8 @@ class MeanShift(BaseTransformer):
650
662
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
651
663
 
652
664
  if isinstance(dataset, DataFrame):
653
- self._deps = self._batch_inference_validate_snowpark(
654
- dataset=dataset,
655
- inference_method=inference_method,
656
- )
665
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
666
+ self._deps = self._get_dependencies()
657
667
  assert isinstance(
658
668
  dataset._session, Session
659
669
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -718,10 +728,8 @@ class MeanShift(BaseTransformer):
718
728
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
719
729
 
720
730
  if isinstance(dataset, DataFrame):
721
- self._deps = self._batch_inference_validate_snowpark(
722
- dataset=dataset,
723
- inference_method=inference_method,
724
- )
731
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
732
+ self._deps = self._get_dependencies()
725
733
  assert isinstance(
726
734
  dataset._session, Session
727
735
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -783,10 +791,8 @@ class MeanShift(BaseTransformer):
783
791
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
784
792
 
785
793
  if isinstance(dataset, DataFrame):
786
- self._deps = self._batch_inference_validate_snowpark(
787
- dataset=dataset,
788
- inference_method=inference_method,
789
- )
794
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
795
+ self._deps = self._get_dependencies()
790
796
  assert isinstance(
791
797
  dataset._session, Session
792
798
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -852,10 +858,8 @@ class MeanShift(BaseTransformer):
852
858
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
853
859
 
854
860
  if isinstance(dataset, DataFrame):
855
- self._deps = self._batch_inference_validate_snowpark(
856
- dataset=dataset,
857
- inference_method=inference_method,
858
- )
861
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
862
+ self._deps = self._get_dependencies()
859
863
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
860
864
  transform_kwargs = dict(
861
865
  session=dataset._session,
@@ -917,17 +921,15 @@ class MeanShift(BaseTransformer):
917
921
  transform_kwargs: ScoreKwargsTypedDict = dict()
918
922
 
919
923
  if isinstance(dataset, DataFrame):
920
- self._deps = self._batch_inference_validate_snowpark(
921
- dataset=dataset,
922
- inference_method="score",
923
- )
924
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
925
+ self._deps = self._get_dependencies()
924
926
  selected_cols = self._get_active_columns()
925
927
  if len(selected_cols) > 0:
926
928
  dataset = dataset.select(selected_cols)
927
929
  assert isinstance(dataset._session, Session) # keep mypy happy
928
930
  transform_kwargs = dict(
929
931
  session=dataset._session,
930
- dependencies=["snowflake-snowpark-python"] + self._deps,
932
+ dependencies=self._deps,
931
933
  score_sproc_imports=['sklearn'],
932
934
  )
933
935
  elif isinstance(dataset, pd.DataFrame):
@@ -992,11 +994,8 @@ class MeanShift(BaseTransformer):
992
994
 
993
995
  if isinstance(dataset, DataFrame):
994
996
 
995
- self._deps = self._batch_inference_validate_snowpark(
996
- dataset=dataset,
997
- inference_method=inference_method,
998
-
999
- )
997
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
998
+ self._deps = self._get_dependencies()
1000
999
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1001
1000
  transform_kwargs = dict(
1002
1001
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MiniBatchKMeans(BaseTransformer):
70
64
  r"""Mini-Batch K-Means clustering
71
65
  For more details on this class, see [sklearn.cluster.MiniBatchKMeans]
@@ -364,20 +358,17 @@ class MiniBatchKMeans(BaseTransformer):
364
358
  self,
365
359
  dataset: DataFrame,
366
360
  inference_method: str,
367
- ) -> List[str]:
368
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
369
- return the available package that exists in the snowflake anaconda channel
361
+ ) -> None:
362
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
370
363
 
371
364
  Args:
372
365
  dataset: snowpark dataframe
373
366
  inference_method: the inference method such as predict, score...
374
-
367
+
375
368
  Raises:
376
369
  SnowflakeMLException: If the estimator is not fitted, raise error
377
370
  SnowflakeMLException: If the session is None, raise error
378
371
 
379
- Returns:
380
- A list of available package that exists in the snowflake anaconda channel
381
372
  """
382
373
  if not self._is_fitted:
383
374
  raise exceptions.SnowflakeMLException(
@@ -395,9 +386,7 @@ class MiniBatchKMeans(BaseTransformer):
395
386
  "Session must not specified for snowpark dataset."
396
387
  ),
397
388
  )
398
- # Validate that key package version in user workspace are supported in snowflake conda channel
399
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
400
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
389
+
401
390
 
402
391
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
403
392
  @telemetry.send_api_usage_telemetry(
@@ -445,7 +434,8 @@ class MiniBatchKMeans(BaseTransformer):
445
434
 
446
435
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
447
436
 
448
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
437
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
438
+ self._deps = self._get_dependencies()
449
439
  assert isinstance(
450
440
  dataset._session, Session
451
441
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -530,10 +520,8 @@ class MiniBatchKMeans(BaseTransformer):
530
520
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
531
521
  expected_dtype = convert_sp_to_sf_type(output_types[0])
532
522
 
533
- self._deps = self._batch_inference_validate_snowpark(
534
- dataset=dataset,
535
- inference_method=inference_method,
536
- )
523
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
524
+ self._deps = self._get_dependencies()
537
525
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
538
526
 
539
527
  transform_kwargs = dict(
@@ -602,16 +590,42 @@ class MiniBatchKMeans(BaseTransformer):
602
590
  self._is_fitted = True
603
591
  return output_result
604
592
 
593
+
594
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
595
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
596
+ """ Compute clustering and transform X to cluster-distance space
597
+ For more details on this function, see [sklearn.cluster.MiniBatchKMeans.fit_transform]
598
+ (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans.fit_transform)
599
+
605
600
 
606
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
607
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
608
- """
601
+ Raises:
602
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
603
+
604
+ Args:
605
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
606
+ Snowpark or Pandas DataFrame.
607
+ output_cols_prefix: Prefix for the response columns
609
608
  Returns:
610
609
  Transformed dataset.
611
610
  """
612
- self.fit(dataset)
613
- assert self._sklearn_object is not None
614
- return self._sklearn_object.embedding_
611
+ self._infer_input_output_cols(dataset)
612
+ super()._check_dataset_type(dataset)
613
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
614
+ estimator=self._sklearn_object,
615
+ dataset=dataset,
616
+ input_cols=self.input_cols,
617
+ label_cols=self.label_cols,
618
+ sample_weight_col=self.sample_weight_col,
619
+ autogenerated=self._autogenerated,
620
+ subproject=_SUBPROJECT,
621
+ )
622
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
623
+ drop_input_cols=self._drop_input_cols,
624
+ expected_output_cols_list=self.output_cols,
625
+ )
626
+ self._sklearn_object = fitted_estimator
627
+ self._is_fitted = True
628
+ return output_result
615
629
 
616
630
 
617
631
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -702,10 +716,8 @@ class MiniBatchKMeans(BaseTransformer):
702
716
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
703
717
 
704
718
  if isinstance(dataset, DataFrame):
705
- self._deps = self._batch_inference_validate_snowpark(
706
- dataset=dataset,
707
- inference_method=inference_method,
708
- )
719
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
720
+ self._deps = self._get_dependencies()
709
721
  assert isinstance(
710
722
  dataset._session, Session
711
723
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -770,10 +782,8 @@ class MiniBatchKMeans(BaseTransformer):
770
782
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
771
783
 
772
784
  if isinstance(dataset, DataFrame):
773
- self._deps = self._batch_inference_validate_snowpark(
774
- dataset=dataset,
775
- inference_method=inference_method,
776
- )
785
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
786
+ self._deps = self._get_dependencies()
777
787
  assert isinstance(
778
788
  dataset._session, Session
779
789
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -835,10 +845,8 @@ class MiniBatchKMeans(BaseTransformer):
835
845
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
836
846
 
837
847
  if isinstance(dataset, DataFrame):
838
- self._deps = self._batch_inference_validate_snowpark(
839
- dataset=dataset,
840
- inference_method=inference_method,
841
- )
848
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
849
+ self._deps = self._get_dependencies()
842
850
  assert isinstance(
843
851
  dataset._session, Session
844
852
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -904,10 +912,8 @@ class MiniBatchKMeans(BaseTransformer):
904
912
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
905
913
 
906
914
  if isinstance(dataset, DataFrame):
907
- self._deps = self._batch_inference_validate_snowpark(
908
- dataset=dataset,
909
- inference_method=inference_method,
910
- )
915
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
916
+ self._deps = self._get_dependencies()
911
917
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
912
918
  transform_kwargs = dict(
913
919
  session=dataset._session,
@@ -971,17 +977,15 @@ class MiniBatchKMeans(BaseTransformer):
971
977
  transform_kwargs: ScoreKwargsTypedDict = dict()
972
978
 
973
979
  if isinstance(dataset, DataFrame):
974
- self._deps = self._batch_inference_validate_snowpark(
975
- dataset=dataset,
976
- inference_method="score",
977
- )
980
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
981
+ self._deps = self._get_dependencies()
978
982
  selected_cols = self._get_active_columns()
979
983
  if len(selected_cols) > 0:
980
984
  dataset = dataset.select(selected_cols)
981
985
  assert isinstance(dataset._session, Session) # keep mypy happy
982
986
  transform_kwargs = dict(
983
987
  session=dataset._session,
984
- dependencies=["snowflake-snowpark-python"] + self._deps,
988
+ dependencies=self._deps,
985
989
  score_sproc_imports=['sklearn'],
986
990
  )
987
991
  elif isinstance(dataset, pd.DataFrame):
@@ -1046,11 +1050,8 @@ class MiniBatchKMeans(BaseTransformer):
1046
1050
 
1047
1051
  if isinstance(dataset, DataFrame):
1048
1052
 
1049
- self._deps = self._batch_inference_validate_snowpark(
1050
- dataset=dataset,
1051
- inference_method=inference_method,
1052
-
1053
- )
1053
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1054
+ self._deps = self._get_dependencies()
1054
1055
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1055
1056
  transform_kwargs = dict(
1056
1057
  session = dataset._session,