snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
61
60
 
62
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
62
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
63
  class OrthogonalMatchingPursuit(BaseTransformer):
71
64
  r"""Orthogonal Matching Pursuit model (OMP)
72
65
  For more details on this class, see [sklearn.linear_model.OrthogonalMatchingPursuit]
@@ -227,12 +220,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
227
220
  )
228
221
  return selected_cols
229
222
 
230
- @telemetry.send_api_usage_telemetry(
231
- project=_PROJECT,
232
- subproject=_SUBPROJECT,
233
- custom_tags=dict([("autogen", True)]),
234
- )
235
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OrthogonalMatchingPursuit":
223
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OrthogonalMatchingPursuit":
236
224
  """Fit the model using X, y as training data
237
225
  For more details on this function, see [sklearn.linear_model.OrthogonalMatchingPursuit.fit]
238
226
  (https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.OrthogonalMatchingPursuit.html#sklearn.linear_model.OrthogonalMatchingPursuit.fit)
@@ -259,12 +247,14 @@ class OrthogonalMatchingPursuit(BaseTransformer):
259
247
 
260
248
  self._snowpark_cols = dataset.select(self.input_cols).columns
261
249
 
262
- # If we are already in a stored procedure, no need to kick off another one.
250
+ # If we are already in a stored procedure, no need to kick off another one.
263
251
  if SNOWML_SPROC_ENV in os.environ:
264
252
  statement_params = telemetry.get_function_usage_statement_params(
265
253
  project=_PROJECT,
266
254
  subproject=_SUBPROJECT,
267
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), OrthogonalMatchingPursuit.__class__.__name__),
255
+ function_name=telemetry.get_statement_params_full_func_name(
256
+ inspect.currentframe(), OrthogonalMatchingPursuit.__class__.__name__
257
+ ),
268
258
  api_calls=[Session.call],
269
259
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
270
260
  )
@@ -285,27 +275,24 @@ class OrthogonalMatchingPursuit(BaseTransformer):
285
275
  )
286
276
  self._sklearn_object = model_trainer.train()
287
277
  self._is_fitted = True
288
- self._get_model_signatures(dataset)
278
+ self._generate_model_signatures(dataset)
289
279
  return self
290
280
 
291
281
  def _batch_inference_validate_snowpark(
292
282
  self,
293
283
  dataset: DataFrame,
294
284
  inference_method: str,
295
- ) -> List[str]:
296
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
297
- return the available package that exists in the snowflake anaconda channel
285
+ ) -> None:
286
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
298
287
 
299
288
  Args:
300
289
  dataset: snowpark dataframe
301
290
  inference_method: the inference method such as predict, score...
302
-
291
+
303
292
  Raises:
304
293
  SnowflakeMLException: If the estimator is not fitted, raise error
305
294
  SnowflakeMLException: If the session is None, raise error
306
295
 
307
- Returns:
308
- A list of available package that exists in the snowflake anaconda channel
309
296
  """
310
297
  if not self._is_fitted:
311
298
  raise exceptions.SnowflakeMLException(
@@ -323,9 +310,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
323
310
  "Session must not specified for snowpark dataset."
324
311
  ),
325
312
  )
326
- # Validate that key package version in user workspace are supported in snowflake conda channel
327
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
328
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
313
+
329
314
 
330
315
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
331
316
  @telemetry.send_api_usage_telemetry(
@@ -361,7 +346,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
361
346
  # when it is classifier, infer the datatype from label columns
362
347
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
363
348
  # Batch inference takes a single expected output column type. Use the first columns type for now.
364
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
349
+ label_cols_signatures = [
350
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
351
+ ]
365
352
  if len(label_cols_signatures) == 0:
366
353
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
367
354
  raise exceptions.SnowflakeMLException(
@@ -369,25 +356,23 @@ class OrthogonalMatchingPursuit(BaseTransformer):
369
356
  original_exception=ValueError(error_str),
370
357
  )
371
358
 
372
- expected_type_inferred = convert_sp_to_sf_type(
373
- label_cols_signatures[0].as_snowpark_type()
374
- )
359
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
375
360
 
376
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
377
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
361
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
362
+ self._deps = self._get_dependencies()
363
+ assert isinstance(
364
+ dataset._session, Session
365
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
378
366
 
379
367
  transform_kwargs = dict(
380
- session = dataset._session,
381
- dependencies = self._deps,
382
- drop_input_cols = self._drop_input_cols,
383
- expected_output_cols_type = expected_type_inferred,
368
+ session=dataset._session,
369
+ dependencies=self._deps,
370
+ drop_input_cols=self._drop_input_cols,
371
+ expected_output_cols_type=expected_type_inferred,
384
372
  )
385
373
 
386
374
  elif isinstance(dataset, pd.DataFrame):
387
- transform_kwargs = dict(
388
- snowpark_input_cols = self._snowpark_cols,
389
- drop_input_cols = self._drop_input_cols
390
- )
375
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
391
376
 
392
377
  transform_handlers = ModelTransformerBuilder.build(
393
378
  dataset=dataset,
@@ -427,7 +412,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
427
412
  Transformed dataset.
428
413
  """
