snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/utils/formatting.py +1 -1
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/feature_store/feature_store.py +166 -184
  10. snowflake/ml/feature_store/feature_view.py +12 -24
  11. snowflake/ml/fileset/sfcfs.py +56 -50
  12. snowflake/ml/fileset/stage_fs.py +48 -13
  13. snowflake/ml/model/_client/model/model_version_impl.py +6 -49
  14. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  15. snowflake/ml/model/_client/sql/model.py +23 -2
  16. snowflake/ml/model/_client/sql/model_version.py +22 -1
  17. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
  18. snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
  19. snowflake/ml/model/_model_composer/model_composer.py +7 -5
  20. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  22. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  24. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  25. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
  29. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  32. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  33. snowflake/ml/model/_packager/model_packager.py +2 -2
  34. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +21 -2
  38. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  40. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  41. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  46. snowflake/ml/modeling/cluster/birch.py +195 -123
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  48. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  50. snowflake/ml/modeling/cluster/k_means.py +195 -123
  51. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  53. snowflake/ml/modeling/cluster/optics.py +195 -123
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  57. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  64. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  65. snowflake/ml/modeling/covariance/oas.py +195 -123
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  69. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  74. snowflake/ml/modeling/decomposition/pca.py +195 -123
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  103. snowflake/ml/modeling/framework/_utils.py +8 -1
  104. snowflake/ml/modeling/framework/base.py +24 -6
  105. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  107. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  108. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  109. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  110. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
  119. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  121. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  125. snowflake/ml/modeling/linear_model/lars.py +195 -123
  126. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  127. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  132. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  142. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  145. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  154. snowflake/ml/modeling/manifold/isomap.py +195 -123
  155. snowflake/ml/modeling/manifold/mds.py +195 -123
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  157. snowflake/ml/modeling/manifold/tsne.py +195 -123
  158. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  159. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  160. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  161. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  162. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  163. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  164. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  165. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  166. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  167. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  168. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  169. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  170. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  171. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  172. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  173. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  174. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  175. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  176. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  177. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  178. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  179. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  180. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  181. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  182. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  183. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  184. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  185. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  186. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  187. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  188. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  189. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  190. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  191. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  192. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  193. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  194. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  195. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  196. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  197. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  198. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  199. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  200. snowflake/ml/modeling/svm/svc.py +195 -123
  201. snowflake/ml/modeling/svm/svr.py +195 -123
  202. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  203. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  204. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  205. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  206. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  207. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  208. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  209. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  210. snowflake/ml/registry/_manager/model_manager.py +5 -1
  211. snowflake/ml/registry/model_registry.py +99 -26
  212. snowflake/ml/registry/registry.py +3 -2
  213. snowflake/ml/version.py +1 -1
  214. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
  215. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
  216. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  217. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  218. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  219. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from snowflake.ml.model.model_signature import (
20
20
  FeatureSpec,
21
21
  ModelSignature,
22
22
  _infer_signature,
23
+ _rename_signature_with_snowflake_identifiers,
23
24
  )
24
25
  from snowflake.ml.modeling._internal.estimator_utils import (
25
26
  gather_dependencies,
@@ -330,7 +331,7 @@ class GridSearchCV(BaseTransformer):
330
331
  )
331
332
  self._sklearn_object = model_trainer.train()
332
333
  self._is_fitted = True
333
- self._get_model_signatures(dataset)
334
+ self._generate_model_signatures(dataset)
334
335
  return self
335
336
 
336
337
  def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
@@ -384,6 +385,9 @@ class GridSearchCV(BaseTransformer):
384
385
 
385
386
  Returns:
386
387
  Transformed dataset.
