snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SpectralCoclustering(BaseTransformer):
70
64
  r"""Spectral Co-Clustering algorithm (Dhillon, 2001)
71
65
  For more details on this class, see [sklearn.cluster.SpectralCoclustering]
@@ -301,20 +295,17 @@ class SpectralCoclustering(BaseTransformer):
301
295
  self,
302
296
  dataset: DataFrame,
303
297
  inference_method: str,
304
- ) -> List[str]:
305
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
306
- return the available package that exists in the snowflake anaconda channel
298
+ ) -> None:
299
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
307
300
 
308
301
  Args:
309
302
  dataset: snowpark dataframe
310
303
  inference_method: the inference method such as predict, score...
311
-
304
+
312
305
  Raises:
313
306
  SnowflakeMLException: If the estimator is not fitted, raise error
314
307
  SnowflakeMLException: If the session is None, raise error
315
308
 
316
- Returns:
317
- A list of available package that exists in the snowflake anaconda channel
318
309
  """
319
310
  if not self._is_fitted:
320
311
  raise exceptions.SnowflakeMLException(
@@ -332,9 +323,7 @@ class SpectralCoclustering(BaseTransformer):
332
323
  "Session must not specified for snowpark dataset."
333
324
  ),
334
325
  )
335
- # Validate that key package version in user workspace are supported in snowflake conda channel
336
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
337
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
326
+
338
327
 
