snowflake-ml-python 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/utils/identifier.py +3 -1
  3. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  4. snowflake/ml/feature_store/feature_store.py +151 -78
  5. snowflake/ml/feature_store/feature_view.py +12 -24
  6. snowflake/ml/fileset/sfcfs.py +56 -50
  7. snowflake/ml/fileset/stage_fs.py +48 -13
  8. snowflake/ml/model/_client/model/model_version_impl.py +2 -50
  9. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  10. snowflake/ml/model/_client/sql/model.py +23 -2
  11. snowflake/ml/model/_client/sql/model_version.py +22 -1
  12. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  13. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  14. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  15. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  16. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  18. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  19. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  20. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  21. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  22. snowflake/ml/model/_packager/model_packager.py +2 -2
  23. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  24. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  25. snowflake/ml/model/type_hints.py +21 -2
  26. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  27. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  29. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  30. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  31. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  32. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  33. snowflake/ml/modeling/cluster/birch.py +195 -123
  34. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  35. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  36. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  37. snowflake/ml/modeling/cluster/k_means.py +195 -123
  38. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  39. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  40. snowflake/ml/modeling/cluster/optics.py +195 -123
  41. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  42. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  43. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  44. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  45. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  46. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  47. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  48. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  49. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  50. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  51. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  52. snowflake/ml/modeling/covariance/oas.py +195 -123
  53. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  54. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  55. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  56. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  57. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  58. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  59. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  60. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  61. snowflake/ml/modeling/decomposition/pca.py +195 -123
  62. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  63. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  64. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  65. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  66. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  67. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  68. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  69. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  70. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  71. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  72. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  73. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  76. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  77. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  78. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  79. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  80. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  81. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  82. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  83. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  84. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  85. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  86. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  87. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  88. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  89. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  90. snowflake/ml/modeling/framework/_utils.py +8 -1
  91. snowflake/ml/modeling/framework/base.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  94. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  95. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  96. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +195 -123
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +195 -123
  105. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  107. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  111. snowflake/ml/modeling/linear_model/lars.py +195 -123
  112. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  113. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  118. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  128. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  131. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  140. snowflake/ml/modeling/manifold/isomap.py +195 -123
  141. snowflake/ml/modeling/manifold/mds.py +195 -123
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  143. snowflake/ml/modeling/manifold/tsne.py +195 -123
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  146. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  147. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  148. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  149. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  150. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  151. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  152. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  153. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  154. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  155. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  156. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  157. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  158. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  159. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  160. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  161. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  162. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  163. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  164. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  165. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  166. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  167. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  168. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  169. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  170. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  171. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  172. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  173. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  174. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  175. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  176. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  178. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  179. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  180. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  181. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  182. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  183. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  184. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  185. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  186. snowflake/ml/modeling/svm/svc.py +195 -123
  187. snowflake/ml/modeling/svm/svr.py +195 -123
  188. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  189. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  190. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  191. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  192. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  193. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  194. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  195. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  196. snowflake/ml/registry/registry.py +1 -1
  197. snowflake/ml/version.py +1 -1
  198. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +68 -57
  199. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +202 -200
  200. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  201. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  202. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  203. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from snowflake.ml.model.model_signature import (
20
20
  FeatureSpec,
21
21
  ModelSignature,
22
22
  _infer_signature,
23
+ _rename_signature_with_snowflake_identifiers,
23
24
  )
24
25
  from snowflake.ml.modeling._internal.estimator_utils import (
25
26
  gather_dependencies,
@@ -330,7 +331,7 @@ class GridSearchCV(BaseTransformer):
330
331
  )
331
332
  self._sklearn_object = model_trainer.train()
332
333
  self._is_fitted = True
333
- self._get_model_signatures(dataset)
334
+ self._generate_model_signatures(dataset)
334
335
  return self
335
336
 
336
337
  def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
@@ -384,6 +385,9 @@ class GridSearchCV(BaseTransformer):
384
385
 
385
386
  Returns:
386
387
  Transformed dataset.
