snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class AdaBoostRegressor(BaseTransformer):
70
64
  r"""An AdaBoost regressor
71
65
  For more details on this class, see [sklearn.ensemble.AdaBoostRegressor]
@@ -302,20 +296,17 @@ class AdaBoostRegressor(BaseTransformer):
302
296
  self,
303
297
  dataset: DataFrame,
304
298
  inference_method: str,
305
- ) -> List[str]:
306
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
307
- return the available package that exists in the snowflake anaconda channel
299
+ ) -> None:
300
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
308
301
 
309
302
  Args:
310
303
  dataset: snowpark dataframe
311
304
  inference_method: the inference method such as predict, score...
312
-
305
+
313
306
  Raises:
314
307
  SnowflakeMLException: If the estimator is not fitted, raise error
315
308
  SnowflakeMLException: If the session is None, raise error
316
309
 
317
- Returns:
318
- A list of available package that exists in the snowflake anaconda channel
319
310
  """
320
311
  if not self._is_fitted:
321
312
  raise exceptions.SnowflakeMLException(
@@ -333,9 +324,7 @@ class AdaBoostRegressor(BaseTransformer):
333
324
  "Session must not specified for snowpark dataset."
334
325
  ),
335
326
  )
336
- # Validate that key package version in user workspace are supported in snowflake conda channel
337
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
338
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
327
+
339
328
 