339
328
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
340
329
  @telemetry.send_api_usage_telemetry(
@@ -380,7 +369,8 @@ class SpectralCoclustering(BaseTransformer):
380
369
 
381
370
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
382
371
 
383
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
372
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
373
+ self._deps = self._get_dependencies()
384
374
  assert isinstance(
385
375
  dataset._session, Session
386
376
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -463,10 +453,8 @@ class SpectralCoclustering(BaseTransformer):
463
453
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
464
454
  expected_dtype = convert_sp_to_sf_type(output_types[0])
465
455
 
466
- self._deps = self._batch_inference_validate_snowpark(
467
- dataset=dataset,
468
- inference_method=inference_method,
469
- )
456
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
457
+ self._deps = self._get_dependencies()
470
458
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
471
459
 
472
460
  transform_kwargs = dict(
@@ -533,16 +521,40 @@ class SpectralCoclustering(BaseTransformer):
533
521
  self._is_fitted = True
534
522
  return output_result
535
523
 
524
+
525
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
526
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
527
+ """ Method not supported for this class.
536
528
 
537
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
538
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
539
- """
529
+
530
+ Raises:
531
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
532
+
533
+ Args:
534
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
535
+ Snowpark or Pandas DataFrame.
536
+ output_cols_prefix: Prefix for the response columns
540
537
  Returns:
541
538
  Transformed dataset.
542
539
  """
543
- self.fit(dataset)
544
- assert self._sklearn_object is not None
545
- return self._sklearn_object.embedding_
540
+ self._infer_input_output_cols(dataset)
541
+ super()._check_dataset_type(dataset)
542
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
543
+ estimator=self._sklearn_object,
544
+ dataset=dataset,
545
+ input_cols=self.input_cols,
546
+ label_cols=self.label_cols,
547
+ sample_weight_col=self.sample_weight_col,
548
+ autogenerated=self._autogenerated,
549
+ subproject=_SUBPROJECT,
550
+ )
551
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
552
+ drop_input_cols=self._drop_input_cols,
553
+ expected_output_cols_list=self.output_cols,
554
+ )
555
+ self._sklearn_object = fitted_estimator
556
+ self._is_fitted = True
557
+ return output_result
546
558
 
547
559
 
548
560
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -633,10 +645,8 @@ class SpectralCoclustering(BaseTransformer):
633
645
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
634
646
 
635
647
  if isinstance(dataset, DataFrame):
636
- self._deps = self._batch_inference_validate_snowpark(
637
- dataset=dataset,
638
- inference_method=inference_method,
639
- )
648
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
649
+ self._deps = self._get_dependencies()
640
650
  assert isinstance(
641
651
  dataset._session, Session
642
652
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -701,10 +711,8 @@ class SpectralCoclustering(BaseTransformer):
701
711
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
702
712
 
703
713
  if isinstance(dataset, DataFrame):
704
- self._deps = self._batch_inference_validate_snowpark(
705
- dataset=dataset,
706
- inference_method=inference_method,
707
- )
714
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
715
+ self._deps = self._get_dependencies()
708
716
  assert isinstance(
709
717
  dataset._session, Session
710
718
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -766,10 +774,8 @@ class SpectralCoclustering(BaseTransformer):
766
774
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
767
775
 
768
776
  if isinstance(dataset, DataFrame):
769
- self._deps = self._batch_inference_validate_snowpark(
770
- dataset=dataset,
771
- inference_method=inference_method,
772
- )
777
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
778
+ self._deps = self._get_dependencies()
773
779
  assert isinstance(
774
780
  dataset._session, Session
775
781
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -835,10 +841,8 @@ class SpectralCoclustering(BaseTransformer):
835
841
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
836
842
 
837
843
  if isinstance(dataset, DataFrame):
838
- self._deps = self._batch_inference_validate_snowpark(
839
- dataset=dataset,
840
- inference_method=inference_method,
841
- )
844
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
845
+ self._deps = self._get_dependencies()
842
846
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
843
847
  transform_kwargs = dict(
844
848
  session=dataset._session,
@@ -900,17 +904,15 @@ class SpectralCoclustering(BaseTransformer):
900
904
  transform_kwargs: ScoreKwargsTypedDict = dict()
901
905
 
902
906
  if isinstance(dataset, DataFrame):
903
- self._deps = self._batch_inference_validate_snowpark(
904
- dataset=dataset,
905
- inference_method="score",
906
- )
907
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
908
+ self._deps = self._get_dependencies()
907
909
  selected_cols = self._get_active_columns()
908
910
  if len(selected_cols) > 0:
909
911
  dataset = dataset.select(selected_cols)
910
912
  assert isinstance(dataset._session, Session) # keep mypy happy
911
913
  transform_kwargs = dict(
912
914
  session=dataset._session,
913
- dependencies=["snowflake-snowpark-python"] + self._deps,
915
+ dependencies=self._deps,
914
916
  score_sproc_imports=['sklearn'],
915
917
  )
916
918
  elif isinstance(dataset, pd.DataFrame):
@@ -975,11 +977,8 @@ class SpectralCoclustering(BaseTransformer):
975
977
 
976
978
  if isinstance(dataset, DataFrame):
977
979
 
978
- self._deps = self._batch_inference_validate_snowpark(
979
- dataset=dataset,
980
- inference_method=inference_method,
981
-
982
- )
980
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
981
+ self._deps = self._get_dependencies()
983
982
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
984
983
  transform_kwargs = dict(
985
984
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ColumnTransformer(BaseTransformer):
70
64
  r"""Applies transformers to columns of an array or pandas DataFrame
71
65
  For more details on this class, see [sklearn.compose.ColumnTransformer]
@@ -331,20 +325,17 @@ class ColumnTransformer(BaseTransformer):
331
325
  self,
332
326
  dataset: DataFrame,
333
327
  inference_method: str,
334
- ) -> List[str]:
335
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
336
- return the available package that exists in the snowflake anaconda channel
328
+ ) -> None:
329
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
337
330
 
338
331
  Args:
339
332
  dataset: snowpark dataframe
340
333
  inference_method: the inference method such as predict, score...
341
-
334
+
342
335
  Raises:
343
336
  SnowflakeMLException: If the estimator is not fitted, raise error
344
337
  SnowflakeMLException: If the session is None, raise error
345
338
 
346
- Returns:
347
- A list of available package that exists in the snowflake anaconda channel
348
339
  """
349
340
  if not self._is_fitted:
350
341
  raise exceptions.SnowflakeMLException(
@@ -362,9 +353,7 @@ class ColumnTransformer(BaseTransformer):
362
353
  "Session must not specified for snowpark dataset."
363
354
  ),
364
355
  )
365
- # Validate that key package version in user workspace are supported in snowflake conda channel
366
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
367
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
356
+
368
357
 
369
358
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
370
359
  @telemetry.send_api_usage_telemetry(
@@ -410,7 +399,8 @@ class ColumnTransformer(BaseTransformer):
410
399
 
411
400
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
412
401
 
413
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
402
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
403
+ self._deps = self._get_dependencies()
414
404
  assert isinstance(
415
405
  dataset._session, Session
416
406
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -495,10 +485,8 @@ class ColumnTransformer(BaseTransformer):
495
485
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
496
486
  expected_dtype = convert_sp_to_sf_type(output_types[0])
497
487
 
498
- self._deps = self._batch_inference_validate_snowpark(
499
- dataset=dataset,
500
- inference_method=inference_method,
501
- )
488
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
489
+ self._deps = self._get_dependencies()
502
490
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
503
491
 
504
492
  transform_kwargs = dict(
@@ -565,16 +553,42 @@ class ColumnTransformer(BaseTransformer):
565
553
  self._is_fitted = True
566
554
  return output_result
567
555
 
556
+
557
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
558
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
559
+ """ Fit all transformers, transform the data and concatenate results
560
+ For more details on this function, see [sklearn.compose.ColumnTransformer.fit_transform]
561
+ (https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer.fit_transform)
562
+
568
563
 
569
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
570
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
571
- """
564
+ Raises:
565
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
566
+
567
+ Args:
568
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
569
+ Snowpark or Pandas DataFrame.
570
+ output_cols_prefix: Prefix for the response columns
572
571
  Returns:
573
572
  Transformed dataset.
574
573
  """
575
- self.fit(dataset)
576
- assert self._sklearn_object is not None
577
- return self._sklearn_object.embedding_
574
+ self._infer_input_output_cols(dataset)
575
+ super()._check_dataset_type(dataset)
576
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
577
+ estimator=self._sklearn_object,
578
+ dataset=dataset,
579
+ input_cols=self.input_cols,
580
+ label_cols=self.label_cols,
581
+ sample_weight_col=self.sample_weight_col,
582
+ autogenerated=self._autogenerated,
583
+ subproject=_SUBPROJECT,
584
+ )
585
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
586
+ drop_input_cols=self._drop_input_cols,
587
+ expected_output_cols_list=self.output_cols,
588
+ )
589
+ self._sklearn_object = fitted_estimator
590
+ self._is_fitted = True
591
+ return output_result
578
592
 
