pyspiral 0.6.3__cp310-abi3-macosx_11_0_arm64.whl → 0.6.4__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,106 @@
1
+ from collections.abc import Callable, Iterator
2
+ from typing import TYPE_CHECKING
3
+
4
+ import pyarrow as pa
5
+
6
+ if TYPE_CHECKING:
7
+ import datasets.iterable_dataset as hf # noqa
8
+ import streaming # noqa
9
+ import torch.utils.data as torchdata # noqa
10
+
11
+
12
+ def _hf_compatible_schema(schema: pa.Schema) -> pa.Schema:
13
+ """
14
+ Replace string-view and binary-view columns in the schema with strings/binary.
15
+ Recursively handles nested types (struct, list, etc).
16
+ We use this converted schema as Features in the returned Dataset.
17
+ Remove this method once we have https://github.com/huggingface/datasets/pull/7718
18
+ """
19
+
20
+ def _convert_type(dtype: pa.DataType) -> pa.DataType:
21
+ if dtype == pa.string_view():
22
+ return pa.string()
23
+ elif dtype == pa.binary_view():
24
+ return pa.binary()
25
+ elif pa.types.is_struct(dtype):
26
+ new_fields = [
27
+ pa.field(field.name, _convert_type(field.type), nullable=field.nullable, metadata=field.metadata)
28
+ for field in dtype
29
+ ]
30
+ return pa.struct(new_fields)
31
+ elif pa.types.is_list(dtype):
32
+ return pa.list_(_convert_type(dtype.value_type))
33
+ elif pa.types.is_large_list(dtype):
34
+ return pa.large_list(_convert_type(dtype.value_type))
35
+ elif pa.types.is_fixed_size_list(dtype):
36
+ return pa.list_(_convert_type(dtype.value_type), dtype.list_size)
37
+ elif pa.types.is_map(dtype):
38
+ return pa.map_(_convert_type(dtype.key_type), _convert_type(dtype.item_type))
39
+ else:
40
+ return dtype
41
+
42
+ new_fields = []
43
+ for field in schema:
44
+ new_type = _convert_type(field.type)
45
+ new_fields.append(pa.field(field.name, new_type, nullable=field.nullable, metadata=field.metadata))
46
+
47
+ return pa.schema(new_fields)
48
+
49
+
50
+ def to_iterable_dataset(stream: pa.RecordBatchReader) -> "hf.IterableDataset":
51
+ from datasets import DatasetInfo, Features
52
+ from datasets.builder import ArrowExamplesIterable
53
+ from datasets.iterable_dataset import IterableDataset
54
+
55
+ def _generate_tables(**kwargs) -> Iterator[tuple[int, pa.Table]]:
56
+ # This key is unused when training with IterableDataset.
57
+ # Default implementation returns shard id, e.g. parquet row group id.
58
+ for i, rb in enumerate(stream):
59
+ yield i, pa.Table.from_batches([rb], stream.schema)
60
+
61
+ # TODO(marko): This is temporary until we stop returning IterableDataset from this function.
62
+ class _IterableDataset(IterableDataset):
63
+ # Diff with datasets.iterable_dataset.IterableDataset:
64
+ # - Removes torch handling which attempts to handle worker processes.
65
+ # - Assumes arrow iterator.
66
+ def __iter__(self):
67
+ from datasets.formatting import get_formatter
68
+
69
+ prepared_ex_iterable = self._prepare_ex_iterable_for_iteration()
70
+ if self._formatting and (prepared_ex_iterable.iter_arrow or self._formatting.is_table):
71
+ formatter = get_formatter(self._formatting.format_type, features=self.features)
72
+ iterator = prepared_ex_iterable.iter_arrow()
73
+ for key, pa_table in iterator:
74
+ yield formatter.format_row(pa_table)
75
+ return
76
+
77
+ for key, example in prepared_ex_iterable:
78
+ # no need to format thanks to FormattedExamplesIterable
79
+ yield example
80
+
81
+ def map(self, *args, **kwargs):
82
+ # Map constructs a new IterableDataset, so we need to "patch" it
83
+ base = super().map(*args, **kwargs)
84
+ if isinstance(base, IterableDataset):
85
+ # Patch __iter__ to avoid torch handling
86
+ base.__class__ = _IterableDataset # type: ignore
87
+ return base
88
+
89
+ class _ArrowExamplesIterable(ArrowExamplesIterable):
90
+ def __init__(self, generate_tables_fn: Callable[..., Iterator[tuple[int, pa.Table]]], features: Features):
91
+ # NOTE: generate_tables_fn type annotations are wrong, return type must be an iterable of tuples.
92
+ super().__init__(generate_tables_fn, kwargs={}) # type: ignore
93
+ self._features = features
94
+
95
+ @property
96
+ def is_typed(self) -> bool:
97
+ return True
98
+
99
+ @property
100
+ def features(self) -> Features:
101
+ return self._features
102
+
103
+ target_features = Features.from_arrow_schema(_hf_compatible_schema(stream.schema))
104
+ ex_iterable = _ArrowExamplesIterable(_generate_tables, target_features)
105
+ info = DatasetInfo(features=target_features)
106
+ return _IterableDataset(ex_iterable=ex_iterable, info=info)
@@ -69,7 +69,7 @@ import betterproto2
69
69
 
