snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
61
60
 
62
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
62
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
63
  class KMeans(BaseTransformer):
71
64
  r"""K-Means clustering
72
65
  For more details on this class, see [sklearn.cluster.KMeans]
@@ -277,12 +270,7 @@ class KMeans(BaseTransformer):
277
270
  )
278
271
  return selected_cols
279
272
 
280
- @telemetry.send_api_usage_telemetry(
281
- project=_PROJECT,
282
- subproject=_SUBPROJECT,
283
- custom_tags=dict([("autogen", True)]),
284
- )
285
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "KMeans":
273
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "KMeans":
286
274
  """Compute k-means clustering
287
275
  For more details on this function, see [sklearn.cluster.KMeans.fit]
288
276
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit)
@@ -309,12 +297,14 @@ class KMeans(BaseTransformer):
309
297
 
310
298
  self._snowpark_cols = dataset.select(self.input_cols).columns
311
299
 
312
- # If we are already in a stored procedure, no need to kick off another one.
300
+ # If we are already in a stored procedure, no need to kick off another one.
313
301
  if SNOWML_SPROC_ENV in os.environ:
314
302
  statement_params = telemetry.get_function_usage_statement_params(
315
303
  project=_PROJECT,
316
304
  subproject=_SUBPROJECT,
317
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), KMeans.__class__.__name__),
305
+ function_name=telemetry.get_statement_params_full_func_name(
306
+ inspect.currentframe(), KMeans.__class__.__name__
307
+ ),
318
308
  api_calls=[Session.call],
319
309
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
320
310
  )
@@ -335,27 +325,24 @@ class KMeans(BaseTransformer):
335
325
  )
336
326
  self._sklearn_object = model_trainer.train()
337
327
  self._is_fitted = True
338
- self._get_model_signatures(dataset)
328
+ self._generate_model_signatures(dataset)
339
329
  return self
340
330
 
341
331
  def _batch_inference_validate_snowpark(
342
332
  self,
343
333
  dataset: DataFrame,
344
334
  inference_method: str,
345
- ) -> List[str]:
346
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
347
- return the available package that exists in the snowflake anaconda channel
335
+ ) -> None:
336
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
348
337
 
349
338
  Args:
350
339
  dataset: snowpark dataframe
351
340
  inference_method: the inference method such as predict, score...
352
-
341
+
353
342
  Raises:
354
343
  SnowflakeMLException: If the estimator is not fitted, raise error
355
344
  SnowflakeMLException: If the session is None, raise error
356
345
 
357
- Returns:
358
- A list of available package that exists in the snowflake anaconda channel
359
346
  """
360
347
  if not self._is_fitted:
361
348
  raise exceptions.SnowflakeMLException(
@@ -373,9 +360,7 @@ class KMeans(BaseTransformer):
373
360
  "Session must not specified for snowpark dataset."
374
361
  ),
375
362
  )
376
- # Validate that key package version in user workspace are supported in snowflake conda channel
377
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
378
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
363
+
379
364
 
380
365
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
381
366
  @telemetry.send_api_usage_telemetry(
@@ -411,7 +396,9 @@ class KMeans(BaseTransformer):
411
396
  # when it is classifier, infer the datatype from label columns
412
397
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
413
398
  # Batch inference takes a single expected output column type. Use the first columns type for now.
414
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
399
+ label_cols_signatures = [
400
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
401
+ ]
415
402
  if len(label_cols_signatures) == 0:
416
403
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
417
404
  raise exceptions.SnowflakeMLException(
@@ -419,25 +406,23 @@ class KMeans(BaseTransformer):
419
406
  original_exception=ValueError(error_str),
420
407
  )
421
408
 
422
- expected_type_inferred = convert_sp_to_sf_type(
423
- label_cols_signatures[0].as_snowpark_type()
424
- )
409
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
425
410
 
426
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
427
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
411
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
412
+ self._deps = self._get_dependencies()
413
+ assert isinstance(
414
+ dataset._session, Session
415
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
428
416
 
429
417
  transform_kwargs = dict(
430
- session = dataset._session,
431
- dependencies = self._deps,
432
- drop_input_cols = self._drop_input_cols,
433
- expected_output_cols_type = expected_type_inferred,
418
+ session=dataset._session,
419
+ dependencies=self._deps,
420
+ drop_input_cols=self._drop_input_cols,
421
+ expected_output_cols_type=expected_type_inferred,
434
422
  )
435
423
 
436
424
  elif isinstance(dataset, pd.DataFrame):
437
- transform_kwargs = dict(
438
- snowpark_input_cols = self._snowpark_cols,
439
- drop_input_cols = self._drop_input_cols
440
- )
425
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
441
426
 
442
427
  transform_handlers = ModelTransformerBuilder.build(
443
428
  dataset=dataset,
@@ -479,7 +464,7 @@ class KMeans(BaseTransformer):
479
464
  Transformed dataset.
480
465
  """
