snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", ""
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LGBMRegressor(BaseTransformer):
70
64
  r"""LightGBM regressor
71
65
  For more details on this class, see [lightgbm.LGBMRegressor]
@@ -294,20 +288,17 @@ class LGBMRegressor(BaseTransformer):
294
288
  self,
295
289
  dataset: DataFrame,
296
290
  inference_method: str,
297
- ) -> List[str]:
298
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
299
- return the available package that exists in the snowflake anaconda channel
291
+ ) -> None:
292
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
300
293
 
301
294
  Args:
302
295
  dataset: snowpark dataframe
303
296
  inference_method: the inference method such as predict, score...
304
-
297
+
305
298
  Raises:
306
299
  SnowflakeMLException: If the estimator is not fitted, raise error
307
300
  SnowflakeMLException: If the session is None, raise error
308
301
 
309
- Returns:
310
- A list of available package that exists in the snowflake anaconda channel
311
302
  """
312
303
  if not self._is_fitted:
313
304
  raise exceptions.SnowflakeMLException(
@@ -325,9 +316,7 @@ class LGBMRegressor(BaseTransformer):
325
316
  "Session must not specified for snowpark dataset."
326
317
  ),
327
318
  )
328
- # Validate that key package version in user workspace are supported in snowflake conda channel
329
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
330
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
319
+
331
320
 
