snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/utils/formatting.py +1 -1
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/feature_store/feature_store.py +166 -184
  10. snowflake/ml/feature_store/feature_view.py +12 -24
  11. snowflake/ml/fileset/sfcfs.py +56 -50
  12. snowflake/ml/fileset/stage_fs.py +48 -13
  13. snowflake/ml/model/_client/model/model_version_impl.py +6 -49
  14. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  15. snowflake/ml/model/_client/sql/model.py +23 -2
  16. snowflake/ml/model/_client/sql/model_version.py +22 -1
  17. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
  18. snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
  19. snowflake/ml/model/_model_composer/model_composer.py +7 -5
  20. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  22. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  24. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  25. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
  29. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  32. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  33. snowflake/ml/model/_packager/model_packager.py +2 -2
  34. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +21 -2
  38. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  40. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  41. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  46. snowflake/ml/modeling/cluster/birch.py +195 -123
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  48. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  50. snowflake/ml/modeling/cluster/k_means.py +195 -123
  51. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  53. snowflake/ml/modeling/cluster/optics.py +195 -123
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  57. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  64. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  65. snowflake/ml/modeling/covariance/oas.py +195 -123
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  69. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  74. snowflake/ml/modeling/decomposition/pca.py +195 -123
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  103. snowflake/ml/modeling/framework/_utils.py +8 -1
  104. snowflake/ml/modeling/framework/base.py +24 -6
  105. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  107. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  108. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  109. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  110. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
  119. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  121. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  125. snowflake/ml/modeling/linear_model/lars.py +195 -123
  126. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  127. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  132. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  142. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  145. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  154. snowflake/ml/modeling/manifold/isomap.py +195 -123
  155. snowflake/ml/modeling/manifold/mds.py +195 -123
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  157. snowflake/ml/modeling/manifold/tsne.py +195 -123
  158. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  159. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  160. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  161. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  162. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  163. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  164. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  165. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  166. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  167. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  168. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  169. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  170. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  171. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  172. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  173. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  174. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  175. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  176. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  177. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  178. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  179. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  180. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  181. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  182. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  183. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  184. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  185. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  186. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  187. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  188. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  189. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  190. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  191. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  192. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  193. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  194. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  195. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  196. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  197. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  198. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  199. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  200. snowflake/ml/modeling/svm/svc.py +195 -123
  201. snowflake/ml/modeling/svm/svr.py +195 -123
  202. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  203. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  204. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  205. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  206. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  207. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  208. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  209. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  210. snowflake/ml/registry/_manager/model_manager.py +5 -1
  211. snowflake/ml/registry/model_registry.py +99 -26
  212. snowflake/ml/registry/registry.py +3 -2
  213. snowflake/ml/version.py +1 -1
  214. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
  215. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
  216. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  217. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  218. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  219. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -330,12 +329,7 @@ class DecisionTreeClassifier(BaseTransformer):
330
329
  )
331
330
  return selected_cols
332
331
 
