snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class OPTICS(BaseTransformer):
70
64
  r"""Estimate clustering structure from vector array
71
65
  For more details on this class, see [sklearn.cluster.OPTICS]
@@ -384,20 +378,17 @@ class OPTICS(BaseTransformer):
384
378
  self,
385
379
  dataset: DataFrame,
386
380
  inference_method: str,
387
- ) -> List[str]:
388
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
389
- return the available package that exists in the snowflake anaconda channel
381
+ ) -> None:
382
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
390
383
 
391
384
  Args:
392
385
  dataset: snowpark dataframe
393
386
  inference_method: the inference method such as predict, score...
394
-
387
+
395
388
  Raises:
396
389
  SnowflakeMLException: If the estimator is not fitted, raise error
397
390
  SnowflakeMLException: If the session is None, raise error
398
391
 
399
- Returns:
400
- A list of available package that exists in the snowflake anaconda channel
401
392
  """
402
393
  if not self._is_fitted:
403
394
  raise exceptions.SnowflakeMLException(
@@ -415,9 +406,7 @@ class OPTICS(BaseTransformer):
415
406
  "Session must not specified for snowpark dataset."
416
407
  ),
417
408
  )
418
- # Validate that key package version in user workspace are supported in snowflake conda channel
419
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
420
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
409
+
421
410
 