388
+
389
+ Raises:
390
+ SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
387
391
  """
388
392
  super()._check_dataset_type(dataset)
389
393
 
@@ -396,9 +400,21 @@ class GridSearchCV(BaseTransformer):
396
400
  expected_type_inferred = ""
397
401
  # infer the datatype from label columns
398
402
  if "predict" in self.model_signatures:
399
- expected_type_inferred = convert_sp_to_sf_type(
400
- self.model_signatures["predict"].outputs[0].as_snowpark_type()
401
- )
403
+ # Batch inference takes a single expected output column type. Use the first columns type for now.
404
+ label_cols_signatures = [
405
+ row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
406
+ ]
407
+ if len(label_cols_signatures) == 0:
408
+ error_str = (
409
+ f"Output columns {self.output_cols} do not match"
410
+ f"model signatures {self.model_signatures['predict'].outputs}."
411
+ )
412
+ raise exceptions.SnowflakeMLException(
413
+ error_code=error_codes.INVALID_ATTRIBUTE,
414
+ original_exception=ValueError(error_str),
415
+ )
416
+
417
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
402
418
  self._deps = self._batch_inference_validate_snowpark(
403
419
  dataset=dataset,
404
420
  inference_method=inference_method,
@@ -785,12 +801,22 @@ class GridSearchCV(BaseTransformer):
785
801
 
786
802
  return output_score
787
803
 
788
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
804
+ def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
805
+ """
806
+ Get sklearn.model_selection.GridSearchCV object.
807
+ """
808
+ assert self._sklearn_object is not None
809
+ return self._sklearn_object
810
+
811
+ def _get_dependencies(self) -> List[str]:
812
+ return self._deps
813
+
814
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
789
815
  self._model_signature_dict = dict()
790
816
 
791
817
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
792
818
 
793
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
819
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
794
820
  outputs: List[BaseFeatureSpec] = []
795
821
  if hasattr(self, "predict"):
796
822
  # keep mypy happy
@@ -798,18 +824,20 @@ class GridSearchCV(BaseTransformer):
798
824
  # For classifier, the type of predict is the same as the type of label
799
825
  if self._sklearn_object._estimator_type == "classifier":
800
826
  # label columns is the desired type for output
801
- outputs = list(_infer_signature(dataset[self.label_cols], "output"))
827
+ outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
802
828
  # rename the output columns
803
829
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
804
830
  self._model_signature_dict["predict"] = ModelSignature(
805
831
  inputs, ([] if self._drop_input_cols else inputs) + outputs
806
832
  )
833
+
807
834
  # For regressor, the type of predict is float64
808
835
  elif self._sklearn_object._estimator_type == "regressor":
809
836
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
810
837
  self._model_signature_dict["predict"] = ModelSignature(
811
838
  inputs, ([] if self._drop_input_cols else inputs) + outputs
812
839
  )
840
+
813
841
  for prob_func in PROB_FUNCTIONS:
814
842
  if hasattr(self, prob_func):
815
843
  output_cols_prefix: str = f"{prob_func}_"
@@ -819,6 +847,12 @@ class GridSearchCV(BaseTransformer):
819
847
  inputs, ([] if self._drop_input_cols else inputs) + outputs
820
848
  )
821
849
 
850
+ # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
851
+ items = list(self._model_signature_dict.items())
852
+ for method, signature in items:
853
+ signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
854
+ self._model_signature_dict[method] = signature
855
+
822
856
  @property
823
857
  def model_signatures(self) -> Dict[str, ModelSignature]:
824
858
  """Returns model signature of current class.
@@ -827,7 +861,7 @@ class GridSearchCV(BaseTransformer):
827
861
  SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
828
862
 
829
863
  Returns:
830
- Dict[str, ModelSignature]: each method and its input output signature
864
+ each method and its input output signature
831
865
  """
832
866
  if self._model_signature_dict is None:
833
867
  raise exceptions.SnowflakeMLException(
@@ -835,13 +869,3 @@ class GridSearchCV(BaseTransformer):
835
869
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
836
870
  )
837
871
  return self._model_signature_dict
838
-
839
- def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
840
- """
841
- Get sklearn.model_selection.GridSearchCV object.
842
- """
843
- assert self._sklearn_object is not None
844
- return self._sklearn_object
845
-
846
- def _get_dependencies(self) -> List[str]:
847
- return self._deps
@@ -17,6 +17,7 @@ from snowflake.ml.model.model_signature import (
17
17
  FeatureSpec,
18
18
  ModelSignature,
19
19
  _infer_signature,
20
+ _rename_signature_with_snowflake_identifiers,
20
21
  )
21
22
  from snowflake.ml.modeling._internal.estimator_utils import (
22
23
  gather_dependencies,
@@ -343,7 +344,7 @@ class RandomizedSearchCV(BaseTransformer):
343
344
  )
344
345
  self._sklearn_object = model_trainer.train()
345
346
  self._is_fitted = True
346
- self._get_model_signatures(dataset)
347
+ self._generate_model_signatures(dataset)
347
348
  return self
348
349
 
349
350
  def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
@@ -383,6 +384,9 @@ class RandomizedSearchCV(BaseTransformer):
383
384
 
384
385
  Returns:
385
386
  Transformed dataset.
387
+
388
+ Raises:
389
+ SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
386
390
  """
