snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class Isomap(BaseTransformer):
70
64
  r"""Isomap Embedding
71
65
  For more details on this class, see [sklearn.manifold.Isomap]
@@ -339,20 +333,17 @@ class Isomap(BaseTransformer):
339
333
  self,
340
334
  dataset: DataFrame,
341
335
  inference_method: str,
342
- ) -> List[str]:
343
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
344
- return the available package that exists in the snowflake anaconda channel
336
+ ) -> None:
337
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
345
338
 
346
339
  Args:
347
340
  dataset: snowpark dataframe
348
341
  inference_method: the inference method such as predict, score...
349
-
342
+
350
343
  Raises:
351
344
  SnowflakeMLException: If the estimator is not fitted, raise error
352
345
  SnowflakeMLException: If the session is None, raise error
353
346
 
354
- Returns:
355
- A list of available package that exists in the snowflake anaconda channel
356
347
  """
357
348
  if not self._is_fitted:
358
349
  raise exceptions.SnowflakeMLException(
@@ -370,9 +361,7 @@ class Isomap(BaseTransformer):
370
361
  "Session must not specified for snowpark dataset."
371
362
  ),
372
363
  )
373
- # Validate that key package version in user workspace are supported in snowflake conda channel
374
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
375
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
364
+
376
365
 
