snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
61
60
 
62
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
62
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
63
  class VarianceThreshold(BaseTransformer):
71
64
  r"""Feature selector that removes all low-variance features
72
65
  For more details on this class, see [sklearn.feature_selection.VarianceThreshold]
@@ -196,12 +189,7 @@ class VarianceThreshold(BaseTransformer):
196
189
  )
197
190
  return selected_cols
198
191
 
199
- @telemetry.send_api_usage_telemetry(
200
- project=_PROJECT,
201
- subproject=_SUBPROJECT,
202
- custom_tags=dict([("autogen", True)]),
203
- )
204
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "VarianceThreshold":
192
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "VarianceThreshold":
205
193
  """Learn empirical variances from X
206
194
  For more details on this function, see [sklearn.feature_selection.VarianceThreshold.fit]
207
195
  (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html#sklearn.feature_selection.VarianceThreshold.fit)
@@ -228,12 +216,14 @@ class VarianceThreshold(BaseTransformer):
228
216
 
229
217
  self._snowpark_cols = dataset.select(self.input_cols).columns
230
218
 
231
- # If we are already in a stored procedure, no need to kick off another one.
219
+ # If we are already in a stored procedure, no need to kick off another one.
232
220
  if SNOWML_SPROC_ENV in os.environ:
233
221
  statement_params = telemetry.get_function_usage_statement_params(
234
222
  project=_PROJECT,
235
223
  subproject=_SUBPROJECT,
236
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), VarianceThreshold.__class__.__name__),
224
+ function_name=telemetry.get_statement_params_full_func_name(
225
+ inspect.currentframe(), VarianceThreshold.__class__.__name__
226
+ ),
237
227
  api_calls=[Session.call],
238
228
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
239
229
  )
@@ -254,27 +244,24 @@ class VarianceThreshold(BaseTransformer):
254
244
  )
255
245
  self._sklearn_object = model_trainer.train()
256
246
  self._is_fitted = True
257
- self._get_model_signatures(dataset)
247
+ self._generate_model_signatures(dataset)
258
248
  return self
259
249
 
260
250
  def _batch_inference_validate_snowpark(
261
251
  self,
262
252
  dataset: DataFrame,
263
253
  inference_method: str,
264
- ) -> List[str]:
265
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
266
- return the available package that exists in the snowflake anaconda channel
254
+ ) -> None:
255
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
267
256
 
268
257
  Args:
269
258
  dataset: snowpark dataframe
270
259
  inference_method: the inference method such as predict, score...
271
-
260
+
272
261
  Raises:
273
262
  SnowflakeMLException: If the estimator is not fitted, raise error
274
263
  SnowflakeMLException: If the session is None, raise error
275
264
 
276
- Returns:
277
- A list of available package that exists in the snowflake anaconda channel
278
265
  """
279
266
  if not self._is_fitted:
280
267
  raise exceptions.SnowflakeMLException(
@@ -292,9 +279,7 @@ class VarianceThreshold(BaseTransformer):
292
279
  "Session must not specified for snowpark dataset."
293
280
  ),
294
281
  )
295
- # Validate that key package version in user workspace are supported in snowflake conda channel
296
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
297
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
282
+
298
283
 
299
284
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
300
285
  @telemetry.send_api_usage_telemetry(
@@ -328,7 +313,9 @@ class VarianceThreshold(BaseTransformer):
328
313
  # when it is classifier, infer the datatype from label columns
329
314
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
330
315
  # Batch inference takes a single expected output column type. Use the first columns type for now.
331
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
316
+ label_cols_signatures = [
317
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
318
+ ]
332
319
  if len(label_cols_signatures) == 0:
333
320
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
334
321
  raise exceptions.SnowflakeMLException(
@@ -336,25 +323,23 @@ class VarianceThreshold(BaseTransformer):
336
323
  original_exception=ValueError(error_str),
337
324
  )
338
325
 
339
- expected_type_inferred = convert_sp_to_sf_type(
340
- label_cols_signatures[0].as_snowpark_type()
341
- )
326
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
342
327
 
343
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
344
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
328
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
329
+ self._deps = self._get_dependencies()
330
+ assert isinstance(
331
+ dataset._session, Session
332
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
345
333
 
346
334
  transform_kwargs = dict(
347
- session = dataset._session,
348
- dependencies = self._deps,
349
- drop_input_cols = self._drop_input_cols,
350
- expected_output_cols_type = expected_type_inferred,
335
+ session=dataset._session,
336
+ dependencies=self._deps,
337
+ drop_input_cols=self._drop_input_cols,
338
+ expected_output_cols_type=expected_type_inferred,
351
339
  )
352
340
 
353
341
  elif isinstance(dataset, pd.DataFrame):
354
- transform_kwargs = dict(
355
- snowpark_input_cols = self._snowpark_cols,
356
- drop_input_cols = self._drop_input_cols
357
- )
342
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
358
343
 
359
344
  transform_handlers = ModelTransformerBuilder.build(
360
345
  dataset=dataset,
@@ -396,7 +381,7 @@ class VarianceThreshold(BaseTransformer):
396
381
  Transformed dataset.
397
382
  """