422
411
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
423
412
  @telemetry.send_api_usage_telemetry(
@@ -463,7 +452,8 @@ class OPTICS(BaseTransformer):
463
452
 
464
453
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
465
454
 
466
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
455
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
456
+ self._deps = self._get_dependencies()
467
457
  assert isinstance(
468
458
  dataset._session, Session
469
459
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -546,10 +536,8 @@ class OPTICS(BaseTransformer):
546
536
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
547
537
  expected_dtype = convert_sp_to_sf_type(output_types[0])
548
538
 
549
- self._deps = self._batch_inference_validate_snowpark(
550
- dataset=dataset,
551
- inference_method=inference_method,
552
- )
539
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
540
+ self._deps = self._get_dependencies()
553
541
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
554
542
 
555
543
  transform_kwargs = dict(
@@ -618,16 +606,40 @@ class OPTICS(BaseTransformer):
618
606
  self._is_fitted = True
619
607
  return output_result
620
608
 
609
+
610
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
611
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
612
+ """ Method not supported for this class.
613
+
621
614
 
622
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
623
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
624
- """
615
+ Raises:
616
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
617
+
618
+ Args:
619
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
620
+ Snowpark or Pandas DataFrame.
621
+ output_cols_prefix: Prefix for the response columns
625
622
  Returns:
626
623
  Transformed dataset.
627
624
  """
628
- self.fit(dataset)
629
- assert self._sklearn_object is not None
630
- return self._sklearn_object.embedding_
625
+ self._infer_input_output_cols(dataset)
626
+ super()._check_dataset_type(dataset)
627
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
628
+ estimator=self._sklearn_object,
629
+ dataset=dataset,
630
+ input_cols=self.input_cols,
631
+ label_cols=self.label_cols,
632
+ sample_weight_col=self.sample_weight_col,
633
+ autogenerated=self._autogenerated,
634
+ subproject=_SUBPROJECT,
635
+ )
636
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
637
+ drop_input_cols=self._drop_input_cols,
638
+ expected_output_cols_list=self.output_cols,
639
+ )
640
+ self._sklearn_object = fitted_estimator
641
+ self._is_fitted = True
642
+ return output_result
631
643
 
632
644
 
633
645
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -718,10 +730,8 @@ class OPTICS(BaseTransformer):
718
730
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
719
731
 
720
732
  if isinstance(dataset, DataFrame):
721
- self._deps = self._batch_inference_validate_snowpark(
722
- dataset=dataset,
723
- inference_method=inference_method,
724
- )
733
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
734
+ self._deps = self._get_dependencies()
725
735
  assert isinstance(
726
736
  dataset._session, Session
727
737
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -786,10 +796,8 @@ class OPTICS(BaseTransformer):
786
796
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
787
797
 
788
798
  if isinstance(dataset, DataFrame):
789
- self._deps = self._batch_inference_validate_snowpark(
790
- dataset=dataset,
791
- inference_method=inference_method,
792
- )
799
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
800
+ self._deps = self._get_dependencies()
793
801
  assert isinstance(
794
802
  dataset._session, Session
795
803
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -851,10 +859,8 @@ class OPTICS(BaseTransformer):
851
859
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
852
860
 
853
861
  if isinstance(dataset, DataFrame):
854
- self._deps = self._batch_inference_validate_snowpark(
855
- dataset=dataset,
856
- inference_method=inference_method,
857
- )
862
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
863
+ self._deps = self._get_dependencies()
858
864
  assert isinstance(
859
865
  dataset._session, Session
860
866
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -920,10 +926,8 @@ class OPTICS(BaseTransformer):
920
926
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
921
927
 
922
928
  if isinstance(dataset, DataFrame):
923
- self._deps = self._batch_inference_validate_snowpark(
924
- dataset=dataset,
925
- inference_method=inference_method,
926
- )
929
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
930
+ self._deps = self._get_dependencies()
927
931
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
928
932
  transform_kwargs = dict(
929
933
  session=dataset._session,
@@ -985,17 +989,15 @@ class OPTICS(BaseTransformer):
985
989
  transform_kwargs: ScoreKwargsTypedDict = dict()
986
990
 
987
991
  if isinstance(dataset, DataFrame):
988
- self._deps = self._batch_inference_validate_snowpark(
989
- dataset=dataset,
990
- inference_method="score",
991
- )
992
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
993
+ self._deps = self._get_dependencies()
992
994
  selected_cols = self._get_active_columns()
993
995
  if len(selected_cols) > 0:
994
996
  dataset = dataset.select(selected_cols)
995
997
  assert isinstance(dataset._session, Session) # keep mypy happy
996
998
  transform_kwargs = dict(
997
999
  session=dataset._session,
998
- dependencies=["snowflake-snowpark-python"] + self._deps,
1000
+ dependencies=self._deps,
999
1001
  score_sproc_imports=['sklearn'],
1000
1002
  )
1001
1003
  elif isinstance(dataset, pd.DataFrame):
@@ -1060,11 +1062,8 @@ class OPTICS(BaseTransformer):
1060
1062
 
1061
1063
  if isinstance(dataset, DataFrame):
1062
1064
 
1063
- self._deps = self._batch_inference_validate_snowpark(
1064
- dataset=dataset,
1065
- inference_method=inference_method,
1066
-
1067
- )
1065
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1066
+ self._deps = self._get_dependencies()
1068
1067
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1069
1068
  transform_kwargs = dict(
1070
1069
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SpectralBiclustering(BaseTransformer):
70
64
  r"""Spectral biclustering (Kluger, 2003)
71
65
  For more details on this class, see [sklearn.cluster.SpectralBiclustering]
@@ -322,20 +316,17 @@ class SpectralBiclustering(BaseTransformer):
322
316
  self,
323
317
  dataset: DataFrame,
324
318
  inference_method: str,
325
- ) -> List[str]:
326
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
327
- return the available package that exists in the snowflake anaconda channel
319
+ ) -> None:
320
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
328
321
 
329
322
  Args:
330
323
  dataset: snowpark dataframe
331
324
  inference_method: the inference method such as predict, score...
332
-
325
+
333
326
  Raises:
334
327
  SnowflakeMLException: If the estimator is not fitted, raise error
335
328
  SnowflakeMLException: If the session is None, raise error
336
329
 
337
- Returns:
338
- A list of available package that exists in the snowflake anaconda channel
339
330
  """
340
331
  if not self._is_fitted:
341
332
  raise exceptions.SnowflakeMLException(
@@ -353,9 +344,7 @@ class SpectralBiclustering(BaseTransformer):
353
344
  "Session must not specified for snowpark dataset."
354
345
  ),
355
346
  )
356
- # Validate that key package version in user workspace are supported in snowflake conda channel
357
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
358
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
347
+
359
348
 
360
349
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
361
350
  @telemetry.send_api_usage_telemetry(
@@ -401,7 +390,8 @@ class SpectralBiclustering(BaseTransformer):
401
390
 
402
391
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
403
392
 
404
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
393
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
394
+ self._deps = self._get_dependencies()
405
395
  assert isinstance(
406
396
  dataset._session, Session
407
397
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -484,10 +474,8 @@ class SpectralBiclustering(BaseTransformer):
484
474
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
485
475
  expected_dtype = convert_sp_to_sf_type(output_types[0])
486
476
 
487
- self._deps = self._batch_inference_validate_snowpark(
488
- dataset=dataset,
489
- inference_method=inference_method,
490
- )
477
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
478
+ self._deps = self._get_dependencies()
491
479
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
492
480
 
493
481
  transform_kwargs = dict(
@@ -554,16 +542,40 @@ class SpectralBiclustering(BaseTransformer):
554
542
  self._is_fitted = True
555
543
  return output_result
556
544
 
545
+
546
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
547
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
548
+ """ Method not supported for this class.
557
549
 
558
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
559
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
560
- """
550
+
551
+ Raises:
552
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
553
+
554
+ Args:
555
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
556
+ Snowpark or Pandas DataFrame.
557
+ output_cols_prefix: Prefix for the response columns
561
558
  Returns:
562
559
  Transformed dataset.
563
560
  """
564
- self.fit(dataset)
565
- assert self._sklearn_object is not None
566
- return self._sklearn_object.embedding_
561
+ self._infer_input_output_cols(dataset)
562
+ super()._check_dataset_type(dataset)
563
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
564
+ estimator=self._sklearn_object,
565
+ dataset=dataset,
566
+ input_cols=self.input_cols,
567
+ label_cols=self.label_cols,
568
+ sample_weight_col=self.sample_weight_col,
569
+ autogenerated=self._autogenerated,
570
+ subproject=_SUBPROJECT,
571
+ )
572
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
573
+ drop_input_cols=self._drop_input_cols,
574
+ expected_output_cols_list=self.output_cols,
575
+ )
576
+ self._sklearn_object = fitted_estimator
577
+ self._is_fitted = True
578
+ return output_result
567
579
 
