snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
61
60
 
62
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
62
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
63
  class AdditiveChi2Sampler(BaseTransformer):
71
64
  r"""Approximate feature map for additive chi2 kernel
72
65
  For more details on this class, see [sklearn.kernel_approximation.AdditiveChi2Sampler]
@@ -199,12 +192,7 @@ class AdditiveChi2Sampler(BaseTransformer):
199
192
  )
200
193
  return selected_cols
201
194
 
202
- @telemetry.send_api_usage_telemetry(
203
- project=_PROJECT,
204
- subproject=_SUBPROJECT,
205
- custom_tags=dict([("autogen", True)]),
206
- )
207
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "AdditiveChi2Sampler":
195
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "AdditiveChi2Sampler":
208
196
  """Only validates estimator's parameters
209
197
  For more details on this function, see [sklearn.kernel_approximation.AdditiveChi2Sampler.fit]
210
198
  (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.AdditiveChi2Sampler.html#sklearn.kernel_approximation.AdditiveChi2Sampler.fit)
@@ -231,12 +219,14 @@ class AdditiveChi2Sampler(BaseTransformer):
231
219
 
232
220
  self._snowpark_cols = dataset.select(self.input_cols).columns
233
221
 
234
- # If we are already in a stored procedure, no need to kick off another one.
222
+ # If we are already in a stored procedure, no need to kick off another one.
235
223
  if SNOWML_SPROC_ENV in os.environ:
236
224
  statement_params = telemetry.get_function_usage_statement_params(
237
225
  project=_PROJECT,
238
226
  subproject=_SUBPROJECT,
239
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), AdditiveChi2Sampler.__class__.__name__),
227
+ function_name=telemetry.get_statement_params_full_func_name(
228
+ inspect.currentframe(), AdditiveChi2Sampler.__class__.__name__
229
+ ),
240
230
  api_calls=[Session.call],
241
231
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
242
232
  )
@@ -257,27 +247,24 @@ class AdditiveChi2Sampler(BaseTransformer):
257
247
  )
258
248
  self._sklearn_object = model_trainer.train()
259
249
  self._is_fitted = True
260
- self._get_model_signatures(dataset)
250
+ self._generate_model_signatures(dataset)
261
251
  return self
262
252
 
263
253
  def _batch_inference_validate_snowpark(
264
254
  self,
265
255
  dataset: DataFrame,
266
256
  inference_method: str,
267
- ) -> List[str]:
268
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
269
- return the available package that exists in the snowflake anaconda channel
257
+ ) -> None:
258
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
270
259
 
271
260
  Args:
272
261
  dataset: snowpark dataframe
273
262
  inference_method: the inference method such as predict, score...
274
-
263
+
275
264
  Raises:
276
265
  SnowflakeMLException: If the estimator is not fitted, raise error
277
266
  SnowflakeMLException: If the session is None, raise error
278
267
 
279
- Returns:
280
- A list of available package that exists in the snowflake anaconda channel
281
268
  """
282
269
  if not self._is_fitted:
283
270
  raise exceptions.SnowflakeMLException(
@@ -295,9 +282,7 @@ class AdditiveChi2Sampler(BaseTransformer):
295
282
  "Session must not specified for snowpark dataset."
296
283
  ),
297
284
  )
298
- # Validate that key package version in user workspace are supported in snowflake conda channel
299
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
300
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
285
+
301
286
 
302
287
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
303
288
  @telemetry.send_api_usage_telemetry(
@@ -331,7 +316,9 @@ class AdditiveChi2Sampler(BaseTransformer):
331
316
  # when it is classifier, infer the datatype from label columns
332
317
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
333
318
  # Batch inference takes a single expected output column type. Use the first columns type for now.
334
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
319
+ label_cols_signatures = [
320
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
321
+ ]
335
322
  if len(label_cols_signatures) == 0:
336
323
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
337
324
  raise exceptions.SnowflakeMLException(
@@ -339,25 +326,23 @@ class AdditiveChi2Sampler(BaseTransformer):
339
326
  original_exception=ValueError(error_str),
340
327
  )
341
328
 
342
- expected_type_inferred = convert_sp_to_sf_type(
343
- label_cols_signatures[0].as_snowpark_type()
344
- )
329
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
345
330
 
346
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
347
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
331
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
332
+ self._deps = self._get_dependencies()
333
+ assert isinstance(
334
+ dataset._session, Session
335
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
348
336
 
349
337
  transform_kwargs = dict(
350
- session = dataset._session,
351
- dependencies = self._deps,
352
- drop_input_cols = self._drop_input_cols,
353
- expected_output_cols_type = expected_type_inferred,
338
+ session=dataset._session,
339
+ dependencies=self._deps,
340
+ drop_input_cols=self._drop_input_cols,
341
+ expected_output_cols_type=expected_type_inferred,
354
342
  )
355
343
 
356
344
  elif isinstance(dataset, pd.DataFrame):
357
- transform_kwargs = dict(
358
- snowpark_input_cols = self._snowpark_cols,
359
- drop_input_cols = self._drop_input_cols
360
- )
345
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
361
346
 
362
347
  transform_handlers = ModelTransformerBuilder.build(
363
348
  dataset=dataset,
@@ -399,7 +384,7 @@ class AdditiveChi2Sampler(BaseTransformer):
399
384
  Transformed dataset.
400
385
  """
