snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
32
32
  BatchInferenceKwargsTypedDict,
33
33
  ScoreKwargsTypedDict
34
34
  )
35
+ from snowflake.ml.model._signatures import utils as model_signature_utils
36
+ from snowflake.ml.model.model_signature import (
37
+ BaseFeatureSpec,
38
+ DataType,
39
+ FeatureSpec,
40
+ ModelSignature,
41
+ _infer_signature,
42
+ _rename_signature_with_snowflake_identifiers,
43
+ )
35
44
 
36
45
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
37
46
 
@@ -42,16 +51,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
42
51
  validate_sklearn_args,
43
52
  )
44
53
 
45
- from snowflake.ml.model.model_signature import (
46
- DataType,
47
- FeatureSpec,
48
- ModelSignature,
49
- _infer_signature,
50
- _rename_signature_with_snowflake_identifiers,
51
- BaseFeatureSpec,
52
- )
53
- from snowflake.ml.model._signatures import utils as model_signature_utils
54
-
55
54
  _PROJECT = "ModelDevelopment"
56
55
  # Derive subproject from module name by removing "sklearn"
57
56
  # and converting module name from underscore to CamelCase
@@ -60,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
60
59
 
61
60
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
61
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
62
  class XGBRegressor(BaseTransformer):
70
63
  r"""Implementation of the scikit-learn API for XGBoost regression
71
64
  For more details on this class, see [xgboost.XGBRegressor]
@@ -421,12 +414,7 @@ class XGBRegressor(BaseTransformer):
421
414
  )
422
415
  return selected_cols
423
416
 
424
- @telemetry.send_api_usage_telemetry(
425
- project=_PROJECT,
426
- subproject=_SUBPROJECT,
427
- custom_tags=dict([("autogen", True)]),
428
- )
429
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "XGBRegressor":
417
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "XGBRegressor":
430
418
  """Fit gradient boosting model
431
419
  For more details on this function, see [xgboost.XGBRegressor.fit]
432
420
  (https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBRegressor.fit)
@@ -453,12 +441,14 @@ class XGBRegressor(BaseTransformer):
453
441
 
454
442
  self._snowpark_cols = dataset.select(self.input_cols).columns
455
443
 
456
- # If we are already in a stored procedure, no need to kick off another one.
444
+ # If we are already in a stored procedure, no need to kick off another one.
457
445
  if SNOWML_SPROC_ENV in os.environ:
458
446
  statement_params = telemetry.get_function_usage_statement_params(
459
447
  project=_PROJECT,
460
448
  subproject=_SUBPROJECT,
461
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), XGBRegressor.__class__.__name__),
449
+ function_name=telemetry.get_statement_params_full_func_name(
450
+ inspect.currentframe(), XGBRegressor.__class__.__name__
451
+ ),
462
452
  api_calls=[Session.call],
463
453
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
464
454
  )
@@ -479,27 +469,24 @@ class XGBRegressor(BaseTransformer):
479
469
  )
480
470
  self._sklearn_object = model_trainer.train()
481
471
  self._is_fitted = True
482
- self._get_model_signatures(dataset)
472
+ self._generate_model_signatures(dataset)
483
473
  return self
484
474
 
485
475
  def _batch_inference_validate_snowpark(
486
476
  self,
487
477
  dataset: DataFrame,
488
478
  inference_method: str,
489
- ) -> List[str]:
490
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
491
- return the available package that exists in the snowflake anaconda channel
479
+ ) -> None:
480
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
492
481
 
493
482
  Args:
494
483
  dataset: snowpark dataframe
495
484
  inference_method: the inference method such as predict, score...
496
-
485
+
497
486
  Raises:
498
487
  SnowflakeMLException: If the estimator is not fitted, raise error
499
488
  SnowflakeMLException: If the session is None, raise error
500
489
 
