snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +240 -16
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_sse_client.py +81 -0
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +34 -10
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  16. snowflake/ml/_internal/telemetry.py +26 -0
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/dataset/dataset.py +54 -32
  20. snowflake/ml/dataset/dataset_factory.py +3 -4
  21. snowflake/ml/feature_store/feature_store.py +440 -243
  22. snowflake/ml/feature_store/feature_view.py +61 -9
  23. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  24. snowflake/ml/fileset/fileset.py +2 -2
  25. snowflake/ml/fileset/snowfs.py +4 -15
  26. snowflake/ml/fileset/stage_fs.py +6 -8
  27. snowflake/ml/lineage/__init__.py +3 -0
  28. snowflake/ml/lineage/lineage_node.py +139 -0
  29. snowflake/ml/model/_client/model/model_impl.py +47 -14
  30. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  31. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  32. snowflake/ml/model/_client/sql/model.py +1 -0
  33. snowflake/ml/model/_client/sql/model_version.py +47 -4
  34. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  35. snowflake/ml/model/_model_composer/model_composer.py +7 -6
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  37. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  38. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
  40. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  41. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  42. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  43. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  45. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  46. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  53. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  56. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  57. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  58. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  59. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  60. snowflake/ml/model/_packager/model_packager.py +9 -4
  61. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  62. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  63. snowflake/ml/model/_signatures/core.py +13 -1
  64. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  65. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  66. snowflake/ml/model/custom_model.py +22 -2
  67. snowflake/ml/model/model_signature.py +2 -0
  68. snowflake/ml/model/type_hints.py +74 -4
  69. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  70. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
  71. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  72. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
  73. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
  74. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
  75. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  76. snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
  77. snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
  78. snowflake/ml/modeling/cluster/birch.py +5 -3
  79. snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
  80. snowflake/ml/modeling/cluster/dbscan.py +5 -3
  81. snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
  82. snowflake/ml/modeling/cluster/k_means.py +5 -3
  83. snowflake/ml/modeling/cluster/mean_shift.py +5 -3
  84. snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
  85. snowflake/ml/modeling/cluster/optics.py +5 -3
  86. snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
  87. snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
  88. snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
  89. snowflake/ml/modeling/compose/column_transformer.py +5 -3
  90. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  91. snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
  92. snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
  93. snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
  94. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
  95. snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
  96. snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
  97. snowflake/ml/modeling/covariance/oas.py +5 -3
  98. snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
  99. snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
  100. snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
  101. snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
  102. snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
  103. snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
  104. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
  105. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
  106. snowflake/ml/modeling/decomposition/pca.py +5 -3
  107. snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
  108. snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
  109. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  110. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  111. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  112. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  113. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  114. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  115. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  116. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  117. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  118. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  119. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  120. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  121. snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
  122. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  123. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  124. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  125. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  126. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  127. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  128. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  129. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  130. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  131. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  132. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  133. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  134. snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
  135. snowflake/ml/modeling/framework/base.py +3 -8
  136. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  137. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  138. snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
  139. snowflake/ml/modeling/impute/knn_imputer.py +5 -3
  140. snowflake/ml/modeling/impute/missing_indicator.py +5 -3
  141. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  142. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
  143. snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
  144. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
  145. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
  146. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
  147. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  148. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  149. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  151. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  152. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  153. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  154. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  155. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  156. snowflake/ml/modeling/linear_model/lars.py +1 -1
  157. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  158. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  159. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  160. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  161. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  162. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  163. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  164. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  165. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  166. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  167. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  168. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  169. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  170. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  171. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  172. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  173. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  174. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  175. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  176. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  177. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  178. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  179. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  180. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  181. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
  182. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  183. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  184. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  185. snowflake/ml/modeling/manifold/isomap.py +5 -3
  186. snowflake/ml/modeling/manifold/mds.py +5 -3
  187. snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
  188. snowflake/ml/modeling/manifold/tsne.py +5 -3
  189. snowflake/ml/modeling/metrics/ranking.py +3 -0
  190. snowflake/ml/modeling/metrics/regression.py +3 -0
  191. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
  192. snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
  193. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  194. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  195. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  196. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  197. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  198. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  199. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  200. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  201. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  202. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  203. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  204. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  205. snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
  206. snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
  207. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  208. snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
  209. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  210. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  211. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  212. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
  213. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  214. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  215. snowflake/ml/modeling/pipeline/pipeline.py +6 -0
  216. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  217. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  218. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  219. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  220. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  221. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  222. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
  223. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
  224. snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
  225. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  226. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  227. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  228. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  229. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  230. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  231. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  232. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  233. snowflake/ml/modeling/svm/svc.py +1 -1
  234. snowflake/ml/modeling/svm/svr.py +1 -1
  235. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  236. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  237. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  238. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  239. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  240. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  241. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  242. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  243. snowflake/ml/registry/_manager/model_manager.py +16 -3
  244. snowflake/ml/version.py +1 -1
  245. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
  246. snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
  247. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  248. snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
  249. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  250. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -28,11 +28,15 @@ class MaxAbsScaler(base.BaseTransformer):