401
386
  super()._check_dataset_type(dataset)
402
- inference_method="transform"
387
+ inference_method = "transform"
403
388
 
404
389
  # This dictionary contains optional kwargs for batch inference. These kwargs
405
390
  # are specific to the type of dataset used.
@@ -429,24 +414,19 @@ class AdditiveChi2Sampler(BaseTransformer):
429
414
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
430
415
  expected_dtype = convert_sp_to_sf_type(output_types[0])
431
416
 
432
- self._deps = self._batch_inference_validate_snowpark(
433
- dataset=dataset,
434
- inference_method=inference_method,
435
- )
417
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
418
+ self._deps = self._get_dependencies()
436
419
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
437
420
 
438
421
  transform_kwargs = dict(
439
- session = dataset._session,
440
- dependencies = self._deps,
441
- drop_input_cols = self._drop_input_cols,
442
- expected_output_cols_type = expected_dtype,
422
+ session=dataset._session,
423
+ dependencies=self._deps,
424
+ drop_input_cols=self._drop_input_cols,
425
+ expected_output_cols_type=expected_dtype,
443
426
  )
444
427
 
445
428
  elif isinstance(dataset, pd.DataFrame):
446
- transform_kwargs = dict(
447
- snowpark_input_cols = self._snowpark_cols,
448
- drop_input_cols = self._drop_input_cols
449
- )
429
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
450
430
 