388
+
389
+ Raises:
390
+ SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
387
391
  """
388
392
  super()._check_dataset_type(dataset)
389
393
 
@@ -396,9 +400,21 @@ class GridSearchCV(BaseTransformer):
396
400
  expected_type_inferred = ""
397
401
  # infer the datatype from label columns
398
402
  if "predict" in self.model_signatures:
399
- expected_type_inferred = convert_sp_to_sf_type(
400
- self.model_signatures["predict"].outputs[0].as_snowpark_type()
401
- )
403
+ # Batch inference takes a single expected output column type. Use the first columns type for now.
404
+ label_cols_signatures = [
405
+ row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
406
+ ]
407
+ if len(label_cols_signatures) == 0:
408
+ error_str = (
409
+ f"Output columns {self.output_cols} do not match"
410
+ f"model signatures {self.model_signatures['predict'].outputs}."
411
+ )
412
+ raise exceptions.SnowflakeMLException(
413
+ error_code=error_codes.INVALID_ATTRIBUTE,
414
+ original_exception=ValueError(error_str),
415
+ )
416
+
417
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
402
418
  self._deps = self._batch_inference_validate_snowpark(
403
419
  dataset=dataset,
404
420
  inference_method=inference_method,
@@ -785,12 +801,22 @@ class GridSearchCV(BaseTransformer):
785
801
 
786
802
  return output_score
787
803
 
788
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
804
+ def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
805
+ """
806
+ Get sklearn.model_selection.GridSearchCV object.
807
+ """
808
+ assert self._sklearn_object is not None
809
+ return self._sklearn_object
810
+
811
+ def _get_dependencies(self) -> List[str]:
812
+ return self._deps
813
+
814
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
789
815
  self._model_signature_dict = dict()
790
816
 
791
817
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
792
818
 
793
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
819
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
794
820
  outputs: List[BaseFeatureSpec] = []
795
821
  if hasattr(self, "predict"):
796
822
  # keep mypy happy
@@ -798,18 +824,20 @@ class GridSearchCV(BaseTransformer):
798
824
  # For classifier, the type of predict is the same as the type of label
799
825
  if self._sklearn_object._estimator_type == "classifier":
800
826
  # label columns is the desired type for output
801
- outputs = list(_infer_signature(dataset[self.label_cols], "output"))
827
+ outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
802
828
  # rename the output columns
803
829
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
804
830
  self._model_signature_dict["predict"] = ModelSignature(
805
831
  inputs, ([] if self._drop_input_cols else inputs) + outputs
806
832
  )
833
+
807
834
  # For regressor, the type of predict is float64
808
835
  elif self._sklearn_object._estimator_type == "regressor":
809
836
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
810
837
  self._model_signature_dict["predict"] = ModelSignature(
811
838
  inputs, ([] if self._drop_input_cols else inputs) + outputs
812
839
  )
840
+
813
841
  for prob_func in PROB_FUNCTIONS:
814
842
  if hasattr(self, prob_func):
815
843
  output_cols_prefix: str = f"{prob_func}_"
@@ -819,6 +847,12 @@ class GridSearchCV(BaseTransformer):
819
847
  inputs, ([] if self._drop_input_cols else inputs) + outputs
820
848
  )
821
849
 
850
+ # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
851
+ items = list(self._model_signature_dict.items())
852
+ for method, signature in items:
853
+ signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
854
+ self._model_signature_dict[method] = signature
855
+
822
856
  @property
823
857
  def model_signatures(self) -> Dict[str, ModelSignature]:
824
858
  """Returns model signature of current class.
@@ -827,7 +861,7 @@ class GridSearchCV(BaseTransformer):
827
861
  SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
828
862
 
829
863
  Returns:
830
- Dict[str, ModelSignature]: each method and its input output signature
864
+ each method and its input output signature
831
865
  """
832
866
  if self._model_signature_dict is None:
833
867
  raise exceptions.SnowflakeMLException(
@@ -835,13 +869,3 @@ class GridSearchCV(BaseTransformer):
835
869
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
836
870
  )
837
871
  return self._model_signature_dict
838
-
839
- def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
840
- """
841
- Get sklearn.model_selection.GridSearchCV object.
842
- """
843
- assert self._sklearn_object is not None
844
- return self._sklearn_object
845
-
846
- def _get_dependencies(self) -> List[str]:
847
- return self._deps
@@ -17,6 +17,7 @@ from snowflake.ml.model.model_signature import (
17
17
  FeatureSpec,
18
18
  ModelSignature,
19
19
  _infer_signature,
20
+ _rename_signature_with_snowflake_identifiers,
20
21
  )
21
22
  from snowflake.ml.modeling._internal.estimator_utils import (
22
23
  gather_dependencies,
@@ -343,7 +344,7 @@ class RandomizedSearchCV(BaseTransformer):
343
344
  )
344
345
  self._sklearn_object = model_trainer.train()
345
346
  self._is_fitted = True
346
- self._get_model_signatures(dataset)
347
+ self._generate_model_signatures(dataset)
347
348
  return self
348
349
 
349
350
  def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
@@ -383,6 +384,9 @@ class RandomizedSearchCV(BaseTransformer):
383
384
 
384
385
  Returns:
385
386
  Transformed dataset.
387
+
388
+ Raises:
389
+ SnowflakeMLException: when the output column(s) doesn't exist in the model signature, raise error
386
390
  """
