snowflake-ml-python 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/utils/identifier.py +3 -1
  3. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  4. snowflake/ml/feature_store/feature_store.py +151 -78
  5. snowflake/ml/feature_store/feature_view.py +12 -24
  6. snowflake/ml/fileset/sfcfs.py +56 -50
  7. snowflake/ml/fileset/stage_fs.py +48 -13
  8. snowflake/ml/model/_client/model/model_version_impl.py +2 -50
  9. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  10. snowflake/ml/model/_client/sql/model.py +23 -2
  11. snowflake/ml/model/_client/sql/model_version.py +22 -1
  12. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  13. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  14. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  15. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  16. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  18. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  19. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  20. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  21. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  22. snowflake/ml/model/_packager/model_packager.py +2 -2
  23. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  24. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  25. snowflake/ml/model/type_hints.py +21 -2
  26. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  27. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  29. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  30. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  31. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  32. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  33. snowflake/ml/modeling/cluster/birch.py +195 -123
  34. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  35. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  36. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  37. snowflake/ml/modeling/cluster/k_means.py +195 -123
  38. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  39. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  40. snowflake/ml/modeling/cluster/optics.py +195 -123
  41. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  42. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  43. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  44. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  45. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  46. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  47. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  48. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  49. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  50. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  51. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  52. snowflake/ml/modeling/covariance/oas.py +195 -123
  53. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  54. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  55. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  56. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  57. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  58. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  59. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  60. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  61. snowflake/ml/modeling/decomposition/pca.py +195 -123
  62. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  63. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  64. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  65. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  66. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  67. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  68. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  69. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  70. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  71. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  72. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  73. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  76. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  77. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  78. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  79. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  80. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  81. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  82. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  83. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  84. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  85. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  86. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  87. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  88. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  89. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  90. snowflake/ml/modeling/framework/_utils.py +8 -1
  91. snowflake/ml/modeling/framework/base.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  94. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  95. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  96. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +195 -123
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +195 -123
  105. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  107. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  111. snowflake/ml/modeling/linear_model/lars.py +195 -123
  112. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  113. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  118. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  128. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  131. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  140. snowflake/ml/modeling/manifold/isomap.py +195 -123
  141. snowflake/ml/modeling/manifold/mds.py +195 -123
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  143. snowflake/ml/modeling/manifold/tsne.py +195 -123
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  146. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  147. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  148. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  149. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  150. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  151. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  152. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  153. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  154. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  155. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  156. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  157. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  158. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  159. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  160. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  161. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  162. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  163. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  164. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  165. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  166. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  167. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  168. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  169. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  170. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  171. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  172. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  173. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  174. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  175. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  176. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  178. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  179. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  180. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  181. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  182. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  183. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  184. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  185. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  186. snowflake/ml/modeling/svm/svc.py +195 -123
  187. snowflake/ml/modeling/svm/svr.py +195 -123
  188. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  189. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  190. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  191. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  192. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  193. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  194. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  195. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  196. snowflake/ml/registry/registry.py +1 -1
  197. snowflake/ml/version.py +1 -1
  198. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +68 -57
  199. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +202 -200
  200. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  201. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  202. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  203. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -319,12 +318,7 @@ class SpectralClustering(BaseTransformer):
319
318
  )
320
319
  return selected_cols
321
320
 
