snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class TSNE(BaseTransformer):
70
64
  r"""T-distributed Stochastic Neighbor Embedding
71
65
  For more details on this class, see [sklearn.manifold.TSNE]
@@ -383,20 +377,17 @@ class TSNE(BaseTransformer):
383
377
  self,
384
378
  dataset: DataFrame,
385
379
  inference_method: str,
386
- ) -> List[str]:
387
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
388
- return the available package that exists in the snowflake anaconda channel
380
+ ) -> None:
381
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
389
382
 
390
383
  Args:
391
384
  dataset: snowpark dataframe
392
385
  inference_method: the inference method such as predict, score...
393
-
386
+
394
387
  Raises:
395
388
  SnowflakeMLException: If the estimator is not fitted, raise error
396
389
  SnowflakeMLException: If the session is None, raise error
397
390
 
398
- Returns:
399
- A list of available package that exists in the snowflake anaconda channel
400
391
  """
401
392
  if not self._is_fitted:
402
393
  raise exceptions.SnowflakeMLException(
@@ -414,9 +405,7 @@ class TSNE(BaseTransformer):
414
405
  "Session must not specified for snowpark dataset."
415
406
  ),
416
407
  )
417
- # Validate that key package version in user workspace are supported in snowflake conda channel
418
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
419
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
408
+
420
409
 