579
593
 
580
594
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -665,10 +679,8 @@ class ColumnTransformer(BaseTransformer):
665
679
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
666
680
 
667
681
  if isinstance(dataset, DataFrame):
668
- self._deps = self._batch_inference_validate_snowpark(
669
- dataset=dataset,
670
- inference_method=inference_method,
671
- )
682
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
683
+ self._deps = self._get_dependencies()
672
684
  assert isinstance(
673
685
  dataset._session, Session
674
686
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -733,10 +745,8 @@ class ColumnTransformer(BaseTransformer):
733
745
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
734
746
 
735
747
  if isinstance(dataset, DataFrame):
736
- self._deps = self._batch_inference_validate_snowpark(
737
- dataset=dataset,
738
- inference_method=inference_method,
739
- )
748
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
749
+ self._deps = self._get_dependencies()
740
750
  assert isinstance(
741
751
  dataset._session, Session
742
752
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -798,10 +808,8 @@ class ColumnTransformer(BaseTransformer):
798
808
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
799
809
 
800
810
  if isinstance(dataset, DataFrame):
801
- self._deps = self._batch_inference_validate_snowpark(
802
- dataset=dataset,
803
- inference_method=inference_method,
804
- )
811
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
812
+ self._deps = self._get_dependencies()
805
813
  assert isinstance(
806
814
  dataset._session, Session
807
815
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -867,10 +875,8 @@ class ColumnTransformer(BaseTransformer):
867
875
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
868
876
 
869
877
  if isinstance(dataset, DataFrame):
870
- self._deps = self._batch_inference_validate_snowpark(
871
- dataset=dataset,
872
- inference_method=inference_method,
873
- )
878
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
879
+ self._deps = self._get_dependencies()
874
880
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
875
881
  transform_kwargs = dict(
876
882
  session=dataset._session,
@@ -932,17 +938,15 @@ class ColumnTransformer(BaseTransformer):
932
938
  transform_kwargs: ScoreKwargsTypedDict = dict()
933
939
 
934
940
  if isinstance(dataset, DataFrame):
935
- self._deps = self._batch_inference_validate_snowpark(
936
- dataset=dataset,
937
- inference_method="score",
938
- )
941
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
942
+ self._deps = self._get_dependencies()
939
943
  selected_cols = self._get_active_columns()
940
944
  if len(selected_cols) > 0:
941
945
  dataset = dataset.select(selected_cols)
942
946
  assert isinstance(dataset._session, Session) # keep mypy happy
943
947
  transform_kwargs = dict(
944
948
  session=dataset._session,
945
- dependencies=["snowflake-snowpark-python"] + self._deps,
949
+ dependencies=self._deps,
946
950
  score_sproc_imports=['sklearn'],
947
951
  )
948
952
  elif isinstance(dataset, pd.DataFrame):
@@ -1007,11 +1011,8 @@ class ColumnTransformer(BaseTransformer):
1007
1011
 
1008
1012
  if isinstance(dataset, DataFrame):
1009
1013
 
1010
- self._deps = self._batch_inference_validate_snowpark(
1011
- dataset=dataset,
1012
- inference_method=inference_method,
1013
-
1014
- )
1014
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1015
+ self._deps = self._get_dependencies()
1015
1016
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1016
1017
  transform_kwargs = dict(
1017
1018
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class TransformedTargetRegressor(BaseTransformer):
70
64
  r"""Meta-estimator to regress on a transformed target
71
65
  For more details on this class, see [sklearn.compose.TransformedTargetRegressor]
@@ -292,20 +286,17 @@ class TransformedTargetRegressor(BaseTransformer):
292
286
  self,
293
287
  dataset: DataFrame,
294
288
  inference_method: str,
295
- ) -> List[str]:
296
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
297
- return the available package that exists in the snowflake anaconda channel
289
+ ) -> None:
290
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
298
291
 
299
292
  Args:
300
293
  dataset: snowpark dataframe
301
294
  inference_method: the inference method such as predict, score...
302
-
295
+
303
296
  Raises:
304
297
  SnowflakeMLException: If the estimator is not fitted, raise error
305
298
  SnowflakeMLException: If the session is None, raise error
306
299
 
307
- Returns:
308
- A list of available package that exists in the snowflake anaconda channel
309
300
  """
310
301
  if not self._is_fitted:
311
302
  raise exceptions.SnowflakeMLException(
@@ -323,9 +314,7 @@ class TransformedTargetRegressor(BaseTransformer):
323
314
  "Session must not specified for snowpark dataset."
324
315
  ),
325
316
  )
326
- # Validate that key package version in user workspace are supported in snowflake conda channel
327
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
328
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
317
+
329
318
 
330
319
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
331
320
  @telemetry.send_api_usage_telemetry(
@@ -373,7 +362,8 @@ class TransformedTargetRegressor(BaseTransformer):
373
362
 
374
363
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
375
364
 
376
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
365
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
366
+ self._deps = self._get_dependencies()
377
367
  assert isinstance(
378
368
  dataset._session, Session
379
369
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -456,10 +446,8 @@ class TransformedTargetRegressor(BaseTransformer):
456
446
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
457
447
  expected_dtype = convert_sp_to_sf_type(output_types[0])
458
448
 
459
- self._deps = self._batch_inference_validate_snowpark(
460
- dataset=dataset,
461
- inference_method=inference_method,
462
- )
449
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
450
+ self._deps = self._get_dependencies()
463
451
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
464
452
 
465
453
  transform_kwargs = dict(
@@ -526,16 +514,40 @@ class TransformedTargetRegressor(BaseTransformer):
526
514
  self._is_fitted = True
527
515
  return output_result
528
516
 
517
+
518
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
519
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
520
+ """ Method not supported for this class.
529
521
 
530
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
531
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
532
- """
522
+
523
+ Raises:
524
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
525
+
526
+ Args:
527
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
528
+ Snowpark or Pandas DataFrame.
529
+ output_cols_prefix: Prefix for the response columns
533
530
  Returns:
534
531
  Transformed dataset.
535
532
  """
536
- self.fit(dataset)
537
- assert self._sklearn_object is not None
538
- return self._sklearn_object.embedding_
533
+ self._infer_input_output_cols(dataset)
534
+ super()._check_dataset_type(dataset)
535
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
536
+ estimator=self._sklearn_object,
537
+ dataset=dataset,
538
+ input_cols=self.input_cols,
539
+ label_cols=self.label_cols,
540
+ sample_weight_col=self.sample_weight_col,
541
+ autogenerated=self._autogenerated,
542
+ subproject=_SUBPROJECT,
543
+ )
544
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
545
+ drop_input_cols=self._drop_input_cols,
546
+ expected_output_cols_list=self.output_cols,
547
+ )
548
+ self._sklearn_object = fitted_estimator
549
+ self._is_fitted = True
550
+ return output_result
539
551
 
540
552
 
541
553
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -626,10 +638,8 @@ class TransformedTargetRegressor(BaseTransformer):
626
638
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
627
639
 
628
640
  if isinstance(dataset, DataFrame):
629
- self._deps = self._batch_inference_validate_snowpark(
630
- dataset=dataset,
631
- inference_method=inference_method,
632
- )
641
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
642
+ self._deps = self._get_dependencies()
633
643
  assert isinstance(
634
644
  dataset._session, Session
635
645
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -694,10 +704,8 @@ class TransformedTargetRegressor(BaseTransformer):
694
704
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
695
705
 
696
706
  if isinstance(dataset, DataFrame):
697
- self._deps = self._batch_inference_validate_snowpark(
698
- dataset=dataset,
699
- inference_method=inference_method,
700
- )
707
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
708
+ self._deps = self._get_dependencies()
701
709
  assert isinstance(
702
710
  dataset._session, Session
703
711
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -759,10 +767,8 @@ class TransformedTargetRegressor(BaseTransformer):
759
767
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
760
768
 
761
769
  if isinstance(dataset, DataFrame):
762
- self._deps = self._batch_inference_validate_snowpark(
763
- dataset=dataset,
764
- inference_method=inference_method,
765
- )
770
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
771
+ self._deps = self._get_dependencies()
766
772
  assert isinstance(
767
773
  dataset._session, Session
768
774
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -828,10 +834,8 @@ class TransformedTargetRegressor(BaseTransformer):
828
834
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
829
835
 
830
836
  if isinstance(dataset, DataFrame):
831
- self._deps = self._batch_inference_validate_snowpark(
832
- dataset=dataset,
833
- inference_method=inference_method,
834
- )
837
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
838
+ self._deps = self._get_dependencies()
835
839
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
836
840
  transform_kwargs = dict(
837
841
  session=dataset._session,
@@ -895,17 +899,15 @@ class TransformedTargetRegressor(BaseTransformer):
895
899
  transform_kwargs: ScoreKwargsTypedDict = dict()
896
900
 
897
901
  if isinstance(dataset, DataFrame):
898
- self._deps = self._batch_inference_validate_snowpark(
899
- dataset=dataset,
900
- inference_method="score",
901
- )
902
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
903
+ self._deps = self._get_dependencies()
902
904
  selected_cols = self._get_active_columns()
903
905
  if len(selected_cols) > 0:
904
906
  dataset = dataset.select(selected_cols)
905
907
  assert isinstance(dataset._session, Session) # keep mypy happy
906
908
  transform_kwargs = dict(
907
909
  session=dataset._session,
908
- dependencies=["snowflake-snowpark-python"] + self._deps,
910
+ dependencies=self._deps,
909
911
  score_sproc_imports=['sklearn'],
910
912
  )
911
913
  elif isinstance(dataset, pd.DataFrame):
@@ -970,11 +972,8 @@ class TransformedTargetRegressor(BaseTransformer):
970
972
 
971
973
  if isinstance(dataset, DataFrame):
972
974
 
973
- self._deps = self._batch_inference_validate_snowpark(
974
- dataset=dataset,
975
- inference_method=inference_method,
976
-
977
- )
975
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
976
+ self._deps = self._get_dependencies()
978
977
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
979
978
  transform_kwargs = dict(
980
979
  session = dataset._session,