398
383
  super()._check_dataset_type(dataset)
399
- inference_method="transform"
384
+ inference_method = "transform"
400
385
 
401
386
  # This dictionary contains optional kwargs for batch inference. These kwargs
402
387
  # are specific to the type of dataset used.
@@ -426,24 +411,19 @@ class VarianceThreshold(BaseTransformer):
426
411
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
427
412
  expected_dtype = convert_sp_to_sf_type(output_types[0])
428
413
 
429
- self._deps = self._batch_inference_validate_snowpark(
430
- dataset=dataset,
431
- inference_method=inference_method,
432
- )
414
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
415
+ self._deps = self._get_dependencies()
433
416
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
434
417
 
435
418
  transform_kwargs = dict(
436
- session = dataset._session,
437
- dependencies = self._deps,
438
- drop_input_cols = self._drop_input_cols,
439
- expected_output_cols_type = expected_dtype,
419
+ session=dataset._session,
420
+ dependencies=self._deps,
421
+ drop_input_cols=self._drop_input_cols,
422
+ expected_output_cols_type=expected_dtype,
440
423
  )
441
424
 
442
425
  elif isinstance(dataset, pd.DataFrame):
443
- transform_kwargs = dict(
444
- snowpark_input_cols = self._snowpark_cols,
445
- drop_input_cols = self._drop_input_cols
446
- )
426
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
447
427
 