322
- @telemetry.send_api_usage_telemetry(
323
- project=_PROJECT,
324
- subproject=_SUBPROJECT,
325
- custom_tags=dict([("autogen", True)]),
326
- )
327
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SpectralClustering":
321
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SpectralClustering":
328
322
  """Perform spectral clustering from features, or affinity matrix
329
323
  For more details on this function, see [sklearn.cluster.SpectralClustering.fit]
330
324
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.SpectralClustering.html#sklearn.cluster.SpectralClustering.fit)
@@ -351,12 +345,14 @@ class SpectralClustering(BaseTransformer):
351
345
 
352
346
  self._snowpark_cols = dataset.select(self.input_cols).columns
353
347
 
354
- # If we are already in a stored procedure, no need to kick off another one.
348
+ # If we are already in a stored procedure, no need to kick off another one.
355
349
  if SNOWML_SPROC_ENV in os.environ:
356
350
  statement_params = telemetry.get_function_usage_statement_params(
357
351
  project=_PROJECT,
358
352
  subproject=_SUBPROJECT,
359
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), SpectralClustering.__class__.__name__),
353
+ function_name=telemetry.get_statement_params_full_func_name(
354
+ inspect.currentframe(), SpectralClustering.__class__.__name__
355
+ ),
360
356
  api_calls=[Session.call],
361
357
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
362
358
  )
@@ -377,7 +373,7 @@ class SpectralClustering(BaseTransformer):
377
373
  )
378
374
  self._sklearn_object = model_trainer.train()
379
375
  self._is_fitted = True
380
- self._get_model_signatures(dataset)
376
+ self._generate_model_signatures(dataset)
381
377
  return self
382
378
 
383
379
  def _batch_inference_validate_snowpark(
@@ -451,7 +447,9 @@ class SpectralClustering(BaseTransformer):
451
447
  # when it is classifier, infer the datatype from label columns
452
448
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
453
449
  # Batch inference takes a single expected output column type. Use the first columns type for now.
454
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
450
+ label_cols_signatures = [
451
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
452
+ ]
455
453
  if len(label_cols_signatures) == 0:
456
454
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
457
455
  raise exceptions.SnowflakeMLException(
@@ -459,25 +457,22 @@ class SpectralClustering(BaseTransformer):
459
457
  original_exception=ValueError(error_str),
460
458
  )
461
459
 
462
- expected_type_inferred = convert_sp_to_sf_type(
463
- label_cols_signatures[0].as_snowpark_type()
464
- )
460
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
465
461
 
466
462
  self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
467
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
463
+ assert isinstance(
464
+ dataset._session, Session
465
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
468
466
 
469
467
  transform_kwargs = dict(
470
- session = dataset._session,
471
- dependencies = self._deps,
472
- drop_input_cols = self._drop_input_cols,
473
- expected_output_cols_type = expected_type_inferred,
468
+ session=dataset._session,
469
+ dependencies=self._deps,
470
+ drop_input_cols=self._drop_input_cols,
471
+ expected_output_cols_type=expected_type_inferred,
474
472
  )
475
473
 
476
474
  elif isinstance(dataset, pd.DataFrame):
477
- transform_kwargs = dict(
478
- snowpark_input_cols = self._snowpark_cols,
479
- drop_input_cols = self._drop_input_cols
480
- )
475
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
481
476
 
482
477
  transform_handlers = ModelTransformerBuilder.build(
483
478
  dataset=dataset,
@@ -517,7 +512,7 @@ class SpectralClustering(BaseTransformer):
517
512
  Transformed dataset.
518
513
  """
519
514
  super()._check_dataset_type(dataset)
520
- inference_method="transform"
515
+ inference_method = "transform"
521
516
 
522
517
  # This dictionary contains optional kwargs for batch inference. These kwargs
523
518
  # are specific to the type of dataset used.
@@ -554,17 +549,14 @@ class SpectralClustering(BaseTransformer):
554
549
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
555
550
 
556
551
  transform_kwargs = dict(
557
- session = dataset._session,
558
- dependencies = self._deps,
559
- drop_input_cols = self._drop_input_cols,
560
- expected_output_cols_type = expected_dtype,
552
+ session=dataset._session,
553
+ dependencies=self._deps,
554
+ drop_input_cols=self._drop_input_cols,
555
+ expected_output_cols_type=expected_dtype,
561
556
  )
562
557
 
563
558
  elif isinstance(dataset, pd.DataFrame):
