snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SkewedChi2Sampler(BaseTransformer):
70
64
  r"""Approximate feature map for "skewed chi-squared" kernel
71
65
  For more details on this class, see [sklearn.kernel_approximation.SkewedChi2Sampler]
@@ -269,20 +263,17 @@ class SkewedChi2Sampler(BaseTransformer):
269
263
  self,
270
264
  dataset: DataFrame,
271
265
  inference_method: str,
272
- ) -> List[str]:
273
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
274
- return the available package that exists in the snowflake anaconda channel
266
+ ) -> None:
267
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
275
268
 
276
269
  Args:
277
270
  dataset: snowpark dataframe
278
271
  inference_method: the inference method such as predict, score...
279
-
272
+
280
273
  Raises:
281
274
  SnowflakeMLException: If the estimator is not fitted, raise error
282
275
  SnowflakeMLException: If the session is None, raise error
283
276
 
284
- Returns:
285
- A list of available package that exists in the snowflake anaconda channel
286
277
  """
287
278
  if not self._is_fitted:
288
279
  raise exceptions.SnowflakeMLException(
@@ -300,9 +291,7 @@ class SkewedChi2Sampler(BaseTransformer):
300
291
  "Session must not specified for snowpark dataset."
301
292
  ),
302
293
  )
303
- # Validate that key package version in user workspace are supported in snowflake conda channel
304
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
305
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
294
+
306
295
 