421
410
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
422
411
  @telemetry.send_api_usage_telemetry(
@@ -462,7 +451,8 @@ class TSNE(BaseTransformer):
462
451
 
463
452
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
464
453
 
465
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
454
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
455
+ self._deps = self._get_dependencies()
466
456
  assert isinstance(
467
457
  dataset._session, Session
468
458
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -545,10 +535,8 @@ class TSNE(BaseTransformer):
545
535
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
546
536
  expected_dtype = convert_sp_to_sf_type(output_types[0])
547
537
 
548
- self._deps = self._batch_inference_validate_snowpark(
549
- dataset=dataset,
550
- inference_method=inference_method,
551
- )
538
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
539
+ self._deps = self._get_dependencies()
552
540
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
553
541
 
554
542
  transform_kwargs = dict(
@@ -615,16 +603,42 @@ class TSNE(BaseTransformer):
615
603
  self._is_fitted = True
616
604
  return output_result
617
605
 
606
+
607
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
608
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
609
+ """ Fit X into an embedded space and return that transformed output
610
+ For more details on this function, see [sklearn.manifold.TSNE.fit_transform]
611
+ (https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html#sklearn.manifold.TSNE.fit_transform)
612
+
618
613
 
619
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
620
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
621
- """
614
+ Raises:
615
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
616
+
617
+ Args:
618
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
619
+ Snowpark or Pandas DataFrame.
620
+ output_cols_prefix: Prefix for the response columns
622
621
  Returns:
623
622
  Transformed dataset.
624
623
  """
625
- self.fit(dataset)
626
- assert self._sklearn_object is not None
627
- return self._sklearn_object.embedding_
624
+ self._infer_input_output_cols(dataset)
625
+ super()._check_dataset_type(dataset)
626
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
627
+ estimator=self._sklearn_object,
628
+ dataset=dataset,
629
+ input_cols=self.input_cols,
630
+ label_cols=self.label_cols,
631
+ sample_weight_col=self.sample_weight_col,
632
+ autogenerated=self._autogenerated,
633
+ subproject=_SUBPROJECT,
634
+ )
635
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
636
+ drop_input_cols=self._drop_input_cols,
637
+ expected_output_cols_list=self.output_cols,
638
+ )
639
+ self._sklearn_object = fitted_estimator
640
+ self._is_fitted = True
641
+ return output_result
628
642
 
629
643
 
630
644
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -715,10 +729,8 @@ class TSNE(BaseTransformer):
715
729
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
716
730
 
717
731
  if isinstance(dataset, DataFrame):
718
- self._deps = self._batch_inference_validate_snowpark(
719
- dataset=dataset,
720
- inference_method=inference_method,
721
- )
732
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
733
+ self._deps = self._get_dependencies()
722
734
  assert isinstance(
723
735
  dataset._session, Session
724
736
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -783,10 +795,8 @@ class TSNE(BaseTransformer):
783
795
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
784
796
 
785
797
  if isinstance(dataset, DataFrame):
786
- self._deps = self._batch_inference_validate_snowpark(
787
- dataset=dataset,
788
- inference_method=inference_method,
789
- )
798
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
799
+ self._deps = self._get_dependencies()
790
800
  assert isinstance(
791
801
  dataset._session, Session
792
802
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -848,10 +858,8 @@ class TSNE(BaseTransformer):
848
858
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
849
859
 
850
860
  if isinstance(dataset, DataFrame):
851
- self._deps = self._batch_inference_validate_snowpark(
852
- dataset=dataset,
853
- inference_method=inference_method,
854
- )
861
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
862
+ self._deps = self._get_dependencies()
855
863
  assert isinstance(
856
864
  dataset._session, Session
857
865
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -917,10 +925,8 @@ class TSNE(BaseTransformer):
917
925
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
918
926
 
919
927
  if isinstance(dataset, DataFrame):
920
- self._deps = self._batch_inference_validate_snowpark(
921
- dataset=dataset,
922
- inference_method=inference_method,
923
- )
928
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
929
+ self._deps = self._get_dependencies()
924
930
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
925
931
  transform_kwargs = dict(
926
932
  session=dataset._session,
@@ -982,17 +988,15 @@ class TSNE(BaseTransformer):
982
988
  transform_kwargs: ScoreKwargsTypedDict = dict()
983
989
 
984
990
  if isinstance(dataset, DataFrame):
985
- self._deps = self._batch_inference_validate_snowpark(
986
- dataset=dataset,
987
- inference_method="score",
988
- )
991
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
992
+ self._deps = self._get_dependencies()
989
993
  selected_cols = self._get_active_columns()
990
994
  if len(selected_cols) > 0:
991
995
  dataset = dataset.select(selected_cols)
992
996
  assert isinstance(dataset._session, Session) # keep mypy happy
993
997
  transform_kwargs = dict(
994
998
  session=dataset._session,
995
- dependencies=["snowflake-snowpark-python"] + self._deps,
999
+ dependencies=self._deps,
996
1000
  score_sproc_imports=['sklearn'],
997
1001
  )
998
1002
  elif isinstance(dataset, pd.DataFrame):
@@ -1057,11 +1061,8 @@ class TSNE(BaseTransformer):
1057
1061
 
1058
1062
  if isinstance(dataset, DataFrame):
1059
1063
 
1060
- self._deps = self._batch_inference_validate_snowpark(
1061
- dataset=dataset,
1062
- inference_method=inference_method,
1063
-
1064
- )
1064
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1065
+ self._deps = self._get_dependencies()
1065
1066
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1066
1067
  transform_kwargs = dict(
1067
1068
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class BayesianGaussianMixture(BaseTransformer):
70
64
  r"""Variational Bayesian estimation of a Gaussian mixture
71
65
  For more details on this class, see [sklearn.mixture.BayesianGaussianMixture]
@@ -386,20 +380,17 @@ class BayesianGaussianMixture(BaseTransformer):
386
380
  self,
387
381
  dataset: DataFrame,
388
382
  inference_method: str,
389
- ) -> List[str]:
390
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
391
- return the available package that exists in the snowflake anaconda channel
383
+ ) -> None:
384
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
392
385
 
393
386
  Args:
394
387
  dataset: snowpark dataframe
395
388
  inference_method: the inference method such as predict, score...
396
-
389
+
397
390
  Raises:
398
391
  SnowflakeMLException: If the estimator is not fitted, raise error
399
392
  SnowflakeMLException: If the session is None, raise error
400
393
 
401
- Returns:
402
- A list of available package that exists in the snowflake anaconda channel
403
394
  """
404
395
  if not self._is_fitted:
405
396
  raise exceptions.SnowflakeMLException(
@@ -417,9 +408,7 @@ class BayesianGaussianMixture(BaseTransformer):
417
408
  "Session must not specified for snowpark dataset."
418
409
  ),
419
410
  )
420
- # Validate that key package version in user workspace are supported in snowflake conda channel
421
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
422
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
411
+
423
412
 
424
413
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
425
414
  @telemetry.send_api_usage_telemetry(
@@ -467,7 +456,8 @@ class BayesianGaussianMixture(BaseTransformer):
467
456
 
468
457
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
469
458
 
470
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
459
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
460
+ self._deps = self._get_dependencies()
471
461
  assert isinstance(
472
462
  dataset._session, Session
473
463
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -550,10 +540,8 @@ class BayesianGaussianMixture(BaseTransformer):
550
540
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
551
541
  expected_dtype = convert_sp_to_sf_type(output_types[0])
552
542
 
553
- self._deps = self._batch_inference_validate_snowpark(
554
- dataset=dataset,
555
- inference_method=inference_method,
556
- )
543
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
544
+ self._deps = self._get_dependencies()
557
545
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
558
546
 
559
547
  transform_kwargs = dict(
@@ -622,16 +610,40 @@ class BayesianGaussianMixture(BaseTransformer):
622
610
  self._is_fitted = True
623
611
  return output_result
624
612
 
613
+
614
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
615
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
616
+ """ Method not supported for this class.
617
+
625
618
 
626
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
627
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
628
- """
619
+ Raises:
620
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
621
+
622
+ Args:
623
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
624
+ Snowpark or Pandas DataFrame.
625
+ output_cols_prefix: Prefix for the response columns
629
626
  Returns:
630
627
  Transformed dataset.
631
628
  """
632
- self.fit(dataset)
633
- assert self._sklearn_object is not None
634
- return self._sklearn_object.embedding_
629
+ self._infer_input_output_cols(dataset)
630
+ super()._check_dataset_type(dataset)
631
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
632
+ estimator=self._sklearn_object,
633
+ dataset=dataset,
634
+ input_cols=self.input_cols,
635
+ label_cols=self.label_cols,
636
+ sample_weight_col=self.sample_weight_col,
637
+ autogenerated=self._autogenerated,
638
+ subproject=_SUBPROJECT,
639
+ )
640
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
641
+ drop_input_cols=self._drop_input_cols,
642
+ expected_output_cols_list=self.output_cols,
643
+ )
644
+ self._sklearn_object = fitted_estimator
645
+ self._is_fitted = True
646
+ return output_result
635
647
 