429
414
  super()._check_dataset_type(dataset)
430
- inference_method="transform"
415
+ inference_method = "transform"
431
416
 
432
417
  # This dictionary contains optional kwargs for batch inference. These kwargs
433
418
  # are specific to the type of dataset used.
@@ -457,24 +442,19 @@ class OrthogonalMatchingPursuit(BaseTransformer):
457
442
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
458
443
  expected_dtype = convert_sp_to_sf_type(output_types[0])
459
444
 
460
- self._deps = self._batch_inference_validate_snowpark(
461
- dataset=dataset,
462
- inference_method=inference_method,
463
- )
445
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
446
+ self._deps = self._get_dependencies()
464
447
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
465
448
 
466
449
  transform_kwargs = dict(
467
- session = dataset._session,
468
- dependencies = self._deps,
469
- drop_input_cols = self._drop_input_cols,
470
- expected_output_cols_type = expected_dtype,
450
+ session=dataset._session,
451
+ dependencies=self._deps,
452
+ drop_input_cols=self._drop_input_cols,
453
+ expected_output_cols_type=expected_dtype,
471
454
  )
472
455
 
473
456
  elif isinstance(dataset, pd.DataFrame):
474
- transform_kwargs = dict(
475
- snowpark_input_cols = self._snowpark_cols,
476
- drop_input_cols = self._drop_input_cols
477
- )
457
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
478
458
 
