snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,103 @@
1
+ import dataclasses
2
+ import json
3
+ import typing
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ _PROPERTY_TYPE_KEY = "$proptype$"
7
+ DATASET_SCHEMA_VERSION = "1"
8
+
9
+
10
+ @dataclasses.dataclass(frozen=True)
11
+ class FeatureStoreMetadata:
12
+ """
13
+ Feature store metadata.
14
+
15
+ Properties:
16
+ spine_query: The input query on source table which will be joined with features.
17
+ serialized_feature_views: A list of serialized feature objects in the feature store.
18
+ spine_timestamp_col: Timestamp column which was used for point-in-time correct feature lookup.
19
+ """
20
+
21
+ spine_query: str
22
+ serialized_feature_views: List[str]
23
+ spine_timestamp_col: Optional[str] = None
24
+
25
+ def to_json(self) -> str:
26
+ return json.dumps(dataclasses.asdict(self))
27
+
28
+ @classmethod
29
+ def from_json(cls, input_json: Union[Dict[str, Any], str, bytes]) -> "FeatureStoreMetadata":
30
+ if isinstance(input_json, dict):
31
+ return cls(**input_json)
32
+ return cls(**json.loads(input_json))
33
+
34
+
35
+ DatasetPropertiesType = Union[
36
+ FeatureStoreMetadata,
37
+ ]
38
+
39
+ # Union[T] gets automatically squashed to T, so default to [T] if get_args() returns empty
40
+ _DatasetPropTypes = typing.get_args(DatasetPropertiesType) or [DatasetPropertiesType]
41
+ _DatasetPropTypeDict = {t.__name__: t for t in _DatasetPropTypes}
42
+
43
+
44
+ @dataclasses.dataclass(frozen=True)
45
+ class DatasetMetadata:
46
+ """
47
+ Dataset metadata.
48
+
49
+ Properties:
50
+ source_query: The query string used to produce the Dataset.
51
+ owner: The owner of the Dataset.
52
+ generation_timestamp: The timestamp when this dataset was generated.
53
+ exclude_cols: Name of column(s) in dataset to be excluded during training/testing.
54
+ These are typically columns for human inspection such as timestamp or other meta-information.
55
+ Columns included in `label_cols` do not need to be included here.
56
+ label_cols: Name of column(s) in dataset that contains labels.
57
+ properties: Additional metadata properties.
58
+ """
59
+
60
+ source_query: str
61
+ owner: str
62
+ exclude_cols: Optional[List[str]] = None
63
+ label_cols: Optional[List[str]] = None
64
+ properties: Optional[DatasetPropertiesType] = None
65
+ schema_version: str = dataclasses.field(default=DATASET_SCHEMA_VERSION, init=False)
66
+
67
+ def to_json(self) -> str:
68
+ state_dict = dataclasses.asdict(self)
69
+ if self.properties:
70
+ prop_type = type(self.properties).__name__
71
+ if prop_type not in _DatasetPropTypeDict:
72
+ raise ValueError(
73
+ f"Unsupported `properties` type={prop_type} (supported={','.join(_DatasetPropTypeDict.keys())})"
74
+ )
75
+ state_dict[_PROPERTY_TYPE_KEY] = prop_type
76
+ return json.dumps(state_dict)
77
+
78
+ @classmethod
79
+ def from_json(cls, input_json: Union[Dict[str, Any], str, bytes]) -> "DatasetMetadata":
80
+ if not input_json:
81
+ raise ValueError("json_str was empty or None")
82
+ try:
83
+ state_dict: Dict[str, Any] = (
84
+ input_json if isinstance(input_json, dict) else json.loads(input_json, strict=False)
85
+ )
86
+
87
+ # TODO: Validate schema version
88
+ _ = state_dict.pop("schema_version", DATASET_SCHEMA_VERSION)
89
+
90
+ prop_type = state_dict.pop(_PROPERTY_TYPE_KEY, None)
91
+ prop_values = state_dict.get("properties", {})
92
+ if prop_type:
93
+ prop_cls = _DatasetPropTypeDict.get(prop_type, None)
94
+ if prop_cls is None:
95
+ raise TypeError(
96
+ f"Unsupported `properties` type={prop_type} (supported={','.join(_DatasetPropTypeDict.keys())})"
97
+ )
98
+ state_dict["properties"] = prop_cls(**prop_values)
99
+ elif prop_values:
100
+ raise TypeError(f"`properties` provided but missing `{_PROPERTY_TYPE_KEY}`")
101
+ return cls(**state_dict)
102
+ except TypeError as e:
103
+ raise ValueError("Invalid input schema") from e
@@ -0,0 +1,202 @@
1
+ from typing import Any, List
2
+
3
+ import pandas as pd
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml._internal.lineage import data_source, dataset_dataframe
8
+ from snowflake.ml._internal.utils import import_utils
9
+ from snowflake.ml.fileset import snowfs
10
+
11
+ _PROJECT = "Dataset"
12
+ _SUBPROJECT = "DatasetReader"
13
+ TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
14
+
15
+
16
+ class DatasetReader:
17
+ """Snowflake Dataset abstraction which provides application integration connectors"""
18
+
19
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
20
+ def __init__(
21
+ self,
22
+ session: snowpark.Session,
23
+ sources: List[data_source.DataSource],
24
+ ) -> None:
25
+ """Initialize a DatasetVersion object.
26
+
27
+ Args:
28
+ session: Snowpark Session to interact with Snowflake backend.
29
+ sources: Data sources to read from.
30
+
31
+ Raises:
32
+ ValueError: `sources` arg was empty or null
33
+ """
34
+ if not sources:
35
+ raise ValueError("Invalid input: empty `sources` list not allowed")
36
+ self._session = session
37
+ self._sources = sources
38
+ self._fs: snowfs.SnowFileSystem = snowfs.SnowFileSystem(
39
+ snowpark_session=self._session,
40
+ cache_type="bytes",
41
+ block_size=2 * TARGET_FILE_SIZE,
42
+ )
43
+
44
+ self._files: List[str] = []
45
+
46
+ def _list_files(self) -> List[str]:
47
+ """Private helper function that lists all files in this DatasetVersion and caches the results."""
48
+ if self._files:
49
+ return self._files
50
+
51
+ files: List[str] = []
52
+ for source in self._sources:
53
+ # Sort within each source for consistent ordering
54
+ files.extend(sorted(self._fs.ls(source.url))) # type: ignore[arg-type]
55
+ files.sort()
56
+
57
+ self._files = files
58
+ return self._files
59
+
60
+ @property
61
+ def data_sources(self) -> List[data_source.DataSource]:
62
+ return self._sources
63
+
64
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
65
+ def files(self) -> List[str]:
66
+ """Get the list of remote file paths for the current DatasetVersion.
67
+
68
+ The file paths follows the snow protocol.
69
+
70
+ Returns:
71
+ A list of remote file paths
72
+
73
+ Example:
74
+ >>> dsv.files()
75
+ ----
76
+ ["snow://dataset/mydb.myschema.mydataset/versions/test/data_0_0_0.snappy.parquet",
77
+ "snow://dataset/mydb.myschema.mydataset/versions/test/data_0_0_1.snappy.parquet"]
78
+ """
79
+ files = self._list_files()
80
+ return [self._fs.unstrip_protocol(f) for f in files]
81
+
82
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
83
+ def filesystem(self) -> snowfs.SnowFileSystem:
84
+ """Return an fsspec FileSystem which can be used to load the DatasetVersion's `files()`"""
85
+ return self._fs
86
+
87
+ @telemetry.send_api_usage_telemetry(
88
+ project=_PROJECT,
89
+ subproject=_SUBPROJECT,
90
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
91
+ )
92
+ def to_torch_datapipe(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any:
93
+ """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
94
+
95
+ Return a Pytorch datapipe which iterates on rows of data.
96
+
97
+ Args:
98
+ batch_size: It specifies the size of each data batch which will be
99
+ yield in the result datapipe
100
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
101
+ rows in each file will also be shuffled.
102
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
103
+ then the last batch will get dropped if its size is smaller than the given batch_size.
104
+
105
+ Returns:
106
+ A Pytorch iterable datapipe that yield data.
107
+
108
+ Examples:
109
+ >>> dp = dataset.to_torch_datapipe(batch_size=1)
110
+ >>> for data in dp:
111
+ >>> print(data)
112
+ ----
113
+ {'_COL_1':[10]}
114
+ """
115
+ IterableWrapper, _ = import_utils.import_or_get_dummy("torchdata.datapipes.iter.IterableWrapper")
116
+ torch_datapipe_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.torch_datapipe")
117
+
118
+ self._fs.optimize_read(self._list_files())
119
+
120
+ input_dp = IterableWrapper(self._list_files())
121
+ return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
122
+
123
+ @telemetry.send_api_usage_telemetry(
124
+ project=_PROJECT,
125
+ subproject=_SUBPROJECT,
126
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
127
+ )
128
+ def to_tf_dataset(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any:
129
+ """Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset.
130
+
131
+ Args:
132
+ batch_size: It specifies the size of each data batch which will be
133
+ yield in the result datapipe
134
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
135
+ rows in each file will also be shuffled.
136
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
137
+ then the last batch will get dropped if its size is smaller than the given batch_size.
138
+
139
+ Returns:
140
+ A tf.data.Dataset that yields batched tf.Tensors.
141
+
142
+ Examples:
143
+ >>> dp = dataset.to_tf_dataset(batch_size=1)
144
+ >>> for data in dp:
145
+ >>> print(data)
146
+ ----
147
+ {'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
148
+ """
149
+ tf_dataset_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.tf_dataset")
150
+
151
+ self._fs.optimize_read(self._list_files())
152
+
153
+ return tf_dataset_module.read_and_parse_parquet(
154
+ self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
155
+ )
156
+
157
+ @telemetry.send_api_usage_telemetry(
158
+ project=_PROJECT,
159
+ subproject=_SUBPROJECT,
160
+ func_params_to_log=["only_feature_cols"],
161
+ )
162
+ def to_snowpark_dataframe(self, only_feature_cols: bool = False) -> snowpark.DataFrame:
163
+ """Convert the DatasetVersion to a Snowpark DataFrame.
164
+
165
+ Args:
166
+ only_feature_cols: If True, drops exclude_cols and label_cols from returned DataFrame.
167
+ The original DatasetVersion is unaffected.
168
+
169
+ Returns:
170
+ A Snowpark dataframe that contains the data of this DatasetVersion.
171
+
172
+ Note: The dataframe generated by this method might not have the same schema as the original one. Specifically,
173
+ - NUMBER type with scale != 0 will become float.
174
+ - Unsupported types (see comments of :func:`Dataset.create_version`) will not have any guarantee.
175
+ For example, an OBJECT column may be scanned back as a STRING column.
176
+ """
177
+ file_path_pattern = ".*data_.*[.]parquet"
178
+ dfs: List[snowpark.DataFrame] = []
179
+ for source in self._sources:
180
+ df = self._session.read.option("pattern", file_path_pattern).parquet(source.url)
181
+ if only_feature_cols and source.exclude_cols:
182
+ df = df.drop(source.exclude_cols)
183
+ dfs.append(df)
184
+
185
+ combined_df = dfs[0]
186
+ for df in dfs[1:]:
187
+ combined_df = combined_df.union_all_by_name(df)
188
+ return dataset_dataframe.DatasetDataFrame.from_dataframe(combined_df, data_sources=self._sources, inplace=True)
189
+
190
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
191
+ def to_pandas(self) -> pd.DataFrame:
192
+ """Retrieve the DatasetVersion contents as a Pandas Dataframe"""
193
+ files = self._list_files()
194
+ if not files:
195
+ return pd.DataFrame() # Return empty DataFrame
196
+ self._fs.optimize_read(files)
197
+ pd_dfs = []
198
+ for file in files:
199
+ with self._fs.open(file) as fp:
200
+ pd_dfs.append(pd.read_parquet(fp))
201
+ pd_df = pd_dfs[0] if len(pd_dfs) == 1 else pd.concat(pd_dfs, ignore_index=True, copy=False)
202
+ return pd_df