568
580
 
569
581
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -654,10 +666,8 @@ class SpectralBiclustering(BaseTransformer):
654
666
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
655
667
 
656
668
  if isinstance(dataset, DataFrame):
657
- self._deps = self._batch_inference_validate_snowpark(
658
- dataset=dataset,
659
- inference_method=inference_method,
660
- )
669
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
670
+ self._deps = self._get_dependencies()
661
671
  assert isinstance(
662
672
  dataset._session, Session
663
673
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -722,10 +732,8 @@ class SpectralBiclustering(BaseTransformer):
722
732
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
723
733
 
724
734
  if isinstance(dataset, DataFrame):
725
- self._deps = self._batch_inference_validate_snowpark(
726
- dataset=dataset,
727
- inference_method=inference_method,
728
- )
735
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
736
+ self._deps = self._get_dependencies()
729
737
  assert isinstance(
730
738
  dataset._session, Session
731
739
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -787,10 +795,8 @@ class SpectralBiclustering(BaseTransformer):
787
795
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
788
796
 
789
797
  if isinstance(dataset, DataFrame):
790
- self._deps = self._batch_inference_validate_snowpark(
791
- dataset=dataset,
792
- inference_method=inference_method,
793
- )
798
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
799
+ self._deps = self._get_dependencies()
794
800
  assert isinstance(
795
801
  dataset._session, Session
796
802
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -856,10 +862,8 @@ class SpectralBiclustering(BaseTransformer):
856
862
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
857
863
 
858
864
  if isinstance(dataset, DataFrame):
859
- self._deps = self._batch_inference_validate_snowpark(
860
- dataset=dataset,
861
- inference_method=inference_method,
862
- )
865
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
866
+ self._deps = self._get_dependencies()
863
867
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
864
868
  transform_kwargs = dict(
865
869
  session=dataset._session,
@@ -921,17 +925,15 @@ class SpectralBiclustering(BaseTransformer):
921
925
  transform_kwargs: ScoreKwargsTypedDict = dict()
922
926
 
923
927
  if isinstance(dataset, DataFrame):
924
- self._deps = self._batch_inference_validate_snowpark(
925
- dataset=dataset,
926
- inference_method="score",
927
- )
928
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
929
+ self._deps = self._get_dependencies()
928
930
  selected_cols = self._get_active_columns()
929
931
  if len(selected_cols) > 0:
930
932
  dataset = dataset.select(selected_cols)
931
933
  assert isinstance(dataset._session, Session) # keep mypy happy
932
934
  transform_kwargs = dict(
933
935
  session=dataset._session,
934
- dependencies=["snowflake-snowpark-python"] + self._deps,
936
+ dependencies=self._deps,
935
937
  score_sproc_imports=['sklearn'],
936
938
  )
937
939
  elif isinstance(dataset, pd.DataFrame):
@@ -996,11 +998,8 @@ class SpectralBiclustering(BaseTransformer):
996
998
 
997
999
  if isinstance(dataset, DataFrame):
998
1000
 
999
- self._deps = self._batch_inference_validate_snowpark(
1000
- dataset=dataset,
1001
- inference_method=inference_method,
1002
-
1003
- )
1001
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1002
+ self._deps = self._get_dependencies()
1004
1003
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1005
1004
  transform_kwargs = dict(
1006
1005
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SpectralClustering(BaseTransformer):
70
64
  r"""Apply clustering to a projection of the normalized Laplacian
71
65
  For more details on this class, see [sklearn.cluster.SpectralClustering]
@@ -380,20 +374,17 @@ class SpectralClustering(BaseTransformer):
380
374
  self,
381
375
  dataset: DataFrame,
382
376
  inference_method: str,
383
- ) -> List[str]:
384
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
385
- return the available package that exists in the snowflake anaconda channel
377
+ ) -> None:
378
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
386
379
 
387
380
  Args:
388
381
  dataset: snowpark dataframe
389
382
  inference_method: the inference method such as predict, score...
390
-
383
+
391
384
  Raises:
392
385
  SnowflakeMLException: If the estimator is not fitted, raise error
393
386
  SnowflakeMLException: If the session is None, raise error
394
387
 
395
- Returns:
396
- A list of available package that exists in the snowflake anaconda channel
397
388
  """
398
389
  if not self._is_fitted:
399
390
  raise exceptions.SnowflakeMLException(
@@ -411,9 +402,7 @@ class SpectralClustering(BaseTransformer):
411
402
  "Session must not specified for snowpark dataset."
412
403
  ),
413
404
  )
414
- # Validate that key package version in user workspace are supported in snowflake conda channel
415
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
416
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
405
+
417
406
 
418
407
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
419
408
  @telemetry.send_api_usage_telemetry(
@@ -459,7 +448,8 @@ class SpectralClustering(BaseTransformer):
459
448
 
460
449
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
461
450
 
462
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
451
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
452
+ self._deps = self._get_dependencies()
463
453
  assert isinstance(
464
454
  dataset._session, Session
465
455
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -542,10 +532,8 @@ class SpectralClustering(BaseTransformer):
542
532
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
543
533
  expected_dtype = convert_sp_to_sf_type(output_types[0])
544
534
 
545
- self._deps = self._batch_inference_validate_snowpark(
546
- dataset=dataset,
547
- inference_method=inference_method,
548
- )
535
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
536
+ self._deps = self._get_dependencies()
549
537
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
550
538
 
551
539
  transform_kwargs = dict(
@@ -614,16 +602,40 @@ class SpectralClustering(BaseTransformer):
614
602
  self._is_fitted = True
615
603
  return output_result
616
604
 
605
+
606
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
607
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
608
+ """ Method not supported for this class.
609
+
617
610
 
618
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
619
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
620
- """
611
+ Raises:
612
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
613
+
614
+ Args:
615
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
616
+ Snowpark or Pandas DataFrame.
617
+ output_cols_prefix: Prefix for the response columns
621
618
  Returns:
622
619
  Transformed dataset.
623
620
  """
624
- self.fit(dataset)
625
- assert self._sklearn_object is not None
626
- return self._sklearn_object.embedding_
621
+ self._infer_input_output_cols(dataset)
622
+ super()._check_dataset_type(dataset)
623
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
624
+ estimator=self._sklearn_object,
625
+ dataset=dataset,
626
+ input_cols=self.input_cols,
627
+ label_cols=self.label_cols,
628
+ sample_weight_col=self.sample_weight_col,
629
+ autogenerated=self._autogenerated,
630
+ subproject=_SUBPROJECT,
631
+ )
632
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
633
+ drop_input_cols=self._drop_input_cols,
634
+ expected_output_cols_list=self.output_cols,
635
+ )
636
+ self._sklearn_object = fitted_estimator
637
+ self._is_fitted = True
638
+ return output_result
627
639
 
628
640
 
629
641
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -714,10 +726,8 @@ class SpectralClustering(BaseTransformer):
714
726
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
715
727
 
716
728
  if isinstance(dataset, DataFrame):
717
- self._deps = self._batch_inference_validate_snowpark(
718
- dataset=dataset,
719
- inference_method=inference_method,
720
- )
729
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
730
+ self._deps = self._get_dependencies()
721
731
  assert isinstance(
722
732
  dataset._session, Session
723
733
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -782,10 +792,8 @@ class SpectralClustering(BaseTransformer):
782
792
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
783
793
 
784
794
  if isinstance(dataset, DataFrame):
785
- self._deps = self._batch_inference_validate_snowpark(
786
- dataset=dataset,
787
- inference_method=inference_method,
788
- )
795
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
796
+ self._deps = self._get_dependencies()
789
797
  assert isinstance(
790
798
  dataset._session, Session
791
799
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -847,10 +855,8 @@ class SpectralClustering(BaseTransformer):
847
855
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
848
856
 
849
857
  if isinstance(dataset, DataFrame):
850
- self._deps = self._batch_inference_validate_snowpark(
851
- dataset=dataset,
852
- inference_method=inference_method,
853
- )
858
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
859
+ self._deps = self._get_dependencies()
854
860
  assert isinstance(
855
861
  dataset._session, Session
856
862
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -916,10 +922,8 @@ class SpectralClustering(BaseTransformer):
916
922
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
917
923
 
918
924
  if isinstance(dataset, DataFrame):
919
- self._deps = self._batch_inference_validate_snowpark(
920
- dataset=dataset,
921
- inference_method=inference_method,
922
- )
925
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
926
+ self._deps = self._get_dependencies()
923
927
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
924
928
  transform_kwargs = dict(
925
929
  session=dataset._session,
@@ -981,17 +985,15 @@ class SpectralClustering(BaseTransformer):
981
985
  transform_kwargs: ScoreKwargsTypedDict = dict()
982
986
 
983
987
  if isinstance(dataset, DataFrame):
984
- self._deps = self._batch_inference_validate_snowpark(
985
- dataset=dataset,
986
- inference_method="score",
987
- )
988
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
989
+ self._deps = self._get_dependencies()
988
990
  selected_cols = self._get_active_columns()
989
991
  if len(selected_cols) > 0:
990
992
  dataset = dataset.select(selected_cols)
991
993
  assert isinstance(dataset._session, Session) # keep mypy happy
992
994
  transform_kwargs = dict(
993
995
  session=dataset._session,
994
- dependencies=["snowflake-snowpark-python"] + self._deps,
996
+ dependencies=self._deps,
995
997
  score_sproc_imports=['sklearn'],
996
998
  )
997
999
  elif isinstance(dataset, pd.DataFrame):
@@ -1056,11 +1058,8 @@ class SpectralClustering(BaseTransformer):
1056
1058
 
1057
1059
  if isinstance(dataset, DataFrame):
1058
1060
 
1059
- self._deps = self._batch_inference_validate_snowpark(
1060
- dataset=dataset,
1061
- inference_method=inference_method,
1062
-
1063
- )
1061
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1062
+ self._deps = self._get_dependencies()
1064
1063
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1065
1064
  transform_kwargs = dict(
1066
1065
  session = dataset._session,