snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/utils/formatting.py +1 -1
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/feature_store/feature_store.py +166 -184
  10. snowflake/ml/feature_store/feature_view.py +12 -24
  11. snowflake/ml/fileset/sfcfs.py +56 -50
  12. snowflake/ml/fileset/stage_fs.py +48 -13
  13. snowflake/ml/model/_client/model/model_version_impl.py +6 -49
  14. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  15. snowflake/ml/model/_client/sql/model.py +23 -2
  16. snowflake/ml/model/_client/sql/model_version.py +22 -1
  17. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
  18. snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
  19. snowflake/ml/model/_model_composer/model_composer.py +7 -5
  20. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  22. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  24. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  25. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
  29. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  32. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  33. snowflake/ml/model/_packager/model_packager.py +2 -2
  34. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +21 -2
  38. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  40. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  41. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  46. snowflake/ml/modeling/cluster/birch.py +195 -123
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  48. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  50. snowflake/ml/modeling/cluster/k_means.py +195 -123
  51. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  53. snowflake/ml/modeling/cluster/optics.py +195 -123
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  57. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  64. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  65. snowflake/ml/modeling/covariance/oas.py +195 -123
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  69. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  74. snowflake/ml/modeling/decomposition/pca.py +195 -123
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  103. snowflake/ml/modeling/framework/_utils.py +8 -1
  104. snowflake/ml/modeling/framework/base.py +24 -6
  105. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  107. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  108. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  109. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  110. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
  119. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  121. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  125. snowflake/ml/modeling/linear_model/lars.py +195 -123
  126. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  127. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  132. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  142. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  145. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  154. snowflake/ml/modeling/manifold/isomap.py +195 -123
  155. snowflake/ml/modeling/manifold/mds.py +195 -123
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  157. snowflake/ml/modeling/manifold/tsne.py +195 -123
  158. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  159. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  160. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  161. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  162. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  163. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  164. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  165. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  166. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  167. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  168. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  169. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  170. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  171. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  172. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  173. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  174. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  175. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  176. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  177. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  178. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  179. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  180. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  181. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  182. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  183. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  184. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  185. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  186. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  187. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  188. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  189. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  190. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  191. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  192. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  193. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  194. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  195. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  196. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  197. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  198. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  199. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  200. snowflake/ml/modeling/svm/svc.py +195 -123
  201. snowflake/ml/modeling/svm/svr.py +195 -123
  202. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  203. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  204. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  205. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  206. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  207. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  208. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  209. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  210. snowflake/ml/registry/_manager/model_manager.py +5 -1
  211. snowflake/ml/registry/model_registry.py +99 -26
  212. snowflake/ml/registry/registry.py +3 -2
  213. snowflake/ml/version.py +1 -1
  214. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
  215. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
  216. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  217. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  218. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  219. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -379,12 +378,7 @@ class ExtraTreesClassifier(BaseTransformer):
379
378
  )
380
379
  return selected_cols
381
380
 