481
466
  super()._check_dataset_type(dataset)
482
- inference_method="transform"
467
+ inference_method = "transform"
483
468
 
484
469
  # This dictionary contains optional kwargs for batch inference. These kwargs
485
470
  # are specific to the type of dataset used.
@@ -509,24 +494,19 @@ class KMeans(BaseTransformer):
509
494
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
510
495
  expected_dtype = convert_sp_to_sf_type(output_types[0])
511
496
 
512
- self._deps = self._batch_inference_validate_snowpark(
513
- dataset=dataset,
514
- inference_method=inference_method,
515
- )
497
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
498
+ self._deps = self._get_dependencies()
516
499
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
517
500
 
518
501
  transform_kwargs = dict(
519
- session = dataset._session,
520
- dependencies = self._deps,
521
- drop_input_cols = self._drop_input_cols,
522
- expected_output_cols_type = expected_dtype,
502
+ session=dataset._session,
503
+ dependencies=self._deps,
504
+ drop_input_cols=self._drop_input_cols,
505
+ expected_output_cols_type=expected_dtype,
523
506
  )
524
507
 
525
508
  elif isinstance(dataset, pd.DataFrame):
526
- transform_kwargs = dict(
527
- snowpark_input_cols = self._snowpark_cols,
528
- drop_input_cols = self._drop_input_cols
529
- )
509
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
530
510
 
