snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
61
60
 
62
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
62
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
63
  class Nystroem(BaseTransformer):
71
64
  r"""Approximate a kernel map using a subset of the training data
72
65
  For more details on this class, see [sklearn.kernel_approximation.Nystroem]
@@ -247,12 +240,7 @@ class Nystroem(BaseTransformer):
247
240
  )
248
241
  return selected_cols
249
242
 
250
- @telemetry.send_api_usage_telemetry(
251
- project=_PROJECT,
252
- subproject=_SUBPROJECT,
253
- custom_tags=dict([("autogen", True)]),
254
- )
255
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "Nystroem":
243
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "Nystroem":
256
244
  """Fit estimator to data
257
245
  For more details on this function, see [sklearn.kernel_approximation.Nystroem.fit]
258
246
  (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.Nystroem.html#sklearn.kernel_approximation.Nystroem.fit)
@@ -279,12 +267,14 @@ class Nystroem(BaseTransformer):
279
267
 
280
268
  self._snowpark_cols = dataset.select(self.input_cols).columns
281
269
 
282
- # If we are already in a stored procedure, no need to kick off another one.
270
+ # If we are already in a stored procedure, no need to kick off another one.
283
271
  if SNOWML_SPROC_ENV in os.environ:
284
272
  statement_params = telemetry.get_function_usage_statement_params(
285
273
  project=_PROJECT,
286
274
  subproject=_SUBPROJECT,
287
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), Nystroem.__class__.__name__),
275
+ function_name=telemetry.get_statement_params_full_func_name(
276
+ inspect.currentframe(), Nystroem.__class__.__name__
277
+ ),
288
278
  api_calls=[Session.call],
289
279
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
290
280
  )
@@ -305,27 +295,24 @@ class Nystroem(BaseTransformer):
305
295
  )
306
296
  self._sklearn_object = model_trainer.train()
307
297
  self._is_fitted = True
308
- self._get_model_signatures(dataset)
298
+ self._generate_model_signatures(dataset)
309
299
  return self
310
300
 
311
301
  def _batch_inference_validate_snowpark(
312
302
  self,
313
303
  dataset: DataFrame,
314
304
  inference_method: str,
315
- ) -> List[str]:
316
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
317
- return the available package that exists in the snowflake anaconda channel
305
+ ) -> None:
306
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
318
307
 
319
308
  Args:
320
309
  dataset: snowpark dataframe
321
310
  inference_method: the inference method such as predict, score...
322
-
311
+
323
312
  Raises:
324
313
  SnowflakeMLException: If the estimator is not fitted, raise error
325
314
  SnowflakeMLException: If the session is None, raise error
326
315
 
327
- Returns:
328
- A list of available package that exists in the snowflake anaconda channel
329
316
  """
330
317
  if not self._is_fitted:
331
318
  raise exceptions.SnowflakeMLException(
@@ -343,9 +330,7 @@ class Nystroem(BaseTransformer):
343
330
  "Session must not specified for snowpark dataset."
344
331
  ),
345
332
  )
346
- # Validate that key package version in user workspace are supported in snowflake conda channel
347
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
348
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
333
+
349
334
 
350
335
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
351
336
  @telemetry.send_api_usage_telemetry(
@@ -379,7 +364,9 @@ class Nystroem(BaseTransformer):
379
364
  # when it is classifier, infer the datatype from label columns
380
365
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
381
366
  # Batch inference takes a single expected output column type. Use the first columns type for now.
382
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
367
+ label_cols_signatures = [
368
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
369
+ ]
383
370
  if len(label_cols_signatures) == 0:
384
371
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
385
372
  raise exceptions.SnowflakeMLException(
@@ -387,25 +374,23 @@ class Nystroem(BaseTransformer):
387
374
  original_exception=ValueError(error_str),
388
375
  )
389
376
 
390
- expected_type_inferred = convert_sp_to_sf_type(
391
- label_cols_signatures[0].as_snowpark_type()
392
- )
377
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
393
378
 
394
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
395
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
379
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
380
+ self._deps = self._get_dependencies()
381
+ assert isinstance(
382
+ dataset._session, Session
383
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
396
384
 
397
385
  transform_kwargs = dict(
398
- session = dataset._session,
399
- dependencies = self._deps,
400
- drop_input_cols = self._drop_input_cols,
401
- expected_output_cols_type = expected_type_inferred,
386
+ session=dataset._session,
387
+ dependencies=self._deps,
388
+ drop_input_cols=self._drop_input_cols,
389
+ expected_output_cols_type=expected_type_inferred,
402
390
  )
403
391
 
404
392
  elif isinstance(dataset, pd.DataFrame):
405
- transform_kwargs = dict(
406
- snowpark_input_cols = self._snowpark_cols,
407
- drop_input_cols = self._drop_input_cols
408
- )
393
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
409
394
 
410
395
  transform_handlers = ModelTransformerBuilder.build(
411
396
  dataset=dataset,
@@ -447,7 +432,7 @@ class Nystroem(BaseTransformer):
447
432
  Transformed dataset.
448
433
  """