479
459
  transform_handlers = ModelTransformerBuilder.build(
480
460
  dataset=dataset,
@@ -493,7 +473,11 @@ class OrthogonalMatchingPursuit(BaseTransformer):
493
473
  return output_df
494
474
 
495
475
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
496
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
476
+ def fit_predict(
477
+ self,
478
+ dataset: Union[DataFrame, pd.DataFrame],
479
+ output_cols_prefix: str = "fit_predict_",
480
+ ) -> Union[DataFrame, pd.DataFrame]:
497
481
  """ Method not supported for this class.
498
482
 
499
483
 
@@ -518,22 +502,104 @@ class OrthogonalMatchingPursuit(BaseTransformer):
518
502
  )
519
503
  output_result, fitted_estimator = model_trainer.train_fit_predict(
520
504
  drop_input_cols=self._drop_input_cols,
521
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
505
+ expected_output_cols_list=(
506
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
507
+ ),
522
508
  )
523
509
  self._sklearn_object = fitted_estimator
524
510
  self._is_fitted = True
525
511
  return output_result
526
512
 
513
+
514
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
515
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
516
+ """ Method not supported for this class.
517
+
527
518
 
528
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
529
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
530
- """
519
+ Raises:
520
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
521
+
522
+ Args:
523
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
524
+ Snowpark or Pandas DataFrame.
525
+ output_cols_prefix: Prefix for the response columns
531
526
  Returns:
532
527
  Transformed dataset.
533
528
  """
534
- self.fit(dataset)
535
- assert self._sklearn_object is not None
536
- return self._sklearn_object.embedding_
529
+ self._infer_input_output_cols(dataset)
530
+ super()._check_dataset_type(dataset)
531
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
532
+ estimator=self._sklearn_object,
533
+ dataset=dataset,
534
+ input_cols=self.input_cols,
535
+ label_cols=self.label_cols,
536
+ sample_weight_col=self.sample_weight_col,
537
+ autogenerated=self._autogenerated,
538
+ subproject=_SUBPROJECT,
539
+ )
540
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
541
+ drop_input_cols=self._drop_input_cols,
542
+ expected_output_cols_list=self.output_cols,
543
+ )
544
+ self._sklearn_object = fitted_estimator
545
+ self._is_fitted = True
546
+ return output_result
547
+
548
+
549
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
550
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
551
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
552
+ """
553
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
554
+ # The following condition is introduced for kneighbors methods, and not used in other methods
555
+ if output_cols:
556
+ output_cols = [
557
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
558
+ for c in output_cols
559
+ ]
560
+ elif getattr(self._sklearn_object, "classes_", None) is None:
561
+ output_cols = [output_cols_prefix]
562
+ elif self._sklearn_object is not None:
563
+ classes = self._sklearn_object.classes_
564
+ if isinstance(classes, numpy.ndarray):
565
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
566
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
567
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
568
+ output_cols = []
569
+ for i, cl in enumerate(classes):
570
+ # For binary classification, there is only one output column for each class
571
+ # ndarray as the two classes are complementary.
572
+ if len(cl) == 2:
573
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
574
+ else:
575
+ output_cols.extend([
576
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
577
+ ])
578
+ else:
579
+ output_cols = []
580
+
581
+ # Make sure column names are valid snowflake identifiers.
582
+ assert output_cols is not None # Make MyPy happy
583
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
584
+
585
+ return rv
586
+
587
+ def _align_expected_output_names(
588
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
589
+ ) -> List[str]:
590
+ # in case the inferred output column names dimension is different
591
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
592
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
593
+ output_df_columns = list(output_df_pd.columns)
594
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
595
+ if self.sample_weight_col:
596
+ output_df_columns_set -= set(self.sample_weight_col)
597
+ # if the dimension of inferred output column names is correct; use it
598
+ if len(expected_output_cols_list) == len(output_df_columns_set):
599
+ return expected_output_cols_list
600
+ # otherwise, use the sklearn estimator's output
601
+ else:
602
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
537
603
 
538
604
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
539
605
  @telemetry.send_api_usage_telemetry(
@@ -565,24 +631,26 @@ class OrthogonalMatchingPursuit(BaseTransformer):
565
631
  # are specific to the type of dataset used.
566
632
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
567
633
 
634
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
635
+
568
636
  if isinstance(dataset, DataFrame):
569
- self._deps = self._batch_inference_validate_snowpark(
570
- dataset=dataset,
571
- inference_method=inference_method,
572
- )
573
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
637
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
638
+ self._deps = self._get_dependencies()
639
+ assert isinstance(
640
+ dataset._session, Session
641
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
574
642
  transform_kwargs = dict(
575
643
  session=dataset._session,
576
644
  dependencies=self._deps,
577
- drop_input_cols = self._drop_input_cols,
645
+ drop_input_cols=self._drop_input_cols,
578
646
  expected_output_cols_type="float",
579
647
  )
648
+ expected_output_cols = self._align_expected_output_names(
649
+ inference_method, dataset, expected_output_cols, output_cols_prefix
650
+ )
580
651
 
581
652
  elif isinstance(dataset, pd.DataFrame):
582
- transform_kwargs = dict(
583
- snowpark_input_cols = self._snowpark_cols,
584
- drop_input_cols = self._drop_input_cols
585
- )
653
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
586
654
 
587
655
  transform_handlers = ModelTransformerBuilder.build(
588
656
  dataset=dataset,
@@ -594,7 +662,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
594
662
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
595
663
  inference_method=inference_method,
596
664
  input_cols=self.input_cols,
597
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
665
+ expected_output_cols=expected_output_cols,
598
666
  **transform_kwargs
599
667
  )
600
668
  return output_df
@@ -624,29 +692,30 @@ class OrthogonalMatchingPursuit(BaseTransformer):
624
692
  Output dataset with log probability of the sample for each class in the model.
625
693
  """
