snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MultiTaskLassoCV(BaseTransformer):
58
70
  r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
59
71
  For more details on this class, see [sklearn.linear_model.MultiTaskLassoCV]
@@ -200,7 +212,9 @@ class MultiTaskLassoCV(BaseTransformer):
200
212
  self.set_label_cols(label_cols)
201
213
  self.set_passthrough_cols(passthrough_cols)
202
214
  self.set_drop_input_cols(drop_input_cols)
203
- self.set_sample_weight_col(sample_weight_col)
215
+ self.set_sample_weight_col(sample_weight_col)
216
+ self._use_external_memory_version = False
217
+ self._batch_size = -1
204
218
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
205
219
 
206
220
  self._deps = list(deps)
@@ -287,11 +301,6 @@ class MultiTaskLassoCV(BaseTransformer):
287
301
  if isinstance(dataset, DataFrame):
288
302
  session = dataset._session
289
303
  assert session is not None # keep mypy happy
290
- # Validate that key package version in user workspace are supported in snowflake conda channel
291
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
292
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
293
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
294
-
295
304
  # Specify input columns so column pruning will be enforced
296
305
  selected_cols = self._get_active_columns()
297
306
  if len(selected_cols) > 0:
@@ -319,7 +328,9 @@ class MultiTaskLassoCV(BaseTransformer):
319
328
  label_cols=self.label_cols,
320
329
  sample_weight_col=self.sample_weight_col,
321
330
  autogenerated=self._autogenerated,
322
- subproject=_SUBPROJECT
331
+ subproject=_SUBPROJECT,
332
+ use_external_memory_version=self._use_external_memory_version,
333
+ batch_size=self._batch_size,
323
334
  )
324
335
  self._sklearn_object = model_trainer.train()
325
336
  self._is_fitted = True
@@ -590,6 +601,22 @@ class MultiTaskLassoCV(BaseTransformer):
590
601
  # each row containing a list of values.
591
602
  expected_dtype = "ARRAY"
592
603
 
