snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MLPClassifier(BaseTransformer):
58
70
  r"""Multi-layer Perceptron classifier
59
71
  For more details on this class, see [sklearn.neural_network.MLPClassifier]
@@ -297,7 +309,9 @@ class MLPClassifier(BaseTransformer):
297
309
  self.set_label_cols(label_cols)
298
310
  self.set_passthrough_cols(passthrough_cols)
299
311
  self.set_drop_input_cols(drop_input_cols)
300
- self.set_sample_weight_col(sample_weight_col)
312
+ self.set_sample_weight_col(sample_weight_col)
313
+ self._use_external_memory_version = False
314
+ self._batch_size = -1
301
315
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
302
316
 
303
317
  self._deps = list(deps)
@@ -395,11 +409,6 @@ class MLPClassifier(BaseTransformer):
395
409
  if isinstance(dataset, DataFrame):
396
410
  session = dataset._session
397
411
  assert session is not None # keep mypy happy
398
- # Validate that key package version in user workspace are supported in snowflake conda channel
399
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
400
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
401
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
402
-
403
412
  # Specify input columns so column pruning will be enforced
404
413
  selected_cols = self._get_active_columns()
405
414
  if len(selected_cols) > 0:
@@ -427,7 +436,9 @@ class MLPClassifier(BaseTransformer):
427
436
  label_cols=self.label_cols,
428
437
  sample_weight_col=self.sample_weight_col,
429
438
  autogenerated=self._autogenerated,
430
- subproject=_SUBPROJECT
439
+ subproject=_SUBPROJECT,
440
+ use_external_memory_version=self._use_external_memory_version,
441
+ batch_size=self._batch_size,
431
442
  )
432
443
  self._sklearn_object = model_trainer.train()
433
444
  self._is_fitted = True
@@ -698,6 +709,22 @@ class MLPClassifier(BaseTransformer):
698
709
  # each row containing a list of values.
699
710
  expected_dtype = "ARRAY"
700
711
 
712
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
713
+ if expected_dtype == "":
714
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
715
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
716
+ expected_dtype = "ARRAY"
717
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
718
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
719
+ expected_dtype = "ARRAY"
720
+ else:
721
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
722
+ # We can only infer the output types from the input types if the following two statemetns are true:
723
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
724
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
725
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
726
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
727
+
701
728
  output_df = self._batch_inference(
702
729
  dataset=dataset,
703
730
  inference_method="transform",
@@ -713,8 +740,8 @@ class MLPClassifier(BaseTransformer):
713
740
 
714
741
  return output_df
715
742
 
716
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
717
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
743
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
744
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
718
745
  """ Method not supported for this class.
719
746
 
720
747
 
@@ -727,13 +754,21 @@ class MLPClassifier(BaseTransformer):
727
754
  Returns:
728
755
  Predicted dataset.
729
756
  """
730
- if False:
731
- self.fit(dataset)
732
- assert self._sklearn_object is not None
733
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
734
- return labels
735
- else:
736
- raise NotImplementedError
757
+ self.fit(dataset)
758
+ assert self._sklearn_object is not None
759
+ return self._sklearn_object.labels_
760
+
761
+
762
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
763
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
764
+ """
765
+ Returns:
766
+ Transformed dataset.
767
+ """
768
+ self.fit(dataset)
769
+ assert self._sklearn_object is not None
770
+ return self._sklearn_object.embedding_
771
+
737
772
 
738
773
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
739
774
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MLPRegressor(BaseTransformer):
58
70
  r"""Multi-layer Perceptron regressor
59
71
  For more details on this class, see [sklearn.neural_network.MLPRegressor]
@@ -293,7 +305,9 @@ class MLPRegressor(BaseTransformer):
293
305
  self.set_label_cols(label_cols)
294
306
  self.set_passthrough_cols(passthrough_cols)
295
307
  self.set_drop_input_cols(drop_input_cols)
296
- self.set_sample_weight_col(sample_weight_col)
308
+ self.set_sample_weight_col(sample_weight_col)
309
+ self._use_external_memory_version = False
310
+ self._batch_size = -1
297
311
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
298
312
 
299
313
  self._deps = list(deps)
@@ -391,11 +405,6 @@ class MLPRegressor(BaseTransformer):
391
405
  if isinstance(dataset, DataFrame):
392
406
  session = dataset._session
393
407
  assert session is not None # keep mypy happy
394
- # Validate that key package version in user workspace are supported in snowflake conda channel
395
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
396
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
397
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
398
-
399
408
  # Specify input columns so column pruning will be enforced
400
409
  selected_cols = self._get_active_columns()
401
410
  if len(selected_cols) > 0:
@@ -423,7 +432,9 @@ class MLPRegressor(BaseTransformer):
423
432
  label_cols=self.label_cols,