28
28
 
29
29
  Args:
30
30
  input_cols: Optional[Union[str, List[str]]], default=None
31
- The name(s) of one or more columns in a DataFrame containing a feature to be scaled.
31
+ The name(s) of one or more columns in the input DataFrame containing feature(s) to be scaled. Input
32
+ columns must be specified before fit with this argument or after initialization with the
33
+ `set_input_cols` method. This argument is optional for API consistency.
32
34
 
33
35
  output_cols: Optional[Union[str, List[str]]], default=None
34
- The name(s) of one or more columns in a DataFrame in which results will be stored. The number of
35
- columns specified must match the number of input columns.
36
+ The name(s) to assign output columns in the output DataFrame. The number of
37
+ columns specified must equal the number of input columns. Output columns must be specified before transform
38
+ with this argument or after initialization with the `set_output_cols` method. This argument is optional for
39
+ API consistency.
36
40
 
37
41
  passthrough_cols: Optional[Union[str, List[str]]], default=None
38
42
  A string or a list of strings indicating column names to be excluded from any
@@ -29,12 +29,15 @@ class MinMaxScaler(base.BaseTransformer):
29
29
  Whether to clip transformed values of held-out data to the specified feature range (default is True).
30
30
 
31
31
  input_cols: Optional[Union[str, List[str]]], default=None
32
- The name(s) of one or more columns in a DataFrame containing a feature to be scaled. Each specified
33
- input column is scaled independently and stored in the corresponding output column.
32
+ The name(s) of one or more columns in the input DataFrame containing feature(s) to be scaled. Input
33
+ columns must be specified before fit with this argument or after initialization with the
34
+ `set_input_cols` method. This argument is optional for API consistency.
34
35
 
35
36
  output_cols: Optional[Union[str, List[str]]], default=None
36
- The name(s) of one or more columns in a DataFrame in which results will be stored. The number of
37
- columns specified must match the number of input columns.
37
+ The name(s) to assign output columns in the output DataFrame. The number of
38
+ columns specified must equal the number of input columns. Output columns must be specified before transform
39
+ with this argument or after initialization with the `set_output_cols` method. This argument is optional for
40
+ API consistency.
38
41
 
39
42
  passthrough_cols: Optional[Union[str, List[str]]], default=None
40
43
  A string or a list of strings indicating column names to be excluded from any
@@ -28,11 +28,15 @@ class Normalizer(base.BaseTransformer):
28
28
  values. It must be one of 'l1', 'l2', or 'max'.
29
29
 
30
30
  input_cols: Optional[Union[str, List[str]]]
31
- Columns to use as inputs during transform.
31
+ The name(s) of one or more columns in the input DataFrame containing feature(s) to be normalized. Input
32
+ columns must be specified before transform with this argument or after initialization with the
33
+ `set_input_cols` method. This argument is optional for API consistency.
32
34
 
33
35
  output_cols: Optional[Union[str, List[str]]]
34
- A string or list of strings representing column names that will store the output of transform operation.
35
- The length of `output_cols` must equal the length of `input_cols`.
36
+ The name(s) to assign output columns in the output DataFrame. The number of
37
+ columns specified must equal the number of input columns. Output columns must be specified before transform
38
+ with this argument or after initialization with the `set_output_cols` method. This argument is optional for
39
+ API consistency.
36
40
 
37
41
  passthrough_cols: Optional[Union[str, List[str]]]
38
42
  A string or a list of strings indicating column names to be excluded from any
@@ -101,16 +101,20 @@ class OneHotEncoder(base.BaseTransformer):
101
101
  (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html).
102
102
 
103
103
  Args:
104
- categories: 'auto' or dict {column_name: np.ndarray([category])}, default='auto'
104
+ categories: 'auto', list of array-like, or dict {column_name: np.ndarray([category])}, default='auto'
105
105
  Categories (unique values) per feature:
106
106
  - 'auto': Determine categories automatically from the training data.
