snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replac
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GaussianProcessClassifier(BaseTransformer):
70
64
  r"""Gaussian process classification (GPC) based on Laplace approximation
71
65
  For more details on this class, see [sklearn.gaussian_process.GaussianProcessClassifier]
@@ -352,20 +346,17 @@ class GaussianProcessClassifier(BaseTransformer):
352
346
  self,
353
347
  dataset: DataFrame,
354
348
  inference_method: str,
355
- ) -> List[str]:
356
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
357
- return the available package that exists in the snowflake anaconda channel
349
+ ) -> None:
350
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
358
351
 
359
352
  Args:
360
353
  dataset: snowpark dataframe
361
354
  inference_method: the inference method such as predict, score...
362
-
355
+
363
356
  Raises:
364
357
  SnowflakeMLException: If the estimator is not fitted, raise error
365
358
  SnowflakeMLException: If the session is None, raise error
366
359
 
367
- Returns:
368
- A list of available package that exists in the snowflake anaconda channel
369
360
  """
370
361
  if not self._is_fitted:
371
362
  raise exceptions.SnowflakeMLException(
@@ -383,9 +374,7 @@ class GaussianProcessClassifier(BaseTransformer):
383
374
  "Session must not specified for snowpark dataset."
384
375
  ),
385
376
  )
386
- # Validate that key package version in user workspace are supported in snowflake conda channel
387
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
388
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
377
+
389
378
 
