snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/utils/formatting.py +1 -1
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/feature_store/feature_store.py +166 -184
  10. snowflake/ml/feature_store/feature_view.py +12 -24
  11. snowflake/ml/fileset/sfcfs.py +56 -50
  12. snowflake/ml/fileset/stage_fs.py +48 -13
  13. snowflake/ml/model/_client/model/model_version_impl.py +6 -49
  14. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  15. snowflake/ml/model/_client/sql/model.py +23 -2
  16. snowflake/ml/model/_client/sql/model_version.py +22 -1
  17. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
  18. snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
  19. snowflake/ml/model/_model_composer/model_composer.py +7 -5
  20. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  22. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  24. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  25. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
  29. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  32. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  33. snowflake/ml/model/_packager/model_packager.py +2 -2
  34. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +21 -2
  38. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  40. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  41. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  46. snowflake/ml/modeling/cluster/birch.py +195 -123
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  48. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  50. snowflake/ml/modeling/cluster/k_means.py +195 -123
  51. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  53. snowflake/ml/modeling/cluster/optics.py +195 -123
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  57. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  64. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  65. snowflake/ml/modeling/covariance/oas.py +195 -123
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  69. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  74. snowflake/ml/modeling/decomposition/pca.py +195 -123
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  103. snowflake/ml/modeling/framework/_utils.py +8 -1
  104. snowflake/ml/modeling/framework/base.py +24 -6
  105. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  107. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  108. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  109. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  110. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
  119. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  121. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  125. snowflake/ml/modeling/linear_model/lars.py +195 -123
  126. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  127. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  132. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  142. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  145. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  154. snowflake/ml/modeling/manifold/isomap.py +195 -123
  155. snowflake/ml/modeling/manifold/mds.py +195 -123
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  157. snowflake/ml/modeling/manifold/tsne.py +195 -123
  158. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  159. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  160. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  161. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  162. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  163. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  164. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  165. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  166. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  167. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  168. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  169. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  170. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  171. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  172. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  173. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  174. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  175. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  176. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  177. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  178. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  179. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  180. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  181. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  182. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  183. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  184. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  185. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  186. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  187. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  188. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  189. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  190. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  191. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  192. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  193. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  194. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  195. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  196. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  197. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  198. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  199. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  200. snowflake/ml/modeling/svm/svc.py +195 -123
  201. snowflake/ml/modeling/svm/svr.py +195 -123
  202. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  203. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  204. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  205. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  206. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  207. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  208. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  209. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  210. snowflake/ml/registry/_manager/model_manager.py +5 -1
  211. snowflake/ml/registry/model_registry.py +99 -26
  212. snowflake/ml/registry/registry.py +3 -2
  213. snowflake/ml/version.py +1 -1
  214. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
  215. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
  216. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  217. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  218. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  219. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -375,12 +374,7 @@ class RandomForestClassifier(BaseTransformer):
375
374
  )
376
375
  return selected_cols
377
376
 
