snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,161 +1,487 @@
1
1
  import json
2
- import time
3
- from dataclasses import dataclass
4
- from typing import Any, Dict, List, Optional
2
+ import warnings
3
+ from datetime import datetime
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
5
 
6
- from snowflake.ml.registry.artifact import Artifact, ArtifactType
7
- from snowflake.snowpark import DataFrame, Session
6
+ from snowflake import snowpark
7
+ from snowflake.ml._internal import telemetry
8
+ from snowflake.ml._internal.exceptions import (
9
+ dataset_error_messages,
10
+ dataset_errors,
11
+ error_codes,
12
+ exceptions as snowml_exceptions,
13
+ )
14
+ from snowflake.ml._internal.lineage import data_source
15
+ from snowflake.ml._internal.utils import (
16
+ formatting,
17
+ identifier,
18
+ query_result_checker,
19
+ snowpark_dataframe_utils,
20
+ )
21
+ from snowflake.ml.dataset import dataset_metadata, dataset_reader
22
+ from snowflake.snowpark import exceptions as snowpark_exceptions, functions
8
23
 
24
+ _PROJECT = "Dataset"
25
+ _TELEMETRY_STATEMENT_PARAMS = telemetry.get_function_usage_statement_params(_PROJECT)
26
+ _METADATA_MAX_QUERY_LENGTH = 10000
27
+ _DATASET_VERSION_NAME_COL = "version"
9
28
 
10
- def _get_val_or_null(val: Any) -> Any:
11
- return val if val is not None else "null"
12
29
 
30
+ class DatasetVersion:
31
+ """Represents a version of a Snowflake Dataset"""
13
32
 
14
- def _wrap_embedded_str(s: str) -> str:
15
- s = s.replace("\\", "\\\\")
16
- s = s.replace('"', '\\"')
17
- return s
33
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
34
+ def __init__(
35
+ self,
36
+ dataset: "Dataset",
37
+ version: str,
38
+ ) -> None:
39
+ """Initialize a DatasetVersion object.
18
40
 
41
+ Args:
42
+ dataset: The parent Snowflake Dataset.
43
+ version: Dataset version name.
44
+ """
45
+ self._parent = dataset
46
+ self._version = version
47
+ self._session: snowpark.Session = self._parent._session
19
48
 
20
- DATASET_SCHEMA_VERSION = "1"
49
+ self._properties: Optional[Dict[str, Any]] = None
50
+ self._raw_metadata: Optional[Dict[str, Any]] = None
51
+ self._metadata: Optional[dataset_metadata.DatasetMetadata] = None
21
52
 
53
+ @property
54
+ def name(self) -> str:
55
+ return self._version
22
56
 
23
- @dataclass(frozen=True)
24
- class FeatureStoreMetadata:
25
- """
26
- Feature store metadata.
57
+ @property
58
+ def created_on(self) -> datetime:
59
+ timestamp = self._get_property("created_on")
60
+ assert isinstance(timestamp, datetime)
61
+ return timestamp
27
62
 
28
- Properties:
29
- spine_query: The input query on source table which will be joined with features.
30
- connection_params: a config contains feature store metadata.
31
- features: A list of feature serialized object in the feature store.
63
+ @property
64
+ def comment(self) -> Optional[str]:
65
+ comment: Optional[str] = self._get_property("comment")
66
+ return comment
32
67
 
33
- """
68
+ def _get_property(self, property_name: str, default: Any = None) -> Any:
69
+ if self._properties is None:
70
+ sql_result = (
71
+ query_result_checker.SqlResultValidator(
72
+ self._session,
73
+ f"SHOW VERSIONS LIKE '{self._version}' IN DATASET {self._parent.fully_qualified_name}",
74
+ statement_params=_TELEMETRY_STATEMENT_PARAMS,
75
+ )
76
+ .has_column(_DATASET_VERSION_NAME_COL, allow_empty=False)
77
+ .validate()
78
+ )
79
+ (match_row,) = (r for r in sql_result if r[_DATASET_VERSION_NAME_COL] == self._version)
80
+ self._properties = match_row.as_dict(True)
81
+ return self._properties.get(property_name, default)
82
+
83
+ def _get_metadata(self) -> Optional[dataset_metadata.DatasetMetadata]:
84
+ if self._raw_metadata is None:
85
+ self._raw_metadata = json.loads(self._get_property("metadata", "{}"))
86
+ try:
87
+ self._metadata = (
88
+ dataset_metadata.DatasetMetadata.from_json(self._raw_metadata) if self._raw_metadata else None
89
+ )
90
+ except ValueError as e:
91
+ warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2)
92
+ return self._metadata
34
93
 