451
431
  transform_handlers = ModelTransformerBuilder.build(
452
432
  dataset=dataset,
@@ -465,7 +445,11 @@ class AdditiveChi2Sampler(BaseTransformer):
465
445
  return output_df
466
446
 
467
447
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
468
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
448
+ def fit_predict(
449
+ self,
450
+ dataset: Union[DataFrame, pd.DataFrame],
451
+ output_cols_prefix: str = "fit_predict_",
452
+ ) -> Union[DataFrame, pd.DataFrame]:
469
453
  """ Method not supported for this class.
470
454
 
471
455
 
@@ -490,22 +474,106 @@ class AdditiveChi2Sampler(BaseTransformer):
490
474
  )
491
475
  output_result, fitted_estimator = model_trainer.train_fit_predict(
492
476
  drop_input_cols=self._drop_input_cols,
493
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
477
+ expected_output_cols_list=(
478
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
479
+ ),
494
480
  )
495
481
  self._sklearn_object = fitted_estimator
496
482
  self._is_fitted = True
497
483
  return output_result
498
484
 
485
+
486
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
487
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
488
+ """ Fit to data, then transform it
489
+ For more details on this function, see [sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform]
490
+ (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.AdditiveChi2Sampler.html#sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform)
491
+
492
+
493
+ Raises:
494
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
499
495
 
500
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
501
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
502
- """
496
+ Args:
497
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
498
+ Snowpark or Pandas DataFrame.
499
+ output_cols_prefix: Prefix for the response columns
503
500
  Returns:
504
501
  Transformed dataset.
505
502
  """
506
- self.fit(dataset)
507
- assert self._sklearn_object is not None
508
- return self._sklearn_object.embedding_
503
+ self._infer_input_output_cols(dataset)
504
+ super()._check_dataset_type(dataset)
505
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
506
+ estimator=self._sklearn_object,
507
+ dataset=dataset,
508
+ input_cols=self.input_cols,
509
+ label_cols=self.label_cols,
510
+ sample_weight_col=self.sample_weight_col,
511
+ autogenerated=self._autogenerated,
512
+ subproject=_SUBPROJECT,
513
+ )
514
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=self.output_cols,
517
+ )
518
+ self._sklearn_object = fitted_estimator
519
+ self._is_fitted = True
520
+ return output_result
521
+
522
+
523
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
524
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
525
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
526
+ """
527
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
528
+ # The following condition is introduced for kneighbors methods, and not used in other methods
529
+ if output_cols:
530
+ output_cols = [
531
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
532
+ for c in output_cols
533
+ ]
534
+ elif getattr(self._sklearn_object, "classes_", None) is None:
535
+ output_cols = [output_cols_prefix]
536
+ elif self._sklearn_object is not None:
537
+ classes = self._sklearn_object.classes_
538
+ if isinstance(classes, numpy.ndarray):
539
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
540
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
541
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
542
+ output_cols = []
543
+ for i, cl in enumerate(classes):
544
+ # For binary classification, there is only one output column for each class
545
+ # ndarray as the two classes are complementary.
546
+ if len(cl) == 2:
547
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
548
+ else:
549
+ output_cols.extend([
550
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
551
+ ])
552
+ else:
553
+ output_cols = []
554
+
555
+ # Make sure column names are valid snowflake identifiers.
556
+ assert output_cols is not None # Make MyPy happy
557
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
558
+
559
+ return rv
560
+
561
+ def _align_expected_output_names(
562
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
563
+ ) -> List[str]:
564
+ # in case the inferred output column names dimension is different
565
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
566
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
567
+ output_df_columns = list(output_df_pd.columns)
568
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
569
+ if self.sample_weight_col:
570
+ output_df_columns_set -= set(self.sample_weight_col)
571
+ # if the dimension of inferred output column names is correct; use it
572
+ if len(expected_output_cols_list) == len(output_df_columns_set):
573
+ return expected_output_cols_list
574
+ # otherwise, use the sklearn estimator's output
575
+ else:
576
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
509
577
 
510
578
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
511
579
  @telemetry.send_api_usage_telemetry(
@@ -537,24 +605,26 @@ class AdditiveChi2Sampler(BaseTransformer):
537
605
  # are specific to the type of dataset used.
538
606
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
539
607
 
608
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
609
+
540
610
  if isinstance(dataset, DataFrame):
541
- self._deps = self._batch_inference_validate_snowpark(
542
- dataset=dataset,
543
- inference_method=inference_method,
544
- )
545
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
611
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
612
+ self._deps = self._get_dependencies()
613
+ assert isinstance(
614
+ dataset._session, Session
615
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
546
616
  transform_kwargs = dict(
547
617
  session=dataset._session,
548
618
  dependencies=self._deps,
549
- drop_input_cols = self._drop_input_cols,
619
+ drop_input_cols=self._drop_input_cols,
550
620
  expected_output_cols_type="float",
551
621
  )
622
+ expected_output_cols = self._align_expected_output_names(
623
+ inference_method, dataset, expected_output_cols, output_cols_prefix
624
+ )
552
625
 
553
626
  elif isinstance(dataset, pd.DataFrame):
554
- transform_kwargs = dict(
555
- snowpark_input_cols = self._snowpark_cols,
556
- drop_input_cols = self._drop_input_cols
557
- )
627
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
558
628
 
559
629
  transform_handlers = ModelTransformerBuilder.build(
560
630
  dataset=dataset,
@@ -566,7 +636,7 @@ class AdditiveChi2Sampler(BaseTransformer):
566
636
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
567
637
  inference_method=inference_method,
568
638
  input_cols=self.input_cols,
569
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
639
+ expected_output_cols=expected_output_cols,
570
640
  **transform_kwargs
571
641
  )
572
642
  return output_df
@@ -596,29 +666,30 @@ class AdditiveChi2Sampler(BaseTransformer):
596
666
  Output dataset with log probability of the sample for each class in the model.
597
667
  """
