snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class DecisionTreeClassifier(BaseTransformer):
70
64
  r"""A decision tree classifier
71
65
  For more details on this class, see [sklearn.tree.DecisionTreeClassifier]
@@ -391,20 +385,17 @@ class DecisionTreeClassifier(BaseTransformer):
391
385
  self,
392
386
  dataset: DataFrame,
393
387
  inference_method: str,
394
- ) -> List[str]:
395
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
396
- return the available package that exists in the snowflake anaconda channel
388
+ ) -> None:
389
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
397
390
 
398
391
  Args:
399
392
  dataset: snowpark dataframe
400
393
  inference_method: the inference method such as predict, score...
401
-
394
+
402
395
  Raises:
403
396
  SnowflakeMLException: If the estimator is not fitted, raise error
404
397
  SnowflakeMLException: If the session is None, raise error
405
398
 
406
- Returns:
407
- A list of available package that exists in the snowflake anaconda channel
408
399
  """
409
400
  if not self._is_fitted:
410
401
  raise exceptions.SnowflakeMLException(
@@ -422,9 +413,7 @@ class DecisionTreeClassifier(BaseTransformer):
422
413
  "Session must not specified for snowpark dataset."
423
414
  ),
424
415
  )
425
- # Validate that key package version in user workspace are supported in snowflake conda channel
426
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
427
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
416
+
428
417
 