107
+ - list: ``categories[i]`` holds the categories expected in the ith
108
+ column. The passed categories should not mix strings and numeric
109
+ values within a single feature, and should be sorted in case of
110
+ numeric values.
107
111
  - dict: ``categories[column_name]`` holds the categories expected in
108
112
  the column provided. The passed categories should not mix strings
109
113
  and numeric values within a single feature, and should be sorted in
110
114
  case of numeric values.
111
115
  The used categories can be found in the ``categories_`` attribute.
112
116
 
113
- drop: {first’, if_binary} or an array-like of shape (n_features,), default=None
117
+ drop: {'first', 'if_binary'} or an array-like of shape (n_features,), default=None
114
118
  Specifies a methodology to use to drop one of the categories per
115
119
  feature. This is useful in situations where perfectly collinear
116
120
  features cause problems, such as when feeding the resulting data
@@ -157,10 +161,18 @@ class OneHotEncoder(base.BaseTransformer):
157
161
  there is no limit to the number of output features.
158
162
 
159
163
  input_cols: Optional[Union[str, List[str]]], default=None
160
- Single or multiple input columns.
164
+ The name(s) of one or more columns in the input DataFrame containing feature(s) to be encoded. Input
165
+ columns must be specified before fit with this argument or after initialization with the
166
+ `set_input_cols` method. This argument is optional for API consistency.
161
167
 
162
168
  output_cols: Optional[Union[str, List[str]]], default=None
163
- Single or multiple output columns.
169
+ The prefix to be used for encoded output for each input column. The number of
170
+ output column prefixes specified must match the number of input columns. Output column prefixes must be
171
+ specified before transform with this argument or after initialization with the `set_output_cols` method.
172
+
173
+ Note: Dense output column names are case-sensitive and resolve identifiers following Snowflake rules, e.g.
174
+ `"PREFIX_a"`, `PREFIX_A`, `"prefix_A"`. Therefore, there is no need to provide double-quoted column names
175
+ as that would result in invalid identifiers.
164
176
 
165
177
  passthrough_cols: Optional[Union[str, List[str]]]
166
178
  A string or a list of strings indicating column names to be excluded from any
@@ -198,7 +210,7 @@ class OneHotEncoder(base.BaseTransformer):
198
210
  def __init__(
199
211
  self,
200
212
  *,
201
- categories: Union[str, Dict[str, type_utils.LiteralNDArrayType]] = "auto",
213
+ categories: Union[str, List[type_utils.LiteralNDArrayType], Dict[str, type_utils.LiteralNDArrayType]] = "auto",
202
214
  drop: Optional[Union[str, npt.ArrayLike]] = None,
203
215
  sparse: bool = False,
204
216
  handle_unknown: str = "error",
@@ -432,8 +444,19 @@ class OneHotEncoder(base.BaseTransformer):
432
444
  assert found_state_df is not None
433
445
  if self.categories != "auto":
434
446
  state_data = []
435
- assert isinstance(self.categories, dict)
436
- for input_col, cats in self.categories.items():
447
+ if isinstance(self.categories, list):
448
+ categories_map = {col_name: cats for col_name, cats in zip(self.input_cols, self.categories)}
449
+ elif isinstance(self.categories, dict):
450
+ categories_map = self.categories
451
+ else:
452
+ raise exceptions.SnowflakeMLException(
453
+ error_code=error_codes.INVALID_ARGUMENT,
454
+ original_exception=ValueError(
455
+ f"Invalid type {type(self.categories)} provided for argument `categories`"
456
+ ),
457
+ )
458
+
459
+ for input_col, cats in categories_map.items():
437
460
  for cat in cats.tolist():
438
461
  state_data.append([input_col, cat])
439
462
  # states of given categories
@@ -557,6 +580,8 @@ class OneHotEncoder(base.BaseTransformer):
557
580
  else:
558
581
  categories[k] = vectorized_func(v)
559
582
  self.categories_ = categories
583
+ elif isinstance(self.categories, list):
584
+ self.categories_ = {col_name: cats for col_name, cats in zip(self.input_cols, self.categories)}
560
585
  else:
561
586
  self.categories_ = self.categories
562
587
 
@@ -842,8 +867,15 @@ class OneHotEncoder(base.BaseTransformer):
842
867
  # In case of fitting with pandas dataframe and transforming with snowpark dataframe
843
868
  # state_pandas cannot recognize the datatype of _CATEGORY and _FITTED_CATEGORY column
844
869
  # Therefore, apply the convert_to_string_excluding_nan function to _CATEGORY and _FITTED_CATEGORY
845
- state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].applymap(convert_to_string_excluding_nan)
846
- state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].applymap(convert_to_string_excluding_nan)
870
+ # applymap is depreciated since pandas 2.1.0, replaced by map
871
+ if pd.__version__ < "2.1.0":
872
+ state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].applymap(convert_to_string_excluding_nan)
873
+ state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].applymap(
874
+ convert_to_string_excluding_nan
875
+ )
876
+ else:
877
+ state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].map(convert_to_string_excluding_nan)
878
+ state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].map(convert_to_string_excluding_nan)
847
879
  state_df = dataset._session.create_dataframe(state_pandas)