377
366
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
378
367
  @telemetry.send_api_usage_telemetry(
@@ -418,7 +407,8 @@ class Isomap(BaseTransformer):
418
407
 
419
408
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
420
409
 
421
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
410
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
411
+ self._deps = self._get_dependencies()
422
412
  assert isinstance(
423
413
  dataset._session, Session
424
414
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -503,10 +493,8 @@ class Isomap(BaseTransformer):
503
493
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
504
494
  expected_dtype = convert_sp_to_sf_type(output_types[0])
505
495
 
506
- self._deps = self._batch_inference_validate_snowpark(
507
- dataset=dataset,
508
- inference_method=inference_method,
509
- )
496
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
497
+ self._deps = self._get_dependencies()
510
498
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
511
499
 
512
500
  transform_kwargs = dict(
@@ -573,16 +561,42 @@ class Isomap(BaseTransformer):
573
561
  self._is_fitted = True
574
562
  return output_result
575
563
 
564
+
565
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
566
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
567
+ """ Fit the model from data in X and transform X
568
+ For more details on this function, see [sklearn.manifold.Isomap.fit_transform]
569
+ (https://scikit-learn.org/stable/modules/generated/sklearn.manifold.Isomap.html#sklearn.manifold.Isomap.fit_transform)
570
+
576
571
 
577
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
578
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
579
- """
572
+ Raises:
573
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
574
+
575
+ Args:
576
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
577
+ Snowpark or Pandas DataFrame.
578
+ output_cols_prefix: Prefix for the response columns
580
579
  Returns:
581
580
  Transformed dataset.
582
581
  """
583
- self.fit(dataset)
584
- assert self._sklearn_object is not None
585
- return self._sklearn_object.embedding_
582
+ self._infer_input_output_cols(dataset)
583
+ super()._check_dataset_type(dataset)
584
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
585
+ estimator=self._sklearn_object,
586
+ dataset=dataset,
587
+ input_cols=self.input_cols,
588
+ label_cols=self.label_cols,
589
+ sample_weight_col=self.sample_weight_col,
590
+ autogenerated=self._autogenerated,
591
+ subproject=_SUBPROJECT,
592
+ )
593
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
594
+ drop_input_cols=self._drop_input_cols,
595
+ expected_output_cols_list=self.output_cols,
596
+ )
597
+ self._sklearn_object = fitted_estimator
598
+ self._is_fitted = True
599
+ return output_result
586
600
 
587
601
 
588
602
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -673,10 +687,8 @@ class Isomap(BaseTransformer):
673
687
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
674
688
 
675
689
  if isinstance(dataset, DataFrame):
676
- self._deps = self._batch_inference_validate_snowpark(
677
- dataset=dataset,
678
- inference_method=inference_method,
679
- )
690
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
691
+ self._deps = self._get_dependencies()
680
692
  assert isinstance(
681
693
  dataset._session, Session
682
694
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -741,10 +753,8 @@ class Isomap(BaseTransformer):
741
753
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
742
754
 
743
755
  if isinstance(dataset, DataFrame):
744
- self._deps = self._batch_inference_validate_snowpark(
745
- dataset=dataset,
746
- inference_method=inference_method,
747
- )
756
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
757
+ self._deps = self._get_dependencies()
748
758
  assert isinstance(
749
759
  dataset._session, Session
750
760
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -806,10 +816,8 @@ class Isomap(BaseTransformer):
806
816
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
807
817
 
808
818
  if isinstance(dataset, DataFrame):
809
- self._deps = self._batch_inference_validate_snowpark(
810
- dataset=dataset,
811
- inference_method=inference_method,
812
- )
819
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
820
+ self._deps = self._get_dependencies()
813
821
  assert isinstance(
814
822
  dataset._session, Session
815
823
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -875,10 +883,8 @@ class Isomap(BaseTransformer):
875
883
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
876
884
 
877
885
  if isinstance(dataset, DataFrame):
878
- self._deps = self._batch_inference_validate_snowpark(
879
- dataset=dataset,
880
- inference_method=inference_method,
881
- )
886
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
887
+ self._deps = self._get_dependencies()
882
888
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
883
889
  transform_kwargs = dict(
884
890
  session=dataset._session,
@@ -940,17 +946,15 @@ class Isomap(BaseTransformer):
940
946
  transform_kwargs: ScoreKwargsTypedDict = dict()
941
947
 
942
948
  if isinstance(dataset, DataFrame):
943
- self._deps = self._batch_inference_validate_snowpark(
944
- dataset=dataset,
945
- inference_method="score",
946
- )
949
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
950
+ self._deps = self._get_dependencies()
947
951
  selected_cols = self._get_active_columns()
948
952
  if len(selected_cols) > 0:
949
953
  dataset = dataset.select(selected_cols)
950
954
  assert isinstance(dataset._session, Session) # keep mypy happy
951
955
  transform_kwargs = dict(
952
956
  session=dataset._session,
953
- dependencies=["snowflake-snowpark-python"] + self._deps,
957
+ dependencies=self._deps,
954
958
  score_sproc_imports=['sklearn'],
955
959
  )
956
960
  elif isinstance(dataset, pd.DataFrame):
@@ -1015,11 +1019,8 @@ class Isomap(BaseTransformer):
1015
1019
 
1016
1020
  if isinstance(dataset, DataFrame):
1017
1021
 
1018
- self._deps = self._batch_inference_validate_snowpark(
1019
- dataset=dataset,
1020
- inference_method=inference_method,
1021
-
1022
- )
1022
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1023
+ self._deps = self._get_dependencies()
1023
1024
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1024
1025
  transform_kwargs = dict(
1025
1026
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MDS(BaseTransformer):
70
64
  r"""Multidimensional scaling
71
65
  For more details on this class, see [sklearn.manifold.MDS]
@@ -322,20 +316,17 @@ class MDS(BaseTransformer):
322
316
  self,
323
317
  dataset: DataFrame,
324
318
  inference_method: str,
325
- ) -> List[str]:
326
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
327
- return the available package that exists in the snowflake anaconda channel
319
+ ) -> None:
320
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
328
321
 
329
322
  Args:
330
323
  dataset: snowpark dataframe
331
324
  inference_method: the inference method such as predict, score...
332
-
325
+
333
326
  Raises:
334
327
  SnowflakeMLException: If the estimator is not fitted, raise error
335
328
  SnowflakeMLException: If the session is None, raise error
336
329
 
337
- Returns:
338
- A list of available package that exists in the snowflake anaconda channel
339
330
  """
340
331
  if not self._is_fitted:
341
332
  raise exceptions.SnowflakeMLException(
@@ -353,9 +344,7 @@ class MDS(BaseTransformer):
353
344
  "Session must not specified for snowpark dataset."
354
345
  ),
355
346
  )
356
- # Validate that key package version in user workspace are supported in snowflake conda channel
357
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
358
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
347
+
359
348
 
360
349
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
361
350
  @telemetry.send_api_usage_telemetry(
@@ -401,7 +390,8 @@ class MDS(BaseTransformer):
401
390
 
402
391
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
403
392
 
404
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
393
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
394
+ self._deps = self._get_dependencies()
405
395
  assert isinstance(
406
396
  dataset._session, Session
407
397
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -484,10 +474,8 @@ class MDS(BaseTransformer):
484
474
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
485
475
  expected_dtype = convert_sp_to_sf_type(output_types[0])
486
476
 
487
- self._deps = self._batch_inference_validate_snowpark(
488
- dataset=dataset,
489
- inference_method=inference_method,
490
- )
477
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
478
+ self._deps = self._get_dependencies()
491
479
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
492
480
 
493
481
  transform_kwargs = dict(
@@ -554,16 +542,42 @@ class MDS(BaseTransformer):
554
542
  self._is_fitted = True
555
543
  return output_result
556
544
 
545
+
546
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
547
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
548
+ """ Fit the data from `X`, and returns the embedded coordinates
549
+ For more details on this function, see [sklearn.manifold.MDS.fit_transform]
550
+ (https://scikit-learn.org/stable/modules/generated/sklearn.manifold.MDS.html#sklearn.manifold.MDS.fit_transform)
551
+
557
552
 
558
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
559
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
560
- """
553
+ Raises:
554
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
555
+
556
+ Args:
557
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
558
+ Snowpark or Pandas DataFrame.
559
+ output_cols_prefix: Prefix for the response columns
561
560
  Returns:
562
561
  Transformed dataset.
563
562
  """
564
- self.fit(dataset)
565
- assert self._sklearn_object is not None
566
- return self._sklearn_object.embedding_
563
+ self._infer_input_output_cols(dataset)
564
+ super()._check_dataset_type(dataset)
565
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
566
+ estimator=self._sklearn_object,
567
+ dataset=dataset,
568
+ input_cols=self.input_cols,
569
+ label_cols=self.label_cols,
570
+ sample_weight_col=self.sample_weight_col,
571
+ autogenerated=self._autogenerated,
572
+ subproject=_SUBPROJECT,
573
+ )
574
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=self.output_cols,
577
+ )
578
+ self._sklearn_object = fitted_estimator
579
+ self._is_fitted = True
580
+ return output_result
567
581
 
