snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/__init__.py +4 -1
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +281 -21
  4. snowflake/cortex/_extract_answer.py +0 -1
  5. snowflake/cortex/_sentiment.py +0 -1
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +12 -85
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  16. snowflake/ml/_internal/telemetry.py +38 -2
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  20. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  21. snowflake/ml/data/data_connector.py +133 -0
  22. snowflake/ml/data/data_ingestor.py +28 -0
  23. snowflake/ml/data/data_source.py +23 -0
  24. snowflake/ml/dataset/dataset.py +39 -32
  25. snowflake/ml/dataset/dataset_reader.py +18 -118
  26. snowflake/ml/feature_store/access_manager.py +7 -1
  27. snowflake/ml/feature_store/entity.py +19 -2
  28. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  29. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  30. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  31. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  32. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  34. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  35. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  36. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  37. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  38. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  39. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  40. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  43. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  44. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  45. snowflake/ml/feature_store/feature_store.py +987 -264
  46. snowflake/ml/feature_store/feature_view.py +228 -13
  47. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  48. snowflake/ml/fileset/fileset.py +2 -2
  49. snowflake/ml/fileset/snowfs.py +4 -15
  50. snowflake/ml/fileset/stage_fs.py +24 -18
  51. snowflake/ml/lineage/__init__.py +3 -0
  52. snowflake/ml/lineage/lineage_node.py +139 -0
  53. snowflake/ml/model/_client/model/model_impl.py +47 -14
  54. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  55. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  56. snowflake/ml/model/_client/sql/model.py +1 -0
  57. snowflake/ml/model/_client/sql/model_version.py +45 -2
  58. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  59. snowflake/ml/model/_model_composer/model_composer.py +15 -17
  60. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
  61. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  62. snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
  63. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  64. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
  65. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
  66. snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
  67. snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
  68. snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
  69. snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
  70. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  71. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  72. snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
  73. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  74. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  75. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  76. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  77. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  78. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  79. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  80. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  81. snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
  82. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  83. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  84. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  85. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  86. snowflake/ml/model/_packager/model_packager.py +9 -4
  87. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  88. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  89. snowflake/ml/model/custom_model.py +22 -2
  90. snowflake/ml/model/model_signature.py +4 -4
  91. snowflake/ml/model/type_hints.py +77 -4
  92. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
  93. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  94. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  95. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  96. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  97. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  98. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  99. snowflake/ml/modeling/cluster/birch.py +4 -2
  100. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  101. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  102. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  103. snowflake/ml/modeling/cluster/k_means.py +4 -2
  104. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  105. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  106. snowflake/ml/modeling/cluster/optics.py +4 -2
  107. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  108. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  109. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  110. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  111. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  112. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  113. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  114. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  115. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  116. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  117. snowflake/ml/modeling/covariance/oas.py +4 -2
  118. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  119. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  120. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  121. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  122. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  123. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  124. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  125. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  126. snowflake/ml/modeling/decomposition/pca.py +4 -2
  127. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  128. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  129. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  130. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  131. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  132. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  133. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  134. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  135. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  136. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  137. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  138. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  139. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  140. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  141. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  142. snowflake/ml/modeling/manifold/isomap.py +4 -2
  143. snowflake/ml/modeling/manifold/mds.py +4 -2
  144. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  145. snowflake/ml/modeling/manifold/tsne.py +4 -2
  146. snowflake/ml/modeling/metrics/ranking.py +3 -0
  147. snowflake/ml/modeling/metrics/regression.py +3 -0
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  150. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  151. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  152. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  153. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  154. snowflake/ml/modeling/pipeline/pipeline.py +5 -4
  155. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  156. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  157. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  158. snowflake/ml/registry/_manager/model_manager.py +16 -3
  159. snowflake/ml/registry/registry.py +100 -13
  160. snowflake/ml/version.py +1 -1
  161. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
  162. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
  163. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  164. snowflake/ml/_internal/lineage/data_source.py +0 -10
  165. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