848
880
 
849
881
  transformed_dataset = dataset
@@ -1001,7 +1033,7 @@ class OneHotEncoder(base.BaseTransformer):
1001
1033
  error_code=error_codes.INVALID_ATTRIBUTE,
1002
1034
  original_exception=ValueError(f"Unsupported `categories` value: {self.categories}."),
1003
1035
  )
1004
- elif isinstance(self.categories, dict):
1036
+ elif isinstance(self.categories, (dict, list)):
1005
1037
  if len(self.categories) != len(self.input_cols):
1006
1038
  raise exceptions.SnowflakeMLException(
1007
1039
  error_code=error_codes.INVALID_ATTRIBUTE,
@@ -1010,7 +1042,7 @@ class OneHotEncoder(base.BaseTransformer):
1010
1042
  f"({len(self.input_cols)})."
1011
1043
  ),
1012
1044
  )
1013
- elif set(self.categories.keys()) != set(self.input_cols):
1045
+ elif isinstance(self.categories, dict) and set(self.categories.keys()) != set(self.input_cols):
1014
1046
  raise exceptions.SnowflakeMLException(
1015
1047
  error_code=error_codes.INVALID_ATTRIBUTE,
1016
1048
  original_exception=ValueError(
@@ -1529,6 +1561,16 @@ class OneHotEncoder(base.BaseTransformer):
1529
1561
  default_sklearn_args = _utils.get_default_args(default_sklearn_obj.__class__.__init__)
1530
1562
  given_args = self.get_params()
1531
1563
 
1564
+ if "categories" in given_args and isinstance(given_args["categories"], dict):
1565
+ # sklearn requires a list of array-like to satisfy the `categories` arg
1566
+ try:
1567
+ given_args["categories"] = [given_args["categories"][input_col] for input_col in self.input_cols]
1568
+ except KeyError as e:
1569
+ raise exceptions.SnowflakeMLException(
1570
+ error_code=error_codes.INVALID_ARGUMENT,
1571
+ original_exception=e,
1572
+ )
1573
+
1532
1574
  # replace 'sparse' with 'sparse_output' when scikit-learn>=1.2
1533
1575
  sklearn_version = sklearn.__version__
1534
1576
  if version.parse(sklearn_version) >= version.parse(_SKLEARN_DEPRECATED_KEYWORD_TO_VERSION_DICT["sparse"]):
@@ -45,9 +45,11 @@ class OrdinalEncoder(base.BaseTransformer):
45
45
  (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OrdinalEncoder.html).
46
46
 
47
47
  Args:
48
- categories: Union[str, Dict[str, type_utils.LiteralNDArrayType]], default="auto"
48
+ categories: Union[str, List[type_utils.LiteralNDArrayType], Dict[str, type_utils.LiteralNDArrayType]],
49
+ default="auto"
49
50
  The string 'auto' (the default) causes the categories to be extracted from the input columns.
50
- To specify the categories yourself, pass a dictionary mapping the column name to an ndarray containing the
51
+ To specify the categories yourself, pass either (1) a list of ndarrays containing the categories or
52
+ (2) a dictionary mapping the column name to an ndarray containing the
51
53
  categories.
52
54
 
53
55
  handle_unknown: str, default="error"
@@ -67,11 +69,14 @@ class OrdinalEncoder(base.BaseTransformer):
67
69
  The value to be used to encode unknown categories.
68
70
 
69
71
  input_cols: Optional[Union[str, List[str]]], default=None
70
- The name(s) of one or more columns in a DataFrame containing a feature to be encoded.
72
+ The name(s) of one or more columns in the input DataFrame containing feature(s) to be encoded. Input
73
+ columns must be specified before fit with this argument or after initialization with the
74
+ `set_input_cols` method. This argument is optional for API consistency.
71
75
 
72
76
  output_cols: Optional[Union[str, List[str]]], default=None
73
- The name(s) of one or more columns in a DataFrame in which results will be stored. The number of
74
- columns specified must match the number of input columns.
77
+ The prefix to be used for encoded output for each input column. The number of
78
+ output column prefixes specified must equal the number of input columns. Output column prefixes must be
79
+ specified before transform with this argument or after initialization with the `set_output_cols` method.
75
80
 
76
81
  passthrough_cols: Optional[Union[str, List[str]]], default=None
77
82
  A string or a list of strings indicating column names to be excluded from any
@@ -93,7 +98,7 @@ class OrdinalEncoder(base.BaseTransformer):
93
98
  def __init__(
94
99
  self,
95
100
  *,
96
- categories: Union[str, Dict[str, type_utils.LiteralNDArrayType]] = "auto",
101
+ categories: Union[str, List[type_utils.LiteralNDArrayType], Dict[str, type_utils.LiteralNDArrayType]] = "auto",
97
102
  handle_unknown: str = "error",
98
103
  unknown_value: Optional[Union[int, float]] = None,
99
104
  encoded_missing_value: Union[int, float] = np.nan,
@@ -111,9 +116,13 @@ class OrdinalEncoder(base.BaseTransformer):
111
116
  a single column of integers (0 to n_categories - 1) per feature.
112
117
 
113
118
  Args:
114
- categories: 'auto' or dict {column_name: ndarray([category])}, default='auto'
119
+ categories: 'auto', list of array-like, or dict {column_name: ndarray([category])}, default='auto'
115
120
  Categories (unique values) per feature:
116
121
  - 'auto': Determine categories automatically from the training data.
122
+ - list: ``categories[i]`` holds the categories expected in the ith
123
+ column. The passed categories should not mix strings and numeric
124
+ values within a single feature, and should be sorted in case of
125
+ numeric values.
117
126
  - dict: ``categories[column_name]`` holds the categories expected in
118
127
  the column provided. The passed categories should not mix strings
119
128
  and numeric values within a single feature, and should be sorted in
@@ -247,7 +256,7 @@ class OrdinalEncoder(base.BaseTransformer):
247
256
  # columns: COLUMN_NAME, CATEGORY, INDEX
248
257
  state_df = self._get_category_index_state_df(dataset)
249
258
  # save the dataframe on server side so that transform doesn't need to upload
250
- state_df.write.save_as_table( # type: ignore[call-overload]
259
+ state_df.write.save_as_table(
251
260
  self._vocab_table_name,
252
261
  mode="overwrite",
253
262
  table_type="temporary",
@@ -314,8 +323,19 @@ class OrdinalEncoder(base.BaseTransformer):
314
323
  assert found_state_df is not None
315
324
  if self.categories != "auto":
316
325
  state_data = []
317
- assert isinstance(self.categories, dict)
318
- for input_col, cats in self.categories.items():
326
+ if isinstance(self.categories, list):
327
+ categories_map = {col_name: cats for col_name, cats in zip(self.input_cols, self.categories)}
328
+ elif isinstance(self.categories, dict):
329
+ categories_map = self.categories
330
+ else:
331
+ raise exceptions.SnowflakeMLException(
332
+ error_code=error_codes.INVALID_ARGUMENT,
333
+ original_exception=ValueError(
334
+ f"Invalid type {type(self.categories)} provided for argument `categories`"
335
+ ),
336
+ )
337
+
338
+ for input_col, cats in categories_map.items():
319
339
  for idx, cat in enumerate(cats.tolist()):
320
340
  state_data.append([input_col, cat, idx])
321
341
  # states of given categories
@@ -365,6 +385,8 @@ class OrdinalEncoder(base.BaseTransformer):
365
385
  for col_name, cats in grouped_categories.items()
366
386
  }
367
387
  self.categories_ = categories
388
+ elif isinstance(self.categories, list):
389
+ self.categories_ = {col_name: cats for col_name, cats in zip(self.input_cols, self.categories)}
368
390
  else:
369
391
  self.categories_ = self.categories
370
392
 
@@ -520,7 +542,7 @@ class OrdinalEncoder(base.BaseTransformer):
520
542
  )
521
543
 
522
544
  batch_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
523
- transformed_dataset.write.save_as_table( # type: ignore[call-overload]
545
+ transformed_dataset.write.save_as_table(
524
546
  batch_table_name,
525
547
  mode="overwrite",
526
548
  table_type="temporary",
@@ -545,6 +567,15 @@ class OrdinalEncoder(base.BaseTransformer):
545
567
  snowml_only_keywords=_SNOWML_ONLY_KEYWORDS,
546
568
  sklearn_added_keyword_to_version_dict=_SKLEARN_ADDED_KEYWORD_TO_VERSION_DICT,
547
569
  )
570
+ if "categories" in sklearn_args and isinstance(sklearn_args["categories"], dict):
571
+ # sklearn requires a list of array-like to satisfy the `categories` arg
572
+ try:
573
+ sklearn_args["categories"] = [sklearn_args["categories"][input_col] for input_col in self.input_cols]
574
+ except KeyError as e:
575
+ raise exceptions.SnowflakeMLException(
576
+ error_code=error_codes.INVALID_ARGUMENT,
577
+ original_exception=e,
578
+ )
548
579
  return preprocessing.OrdinalEncoder(**sklearn_args)
549
580
 
550
581
  def _create_sklearn_object(self) -> preprocessing.OrdinalEncoder:
@@ -567,7 +598,7 @@ class OrdinalEncoder(base.BaseTransformer):
567
598
  error_code=error_codes.INVALID_ATTRIBUTE,
568
599
  original_exception=ValueError(f"Unsupported `categories` value: {self.categories}."),
569
600
  )
570
- elif isinstance(self.categories, dict):
601
+ elif isinstance(self.categories, (dict, list)):
571
602
  if len(self.categories) != len(self.input_cols):
572
603
  raise exceptions.SnowflakeMLException(
573
604
  error_code=error_codes.INVALID_ATTRIBUTE,
@@ -576,7 +607,7 @@ class OrdinalEncoder(base.BaseTransformer):
576
607
  f"({len(self.input_cols)})."
577
608
  ),
578
609
  )
579
- elif set(self.categories.keys()) != set(self.input_cols):
610
+ elif isinstance(self.categories, dict) and set(self.categories.keys()) != set(self.input_cols):
580
611
  raise exceptions.SnowflakeMLException(
581
612
  error_code=error_codes.INVALID_ATTRIBUTE,
582
613
  original_exception=ValueError(
@@ -76,8 +76,10 @@ class PolynomialFeatures(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -251,7 +253,7 @@ class PolynomialFeatures(BaseTransformer):
251
253
  inspect.currentframe(), PolynomialFeatures.__class__.__name__
252
254
  ),
253
255
  api_calls=[Session.call],
254
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
256
+ custom_tags={"autogen": True} if self._autogenerated else None,
255
257
  )
256
258
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
257
259
  pd_df.columns = dataset.columns
@@ -37,12 +37,15 @@ class RobustScaler(base.BaseTransformer):
37
37
  the dataset is scaled down. If less than 1, the dataset is scaled up.
38
38
 
39
39
  input_cols: Optional[Union[str, List[str]]], default=None
40
- The name(s) of one or more columns in a DataFrame containing a feature to be scaled.
40
+ The name(s) of one or more columns in the input DataFrame containing feature(s) to be scaled. Input
41
+ columns must be specified before fit with this argument or after initialization with the
42
+ `set_input_cols` method. This argument is optional for API consistency.
41
43
 
42
44
  output_cols: Optional[Union[str, List[str]]], default=None
43
- The name(s) of one or more columns in a DataFrame in which results will be stored. The number of
44
- columns specified must match the number of input columns. For dense output, the column names specified are
45
- used as base names for the columns created for each category.
45
+ The name(s) to assign output columns in the output DataFrame. The number of
46
+ columns specified must equal the number of input columns. Output columns must be specified before transform
47
+ with this argument or after initialization with the `set_output_cols` method. This argument is optional for
48
+ API consistency.
46
49
 
47
50
  passthrough_cols: Optional[Union[str, List[str]]], default=None
48
51
  A string or a list of strings indicating column names to be excluded from any
@@ -26,11 +26,15 @@ class StandardScaler(base.BaseTransformer):
26
26
  If True, scale the data unit variance (i.e. unit standard deviation).
27
27
 
28
28
  input_cols: Optional[Union[str, List[str]]], default=None
29
- The name(s) of one or more columns in a DataFrame containing a feature to be scaled.
29
+ The name(s) of one or more columns in the input DataFrame containing feature(s) to be scaled. Input
30
+ columns must be specified before fit with this argument or after initialization with the
31
+ `set_input_cols` method. This argument is optional for API consistency.
30
32
 
31
33
  output_cols: Optional[Union[str, List[str]]], default=None
32
- The name(s) of one or more columns in a DataFrame in which results will be stored. The number of
33
- columns specified must match the number of input columns.
34
+ The name(s) to assign output columns in the output DataFrame. The number of
35
+ columns specified must equal the number of input columns. Output columns must be specified before transform
36
+ with this argument or after initialization with the `set_output_cols` method. This argument is optional for
37
+ API consistency.
34
38
 
35
39
  passthrough_cols: Optional[Union[str, List[str]]], default=None
36
40
  A string or a list of strings indicating column names to be excluded from any
@@ -257,7 +257,7 @@ class LabelPropagation(BaseTransformer):
257
257
  inspect.currentframe(), LabelPropagation.__class__.__name__
258
258
  ),
259
259
  api_calls=[Session.call],
260
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
260
+ custom_tags={"autogen": True} if self._autogenerated else None,
261
261
  )
262
262
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
263
263
  pd_df.columns = dataset.columns
@@ -266,7 +266,7 @@ class LabelSpreading(BaseTransformer):
266
266
  inspect.currentframe(), LabelSpreading.__class__.__name__
267
267
  ),