387
391
  super()._check_dataset_type(dataset)
388
392
 
@@ -395,9 +399,21 @@ class RandomizedSearchCV(BaseTransformer):
395
399
  expected_type_inferred = ""
396
400
  # infer the datatype from label columns
397
401
  if "predict" in self.model_signatures:
398
- expected_type_inferred = convert_sp_to_sf_type(
399
- self.model_signatures["predict"].outputs[0].as_snowpark_type()
400
- )
402
+ # Batch inference takes a single expected output column type. Use the first columns type for now.
403
+ label_cols_signatures = [
404
+ row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
405
+ ]
406
+ if len(label_cols_signatures) == 0:
407
+ error_str = (
408
+ f"Output columns {self.output_cols} do not match"
409
+ f"model signatures {self.model_signatures['predict'].outputs}."
410
+ )
411
+ raise exceptions.SnowflakeMLException(
412
+ error_code=error_codes.INVALID_ATTRIBUTE,
413
+ original_exception=ValueError(error_str),
414
+ )
415
+
416
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
401
417
  self._deps = self._batch_inference_validate_snowpark(
402
418
  dataset=dataset,
403
419
  inference_method=inference_method,
@@ -780,12 +796,22 @@ class RandomizedSearchCV(BaseTransformer):
780
796
 
781
797
  return output_score
782
798
 
783
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
799
+ def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
800
+ """
801
+ Get sklearn.model_selection.RandomizedSearchCV object.
802
+ """
803
+ assert self._sklearn_object is not None
804
+ return self._sklearn_object
805
+
806
+ def _get_dependencies(self) -> List[str]:
807
+ return self._deps
808
+
809
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
784
810
  self._model_signature_dict = dict()
785
811
 
786
812
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
787
813
 
788
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
814
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
789
815
  outputs: List[BaseFeatureSpec] = []
790
816
  if hasattr(self, "predict"):
791
817
  # keep mypy happy
@@ -793,18 +819,20 @@ class RandomizedSearchCV(BaseTransformer):
793
819
  # For classifier, the type of predict is the same as the type of label
794
820
  if self._sklearn_object._estimator_type == "classifier":
795
821
  # label columns is the desired type for output
796
- outputs = list(_infer_signature(dataset[self.label_cols], "output"))
822
+ outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
797
823
  # rename the output columns
798
824
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
799
825
  self._model_signature_dict["predict"] = ModelSignature(
800
826
  inputs, ([] if self._drop_input_cols else inputs) + outputs
801
827
  )
828
+
802
829
  # For regressor, the type of predict is float64
803
830
  elif self._sklearn_object._estimator_type == "regressor":
804
831
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
805
832
  self._model_signature_dict["predict"] = ModelSignature(
806
833
  inputs, ([] if self._drop_input_cols else inputs) + outputs
807
834
  )
835
+
808
836
  for prob_func in PROB_FUNCTIONS:
809
837
  if hasattr(self, prob_func):
810
838
  output_cols_prefix: str = f"{prob_func}_"
@@ -814,6 +842,12 @@ class RandomizedSearchCV(BaseTransformer):
814
842
  inputs, ([] if self._drop_input_cols else inputs) + outputs
815
843
  )
816
844
 
845
+ # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
846
+ items = list(self._model_signature_dict.items())
847
+ for method, signature in items:
848
+ signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
849
+ self._model_signature_dict[method] = signature
850
+
817
851
  @property
818
852
  def model_signatures(self) -> Dict[str, ModelSignature]:
819
853
  """Returns model signature of current class.
@@ -822,7 +856,7 @@ class RandomizedSearchCV(BaseTransformer):
822
856
  SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
823
857
 
824
858
  Returns:
825
- Dict[str, ModelSignature]: each method and its input output signature
859
+ each method and its input output signature
826
860
  """
827
861
  if self._model_signature_dict is None:
828
862
  raise exceptions.SnowflakeMLException(
@@ -830,13 +864,3 @@ class RandomizedSearchCV(BaseTransformer):
830
864
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
831
865
  )
832
866
  return self._model_signature_dict
833
-
834
- def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
835
- """
836
- Get sklearn.model_selection.RandomizedSearchCV object.
837
- """
838
- assert self._sklearn_object is not None
839
- return self._sklearn_object
840
-
841
- def _get_dependencies(self) -> List[str]:
842
- return self._deps