pyspiral 0.6.2__cp310-abi3-macosx_11_0_arm64.whl → 0.6.4__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spiral/table.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from datetime import datetime
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
+ from spiral import ShuffleStrategy
4
5
  from spiral.core.table import Table as CoreTable
5
6
  from spiral.core.table.spec import Schema
6
7
  from spiral.expressions.base import Expr, ExprLike
@@ -131,6 +132,7 @@ class Table(Expr):
131
132
 
132
133
  IMPORTANT: While transaction can be used to atomically write data to the table,
133
134
  it is important that the primary key columns are unique within the transaction.
135
+ The behavior is undefined if this is not the case.
134
136
  """
135
137
  return Transaction(self.spiral._core.transaction(self.core, settings().file_format))
136
138
 
@@ -146,108 +148,58 @@ class Table(Expr):
146
148
  """Returns a DuckDB relation for the Spiral table."""
147
149
  return self.snapshot().to_duckdb()
148
150
 
149
- def to_data_loader(self, *, index: "KeySpaceIndex", **kwargs) -> "torchdata.DataLoader":
150
- """Returns a PyTorch DataLoader.
151
+ def to_iterable_dataset(
152
+ self,
153
+ *,
154
+ index: Optional["KeySpaceIndex"] = None,
155
+ shuffle: ShuffleStrategy | None = None,
156
+ batch_readahead: int | None = None,
157
+ infinite: bool = False,
158
+ ) -> "torchdata.IterableDataset":
159
+ """Returns an iterable dataset compatible with `torch.IterableDataset`. It can be used for training
160
+ in local or distributed settings.
151
161
 
152
- Requires `torch` and `streaming` package to be installed.
162
+ Supports sharding, shuffling, and compatible for multiprocessing with `num_workers`. If projections and
163
+ filtering are needed, you must create a key space index and pass it when creating the stream.
164
+
165
+ Requires `torch` package to be installed.
153
166
 
154
167
  Args:
155
- index: See `streaming` method.
168
+ shuffle: Controls sample shuffling. If None, no shuffling is performed.
169
+ batch_readahead: Controls how many batches to read ahead concurrently.
170
+ If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
171
+ Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
172
+ asof: If provided, only data written before the given timestamp will be returned.
173
+ If `index` is provided, it must not be used. The index's `asof` will be used instead.
174
+ infinite: If True, the returned IterableDataset will loop infinitely over the data,
175
+ re-shuffling ranges after exhausting all data.
176
+ index: Optional prebuilt KeysIndex to use when creating the stream.
177
+ The index's `asof` will be used when scanning.
156
178
  **kwargs: Additional arguments passed to the PyTorch DataLoader constructor.
157
-
158
179
  """