501
- Returns:
502
- A list of available package that exists in the snowflake anaconda channel
503
490
  """
504
491
  if not self._is_fitted:
505
492
  raise exceptions.SnowflakeMLException(
@@ -517,9 +504,7 @@ class XGBRegressor(BaseTransformer):
517
504
  "Session must not specified for snowpark dataset."
518
505
  ),
519
506
  )
520
- # Validate that key package version in user workspace are supported in snowflake conda channel
521
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
522
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
507
+
523
508
 
524
509
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
525
510
  @telemetry.send_api_usage_telemetry(
@@ -555,7 +540,9 @@ class XGBRegressor(BaseTransformer):
555
540
  # when it is classifier, infer the datatype from label columns
556
541
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
557
542
  # Batch inference takes a single expected output column type. Use the first columns type for now.
558
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
543
+ label_cols_signatures = [
544
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
545
+ ]
559
546
  if len(label_cols_signatures) == 0:
560
547
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
561
548
  raise exceptions.SnowflakeMLException(
@@ -563,25 +550,23 @@ class XGBRegressor(BaseTransformer):
563
550
  original_exception=ValueError(error_str),
564
551
  )
565
552
 
566
- expected_type_inferred = convert_sp_to_sf_type(
567
- label_cols_signatures[0].as_snowpark_type()
568
- )
553
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
569
554
 
570
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
571
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
555
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
556
+ self._deps = self._get_dependencies()
557
+ assert isinstance(
558
+ dataset._session, Session
559
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
572
560
 
573
561
  transform_kwargs = dict(
574
- session = dataset._session,
575
- dependencies = self._deps,
576
- drop_input_cols = self._drop_input_cols,
577
- expected_output_cols_type = expected_type_inferred,
562
+ session=dataset._session,
563
+ dependencies=self._deps,
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_type=expected_type_inferred,
578
566
  )
579
567
 
580
568
  elif isinstance(dataset, pd.DataFrame):
581
- transform_kwargs = dict(
582
- snowpark_input_cols = self._snowpark_cols,
583
- drop_input_cols = self._drop_input_cols
584
- )
569
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
585
570
 
586
571
  transform_handlers = ModelTransformerBuilder.build(
587
572
  dataset=dataset,
@@ -621,7 +606,7 @@ class XGBRegressor(BaseTransformer):
621
606
  Transformed dataset.
622
607
  """
623
608
  super()._check_dataset_type(dataset)
624
- inference_method="transform"
609
+ inference_method = "transform"
625
610
 
626
611
  # This dictionary contains optional kwargs for batch inference. These kwargs
627
612
  # are specific to the type of dataset used.
@@ -651,24 +636,19 @@ class XGBRegressor(BaseTransformer):
651
636
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
652
637
  expected_dtype = convert_sp_to_sf_type(output_types[0])
653
638
 
654
- self._deps = self._batch_inference_validate_snowpark(
655
- dataset=dataset,
656
- inference_method=inference_method,
657
- )
639
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
640
+ self._deps = self._get_dependencies()
658
641
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
659
642
 
660
643
  transform_kwargs = dict(
661
- session = dataset._session,
662
- dependencies = self._deps,
663
- drop_input_cols = self._drop_input_cols,
664
- expected_output_cols_type = expected_dtype,
644
+ session=dataset._session,
645
+ dependencies=self._deps,
646
+ drop_input_cols=self._drop_input_cols,
647
+ expected_output_cols_type=expected_dtype,
665
648
  )
666
649
 
667
650
  elif isinstance(dataset, pd.DataFrame):
668
- transform_kwargs = dict(
669
- snowpark_input_cols = self._snowpark_cols,
670
- drop_input_cols = self._drop_input_cols
671
- )
651
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
672
652
 