598
668
  super()._check_dataset_type(dataset)
599
- inference_method="predict_log_proba"
669
+ inference_method = "predict_log_proba"
670
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
600
671
 
601
672
  # This dictionary contains optional kwargs for batch inference. These kwargs
602
673
  # are specific to the type of dataset used.
603
674
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
604
675
 
605
676
  if isinstance(dataset, DataFrame):
606
- self._deps = self._batch_inference_validate_snowpark(
607
- dataset=dataset,
608
- inference_method=inference_method,
609
- )
610
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
677
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
678
+ self._deps = self._get_dependencies()
679
+ assert isinstance(
680
+ dataset._session, Session
681
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
611
682
  transform_kwargs = dict(
612
683
  session=dataset._session,
613
684
  dependencies=self._deps,
614
- drop_input_cols = self._drop_input_cols,
685
+ drop_input_cols=self._drop_input_cols,
615
686
  expected_output_cols_type="float",
616
687
  )
688
+ expected_output_cols = self._align_expected_output_names(
689
+ inference_method, dataset, expected_output_cols, output_cols_prefix
690
+ )
617
691
  elif isinstance(dataset, pd.DataFrame):
618
- transform_kwargs = dict(
619
- snowpark_input_cols = self._snowpark_cols,
620
- drop_input_cols = self._drop_input_cols
621
- )
692
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
622
693
 
623
694
  transform_handlers = ModelTransformerBuilder.build(
624
695
  dataset=dataset,
@@ -631,7 +702,7 @@ class AdditiveChi2Sampler(BaseTransformer):
631
702
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
632
703
  inference_method=inference_method,
633
704
  input_cols=self.input_cols,
634
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
705
+ expected_output_cols=expected_output_cols,
635
706
  **transform_kwargs
636
707
  )
637
708
  return output_df
@@ -657,30 +728,32 @@ class AdditiveChi2Sampler(BaseTransformer):
657
728
  Output dataset with results of the decision function for the samples in input dataset.
658
729
  """
659
730
  super()._check_dataset_type(dataset)
660
- inference_method="decision_function"
731
+ inference_method = "decision_function"
661
732
 
662
733
  # This dictionary contains optional kwargs for batch inference. These kwargs
663
734
  # are specific to the type of dataset used.
664
735
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
665
736
 
737
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
738
+
666
739
  if isinstance(dataset, DataFrame):
667
- self._deps = self._batch_inference_validate_snowpark(
668
- dataset=dataset,
669
- inference_method=inference_method,
670
- )
671
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
740
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
741
+ self._deps = self._get_dependencies()
742
+ assert isinstance(
743
+ dataset._session, Session
744
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
672
745
  transform_kwargs = dict(
673
746
  session=dataset._session,
674
747
  dependencies=self._deps,
675
- drop_input_cols = self._drop_input_cols,
748
+ drop_input_cols=self._drop_input_cols,
676
749
  expected_output_cols_type="float",
677
750
  )
751
+ expected_output_cols = self._align_expected_output_names(
752
+ inference_method, dataset, expected_output_cols, output_cols_prefix
753
+ )
678
754
 
679
755
  elif isinstance(dataset, pd.DataFrame):
680
- transform_kwargs = dict(
681
- snowpark_input_cols = self._snowpark_cols,
682
- drop_input_cols = self._drop_input_cols
683
- )
756
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
684
757
 
685
758
  transform_handlers = ModelTransformerBuilder.build(
686
759
  dataset=dataset,
@@ -693,7 +766,7 @@ class AdditiveChi2Sampler(BaseTransformer):
693
766
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
694
767
  inference_method=inference_method,
695
768
  input_cols=self.input_cols,
696
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
769
+ expected_output_cols=expected_output_cols,
697
770
  **transform_kwargs
698
771
  )
699
772
  return output_df
@@ -722,17 +795,17 @@ class AdditiveChi2Sampler(BaseTransformer):
722
795
  Output dataset with probability of the sample for each class in the model.
723
796
  """
724
797
  super()._check_dataset_type(dataset)
725
- inference_method="score_samples"
798
+ inference_method = "score_samples"
726
799
 
727
800
  # This dictionary contains optional kwargs for batch inference. These kwargs
728
801
  # are specific to the type of dataset used.
729
802
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
730
803
 
804
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
805
+
731
806
  if isinstance(dataset, DataFrame):
732
- self._deps = self._batch_inference_validate_snowpark(
733
- dataset=dataset,
734
- inference_method=inference_method,
735
- )
807
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
808
+ self._deps = self._get_dependencies()
736
809
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
737
810
  transform_kwargs = dict(
738
811
  session=dataset._session,
@@ -740,6 +813,9 @@ class AdditiveChi2Sampler(BaseTransformer):
740
813
  drop_input_cols = self._drop_input_cols,
741
814
  expected_output_cols_type="float",
742
815
  )
816
+ expected_output_cols = self._align_expected_output_names(
817
+ inference_method, dataset, expected_output_cols, output_cols_prefix
818
+ )
743
819
 
744
820
  elif isinstance(dataset, pd.DataFrame):
745
821
  transform_kwargs = dict(
@@ -758,7 +834,7 @@ class AdditiveChi2Sampler(BaseTransformer):
758
834
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
759
835
  inference_method=inference_method,
760
836
  input_cols=self.input_cols,
761
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
837
+ expected_output_cols=expected_output_cols,
762
838
  **transform_kwargs
763
839
  )
764
840
  return output_df
@@ -791,17 +867,15 @@ class AdditiveChi2Sampler(BaseTransformer):
791
867
  transform_kwargs: ScoreKwargsTypedDict = dict()
792
868
 
793
869
  if isinstance(dataset, DataFrame):
794
- self._deps = self._batch_inference_validate_snowpark(
795
- dataset=dataset,
796
- inference_method="score",
797
- )
870
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
871
+ self._deps = self._get_dependencies()
798
872
  selected_cols = self._get_active_columns()
799
873
  if len(selected_cols) > 0:
800
874
  dataset = dataset.select(selected_cols)
801
875
  assert isinstance(dataset._session, Session) # keep mypy happy
802
876
  transform_kwargs = dict(
803
877
  session=dataset._session,
804
- dependencies=["snowflake-snowpark-python"] + self._deps,
878
+ dependencies=self._deps,
805
879
  score_sproc_imports=['sklearn'],
806
880
  )
807
881
  elif isinstance(dataset, pd.DataFrame):
@@ -866,11 +940,8 @@ class AdditiveChi2Sampler(BaseTransformer):
866
940
 
867
941
  if isinstance(dataset, DataFrame):
868
942
 
869
- self._deps = self._batch_inference_validate_snowpark(
870
- dataset=dataset,
871
- inference_method=inference_method,
872
-
873
- )
943
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
944
+ self._deps = self._get_dependencies()
874
945
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
875
946
  transform_kwargs = dict(
876
947
  session = dataset._session,
@@ -903,50 +974,84 @@ class AdditiveChi2Sampler(BaseTransformer):
903
974
  )
904
975
  return output_df
905
976
 
977
+
978
+
979
+ def to_sklearn(self) -> Any:
980
+ """Get sklearn.kernel_approximation.AdditiveChi2Sampler object.
981
+ """
982
+ if self._sklearn_object is None:
983
+ self._sklearn_object = self._create_sklearn_object()
984
+ return self._sklearn_object
985
+
986
+ def to_xgboost(self) -> Any:
987
+ raise exceptions.SnowflakeMLException(
988
+ error_code=error_codes.METHOD_NOT_ALLOWED,
989
+ original_exception=AttributeError(
990
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
991
+ "to_xgboost()",
992
+ "to_sklearn()"
993
+ )
994
+ ),
995
+ )
906
996
 