390
379
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
391
380
  @telemetry.send_api_usage_telemetry(
@@ -433,7 +422,8 @@ class GaussianProcessClassifier(BaseTransformer):
433
422
 
434
423
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
435
424
 
436
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
425
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
426
+ self._deps = self._get_dependencies()
437
427
  assert isinstance(
438
428
  dataset._session, Session
439
429
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -516,10 +506,8 @@ class GaussianProcessClassifier(BaseTransformer):
516
506
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
517
507
  expected_dtype = convert_sp_to_sf_type(output_types[0])
518
508
 
519
- self._deps = self._batch_inference_validate_snowpark(
520
- dataset=dataset,
521
- inference_method=inference_method,
522
- )
509
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
510
+ self._deps = self._get_dependencies()
523
511
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
524
512
 
525
513
  transform_kwargs = dict(
@@ -586,16 +574,40 @@ class GaussianProcessClassifier(BaseTransformer):
586
574
  self._is_fitted = True
587
575
  return output_result
588
576
 
577
+
578
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
579
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
580
+ """ Method not supported for this class.
589
581
 
590
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
591
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
592
- """
582
+
583
+ Raises:
584
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
585
+
586
+ Args:
587
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
588
+ Snowpark or Pandas DataFrame.
589
+ output_cols_prefix: Prefix for the response columns
593
590
  Returns:
594
591
  Transformed dataset.
595
592
  """
596
- self.fit(dataset)
597
- assert self._sklearn_object is not None
598
- return self._sklearn_object.embedding_
593
+ self._infer_input_output_cols(dataset)
594
+ super()._check_dataset_type(dataset)
595
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
596
+ estimator=self._sklearn_object,
597
+ dataset=dataset,
598
+ input_cols=self.input_cols,
599
+ label_cols=self.label_cols,
600
+ sample_weight_col=self.sample_weight_col,
601
+ autogenerated=self._autogenerated,
602
+ subproject=_SUBPROJECT,
603
+ )
604
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
605
+ drop_input_cols=self._drop_input_cols,
606
+ expected_output_cols_list=self.output_cols,
607
+ )
608
+ self._sklearn_object = fitted_estimator
609
+ self._is_fitted = True
610
+ return output_result
599
611
 
600
612
 
601
613
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -688,10 +700,8 @@ class GaussianProcessClassifier(BaseTransformer):
688
700
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
689
701
 
690
702
  if isinstance(dataset, DataFrame):
691
- self._deps = self._batch_inference_validate_snowpark(
692
- dataset=dataset,
693
- inference_method=inference_method,
694
- )
703
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
704
+ self._deps = self._get_dependencies()
695
705
  assert isinstance(
696
706
  dataset._session, Session
697
707
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -758,10 +768,8 @@ class GaussianProcessClassifier(BaseTransformer):
758
768
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
759
769
 
760
770
  if isinstance(dataset, DataFrame):
761
- self._deps = self._batch_inference_validate_snowpark(
762
- dataset=dataset,
763
- inference_method=inference_method,
764
- )
771
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
772
+ self._deps = self._get_dependencies()
765
773
  assert isinstance(
766
774
  dataset._session, Session
767
775
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -823,10 +831,8 @@ class GaussianProcessClassifier(BaseTransformer):
823
831
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
824
832
 
825
833
  if isinstance(dataset, DataFrame):
826
- self._deps = self._batch_inference_validate_snowpark(
827
- dataset=dataset,
828
- inference_method=inference_method,
829
- )
834
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
835
+ self._deps = self._get_dependencies()
830
836
  assert isinstance(
831
837
  dataset._session, Session
832
838
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -892,10 +898,8 @@ class GaussianProcessClassifier(BaseTransformer):
892
898
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
893
899
 
894
900
  if isinstance(dataset, DataFrame):
895
- self._deps = self._batch_inference_validate_snowpark(
896
- dataset=dataset,
897
- inference_method=inference_method,
898
- )
901
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
902
+ self._deps = self._get_dependencies()
899
903
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
900
904
  transform_kwargs = dict(
901
905
  session=dataset._session,
@@ -959,17 +963,15 @@ class GaussianProcessClassifier(BaseTransformer):
959
963
  transform_kwargs: ScoreKwargsTypedDict = dict()
960
964
 
961
965
  if isinstance(dataset, DataFrame):
962
- self._deps = self._batch_inference_validate_snowpark(
963
- dataset=dataset,
964
- inference_method="score",
965
- )
966
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
967
+ self._deps = self._get_dependencies()
966
968
  selected_cols = self._get_active_columns()
967
969
  if len(selected_cols) > 0:
968
970
  dataset = dataset.select(selected_cols)
969
971
  assert isinstance(dataset._session, Session) # keep mypy happy
970
972
  transform_kwargs = dict(
971
973
  session=dataset._session,
972
- dependencies=["snowflake-snowpark-python"] + self._deps,
974
+ dependencies=self._deps,
973
975
  score_sproc_imports=['sklearn'],
974
976
  )
975
977
  elif isinstance(dataset, pd.DataFrame):
@@ -1034,11 +1036,8 @@ class GaussianProcessClassifier(BaseTransformer):
1034
1036
 
1035
1037
  if isinstance(dataset, DataFrame):
1036
1038
 
1037
- self._deps = self._batch_inference_validate_snowpark(
1038
- dataset=dataset,
1039
- inference_method=inference_method,
1040
-
1041
- )
1039
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1040
+ self._deps = self._get_dependencies()
1042
1041
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1043
1042
  transform_kwargs = dict(
1044
1043
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replac
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GaussianProcessRegressor(BaseTransformer):
70
64
  r"""Gaussian process regression (GPR)
71
65
  For more details on this class, see [sklearn.gaussian_process.GaussianProcessRegressor]
@@ -343,20 +337,17 @@ class GaussianProcessRegressor(BaseTransformer):
343
337
  self,
344
338
  dataset: DataFrame,
345
339
  inference_method: str,
346
- ) -> List[str]:
347
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
348
- return the available package that exists in the snowflake anaconda channel
340
+ ) -> None:
341
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
349
342
 
350
343
  Args:
351
344
  dataset: snowpark dataframe
352
345
  inference_method: the inference method such as predict, score...
353
-
346
+
354
347
  Raises:
355
348
  SnowflakeMLException: If the estimator is not fitted, raise error
356
349
  SnowflakeMLException: If the session is None, raise error
357
350
 
358
- Returns:
359
- A list of available package that exists in the snowflake anaconda channel
360
351
  """
361
352
  if not self._is_fitted:
362
353
  raise exceptions.SnowflakeMLException(
@@ -374,9 +365,7 @@ class GaussianProcessRegressor(BaseTransformer):
374
365
  "Session must not specified for snowpark dataset."
375
366
  ),
376
367
  )
377
- # Validate that key package version in user workspace are supported in snowflake conda channel
378
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
379
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
368
+
380
369
 
381
370
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
382
371
  @telemetry.send_api_usage_telemetry(
@@ -424,7 +413,8 @@ class GaussianProcessRegressor(BaseTransformer):
424
413
 
425
414
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
426
415
 
427
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
416
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
417
+ self._deps = self._get_dependencies()
428
418
  assert isinstance(
429
419
  dataset._session, Session
430
420
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -507,10 +497,8 @@ class GaussianProcessRegressor(BaseTransformer):
507
497
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
508
498
  expected_dtype = convert_sp_to_sf_type(output_types[0])
509
499
 
510
- self._deps = self._batch_inference_validate_snowpark(
511
- dataset=dataset,
512
- inference_method=inference_method,
513
- )
500
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
501
+ self._deps = self._get_dependencies()
514
502
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
515
503
 
516
504
  transform_kwargs = dict(
@@ -577,16 +565,40 @@ class GaussianProcessRegressor(BaseTransformer):
577
565
  self._is_fitted = True
578
566
  return output_result
579
567
 
568
+
569
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
570
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
571
+ """ Method not supported for this class.
580
572
 
581
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
582
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
583
- """
573
+
574
+ Raises:
575
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
576
+
577
+ Args:
578
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
579
+ Snowpark or Pandas DataFrame.
580
+ output_cols_prefix: Prefix for the response columns
584
581
  Returns:
585
582
  Transformed dataset.
586
583
  """
587
- self.fit(dataset)
588
- assert self._sklearn_object is not None
589
- return self._sklearn_object.embedding_
584
+ self._infer_input_output_cols(dataset)
585
+ super()._check_dataset_type(dataset)
586
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
587
+ estimator=self._sklearn_object,
588
+ dataset=dataset,
589
+ input_cols=self.input_cols,
590
+ label_cols=self.label_cols,
591
+ sample_weight_col=self.sample_weight_col,
592
+ autogenerated=self._autogenerated,
593
+ subproject=_SUBPROJECT,
594
+ )
595
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
596
+ drop_input_cols=self._drop_input_cols,
597
+ expected_output_cols_list=self.output_cols,
598
+ )
599
+ self._sklearn_object = fitted_estimator
600
+ self._is_fitted = True
601
+ return output_result
590
602
 