268
268
  api_calls=[Session.call],
269
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
269
+ custom_tags={"autogen": True} if self._autogenerated else None,
270
270
  )
271
271
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
272
272
  pd_df.columns = dataset.columns
@@ -322,7 +322,7 @@ class LinearSVC(BaseTransformer):
322
322
  inspect.currentframe(), LinearSVC.__class__.__name__
323
323
  ),
324
324
  api_calls=[Session.call],
325
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
325
+ custom_tags={"autogen": True} if self._autogenerated else None,
326
326
  )
327
327
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
328
328
  pd_df.columns = dataset.columns
@@ -294,7 +294,7 @@ class LinearSVR(BaseTransformer):
294
294
  inspect.currentframe(), LinearSVR.__class__.__name__
295
295
  ),
296
296
  api_calls=[Session.call],
297
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
297
+ custom_tags={"autogen": True} if self._autogenerated else None,
298
298
  )
299
299
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
300
300
  pd_df.columns = dataset.columns
@@ -328,7 +328,7 @@ class NuSVC(BaseTransformer):
328
328
  inspect.currentframe(), NuSVC.__class__.__name__
329
329
  ),
330
330
  api_calls=[Session.call],
331
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
331
+ custom_tags={"autogen": True} if self._autogenerated else None,
332
332
  )