531
511
  transform_handlers = ModelTransformerBuilder.build(
532
512
  dataset=dataset,
@@ -545,7 +525,11 @@ class KMeans(BaseTransformer):
545
525
  return output_df
546
526
 
547
527
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
548
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
528
+ def fit_predict(
529
+ self,
530
+ dataset: Union[DataFrame, pd.DataFrame],
531
+ output_cols_prefix: str = "fit_predict_",
532
+ ) -> Union[DataFrame, pd.DataFrame]:
549
533
  """ Compute cluster centers and predict cluster index for each sample
550
534
  For more details on this function, see [sklearn.cluster.KMeans.fit_predict]
551
535
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_predict)
@@ -572,22 +556,106 @@ class KMeans(BaseTransformer):
572
556
  )
573
557
  output_result, fitted_estimator = model_trainer.train_fit_predict(
574
558
  drop_input_cols=self._drop_input_cols,
575
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
559
+ expected_output_cols_list=(
560
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
561
+ ),
576
562
  )
577
563
  self._sklearn_object = fitted_estimator
578
564
  self._is_fitted = True
579
565
  return output_result
580
566
 
567
+
568
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
569
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
570
+ """ Compute clustering and transform X to cluster-distance space
571
+ For more details on this function, see [sklearn.cluster.KMeans.fit_transform]
572
+ (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_transform)
573
+
574
+
575
+ Raises:
576
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
581
577
 
582
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
583
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
584
- """
578
+ Args:
579
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
580
+ Snowpark or Pandas DataFrame.
581
+ output_cols_prefix: Prefix for the response columns
585
582
  Returns:
586
583
  Transformed dataset.
587
584
  """
588
- self.fit(dataset)
589
- assert self._sklearn_object is not None
590
- return self._sklearn_object.embedding_
585
+ self._infer_input_output_cols(dataset)
586
+ super()._check_dataset_type(dataset)
587
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
588
+ estimator=self._sklearn_object,
589
+ dataset=dataset,
590
+ input_cols=self.input_cols,
591
+ label_cols=self.label_cols,
592
+ sample_weight_col=self.sample_weight_col,
593
+ autogenerated=self._autogenerated,
594
+ subproject=_SUBPROJECT,
595
+ )
596
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
597
+ drop_input_cols=self._drop_input_cols,
598
+ expected_output_cols_list=self.output_cols,
599
+ )
600
+ self._sklearn_object = fitted_estimator
601
+ self._is_fitted = True
602
+ return output_result
603
+
604
+
605
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
606
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
607
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
608
+ """
609
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
610
+ # The following condition is introduced for kneighbors methods, and not used in other methods
611
+ if output_cols:
612
+ output_cols = [
613
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
614
+ for c in output_cols
615
+ ]
616
+ elif getattr(self._sklearn_object, "classes_", None) is None:
617
+ output_cols = [output_cols_prefix]
618
+ elif self._sklearn_object is not None:
619
+ classes = self._sklearn_object.classes_
620
+ if isinstance(classes, numpy.ndarray):
621
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
622
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
623
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
624
+ output_cols = []
625
+ for i, cl in enumerate(classes):
626
+ # For binary classification, there is only one output column for each class
627
+ # ndarray as the two classes are complementary.
628
+ if len(cl) == 2:
629
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
630
+ else:
631
+ output_cols.extend([
632
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
633
+ ])
634
+ else:
635
+ output_cols = []
636
+
637
+ # Make sure column names are valid snowflake identifiers.
638
+ assert output_cols is not None # Make MyPy happy
639
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
640
+
641
+ return rv
642
+
643
+ def _align_expected_output_names(
644
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
645
+ ) -> List[str]:
646
+ # in case the inferred output column names dimension is different
647
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
648
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
649
+ output_df_columns = list(output_df_pd.columns)
650
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
651
+ if self.sample_weight_col:
652
+ output_df_columns_set -= set(self.sample_weight_col)
653
+ # if the dimension of inferred output column names is correct; use it
654
+ if len(expected_output_cols_list) == len(output_df_columns_set):
655
+ return expected_output_cols_list
656
+ # otherwise, use the sklearn estimator's output
657
+ else:
658
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
591
659
 
592
660
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
593
661
  @telemetry.send_api_usage_telemetry(
@@ -619,24 +687,26 @@ class KMeans(BaseTransformer):
619
687
  # are specific to the type of dataset used.
620
688
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
621
689
 
690
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
691
+
622
692
  if isinstance(dataset, DataFrame):
623
- self._deps = self._batch_inference_validate_snowpark(
624
- dataset=dataset,
625
- inference_method=inference_method,
626
- )
627
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
693
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
694
+ self._deps = self._get_dependencies()
695
+ assert isinstance(
696
+ dataset._session, Session
697
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
628
698
  transform_kwargs = dict(
629
699
  session=dataset._session,
630
700
  dependencies=self._deps,
631
- drop_input_cols = self._drop_input_cols,
701
+ drop_input_cols=self._drop_input_cols,
632
702
  expected_output_cols_type="float",
633
703
  )
704
+ expected_output_cols = self._align_expected_output_names(
705
+ inference_method, dataset, expected_output_cols, output_cols_prefix
706
+ )
634
707
 
635
708
  elif isinstance(dataset, pd.DataFrame):
636
- transform_kwargs = dict(
637
- snowpark_input_cols = self._snowpark_cols,
638
- drop_input_cols = self._drop_input_cols
639
- )
709
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
640
710
 
641
711
  transform_handlers = ModelTransformerBuilder.build(
642
712
  dataset=dataset,
@@ -648,7 +718,7 @@ class KMeans(BaseTransformer):
648
718
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
649
719
  inference_method=inference_method,
650
720
  input_cols=self.input_cols,
651
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
721
+ expected_output_cols=expected_output_cols,
652
722
  **transform_kwargs
653
723
  )
654
724
  return output_df
@@ -678,29 +748,30 @@ class KMeans(BaseTransformer):
678
748
  Output dataset with log probability of the sample for each class in the model.
679
749
  """