449
434
  super()._check_dataset_type(dataset)
450
- inference_method="transform"
435
+ inference_method = "transform"
451
436
 
452
437
  # This dictionary contains optional kwargs for batch inference. These kwargs
453
438
  # are specific to the type of dataset used.
@@ -477,24 +462,19 @@ class Nystroem(BaseTransformer):
477
462
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
478
463
  expected_dtype = convert_sp_to_sf_type(output_types[0])
479
464
 
480
- self._deps = self._batch_inference_validate_snowpark(
481
- dataset=dataset,
482
- inference_method=inference_method,
483
- )
465
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
466
+ self._deps = self._get_dependencies()
484
467
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
485
468
 
486
469
  transform_kwargs = dict(
487
- session = dataset._session,
488
- dependencies = self._deps,
489
- drop_input_cols = self._drop_input_cols,
490
- expected_output_cols_type = expected_dtype,
470
+ session=dataset._session,
471
+ dependencies=self._deps,
472
+ drop_input_cols=self._drop_input_cols,
473
+ expected_output_cols_type=expected_dtype,
491
474
  )
492
475
 
493
476
  elif isinstance(dataset, pd.DataFrame):
494
- transform_kwargs = dict(
495
- snowpark_input_cols = self._snowpark_cols,
496
- drop_input_cols = self._drop_input_cols
497
- )
477
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
498
478
 