382
- @telemetry.send_api_usage_telemetry(
383
- project=_PROJECT,
384
- subproject=_SUBPROJECT,
385
- custom_tags=dict([("autogen", True)]),
386
- )
387
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreesClassifier":
381
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreesClassifier":
388
382
  """Build a forest of trees from the training set (X, y)
389
383
  For more details on this function, see [sklearn.ensemble.ExtraTreesClassifier.fit]
390
384
  (https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html#sklearn.ensemble.ExtraTreesClassifier.fit)
@@ -411,12 +405,14 @@ class ExtraTreesClassifier(BaseTransformer):
411
405
 
412
406
  self._snowpark_cols = dataset.select(self.input_cols).columns
413
407
 
414
- # If we are already in a stored procedure, no need to kick off another one.
408
+ # If we are already in a stored procedure, no need to kick off another one.
415
409
  if SNOWML_SPROC_ENV in os.environ:
416
410
  statement_params = telemetry.get_function_usage_statement_params(
417
411
  project=_PROJECT,
418
412
  subproject=_SUBPROJECT,
419
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), ExtraTreesClassifier.__class__.__name__),
413
+ function_name=telemetry.get_statement_params_full_func_name(
414
+ inspect.currentframe(), ExtraTreesClassifier.__class__.__name__
415
+ ),
420
416
  api_calls=[Session.call],
421
417
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
422
418
  )
@@ -437,7 +433,7 @@ class ExtraTreesClassifier(BaseTransformer):
437
433
  )
438
434
  self._sklearn_object = model_trainer.train()
439
435
  self._is_fitted = True
440
- self._get_model_signatures(dataset)
436
+ self._generate_model_signatures(dataset)
441
437
  return self
442
438
 
443
439
  def _batch_inference_validate_snowpark(
@@ -513,7 +509,9 @@ class ExtraTreesClassifier(BaseTransformer):
513
509
  # when it is classifier, infer the datatype from label columns
514
510
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
515
511
  # Batch inference takes a single expected output column type. Use the first columns type for now.
516
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
512
+ label_cols_signatures = [
513
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
514
+ ]
517
515
  if len(label_cols_signatures) == 0:
518
516
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
519
517
  raise exceptions.SnowflakeMLException(
@@ -521,25 +519,22 @@ class ExtraTreesClassifier(BaseTransformer):
521
519
  original_exception=ValueError(error_str),
522
520
  )
523
521
 
524
- expected_type_inferred = convert_sp_to_sf_type(
525
- label_cols_signatures[0].as_snowpark_type()
526
- )
522
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
527
523
 
528
524
  self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
529
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
525
+ assert isinstance(
526
+ dataset._session, Session
527
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
530
528
 
531
529
  transform_kwargs = dict(
532
- session = dataset._session,
533
- dependencies = self._deps,
534
- drop_input_cols = self._drop_input_cols,
535
- expected_output_cols_type = expected_type_inferred,
530
+ session=dataset._session,
531
+ dependencies=self._deps,
532
+ drop_input_cols=self._drop_input_cols,
533
+ expected_output_cols_type=expected_type_inferred,
536
534
  )
537
535
 
538
536
  elif isinstance(dataset, pd.DataFrame):
539
- transform_kwargs = dict(
540
- snowpark_input_cols = self._snowpark_cols,
541
- drop_input_cols = self._drop_input_cols
542
- )
537
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
543
538
 
544
539
  transform_handlers = ModelTransformerBuilder.build(
545
540
  dataset=dataset,
@@ -579,7 +574,7 @@ class ExtraTreesClassifier(BaseTransformer):
579
574
  Transformed dataset.
580
575
  """
581
576
  super()._check_dataset_type(dataset)
582
- inference_method="transform"
577
+ inference_method = "transform"
583
578
 
584
579
  # This dictionary contains optional kwargs for batch inference. These kwargs
585
580
  # are specific to the type of dataset used.
@@ -616,17 +611,14 @@ class ExtraTreesClassifier(BaseTransformer):
616
611
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
617
612
 
618
613
  transform_kwargs = dict(
619
- session = dataset._session,
620
- dependencies = self._deps,
621
- drop_input_cols = self._drop_input_cols,
622
- expected_output_cols_type = expected_dtype,
614
+ session=dataset._session,
615
+ dependencies=self._deps,
616
+ drop_input_cols=self._drop_input_cols,
617
+ expected_output_cols_type=expected_dtype,
623
618
  )
624
619
 
625
620
  elif isinstance(dataset, pd.DataFrame):
626
- transform_kwargs = dict(
627
- snowpark_input_cols = self._snowpark_cols,
628
- drop_input_cols = self._drop_input_cols
629
- )
621
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
630
622
 