673
653
  transform_handlers = ModelTransformerBuilder.build(
674
654
  dataset=dataset,
@@ -687,7 +667,11 @@ class XGBRegressor(BaseTransformer):
687
667
  return output_df
688
668
 
689
669
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
690
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
670
+ def fit_predict(
671
+ self,
672
+ dataset: Union[DataFrame, pd.DataFrame],
673
+ output_cols_prefix: str = "fit_predict_",
674
+ ) -> Union[DataFrame, pd.DataFrame]:
691
675
  """ Method not supported for this class.
692
676
 
693
677
 
@@ -712,22 +696,104 @@ class XGBRegressor(BaseTransformer):
712
696
  )
713
697
  output_result, fitted_estimator = model_trainer.train_fit_predict(
714
698
  drop_input_cols=self._drop_input_cols,
715
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
699
+ expected_output_cols_list=(
700
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
701
+ ),
716
702
  )
717
703
  self._sklearn_object = fitted_estimator
718
704
  self._is_fitted = True
719
705
  return output_result
720
706
 
707
+
708
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
709
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
710
+ """ Method not supported for this class.
711
+
721
712
 
722
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
723
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
724
- """
713
+ Raises:
714
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
715
+
716
+ Args:
717
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
718
+ Snowpark or Pandas DataFrame.
719
+ output_cols_prefix: Prefix for the response columns
725
720
  Returns:
726
721
  Transformed dataset.
727
722
  """
728
- self.fit(dataset)
729
- assert self._sklearn_object is not None
730
- return self._sklearn_object.embedding_
723
+ self._infer_input_output_cols(dataset)
724
+ super()._check_dataset_type(dataset)
725
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
726
+ estimator=self._sklearn_object,
727
+ dataset=dataset,
728
+ input_cols=self.input_cols,
729
+ label_cols=self.label_cols,
730
+ sample_weight_col=self.sample_weight_col,
731
+ autogenerated=self._autogenerated,
732
+ subproject=_SUBPROJECT,
733
+ )
734
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
735
+ drop_input_cols=self._drop_input_cols,
736
+ expected_output_cols_list=self.output_cols,
737
+ )
738
+ self._sklearn_object = fitted_estimator
739
+ self._is_fitted = True
740
+ return output_result
741
+
742
+
743
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
744
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
745
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
746
+ """
747
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
748
+ # The following condition is introduced for kneighbors methods, and not used in other methods
749
+ if output_cols:
750
+ output_cols = [
751
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
752
+ for c in output_cols
753
+ ]
754
+ elif getattr(self._sklearn_object, "classes_", None) is None:
755
+ output_cols = [output_cols_prefix]
756
+ elif self._sklearn_object is not None:
757
+ classes = self._sklearn_object.classes_
758
+ if isinstance(classes, numpy.ndarray):
759
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
760
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
761
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
762
+ output_cols = []
763
+ for i, cl in enumerate(classes):
764
+ # For binary classification, there is only one output column for each class
765
+ # ndarray as the two classes are complementary.
766
+ if len(cl) == 2:
767
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
768
+ else:
769
+ output_cols.extend([
770
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
771
+ ])
772
+ else:
773
+ output_cols = []
774
+
775
+ # Make sure column names are valid snowflake identifiers.
776
+ assert output_cols is not None # Make MyPy happy
777
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
778
+
779
+ return rv
780
+
781
+ def _align_expected_output_names(
782
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
783
+ ) -> List[str]:
784
+ # in case the inferred output column names dimension is different
785
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
786
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
787
+ output_df_columns = list(output_df_pd.columns)
788
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
789
+ if self.sample_weight_col:
790
+ output_df_columns_set -= set(self.sample_weight_col)
791
+ # if the dimension of inferred output column names is correct; use it
792
+ if len(expected_output_cols_list) == len(output_df_columns_set):
793
+ return expected_output_cols_list
794
+ # otherwise, use the sklearn estimator's output
795
+ else:
796
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
731
797
 
732
798
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
733
799
  @telemetry.send_api_usage_telemetry(
@@ -759,24 +825,26 @@ class XGBRegressor(BaseTransformer):
759
825
  # are specific to the type of dataset used.
760
826
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
761
827
 
828
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
829
+
762
830
  if isinstance(dataset, DataFrame):
763
- self._deps = self._batch_inference_validate_snowpark(
764
- dataset=dataset,
765
- inference_method=inference_method,
766
- )
767
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
831
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
832
+ self._deps = self._get_dependencies()
833
+ assert isinstance(
834
+ dataset._session, Session
835
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
768
836
  transform_kwargs = dict(
769
837
  session=dataset._session,
770
838
  dependencies=self._deps,
771
- drop_input_cols = self._drop_input_cols,
839
+ drop_input_cols=self._drop_input_cols,
772
840
  expected_output_cols_type="float",
773
841
  )
842
+ expected_output_cols = self._align_expected_output_names(
843
+ inference_method, dataset, expected_output_cols, output_cols_prefix
844
+ )
774
845
 
775
846
  elif isinstance(dataset, pd.DataFrame):
776
- transform_kwargs = dict(
777
- snowpark_input_cols = self._snowpark_cols,
778
- drop_input_cols = self._drop_input_cols
779
- )
847
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
780
848
 
781
849
  transform_handlers = ModelTransformerBuilder.build(
782
850
  dataset=dataset,
@@ -788,7 +856,7 @@ class XGBRegressor(BaseTransformer):
788
856
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
789
857
  inference_method=inference_method,
790
858
  input_cols=self.input_cols,
791
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
859
+ expected_output_cols=expected_output_cols,
792
860
  **transform_kwargs
793
861
  )
794
862
  return output_df
@@ -818,29 +886,30 @@ class XGBRegressor(BaseTransformer):
818
886
  Output dataset with log probability of the sample for each class in the model.
819
887
  """