333
- @telemetry.send_api_usage_telemetry(
334
- project=_PROJECT,
335
- subproject=_SUBPROJECT,
336
- custom_tags=dict([("autogen", True)]),
337
- )
338
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "DecisionTreeClassifier":
332
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "DecisionTreeClassifier":
339
333
  """Build a decision tree classifier from the training set (X, y)
340
334
  For more details on this function, see [sklearn.tree.DecisionTreeClassifier.fit]
341
335
  (https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier.fit)
@@ -362,12 +356,14 @@ class DecisionTreeClassifier(BaseTransformer):
362
356
 
363
357
  self._snowpark_cols = dataset.select(self.input_cols).columns
364
358
 
365
- # If we are already in a stored procedure, no need to kick off another one.
359
+ # If we are already in a stored procedure, no need to kick off another one.
366
360
  if SNOWML_SPROC_ENV in os.environ:
367
361
  statement_params = telemetry.get_function_usage_statement_params(
368
362
  project=_PROJECT,
369
363
  subproject=_SUBPROJECT,
370
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), DecisionTreeClassifier.__class__.__name__),
364
+ function_name=telemetry.get_statement_params_full_func_name(
365
+ inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
366
+ ),
371
367
  api_calls=[Session.call],
372
368
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
373
369
  )
@@ -388,7 +384,7 @@ class DecisionTreeClassifier(BaseTransformer):
388
384
  )
389
385
  self._sklearn_object = model_trainer.train()
390
386
  self._is_fitted = True
391
- self._get_model_signatures(dataset)
387
+ self._generate_model_signatures(dataset)
392
388
  return self
393
389
 
394
390
  def _batch_inference_validate_snowpark(
@@ -464,7 +460,9 @@ class DecisionTreeClassifier(BaseTransformer):
464
460
  # when it is classifier, infer the datatype from label columns
465
461
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
466
462
  # Batch inference takes a single expected output column type. Use the first columns type for now.
467
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
463
+ label_cols_signatures = [
464
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
465
+ ]
468
466
  if len(label_cols_signatures) == 0:
469
467
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
470
468
  raise exceptions.SnowflakeMLException(
@@ -472,25 +470,22 @@ class DecisionTreeClassifier(BaseTransformer):
472
470
  original_exception=ValueError(error_str),
473
471
  )
474
472
 
475
- expected_type_inferred = convert_sp_to_sf_type(
476
- label_cols_signatures[0].as_snowpark_type()
477
- )
473
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
478
474
 
479
475
  self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
480
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
476
+ assert isinstance(
477
+ dataset._session, Session
478
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
481
479
 
482
480
  transform_kwargs = dict(
483
- session = dataset._session,
484
- dependencies = self._deps,
485
- drop_input_cols = self._drop_input_cols,
486
- expected_output_cols_type = expected_type_inferred,
481
+ session=dataset._session,
482
+ dependencies=self._deps,
483
+ drop_input_cols=self._drop_input_cols,
484
+ expected_output_cols_type=expected_type_inferred,
487
485
  )
488
486
 
489
487
  elif isinstance(dataset, pd.DataFrame):
490
- transform_kwargs = dict(
491
- snowpark_input_cols = self._snowpark_cols,
492
- drop_input_cols = self._drop_input_cols
493
- )
488
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
494
489
 
495
490
  transform_handlers = ModelTransformerBuilder.build(
496
491
  dataset=dataset,
@@ -530,7 +525,7 @@ class DecisionTreeClassifier(BaseTransformer):
530
525
  Transformed dataset.
531
526
  """
532
527
  super()._check_dataset_type(dataset)
533
- inference_method="transform"
528
+ inference_method = "transform"
534
529
 
535
530
  # This dictionary contains optional kwargs for batch inference. These kwargs
536
531
  # are specific to the type of dataset used.
@@ -567,17 +562,14 @@ class DecisionTreeClassifier(BaseTransformer):
567
562
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
568
563
 
569
564
  transform_kwargs = dict(
570
- session = dataset._session,
571
- dependencies = self._deps,
572
- drop_input_cols = self._drop_input_cols,
573
- expected_output_cols_type = expected_dtype,
565
+ session=dataset._session,
566
+ dependencies=self._deps,
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_type=expected_dtype,
574
569
  )
575
570
 
576
571
  elif isinstance(dataset, pd.DataFrame):
577
- transform_kwargs = dict(
578
- snowpark_input_cols = self._snowpark_cols,
579
- drop_input_cols = self._drop_input_cols
580
- )
572
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
581
573
 