333
333
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
334
334
  pd_df.columns = dataset.columns
@@ -289,7 +289,7 @@ class NuSVR(BaseTransformer):
289
289
  inspect.currentframe(), NuSVR.__class__.__name__
290
290
  ),
291
291
  api_calls=[Session.call],
292
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
292
+ custom_tags={"autogen": True} if self._autogenerated else None,
293
293
  )
294
294
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
295
295
  pd_df.columns = dataset.columns
@@ -331,7 +331,7 @@ class SVC(BaseTransformer):
331
331
  inspect.currentframe(), SVC.__class__.__name__
332
332
  ),
333
333
  api_calls=[Session.call],
334
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
334
+ custom_tags={"autogen": True} if self._autogenerated else None,
335
335
  )
336
336
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
337
337
  pd_df.columns = dataset.columns
@@ -292,7 +292,7 @@ class SVR(BaseTransformer):
292
292
  inspect.currentframe(), SVR.__class__.__name__
293
293
  ),
294
294
  api_calls=[Session.call],
295
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
295
+ custom_tags={"autogen": True} if self._autogenerated else None,
296
296
  )
297
297
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
298
298
  pd_df.columns = dataset.columns
@@ -359,7 +359,7 @@ class DecisionTreeClassifier(BaseTransformer):
359
359
  inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