307
296
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
308
297
  @telemetry.send_api_usage_telemetry(
@@ -348,7 +337,8 @@ class SkewedChi2Sampler(BaseTransformer):
348
337
 
349
338
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
350
339
 
351
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
340
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
341
+ self._deps = self._get_dependencies()
352
342
  assert isinstance(
353
343
  dataset._session, Session
354
344
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -433,10 +423,8 @@ class SkewedChi2Sampler(BaseTransformer):
433
423
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
434
424
  expected_dtype = convert_sp_to_sf_type(output_types[0])
435
425
 
436
- self._deps = self._batch_inference_validate_snowpark(
437
- dataset=dataset,
438
- inference_method=inference_method,
439
- )
426
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
427
+ self._deps = self._get_dependencies()
440
428
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
441
429
 
442
430
  transform_kwargs = dict(
@@ -503,16 +491,42 @@ class SkewedChi2Sampler(BaseTransformer):
503
491
  self._is_fitted = True
504
492
  return output_result
505
493
 
494
+
495
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
496
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
497
+ """ Fit to data, then transform it
498
+ For more details on this function, see [sklearn.kernel_approximation.SkewedChi2Sampler.fit_transform]
499
+ (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.SkewedChi2Sampler.html#sklearn.kernel_approximation.SkewedChi2Sampler.fit_transform)
500
+
506
501
 
507
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
508
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
509
- """
502
+ Raises:
503
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
504
+
505
+ Args:
506
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
507
+ Snowpark or Pandas DataFrame.
508
+ output_cols_prefix: Prefix for the response columns
510
509
  Returns:
511
510
  Transformed dataset.
512
511
  """
513
- self.fit(dataset)
514
- assert self._sklearn_object is not None
515
- return self._sklearn_object.embedding_
512
+ self._infer_input_output_cols(dataset)
513
+ super()._check_dataset_type(dataset)
514
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
515
+ estimator=self._sklearn_object,
516
+ dataset=dataset,
517
+ input_cols=self.input_cols,
518
+ label_cols=self.label_cols,
519
+ sample_weight_col=self.sample_weight_col,
520
+ autogenerated=self._autogenerated,
521
+ subproject=_SUBPROJECT,
522
+ )
523
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
524
+ drop_input_cols=self._drop_input_cols,
525
+ expected_output_cols_list=self.output_cols,
526
+ )
527
+ self._sklearn_object = fitted_estimator
528
+ self._is_fitted = True
529
+ return output_result
516
530
 
517
531
 
518
532
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -603,10 +617,8 @@ class SkewedChi2Sampler(BaseTransformer):
603
617
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
604
618
 
605
619
  if isinstance(dataset, DataFrame):
606
- self._deps = self._batch_inference_validate_snowpark(
607
- dataset=dataset,
608
- inference_method=inference_method,
609
- )
620
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
621
+ self._deps = self._get_dependencies()
610
622
  assert isinstance(
611
623
  dataset._session, Session
612
624
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -671,10 +683,8 @@ class SkewedChi2Sampler(BaseTransformer):
671
683
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
672
684
 
673
685
  if isinstance(dataset, DataFrame):
674
- self._deps = self._batch_inference_validate_snowpark(
675
- dataset=dataset,
676
- inference_method=inference_method,
677
- )
686
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
687
+ self._deps = self._get_dependencies()
678
688
  assert isinstance(
679
689
  dataset._session, Session
680
690
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -736,10 +746,8 @@ class SkewedChi2Sampler(BaseTransformer):
736
746
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
737
747
 
738
748
  if isinstance(dataset, DataFrame):
739
- self._deps = self._batch_inference_validate_snowpark(
740
- dataset=dataset,
741
- inference_method=inference_method,
742
- )
749
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
750
+ self._deps = self._get_dependencies()
743
751
  assert isinstance(
744
752
  dataset._session, Session
745
753
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -805,10 +813,8 @@ class SkewedChi2Sampler(BaseTransformer):
805
813
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
806
814
 
807
815
  if isinstance(dataset, DataFrame):
808
- self._deps = self._batch_inference_validate_snowpark(
809
- dataset=dataset,
810
- inference_method=inference_method,
811
- )
816
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
817
+ self._deps = self._get_dependencies()
812
818
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
813
819
  transform_kwargs = dict(
814
820
  session=dataset._session,
@@ -870,17 +876,15 @@ class SkewedChi2Sampler(BaseTransformer):
870
876
  transform_kwargs: ScoreKwargsTypedDict = dict()
871
877
 
872
878
  if isinstance(dataset, DataFrame):
873
- self._deps = self._batch_inference_validate_snowpark(
874
- dataset=dataset,
875
- inference_method="score",
876
- )
879
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
880
+ self._deps = self._get_dependencies()
877
881
  selected_cols = self._get_active_columns()
878
882
  if len(selected_cols) > 0:
879
883
  dataset = dataset.select(selected_cols)
880
884
  assert isinstance(dataset._session, Session) # keep mypy happy
881
885
  transform_kwargs = dict(
882
886
  session=dataset._session,
883
- dependencies=["snowflake-snowpark-python"] + self._deps,
887
+ dependencies=self._deps,
884
888
  score_sproc_imports=['sklearn'],
885
889
  )
886
890
  elif isinstance(dataset, pd.DataFrame):
@@ -945,11 +949,8 @@ class SkewedChi2Sampler(BaseTransformer):
945
949
 
946
950
  if isinstance(dataset, DataFrame):
947
951
 
948
- self._deps = self._batch_inference_validate_snowpark(
949
- dataset=dataset,
950
- inference_method=inference_method,
951
-
952
- )
952
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
953
+ self._deps = self._get_dependencies()
953
954
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
954
955
  transform_kwargs = dict(
955
956
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_ridge".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class KernelRidge(BaseTransformer):
70
64
  r"""Kernel ridge regression
71
65
  For more details on this class, see [sklearn.kernel_ridge.KernelRidge]
@@ -305,20 +299,17 @@ class KernelRidge(BaseTransformer):
305
299
  self,
306
300
  dataset: DataFrame,
307
301
  inference_method: str,
308
- ) -> List[str]:
309
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
310
- return the available package that exists in the snowflake anaconda channel
302
+ ) -> None:
303
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
311
304
 
312
305
  Args:
313
306
  dataset: snowpark dataframe
314
307
  inference_method: the inference method such as predict, score...
315
-
308
+
316
309
  Raises:
317
310
  SnowflakeMLException: If the estimator is not fitted, raise error
318
311
  SnowflakeMLException: If the session is None, raise error
319
312
 
320
- Returns:
321
- A list of available package that exists in the snowflake anaconda channel
322
313
  """
323
314
  if not self._is_fitted:
324
315
  raise exceptions.SnowflakeMLException(
@@ -336,9 +327,7 @@ class KernelRidge(BaseTransformer):
336
327
  "Session must not specified for snowpark dataset."
337
328
  ),
338
329
  )
339
- # Validate that key package version in user workspace are supported in snowflake conda channel
340
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
341
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
330
+
342
331
 
343
332
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
344
333
  @telemetry.send_api_usage_telemetry(
@@ -386,7 +375,8 @@ class KernelRidge(BaseTransformer):
386
375
 
387
376
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
388
377
 
389
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
378
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
379
+ self._deps = self._get_dependencies()
390
380
  assert isinstance(
391
381
  dataset._session, Session
392
382
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -469,10 +459,8 @@ class KernelRidge(BaseTransformer):
469
459
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
470
460
  expected_dtype = convert_sp_to_sf_type(output_types[0])
471
461
 
472
- self._deps = self._batch_inference_validate_snowpark(
473
- dataset=dataset,
474
- inference_method=inference_method,
475
- )
462
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
463
+ self._deps = self._get_dependencies()
476
464
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
477
465
 
478
466
  transform_kwargs = dict(
@@ -539,16 +527,40 @@ class KernelRidge(BaseTransformer):
539
527
  self._is_fitted = True
540
528
  return output_result
541
529
 
530
+
531
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
532
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
533
+ """ Method not supported for this class.
542
534
 
543
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
544
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
545
- """
535
+
536
+ Raises:
537
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
538
+
539
+ Args:
540
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
541
+ Snowpark or Pandas DataFrame.
542
+ output_cols_prefix: Prefix for the response columns
546
543
  Returns:
547
544
  Transformed dataset.
548
545
  """
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- return self._sklearn_object.embedding_
546
+ self._infer_input_output_cols(dataset)
547
+ super()._check_dataset_type(dataset)
548
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
549
+ estimator=self._sklearn_object,
550
+ dataset=dataset,
551
+ input_cols=self.input_cols,
552
+ label_cols=self.label_cols,
553
+ sample_weight_col=self.sample_weight_col,
554
+ autogenerated=self._autogenerated,
555
+ subproject=_SUBPROJECT,
556
+ )
557
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
558
+ drop_input_cols=self._drop_input_cols,
559
+ expected_output_cols_list=self.output_cols,
560
+ )
561
+ self._sklearn_object = fitted_estimator
562
+ self._is_fitted = True
563
+ return output_result
552
564
 
