pyspiral 0.6.3__cp310-abi3-macosx_11_0_arm64.whl → 0.6.5__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spiral/scan.py CHANGED
@@ -1,4 +1,3 @@
1
- from collections.abc import Iterator
2
1
  from typing import TYPE_CHECKING, Any
3
2
 
4
3
  import pyarrow as pa
@@ -120,8 +119,11 @@ class Scan:
120
119
  self,
121
120
  shuffle: ShuffleStrategy | None = None,
122
121
  batch_readahead: int | None = None,
122
+ num_workers: int | None = None,
123
+ worker_id: int | None = None,
124
+ infinite: bool = False,
123
125
  ) -> "hf.IterableDataset":
124
- """Returns an Huggingface's IterableDataset.
126
+ """Returns a Huggingface's IterableDataset.
125
127
 
126
128
  Requires `datasets` package to be installed.
127
129
 
@@ -130,39 +132,25 @@ class Scan:
130
132
  batch_readahead: Controls how many batches to read ahead concurrently.
131
133
  If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
132
134
  Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
135
+ num_workers: If not None, shards the scan across multiple workers.
136
+ Must be used together with worker_id.
137
+ worker_id: If not None, the id of the current worker.
138
+ Scan will only return a subset of the data corresponding to the worker_id.
139
+ infinite: If True, the returned IterableDataset will loop infinitely over the data,
140
+ re-shuffling ranges after exhausting all data.
133
141
  """
134
- from datasets import DatasetInfo, Features
135
- from datasets.iterable_dataset import ArrowExamplesIterable, IterableDataset
136
-
137
- def _generate_tables(**kwargs) -> Iterator[tuple[int, pa.Table]]:
138
- stream = self.core.to_shuffled_record_batches(
139
- shuffle,
140
- batch_readahead,
141
- )
142
-
143
- # This key is unused when training with IterableDataset.
144
- # Default implementation returns shard id, e.g. parquet row group id.
145
- for i, rb in enumerate(stream):
146
- yield i, pa.Table.from_batches([rb], stream.schema)
147
-
148
- def _hf_compatible_schema(schema: pa.Schema) -> pa.Schema:
149
- """
150
- Replace string-view columns in the schema with strings. We do use this converted schema
151
- as Features in the returned Dataset.
152
- Remove this method once we have https://github.com/huggingface/datasets/pull/7718
153
- """
154
- new_fields = [
155
- pa.field(field.name, pa.string(), nullable=field.nullable, metadata=field.metadata)
156
- if field.type == pa.string_view()
157
- else field
158
- for field in schema
159
- ]
160
- return pa.schema(new_fields)
161
-
162
- # NOTE: generate_tables_fn type annotations are wrong, return type must be an iterable of tuples.
163
- ex_iterable = ArrowExamplesIterable(generate_tables_fn=_generate_tables, kwargs={}) # type: ignore
164
- info = DatasetInfo(features=Features.from_arrow_schema(_hf_compatible_schema(self.schema.to_arrow())))
165
- return IterableDataset(ex_iterable=ex_iterable, info=info)
142
+
143
+ stream = self.core.to_shuffled_record_batches(
144
+ shuffle,
145
+ batch_readahead,
146
+ num_workers,
147
+ worker_id,
148
+ infinite,
149
+ )
150
+
151
+ from spiral.iterable_dataset import to_iterable_dataset
152
+
153
+ return to_iterable_dataset(stream)
166
154
 
167
155
  def _splits(self) -> list[KeyRange]:
168
156
  # Splits the scan into a set of key ranges.
spiral/settings.py CHANGED
@@ -24,6 +24,8 @@ CI = "GITHUB_ACTIONS" in os.environ
24
24
  APP_DIR = Path(typer.get_app_dir("pyspiral"))
25
25
  LOG_DIR = APP_DIR / "logs"
26
26
 
27
+ PACKAGE_NAME = "pyspiral"
28
+
27
29
 
28
30
  def validate_token(v, handler: ValidatorFunctionWrapHandler):
29
31
  if isinstance(v, str):
@@ -89,7 +91,7 @@ class Settings(BaseSettings):
89
91
  def authn(self):
90
92
  if self.spiraldb.token:
91
93
  return Authn.from_token(self.spiraldb.token)
92
- return Authn.from_fallback()
94
+ return Authn.from_fallback(self.spiraldb.uri)
93
95
 
94
96
  @functools.cached_property
95
97
  def device_code_auth(self) -> DeviceCodeAuth:
spiral/snapshot.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from typing import TYPE_CHECKING
2
2
 
3
+ from spiral import ShuffleStrategy
3
4
  from spiral.core.table import Snapshot as CoreSnapshot
4
5
  from spiral.core.table.spec import Schema
5
6
  from spiral.types_ import Timestamp
@@ -8,6 +9,7 @@ if TYPE_CHECKING:
8
9
  import duckdb
9
10
  import polars as pl
10
11
  import pyarrow.dataset as ds
12
+ import torch.utils.data as torchdata # noqa
11
13
 
12
14
  from spiral.table import Table
13
15
 
@@ -53,3 +55,17 @@ class Snapshot:
53
55
  import duckdb
54
56
 
55
57
  return duckdb.from_arrow(self.to_dataset())
58
+
59
+ def to_iterable_dataset(
60
+ self,
61
+ *,
62
+ shuffle: ShuffleStrategy | None = None,
63
+ batch_readahead: int | None = None,
64
+ infinite: bool = False,
65
+ ) -> "torchdata.IterableDataset":
66
+ """Returns an iterable dataset compatible with `torch.IterableDataset`.
67
+
68
+ See `Table` docs for details on the parameters.
69
+ """
70
+ # TODO(marko): WIP.
71
+ raise NotImplementedError
@@ -25,12 +25,16 @@ class SpiralStream:
25
25
  """