636
648
 
637
649
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -724,10 +736,8 @@ class BayesianGaussianMixture(BaseTransformer):
724
736
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
725
737
 
726
738
  if isinstance(dataset, DataFrame):
727
- self._deps = self._batch_inference_validate_snowpark(
728
- dataset=dataset,
729
- inference_method=inference_method,
730
- )
739
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
740
+ self._deps = self._get_dependencies()
731
741
  assert isinstance(
732
742
  dataset._session, Session
733
743
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -794,10 +804,8 @@ class BayesianGaussianMixture(BaseTransformer):
794
804
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
795
805
 
796
806
  if isinstance(dataset, DataFrame):
797
- self._deps = self._batch_inference_validate_snowpark(
798
- dataset=dataset,
799
- inference_method=inference_method,
800
- )
807
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
808
+ self._deps = self._get_dependencies()
801
809
  assert isinstance(
802
810
  dataset._session, Session
803
811
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -859,10 +867,8 @@ class BayesianGaussianMixture(BaseTransformer):
859
867
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
860
868
 
861
869
  if isinstance(dataset, DataFrame):
862
- self._deps = self._batch_inference_validate_snowpark(
863
- dataset=dataset,
864
- inference_method=inference_method,
865
- )
870
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
871
+ self._deps = self._get_dependencies()
866
872
  assert isinstance(
867
873
  dataset._session, Session
868
874
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -930,10 +936,8 @@ class BayesianGaussianMixture(BaseTransformer):
930
936
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
931
937
 
932
938
  if isinstance(dataset, DataFrame):
933
- self._deps = self._batch_inference_validate_snowpark(
934
- dataset=dataset,
935
- inference_method=inference_method,
936
- )
939
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
940
+ self._deps = self._get_dependencies()
937
941
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
938
942
  transform_kwargs = dict(
939
943
  session=dataset._session,
@@ -997,17 +1001,15 @@ class BayesianGaussianMixture(BaseTransformer):
997
1001
  transform_kwargs: ScoreKwargsTypedDict = dict()
998
1002
 
999
1003
  if isinstance(dataset, DataFrame):
1000
- self._deps = self._batch_inference_validate_snowpark(
1001
- dataset=dataset,
1002
- inference_method="score",
1003
- )
1004
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1005
+ self._deps = self._get_dependencies()
1004
1006
  selected_cols = self._get_active_columns()
1005
1007
  if len(selected_cols) > 0:
1006
1008
  dataset = dataset.select(selected_cols)
1007
1009
  assert isinstance(dataset._session, Session) # keep mypy happy
1008
1010
  transform_kwargs = dict(
1009
1011
  session=dataset._session,
1010
- dependencies=["snowflake-snowpark-python"] + self._deps,
1012
+ dependencies=self._deps,
1011
1013
  score_sproc_imports=['sklearn'],
1012
1014
  )
1013
1015
  elif isinstance(dataset, pd.DataFrame):
@@ -1072,11 +1074,8 @@ class BayesianGaussianMixture(BaseTransformer):
1072
1074
 
1073
1075
  if isinstance(dataset, DataFrame):
1074
1076
 
1075
- self._deps = self._batch_inference_validate_snowpark(
1076
- dataset=dataset,
1077
- inference_method=inference_method,
1078
-
1079
- )
1077
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1078
+ self._deps = self._get_dependencies()
1080
1079
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1081
1080
  transform_kwargs = dict(
1082
1081
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GaussianMixture(BaseTransformer):
70
64
  r"""Gaussian Mixture
71
65
  For more details on this class, see [sklearn.mixture.GaussianMixture]
@@ -359,20 +353,17 @@ class GaussianMixture(BaseTransformer):
359
353
  self,
360
354
  dataset: DataFrame,
361
355
  inference_method: str,
362
- ) -> List[str]:
363
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
364
- return the available package that exists in the snowflake anaconda channel
356
+ ) -> None:
357
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
365
358
 
366
359
  Args:
367
360
  dataset: snowpark dataframe
368
361
  inference_method: the inference method such as predict, score...
369
-
362
+
370
363
  Raises:
371
364
  SnowflakeMLException: If the estimator is not fitted, raise error
372
365
  SnowflakeMLException: If the session is None, raise error
373
366
 
374
- Returns:
375
- A list of available package that exists in the snowflake anaconda channel
376
367
  """
377
368
  if not self._is_fitted:
378
369
  raise exceptions.SnowflakeMLException(
@@ -390,9 +381,7 @@ class GaussianMixture(BaseTransformer):
390
381
  "Session must not specified for snowpark dataset."
391
382
  ),
392
383
  )
393
- # Validate that key package version in user workspace are supported in snowflake conda channel
394
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
395
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
384
+
396
385
 
397
386
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
398
387
  @telemetry.send_api_usage_telemetry(
@@ -440,7 +429,8 @@ class GaussianMixture(BaseTransformer):
440
429
 
441
430
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
442
431
 
443
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
432
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
433
+ self._deps = self._get_dependencies()
444
434
  assert isinstance(
445
435
  dataset._session, Session
446
436
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -523,10 +513,8 @@ class GaussianMixture(BaseTransformer):
523
513
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
524
514
  expected_dtype = convert_sp_to_sf_type(output_types[0])
525
515
 
526
- self._deps = self._batch_inference_validate_snowpark(
527
- dataset=dataset,
528
- inference_method=inference_method,
529
- )
516
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
517
+ self._deps = self._get_dependencies()
530
518
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
531
519
 
532
520
  transform_kwargs = dict(
@@ -595,16 +583,40 @@ class GaussianMixture(BaseTransformer):
595
583
  self._is_fitted = True
596
584
  return output_result
597
585
 
586
+
587
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
588
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
589
+ """ Method not supported for this class.
590
+
598
591
 
599
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
600
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
601
- """
592
+ Raises:
593
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
594
+
595
+ Args:
596
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
597
+ Snowpark or Pandas DataFrame.
598
+ output_cols_prefix: Prefix for the response columns
602
599
  Returns:
603
600
  Transformed dataset.
604
601
  """
605
- self.fit(dataset)
606
- assert self._sklearn_object is not None
607
- return self._sklearn_object.embedding_
602
+ self._infer_input_output_cols(dataset)
603
+ super()._check_dataset_type(dataset)
604
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
605
+ estimator=self._sklearn_object,
606
+ dataset=dataset,
607
+ input_cols=self.input_cols,
608
+ label_cols=self.label_cols,
609
+ sample_weight_col=self.sample_weight_col,
610
+ autogenerated=self._autogenerated,
611
+ subproject=_SUBPROJECT,
612
+ )
613
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
614
+ drop_input_cols=self._drop_input_cols,
615
+ expected_output_cols_list=self.output_cols,
616
+ )
617
+ self._sklearn_object = fitted_estimator
618
+ self._is_fitted = True
619
+ return output_result
608
620
 
609
621
 
610
622
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -697,10 +709,8 @@ class GaussianMixture(BaseTransformer):
697
709
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
698
710
 
699
711
  if isinstance(dataset, DataFrame):
700
- self._deps = self._batch_inference_validate_snowpark(
701
- dataset=dataset,
702
- inference_method=inference_method,
703
- )
712
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
713
+ self._deps = self._get_dependencies()
704
714
  assert isinstance(
705
715
  dataset._session, Session
706
716
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -767,10 +777,8 @@ class GaussianMixture(BaseTransformer):
767
777
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
768
778
 
769
779
  if isinstance(dataset, DataFrame):
770
- self._deps = self._batch_inference_validate_snowpark(
771
- dataset=dataset,
772
- inference_method=inference_method,
773
- )
780
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
781
+ self._deps = self._get_dependencies()
774
782
  assert isinstance(
775
783
  dataset._session, Session
776
784
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -832,10 +840,8 @@ class GaussianMixture(BaseTransformer):
832
840
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
833
841
 
834
842
  if isinstance(dataset, DataFrame):
835
- self._deps = self._batch_inference_validate_snowpark(
836
- dataset=dataset,
837
- inference_method=inference_method,
838
- )
843
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
844
+ self._deps = self._get_dependencies()
839
845
  assert isinstance(
840
846
  dataset._session, Session
841
847
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -903,10 +909,8 @@ class GaussianMixture(BaseTransformer):
903
909
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
904
910
 
905
911
  if isinstance(dataset, DataFrame):
906
- self._deps = self._batch_inference_validate_snowpark(
907
- dataset=dataset,
908
- inference_method=inference_method,
909
- )
912
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
913
+ self._deps = self._get_dependencies()
910
914
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
911
915
  transform_kwargs = dict(
912
916
  session=dataset._session,
@@ -970,17 +974,15 @@ class GaussianMixture(BaseTransformer):
970
974
  transform_kwargs: ScoreKwargsTypedDict = dict()
971
975
 
972
976
  if isinstance(dataset, DataFrame):
973
- self._deps = self._batch_inference_validate_snowpark(
974
- dataset=dataset,
975
- inference_method="score",
976
- )
977
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
978
+ self._deps = self._get_dependencies()
977
979
  selected_cols = self._get_active_columns()
978
980
  if len(selected_cols) > 0:
979
981
  dataset = dataset.select(selected_cols)
980
982
  assert isinstance(dataset._session, Session) # keep mypy happy
981
983
  transform_kwargs = dict(
982
984
  session=dataset._session,
983
- dependencies=["snowflake-snowpark-python"] + self._deps,
985
+ dependencies=self._deps,
984
986
  score_sproc_imports=['sklearn'],
985
987
  )
986
988
  elif isinstance(dataset, pd.DataFrame):
@@ -1045,11 +1047,8 @@ class GaussianMixture(BaseTransformer):
1045
1047
 
1046
1048
  if isinstance(dataset, DataFrame):
1047
1049
 
1048
- self._deps = self._batch_inference_validate_snowpark(
1049
- dataset=dataset,
1050
- inference_method=inference_method,
1051
-
1052
- )
1050
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1051
+ self._deps = self._get_dependencies()
1053
1052
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1054
1053
  transform_kwargs = dict(
1055
1054
  session = dataset._session,