568
582
 
569
583
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -654,10 +668,8 @@ class MDS(BaseTransformer):
654
668
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
655
669
 
656
670
  if isinstance(dataset, DataFrame):
657
- self._deps = self._batch_inference_validate_snowpark(
658
- dataset=dataset,
659
- inference_method=inference_method,
660
- )
671
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
672
+ self._deps = self._get_dependencies()
661
673
  assert isinstance(
662
674
  dataset._session, Session
663
675
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -722,10 +734,8 @@ class MDS(BaseTransformer):
722
734
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
723
735
 
724
736
  if isinstance(dataset, DataFrame):
725
- self._deps = self._batch_inference_validate_snowpark(
726
- dataset=dataset,
727
- inference_method=inference_method,
728
- )
737
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
738
+ self._deps = self._get_dependencies()
729
739
  assert isinstance(
730
740
  dataset._session, Session
731
741
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -787,10 +797,8 @@ class MDS(BaseTransformer):
787
797
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
788
798
 
789
799
  if isinstance(dataset, DataFrame):
790
- self._deps = self._batch_inference_validate_snowpark(
791
- dataset=dataset,
792
- inference_method=inference_method,
793
- )
800
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
801
+ self._deps = self._get_dependencies()
794
802
  assert isinstance(
795
803
  dataset._session, Session
796
804
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -856,10 +864,8 @@ class MDS(BaseTransformer):
856
864
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
857
865
 
858
866
  if isinstance(dataset, DataFrame):
859
- self._deps = self._batch_inference_validate_snowpark(
860
- dataset=dataset,
861
- inference_method=inference_method,
862
- )
867
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
868
+ self._deps = self._get_dependencies()
863
869
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
864
870
  transform_kwargs = dict(
865
871
  session=dataset._session,
@@ -921,17 +927,15 @@ class MDS(BaseTransformer):
921
927
  transform_kwargs: ScoreKwargsTypedDict = dict()
922
928
 
923
929
  if isinstance(dataset, DataFrame):
924
- self._deps = self._batch_inference_validate_snowpark(
925
- dataset=dataset,
926
- inference_method="score",
927
- )
930
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
931
+ self._deps = self._get_dependencies()
928
932
  selected_cols = self._get_active_columns()
929
933
  if len(selected_cols) > 0:
930
934
  dataset = dataset.select(selected_cols)
931
935
  assert isinstance(dataset._session, Session) # keep mypy happy
932
936
  transform_kwargs = dict(
933
937
  session=dataset._session,
934
- dependencies=["snowflake-snowpark-python"] + self._deps,
938
+ dependencies=self._deps,
935
939
  score_sproc_imports=['sklearn'],
936
940
  )
937
941
  elif isinstance(dataset, pd.DataFrame):
@@ -996,11 +1000,8 @@ class MDS(BaseTransformer):
996
1000
 
997
1001
  if isinstance(dataset, DataFrame):
998
1002
 
999
- self._deps = self._batch_inference_validate_snowpark(
1000
- dataset=dataset,
1001
- inference_method=inference_method,
1002
-
1003
- )
1003
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1004
+ self._deps = self._get_dependencies()
1004
1005
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1005
1006
  transform_kwargs = dict(
1006
1007
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SpectralEmbedding(BaseTransformer):
70
64
  r"""Spectral embedding for non-linear dimensionality reduction
71
65
  For more details on this class, see [sklearn.manifold.SpectralEmbedding]
@@ -324,20 +318,17 @@ class SpectralEmbedding(BaseTransformer):
324
318
  self,
325
319
  dataset: DataFrame,
326
320
  inference_method: str,
327
- ) -> List[str]:
328
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
329
- return the available package that exists in the snowflake anaconda channel
321
+ ) -> None:
322
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
330
323
 
331
324
  Args:
332
325
  dataset: snowpark dataframe
333
326
  inference_method: the inference method such as predict, score...
334
-
327
+
335
328
  Raises:
336
329
  SnowflakeMLException: If the estimator is not fitted, raise error
337
330
  SnowflakeMLException: If the session is None, raise error
338
331
 
339
- Returns:
340
- A list of available package that exists in the snowflake anaconda channel
341
332
  """
342
333
  if not self._is_fitted:
343
334
  raise exceptions.SnowflakeMLException(
@@ -355,9 +346,7 @@ class SpectralEmbedding(BaseTransformer):
355
346
  "Session must not specified for snowpark dataset."
356
347
  ),
357
348
  )
358
- # Validate that key package version in user workspace are supported in snowflake conda channel
359
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
360
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
349
+
361
350
 
362
351
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
363
352
  @telemetry.send_api_usage_telemetry(
@@ -403,7 +392,8 @@ class SpectralEmbedding(BaseTransformer):
403
392
 
404
393
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
405
394
 
406
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
395
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
396
+ self._deps = self._get_dependencies()
407
397
  assert isinstance(
408
398
  dataset._session, Session
409
399
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -486,10 +476,8 @@ class SpectralEmbedding(BaseTransformer):
486
476
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
487
477
  expected_dtype = convert_sp_to_sf_type(output_types[0])
488
478
 
489
- self._deps = self._batch_inference_validate_snowpark(
490
- dataset=dataset,
491
- inference_method=inference_method,
492
- )
479
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
480
+ self._deps = self._get_dependencies()
493
481
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
494
482
 
495
483
  transform_kwargs = dict(
@@ -556,16 +544,42 @@ class SpectralEmbedding(BaseTransformer):
556
544
  self._is_fitted = True
557
545
  return output_result
558
546
 
547
+
548
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
549
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
550
+ """ Fit the model from data in X and transform X
551
+ For more details on this function, see [sklearn.manifold.SpectralEmbedding.fit_transform]
552
+ (https://scikit-learn.org/stable/modules/generated/sklearn.manifold.SpectralEmbedding.html#sklearn.manifold.SpectralEmbedding.fit_transform)
553
+
559
554
 
560
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
561
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
562
- """
555
+ Raises:
556
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
557
+
558
+ Args:
559
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
560
+ Snowpark or Pandas DataFrame.
561
+ output_cols_prefix: Prefix for the response columns
563
562
  Returns:
564
563
  Transformed dataset.
565
564
  """
566
- self.fit(dataset)
567
- assert self._sklearn_object is not None
568
- return self._sklearn_object.embedding_
565
+ self._infer_input_output_cols(dataset)
566
+ super()._check_dataset_type(dataset)
567
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
568
+ estimator=self._sklearn_object,
569
+ dataset=dataset,
570
+ input_cols=self.input_cols,
571
+ label_cols=self.label_cols,
572
+ sample_weight_col=self.sample_weight_col,
573
+ autogenerated=self._autogenerated,
574
+ subproject=_SUBPROJECT,
575
+ )
576
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
577
+ drop_input_cols=self._drop_input_cols,
578
+ expected_output_cols_list=self.output_cols,
579
+ )
580
+ self._sklearn_object = fitted_estimator
581
+ self._is_fitted = True
582
+ return output_result
569
583
 
570
584
 
571
585
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -656,10 +670,8 @@ class SpectralEmbedding(BaseTransformer):
656
670
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
657
671
 
658
672
  if isinstance(dataset, DataFrame):
659
- self._deps = self._batch_inference_validate_snowpark(
660
- dataset=dataset,
661
- inference_method=inference_method,
662
- )
673
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
674
+ self._deps = self._get_dependencies()
663
675
  assert isinstance(
664
676
  dataset._session, Session
665
677
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -724,10 +736,8 @@ class SpectralEmbedding(BaseTransformer):
724
736
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
725
737
 
726
738
  if isinstance(dataset, DataFrame):
727
- self._deps = self._batch_inference_validate_snowpark(
728
- dataset=dataset,
729
- inference_method=inference_method,
730
- )
739
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
740
+ self._deps = self._get_dependencies()
731
741
  assert isinstance(
732
742
  dataset._session, Session
733
743
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -789,10 +799,8 @@ class SpectralEmbedding(BaseTransformer):
789
799
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
790
800
 
791
801
  if isinstance(dataset, DataFrame):
792
- self._deps = self._batch_inference_validate_snowpark(
793
- dataset=dataset,
794
- inference_method=inference_method,
795
- )
802
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
803
+ self._deps = self._get_dependencies()
796
804
  assert isinstance(
797
805
  dataset._session, Session
798
806
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -858,10 +866,8 @@ class SpectralEmbedding(BaseTransformer):
858
866
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
859
867
 
860
868
  if isinstance(dataset, DataFrame):
861
- self._deps = self._batch_inference_validate_snowpark(
862
- dataset=dataset,
863
- inference_method=inference_method,
864
- )
869
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
870
+ self._deps = self._get_dependencies()
865
871
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
866
872
  transform_kwargs = dict(
867
873
  session=dataset._session,
@@ -923,17 +929,15 @@ class SpectralEmbedding(BaseTransformer):
923
929
  transform_kwargs: ScoreKwargsTypedDict = dict()
924
930
 
925
931
  if isinstance(dataset, DataFrame):
926
- self._deps = self._batch_inference_validate_snowpark(
927
- dataset=dataset,
928
- inference_method="score",
929
- )
932
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
933
+ self._deps = self._get_dependencies()
930
934
  selected_cols = self._get_active_columns()
931
935
  if len(selected_cols) > 0:
932
936
  dataset = dataset.select(selected_cols)
933
937
  assert isinstance(dataset._session, Session) # keep mypy happy
934
938
  transform_kwargs = dict(
935
939
  session=dataset._session,
936
- dependencies=["snowflake-snowpark-python"] + self._deps,
940
+ dependencies=self._deps,
937
941
  score_sproc_imports=['sklearn'],
938
942
  )
939
943
  elif isinstance(dataset, pd.DataFrame):
@@ -998,11 +1002,8 @@ class SpectralEmbedding(BaseTransformer):
998
1002
 
999
1003
  if isinstance(dataset, DataFrame):
1000
1004
 
1001
- self._deps = self._batch_inference_validate_snowpark(
1002
- dataset=dataset,
1003
- inference_method=inference_method,
1004
-
1005
- )
1005
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1006
+ self._deps = self._get_dependencies()
1006
1007
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1007
1008
  transform_kwargs = dict(
1008
1009
  session = dataset._session,