448
428
  transform_handlers = ModelTransformerBuilder.build(
449
429
  dataset=dataset,
@@ -462,7 +442,11 @@ class VarianceThreshold(BaseTransformer):
462
442
  return output_df
463
443
 
464
444
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
465
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
445
+ def fit_predict(
446
+ self,
447
+ dataset: Union[DataFrame, pd.DataFrame],
448
+ output_cols_prefix: str = "fit_predict_",
449
+ ) -> Union[DataFrame, pd.DataFrame]:
466
450
  """ Method not supported for this class.
467
451
 
468
452
 
@@ -487,22 +471,106 @@ class VarianceThreshold(BaseTransformer):
487
471
  )
488
472
  output_result, fitted_estimator = model_trainer.train_fit_predict(
489
473
  drop_input_cols=self._drop_input_cols,
490
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
474
+ expected_output_cols_list=(
475
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
476
+ ),
491
477
  )
492
478
  self._sklearn_object = fitted_estimator
493
479
  self._is_fitted = True
494
480
  return output_result
495
481
 
482
+
483
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
484
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
485
+ """ Fit to data, then transform it
486
+ For more details on this function, see [sklearn.feature_selection.VarianceThreshold.fit_transform]
487
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html#sklearn.feature_selection.VarianceThreshold.fit_transform)
488
+
489
+
490
+ Raises:
491
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
496
492
 
497
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
498
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
499
- """
493
+ Args:
494
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
495
+ Snowpark or Pandas DataFrame.
496
+ output_cols_prefix: Prefix for the response columns
500
497
  Returns:
501
498
  Transformed dataset.
502
499
  """
503
- self.fit(dataset)
504
- assert self._sklearn_object is not None
505
- return self._sklearn_object.embedding_
500
+ self._infer_input_output_cols(dataset)
501
+ super()._check_dataset_type(dataset)
502
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
503
+ estimator=self._sklearn_object,
504
+ dataset=dataset,
505
+ input_cols=self.input_cols,
506
+ label_cols=self.label_cols,
507
+ sample_weight_col=self.sample_weight_col,
508
+ autogenerated=self._autogenerated,
509
+ subproject=_SUBPROJECT,
510
+ )
511
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
512
+ drop_input_cols=self._drop_input_cols,
513
+ expected_output_cols_list=self.output_cols,
514
+ )
515
+ self._sklearn_object = fitted_estimator
516
+ self._is_fitted = True
517
+ return output_result
518
+
519
+
520
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
521
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
522
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
523
+ """
524
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
525
+ # The following condition is introduced for kneighbors methods, and not used in other methods
526
+ if output_cols:
527
+ output_cols = [
528
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
529
+ for c in output_cols
530
+ ]
531
+ elif getattr(self._sklearn_object, "classes_", None) is None:
532
+ output_cols = [output_cols_prefix]
533
+ elif self._sklearn_object is not None:
534
+ classes = self._sklearn_object.classes_
535
+ if isinstance(classes, numpy.ndarray):
536
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
537
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
538
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
539
+ output_cols = []
540
+ for i, cl in enumerate(classes):
541
+ # For binary classification, there is only one output column for each class
542
+ # ndarray as the two classes are complementary.
543
+ if len(cl) == 2:
544
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
545
+ else:
546
+ output_cols.extend([
547
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
548
+ ])
549
+ else:
550
+ output_cols = []
551
+
552
+ # Make sure column names are valid snowflake identifiers.
553
+ assert output_cols is not None # Make MyPy happy
554
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
555
+
556
+ return rv
557
+
558
+ def _align_expected_output_names(
559
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
560
+ ) -> List[str]:
561
+ # in case the inferred output column names dimension is different
562
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
563
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
564
+ output_df_columns = list(output_df_pd.columns)
565
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
566
+ if self.sample_weight_col:
567
+ output_df_columns_set -= set(self.sample_weight_col)
568
+ # if the dimension of inferred output column names is correct; use it
569
+ if len(expected_output_cols_list) == len(output_df_columns_set):
570
+ return expected_output_cols_list
571
+ # otherwise, use the sklearn estimator's output
572
+ else:
573
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
506
574
 
507
575
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
508
576
  @telemetry.send_api_usage_telemetry(
@@ -534,24 +602,26 @@ class VarianceThreshold(BaseTransformer):
534
602
  # are specific to the type of dataset used.
535
603
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
536
604
 
605
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
606
+
537
607
  if isinstance(dataset, DataFrame):
538
- self._deps = self._batch_inference_validate_snowpark(
539
- dataset=dataset,
540
- inference_method=inference_method,
541
- )
542
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
608
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
609
+ self._deps = self._get_dependencies()
610
+ assert isinstance(
611
+ dataset._session, Session
612
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
543
613
  transform_kwargs = dict(
544
614
  session=dataset._session,
545
615
  dependencies=self._deps,
546
- drop_input_cols = self._drop_input_cols,
616
+ drop_input_cols=self._drop_input_cols,
547
617
  expected_output_cols_type="float",
548
618
  )
619
+ expected_output_cols = self._align_expected_output_names(
620
+ inference_method, dataset, expected_output_cols, output_cols_prefix
621
+ )
549
622
 
550
623
  elif isinstance(dataset, pd.DataFrame):
551
- transform_kwargs = dict(
552
- snowpark_input_cols = self._snowpark_cols,
553
- drop_input_cols = self._drop_input_cols
554
- )
624
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
555
625
 
556
626
  transform_handlers = ModelTransformerBuilder.build(
557
627
  dataset=dataset,
@@ -563,7 +633,7 @@ class VarianceThreshold(BaseTransformer):
563
633
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
564
634
  inference_method=inference_method,
565
635
  input_cols=self.input_cols,
566
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
636
+ expected_output_cols=expected_output_cols,
567
637
  **transform_kwargs
568
638
  )
569
639
  return output_df
@@ -593,29 +663,30 @@ class VarianceThreshold(BaseTransformer):
593
663
  Output dataset with log probability of the sample for each class in the model.
594
664
  """