820
888
  super()._check_dataset_type(dataset)
821
- inference_method="predict_log_proba"
889
+ inference_method = "predict_log_proba"
890
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
822
891
 
823
892
  # This dictionary contains optional kwargs for batch inference. These kwargs
824
893
  # are specific to the type of dataset used.
825
894
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
826
895
 
827
896
  if isinstance(dataset, DataFrame):
828
- self._deps = self._batch_inference_validate_snowpark(
829
- dataset=dataset,
830
- inference_method=inference_method,
831
- )
832
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
897
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
898
+ self._deps = self._get_dependencies()
899
+ assert isinstance(
900
+ dataset._session, Session
901
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
833
902
  transform_kwargs = dict(
834
903
  session=dataset._session,
835
904
  dependencies=self._deps,
836
- drop_input_cols = self._drop_input_cols,
905
+ drop_input_cols=self._drop_input_cols,
837
906
  expected_output_cols_type="float",
838
907
  )
908
+ expected_output_cols = self._align_expected_output_names(
909
+ inference_method, dataset, expected_output_cols, output_cols_prefix
910
+ )
839
911
  elif isinstance(dataset, pd.DataFrame):
840
- transform_kwargs = dict(
841
- snowpark_input_cols = self._snowpark_cols,
842
- drop_input_cols = self._drop_input_cols
843
- )
912
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
844
913
 
845
914
  transform_handlers = ModelTransformerBuilder.build(
846
915
  dataset=dataset,
@@ -853,7 +922,7 @@ class XGBRegressor(BaseTransformer):
853
922
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
854
923
  inference_method=inference_method,
855
924
  input_cols=self.input_cols,
856
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
925
+ expected_output_cols=expected_output_cols,
857
926
  **transform_kwargs
858
927
  )
859
928
  return output_df
@@ -879,30 +948,32 @@ class XGBRegressor(BaseTransformer):
879
948
  Output dataset with results of the decision function for the samples in input dataset.
880
949
  """
881
950
  super()._check_dataset_type(dataset)
882
- inference_method="decision_function"
951
+ inference_method = "decision_function"
883
952
 
884
953
  # This dictionary contains optional kwargs for batch inference. These kwargs
885
954
  # are specific to the type of dataset used.
886
955
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
887
956
 
957
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
958
+
888
959
  if isinstance(dataset, DataFrame):
889
- self._deps = self._batch_inference_validate_snowpark(
890
- dataset=dataset,
891
- inference_method=inference_method,
892
- )
893
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
960
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
961
+ self._deps = self._get_dependencies()
962
+ assert isinstance(
963
+ dataset._session, Session
964
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
894
965
  transform_kwargs = dict(
895
966
  session=dataset._session,
896
967
  dependencies=self._deps,
897
- drop_input_cols = self._drop_input_cols,
968
+ drop_input_cols=self._drop_input_cols,
898
969
  expected_output_cols_type="float",
899
970
  )
971
+ expected_output_cols = self._align_expected_output_names(
972
+ inference_method, dataset, expected_output_cols, output_cols_prefix
973
+ )
900
974
 
901
975
  elif isinstance(dataset, pd.DataFrame):
902
- transform_kwargs = dict(
903
- snowpark_input_cols = self._snowpark_cols,
904
- drop_input_cols = self._drop_input_cols
905
- )
976
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
906
977
 
907
978
  transform_handlers = ModelTransformerBuilder.build(
908
979
  dataset=dataset,
@@ -915,7 +986,7 @@ class XGBRegressor(BaseTransformer):
915
986
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
916
987
  inference_method=inference_method,
917
988
  input_cols=self.input_cols,
918
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
989
+ expected_output_cols=expected_output_cols,
919
990
  **transform_kwargs
920
991
  )
921
992
  return output_df
@@ -944,17 +1015,17 @@ class XGBRegressor(BaseTransformer):
944
1015
  Output dataset with probability of the sample for each class in the model.
945
1016
  """
946
1017
  super()._check_dataset_type(dataset)
947
- inference_method="score_samples"
1018
+ inference_method = "score_samples"
948
1019
 
949
1020
  # This dictionary contains optional kwargs for batch inference. These kwargs
950
1021
  # are specific to the type of dataset used.
951
1022
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
952
1023
 
1024
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
1025
+
953
1026
  if isinstance(dataset, DataFrame):
954
- self._deps = self._batch_inference_validate_snowpark(
955
- dataset=dataset,
956
- inference_method=inference_method,
957
- )
1027
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1028
+ self._deps = self._get_dependencies()
958
1029
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
959
1030
  transform_kwargs = dict(
960
1031
  session=dataset._session,
@@ -962,6 +1033,9 @@ class XGBRegressor(BaseTransformer):
962
1033
  drop_input_cols = self._drop_input_cols,
963
1034
  expected_output_cols_type="float",
964
1035
  )
1036
+ expected_output_cols = self._align_expected_output_names(
1037
+ inference_method, dataset, expected_output_cols, output_cols_prefix
1038
+ )
965
1039
 
966
1040
  elif isinstance(dataset, pd.DataFrame):
967
1041
  transform_kwargs = dict(
@@ -980,7 +1054,7 @@ class XGBRegressor(BaseTransformer):
980
1054
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
981
1055
  inference_method=inference_method,
982
1056
  input_cols=self.input_cols,
983
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
1057
+ expected_output_cols=expected_output_cols,
984
1058
  **transform_kwargs
985
1059
  )
986
1060
  return output_df
@@ -1015,17 +1089,15 @@ class XGBRegressor(BaseTransformer):
1015
1089
  transform_kwargs: ScoreKwargsTypedDict = dict()
1016
1090
 
1017
1091
  if isinstance(dataset, DataFrame):
1018
- self._deps = self._batch_inference_validate_snowpark(
1019
- dataset=dataset,
1020
- inference_method="score",
1021
- )
1092
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1093
+ self._deps = self._get_dependencies()
1022
1094
  selected_cols = self._get_active_columns()
1023
1095
  if len(selected_cols) > 0:
1024
1096
  dataset = dataset.select(selected_cols)
1025
1097
  assert isinstance(dataset._session, Session) # keep mypy happy
1026
1098
  transform_kwargs = dict(
1027
1099
  session=dataset._session,
1028
- dependencies=["snowflake-snowpark-python"] + self._deps,
1100
+ dependencies=self._deps,
1029
1101
  score_sproc_imports=['xgboost'],
1030
1102
  )
1031
1103
  elif isinstance(dataset, pd.DataFrame):
@@ -1090,11 +1162,8 @@ class XGBRegressor(BaseTransformer):
1090
1162
 
1091
1163
  if isinstance(dataset, DataFrame):
1092
1164
 
1093
- self._deps = self._batch_inference_validate_snowpark(
1094
- dataset=dataset,
1095
- inference_method=inference_method,
1096
-
1097
- )
1165
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1166
+ self._deps = self._get_dependencies()
1098
1167
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1099
1168
  transform_kwargs = dict(
1100
1169
  session = dataset._session,
@@ -1127,50 +1196,84 @@ class XGBRegressor(BaseTransformer):
1127
1196
  )
1128
1197
  return output_df
1129
1198
 
1199
+
1200
+
1201
+ def to_xgboost(self) -> Any:
1202
+ """Get xgboost.XGBRegressor object.
1203
+ """
1204
+ if self._sklearn_object is None:
1205
+ self._sklearn_object = self._create_sklearn_object()
1206
+ return self._sklearn_object
1207
+
1208
+ def to_sklearn(self) -> Any:
1209
+ raise exceptions.SnowflakeMLException(
1210
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1211
+ original_exception=AttributeError(
1212
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1213
+ "to_sklearn()",
1214
+ "to_xgboost()"
1215
+ )
1216
+ ),
1217
+ )
1218
+
1219
+ def to_lightgbm(self) -> Any:
1220
+ raise exceptions.SnowflakeMLException(
1221
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1222
+ original_exception=AttributeError(
1223
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1224
+ "to_lightgbm()",
1225
+ "to_xgboost()"
1226
+ )
1227
+ ),
1228
+ )
1229
+
1230
+ def _get_dependencies(self) -> List[str]:
1231
+ return self._deps
1232
+
1130
1233
 