604
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
605
+ if expected_dtype == "":
606
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
607
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
608
+ expected_dtype = "ARRAY"
609
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
610
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ else:
613
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
614
+ # We can only infer the output types from the input types if the following two statemetns are true:
615
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
616
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
617
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
618
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
619
+
593
620
  output_df = self._batch_inference(
594
621
  dataset=dataset,
595
622
  inference_method="transform",
@@ -605,8 +632,8 @@ class MultiTaskLassoCV(BaseTransformer):
605
632
 
606
633
  return output_df
607
634
 
608
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
609
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
635
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
636
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
610
637
  """ Method not supported for this class.
611
638
 
612
639
 
@@ -619,13 +646,21 @@ class MultiTaskLassoCV(BaseTransformer):
619
646
  Returns:
620
647
  Predicted dataset.
621
648
  """
622
- if False:
623
- self.fit(dataset)
624
- assert self._sklearn_object is not None
625
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
626
- return labels
627
- else:
628
- raise NotImplementedError
649
+ self.fit(dataset)
650
+ assert self._sklearn_object is not None
651
+ return self._sklearn_object.labels_
652
+
653
+
654
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
655
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
656
+ """
657
+ Returns:
658
+ Transformed dataset.
659
+ """
660
+ self.fit(dataset)
661
+ assert self._sklearn_object is not None
662
+ return self._sklearn_object.embedding_
663
+
629
664
 
630
665
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
631
666
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class OrthogonalMatchingPursuit(BaseTransformer):
58
70
  r"""Orthogonal Matching Pursuit model (OMP)
59
71
  For more details on this class, see [sklearn.linear_model.OrthogonalMatchingPursuit]
@@ -155,7 +167,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
155
167
  self.set_label_cols(label_cols)
156
168
  self.set_passthrough_cols(passthrough_cols)
157
169
  self.set_drop_input_cols(drop_input_cols)
158
- self.set_sample_weight_col(sample_weight_col)
170
+ self.set_sample_weight_col(sample_weight_col)
171
+ self._use_external_memory_version = False
172
+ self._batch_size = -1
159
173
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
160
174
 
161
175
  self._deps = list(deps)
@@ -235,11 +249,6 @@ class OrthogonalMatchingPursuit(BaseTransformer):
235
249
  if isinstance(dataset, DataFrame):
236
250
  session = dataset._session
237
251
  assert session is not None # keep mypy happy
238
- # Validate that key package version in user workspace are supported in snowflake conda channel
239
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
240
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
241
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
242
-
243
252
  # Specify input columns so column pruning will be enforced
244
253
  selected_cols = self._get_active_columns()
245
254
  if len(selected_cols) > 0:
@@ -267,7 +276,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
267
276
  label_cols=self.label_cols,
268
277
  sample_weight_col=self.sample_weight_col,
269
278
  autogenerated=self._autogenerated,
270
- subproject=_SUBPROJECT
279
+ subproject=_SUBPROJECT,
280
+ use_external_memory_version=self._use_external_memory_version,
281
+ batch_size=self._batch_size,
271
282
  )
272
283
  self._sklearn_object = model_trainer.train()
273
284
  self._is_fitted = True
@@ -538,6 +549,22 @@ class OrthogonalMatchingPursuit(BaseTransformer):
538
549
  # each row containing a list of values.
539
550
  expected_dtype = "ARRAY"
540
551
 
552
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
553
+ if expected_dtype == "":
554
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
555
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
556
+ expected_dtype = "ARRAY"
557
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
558
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
559
+ expected_dtype = "ARRAY"
560
+ else:
561
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
562
+ # We can only infer the output types from the input types if the following two statemetns are true:
563
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
564
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
565
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
566
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
567
+
541
568
  output_df = self._batch_inference(
542
569
  dataset=dataset,
543
570
  inference_method="transform",
@@ -553,8 +580,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
553
580
 
554
581
  return output_df
555
582
 
556
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
557
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
583
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
584
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
558
585
  """ Method not supported for this class.
559
586
 
560
587
 
@@ -567,13 +594,21 @@ class OrthogonalMatchingPursuit(BaseTransformer):
567
594
  Returns:
568
595
  Predicted dataset.
569
596
  """
570
- if False:
571
- self.fit(dataset)
572
- assert self._sklearn_object is not None
573
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
574
- return labels
575
- else:
576
- raise NotImplementedError
597
+ self.fit(dataset)
598
+ assert self._sklearn_object is not None
599
+ return self._sklearn_object.labels_
600
+
601
+
602
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
603
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
604
+ """
605
+ Returns:
606
+ Transformed dataset.
607
+ """
608
+ self.fit(dataset)
609
+ assert self._sklearn_object is not None
610
+ return self._sklearn_object.embedding_
611
+
577
612
 
578
613
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
579
614
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class PassiveAggressiveClassifier(BaseTransformer):
58
70
  r"""Passive Aggressive Classifier
59
71
  For more details on this class, see [sklearn.linear_model.PassiveAggressiveClassifier]
@@ -219,7 +231,9 @@ class PassiveAggressiveClassifier(BaseTransformer):
219
231
  self.set_label_cols(label_cols)
220
232
  self.set_passthrough_cols(passthrough_cols)
221
233
  self.set_drop_input_cols(drop_input_cols)
222
- self.set_sample_weight_col(sample_weight_col)
234
+ self.set_sample_weight_col(sample_weight_col)
235
+ self._use_external_memory_version = False
236
+ self._batch_size = -1
223
237
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
224
238
 
225
239
  self._deps = list(deps)
@@ -309,11 +323,6 @@ class PassiveAggressiveClassifier(BaseTransformer):
309
323
  if isinstance(dataset, DataFrame):
310
324
  session = dataset._session
311
325
  assert session is not None # keep mypy happy
312
- # Validate that key package version in user workspace are supported in snowflake conda channel
313
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
314
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
315
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
316
-
317
326
  # Specify input columns so column pruning will be enforced
318
327
  selected_cols = self._get_active_columns()
319
328
  if len(selected_cols) > 0:
@@ -341,7 +350,9 @@ class PassiveAggressiveClassifier(BaseTransformer):
341
350
  label_cols=self.label_cols,
342
351
  sample_weight_col=self.sample_weight_col,
343
352
  autogenerated=self._autogenerated,
344
- subproject=_SUBPROJECT
353
+ subproject=_SUBPROJECT,
354
+ use_external_memory_version=self._use_external_memory_version,
355
+ batch_size=self._batch_size,
345
356
  )
