snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
53
53
  _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "").split("_")])
54
54
 
55
55
 
56
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
57
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
58
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
59
+ return check
60
+
61
+
62
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
+ return check
66
+
67
+
56
68
  class XGBClassifier(BaseTransformer):
57
69
  r"""Implementation of the scikit-learn API for XGBoost classification
58
70
  For more details on this class, see [xgboost.XGBClassifier]
@@ -105,6 +117,22 @@ class XGBClassifier(BaseTransformer):
105
117
  drop_input_cols: Optional[bool], default=False
106
118
  If set, the response of predict(), transform() methods will not contain input columns.
107
119
 
120
+ use_external_memory_version: bool, default=False
121
+ If set, external memory version of XGBoost trainer is used. External memory training
122
+ is done in a two-step process. First,in the preprocessing step, input data is read and
123
+ parsed into an internal format, which can be CSR, CSC, or sorted CSC, and stored in
124
+ in-memory buffers. The in-memory buffers are continuously flushed out to disk when
125
+ predefined memory limit is reached. Second, in the tree construction step, the data
126
+ pages are streamed from disk via a multi-threaded pre-fetcher as needed for tree construction.
127
+ Note:'tree_method's 'approx', and 'hist' are supported in the external memory version.
128
+ Note:'grow_policy=depthwise' is used for optimal performance in the external memory version.
129
+
130
+ batch_size: int, default=10000
131
+ Number of rows in each batch of input data while using external memory training.
132
+ It is not recommended to set small batch sizes, like 32 samples per batch, as this
133
+ can seriously hurt performance in gradient boosting. Set the batch_size as large as possible
134
+ based on the available memory.
135
+
108
136
  n_estimators: int
109
137
  Number of boosting rounds.
110
138
 
@@ -325,6 +353,8 @@ class XGBClassifier(BaseTransformer):
325
353
  passthrough_cols: Optional[Union[str, Iterable[str]]] = None,
326
354
  drop_input_cols: Optional[bool] = False,
327
355
  sample_weight_col: Optional[str] = None,
356
+ use_external_memory_version: bool = False,
357
+ batch_size: int = 10000,
328
358
  **kwargs,
329
359
  ) -> None:
330
360
  super().__init__()
@@ -334,7 +364,9 @@ class XGBClassifier(BaseTransformer):
334
364
  self.set_label_cols(label_cols)
335
365
  self.set_passthrough_cols(passthrough_cols)
336
366
  self.set_drop_input_cols(drop_input_cols)
337
- self.set_sample_weight_col(sample_weight_col)
367
+ self.set_sample_weight_col(sample_weight_col)
368
+ self._use_external_memory_version = use_external_memory_version
369
+ self._batch_size = batch_size
338
370
  deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
339
371
 
340
372
  self._deps = list(deps)
@@ -412,11 +444,6 @@ class XGBClassifier(BaseTransformer):
412
444
  if isinstance(dataset, DataFrame):
413
445
  session = dataset._session
414
446
  assert session is not None # keep mypy happy
415
- # Validate that key package version in user workspace are supported in snowflake conda channel
416
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
417
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
418
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
419
-
420
447
  # Specify input columns so column pruning will be enforced
421
448
  selected_cols = self._get_active_columns()
422
449
  if len(selected_cols) > 0:
@@ -444,7 +471,9 @@ class XGBClassifier(BaseTransformer):
444
471
  label_cols=self.label_cols,
445
472
  sample_weight_col=self.sample_weight_col,
446
473
  autogenerated=self._autogenerated,
447
- subproject=_SUBPROJECT
474
+ subproject=_SUBPROJECT,
475
+ use_external_memory_version=self._use_external_memory_version,
476
+ batch_size=self._batch_size,
448
477
  )
449
478
  self._sklearn_object = model_trainer.train()
450
479
  self._is_fitted = True
@@ -715,6 +744,22 @@ class XGBClassifier(BaseTransformer):
715
744
  # each row containing a list of values.
716
745
  expected_dtype = "ARRAY"
717
746
 
747
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
748
+ if expected_dtype == "":
749
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
750
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
751
+ expected_dtype = "ARRAY"
752
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
753
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
754
+ expected_dtype = "ARRAY"
755
+ else:
756
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
757
+ # We can only infer the output types from the input types if the following two statemetns are true:
758
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
759
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
760
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
761
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
762
+
718
763
  output_df = self._batch_inference(
719
764
  dataset=dataset,
720
765
  inference_method="transform",
@@ -730,8 +775,8 @@ class XGBClassifier(BaseTransformer):
730
775
 
731
776
  return output_df
732
777
 
733
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
734
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
778
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
779
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
735
780
  """ Method not supported for this class.