26
26
 
27
27
  def __init__(
28
- self, scan: CoreScan, shards: list[Shard], cache_dir: str | None = None, shard_row_block_size: int = 8192
28
+ self,
29
+ scan: CoreScan,
30
+ shards: list[Shard],
31
+ cache_dir: str | None = None,
32
+ shard_row_block_size: int | None = None,
29
33
  ):
30
34
  self._scan = scan
31
35
  # TODO(marko): Read shards only on world.is_local_leader in `get_shards` and materialize on disk.
32
36
  self._shards = shards
33
- self.shard_row_block_size = shard_row_block_size
37
+ self._shard_row_block_size = shard_row_block_size or 8192
34
38
 
35
39
  if cache_dir is not None:
36
40
  if not os.path.exists(cache_dir):
@@ -99,7 +103,7 @@ class SpiralStream:
99
103
  shard_path,
100
104
  shard.shard.key_range,
101
105
  expected_cardinality=shard.shard.cardinality,
102
- shard_row_block_size=self.shard_row_block_size,
106
+ shard_row_block_size=self._shard_row_block_size,
103
107
  )
104
108
 
105
109
  # Get the size of the file on disk.
spiral/table.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from datetime import datetime
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
+ from spiral import ShuffleStrategy
4
5
  from spiral.core.table import Table as CoreTable
5
6
  from spiral.core.table.spec import Schema
6
7
  from spiral.expressions.base import Expr, ExprLike
@@ -115,7 +116,7 @@ class Table(Expr):
115
116
 
116
117
 
117
118
  :param column_paths: Fully qualified column names. (e.g., "column_name" or "nested.field").
118
- All columns must exist, if a a column doesn't exist the function will return an error.
119
+ All columns must exist, if a column doesn't exist the function will return an error.
119
120
  """
120
121
  with self.txn() as txn:
121
122
  txn.drop_columns(column_paths)
@@ -126,13 +127,16 @@ class Table(Expr):
126
127
  asof = int(asof.timestamp() * 1_000_000)
127
128
  return Snapshot(self, self.core.get_snapshot(asof=asof))
128
129
 
129
- def txn(self) -> Transaction:
130
+ def txn(self, retries: int | None = 3) -> Transaction:
130
131
  """Begins a new transaction. Transaction must be committed for writes to become visible.
131
132
 
133
+ :param retries: Maximum number of retry attempts on conflict (default: 3). Set to None for a single attempt.
134
+
132
135
  IMPORTANT: While transaction can be used to atomically write data to the table,
133
136
  it is important that the primary key columns are unique within the transaction.