595
665
  super()._check_dataset_type(dataset)
596
- inference_method="predict_log_proba"
666
+ inference_method = "predict_log_proba"
667
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
597
668
 
598
669
  # This dictionary contains optional kwargs for batch inference. These kwargs
599
670
  # are specific to the type of dataset used.
600
671
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
601
672
 
602
673
  if isinstance(dataset, DataFrame):
603
- self._deps = self._batch_inference_validate_snowpark(
604
- dataset=dataset,
605
- inference_method=inference_method,
606
- )
607
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
674
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
675
+ self._deps = self._get_dependencies()
676
+ assert isinstance(
677
+ dataset._session, Session
678
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
608
679
  transform_kwargs = dict(
609
680
  session=dataset._session,
610
681
  dependencies=self._deps,
611
- drop_input_cols = self._drop_input_cols,
682
+ drop_input_cols=self._drop_input_cols,
612
683
  expected_output_cols_type="float",
613
684
  )
685
+ expected_output_cols = self._align_expected_output_names(
686
+ inference_method, dataset, expected_output_cols, output_cols_prefix
687
+ )
614
688
  elif isinstance(dataset, pd.DataFrame):
615
- transform_kwargs = dict(
616
- snowpark_input_cols = self._snowpark_cols,
617
- drop_input_cols = self._drop_input_cols
618
- )
689
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
619
690
 
620
691
  transform_handlers = ModelTransformerBuilder.build(
621
692
  dataset=dataset,
@@ -628,7 +699,7 @@ class VarianceThreshold(BaseTransformer):
628
699
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
629
700
  inference_method=inference_method,
630
701
  input_cols=self.input_cols,
631
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
702
+ expected_output_cols=expected_output_cols,
632
703
  **transform_kwargs
633
704
  )
634
705
  return output_df
@@ -654,30 +725,32 @@ class VarianceThreshold(BaseTransformer):
654
725
  Output dataset with results of the decision function for the samples in input dataset.
655
726
  """
656
727
  super()._check_dataset_type(dataset)
657
- inference_method="decision_function"
728
+ inference_method = "decision_function"
658
729
 
659
730
  # This dictionary contains optional kwargs for batch inference. These kwargs
660
731
  # are specific to the type of dataset used.
661
732
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
662
733
 
734
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
735
+
663
736
  if isinstance(dataset, DataFrame):
664
- self._deps = self._batch_inference_validate_snowpark(
665
- dataset=dataset,
666
- inference_method=inference_method,
667
- )
668
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
737
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
738
+ self._deps = self._get_dependencies()
739
+ assert isinstance(
740
+ dataset._session, Session
741
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
669
742
  transform_kwargs = dict(
670
743
  session=dataset._session,
671
744
  dependencies=self._deps,
672
- drop_input_cols = self._drop_input_cols,
745
+ drop_input_cols=self._drop_input_cols,
673
746
  expected_output_cols_type="float",
674
747
  )
748
+ expected_output_cols = self._align_expected_output_names(
749
+ inference_method, dataset, expected_output_cols, output_cols_prefix
750
+ )
675
751
 
676
752
  elif isinstance(dataset, pd.DataFrame):
677
- transform_kwargs = dict(
678
- snowpark_input_cols = self._snowpark_cols,
679
- drop_input_cols = self._drop_input_cols
680
- )
753
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
681
754
 
682
755
  transform_handlers = ModelTransformerBuilder.build(
683
756
  dataset=dataset,
@@ -690,7 +763,7 @@ class VarianceThreshold(BaseTransformer):
690
763
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
691
764
  inference_method=inference_method,
692
765
  input_cols=self.input_cols,
693
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
766
+ expected_output_cols=expected_output_cols,
694
767
  **transform_kwargs
695
768
  )
696
769
  return output_df
@@ -719,17 +792,17 @@ class VarianceThreshold(BaseTransformer):
719
792
  Output dataset with probability of the sample for each class in the model.
720
793
  """
721
794
  super()._check_dataset_type(dataset)
722
- inference_method="score_samples"
795
+ inference_method = "score_samples"
723
796
 
724
797
  # This dictionary contains optional kwargs for batch inference. These kwargs
725
798
  # are specific to the type of dataset used.
726
799
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
727
800
 
801
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
802
+
728
803
  if isinstance(dataset, DataFrame):
729
- self._deps = self._batch_inference_validate_snowpark(
730
- dataset=dataset,
731
- inference_method=inference_method,
732
- )
804
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
805
+ self._deps = self._get_dependencies()
733
806
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
734
807
  transform_kwargs = dict(
735
808
  session=dataset._session,
@@ -737,6 +810,9 @@ class VarianceThreshold(BaseTransformer):
737
810
  drop_input_cols = self._drop_input_cols,
738
811
  expected_output_cols_type="float",
739
812
  )
813
+ expected_output_cols = self._align_expected_output_names(
814
+ inference_method, dataset, expected_output_cols, output_cols_prefix
815
+ )
740
816
 
741
817
  elif isinstance(dataset, pd.DataFrame):
742
818
  transform_kwargs = dict(
@@ -755,7 +831,7 @@ class VarianceThreshold(BaseTransformer):
755
831
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
756
832
  inference_method=inference_method,
757
833
  input_cols=self.input_cols,
758
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
834
+ expected_output_cols=expected_output_cols,
759
835
  **transform_kwargs
760
836
  )
761
837
  return output_df
@@ -788,17 +864,15 @@ class VarianceThreshold(BaseTransformer):
788
864
  transform_kwargs: ScoreKwargsTypedDict = dict()
789
865
 
790
866
  if isinstance(dataset, DataFrame):
791
- self._deps = self._batch_inference_validate_snowpark(
792
- dataset=dataset,
793
- inference_method="score",
794
- )
867
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
868
+ self._deps = self._get_dependencies()
795
869
  selected_cols = self._get_active_columns()
796
870
  if len(selected_cols) > 0:
797
871
  dataset = dataset.select(selected_cols)
798
872
  assert isinstance(dataset._session, Session) # keep mypy happy
799
873
  transform_kwargs = dict(
800
874
  session=dataset._session,
801
- dependencies=["snowflake-snowpark-python"] + self._deps,
875
+ dependencies=self._deps,
802
876
  score_sproc_imports=['sklearn'],
803
877
  )
804
878
  elif isinstance(dataset, pd.DataFrame):
@@ -863,11 +937,8 @@ class VarianceThreshold(BaseTransformer):
863
937
 
864
938
  if isinstance(dataset, DataFrame):
865
939
 
866
- self._deps = self._batch_inference_validate_snowpark(
867
- dataset=dataset,
868
- inference_method=inference_method,
869
-
870
- )
940
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
941
+ self._deps = self._get_dependencies()
871
942
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
872
943
  transform_kwargs = dict(
873
944
  session = dataset._session,
@@ -900,50 +971,84 @@ class VarianceThreshold(BaseTransformer):
900
971
  )
901
972
  return output_df
902
973
 
974
+
975
+
976
+ def to_sklearn(self) -> Any:
977
+ """Get sklearn.feature_selection.VarianceThreshold object.
978
+ """
979
+ if self._sklearn_object is None:
980
+ self._sklearn_object = self._create_sklearn_object()
981
+ return self._sklearn_object
982
+
983
+ def to_xgboost(self) -> Any:
984
+ raise exceptions.SnowflakeMLException(
985
+ error_code=error_codes.METHOD_NOT_ALLOWED,
986
+ original_exception=AttributeError(
987
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
988
+ "to_xgboost()",
989
+ "to_sklearn()"
990
+ )
991
+ ),
992
+ )
903
993
 