626
694
  super()._check_dataset_type(dataset)
627
- inference_method="predict_log_proba"
695
+ inference_method = "predict_log_proba"
696
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
628
697
 
629
698
  # This dictionary contains optional kwargs for batch inference. These kwargs
630
699
  # are specific to the type of dataset used.
631
700
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
632
701
 
633
702
  if isinstance(dataset, DataFrame):
634
- self._deps = self._batch_inference_validate_snowpark(
635
- dataset=dataset,
636
- inference_method=inference_method,
637
- )
638
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
703
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
704
+ self._deps = self._get_dependencies()
705
+ assert isinstance(
706
+ dataset._session, Session
707
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
639
708
  transform_kwargs = dict(
640
709
  session=dataset._session,
641
710
  dependencies=self._deps,
642
- drop_input_cols = self._drop_input_cols,
711
+ drop_input_cols=self._drop_input_cols,
643
712
  expected_output_cols_type="float",
644
713
  )
714
+ expected_output_cols = self._align_expected_output_names(
715
+ inference_method, dataset, expected_output_cols, output_cols_prefix
716
+ )
645
717
  elif isinstance(dataset, pd.DataFrame):
646
- transform_kwargs = dict(
647
- snowpark_input_cols = self._snowpark_cols,
648
- drop_input_cols = self._drop_input_cols
649
- )
718
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
650
719
 
651
720
  transform_handlers = ModelTransformerBuilder.build(
652
721
  dataset=dataset,
@@ -659,7 +728,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
659
728
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
660
729
  inference_method=inference_method,
661
730
  input_cols=self.input_cols,
662
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
731
+ expected_output_cols=expected_output_cols,
663
732
  **transform_kwargs
664
733
  )
665
734
  return output_df
@@ -685,30 +754,32 @@ class OrthogonalMatchingPursuit(BaseTransformer):
685
754
  Output dataset with results of the decision function for the samples in input dataset.
686
755
  """
687
756
  super()._check_dataset_type(dataset)
688
- inference_method="decision_function"
757
+ inference_method = "decision_function"
689
758
 
690
759
  # This dictionary contains optional kwargs for batch inference. These kwargs
691
760
  # are specific to the type of dataset used.
692
761
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
693
762
 
763
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
764
+
694
765
  if isinstance(dataset, DataFrame):
695
- self._deps = self._batch_inference_validate_snowpark(
696
- dataset=dataset,
697
- inference_method=inference_method,
698
- )
699
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
766
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
767
+ self._deps = self._get_dependencies()
768
+ assert isinstance(
769
+ dataset._session, Session
770
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
700
771
  transform_kwargs = dict(
701
772
  session=dataset._session,
702
773
  dependencies=self._deps,
703
- drop_input_cols = self._drop_input_cols,
774
+ drop_input_cols=self._drop_input_cols,
704
775
  expected_output_cols_type="float",
705
776
  )
777
+ expected_output_cols = self._align_expected_output_names(
778
+ inference_method, dataset, expected_output_cols, output_cols_prefix
779
+ )
706
780
 
707
781
  elif isinstance(dataset, pd.DataFrame):
708
- transform_kwargs = dict(
709
- snowpark_input_cols = self._snowpark_cols,
710
- drop_input_cols = self._drop_input_cols
711
- )
782
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
712
783
 
713
784
  transform_handlers = ModelTransformerBuilder.build(
714
785
  dataset=dataset,
@@ -721,7 +792,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
721
792
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
722
793
  inference_method=inference_method,
723
794
  input_cols=self.input_cols,
724
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
795
+ expected_output_cols=expected_output_cols,
725
796
  **transform_kwargs
726
797
  )
727
798
  return output_df
@@ -750,17 +821,17 @@ class OrthogonalMatchingPursuit(BaseTransformer):
750
821
  Output dataset with probability of the sample for each class in the model.
751
822
  """
752
823
  super()._check_dataset_type(dataset)
753
- inference_method="score_samples"
824
+ inference_method = "score_samples"
754
825
 
755
826
  # This dictionary contains optional kwargs for batch inference. These kwargs
756
827
  # are specific to the type of dataset used.
757
828
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
758
829
 
830
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
831
+
759
832
  if isinstance(dataset, DataFrame):
760
- self._deps = self._batch_inference_validate_snowpark(
761
- dataset=dataset,
762
- inference_method=inference_method,
763
- )
833
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
834
+ self._deps = self._get_dependencies()
764
835
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
765
836
  transform_kwargs = dict(
766
837
  session=dataset._session,
@@ -768,6 +839,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
768
839
  drop_input_cols = self._drop_input_cols,
769
840
  expected_output_cols_type="float",
770
841
  )
842
+ expected_output_cols = self._align_expected_output_names(
843
+ inference_method, dataset, expected_output_cols, output_cols_prefix
844
+ )
771
845
 
772
846
  elif isinstance(dataset, pd.DataFrame):
773
847
  transform_kwargs = dict(
@@ -786,7 +860,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
786
860
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
787
861
  inference_method=inference_method,
788
862
  input_cols=self.input_cols,
789
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
863
+ expected_output_cols=expected_output_cols,
790
864
  **transform_kwargs
791
865
  )
792
866
  return output_df
@@ -821,17 +895,15 @@ class OrthogonalMatchingPursuit(BaseTransformer):
821
895
  transform_kwargs: ScoreKwargsTypedDict = dict()
822
896
 
823
897
  if isinstance(dataset, DataFrame):
824
- self._deps = self._batch_inference_validate_snowpark(
825
- dataset=dataset,
826
- inference_method="score",
827
- )
898
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
899
+ self._deps = self._get_dependencies()
828
900
  selected_cols = self._get_active_columns()
829
901
  if len(selected_cols) > 0:
830
902
  dataset = dataset.select(selected_cols)
831
903
  assert isinstance(dataset._session, Session) # keep mypy happy
832
904
  transform_kwargs = dict(
833
905
  session=dataset._session,
834
- dependencies=["snowflake-snowpark-python"] + self._deps,
906
+ dependencies=self._deps,
835
907
  score_sproc_imports=['sklearn'],
836
908
  )
837
909
  elif isinstance(dataset, pd.DataFrame):
@@ -896,11 +968,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
896
968
 
897
969
  if isinstance(dataset, DataFrame):
898
970
 
899
- self._deps = self._batch_inference_validate_snowpark(
900
- dataset=dataset,
901
- inference_method=inference_method,
902
-
903
- )
971
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
972
+ self._deps = self._get_dependencies()
904
973
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
905
974
  transform_kwargs = dict(
906
975
  session = dataset._session,
@@ -933,50 +1002,84 @@ class OrthogonalMatchingPursuit(BaseTransformer):
933
1002
  )
934
1003
  return output_df
935
1004
 
1005
+
1006
+
1007
+ def to_sklearn(self) -> Any:
1008
+ """Get sklearn.linear_model.OrthogonalMatchingPursuit object.
1009
+ """
1010
+ if self._sklearn_object is None:
1011
+ self._sklearn_object = self._create_sklearn_object()
1012
+ return self._sklearn_object
1013
+
1014
+ def to_xgboost(self) -> Any:
1015
+ raise exceptions.SnowflakeMLException(
1016
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1017
+ original_exception=AttributeError(
1018
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1019
+ "to_xgboost()",
1020
+ "to_sklearn()"
1021
+ )
1022
+ ),
1023
+ )
1024
+
1025
+ def to_lightgbm(self) -> Any:
1026
+ raise exceptions.SnowflakeMLException(
1027
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1028
+ original_exception=AttributeError(
1029
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1030
+ "to_lightgbm()",
1031
+ "to_sklearn()"
1032
+ )
1033
+ ),
1034
+ )
1035
+
1036
+ def _get_dependencies(self) -> List[str]:
1037
+ return self._deps
1038
+
936
1039
 