680
750
  super()._check_dataset_type(dataset)
681
- inference_method="predict_log_proba"
751
+ inference_method = "predict_log_proba"
752
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
682
753
 
683
754
  # This dictionary contains optional kwargs for batch inference. These kwargs
684
755
  # are specific to the type of dataset used.
685
756
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
686
757
 
687
758
  if isinstance(dataset, DataFrame):
688
- self._deps = self._batch_inference_validate_snowpark(
689
- dataset=dataset,
690
- inference_method=inference_method,
691
- )
692
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
759
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
760
+ self._deps = self._get_dependencies()
761
+ assert isinstance(
762
+ dataset._session, Session
763
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
693
764
  transform_kwargs = dict(
694
765
  session=dataset._session,
695
766
  dependencies=self._deps,
696
- drop_input_cols = self._drop_input_cols,
767
+ drop_input_cols=self._drop_input_cols,
697
768
  expected_output_cols_type="float",
698
769
  )
770
+ expected_output_cols = self._align_expected_output_names(
771
+ inference_method, dataset, expected_output_cols, output_cols_prefix
772
+ )
699
773
  elif isinstance(dataset, pd.DataFrame):
700
- transform_kwargs = dict(
701
- snowpark_input_cols = self._snowpark_cols,
702
- drop_input_cols = self._drop_input_cols
703
- )
774
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
704
775
 
705
776
  transform_handlers = ModelTransformerBuilder.build(
706
777
  dataset=dataset,
@@ -713,7 +784,7 @@ class KMeans(BaseTransformer):
713
784
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
714
785
  inference_method=inference_method,
715
786
  input_cols=self.input_cols,
716
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
787
+ expected_output_cols=expected_output_cols,
717
788
  **transform_kwargs
718
789
  )
719
790
  return output_df
@@ -739,30 +810,32 @@ class KMeans(BaseTransformer):
739
810
  Output dataset with results of the decision function for the samples in input dataset.
740
811
  """
741
812
  super()._check_dataset_type(dataset)
742
- inference_method="decision_function"
813
+ inference_method = "decision_function"
743
814
 
744
815
  # This dictionary contains optional kwargs for batch inference. These kwargs
745
816
  # are specific to the type of dataset used.
746
817
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
747
818
 
819
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
820
+
748
821
  if isinstance(dataset, DataFrame):
749
- self._deps = self._batch_inference_validate_snowpark(
750
- dataset=dataset,
751
- inference_method=inference_method,
752
- )
753
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
822
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
823
+ self._deps = self._get_dependencies()
824
+ assert isinstance(
825
+ dataset._session, Session
826
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
754
827
  transform_kwargs = dict(
755
828
  session=dataset._session,
756
829
  dependencies=self._deps,
757
- drop_input_cols = self._drop_input_cols,
830
+ drop_input_cols=self._drop_input_cols,
758
831
  expected_output_cols_type="float",
759
832
  )
833
+ expected_output_cols = self._align_expected_output_names(
834
+ inference_method, dataset, expected_output_cols, output_cols_prefix
835
+ )
760
836
 
761
837
  elif isinstance(dataset, pd.DataFrame):
762
- transform_kwargs = dict(
763
- snowpark_input_cols = self._snowpark_cols,
764
- drop_input_cols = self._drop_input_cols
765
- )
838
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
766
839
 
767
840
  transform_handlers = ModelTransformerBuilder.build(
768
841
  dataset=dataset,
@@ -775,7 +848,7 @@ class KMeans(BaseTransformer):
775
848
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
776
849
  inference_method=inference_method,
777
850
  input_cols=self.input_cols,
778
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
851
+ expected_output_cols=expected_output_cols,
779
852
  **transform_kwargs
780
853
  )
781
854
  return output_df
@@ -804,17 +877,17 @@ class KMeans(BaseTransformer):
804
877
  Output dataset with probability of the sample for each class in the model.
805
878
  """
806
879
  super()._check_dataset_type(dataset)
807
- inference_method="score_samples"
880
+ inference_method = "score_samples"
808
881
 
809
882
  # This dictionary contains optional kwargs for batch inference. These kwargs
810
883
  # are specific to the type of dataset used.
811
884
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
812
885
 
886
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
887
+
813
888
  if isinstance(dataset, DataFrame):
814
- self._deps = self._batch_inference_validate_snowpark(
815
- dataset=dataset,
816
- inference_method=inference_method,
817
- )
889
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
890
+ self._deps = self._get_dependencies()
818
891
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
819
892
  transform_kwargs = dict(
820
893
  session=dataset._session,
@@ -822,6 +895,9 @@ class KMeans(BaseTransformer):
822
895
  drop_input_cols = self._drop_input_cols,
823
896
  expected_output_cols_type="float",
824
897
  )
898
+ expected_output_cols = self._align_expected_output_names(
899
+ inference_method, dataset, expected_output_cols, output_cols_prefix
900
+ )
825
901
 
826
902
  elif isinstance(dataset, pd.DataFrame):
827
903
  transform_kwargs = dict(
@@ -840,7 +916,7 @@ class KMeans(BaseTransformer):
840
916
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
841
917
  inference_method=inference_method,
842
918
  input_cols=self.input_cols,
843
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
919
+ expected_output_cols=expected_output_cols,
844
920
  **transform_kwargs
845
921
  )
846
922
  return output_df
@@ -875,17 +951,15 @@ class KMeans(BaseTransformer):
875
951
  transform_kwargs: ScoreKwargsTypedDict = dict()
876
952
 
877
953
  if isinstance(dataset, DataFrame):
878
- self._deps = self._batch_inference_validate_snowpark(
879
- dataset=dataset,
880
- inference_method="score",
881
- )
954
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
955
+ self._deps = self._get_dependencies()
882
956
  selected_cols = self._get_active_columns()
883
957
  if len(selected_cols) > 0:
884
958
  dataset = dataset.select(selected_cols)
885
959
  assert isinstance(dataset._session, Session) # keep mypy happy
886
960
  transform_kwargs = dict(
887
961
  session=dataset._session,
888
- dependencies=["snowflake-snowpark-python"] + self._deps,
962
+ dependencies=self._deps,
889
963
  score_sproc_imports=['sklearn'],
890
964
  )
891
965
  elif isinstance(dataset, pd.DataFrame):
@@ -950,11 +1024,8 @@ class KMeans(BaseTransformer):
950
1024
 
951
1025
  if isinstance(dataset, DataFrame):
952
1026
 
953
- self._deps = self._batch_inference_validate_snowpark(
954
- dataset=dataset,
955
- inference_method=inference_method,
956
-
957
- )
1027
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1028
+ self._deps = self._get_dependencies()
958
1029
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
959
1030
  transform_kwargs = dict(
960
1031
  session = dataset._session,
@@ -987,50 +1058,84 @@ class KMeans(BaseTransformer):
987
1058
  )
988
1059
  return output_df
989
1060
 
1061
+
1062
+
1063
+ def to_sklearn(self) -> Any:
1064
+ """Get sklearn.cluster.KMeans object.
1065
+ """
1066
+ if self._sklearn_object is None:
1067
+ self._sklearn_object = self._create_sklearn_object()
1068
+ return self._sklearn_object
1069
+
1070
+ def to_xgboost(self) -> Any:
1071
+ raise exceptions.SnowflakeMLException(
1072
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1073
+ original_exception=AttributeError(
1074
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1075
+ "to_xgboost()",
1076
+ "to_sklearn()"
1077
+ )
1078
+ ),
1079
+ )
990
1080
 