340
329
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
341
330
  @telemetry.send_api_usage_telemetry(
@@ -383,7 +372,8 @@ class AdaBoostRegressor(BaseTransformer):
383
372
 
384
373
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
385
374
 
386
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
375
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
376
+ self._deps = self._get_dependencies()
387
377
  assert isinstance(
388
378
  dataset._session, Session
389
379
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -466,10 +456,8 @@ class AdaBoostRegressor(BaseTransformer):
466
456
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
467
457
  expected_dtype = convert_sp_to_sf_type(output_types[0])
468
458
 
469
- self._deps = self._batch_inference_validate_snowpark(
470
- dataset=dataset,
471
- inference_method=inference_method,
472
- )
459
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
460
+ self._deps = self._get_dependencies()
473
461
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
474
462
 
475
463
  transform_kwargs = dict(
@@ -536,16 +524,40 @@ class AdaBoostRegressor(BaseTransformer):
536
524
  self._is_fitted = True
537
525
  return output_result
538
526
 
527
+
528
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
529
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
530
+ """ Method not supported for this class.
539
531
 
540
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
541
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
542
- """
532
+
533
+ Raises:
534
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
535
+
536
+ Args:
537
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
538
+ Snowpark or Pandas DataFrame.
539
+ output_cols_prefix: Prefix for the response columns
543
540
  Returns:
544
541
  Transformed dataset.
545
542
  """
546
- self.fit(dataset)
547
- assert self._sklearn_object is not None
548
- return self._sklearn_object.embedding_
543
+ self._infer_input_output_cols(dataset)
544
+ super()._check_dataset_type(dataset)
545
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
546
+ estimator=self._sklearn_object,
547
+ dataset=dataset,
548
+ input_cols=self.input_cols,
549
+ label_cols=self.label_cols,
550
+ sample_weight_col=self.sample_weight_col,
551
+ autogenerated=self._autogenerated,
552
+ subproject=_SUBPROJECT,
553
+ )
554
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
555
+ drop_input_cols=self._drop_input_cols,
556
+ expected_output_cols_list=self.output_cols,
557
+ )
558
+ self._sklearn_object = fitted_estimator
559
+ self._is_fitted = True
560
+ return output_result
549
561
 
550
562
 
551
563
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -636,10 +648,8 @@ class AdaBoostRegressor(BaseTransformer):
636
648
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
637
649
 
638
650
  if isinstance(dataset, DataFrame):
639
- self._deps = self._batch_inference_validate_snowpark(
640
- dataset=dataset,
641
- inference_method=inference_method,
642
- )
651
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
652
+ self._deps = self._get_dependencies()
643
653
  assert isinstance(
644
654
  dataset._session, Session
645
655
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -704,10 +714,8 @@ class AdaBoostRegressor(BaseTransformer):
704
714
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
705
715
 
706
716
  if isinstance(dataset, DataFrame):
707
- self._deps = self._batch_inference_validate_snowpark(
708
- dataset=dataset,
709
- inference_method=inference_method,
710
- )
717
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
718
+ self._deps = self._get_dependencies()
711
719
  assert isinstance(
712
720
  dataset._session, Session
713
721
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -769,10 +777,8 @@ class AdaBoostRegressor(BaseTransformer):
769
777
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
770
778
 
771
779
  if isinstance(dataset, DataFrame):
772
- self._deps = self._batch_inference_validate_snowpark(
773
- dataset=dataset,
774
- inference_method=inference_method,
775
- )
780
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
781
+ self._deps = self._get_dependencies()
776
782
  assert isinstance(
777
783
  dataset._session, Session
778
784
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -838,10 +844,8 @@ class AdaBoostRegressor(BaseTransformer):
838
844
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
839
845
 
840
846
  if isinstance(dataset, DataFrame):
841
- self._deps = self._batch_inference_validate_snowpark(
842
- dataset=dataset,
843
- inference_method=inference_method,
844
- )
847
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
848
+ self._deps = self._get_dependencies()
845
849
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
846
850
  transform_kwargs = dict(
847
851
  session=dataset._session,
@@ -905,17 +909,15 @@ class AdaBoostRegressor(BaseTransformer):
905
909
  transform_kwargs: ScoreKwargsTypedDict = dict()
906
910
 
907
911
  if isinstance(dataset, DataFrame):
908
- self._deps = self._batch_inference_validate_snowpark(
909
- dataset=dataset,
910
- inference_method="score",
911
- )
912
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
913
+ self._deps = self._get_dependencies()
912
914
  selected_cols = self._get_active_columns()
913
915
  if len(selected_cols) > 0:
914
916
  dataset = dataset.select(selected_cols)
915
917
  assert isinstance(dataset._session, Session) # keep mypy happy
916
918
  transform_kwargs = dict(
917
919
  session=dataset._session,
918
- dependencies=["snowflake-snowpark-python"] + self._deps,
920
+ dependencies=self._deps,
919
921
  score_sproc_imports=['sklearn'],
920
922
  )
921
923
  elif isinstance(dataset, pd.DataFrame):
@@ -980,11 +982,8 @@ class AdaBoostRegressor(BaseTransformer):
980
982
 
981
983
  if isinstance(dataset, DataFrame):
982
984
 
983
- self._deps = self._batch_inference_validate_snowpark(
984
- dataset=dataset,
985
- inference_method=inference_method,
986
-
987
- )
985
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
986
+ self._deps = self._get_dependencies()
988
987
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
989
988
  transform_kwargs = dict(
990
989
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class BaggingClassifier(BaseTransformer):
70
64
  r"""A Bagging classifier
71
65
  For more details on this class, see [sklearn.ensemble.BaggingClassifier]
@@ -337,20 +331,17 @@ class BaggingClassifier(BaseTransformer):
337
331
  self,
338
332
  dataset: DataFrame,
339
333
  inference_method: str,
340
- ) -> List[str]:
341
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
342
- return the available package that exists in the snowflake anaconda channel
334
+ ) -> None:
335
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
343
336
 
344
337
  Args:
345
338
  dataset: snowpark dataframe
346
339
  inference_method: the inference method such as predict, score...
347
-
340
+
348
341
  Raises:
349
342
  SnowflakeMLException: If the estimator is not fitted, raise error
350
343
  SnowflakeMLException: If the session is None, raise error
351
344
 
352
- Returns:
353
- A list of available package that exists in the snowflake anaconda channel
354
345
  """
355
346
  if not self._is_fitted:
356
347
  raise exceptions.SnowflakeMLException(
@@ -368,9 +359,7 @@ class BaggingClassifier(BaseTransformer):
368
359
  "Session must not specified for snowpark dataset."
369
360
  ),
370
361
  )
371
- # Validate that key package version in user workspace are supported in snowflake conda channel
372
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
373
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
362
+
374
363
 
375
364
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
376
365
  @telemetry.send_api_usage_telemetry(
@@ -418,7 +407,8 @@ class BaggingClassifier(BaseTransformer):
418
407
 
419
408
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
420
409
 
421
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
410
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
411
+ self._deps = self._get_dependencies()
422
412
  assert isinstance(
423
413
  dataset._session, Session
424
414
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -501,10 +491,8 @@ class BaggingClassifier(BaseTransformer):
501
491
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
502
492
  expected_dtype = convert_sp_to_sf_type(output_types[0])
503
493
 
504
- self._deps = self._batch_inference_validate_snowpark(
505
- dataset=dataset,
506
- inference_method=inference_method,
507
- )
494
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
495
+ self._deps = self._get_dependencies()
508
496
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
509
497
 
510
498
  transform_kwargs = dict(
@@ -571,16 +559,40 @@ class BaggingClassifier(BaseTransformer):
571
559
  self._is_fitted = True
572
560
  return output_result
573
561
 
562
+
563
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
564
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
565
+ """ Method not supported for this class.
574
566
 
575
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
576
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
577
- """
567
+
568
+ Raises:
569
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
570
+
571
+ Args:
572
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
573
+ Snowpark or Pandas DataFrame.
574
+ output_cols_prefix: Prefix for the response columns
578
575
  Returns:
579
576
  Transformed dataset.
580
577
  """
581
- self.fit(dataset)
582
- assert self._sklearn_object is not None
583
- return self._sklearn_object.embedding_
578
+ self._infer_input_output_cols(dataset)
579
+ super()._check_dataset_type(dataset)
580
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
581
+ estimator=self._sklearn_object,
582
+ dataset=dataset,
583
+ input_cols=self.input_cols,
584
+ label_cols=self.label_cols,
585
+ sample_weight_col=self.sample_weight_col,
586
+ autogenerated=self._autogenerated,
587
+ subproject=_SUBPROJECT,
588
+ )
589
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
590
+ drop_input_cols=self._drop_input_cols,
591
+ expected_output_cols_list=self.output_cols,
592
+ )
593
+ self._sklearn_object = fitted_estimator
594
+ self._is_fitted = True
595
+ return output_result
584
596
 