35
- spine_query: str
36
- connection_params: Dict[str, str]
37
- features: List[str]
94
+ def _get_exclude_cols(self) -> List[str]:
95
+ metadata = self._get_metadata()
96
+ if metadata is None:
97
+ return []
98
+ cols = []
99
+ if metadata.exclude_cols:
100
+ cols.extend(metadata.exclude_cols)
101
+ if metadata.label_cols:
102
+ cols.extend(metadata.label_cols)
103
+ return cols
38
104
 
39
- def to_json(self) -> str:
40
- state_dict = {
41
- # TODO(zhe): Additional wrap is needed because ml_.artifact.ad_artifact takes a dict
42
- # but we retrieve it as an object. Snowpark serialization is inconsistent with
43
- # our deserialization. A fix is let artifact table stores string and callers
44
- # handles both serialization and deserialization.
45
- "spine_query": self.spine_query,
46
- "connection_params": json.dumps(self.connection_params),
47
- "features": json.dumps(self.features),
48
- }
49
- return json.dumps(state_dict)
105
+ def url(self) -> str:
106
+ """Returns the URL of the DatasetVersion contents in Snowflake.
107
+
108
+ Returns:
109
+ Snowflake URL string.
110
+ """
111
+ path = f"snow://dataset/{self._parent.fully_qualified_name}/versions/{self._version}/"
112
+ return path
50
113
 