387
391
  super()._check_dataset_type(dataset)
388
392
 
@@ -395,9 +399,21 @@ class RandomizedSearchCV(BaseTransformer):
395
399
  expected_type_inferred = ""
396
400
  # infer the datatype from label columns
397
401
  if "predict" in self.model_signatures:
398
- expected_type_inferred = convert_sp_to_sf_type(
399
- self.model_signatures["predict"].outputs[0].as_snowpark_type()
400
- )
402
+ # Batch inference takes a single expected output column type. Use the first columns type for now.
403
+ label_cols_signatures = [
404
+ row for row in self.model_signatures["predict"].outputs if row.name in self.output_cols
405
+ ]
406
+ if len(label_cols_signatures) == 0:
407
+ error_str = (
408
+ f"Output columns {self.output_cols} do not match"
409
+ f"model signatures {self.model_signatures['predict'].outputs}."
410
+ )
411
+ raise exceptions.SnowflakeMLException(
412
+ error_code=error_codes.INVALID_ATTRIBUTE,
413
+ original_exception=ValueError(error_str),
414
+ )
415
+
416
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
401
417
  self._deps = self._batch_inference_validate_snowpark(
402
418
  dataset=dataset,
403
419
  inference_method=inference_method,
@@ -780,12 +796,22 @@ class RandomizedSearchCV(BaseTransformer):
780
796
 
781
797
  return output_score
782
798
 
783
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
799
+ def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
800
+ """
801
+ Get sklearn.model_selection.RandomizedSearchCV object.
802
+ """
803
+ assert self._sklearn_object is not None
804
+ return self._sklearn_object
805
+
806
+ def _get_dependencies(self) -> List[str]:
807
+ return self._deps
808
+
809
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
784
810
  self._model_signature_dict = dict()
785
811
 
786
812
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
787
813
 
788
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
814
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
789
815
  outputs: List[BaseFeatureSpec] = []
790
816
  if hasattr(self, "predict"):
791
817
  # keep mypy happy
@@ -793,18 +819,20 @@ class RandomizedSearchCV(BaseTransformer):
793
819
  # For classifier, the type of predict is the same as the type of label
794
820
  if self._sklearn_object._estimator_type == "classifier":
795
821
  # label columns is the desired type for output
796
- outputs = list(_infer_signature(dataset[self.label_cols], "output"))
822
+ outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
797
823
  # rename the output columns
798
824
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
799
825
  self._model_signature_dict["predict"] = ModelSignature(
800
826
  inputs, ([] if self._drop_input_cols else inputs) + outputs
801
827
  )
828
+
802
829
  # For regressor, the type of predict is float64
803
830
  elif self._sklearn_object._estimator_type == "regressor":
804
831
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
805
832
  self._model_signature_dict["predict"] = ModelSignature(
806
833
  inputs, ([] if self._drop_input_cols else inputs) + outputs
807
834
  )
835
+
808
836
  for prob_func in PROB_FUNCTIONS:
809
837
  if hasattr(self, prob_func):
810
838
  output_cols_prefix: str = f"{prob_func}_"
@@ -814,6 +842,12 @@ class RandomizedSearchCV(BaseTransformer):
814
842
  inputs, ([] if self._drop_input_cols else inputs) + outputs
815
843
  )
816
844
 
845
+ # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
846
+ items = list(self._model_signature_dict.items())
847
+ for method, signature in items:
848
+ signature._outputs = _rename_signature_with_snowflake_identifiers(signature._outputs)
849
+ self._model_signature_dict[method] = signature
850
+
817
851
  @property
818
852
  def model_signatures(self) -> Dict[str, ModelSignature]:
819
853
  """Returns model signature of current class.
@@ -822,7 +856,7 @@ class RandomizedSearchCV(BaseTransformer):
822
856
  SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
823
857
 
824
858
  Returns:
825
- Dict[str, ModelSignature]: each method and its input output signature
859
+ each method and its input output signature
826
860
  """
827
861
  if self._model_signature_dict is None:
828
862
  raise exceptions.SnowflakeMLException(
@@ -830,13 +864,3 @@ class RandomizedSearchCV(BaseTransformer):
830
864
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
831
865
  )
832
866
  return self._model_signature_dict
833
-
834
- def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
835
- """
836
- Get sklearn.model_selection.RandomizedSearchCV object.
837
- """
838
- assert self._sklearn_object is not None
839
- return self._sklearn_object
840
-
841
- def _get_dependencies(self) -> List[str]:
842
- return self._deps