2
+
3
+ import numpy.typing as npt
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml.data import data_ingestor, data_source
8
+ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor as DefaultIngestor
9
+
10
+ if TYPE_CHECKING:
11
+ import pandas as pd
12
+ import tensorflow as tf
13
+ from torch.utils import data as torch_data
14
+
15
+ # This module can't actually depend on dataset to avoid a circular dependency
16
+ # Dataset -> DatasetReader -> DataConnector -!-> Dataset
17
+ from snowflake.ml import dataset
18
+
19
+ _PROJECT = "DataConnector"
20
+
21
+ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
22
+
23
+
24
+ class DataConnector:
25
+ """Snowflake data reader which provides application integration connectors"""
26
+
27
+ def __init__(
28
+ self,
29
+ ingestor: data_ingestor.DataIngestor,
30
+ ) -> None:
31
+ self._ingestor = ingestor
32
+
33
+ @classmethod
34
+ def from_dataframe(cls: Type[DataConnectorType], df: snowpark.DataFrame, **kwargs: Any) -> DataConnectorType:
35
+ if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
36
+ raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
37
+ source = data_source.DataFrameInfo(df.queries["queries"][0])
38
+ assert df._session is not None
39
+ ingestor = DefaultIngestor(df._session, [source])
40
+ return cls(ingestor, **kwargs)
41
+
42
+ @classmethod
43
+ def from_dataset(cls: Type[DataConnectorType], ds: "dataset.Dataset", **kwargs: Any) -> DataConnectorType:
44
+ dsv = ds.selected_version
45
+ assert dsv is not None
46
+ source = data_source.DatasetInfo(
47
+ ds.fully_qualified_name, dsv.name, dsv.url(), exclude_cols=(dsv.label_cols + dsv.exclude_cols)
48
+ )
49
+ ingestor = DefaultIngestor(ds._session, [source])
50
+ return cls(ingestor, **kwargs)
51
+
52
+ @property
53
+ def data_sources(self) -> List[data_source.DataSource]:
54
+ return self._ingestor.data_sources
55
+
56
+ @telemetry.send_api_usage_telemetry(
57
+ project=_PROJECT,
58
+ subproject_extractor=lambda self: type(self).__name__,
59
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
60
+ )
61
+ def to_tf_dataset(
62
+ self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
63
+ ) -> "tf.data.Dataset":
64
+ """Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset.
65
+
66
+ Args:
67
+ batch_size: It specifies the size of each data batch which will be
68
+ yield in the result datapipe
69
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
70
+ rows in each file will also be shuffled.
71
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
72
+ then the last batch will get dropped if its size is smaller than the given batch_size.
73
+
74
+ Returns:
75
+ A tf.data.Dataset that yields batched tf.Tensors.
76
+ """
77
+ import tensorflow as tf
78
+
79
+ def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
80
+ yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
81
+
82
+ # Derive TensorFlow signature
83
+ first_batch = next(self._ingestor.to_batches(1, shuffle=False, drop_last_batch=False))
84
+ tf_signature = {
85
+ k: tf.TensorSpec(shape=(None,), dtype=tf.dtypes.as_dtype(v.dtype), name=k) for k, v in first_batch.items()
86
+ }
87
+
88
+ return tf.data.Dataset.from_generator(generator, output_signature=tf_signature)
89
+
90
+ @telemetry.send_api_usage_telemetry(
91
+ project=_PROJECT,
92
+ subproject_extractor=lambda self: type(self).__name__,
93
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
94
+ )
95
+ def to_torch_datapipe(
96
+ self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
97
+ ) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
98
+ """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
99
+
100
+ Return a Pytorch datapipe which iterates on rows of data.
101
+
102
+ Args:
103
+ batch_size: It specifies the size of each data batch which will be
104
+ yield in the result datapipe
105
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
106
+ rows in each file will also be shuffled.
107
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
108
+ then the last batch will get dropped if its size is smaller than the given batch_size.
109
+
110
+ Returns:
111
+ A Pytorch iterable datapipe that yield data.
112
+ """
113
+ from torch.utils.data.datapipes import iter as torch_iter
114
+
115
+ return torch_iter.IterableWrapper( # type: ignore[no-untyped-call]
116
+ self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
117
+ )
118
+
119
+ @telemetry.send_api_usage_telemetry(
120
+ project=_PROJECT,
121
+ subproject_extractor=lambda self: type(self).__name__,
122
+ func_params_to_log=["limit"],
123
+ )
124
+ def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
125
+ """Retrieve the Snowflake data as a Pandas DataFrame.
126
+
127
+ Args:
128
+ limit: If specified, the maximum number of rows to load into the DataFrame.
129
+
130
+ Returns:
131
+ A Pandas DataFrame.
132
+ """
133
+ return self._ingestor.to_pandas(limit)
@@ -0,0 +1,28 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Protocol, TypeVar
2
+
3
+ from numpy import typing as npt
4
+
5
+ from snowflake.ml.data import data_source
6
+
7
+ if TYPE_CHECKING:
8
+ import pandas as pd
9
+
10
+
11
+ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
12
+
13
+
14
+ class DataIngestor(Protocol):
15
+ @property
16
+ def data_sources(self) -> List[data_source.DataSource]:
17
+ raise NotImplementedError
18
+
19
+ def to_batches(
20
+ self,
21
+ batch_size: int,
22
+ shuffle: bool = True,
23
+ drop_last_batch: bool = True,
24
+ ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
25
+ raise NotImplementedError
26
+
27
+ def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
28
+ raise NotImplementedError
@@ -0,0 +1,23 @@
1
+ import dataclasses
2
+ from typing import List, Optional, Union
3
+
4
+
5
+ @dataclasses.dataclass(frozen=True)
6
+ class DataFrameInfo:
7
+ """Serializable information from Snowpark DataFrames"""
8
+
9
+ sql: str
10
+ query_id: Optional[str] = None
11
+
12
+
13
+ @dataclasses.dataclass(frozen=True)
14
+ class DatasetInfo:
15
+ """Serializable information from SnowML Datasets"""
16
+
17
+ fully_qualified_name: str
18
+ version: str
19
+ url: Optional[str] = None
20
+ exclude_cols: Optional[List[str]] = None
21
+
22
+
23
+ DataSource = Union[DataFrameInfo, DatasetInfo, str]
@@ -11,7 +11,6 @@ from snowflake.ml._internal.exceptions import (
11
11
  error_codes,
12
12
  exceptions as snowml_exceptions,
13
13
  )
14
- from snowflake.ml._internal.lineage import data_source
15
14
  from snowflake.ml._internal.utils import (
16
15
  formatting,
17
16
  identifier,
@@ -19,6 +18,7 @@ from snowflake.ml._internal.utils import (
19
18
  snowpark_dataframe_utils,
20
19
  )
21
20
  from snowflake.ml.dataset import dataset_metadata, dataset_reader
21
+ from snowflake.ml.lineage import lineage_node
22
22
  from snowflake.snowpark import exceptions as snowpark_exceptions, functions
23
23
 
24
24
  _PROJECT = "Dataset"
@@ -125,7 +125,7 @@ class DatasetVersion:
125
125
  return f"{self.__class__.__name__}(dataset='{self._parent.fully_qualified_name}', version='{self.name}')"
126
126
 
127
127
 
128
- class Dataset:
128
+ class Dataset(lineage_node.LineageNode):
129
129
  """Represents a Snowflake Dataset which is organized into versions."""
130
130
 
131
131
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -138,18 +138,31 @@ class Dataset:
138
138
  selected_version: Optional[str] = None,
139
139
  ) -> None:
140
140
  """Initialize a lazily evaluated Dataset object"""
141
- self._session = session
142
141
  self._db = database
143
142
  self._schema = schema
144
143
  self._name = name
145
- self._fully_qualified_name = identifier.get_schema_level_object_identifier(database, schema, name)
144
+
145
+ super().__init__(
146
+ session,
147
+ identifier.get_schema_level_object_identifier(database, schema, name),
148
+ domain="dataset",
149
+ version=selected_version,
150
+ )
146
151
 
147
152
  self._version = DatasetVersion(self, selected_version) if selected_version else None
148
153
  self._reader: Optional[dataset_reader.DatasetReader] = None
149
154
 
155
+ def __repr__(self) -> str:
156
+ return (
157
+ f"{self.__class__.__name__}(\n"
158
+ f" name='{self._lineage_node_name}',\n"
159
+ f" version='{self._version._version if self._version else None}',\n"
160
+ f")"
161
+ )
162
+
150
163
  @property
151
164
  def fully_qualified_name(self) -> str:
152
- return self._fully_qualified_name
165
+ return self._lineage_node_name
153
166
 
154
167
  @property
155
168
  def selected_version(self) -> Optional[DatasetVersion]:
@@ -163,18 +176,7 @@ class Dataset:
163
176
  original_exception=RuntimeError("No Dataset version selected."),
164
177
  )