1131
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1234
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1132
1235
  self._model_signature_dict = dict()
1133
1236
 
1134
1237
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
1135
1238
 
1136
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1239
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
1137
1240
  outputs: List[BaseFeatureSpec] = []
1138
1241
  if hasattr(self, "predict"):
1139
1242
  # keep mypy happy
1140
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1243
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1141
1244
  # For classifier, the type of predict is the same as the type of label
1142
- if self._sklearn_object._estimator_type == 'classifier':
1143
- # label columns is the desired type for output
1245
+ if self._sklearn_object._estimator_type == "classifier":
1246
+ # label columns is the desired type for output
1144
1247
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
1145
1248
  # rename the output columns
1146
1249
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
1147
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1148
- ([] if self._drop_input_cols else inputs)
1149
- + outputs)
1250
+ self._model_signature_dict["predict"] = ModelSignature(
1251
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1252
+ )
1150
1253
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
1151
1254
  # For outlier models, returns -1 for outliers and 1 for inliers.
1152
- # Clusterer returns int64 cluster labels.
1255
+ # Clusterer returns int64 cluster labels.
1153
1256
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
1154
1257
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
1155
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1156
- ([] if self._drop_input_cols else inputs)
1157
- + outputs)
1158
-
1258
+ self._model_signature_dict["predict"] = ModelSignature(
1259
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1260
+ )
1261
+
1159
1262
  # For regressor, the type of predict is float64