70
70
  from .....message_pool import default_message_pool
71
71
 
72
- _COMPILER_VERSION = "0.8.0"
72
+ _COMPILER_VERSION = "0.9.0"
73
73
  betterproto2.check_compiler_version(_COMPILER_VERSION)
74
74
 
75
75
 
@@ -21,12 +21,15 @@ __all__ = (
21
21
  "FeatureSet",
22
22
  "FeatureSetDefaults",
23
23
  "FeatureSetDefaultsFeatureSetEditionDefault",
24
+ "FeatureSetEnforceNamingStyle",
24
25
  "FeatureSetEnumType",
25
26
  "FeatureSetFieldPresence",
26
27
  "FeatureSetJsonFormat",
27
28
  "FeatureSetMessageEncoding",
28
29
  "FeatureSetRepeatedFieldEncoding",
29
30
  "FeatureSetUtf8Validation",
31
+ "FeatureSetVisibilityFeature",
32
+ "FeatureSetVisibilityFeatureDefaultSymbolVisibility",
30
33
  "FieldDescriptorProto",
31
34
  "FieldDescriptorProtoLabel",
32
35
  "FieldDescriptorProtoType",
@@ -54,6 +57,7 @@ __all__ = (
54
57
  "ServiceOptions",
55
58
  "SourceCodeInfo",
56
59
  "SourceCodeInfoLocation",
60
+ "SymbolVisibility",
57
61
  "UninterpretedOption",
58
62
  "UninterpretedOptionNamePart",
59
63
  )
@@ -66,7 +70,7 @@ import betterproto2
66
70
 
67
71
  from ...message_pool import default_message_pool
68
72
 
69
- _COMPILER_VERSION = "0.8.0"
73
+ _COMPILER_VERSION = "0.9.0"
70
74
  betterproto2.check_compiler_version(_COMPILER_VERSION)
71
75
 
72
76
 
@@ -174,6 +178,14 @@ class ExtensionRangeOptionsVerificationState(betterproto2.Enum):
174
178
  UNVERIFIED = 1
175
179
 
176
180
 
181
+ class FeatureSetEnforceNamingStyle(betterproto2.Enum):
182
+ ENFORCE_NAMING_STYLE_UNKNOWN = 0
183
+
184
+ STYLE2024 = 1
185
+
186
+ STYLE_LEGACY = 2
187
+
188
+
177
189
  class FeatureSetEnumType(betterproto2.Enum):
178
190
  ENUM_TYPE_UNKNOWN = 0
179
191
 
@@ -224,6 +236,32 @@ class FeatureSetUtf8Validation(betterproto2.Enum):
224
236
  NONE = 3
225
237
 
226
238
 
239
+ class FeatureSetVisibilityFeatureDefaultSymbolVisibility(betterproto2.Enum):
240
+ DEFAULT_SYMBOL_VISIBILITY_UNKNOWN = 0
241
+
242
+ EXPORT_ALL = 1
243
+ """
244
+ Default pre-EDITION_2024, all UNSET visibility are export.
245
+ """
246
+
247
+ EXPORT_TOP_LEVEL = 2
248
+ """
249
+ All top-level symbols default to export, nested default to local.
250
+ """
251
+
252
+ LOCAL_ALL = 3
253
+ """
254
+ All symbols default to local.
255
+ """
256
+
257
+ STRICT = 4
258
+ """
259
+ All symbols local by default. Nested types cannot be exported.
260
+ With special case caveat for message { enum {} reserved 1 to max; }
261
+ This is the recommended setting for new protos.
262
+ """
263
+
264
+
227
265
  class FieldDescriptorProtoLabel(betterproto2.Enum):
228
266
  OPTIONAL = 1
229
267
  """
@@ -512,6 +550,22 @@ class MethodOptionsIdempotencyLevel(betterproto2.Enum):
512
550
  """
513
551
 
514
552
 
553
+ class SymbolVisibility(betterproto2.Enum):
554
+ """
555
+ Describes the 'visibility' of a symbol with respect to the proto import
556
+ system. Symbols can only be imported when the visibility rules do not prevent
557
+ it (ex: local symbols cannot be imported). Visibility modifiers can only set
558
+ on `message` and `enum` as they are the only types available to be referenced
559
+ from other files.
560
+ """
561
+
562
+ VISIBILITY_UNSET = 0
563
+
564
+ VISIBILITY_LOCAL = 1
565
+
566
+ VISIBILITY_EXPORT = 2
567
+
568
+
515
569
  @dataclass(eq=False, repr=False)
516
570
  class Any(betterproto2.Message):
517
571
  """
@@ -744,6 +798,13 @@ class DescriptorProto(betterproto2.Message):
744
798
  A given name may only be reserved once.
745
799
  """
746
800
 
801
+ visibility: "SymbolVisibility" = betterproto2.field(
802
+ 11, betterproto2.TYPE_ENUM, default_factory=lambda: SymbolVisibility(0)
803
+ )
804
+ """
805
+ Support for `export` and `local` keywords on enums.
806
+ """
807
+
747
808
 
748
809
  default_message_pool.register_message("google.protobuf", "DescriptorProto", DescriptorProto)
749
810
 
@@ -835,6 +896,13 @@ class EnumDescriptorProto(betterproto2.Message):
835
896
  be reserved once.
836
897
  """
837
898
 
899
+ visibility: "SymbolVisibility" = betterproto2.field(
900
+ 6, betterproto2.TYPE_ENUM, default_factory=lambda: SymbolVisibility(0)
901
+ )
902
+ """
903
+ Support for `export` and `local` keywords on enums.
904
+ """
905
+
838
906
 
839
907
  default_message_pool.register_message("google.protobuf", "EnumDescriptorProto", EnumDescriptorProto)
840
908
 
@@ -895,6 +963,9 @@ class EnumOptions(betterproto2.Message):
895
963
  features: "FeatureSet | None" = betterproto2.field(7, betterproto2.TYPE_MESSAGE, optional=True)
896
964
  """
897
965
  Any features defined in the specific edition.
966
+ WARNING: This field should only be used by protobuf plugins or special
967
+ cases like the proto compiler. Other uses are discouraged and
968
+ developers should rely on the protoreflect APIs for their client language.
898
969
  """
899
970
 
900
971
  uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
@@ -942,6 +1013,9 @@ class EnumValueOptions(betterproto2.Message):
942
1013
  features: "FeatureSet | None" = betterproto2.field(2, betterproto2.TYPE_MESSAGE, optional=True)
943
1014
  """
944
1015
  Any features defined in the specific edition.
1016
+ WARNING: This field should only be used by protobuf plugins or special
1017
+ cases like the proto compiler. Other uses are discouraged and
1018
+ developers should rely on the protoreflect APIs for their client language.
945
1019
  """
946
1020
 
947
1021
  debug_redact: "bool" = betterproto2.field(3, betterproto2.TYPE_BOOL)
@@ -1082,10 +1156,26 @@ class FeatureSet(betterproto2.Message):
1082
1156
  6, betterproto2.TYPE_ENUM, default_factory=lambda: FeatureSetJsonFormat(0)
1083
1157
  )
1084
1158
 
1159
+ enforce_naming_style: "FeatureSetEnforceNamingStyle" = betterproto2.field(
1160
+ 7, betterproto2.TYPE_ENUM, default_factory=lambda: FeatureSetEnforceNamingStyle(0)
1161
+ )
1162
+
1163
+ default_symbol_visibility: "FeatureSetVisibilityFeatureDefaultSymbolVisibility" = betterproto2.field(
1164
+ 8, betterproto2.TYPE_ENUM, default_factory=lambda: FeatureSetVisibilityFeatureDefaultSymbolVisibility(0)
1165
+ )
1166
+
1085
1167
 
1086
1168
  default_message_pool.register_message("google.protobuf", "FeatureSet", FeatureSet)
1087
1169
 
1088
1170
 
1171
+ @dataclass(eq=False, repr=False)
1172
+ class FeatureSetVisibilityFeature(betterproto2.Message):
1173
+ pass
1174
+
1175
+
1176
+ default_message_pool.register_message("google.protobuf", "FeatureSet.VisibilityFeature", FeatureSetVisibilityFeature)
1177
+
1178
+
1089
1179
  @dataclass(eq=False, repr=False)
1090
1180
  class FeatureSetDefaults(betterproto2.Message):
1091
1181
  """
@@ -1340,6 +1430,9 @@ class FieldOptions(betterproto2.Message):
1340
1430
  features: "FeatureSet | None" = betterproto2.field(21, betterproto2.TYPE_MESSAGE, optional=True)
1341
1431
  """
1342
1432
  Any features defined in the specific edition.
1433
+ WARNING: This field should only be used by protobuf plugins or special
1434
+ cases like the proto compiler. Other uses are discouraged and
1435
+ developers should rely on the protoreflect APIs for their client language.
1343
1436
  """
1344
1437
 
1345
1438
  feature_support: "FieldOptionsFeatureSupport | None" = betterproto2.field(
@@ -1438,6 +1531,12 @@ class FileDescriptorProto(betterproto2.Message):
1438
1531
  For Google-internal migration only. Do not use.
1439
1532
  """
1440
1533
 
1534
+ option_dependency: "list[str]" = betterproto2.field(15, betterproto2.TYPE_STRING, repeated=True)
1535
+ """
1536
+ Names of files imported by this file purely for the purpose of providing
1537
+ option extensions. These are excluded from the dependency list above.
1538
+ """
1539
+
1441
1540
  message_type: "list[DescriptorProto]" = betterproto2.field(4, betterproto2.TYPE_MESSAGE, repeated=True)
1442
1541
  """
1443
1542
  All top-level definitions in this file.
@@ -1465,11 +1564,17 @@ class FileDescriptorProto(betterproto2.Message):
1465
1564
  The supported values are "proto2", "proto3", and "editions".
1466
1565
 
1467
1566
  If `edition` is present, this value must be "editions".
1567
+ WARNING: This field should only be used by protobuf plugins or special
1568
+ cases like the proto compiler. Other uses are discouraged and
1569
+ developers should rely on the protoreflect APIs for their client language.
1468
1570
  """
1469
1571
 
1470
1572
  edition: "Edition" = betterproto2.field(14, betterproto2.TYPE_ENUM, default_factory=lambda: Edition(0))
1471
1573
  """
1472
1574
  The edition of the proto file.
1575
+ WARNING: This field should only be used by protobuf plugins or special
1576
+ cases like the proto compiler. Other uses are discouraged and
1577
+ developers should rely on the protoreflect APIs for their client language.
1473
1578
  """
1474
1579
 
1475
1580
 
@@ -1665,6 +1770,9 @@ class FileOptions(betterproto2.Message):
1665
1770
  features: "FeatureSet | None" = betterproto2.field(50, betterproto2.TYPE_MESSAGE, optional=True)
1666
1771
  """
1667
1772
  Any features defined in the specific edition.
1773
+ WARNING: This field should only be used by protobuf plugins or special
1774
+ cases like the proto compiler. Other uses are discouraged and
1775
+ developers should rely on the protoreflect APIs for their client language.
1668
1776
  """
1669
1777
 
1670
1778
  uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
@@ -1817,6 +1925,9 @@ class MessageOptions(betterproto2.Message):
1817
1925
  features: "FeatureSet | None" = betterproto2.field(12, betterproto2.TYPE_MESSAGE, optional=True)
1818
1926
  """
1819
1927
  Any features defined in the specific edition.
1928
+ WARNING: This field should only be used by protobuf plugins or special
1929
+ cases like the proto compiler. Other uses are discouraged and
1930
+ developers should rely on the protoreflect APIs for their client language.
1820
1931
  """
1821
1932
 
1822
1933
  uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
@@ -1889,6 +2000,9 @@ class MethodOptions(betterproto2.Message):
1889
2000
  features: "FeatureSet | None" = betterproto2.field(35, betterproto2.TYPE_MESSAGE, optional=True)
1890
2001
  """
1891
2002
  Any features defined in the specific edition.
2003
+ WARNING: This field should only be used by protobuf plugins or special
2004
+ cases like the proto compiler. Other uses are discouraged and
2005
+ developers should rely on the protoreflect APIs for their client language.
1892
2006
  """
1893
2007
 
1894
2008
  uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
@@ -1921,6 +2035,9 @@ class OneofOptions(betterproto2.Message):
1921
2035
  features: "FeatureSet | None" = betterproto2.field(1, betterproto2.TYPE_MESSAGE, optional=True)
1922
2036
  """
1923
2037
  Any features defined in the specific edition.
2038
+ WARNING: This field should only be used by protobuf plugins or special
2039
+ cases like the proto compiler. Other uses are discouraged and
2040
+ developers should rely on the protoreflect APIs for their client language.
1924
2041
  """
1925
2042
 
1926
2043
  uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
@@ -1955,6 +2072,9 @@ class ServiceOptions(betterproto2.Message):
1955
2072
  features: "FeatureSet | None" = betterproto2.field(34, betterproto2.TYPE_MESSAGE, optional=True)
1956
2073
  """
1957
2074
  Any features defined in the specific edition.
2075
+ WARNING: This field should only be used by protobuf plugins or special
2076
+ cases like the proto compiler. Other uses are discouraged and
2077
+ developers should rely on the protoreflect APIs for their client language.
1958
2078
  """
1959
2079
 
1960
2080
  deprecated: "bool" = betterproto2.field(33, betterproto2.TYPE_BOOL)
@@ -25,7 +25,7 @@ import grpc
25
25
 
26
26
  from ..message_pool import default_message_pool
27
27
 
28
- _COMPILER_VERSION = "0.8.0"
28
+ _COMPILER_VERSION = "0.9.0"
29
29
  betterproto2.check_compiler_version(_COMPILER_VERSION)
30
30
 
31
31
 
@@ -16,7 +16,7 @@ import betterproto2
16
16
 
17
17
  from ..message_pool import default_message_pool
18
18
 
19
- _COMPILER_VERSION = "0.8.0"
19
+ _COMPILER_VERSION = "0.9.0"
20
20
  betterproto2.check_compiler_version(_COMPILER_VERSION)
21
21
 
22
22
 
@@ -17,7 +17,7 @@ import betterproto2
17
17
 
18
18
  from ..message_pool import default_message_pool
19
19
 
20
- _COMPILER_VERSION = "0.8.0"
20
+ _COMPILER_VERSION = "0.9.0"
21
21
  betterproto2.check_compiler_version(_COMPILER_VERSION)
22
22
 
23
23
 
@@ -246,7 +246,7 @@ import betterproto2
246
246
 
247
247
  from ..message_pool import default_message_pool
248
248
 
249
- _COMPILER_VERSION = "0.8.0"
249
+ _COMPILER_VERSION = "0.9.0"
250
250
  betterproto2.check_compiler_version(_COMPILER_VERSION)
251
251
 
252
252
 
@@ -18,7 +18,7 @@ import betterproto2
18
18
 
19
19
  from ...message_pool import default_message_pool
20
20
 
21
- _COMPILER_VERSION = "0.8.0"
21
+ _COMPILER_VERSION = "0.9.0"
22
22
  betterproto2.check_compiler_version(_COMPILER_VERSION)
23
23
 
24
24
 
spiral/scan.py CHANGED
@@ -1,4 +1,3 @@
1
- from collections.abc import Iterator
2
1
  from typing import TYPE_CHECKING, Any
3
2
 
4
3
  import pyarrow as pa
@@ -120,8 +119,11 @@ class Scan:
120
119
  self,
121
120
  shuffle: ShuffleStrategy | None = None,
122
121
  batch_readahead: int | None = None,
122
+ num_workers: int | None = None,
123
+ worker_id: int | None = None,
124
+ infinite: bool = False,
123
125
  ) -> "hf.IterableDataset":
124
- """Returns an Huggingface's IterableDataset.
126
+ """Returns a Huggingface's IterableDataset.
125
127
 
126
128
  Requires `datasets` package to be installed.
127
129
 
@@ -130,39 +132,25 @@ class Scan:
130
132
  batch_readahead: Controls how many batches to read ahead concurrently.
131
133
  If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
132
134
  Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
135
+ num_workers: If not None, shards the scan across multiple workers.
136
+ Must be used together with worker_id.
137
+ worker_id: If not None, the id of the current worker.
138
+ Scan will only return a subset of the data corresponding to the worker_id.
139
+ infinite: If True, the returned IterableDataset will loop infinitely over the data,
140
+ re-shuffling ranges after exhausting all data.
133
141
  """
134
- from datasets import DatasetInfo, Features
135
- from datasets.iterable_dataset import ArrowExamplesIterable, IterableDataset
136
-
137
- def _generate_tables(**kwargs) -> Iterator[tuple[int, pa.Table]]:
138
- stream = self.core.to_shuffled_record_batches(
139
- shuffle,
140
- batch_readahead,
141
- )
142
-
143
- # This key is unused when training with IterableDataset.
144
- # Default implementation returns shard id, e.g. parquet row group id.
145
- for i, rb in enumerate(stream):
146
- yield i, pa.Table.from_batches([rb], stream.schema)
147
-
148
- def _hf_compatible_schema(schema: pa.Schema) -> pa.Schema:
149
- """
150
- Replace string-view columns in the schema with strings. We do use this converted schema
151
- as Features in the returned Dataset.
152
- Remove this method once we have https://github.com/huggingface/datasets/pull/7718
153
- """
154
- new_fields = [
155
- pa.field(field.name, pa.string(), nullable=field.nullable, metadata=field.metadata)
156
- if field.type == pa.string_view()
157
- else field
158
- for field in schema
159
- ]
160
- return pa.schema(new_fields)
161
-
162
- # NOTE: generate_tables_fn type annotations are wrong, return type must be an iterable of tuples.
163
- ex_iterable = ArrowExamplesIterable(generate_tables_fn=_generate_tables, kwargs={}) # type: ignore
164
- info = DatasetInfo(features=Features.from_arrow_schema(_hf_compatible_schema(self.schema.to_arrow())))
165
- return IterableDataset(ex_iterable=ex_iterable, info=info)
142
+
143
+ stream = self.core.to_shuffled_record_batches(
144
+ shuffle,
145
+ batch_readahead,
146
+ num_workers,
147
+ worker_id,
148
+ infinite,
149
+ )
150
+
151
+ from spiral.iterable_dataset import to_iterable_dataset
152
+
153
+ return to_iterable_dataset(stream)
166
154
 
167
155
  def _splits(self) -> list[KeyRange]:
168
156
  # Splits the scan into a set of key ranges.
spiral/settings.py CHANGED
@@ -24,6 +24,8 @@ CI = "GITHUB_ACTIONS" in os.environ
24
24
  APP_DIR = Path(typer.get_app_dir("pyspiral"))
25
25
  LOG_DIR = APP_DIR / "logs"
26
26
 
27
+ PACKAGE_NAME = "pyspiral"
28
+
27
29
 
28
30
  def validate_token(v, handler: ValidatorFunctionWrapHandler):
29
31
  if isinstance(v, str):
spiral/snapshot.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from typing import TYPE_CHECKING
2
2
 
3
+ from spiral import ShuffleStrategy
3
4
  from spiral.core.table import Snapshot as CoreSnapshot
4
5
  from spiral.core.table.spec import Schema
5
6
  from spiral.types_ import Timestamp
@@ -8,6 +9,7 @@ if TYPE_CHECKING:
8
9
  import duckdb
9
10
  import polars as pl
10
11
  import pyarrow.dataset as ds
12
+ import torch.utils.data as torchdata # noqa
11
13
 
12
14
  from spiral.table import Table
13
15
 
@@ -53,3 +55,17 @@ class Snapshot:
53
55
  import duckdb
54
56
 
55
57
  return duckdb.from_arrow(self.to_dataset())
58
+
59
+ def to_iterable_dataset(
60
+ self,
61
+ *,
62
+ shuffle: ShuffleStrategy | None = None,
63
+ batch_readahead: int | None = None,
64
+ infinite: bool = False,
65
+ ) -> "torchdata.IterableDataset":
66
+ """Returns an iterable dataset compatible with `torch.IterableDataset`.
67
+
68
+ See `Table` docs for details on the parameters.
69
+ """
70
+ # TODO(marko): WIP.
71
+ raise NotImplementedError
@@ -25,12 +25,16 @@ class SpiralStream:
25
25
  """
26
26
 
27
27
  def __init__(
28
- self, scan: CoreScan, shards: list[Shard], cache_dir: str | None = None, shard_row_block_size: int = 8192
28
+ self,
29
+ scan: CoreScan,
30
+ shards: list[Shard],
31
+ cache_dir: str | None = None,
32
+ shard_row_block_size: int | None = None,
29
33
  ):
30
34
  self._scan = scan
31
35
  # TODO(marko): Read shards only on world.is_local_leader in `get_shards` and materialize on disk.
32
36
  self._shards = shards
33
- self.shard_row_block_size = shard_row_block_size
37
+ self._shard_row_block_size = shard_row_block_size or 8192
34
38
 
35
39
  if cache_dir is not None:
36
40
  if not os.path.exists(cache_dir):
@@ -99,7 +103,7 @@ class SpiralStream:
99
103
  shard_path,
100
104
  shard.shard.key_range,
101
105
  expected_cardinality=shard.shard.cardinality,
102
- shard_row_block_size=self.shard_row_block_size,
106
+ shard_row_block_size=self._shard_row_block_size,
103
107
  )
104
108
 
105
109
  # Get the size of the file on disk.