907
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
997
+ def to_lightgbm(self) -> Any:
998
+ raise exceptions.SnowflakeMLException(
999
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1000
+ original_exception=AttributeError(
1001
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1002
+ "to_lightgbm()",
1003
+ "to_sklearn()"
1004
+ )
1005
+ ),
1006
+ )
1007
+
1008
+ def _get_dependencies(self) -> List[str]:
1009
+ return self._deps
1010
+
1011
+
1012
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
908
1013
  self._model_signature_dict = dict()
909
1014
 
910
1015
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
911
1016
 
912
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1017
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
913
1018
  outputs: List[BaseFeatureSpec] = []
914
1019
  if hasattr(self, "predict"):
915
1020
  # keep mypy happy
916
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1021
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
917
1022
  # For classifier, the type of predict is the same as the type of label
918
- if self._sklearn_object._estimator_type == 'classifier':
919
- # label columns is the desired type for output
1023
+ if self._sklearn_object._estimator_type == "classifier":
1024
+ # label columns is the desired type for output
920
1025
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
921
1026
  # rename the output columns
922
1027
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
923
- self._model_signature_dict["predict"] = ModelSignature(inputs,
924
- ([] if self._drop_input_cols else inputs)
925
- + outputs)
1028
+ self._model_signature_dict["predict"] = ModelSignature(
1029
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1030
+ )
926
1031
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
927
1032
  # For outlier models, returns -1 for outliers and 1 for inliers.