564
- transform_kwargs = dict(
565
- snowpark_input_cols = self._snowpark_cols,
566
- drop_input_cols = self._drop_input_cols
567
- )
559
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
568
560
 
569
561
  transform_handlers = ModelTransformerBuilder.build(
570
562
  dataset=dataset,
@@ -583,7 +575,11 @@ class SpectralClustering(BaseTransformer):
583
575
  return output_df
584
576
 
585
577
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
586
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
578
+ def fit_predict(
579
+ self,
580
+ dataset: Union[DataFrame, pd.DataFrame],
581
+ output_cols_prefix: str = "fit_predict_",
582
+ ) -> Union[DataFrame, pd.DataFrame]:
587
583
  """ Perform spectral clustering on `X` and return cluster labels
588
584
  For more details on this function, see [sklearn.cluster.SpectralClustering.fit_predict]
589
585
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.SpectralClustering.html#sklearn.cluster.SpectralClustering.fit_predict)
@@ -610,7 +606,9 @@ class SpectralClustering(BaseTransformer):
610
606
  )
611
607
  output_result, fitted_estimator = model_trainer.train_fit_predict(
612
608
  drop_input_cols=self._drop_input_cols,
613
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
609
+ expected_output_cols_list=(
610
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
611
+ ),
614
612
  )
615
613
  self._sklearn_object = fitted_estimator
616
614
  self._is_fitted = True
@@ -627,6 +625,62 @@ class SpectralClustering(BaseTransformer):
627
625
  assert self._sklearn_object is not None
628
626
  return self._sklearn_object.embedding_
629
627
 
628
+
629
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
630
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
631
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
632
+ """
633
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
634
+ # The following condition is introduced for kneighbors methods, and not used in other methods
635
+ if output_cols:
636
+ output_cols = [
637
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
638
+ for c in output_cols
639
+ ]
640
+ elif getattr(self._sklearn_object, "classes_", None) is None:
641
+ output_cols = [output_cols_prefix]
642
+ elif self._sklearn_object is not None:
643
+ classes = self._sklearn_object.classes_
644
+ if isinstance(classes, numpy.ndarray):
645
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
646
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
647
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
648
+ output_cols = []
649
+ for i, cl in enumerate(classes):
650
+ # For binary classification, there is only one output column for each class
651
+ # ndarray as the two classes are complementary.
652
+ if len(cl) == 2:
653
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
654
+ else:
655
+ output_cols.extend([
656
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
657
+ ])
658
+ else:
659
+ output_cols = []
660
+
661
+ # Make sure column names are valid snowflake identifiers.
662
+ assert output_cols is not None # Make MyPy happy
663
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
664
+
665
+ return rv
666
+
667
+ def _align_expected_output_names(
668
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
669
+ ) -> List[str]:
670
+ # in case the inferred output column names dimension is different
671
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
672
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
673
+ output_df_columns = list(output_df_pd.columns)
674
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
675
+ if self.sample_weight_col:
676
+ output_df_columns_set -= set(self.sample_weight_col)
677
+ # if the dimension of inferred output column names is correct; use it
678
+ if len(expected_output_cols_list) == len(output_df_columns_set):
679
+ return expected_output_cols_list
680
+ # otherwise, use the sklearn estimator's output
681
+ else:
682
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+
630
684
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
631
685
  @telemetry.send_api_usage_telemetry(
632
686
  project=_PROJECT,
@@ -657,24 +711,28 @@ class SpectralClustering(BaseTransformer):
657
711
  # are specific to the type of dataset used.
658
712
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
659
713
 
714
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
715
+
660
716
  if isinstance(dataset, DataFrame):
661
717
  self._deps = self._batch_inference_validate_snowpark(
662
718
  dataset=dataset,
663
719
  inference_method=inference_method,
664
720
  )
665
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
721
+ assert isinstance(
722
+ dataset._session, Session
723
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
666
724
  transform_kwargs = dict(
667
725
  session=dataset._session,
668
726
  dependencies=self._deps,
669
- drop_input_cols = self._drop_input_cols,
727
+ drop_input_cols=self._drop_input_cols,
670
728
  expected_output_cols_type="float",
671
729
  )
730
+ expected_output_cols = self._align_expected_output_names(
731
+ inference_method, dataset, expected_output_cols, output_cols_prefix
732
+ )
672
733
 
673
734
  elif isinstance(dataset, pd.DataFrame):
674
- transform_kwargs = dict(
675
- snowpark_input_cols = self._snowpark_cols,
676
- drop_input_cols = self._drop_input_cols
677
- )
735
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
678
736
 
679
737
  transform_handlers = ModelTransformerBuilder.build(
680
738
  dataset=dataset,
@@ -686,7 +744,7 @@ class SpectralClustering(BaseTransformer):
686
744
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
687
745
  inference_method=inference_method,
688
746
  input_cols=self.input_cols,
689
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
747
+ expected_output_cols=expected_output_cols,
690
748
  **transform_kwargs
691
749
  )
692
750
  return output_df
@@ -716,7 +774,8 @@ class SpectralClustering(BaseTransformer):
716
774
  Output dataset with log probability of the sample for each class in the model.
717
775
  """
718
776
  super()._check_dataset_type(dataset)
719
- inference_method="predict_log_proba"
777
+ inference_method = "predict_log_proba"
778
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
720
779
 
721
780
  # This dictionary contains optional kwargs for batch inference. These kwargs
722
781
  # are specific to the type of dataset used.
@@ -727,18 +786,20 @@ class SpectralClustering(BaseTransformer):
727
786
  dataset=dataset,
728
787
  inference_method=inference_method,
729
788
  )
730
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
789
+ assert isinstance(
790
+ dataset._session, Session
791
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
731
792
  transform_kwargs = dict(
732
793
  session=dataset._session,
733
794
  dependencies=self._deps,
734
- drop_input_cols = self._drop_input_cols,
795
+ drop_input_cols=self._drop_input_cols,
735
796
  expected_output_cols_type="float",
736
797
  )
798
+ expected_output_cols = self._align_expected_output_names(
799
+ inference_method, dataset, expected_output_cols, output_cols_prefix
800
+ )
737
801
  elif isinstance(dataset, pd.DataFrame):
738
- transform_kwargs = dict(
739
- snowpark_input_cols = self._snowpark_cols,
740
- drop_input_cols = self._drop_input_cols
741
- )
802
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
742
803
 
743
804
  transform_handlers = ModelTransformerBuilder.build(
744
805
  dataset=dataset,
@@ -751,7 +812,7 @@ class SpectralClustering(BaseTransformer):
751
812
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
752
813
  inference_method=inference_method,
753
814
  input_cols=self.input_cols,
754
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
815
+ expected_output_cols=expected_output_cols,
755
816
  **transform_kwargs
756
817
  )
757
818
  return output_df
@@ -777,30 +838,34 @@ class SpectralClustering(BaseTransformer):
777
838
  Output dataset with results of the decision function for the samples in input dataset.
778
839
  """
779
840
  super()._check_dataset_type(dataset)
780
- inference_method="decision_function"
841
+ inference_method = "decision_function"
781
842
 
782
843
  # This dictionary contains optional kwargs for batch inference. These kwargs
783
844
  # are specific to the type of dataset used.
784
845
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
785
846
 
847
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
848
+
786
849
  if isinstance(dataset, DataFrame):
787
850
  self._deps = self._batch_inference_validate_snowpark(
788
851
  dataset=dataset,
789
852
  inference_method=inference_method,
790
853
  )
791
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
854
+ assert isinstance(
855
+ dataset._session, Session
856
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
792
857
  transform_kwargs = dict(
793
858
  session=dataset._session,
794
859
  dependencies=self._deps,
795
- drop_input_cols = self._drop_input_cols,
860
+ drop_input_cols=self._drop_input_cols,
796
861
  expected_output_cols_type="float",
797
862
  )
863
+ expected_output_cols = self._align_expected_output_names(
864
+ inference_method, dataset, expected_output_cols, output_cols_prefix
865
+ )
798
866
 
799
867
  elif isinstance(dataset, pd.DataFrame):
800
- transform_kwargs = dict(
801
- snowpark_input_cols = self._snowpark_cols,
802
- drop_input_cols = self._drop_input_cols
803
- )
868
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
804
869
 
805
870
  transform_handlers = ModelTransformerBuilder.build(
806
871
  dataset=dataset,
@@ -813,7 +878,7 @@ class SpectralClustering(BaseTransformer):
813
878
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
814
879
  inference_method=inference_method,
815
880
  input_cols=self.input_cols,
816
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
881
+ expected_output_cols=expected_output_cols,
817
882
  **transform_kwargs
818
883
  )
819
884
  return output_df
@@ -842,12 +907,14 @@ class SpectralClustering(BaseTransformer):
842
907
  Output dataset with probability of the sample for each class in the model.
843
908
  """
844
909
  super()._check_dataset_type(dataset)
845
- inference_method="score_samples"
910
+ inference_method = "score_samples"
846
911
 
847
912
  # This dictionary contains optional kwargs for batch inference. These kwargs
848
913
  # are specific to the type of dataset used.
849
914
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
850
915
 
916
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
917
+
851
918
  if isinstance(dataset, DataFrame):
852
919
  self._deps = self._batch_inference_validate_snowpark(
853
920
  dataset=dataset,
@@ -860,6 +927,9 @@ class SpectralClustering(BaseTransformer):
860
927
  drop_input_cols = self._drop_input_cols,
861
928
  expected_output_cols_type="float",
862
929
  )
930
+ expected_output_cols = self._align_expected_output_names(
931
+ inference_method, dataset, expected_output_cols, output_cols_prefix
932
+ )
863
933
 
864
934
  elif isinstance(dataset, pd.DataFrame):
865
935
  transform_kwargs = dict(
@@ -878,7 +948,7 @@ class SpectralClustering(BaseTransformer):
878
948
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
879
949
  inference_method=inference_method,
880
950
  input_cols=self.input_cols,
881
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
951
+ expected_output_cols=expected_output_cols,
882
952
  **transform_kwargs
883
953
  )
884
954
  return output_df
@@ -1023,50 +1093,84 @@ class SpectralClustering(BaseTransformer):
1023
1093
  )
1024
1094
  return output_df
1025
1095
 
1096
+
1097
+
1098
+ def to_sklearn(self) -> Any:
1099
+ """Get sklearn.cluster.SpectralClustering object.
1100
+ """
1101
+ if self._sklearn_object is None:
1102
+ self._sklearn_object = self._create_sklearn_object()
1103
+ return self._sklearn_object
1104
+
1105
+ def to_xgboost(self) -> Any:
1106
+ raise exceptions.SnowflakeMLException(
1107
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1108
+ original_exception=AttributeError(
1109
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1110
+ "to_xgboost()",
1111
+ "to_sklearn()"
1112
+ )
1113
+ ),
1114
+ )
1115
+
1116
+ def to_lightgbm(self) -> Any:
1117
+ raise exceptions.SnowflakeMLException(
1118
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1119
+ original_exception=AttributeError(
1120
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1121
+ "to_lightgbm()",
1122
+ "to_sklearn()"
1123
+ )
1124
+ ),
1125
+ )
1026
1126
 
1027
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1127
+ def _get_dependencies(self) -> List[str]:
1128
+ return self._deps
1129
+
1130
+
1131
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1028
1132
  self._model_signature_dict = dict()
1029
1133
 
1030
1134
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1031
1135
 
1032
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1136
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1033
1137
  outputs: List[BaseFeatureSpec] = []
1034
1138
  if hasattr(self, "predict"):
1035
1139
  # keep mypy happy
1036
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1140
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1037
1141
  # For classifier, the type of predict is the same as the type of label
1038
- if self._sklearn_object._estimator_type == 'classifier':
1039
- # label columns is the desired type for output
1142
+ if self._sklearn_object._estimator_type == "classifier":
1143
+ # label columns is the desired type for output
1040
1144
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1041
1145
  # rename the output columns
1042
1146
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1043
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1044
- ([] if self._drop_input_cols else inputs)
1045
- + outputs)
1147
+ self._model_signature_dict["predict"] = ModelSignature(
1148
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1149
+ )
1046
1150
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
1047
1151
  # For outlier models, returns -1 for outliers and 1 for inliers.
1048
- # Clusterer returns int64 cluster labels.
1152
+ # Clusterer returns int64 cluster labels.
1049
1153
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
1050
1154
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
1051
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1052
- ([] if self._drop_input_cols else inputs)
1053
- + outputs)
1054
-
1155
+ self._model_signature_dict["predict"] = ModelSignature(
1156
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1157
+ )
1158
+
1055
1159
  # For regressor, the type of predict is float64
1056
- elif self._sklearn_object._estimator_type == 'regressor':
1160
+ elif self._sklearn_object._estimator_type == "regressor":
1057
1161
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1058
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1059
- ([] if self._drop_input_cols else inputs)
1060
- + outputs)
1061
-
1162
+ self._model_signature_dict["predict"] = ModelSignature(
1163
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1164
+ )
1165
+
1062
1166
  for prob_func in PROB_FUNCTIONS:
1063
1167
  if hasattr(self, prob_func):
1064
1168
  output_cols_prefix: str = f"{prob_func}_"
1065
1169
  output_column_names = self._get_output_column_names(output_cols_prefix)
1066
1170
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1067
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1068
- ([] if self._drop_input_cols else inputs)
1069
- + outputs)
1171
+ self._model_signature_dict[prob_func] = ModelSignature(
1172
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1173
+ )
1070
1174
 
1071
1175
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1072
1176
  items = list(self._model_signature_dict.items())
@@ -1079,10 +1183,10 @@ class SpectralClustering(BaseTransformer):
1079
1183
  """Returns model signature of current class.
1080
1184
 
1081
1185
  Raises:
1082
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1186
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1083
1187
 
1084
1188
  Returns:
1085
- Dict[str, ModelSignature]: each method and its input output signature
1189
+ Dict with each method and its input output signature
1086
1190
  """
1087
1191
  if self._model_signature_dict is None:
1088
1192
  raise exceptions.SnowflakeMLException(
@@ -1090,35 +1194,3 @@ class SpectralClustering(BaseTransformer):
1090
1194
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1091
1195
  )
1092
1196
  return self._model_signature_dict
1093
-
1094
- def to_sklearn(self) -> Any:
1095
- """Get sklearn.cluster.SpectralClustering object.
1096
- """
1097
- if self._sklearn_object is None:
1098
- self._sklearn_object = self._create_sklearn_object()
1099
- return self._sklearn_object
1100
-
1101
- def to_xgboost(self) -> Any:
1102
- raise exceptions.SnowflakeMLException(
1103
- error_code=error_codes.METHOD_NOT_ALLOWED,
1104
- original_exception=AttributeError(
1105
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1106
- "to_xgboost()",
1107
- "to_sklearn()"
1108
- )
1109
- ),
1110
- )
1111
-
1112
- def to_lightgbm(self) -> Any:
1113
- raise exceptions.SnowflakeMLException(
1114
- error_code=error_codes.METHOD_NOT_ALLOWED,
1115
- original_exception=AttributeError(
1116
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1117
- "to_lightgbm()",
1118
- "to_sklearn()"
1119
- )
1120
- ),
1121
- )
1122
-
1123
- def _get_dependencies(self) -> List[str]:
1124
- return self._deps