snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class TweedieRegressor(BaseTransformer):
58
70
  r"""Generalized Linear Model with a Tweedie distribution
59
71
  For more details on this class, see [sklearn.linear_model.TweedieRegressor]
@@ -206,7 +218,9 @@ class TweedieRegressor(BaseTransformer):
206
218
  self.set_label_cols(label_cols)
207
219
  self.set_passthrough_cols(passthrough_cols)
208
220
  self.set_drop_input_cols(drop_input_cols)
209
- self.set_sample_weight_col(sample_weight_col)
221
+ self.set_sample_weight_col(sample_weight_col)
222
+ self._use_external_memory_version = False
223
+ self._batch_size = -1
210
224
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
211
225
 
212
226
  self._deps = list(deps)
@@ -290,11 +304,6 @@ class TweedieRegressor(BaseTransformer):
290
304
  if isinstance(dataset, DataFrame):
291
305
  session = dataset._session
292
306
  assert session is not None # keep mypy happy
293
- # Validate that key package version in user workspace are supported in snowflake conda channel
294
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
295
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
296
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
297
-
298
307
  # Specify input columns so column pruning will be enforced
299
308
  selected_cols = self._get_active_columns()
300
309
  if len(selected_cols) > 0:
@@ -322,7 +331,9 @@ class TweedieRegressor(BaseTransformer):
322
331
  label_cols=self.label_cols,
323
332
  sample_weight_col=self.sample_weight_col,
324
333
  autogenerated=self._autogenerated,
325
- subproject=_SUBPROJECT
334
+ subproject=_SUBPROJECT,
335
+ use_external_memory_version=self._use_external_memory_version,
336
+ batch_size=self._batch_size,
326
337
  )
327
338
  self._sklearn_object = model_trainer.train()
328
339
  self._is_fitted = True
@@ -593,6 +604,22 @@ class TweedieRegressor(BaseTransformer):
593
604
  # each row containing a list of values.
594
605
  expected_dtype = "ARRAY"
595
606
 
607
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
608
+ if expected_dtype == "":
609
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
610
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
613
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
614
+ expected_dtype = "ARRAY"
615
+ else:
616
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
617
+ # We can only infer the output types from the input types if the following two statemetns are true:
618
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
619
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
620
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
621
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
622
+
596
623
  output_df = self._batch_inference(
597
624
  dataset=dataset,
598
625
  inference_method="transform",
@@ -608,8 +635,8 @@ class TweedieRegressor(BaseTransformer):
608
635
 
609
636
  return output_df
610
637
 
611
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
612
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
638
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
639
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
613
640
  """ Method not supported for this class.
614
641
 
615
642
 
@@ -622,13 +649,21 @@ class TweedieRegressor(BaseTransformer):
622
649
  Returns:
623
650
  Predicted dataset.
624
651
  """
625
- if False:
626
- self.fit(dataset)
627
- assert self._sklearn_object is not None
628
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
629
- return labels
630
- else:
631
- raise NotImplementedError
652
+ self.fit(dataset)
653
+ assert self._sklearn_object is not None
654
+ return self._sklearn_object.labels_
655
+
656
+
657
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
658
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
659
+ """
660
+ Returns:
661
+ Transformed dataset.
662
+ """
663
+ self.fit(dataset)
664
+ assert self._sklearn_object is not None
665
+ return self._sklearn_object.embedding_
666
+
632
667
 
633
668
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
634
669
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class Isomap(BaseTransformer):
58
70
  r"""Isomap Embedding
59
71
  For more details on this class, see [sklearn.manifold.Isomap]
@@ -199,7 +211,9 @@ class Isomap(BaseTransformer):
199
211
  self.set_label_cols(label_cols)
200
212
  self.set_passthrough_cols(passthrough_cols)
201
213
  self.set_drop_input_cols(drop_input_cols)
202
- self.set_sample_weight_col(sample_weight_col)
214
+ self.set_sample_weight_col(sample_weight_col)
215
+ self._use_external_memory_version = False
216
+ self._batch_size = -1
203
217
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
204
218
 
205
219
  self._deps = list(deps)
@@ -286,11 +300,6 @@ class Isomap(BaseTransformer):
286
300
  if isinstance(dataset, DataFrame):
287
301
  session = dataset._session