736
781
 
737
782
 
@@ -744,13 +789,21 @@ class XGBClassifier(BaseTransformer):
744
789
  Returns:
745
790
  Predicted dataset.
746
791
  """
747
- if False:
748
- self.fit(dataset)
749
- assert self._sklearn_object is not None
750
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
751
- return labels
752
- else:
753
- raise NotImplementedError
792
+ self.fit(dataset)
793
+ assert self._sklearn_object is not None
794
+ return self._sklearn_object.labels_
795
+
796
+
797
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
798
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
799
+ """
800
+ Returns:
801
+ Transformed dataset.
802
+ """
803
+ self.fit(dataset)
804
+ assert self._sklearn_object is not None
805
+ return self._sklearn_object.embedding_
806
+
754
807
 
755
808
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
756
809
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
53
53
  _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "").split("_")])
54
54
 
55
55
 
56
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
57
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
58
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
59
+ return check
60
+
61
+
62
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
+ return check
66
+
67
+
56
68
  class XGBRegressor(BaseTransformer):
57
69
  r"""Implementation of the scikit-learn API for XGBoost regression
58
70
  For more details on this class, see [xgboost.XGBRegressor]
@@ -105,6 +117,22 @@ class XGBRegressor(BaseTransformer):
105
117
  drop_input_cols: Optional[bool], default=False
106
118
  If set, the response of predict(), transform() methods will not contain input columns.
107
119
 
120
+ use_external_memory_version: bool, default=False
121
+ If set, external memory version of XGBoost trainer is used. External memory training
122
+ is done in a two-step process. First,in the preprocessing step, input data is read and
123
+ parsed into an internal format, which can be CSR, CSC, or sorted CSC, and stored in
124
+ in-memory buffers. The in-memory buffers are continuously flushed out to disk when
125
+ predefined memory limit is reached. Second, in the tree construction step, the data
126
+ pages are streamed from disk via a multi-threaded pre-fetcher as needed for tree construction.
127
+ Note:'tree_method's 'approx', and 'hist' are supported in the external memory version.
128
+ Note:'grow_policy=depthwise' is used for optimal performance in the external memory version.
129
+
130
+ batch_size: int, default=10000
131
+ Number of rows in each batch of input data while using external memory training.
132
+ It is not recommended to set small batch sizes, like 32 samples per batch, as this
133
+ can seriously hurt performance in gradient boosting. Set the batch_size as large as possible
134
+ based on the available memory.
135
+
108
136
  n_estimators: int
109
137
  Number of gradient boosted trees. Equivalent to number of boosting
110
138
  rounds.
@@ -325,6 +353,8 @@ class XGBRegressor(BaseTransformer):
325
353
  passthrough_cols: Optional[Union[str, Iterable[str]]] = None,
326
354
  drop_input_cols: Optional[bool] = False,
327
355
  sample_weight_col: Optional[str] = None,
356
+ use_external_memory_version: bool = False,
357
+ batch_size: int = 10000,
328
358
  **kwargs,
329
359
  ) -> None:
330
360
  super().__init__()
@@ -334,7 +364,9 @@ class XGBRegressor(BaseTransformer):
334
364
  self.set_label_cols(label_cols)
335
365
  self.set_passthrough_cols(passthrough_cols)
336
366
  self.set_drop_input_cols(drop_input_cols)
337
- self.set_sample_weight_col(sample_weight_col)
367
+ self.set_sample_weight_col(sample_weight_col)
368
+ self._use_external_memory_version = use_external_memory_version
369
+ self._batch_size = batch_size
338
370
  deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
339
371
 
340
372
  self._deps = list(deps)
@@ -411,11 +443,6 @@ class XGBRegressor(BaseTransformer):
411
443
  if isinstance(dataset, DataFrame):
412
444
  session = dataset._session
413
445
  assert session is not None # keep mypy happy
414
- # Validate that key package version in user workspace are supported in snowflake conda channel
415
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
416
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
417
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
418
-
419
446
  # Specify input columns so column pruning will be enforced
420
447
  selected_cols = self._get_active_columns()
421
448
  if len(selected_cols) > 0:
@@ -443,7 +470,9 @@ class XGBRegressor(BaseTransformer):
443
470
  label_cols=self.label_cols,