137
+ The behavior is undefined if this is not the case.
134
138
  """
135
- return Transaction(self.spiral._core.transaction(self.core, settings().file_format))
139
+ return Transaction(self.spiral._core.transaction(self.core, settings().file_format, retries=retries))
136
140
 
137
141
  def to_dataset(self) -> "ds.Dataset":
138
142
  """Returns a PyArrow Dataset representing the table."""
@@ -146,108 +150,58 @@ class Table(Expr):
146
150
  """Returns a DuckDB relation for the Spiral table."""
147
151
  return self.snapshot().to_duckdb()
148
152
 
149
- def to_data_loader(self, *, index: "KeySpaceIndex", **kwargs) -> "torchdata.DataLoader":
150
- """Returns a PyTorch DataLoader.
153
+ def to_iterable_dataset(
154
+ self,
155
+ *,
156
+ index: Optional["KeySpaceIndex"] = None,
157
+ shuffle: ShuffleStrategy | None = None,
158
+ batch_readahead: int | None = None,
159
+ infinite: bool = False,
160
+ ) -> "torchdata.IterableDataset":
161
+ """Returns an iterable dataset compatible with `torch.IterableDataset`. It can be used for training
162
+ in local or distributed settings.
163
+
164
+ Supports sharding, shuffling, and compatible for multiprocessing with `num_workers`. If projections and
165
+ filtering are needed, you must create a key space index and pass it when creating the stream.
151
166
 
152
- Requires `torch` and `streaming` package to be installed.
167
+ Requires `torch` package to be installed.
153
168
 
154
169
  Args:
155
- index: See `streaming` method.
170
+ shuffle: Controls sample shuffling. If None, no shuffling is performed.
171
+ batch_readahead: Controls how many batches to read ahead concurrently.
172
+ If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
173
+ Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
174
+ asof: If provided, only data written before the given timestamp will be returned.
175
+ If `index` is provided, it must not be used. The index's `asof` will be used instead.
176
+ infinite: If True, the returned IterableDataset will loop infinitely over the data,
177
+ re-shuffling ranges after exhausting all data.
178
+ index: Optional prebuilt KeysIndex to use when creating the stream.
179
+ The index's `asof` will be used when scanning.
156
180
  **kwargs: Additional arguments passed to the PyTorch DataLoader constructor.
157
-
158
181
  """