582
574
  transform_handlers = ModelTransformerBuilder.build(
583
575
  dataset=dataset,
@@ -596,7 +588,11 @@ class DecisionTreeClassifier(BaseTransformer):
596
588
  return output_df
597
589
 
598
590
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
599
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
591
+ def fit_predict(
592
+ self,
593
+ dataset: Union[DataFrame, pd.DataFrame],
594
+ output_cols_prefix: str = "fit_predict_",
595
+ ) -> Union[DataFrame, pd.DataFrame]:
600
596
  """ Method not supported for this class.
601
597
 
602
598
 
@@ -621,7 +617,9 @@ class DecisionTreeClassifier(BaseTransformer):
621
617
  )
622
618
  output_result, fitted_estimator = model_trainer.train_fit_predict(
623
619
  drop_input_cols=self._drop_input_cols,
624
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
620
+ expected_output_cols_list=(
621
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
622
+ ),
625
623
  )
626
624
  self._sklearn_object = fitted_estimator
627
625
  self._is_fitted = True
@@ -638,6 +636,62 @@ class DecisionTreeClassifier(BaseTransformer):
638
636
  assert self._sklearn_object is not None
639
637
  return self._sklearn_object.embedding_
640
638
 
639
+
640
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
641
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
642
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
643
+ """
644
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
645
+ # The following condition is introduced for kneighbors methods, and not used in other methods
646
+ if output_cols:
647
+ output_cols = [
648
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
649
+ for c in output_cols
650
+ ]
651
+ elif getattr(self._sklearn_object, "classes_", None) is None:
652
+ output_cols = [output_cols_prefix]
653
+ elif self._sklearn_object is not None:
654
+ classes = self._sklearn_object.classes_
655
+ if isinstance(classes, numpy.ndarray):
656
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
657
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
658
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
659
+ output_cols = []
660
+ for i, cl in enumerate(classes):
661
+ # For binary classification, there is only one output column for each class
662
+ # ndarray as the two classes are complementary.
663
+ if len(cl) == 2:
664
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
665
+ else:
666
+ output_cols.extend([
667
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
668
+ ])
669
+ else:
670
+ output_cols = []
671
+
672
+ # Make sure column names are valid snowflake identifiers.
673
+ assert output_cols is not None # Make MyPy happy
674
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
675
+
676
+ return rv
677
+
678
+ def _align_expected_output_names(
679
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
680
+ ) -> List[str]:
681
+ # in case the inferred output column names dimension is different
682
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
683
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
684
+ output_df_columns = list(output_df_pd.columns)
685
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
686
+ if self.sample_weight_col:
687
+ output_df_columns_set -= set(self.sample_weight_col)
688
+ # if the dimension of inferred output column names is correct; use it
689
+ if len(expected_output_cols_list) == len(output_df_columns_set):
690
+ return expected_output_cols_list
691
+ # otherwise, use the sklearn estimator's output
692
+ else:
693
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
694
+
641
695
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
642
696
  @telemetry.send_api_usage_telemetry(
643
697
  project=_PROJECT,
@@ -670,24 +724,28 @@ class DecisionTreeClassifier(BaseTransformer):
670
724
  # are specific to the type of dataset used.
671
725
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
672
726
 
727
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
728
+
673
729
  if isinstance(dataset, DataFrame):
674
730
  self._deps = self._batch_inference_validate_snowpark(
675
731
  dataset=dataset,
676
732
  inference_method=inference_method,
677
733
  )
678
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
734
+ assert isinstance(
735
+ dataset._session, Session
736
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
679
737
  transform_kwargs = dict(
680
738
  session=dataset._session,
681
739
  dependencies=self._deps,
682
- drop_input_cols = self._drop_input_cols,
740
+ drop_input_cols=self._drop_input_cols,
683
741
  expected_output_cols_type="float",
684
742
  )
743
+ expected_output_cols = self._align_expected_output_names(
744
+ inference_method, dataset, expected_output_cols, output_cols_prefix
745
+ )
685
746
 
686
747
  elif isinstance(dataset, pd.DataFrame):
687
- transform_kwargs = dict(
688
- snowpark_input_cols = self._snowpark_cols,
689
- drop_input_cols = self._drop_input_cols
690
- )
748
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
691
749
 
692
750
  transform_handlers = ModelTransformerBuilder.build(
693
751
  dataset=dataset,
@@ -699,7 +757,7 @@ class DecisionTreeClassifier(BaseTransformer):
699
757
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
700
758
  inference_method=inference_method,
701
759
  input_cols=self.input_cols,
702
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
760
+ expected_output_cols=expected_output_cols,
703
761
  **transform_kwargs
704
762
  )
705
763
  return output_df
@@ -731,7 +789,8 @@ class DecisionTreeClassifier(BaseTransformer):
731
789
  Output dataset with log probability of the sample for each class in the model.
732
790
  """
733
791
  super()._check_dataset_type(dataset)
734
- inference_method="predict_log_proba"
792
+ inference_method = "predict_log_proba"
793
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
735
794
 
736
795
  # This dictionary contains optional kwargs for batch inference. These kwargs
737
796
  # are specific to the type of dataset used.
@@ -742,18 +801,20 @@ class DecisionTreeClassifier(BaseTransformer):
742
801
  dataset=dataset,
743
802
  inference_method=inference_method,
744
803
  )
745
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
804
+ assert isinstance(
805
+ dataset._session, Session
806
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
746
807
  transform_kwargs = dict(
747
808
  session=dataset._session,
748
809
  dependencies=self._deps,
749
- drop_input_cols = self._drop_input_cols,
810
+ drop_input_cols=self._drop_input_cols,
750
811
  expected_output_cols_type="float",
751
812
  )
813
+ expected_output_cols = self._align_expected_output_names(
814
+ inference_method, dataset, expected_output_cols, output_cols_prefix
815
+ )
752
816
  elif isinstance(dataset, pd.DataFrame):
753
- transform_kwargs = dict(
754
- snowpark_input_cols = self._snowpark_cols,
755
- drop_input_cols = self._drop_input_cols
756
- )
817
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
757
818
 
758
819
  transform_handlers = ModelTransformerBuilder.build(
759
820
  dataset=dataset,
@@ -766,7 +827,7 @@ class DecisionTreeClassifier(BaseTransformer):
766
827
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
767
828
  inference_method=inference_method,
768
829
  input_cols=self.input_cols,
769
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
830
+ expected_output_cols=expected_output_cols,
770
831
  **transform_kwargs
771
832
  )
772
833
  return output_df
@@ -792,30 +853,34 @@ class DecisionTreeClassifier(BaseTransformer):
792
853
  Output dataset with results of the decision function for the samples in input dataset.
793
854
  """
794
855
  super()._check_dataset_type(dataset)
795
- inference_method="decision_function"
856
+ inference_method = "decision_function"
796
857
 
797
858
  # This dictionary contains optional kwargs for batch inference. These kwargs
798
859
  # are specific to the type of dataset used.
799
860
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
800
861
 
862
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
863
+
801
864
  if isinstance(dataset, DataFrame):
802
865
  self._deps = self._batch_inference_validate_snowpark(
803
866
  dataset=dataset,
804
867
  inference_method=inference_method,
805
868
  )
806
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
869
+ assert isinstance(
870
+ dataset._session, Session
871
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
807
872
  transform_kwargs = dict(
808
873
  session=dataset._session,
809
874
  dependencies=self._deps,
810
- drop_input_cols = self._drop_input_cols,
875
+ drop_input_cols=self._drop_input_cols,
811
876
  expected_output_cols_type="float",
812
877
  )
878
+ expected_output_cols = self._align_expected_output_names(
879
+ inference_method, dataset, expected_output_cols, output_cols_prefix
880
+ )
813
881
 
814
882
  elif isinstance(dataset, pd.DataFrame):
815
- transform_kwargs = dict(
816
- snowpark_input_cols = self._snowpark_cols,
817
- drop_input_cols = self._drop_input_cols
818
- )
883
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
819
884
 
820
885
  transform_handlers = ModelTransformerBuilder.build(
821
886
  dataset=dataset,
@@ -828,7 +893,7 @@ class DecisionTreeClassifier(BaseTransformer):
828
893
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
829
894
  inference_method=inference_method,
830
895
  input_cols=self.input_cols,
831
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
896
+ expected_output_cols=expected_output_cols,
832
897
  **transform_kwargs
833
898
  )
834
899
  return output_df
@@ -857,12 +922,14 @@ class DecisionTreeClassifier(BaseTransformer):
857
922
  Output dataset with probability of the sample for each class in the model.
858
923
  """
859
924
  super()._check_dataset_type(dataset)
860
- inference_method="score_samples"
925
+ inference_method = "score_samples"
861
926
 
862
927
  # This dictionary contains optional kwargs for batch inference. These kwargs
863
928
  # are specific to the type of dataset used.
864
929
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
865
930
 
931
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
932
+
866
933
  if isinstance(dataset, DataFrame):
867
934
  self._deps = self._batch_inference_validate_snowpark(
868
935
  dataset=dataset,
@@ -875,6 +942,9 @@ class DecisionTreeClassifier(BaseTransformer):
875
942
  drop_input_cols = self._drop_input_cols,
876
943
  expected_output_cols_type="float",
877
944
  )
945
+ expected_output_cols = self._align_expected_output_names(
946
+ inference_method, dataset, expected_output_cols, output_cols_prefix
947
+ )
878
948
 
879
949
  elif isinstance(dataset, pd.DataFrame):
880
950
  transform_kwargs = dict(
@@ -893,7 +963,7 @@ class DecisionTreeClassifier(BaseTransformer):
893
963
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
894
964
  inference_method=inference_method,
895
965
  input_cols=self.input_cols,
896
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
966
+ expected_output_cols=expected_output_cols,
897
967
  **transform_kwargs
898
968
  )
899
969
  return output_df
@@ -1040,50 +1110,84 @@ class DecisionTreeClassifier(BaseTransformer):
1040
1110
  )
1041
1111
  return output_df
1042
1112
 
1113
+
1114
+
1115
+ def to_sklearn(self) -> Any:
1116
+ """Get sklearn.tree.DecisionTreeClassifier object.
1117
+ """
1118
+ if self._sklearn_object is None:
1119
+ self._sklearn_object = self._create_sklearn_object()
1120
+ return self._sklearn_object
1121
+
1122
+ def to_xgboost(self) -> Any:
1123
+ raise exceptions.SnowflakeMLException(
1124
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1125
+ original_exception=AttributeError(
1126
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1127
+ "to_xgboost()",
1128
+ "to_sklearn()"
1129
+ )
1130
+ ),
1131
+ )
1132
+
1133
+ def to_lightgbm(self) -> Any:
1134
+ raise exceptions.SnowflakeMLException(
1135
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1136
+ original_exception=AttributeError(
1137
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1138
+ "to_lightgbm()",
1139
+ "to_sklearn()"
1140
+ )
1141
+ ),
1142
+ )
1043
1143
 
1044
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1144
+ def _get_dependencies(self) -> List[str]:
1145
+ return self._deps
1146
+
1147
+
1148
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1045
1149
  self._model_signature_dict = dict()
1046
1150
 
1047
1151
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1048
1152
 
1049
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1153
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1050
1154
  outputs: List[BaseFeatureSpec] = []
1051
1155
  if hasattr(self, "predict"):
1052
1156
  # keep mypy happy
1053
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1157
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1054
1158
  # For classifier, the type of predict is the same as the type of label
1055
- if self._sklearn_object._estimator_type == 'classifier':
1056
- # label columns is the desired type for output
1159
+ if self._sklearn_object._estimator_type == "classifier":
1160
+ # label columns is the desired type for output
1057
1161
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1058
1162
  # rename the output columns
1059
1163
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1060
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1061
- ([] if self._drop_input_cols else inputs)
1062
- + outputs)
1164
+ self._model_signature_dict["predict"] = ModelSignature(
1165
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1166
+ )
1063
1167
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
1064
1168
  # For outlier models, returns -1 for outliers and 1 for inliers.
1065
- # Clusterer returns int64 cluster labels.
1169
+ # Clusterer returns int64 cluster labels.
1066
1170
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
1067
1171
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
1068
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1069
- ([] if self._drop_input_cols else inputs)
1070
- + outputs)
1071
-
1172
+ self._model_signature_dict["predict"] = ModelSignature(
1173
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1174
+ )
1175
+
1072
1176
  # For regressor, the type of predict is float64
1073
- elif self._sklearn_object._estimator_type == 'regressor':
1177
+ elif self._sklearn_object._estimator_type == "regressor":
1074
1178
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1075
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1076
- ([] if self._drop_input_cols else inputs)
1077
- + outputs)
1078
-
1179
+ self._model_signature_dict["predict"] = ModelSignature(
1180
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1181
+ )
1182
+
1079
1183
  for prob_func in PROB_FUNCTIONS:
1080
1184
  if hasattr(self, prob_func):
1081
1185
  output_cols_prefix: str = f"{prob_func}_"
1082
1186
  output_column_names = self._get_output_column_names(output_cols_prefix)
1083
1187
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1084
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1085
- ([] if self._drop_input_cols else inputs)
1086
- + outputs)
1188
+ self._model_signature_dict[prob_func] = ModelSignature(
1189
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1190
+ )
1087
1191
 
1088
1192
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1089
1193
  items = list(self._model_signature_dict.items())
@@ -1096,10 +1200,10 @@ class DecisionTreeClassifier(BaseTransformer):
1096
1200
  """Returns model signature of current class.
1097
1201
 
1098
1202
  Raises:
1099
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1203
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1100
1204
 
1101
1205
  Returns:
1102
- Dict[str, ModelSignature]: each method and its input output signature
1206
+ Dict with each method and its input output signature
1103
1207
  """
1104
1208
  if self._model_signature_dict is None:
1105
1209
  raise exceptions.SnowflakeMLException(
@@ -1107,35 +1211,3 @@ class DecisionTreeClassifier(BaseTransformer):
1107
1211
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1108
1212
  )
1109
1213
  return self._model_signature_dict
1110
-
1111
- def to_sklearn(self) -> Any:
1112
- """Get sklearn.tree.DecisionTreeClassifier object.
1113
- """
1114
- if self._sklearn_object is None:
1115
- self._sklearn_object = self._create_sklearn_object()
1116
- return self._sklearn_object
1117
-
1118
- def to_xgboost(self) -> Any:
1119
- raise exceptions.SnowflakeMLException(
1120
- error_code=error_codes.METHOD_NOT_ALLOWED,
1121
- original_exception=AttributeError(
1122
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1123
- "to_xgboost()",
1124
- "to_sklearn()"
1125
- )
1126
- ),
1127
- )
1128
-
1129
- def to_lightgbm(self) -> Any:
1130
- raise exceptions.SnowflakeMLException(
1131
- error_code=error_codes.METHOD_NOT_ALLOWED,
1132
- original_exception=AttributeError(
1133
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1134
- "to_lightgbm()",
1135
- "to_sklearn()"
1136
- )
1137
- ),
1138
- )
1139
-
1140
- def _get_dependencies(self) -> List[str]:
1141
- return self._deps