444
471
  sample_weight_col=self.sample_weight_col,
445
472
  autogenerated=self._autogenerated,
446
- subproject=_SUBPROJECT
473
+ subproject=_SUBPROJECT,
474
+ use_external_memory_version=self._use_external_memory_version,
475
+ batch_size=self._batch_size,
447
476
  )
448
477
  self._sklearn_object = model_trainer.train()
449
478
  self._is_fitted = True
@@ -714,6 +743,22 @@ class XGBRegressor(BaseTransformer):
714
743
  # each row containing a list of values.
715
744
  expected_dtype = "ARRAY"
716
745
 
746
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
747
+ if expected_dtype == "":
748
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
749
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
750
+ expected_dtype = "ARRAY"
751
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
752
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
753
+ expected_dtype = "ARRAY"
754
+ else:
755
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
756
+ # We can only infer the output types from the input types if the following two statemetns are true:
757
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
758
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
759
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
760
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
761
+
717
762
  output_df = self._batch_inference(
718
763
  dataset=dataset,
719
764
  inference_method="transform",
@@ -729,8 +774,8 @@ class XGBRegressor(BaseTransformer):
729
774
 
730
775
  return output_df
731
776
 
732
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
733
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
777
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
778
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
734
779
  """ Method not supported for this class.
735
780
 
736
781
 
@@ -743,13 +788,21 @@ class XGBRegressor(BaseTransformer):
743
788
  Returns:
744
789
  Predicted dataset.
745
790
  """
746
- if False:
747
- self.fit(dataset)
748
- assert self._sklearn_object is not None
749
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
750
- return labels
751
- else:
752
- raise NotImplementedError
791
+ self.fit(dataset)
792
+ assert self._sklearn_object is not None
793
+ return self._sklearn_object.labels_
794
+
795
+
796
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
797
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
798
+ """
799
+ Returns:
800
+ Transformed dataset.
801
+ """
802
+ self.fit(dataset)
803
+ assert self._sklearn_object is not None
804
+ return self._sklearn_object.embedding_
805
+
753
806
 
754
807
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
755
808
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
53
53
  _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "").split("_")])
54
54
 
55
55
 
56
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
57
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
58
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
59
+ return check
60
+
61
+
62
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
+ return check
66
+
67
+
56
68
  class XGBRFClassifier(BaseTransformer):