591
603
 
592
604
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -677,10 +689,8 @@ class GaussianProcessRegressor(BaseTransformer):
677
689
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
678
690
 
679
691
  if isinstance(dataset, DataFrame):
680
- self._deps = self._batch_inference_validate_snowpark(
681
- dataset=dataset,
682
- inference_method=inference_method,
683
- )
692
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
693
+ self._deps = self._get_dependencies()
684
694
  assert isinstance(
685
695
  dataset._session, Session
686
696
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -745,10 +755,8 @@ class GaussianProcessRegressor(BaseTransformer):
745
755
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
746
756
 
747
757
  if isinstance(dataset, DataFrame):
748
- self._deps = self._batch_inference_validate_snowpark(
749
- dataset=dataset,
750
- inference_method=inference_method,
751
- )
758
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
759
+ self._deps = self._get_dependencies()
752
760
  assert isinstance(
753
761
  dataset._session, Session
754
762
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -810,10 +818,8 @@ class GaussianProcessRegressor(BaseTransformer):
810
818
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
811
819
 
812
820
  if isinstance(dataset, DataFrame):
813
- self._deps = self._batch_inference_validate_snowpark(
814
- dataset=dataset,
815
- inference_method=inference_method,
816
- )
821
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
822
+ self._deps = self._get_dependencies()
817
823
  assert isinstance(
818
824
  dataset._session, Session
819
825
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -879,10 +885,8 @@ class GaussianProcessRegressor(BaseTransformer):
879
885
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
880
886
 
881
887
  if isinstance(dataset, DataFrame):
882
- self._deps = self._batch_inference_validate_snowpark(
883
- dataset=dataset,
884
- inference_method=inference_method,
885
- )
888
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
889
+ self._deps = self._get_dependencies()
886
890
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
887
891
  transform_kwargs = dict(
888
892
  session=dataset._session,
@@ -946,17 +950,15 @@ class GaussianProcessRegressor(BaseTransformer):
946
950
  transform_kwargs: ScoreKwargsTypedDict = dict()
947
951
 
948
952
  if isinstance(dataset, DataFrame):
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method="score",
952
- )
953
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
954
+ self._deps = self._get_dependencies()
953
955
  selected_cols = self._get_active_columns()
954
956
  if len(selected_cols) > 0:
955
957
  dataset = dataset.select(selected_cols)
956
958
  assert isinstance(dataset._session, Session) # keep mypy happy
957
959
  transform_kwargs = dict(
958
960
  session=dataset._session,
959
- dependencies=["snowflake-snowpark-python"] + self._deps,
961
+ dependencies=self._deps,
960
962
  score_sproc_imports=['sklearn'],
961
963
  )
962
964
  elif isinstance(dataset, pd.DataFrame):
@@ -1021,11 +1023,8 @@ class GaussianProcessRegressor(BaseTransformer):
1021
1023
 
1022
1024
  if isinstance(dataset, DataFrame):
1023
1025
 
1024
- self._deps = self._batch_inference_validate_snowpark(
1025
- dataset=dataset,
1026
- inference_method=inference_method,
1027
-
1028
- )
1026
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1027
+ self._deps = self._get_dependencies()
1029
1028
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1030
1029
  transform_kwargs = dict(
1031
1030
  session = dataset._session,
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn
61
61
 
62
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
63
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
64
  class IterativeImputer(BaseTransformer):
71
65
  r"""Multivariate imputer that estimates each feature from all the others
72
66
  For more details on this class, see [sklearn.impute.IterativeImputer]
@@ -385,20 +379,17 @@ class IterativeImputer(BaseTransformer):
385
379
  self,
386
380
  dataset: DataFrame,
387
381
  inference_method: str,
388
- ) -> List[str]:
389
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
390
- return the available package that exists in the snowflake anaconda channel
382
+ ) -> None:
383
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
391
384
 
392
385
  Args:
393
386
  dataset: snowpark dataframe
394
387
  inference_method: the inference method such as predict, score...
395
-
388
+
396
389
  Raises:
397
390
  SnowflakeMLException: If the estimator is not fitted, raise error
398
391
  SnowflakeMLException: If the session is None, raise error
399
392
 
400
- Returns:
401
- A list of available package that exists in the snowflake anaconda channel
402
393
  """
403
394
  if not self._is_fitted:
404
395
  raise exceptions.SnowflakeMLException(
@@ -416,9 +407,7 @@ class IterativeImputer(BaseTransformer):
416
407
  "Session must not specified for snowpark dataset."
417
408
  ),
418
409
  )
419
- # Validate that key package version in user workspace are supported in snowflake conda channel
420
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
421
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
410
+
422
411
 
423
412
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
424
413
  @telemetry.send_api_usage_telemetry(
@@ -464,7 +453,8 @@ class IterativeImputer(BaseTransformer):
464
453
 
465
454
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
466
455
 
467
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
456
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
457
+ self._deps = self._get_dependencies()
468
458
  assert isinstance(
469
459
  dataset._session, Session
470
460
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -549,10 +539,8 @@ class IterativeImputer(BaseTransformer):
549
539
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
550
540
  expected_dtype = convert_sp_to_sf_type(output_types[0])
551
541
 
552
- self._deps = self._batch_inference_validate_snowpark(
553
- dataset=dataset,
554
- inference_method=inference_method,
555
- )
542
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
543
+ self._deps = self._get_dependencies()
556
544
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
557
545
 
558
546
  transform_kwargs = dict(
@@ -619,16 +607,42 @@ class IterativeImputer(BaseTransformer):
619
607
  self._is_fitted = True
620
608
  return output_result
621
609
 
610
+
611
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
612
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
613
+ """ Fit the imputer on `X` and return the transformed `X`
614
+ For more details on this function, see [sklearn.impute.IterativeImputer.fit_transform]
615
+ (https://scikit-learn.org/stable/modules/generated/sklearn.impute.IterativeImputer.html#sklearn.impute.IterativeImputer.fit_transform)
616
+
622
617
 
623
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
624
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
625
- """
618
+ Raises:
619
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
620
+
621
+ Args:
622
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
623
+ Snowpark or Pandas DataFrame.
624
+ output_cols_prefix: Prefix for the response columns
626
625
  Returns:
627
626
  Transformed dataset.
628
627
  """
629
- self.fit(dataset)
630
- assert self._sklearn_object is not None
631
- return self._sklearn_object.embedding_
628
+ self._infer_input_output_cols(dataset)
629
+ super()._check_dataset_type(dataset)
630
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
631
+ estimator=self._sklearn_object,
632
+ dataset=dataset,
633
+ input_cols=self.input_cols,
634
+ label_cols=self.label_cols,
635
+ sample_weight_col=self.sample_weight_col,
636
+ autogenerated=self._autogenerated,
637
+ subproject=_SUBPROJECT,
638
+ )
639
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
640
+ drop_input_cols=self._drop_input_cols,
641
+ expected_output_cols_list=self.output_cols,
642
+ )
643
+ self._sklearn_object = fitted_estimator
644
+ self._is_fitted = True
645
+ return output_result
632
646
 
633
647
 
634
648
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -719,10 +733,8 @@ class IterativeImputer(BaseTransformer):
719
733
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
720
734
 
721
735
  if isinstance(dataset, DataFrame):
722
- self._deps = self._batch_inference_validate_snowpark(
723
- dataset=dataset,
724
- inference_method=inference_method,
725
- )
736
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
737
+ self._deps = self._get_dependencies()
726
738
  assert isinstance(
727
739
  dataset._session, Session
728
740
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -787,10 +799,8 @@ class IterativeImputer(BaseTransformer):
787
799
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
788
800
 
789
801
  if isinstance(dataset, DataFrame):
790
- self._deps = self._batch_inference_validate_snowpark(
791
- dataset=dataset,
792
- inference_method=inference_method,
793
- )
802
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
803
+ self._deps = self._get_dependencies()
794
804
  assert isinstance(
795
805
  dataset._session, Session
796
806
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -852,10 +862,8 @@ class IterativeImputer(BaseTransformer):
852
862
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
853
863
 
854
864
  if isinstance(dataset, DataFrame):
855
- self._deps = self._batch_inference_validate_snowpark(
856
- dataset=dataset,
857
- inference_method=inference_method,
858
- )
865
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
866
+ self._deps = self._get_dependencies()
859
867
  assert isinstance(
860
868
  dataset._session, Session
861
869
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -921,10 +929,8 @@ class IterativeImputer(BaseTransformer):
921
929
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
922
930
 
923
931
  if isinstance(dataset, DataFrame):
924
- self._deps = self._batch_inference_validate_snowpark(
925
- dataset=dataset,
926
- inference_method=inference_method,
927
- )
932
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
933
+ self._deps = self._get_dependencies()
928
934
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
929
935
  transform_kwargs = dict(
930
936
  session=dataset._session,
@@ -986,17 +992,15 @@ class IterativeImputer(BaseTransformer):
986
992
  transform_kwargs: ScoreKwargsTypedDict = dict()
987
993
 
988
994
  if isinstance(dataset, DataFrame):
989
- self._deps = self._batch_inference_validate_snowpark(
990
- dataset=dataset,
991
- inference_method="score",
992
- )
995
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
996
+ self._deps = self._get_dependencies()
993
997
  selected_cols = self._get_active_columns()
994
998
  if len(selected_cols) > 0:
995
999
  dataset = dataset.select(selected_cols)
996
1000
  assert isinstance(dataset._session, Session) # keep mypy happy
997
1001
  transform_kwargs = dict(
998
1002
  session=dataset._session,
999
- dependencies=["snowflake-snowpark-python"] + self._deps,
1003
+ dependencies=self._deps,
1000
1004
  score_sproc_imports=['sklearn'],
1001
1005
  )
1002
1006
  elif isinstance(dataset, pd.DataFrame):
@@ -1061,11 +1065,8 @@ class IterativeImputer(BaseTransformer):
1061
1065
 
1062
1066
  if isinstance(dataset, DataFrame):
1063
1067
 
1064
- self._deps = self._batch_inference_validate_snowpark(
1065
- dataset=dataset,
1066
- inference_method=inference_method,
1067
-
1068
- )
1068
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1069
+ self._deps = self._get_dependencies()
1069
1070
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1070
1071
  transform_kwargs = dict(
1071
1072
  session = dataset._session,