499
479
  transform_handlers = ModelTransformerBuilder.build(
500
480
  dataset=dataset,
@@ -513,7 +493,11 @@ class Nystroem(BaseTransformer):
513
493
  return output_df
514
494
 
515
495
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
516
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
496
+ def fit_predict(
497
+ self,
498
+ dataset: Union[DataFrame, pd.DataFrame],
499
+ output_cols_prefix: str = "fit_predict_",
500
+ ) -> Union[DataFrame, pd.DataFrame]:
517
501
  """ Method not supported for this class.
518
502
 
519
503
 
@@ -538,22 +522,106 @@ class Nystroem(BaseTransformer):
538
522
  )
539
523
  output_result, fitted_estimator = model_trainer.train_fit_predict(
540
524
  drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
525
+ expected_output_cols_list=(
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
527
+ ),
542
528
  )
543
529
  self._sklearn_object = fitted_estimator
544
530
  self._is_fitted = True
545
531
  return output_result
546
532
 
533
+
534
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
535
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
536
+ """ Fit to data, then transform it
537
+ For more details on this function, see [sklearn.kernel_approximation.Nystroem.fit_transform]
538
+ (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.Nystroem.html#sklearn.kernel_approximation.Nystroem.fit_transform)
539
+
540
+
541
+ Raises:
542
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
547
543
 
548
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
549
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
550
- """
544
+ Args:
545
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
546
+ Snowpark or Pandas DataFrame.
547
+ output_cols_prefix: Prefix for the response columns
551
548
  Returns:
552
549
  Transformed dataset.
553
550
  """
554
- self.fit(dataset)
555
- assert self._sklearn_object is not None
556
- return self._sklearn_object.embedding_
551
+ self._infer_input_output_cols(dataset)
552
+ super()._check_dataset_type(dataset)
553
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
554
+ estimator=self._sklearn_object,
555
+ dataset=dataset,
556
+ input_cols=self.input_cols,
557
+ label_cols=self.label_cols,
558
+ sample_weight_col=self.sample_weight_col,
559
+ autogenerated=self._autogenerated,
560
+ subproject=_SUBPROJECT,
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=self.output_cols,
565
+ )
566
+ self._sklearn_object = fitted_estimator
567
+ self._is_fitted = True
568
+ return output_result
569
+
570
+
571
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
572
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
573
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
574
+ """
575
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
576
+ # The following condition is introduced for kneighbors methods, and not used in other methods
577
+ if output_cols:
578
+ output_cols = [
579
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
580
+ for c in output_cols
581
+ ]
582
+ elif getattr(self._sklearn_object, "classes_", None) is None:
583
+ output_cols = [output_cols_prefix]
584
+ elif self._sklearn_object is not None:
585
+ classes = self._sklearn_object.classes_
586
+ if isinstance(classes, numpy.ndarray):
587
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
588
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
589
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
590
+ output_cols = []
591
+ for i, cl in enumerate(classes):
592
+ # For binary classification, there is only one output column for each class
593
+ # ndarray as the two classes are complementary.
594
+ if len(cl) == 2:
595
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
596
+ else:
597
+ output_cols.extend([
598
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
599
+ ])
600
+ else:
601
+ output_cols = []
602
+
603
+ # Make sure column names are valid snowflake identifiers.
604
+ assert output_cols is not None # Make MyPy happy
605
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
606
+
607
+ return rv
608
+
609
+ def _align_expected_output_names(
610
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
611
+ ) -> List[str]:
612
+ # in case the inferred output column names dimension is different
613
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
614
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
615
+ output_df_columns = list(output_df_pd.columns)
616
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
617
+ if self.sample_weight_col:
618
+ output_df_columns_set -= set(self.sample_weight_col)
619
+ # if the dimension of inferred output column names is correct; use it
620
+ if len(expected_output_cols_list) == len(output_df_columns_set):
621
+ return expected_output_cols_list
622
+ # otherwise, use the sklearn estimator's output
623
+ else:
624
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
557
625
 
558
626
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
559
627
  @telemetry.send_api_usage_telemetry(
@@ -585,24 +653,26 @@ class Nystroem(BaseTransformer):
585
653
  # are specific to the type of dataset used.
586
654
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
587
655
 
656
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
657
+
588
658
  if isinstance(dataset, DataFrame):
589
- self._deps = self._batch_inference_validate_snowpark(
590
- dataset=dataset,
591
- inference_method=inference_method,
592
- )
593
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
659
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
660
+ self._deps = self._get_dependencies()
661
+ assert isinstance(
662
+ dataset._session, Session
663
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
594
664
  transform_kwargs = dict(
595
665
  session=dataset._session,
596
666
  dependencies=self._deps,
597
- drop_input_cols = self._drop_input_cols,
667
+ drop_input_cols=self._drop_input_cols,
598
668
  expected_output_cols_type="float",
599
669
  )
670
+ expected_output_cols = self._align_expected_output_names(
671
+ inference_method, dataset, expected_output_cols, output_cols_prefix
672
+ )
600
673
 
601
674
  elif isinstance(dataset, pd.DataFrame):
602
- transform_kwargs = dict(
603
- snowpark_input_cols = self._snowpark_cols,
604
- drop_input_cols = self._drop_input_cols
605
- )
675
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
606
676
 
607
677
  transform_handlers = ModelTransformerBuilder.build(
608
678
  dataset=dataset,
@@ -614,7 +684,7 @@ class Nystroem(BaseTransformer):
614
684
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
615
685
  inference_method=inference_method,
616
686
  input_cols=self.input_cols,
617
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
687
+ expected_output_cols=expected_output_cols,
618
688
  **transform_kwargs
619
689
  )
620
690
  return output_df
@@ -644,29 +714,30 @@ class Nystroem(BaseTransformer):
644
714
  Output dataset with log probability of the sample for each class in the model.
645
715
  """
646
716
  super()._check_dataset_type(dataset)
647
- inference_method="predict_log_proba"
717
+ inference_method = "predict_log_proba"
718
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
648
719
 
649
720
  # This dictionary contains optional kwargs for batch inference. These kwargs
650
721
  # are specific to the type of dataset used.
651
722
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
652
723
 
653
724
  if isinstance(dataset, DataFrame):
654
- self._deps = self._batch_inference_validate_snowpark(
655
- dataset=dataset,
656
- inference_method=inference_method,
657
- )
658
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
725
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
726
+ self._deps = self._get_dependencies()
727
+ assert isinstance(
728
+ dataset._session, Session
729
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
659
730
  transform_kwargs = dict(
660
731
  session=dataset._session,
661
732
  dependencies=self._deps,
662
- drop_input_cols = self._drop_input_cols,
733
+ drop_input_cols=self._drop_input_cols,
663
734
  expected_output_cols_type="float",
664
735
  )
736
+ expected_output_cols = self._align_expected_output_names(
737
+ inference_method, dataset, expected_output_cols, output_cols_prefix
738
+ )
665
739
  elif isinstance(dataset, pd.DataFrame):
666
- transform_kwargs = dict(
667
- snowpark_input_cols = self._snowpark_cols,
668
- drop_input_cols = self._drop_input_cols
669
- )
740
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
670
741
 
671
742
  transform_handlers = ModelTransformerBuilder.build(
672
743
  dataset=dataset,
@@ -679,7 +750,7 @@ class Nystroem(BaseTransformer):
679
750
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
680
751
  inference_method=inference_method,
681
752
  input_cols=self.input_cols,
682
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
753
+ expected_output_cols=expected_output_cols,
683
754
  **transform_kwargs
684
755
  )
685
756
  return output_df
@@ -705,30 +776,32 @@ class Nystroem(BaseTransformer):
705
776
  Output dataset with results of the decision function for the samples in input dataset.
706
777
  """
707
778
  super()._check_dataset_type(dataset)
708
- inference_method="decision_function"
779
+ inference_method = "decision_function"
709
780
 
710
781
  # This dictionary contains optional kwargs for batch inference. These kwargs
711
782
  # are specific to the type of dataset used.
712
783
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
713
784
 
785
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
786
+
714
787
  if isinstance(dataset, DataFrame):
715
- self._deps = self._batch_inference_validate_snowpark(
716
- dataset=dataset,
717
- inference_method=inference_method,
718
- )
719
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
788
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
789
+ self._deps = self._get_dependencies()
790
+ assert isinstance(
791
+ dataset._session, Session
792
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
720
793
  transform_kwargs = dict(
721
794
  session=dataset._session,
722
795
  dependencies=self._deps,
723
- drop_input_cols = self._drop_input_cols,
796
+ drop_input_cols=self._drop_input_cols,
724
797
  expected_output_cols_type="float",
725
798
  )
799
+ expected_output_cols = self._align_expected_output_names(
800
+ inference_method, dataset, expected_output_cols, output_cols_prefix
801
+ )
726
802
 
727
803
  elif isinstance(dataset, pd.DataFrame):
728
- transform_kwargs = dict(
729
- snowpark_input_cols = self._snowpark_cols,
730
- drop_input_cols = self._drop_input_cols
731
- )
804
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
732
805
 
733
806
  transform_handlers = ModelTransformerBuilder.build(
734
807
  dataset=dataset,
@@ -741,7 +814,7 @@ class Nystroem(BaseTransformer):
741
814
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
742
815
  inference_method=inference_method,
743
816
  input_cols=self.input_cols,
744
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
817
+ expected_output_cols=expected_output_cols,
745
818
  **transform_kwargs
746
819
  )
747
820
  return output_df
@@ -770,17 +843,17 @@ class Nystroem(BaseTransformer):
770
843
  Output dataset with probability of the sample for each class in the model.
771
844
  """
772
845
  super()._check_dataset_type(dataset)
773
- inference_method="score_samples"
846
+ inference_method = "score_samples"
774
847
 
775
848
  # This dictionary contains optional kwargs for batch inference. These kwargs
776
849
  # are specific to the type of dataset used.
777
850
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
778
851
 
852
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
853
+
779
854
  if isinstance(dataset, DataFrame):
780
- self._deps = self._batch_inference_validate_snowpark(
781
- dataset=dataset,
782
- inference_method=inference_method,
783
- )
855
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
856
+ self._deps = self._get_dependencies()
784
857
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
785
858
  transform_kwargs = dict(
786
859
  session=dataset._session,
@@ -788,6 +861,9 @@ class Nystroem(BaseTransformer):
788
861
  drop_input_cols = self._drop_input_cols,
789
862
  expected_output_cols_type="float",
790
863
  )
864
+ expected_output_cols = self._align_expected_output_names(
865
+ inference_method, dataset, expected_output_cols, output_cols_prefix
866
+ )
791
867
 
792
868
  elif isinstance(dataset, pd.DataFrame):
793
869
  transform_kwargs = dict(
@@ -806,7 +882,7 @@ class Nystroem(BaseTransformer):
806
882
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
807
883
  inference_method=inference_method,
808
884
  input_cols=self.input_cols,
809
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
885
+ expected_output_cols=expected_output_cols,
810
886
  **transform_kwargs
811
887
  )
812
888
  return output_df
@@ -839,17 +915,15 @@ class Nystroem(BaseTransformer):
839
915
  transform_kwargs: ScoreKwargsTypedDict = dict()
840
916
 
841
917
  if isinstance(dataset, DataFrame):
842
- self._deps = self._batch_inference_validate_snowpark(
843
- dataset=dataset,
844
- inference_method="score",
845
- )
918
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
919
+ self._deps = self._get_dependencies()
846
920
  selected_cols = self._get_active_columns()
847
921
  if len(selected_cols) > 0:
848
922
  dataset = dataset.select(selected_cols)
849
923
  assert isinstance(dataset._session, Session) # keep mypy happy
850
924
  transform_kwargs = dict(
851
925
  session=dataset._session,
852
- dependencies=["snowflake-snowpark-python"] + self._deps,
926
+ dependencies=self._deps,
853
927
  score_sproc_imports=['sklearn'],
854
928
  )
855
929
  elif isinstance(dataset, pd.DataFrame):
@@ -914,11 +988,8 @@ class Nystroem(BaseTransformer):
914
988
 
915
989
  if isinstance(dataset, DataFrame):
916
990
 
917
- self._deps = self._batch_inference_validate_snowpark(
918
- dataset=dataset,
919
- inference_method=inference_method,
920
-
921
- )
991
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
992
+ self._deps = self._get_dependencies()
922
993
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
923
994
  transform_kwargs = dict(
924
995
  session = dataset._session,
@@ -951,50 +1022,84 @@ class Nystroem(BaseTransformer):
951
1022
  )
952
1023
  return output_df
953
1024
 
1025
+
1026
+
1027
+ def to_sklearn(self) -> Any:
1028
+ """Get sklearn.kernel_approximation.Nystroem object.
1029
+ """
1030
+ if self._sklearn_object is None:
1031
+ self._sklearn_object = self._create_sklearn_object()
1032
+ return self._sklearn_object
1033
+
1034
+ def to_xgboost(self) -> Any:
1035
+ raise exceptions.SnowflakeMLException(
1036
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1037
+ original_exception=AttributeError(
1038
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1039
+ "to_xgboost()",
1040
+ "to_sklearn()"
1041
+ )
1042
+ ),
1043
+ )
954
1044
 
955
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1045
+ def to_lightgbm(self) -> Any:
1046
+ raise exceptions.SnowflakeMLException(
1047
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1048
+ original_exception=AttributeError(
1049
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1050
+ "to_lightgbm()",
1051
+ "to_sklearn()"
1052
+ )
1053
+ ),
1054
+ )
1055
+
1056
+ def _get_dependencies(self) -> List[str]:
1057
+ return self._deps
1058
+
1059
+
1060
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
956
1061
  self._model_signature_dict = dict()