1160
- elif self._sklearn_object._estimator_type == 'regressor':
1263
+ elif self._sklearn_object._estimator_type == "regressor":
1161
1264
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
1162
- self._model_signature_dict["predict"] = ModelSignature(inputs,
1163
- ([] if self._drop_input_cols else inputs)
1164
- + outputs)
1165
-
1265
+ self._model_signature_dict["predict"] = ModelSignature(
1266
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1267
+ )
1268
+
1166
1269
  for prob_func in PROB_FUNCTIONS:
1167
1270
  if hasattr(self, prob_func):
1168
1271
  output_cols_prefix: str = f"{prob_func}_"
1169
1272
  output_column_names = self._get_output_column_names(output_cols_prefix)
1170
1273
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
1171
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
1172
- ([] if self._drop_input_cols else inputs)
1173
- + outputs)
1274
+ self._model_signature_dict[prob_func] = ModelSignature(
1275
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1276
+ )
1174
1277
 
1175
1278
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1176
1279
  items = list(self._model_signature_dict.items())
@@ -1183,10 +1286,10 @@ class XGBRegressor(BaseTransformer):
1183
1286
  """Returns model signature of current class.
1184
1287
 
1185
1288
  Raises:
1186
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1289
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1187
1290
 
1188
1291
  Returns:
1189
- Dict[str, ModelSignature]: each method and its input output signature
1292
+ Dict with each method and its input output signature
1190
1293
  """
1191
1294
  if self._model_signature_dict is None:
1192
1295
  raise exceptions.SnowflakeMLException(
@@ -1194,35 +1297,3 @@ class XGBRegressor(BaseTransformer):
1194
1297
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1195
1298
  )
1196
1299
  return self._model_signature_dict
1197
-
1198
- def to_xgboost(self) -> Any:
1199
- """Get xgboost.XGBRegressor object.
1200
- """
1201
- if self._sklearn_object is None:
1202
- self._sklearn_object = self._create_sklearn_object()
1203
- return self._sklearn_object
1204
-
1205
- def to_sklearn(self) -> Any:
1206
- raise exceptions.SnowflakeMLException(
1207
- error_code=error_codes.METHOD_NOT_ALLOWED,
1208
- original_exception=AttributeError(
1209
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1210
- "to_sklearn()",
1211
- "to_xgboost()"
1212
- )
1213
- ),
1214
- )
1215
-
1216
- def to_lightgbm(self) -> Any:
1217
- raise exceptions.SnowflakeMLException(
1218
- error_code=error_codes.METHOD_NOT_ALLOWED,
1219
- original_exception=AttributeError(
1220
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1221
- "to_lightgbm()",
1222
- "to_xgboost()"
1223
- )
1224
- ),
1225
- )
1226
-
1227
- def _get_dependencies(self) -> List[str]:
1228
- return self._deps