pyspiral 0.6.3__cp310-abi3-macosx_11_0_arm64.whl → 0.6.4__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyspiral-0.6.3.dist-info → pyspiral-0.6.4.dist-info}/METADATA +3 -3
- {pyspiral-0.6.3.dist-info → pyspiral-0.6.4.dist-info}/RECORD +29 -27
- {pyspiral-0.6.3.dist-info → pyspiral-0.6.4.dist-info}/WHEEL +1 -1
- spiral/_lib.abi3.so +0 -0
- spiral/api/client.py +1 -1
- spiral/api/filesystems.py +9 -40
- spiral/cli/app.py +42 -6
- spiral/cli/fs.py +25 -60
- spiral/cli/login.py +3 -2
- spiral/core/_tools/__init__.pyi +5 -0
- spiral/core/client/__init__.pyi +12 -1
- spiral/core/table/__init__.pyi +3 -0
- spiral/debug/manifests.py +26 -18
- spiral/expressions/__init__.py +2 -2
- spiral/expressions/base.py +9 -3
- spiral/iterable_dataset.py +106 -0
- spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +1 -1
- spiral/protogen/_/google/protobuf/__init__.py +121 -1
- spiral/protogen/_/scandal/__init__.py +1 -1
- spiral/protogen/_/spfs/__init__.py +1 -1
- spiral/protogen/_/spql/__init__.py +1 -1
- spiral/protogen/_/substrait/__init__.py +1 -1
- spiral/protogen/_/substrait/extensions/__init__.py +1 -1
- spiral/scan.py +22 -34
- spiral/settings.py +2 -0
- spiral/snapshot.py +16 -0
- spiral/streaming_/stream.py +7 -3
- spiral/table.py +48 -91
- {pyspiral-0.6.3.dist-info → pyspiral-0.6.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
from collections.abc import Callable, Iterator
|
2
|
+
from typing import TYPE_CHECKING
|
3
|
+
|
4
|
+
import pyarrow as pa
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
import datasets.iterable_dataset as hf # noqa
|
8
|
+
import streaming # noqa
|
9
|
+
import torch.utils.data as torchdata # noqa
|
10
|
+
|
11
|
+
|
12
|
+
def _hf_compatible_schema(schema: pa.Schema) -> pa.Schema:
|
13
|
+
"""
|
14
|
+
Replace string-view and binary-view columns in the schema with strings/binary.
|
15
|
+
Recursively handles nested types (struct, list, etc).
|
16
|
+
We use this converted schema as Features in the returned Dataset.
|
17
|
+
Remove this method once we have https://github.com/huggingface/datasets/pull/7718
|
18
|
+
"""
|
19
|
+
|
20
|
+
def _convert_type(dtype: pa.DataType) -> pa.DataType:
|
21
|
+
if dtype == pa.string_view():
|
22
|
+
return pa.string()
|
23
|
+
elif dtype == pa.binary_view():
|
24
|
+
return pa.binary()
|
25
|
+
elif pa.types.is_struct(dtype):
|
26
|
+
new_fields = [
|
27
|
+
pa.field(field.name, _convert_type(field.type), nullable=field.nullable, metadata=field.metadata)
|
28
|
+
for field in dtype
|
29
|
+
]
|
30
|
+
return pa.struct(new_fields)
|
31
|
+
elif pa.types.is_list(dtype):
|
32
|
+
return pa.list_(_convert_type(dtype.value_type))
|
33
|
+
elif pa.types.is_large_list(dtype):
|
34
|
+
return pa.large_list(_convert_type(dtype.value_type))
|
35
|
+
elif pa.types.is_fixed_size_list(dtype):
|
36
|
+
return pa.list_(_convert_type(dtype.value_type), dtype.list_size)
|
37
|
+
elif pa.types.is_map(dtype):
|
38
|
+
return pa.map_(_convert_type(dtype.key_type), _convert_type(dtype.item_type))
|
39
|
+
else:
|
40
|
+
return dtype
|
41
|
+
|
42
|
+
new_fields = []
|
43
|
+
for field in schema:
|
44
|
+
new_type = _convert_type(field.type)
|
45
|
+
new_fields.append(pa.field(field.name, new_type, nullable=field.nullable, metadata=field.metadata))
|
46
|
+
|
47
|
+
return pa.schema(new_fields)
|
48
|
+
|
49
|
+
|
50
|
+
def to_iterable_dataset(stream: pa.RecordBatchReader) -> "hf.IterableDataset":
|
51
|
+
from datasets import DatasetInfo, Features
|
52
|
+
from datasets.builder import ArrowExamplesIterable
|
53
|
+
from datasets.iterable_dataset import IterableDataset
|
54
|
+
|
55
|
+
def _generate_tables(**kwargs) -> Iterator[tuple[int, pa.Table]]:
|
56
|
+
# This key is unused when training with IterableDataset.
|
57
|
+
# Default implementation returns shard id, e.g. parquet row group id.
|
58
|
+
for i, rb in enumerate(stream):
|
59
|
+
yield i, pa.Table.from_batches([rb], stream.schema)
|
60
|
+
|
61
|
+
# TODO(marko): This is temporary until we stop returning IterableDataset from this function.
|
62
|
+
class _IterableDataset(IterableDataset):
|
63
|
+
# Diff with datasets.iterable_dataset.IterableDataset:
|
64
|
+
# - Removes torch handling which attempts to handle worker processes.
|
65
|
+
# - Assumes arrow iterator.
|
66
|
+
def __iter__(self):
|
67
|
+
from datasets.formatting import get_formatter
|
68
|
+
|
69
|
+
prepared_ex_iterable = self._prepare_ex_iterable_for_iteration()
|
70
|
+
if self._formatting and (prepared_ex_iterable.iter_arrow or self._formatting.is_table):
|
71
|
+
formatter = get_formatter(self._formatting.format_type, features=self.features)
|
72
|
+
iterator = prepared_ex_iterable.iter_arrow()
|
73
|
+
for key, pa_table in iterator:
|
74
|
+
yield formatter.format_row(pa_table)
|
75
|
+
return
|
76
|
+
|
77
|
+
for key, example in prepared_ex_iterable:
|
78
|
+
# no need to format thanks to FormattedExamplesIterable
|
79
|
+
yield example
|
80
|
+
|
81
|
+
def map(self, *args, **kwargs):
|
82
|
+
# Map constructs a new IterableDataset, so we need to "patch" it
|
83
|
+
base = super().map(*args, **kwargs)
|
84
|
+
if isinstance(base, IterableDataset):
|
85
|
+
# Patch __iter__ to avoid torch handling
|
86
|
+
base.__class__ = _IterableDataset # type: ignore
|
87
|
+
return base
|
88
|
+
|
89
|
+
class _ArrowExamplesIterable(ArrowExamplesIterable):
|
90
|
+
def __init__(self, generate_tables_fn: Callable[..., Iterator[tuple[int, pa.Table]]], features: Features):
|
91
|
+
# NOTE: generate_tables_fn type annotations are wrong, return type must be an iterable of tuples.
|
92
|
+
super().__init__(generate_tables_fn, kwargs={}) # type: ignore
|
93
|
+
self._features = features
|
94
|
+
|
95
|
+
@property
|
96
|
+
def is_typed(self) -> bool:
|
97
|
+
return True
|
98
|
+
|
99
|
+
@property
|
100
|
+
def features(self) -> Features:
|
101
|
+
return self._features
|
102
|
+
|
103
|
+
target_features = Features.from_arrow_schema(_hf_compatible_schema(stream.schema))
|
104
|
+
ex_iterable = _ArrowExamplesIterable(_generate_tables, target_features)
|
105
|
+
info = DatasetInfo(features=target_features)
|
106
|
+
return _IterableDataset(ex_iterable=ex_iterable, info=info)
|
@@ -21,12 +21,15 @@ __all__ = (
|
|
21
21
|
"FeatureSet",
|
22
22
|
"FeatureSetDefaults",
|
23
23
|
"FeatureSetDefaultsFeatureSetEditionDefault",
|
24
|
+
"FeatureSetEnforceNamingStyle",
|
24
25
|
"FeatureSetEnumType",
|
25
26
|
"FeatureSetFieldPresence",
|
26
27
|
"FeatureSetJsonFormat",
|
27
28
|
"FeatureSetMessageEncoding",
|
28
29
|
"FeatureSetRepeatedFieldEncoding",
|
29
30
|
"FeatureSetUtf8Validation",
|
31
|
+
"FeatureSetVisibilityFeature",
|
32
|
+
"FeatureSetVisibilityFeatureDefaultSymbolVisibility",
|
30
33
|
"FieldDescriptorProto",
|
31
34
|
"FieldDescriptorProtoLabel",
|
32
35
|
"FieldDescriptorProtoType",
|
@@ -54,6 +57,7 @@ __all__ = (
|
|
54
57
|
"ServiceOptions",
|
55
58
|
"SourceCodeInfo",
|
56
59
|
"SourceCodeInfoLocation",
|
60
|
+
"SymbolVisibility",
|
57
61
|
"UninterpretedOption",
|
58
62
|
"UninterpretedOptionNamePart",
|
59
63
|
)
|
@@ -66,7 +70,7 @@ import betterproto2
|
|
66
70
|
|
67
71
|
from ...message_pool import default_message_pool
|
68
72
|
|
69
|
-
_COMPILER_VERSION = "0.
|
73
|
+
_COMPILER_VERSION = "0.9.0"
|
70
74
|
betterproto2.check_compiler_version(_COMPILER_VERSION)
|
71
75
|
|
72
76
|
|
@@ -174,6 +178,14 @@ class ExtensionRangeOptionsVerificationState(betterproto2.Enum):
|
|
174
178
|
UNVERIFIED = 1
|
175
179
|
|
176
180
|
|
181
|
+
class FeatureSetEnforceNamingStyle(betterproto2.Enum):
|
182
|
+
ENFORCE_NAMING_STYLE_UNKNOWN = 0
|
183
|
+
|
184
|
+
STYLE2024 = 1
|
185
|
+
|
186
|
+
STYLE_LEGACY = 2
|
187
|
+
|
188
|
+
|
177
189
|
class FeatureSetEnumType(betterproto2.Enum):
|
178
190
|
ENUM_TYPE_UNKNOWN = 0
|
179
191
|
|
@@ -224,6 +236,32 @@ class FeatureSetUtf8Validation(betterproto2.Enum):
|
|
224
236
|
NONE = 3
|
225
237
|
|
226
238
|
|
239
|
+
class FeatureSetVisibilityFeatureDefaultSymbolVisibility(betterproto2.Enum):
|
240
|
+
DEFAULT_SYMBOL_VISIBILITY_UNKNOWN = 0
|
241
|
+
|
242
|
+
EXPORT_ALL = 1
|
243
|
+
"""
|
244
|
+
Default pre-EDITION_2024, all UNSET visibility are export.
|
245
|
+
"""
|
246
|
+
|
247
|
+
EXPORT_TOP_LEVEL = 2
|
248
|
+
"""
|
249
|
+
All top-level symbols default to export, nested default to local.
|
250
|
+
"""
|
251
|
+
|
252
|
+
LOCAL_ALL = 3
|
253
|
+
"""
|
254
|
+
All symbols default to local.
|
255
|
+
"""
|
256
|
+
|
257
|
+
STRICT = 4
|
258
|
+
"""
|
259
|
+
All symbols local by default. Nested types cannot be exported.
|
260
|
+
With special case caveat for message { enum {} reserved 1 to max; }
|
261
|
+
This is the recommended setting for new protos.
|
262
|
+
"""
|
263
|
+
|
264
|
+
|
227
265
|
class FieldDescriptorProtoLabel(betterproto2.Enum):
|
228
266
|
OPTIONAL = 1
|
229
267
|
"""
|
@@ -512,6 +550,22 @@ class MethodOptionsIdempotencyLevel(betterproto2.Enum):
|
|
512
550
|
"""
|
513
551
|
|
514
552
|
|
553
|
+
class SymbolVisibility(betterproto2.Enum):
|
554
|
+
"""
|
555
|
+
Describes the 'visibility' of a symbol with respect to the proto import
|
556
|
+
system. Symbols can only be imported when the visibility rules do not prevent
|
557
|
+
it (ex: local symbols cannot be imported). Visibility modifiers can only set
|
558
|
+
on `message` and `enum` as they are the only types available to be referenced
|
559
|
+
from other files.
|
560
|
+
"""
|
561
|
+
|
562
|
+
VISIBILITY_UNSET = 0
|
563
|
+
|
564
|
+
VISIBILITY_LOCAL = 1
|
565
|
+
|
566
|
+
VISIBILITY_EXPORT = 2
|
567
|
+
|
568
|
+
|
515
569
|
@dataclass(eq=False, repr=False)
|
516
570
|
class Any(betterproto2.Message):
|
517
571
|
"""
|
@@ -744,6 +798,13 @@ class DescriptorProto(betterproto2.Message):
|
|
744
798
|
A given name may only be reserved once.
|
745
799
|
"""
|
746
800
|
|
801
|
+
visibility: "SymbolVisibility" = betterproto2.field(
|
802
|
+
11, betterproto2.TYPE_ENUM, default_factory=lambda: SymbolVisibility(0)
|
803
|
+
)
|
804
|
+
"""
|
805
|
+
Support for `export` and `local` keywords on enums.
|
806
|
+
"""
|
807
|
+
|
747
808
|
|
748
809
|
default_message_pool.register_message("google.protobuf", "DescriptorProto", DescriptorProto)
|
749
810
|
|
@@ -835,6 +896,13 @@ class EnumDescriptorProto(betterproto2.Message):
|
|
835
896
|
be reserved once.
|
836
897
|
"""
|
837
898
|
|
899
|
+
visibility: "SymbolVisibility" = betterproto2.field(
|
900
|
+
6, betterproto2.TYPE_ENUM, default_factory=lambda: SymbolVisibility(0)
|
901
|
+
)
|
902
|
+
"""
|
903
|
+
Support for `export` and `local` keywords on enums.
|
904
|
+
"""
|
905
|
+
|
838
906
|
|
839
907
|
default_message_pool.register_message("google.protobuf", "EnumDescriptorProto", EnumDescriptorProto)
|
840
908
|
|
@@ -895,6 +963,9 @@ class EnumOptions(betterproto2.Message):
|
|
895
963
|
features: "FeatureSet | None" = betterproto2.field(7, betterproto2.TYPE_MESSAGE, optional=True)
|
896
964
|
"""
|
897
965
|
Any features defined in the specific edition.
|
966
|
+
WARNING: This field should only be used by protobuf plugins or special
|
967
|
+
cases like the proto compiler. Other uses are discouraged and
|
968
|
+
developers should rely on the protoreflect APIs for their client language.
|
898
969
|
"""
|
899
970
|
|
900
971
|
uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
|
@@ -942,6 +1013,9 @@ class EnumValueOptions(betterproto2.Message):
|
|
942
1013
|
features: "FeatureSet | None" = betterproto2.field(2, betterproto2.TYPE_MESSAGE, optional=True)
|
943
1014
|
"""
|
944
1015
|
Any features defined in the specific edition.
|
1016
|
+
WARNING: This field should only be used by protobuf plugins or special
|
1017
|
+
cases like the proto compiler. Other uses are discouraged and
|
1018
|
+
developers should rely on the protoreflect APIs for their client language.
|
945
1019
|
"""
|
946
1020
|
|
947
1021
|
debug_redact: "bool" = betterproto2.field(3, betterproto2.TYPE_BOOL)
|
@@ -1082,10 +1156,26 @@ class FeatureSet(betterproto2.Message):
|
|
1082
1156
|
6, betterproto2.TYPE_ENUM, default_factory=lambda: FeatureSetJsonFormat(0)
|
1083
1157
|
)
|
1084
1158
|
|
1159
|
+
enforce_naming_style: "FeatureSetEnforceNamingStyle" = betterproto2.field(
|
1160
|
+
7, betterproto2.TYPE_ENUM, default_factory=lambda: FeatureSetEnforceNamingStyle(0)
|
1161
|
+
)
|
1162
|
+
|
1163
|
+
default_symbol_visibility: "FeatureSetVisibilityFeatureDefaultSymbolVisibility" = betterproto2.field(
|
1164
|
+
8, betterproto2.TYPE_ENUM, default_factory=lambda: FeatureSetVisibilityFeatureDefaultSymbolVisibility(0)
|
1165
|
+
)
|
1166
|
+
|
1085
1167
|
|
1086
1168
|
default_message_pool.register_message("google.protobuf", "FeatureSet", FeatureSet)
|
1087
1169
|
|
1088
1170
|
|
1171
|
+
@dataclass(eq=False, repr=False)
|
1172
|
+
class FeatureSetVisibilityFeature(betterproto2.Message):
|
1173
|
+
pass
|
1174
|
+
|
1175
|
+
|
1176
|
+
default_message_pool.register_message("google.protobuf", "FeatureSet.VisibilityFeature", FeatureSetVisibilityFeature)
|
1177
|
+
|
1178
|
+
|
1089
1179
|
@dataclass(eq=False, repr=False)
|
1090
1180
|
class FeatureSetDefaults(betterproto2.Message):
|
1091
1181
|
"""
|
@@ -1340,6 +1430,9 @@ class FieldOptions(betterproto2.Message):
|
|
1340
1430
|
features: "FeatureSet | None" = betterproto2.field(21, betterproto2.TYPE_MESSAGE, optional=True)
|
1341
1431
|
"""
|
1342
1432
|
Any features defined in the specific edition.
|
1433
|
+
WARNING: This field should only be used by protobuf plugins or special
|
1434
|
+
cases like the proto compiler. Other uses are discouraged and
|
1435
|
+
developers should rely on the protoreflect APIs for their client language.
|
1343
1436
|
"""
|
1344
1437
|
|
1345
1438
|
feature_support: "FieldOptionsFeatureSupport | None" = betterproto2.field(
|
@@ -1438,6 +1531,12 @@ class FileDescriptorProto(betterproto2.Message):
|
|
1438
1531
|
For Google-internal migration only. Do not use.
|
1439
1532
|
"""
|
1440
1533
|
|
1534
|
+
option_dependency: "list[str]" = betterproto2.field(15, betterproto2.TYPE_STRING, repeated=True)
|
1535
|
+
"""
|
1536
|
+
Names of files imported by this file purely for the purpose of providing
|
1537
|
+
option extensions. These are excluded from the dependency list above.
|
1538
|
+
"""
|
1539
|
+
|
1441
1540
|
message_type: "list[DescriptorProto]" = betterproto2.field(4, betterproto2.TYPE_MESSAGE, repeated=True)
|
1442
1541
|
"""
|
1443
1542
|
All top-level definitions in this file.
|
@@ -1465,11 +1564,17 @@ class FileDescriptorProto(betterproto2.Message):
|
|
1465
1564
|
The supported values are "proto2", "proto3", and "editions".
|
1466
1565
|
|
1467
1566
|
If `edition` is present, this value must be "editions".
|
1567
|
+
WARNING: This field should only be used by protobuf plugins or special
|
1568
|
+
cases like the proto compiler. Other uses are discouraged and
|
1569
|
+
developers should rely on the protoreflect APIs for their client language.
|
1468
1570
|
"""
|
1469
1571
|
|
1470
1572
|
edition: "Edition" = betterproto2.field(14, betterproto2.TYPE_ENUM, default_factory=lambda: Edition(0))
|
1471
1573
|
"""
|
1472
1574
|
The edition of the proto file.
|
1575
|
+
WARNING: This field should only be used by protobuf plugins or special
|
1576
|
+
cases like the proto compiler. Other uses are discouraged and
|
1577
|
+
developers should rely on the protoreflect APIs for their client language.
|
1473
1578
|
"""
|
1474
1579
|
|
1475
1580
|
|
@@ -1665,6 +1770,9 @@ class FileOptions(betterproto2.Message):
|
|
1665
1770
|
features: "FeatureSet | None" = betterproto2.field(50, betterproto2.TYPE_MESSAGE, optional=True)
|
1666
1771
|
"""
|
1667
1772
|
Any features defined in the specific edition.
|
1773
|
+
WARNING: This field should only be used by protobuf plugins or special
|
1774
|
+
cases like the proto compiler. Other uses are discouraged and
|
1775
|
+
developers should rely on the protoreflect APIs for their client language.
|
1668
1776
|
"""
|
1669
1777
|
|
1670
1778
|
uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
|
@@ -1817,6 +1925,9 @@ class MessageOptions(betterproto2.Message):
|
|
1817
1925
|
features: "FeatureSet | None" = betterproto2.field(12, betterproto2.TYPE_MESSAGE, optional=True)
|
1818
1926
|
"""
|
1819
1927
|
Any features defined in the specific edition.
|
1928
|
+
WARNING: This field should only be used by protobuf plugins or special
|
1929
|
+
cases like the proto compiler. Other uses are discouraged and
|
1930
|
+
developers should rely on the protoreflect APIs for their client language.
|
1820
1931
|
"""
|
1821
1932
|
|
1822
1933
|
uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
|
@@ -1889,6 +2000,9 @@ class MethodOptions(betterproto2.Message):
|
|
1889
2000
|
features: "FeatureSet | None" = betterproto2.field(35, betterproto2.TYPE_MESSAGE, optional=True)
|
1890
2001
|
"""
|
1891
2002
|
Any features defined in the specific edition.
|
2003
|
+
WARNING: This field should only be used by protobuf plugins or special
|
2004
|
+
cases like the proto compiler. Other uses are discouraged and
|
2005
|
+
developers should rely on the protoreflect APIs for their client language.
|
1892
2006
|
"""
|
1893
2007
|
|
1894
2008
|
uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
|
@@ -1921,6 +2035,9 @@ class OneofOptions(betterproto2.Message):
|
|
1921
2035
|
features: "FeatureSet | None" = betterproto2.field(1, betterproto2.TYPE_MESSAGE, optional=True)
|
1922
2036
|
"""
|
1923
2037
|
Any features defined in the specific edition.
|
2038
|
+
WARNING: This field should only be used by protobuf plugins or special
|
2039
|
+
cases like the proto compiler. Other uses are discouraged and
|
2040
|
+
developers should rely on the protoreflect APIs for their client language.
|
1924
2041
|
"""
|
1925
2042
|
|
1926
2043
|
uninterpreted_option: "list[UninterpretedOption]" = betterproto2.field(
|
@@ -1955,6 +2072,9 @@ class ServiceOptions(betterproto2.Message):
|
|
1955
2072
|
features: "FeatureSet | None" = betterproto2.field(34, betterproto2.TYPE_MESSAGE, optional=True)
|
1956
2073
|
"""
|
1957
2074
|
Any features defined in the specific edition.
|
2075
|
+
WARNING: This field should only be used by protobuf plugins or special
|
2076
|
+
cases like the proto compiler. Other uses are discouraged and
|
2077
|
+
developers should rely on the protoreflect APIs for their client language.
|
1958
2078
|
"""
|
1959
2079
|
|
1960
2080
|
deprecated: "bool" = betterproto2.field(33, betterproto2.TYPE_BOOL)
|
spiral/scan.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
from collections.abc import Iterator
|
2
1
|
from typing import TYPE_CHECKING, Any
|
3
2
|
|
4
3
|
import pyarrow as pa
|
@@ -120,8 +119,11 @@ class Scan:
|
|
120
119
|
self,
|
121
120
|
shuffle: ShuffleStrategy | None = None,
|
122
121
|
batch_readahead: int | None = None,
|
122
|
+
num_workers: int | None = None,
|
123
|
+
worker_id: int | None = None,
|
124
|
+
infinite: bool = False,
|
123
125
|
) -> "hf.IterableDataset":
|
124
|
-
"""Returns
|
126
|
+
"""Returns a Huggingface's IterableDataset.
|
125
127
|
|
126
128
|
Requires `datasets` package to be installed.
|
127
129
|
|
@@ -130,39 +132,25 @@ class Scan:
|
|
130
132
|
batch_readahead: Controls how many batches to read ahead concurrently.
|
131
133
|
If pipeline includes work after reading (e.g. decoding, transforming, ...) this can be set higher.
|
132
134
|
Otherwise, it should be kept low to reduce next batch latency. Defaults to 2.
|
135
|
+
num_workers: If not None, shards the scan across multiple workers.
|
136
|
+
Must be used together with worker_id.
|
137
|
+
worker_id: If not None, the id of the current worker.
|
138
|
+
Scan will only return a subset of the data corresponding to the worker_id.
|
139
|
+
infinite: If True, the returned IterableDataset will loop infinitely over the data,
|
140
|
+
re-shuffling ranges after exhausting all data.
|
133
141
|
"""
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
yield i, pa.Table.from_batches([rb], stream.schema)
|
147
|
-
|
148
|
-
def _hf_compatible_schema(schema: pa.Schema) -> pa.Schema:
|
149
|
-
"""
|
150
|
-
Replace string-view columns in the schema with strings. We do use this converted schema
|
151
|
-
as Features in the returned Dataset.
|
152
|
-
Remove this method once we have https://github.com/huggingface/datasets/pull/7718
|
153
|
-
"""
|
154
|
-
new_fields = [
|
155
|
-
pa.field(field.name, pa.string(), nullable=field.nullable, metadata=field.metadata)
|
156
|
-
if field.type == pa.string_view()
|
157
|
-
else field
|
158
|
-
for field in schema
|
159
|
-
]
|
160
|
-
return pa.schema(new_fields)
|
161
|
-
|
162
|
-
# NOTE: generate_tables_fn type annotations are wrong, return type must be an iterable of tuples.
|
163
|
-
ex_iterable = ArrowExamplesIterable(generate_tables_fn=_generate_tables, kwargs={}) # type: ignore
|
164
|
-
info = DatasetInfo(features=Features.from_arrow_schema(_hf_compatible_schema(self.schema.to_arrow())))
|
165
|
-
return IterableDataset(ex_iterable=ex_iterable, info=info)
|
142
|
+
|
143
|
+
stream = self.core.to_shuffled_record_batches(
|
144
|
+
shuffle,
|
145
|
+
batch_readahead,
|
146
|
+
num_workers,
|
147
|
+
worker_id,
|
148
|
+
infinite,
|
149
|
+
)
|
150
|
+
|
151
|
+
from spiral.iterable_dataset import to_iterable_dataset
|
152
|
+
|
153
|
+
return to_iterable_dataset(stream)
|
166
154
|
|
167
155
|
def _splits(self) -> list[KeyRange]:
|
168
156
|
# Splits the scan into a set of key ranges.
|
spiral/settings.py
CHANGED
spiral/snapshot.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import TYPE_CHECKING
|
2
2
|
|
3
|
+
from spiral import ShuffleStrategy
|
3
4
|
from spiral.core.table import Snapshot as CoreSnapshot
|
4
5
|
from spiral.core.table.spec import Schema
|
5
6
|
from spiral.types_ import Timestamp
|
@@ -8,6 +9,7 @@ if TYPE_CHECKING:
|
|
8
9
|
import duckdb
|
9
10
|
import polars as pl
|
10
11
|
import pyarrow.dataset as ds
|
12
|
+
import torch.utils.data as torchdata # noqa
|
11
13
|
|
12
14
|
from spiral.table import Table
|
13
15
|
|
@@ -53,3 +55,17 @@ class Snapshot:
|
|
53
55
|
import duckdb
|
54
56
|
|
55
57
|
return duckdb.from_arrow(self.to_dataset())
|
58
|
+
|
59
|
+
def to_iterable_dataset(
|
60
|
+
self,
|
61
|
+
*,
|
62
|
+
shuffle: ShuffleStrategy | None = None,
|
63
|
+
batch_readahead: int | None = None,
|
64
|
+
infinite: bool = False,
|
65
|
+
) -> "torchdata.IterableDataset":
|
66
|
+
"""Returns an iterable dataset compatible with `torch.IterableDataset`.
|
67
|
+
|
68
|
+
See `Table` docs for details on the parameters.
|
69
|
+
"""
|
70
|
+
# TODO(marko): WIP.
|
71
|
+
raise NotImplementedError
|
spiral/streaming_/stream.py
CHANGED
@@ -25,12 +25,16 @@ class SpiralStream:
|
|
25
25
|
"""
|
26
26
|
|
27
27
|
def __init__(
|
28
|
-
self,
|
28
|
+
self,
|
29
|
+
scan: CoreScan,
|
30
|
+
shards: list[Shard],
|
31
|
+
cache_dir: str | None = None,
|
32
|
+
shard_row_block_size: int | None = None,
|
29
33
|
):
|
30
34
|
self._scan = scan
|
31
35
|
# TODO(marko): Read shards only on world.is_local_leader in `get_shards` and materialize on disk.
|
32
36
|
self._shards = shards
|
33
|
-
self.
|
37
|
+
self._shard_row_block_size = shard_row_block_size or 8192
|
34
38
|
|
35
39
|
if cache_dir is not None:
|
36
40
|
if not os.path.exists(cache_dir):
|
@@ -99,7 +103,7 @@ class SpiralStream:
|
|
99
103
|
shard_path,
|
100
104
|
shard.shard.key_range,
|
101
105
|
expected_cardinality=shard.shard.cardinality,
|
102
|
-
shard_row_block_size=self.
|
106
|
+
shard_row_block_size=self._shard_row_block_size,
|
103
107
|
)
|
104
108
|
|
105
109
|
# Get the size of the file on disk.
|