332
321
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
333
322
  @telemetry.send_api_usage_telemetry(
@@ -375,7 +364,8 @@ class LGBMRegressor(BaseTransformer):
375
364
 
376
365
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
377
366
 
378
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
367
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
368
+ self._deps = self._get_dependencies()
379
369
  assert isinstance(
380
370
  dataset._session, Session
381
371
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -458,10 +448,8 @@ class LGBMRegressor(BaseTransformer):
458
448
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
459
449
  expected_dtype = convert_sp_to_sf_type(output_types[0])
460
450
 
461
- self._deps = self._batch_inference_validate_snowpark(
462
- dataset=dataset,
463
- inference_method=inference_method,
464
- )
451
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
452
+ self._deps = self._get_dependencies()
465
453
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
466
454
 
467
455
  transform_kwargs = dict(
@@ -528,16 +516,40 @@ class LGBMRegressor(BaseTransformer):
528
516
  self._is_fitted = True
529
517
  return output_result
530
518
 
519
+
520
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
521
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
522
+ """ Method not supported for this class.
531
523
 
532
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
533
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
534
- """
524
+
525
+ Raises:
526
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
527
+
528
+ Args:
529
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
530
+ Snowpark or Pandas DataFrame.
531
+ output_cols_prefix: Prefix for the response columns
535
532
  Returns:
536
533
  Transformed dataset.
537
534
  """
538
- self.fit(dataset)
539
- assert self._sklearn_object is not None
540
- return self._sklearn_object.embedding_
535
+ self._infer_input_output_cols(dataset)
536
+ super()._check_dataset_type(dataset)
537
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
538
+ estimator=self._sklearn_object,
539
+ dataset=dataset,
540
+ input_cols=self.input_cols,
541
+ label_cols=self.label_cols,
542
+ sample_weight_col=self.sample_weight_col,
543
+ autogenerated=self._autogenerated,
544
+ subproject=_SUBPROJECT,
545
+ )
546
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
547
+ drop_input_cols=self._drop_input_cols,
548
+ expected_output_cols_list=self.output_cols,
549
+ )
550
+ self._sklearn_object = fitted_estimator
551
+ self._is_fitted = True
552
+ return output_result
541
553
 
542
554
 
543
555
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -628,10 +640,8 @@ class LGBMRegressor(BaseTransformer):
628
640
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
629
641
 
630
642
  if isinstance(dataset, DataFrame):
631
- self._deps = self._batch_inference_validate_snowpark(
632
- dataset=dataset,
633
- inference_method=inference_method,
634
- )
643
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
644
+ self._deps = self._get_dependencies()
635
645
  assert isinstance(
636
646
  dataset._session, Session
637
647
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -696,10 +706,8 @@ class LGBMRegressor(BaseTransformer):
696
706
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
697
707
 
698
708
  if isinstance(dataset, DataFrame):
699
- self._deps = self._batch_inference_validate_snowpark(
700
- dataset=dataset,
701
- inference_method=inference_method,
702
- )
709
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
710
+ self._deps = self._get_dependencies()
703
711
  assert isinstance(
704
712
  dataset._session, Session
705
713
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -761,10 +769,8 @@ class LGBMRegressor(BaseTransformer):
761
769
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
762
770
 
763
771
  if isinstance(dataset, DataFrame):
764
- self._deps = self._batch_inference_validate_snowpark(
765
- dataset=dataset,
766
- inference_method=inference_method,
767
- )
772
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
773
+ self._deps = self._get_dependencies()
768
774
  assert isinstance(
769
775
  dataset._session, Session
770
776
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -830,10 +836,8 @@ class LGBMRegressor(BaseTransformer):
830
836
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
831
837
 
832
838
  if isinstance(dataset, DataFrame):
833
- self._deps = self._batch_inference_validate_snowpark(
834
- dataset=dataset,
835
- inference_method=inference_method,
836
- )
839
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
840
+ self._deps = self._get_dependencies()
837
841
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
838
842
  transform_kwargs = dict(
839
843
  session=dataset._session,
@@ -897,17 +901,15 @@ class LGBMRegressor(BaseTransformer):
897
901
  transform_kwargs: ScoreKwargsTypedDict = dict()
898
902
 
899
903
  if isinstance(dataset, DataFrame):
900
- self._deps = self._batch_inference_validate_snowpark(
901
- dataset=dataset,
902
- inference_method="score",
903
- )
904
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
905
+ self._deps = self._get_dependencies()
904
906
  selected_cols = self._get_active_columns()
905
907
  if len(selected_cols) > 0:
906
908
  dataset = dataset.select(selected_cols)
907
909
  assert isinstance(dataset._session, Session) # keep mypy happy
908
910
  transform_kwargs = dict(
909
911
  session=dataset._session,
910
- dependencies=["snowflake-snowpark-python"] + self._deps,
912
+ dependencies=self._deps,
911
913
  score_sproc_imports=['lightgbm', 'sklearn'],
912
914
  )
913
915
  elif isinstance(dataset, pd.DataFrame):
@@ -972,11 +974,8 @@ class LGBMRegressor(BaseTransformer):
972
974
 
973
975
  if isinstance(dataset, DataFrame):
974
976
 
975
- self._deps = self._batch_inference_validate_snowpark(
976
- dataset=dataset,
977
- inference_method=inference_method,
978
-
979
- )
977
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
978
+ self._deps = self._get_dependencies()
980
979
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
981
980
  transform_kwargs = dict(
982
981
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ARDRegression(BaseTransformer):
70
64
  r"""Bayesian ARD regression
71
65
  For more details on this class, see [sklearn.linear_model.ARDRegression]
@@ -319,20 +313,17 @@ class ARDRegression(BaseTransformer):
319
313
  self,
320
314
  dataset: DataFrame,
321
315
  inference_method: str,
322
- ) -> List[str]:
323
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
324
- return the available package that exists in the snowflake anaconda channel
316
+ ) -> None:
317
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
325
318
 
326
319
  Args:
327
320
  dataset: snowpark dataframe
328
321
  inference_method: the inference method such as predict, score...
329
-
322
+
330
323
  Raises:
331
324
  SnowflakeMLException: If the estimator is not fitted, raise error
332
325
  SnowflakeMLException: If the session is None, raise error
333
326
 
334
- Returns:
335
- A list of available package that exists in the snowflake anaconda channel
336
327
  """
337
328
  if not self._is_fitted:
338
329
  raise exceptions.SnowflakeMLException(
@@ -350,9 +341,7 @@ class ARDRegression(BaseTransformer):
350
341
  "Session must not specified for snowpark dataset."
351
342
  ),
352
343
  )
353
- # Validate that key package version in user workspace are supported in snowflake conda channel
354
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
355
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
344
+
356
345
 
357
346
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
358
347
  @telemetry.send_api_usage_telemetry(
@@ -400,7 +389,8 @@ class ARDRegression(BaseTransformer):
400
389
 
401
390
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
402
391
 
403
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
392
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
393
+ self._deps = self._get_dependencies()
404
394
  assert isinstance(
405
395
  dataset._session, Session
406
396
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -483,10 +473,8 @@ class ARDRegression(BaseTransformer):
483
473
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
484
474
  expected_dtype = convert_sp_to_sf_type(output_types[0])
485
475
 
486
- self._deps = self._batch_inference_validate_snowpark(
487
- dataset=dataset,
488
- inference_method=inference_method,
489
- )
476
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
477
+ self._deps = self._get_dependencies()
490
478
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
491
479
 
492
480
  transform_kwargs = dict(
@@ -553,16 +541,40 @@ class ARDRegression(BaseTransformer):
553
541
  self._is_fitted = True
554
542
  return output_result
555
543
 
544
+
545
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
546
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
547
+ """ Method not supported for this class.
556
548
 
557
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
558
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
559
- """
549
+
550
+ Raises:
551
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
552
+
553
+ Args:
554
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
555
+ Snowpark or Pandas DataFrame.
556
+ output_cols_prefix: Prefix for the response columns
560
557
  Returns:
561
558
  Transformed dataset.
562
559
  """
563
- self.fit(dataset)
564
- assert self._sklearn_object is not None
565
- return self._sklearn_object.embedding_
560
+ self._infer_input_output_cols(dataset)
561
+ super()._check_dataset_type(dataset)
562
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
563
+ estimator=self._sklearn_object,
564
+ dataset=dataset,
565
+ input_cols=self.input_cols,
566
+ label_cols=self.label_cols,
567
+ sample_weight_col=self.sample_weight_col,
568
+ autogenerated=self._autogenerated,
569
+ subproject=_SUBPROJECT,
570
+ )
571
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
572
+ drop_input_cols=self._drop_input_cols,
573
+ expected_output_cols_list=self.output_cols,
574
+ )
575
+ self._sklearn_object = fitted_estimator
576
+ self._is_fitted = True
577
+ return output_result
566
578
 