360
360
  ),
361
361
  api_calls=[Session.call],
362
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
362
+ custom_tags={"autogen": True} if self._autogenerated else None,
363
363
  )
364
364
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
365
365
  pd_df.columns = dataset.columns
@@ -341,7 +341,7 @@ class DecisionTreeRegressor(BaseTransformer):
341
341
  inspect.currentframe(), DecisionTreeRegressor.__class__.__name__
342
342
  ),
343
343
  api_calls=[Session.call],
344
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
344
+ custom_tags={"autogen": True} if self._autogenerated else None,
345
345
  )
346
346
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
347
347
  pd_df.columns = dataset.columns
@@ -351,7 +351,7 @@ class ExtraTreeClassifier(BaseTransformer):
351
351
  inspect.currentframe(), ExtraTreeClassifier.__class__.__name__
352
352
  ),
353
353
  api_calls=[Session.call],
354
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
354
+ custom_tags={"autogen": True} if self._autogenerated else None,
355
355
  )
356
356
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
357
357
  pd_df.columns = dataset.columns
@@ -333,7 +333,7 @@ class ExtraTreeRegressor(BaseTransformer):
333
333
  inspect.currentframe(), ExtraTreeRegressor.__class__.__name__
334
334
  ),
335
335
  api_calls=[Session.call],