346
357
  self._sklearn_object = model_trainer.train()
347
358
  self._is_fitted = True
@@ -612,6 +623,22 @@ class PassiveAggressiveClassifier(BaseTransformer):
612
623
  # each row containing a list of values.
613
624
  expected_dtype = "ARRAY"
614
625
 
626
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
627
+ if expected_dtype == "":
628
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
629
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
630
+ expected_dtype = "ARRAY"
631
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
632
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
633
+ expected_dtype = "ARRAY"
634
+ else:
635
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
636
+ # We can only infer the output types from the input types if the following two statemetns are true:
637
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
638
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
639
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
640
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
641
+
615
642
  output_df = self._batch_inference(
616
643
  dataset=dataset,
617
644
  inference_method="transform",
@@ -627,8 +654,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
627
654
 
628
655
  return output_df
629
656
 
630
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
631
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
657
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
658
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
632
659
  """ Method not supported for this class.
633
660
 
634
661
 
@@ -641,13 +668,21 @@ class PassiveAggressiveClassifier(BaseTransformer):
641
668
  Returns:
642
669
  Predicted dataset.
643
670
  """
644
- if False:
645
- self.fit(dataset)
646
- assert self._sklearn_object is not None
647
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
648
- return labels
649
- else:
650
- raise NotImplementedError
671
+ self.fit(dataset)
672
+ assert self._sklearn_object is not None
673
+ return self._sklearn_object.labels_
674
+
675
+
676
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
677
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
678
+ """
679
+ Returns:
680
+ Transformed dataset.
681
+ """
682
+ self.fit(dataset)
683
+ assert self._sklearn_object is not None
684
+ return self._sklearn_object.embedding_
685
+
651
686
 
652
687
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
653
688
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class PassiveAggressiveRegressor(BaseTransformer):
58
70
  r"""Passive Aggressive Regressor
59
71
  For more details on this class, see [sklearn.linear_model.PassiveAggressiveRegressor]
@@ -206,7 +218,9 @@ class PassiveAggressiveRegressor(BaseTransformer):
206
218
  self.set_label_cols(label_cols)
207
219
  self.set_passthrough_cols(passthrough_cols)
208
220
  self.set_drop_input_cols(drop_input_cols)
209
- self.set_sample_weight_col(sample_weight_col)
221
+ self.set_sample_weight_col(sample_weight_col)
222
+ self._use_external_memory_version = False
223
+ self._batch_size = -1
210
224
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
211
225
 
212
226
  self._deps = list(deps)
@@ -295,11 +309,6 @@ class PassiveAggressiveRegressor(BaseTransformer):
295
309
  if isinstance(dataset, DataFrame):
296
310
  session = dataset._session
297
311
  assert session is not None # keep mypy happy
298
- # Validate that key package version in user workspace are supported in snowflake conda channel
299
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
300
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
301
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
302
-
303
312
  # Specify input columns so column pruning will be enforced
304
313
  selected_cols = self._get_active_columns()
305
314
  if len(selected_cols) > 0:
@@ -327,7 +336,9 @@ class PassiveAggressiveRegressor(BaseTransformer):
327
336
  label_cols=self.label_cols,
328
337
  sample_weight_col=self.sample_weight_col,
329
338
  autogenerated=self._autogenerated,
330
- subproject=_SUBPROJECT
339
+ subproject=_SUBPROJECT,
340
+ use_external_memory_version=self._use_external_memory_version,
341
+ batch_size=self._batch_size,
331
342
  )
332
343
  self._sklearn_object = model_trainer.train()
333
344
  self._is_fitted = True
@@ -598,6 +609,22 @@ class PassiveAggressiveRegressor(BaseTransformer):
598
609
  # each row containing a list of values.
599
610
  expected_dtype = "ARRAY"
600
611
 
612
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
613
+ if expected_dtype == "":
614
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
615
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
616
+ expected_dtype = "ARRAY"
617
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
618
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
619
+ expected_dtype = "ARRAY"
620
+ else:
621
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
622
+ # We can only infer the output types from the input types if the following two statemetns are true:
623
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
624
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
625
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
626
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
627
+
601
628
  output_df = self._batch_inference(
602
629
  dataset=dataset,
603
630
  inference_method="transform",
@@ -613,8 +640,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
613
640
 
614
641
  return output_df
615
642
 
616
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
617
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
643
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
644
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
618
645
  """ Method not supported for this class.
619
646
 
620
647
 
@@ -627,13 +654,21 @@ class PassiveAggressiveRegressor(BaseTransformer):
627
654
  Returns:
628
655
  Predicted dataset.
629
656
  """
630
- if False:
631
- self.fit(dataset)
632
- assert self._sklearn_object is not None
633
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
634
- return labels
635
- else:
636
- raise NotImplementedError
657
+ self.fit(dataset)
658
+ assert self._sklearn_object is not None
659
+ return self._sklearn_object.labels_
660
+
661
+
662
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
663
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
664
+ """
665
+ Returns:
666
+ Transformed dataset.
667
+ """
668
+ self.fit(dataset)
669
+ assert self._sklearn_object is not None
670
+ return self._sklearn_object.embedding_
671
+
637
672
 
638
673
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
639
674
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class Perceptron(BaseTransformer):
58
70
  r"""Linear perceptron classifier
59
71
  For more details on this class, see [sklearn.linear_model.Perceptron]
@@ -217,7 +229,9 @@ class Perceptron(BaseTransformer):
217
229
  self.set_label_cols(label_cols)
218
230
  self.set_passthrough_cols(passthrough_cols)
219
231
  self.set_drop_input_cols(drop_input_cols)
220
- self.set_sample_weight_col(sample_weight_col)
232
+ self.set_sample_weight_col(sample_weight_col)
233
+ self._use_external_memory_version = False
234
+ self._batch_size = -1
221
235
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
222
236
 
223
237
  self._deps = list(deps)
@@ -308,11 +322,6 @@ class Perceptron(BaseTransformer):
308
322
  if isinstance(dataset, DataFrame):
309
323
  session = dataset._session
310
324
  assert session is not None # keep mypy happy
311
- # Validate that key package version in user workspace are supported in snowflake conda channel
312
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
313
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
314
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
315
-
316
325
  # Specify input columns so column pruning will be enforced
317
326
  selected_cols = self._get_active_columns()
318
327
  if len(selected_cols) > 0:
@@ -340,7 +349,9 @@ class Perceptron(BaseTransformer):
340
349
  label_cols=self.label_cols,
341
350
  sample_weight_col=self.sample_weight_col,
342
351
  autogenerated=self._autogenerated,
343
- subproject=_SUBPROJECT
352
+ subproject=_SUBPROJECT,
353
+ use_external_memory_version=self._use_external_memory_version,
354
+ batch_size=self._batch_size,
344
355
  )