585
597
 
586
598
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -673,10 +685,8 @@ class BaggingClassifier(BaseTransformer):
673
685
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
674
686
 
675
687
  if isinstance(dataset, DataFrame):
676
- self._deps = self._batch_inference_validate_snowpark(
677
- dataset=dataset,
678
- inference_method=inference_method,
679
- )
688
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
689
+ self._deps = self._get_dependencies()
680
690
  assert isinstance(
681
691
  dataset._session, Session
682
692
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -743,10 +753,8 @@ class BaggingClassifier(BaseTransformer):
743
753
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
744
754
 
745
755
  if isinstance(dataset, DataFrame):
746
- self._deps = self._batch_inference_validate_snowpark(
747
- dataset=dataset,
748
- inference_method=inference_method,
749
- )
756
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
757
+ self._deps = self._get_dependencies()
750
758
  assert isinstance(
751
759
  dataset._session, Session
752
760
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -810,10 +818,8 @@ class BaggingClassifier(BaseTransformer):
810
818
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
811
819
 
812
820
  if isinstance(dataset, DataFrame):
813
- self._deps = self._batch_inference_validate_snowpark(
814
- dataset=dataset,
815
- inference_method=inference_method,
816
- )
821
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
822
+ self._deps = self._get_dependencies()
817
823
  assert isinstance(
818
824
  dataset._session, Session
819
825
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -879,10 +885,8 @@ class BaggingClassifier(BaseTransformer):
879
885
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
880
886
 
881
887
  if isinstance(dataset, DataFrame):
882
- self._deps = self._batch_inference_validate_snowpark(
883
- dataset=dataset,
884
- inference_method=inference_method,
885
- )
888
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
889
+ self._deps = self._get_dependencies()
886
890
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
887
891
  transform_kwargs = dict(
888
892
  session=dataset._session,
@@ -946,17 +950,15 @@ class BaggingClassifier(BaseTransformer):
946
950
  transform_kwargs: ScoreKwargsTypedDict = dict()
947
951
 
948
952
  if isinstance(dataset, DataFrame):
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method="score",
952
- )
953
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
954
+ self._deps = self._get_dependencies()
953
955
  selected_cols = self._get_active_columns()
954
956
  if len(selected_cols) > 0:
955
957
  dataset = dataset.select(selected_cols)
956
958
  assert isinstance(dataset._session, Session) # keep mypy happy
957
959
  transform_kwargs = dict(
958
960
  session=dataset._session,
959
- dependencies=["snowflake-snowpark-python"] + self._deps,
961
+ dependencies=self._deps,
960
962
  score_sproc_imports=['sklearn'],
961
963
  )
962
964
  elif isinstance(dataset, pd.DataFrame):
@@ -1021,11 +1023,8 @@ class BaggingClassifier(BaseTransformer):
1021
1023
 
1022
1024
  if isinstance(dataset, DataFrame):
1023
1025
 
1024
- self._deps = self._batch_inference_validate_snowpark(
1025
- dataset=dataset,
1026
- inference_method=inference_method,
1027
-
1028
- )
1026
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1027
+ self._deps = self._get_dependencies()
1029
1028
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1030
1029
  transform_kwargs = dict(
1031
1030
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class BaggingRegressor(BaseTransformer):
70
64
  r"""A Bagging regressor
71
65
  For more details on this class, see [sklearn.ensemble.BaggingRegressor]
@@ -337,20 +331,17 @@ class BaggingRegressor(BaseTransformer):
337
331
  self,
338
332
  dataset: DataFrame,
339
333
  inference_method: str,
340
- ) -> List[str]:
341
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
342
- return the available package that exists in the snowflake anaconda channel
334
+ ) -> None:
335
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
343
336
 
344
337
  Args:
345
338
  dataset: snowpark dataframe
346
339
  inference_method: the inference method such as predict, score...
347
-
340
+
348
341
  Raises:
349
342
  SnowflakeMLException: If the estimator is not fitted, raise error
350
343
  SnowflakeMLException: If the session is None, raise error
351
344
 
352
- Returns:
353
- A list of available package that exists in the snowflake anaconda channel
354
345
  """
355
346
  if not self._is_fitted:
356
347
  raise exceptions.SnowflakeMLException(
@@ -368,9 +359,7 @@ class BaggingRegressor(BaseTransformer):
368
359
  "Session must not specified for snowpark dataset."
369
360
  ),
370
361
  )
371
- # Validate that key package version in user workspace are supported in snowflake conda channel
372
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
373
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
362
+
374
363
 
375
364
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
376
365
  @telemetry.send_api_usage_telemetry(
@@ -418,7 +407,8 @@ class BaggingRegressor(BaseTransformer):
418
407
 
419
408
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
420
409
 
421
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
410
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
411
+ self._deps = self._get_dependencies()
422
412
  assert isinstance(
423
413
  dataset._session, Session
424
414
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -501,10 +491,8 @@ class BaggingRegressor(BaseTransformer):
501
491
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
502
492
  expected_dtype = convert_sp_to_sf_type(output_types[0])
503
493
 
504
- self._deps = self._batch_inference_validate_snowpark(
505
- dataset=dataset,
506
- inference_method=inference_method,
507
- )
494
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
495
+ self._deps = self._get_dependencies()
508
496
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
509
497
 
510
498
  transform_kwargs = dict(
@@ -571,16 +559,40 @@ class BaggingRegressor(BaseTransformer):
571
559
  self._is_fitted = True
572
560
  return output_result
573
561
 
562
+
563
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
564
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
565
+ """ Method not supported for this class.
574
566
 
575
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
576
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
577
- """
567
+
568
+ Raises:
569
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
570
+
571
+ Args:
572
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
573
+ Snowpark or Pandas DataFrame.
574
+ output_cols_prefix: Prefix for the response columns
578
575
  Returns:
579
576
  Transformed dataset.
580
577
  """
581
- self.fit(dataset)
582
- assert self._sklearn_object is not None
583
- return self._sklearn_object.embedding_
578
+ self._infer_input_output_cols(dataset)
579
+ super()._check_dataset_type(dataset)
580
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
581
+ estimator=self._sklearn_object,
582
+ dataset=dataset,
583
+ input_cols=self.input_cols,
584
+ label_cols=self.label_cols,
585
+ sample_weight_col=self.sample_weight_col,
586
+ autogenerated=self._autogenerated,
587
+ subproject=_SUBPROJECT,
588
+ )
589
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
590
+ drop_input_cols=self._drop_input_cols,
591
+ expected_output_cols_list=self.output_cols,
592
+ )
593
+ self._sklearn_object = fitted_estimator
594
+ self._is_fitted = True
595
+ return output_result
584
596
 
585
597
 
586
598
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -671,10 +683,8 @@ class BaggingRegressor(BaseTransformer):
671
683
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
672
684
 
673
685
  if isinstance(dataset, DataFrame):
674
- self._deps = self._batch_inference_validate_snowpark(
675
- dataset=dataset,
676
- inference_method=inference_method,
677
- )
686
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
687
+ self._deps = self._get_dependencies()
678
688
  assert isinstance(
679
689
  dataset._session, Session
680
690
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -739,10 +749,8 @@ class BaggingRegressor(BaseTransformer):
739
749
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
740
750
 
741
751
  if isinstance(dataset, DataFrame):
742
- self._deps = self._batch_inference_validate_snowpark(
743
- dataset=dataset,
744
- inference_method=inference_method,
745
- )
752
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
753
+ self._deps = self._get_dependencies()
746
754
  assert isinstance(
747
755
  dataset._session, Session
748
756
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -804,10 +812,8 @@ class BaggingRegressor(BaseTransformer):
804
812
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
805
813
 
806
814
  if isinstance(dataset, DataFrame):
807
- self._deps = self._batch_inference_validate_snowpark(
808
- dataset=dataset,
809
- inference_method=inference_method,
810
- )
815
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
816
+ self._deps = self._get_dependencies()
811
817
  assert isinstance(
812
818
  dataset._session, Session
813
819
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -873,10 +879,8 @@ class BaggingRegressor(BaseTransformer):
873
879
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
874
880
 
875
881
  if isinstance(dataset, DataFrame):
876
- self._deps = self._batch_inference_validate_snowpark(
877
- dataset=dataset,
878
- inference_method=inference_method,
879
- )
882
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
883
+ self._deps = self._get_dependencies()
880
884
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
881
885
  transform_kwargs = dict(
882
886
  session=dataset._session,
@@ -940,17 +944,15 @@ class BaggingRegressor(BaseTransformer):
940
944
  transform_kwargs: ScoreKwargsTypedDict = dict()
941
945
 
942
946
  if isinstance(dataset, DataFrame):
943
- self._deps = self._batch_inference_validate_snowpark(
944
- dataset=dataset,
945
- inference_method="score",
946
- )
947
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
948
+ self._deps = self._get_dependencies()
947
949
  selected_cols = self._get_active_columns()
948
950
  if len(selected_cols) > 0:
949
951
  dataset = dataset.select(selected_cols)
950
952
  assert isinstance(dataset._session, Session) # keep mypy happy
951
953
  transform_kwargs = dict(
952
954
  session=dataset._session,
953
- dependencies=["snowflake-snowpark-python"] + self._deps,
955
+ dependencies=self._deps,
954
956
  score_sproc_imports=['sklearn'],
955
957
  )
956
958
  elif isinstance(dataset, pd.DataFrame):
@@ -1015,11 +1017,8 @@ class BaggingRegressor(BaseTransformer):
1015
1017
 
1016
1018
  if isinstance(dataset, DataFrame):
1017
1019
 
1018
- self._deps = self._batch_inference_validate_snowpark(
1019
- dataset=dataset,
1020
- inference_method=inference_method,
1021
-
1022
- )
1020
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1021
+ self._deps = self._get_dependencies()
1023
1022
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1024
1023
  transform_kwargs = dict(
1025
1024
  session = dataset._session,