57
69
  r"""scikit-learn API for XGBoost random forest classification
58
70
  For more details on this class, see [xgboost.XGBRFClassifier]
@@ -105,6 +117,22 @@ class XGBRFClassifier(BaseTransformer):
105
117
  drop_input_cols: Optional[bool], default=False
106
118
  If set, the response of predict(), transform() methods will not contain input columns.
107
119
 
120
+ use_external_memory_version: bool, default=False
121
+ If set, external memory version of XGBoost trainer is used. External memory training
122
+ is done in a two-step process. First,in the preprocessing step, input data is read and
123
+ parsed into an internal format, which can be CSR, CSC, or sorted CSC, and stored in
124
+ in-memory buffers. The in-memory buffers are continuously flushed out to disk when
125
+ predefined memory limit is reached. Second, in the tree construction step, the data
126
+ pages are streamed from disk via a multi-threaded pre-fetcher as needed for tree construction.
127
+ Note:'tree_method's 'approx', and 'hist' are supported in the external memory version.
128
+ Note:'grow_policy=depthwise' is used for optimal performance in the external memory version.
129
+
130
+ batch_size: int, default=10000
131
+ Number of rows in each batch of input data while using external memory training.
132
+ It is not recommended to set small batch sizes, like 32 samples per batch, as this
133
+ can seriously hurt performance in gradient boosting. Set the batch_size as large as possible
134
+ based on the available memory.
135
+
108
136
  n_estimators: int
109
137
  Number of trees in random forest to fit.
110
138
 
@@ -327,6 +355,8 @@ class XGBRFClassifier(BaseTransformer):
327
355
  passthrough_cols: Optional[Union[str, Iterable[str]]] = None,
328
356
  drop_input_cols: Optional[bool] = False,
329
357
  sample_weight_col: Optional[str] = None,
358
+ use_external_memory_version: bool = False,
359
+ batch_size: int = 10000,
330
360
  **kwargs,
331
361
  ) -> None:
332
362
  super().__init__()
@@ -336,7 +366,9 @@ class XGBRFClassifier(BaseTransformer):
336
366
  self.set_label_cols(label_cols)
337
367
  self.set_passthrough_cols(passthrough_cols)
338
368
  self.set_drop_input_cols(drop_input_cols)
339
- self.set_sample_weight_col(sample_weight_col)
369
+ self.set_sample_weight_col(sample_weight_col)
370
+ self._use_external_memory_version = use_external_memory_version
371
+ self._batch_size = batch_size
340
372
  deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
341
373
 
342
374
  self._deps = list(deps)
@@ -416,11 +448,6 @@ class XGBRFClassifier(BaseTransformer):
416
448
  if isinstance(dataset, DataFrame):
417
449
  session = dataset._session
418
450
  assert session is not None # keep mypy happy
419
- # Validate that key package version in user workspace are supported in snowflake conda channel
420
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
421
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
422
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
423
-
424
451
  # Specify input columns so column pruning will be enforced
425
452
  selected_cols = self._get_active_columns()
426
453
  if len(selected_cols) > 0:
@@ -448,7 +475,9 @@ class XGBRFClassifier(BaseTransformer):
448
475
  label_cols=self.label_cols,
449
476
  sample_weight_col=self.sample_weight_col,
450
477
  autogenerated=self._autogenerated,
451
- subproject=_SUBPROJECT
478
+ subproject=_SUBPROJECT,
479
+ use_external_memory_version=self._use_external_memory_version,
480
+ batch_size=self._batch_size,
452
481
  )
453
482
  self._sklearn_object = model_trainer.train()
454
483
  self._is_fitted = True
@@ -719,6 +748,22 @@ class XGBRFClassifier(BaseTransformer):
719
748
  # each row containing a list of values.
720
749
  expected_dtype = "ARRAY"
721
750
 
751
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
752
+ if expected_dtype == "":
753
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
754
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
755
+ expected_dtype = "ARRAY"
756
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
757
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
758
+ expected_dtype = "ARRAY"
759
+ else:
760
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
761
+ # We can only infer the output types from the input types if the following two statemetns are true:
762
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
763
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
764
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
765
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
766
+
722
767
  output_df = self._batch_inference(
723
768
  dataset=dataset,
724
769
  inference_method="transform",
@@ -734,8 +779,8 @@ class XGBRFClassifier(BaseTransformer):
734
779
 
735
780
  return output_df
736
781
 
737
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
738
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
782
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
783
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
739
784
  """ Method not supported for this class.
740
785
 
741
786
 
@@ -748,13 +793,21 @@ class XGBRFClassifier(BaseTransformer):
748
793
  Returns:
749
794
  Predicted dataset.
750
795
  """
751
- if False:
752
- self.fit(dataset)
753
- assert self._sklearn_object is not None
754
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
755
- return labels
756
- else:
757
- raise NotImplementedError
796
+ self.fit(dataset)
797
+ assert self._sklearn_object is not None
798
+ return self._sklearn_object.labels_
799
+
800
+
801
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
802
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
803
+ """
804
+ Returns:
805
+ Transformed dataset.
806
+ """
807
+ self.fit(dataset)
808
+ assert self._sklearn_object is not None
809
+ return self._sklearn_object.embedding_
810
+
758
811
 
759
812
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
760
813
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
53
53
  _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "").split("_")])
54
54
 
55
55
 
56
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
57
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
58
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
59
+ return check
60
+
61
+
62
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
+ return check
66
+
67
+
56
68
  class XGBRFRegressor(BaseTransformer):
57
69
  r"""scikit-learn API for XGBoost random forest regression
58
70
  For more details on this class, see [xgboost.XGBRFRegressor]
@@ -105,6 +117,22 @@ class XGBRFRegressor(BaseTransformer):
105
117
  drop_input_cols: Optional[bool], default=False
106
118
  If set, the response of predict(), transform() methods will not contain input columns.
107
119
 
120
+ use_external_memory_version: bool, default=False
121
+ If set, external memory version of XGBoost trainer is used. External memory training
122
+ is done in a two-step process. First,in the preprocessing step, input data is read and
123
+ parsed into an internal format, which can be CSR, CSC, or sorted CSC, and stored in
124
+ in-memory buffers. The in-memory buffers are continuously flushed out to disk when
125
+ predefined memory limit is reached. Second, in the tree construction step, the data
126
+ pages are streamed from disk via a multi-threaded pre-fetcher as needed for tree construction.
127
+ Note:'tree_method's 'approx', and 'hist' are supported in the external memory version.
128
+ Note:'grow_policy=depthwise' is used for optimal performance in the external memory version.
129
+
130
+ batch_size: int, default=10000
131
+ Number of rows in each batch of input data while using external memory training.
132
+ It is not recommended to set small batch sizes, like 32 samples per batch, as this
133
+ can seriously hurt performance in gradient boosting. Set the batch_size as large as possible
134
+ based on the available memory.
135
+
108
136
  n_estimators: int