957
1062
 
958
1063
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
959
1064
 
960
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1065
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
961
1066
  outputs: List[BaseFeatureSpec] = []
962
1067
  if hasattr(self, "predict"):
963
1068
  # keep mypy happy
964
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1069
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
965
1070
  # For classifier, the type of predict is the same as the type of label
966
- if self._sklearn_object._estimator_type == 'classifier':
967
- # label columns is the desired type for output
1071
+ if self._sklearn_object._estimator_type == "classifier":
1072
+ # label columns is the desired type for output
968
1073
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
969
1074
  # rename the output columns
970
1075
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
971
- self._model_signature_dict["predict"] = ModelSignature(inputs,
972
- ([] if self._drop_input_cols else inputs)
973
- + outputs)
1076
+ self._model_signature_dict["predict"] = ModelSignature(
1077
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1078
+ )
974
1079
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
975
1080
  # For outlier models, returns -1 for outliers and 1 for inliers.
976
- # Clusterer returns int64 cluster labels.
1081
+ # Clusterer returns int64 cluster labels.
977
1082
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
978
1083
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
979
- self._model_signature_dict["predict"] = ModelSignature(inputs,
980
- ([] if self._drop_input_cols else inputs)
981
- + outputs)
982
-
1084
+ self._model_signature_dict["predict"] = ModelSignature(
1085
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1086
+ )
1087
+
983
1088
  # For regressor, the type of predict is float64