553
565
 
554
566
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -639,10 +651,8 @@ class KernelRidge(BaseTransformer):
639
651
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
640
652
 
641
653
  if isinstance(dataset, DataFrame):
642
- self._deps = self._batch_inference_validate_snowpark(
643
- dataset=dataset,
644
- inference_method=inference_method,
645
- )
654
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
655
+ self._deps = self._get_dependencies()
646
656
  assert isinstance(
647
657
  dataset._session, Session
648
658
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -707,10 +717,8 @@ class KernelRidge(BaseTransformer):
707
717
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
708
718
 
709
719
  if isinstance(dataset, DataFrame):
710
- self._deps = self._batch_inference_validate_snowpark(
711
- dataset=dataset,
712
- inference_method=inference_method,
713
- )
720
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
721
+ self._deps = self._get_dependencies()
714
722
  assert isinstance(
715
723
  dataset._session, Session
716
724
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -772,10 +780,8 @@ class KernelRidge(BaseTransformer):
772
780
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
773
781
 
774
782
  if isinstance(dataset, DataFrame):
775
- self._deps = self._batch_inference_validate_snowpark(
776
- dataset=dataset,
777
- inference_method=inference_method,
778
- )
783
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
784
+ self._deps = self._get_dependencies()
779
785
  assert isinstance(
780
786
  dataset._session, Session
781
787
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -841,10 +847,8 @@ class KernelRidge(BaseTransformer):
841
847
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
842
848
 
843
849
  if isinstance(dataset, DataFrame):
844
- self._deps = self._batch_inference_validate_snowpark(
845
- dataset=dataset,
846
- inference_method=inference_method,
847
- )
850
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
851
+ self._deps = self._get_dependencies()
848
852
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
849
853
  transform_kwargs = dict(
850
854
  session=dataset._session,
@@ -908,17 +912,15 @@ class KernelRidge(BaseTransformer):
908
912
  transform_kwargs: ScoreKwargsTypedDict = dict()
909
913
 
910
914
  if isinstance(dataset, DataFrame):
911
- self._deps = self._batch_inference_validate_snowpark(
912
- dataset=dataset,
913
- inference_method="score",
914
- )
915
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
916
+ self._deps = self._get_dependencies()
915
917
  selected_cols = self._get_active_columns()
916
918
  if len(selected_cols) > 0:
917
919
  dataset = dataset.select(selected_cols)
918
920
  assert isinstance(dataset._session, Session) # keep mypy happy
919
921
  transform_kwargs = dict(
920
922
  session=dataset._session,
921
- dependencies=["snowflake-snowpark-python"] + self._deps,
923
+ dependencies=self._deps,
922
924
  score_sproc_imports=['sklearn'],
923
925
  )
924
926
  elif isinstance(dataset, pd.DataFrame):
@@ -983,11 +985,8 @@ class KernelRidge(BaseTransformer):
983
985
 
984
986
  if isinstance(dataset, DataFrame):
985
987
 
986
- self._deps = self._batch_inference_validate_snowpark(
987
- dataset=dataset,
988
- inference_method=inference_method,
989
-
990
- )
988
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
989
+ self._deps = self._get_dependencies()
991
990
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
992
991
  transform_kwargs = dict(
993
992
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", ""
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LGBMClassifier(BaseTransformer):
70
64
  r"""LightGBM classifier
71
65
  For more details on this class, see [lightgbm.LGBMClassifier]
@@ -294,20 +288,17 @@ class LGBMClassifier(BaseTransformer):
294
288
  self,
295
289
  dataset: DataFrame,
296
290
  inference_method: str,
297
- ) -> List[str]:
298
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
299
- return the available package that exists in the snowflake anaconda channel
291
+ ) -> None:
292
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
300
293
 
301
294
  Args:
302
295
  dataset: snowpark dataframe
303
296
  inference_method: the inference method such as predict, score...
304
-
297
+
305
298
  Raises:
306
299
  SnowflakeMLException: If the estimator is not fitted, raise error
307
300
  SnowflakeMLException: If the session is None, raise error
308
301
 
309
- Returns:
310
- A list of available package that exists in the snowflake anaconda channel
311
302
  """
312
303
  if not self._is_fitted:
313
304
  raise exceptions.SnowflakeMLException(
@@ -325,9 +316,7 @@ class LGBMClassifier(BaseTransformer):
325
316
  "Session must not specified for snowpark dataset."
326
317
  ),
327
318
  )
328
- # Validate that key package version in user workspace are supported in snowflake conda channel
329
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
330
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
319
+
331
320
 
332
321
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
333
322
  @telemetry.send_api_usage_telemetry(
@@ -375,7 +364,8 @@ class LGBMClassifier(BaseTransformer):
375
364
 
376
365
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
377
366
 
378
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
367
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
368
+ self._deps = self._get_dependencies()
379
369
  assert isinstance(
380
370
  dataset._session, Session
381
371
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -458,10 +448,8 @@ class LGBMClassifier(BaseTransformer):
458
448
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
459
449
  expected_dtype = convert_sp_to_sf_type(output_types[0])
460
450
 
461
- self._deps = self._batch_inference_validate_snowpark(
462
- dataset=dataset,
463
- inference_method=inference_method,
464
- )
451
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
452
+ self._deps = self._get_dependencies()
465
453
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
466
454
 
467
455
  transform_kwargs = dict(
@@ -528,16 +516,40 @@ class LGBMClassifier(BaseTransformer):
528
516
  self._is_fitted = True
529
517
  return output_result
530
518
 
519
+
520
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
521
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
522
+ """ Method not supported for this class.
531
523
 
532
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
533
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
534
- """
524
+
525
+ Raises:
526
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
527
+
528
+ Args:
529
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
530
+ Snowpark or Pandas DataFrame.
531
+ output_cols_prefix: Prefix for the response columns
535
532
  Returns:
536
533
  Transformed dataset.
537
534
  """
538
- self.fit(dataset)
539
- assert self._sklearn_object is not None
540
- return self._sklearn_object.embedding_
535
+ self._infer_input_output_cols(dataset)
536
+ super()._check_dataset_type(dataset)
537
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
538
+ estimator=self._sklearn_object,
539
+ dataset=dataset,
540
+ input_cols=self.input_cols,
541
+ label_cols=self.label_cols,
542
+ sample_weight_col=self.sample_weight_col,
543
+ autogenerated=self._autogenerated,
544
+ subproject=_SUBPROJECT,
545
+ )
546
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
547
+ drop_input_cols=self._drop_input_cols,
548
+ expected_output_cols_list=self.output_cols,
549
+ )
550
+ self._sklearn_object = fitted_estimator
551
+ self._is_fitted = True
552
+ return output_result
541
553
 
542
554
 
543
555
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -630,10 +642,8 @@ class LGBMClassifier(BaseTransformer):
630
642
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
631
643
 
632
644
  if isinstance(dataset, DataFrame):
633
- self._deps = self._batch_inference_validate_snowpark(
634
- dataset=dataset,
635
- inference_method=inference_method,
636
- )
645
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
646
+ self._deps = self._get_dependencies()
637
647
  assert isinstance(
638
648
  dataset._session, Session
639
649
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -700,10 +710,8 @@ class LGBMClassifier(BaseTransformer):
700
710
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
701
711
 
702
712
  if isinstance(dataset, DataFrame):
703
- self._deps = self._batch_inference_validate_snowpark(
704
- dataset=dataset,
705
- inference_method=inference_method,
706
- )
713
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
714
+ self._deps = self._get_dependencies()
707
715
  assert isinstance(
708
716
  dataset._session, Session
709
717
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -765,10 +773,8 @@ class LGBMClassifier(BaseTransformer):
765
773
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
766
774
 
767
775
  if isinstance(dataset, DataFrame):
768
- self._deps = self._batch_inference_validate_snowpark(
769
- dataset=dataset,
770
- inference_method=inference_method,
771
- )
776
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
777
+ self._deps = self._get_dependencies()
772
778
  assert isinstance(
773
779
  dataset._session, Session
774
780
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -834,10 +840,8 @@ class LGBMClassifier(BaseTransformer):
834
840
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
835
841
 
836
842
  if isinstance(dataset, DataFrame):
837
- self._deps = self._batch_inference_validate_snowpark(
838
- dataset=dataset,
839
- inference_method=inference_method,
840
- )
843
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
844
+ self._deps = self._get_dependencies()
841
845
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
842
846
  transform_kwargs = dict(
843
847
  session=dataset._session,
@@ -901,17 +905,15 @@ class LGBMClassifier(BaseTransformer):
901
905
  transform_kwargs: ScoreKwargsTypedDict = dict()
902
906
 
903
907
  if isinstance(dataset, DataFrame):
904
- self._deps = self._batch_inference_validate_snowpark(
905
- dataset=dataset,
906
- inference_method="score",
907
- )
908
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
909
+ self._deps = self._get_dependencies()
908
910
  selected_cols = self._get_active_columns()
909
911
  if len(selected_cols) > 0:
910
912
  dataset = dataset.select(selected_cols)
911
913
  assert isinstance(dataset._session, Session) # keep mypy happy
912
914
  transform_kwargs = dict(
913
915
  session=dataset._session,
914
- dependencies=["snowflake-snowpark-python"] + self._deps,
916
+ dependencies=self._deps,
915
917
  score_sproc_imports=['lightgbm', 'sklearn'],
916
918
  )
917
919
  elif isinstance(dataset, pd.DataFrame):
@@ -976,11 +978,8 @@ class LGBMClassifier(BaseTransformer):
976
978
 
977
979
  if isinstance(dataset, DataFrame):
978
980
 
979
- self._deps = self._batch_inference_validate_snowpark(
980
- dataset=dataset,
981
- inference_method=inference_method,
982
-
983
- )
981
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
982
+ self._deps = self._get_dependencies()
984
983
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
985
984
  transform_kwargs = dict(
986
985
  session = dataset._session,