159
- from streaming import StreamingDataLoader
160
-
161
- dataset_kwargs = {}
162
- if "batch_size" in kwargs:
163
- # Keep it in kwargs for DataLoader
164
- dataset_kwargs["batch_size"] = kwargs["batch_size"]
165
- if "cache_limit" in kwargs:
166
- dataset_kwargs["cache_limit"] = kwargs.pop("cache_limit")
167
- if "sampling_method" in kwargs:
168
- dataset_kwargs["sampling_method"] = kwargs.pop("sampling_method")
169
- if "sampling_granularity" in kwargs:
170
- dataset_kwargs["sampling_granularity"] = kwargs.pop("sampling_granularity")
171
- if "partition_algo" in kwargs:
172
- dataset_kwargs["partition_algo"] = kwargs.pop("partition_algo")
173
- if "num_canonical_nodes" in kwargs:
174
- dataset_kwargs["num_canonical_nodes"] = kwargs.pop("num_canonical_nodes")
175
- if "shuffle" in kwargs:
176
- dataset_kwargs["shuffle"] = kwargs.pop("shuffle")
177
- if "shuffle_algo" in kwargs:
178
- dataset_kwargs["shuffle_algo"] = kwargs.pop("shuffle_algo")
179
- if "shuffle_seed" in kwargs:
180
- dataset_kwargs["shuffle_seed"] = kwargs.pop("shuffle_seed")
181
- if "shuffle_block_size" in kwargs:
182
- dataset_kwargs["shuffle_block_size"] = kwargs.pop("shuffle_block_size")
183
- if "batching_method" in kwargs:
184
- dataset_kwargs["batching_method"] = kwargs.pop("batching_method")
185
- if "replication" in kwargs:
186
- dataset_kwargs["replication"] = kwargs.pop("replication")
187
-
188
- dataset = self.to_streaming_dataset(index=index, **dataset_kwargs)
189
-
190
- return StreamingDataLoader(dataset=dataset, **kwargs)
191
-
192
- def to_streaming_dataset(
180
+ # TODO(marko): WIP.
181
+ raise NotImplementedError
182
+
183
+ def to_streaming(
193
184
  self,
194
- *,
195
185
  index: "KeySpaceIndex",
196
- batch_size: int | None = None,
186
+ *,
187
+ projection: Expr | None = None,
197
188
  cache_dir: str | None = None,
198
- cache_limit: int | str | None = None,
199
- predownload: int | None = None,
200
- sampling_method: str = "balanced",
201
- sampling_granularity: int = 1,
202
- partition_algo: str = "relaxed",
203
- num_canonical_nodes: int | None = None,
204
- shuffle: bool = False,
205
- shuffle_algo: str = "py1e",
206
- shuffle_seed: int = 9176,
207
- shuffle_block_size: int | None = None,
208
- batching_method: str = "random",
209
- replication: int | None = None,
210
- ) -> "streaming.StreamingDataset":
211
- """Returns a MosaicML's StreamingDataset that can be used for distributed training.
212
-
213
- Requires `streaming` package to be installed.
214
-
215
- Args:
216
- See `streaming` method for `index` arg.
217
- See MosaicML's `StreamingDataset` for other args.
218
-
219
- This is a helper method to construct a single stream dataset from the scan. When multiple streams are combined,
220
- use `to_stream` to get the SpiralStream and construct the StreamingDataset manually using a `streams` arg.
221
- """
222
- from streaming import StreamingDataset
223
-
224
- stream = self.to_streaming(index=index, cache_dir=cache_dir)
225
-
226
- return StreamingDataset(
227
- streams=[stream],
228
- batch_size=batch_size,
229
- cache_limit=cache_limit,
230
- predownload=predownload,
231
- sampling_method=sampling_method,
232
- sampling_granularity=sampling_granularity,
233
- partition_algo=partition_algo,
234
- num_canonical_nodes=num_canonical_nodes,
235
- shuffle=shuffle,
236
- shuffle_algo=shuffle_algo,
237
- shuffle_seed=shuffle_seed,
238
- shuffle_block_size=shuffle_block_size,
239
- batching_method=batching_method,
240
- replication=replication,
241
- )
242
-
243
- def to_streaming(self, index: "KeySpaceIndex", *, cache_dir: str | None = None) -> "streaming.Stream":
189
+ shard_row_block_size: int | None = None,
190
+ ) -> "streaming.Stream":
244
191
  """Returns a stream to be used with MosaicML's StreamingDataset.
245
192
 
246
193
  Requires `streaming` package to be installed.
247
194
 
248
195
  Args:
249
- index: Prebuilt KeysIndex to use when creating the stream. The index's `asof` will be used when scanning.
196
+ index: Prebuilt KeysIndex to use when creating the stream.
197
+ The index's `asof` will be used when scanning.
198
+ projection: Optional projection to use when scanning the table if index's projection is not used.
199
+ Projection must be compatible with the index's projection for correctness.
250
200
  cache_dir: Directory to use for caching data. If None, a temporary directory will be used.
201
+ shard_row_block_size: Number of rows per segment of a shard file. Defaults to 8192.
202
+ Value should be set to lower for larger rows.
251
203
  """
252
204
  from spiral.streaming_ import SpiralStream
253
205
 
@@ -258,7 +210,7 @@ class Table(Expr):
258
210
 
259
211
  # We know table from projection is in the session cause this method is on it.
260
212
  scan = self.spiral.scan(
261
- index.projection,
213
+ projection if projection is not None else index.projection,
262
214
  where=index.filter,
263
215
  asof=index.asof,
264
216
  # TODO(marko): This should be configurable?
@@ -269,4 +221,9 @@ class Table(Expr):
269
221
  # We have a world there and can compute shards only on leader.
270
222
  shards = self.spiral._core._ops().compute_shards(index=index.core)
271
223
 
272
- return SpiralStream(scan=scan.core, shards=shards, cache_dir=cache_dir) # type: ignore[return-value]
224
+ return SpiralStream(
225
+ scan=scan.core,
226
+ shards=shards,
227
+ cache_dir=cache_dir,
228
+ shard_row_block_size=shard_row_block_size,
229
+ ) # type: ignore[return-value]