984
- elif self._sklearn_object._estimator_type == 'regressor':
1089
+ elif self._sklearn_object._estimator_type == "regressor":
985
1090
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
986
- self._model_signature_dict["predict"] = ModelSignature(inputs,
987
- ([] if self._drop_input_cols else inputs)
988
- + outputs)
989
-
1091
+ self._model_signature_dict["predict"] = ModelSignature(
1092
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1093
+ )
1094
+
990
1095
  for prob_func in PROB_FUNCTIONS:
991
1096
  if hasattr(self, prob_func):
992
1097
  output_cols_prefix: str = f"{prob_func}_"
993
1098
  output_column_names = self._get_output_column_names(output_cols_prefix)
994
1099
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
995
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
996
- ([] if self._drop_input_cols else inputs)
997
- + outputs)
1100
+ self._model_signature_dict[prob_func] = ModelSignature(
1101
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1102
+ )
998
1103
 
999
1104
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
1000
1105
  items = list(self._model_signature_dict.items())
@@ -1007,10 +1112,10 @@ class Nystroem(BaseTransformer):
1007
1112
  """Returns model signature of current class.
1008
1113
 
1009
1114
  Raises:
1010
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1115
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1011
1116
 
1012
1117
  Returns:
1013
- Dict[str, ModelSignature]: each method and its input output signature
1118
+ Dict with each method and its input output signature
1014
1119
  """