429
418
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
430
419
  @telemetry.send_api_usage_telemetry(
@@ -472,7 +461,8 @@ class DecisionTreeClassifier(BaseTransformer):
472
461
 
473
462
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
474
463
 
475
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
464
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
465
+ self._deps = self._get_dependencies()
476
466
  assert isinstance(
477
467
  dataset._session, Session
478
468
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -555,10 +545,8 @@ class DecisionTreeClassifier(BaseTransformer):
555
545
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
556
546
  expected_dtype = convert_sp_to_sf_type(output_types[0])
557
547
 
558
- self._deps = self._batch_inference_validate_snowpark(
559
- dataset=dataset,
560
- inference_method=inference_method,
561
- )
548
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
549
+ self._deps = self._get_dependencies()
562
550
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
563
551
 
564
552
  transform_kwargs = dict(
@@ -625,16 +613,40 @@ class DecisionTreeClassifier(BaseTransformer):
625
613
  self._is_fitted = True
626
614
  return output_result
627
615
 
616
+
617
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
618
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
619
+ """ Method not supported for this class.
628
620
 
629
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
630
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
631
- """
621
+
622
+ Raises:
623
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
624
+
625
+ Args:
626
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
627
+ Snowpark or Pandas DataFrame.
628
+ output_cols_prefix: Prefix for the response columns
632
629
  Returns:
633
630
  Transformed dataset.
634
631
  """
635
- self.fit(dataset)
636
- assert self._sklearn_object is not None
637
- return self._sklearn_object.embedding_
632
+ self._infer_input_output_cols(dataset)
633
+ super()._check_dataset_type(dataset)
634
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
635
+ estimator=self._sklearn_object,
636
+ dataset=dataset,
637
+ input_cols=self.input_cols,
638
+ label_cols=self.label_cols,
639
+ sample_weight_col=self.sample_weight_col,
640
+ autogenerated=self._autogenerated,
641
+ subproject=_SUBPROJECT,
642
+ )
643
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
644
+ drop_input_cols=self._drop_input_cols,
645
+ expected_output_cols_list=self.output_cols,
646
+ )
647
+ self._sklearn_object = fitted_estimator
648
+ self._is_fitted = True
649
+ return output_result
638
650
 
639
651
 
640
652
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -727,10 +739,8 @@ class DecisionTreeClassifier(BaseTransformer):
727
739
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
728
740
 
729
741
  if isinstance(dataset, DataFrame):
730
- self._deps = self._batch_inference_validate_snowpark(
731
- dataset=dataset,
732
- inference_method=inference_method,
733
- )
742
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
743
+ self._deps = self._get_dependencies()
734
744
  assert isinstance(
735
745
  dataset._session, Session
736
746
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -797,10 +807,8 @@ class DecisionTreeClassifier(BaseTransformer):
797
807
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
798
808
 
799
809
  if isinstance(dataset, DataFrame):
800
- self._deps = self._batch_inference_validate_snowpark(
801
- dataset=dataset,
802
- inference_method=inference_method,
803
- )
810
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
811
+ self._deps = self._get_dependencies()
804
812
  assert isinstance(
805
813
  dataset._session, Session
806
814
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -862,10 +870,8 @@ class DecisionTreeClassifier(BaseTransformer):
862
870
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
863
871
 
864
872
  if isinstance(dataset, DataFrame):
865
- self._deps = self._batch_inference_validate_snowpark(
866
- dataset=dataset,
867
- inference_method=inference_method,
868
- )
873
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
874
+ self._deps = self._get_dependencies()
869
875
  assert isinstance(
870
876
  dataset._session, Session
871
877
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -931,10 +937,8 @@ class DecisionTreeClassifier(BaseTransformer):
931
937
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
932
938
 
933
939
  if isinstance(dataset, DataFrame):
934
- self._deps = self._batch_inference_validate_snowpark(
935
- dataset=dataset,
936
- inference_method=inference_method,
937
- )
940
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
941
+ self._deps = self._get_dependencies()
938
942
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
939
943
  transform_kwargs = dict(
940
944
  session=dataset._session,
@@ -998,17 +1002,15 @@ class DecisionTreeClassifier(BaseTransformer):
998
1002
  transform_kwargs: ScoreKwargsTypedDict = dict()
999
1003
 
1000
1004
  if isinstance(dataset, DataFrame):
1001
- self._deps = self._batch_inference_validate_snowpark(
1002
- dataset=dataset,
1003
- inference_method="score",
1004
- )
1005
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1006
+ self._deps = self._get_dependencies()
1005
1007
  selected_cols = self._get_active_columns()
1006
1008
  if len(selected_cols) > 0:
1007
1009
  dataset = dataset.select(selected_cols)
1008
1010
  assert isinstance(dataset._session, Session) # keep mypy happy
1009
1011
  transform_kwargs = dict(
1010
1012
  session=dataset._session,
1011
- dependencies=["snowflake-snowpark-python"] + self._deps,
1013
+ dependencies=self._deps,
1012
1014
  score_sproc_imports=['sklearn'],
1013
1015
  )
1014
1016
  elif isinstance(dataset, pd.DataFrame):
@@ -1073,11 +1075,8 @@ class DecisionTreeClassifier(BaseTransformer):
1073
1075
 
1074
1076
  if isinstance(dataset, DataFrame):
1075
1077
 
1076
- self._deps = self._batch_inference_validate_snowpark(
1077
- dataset=dataset,
1078
- inference_method=inference_method,
1079
-
1080
- )
1078
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1079
+ self._deps = self._get_dependencies()
1081
1080
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1082
1081
  transform_kwargs = dict(
1083
1082
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class DecisionTreeRegressor(BaseTransformer):
70
64
  r"""A decision tree regressor
71
65
  For more details on this class, see [sklearn.tree.DecisionTreeRegressor]
@@ -373,20 +367,17 @@ class DecisionTreeRegressor(BaseTransformer):
373
367
  self,
374
368
  dataset: DataFrame,
375
369
  inference_method: str,
376
- ) -> List[str]:
377
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
378
- return the available package that exists in the snowflake anaconda channel
370
+ ) -> None:
371
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
379
372
 
380
373
  Args:
381
374
  dataset: snowpark dataframe
382
375
  inference_method: the inference method such as predict, score...
383
-
376
+
384
377
  Raises:
385
378
  SnowflakeMLException: If the estimator is not fitted, raise error
386
379
  SnowflakeMLException: If the session is None, raise error
387
380
 
388
- Returns:
389
- A list of available package that exists in the snowflake anaconda channel
390
381
  """
391
382
  if not self._is_fitted:
392
383
  raise exceptions.SnowflakeMLException(
@@ -404,9 +395,7 @@ class DecisionTreeRegressor(BaseTransformer):
404
395
  "Session must not specified for snowpark dataset."
405
396
  ),
406
397
  )
407
- # Validate that key package version in user workspace are supported in snowflake conda channel
408
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
409
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
398
+
410
399
 
411
400
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
412
401
  @telemetry.send_api_usage_telemetry(
@@ -454,7 +443,8 @@ class DecisionTreeRegressor(BaseTransformer):
454
443
 
455
444
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
456
445
 
457
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
446
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
447
+ self._deps = self._get_dependencies()
458
448
  assert isinstance(
459
449
  dataset._session, Session
460
450
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -537,10 +527,8 @@ class DecisionTreeRegressor(BaseTransformer):
537
527
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
538
528
  expected_dtype = convert_sp_to_sf_type(output_types[0])
539
529
 
540
- self._deps = self._batch_inference_validate_snowpark(
541
- dataset=dataset,
542
- inference_method=inference_method,
543
- )
530
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
531
+ self._deps = self._get_dependencies()
544
532
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
545
533
 
546
534
  transform_kwargs = dict(
@@ -607,16 +595,40 @@ class DecisionTreeRegressor(BaseTransformer):
607
595
  self._is_fitted = True
608
596
  return output_result
609
597
 
598
+
599
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
600
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
601
+ """ Method not supported for this class.
610
602
 
611
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
612
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
613
- """
603
+
604
+ Raises:
605
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
606
+
607
+ Args:
608
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
609
+ Snowpark or Pandas DataFrame.
610
+ output_cols_prefix: Prefix for the response columns
614
611
  Returns:
615
612
  Transformed dataset.
616
613
  """
617
- self.fit(dataset)
618
- assert self._sklearn_object is not None
619
- return self._sklearn_object.embedding_
614
+ self._infer_input_output_cols(dataset)
615
+ super()._check_dataset_type(dataset)
616
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
617
+ estimator=self._sklearn_object,
618
+ dataset=dataset,
619
+ input_cols=self.input_cols,
620
+ label_cols=self.label_cols,
621
+ sample_weight_col=self.sample_weight_col,
622
+ autogenerated=self._autogenerated,
623
+ subproject=_SUBPROJECT,
624
+ )
625
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
626
+ drop_input_cols=self._drop_input_cols,
627
+ expected_output_cols_list=self.output_cols,
628
+ )
629
+ self._sklearn_object = fitted_estimator
630
+ self._is_fitted = True
631
+ return output_result
620
632
 
