snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +240 -16
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_sse_client.py +81 -0
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +34 -10
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  16. snowflake/ml/_internal/telemetry.py +26 -0
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/dataset/dataset.py +54 -32
  20. snowflake/ml/dataset/dataset_factory.py +3 -4
  21. snowflake/ml/feature_store/feature_store.py +440 -243
  22. snowflake/ml/feature_store/feature_view.py +61 -9
  23. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  24. snowflake/ml/fileset/fileset.py +2 -2
  25. snowflake/ml/fileset/snowfs.py +4 -15
  26. snowflake/ml/fileset/stage_fs.py +6 -8
  27. snowflake/ml/lineage/__init__.py +3 -0
  28. snowflake/ml/lineage/lineage_node.py +139 -0
  29. snowflake/ml/model/_client/model/model_impl.py +47 -14
  30. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  31. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  32. snowflake/ml/model/_client/sql/model.py +1 -0
  33. snowflake/ml/model/_client/sql/model_version.py +47 -4
  34. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  35. snowflake/ml/model/_model_composer/model_composer.py +7 -6
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  37. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  38. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
  40. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  41. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  42. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  43. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  45. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  46. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  53. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  56. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  57. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  58. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  59. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  60. snowflake/ml/model/_packager/model_packager.py +9 -4
  61. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  62. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  63. snowflake/ml/model/_signatures/core.py +13 -1
  64. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  65. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  66. snowflake/ml/model/custom_model.py +22 -2
  67. snowflake/ml/model/model_signature.py +2 -0
  68. snowflake/ml/model/type_hints.py +74 -4
  69. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  70. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
  71. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  72. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
  73. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
  74. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
  75. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  76. snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
  77. snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
  78. snowflake/ml/modeling/cluster/birch.py +5 -3
  79. snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
  80. snowflake/ml/modeling/cluster/dbscan.py +5 -3
  81. snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
  82. snowflake/ml/modeling/cluster/k_means.py +5 -3
  83. snowflake/ml/modeling/cluster/mean_shift.py +5 -3
  84. snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
  85. snowflake/ml/modeling/cluster/optics.py +5 -3
  86. snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
  87. snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
  88. snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
  89. snowflake/ml/modeling/compose/column_transformer.py +5 -3
  90. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  91. snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
  92. snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
  93. snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
  94. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
  95. snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
  96. snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
  97. snowflake/ml/modeling/covariance/oas.py +5 -3
  98. snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
  99. snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
  100. snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
  101. snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
  102. snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
  103. snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
  104. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
  105. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
  106. snowflake/ml/modeling/decomposition/pca.py +5 -3
  107. snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
  108. snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
  109. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  110. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  111. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  112. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  113. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  114. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  115. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  116. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  117. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  118. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  119. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  120. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  121. snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
  122. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  123. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  124. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  125. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  126. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  127. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  128. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  129. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  130. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  131. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  132. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  133. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  134. snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
  135. snowflake/ml/modeling/framework/base.py +3 -8
  136. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  137. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  138. snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
  139. snowflake/ml/modeling/impute/knn_imputer.py +5 -3
  140. snowflake/ml/modeling/impute/missing_indicator.py +5 -3
  141. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  142. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
  143. snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
  144. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
  145. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
  146. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
  147. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  148. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  149. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  151. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  152. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  153. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  154. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  155. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  156. snowflake/ml/modeling/linear_model/lars.py +1 -1
  157. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  158. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  159. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  160. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  161. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  162. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  163. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  164. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  165. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  166. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  167. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  168. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  169. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  170. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  171. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  172. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  173. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  174. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  175. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  176. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  177. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  178. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  179. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  180. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  181. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
  182. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  183. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  184. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  185. snowflake/ml/modeling/manifold/isomap.py +5 -3
  186. snowflake/ml/modeling/manifold/mds.py +5 -3
  187. snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
  188. snowflake/ml/modeling/manifold/tsne.py +5 -3
  189. snowflake/ml/modeling/metrics/ranking.py +3 -0
  190. snowflake/ml/modeling/metrics/regression.py +3 -0
  191. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
  192. snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
  193. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  194. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  195. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  196. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  197. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  198. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  199. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  200. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  201. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  202. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  203. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  204. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  205. snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
  206. snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
  207. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  208. snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
  209. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  210. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  211. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  212. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
  213. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  214. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  215. snowflake/ml/modeling/pipeline/pipeline.py +6 -0
  216. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  217. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  218. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  219. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  220. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  221. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  222. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
  223. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
  224. snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
  225. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  226. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  227. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  228. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  229. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  230. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  231. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  232. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  233. snowflake/ml/modeling/svm/svc.py +1 -1
  234. snowflake/ml/modeling/svm/svr.py +1 -1
  235. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  236. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  237. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  238. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  239. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  240. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  241. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  242. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  243. snowflake/ml/registry/_manager/model_manager.py +16 -3
  244. snowflake/ml/version.py +1 -1
  245. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
  246. snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
  247. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  248. snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
  249. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  250. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,11 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List