165
178
  if self._reader is None:
166
- v = self.selected_version
167
- self._reader = dataset_reader.DatasetReader(
168
- self._session,
169
- [
170
- data_source.DataSource(
171
- fully_qualified_name=self._fully_qualified_name,
172
- version=v.name,
173
- url=v.url(),
174
- exclude_cols=(v.label_cols + v.exclude_cols),
175
- )
176
- ],
177
- )
179
+ self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
178
180
  return self._reader
179
181
 
180
182
  @staticmethod
@@ -230,9 +232,8 @@ class Dataset:
230
232
  try:
231
233
  session.sql(query).collect(statement_params=_TELEMETRY_STATEMENT_PARAMS)
232
234
  return Dataset(session, db, schema, ds_name)
233
- except snowpark_exceptions.SnowparkClientException as e:
234
- # Snowpark wraps the Python Connector error code in the head of the error message.
235
- if e.message.startswith(dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS):
235
+ except snowpark_exceptions.SnowparkSQLException as e:
236
+ if e.sql_error_code == dataset_errors.ERRNO_OBJECT_ALREADY_EXISTS:
236
237
  raise snowml_exceptions.SnowflakeMLException(
237
238
  error_code=error_codes.OBJECT_ALREADY_EXISTS,
238
239
  original_exception=dataset_errors.DatasetExistError(
@@ -296,7 +297,7 @@ class Dataset:
296
297
  Raises:
297
298
  SnowflakeMLException: The Dataset no longer exists.
298
299
  SnowflakeMLException: The specified Dataset version already exists.
299
- snowpark_exceptions.SnowparkClientException: An error occurred during Dataset creation.
300
+ snowpark_exceptions.SnowparkSQLException: An error occurred during Dataset creation.
300
301
 
301
302
  Note: During the generation of stage files, data casting will occur. The casting rules are as follows::
302
303
  - Data casting:
@@ -321,7 +322,8 @@ class Dataset:
321
322
  - DateType(DATE): Not supported. A warning will be logged.
322
323
  - VariantType(VARIANT): Not supported. A warning will be logged.
323
324
  """
324
- casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe)
325
+ cast_ignore_cols = (exclude_cols or []) + (label_cols or [])
326
+ casted_df = snowpark_dataframe_utils.cast_snowpark_dataframe(input_dataframe, ignore_columns=cast_ignore_cols)
325
327
 
326
328
  if shuffle:
327
329
  casted_df = casted_df.order_by(functions.random())
@@ -367,19 +369,19 @@ class Dataset:
367
369
 
368
370
  return Dataset(self._session, self._db, self._schema, self._name, version)
369
371
 
370
- except snowpark_exceptions.SnowparkClientException as e:
371
- if e.message.startswith(dataset_errors.ERRNO_DATASET_NOT_EXIST):
372
+ except snowpark_exceptions.SnowparkSQLException as e:
373
+ if e.sql_error_code == dataset_errors.ERRNO_DATASET_NOT_EXIST:
372
374
  raise snowml_exceptions.SnowflakeMLException(
373
375
  error_code=error_codes.NOT_FOUND,
374
376
  original_exception=dataset_errors.DatasetNotExistError(
375
377
  dataset_error_messages.DATASET_NOT_EXIST.format(self.fully_qualified_name)
376
378
  ),
377
379
  ) from e
378
- elif (
379
- e.message.startswith(dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS)
380
- or e.message.startswith(dataset_errors.ERRNO_VERSION_ALREADY_EXISTS)
381
- or e.message.startswith(dataset_errors.ERRNO_FILES_ALREADY_EXISTING)
382
- ):
380
+ elif e.sql_error_code in {
381
+ dataset_errors.ERRNO_DATASET_VERSION_ALREADY_EXISTS,
382
+ dataset_errors.ERRNO_VERSION_ALREADY_EXISTS,
383
+ dataset_errors.ERRNO_FILES_ALREADY_EXISTING,
384
+ }:
383
385
  raise snowml_exceptions.SnowflakeMLException(
384
386
  error_code=error_codes.OBJECT_ALREADY_EXISTS,
385
387
  original_exception=dataset_errors.DatasetExistError(
@@ -435,9 +437,8 @@ class Dataset:
435
437
  .has_column(_DATASET_VERSION_NAME_COL, allow_empty=True)
436
438
  .validate()
437
439
  )
438
- except snowpark_exceptions.SnowparkClientException as e:
439
- # Snowpark wraps the Python Connector error code in the head of the error message.
440
- if e.message.startswith(dataset_errors.ERRNO_OBJECT_NOT_EXIST):
440
+ except snowpark_exceptions.SnowparkSQLException as e:
441
+ if e.sql_error_code == dataset_errors.ERRNO_OBJECT_NOT_EXIST:
441
442
  raise snowml_exceptions.SnowflakeMLException(
442
443
  error_code=error_codes.NOT_FOUND,
443
444
  original_exception=dataset_errors.DatasetNotExistError(
@@ -459,6 +460,12 @@ class Dataset:
459
460
  ),
460
461
  )
461
462
 
463
+ @staticmethod
464
+ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) -> "Dataset":
465
+ return Dataset.load(session, name).select_version(version)
466
+
467
+
468
+ lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
462
469
 
463
470
  # Utility methods
464
471
 
@@ -1,48 +1,31 @@
1
- from typing import Any, List
2
-
3
- import pandas as pd
4
- from pyarrow import parquet as pq
1
+ from typing import List, Optional
5
2
 
6
3
  from snowflake import snowpark
7
4
  from snowflake.ml._internal import telemetry
8
- from snowflake.ml._internal.lineage import data_source, lineage_utils
9
- from snowflake.ml._internal.utils import import_utils
5
+ from snowflake.ml._internal.lineage import lineage_utils
6
+ from snowflake.ml.data import data_connector, data_ingestor, data_source
7
+ from snowflake.ml.data._internal import ingestor_utils
10
8
  from snowflake.ml.fileset import snowfs
11
9
 
12
10
  _PROJECT = "Dataset"
13
11
  _SUBPROJECT = "DatasetReader"
14
- TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
15
12
 
16
13
 
17
- class DatasetReader:
14
+ class DatasetReader(data_connector.DataConnector):
18
15
  """Snowflake Dataset abstraction which provides application integration connectors"""
19
16
 
20
17
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
21
18
  def __init__(
22
19
  self,
23
- session: snowpark.Session,
24
- sources: List[data_source.DataSource],
20
+ ingestor: data_ingestor.DataIngestor,
21
+ *,
22
+ snowpark_session: snowpark.Session,
25
23
  ) -> None:
26
- """Initialize a DatasetVersion object.
24
+ super().__init__(ingestor)
27
25
 
28
- Args:
29
- session: Snowpark Session to interact with Snowflake backend.
30
- sources: Data sources to read from.
31
-
32
- Raises:
33
- ValueError: `sources` arg was empty or null
34
- """
35
- if not sources:
36
- raise ValueError("Invalid input: empty `sources` list not allowed")
37
- self._session = session
38
- self._sources = sources
39
- self._fs: snowfs.SnowFileSystem = snowfs.SnowFileSystem(
40
- snowpark_session=self._session,
41
- cache_type="bytes",
42
- block_size=2 * TARGET_FILE_SIZE,
43
- )
44
-
45
- self._files: List[str] = []
26
+ self._session: snowpark.Session = snowpark_session
27
+ self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
28
+ self._files: Optional[List[str]] = None
46
29
 
47
30
  def _list_files(self) -> List[str]:
48
31
  """Private helper function that lists all files in this DatasetVersion and caches the results."""
@@ -50,18 +33,14 @@ class DatasetReader:
50
33
  return self._files
51
34
 
52
35
  files: List[str] = []
53
- for source in self._sources:
54
- # Sort within each source for consistent ordering
55
- files.extend(sorted(self._fs.ls(source.url))) # type: ignore[arg-type]
36
+ for source in self.data_sources:
37
+ assert isinstance(source, data_source.DatasetInfo)
38
+ files.extend(ingestor_utils.get_dataset_files(self._session, source, filesystem=self._fs))
56
39
  files.sort()
57
40
 
58
41
  self._files = files
59
42
  return self._files
60
43
 
61
- @property
62
- def data_sources(self) -> List[data_source.DataSource]:
63
- return self._sources
64
-
65
44
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
66
45
  def files(self) -> List[str]:
67
46
  """Get the list of remote file paths for the current DatasetVersion.
@@ -85,76 +64,6 @@ class DatasetReader:
85
64
  """Return an fsspec FileSystem which can be used to load the DatasetVersion's `files()`"""
86
65
  return self._fs
87
66
 
88
- @telemetry.send_api_usage_telemetry(
89
- project=_PROJECT,
90
- subproject=_SUBPROJECT,
91
- func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
92
- )
93
- def to_torch_datapipe(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any:
94
- """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
95
-
96
- Return a Pytorch datapipe which iterates on rows of data.
97
-
98
- Args:
99
- batch_size: It specifies the size of each data batch which will be
100
- yield in the result datapipe
101
- shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
102
- rows in each file will also be shuffled.
103
- drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
104
- then the last batch will get dropped if its size is smaller than the given batch_size.
105
-
106
- Returns:
107
- A Pytorch iterable datapipe that yield data.
108
-
109
- Examples:
110
- >>> dp = dataset.to_torch_datapipe(batch_size=1)
111
- >>> for data in dp:
112
- >>> print(data)
113
- ----
114
- {'_COL_1':[10]}
115
- """
116
- IterableWrapper, _ = import_utils.import_or_get_dummy("torchdata.datapipes.iter.IterableWrapper")
117
- torch_datapipe_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.torch_datapipe")
118
-
119
- self._fs.optimize_read(self._list_files())
120
-
121
- input_dp = IterableWrapper(self._list_files())
122
- return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
123
-
124
- @telemetry.send_api_usage_telemetry(
125
- project=_PROJECT,
126
- subproject=_SUBPROJECT,
127
- func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
128
- )
129
- def to_tf_dataset(self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True) -> Any:
130
- """Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset.
131
-
132
- Args:
133
- batch_size: It specifies the size of each data batch which will be
134
- yield in the result datapipe
135
- shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
136
- rows in each file will also be shuffled.
137
- drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
138
- then the last batch will get dropped if its size is smaller than the given batch_size.
139
-
140
- Returns:
141
- A tf.data.Dataset that yields batched tf.Tensors.
142
-
143
- Examples:
144
- >>> dp = dataset.to_tf_dataset(batch_size=1)
145
- >>> for data in dp:
146
- >>> print(data)
147
- ----
148
- {'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
149
- """
150
- tf_dataset_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.tf_dataset")
151
-
152
- self._fs.optimize_read(self._list_files())
153
-
154
- return tf_dataset_module.read_and_parse_parquet(
155
- self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
156
- )
157
-
158
67
  @telemetry.send_api_usage_telemetry(
159
68
  project=_PROJECT,
160
69
  subproject=_SUBPROJECT,
@@ -177,7 +86,8 @@ class DatasetReader:
177
86
  """
178
87
  file_path_pattern = ".*data_.*[.]parquet"
179
88
  dfs: List[snowpark.DataFrame] = []
180
- for source in self._sources:
89
+ for source in self.data_sources:
90
+ assert isinstance(source, data_source.DatasetInfo) and source.url is not None
181
91
  df = self._session.read.option("pattern", file_path_pattern).parquet(source.url)
182
92
  if only_feature_cols and source.exclude_cols:
183
93
  df = df.drop(source.exclude_cols)
@@ -186,14 +96,4 @@ class DatasetReader:
186
96
  combined_df = dfs[0]
187
97
  for df in dfs[1:]:
188
98
  combined_df = combined_df.union_all_by_name(df)
189
- return lineage_utils.patch_dataframe(combined_df, data_sources=self._sources, inplace=True)
190
-
191
- @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
192
- def to_pandas(self) -> pd.DataFrame:
193
- """Retrieve the DatasetVersion contents as a Pandas Dataframe"""
194
- files = self._list_files()
195
- if not files:
196
- return pd.DataFrame() # Return empty DataFrame
197
- self._fs.optimize_read(files)
198
- pd_ds = pq.ParquetDataset(files, filesystem=self._fs)
199
- return pd_ds.read_pandas().to_pandas()
99
+ return lineage_utils.patch_dataframe(combined_df, data_sources=self.data_sources, inplace=True)
@@ -273,7 +273,13 @@ def setup_feature_store(
273
273
  assert current_role is not None # to make mypy happy
274
274
  try:
275
275
  session.use_role(producer_role)
276
- fs = FeatureStore(session, database, schema, warehouse, creation_mode=CreationMode.CREATE_IF_NOT_EXIST)
276
+ fs = FeatureStore(
277
+ session,
278
+ database,
279
+ schema,
280
+ default_warehouse=warehouse,
281
+ creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
282
+ )
277
283
  finally:
278
284
  session.use_role(current_role)
279
285
 
@@ -22,7 +22,7 @@ class Entity:
22
22
  It can also be used for FeatureView search and lineage tracking.
23
23
  """
24
24
 
25
- def __init__(self, name: str, join_keys: List[str], desc: str = "") -> None:
25
+ def __init__(self, name: str, join_keys: List[str], *, desc: str = "") -> None:
26
26
  """
27
27
  Creates an Entity instance.
28
28
 
@@ -30,6 +30,23 @@ class Entity:
30
30
  name: name of the Entity.
31
31
  join_keys: join keys associated with a FeatureView, used for feature retrieval.
32
32
  desc: description of the Entity.
33
+
34
+ Example::
35
+
36
+ >>> fs = FeatureStore(...)
37
+ >>> e_1 = Entity(
38
+ ... name="my_entity",
39
+ ... join_keys=['col_1'],
40
+ ... desc='My first entity.'
41
+ ... )
42
+ >>> fs.register_entity(e_1)
43
+ >>> fs.list_entities().show()
44
+ -----------------------------------------------------------
45
+ |"NAME" |"JOIN_KEYS" |"DESC" |"OWNER" |
46
+ -----------------------------------------------------------
47
+ |MY_ENTITY |["COL_1"] |My first entity. |REGTEST_RL |
48
+ -----------------------------------------------------------
49
+
33
50
  """
34
51
  self._validate(name, join_keys)
35
52
 
@@ -65,7 +82,7 @@ class Entity:
65
82
 
66
83
  @staticmethod
67
84
  def _construct_entity(name: str, join_keys: List[str], desc: str, owner: str) -> "Entity":
68
- e = Entity(name, join_keys, desc)
85
+ e = Entity(name, join_keys, desc=desc)
69
86
  e.owner = owner
70
87
  return e
71
88
 
@@ -0,0 +1,20 @@
1
+ from typing import List
2
+
3
+ from snowflake.ml.feature_store import Entity
4
+
5
+ end_station_id = Entity(
6
+ name="end_station_id",
7
+ join_keys=["end_station_id"],
8
+ desc="The id of an end station.",
9
+ )
10
+
11
+ trip_id = Entity(
12
+ name="trip_id",
13
+ join_keys=["trip_id"],
14
+ desc="The id of a trip.",
15
+ )
16
+
17
+
18
+ # This will be invoked by example_helper.py. Do not change function name.
19
+ def get_all_entities() -> List[Entity]:
20
+ return [end_station_id, trip_id]
@@ -0,0 +1,31 @@
1
+ from typing import List
2
+
3
+ from snowflake.ml.feature_store import FeatureView
4
+ from snowflake.ml.feature_store.examples.citibike_trip_features.entities import (
5
+ end_station_id,
6
+ )
7
+ from snowflake.snowpark import DataFrame, Session
8
+
9
+
10
+ # This function will be invoked by example_helper.py. Do not change the name.
11
+ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
12
+ """Create a feature view about trip station."""
13
+ query = session.sql(
14
+ f"""
15
+ select
16
+ end_station_id,
17
+ count(end_station_id) as f_count_1d,
18
+ avg(end_station_latitude) as f_avg_latitude_1d,
19
+ avg(end_station_longitude) as f_avg_longtitude_1d
20
+ from {source_tables[0]}
21
+ group by end_station_id
22
+ """
23
+ )
24
+
25
+ return FeatureView(
26
+ name="f_station_1d", # name of feature view
27
+ entities=[end_station_id], # entities
28
+ feature_df=query, # definition query
29
+ refresh_freq="1d", # refresh frequency. '1d' means it refreshes everyday
30
+ desc="Station features refreshed every day.",
31
+ )
@@ -0,0 +1,24 @@
1
+ from typing import List
2
+
3
+ from snowflake.ml.feature_store import FeatureView
4
+ from snowflake.ml.feature_store.examples.citibike_trip_features.entities import trip_id
5
+ from snowflake.snowpark import DataFrame, Session, functions as F
6
+
7
+
8
+ # This function will be invoked by example_helper.py. Do not change the name.
9
+ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
10
+ """Create a feature view about trip."""
11
+ feature_df = source_dfs[0].select(
12
+ "trip_id",
13
+ F.col("birth_year").alias("f_birth_year"),
14
+ F.col("gender").alias("f_gender"),
15
+ F.col("bikeid").alias("f_bikeid"),
16
+ )
17
+
18
+ return FeatureView(
19
+ name="f_trip", # name of feature view
20
+ entities=[trip_id], # entities
21
+ feature_df=feature_df, # definition query
22
+ refresh_freq=None, # refresh frequency. None indicates it never refresh
23
+ desc="Static trip features",
24
+ )
@@ -0,0 +1,4 @@
1
+ ---
2
+ source_data: citibike_trips
3
+ label_columns: tripduration
4
+ add_id_column: trip_id