904
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
994
+ def to_lightgbm(self) -> Any:
995
+ raise exceptions.SnowflakeMLException(
996
+ error_code=error_codes.METHOD_NOT_ALLOWED,
997
+ original_exception=AttributeError(
998
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
999
+ "to_lightgbm()",
1000
+ "to_sklearn()"
1001
+ )
1002
+ ),
1003
+ )
1004
+
1005
+ def _get_dependencies(self) -> List[str]:
1006
+ return self._deps
1007
+
1008
+
1009
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
905
1010
  self._model_signature_dict = dict()
906
1011
 
907
1012
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
908
1013
 
909
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1014
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
910
1015
  outputs: List[BaseFeatureSpec] = []
911
1016
  if hasattr(self, "predict"):
912
1017
  # keep mypy happy
913
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1018
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
914
1019
  # For classifier, the type of predict is the same as the type of label
915
- if self._sklearn_object._estimator_type == 'classifier':
916
- # label columns is the desired type for output
1020
+ if self._sklearn_object._estimator_type == "classifier":
1021
+ # label columns is the desired type for output
917
1022
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
918
1023
  # rename the output columns
919
1024
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
920
- self._model_signature_dict["predict"] = ModelSignature(inputs,
921
- ([] if self._drop_input_cols else inputs)
922
- + outputs)
1025
+ self._model_signature_dict["predict"] = ModelSignature(
1026
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1027
+ )
923
1028
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
924
1029
  # For outlier models, returns -1 for outliers and 1 for inliers.