567
579
 
568
580
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -653,10 +665,8 @@ class ARDRegression(BaseTransformer):
653
665
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
654
666
 
655
667
  if isinstance(dataset, DataFrame):
656
- self._deps = self._batch_inference_validate_snowpark(
657
- dataset=dataset,
658
- inference_method=inference_method,
659
- )
668
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
669
+ self._deps = self._get_dependencies()
660
670
  assert isinstance(
661
671
  dataset._session, Session
662
672
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -721,10 +731,8 @@ class ARDRegression(BaseTransformer):
721
731
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
722
732
 
723
733
  if isinstance(dataset, DataFrame):
724
- self._deps = self._batch_inference_validate_snowpark(
725
- dataset=dataset,
726
- inference_method=inference_method,
727
- )
734
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
735
+ self._deps = self._get_dependencies()
728
736
  assert isinstance(
729
737
  dataset._session, Session
730
738
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -786,10 +794,8 @@ class ARDRegression(BaseTransformer):
786
794
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
787
795
 
788
796
  if isinstance(dataset, DataFrame):
789
- self._deps = self._batch_inference_validate_snowpark(
790
- dataset=dataset,
791
- inference_method=inference_method,
792
- )
797
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
798
+ self._deps = self._get_dependencies()
793
799
  assert isinstance(
794
800
  dataset._session, Session
795
801
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -855,10 +861,8 @@ class ARDRegression(BaseTransformer):
855
861
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
856
862
 
857
863
  if isinstance(dataset, DataFrame):
858
- self._deps = self._batch_inference_validate_snowpark(
859
- dataset=dataset,
860
- inference_method=inference_method,
861
- )
864
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
865
+ self._deps = self._get_dependencies()
862
866
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
863
867
  transform_kwargs = dict(
864
868
  session=dataset._session,
@@ -922,17 +926,15 @@ class ARDRegression(BaseTransformer):
922
926
  transform_kwargs: ScoreKwargsTypedDict = dict()
923
927
 
924
928
  if isinstance(dataset, DataFrame):
925
- self._deps = self._batch_inference_validate_snowpark(
926
- dataset=dataset,
927
- inference_method="score",
928
- )
929
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
930
+ self._deps = self._get_dependencies()
929
931
  selected_cols = self._get_active_columns()
930
932
  if len(selected_cols) > 0:
931
933
  dataset = dataset.select(selected_cols)
932
934
  assert isinstance(dataset._session, Session) # keep mypy happy
933
935
  transform_kwargs = dict(
934
936
  session=dataset._session,
935
- dependencies=["snowflake-snowpark-python"] + self._deps,
937
+ dependencies=self._deps,
936
938
  score_sproc_imports=['sklearn'],
937
939
  )
938
940
  elif isinstance(dataset, pd.DataFrame):
@@ -997,11 +999,8 @@ class ARDRegression(BaseTransformer):
997
999
 
998
1000
  if isinstance(dataset, DataFrame):
999
1001
 
1000
- self._deps = self._batch_inference_validate_snowpark(
1001
- dataset=dataset,
1002
- inference_method=inference_method,
1003
-
1004
- )
1002
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1003
+ self._deps = self._get_dependencies()
1005
1004
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1006
1005
  transform_kwargs = dict(
1007
1006
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class BayesianRidge(BaseTransformer):
70
64
  r"""Bayesian ridge regression
71
65
  For more details on this class, see [sklearn.linear_model.BayesianRidge]
@@ -330,20 +324,17 @@ class BayesianRidge(BaseTransformer):
330
324
  self,
331
325
  dataset: DataFrame,
332
326
  inference_method: str,
333
- ) -> List[str]:
334
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
335
- return the available package that exists in the snowflake anaconda channel
327
+ ) -> None:
328
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
336
329
 
337
330
  Args:
338
331
  dataset: snowpark dataframe
339
332
  inference_method: the inference method such as predict, score...
340
-
333
+
341
334
  Raises:
342
335
  SnowflakeMLException: If the estimator is not fitted, raise error
343
336
  SnowflakeMLException: If the session is None, raise error
344
337
 
345
- Returns:
346
- A list of available package that exists in the snowflake anaconda channel
347
338
  """
348
339
  if not self._is_fitted:
349
340
  raise exceptions.SnowflakeMLException(
@@ -361,9 +352,7 @@ class BayesianRidge(BaseTransformer):
361
352
  "Session must not specified for snowpark dataset."
362
353
  ),
363
354
  )
364
- # Validate that key package version in user workspace are supported in snowflake conda channel
365
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
366
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
355
+
367
356
 
368
357
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
369
358
  @telemetry.send_api_usage_telemetry(
@@ -411,7 +400,8 @@ class BayesianRidge(BaseTransformer):
411
400
 
412
401
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
413
402
 
414
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
403
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
404
+ self._deps = self._get_dependencies()
415
405
  assert isinstance(
416
406
  dataset._session, Session
417
407
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -494,10 +484,8 @@ class BayesianRidge(BaseTransformer):
494
484
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
495
485
  expected_dtype = convert_sp_to_sf_type(output_types[0])
496
486
 
497
- self._deps = self._batch_inference_validate_snowpark(
498
- dataset=dataset,
499
- inference_method=inference_method,
500
- )
487
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
488
+ self._deps = self._get_dependencies()
501
489
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
502
490
 
503
491
  transform_kwargs = dict(
@@ -564,16 +552,40 @@ class BayesianRidge(BaseTransformer):
564
552
  self._is_fitted = True
565
553
  return output_result
566
554
 
555
+
556
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
557
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
558
+ """ Method not supported for this class.
567
559
 
568
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
569
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
570
- """
560
+
561
+ Raises:
562
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
563
+
564
+ Args:
565
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
566
+ Snowpark or Pandas DataFrame.
567
+ output_cols_prefix: Prefix for the response columns
571
568
  Returns:
572
569
  Transformed dataset.
573
570
  """
574
- self.fit(dataset)
575
- assert self._sklearn_object is not None
576
- return self._sklearn_object.embedding_
571
+ self._infer_input_output_cols(dataset)
572
+ super()._check_dataset_type(dataset)
573
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
574
+ estimator=self._sklearn_object,
575
+ dataset=dataset,
576
+ input_cols=self.input_cols,
577
+ label_cols=self.label_cols,
578
+ sample_weight_col=self.sample_weight_col,
579
+ autogenerated=self._autogenerated,
580
+ subproject=_SUBPROJECT,
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=self.output_cols,
585
+ )
586
+ self._sklearn_object = fitted_estimator
587
+ self._is_fitted = True
588
+ return output_result
577
589
 
578
590
 
579
591
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -664,10 +676,8 @@ class BayesianRidge(BaseTransformer):
664
676
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
665
677
 
666
678
  if isinstance(dataset, DataFrame):
667
- self._deps = self._batch_inference_validate_snowpark(
668
- dataset=dataset,
669
- inference_method=inference_method,
670
- )
679
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
680
+ self._deps = self._get_dependencies()
671
681
  assert isinstance(
672
682
  dataset._session, Session
673
683
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -732,10 +742,8 @@ class BayesianRidge(BaseTransformer):
732
742
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
733
743
 
734
744
  if isinstance(dataset, DataFrame):
735
- self._deps = self._batch_inference_validate_snowpark(
736
- dataset=dataset,
737
- inference_method=inference_method,
738
- )
745
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
746
+ self._deps = self._get_dependencies()
739
747
  assert isinstance(
740
748
  dataset._session, Session
741
749
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -797,10 +805,8 @@ class BayesianRidge(BaseTransformer):
797
805
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
798
806
 
799
807
  if isinstance(dataset, DataFrame):
800
- self._deps = self._batch_inference_validate_snowpark(
801
- dataset=dataset,
802
- inference_method=inference_method,
803
- )
808
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
809
+ self._deps = self._get_dependencies()
804
810
  assert isinstance(
805
811
  dataset._session, Session
806
812
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -866,10 +872,8 @@ class BayesianRidge(BaseTransformer):
866
872
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
867
873
 
868
874
  if isinstance(dataset, DataFrame):
869
- self._deps = self._batch_inference_validate_snowpark(
870
- dataset=dataset,
871
- inference_method=inference_method,
872
- )
875
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
876
+ self._deps = self._get_dependencies()
873
877
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
874
878
  transform_kwargs = dict(
875
879
  session=dataset._session,
@@ -933,17 +937,15 @@ class BayesianRidge(BaseTransformer):
933
937
  transform_kwargs: ScoreKwargsTypedDict = dict()
934
938
 
935
939
  if isinstance(dataset, DataFrame):
936
- self._deps = self._batch_inference_validate_snowpark(
937
- dataset=dataset,
938
- inference_method="score",
939
- )
940
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
941
+ self._deps = self._get_dependencies()
940
942
  selected_cols = self._get_active_columns()
941
943
  if len(selected_cols) > 0:
942
944
  dataset = dataset.select(selected_cols)
943
945
  assert isinstance(dataset._session, Session) # keep mypy happy
944
946
  transform_kwargs = dict(
945
947
  session=dataset._session,
946
- dependencies=["snowflake-snowpark-python"] + self._deps,
948
+ dependencies=self._deps,
947
949
  score_sproc_imports=['sklearn'],
948
950
  )
949
951
  elif isinstance(dataset, pd.DataFrame):
@@ -1008,11 +1010,8 @@ class BayesianRidge(BaseTransformer):
1008
1010
 
1009
1011
  if isinstance(dataset, DataFrame):
1010
1012
 
1011
- self._deps = self._batch_inference_validate_snowpark(
1012
- dataset=dataset,
1013
- inference_method=inference_method,
1014
-
1015
- )
1013
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1014
+ self._deps = self._get_dependencies()
1016
1015
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1017
1016
  transform_kwargs = dict(
1018
1017
  session = dataset._session,