345
356
  self._sklearn_object = model_trainer.train()
346
357
  self._is_fitted = True
@@ -611,6 +622,22 @@ class Perceptron(BaseTransformer):
611
622
  # each row containing a list of values.
612
623
  expected_dtype = "ARRAY"
613
624
 
625
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
626
+ if expected_dtype == "":
627
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
628
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
629
+ expected_dtype = "ARRAY"
630
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
631
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
632
+ expected_dtype = "ARRAY"
633
+ else:
634
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
635
+ # We can only infer the output types from the input types if the following two statemetns are true:
636
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
637
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
638
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
639
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
640
+
614
641
  output_df = self._batch_inference(
615
642
  dataset=dataset,
616
643
  inference_method="transform",
@@ -626,8 +653,8 @@ class Perceptron(BaseTransformer):
626
653
 
627
654
  return output_df
628
655
 
629
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
630
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
656
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
657
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
631
658
  """ Method not supported for this class.
632
659
 
633
660
 
@@ -640,13 +667,21 @@ class Perceptron(BaseTransformer):
640
667
  Returns:
641
668
  Predicted dataset.
642
669
  """
643
- if False:
644
- self.fit(dataset)
645
- assert self._sklearn_object is not None
646
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
647
- return labels
648
- else:
649
- raise NotImplementedError
670
+ self.fit(dataset)
671
+ assert self._sklearn_object is not None
672
+ return self._sklearn_object.labels_
673
+
674
+
675
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
676
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
677
+ """
678
+ Returns:
679
+ Transformed dataset.
680
+ """
681
+ self.fit(dataset)
682
+ assert self._sklearn_object is not None
683
+ return self._sklearn_object.embedding_
684
+
650
685
 
651
686
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
652
687
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.