snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -122,7 +122,8 @@ def precision_recall_curve(
122
122
  result_module = cloudpickle.loads(pickled_result_module)
123
123
  return result_module.serialize(session, (precision, recall, thresholds)) # type: ignore[no-any-return]
124
124
 
125
- result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session))
125
+ kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
126
+ result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
126
127
  res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
127
128
  return res
128
129
 
@@ -271,7 +272,8 @@ def roc_auc_score(
271
272
  result_module = cloudpickle.loads(pickled_result_module)
272
273
  return result_module.serialize(session, auc) # type: ignore[no-any-return]
273
274
 
274
- result_object = result.deserialize(session, roc_auc_score_anon_sproc(session))
275
+ kwargs = telemetry.get_sproc_statement_params_kwargs(roc_auc_score_anon_sproc, statement_params)
276
+ result_object = result.deserialize(session, roc_auc_score_anon_sproc(session, **kwargs))
275
277
  auc: Union[float, npt.NDArray[np.float_]] = result_object
276
278
  return auc
277
279
 
@@ -372,7 +374,9 @@ def roc_curve(
372
374
  result_module = cloudpickle.loads(pickled_result_module)
373
375
  return result_module.serialize(session, (fpr, tpr, thresholds)) # type: ignore[no-any-return]
374
376
 
375
- result_object = result.deserialize(session, roc_curve_anon_sproc(session))
377
+ kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
378
+ result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
379
+
376
380
  res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
377
381
 
378
382
  return res
@@ -108,7 +108,8 @@ def d2_absolute_error_score(
108
108
  result_module = cloudpickle.loads(pickled_snowflake_result)
109
109
  return result_module.serialize(session, score) # type: ignore[no-any-return]
110
110
 
111
- result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session))
111
+ kwargs = telemetry.get_sproc_statement_params_kwargs(d2_absolute_error_score_anon_sproc, statement_params)
112
+ result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session, **kwargs))
112
113
  score: Union[float, npt.NDArray[np.float_]] = result_object
113
114
  return score
114
115
 
@@ -205,7 +206,8 @@ def d2_pinball_score(
205
206
  result_module = cloudpickle.loads(pickled_result_module)
206
207
  return result_module.serialize(session, score) # type: ignore[no-any-return]
207
208
 
208
- result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session))
209
+ kwargs = telemetry.get_sproc_statement_params_kwargs(d2_pinball_score_anon_sproc, statement_params)
210
+ result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session, **kwargs))
209
211
 
210
212
  score: Union[float, npt.NDArray[np.float_]] = result_object
211
213
  return score