991
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1081
+ def to_lightgbm(self) -> Any:
1082
+ raise exceptions.SnowflakeMLException(
1083
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1084
+ original_exception=AttributeError(
1085
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1086
+ "to_lightgbm()",
1087
+ "to_sklearn()"
1088
+ )
1089
+ ),
1090
+ )
1091
+
1092
+ def _get_dependencies(self) -> List[str]:
1093
+ return self._deps
1094
+
1095
+
1096
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
992
1097
  self._model_signature_dict = dict()
993
1098
 
994
1099
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
995
1100
 
996
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1101
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
997
1102
  outputs: List[BaseFeatureSpec] = []
998
1103
  if hasattr(self, "predict"):
999
1104
  # keep mypy happy
1000
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1105
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1001
1106
  # For classifier, the type of predict is the same as the type of label
1002
- if self._sklearn_object._estimator_type == 'classifier':
1003
- # label columns is the desired type for output
1107
+ if self._sklearn_object._estimator_type == "classifier":
1108
+ # label columns is the desired type for output
1004
1109
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1005
1110
  # rename the output columns
1006
1111
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1007
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1008
- ([] if self._drop_input_cols else inputs)
1009
- + outputs)
1112
+ self._model_signature_dict["predict"] = ModelSignature(
1113
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1114
+ )
1010
1115
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
1011
1116
  # For outlier models, returns -1 for outliers and 1 for inliers.