336
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
336
+ custom_tags={"autogen": True} if self._autogenerated else None,
337
337
  )
338
338
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
339
339
  pd_df.columns = dataset.columns
@@ -451,7 +451,7 @@ class XGBClassifier(BaseTransformer):
451
451
  inspect.currentframe(), XGBClassifier.__class__.__name__
452
452
  ),
453
453
  api_calls=[Session.call],
454
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
454
+ custom_tags={"autogen": True} if self._autogenerated else None,
455
455
  )
456
456
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
457
457
  pd_df.columns = dataset.columns
@@ -450,7 +450,7 @@ class XGBRegressor(BaseTransformer):
450
450
  inspect.currentframe(), XGBRegressor.__class__.__name__
451
451
  ),
452
452
  api_calls=[Session.call],
453
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
453
+ custom_tags={"autogen": True} if self._autogenerated else None,
454
454
  )
455
455
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
456
456
  pd_df.columns = dataset.columns
@@ -455,7 +455,7 @@ class XGBRFClassifier(BaseTransformer):
455
455
  inspect.currentframe(), XGBRFClassifier.__class__.__name__
456
456
  ),
457
457
  api_calls=[Session.call],
458
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
458
+ custom_tags={"autogen": True} if self._autogenerated else None,
459
459
  )
460
460
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
461
461
  pd_df.columns = dataset.columns
@@ -455,7 +455,7 @@ class XGBRFRegressor(BaseTransformer):
455
455
  inspect.currentframe(), XGBRFRegressor.__class__.__name__
456
456
  ),
457
457
  api_calls=[Session.call],
458
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
458
+ custom_tags={"autogen": True} if self._autogenerated else None,
459
459
  )
460
460
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
461
461
  pd_df.columns = dataset.columns
@@ -4,12 +4,14 @@ from typing import Any, Dict, List, Optional, Union
4
4
  import pandas as pd
5
5
  from absl.logging import logging
6
6
 
7
+ from snowflake.ml._internal import telemetry
7
8
  from snowflake.ml._internal.human_readable_id import hrid_generator
8
9
  from snowflake.ml._internal.utils import sql_identifier
9
10
  from snowflake.ml.model import model_signature, type_hints as model_types
10
11
  from snowflake.ml.model._client.model import model_impl, model_version_impl
11
12
  from snowflake.ml.model._client.ops import metadata_ops, model_ops
12
13
  from snowflake.ml.model._model_composer import model_composer
14
+ from snowflake.ml.model._packager.model_meta import model_meta
13
15
  from snowflake.snowpark import session
14
16
 
15
17
  logger = logging.getLogger(__name__)
@@ -124,7 +126,10 @@ class ModelManager:
124
126
  version_name=version_name_id,
125
127
  statement_params=statement_params,
126
128
  ):
127
- raise ValueError(f"Model {model_name} version {version_name} already existed.")
129
+ raise ValueError(
130
+ f"Model {model_name} version {version_name} already existed. "
131
+ + "To auto-generate `version_name`, skip that argument."
132
+ )
128
133
 
129
134
  stage_path = self._model_ops.prepare_model_stage_path(
130
135
  database_name=database_name_id,
@@ -134,8 +139,10 @@ class ModelManager:
134
139
 
135
140
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
136
141
 
137
- mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path)
138
- mc.save(
142
+ mc = model_composer.ModelComposer(
143
+ self._model_ops._session, stage_path=stage_path, statement_params=statement_params
144
+ )
145
+ model_metadata: model_meta.ModelMetadata = mc.save(
139
146
  name=model_name_id.resolved(),
140
147
  model=model,
141
148
  signatures=signatures,
@@ -147,6 +154,12 @@ class ModelManager:
147
154
  ext_modules=ext_modules,
148
155
  options=options,
149
156
  )
157
+ statement_params = telemetry.add_statement_params_custom_tags(
158
+ statement_params, model_metadata.telemetry_metadata()
159
+ )
160
+ statement_params = telemetry.add_statement_params_custom_tags(
161
+ statement_params, {"model_version_name": version_name_id}
162
+ )
150
163
 
151
164
  logger.info("Start creating MODEL object for you in the Snowflake.")
152
165
 
snowflake/ml/version.py CHANGED
@@ -1 +1 @@
1
- VERSION="1.5.2"
1
+ VERSION="1.5.4"