51
- @classmethod
52
- def from_json(cls, json_str: str) -> "FeatureStoreMetadata":
53
- json_dict = json.loads(json_str)
54
- return cls(
55
- spine_query=json_dict["spine_query"],
56
- connection_params=json.loads(json_dict["connection_params"]),
57
- features=json.loads(json_dict["features"]),
114
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
115
+ def list_files(self, subdir: Optional[str] = None) -> List[snowpark.Row]:
116
+ """Get the list of remote file paths for the current DatasetVersion."""
117
+ return self._session.sql(f"LIST {self.url()}{subdir or ''}").collect(
118
+ statement_params=_TELEMETRY_STATEMENT_PARAMS
58
119
  )
59
120
 
121
+ def __repr__(self) -> str:
122
+ return f"{self.__class__.__name__}(dataset='{self._parent.fully_qualified_name}', version='{self.name}')"
60
123
 
61
- class Dataset(Artifact):
62
- """Metadata of dataset."""
63
124
 
125
+ class Dataset:
126
+ """Represents a Snowflake Dataset which is organized into versions."""
127
+
128
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
64
129
  def __init__(
65
130
  self,
66
- session: Session,
67
- df: DataFrame,
68
- generation_timestamp: Optional[float] = None,
69
- materialized_table: Optional[str] = None,
70
- snapshot_table: Optional[str] = None,
71
- timestamp_col: Optional[str] = None,
72
- label_cols: Optional[List[str]] = None,
73
- feature_store_metadata: Optional[FeatureStoreMetadata] = None,
74
- desc: str = "",
131
+ session: snowpark.Session,
132
+ database: str,
133
+ schema: str,
134
+ name: str,
135
+ selected_version: Optional[str] = None,
75
136
  ) -> None:
76
- """Initialize dataset object.
137
+ """Initialize a lazily evaluated Dataset object"""
138
+ self._session = session
139
+ self._db = database
140
+ self._schema = schema
141
+ self._name = name
142
+ self._fully_qualified_name = identifier.get_schema_level_object_identifier(database, schema, name)
143
+
144
+ self._version = DatasetVersion(self, selected_version) if selected_version else None
145
+ self._reader: Optional[dataset_reader.DatasetReader] = None
146
+
147
+ @property
148
+ def fully_qualified_name(self) -> str:
149
+ return self._fully_qualified_name
150
+
151
+ @property
152
+ def selected_version(self) -> Optional[DatasetVersion]:
153
+ return self._version
154
+
155
+ @property
156
+ def read(self) -> dataset_reader.DatasetReader:
157
+ if not self.selected_version:
158
+ raise snowml_exceptions.SnowflakeMLException(
159
+ error_code=error_codes.INVALID_ATTRIBUTE,
160
+ original_exception=RuntimeError("No Dataset version selected."),
161
+ )
162
+ if self._reader is None:
163
+ v = self.selected_version
164
+ self._reader = dataset_reader.DatasetReader(
165
+ self._session,
166
+ [
167
+ data_source.DataSource(
168
+ fully_qualified_name=self._fully_qualified_name,
169
+ version=v.name,
170
+ url=v.url(),
171
+ exclude_cols=v._get_exclude_cols(),
172
+ )
173
+ ],
174
+ )
175
+ return self._reader
176
+
177
+ @staticmethod
178
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
179
+ def load(session: snowpark.Session, name: str) -> "Dataset":
180
+ """
181
+ Load an existing Snowflake Dataset. DatasetVersions can be created from the Dataset object
182
+ using `Dataset.create_version()` and loaded with `Dataset.version()`.
77
183
 
78
184
  Args:
79
- session: An active snowpark session.
80
- df: A dataframe object representing the dataset generation.
81
- generation_timestamp: The timestamp when this dataset is generated. It will use current time if
82
- not provided.
83
- materialized_table: The destination table name which data will writes into.
84
- snapshot_table: A snapshot table name on the materialized table.
85
- timestamp_col: Timestamp column which was used for point-in-time correct feature lookup.
86
- label_cols: Name of column(s) in materialized_table that contains labels.
87
- feature_store_metadata: A feature store metadata object.
88
- desc: A description about this dataset.
185
+ session: Snowpark Session to interact with Snowflake backend.
186
+ name: Name of dataset to load. May optionally be a schema-level identifier.
187
+
188
+ Returns:
189
+ Dataset object representing loaded dataset
190
+
191
+ Raises:
192
+ ValueError: name is not a valid Snowflake identifier
193
+ DatasetNotExistError: Specified Dataset does not exist
194
+
195
+ # noqa: DAR402
89
196
  """
90
- self.df = df
91
- self.generation_timestamp = generation_timestamp if generation_timestamp is not None else time.time()
92
- self.materialized_table = materialized_table
93
- self.snapshot_table = snapshot_table
94
- self.timestamp_col = timestamp_col
95
- self.label_cols = label_cols
96
- self.feature_store_metadata = feature_store_metadata
97
- self.desc = desc
98
- self.owner = session.sql("SELECT CURRENT_USER()").collect()[0]["CURRENT_USER()"]
99
- self.schema_version = DATASET_SCHEMA_VERSION
100
-
101
- super().__init__(type=ArtifactType.DATASET, spec=self.to_json())
102
-
103
- def load_features(self) -> Optional[List[str]]:
104
- if self.feature_store_metadata is not None:
105
- return self.feature_store_metadata.features
106
- else:
107
- return None
108
-
109
- def features_df(self) -> DataFrame:
110
- result = self.df
111
- if self.timestamp_col is not None:
112
- result = result.drop(self.timestamp_col)
113
- if self.label_cols is not None:
114
- result = result.drop(self.label_cols)
115
- return result
116
-
117
- def to_json(self) -> str:
118
- if len(self.df.queries["queries"]) != 1:
119
- raise ValueError(
120
- f"""df dataframe must contain only 1 query.
121
- Got {len(self.df.queries['queries'])}: {self.df.queries['queries']}
122
- """
197
+ db, schema, ds_name = _get_schema_level_identifier(session, name)
198
+ _validate_dataset_exists(session, db, schema, ds_name)
199
+ return Dataset(session, db, schema, ds_name)
200
+
201
+ @staticmethod
202
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
203
+ def create(session: snowpark.Session, name: str, exist_ok: bool = False) -> "Dataset":
204
+ """
205
+ Create a new Snowflake Dataset. DatasetVersions can created from the Dataset object
206
+ using `Dataset.create_version()` and loaded with `Dataset.version()`.
207
+
208
+ Args:
209
+ session: Snowpark Session to interact with Snowflake backend.
210
+ name: Name of dataset to create. May optionally be a schema-level identifier.
211
+ exist_ok: If False, raises an exception if specified Dataset already exists
212
+
213
+ Returns:
214
+ Dataset object representing created dataset
215
+
216
+ Raises:
217
+ ValueError: name is not a valid Snowflake identifier
218
+ DatasetExistError: Specified Dataset already exists
219
+ DatasetError: Dataset creation failed
220
+
221
+ # noqa: DAR401
222
+ # noqa: DAR402
223
+ """
224
+ db, schema, ds_name = _get_schema_level_identifier(session, name)
225
+ ds_fqn = identifier.get_schema_level_object_identifier(db, schema, ds_name)
226
+ query = f"CREATE DATASET{' IF NOT EXISTS' if exist_ok else ''} {ds_fqn}"
227
+ try:
228
+ session.sql(query).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
229
+ return Dataset(session, db, schema, ds_name)
230
+ except snowpark_exceptions.SnowparkClientException as e:
231
+ # Snowpark wraps the Python Connector error code in the head of the error message.
232
+ if e.message.startswith(dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS):
233
+ raise snowml_exceptions.SnowflakeMLException(
234
+ error_code=error_codes.OBJECT_ALREADY_EXISTS,
235
+ original_exception=dataset_errors.DatasetExistError(
236
+ dataset_error_messages.DATASET_ALREADY_EXISTS.format(name)
237
+ ),
238
+ ) from e
239
+ else:
240
+ raise
241
+
242
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
243
+ def list_versions(self, detailed: bool = False) -> Union[List[str], List[snowpark.Row]]:
244
+ """Return list of versions"""
245
+ versions = self._list_versions()
246
+ versions.sort(key=lambda r: r[_DATASET_VERSION_NAME_COL])
247
+ if not detailed:
248
+ return [r[_DATASET_VERSION_NAME_COL] for r in versions]
249
+ return versions
250
+
251
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
252
+ def select_version(self, version: str) -> "Dataset":
253
+ """Return a new Dataset instance with the specified version selected.
254
+
255
+ Args:
256
+ version: Dataset version name.
257
+
258
+ Returns:
259
+ Dataset object.
260
+ """
261
+ self._validate_version_exists(version)
262
+ return Dataset(self._session, self._db, self._schema, self._name, version)
263
+
264
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
265
+ def create_version(
266
+ self,
267
+ version: str,
268
+ input_dataframe: snowpark.DataFrame,
269
+ shuffle: bool = False,
270
+ exclude_cols: Optional[List[str]] = None,
271
+ label_cols: Optional[List[str]] = None,
272
+ properties: Optional[dataset_metadata.DatasetPropertiesType] = None,
273
+ partition_by: Optional[str] = None,
274
+ comment: Optional[str] = None,
275
+ ) -> "Dataset":
276
+ """Create a new version of the current Dataset.
277
+
278
+ The result Dataset object captures the query result deterministically as stage files.
279
+
280
+ Args:
281
+ version: Dataset version name. Data contents are materialized to the Dataset entity.
282
+ input_dataframe: A Snowpark DataFrame which yields the Dataset contents.
283
+ shuffle: A boolean represents whether the data should be shuffled globally. Default to be false.
284
+ exclude_cols: Name of column(s) in dataset to be excluded during training/testing (e.g. timestamp).
285
+ label_cols: Name of column(s) in dataset that contains labels.
286
+ properties: Custom metadata properties, saved under `DatasetMetadata.properties`
287
+ partition_by: Optional SQL expression to use as the partitioning scheme within the new Dataset version.
288
+ comment: A descriptive comment about this dataset.
289
+
290
+ Returns:
291
+ A Dataset object with the newly created version selected.
292
+
293
+ Raises:
294
+ SnowflakeMLException: The Dataset no longer exists.
295
+ SnowflakeMLException: The specified Dataset version already exists.
296
+ snowpark_exceptions.SnowparkClientException: An error occurred during Dataset creation.
297
+
298
+ Note: During the generation of stage files, data casting will occur. The casting rules are as follows::
299
+ - Data casting:
300
+ - DecimalType(NUMBER):
301
+ - If its scale is zero, cast to BIGINT
302
+ - If its scale is non-zero, cast to FLOAT
303
+ - DoubleType(DOUBLE): Cast to FLOAT.
304
+ - ByteType(TINYINT): Cast to SMALLINT.
305
+ - ShortType(SMALLINT):Cast to SMALLINT.
306
+ - IntegerType(INT): Cast to INT.
307
+ - LongType(BIGINT): Cast to BIGINT.
308
+ - No action:
309
+ - FloatType(FLOAT): No action.
310
+ - StringType(String): No action.
311
+ - BinaryType(BINARY): No action.
312
+ - BooleanType(BOOLEAN): No action.
313
+ - Not supported:
314
+ - ArrayType(ARRAY): Not supported. A warning will be logged.
315
+ - MapType(OBJECT): Not supported. A warning will be logged.
316
+ - TimestampType(TIMESTAMP): Not supported. A warning will be logged.
317
+ - TimeType(TIME): Not supported. A warning will be logged.
318
+ - DateType(DATE): Not supported. A warning will be logged.
319
+ - VariantType(VARIANT): Not supported. A warning will be logged.
320
+ """
321
+ casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe)
322
+
323
+ if shuffle:
324
+ casted_df = casted_df.order_by(functions.random())
325
+
326
+ source_query = json.dumps(input_dataframe.queries)
327
+ if len(source_query) > _METADATA_MAX_QUERY_LENGTH:
328
+ warnings.warn(
329
+ "Source query exceeded max query length, dropping from metadata (limit=%d, actual=%d)"
330
+ % (_METADATA_MAX_QUERY_LENGTH, len(source_query)),
331
+ stacklevel=2,
123
332
  )
333
+ source_query = "<query too long>"
124
334
 
125
- state_dict = {
126
- "df_query": _wrap_embedded_str(self.df.queries["queries"][0]),
127
- "generation_timestamp": self.generation_timestamp,
128
- "owner": self.owner,
129
- "materialized_table": _wrap_embedded_str(_get_val_or_null(self.materialized_table)),
130
- "snapshot_table": _wrap_embedded_str(_get_val_or_null(self.snapshot_table)),
131
- "timestamp_col": _wrap_embedded_str(_get_val_or_null(self.timestamp_col)),
132
- "label_cols": _get_val_or_null(self.label_cols),
133
- "feature_store_metadata": _wrap_embedded_str(self.feature_store_metadata.to_json())
134
- if self.feature_store_metadata is not None
135
- else "null",
136
- "schema_version": self.schema_version,
137
- "desc": self.desc,
138
- }
139
- return json.dumps(state_dict)
140
-
141
- @classmethod
142
- def from_json(cls, json_str: str, session: Session) -> "Dataset":
143
- json_dict = json.loads(json_str, strict=False)
144
- json_dict["df"] = session.sql(json_dict.pop("df_query"))
145
-
146
- fs_meta_json = json_dict["feature_store_metadata"]
147
- json_dict["feature_store_metadata"] = (
148
- FeatureStoreMetadata.from_json(fs_meta_json) if fs_meta_json != "null" else None
335
+ metadata = dataset_metadata.DatasetMetadata(
336
+ source_query=source_query,
337
+ owner=self._session.sql("SELECT CURRENT_USER()").collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)[0][
338
+ "CURRENT_USER()"
339
+ ],
340
+ exclude_cols=exclude_cols,
341
+ label_cols=label_cols,
342
+ properties=properties,
149
343
  )
150
344
 
151
- schema_version = json_dict.pop("schema_version")
152
- owner = json_dict.pop("owner")
345
+ post_actions = casted_df._plan.post_actions
346
+ try:
347
+ # Execute all but the last query, final query gets passed to ALTER DATASET ADD VERSION
348
+ query = casted_df._plan.queries[-1].sql.strip()
349
+ if len(casted_df._plan.queries) > 1:
350
+ casted_df._plan.queries = casted_df._plan.queries[:-1]
351
+ casted_df._plan.post_actions = []
352
+ casted_df.collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
353
+ sql_command = "ALTER DATASET {} ADD VERSION '{}' FROM ({})".format(
354
+ self.fully_qualified_name,
355
+ version,
356
+ query,
357
+ )
358
+ if partition_by:
359
+ sql_command += f" PARTITION BY {partition_by}"
360
+ if comment:
361
+ sql_command += f" COMMENT={formatting.format_value_for_select(comment)}"
362
+ sql_command += f" METADATA=$${metadata.to_json()}$$"
363
+ self._session.sql(sql_command).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
364
+
365
+ return Dataset(self._session, self._db, self._schema, self._name, version)
153
366
 
154
- result = cls(session, **json_dict)
155
- result.schema_version = schema_version
156
- result.owner = owner
367
+ except snowpark_exceptions.SnowparkClientException as e:
368
+ if e.message.startswith(dataset_errors.ERRNO_DATASET_NOT_EXIST):
369
+ raise snowml_exceptions.SnowflakeMLException(
370
+ error_code=error_codes.NOT_FOUND,
371
+ original_exception=dataset_errors.DatasetNotExistError(
372
+ dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name)
373
+ ),
374
+ ) from e
375
+ elif (
376
+ e.message.startswith(dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS)
377
+ or e.message.startswith(dataset_errors.ERRNO_VERSION_ALREADY_EXISTS)
378
+ or e.message.startswith(dataset_errors.ERRNO_FILES_ALREADY_EXISTING)
379
+ ):
380
+ raise snowml_exceptions.SnowflakeMLException(
381
+ error_code=error_codes.OBJECT_ALREADY_EXISTS,
382
+ original_exception=dataset_errors.DatasetExistError(
383
+ dataset_error_messages.DATASET_VERSION_ALREADY_EXISTS.format(self.fully_qualified_name, version)
384
+ ),
385
+ ) from e
386
+ else:
387
+ raise
388
+ finally:
389
+ for action in post_actions:
390
+ self._session.sql(action.sql.strip()).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
157
391
 
158
- return result
392
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
393
+ def delete_version(self, version_name: str) -> None:
394
+ """Delete the Dataset version
159
395
 
160
- def __eq__(self, other: object) -> bool:
161
- return isinstance(other, Dataset) and self.to_json() == other.to_json()
396
+ Args:
397
+ version_name: Name of version to delete from Dataset
398
+
399
+ Raises:
400
+ SnowflakeMLException: An error occurred when the DatasetVersion cannot get deleted.
401
+ """
402
+ delete_sql = f"ALTER DATASET {self.fully_qualified_name} DROP VERSION '{version_name}'"
403
+ try:
404
+ self._session.sql(delete_sql).collect(
405
+ statement_params=_TELEMETRY_STATEMENT_PARAMS,
406
+ )
407
+ except snowpark_exceptions.SnowparkClientException as e:
408
+ raise snowml_exceptions.SnowflakeMLException(
409
+ error_code=error_codes.SNOWML_DELETE_FAILED,
410
+ original_exception=dataset_errors.DatasetCannotDeleteError(str(e)),
411
+ ) from e
412
+ return
413
+
414
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
415
+ def delete(self) -> None:
416
+ """Delete Dataset and all contained versions"""
417
+ # TODO: Check and warn if any versions exist
418
+ self._session.sql(f"DROP DATASET {self.fully_qualified_name}").collect(
419
+ statement_params=_TELEMETRY_STATEMENT_PARAMS
420
+ )
421
+
422
+ def _list_versions(self, pattern: Optional[str] = None) -> List[snowpark.Row]:
423
+ """Return list of versions"""
424
+ try:
425
+ pattern_clause = f" LIKE '{pattern}'" if pattern else ""
426
+ return (
427
+ query_result_checker.SqlResultValidator(
428
+ self._session,
429
+ f"SHOW VERSIONS{pattern_clause} IN DATASET {self.fully_qualified_name}",
430
+ statement_params=_TELEMETRY_STATEMENT_PARAMS,
431
+ )
432
+ .has_column(_DATASET_VERSION_NAME_COL, allow_empty=True)
433
+ .validate()
434
+ )
435
+ except snowpark_exceptions.SnowparkClientException as e:
436
+ # Snowpark wraps the Python Connector error code in the head of the error message.
437
+ if e.message.startswith(dataset_errors.ERRNO_OBJECT_NOT_EXIST):
438
+ raise snowml_exceptions.SnowflakeMLException(
439
+ error_code=error_codes.NOT_FOUND,
440
+ original_exception=dataset_errors.DatasetNotExistError(
441
+ dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name)
442
+ ),
443
+ ) from e
444
+ else:
445
+ raise
446
+
447
+ def _validate_version_exists(self, version: str) -> None:
448
+ """Verify that the requested version exists. Raises DatasetNotExist if version not found"""
449
+ matches = self._list_versions(version)
450
+ matches = [m for m in matches if m[_DATASET_VERSION_NAME_COL] == version] # Case sensitive match
451
+ if len(matches) == 0:
452
+ raise snowml_exceptions.SnowflakeMLException(
453
+ error_code=error_codes.NOT_FOUND,
454
+ original_exception=dataset_errors.DatasetNotExistError(
455
+ dataset_error_messages.DATASET_VERSION_NOT_EXIST.format(self.fully_qualified_name, version)
456
+ ),
457
+ )
458
+
459
+
460
+ # Utility methods
461
+
462
+
463
+ def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> Tuple[str, str, str]:
464
+ """Resolve a dataset name into a validated schema-level location identifier"""
465
+ db, schema, object_name, others = identifier.parse_schema_level_object_identifier(dataset_name)
466
+ if others:
467
+ raise ValueError(f"Invalid identifier: unexpected '{others}'")
468
+ db = db or session.get_current_database()
469
+ schema = schema or session.get_current_schema()
470
+ return str(db), str(schema), str(object_name)
471
+
472
+
473
+ def _validate_dataset_exists(session: snowpark.Session, db: str, schema: str, dataset_name: str) -> None:
474
+ # FIXME: Once we switch version to SQL Identifiers we can just use version check with version=''
475
+ dataset_name = identifier.resolve_identifier(dataset_name)
476
+ if len(dataset_name) > 0 and dataset_name[0] == '"' and dataset_name[-1] == '"':
477
+ dataset_name = identifier.get_unescaped_names(dataset_name)
478
+ # Case sensitive match
479
+ query = f"show datasets like '{dataset_name}' in schema {db}.{schema} starts with '{dataset_name}'"
480
+ ds_matches = session.sql(query).count()
481
+ if ds_matches == 0:
482
+ raise snowml_exceptions.SnowflakeMLException(
483
+ error_code=error_codes.NOT_FOUND,
484
+ original_exception=dataset_errors.DatasetNotExistError(
485
+ dataset_error_messages.DATASET_NOT_EXIST.format(dataset_name)
486
+ ),
487
+ )
@@ -0,0 +1,53 @@
1
+ from typing import Any
2
+
3
+ from snowflake import snowpark
4
+ from snowflake.ml._internal import telemetry
5
+ from snowflake.ml.dataset import dataset
6
+
7
+ _PROJECT = "Dataset"
8
+
9
+
10
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
11
+ def create_from_dataframe(
12
+ session: snowpark.Session,
13
+ name: str,
14
+ version: str,
15
+ input_dataframe: snowpark.DataFrame,
16
+ **version_kwargs: Any,
17
+ ) -> dataset.Dataset:
18
+ """
19
+ Create a new versioned Dataset from a DataFrame and returns
20
+ a DatasetReader for the newly created Dataset version.
21
+
22
+ Args:
23
+ session: The Snowpark Session instance to use.
24
+ name: The dataset name
25
+ version: The dataset version name
26
+ input_dataframe: DataFrame containing data to be saved to the created Dataset.
27
+ version_kwargs: Keyword arguments passed to dataset version creation.
28
+ See `Dataset.create_version()` documentation for supported arguments.
29
+
30
+ Returns:
31
+ A Dataset object.
32
+ """
33
+ ds: dataset.Dataset = dataset.Dataset.create(session, name, exist_ok=True)
34
+ ds.create_version(version, input_dataframe=input_dataframe, **version_kwargs)
35
+ ds = ds.select_version(version) # select_version returns a new copy
36
+ return ds
37
+
38
+
39
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
40
+ def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset:
41
+ """
42
+ Load a versioned Dataset into a DatasetReader.
43
+
44
+ Args:
45
+ session: The Snowpark Session instance to use.
46
+ name: The dataset name.
47
+ version: The dataset version name.
48
+
49
+ Returns:
50
+ A DatasetReader object.
51
+ """
52
+ ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version)
53
+ return ds