159
- from streaming import StreamingDataLoader
160
-
161
- dataset_kwargs = {}
162
- if "batch_size" in kwargs:
163
- # Keep it in kwargs for DataLoader
164
- dataset_kwargs["batch_size"] = kwargs["batch_size"]
165
- if "cache_limit" in kwargs:
166
- dataset_kwargs["cache_limit"] = kwargs.pop("cache_limit")
167
- if "sampling_method" in kwargs:
168
- dataset_kwargs["sampling_method"] = kwargs.pop("sampling_method")
169
- if "sampling_granularity" in kwargs:
170
- dataset_kwargs["sampling_granularity"] = kwargs.pop("sampling_granularity")
171
- if "partition_algo" in kwargs:
172
- dataset_kwargs["partition_algo"] = kwargs.pop("partition_algo")
173
- if "num_canonical_nodes" in kwargs:
174
- dataset_kwargs["num_canonical_nodes"] = kwargs.pop("num_canonical_nodes")
175
- if "shuffle" in kwargs:
176
- dataset_kwargs["shuffle"] = kwargs.pop("shuffle")
177
- if "shuffle_algo" in kwargs:
178
- dataset_kwargs["shuffle_algo"] = kwargs.pop("shuffle_algo")
179
- if "shuffle_seed" in kwargs:
180
- dataset_kwargs["shuffle_seed"] = kwargs.pop("shuffle_seed")
181
- if "shuffle_block_size" in kwargs:
182
- dataset_kwargs["shuffle_block_size"] = kwargs.pop("shuffle_block_size")
183
- if "batching_method" in kwargs:
184
- dataset_kwargs["batching_method"] = kwargs.pop("batching_method")
185
- if "replication" in kwargs:
186
- dataset_kwargs["replication"] = kwargs.pop("replication")
187
-
188
- dataset = self.to_streaming_dataset(index=index, **dataset_kwargs)
189
-
190
- return StreamingDataLoader(dataset=dataset, **kwargs)
191
-
192
- def to_streaming_dataset(
182
+ # TODO(marko): WIP.
183
+ raise NotImplementedError
184
+
185
+ def to_streaming(
193
186
  self,
194
- *,
195
187
  index: "KeySpaceIndex",
196
- batch_size: int | None = None,
188
+ *,
189
+ projection: Expr | None = None,
197
190
  cache_dir: str | None = None,
198
- cache_limit: int | str | None = None,
199
- predownload: int | None = None,
200
- sampling_method: str = "balanced",
201
- sampling_granularity: int = 1,
202
- partition_algo: str = "relaxed",
203
- num_canonical_nodes: int | None = None,
204
- shuffle: bool = False,
205
- shuffle_algo: str = "py1e",
206
- shuffle_seed: int = 9176,
207
- shuffle_block_size: int | None = None,
208
- batching_method: str = "random",
209
- replication: int | None = None,
210
- ) -> "streaming.StreamingDataset":
211
- """Returns a MosaicML's StreamingDataset that can be used for distributed training.
212
-
213
- Requires `streaming` package to be installed.
214
-
215
- Args:
216
- See `streaming` method for `index` arg.
217
- See MosaicML's `StreamingDataset` for other args.
218
-
219
- This is a helper method to construct a single stream dataset from the scan. When multiple streams are combined,
220
- use `to_stream` to get the SpiralStream and construct the StreamingDataset manually using a `streams` arg.
221
- """
222
- from streaming import StreamingDataset
223
-
224
- stream = self.to_streaming(index=index, cache_dir=cache_dir)
225
-
226
- return StreamingDataset(
227
- streams=[stream],
228
- batch_size=batch_size,
229
- cache_limit=cache_limit,
230
- predownload=predownload,
231
- sampling_method=sampling_method,
232
- sampling_granularity=sampling_granularity,
233
- partition_algo=partition_algo,
234
- num_canonical_nodes=num_canonical_nodes,
235
- shuffle=shuffle,
236
- shuffle_algo=shuffle_algo,
237
- shuffle_seed=shuffle_seed,
238
- shuffle_block_size=shuffle_block_size,
239
- batching_method=batching_method,
240
- replication=replication,
241
- )
242
-
243
- def to_streaming(self, index: "KeySpaceIndex", *, cache_dir: str | None = None) -> "streaming.Stream":
191
+ shard_row_block_size: int | None = None,
192
+ ) -> "streaming.Stream":
244
193
  """Returns a stream to be used with MosaicML's StreamingDataset.
245
194
 
246
195
  Requires `streaming` package to be installed.
247
196
 
248
197
  Args:
249
- index: Prebuilt KeysIndex to use when creating the stream. The index's `asof` will be used when scanning.
198
+ index: Prebuilt KeysIndex to use when creating the stream.
199
+ The index's `asof` will be used when scanning.
200
+ projection: Optional projection to use when scanning the table if index's projection is not used.
201
+ Projection must be compatible with the index's projection for correctness.
250
202
  cache_dir: Directory to use for caching data. If None, a temporary directory will be used.
203
+ shard_row_block_size: Number of rows per segment of a shard file. Defaults to 8192.
204
+ Value should be set to lower for larger rows.
251
205
  """
252
206
  from spiral.streaming_ import SpiralStream
253
207
 
@@ -258,7 +212,7 @@ class Table(Expr):
258
212
 
259
213
  # We know table from projection is in the session cause this method is on it.
260
214
  scan = self.spiral.scan(
261
- index.projection,
215
+ projection if projection is not None else index.projection,
262
216
  where=index.filter,
263
217
  asof=index.asof,
264
218
  # TODO(marko): This should be configurable?
@@ -269,4 +223,9 @@ class Table(Expr):
269
223
  # We have a world there and can compute shards only on leader.
270
224
  shards = self.spiral._core._ops().compute_shards(index=index.core)
271
225
 
272
- return SpiralStream(scan=scan.core, shards=shards, cache_dir=cache_dir) # type: ignore[return-value]
226
+ return SpiralStream(
227
+ scan=scan.core,
228
+ shards=shards,
229
+ cache_dir=cache_dir,
230
+ shard_row_block_size=shard_row_block_size,
231
+ ) # type: ignore[return-value]
spiral/transaction.py CHANGED
@@ -45,7 +45,7 @@ class Transaction:
45
45
 
46
46
 
47
47
  :param column_paths: Fully qualified column names. (e.g., "column_name" or "nested.field").
48
- All columns must exist, if a a column doesn't exist the function will return an error.
48
+ All columns must exist, if a column doesn't exist the function will return an error.
49
49
  """
50
50
  self._core.drop_columns(column_paths)
51
51