288
302
  assert session is not None # keep mypy happy
289
- # Validate that key package version in user workspace are supported in snowflake conda channel
290
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
291
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
292
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
293
-
294
303
  # Specify input columns so column pruning will be enforced
295
304
  selected_cols = self._get_active_columns()
296
305
  if len(selected_cols) > 0:
@@ -318,7 +327,9 @@ class Isomap(BaseTransformer):
318
327
  label_cols=self.label_cols,
319
328
  sample_weight_col=self.sample_weight_col,
320
329
  autogenerated=self._autogenerated,
321
- subproject=_SUBPROJECT
330
+ subproject=_SUBPROJECT,
331
+ use_external_memory_version=self._use_external_memory_version,
332
+ batch_size=self._batch_size,
322
333
  )
323
334
  self._sklearn_object = model_trainer.train()
324
335
  self._is_fitted = True
@@ -589,6 +600,22 @@ class Isomap(BaseTransformer):
589
600
  # each row containing a list of values.
590
601
  expected_dtype = "ARRAY"
591
602
 
603
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
604
+ if expected_dtype == "":
605
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
606
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
607
+ expected_dtype = "ARRAY"
608
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
609
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
610
+ expected_dtype = "ARRAY"
611
+ else:
612
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
613
+ # We can only infer the output types from the input types if the following two statemetns are true:
614
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
615
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
616
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
617
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
618
+
592
619
  output_df = self._batch_inference(
593
620
  dataset=dataset,
594
621
  inference_method="transform",
@@ -604,8 +631,8 @@ class Isomap(BaseTransformer):
604
631
 
605
632
  return output_df
606
633
 
607
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
608
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
634
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
635
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
609
636
  """ Method not supported for this class.
610
637
 
611
638
 
@@ -618,13 +645,21 @@ class Isomap(BaseTransformer):
618
645
  Returns:
619
646
  Predicted dataset.
620
647
  """
621
- if False:
622
- self.fit(dataset)
623
- assert self._sklearn_object is not None
624
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
625
- return labels
626
- else:
627
- raise NotImplementedError
648
+ self.fit(dataset)
649
+ assert self._sklearn_object is not None
650
+ return self._sklearn_object.labels_
651
+
652
+
653
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
654
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
655
+ """
656
+ Returns:
657
+ Transformed dataset.
658
+ """
659
+ self.fit(dataset)
660
+ assert self._sklearn_object is not None
661
+ return self._sklearn_object.embedding_
662
+
628
663
 
629
664
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
630
665
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MDS(BaseTransformer):
58
70
  r"""Multidimensional scaling
59
71
  For more details on this class, see [sklearn.manifold.MDS]
@@ -184,7 +196,9 @@ class MDS(BaseTransformer):
184
196
  self.set_label_cols(label_cols)
185
197
  self.set_passthrough_cols(passthrough_cols)
186
198
  self.set_drop_input_cols(drop_input_cols)
187
- self.set_sample_weight_col(sample_weight_col)
199
+ self.set_sample_weight_col(sample_weight_col)
200
+ self._use_external_memory_version = False
201
+ self._batch_size = -1
188
202
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
189
203
 
190
204
  self._deps = list(deps)
@@ -269,11 +283,6 @@ class MDS(BaseTransformer):
269
283
  if isinstance(dataset, DataFrame):
270
284
  session = dataset._session
271
285
  assert session is not None # keep mypy happy
272
- # Validate that key package version in user workspace are supported in snowflake conda channel
273
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
274
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
275
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
276
-
277
286
  # Specify input columns so column pruning will be enforced
278
287
  selected_cols = self._get_active_columns()
279
288
  if len(selected_cols) > 0:
@@ -301,7 +310,9 @@ class MDS(BaseTransformer):
301
310
  label_cols=self.label_cols,
302
311
  sample_weight_col=self.sample_weight_col,
303
312
  autogenerated=self._autogenerated,
304
- subproject=_SUBPROJECT
313
+ subproject=_SUBPROJECT,
314
+ use_external_memory_version=self._use_external_memory_version,
315
+ batch_size=self._batch_size,
305
316
  )
306
317
  self._sklearn_object = model_trainer.train()
307
318
  self._is_fitted = True
@@ -570,6 +581,22 @@ class MDS(BaseTransformer):
570
581
  # each row containing a list of values.
571
582
  expected_dtype = "ARRAY"
572
583
 
584
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
585
+ if expected_dtype == "":
586
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
587
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
588
+ expected_dtype = "ARRAY"
589
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
590
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
591
+ expected_dtype = "ARRAY"
592
+ else:
593
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
594
+ # We can only infer the output types from the input types if the following two statemetns are true:
595
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
596
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
597
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
598
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
599
+
573
600
  output_df = self._batch_inference(
574
601
  dataset=dataset,
575
602
  inference_method="transform",
@@ -585,8 +612,8 @@ class MDS(BaseTransformer):
585
612
 
586
613
  return output_df
587
614
 
588
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
589
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
615
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
616
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
590
617
  """ Method not supported for this class.
591
618
 
592
619
 
@@ -599,13 +626,21 @@ class MDS(BaseTransformer):
599
626
  Returns:
600
627
  Predicted dataset.
601
628
  """
602
- if False:
603
- self.fit(dataset)
604
- assert self._sklearn_object is not None
605
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
606
- return labels
607
- else:
608
- raise NotImplementedError
629
+ self.fit(dataset)
630
+ assert self._sklearn_object is not None
631
+ return self._sklearn_object.labels_
632
+
633
+
634
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
635
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
636
+ """
637
+ Returns:
638
+ Transformed dataset.
639
+ """
640
+ self.fit(dataset)
641
+ assert self._sklearn_object is not None
642
+ return self._sklearn_object.embedding_
643
+
609
644
 
610
645
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
611
646
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SpectralEmbedding(BaseTransformer):
58
70
  r"""Spectral embedding for non-linear dimensionality reduction
59
71
  For more details on this class, see [sklearn.manifold.SpectralEmbedding]
@@ -188,7 +200,9 @@ class SpectralEmbedding(BaseTransformer):
188
200
  self.set_label_cols(label_cols)
189
201
  self.set_passthrough_cols(passthrough_cols)
190
202
  self.set_drop_input_cols(drop_input_cols)
191
- self.set_sample_weight_col(sample_weight_col)
203
+ self.set_sample_weight_col(sample_weight_col)
204
+ self._use_external_memory_version = False
205
+ self._batch_size = -1
192
206
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
193
207
 
194
208
  self._deps = list(deps)
@@ -271,11 +285,6 @@ class SpectralEmbedding(BaseTransformer):
271
285
  if isinstance(dataset, DataFrame):
272
286
  session = dataset._session
273
287
  assert session is not None # keep mypy happy
274
- # Validate that key package version in user workspace are supported in snowflake conda channel
275
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
276
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
277
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
278
-
279
288
  # Specify input columns so column pruning will be enforced
280
289
  selected_cols = self._get_active_columns()
281
290
  if len(selected_cols) > 0:
@@ -303,7 +312,9 @@ class SpectralEmbedding(BaseTransformer):
303
312
  label_cols=self.label_cols,
304
313
  sample_weight_col=self.sample_weight_col,
305
314
  autogenerated=self._autogenerated,
306
- subproject=_SUBPROJECT
315
+ subproject=_SUBPROJECT,
316
+ use_external_memory_version=self._use_external_memory_version,
317
+ batch_size=self._batch_size,
307
318
  )