@@ -319,7 +321,8 @@ def explained_variance_score(
319
321
  result_module = cloudpickle.loads(pickled_result_module)
320
322
  return result_module.serialize(session, score) # type: ignore[no-any-return]
321
323
 
322
- result_object = result.deserialize(session, explained_variance_score_anon_sproc(session))
324
+ kwargs = telemetry.get_sproc_statement_params_kwargs(explained_variance_score_anon_sproc, statement_params)
325
+ result_object = result.deserialize(session, explained_variance_score_anon_sproc(session, **kwargs))
323
326
  score: Union[float, npt.NDArray[np.float_]] = result_object
324
327
  return score
325
328
 
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class BayesianGaussianMixture(BaseTransformer):
58
70
  r"""Variational Bayesian estimation of a Gaussian mixture
59
71
  For more details on this class, see [sklearn.mixture.BayesianGaussianMixture]
@@ -241,7 +253,9 @@ class BayesianGaussianMixture(BaseTransformer):
241
253
  self.set_label_cols(label_cols)
242
254
  self.set_passthrough_cols(passthrough_cols)
243
255
  self.set_drop_input_cols(drop_input_cols)
244
- self.set_sample_weight_col(sample_weight_col)
256
+ self.set_sample_weight_col(sample_weight_col)
257
+ self._use_external_memory_version = False
258
+ self._batch_size = -1
245
259
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
246
260
 
247
261
  self._deps = list(deps)
@@ -333,11 +347,6 @@ class BayesianGaussianMixture(BaseTransformer):
333
347
  if isinstance(dataset, DataFrame):
334
348
  session = dataset._session
335
349
  assert session is not None # keep mypy happy
336
- # Validate that key package version in user workspace are supported in snowflake conda channel
337
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
338
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
339
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
340
-
341
350
  # Specify input columns so column pruning will be enforced
342
351
  selected_cols = self._get_active_columns()
343
352
  if len(selected_cols) > 0:
@@ -365,7 +374,9 @@ class BayesianGaussianMixture(BaseTransformer):
365
374
  label_cols=self.label_cols,
366
375
  sample_weight_col=self.sample_weight_col,
367
376
  autogenerated=self._autogenerated,
368
- subproject=_SUBPROJECT
377
+ subproject=_SUBPROJECT,
378
+ use_external_memory_version=self._use_external_memory_version,
379
+ batch_size=self._batch_size,
369
380
  )
370
381
  self._sklearn_object = model_trainer.train()
371
382
  self._is_fitted = True
@@ -636,6 +647,22 @@ class BayesianGaussianMixture(BaseTransformer):
636
647
  # each row containing a list of values.
637
648
  expected_dtype = "ARRAY"
638
649
 
650
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
651
+ if expected_dtype == "":
652
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
653
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
654
+ expected_dtype = "ARRAY"
655
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
656
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
657
+ expected_dtype = "ARRAY"
658
+ else:
659
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
660
+ # We can only infer the output types from the input types if the following two statemetns are true:
661
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
662
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
663
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
664
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
665
+
639
666
  output_df = self._batch_inference(
640
667
  dataset=dataset,
641
668
  inference_method="transform",
@@ -651,8 +678,8 @@ class BayesianGaussianMixture(BaseTransformer):
651
678
 
652
679
  return output_df
653
680
 
654
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
655
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
681
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
682
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
656
683
  """ Estimate model parameters using X and predict the labels for X
657
684
  For more details on this function, see [sklearn.mixture.BayesianGaussianMixture.fit_predict]
658
685
  (https://scikit-learn.org/stable/modules/generated/sklearn.mixture.BayesianGaussianMixture.html#sklearn.mixture.BayesianGaussianMixture.fit_predict)
@@ -667,13 +694,21 @@ class BayesianGaussianMixture(BaseTransformer):
667
694
  Returns:
668
695
  Predicted dataset.
669
696
  """
670
- if False:
671
- self.fit(dataset)
672
- assert self._sklearn_object is not None
673
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
674
- return labels
675
- else:
676
- raise NotImplementedError
697
+ self.fit(dataset)
698
+ assert self._sklearn_object is not None
699
+ return self._sklearn_object.labels_
700
+
701
+
702
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
703
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
704
+ """
705
+ Returns:
706
+ Transformed dataset.
707
+ """
708
+ self.fit(dataset)
709
+ assert self._sklearn_object is not None
710
+ return self._sklearn_object.embedding_
711
+
677
712
 
678
713
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
679
714
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GaussianMixture(BaseTransformer):
58
70
  r"""Gaussian Mixture
59
71
  For more details on this class, see [sklearn.mixture.GaussianMixture]
@@ -217,7 +229,9 @@ class GaussianMixture(BaseTransformer):
217
229
  self.set_label_cols(label_cols)
218
230
  self.set_passthrough_cols(passthrough_cols)
219
231
  self.set_drop_input_cols(drop_input_cols)
220
- self.set_sample_weight_col(sample_weight_col)
232
+ self.set_sample_weight_col(sample_weight_col)
233
+ self._use_external_memory_version = False
234
+ self._batch_size = -1
221
235
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
222
236
 
223
237
  self._deps = list(deps)
@@ -306,11 +320,6 @@ class GaussianMixture(BaseTransformer):
306
320
  if isinstance(dataset, DataFrame):
307
321
  session = dataset._session
308
322
  assert session is not None # keep mypy happy
309
- # Validate that key package version in user workspace are supported in snowflake conda channel
310
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
311
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
312
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
313
-
314
323
  # Specify input columns so column pruning will be enforced
315
324
  selected_cols = self._get_active_columns()
316
325
  if len(selected_cols) > 0:
@@ -338,7 +347,9 @@ class GaussianMixture(BaseTransformer):
338
347
  label_cols=self.label_cols,
339
348
  sample_weight_col=self.sample_weight_col,
340
349
  autogenerated=self._autogenerated,
341
- subproject=_SUBPROJECT
350
+ subproject=_SUBPROJECT,
351
+ use_external_memory_version=self._use_external_memory_version,
352
+ batch_size=self._batch_size,
342
353
  )
343
354
  self._sklearn_object = model_trainer.train()
344
355
  self._is_fitted = True
@@ -609,6 +620,22 @@ class GaussianMixture(BaseTransformer):
609
620
  # each row containing a list of values.
610
621
  expected_dtype = "ARRAY"
611
622
 
623
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
624
+ if expected_dtype == "":
625
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
626
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
627
+ expected_dtype = "ARRAY"
628
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
629
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
630
+ expected_dtype = "ARRAY"
631
+ else:
632
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
633
+ # We can only infer the output types from the input types if the following two statemetns are true:
634
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
635
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
636
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
637
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
638
+
612
639
  output_df = self._batch_inference(
613
640
  dataset=dataset,
614
641
  inference_method="transform",
@@ -624,8 +651,8 @@ class GaussianMixture(BaseTransformer):
624
651
 
625
652
  return output_df
626
653
 
627
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
628
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
654
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
655
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
629
656
  """ Estimate model parameters using X and predict the labels for X
630
657
  For more details on this function, see [sklearn.mixture.GaussianMixture.fit_predict]
631
658
  (https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html#sklearn.mixture.GaussianMixture.fit_predict)
@@ -640,13 +667,21 @@ class GaussianMixture(BaseTransformer):
640
667
  Returns:
641
668
  Predicted dataset.
642
669
  """
643
- if False:
644
- self.fit(dataset)
645
- assert self._sklearn_object is not None
646
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
647
- return labels
648
- else:
649
- raise NotImplementedError
670
+ self.fit(dataset)
671
+ assert self._sklearn_object is not None
672
+ return self._sklearn_object.labels_
673
+
674
+
675
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
676
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
677
+ """
678
+ Returns:
679
+ Transformed dataset.
680
+ """
681
+ self.fit(dataset)
682
+ assert self._sklearn_object is not None
683
+ return self._sklearn_object.embedding_
684
+
650
685
 
651
686
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
652
687
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class OneVsOneClassifier(BaseTransformer):
58
70
  r"""One-vs-one multiclass strategy
59
71
  For more details on this class, see [sklearn.multiclass.OneVsOneClassifier]
@@ -141,7 +153,9 @@ class OneVsOneClassifier(BaseTransformer):
141
153
  self.set_label_cols(label_cols)
142
154
  self.set_passthrough_cols(passthrough_cols)
143
155
  self.set_drop_input_cols(drop_input_cols)
144
- self.set_sample_weight_col(sample_weight_col)
156
+ self.set_sample_weight_col(sample_weight_col)
157
+ self._use_external_memory_version = False
158
+ self._batch_size = -1
145
159
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
146
160
  deps = deps | gather_dependencies(estimator)
147
161
  self._deps = list(deps)
@@ -218,11 +232,6 @@ class OneVsOneClassifier(BaseTransformer):
218
232
  if isinstance(dataset, DataFrame):
219
233
  session = dataset._session
220
234
  assert session is not None # keep mypy happy
221
- # Validate that key package version in user workspace are supported in snowflake conda channel
222
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
223
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
224
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
225
-
226
235
  # Specify input columns so column pruning will be enforced
227
236
  selected_cols = self._get_active_columns()
228
237
  if len(selected_cols) > 0:
@@ -250,7 +259,9 @@ class OneVsOneClassifier(BaseTransformer):
250
259
  label_cols=self.label_cols,
251
260
  sample_weight_col=self.sample_weight_col,
252
261
  autogenerated=self._autogenerated,
253
- subproject=_SUBPROJECT
262
+ subproject=_SUBPROJECT,
263
+ use_external_memory_version=self._use_external_memory_version,
264
+ batch_size=self._batch_size,
254
265
  )
255
266
  self._sklearn_object = model_trainer.train()
256
267
  self._is_fitted = True
@@ -521,6 +532,22 @@ class OneVsOneClassifier(BaseTransformer):
521
532
  # each row containing a list of values.
522
533
  expected_dtype = "ARRAY"
523
534
 
535
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
536
+ if expected_dtype == "":
537
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
538
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
539
+ expected_dtype = "ARRAY"
540
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
541
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
542
+ expected_dtype = "ARRAY"
543
+ else:
544
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
545
+ # We can only infer the output types from the input types if the following two statemetns are true:
546
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
547
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
548
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
549
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
550
+
524
551
  output_df = self._batch_inference(
525
552
  dataset=dataset,
526
553
  inference_method="transform",
@@ -536,8 +563,8 @@ class OneVsOneClassifier(BaseTransformer):
536
563
 
537
564
  return output_df
538
565
 
539
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
540
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
566
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
567
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
541
568
  """ Method not supported for this class.
542
569
 
543
570
 
@@ -550,13 +577,21 @@ class OneVsOneClassifier(BaseTransformer):
550
577
  Returns:
551
578
  Predicted dataset.
552
579
  """
553
- if False:
554
- self.fit(dataset)
555
- assert self._sklearn_object is not None
556
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
557
- return labels
558
- else:
559
- raise NotImplementedError
580
+ self.fit(dataset)
581
+ assert self._sklearn_object is not None
582
+ return self._sklearn_object.labels_
583
+
584
+
585
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
586
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
587
+ """
588
+ Returns:
589
+ Transformed dataset.
590
+ """
591
+ self.fit(dataset)
592
+ assert self._sklearn_object is not None
593
+ return self._sklearn_object.embedding_
594
+
560
595
 
561
596
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
562
597
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class OneVsRestClassifier(BaseTransformer):
58
70
  r"""One-vs-the-rest (OvR) multiclass strategy
59
71
  For more details on this class, see [sklearn.multiclass.OneVsRestClassifier]
@@ -149,7 +161,9 @@ class OneVsRestClassifier(BaseTransformer):
149
161
  self.set_label_cols(label_cols)
150
162
  self.set_passthrough_cols(passthrough_cols)
151
163
  self.set_drop_input_cols(drop_input_cols)
152
- self.set_sample_weight_col(sample_weight_col)
164
+ self.set_sample_weight_col(sample_weight_col)
165
+ self._use_external_memory_version = False
166
+ self._batch_size = -1
153
167
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
154
168
  deps = deps | gather_dependencies(estimator)
155
169
  self._deps = list(deps)
@@ -227,11 +241,6 @@ class OneVsRestClassifier(BaseTransformer):
227
241
  if isinstance(dataset, DataFrame):
228
242
  session = dataset._session
229
243
  assert session is not None # keep mypy happy
230
- # Validate that key package version in user workspace are supported in snowflake conda channel
231
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
232
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
233
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
234
-
235
244
  # Specify input columns so column pruning will be enforced
236
245
  selected_cols = self._get_active_columns()
237
246
  if len(selected_cols) > 0:
@@ -259,7 +268,9 @@ class OneVsRestClassifier(BaseTransformer):
259
268
  label_cols=self.label_cols,
260
269
  sample_weight_col=self.sample_weight_col,
261
270
  autogenerated=self._autogenerated,
262
- subproject=_SUBPROJECT
271
+ subproject=_SUBPROJECT,
272
+ use_external_memory_version=self._use_external_memory_version,
273
+ batch_size=self._batch_size,
263
274
  )
264
275
  self._sklearn_object = model_trainer.train()
265
276
  self._is_fitted = True
@@ -530,6 +541,22 @@ class OneVsRestClassifier(BaseTransformer):
530
541
  # each row containing a list of values.
531
542
  expected_dtype = "ARRAY"
532
543
 
544
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
545
+ if expected_dtype == "":
546
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
547
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
550
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ else:
553
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
554
+ # We can only infer the output types from the input types if the following two statemetns are true:
555
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
556
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
557
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
558
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
559
+
533
560
  output_df = self._batch_inference(
534
561
  dataset=dataset,
535
562
  inference_method="transform",
@@ -545,8 +572,8 @@ class OneVsRestClassifier(BaseTransformer):
545
572
 
546
573
  return output_df
547
574
 
548
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
549
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
575
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
576
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
550
577
  """ Method not supported for this class.
551
578
 
552
579
 
@@ -559,13 +586,21 @@ class OneVsRestClassifier(BaseTransformer):
559
586
  Returns:
560
587
  Predicted dataset.
561
588
  """
562
- if False:
563
- self.fit(dataset)
564
- assert self._sklearn_object is not None
565
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
566
- return labels
567
- else:
568
- raise NotImplementedError
589
+ self.fit(dataset)
590
+ assert self._sklearn_object is not None
591
+ return self._sklearn_object.labels_
592
+
593
+
594
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
595
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
596
+ """
597
+ Returns:
598
+ Transformed dataset.
599
+ """
600
+ self.fit(dataset)
601
+ assert self._sklearn_object is not None
602
+ return self._sklearn_object.embedding_
603
+
569
604
 
570
605
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
571
606
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class OutputCodeClassifier(BaseTransformer):
58
70
  r"""(Error-Correcting) Output-Code multiclass strategy
59
71
  For more details on this class, see [sklearn.multiclass.OutputCodeClassifier]
@@ -151,7 +163,9 @@ class OutputCodeClassifier(BaseTransformer):
151
163
  self.set_label_cols(label_cols)
152
164
  self.set_passthrough_cols(passthrough_cols)
153
165
  self.set_drop_input_cols(drop_input_cols)
154
- self.set_sample_weight_col(sample_weight_col)
166
+ self.set_sample_weight_col(sample_weight_col)
167
+ self._use_external_memory_version = False
168
+ self._batch_size = -1
155
169
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
156
170
  deps = deps | gather_dependencies(estimator)
157
171
  self._deps = list(deps)
@@ -230,11 +244,6 @@ class OutputCodeClassifier(BaseTransformer):
230
244
  if isinstance(dataset, DataFrame):
231
245
  session = dataset._session
232
246
  assert session is not None # keep mypy happy
233
- # Validate that key package version in user workspace are supported in snowflake conda channel
234
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
235
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
236
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
237
-
238
247
  # Specify input columns so column pruning will be enforced
239
248
  selected_cols = self._get_active_columns()
240
249
  if len(selected_cols) > 0:
@@ -262,7 +271,9 @@ class OutputCodeClassifier(BaseTransformer):
262
271
  label_cols=self.label_cols,
263
272
  sample_weight_col=self.sample_weight_col,
264
273
  autogenerated=self._autogenerated,
265
- subproject=_SUBPROJECT
274
+ subproject=_SUBPROJECT,
275
+ use_external_memory_version=self._use_external_memory_version,
276
+ batch_size=self._batch_size,
266
277
  )
267
278
  self._sklearn_object = model_trainer.train()
268
279
  self._is_fitted = True
@@ -533,6 +544,22 @@ class OutputCodeClassifier(BaseTransformer):
533
544
  # each row containing a list of values.
534
545
  expected_dtype = "ARRAY"
535
546
 
547
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
548
+ if expected_dtype == "":
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
550
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
553
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
554
+ expected_dtype = "ARRAY"
555
+ else:
556
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
557
+ # We can only infer the output types from the input types if the following two statemetns are true:
558
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
559
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
560
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
561
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
562
+
536
563
  output_df = self._batch_inference(
537
564
  dataset=dataset,
538
565
  inference_method="transform",
@@ -548,8 +575,8 @@ class OutputCodeClassifier(BaseTransformer):
548
575
 
549
576
  return output_df
550
577
 
551
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
552
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
578
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
579
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
553
580
  """ Method not supported for this class.
554
581
 
555
582
 
@@ -562,13 +589,21 @@ class OutputCodeClassifier(BaseTransformer):
562
589
  Returns:
563
590
  Predicted dataset.
564
591
  """
565
- if False:
566
- self.fit(dataset)
567
- assert self._sklearn_object is not None
568
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
569
- return labels
570
- else:
571
- raise NotImplementedError
592
+ self.fit(dataset)
593
+ assert self._sklearn_object is not None
594
+ return self._sklearn_object.labels_
595
+
596
+
597
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
598
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
599
+ """
600
+ Returns:
601
+ Transformed dataset.
602
+ """
603
+ self.fit(dataset)
604
+ assert self._sklearn_object is not None
605
+ return self._sklearn_object.embedding_
606
+
572
607
 
573
608
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
574
609
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.