snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RidgeCV(BaseTransformer):
58
70
  r"""Ridge regression with built-in cross-validation
59
71
  For more details on this class, see [sklearn.linear_model.RidgeCV]
@@ -195,7 +207,9 @@ class RidgeCV(BaseTransformer):
195
207
  self.set_label_cols(label_cols)
196
208
  self.set_passthrough_cols(passthrough_cols)
197
209
  self.set_drop_input_cols(drop_input_cols)
198
- self.set_sample_weight_col(sample_weight_col)
210
+ self.set_sample_weight_col(sample_weight_col)
211
+ self._use_external_memory_version = False
212
+ self._batch_size = -1
199
213
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
200
214
 
201
215
  self._deps = list(deps)
@@ -277,11 +291,6 @@ class RidgeCV(BaseTransformer):
277
291
  if isinstance(dataset, DataFrame):
278
292
  session = dataset._session
279
293
  assert session is not None # keep mypy happy
280
- # Validate that key package version in user workspace are supported in snowflake conda channel
281
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
282
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
283
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
284
-
285
294
  # Specify input columns so column pruning will be enforced
286
295
  selected_cols = self._get_active_columns()
287
296
  if len(selected_cols) > 0:
@@ -309,7 +318,9 @@ class RidgeCV(BaseTransformer):
309
318
  label_cols=self.label_cols,
310
319
  sample_weight_col=self.sample_weight_col,
311
320
  autogenerated=self._autogenerated,
312
- subproject=_SUBPROJECT
321
+ subproject=_SUBPROJECT,
322
+ use_external_memory_version=self._use_external_memory_version,
323
+ batch_size=self._batch_size,
313
324
  )
314
325
  self._sklearn_object = model_trainer.train()
315
326
  self._is_fitted = True
@@ -580,6 +591,22 @@ class RidgeCV(BaseTransformer):
580
591
  # each row containing a list of values.
581
592
  expected_dtype = "ARRAY"
582
593
 
594
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
595
+ if expected_dtype == "":
596
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
597
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
600
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
601
+ expected_dtype = "ARRAY"
602
+ else:
603
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
604
+ # We can only infer the output types from the input types if the following two statemetns are true:
605
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
606
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
607
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
608
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
609
+
583
610
  output_df = self._batch_inference(
584
611
  dataset=dataset,
585
612
  inference_method="transform",
@@ -595,8 +622,8 @@ class RidgeCV(BaseTransformer):
595
622
 
596
623
  return output_df
597
624
 
598
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
599
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
625
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
626
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
600
627
  """ Method not supported for this class.
601
628
 
602
629
 
@@ -609,13 +636,21 @@ class RidgeCV(BaseTransformer):
609
636
  Returns:
610
637
  Predicted dataset.
611
638
  """
612
- if False:
613
- self.fit(dataset)
614
- assert self._sklearn_object is not None
615
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
616
- return labels
617
- else:
618
- raise NotImplementedError
639
+ self.fit(dataset)
640
+ assert self._sklearn_object is not None
641
+ return self._sklearn_object.labels_
642
+
643
+
644
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
645
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
646
+ """
647
+ Returns:
648
+ Transformed dataset.
649
+ """
650
+ self.fit(dataset)
651
+ assert self._sklearn_object is not None
652
+ return self._sklearn_object.embedding_
653
+
619
654
 
620
655
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
621
656
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SGDClassifier(BaseTransformer):
58
70
  r"""Linear classifiers (SVM, logistic regression, etc
59
71
  For more details on this class, see [sklearn.linear_model.SGDClassifier]
@@ -300,7 +312,9 @@ class SGDClassifier(BaseTransformer):
300
312
  self.set_label_cols(label_cols)
301
313
  self.set_passthrough_cols(passthrough_cols)
302
314
  self.set_drop_input_cols(drop_input_cols)
303
- self.set_sample_weight_col(sample_weight_col)
315
+ self.set_sample_weight_col(sample_weight_col)
316
+ self._use_external_memory_version = False
317
+ self._batch_size = -1
304
318
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
305
319
 
306
320
  self._deps = list(deps)
@@ -396,11 +410,6 @@ class SGDClassifier(BaseTransformer):
396
410
  if isinstance(dataset, DataFrame):
397
411
  session = dataset._session
398
412
  assert session is not None # keep mypy happy
399
- # Validate that key package version in user workspace are supported in snowflake conda channel
400
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
401
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
402
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
403
-
404
413
  # Specify input columns so column pruning will be enforced
405
414
  selected_cols = self._get_active_columns()
406
415
  if len(selected_cols) > 0:
@@ -428,7 +437,9 @@ class SGDClassifier(BaseTransformer):
428
437
  label_cols=self.label_cols,
429
438
  sample_weight_col=self.sample_weight_col,
430
439
  autogenerated=self._autogenerated,
431
- subproject=_SUBPROJECT
440
+ subproject=_SUBPROJECT,
441
+ use_external_memory_version=self._use_external_memory_version,
442
+ batch_size=self._batch_size,
432
443
  )