109
137
  Number of trees in random forest to fit.
110
138
 
@@ -327,6 +355,8 @@ class XGBRFRegressor(BaseTransformer):
327
355
  passthrough_cols: Optional[Union[str, Iterable[str]]] = None,
328
356
  drop_input_cols: Optional[bool] = False,
329
357
  sample_weight_col: Optional[str] = None,
358
+ use_external_memory_version: bool = False,
359
+ batch_size: int = 10000,
330
360
  **kwargs,
331
361
  ) -> None:
332
362
  super().__init__()
@@ -336,7 +366,9 @@ class XGBRFRegressor(BaseTransformer):
336
366
  self.set_label_cols(label_cols)
337
367
  self.set_passthrough_cols(passthrough_cols)
338
368
  self.set_drop_input_cols(drop_input_cols)
339
- self.set_sample_weight_col(sample_weight_col)
369
+ self.set_sample_weight_col(sample_weight_col)
370
+ self._use_external_memory_version = use_external_memory_version
371
+ self._batch_size = batch_size
340
372
  deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
341
373
 
342
374
  self._deps = list(deps)
@@ -416,11 +448,6 @@ class XGBRFRegressor(BaseTransformer):
416
448
  if isinstance(dataset, DataFrame):
417
449
  session = dataset._session
418
450
  assert session is not None # keep mypy happy
419
- # Validate that key package version in user workspace are supported in snowflake conda channel
420
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
421
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
422
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
423
-
424
451
  # Specify input columns so column pruning will be enforced
425
452
  selected_cols = self._get_active_columns()
426
453
  if len(selected_cols) > 0:
@@ -448,7 +475,9 @@ class XGBRFRegressor(BaseTransformer):
448
475
  label_cols=self.label_cols,
449
476
  sample_weight_col=self.sample_weight_col,
450
477
  autogenerated=self._autogenerated,
451
- subproject=_SUBPROJECT
478
+ subproject=_SUBPROJECT,
479
+ use_external_memory_version=self._use_external_memory_version,
480
+ batch_size=self._batch_size,
452
481
  )
453
482
  self._sklearn_object = model_trainer.train()
454
483
  self._is_fitted = True
@@ -719,6 +748,22 @@ class XGBRFRegressor(BaseTransformer):
719
748
  # each row containing a list of values.
720
749
  expected_dtype = "ARRAY"
721
750
 
751
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
752
+ if expected_dtype == "":
753
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
754
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
755
+ expected_dtype = "ARRAY"
756
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
757
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
758
+ expected_dtype = "ARRAY"
759
+ else:
760
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
761
+ # We can only infer the output types from the input types if the following two statemetns are true:
762
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
763
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
764
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
765
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
766
+
722
767
  output_df = self._batch_inference(
723
768
  dataset=dataset,
724
769
  inference_method="transform",
@@ -734,8 +779,8 @@ class XGBRFRegressor(BaseTransformer):
734
779
 
735
780
  return output_df
736
781
 
737
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
738
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
782
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
783
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
739
784
  """ Method not supported for this class.
740
785
 
741
786
 
@@ -748,13 +793,21 @@ class XGBRFRegressor(BaseTransformer):
748
793
  Returns:
749
794
  Predicted dataset.
750
795
  """
751
- if False:
752
- self.fit(dataset)
753
- assert self._sklearn_object is not None
754
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
755
- return labels
756
- else:
757
- raise NotImplementedError
796
+ self.fit(dataset)
797
+ assert self._sklearn_object is not None
798
+ return self._sklearn_object.labels_
799
+
800
+
801
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
802
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
803
+ """
804
+ Returns:
805
+ Transformed dataset.
806
+ """
807
+ self.fit(dataset)
808
+ assert self._sklearn_object is not None
809
+ return self._sklearn_object.embedding_
810
+
758
811
 
759
812
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
760
813
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -0,0 +1,3 @@
1
+ from snowflake.ml.registry.registry import Registry
2
+
3
+ __all__ = ["Registry"]