snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +66 -35
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +26 -2
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
- snowflake/ml/data/data_connector.py +186 -0
- snowflake/ml/data/data_ingestor.py +45 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/data/ingestor_utils.py +62 -0
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +23 -117
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/example_helper.py +278 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
- snowflake/ml/feature_store/feature_store.py +637 -76
- snowflake/ml/feature_store/feature_view.py +316 -9
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +14 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +7 -4
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,284 @@
|
|
1
|
+
import collections
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Union
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import numpy.typing as npt
|
9
|
+
import pandas as pd
|
10
|
+
import pyarrow as pa
|
11
|
+
import pyarrow.dataset as pds
|
12
|
+
|
13
|
+
from snowflake import snowpark
|
14
|
+
from snowflake.connector import result_batch
|
15
|
+
from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
|
16
|
+
|
17
|
+
_EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
|
18
|
+
|
19
|
+
# The row count for batches read from PyArrow Dataset. This number should be large enough so that
|
20
|
+
# dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
|
21
|
+
_DEFAULT_DATASET_BATCH_SIZE = 1000000
|
22
|
+
|
23
|
+
|
24
|
+
class _RecordBatchesBuffer:
|
25
|
+
"""A queue that stores record batches and tracks the total num of rows in it."""
|
26
|
+
|
27
|
+
def __init__(self) -> None:
|
28
|
+
self.buffer: Deque[pa.RecordBatch] = collections.deque()
|
29
|
+
self.num_rows = 0
|
30
|
+
|
31
|
+
def append(self, rb: pa.RecordBatch) -> None:
|
32
|
+
self.buffer.append(rb)
|
33
|
+
self.num_rows += rb.num_rows
|
34
|
+
|
35
|
+
def appendleft(self, rb: pa.RecordBatch) -> None:
|
36
|
+
self.buffer.appendleft(rb)
|
37
|
+
self.num_rows += rb.num_rows
|
38
|
+
|
39
|
+
def popleft(self) -> pa.RecordBatch:
|
40
|
+
popped = self.buffer.popleft()
|
41
|
+
self.num_rows -= popped.num_rows
|
42
|
+
return popped
|
43
|
+
|
44
|
+
|
45
|
+
class ArrowIngestor(data_ingestor.DataIngestor):
|
46
|
+
"""Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
session: snowpark.Session,
|
51
|
+
data_sources: List[data_source.DataSource],
|
52
|
+
format: Optional[str] = None,
|
53
|
+
**kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
"""
|
56
|
+
Args:
|
57
|
+
session: The Snowpark Session to use.
|
58
|
+
data_sources: List of data sources to ingest.
|
59
|
+
format: Currently “parquet”, “ipc”/”arrow”/”feather”, “csv”, “json”, and “orc” are supported.
|
60
|
+
Will be inferred if not specified.
|
61
|
+
kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
|
62
|
+
"""
|
63
|
+
self._session = session
|
64
|
+
self._data_sources = data_sources
|
65
|
+
self._format = format
|
66
|
+
self._kwargs = kwargs
|
67
|
+
|
68
|
+
self._schema: Optional[pa.Schema] = None
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def from_sources(cls, session: snowpark.Session, sources: List[data_source.DataSource]) -> "ArrowIngestor":
|
72
|
+
return cls(session, sources)
|
73
|
+
|
74
|
+
@property
|
75
|
+
def data_sources(self) -> List[data_source.DataSource]:
|
76
|
+
return self._data_sources
|
77
|
+
|
78
|
+
def to_batches(
|
79
|
+
self,
|
80
|
+
batch_size: int,
|
81
|
+
shuffle: bool = True,
|
82
|
+
drop_last_batch: bool = True,
|
83
|
+
) -> Iterator[Dict[str, npt.NDArray[Any]]]:
|
84
|
+
"""Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
|
85
|
+
|
86
|
+
As we are generating batches with the exactly same length, the last few rows in each file might get left as they
|
87
|
+
are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
|
88
|
+
few rows of the next file to generate a new batch.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
batch_size: Specifies the size of each batch that will be yield
|
92
|
+
shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
|
93
|
+
the order of files, and then shuflle the order of rows in each file.
|
94
|
+
drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last
|
95
|
+
batch will get dropped if its size is smaller than the given batch_size.
|
96
|
+
|
97
|
+
Yields:
|
98
|
+
A dict mapping column names to the corresponding data fetch from that column.
|
99
|
+
"""
|
100
|
+
self._rb_buffer = _RecordBatchesBuffer()
|
101
|
+
|
102
|
+
# Extract schema if not already known
|
103
|
+
dataset = self._get_dataset(shuffle)
|
104
|
+
if self._schema is None:
|
105
|
+
self._schema = dataset.schema
|
106
|
+
|
107
|
+
for rb in _retryable_batches(dataset, batch_size=max(_DEFAULT_DATASET_BATCH_SIZE, batch_size)):
|
108
|
+
if shuffle:
|
109
|
+
rb = rb.take(np.random.permutation(rb.num_rows))
|
110
|
+
self._rb_buffer.append(rb)
|
111
|
+
while self._rb_buffer.num_rows >= batch_size:
|
112
|
+
yield self._get_batches_from_buffer(batch_size)
|
113
|
+
|
114
|
+
if self._rb_buffer.num_rows and not drop_last_batch:
|
115
|
+
yield self._get_batches_from_buffer(batch_size)
|
116
|
+
|
117
|
+
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
|
118
|
+
ds = self._get_dataset(shuffle=False)
|
119
|
+
table = ds.to_table() if limit is None else ds.head(num_rows=limit)
|
120
|
+
return table.to_pandas()
|
121
|
+
|
122
|
+
def _get_dataset(self, shuffle: bool) -> pds.Dataset:
|
123
|
+
format = self._format
|
124
|
+
sources: List[Any] = []
|
125
|
+
source_format = None
|
126
|
+
for source in self._data_sources:
|
127
|
+
if isinstance(source, str):
|
128
|
+
sources.append(source)
|
129
|
+
source_format = format or os.path.splitext(source)[-1]
|
130
|
+
elif isinstance(source, data_source.DatasetInfo):
|
131
|
+
if not self._kwargs.get("filesystem"):
|
132
|
+
self._kwargs["filesystem"] = ingestor_utils.get_dataset_filesystem(self._session, source)
|
133
|
+
sources.extend(
|
134
|
+
ingestor_utils.get_dataset_files(self._session, source, filesystem=self._kwargs["filesystem"])
|
135
|
+
)
|
136
|
+
source_format = "parquet"
|
137
|
+
elif isinstance(source, data_source.DataFrameInfo):
|
138
|
+
# FIXME: This currently loads all result batches into memory so that it
|
139
|
+
# can be passed into pyarrow.dataset as a list/tuple of pa.RecordBatches
|
140
|
+
# We may be able to optimize this by splitting the result batches into
|
141
|
+
# in-memory (first batch) and file URLs (subsequent batches) and creating a
|
142
|
+
# union dataset.
|
143
|
+
result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
|
144
|
+
sources.extend(
|
145
|
+
b.to_arrow(self._session.connection)
|
146
|
+
if isinstance(b, result_batch.ArrowResultBatch)
|
147
|
+
else b.to_arrow()
|
148
|
+
for b in result_batches
|
149
|
+
)
|
150
|
+
# HACK: Mitigate typing inconsistencies in Snowpark results
|
151
|
+
if len(sources) > 0:
|
152
|
+
sources = [_cast_if_needed(s, sources[-1].schema) for s in sources]
|
153
|
+
source_format = None # Arrow Dataset expects "None" for in-memory datasets
|
154
|
+
else:
|
155
|
+
raise RuntimeError(f"Unsupported data source type: {type(source)}")
|
156
|
+
|
157
|
+
# Make sure source types not mixed
|
158
|
+
if format and format != source_format:
|
159
|
+
raise RuntimeError(f"Unexpected data source format (expected {format}, found {source_format})")
|
160
|
+
format = source_format
|
161
|
+
|
162
|
+
# Re-shuffle input files on each iteration start
|
163
|
+
if shuffle:
|
164
|
+
np.random.shuffle(sources)
|
165
|
+
pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
|
166
|
+
return pa_dataset
|
167
|
+
|
168
|
+
def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
|
169
|
+
"""Generate new batches from the existing record batch buffer."""
|
170
|
+
cnt_rbs_num_rows = 0
|
171
|
+
candidates = []
|
172
|
+
|
173
|
+
# Keep popping record batches in buffer until there are enough rows for a batch.
|
174
|
+
while self._rb_buffer.num_rows and cnt_rbs_num_rows < batch_size:
|
175
|
+
candidate = self._rb_buffer.popleft()
|
176
|
+
cnt_rbs_num_rows += candidate.num_rows
|
177
|
+
candidates.append(candidate)
|
178
|
+
|
179
|
+
# When there are more rows than needed, slice the last popped batch to fit batch_size.
|
180
|
+
if cnt_rbs_num_rows > batch_size:
|
181
|
+
row_diff = cnt_rbs_num_rows - batch_size
|
182
|
+
slice_target = candidates[-1]
|
183
|
+
cut_off = slice_target.num_rows - row_diff
|
184
|
+
to_merge = slice_target.slice(length=cut_off)
|
185
|
+
left_over = slice_target.slice(offset=cut_off)
|
186
|
+
candidates[-1] = to_merge
|
187
|
+
self._rb_buffer.appendleft(left_over)
|
188
|
+
|
189
|
+
res = _merge_record_batches(candidates)
|
190
|
+
return _record_batch_to_arrays(res)
|
191
|
+
|
192
|
+
|
193
|
+
def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
|
194
|
+
"""Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
|
195
|
+
if not record_batches:
|
196
|
+
return _EMPTY_RECORD_BATCH
|
197
|
+
if len(record_batches) == 1:
|
198
|
+
return record_batches[0]
|
199
|
+
record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
|
200
|
+
one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
|
201
|
+
batches = one_chunk_table.to_batches(max_chunksize=None)
|
202
|
+
return batches[0]
|
203
|
+
|
204
|
+
|
205
|
+
def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
|
206
|
+
"""Transform the record batch to a (string, numpy array) dict."""
|
207
|
+
batch_dict = {}
|
208
|
+
for column, column_schema in zip(rb, rb.schema):
|
209
|
+
# zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
|
210
|
+
array = column.to_numpy(zero_copy_only=False)
|
211
|
+
batch_dict[column_schema.name] = array
|
212
|
+
return batch_dict
|
213
|
+
|
214
|
+
|
215
|
+
def _retryable_batches(
|
216
|
+
dataset: pds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
|
217
|
+
) -> Iterator[pa.RecordBatch]:
|
218
|
+
"""Make the Dataset to_batches retryable."""
|
219
|
+
retries = 0
|
220
|
+
current_batch_index = 0
|
221
|
+
|
222
|
+
while True:
|
223
|
+
try:
|
224
|
+
for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
|
225
|
+
if batch_index < current_batch_index:
|
226
|
+
# Skip batches that have already been processed
|
227
|
+
continue
|
228
|
+
|
229
|
+
yield batch
|
230
|
+
current_batch_index = batch_index + 1
|
231
|
+
# Exit the loop once all batches are processed
|
232
|
+
break
|
233
|
+
|
234
|
+
except Exception as e:
|
235
|
+
if retries < max_retries:
|
236
|
+
retries += 1
|
237
|
+
logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
|
238
|
+
time.sleep(delay)
|
239
|
+
else:
|
240
|
+
raise e
|
241
|
+
|
242
|
+
|
243
|
+
def _cast_if_needed(
|
244
|
+
batch: Union[pa.Table, pa.RecordBatch], schema: Optional[pa.Schema] = None
|
245
|
+
) -> Union[pa.Table, pa.RecordBatch]:
|
246
|
+
"""
|
247
|
+
Cast the batch to be compatible with downstream frameworks. Returns original batch if cast is not necessary.
|
248
|
+
Besides casting types to match `schema` (if provided), this function also applies the following casting:
|
249
|
+
- Decimal (fixed-point) types: Convert to float or integer types based on scale and byte length
|
250
|
+
|
251
|
+
Args:
|
252
|
+
batch: The PyArrow batch to cast if needed
|
253
|
+
schema: Optional schema the batch should be casted to match. Note that compatibility type casting takes
|
254
|
+
precedence over the provided schema, e.g. if the schema has decimal types the result will be further
|
255
|
+
cast into integer/float types.
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
The type-casted PyArrow batch, or the original batch if casting was not necessary
|
259
|
+
"""
|
260
|
+
schema = schema or batch.schema
|
261
|
+
assert len(batch.schema) == len(schema)
|
262
|
+
fields = []
|
263
|
+
cast_needed = False
|
264
|
+
for field, target in zip(batch.schema, schema):
|
265
|
+
# Need to convert decimal types to supported types. This behavior supersedes target schema data types
|
266
|
+
if pa.types.is_decimal(target.type):
|
267
|
+
byte_length = int(target.metadata.get(b"byteLength", 8))
|
268
|
+
if int(target.metadata.get(b"scale", 0)) > 0:
|
269
|
+
target = target.with_type(pa.float32() if byte_length == 4 else pa.float64())
|
270
|
+
else:
|
271
|
+
if byte_length == 2:
|
272
|
+
target = target.with_type(pa.int16())
|
273
|
+
elif byte_length == 4:
|
274
|
+
target = target.with_type(pa.int32())
|
275
|
+
else: # Cap out at 64-bit
|
276
|
+
target = target.with_type(pa.int64())
|
277
|
+
if not field.equals(target):
|
278
|
+
cast_needed = True
|
279
|
+
field = target
|
280
|
+
fields.append(field)
|
281
|
+
|
282
|
+
if cast_needed:
|
283
|
+
return batch.cast(pa.schema(fields))
|
284
|
+
return batch
|
@@ -0,0 +1,186 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
|
2
|
+
|
3
|
+
import numpy.typing as npt
|
4
|
+
from typing_extensions import deprecated
|
5
|
+
|
6
|
+
from snowflake import snowpark
|
7
|
+
from snowflake.ml._internal import telemetry
|
8
|
+
from snowflake.ml.data import data_ingestor, data_source
|
9
|
+
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
import pandas as pd
|
13
|
+
import tensorflow as tf
|
14
|
+
from torch.utils import data as torch_data
|
15
|
+
|
16
|
+
# This module can't actually depend on dataset to avoid a circular dependency
|
17
|
+
# Dataset -> DatasetReader -> DataConnector -!-> Dataset
|
18
|
+
from snowflake.ml import dataset
|
19
|
+
|
20
|
+
_PROJECT = "DataConnector"
|
21
|
+
|
22
|
+
DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
|
23
|
+
|
24
|
+
|
25
|
+
class DataConnector:
|
26
|
+
"""Snowflake data reader which provides application integration connectors"""
|
27
|
+
|
28
|
+
DEFAULT_INGESTOR_CLASS: Type[data_ingestor.DataIngestor] = ArrowIngestor
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
ingestor: data_ingestor.DataIngestor,
|
33
|
+
) -> None:
|
34
|
+
self._ingestor = ingestor
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
@snowpark._internal.utils.private_preview(version="1.6.0")
|
38
|
+
def from_dataframe(
|
39
|
+
cls: Type[DataConnectorType],
|
40
|
+
df: snowpark.DataFrame,
|
41
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
42
|
+
**kwargs: Any
|
43
|
+
) -> DataConnectorType:
|
44
|
+
if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
|
45
|
+
raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
|
46
|
+
source = data_source.DataFrameInfo(df.queries["queries"][0])
|
47
|
+
assert df._session is not None
|
48
|
+
return cls.from_sources(df._session, [source], ingestor_class=ingestor_class, **kwargs)
|
49
|
+
|
50
|
+
@classmethod
|
51
|
+
def from_dataset(
|
52
|
+
cls: Type[DataConnectorType],
|
53
|
+
ds: "dataset.Dataset",
|
54
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
55
|
+
**kwargs: Any
|
56
|
+
) -> DataConnectorType:
|
57
|
+
dsv = ds.selected_version
|
58
|
+
assert dsv is not None
|
59
|
+
source = data_source.DatasetInfo(
|
60
|
+
ds.fully_qualified_name, dsv.name, dsv.url(), exclude_cols=(dsv.label_cols + dsv.exclude_cols)
|
61
|
+
)
|
62
|
+
return cls.from_sources(ds._session, [source], ingestor_class=ingestor_class, **kwargs)
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
@telemetry.send_api_usage_telemetry(
|
66
|
+
project=_PROJECT,
|
67
|
+
subproject_extractor=lambda cls: cls.__name__,
|
68
|
+
func_params_to_log=["sources", "ingestor_class"],
|
69
|
+
)
|
70
|
+
def from_sources(
|
71
|
+
cls: Type[DataConnectorType],
|
72
|
+
session: snowpark.Session,
|
73
|
+
sources: List[data_source.DataSource],
|
74
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
75
|
+
**kwargs: Any
|
76
|
+
) -> DataConnectorType:
|
77
|
+
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
78
|
+
ingestor = ingestor_class.from_sources(session, sources)
|
79
|
+
return cls(ingestor, **kwargs)
|
80
|
+
|
81
|
+
@property
|
82
|
+
def data_sources(self) -> List[data_source.DataSource]:
|
83
|
+
return self._ingestor.data_sources
|
84
|
+
|
85
|
+
@telemetry.send_api_usage_telemetry(
|
86
|
+
project=_PROJECT,
|
87
|
+
subproject_extractor=lambda self: type(self).__name__,
|
88
|
+
func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
|
89
|
+
)
|
90
|
+
def to_tf_dataset(
|
91
|
+
self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
|
92
|
+
) -> "tf.data.Dataset":
|
93
|
+
"""Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
batch_size: It specifies the size of each data batch which will be
|
97
|
+
yield in the result datapipe
|
98
|
+
shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
|
99
|
+
rows in each file will also be shuffled.
|
100
|
+
drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
|
101
|
+
then the last batch will get dropped if its size is smaller than the given batch_size.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
A tf.data.Dataset that yields batched tf.Tensors.
|
105
|
+
"""
|
106
|
+
import tensorflow as tf
|
107
|
+
|
108
|
+
def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
|
109
|
+
yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
|
110
|
+
|
111
|
+
# Derive TensorFlow signature
|
112
|
+
first_batch = next(self._ingestor.to_batches(1, shuffle=False, drop_last_batch=False))
|
113
|
+
tf_signature = {
|
114
|
+
k: tf.TensorSpec(shape=(None,), dtype=tf.dtypes.as_dtype(v.dtype), name=k) for k, v in first_batch.items()
|
115
|
+
}
|
116
|
+
|
117
|
+
return tf.data.Dataset.from_generator(generator, output_signature=tf_signature)
|
118
|
+
|
119
|
+
@deprecated(
|
120
|
+
"to_torch_datapipe() is deprecated and will be removed in a future release. Use to_torch_dataset() instead"
|
121
|
+
)
|
122
|
+
@telemetry.send_api_usage_telemetry(
|
123
|
+
project=_PROJECT,
|
124
|
+
subproject_extractor=lambda self: type(self).__name__,
|
125
|
+
func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
|
126
|
+
)
|
127
|
+
def to_torch_datapipe(
|
128
|
+
self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
|
129
|
+
) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
|
130
|
+
"""Transform the Snowflake data into a ready-to-use Pytorch datapipe.
|
131
|
+
|
132
|
+
Return a Pytorch datapipe which iterates on rows of data.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
batch_size: It specifies the size of each data batch which will be
|
136
|
+
yield in the result datapipe
|
137
|
+
shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
|
138
|
+
rows in each file will also be shuffled.
|
139
|
+
drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
|
140
|
+
then the last batch will get dropped if its size is smaller than the given batch_size.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
A Pytorch iterable datapipe that yield data.
|
144
|
+
"""
|
145
|
+
from torch.utils.data.datapipes import iter as torch_iter
|
146
|
+
|
147
|
+
return torch_iter.IterableWrapper( # type: ignore[no-untyped-call]
|
148
|
+
self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
|
149
|
+
)
|
150
|
+
|
151
|
+
@telemetry.send_api_usage_telemetry(
|
152
|
+
project=_PROJECT,
|
153
|
+
subproject_extractor=lambda self: type(self).__name__,
|
154
|
+
func_params_to_log=["shuffle"],
|
155
|
+
)
|
156
|
+
def to_torch_dataset(self, *, shuffle: bool = False) -> "torch_data.IterableDataset": # type: ignore[type-arg]
|
157
|
+
"""Transform the Snowflake data into a PyTorch Iterable Dataset to be used with a DataLoader.
|
158
|
+
|
159
|
+
Return a PyTorch Dataset which iterates on rows of data.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
|
163
|
+
rows in each file will also be shuffled.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
A PyTorch Iterable Dataset that yields data.
|
167
|
+
"""
|
168
|
+
from snowflake.ml.data import torch_dataset
|
169
|
+
|
170
|
+
return torch_dataset.TorchDataset(self._ingestor, shuffle)
|
171
|
+
|
172
|
+
@telemetry.send_api_usage_telemetry(
|
173
|
+
project=_PROJECT,
|
174
|
+
subproject_extractor=lambda self: type(self).__name__,
|
175
|
+
func_params_to_log=["limit"],
|
176
|
+
)
|
177
|
+
def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
|
178
|
+
"""Retrieve the Snowflake data as a Pandas DataFrame.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
limit: If specified, the maximum number of rows to load into the DataFrame.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
A Pandas DataFrame.
|
185
|
+
"""
|
186
|
+
return self._ingestor.to_pandas(limit)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from typing import (
|
2
|
+
TYPE_CHECKING,
|
3
|
+
Any,
|
4
|
+
Dict,
|
5
|
+
Iterator,
|
6
|
+
List,
|
7
|
+
Optional,
|
8
|
+
Protocol,
|
9
|
+
Type,
|
10
|
+
TypeVar,
|
11
|
+
)
|
12
|
+
|
13
|
+
from numpy import typing as npt
|
14
|
+
|
15
|
+
from snowflake import snowpark
|
16
|
+
from snowflake.ml.data import data_source
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import pandas as pd
|
20
|
+
|
21
|
+
|
22
|
+
DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
23
|
+
|
24
|
+
|
25
|
+
class DataIngestor(Protocol):
|
26
|
+
@classmethod
|
27
|
+
def from_sources(
|
28
|
+
cls: Type[DataIngestorType], session: snowpark.Session, sources: List[data_source.DataSource]
|
29
|
+
) -> DataIngestorType:
|
30
|
+
raise NotImplementedError
|
31
|
+
|
32
|
+
@property
|
33
|
+
def data_sources(self) -> List[data_source.DataSource]:
|
34
|
+
raise NotImplementedError
|
35
|
+
|
36
|
+
def to_batches(
|
37
|
+
self,
|
38
|
+
batch_size: int,
|
39
|
+
shuffle: bool = True,
|
40
|
+
drop_last_batch: bool = True,
|
41
|
+
) -> Iterator[Dict[str, npt.NDArray[Any]]]:
|
42
|
+
raise NotImplementedError
|
43
|
+
|
44
|
+
def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
|
45
|
+
raise NotImplementedError
|
@@ -0,0 +1,23 @@
|
|
1
|
+
import dataclasses
|
2
|
+
from typing import List, Optional, Union
|
3
|
+
|
4
|
+
|
5
|
+
@dataclasses.dataclass(frozen=True)
|
6
|
+
class DataFrameInfo:
|
7
|
+
"""Serializable information from Snowpark DataFrames"""
|
8
|
+
|
9
|
+
sql: str
|
10
|
+
query_id: Optional[str] = None
|
11
|
+
|
12
|
+
|
13
|
+
@dataclasses.dataclass(frozen=True)
|
14
|
+
class DatasetInfo:
|
15
|
+
"""Serializable information from SnowML Datasets"""
|
16
|
+
|
17
|
+
fully_qualified_name: str
|
18
|
+
version: str
|
19
|
+
url: Optional[str] = None
|
20
|
+
exclude_cols: Optional[List[str]] = None
|
21
|
+
|
22
|
+
|
23
|
+
DataSource = Union[DataFrameInfo, DatasetInfo, str]
|
@@ -0,0 +1,62 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import fsspec
|
4
|
+
|
5
|
+
from snowflake import snowpark
|
6
|
+
from snowflake.connector import result_batch
|
7
|
+
from snowflake.ml.data import data_source
|
8
|
+
from snowflake.ml.fileset import snowfs
|
9
|
+
|
10
|
+
_TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
|
11
|
+
|
12
|
+
|
13
|
+
def get_dataframe_result_batches(
|
14
|
+
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
15
|
+
) -> List[result_batch.ResultBatch]:
|
16
|
+
"""Retrieve the ResultBatches for a given query"""
|
17
|
+
cursor = session._conn._cursor
|
18
|
+
|
19
|
+
if df_info.query_id:
|
20
|
+
query_id = df_info.query_id
|
21
|
+
else:
|
22
|
+
query_id = session.sql(df_info.sql).collect_nowait().query_id
|
23
|
+
|
24
|
+
# TODO: Check if query result cache is still live
|
25
|
+
cursor.get_results_from_sfqid(sfqid=query_id)
|
26
|
+
|
27
|
+
# Prefetch hook should be set by `get_results_from_sfqid`
|
28
|
+
# This call blocks until the query results are ready
|
29
|
+
if cursor._prefetch_hook is None:
|
30
|
+
raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.")
|
31
|
+
cursor._prefetch_hook()
|
32
|
+
batches = cursor.get_result_batches()
|
33
|
+
if batches is None:
|
34
|
+
raise ValueError(
|
35
|
+
"Failed to retrieve training data. Query status:" f" {session._conn._conn.get_query_status(query_id)}"
|
36
|
+
)
|
37
|
+
return batches
|
38
|
+
|
39
|
+
|
40
|
+
def get_dataset_filesystem(
|
41
|
+
session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
|
42
|
+
) -> fsspec.AbstractFileSystem:
|
43
|
+
"""Get the fsspec filesystem for a given Dataset"""
|
44
|
+
# We can't directly load the Dataset to avoid a circular dependency
|
45
|
+
# Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
|
46
|
+
# TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
|
47
|
+
return snowfs.SnowFileSystem(
|
48
|
+
snowpark_session=session,
|
49
|
+
cache_type="bytes",
|
50
|
+
block_size=2 * _TARGET_FILE_SIZE,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
def get_dataset_files(
|
55
|
+
session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
|
56
|
+
) -> List[str]:
|
57
|
+
"""Get the list of files in a given Dataset"""
|
58
|
+
if filesystem is None:
|
59
|
+
filesystem = get_dataset_filesystem(session, ds_info)
|
60
|
+
assert bool(ds_info.url) # Not null or empty
|
61
|
+
files = sorted(filesystem.ls(ds_info.url))
|
62
|
+
return [filesystem.unstrip_protocol(f) for f in files]
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import Any, Dict, Iterator
|
2
|
+
|
3
|
+
import torch.utils.data
|
4
|
+
|
5
|
+
from snowflake.ml.data import data_ingestor
|
6
|
+
|
7
|
+
|
8
|
+
class TorchDataset(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
9
|
+
"""Implementation of PyTorch IterableDataset"""
|
10
|
+
|
11
|
+
def __init__(self, ingestor: data_ingestor.DataIngestor, shuffle: bool = False) -> None:
|
12
|
+
"""Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
|
13
|
+
self._ingestor = ingestor
|
14
|
+
self._shuffle = shuffle
|
15
|
+
|
16
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
17
|
+
max_idx = 0
|
18
|
+
filter_idx = 0
|
19
|
+
worker_info = torch.utils.data.get_worker_info()
|
20
|
+
if worker_info is not None:
|
21
|
+
max_idx = worker_info.num_workers - 1
|
22
|
+
filter_idx = worker_info.id
|
23
|
+
|
24
|
+
counter = 0
|
25
|
+
for batch in self._ingestor.to_batches(batch_size=1, shuffle=self._shuffle, drop_last_batch=False):
|
26
|
+
# Skip indices during multi-process data loading to prevent data duplication
|
27
|
+
if counter == filter_idx:
|
28
|
+
yield {k: v.item() for k, v in batch.items()}
|
29
|
+
|
30
|
+
if counter < max_idx:
|
31
|
+
counter += 1
|
32
|
+
else:
|
33
|
+
counter = 0
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -11,7 +11,6 @@ from snowflake.ml._internal.exceptions import (
|
|
11
11
|
error_codes,
|
12
12
|
exceptions as snowml_exceptions,
|
13
13
|
)
|
14
|
-
from snowflake.ml._internal.lineage import data_source
|
15
14
|
from snowflake.ml._internal.utils import (
|
16
15
|
formatting,
|
17
16
|
identifier,
|
@@ -177,18 +176,7 @@ class Dataset(lineage_node.LineageNode):
|
|
177
176
|
original_exception=RuntimeError("No Dataset version selected."),
|
178
177
|
)
|
179
178
|
if self._reader is None:
|
180
|
-
|
181
|
-
self._reader = dataset_reader.DatasetReader(
|
182
|
-
self._session,
|
183
|
-
[
|
184
|
-
data_source.DataSource(
|
185
|
-
fully_qualified_name=self._lineage_node_name,
|
186
|
-
version=v.name,
|
187
|
-
url=v.url(),
|
188
|
-
exclude_cols=(v.label_cols + v.exclude_cols),
|
189
|
-
)
|
190
|
-
],
|
191
|
-
)
|
179
|
+
self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
|
192
180
|
return self._reader
|
193
181
|
|
194
182
|
@staticmethod
|
@@ -15,11 +15,13 @@ class FeatureStoreMetadata:
|
|
15
15
|
Properties:
|
16
16
|
spine_query: The input query on source table which will be joined with features.
|
17
17
|
serialized_feature_views: A list of serialized feature objects in the feature store.
|
18
|
+
compact_feature_views: A compact representation of a FeatureView or FeatureViewSlice.
|
18
19
|
spine_timestamp_col: Timestamp column which was used for point-in-time correct feature lookup.
|
19
20
|
"""
|
20
21
|
|
21
22
|
spine_query: str
|
22
|
-
serialized_feature_views: List[str]
|
23
|
+
serialized_feature_views: Optional[List[str]] = None
|
24
|
+
compact_feature_views: Optional[List[str]] = None
|
23
25
|
spine_timestamp_col: Optional[str] = None
|
24
26
|
|
25
27
|
def to_json(self) -> str:
|