433
444
  self._sklearn_object = model_trainer.train()
434
445
  self._is_fitted = True
@@ -699,6 +710,22 @@ class SGDClassifier(BaseTransformer):
699
710
  # each row containing a list of values.
700
711
  expected_dtype = "ARRAY"
701
712
 
713
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
714
+ if expected_dtype == "":
715
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
716
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
717
+ expected_dtype = "ARRAY"
718
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
719
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
720
+ expected_dtype = "ARRAY"
721
+ else:
722
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
723
+ # We can only infer the output types from the input types if the following two statemetns are true:
724
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
725
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
726
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
727
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
728
+
702
729
  output_df = self._batch_inference(
703
730
  dataset=dataset,
704
731
  inference_method="transform",
@@ -714,8 +741,8 @@ class SGDClassifier(BaseTransformer):
714
741
 
715
742
  return output_df
716
743
 
717
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
718
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
744
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
745
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
719
746
  """ Method not supported for this class.
720
747
 
721
748
 
@@ -728,13 +755,21 @@ class SGDClassifier(BaseTransformer):
728
755
  Returns:
729
756
  Predicted dataset.
730
757
  """
731
- if False:
732
- self.fit(dataset)
733
- assert self._sklearn_object is not None
734
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
735
- return labels
736
- else:
737
- raise NotImplementedError
758
+ self.fit(dataset)
759
+ assert self._sklearn_object is not None
760
+ return self._sklearn_object.labels_
761
+
762
+
763
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
764
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
765
+ """
766
+ Returns:
767
+ Transformed dataset.
768
+ """
769
+ self.fit(dataset)
770
+ assert self._sklearn_object is not None
771
+ return self._sklearn_object.embedding_
772
+
738
773
 
739
774
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
740
775
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SGDOneClassSVM(BaseTransformer):
58
70
  r"""Solves linear One-Class SVM using Stochastic Gradient Descent
59
71
  For more details on this class, see [sklearn.linear_model.SGDOneClassSVM]
@@ -207,7 +219,9 @@ class SGDOneClassSVM(BaseTransformer):
207
219
  self.set_label_cols(label_cols)
208
220
  self.set_passthrough_cols(passthrough_cols)
209
221
  self.set_drop_input_cols(drop_input_cols)
210
- self.set_sample_weight_col(sample_weight_col)
222
+ self.set_sample_weight_col(sample_weight_col)
223
+ self._use_external_memory_version = False
224
+ self._batch_size = -1
211
225
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
212
226
 
213
227
  self._deps = list(deps)
@@ -294,11 +308,6 @@ class SGDOneClassSVM(BaseTransformer):
294
308
  if isinstance(dataset, DataFrame):
295
309
  session = dataset._session
296
310
  assert session is not None # keep mypy happy
297
- # Validate that key package version in user workspace are supported in snowflake conda channel
298
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
299
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
300
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
301
-
302
311
  # Specify input columns so column pruning will be enforced
303
312
  selected_cols = self._get_active_columns()
304
313
  if len(selected_cols) > 0:
@@ -326,7 +335,9 @@ class SGDOneClassSVM(BaseTransformer):
326
335
  label_cols=self.label_cols,
327
336
  sample_weight_col=self.sample_weight_col,
328
337
  autogenerated=self._autogenerated,
329
- subproject=_SUBPROJECT
338
+ subproject=_SUBPROJECT,
339
+ use_external_memory_version=self._use_external_memory_version,
340
+ batch_size=self._batch_size,
330
341
  )
331
342
  self._sklearn_object = model_trainer.train()
332
343
  self._is_fitted = True
@@ -597,6 +608,22 @@ class SGDOneClassSVM(BaseTransformer):
597
608
  # each row containing a list of values.
598
609
  expected_dtype = "ARRAY"
599
610
 
611
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
612
+ if expected_dtype == "":
613
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
614
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
615
+ expected_dtype = "ARRAY"
616
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
617
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
618
+ expected_dtype = "ARRAY"
619
+ else:
620
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
621
+ # We can only infer the output types from the input types if the following two statemetns are true:
622
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
623
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
624
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
625
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
626
+
600
627
  output_df = self._batch_inference(
601
628
  dataset=dataset,
602
629
  inference_method="transform",
@@ -612,8 +639,8 @@ class SGDOneClassSVM(BaseTransformer):
612
639
 
613
640
  return output_df
614
641
 
615
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
616
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
642
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
643
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
617
644
  """ Perform fit on X and returns labels for X
618
645
  For more details on this function, see [sklearn.linear_model.SGDOneClassSVM.fit_predict]
619
646
  (https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDOneClassSVM.html#sklearn.linear_model.SGDOneClassSVM.fit_predict)
@@ -628,13 +655,21 @@ class SGDOneClassSVM(BaseTransformer):
628
655
  Returns:
629
656
  Predicted dataset.
630
657
  """
631
- if False:
632
- self.fit(dataset)
633
- assert self._sklearn_object is not None
634
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
635
- return labels
636
- else:
637
- raise NotImplementedError
658
+ self.fit(dataset)
659
+ assert self._sklearn_object is not None
660
+ return self._sklearn_object.labels_
661
+
662
+
663
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
664
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
665
+ """
666
+ Returns:
667
+ Transformed dataset.
668
+ """
669
+ self.fit(dataset)
670
+ assert self._sklearn_object is not None
671
+ return self._sklearn_object.embedding_
672
+
638
673
 
639
674
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
640
675
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SGDRegressor(BaseTransformer):
58
70
  r"""Linear model fitted by minimizing a regularized empirical loss with SGD
59
71
  For more details on this class, see [sklearn.linear_model.SGDRegressor]
@@ -268,7 +280,9 @@ class SGDRegressor(BaseTransformer):
268
280
  self.set_label_cols(label_cols)
269
281
  self.set_passthrough_cols(passthrough_cols)
270
282
  self.set_drop_input_cols(drop_input_cols)
271
- self.set_sample_weight_col(sample_weight_col)
283
+ self.set_sample_weight_col(sample_weight_col)
284
+ self._use_external_memory_version = False
285
+ self._batch_size = -1
272
286
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
273
287
 
274
288
  self._deps = list(deps)
@@ -362,11 +376,6 @@ class SGDRegressor(BaseTransformer):
362
376
  if isinstance(dataset, DataFrame):
363
377
  session = dataset._session
364
378
  assert session is not None # keep mypy happy
365
- # Validate that key package version in user workspace are supported in snowflake conda channel
366
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
367
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
368
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
369
-
370
379
  # Specify input columns so column pruning will be enforced
371
380
  selected_cols = self._get_active_columns()
372
381
  if len(selected_cols) > 0:
@@ -394,7 +403,9 @@ class SGDRegressor(BaseTransformer):
394
403
  label_cols=self.label_cols,
395
404
  sample_weight_col=self.sample_weight_col,
396
405
  autogenerated=self._autogenerated,
397
- subproject=_SUBPROJECT
406
+ subproject=_SUBPROJECT,
407
+ use_external_memory_version=self._use_external_memory_version,
408
+ batch_size=self._batch_size,
398
409
  )
399
410
  self._sklearn_object = model_trainer.train()
400
411
  self._is_fitted = True
@@ -665,6 +676,22 @@ class SGDRegressor(BaseTransformer):
665
676
  # each row containing a list of values.
666
677
  expected_dtype = "ARRAY"
667
678
 
679
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
680
+ if expected_dtype == "":
681
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
682
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
683
+ expected_dtype = "ARRAY"
684
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
685
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
686
+ expected_dtype = "ARRAY"
687
+ else:
688
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
689
+ # We can only infer the output types from the input types if the following two statemetns are true:
690
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
691
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
692
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
693
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
694
+
668
695
  output_df = self._batch_inference(
669
696
  dataset=dataset,
670
697
  inference_method="transform",
@@ -680,8 +707,8 @@ class SGDRegressor(BaseTransformer):
680
707
 
681
708
  return output_df
682
709
 
683
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
684
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
710
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
711
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
685
712
  """ Method not supported for this class.
686
713
 
687
714
 
@@ -694,13 +721,21 @@ class SGDRegressor(BaseTransformer):
694
721
  Returns:
695
722
  Predicted dataset.
696
723
  """
697
- if False:
698
- self.fit(dataset)
699
- assert self._sklearn_object is not None
700
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
701
- return labels
702
- else:
703
- raise NotImplementedError
724
+ self.fit(dataset)
725
+ assert self._sklearn_object is not None
726
+ return self._sklearn_object.labels_
727
+
728
+
729
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
730
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
731
+ """
732
+ Returns:
733
+ Transformed dataset.
734
+ """
735
+ self.fit(dataset)
736
+ assert self._sklearn_object is not None
737
+ return self._sklearn_object.embedding_
738
+
704
739
 
705
740
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
706
741
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class TheilSenRegressor(BaseTransformer):
58
70
  r"""Theil-Sen Estimator: robust multivariate regression model
59
71
  For more details on this class, see [sklearn.linear_model.TheilSenRegressor]
@@ -180,7 +192,9 @@ class TheilSenRegressor(BaseTransformer):
180
192
  self.set_label_cols(label_cols)
181
193
  self.set_passthrough_cols(passthrough_cols)
182
194
  self.set_drop_input_cols(drop_input_cols)
183
- self.set_sample_weight_col(sample_weight_col)
195
+ self.set_sample_weight_col(sample_weight_col)
196
+ self._use_external_memory_version = False
197
+ self._batch_size = -1
184
198
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
185
199
 
186
200
  self._deps = list(deps)
@@ -264,11 +278,6 @@ class TheilSenRegressor(BaseTransformer):
264
278
  if isinstance(dataset, DataFrame):
265
279
  session = dataset._session
266
280
  assert session is not None # keep mypy happy
267
- # Validate that key package version in user workspace are supported in snowflake conda channel
268
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
269
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
270
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
271
-
272
281
  # Specify input columns so column pruning will be enforced
273
282
  selected_cols = self._get_active_columns()
274
283
  if len(selected_cols) > 0:
@@ -296,7 +305,9 @@ class TheilSenRegressor(BaseTransformer):
296
305
  label_cols=self.label_cols,
297
306
  sample_weight_col=self.sample_weight_col,
298
307
  autogenerated=self._autogenerated,
299
- subproject=_SUBPROJECT
308
+ subproject=_SUBPROJECT,
309
+ use_external_memory_version=self._use_external_memory_version,
310
+ batch_size=self._batch_size,
300
311
  )
301
312
  self._sklearn_object = model_trainer.train()
302
313
  self._is_fitted = True
@@ -567,6 +578,22 @@ class TheilSenRegressor(BaseTransformer):
567
578
  # each row containing a list of values.
568
579
  expected_dtype = "ARRAY"
569
580
 
581
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
582
+ if expected_dtype == "":
583
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
584
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
585
+ expected_dtype = "ARRAY"
586
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
587
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
588
+ expected_dtype = "ARRAY"
589
+ else:
590
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
591
+ # We can only infer the output types from the input types if the following two statemetns are true:
592
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
593
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
594
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
595
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
596
+
570
597
  output_df = self._batch_inference(
571
598
  dataset=dataset,
572
599
  inference_method="transform",
@@ -582,8 +609,8 @@ class TheilSenRegressor(BaseTransformer):
582
609
 
583
610
  return output_df
584
611
 
585
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
586
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
612
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
613
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
587
614
  """ Method not supported for this class.
588
615
 
589
616
 
@@ -596,13 +623,21 @@ class TheilSenRegressor(BaseTransformer):
596
623
  Returns:
597
624
  Predicted dataset.
598
625
  """
599
- if False:
600
- self.fit(dataset)
601
- assert self._sklearn_object is not None
602
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
603
- return labels
604
- else:
605
- raise NotImplementedError
626
+ self.fit(dataset)
627
+ assert self._sklearn_object is not None
628
+ return self._sklearn_object.labels_
629
+
630
+
631
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
632
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
633
+ """
634
+ Returns:
635
+ Transformed dataset.
636
+ """
637
+ self.fit(dataset)
638
+ assert self._sklearn_object is not None
639
+ return self._sklearn_object.embedding_
640
+
606
641
 
607
642
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
608
643
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.