1015
1120
  if self._model_signature_dict is None:
1016
1121
  raise exceptions.SnowflakeMLException(
@@ -1018,35 +1123,3 @@ class Nystroem(BaseTransformer):
1018
1123
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
1019
1124
  )
1020
1125
  return self._model_signature_dict
1021
-
1022
- def to_sklearn(self) -> Any:
1023
- """Get sklearn.kernel_approximation.Nystroem object.
1024
- """
1025
- if self._sklearn_object is None:
1026
- self._sklearn_object = self._create_sklearn_object()
1027
- return self._sklearn_object
1028
-
1029
- def to_xgboost(self) -> Any:
1030
- raise exceptions.SnowflakeMLException(
1031
- error_code=error_codes.METHOD_NOT_ALLOWED,
1032
- original_exception=AttributeError(
1033
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1034
- "to_xgboost()",
1035
- "to_sklearn()"
1036
- )
1037
- ),
1038
- )
1039
-
1040
- def to_lightgbm(self) -> Any:
1041
- raise exceptions.SnowflakeMLException(
1042
- error_code=error_codes.METHOD_NOT_ALLOWED,
1043
- original_exception=AttributeError(
1044
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1045
- "to_lightgbm()",
1046
- "to_sklearn()"
1047
- )
1048
- ),
1049
- )
1050
-
1051
- def _get_dependencies(self) -> List[str]:
1052
- return self._deps