378
- @telemetry.send_api_usage_telemetry(
379
- project=_PROJECT,
380
- subproject=_SUBPROJECT,
381
- custom_tags=dict([("autogen", True)]),
382
- )
383
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RandomForestClassifier":
377
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "RandomForestClassifier":
384
378
  """Build a forest of trees from the training set (X, y)
385
379
  For more details on this function, see [sklearn.ensemble.RandomForestClassifier.fit]
386
380
  (https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier.fit)
@@ -407,12 +401,14 @@ class RandomForestClassifier(BaseTransformer):
407
401
 
408
402
  self._snowpark_cols = dataset.select(self.input_cols).columns
409
403
 
410
- # If we are already in a stored procedure, no need to kick off another one.
404
+ # If we are already in a stored procedure, no need to kick off another one.
411
405
  if SNOWML_SPROC_ENV in os.environ:
412
406
  statement_params = telemetry.get_function_usage_statement_params(
413
407
  project=_PROJECT,
414
408
  subproject=_SUBPROJECT,
415
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), RandomForestClassifier.__class__.__name__),
409
+ function_name=telemetry.get_statement_params_full_func_name(
410
+ inspect.currentframe(), RandomForestClassifier.__class__.__name__
411
+ ),
416
412
  api_calls=[Session.call],
417
413
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
418
414
  )
@@ -433,7 +429,7 @@ class RandomForestClassifier(BaseTransformer):
433
429
  )
434
430
  self._sklearn_object = model_trainer.train()
435
431
  self._is_fitted = True
436
- self._get_model_signatures(dataset)
432
+ self._generate_model_signatures(dataset)
437
433
  return self
438
434
 
439
435
  def _batch_inference_validate_snowpark(
@@ -509,7 +505,9 @@ class RandomForestClassifier(BaseTransformer):
509
505
  # when it is classifier, infer the datatype from label columns
510
506
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
511
507
  # Batch inference takes a single expected output column type. Use the first columns type for now.
512
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
508
+ label_cols_signatures = [
509
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
510
+ ]
513
511
  if len(label_cols_signatures) == 0:
514
512
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
515
513
  raise exceptions.SnowflakeMLException(
@@ -517,25 +515,22 @@ class RandomForestClassifier(BaseTransformer):
517
515
  original_exception=ValueError(error_str),
518
516
  )
519
517
 
520
- expected_type_inferred = convert_sp_to_sf_type(
521
- label_cols_signatures[0].as_snowpark_type()
522
- )
518
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
523
519
 
524
520
  self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
525
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
521
+ assert isinstance(
522
+ dataset._session, Session
523
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
526
524
 
527
525
  transform_kwargs = dict(
528
- session = dataset._session,
529
- dependencies = self._deps,
530
- drop_input_cols = self._drop_input_cols,
531
- expected_output_cols_type = expected_type_inferred,
526
+ session=dataset._session,
527
+ dependencies=self._deps,
528
+ drop_input_cols=self._drop_input_cols,
529
+ expected_output_cols_type=expected_type_inferred,
532
530
  )
533
531
 
534
532
  elif isinstance(dataset, pd.DataFrame):
535
- transform_kwargs = dict(
536
- snowpark_input_cols = self._snowpark_cols,
537
- drop_input_cols = self._drop_input_cols
538
- )
533
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
539
534
 
540
535
  transform_handlers = ModelTransformerBuilder.build(
541
536
  dataset=dataset,
@@ -575,7 +570,7 @@ class RandomForestClassifier(BaseTransformer):
575
570
  Transformed dataset.
576
571
  """
577
572
  super()._check_dataset_type(dataset)
578
- inference_method="transform"
573
+ inference_method = "transform"
579
574
 
580
575
  # This dictionary contains optional kwargs for batch inference. These kwargs
581
576
  # are specific to the type of dataset used.
@@ -612,17 +607,14 @@ class RandomForestClassifier(BaseTransformer):
612
607
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
613
608
 
614
609
  transform_kwargs = dict(
615
- session = dataset._session,
616
- dependencies = self._deps,
617
- drop_input_cols = self._drop_input_cols,
618
- expected_output_cols_type = expected_dtype,
610
+ session=dataset._session,
611
+ dependencies=self._deps,
612
+ drop_input_cols=self._drop_input_cols,
613
+ expected_output_cols_type=expected_dtype,
619
614
  )
620
615
 
621
616
  elif isinstance(dataset, pd.DataFrame):
622
- transform_kwargs = dict(
623
- snowpark_input_cols = self._snowpark_cols,
624
- drop_input_cols = self._drop_input_cols
625
- )
617
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
626
618
 