308
319
  self._sklearn_object = model_trainer.train()
309
320
  self._is_fitted = True
@@ -572,6 +583,22 @@ class SpectralEmbedding(BaseTransformer):
572
583
  # each row containing a list of values.
573
584
  expected_dtype = "ARRAY"
574
585
 
586
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
587
+ if expected_dtype == "":
588
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
589
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
590
+ expected_dtype = "ARRAY"
591
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
592
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
593
+ expected_dtype = "ARRAY"
594
+ else:
595
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
596
+ # We can only infer the output types from the input types if the following two statemetns are true:
597
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
598
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
599
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
600
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
601
+
575
602
  output_df = self._batch_inference(
576
603
  dataset=dataset,
577
604
  inference_method="transform",
@@ -587,8 +614,8 @@ class SpectralEmbedding(BaseTransformer):
587
614
 
588
615
  return output_df
589
616
 
590
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
591
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
617
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
618
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
592
619
  """ Method not supported for this class.
593
620
 
594
621
 
@@ -601,13 +628,21 @@ class SpectralEmbedding(BaseTransformer):
601
628
  Returns:
602
629
  Predicted dataset.
603
630
  """
604
- if False:
605
- self.fit(dataset)
606
- assert self._sklearn_object is not None
607
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
608
- return labels
609
- else:
610
- raise NotImplementedError
631
+ self.fit(dataset)
632
+ assert self._sklearn_object is not None
633
+ return self._sklearn_object.labels_
634
+
635
+
636
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
637
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
638
+ """
639
+ Returns:
640
+ Transformed dataset.
641
+ """
642
+ self.fit(dataset)
643
+ assert self._sklearn_object is not None
644
+ return self._sklearn_object.embedding_
645
+
611
646
 
612
647
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
613
648
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return True and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class TSNE(BaseTransformer):
58
70
  r"""T-distributed Stochastic Neighbor Embedding
59
71
  For more details on this class, see [sklearn.manifold.TSNE]
@@ -240,7 +252,9 @@ class TSNE(BaseTransformer):
240
252
  self.set_label_cols(label_cols)
241
253
  self.set_passthrough_cols(passthrough_cols)
242
254
  self.set_drop_input_cols(drop_input_cols)
243
- self.set_sample_weight_col(sample_weight_col)
255
+ self.set_sample_weight_col(sample_weight_col)
256
+ self._use_external_memory_version = False
257
+ self._batch_size = -1
244
258
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
245
259
 
246
260
  self._deps = list(deps)
@@ -330,11 +344,6 @@ class TSNE(BaseTransformer):
330
344
  if isinstance(dataset, DataFrame):
331
345
  session = dataset._session
332
346
  assert session is not None # keep mypy happy
333
- # Validate that key package version in user workspace are supported in snowflake conda channel
334
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
335
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
336
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
337
-
338
347
  # Specify input columns so column pruning will be enforced
339
348
  selected_cols = self._get_active_columns()
340
349
  if len(selected_cols) > 0:
@@ -362,7 +371,9 @@ class TSNE(BaseTransformer):
362
371
  label_cols=self.label_cols,
363
372
  sample_weight_col=self.sample_weight_col,
364
373
  autogenerated=self._autogenerated,
365
- subproject=_SUBPROJECT
374
+ subproject=_SUBPROJECT,
375
+ use_external_memory_version=self._use_external_memory_version,
376
+ batch_size=self._batch_size,
366
377
  )
367
378
  self._sklearn_object = model_trainer.train()
368
379
  self._is_fitted = True
@@ -631,6 +642,22 @@ class TSNE(BaseTransformer):
631
642
  # each row containing a list of values.
632
643
  expected_dtype = "ARRAY"
633
644
 
645
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
646
+ if expected_dtype == "":
647
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
648
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
649
+ expected_dtype = "ARRAY"
650
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
651
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
652
+ expected_dtype = "ARRAY"
653
+ else:
654
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
655
+ # We can only infer the output types from the input types if the following two statemetns are true:
656
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
657
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
658
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
659
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
660
+
634
661
  output_df = self._batch_inference(
635
662
  dataset=dataset,
636
663
  inference_method="transform",
@@ -646,8 +673,8 @@ class TSNE(BaseTransformer):
646
673
 
647
674
  return output_df
648
675
 
649
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
650
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
676
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
677
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
651
678
  """ Method not supported for this class.
652
679
 
653
680
 
@@ -660,13 +687,21 @@ class TSNE(BaseTransformer):
660
687
  Returns:
661
688
  Predicted dataset.
662
689
  """
663
- if False:
664
- self.fit(dataset)
665
- assert self._sklearn_object is not None
666
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
667
- return labels
668
- else:
669
- raise NotImplementedError
690
+ self.fit(dataset)
691
+ assert self._sklearn_object is not None
692
+ return self._sklearn_object.labels_
693
+
694
+
695
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
696
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
697
+ """
698
+ Returns:
699
+ Transformed dataset.
700
+ """
701
+ self.fit(dataset)
702
+ assert self._sklearn_object is not None
703
+ return self._sklearn_object.embedding_
704
+
670
705
 
671
706
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
672
707
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -228,16 +228,15 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
228
228
  Returns:
229
229
  Name of the UDTF.
230
230
  """
231
+ batch_size = metrics_utils.BATCH_SIZE
231
232
 
232
233
  class ConfusionMatrixComputer:
233
- BATCH_SIZE = 1000
234
-
235
234
  def __init__(self) -> None:
236
235
  self._initialized = False
237
236
  self._confusion_matrix = np.zeros((1, 1))
238
- # 2d array containing a batch of input rows. A batch contains self.BATCH_SIZE rows.
237
+ # 2d array containing a batch of input rows. A batch contains metrics_utils.BATCH_SIZE rows.
239
238
  # [sample_weight, y_true, y_pred]
240
- self._batched_rows = np.zeros((self.BATCH_SIZE, 1))
239
+ self._batched_rows = np.zeros((batch_size, 1))
241
240
  # Number of columns in the dataset.
242
241
  self._n_cols = -1
243
242
  # Running count of number of rows added to self._batched_rows.
@@ -255,7 +254,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
255
254
  # 1. Initialize variables.
256
255
  if not self._initialized:
257
256
  self._n_cols = len(input_row)
258
- self._batched_rows = np.zeros((self.BATCH_SIZE, self._n_cols))
257
+ self._batched_rows = np.zeros((batch_size, self._n_cols))
259
258
  self._n_label = n_label
260
259
  self._confusion_matrix = np.zeros((self._n_label, self._n_label))
261
260
  self._initialized = True
@@ -264,7 +263,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
264
263
  self._cur_count += 1
265
264
 
266
265
  # 2. Compute incremental confusion matrix for the batch.
267
- if self._cur_count >= self.BATCH_SIZE:
266
+ if self._cur_count >= batch_size:
268
267
  self.update_confusion_matrix()
269
268
  self._cur_count = 0
270
269
 
@@ -15,6 +15,7 @@ from snowflake.snowpark import Session, functions as F, types as T
15
15
 
16
16
  LABEL = "LABEL"
17
17
  INDEX = "INDEX"
18
+ BATCH_SIZE = 1000
18
19
 
19
20
 
20
21
  def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, Any]) -> str:
@@ -82,7 +83,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
82
83
  """This class is registered as a UDTF and computes the sum and dot product
83
84
  of columns for each partition of rows. The computations across all the partitions happens
84
85
  in parallel using the nodes in the warehouse. In order to avoid keeping the entire partition
85
- in memory, we batch the rows (size is 1000) and maintain a running sum and dot prod in self._sum_by_count,
86
+ in memory, we batch the rows and maintain a running sum and dot prod in self._sum_by_count,
86
87
  self._sum_by_countd and self._dot_prod respectively. We return these at the end of the partition.
87
88
  """
88
89
 
@@ -95,7 +96,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
95
96
  # delta degree of freedom
96
97
  self._ddof = 0
97
98
  # Setting the batch size to 1000 based on experimentation. Can be fine tuned later.
98
- self._batch_size = 1000
99
+ self._batch_size = BATCH_SIZE
99
100
  # 2d array containing a batch of input rows. A batch contains self._batch_size rows.
100
101
  self._batched_rows = np.zeros((self._batch_size, 1))
101
102
  # 1d array of length = # of cols. Contains sum(col/count) for each column.
@@ -224,7 +225,7 @@ def check_label_columns(
224
225
  TypeError: `y_true_col_names` and `y_pred_col_names` are of different types.
225
226
  ValueError: Multilabel `y_true_col_names` and `y_pred_col_names` are of different lengths.
226
227
  """
227
- if type(y_true_col_names) != type(y_pred_col_names):
228
+ if type(y_true_col_names) is not type(y_pred_col_names):
228
229
  raise TypeError(
229
230
  "Label columns should be of the same type."
230
231
  f"Got y_true_col_names={type(y_true_col_names)} vs y_pred_col_names={type(y_pred_col_names)}."
@@ -300,6 +301,7 @@ def validate_average_pos_label(average: Optional[str] = None, pos_label: Union[s
300
301
  "average != 'binary' (got %r). You may use "
301
302
  "labels=[pos_label] to specify a single positive class." % (pos_label, average),
302
303
  UserWarning,
304
+ stacklevel=2,
303
305
  )
304
306
 
305
307