1012
- # Clusterer returns int64 cluster labels.
1117
+ # Clusterer returns int64 cluster labels.
1013
1118
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
1014
1119
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
1015
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1016
- ([] if self._drop_input_cols else inputs)
1017
- + outputs)
1018
-
1120
+ self._model_signature_dict["predict"] = ModelSignature(
1121
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1122
+ )
1123
+
1019
1124
  # For regressor, the type of predict is float64
1020
- elif self._sklearn_object._estimator_type == 'regressor':
1125
+ elif self._sklearn_object._estimator_type == "regressor":
1021
1126
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1022
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1023
- ([] if self._drop_input_cols else inputs)
1024
- + outputs)
1025
-
1127
+ self._model_signature_dict["predict"] = ModelSignature(
1128
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1129
+ )
1130
+
1026
1131
  for prob_func in PROB_FUNCTIONS:
1027
1132
  if hasattr(self, prob_func):
1028
1133
  output_cols_prefix: str = f"{prob_func}_"
1029
1134
  output_column_names = self._get_output_column_names(output_cols_prefix)
1030
1135
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1031
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1032
- ([] if self._drop_input_cols else inputs)
1033
- + outputs)
1136
+ self._model_signature_dict[prob_func] = ModelSignature(
1137
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1138
+ )
1034
1139
 
1035
1140
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1036
1141
  items = list(self._model_signature_dict.items())
@@ -1043,10 +1148,10 @@ class KMeans(BaseTransformer):
1043
1148
  """Returns model signature of current class.
1044
1149
 
1045
1150
  Raises:
1046
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1151
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1047
1152
 
1048
1153
  Returns:
1049
- Dict[str, ModelSignature]: each method and its input output signature
1154
+ Dict with each method and its input output signature
1050
1155
  """
1051
1156
  if self._model_signature_dict is None:
1052
1157
  raise exceptions.SnowflakeMLException(
@@ -1054,35 +1159,3 @@ class KMeans(BaseTransformer):
1054
1159
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1055
1160
  )
1056
1161
  return self._model_signature_dict
1057
-
1058
- def to_sklearn(self) -> Any:
1059
- """Get sklearn.cluster.KMeans object.
1060
- """
1061
- if self._sklearn_object is None:
1062
- self._sklearn_object = self._create_sklearn_object()
1063
- return self._sklearn_object
1064
-
1065
- def to_xgboost(self) -> Any:
1066
- raise exceptions.SnowflakeMLException(
1067
- error_code=error_codes.METHOD_NOT_ALLOWED,
1068
- original_exception=AttributeError(
1069
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1070
- "to_xgboost()",
1071
- "to_sklearn()"
1072
- )
1073
- ),
1074
- )
1075
-
1076
- def to_lightgbm(self) -> Any:
1077
- raise exceptions.SnowflakeMLException(
1078
- error_code=error_codes.METHOD_NOT_ALLOWED,
1079
- original_exception=AttributeError(
1080
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1081
- "to_lightgbm()",
1082
- "to_sklearn()"
1083
- )
1084
- ),
1085
- )
1086
-
1087
- def _get_dependencies(self) -> List[str]:
1088
- return self._deps