631
623
  transform_handlers = ModelTransformerBuilder.build(
632
624
  dataset=dataset,
@@ -645,7 +637,11 @@ class ExtraTreesClassifier(BaseTransformer):
645
637
  return output_df
646
638
 
647
639
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
648
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
640
+ def fit_predict(
641
+ self,
642
+ dataset: Union[DataFrame, pd.DataFrame],
643
+ output_cols_prefix: str = "fit_predict_",
644
+ ) -> Union[DataFrame, pd.DataFrame]:
649
645
  """ Method not supported for this class.
650
646
 
651
647
 
@@ -670,7 +666,9 @@ class ExtraTreesClassifier(BaseTransformer):
670
666
  )
671
667
  output_result, fitted_estimator = model_trainer.train_fit_predict(
672
668
  drop_input_cols=self._drop_input_cols,
673
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
669
+ expected_output_cols_list=(
670
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
671
+ ),
674
672
  )
675
673
  self._sklearn_object = fitted_estimator
676
674
  self._is_fitted = True
@@ -687,6 +685,62 @@ class ExtraTreesClassifier(BaseTransformer):
687
685
  assert self._sklearn_object is not None
688
686
  return self._sklearn_object.embedding_
689
687
 
688
+
689
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
690
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
691
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
692
+ """
693
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
694
+ # The following condition is introduced for kneighbors methods, and not used in other methods
695
+ if output_cols:
696
+ output_cols = [
697
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
698
+ for c in output_cols
699
+ ]
700
+ elif getattr(self._sklearn_object, "classes_", None) is None:
701
+ output_cols = [output_cols_prefix]
702
+ elif self._sklearn_object is not None:
703
+ classes = self._sklearn_object.classes_
704
+ if isinstance(classes, numpy.ndarray):
705
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
706
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
707
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
708
+ output_cols = []
709
+ for i, cl in enumerate(classes):
710
+ # For binary classification, there is only one output column for each class
711
+ # ndarray as the two classes are complementary.
712
+ if len(cl) == 2:
713
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
714
+ else:
715
+ output_cols.extend([
716
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
717
+ ])
718
+ else:
719
+ output_cols = []
720
+
721
+ # Make sure column names are valid snowflake identifiers.
722
+ assert output_cols is not None # Make MyPy happy
723
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
724
+
725
+ return rv
726
+
727
+ def _align_expected_output_names(
728
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
729
+ ) -> List[str]:
730
+ # in case the inferred output column names dimension is different
731
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
732
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
733
+ output_df_columns = list(output_df_pd.columns)
734
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
735
+ if self.sample_weight_col:
736
+ output_df_columns_set -= set(self.sample_weight_col)
737
+ # if the dimension of inferred output column names is correct; use it
738
+ if len(expected_output_cols_list) == len(output_df_columns_set):
739
+ return expected_output_cols_list
740
+ # otherwise, use the sklearn estimator's output
741
+ else:
742
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
743
+
690
744
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
691
745
  @telemetry.send_api_usage_telemetry(
692
746
  project=_PROJECT,
@@ -719,24 +773,28 @@ class ExtraTreesClassifier(BaseTransformer):
719
773
  # are specific to the type of dataset used.
720
774
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
721
775
 
776
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
777
+
722
778
  if isinstance(dataset, DataFrame):
723
779
  self._deps = self._batch_inference_validate_snowpark(
724
780
  dataset=dataset,
725
781
  inference_method=inference_method,
726
782
  )
727
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
783
+ assert isinstance(
784
+ dataset._session, Session
785
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
728
786
  transform_kwargs = dict(
729
787
  session=dataset._session,
730
788
  dependencies=self._deps,
731
- drop_input_cols = self._drop_input_cols,
789
+ drop_input_cols=self._drop_input_cols,
732
790
  expected_output_cols_type="float",
733
791
  )
792
+ expected_output_cols = self._align_expected_output_names(
793
+ inference_method, dataset, expected_output_cols, output_cols_prefix
794
+ )
734
795
 
735
796
  elif isinstance(dataset, pd.DataFrame):
736
- transform_kwargs = dict(
737
- snowpark_input_cols = self._snowpark_cols,
738
- drop_input_cols = self._drop_input_cols
739
- )
797
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
740
798
 
741
799
  transform_handlers = ModelTransformerBuilder.build(
742
800
  dataset=dataset,
@@ -748,7 +806,7 @@ class ExtraTreesClassifier(BaseTransformer):
748
806
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
749
807
  inference_method=inference_method,
750
808
  input_cols=self.input_cols,
751
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
809
+ expected_output_cols=expected_output_cols,
752
810
  **transform_kwargs
753
811
  )
754
812
  return output_df
@@ -780,7 +838,8 @@ class ExtraTreesClassifier(BaseTransformer):
780
838
  Output dataset with log probability of the sample for each class in the model.
781
839
  """
782
840
  super()._check_dataset_type(dataset)
783
- inference_method="predict_log_proba"
841
+ inference_method = "predict_log_proba"
842
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
784
843
 
785
844
  # This dictionary contains optional kwargs for batch inference. These kwargs
786
845
  # are specific to the type of dataset used.
@@ -791,18 +850,20 @@ class ExtraTreesClassifier(BaseTransformer):
791
850
  dataset=dataset,
792
851
  inference_method=inference_method,
793
852
  )
794
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
853
+ assert isinstance(
854
+ dataset._session, Session
855
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
795
856
  transform_kwargs = dict(
796
857
  session=dataset._session,
797
858
  dependencies=self._deps,
798
- drop_input_cols = self._drop_input_cols,
859
+ drop_input_cols=self._drop_input_cols,
799
860
  expected_output_cols_type="float",
800
861
  )
862
+ expected_output_cols = self._align_expected_output_names(
863
+ inference_method, dataset, expected_output_cols, output_cols_prefix
864
+ )
801
865
  elif isinstance(dataset, pd.DataFrame):
802
- transform_kwargs = dict(
803
- snowpark_input_cols = self._snowpark_cols,
804
- drop_input_cols = self._drop_input_cols
805
- )
866
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
806
867
 
807
868
  transform_handlers = ModelTransformerBuilder.build(
808
869
  dataset=dataset,
@@ -815,7 +876,7 @@ class ExtraTreesClassifier(BaseTransformer):
815
876
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
816
877
  inference_method=inference_method,
817
878
  input_cols=self.input_cols,
818
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
879
+ expected_output_cols=expected_output_cols,
819
880
  **transform_kwargs
820
881
  )
821
882
  return output_df
@@ -841,30 +902,34 @@ class ExtraTreesClassifier(BaseTransformer):
841
902
  Output dataset with results of the decision function for the samples in input dataset.
842
903
  """
843
904
  super()._check_dataset_type(dataset)
844
- inference_method="decision_function"
905
+ inference_method = "decision_function"
845
906
 
846
907
  # This dictionary contains optional kwargs for batch inference. These kwargs
847
908
  # are specific to the type of dataset used.
848
909
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
849
910
 
911
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
912
+
850
913
  if isinstance(dataset, DataFrame):
851
914
  self._deps = self._batch_inference_validate_snowpark(
852
915
  dataset=dataset,
853
916
  inference_method=inference_method,
854
917
  )
855
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
918
+ assert isinstance(
919
+ dataset._session, Session
920
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
856
921
  transform_kwargs = dict(
857
922
  session=dataset._session,
858
923
  dependencies=self._deps,
859
- drop_input_cols = self._drop_input_cols,
924
+ drop_input_cols=self._drop_input_cols,
860
925
  expected_output_cols_type="float",
861
926
  )
927
+ expected_output_cols = self._align_expected_output_names(
928
+ inference_method, dataset, expected_output_cols, output_cols_prefix
929
+ )
862
930
 
863
931
  elif isinstance(dataset, pd.DataFrame):
864
- transform_kwargs = dict(
865
- snowpark_input_cols = self._snowpark_cols,
866
- drop_input_cols = self._drop_input_cols
867
- )
932
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
868
933
 
869
934
  transform_handlers = ModelTransformerBuilder.build(
870
935
  dataset=dataset,
@@ -877,7 +942,7 @@ class ExtraTreesClassifier(BaseTransformer):
877
942
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
878
943
  inference_method=inference_method,
879
944
  input_cols=self.input_cols,
880
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
945
+ expected_output_cols=expected_output_cols,
881
946
  **transform_kwargs
882
947
  )
883
948
  return output_df
@@ -906,12 +971,14 @@ class ExtraTreesClassifier(BaseTransformer):
906
971
  Output dataset with probability of the sample for each class in the model.
907
972
  """
908
973
  super()._check_dataset_type(dataset)
909
- inference_method="score_samples"
974
+ inference_method = "score_samples"
910
975
 
911
976
  # This dictionary contains optional kwargs for batch inference. These kwargs
912
977
  # are specific to the type of dataset used.
913
978
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
914
979
 
980
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
981
+
915
982
  if isinstance(dataset, DataFrame):
916
983
  self._deps = self._batch_inference_validate_snowpark(
917
984
  dataset=dataset,
@@ -924,6 +991,9 @@ class ExtraTreesClassifier(BaseTransformer):
924
991
  drop_input_cols = self._drop_input_cols,
925
992
  expected_output_cols_type="float",
926
993
  )
994
+ expected_output_cols = self._align_expected_output_names(
995
+ inference_method, dataset, expected_output_cols, output_cols_prefix
996
+ )
927
997
 
928
998
  elif isinstance(dataset, pd.DataFrame):
929
999
  transform_kwargs = dict(
@@ -942,7 +1012,7 @@ class ExtraTreesClassifier(BaseTransformer):
942
1012
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
943
1013
  inference_method=inference_method,
944
1014
  input_cols=self.input_cols,
945
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
1015
+ expected_output_cols=expected_output_cols,
946
1016
  **transform_kwargs
947
1017
  )
948
1018
  return output_df
@@ -1089,50 +1159,84 @@ class ExtraTreesClassifier(BaseTransformer):
1089
1159
  )
1090
1160
  return output_df
1091
1161
 
1162
+
1163
+
1164
+ def to_sklearn(self) -> Any:
1165
+ """Get sklearn.ensemble.ExtraTreesClassifier object.
1166
+ """
1167
+ if self._sklearn_object is None:
1168
+ self._sklearn_object = self._create_sklearn_object()
1169
+ return self._sklearn_object
1170
+
1171
+ def to_xgboost(self) -> Any:
1172
+ raise exceptions.SnowflakeMLException(
1173
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1174
+ original_exception=AttributeError(
1175
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1176
+ "to_xgboost()",
1177
+ "to_sklearn()"
1178
+ )
1179
+ ),
1180
+ )
1181
+
1182
+ def to_lightgbm(self) -> Any:
1183
+ raise exceptions.SnowflakeMLException(
1184
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1185
+ original_exception=AttributeError(
1186
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1187
+ "to_lightgbm()",
1188
+ "to_sklearn()"
1189
+ )
1190
+ ),
1191
+ )
1092
1192
 
1093
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1193
+ def _get_dependencies(self) -> List[str]:
1194
+ return self._deps
1195
+
1196
+
1197
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1094
1198
  self._model_signature_dict = dict()
1095
1199
 
1096
1200
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1097
1201
 
1098
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1202
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1099
1203
  outputs: List[BaseFeatureSpec] = []
1100
1204
  if hasattr(self, "predict"):
1101
1205
  # keep mypy happy
1102
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1206
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1103
1207
  # For classifier, the type of predict is the same as the type of label
1104
- if self._sklearn_object._estimator_type == 'classifier':
1105
- # label columns is the desired type for output
1208
+ if self._sklearn_object._estimator_type == "classifier":
1209
+ # label columns is the desired type for output
1106
1210
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1107
1211
  # rename the output columns
1108
1212
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1109
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1110
- ([] if self._drop_input_cols else inputs)
1111
- + outputs)
1213
+ self._model_signature_dict["predict"] = ModelSignature(
1214
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1215
+ )
1112
1216
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
1113
1217
  # For outlier models, returns -1 for outliers and 1 for inliers.
1114
- # Clusterer returns int64 cluster labels.
1218
+ # Clusterer returns int64 cluster labels.
1115
1219
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
1116
1220
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
1117
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1118
- ([] if self._drop_input_cols else inputs)
1119
- + outputs)
1120
-
1221
+ self._model_signature_dict["predict"] = ModelSignature(
1222
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1223
+ )
1224
+
1121
1225
  # For regressor, the type of predict is float64
1122
- elif self._sklearn_object._estimator_type == 'regressor':
1226
+ elif self._sklearn_object._estimator_type == "regressor":
1123
1227
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1124
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1125
- ([] if self._drop_input_cols else inputs)
1126
- + outputs)
1127
-
1228
+ self._model_signature_dict["predict"] = ModelSignature(
1229
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1230
+ )
1231
+
1128
1232
  for prob_func in PROB_FUNCTIONS:
1129
1233
  if hasattr(self, prob_func):
1130
1234
  output_cols_prefix: str = f"{prob_func}_"
1131
1235
  output_column_names = self._get_output_column_names(output_cols_prefix)
1132
1236
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1133
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1134
- ([] if self._drop_input_cols else inputs)
1135
- + outputs)
1237
+ self._model_signature_dict[prob_func] = ModelSignature(
1238
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1239
+ )
1136
1240
 
1137
1241
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1138
1242
  items = list(self._model_signature_dict.items())
@@ -1145,10 +1249,10 @@ class ExtraTreesClassifier(BaseTransformer):
1145
1249
  """Returns model signature of current class.
1146
1250
 
1147
1251
  Raises:
1148
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1252
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1149
1253
 
1150
1254
  Returns:
1151
- Dict[str, ModelSignature]: each method and its input output signature
1255
+ Dict with each method and its input output signature
1152
1256
  """
1153
1257
  if self._model_signature_dict is None:
1154
1258
  raise exceptions.SnowflakeMLException(
@@ -1156,35 +1260,3 @@ class ExtraTreesClassifier(BaseTransformer):
1156
1260
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1157
1261
  )
1158
1262
  return self._model_signature_dict
1159
-
1160
- def to_sklearn(self) -> Any:
1161
- """Get sklearn.ensemble.ExtraTreesClassifier object.
1162
- """
1163
- if self._sklearn_object is None:
1164
- self._sklearn_object = self._create_sklearn_object()
1165
- return self._sklearn_object
1166
-
1167
- def to_xgboost(self) -> Any:
1168
- raise exceptions.SnowflakeMLException(
1169
- error_code=error_codes.METHOD_NOT_ALLOWED,
1170
- original_exception=AttributeError(
1171
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1172
- "to_xgboost()",
1173
- "to_sklearn()"
1174
- )
1175
- ),
1176
- )
1177
-
1178
- def to_lightgbm(self) -> Any:
1179
- raise exceptions.SnowflakeMLException(
1180
- error_code=error_codes.METHOD_NOT_ALLOWED,
1181
- original_exception=AttributeError(
1182
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1183
- "to_lightgbm()",
1184
- "to_sklearn()"
1185
- )
1186
- ),
1187
- )
1188
-
1189
- def _get_dependencies(self) -> List[str]:
1190
- return self._deps