621
633
 
622
634
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -707,10 +719,8 @@ class DecisionTreeRegressor(BaseTransformer):
707
719
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
708
720
 
709
721
  if isinstance(dataset, DataFrame):
710
- self._deps = self._batch_inference_validate_snowpark(
711
- dataset=dataset,
712
- inference_method=inference_method,
713
- )
722
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
723
+ self._deps = self._get_dependencies()
714
724
  assert isinstance(
715
725
  dataset._session, Session
716
726
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -775,10 +785,8 @@ class DecisionTreeRegressor(BaseTransformer):
775
785
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
776
786
 
777
787
  if isinstance(dataset, DataFrame):
778
- self._deps = self._batch_inference_validate_snowpark(
779
- dataset=dataset,
780
- inference_method=inference_method,
781
- )
788
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
789
+ self._deps = self._get_dependencies()
782
790
  assert isinstance(
783
791
  dataset._session, Session
784
792
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -840,10 +848,8 @@ class DecisionTreeRegressor(BaseTransformer):
840
848
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
841
849
 
842
850
  if isinstance(dataset, DataFrame):
843
- self._deps = self._batch_inference_validate_snowpark(
844
- dataset=dataset,
845
- inference_method=inference_method,
846
- )
851
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
852
+ self._deps = self._get_dependencies()
847
853
  assert isinstance(
848
854
  dataset._session, Session
849
855
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -909,10 +915,8 @@ class DecisionTreeRegressor(BaseTransformer):
909
915
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
910
916
 
911
917
  if isinstance(dataset, DataFrame):
912
- self._deps = self._batch_inference_validate_snowpark(
913
- dataset=dataset,
914
- inference_method=inference_method,
915
- )
918
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
919
+ self._deps = self._get_dependencies()
916
920
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
917
921
  transform_kwargs = dict(
918
922
  session=dataset._session,
@@ -976,17 +980,15 @@ class DecisionTreeRegressor(BaseTransformer):
976
980
  transform_kwargs: ScoreKwargsTypedDict = dict()
977
981
 
978
982
  if isinstance(dataset, DataFrame):
979
- self._deps = self._batch_inference_validate_snowpark(
980
- dataset=dataset,
981
- inference_method="score",
982
- )
983
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
984
+ self._deps = self._get_dependencies()
983
985
  selected_cols = self._get_active_columns()
984
986
  if len(selected_cols) > 0:
985
987
  dataset = dataset.select(selected_cols)
986
988
  assert isinstance(dataset._session, Session) # keep mypy happy
987
989
  transform_kwargs = dict(
988
990
  session=dataset._session,
989
- dependencies=["snowflake-snowpark-python"] + self._deps,
991
+ dependencies=self._deps,
990
992
  score_sproc_imports=['sklearn'],
991
993
  )
992
994
  elif isinstance(dataset, pd.DataFrame):
@@ -1051,11 +1053,8 @@ class DecisionTreeRegressor(BaseTransformer):
1051
1053
 
1052
1054
  if isinstance(dataset, DataFrame):
1053
1055
 
1054
- self._deps = self._batch_inference_validate_snowpark(
1055
- dataset=dataset,
1056
- inference_method=inference_method,
1057
-
1058
- )
1056
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1057
+ self._deps = self._get_dependencies()
1059
1058
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1060
1059
  transform_kwargs = dict(
1061
1060
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ExtraTreeClassifier(BaseTransformer):
70
64
  r"""An extremely randomized tree classifier
71
65
  For more details on this class, see [sklearn.tree.ExtraTreeClassifier]
@@ -383,20 +377,17 @@ class ExtraTreeClassifier(BaseTransformer):
383
377
  self,
384
378
  dataset: DataFrame,
385
379
  inference_method: str,
386
- ) -> List[str]:
387
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
388
- return the available package that exists in the snowflake anaconda channel
380
+ ) -> None:
381
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
389
382
 
390
383
  Args:
391
384
  dataset: snowpark dataframe
392
385
  inference_method: the inference method such as predict, score...
393
-
386
+
394
387
  Raises:
395
388
  SnowflakeMLException: If the estimator is not fitted, raise error
396
389
  SnowflakeMLException: If the session is None, raise error
397
390
 
398
- Returns:
399
- A list of available package that exists in the snowflake anaconda channel
400
391
  """
401
392
  if not self._is_fitted:
402
393
  raise exceptions.SnowflakeMLException(
@@ -414,9 +405,7 @@ class ExtraTreeClassifier(BaseTransformer):
414
405
  "Session must not specified for snowpark dataset."
415
406
  ),
416
407
  )
417
- # Validate that key package version in user workspace are supported in snowflake conda channel
418
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
419
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
408
+
420
409
 
421
410
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
422
411
  @telemetry.send_api_usage_telemetry(
@@ -464,7 +453,8 @@ class ExtraTreeClassifier(BaseTransformer):
464
453
 
465
454
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
466
455
 
467
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
456
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
457
+ self._deps = self._get_dependencies()
468
458
  assert isinstance(
469
459
  dataset._session, Session
470
460
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -547,10 +537,8 @@ class ExtraTreeClassifier(BaseTransformer):
547
537
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
548
538
  expected_dtype = convert_sp_to_sf_type(output_types[0])
549
539
 
550
- self._deps = self._batch_inference_validate_snowpark(
551
- dataset=dataset,
552
- inference_method=inference_method,
553
- )
540
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
541
+ self._deps = self._get_dependencies()
554
542
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
555
543
 
556
544
  transform_kwargs = dict(
@@ -617,16 +605,40 @@ class ExtraTreeClassifier(BaseTransformer):
617
605
  self._is_fitted = True
618
606
  return output_result
619
607
 
608
+
609
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
610
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
611
+ """ Method not supported for this class.
620
612
 
621
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
622
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
623
- """
613
+
614
+ Raises:
615
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
616
+
617
+ Args:
618
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
619
+ Snowpark or Pandas DataFrame.
620
+ output_cols_prefix: Prefix for the response columns
624
621
  Returns:
625
622
  Transformed dataset.
626
623
  """
627
- self.fit(dataset)
628
- assert self._sklearn_object is not None
629
- return self._sklearn_object.embedding_
624
+ self._infer_input_output_cols(dataset)
625
+ super()._check_dataset_type(dataset)
626
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
627
+ estimator=self._sklearn_object,
628
+ dataset=dataset,
629
+ input_cols=self.input_cols,
630
+ label_cols=self.label_cols,
631
+ sample_weight_col=self.sample_weight_col,
632
+ autogenerated=self._autogenerated,
633
+ subproject=_SUBPROJECT,
634
+ )
635
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
636
+ drop_input_cols=self._drop_input_cols,
637
+ expected_output_cols_list=self.output_cols,
638
+ )
639
+ self._sklearn_object = fitted_estimator
640
+ self._is_fitted = True
641
+ return output_result
630
642
 
631
643
 
632
644
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -719,10 +731,8 @@ class ExtraTreeClassifier(BaseTransformer):
719
731
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
720
732
 
721
733
  if isinstance(dataset, DataFrame):
722
- self._deps = self._batch_inference_validate_snowpark(
723
- dataset=dataset,
724
- inference_method=inference_method,
725
- )
734
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
735
+ self._deps = self._get_dependencies()
726
736
  assert isinstance(
727
737
  dataset._session, Session
728
738
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -789,10 +799,8 @@ class ExtraTreeClassifier(BaseTransformer):
789
799
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
790
800
 
791
801
  if isinstance(dataset, DataFrame):
792
- self._deps = self._batch_inference_validate_snowpark(
793
- dataset=dataset,
794
- inference_method=inference_method,
795
- )
802
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
803
+ self._deps = self._get_dependencies()
796
804
  assert isinstance(
797
805
  dataset._session, Session
798
806
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -854,10 +862,8 @@ class ExtraTreeClassifier(BaseTransformer):
854
862
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
855
863
 
856
864
  if isinstance(dataset, DataFrame):
857
- self._deps = self._batch_inference_validate_snowpark(
858
- dataset=dataset,
859
- inference_method=inference_method,
860
- )
865
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
866
+ self._deps = self._get_dependencies()
861
867
  assert isinstance(
862
868
  dataset._session, Session
863
869
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -923,10 +929,8 @@ class ExtraTreeClassifier(BaseTransformer):
923
929
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
924
930
 
925
931
  if isinstance(dataset, DataFrame):
926
- self._deps = self._batch_inference_validate_snowpark(
927
- dataset=dataset,
928
- inference_method=inference_method,
929
- )
932
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
933
+ self._deps = self._get_dependencies()
930
934
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
931
935
  transform_kwargs = dict(
932
936
  session=dataset._session,
@@ -990,17 +994,15 @@ class ExtraTreeClassifier(BaseTransformer):
990
994
  transform_kwargs: ScoreKwargsTypedDict = dict()
991
995
 
992
996
  if isinstance(dataset, DataFrame):
993
- self._deps = self._batch_inference_validate_snowpark(
994
- dataset=dataset,
995
- inference_method="score",
996
- )
997
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
998
+ self._deps = self._get_dependencies()
997
999
  selected_cols = self._get_active_columns()
998
1000
  if len(selected_cols) > 0:
999
1001
  dataset = dataset.select(selected_cols)
1000
1002
  assert isinstance(dataset._session, Session) # keep mypy happy
1001
1003
  transform_kwargs = dict(
1002
1004
  session=dataset._session,
1003
- dependencies=["snowflake-snowpark-python"] + self._deps,
1005
+ dependencies=self._deps,
1004
1006
  score_sproc_imports=['sklearn'],
1005
1007
  )
1006
1008
  elif isinstance(dataset, pd.DataFrame):
@@ -1065,11 +1067,8 @@ class ExtraTreeClassifier(BaseTransformer):
1065
1067
 
1066
1068
  if isinstance(dataset, DataFrame):
1067
1069
 
1068
- self._deps = self._batch_inference_validate_snowpark(
1069
- dataset=dataset,
1070
- inference_method=inference_method,
1071
-
1072
- )
1070
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1071
+ self._deps = self._get_dependencies()
1073
1072
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1074
1073
  transform_kwargs = dict(
1075
1074
  session = dataset._session,