928
- # Clusterer returns int64 cluster labels.
1033
+ # Clusterer returns int64 cluster labels.
929
1034
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
930
1035
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
931
- self._model_signature_dict["predict"] = ModelSignature(inputs,
932
- ([] if self._drop_input_cols else inputs)
933
- + outputs)
934
-
1036
+ self._model_signature_dict["predict"] = ModelSignature(
1037
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1038
+ )
1039
+
935
1040
  # For regressor, the type of predict is float64
936
- elif self._sklearn_object._estimator_type == 'regressor':
1041
+ elif self._sklearn_object._estimator_type == "regressor":
937
1042
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
938
- self._model_signature_dict["predict"] = ModelSignature(inputs,
939
- ([] if self._drop_input_cols else inputs)
940
- + outputs)
941
-
1043
+ self._model_signature_dict["predict"] = ModelSignature(
1044
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1045
+ )
1046
+
942
1047
  for prob_func in PROB_FUNCTIONS:
943
1048
  if hasattr(self, prob_func):
944
1049
  output_cols_prefix: str = f"{prob_func}_"
945
1050
  output_column_names = self._get_output_column_names(output_cols_prefix)
946
1051
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
947
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
948
- ([] if self._drop_input_cols else inputs)
949
- + outputs)
1052
+ self._model_signature_dict[prob_func] = ModelSignature(
1053
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1054
+ )
950
1055
 
951
1056
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
952
1057
  items = list(self._model_signature_dict.items())
@@ -959,10 +1064,10 @@ class AdditiveChi2Sampler(BaseTransformer):
959
1064
  """Returns model signature of current class.
960
1065
 
961
1066
  Raises:
962
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1067
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
963
1068
 
964
1069
  Returns:
965
- Dict[str, ModelSignature]: each method and its input output signature
1070
+ Dict with each method and its input output signature
966
1071
  """
967
1072
  if self._model_signature_dict is None:
968
1073
  raise exceptions.SnowflakeMLException(
@@ -970,35 +1075,3 @@ class AdditiveChi2Sampler(BaseTransformer):
970
1075
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
971
1076
  )
972
1077
  return self._model_signature_dict
973
-
974
- def to_sklearn(self) -> Any:
975
- """Get sklearn.kernel_approximation.AdditiveChi2Sampler object.
976
- """
977
- if self._sklearn_object is None:
978
- self._sklearn_object = self._create_sklearn_object()
979
- return self._sklearn_object
980
-
981
- def to_xgboost(self) -> Any:
982
- raise exceptions.SnowflakeMLException(
983
- error_code=error_codes.METHOD_NOT_ALLOWED,
984
- original_exception=AttributeError(
985
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
986
- "to_xgboost()",
987
- "to_sklearn()"
988
- )
989
- ),
990
- )
991
-
992
- def to_lightgbm(self) -> Any:
993
- raise exceptions.SnowflakeMLException(
994
- error_code=error_codes.METHOD_NOT_ALLOWED,
995
- original_exception=AttributeError(
996
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
997
- "to_lightgbm()",
998
- "to_sklearn()"
999
- )
1000
- ),
1001
- )
1002
-
1003
- def _get_dependencies(self) -> List[str]:
1004
- return self._deps