937
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1040
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
938
1041
  self._model_signature_dict = dict()
939
1042
 
940
1043
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
941
1044
 
942
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1045
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
943
1046
  outputs: List[BaseFeatureSpec] = []
944
1047
  if hasattr(self, "predict"):
945
1048
  # keep mypy happy
946
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1049
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
947
1050
  # For classifier, the type of predict is the same as the type of label
948
- if self._sklearn_object._estimator_type == 'classifier':
949
- # label columns is the desired type for output
1051
+ if self._sklearn_object._estimator_type == "classifier":
1052
+ # label columns is the desired type for output
950
1053
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
951
1054
  # rename the output columns
952
1055
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
953
- self._model_signature_dict["predict"] = ModelSignature(inputs,
954
- ([] if self._drop_input_cols else inputs)
955
- + outputs)
1056
+ self._model_signature_dict["predict"] = ModelSignature(
1057
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1058
+ )
956
1059
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
957
1060
  # For outlier models, returns -1 for outliers and 1 for inliers.
958
- # Clusterer returns int64 cluster labels.
1061
+ # Clusterer returns int64 cluster labels.
959
1062
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
960
1063
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
961
- self._model_signature_dict["predict"] = ModelSignature(inputs,
962
- ([] if self._drop_input_cols else inputs)
963
- + outputs)
964
-
1064
+ self._model_signature_dict["predict"] = ModelSignature(
1065
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1066
+ )
1067
+
965
1068
  # For regressor, the type of predict is float64