3
+ from typing import Any, Callable, List, Optional
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal.lineage import data_source
7
7
 
8
- DATA_SOURCES_ATTR = "_data_sources"
9
-
10
-
11
- def _get_datasources(*args: Any) -> List[data_source.DataSource]:
12
- """Helper method for extracting data sources attribute from DataFrames in an argument list"""
13
- result = []
14
- for arg in args:
15
- srcs = getattr(arg, DATA_SOURCES_ATTR, None)
16
- if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
17
- result += srcs
18
- return result
8
+ _DATA_SOURCES_ATTR = "_data_sources"
19
9
 
20
10
 
21
11
  def _wrap_func(
@@ -32,6 +22,37 @@ def _wrap_func(
32
22
  return wrapped
33
23
 
34
24
 
25
+ def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
26
+ @functools.wraps(fn)
27
+ def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
28
+ df = fn(*args, **kwargs)
29
+ data_sources = get_data_sources(*args, *kwargs.values())
30
+ if data_sources:
31
+ patch_dataframe(df, data_sources, inplace=True)
32
+ return df
33
+
34
+ return wrapped
35
+
36
+
37
+ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
38
+ """Helper method for extracting data sources attribute from DataFrames in an argument list"""
39
+ result: Optional[List[data_source.DataSource]] = None
40
+ for arg in args:
41
+ srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
+ if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
43
+ if result is None:
44
+ result = []
45
+ result += srcs
46
+ return result
47
+
48
+
49
+ def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
50
+ """Helper method for attaching data sources to an object"""
51
+ if data_sources:
52
+ assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
53
+ setattr(obj, _DATA_SOURCES_ATTR, data_sources)
54
+
55
+
35
56
  def patch_dataframe(
36
57
  df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
37
58
  ) -> snowpark.DataFrame:
@@ -62,7 +83,7 @@ def patch_dataframe(
62
83
  ]
63
84
  if not inplace:
64
85
  df = copy.copy(df)
65
- setattr(df, DATA_SOURCES_ATTR, data_sources)
86
+ set_data_sources(df, data_sources)
66
87
  for func in funcs:
67
88
  fn = getattr(df, func, None)
68
89
  if fn is not None:
@@ -70,18 +91,6 @@ def patch_dataframe(
70
91
  return df
71
92
 
72
93
 
73
- def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
74
- @functools.wraps(fn)
75
- def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
76
- df = fn(*args, **kwargs)
77
- data_sources = _get_datasources(*args) + _get_datasources(*kwargs.values())
78
- if data_sources:
79
- patch_dataframe(df, data_sources, inplace=True)
80
- return df
81
-
82
- return wrapped
83
-
84
-
85
94
  # Class-level monkey-patches
86
95
  for klass, func_list in {
87
96
  snowpark.DataFrame: [
@@ -10,6 +10,7 @@ from typing import (
10
10
  Dict,
11
11
  Iterable,
12
12
  List,
13
+ Mapping,
13
14
  Optional,
14
15
  Tuple,
15
16
  TypeVar,
@@ -92,6 +93,31 @@ def get_statement_params(
92
93
  )
93
94
 
94
95
 
96
+ def add_statement_params_custom_tags(
97
+ statement_params: Optional[Dict[str, Any]], custom_tags: Mapping[str, Any]
98
+ ) -> Dict[str, Any]:
99
+ """
100
+ Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
101
+ If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
102
+
103
+ Args:
104
+ statement_params: Existing statement_params dictionary.
105
+ custom_tags: Dictionary of existing k/v pairs to add as custom_tags
106
+
107
+ Returns:
108
+ new statement_params dictionary with all keys and an updated custom_tags field.
109
+ """
110
+ if not statement_params:
111
+ return {}
112
+ existing_custom_tags: Dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
113
+ existing_custom_tags.update(custom_tags)
114
+ # NOTE: This can be done with | operator after upgrade from py3.8
115
+ return {
116
+ **statement_params,
117
+ TelemetryField.KEY_CUSTOM_TAGS.value: existing_custom_tags,
118
+ }
119
+
120
+
95
121
  # TODO: we can merge this with get_statement_params after code clean up
96
122
  def get_statement_params_full_func_name(frame: Optional[types.FrameType], class_name: Optional[str] = None) -> str:
97
123
  """
@@ -165,6 +165,20 @@ def parse_schema_level_object_identifier(
165
165
  )
166
166
 
167
167
 
168
+ def is_fully_qualified_name(name: str) -> bool:
169
+ """
170
+ Checks if a given name is a fully qualified name, which is in the format '<db>.<schema>.<object_name>'.
171
+
172
+ Args:
173
+ name: The name to be checked.
174
+
175
+ Returns:
176
+ bool: True if the name is fully qualified, False otherwise.
177
+ """
178
+ res = parse_schema_level_object_identifier(name)
179
+ return res[0] is not None and res[1] is not None and res[2] is not None and not res[3]
180
+
181
+
168
182
  def get_schema_level_object_identifier(
169
183
  db: Optional[str],
170
184
  schema: Optional[str],
@@ -1,22 +1,27 @@
1
1
  import logging
2
2
  import warnings
3
+ from typing import List, Optional
3
4
 
4
5
  from snowflake import snowpark
6
+ from snowflake.ml._internal.utils import sql_identifier
5
7
  from snowflake.snowpark import functions, types
6
8
 
7
9
 
8
- def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
10
+ def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[List[str]] = None) -> snowpark.DataFrame:
9
11
  """Cast columns in the dataframe to types that are compatible with tensor.
10
12
 
11
13
  It assists FileSet.make() in performing implicit data casting.
12
14
 
13
15
  Args:
14
16
  df: A snowpark dataframe.
17
+ ignore_columns: Columns to exclude from casting. These columns will be propagated unchanged.
15
18
 
16
19
  Returns:
17
20
  A snowpark dataframe whose data type has been casted.
18
21
  """
19
22
 
23
+ ignore_cols_set = {sql_identifier.SqlIdentifier(c).identifier() for c in ignore_columns} if ignore_columns else {}
24
+
20
25
  fields = df.schema.fields
21
26
  selected_cols = []
22
27
  for field in fields:
@@ -40,7 +45,9 @@ def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
40
45
  dest = field.datatype
41
46
  selected_cols.append(functions.cast(functions.col(src), dest).alias(src))
42
47
  else:
43
- if field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
48
+ if field.column_identifier.name in ignore_cols_set:
49
+ pass
50
+ elif field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
44
51
  logging.warning(
45
52
  "A Column with DATE or TIMESTAMP data type detected. "
46
53
  "It might not be able to get converted to tensors. "
@@ -90,7 +97,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
90
97
  " is being automatically converted to DoubleType in the Snowpark DataFrame. "
91
98
  "This automatic conversion may lead to potential precision loss and rounding errors. "
92
99
  "If you wish to prevent this conversion, you should manually perform "
93
- "the necessary data type conversion."
100
+ "the necessary data type conversion.",
101
+ UserWarning,
102
+ stacklevel=2,
94
103
  )
95
104
  else:
96
105
  # IntegerType default as NUMBER(38, 0), but
@@ -102,7 +111,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
102
111
  " is being automatically converted to LongType in the Snowpark DataFrame. "
103
112
  "This automatic conversion may lead to potential precision loss and rounding errors. "
104
113
  "If you wish to prevent this conversion, you should manually perform "
105
- "the necessary data type conversion."
114
+ "the necessary data type conversion.",
115
+ UserWarning,
116
+ stacklevel=2,
106
117
  )
107
118
  selected_cols.append(functions.cast(functions.col(src), dest_dtype).alias(src))
108
119
  # TODO: add more type handling or error message
@@ -19,6 +19,7 @@ from snowflake.ml._internal.utils import (
19
19
  snowpark_dataframe_utils,
20
20
  )
21
21
  from snowflake.ml.dataset import dataset_metadata, dataset_reader
22
+ from snowflake.ml.lineage import lineage_node
22
23
  from snowflake.snowpark import exceptions as snowpark_exceptions, functions
23
24
 
24
25
  _PROJECT = "Dataset"
@@ -65,6 +66,20 @@ class DatasetVersion:
65
66
  comment: Optional[str] = self._get_property("comment")
66
67
  return comment
67
68
 
69
+ @property
70
+ def label_cols(self) -> List[str]:
71
+ metadata = self._get_metadata()
72
+ if metadata is None or metadata.label_cols is None:
73
+ return []
74
+ return metadata.label_cols
75
+
76
+ @property
77
+ def exclude_cols(self) -> List[str]:
78
+ metadata = self._get_metadata()
79
+ if metadata is None or metadata.exclude_cols is None:
80
+ return []
81
+ return metadata.exclude_cols
82
+
68
83
  def _get_property(self, property_name: str, default: Any = None) -> Any:
69
84
  if self._properties is None:
70
85
  sql_result = (
@@ -91,17 +106,6 @@ class DatasetVersion:
91
106
  warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2)
92
107
  return self._metadata
93
108
 
94
- def _get_exclude_cols(self) -> List[str]:
95
- metadata = self._get_metadata()
96
- if metadata is None:
97
- return []
98
- cols = []
99
- if metadata.exclude_cols:
100
- cols.extend(metadata.exclude_cols)
101
- if metadata.label_cols:
102
- cols.extend(metadata.label_cols)
103
- return cols
104
-
105
109
  def url(self) -> str:
106
110
  """Returns the URL of the DatasetVersion contents in Snowflake.
107
111
 
@@ -122,7 +126,7 @@ class DatasetVersion:
122
126
  return f"{self.__class__.__name__}(dataset='{self._parent.fully_qualified_name}', version='{self.name}')"
123
127
 
124
128
 
125
- class Dataset:
129
+ class Dataset(lineage_node.LineageNode):
126
130
  """Represents a Snowflake Dataset which is organized into versions."""
127
131
 
128
132
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -135,18 +139,31 @@ class Dataset:
135
139
  selected_version: Optional[str] = None,
136
140
  ) -> None:
137
141
  """Initialize a lazily evaluated Dataset object"""
138
- self._session = session
139
142
  self._db = database
140
143
  self._schema = schema
141
144
  self._name = name
142
- self._fully_qualified_name = identifier.get_schema_level_object_identifier(database, schema, name)
145
+
146
+ super().__init__(
147
+ session,
148
+ identifier.get_schema_level_object_identifier(database, schema, name),
149
+ domain="dataset",
150
+ version=selected_version,
151
+ )
143
152
 
144
153
  self._version = DatasetVersion(self, selected_version) if selected_version else None
145
154
  self._reader: Optional[dataset_reader.DatasetReader] = None
146
155
 
156
+ def __repr__(self) -> str:
157
+ return (
158
+ f"{self.__class__.__name__}(\n"
159
+ f" name='{self._lineage_node_name}',\n"
160
+ f" version='{self._version._version if self._version else None}',\n"
161
+ f")"
162
+ )
163
+
147
164
  @property
148
165
  def fully_qualified_name(self) -> str:
149
- return self._fully_qualified_name
166
+ return self._lineage_node_name
150
167
 
151
168
  @property
152
169
  def selected_version(self) -> Optional[DatasetVersion]:
@@ -165,10 +182,10 @@ class Dataset:
165
182
  self._session,
166
183
  [
167
184
  data_source.DataSource(
168
- fully_qualified_name=self._fully_qualified_name,
185
+ fully_qualified_name=self._lineage_node_name,
169
186
  version=v.name,
170
187
  url=v.url(),
171
- exclude_cols=v._get_exclude_cols(),
188
+ exclude_cols=(v.label_cols + v.exclude_cols),
172
189
  )
173
190
  ],
174
191
  )
@@ -227,9 +244,8 @@ class Dataset:
227
244
  try:
228
245
  session.sql(query).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
229
246
  return Dataset(session, db, schema, ds_name)
230
- except snowpark_exceptions.SnowparkClientException as e:
231
- # Snowpark wraps the Python Connector error code in the head of the error message.
232
- if e.message.startswith(dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS):
247
+ except snowpark_exceptions.SnowparkSQLException as e:
248
+ if e.sql_error_code == dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS:
233
249
  raise snowml_exceptions.SnowflakeMLException(
234
250
  error_code=error_codes.OBJECT_ALREADY_EXISTS,
235
251
  original_exception=dataset_errors.DatasetExistError(
@@ -293,7 +309,7 @@ class Dataset:
293
309
  Raises:
294
310
  SnowflakeMLException: The Dataset no longer exists.
295
311
  SnowflakeMLException: The specified Dataset version already exists.
296
- snowpark_exceptions.SnowparkClientException: An error occurred during Dataset creation.
312
+ snowpark_exceptions.SnowparkSQLException: An error occurred during Dataset creation.
297
313
 
298
314
  Note: During the generation of stage files, data casting will occur. The casting rules are as follows::
299
315
  - Data casting:
@@ -318,7 +334,8 @@ class Dataset:
318
334
  - DateType(DATE): Not supported. A warning will be logged.
319
335
  - VariantType(VARIANT): Not supported. A warning will be logged.
320
336
  """
321
- casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe)
337
+ cast_ignore_cols = (exclude_cols or []) + (label_cols or [])
338
+ casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe, ignore_columns=cast_ignore_cols)
322
339
 
323
340
  if shuffle:
324
341
  casted_df = casted_df.order_by(functions.random())
@@ -364,19 +381,19 @@ class Dataset:
364
381
 
365
382
  return Dataset(self._session, self._db, self._schema, self._name, version)
366
383
 
367
- except snowpark_exceptions.SnowparkClientException as e:
368
- if e.message.startswith(dataset_errors.ERRNO_DATASET_NOT_EXIST):
384
+ except snowpark_exceptions.SnowparkSQLException as e:
385
+ if e.sql_error_code == dataset_errors.ERRNO_DATASET_NOT_EXIST:
369
386
  raise snowml_exceptions.SnowflakeMLException(
370
387
  error_code=error_codes.NOT_FOUND,
371
388
  original_exception=dataset_errors.DatasetNotExistError(
372
389
  dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name)
373
390
  ),
374
391
  ) from e
375
- elif (
376
- e.message.startswith(dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS)
377
- or e.message.startswith(dataset_errors.ERRNO_VERSION_ALREADY_EXISTS)
378
- or e.message.startswith(dataset_errors.ERRNO_FILES_ALREADY_EXISTING)
379
- ):
392
+ elif e.sql_error_code in {
393
+ dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS,
394
+ dataset_errors.ERRNO_VERSION_ALREADY_EXISTS,
395
+ dataset_errors.ERRNO_FILES_ALREADY_EXISTING,
396
+ }:
380
397
  raise snowml_exceptions.SnowflakeMLException(
381
398
  error_code=error_codes.OBJECT_ALREADY_EXISTS,
382
399
  original_exception=dataset_errors.DatasetExistError(
@@ -432,9 +449,8 @@ class Dataset:
432
449
  .has_column(_DATASET_VERSION_NAME_COL, allow_empty=True)
433
450
  .validate()
434
451
  )
435
- except snowpark_exceptions.SnowparkClientException as e:
436
- # Snowpark wraps the Python Connector error code in the head of the error message.
437
- if e.message.startswith(dataset_errors.ERRNO_OBJECT_NOT_EXIST):
452
+ except snowpark_exceptions.SnowparkSQLException as e:
453
+ if e.sql_error_code == dataset_errors.ERRNO_OBJECT_NOT_EXIST:
438
454
  raise snowml_exceptions.SnowflakeMLException(
439
455
  error_code=error_codes.NOT_FOUND,
440
456
  original_exception=dataset_errors.DatasetNotExistError(
@@ -456,6 +472,12 @@ class Dataset:
456
472
  ),
457
473
  )
458
474
 
475
+ @staticmethod
476
+ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) -> "Dataset":
477
+ return Dataset.load(session, name).select_version(version)
478
+
479
+
480
+ lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
459
481
 
460
482
  # Utility methods
461
483
 
@@ -16,8 +16,7 @@ def create_from_dataframe(
16
16
  **version_kwargs: Any,
17
17
  ) -> dataset.Dataset:
18
18
  """
19
- Create a new versioned Dataset from a DataFrame and returns
20
- a DatasetReader for the newly created Dataset version.
19
+ Create a new versioned Dataset from a DataFrame.
21
20
 
22
21
  Args:
23
22
  session: The Snowpark Session instance to use.
@@ -39,7 +38,7 @@ def create_from_dataframe(
39
38
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
40
39
  def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset:
41
40
  """
42
- Load a versioned Dataset into a DatasetReader.
41
+ Load a versioned Dataset.
43
42
 
44
43
  Args:
45
44
  session: The Snowpark Session instance to use.
@@ -47,7 +46,7 @@ def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.
47
46
  version: The dataset version name.
48
47
 
49
48
  Returns:
50
- A DatasetReader object.
49
+ A Dataset object.
51
50
  """
52
51
  ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version)
53
52
  return ds