925
- # Clusterer returns int64 cluster labels.
1030
+ # Clusterer returns int64 cluster labels.
926
1031
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
927
1032
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
928
- self._model_signature_dict["predict"] = ModelSignature(inputs,
929
- ([] if self._drop_input_cols else inputs)
930
- + outputs)
931
-
1033
+ self._model_signature_dict["predict"] = ModelSignature(
1034
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1035
+ )
1036
+
932
1037
  # For regressor, the type of predict is float64
933
- elif self._sklearn_object._estimator_type == 'regressor':
1038
+ elif self._sklearn_object._estimator_type == "regressor":
934
1039
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
935
- self._model_signature_dict["predict"] = ModelSignature(inputs,
936
- ([] if self._drop_input_cols else inputs)
937
- + outputs)
938
-
1040
+ self._model_signature_dict["predict"] = ModelSignature(
1041
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1042
+ )
1043
+
939
1044
  for prob_func in PROB_FUNCTIONS:
940
1045
  if hasattr(self, prob_func):
941
1046
  output_cols_prefix: str = f"{prob_func}_"
942
1047
  output_column_names = self._get_output_column_names(output_cols_prefix)
943
1048
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
944
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
945
- ([] if self._drop_input_cols else inputs)
946
- + outputs)
1049
+ self._model_signature_dict[prob_func] = ModelSignature(
1050
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1051
+ )
947
1052
 
948
1053
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
949
1054
  items = list(self._model_signature_dict.items())
@@ -956,10 +1061,10 @@ class VarianceThreshold(BaseTransformer):
956
1061
  """Returns model signature of current class.
957
1062
 
958
1063
  Raises:
959
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1064
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
960
1065
 
961
1066
  Returns:
962
- Dict[str, ModelSignature]: each method and its input output signature
1067
+ Dict with each method and its input output signature
963
1068
  """
964
1069
  if self._model_signature_dict is None:
965
1070
  raise exceptions.SnowflakeMLException(
@@ -967,35 +1072,3 @@ class VarianceThreshold(BaseTransformer):
967
1072
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
968
1073
  )
969
1074
  return self._model_signature_dict
970
-
971
- def to_sklearn(self) -> Any:
972
- """Get sklearn.feature_selection.VarianceThreshold object.
973
- """
974
- if self._sklearn_object is None:
975
- self._sklearn_object = self._create_sklearn_object()
976
- return self._sklearn_object
977
-
978
- def to_xgboost(self) -> Any:
979
- raise exceptions.SnowflakeMLException(
980
- error_code=error_codes.METHOD_NOT_ALLOWED,
981
- original_exception=AttributeError(
982
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
983
- "to_xgboost()",
984
- "to_sklearn()"
985
- )
986
- ),
987
- )
988
-
989
- def to_lightgbm(self) -> Any:
990
- raise exceptions.SnowflakeMLException(
991
- error_code=error_codes.METHOD_NOT_ALLOWED,
992
- original_exception=AttributeError(
993
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
994
- "to_lightgbm()",
995
- "to_sklearn()"
996
- )
997
- ),
998
- )
999
-
1000
- def _get_dependencies(self) -> List[str]:
1001
- return self._deps