424
433
  sample_weight_col=self.sample_weight_col,
425
434
  autogenerated=self._autogenerated,
426
- subproject=_SUBPROJECT
435
+ subproject=_SUBPROJECT,
436
+ use_external_memory_version=self._use_external_memory_version,
437
+ batch_size=self._batch_size,
427
438
  )
428
439
  self._sklearn_object = model_trainer.train()
429
440
  self._is_fitted = True
@@ -694,6 +705,22 @@ class MLPRegressor(BaseTransformer):
694
705
  # each row containing a list of values.
695
706
  expected_dtype = "ARRAY"
696
707
 
708
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
709
+ if expected_dtype == "":
710
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
711
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
712
+ expected_dtype = "ARRAY"
713
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
714
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
715
+ expected_dtype = "ARRAY"
716
+ else:
717
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
718
+ # We can only infer the output types from the input types if the following two statemetns are true:
719
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
720
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
721
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
722
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
723
+
697
724
  output_df = self._batch_inference(
698
725
  dataset=dataset,
699
726
  inference_method="transform",
@@ -709,8 +736,8 @@ class MLPRegressor(BaseTransformer):
709
736
 
710
737
  return output_df
711
738
 
712
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
713
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
739
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
740
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
714
741
  """ Method not supported for this class.
715
742
 
716
743
 
@@ -723,13 +750,21 @@ class MLPRegressor(BaseTransformer):
723
750
  Returns:
724
751
  Predicted dataset.
725
752
  """
726
- if False:
727
- self.fit(dataset)
728
- assert self._sklearn_object is not None
729
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
730
- return labels
731
- else:
732
- raise NotImplementedError
753
+ self.fit(dataset)
754
+ assert self._sklearn_object is not None
755
+ return self._sklearn_object.labels_
756
+
757
+
758
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
759
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
760
+ """
761
+ Returns:
762
+ Transformed dataset.
763
+ """
764
+ self.fit(dataset)
765
+ assert self._sklearn_object is not None
766
+ return self._sklearn_object.embedding_
767
+
733
768
 
734
769
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
735
770
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -8,8 +8,9 @@ from sklearn.preprocessing import _data as sklearn_preprocessing_data
8
8
 
9
9
  from snowflake import snowpark
10
10
  from snowflake.ml._internal import telemetry
11
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
11
12
  from snowflake.ml.modeling.framework import _utils, base
12
- from snowflake.snowpark import functions as F
13
+ from snowflake.snowpark import functions as F, types as T
13
14
 
14
15
 
15
16
  class MinMaxScaler(base.BaseTransformer):
@@ -125,6 +126,18 @@ class MinMaxScaler(base.BaseTransformer):
125
126
  self.data_max_ = {}
126
127
  self.data_range_ = {}
127
128
 
129
+ def _check_input_column_types(self, dataset: snowpark.DataFrame) -> None:
130
+ for field in dataset.schema.fields:
131
+ if field.name in self.input_cols:
132
+ if not issubclass(type(field.datatype), T._NumericType):
133
+ raise exceptions.SnowflakeMLException(
134
+ error_code=error_codes.INVALID_DATA_TYPE,
135
+ original_exception=TypeError(
136
+ f"Non-numeric input column {field.name} datatype {field.datatype} "
137
+ "is not supported by the MinMaxScaler."
138
+ ),
139
+ )
140
+
128
141
  @telemetry.send_api_usage_telemetry(
129
142
  project=base.PROJECT,
130
143
  subproject=base.SUBPROJECT,
@@ -169,6 +182,7 @@ class MinMaxScaler(base.BaseTransformer):
169
182
  self.data_range_[input_col] = float(sklearn_scaler.data_range_[i])
170
183
 
171
184
  def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None:
185
+ self._check_input_column_types(dataset)
172
186
  computed_states = self._compute(dataset, self.input_cols, self.custom_states)
173
187
 
174
188
  # assign states to the object
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.preprocessing".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class PolynomialFeatures(BaseTransformer):
58
70
  r"""Generate polynomial and interaction features
59
71
  For more details on this class, see [sklearn.preprocessing.PolynomialFeatures]
@@ -151,7 +163,9 @@ class PolynomialFeatures(BaseTransformer):
151
163
  self.set_label_cols(label_cols)
152
164
  self.set_passthrough_cols(passthrough_cols)
153
165
  self.set_drop_input_cols(drop_input_cols)
154
- self.set_sample_weight_col(sample_weight_col)
166
+ self.set_sample_weight_col(sample_weight_col)
167
+ self._use_external_memory_version = False
168
+ self._batch_size = -1
155
169
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
156
170
 
157
171
  self._deps = list(deps)
@@ -230,11 +244,6 @@ class PolynomialFeatures(BaseTransformer):
230
244
  if isinstance(dataset, DataFrame):
231
245
  session = dataset._session
232
246
  assert session is not None # keep mypy happy
233
- # Validate that key package version in user workspace are supported in snowflake conda channel
234
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
235
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
236
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
237
-
238
247
  # Specify input columns so column pruning will be enforced
239
248
  selected_cols = self._get_active_columns()
240
249
  if len(selected_cols) > 0:
@@ -262,7 +271,9 @@ class PolynomialFeatures(BaseTransformer):
262
271
  label_cols=self.label_cols,
263
272
  sample_weight_col=self.sample_weight_col,
264
273
  autogenerated=self._autogenerated,
265
- subproject=_SUBPROJECT
274
+ subproject=_SUBPROJECT,
275
+ use_external_memory_version=self._use_external_memory_version,
276
+ batch_size=self._batch_size,
266
277
  )
267
278
  self._sklearn_object = model_trainer.train()
268
279
  self._is_fitted = True
@@ -533,6 +544,22 @@ class PolynomialFeatures(BaseTransformer):
533
544
  # each row containing a list of values.
534
545
  expected_dtype = "ARRAY"
535
546
 
547
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
548
+ if expected_dtype == "":
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
550
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
553
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
554
+ expected_dtype = "ARRAY"
555
+ else:
556
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
557
+ # We can only infer the output types from the input types if the following two statemetns are true:
558
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
559
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
560
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
561
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
562
+
536
563
  output_df = self._batch_inference(
537
564
  dataset=dataset,
538
565
  inference_method="transform",
@@ -548,8 +575,8 @@ class PolynomialFeatures(BaseTransformer):
548
575
 
549
576
  return output_df
550
577
 
551
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
552
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
578
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
579
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
553
580
  """ Method not supported for this class.
554
581
 
555
582
 
@@ -562,13 +589,21 @@ class PolynomialFeatures(BaseTransformer):
562
589
  Returns:
563
590
  Predicted dataset.
564
591
  """
565
- if False:
566
- self.fit(dataset)
567
- assert self._sklearn_object is not None
568
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
569
- return labels
570
- else:
571
- raise NotImplementedError
592
+ self.fit(dataset)
593
+ assert self._sklearn_object is not None
594
+ return self._sklearn_object.labels_
595
+
596
+
597
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
598
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
599
+ """
600
+ Returns:
601
+ Transformed dataset.
602
+ """
603
+ self.fit(dataset)
604
+ assert self._sklearn_object is not None
605
+ return self._sklearn_object.embedding_
606
+
572
607
 
573
608
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
574
609
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LabelPropagation(BaseTransformer):
58
70
  r"""Label Propagation classifier
59
71
  For more details on this class, see [sklearn.semi_supervised.LabelPropagation]
@@ -155,7 +167,9 @@ class LabelPropagation(BaseTransformer):
155
167
  self.set_label_cols(label_cols)
156
168
  self.set_passthrough_cols(passthrough_cols)
157
169
  self.set_drop_input_cols(drop_input_cols)
158
- self.set_sample_weight_col(sample_weight_col)
170
+ self.set_sample_weight_col(sample_weight_col)
171
+ self._use_external_memory_version = False
172
+ self._batch_size = -1
159
173
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
160
174
 
161
175
  self._deps = list(deps)
@@ -236,11 +250,6 @@ class LabelPropagation(BaseTransformer):
236
250
  if isinstance(dataset, DataFrame):
237
251
  session = dataset._session
238
252
  assert session is not None # keep mypy happy
239
- # Validate that key package version in user workspace are supported in snowflake conda channel
240
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
241
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
242
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
243
-
244
253
  # Specify input columns so column pruning will be enforced
245
254
  selected_cols = self._get_active_columns()
246
255
  if len(selected_cols) > 0:
@@ -268,7 +277,9 @@ class LabelPropagation(BaseTransformer):
268
277
  label_cols=self.label_cols,
269
278
  sample_weight_col=self.sample_weight_col,
270
279
  autogenerated=self._autogenerated,
271
- subproject=_SUBPROJECT
280
+ subproject=_SUBPROJECT,
281
+ use_external_memory_version=self._use_external_memory_version,
282
+ batch_size=self._batch_size,
272
283
  )
273
284
  self._sklearn_object = model_trainer.train()
274
285
  self._is_fitted = True
@@ -539,6 +550,22 @@ class LabelPropagation(BaseTransformer):
539
550
  # each row containing a list of values.
540
551
  expected_dtype = "ARRAY"
541
552
 
553
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
554
+ if expected_dtype == "":
555
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
556
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
557
+ expected_dtype = "ARRAY"
558
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
559
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
560
+ expected_dtype = "ARRAY"
561
+ else:
562
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
563
+ # We can only infer the output types from the input types if the following two statemetns are true:
564
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
565
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
566
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
567
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
568
+
542
569
  output_df = self._batch_inference(
543
570
  dataset=dataset,
544
571
  inference_method="transform",
@@ -554,8 +581,8 @@ class LabelPropagation(BaseTransformer):
554
581
 
555
582
  return output_df
556
583
 
557
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
558
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
584
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
585
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
559
586
  """ Method not supported for this class.
560
587
 
561
588
 
@@ -568,13 +595,21 @@ class LabelPropagation(BaseTransformer):
568
595
  Returns:
569
596
  Predicted dataset.
570
597
  """
571
- if False:
572
- self.fit(dataset)
573
- assert self._sklearn_object is not None
574
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
575
- return labels
576
- else:
577
- raise NotImplementedError
598
+ self.fit(dataset)
599
+ assert self._sklearn_object is not None
600
+ return self._sklearn_object.labels_
601
+
602
+
603
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
604
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
605
+ """
606
+ Returns:
607
+ Transformed dataset.
608
+ """
609
+ self.fit(dataset)
610
+ assert self._sklearn_object is not None
611
+ return self._sklearn_object.embedding_
612
+
578
613
 
579
614
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
580
615
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LabelSpreading(BaseTransformer):
58
70
  r"""LabelSpreading model for semi-supervised learning
59
71
  For more details on this class, see [sklearn.semi_supervised.LabelSpreading]
@@ -163,7 +175,9 @@ class LabelSpreading(BaseTransformer):
163
175
  self.set_label_cols(label_cols)
164
176
  self.set_passthrough_cols(passthrough_cols)
165
177
  self.set_drop_input_cols(drop_input_cols)
166
- self.set_sample_weight_col(sample_weight_col)
178
+ self.set_sample_weight_col(sample_weight_col)
179
+ self._use_external_memory_version = False
180
+ self._batch_size = -1
167
181
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
168
182
 
169
183
  self._deps = list(deps)
@@ -245,11 +259,6 @@ class LabelSpreading(BaseTransformer):
245
259
  if isinstance(dataset, DataFrame):
246
260
  session = dataset._session
247
261
  assert session is not None # keep mypy happy
248
- # Validate that key package version in user workspace are supported in snowflake conda channel
249
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
250
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
251
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
252
-
253
262
  # Specify input columns so column pruning will be enforced
254
263
  selected_cols = self._get_active_columns()
255
264
  if len(selected_cols) > 0:
@@ -277,7 +286,9 @@ class LabelSpreading(BaseTransformer):
277
286
  label_cols=self.label_cols,
278
287
  sample_weight_col=self.sample_weight_col,
279
288
  autogenerated=self._autogenerated,
280
- subproject=_SUBPROJECT
289
+ subproject=_SUBPROJECT,
290
+ use_external_memory_version=self._use_external_memory_version,
291
+ batch_size=self._batch_size,
281
292
  )
282
293
  self._sklearn_object = model_trainer.train()
283
294
  self._is_fitted = True
@@ -548,6 +559,22 @@ class LabelSpreading(BaseTransformer):
548
559
  # each row containing a list of values.
549
560
  expected_dtype = "ARRAY"
550
561
 
562
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
563
+ if expected_dtype == "":
564
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
565
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
566
+ expected_dtype = "ARRAY"
567
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
568
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
569
+ expected_dtype = "ARRAY"
570
+ else:
571
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
572
+ # We can only infer the output types from the input types if the following two statemetns are true:
573
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
574
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
575
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
576
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
577
+
551
578
  output_df = self._batch_inference(
552
579
  dataset=dataset,
553
580
  inference_method="transform",
@@ -563,8 +590,8 @@ class LabelSpreading(BaseTransformer):
563
590
 
564
591
  return output_df
565
592
 
566
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
567
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
593
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
594
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
568
595
  """ Method not supported for this class.
569
596
 
570
597
 
@@ -577,13 +604,21 @@ class LabelSpreading(BaseTransformer):
577
604
  Returns:
578
605
  Predicted dataset.
579
606
  """
580
- if False:
581
- self.fit(dataset)
582
- assert self._sklearn_object is not None
583
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
584
- return labels
585
- else:
586
- raise NotImplementedError
607
+ self.fit(dataset)
608
+ assert self._sklearn_object is not None
609
+ return self._sklearn_object.labels_
610
+
611
+
612
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
613
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
614
+ """
615
+ Returns:
616
+ Transformed dataset.
617
+ """
618
+ self.fit(dataset)
619
+ assert self._sklearn_object is not None
620
+ return self._sklearn_object.embedding_
621
+
587
622
 
588
623
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
589
624
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.