966
- elif self._sklearn_object._estimator_type == 'regressor':
1069
+ elif self._sklearn_object._estimator_type == "regressor":
967
1070
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
968
- self._model_signature_dict["predict"] = ModelSignature(inputs,
969
- ([] if self._drop_input_cols else inputs)
970
- + outputs)
971
-
1071
+ self._model_signature_dict["predict"] = ModelSignature(
1072
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1073
+ )
1074
+
972
1075
  for prob_func in PROB_FUNCTIONS:
973
1076
  if hasattr(self, prob_func):
974
1077
  output_cols_prefix: str = f"{prob_func}_"
975
1078
  output_column_names = self._get_output_column_names(output_cols_prefix)
976
1079
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
977
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
978
- ([] if self._drop_input_cols else inputs)
979
- + outputs)
1080
+ self._model_signature_dict[prob_func] = ModelSignature(
1081
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1082
+ )
980
1083
 
981
1084
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
982
1085
  items = list(self._model_signature_dict.items())
@@ -989,10 +1092,10 @@ class OrthogonalMatchingPursuit(BaseTransformer):
989
1092
  """Returns model signature of current class.
990
1093
 
991
1094
  Raises:
992
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1095
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
993
1096
 
994
1097
  Returns:
995
- Dict[str, ModelSignature]: each method and its input output signature
1098
+ Dict with each method and its input output signature
996
1099
  """
997
1100
  if self._model_signature_dict is None:
998
1101
  raise exceptions.SnowflakeMLException(
@@ -1000,35 +1103,3 @@ class OrthogonalMatchingPursuit(BaseTransformer):
1000
1103
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1001
1104
  )
1002
1105
  return self._model_signature_dict
1003
-
1004
- def to_sklearn(self) -> Any:
1005
- """Get sklearn.linear_model.OrthogonalMatchingPursuit object.
1006
- """
1007
- if self._sklearn_object is None:
1008
- self._sklearn_object = self._create_sklearn_object()
1009
- return self._sklearn_object
1010
-
1011
- def to_xgboost(self) -> Any:
1012
- raise exceptions.SnowflakeMLException(
1013
- error_code=error_codes.METHOD_NOT_ALLOWED,
1014
- original_exception=AttributeError(
1015
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1016
- "to_xgboost()",
1017
- "to_sklearn()"
1018
- )
1019
- ),
1020
- )
1021
-
1022
- def to_lightgbm(self) -> Any:
1023
- raise exceptions.SnowflakeMLException(
1024
- error_code=error_codes.METHOD_NOT_ALLOWED,
1025
- original_exception=AttributeError(
1026
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1027
- "to_lightgbm()",
1028
- "to_sklearn()"
1029
- )
1030
- ),
1031
- )
1032
-
1033
- def _get_dependencies(self) -> List[str]:
1034
- return self._deps