627
619
  transform_handlers = ModelTransformerBuilder.build(
628
620
  dataset=dataset,
@@ -641,7 +633,11 @@ class RandomForestClassifier(BaseTransformer):
641
633
  return output_df
642
634
 
643
635
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
644
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
636
+ def fit_predict(
637
+ self,
638
+ dataset: Union[DataFrame, pd.DataFrame],
639
+ output_cols_prefix: str = "fit_predict_",
640
+ ) -> Union[DataFrame, pd.DataFrame]:
645
641
  """ Method not supported for this class.
646
642
 
647
643
 
@@ -666,7 +662,9 @@ class RandomForestClassifier(BaseTransformer):
666
662
  )
667
663
  output_result, fitted_estimator = model_trainer.train_fit_predict(
668
664
  drop_input_cols=self._drop_input_cols,
669
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
665
+ expected_output_cols_list=(
666
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
667
+ ),
670
668
  )
671
669
  self._sklearn_object = fitted_estimator
672
670
  self._is_fitted = True
@@ -683,6 +681,62 @@ class RandomForestClassifier(BaseTransformer):
683
681
  assert self._sklearn_object is not None
684
682
  return self._sklearn_object.embedding_
685
683
 
684
+
685
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
686
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
687
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
688
+ """
689
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
690
+ # The following condition is introduced for kneighbors methods, and not used in other methods
691
+ if output_cols:
692
+ output_cols = [
693
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
694
+ for c in output_cols
695
+ ]
696
+ elif getattr(self._sklearn_object, "classes_", None) is None:
697
+ output_cols = [output_cols_prefix]
698
+ elif self._sklearn_object is not None:
699
+ classes = self._sklearn_object.classes_
700
+ if isinstance(classes, numpy.ndarray):
701
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
702
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
703
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
704
+ output_cols = []
705
+ for i, cl in enumerate(classes):
706
+ # For binary classification, there is only one output column for each class
707
+ # ndarray as the two classes are complementary.
708
+ if len(cl) == 2:
709
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
710
+ else:
711
+ output_cols.extend([
712
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
713
+ ])
714
+ else:
715
+ output_cols = []
716
+
717
+ # Make sure column names are valid snowflake identifiers.
718
+ assert output_cols is not None # Make MyPy happy
719
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
720
+
721
+ return rv
722
+
723
+ def _align_expected_output_names(
724
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
725
+ ) -> List[str]:
726
+ # in case the inferred output column names dimension is different
727
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
728
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
729
+ output_df_columns = list(output_df_pd.columns)
730
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
731
+ if self.sample_weight_col:
732
+ output_df_columns_set -= set(self.sample_weight_col)
733
+ # if the dimension of inferred output column names is correct; use it
734
+ if len(expected_output_cols_list) == len(output_df_columns_set):
735
+ return expected_output_cols_list
736
+ # otherwise, use the sklearn estimator's output
737
+ else:
738
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
739
+
686
740
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
687
741
  @telemetry.send_api_usage_telemetry(
688
742
  project=_PROJECT,
@@ -715,24 +769,28 @@ class RandomForestClassifier(BaseTransformer):
715
769
  # are specific to the type of dataset used.
716
770
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
717
771
 
772
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
773
+
718
774
  if isinstance(dataset, DataFrame):
719
775
  self._deps = self._batch_inference_validate_snowpark(
720
776
  dataset=dataset,
721
777
  inference_method=inference_method,
722
778
  )
723
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
779
+ assert isinstance(
780
+ dataset._session, Session
781
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
724
782
  transform_kwargs = dict(
725
783
  session=dataset._session,
726
784
  dependencies=self._deps,
727
- drop_input_cols = self._drop_input_cols,
785
+ drop_input_cols=self._drop_input_cols,
728
786
  expected_output_cols_type="float",
729
787
  )
788
+ expected_output_cols = self._align_expected_output_names(
789
+ inference_method, dataset, expected_output_cols, output_cols_prefix
790
+ )
730
791
 
731
792
  elif isinstance(dataset, pd.DataFrame):
732
- transform_kwargs = dict(
733
- snowpark_input_cols = self._snowpark_cols,
734
- drop_input_cols = self._drop_input_cols
735
- )
793
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
736
794
 
737
795
  transform_handlers = ModelTransformerBuilder.build(
738
796
  dataset=dataset,
@@ -744,7 +802,7 @@ class RandomForestClassifier(BaseTransformer):
744
802
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
745
803
  inference_method=inference_method,
746
804
  input_cols=self.input_cols,
747
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
805
+ expected_output_cols=expected_output_cols,
748
806
  **transform_kwargs
749
807
  )
750
808
  return output_df
@@ -776,7 +834,8 @@ class RandomForestClassifier(BaseTransformer):
776
834
  Output dataset with log probability of the sample for each class in the model.
777
835
  """
778
836
  super()._check_dataset_type(dataset)
779
- inference_method="predict_log_proba"
837
+ inference_method = "predict_log_proba"
838
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
780
839
 
781
840
  # This dictionary contains optional kwargs for batch inference. These kwargs
782
841
  # are specific to the type of dataset used.
@@ -787,18 +846,20 @@ class RandomForestClassifier(BaseTransformer):
787
846
  dataset=dataset,
788
847
  inference_method=inference_method,
789
848
  )
790
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
849
+ assert isinstance(
850
+ dataset._session, Session
851
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
791
852
  transform_kwargs = dict(
792
853
  session=dataset._session,
793
854
  dependencies=self._deps,
794
- drop_input_cols = self._drop_input_cols,
855
+ drop_input_cols=self._drop_input_cols,
795
856
  expected_output_cols_type="float",
796
857
  )
858
+ expected_output_cols = self._align_expected_output_names(
859
+ inference_method, dataset, expected_output_cols, output_cols_prefix
860
+ )
797
861
  elif isinstance(dataset, pd.DataFrame):
798
- transform_kwargs = dict(
799
- snowpark_input_cols = self._snowpark_cols,
800
- drop_input_cols = self._drop_input_cols
801
- )
862
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
802
863
 
803
864
  transform_handlers = ModelTransformerBuilder.build(
804
865
  dataset=dataset,
@@ -811,7 +872,7 @@ class RandomForestClassifier(BaseTransformer):
811
872
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
812
873
  inference_method=inference_method,
813
874
  input_cols=self.input_cols,
814
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
875
+ expected_output_cols=expected_output_cols,
815
876
  **transform_kwargs
816
877
  )
817
878
  return output_df
@@ -837,30 +898,34 @@ class RandomForestClassifier(BaseTransformer):
837
898
  Output dataset with results of the decision function for the samples in input dataset.
838
899
  """
839
900
  super()._check_dataset_type(dataset)
840
- inference_method="decision_function"
901
+ inference_method = "decision_function"
841
902
 
842
903
  # This dictionary contains optional kwargs for batch inference. These kwargs
843
904
  # are specific to the type of dataset used.
844
905
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
845
906
 
907
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
908
+
846
909
  if isinstance(dataset, DataFrame):
847
910
  self._deps = self._batch_inference_validate_snowpark(
848
911
  dataset=dataset,
849
912
  inference_method=inference_method,
850
913
  )
851
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
914
+ assert isinstance(
915
+ dataset._session, Session
916
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
852
917
  transform_kwargs = dict(
853
918
  session=dataset._session,
854
919
  dependencies=self._deps,
855
- drop_input_cols = self._drop_input_cols,
920
+ drop_input_cols=self._drop_input_cols,
856
921
  expected_output_cols_type="float",
857
922
  )
923
+ expected_output_cols = self._align_expected_output_names(
924
+ inference_method, dataset, expected_output_cols, output_cols_prefix
925
+ )
858
926
 
859
927
  elif isinstance(dataset, pd.DataFrame):
860
- transform_kwargs = dict(
861
- snowpark_input_cols = self._snowpark_cols,
862
- drop_input_cols = self._drop_input_cols
863
- )
928
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
864
929
 
865
930
  transform_handlers = ModelTransformerBuilder.build(
866
931
  dataset=dataset,
@@ -873,7 +938,7 @@ class RandomForestClassifier(BaseTransformer):
873
938
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
874
939
  inference_method=inference_method,
875
940
  input_cols=self.input_cols,
876
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
941
+ expected_output_cols=expected_output_cols,
877
942
  **transform_kwargs
878
943
  )
879
944
  return output_df
@@ -902,12 +967,14 @@ class RandomForestClassifier(BaseTransformer):
902
967
  Output dataset with probability of the sample for each class in the model.
903
968
  """
904
969
  super()._check_dataset_type(dataset)
905
- inference_method="score_samples"
970
+ inference_method = "score_samples"
906
971
 
907
972
  # This dictionary contains optional kwargs for batch inference. These kwargs
908
973
  # are specific to the type of dataset used.
909
974
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
910
975
 
976
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
977
+
911
978
  if isinstance(dataset, DataFrame):
912
979
  self._deps = self._batch_inference_validate_snowpark(
913
980
  dataset=dataset,
@@ -920,6 +987,9 @@ class RandomForestClassifier(BaseTransformer):
920
987
  drop_input_cols = self._drop_input_cols,
921
988
  expected_output_cols_type="float",
922
989
  )
990
+ expected_output_cols = self._align_expected_output_names(
991
+ inference_method, dataset, expected_output_cols, output_cols_prefix
992
+ )
923
993
 
924
994
  elif isinstance(dataset, pd.DataFrame):
925
995
  transform_kwargs = dict(
@@ -938,7 +1008,7 @@ class RandomForestClassifier(BaseTransformer):
938
1008
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
939
1009
  inference_method=inference_method,
940
1010
  input_cols=self.input_cols,
941
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
1011
+ expected_output_cols=expected_output_cols,
942
1012
  **transform_kwargs
943
1013
  )
944
1014
  return output_df
@@ -1085,50 +1155,84 @@ class RandomForestClassifier(BaseTransformer):
1085
1155
  )
1086
1156
  return output_df
1087
1157
 
1158
+
1159
+
1160
+ def to_sklearn(self) -> Any:
1161
+ """Get sklearn.ensemble.RandomForestClassifier object.
1162
+ """
1163
+ if self._sklearn_object is None:
1164
+ self._sklearn_object = self._create_sklearn_object()
1165
+ return self._sklearn_object
1166
+
1167
+ def to_xgboost(self) -> Any:
1168
+ raise exceptions.SnowflakeMLException(
1169
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1170
+ original_exception=AttributeError(
1171
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1172
+ "to_xgboost()",
1173
+ "to_sklearn()"
1174
+ )
1175
+ ),
1176
+ )
1177
+
1178
+ def to_lightgbm(self) -> Any:
1179
+ raise exceptions.SnowflakeMLException(
1180
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1181
+ original_exception=AttributeError(
1182
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1183
+ "to_lightgbm()",
1184
+ "to_sklearn()"
1185
+ )
1186
+ ),
1187
+ )
1088
1188
 
1089
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1189
+ def _get_dependencies(self) -> List[str]:
1190
+ return self._deps
1191
+
1192
+
1193
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1090
1194
  self._model_signature_dict = dict()
1091
1195
 
1092
1196
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1093
1197
 
1094
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1198
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1095
1199
  outputs: List[BaseFeatureSpec] = []
1096
1200
  if hasattr(self, "predict"):
1097
1201
  # keep mypy happy
1098
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1202
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1099
1203
  # For classifier, the type of predict is the same as the type of label
1100
- if self._sklearn_object._estimator_type == 'classifier':
1101
- # label columns is the desired type for output
1204
+ if self._sklearn_object._estimator_type == "classifier":
1205
+ # label columns is the desired type for output
1102
1206
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1103
1207
  # rename the output columns
1104
1208
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1105
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1106
- ([] if self._drop_input_cols else inputs)
1107
- + outputs)
1209
+ self._model_signature_dict["predict"] = ModelSignature(
1210
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1211
+ )
1108
1212
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
1109
1213
  # For outlier models, returns -1 for outliers and 1 for inliers.
1110
- # Clusterer returns int64 cluster labels.
1214
+ # Clusterer returns int64 cluster labels.
1111
1215
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
1112
1216
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
1113
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1114
- ([] if self._drop_input_cols else inputs)
1115
- + outputs)
1116
-
1217
+ self._model_signature_dict["predict"] = ModelSignature(
1218
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1219
+ )
1220
+
1117
1221
  # For regressor, the type of predict is float64
1118
- elif self._sklearn_object._estimator_type == 'regressor':
1222
+ elif self._sklearn_object._estimator_type == "regressor":
1119
1223
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1120
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1121
- ([] if self._drop_input_cols else inputs)
1122
- + outputs)
1123
-
1224
+ self._model_signature_dict["predict"] = ModelSignature(
1225
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1226
+ )
1227
+
1124
1228
  for prob_func in PROB_FUNCTIONS:
1125
1229
  if hasattr(self, prob_func):
1126
1230
  output_cols_prefix: str = f"{prob_func}_"
1127
1231
  output_column_names = self._get_output_column_names(output_cols_prefix)
1128
1232
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1129
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1130
- ([] if self._drop_input_cols else inputs)
1131
- + outputs)
1233
+ self._model_signature_dict[prob_func] = ModelSignature(
1234
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1235
+ )
1132
1236
 
1133
1237
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1134
1238
  items = list(self._model_signature_dict.items())
@@ -1141,10 +1245,10 @@ class RandomForestClassifier(BaseTransformer):
1141
1245
  """Returns model signature of current class.
1142
1246
 
1143
1247
  Raises:
1144
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1248
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1145
1249
 
1146
1250
  Returns:
1147
- Dict[str, ModelSignature]: each method and its input output signature
1251
+ Dict with each method and its input output signature
1148
1252
  """
1149
1253
  if self._model_signature_dict is None:
1150
1254
  raise exceptions.SnowflakeMLException(
@@ -1152,35 +1256,3 @@ class RandomForestClassifier(BaseTransformer):
1152
1256
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1153
1257
  )
1154
1258
  return self._model_signature_dict
1155
-
1156
- def to_sklearn(self) -> Any:
1157
- """Get sklearn.ensemble.RandomForestClassifier object.
1158
- """
1159
- if self._sklearn_object is None:
1160
- self._sklearn_object = self._create_sklearn_object()
1161
- return self._sklearn_object
1162
-
1163
- def to_xgboost(self) -> Any:
1164
- raise exceptions.SnowflakeMLException(
1165
- error_code=error_codes.METHOD_NOT_ALLOWED,
1166
- original_exception=AttributeError(
1167
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1168
- "to_xgboost()",
1169
- "to_sklearn()"
1170
- )
1171
- ),
1172
- )
1173
-
1174
- def to_lightgbm(self) -> Any:
1175
- raise exceptions.SnowflakeMLException(
1176
- error_code=error_codes.METHOD_NOT_ALLOWED,
1177
- original_exception=AttributeError(
1178
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1179
- "to_lightgbm()",
1180
- "to_sklearn()"
1181
- )
1182
- ),
1183
- )
1184
-
1185
- def _get_dependencies(self) -> List[str]:
1186
- return self._deps