pyspiral 0.8.9__cp311-abi3-macosx_11_0_arm64.whl → 0.9.9__cp311-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyspiral-0.8.9.dist-info → pyspiral-0.9.9.dist-info}/METADATA +4 -2
- {pyspiral-0.8.9.dist-info → pyspiral-0.9.9.dist-info}/RECORD +39 -34
- spiral/__init__.py +3 -2
- spiral/_lib.abi3.so +0 -0
- spiral/api/__init__.py +7 -0
- spiral/api/client.py +86 -8
- spiral/api/projects.py +4 -2
- spiral/api/tables.py +77 -0
- spiral/arrow_.py +4 -155
- spiral/cli/app.py +10 -4
- spiral/cli/chooser.py +30 -0
- spiral/cli/fs.py +3 -2
- spiral/cli/iceberg.py +1 -1
- spiral/cli/key_spaces.py +4 -4
- spiral/cli/orgs.py +1 -1
- spiral/cli/projects.py +2 -2
- spiral/cli/tables.py +47 -20
- spiral/cli/telemetry.py +13 -6
- spiral/cli/text.py +4 -4
- spiral/cli/transactions.py +84 -0
- spiral/cli/{types.py → types_.py} +6 -6
- spiral/cli/workloads.py +4 -4
- spiral/client.py +70 -8
- spiral/core/client/__init__.pyi +25 -16
- spiral/core/table/__init__.pyi +24 -22
- spiral/debug/manifests.py +21 -9
- spiral/debug/scan.py +4 -6
- spiral/demo.py +145 -38
- spiral/enrichment.py +18 -23
- spiral/expressions/__init__.py +3 -75
- spiral/expressions/base.py +5 -10
- spiral/huggingface.py +456 -0
- spiral/input.py +131 -0
- spiral/ray_.py +75 -0
- spiral/scan.py +218 -64
- spiral/table.py +5 -4
- spiral/transaction.py +95 -15
- spiral/iterable_dataset.py +0 -106
- {pyspiral-0.8.9.dist-info → pyspiral-0.9.9.dist-info}/WHEEL +0 -0
- {pyspiral-0.8.9.dist-info → pyspiral-0.9.9.dist-info}/entry_points.txt +0 -0
spiral/core/client/__init__.pyi
CHANGED
|
@@ -5,7 +5,8 @@ import pyarrow as pa
|
|
|
5
5
|
from spiral.api.types import DatasetName, IndexName, ProjectId, RootUri, TableId, TableName
|
|
6
6
|
from spiral.core.authn import Authn
|
|
7
7
|
from spiral.core.config import ClientSettings
|
|
8
|
-
from spiral.core.table import
|
|
8
|
+
from spiral.core.table import KeyRange, Scan, ScanContext, Snapshot, Table, Transaction
|
|
9
|
+
from spiral.core.table.manifests import FragmentManifest
|
|
9
10
|
from spiral.core.table.spec import ColumnGroup, Schema
|
|
10
11
|
from spiral.expressions import Expr
|
|
11
12
|
|
|
@@ -36,16 +37,23 @@ class Spiral:
|
|
|
36
37
|
asof: int | None = None,
|
|
37
38
|
shard: Shard | None = None,
|
|
38
39
|
key_columns: KeyColumns | None = None,
|
|
40
|
+
progress: bool = True,
|
|
39
41
|
) -> Scan:
|
|
40
42
|
"""Construct a table scan."""
|
|
41
43
|
...
|
|
42
44
|
|
|
43
|
-
def load_scan(self,
|
|
44
|
-
"""Load a scan from a serialized scan
|
|
45
|
+
def load_scan(self, context: ScanContext) -> Scan:
|
|
46
|
+
"""Load a scan from a serialized scan context."""
|
|
45
47
|
...
|
|
46
48
|
|
|
47
|
-
def transaction(
|
|
48
|
-
|
|
49
|
+
def transaction(
|
|
50
|
+
self,
|
|
51
|
+
table: Table,
|
|
52
|
+
*,
|
|
53
|
+
partition_max_bytes: int | None = None,
|
|
54
|
+
compact_threshold: int | None = None,
|
|
55
|
+
) -> Transaction:
|
|
56
|
+
"""Begin a table transaction."""
|
|
49
57
|
...
|
|
50
58
|
|
|
51
59
|
def search(
|
|
@@ -220,31 +228,32 @@ class Internal:
|
|
|
220
228
|
Flush the write-ahead log of the table.
|
|
221
229
|
"""
|
|
222
230
|
...
|
|
223
|
-
def
|
|
231
|
+
def truncate_metadata(self, table: Table) -> None:
|
|
224
232
|
"""
|
|
225
|
-
|
|
233
|
+
Truncate the column group metadata of the table.
|
|
234
|
+
|
|
235
|
+
This removes compacted fragments from metadata.
|
|
236
|
+
IMPORTANT: The command will break as-of before truncation for the table.
|
|
226
237
|
"""
|
|
227
238
|
...
|
|
228
|
-
def
|
|
239
|
+
def update_text_index(self, index: TextIndex, snapshot: Snapshot) -> None:
|
|
229
240
|
"""
|
|
230
241
|
Index table changes up to the given snapshot.
|
|
231
242
|
"""
|
|
232
243
|
...
|
|
233
|
-
def
|
|
244
|
+
def update_key_space_index(self, index: KeySpaceIndex, snapshot: Snapshot) -> None:
|
|
234
245
|
"""
|
|
235
|
-
|
|
246
|
+
Index table changes up to the given snapshot.
|
|
236
247
|
"""
|
|
237
248
|
...
|
|
238
|
-
def
|
|
239
|
-
self, snapshot: Snapshot, key_space_state: KeySpaceState, column_group: ColumnGroup
|
|
240
|
-
) -> ColumnGroupState:
|
|
249
|
+
def key_space_manifest(self, snapshot: Snapshot) -> FragmentManifest:
|
|
241
250
|
"""
|
|
242
|
-
The
|
|
251
|
+
The manifest of the key space of the table as of the given snapshot.
|
|
243
252
|
"""
|
|
244
253
|
...
|
|
245
|
-
def
|
|
254
|
+
def column_group_manifest(self, snapshot: Snapshot, column_group: ColumnGroup) -> FragmentManifest:
|
|
246
255
|
"""
|
|
247
|
-
The
|
|
256
|
+
The manifest of the given column group of the table as of the given snapshot.
|
|
248
257
|
"""
|
|
249
258
|
...
|
|
250
259
|
def key_space_index_shards(self, index: KeySpaceIndex) -> list[Shard]:
|
spiral/core/table/__init__.pyi
CHANGED
|
@@ -52,10 +52,12 @@ class Snapshot:
|
|
|
52
52
|
table: Table
|
|
53
53
|
wal: WriteAheadLog
|
|
54
54
|
|
|
55
|
-
|
|
56
|
-
|
|
55
|
+
def column_groups(self) -> list[ColumnGroup]: ...
|
|
56
|
+
|
|
57
|
+
class ScanContext:
|
|
58
|
+
def to_bytes_compressed(self) -> bytes: ...
|
|
57
59
|
@staticmethod
|
|
58
|
-
def
|
|
60
|
+
def from_bytes_compressed(compressed: bytes) -> ScanContext: ...
|
|
59
61
|
|
|
60
62
|
class MaterializablePlan:
|
|
61
63
|
pass
|
|
@@ -73,26 +75,36 @@ class Scan:
|
|
|
73
75
|
def is_empty(self) -> bool: ...
|
|
74
76
|
def shards(self) -> list[Shard]: ...
|
|
75
77
|
def table_ids(self) -> list[str]: ...
|
|
78
|
+
def context(self) -> ScanContext: ...
|
|
76
79
|
def column_groups(self) -> list[ColumnGroup]: ...
|
|
77
|
-
def
|
|
78
|
-
|
|
79
|
-
|
|
80
|
+
def key_space_manifest(self, table_id: str) -> FragmentManifest:
|
|
81
|
+
"""
|
|
82
|
+
Manifest of the key fragments for the given table id.
|
|
83
|
+
"""
|
|
84
|
+
...
|
|
85
|
+
def column_group_manifest(self, column_group: ColumnGroup) -> FragmentManifest:
|
|
86
|
+
"""
|
|
87
|
+
Manifest of the fragments for the given column group.
|
|
88
|
+
"""
|
|
89
|
+
...
|
|
90
|
+
def plan_context(self) -> ScanContext: ...
|
|
80
91
|
def materializable_plan(self) -> MaterializablePlan: ...
|
|
81
92
|
def to_record_batches(
|
|
82
93
|
self,
|
|
83
94
|
*,
|
|
84
95
|
shards: list[Shard] | None = None,
|
|
85
|
-
key_table: pa.Table | pa.
|
|
96
|
+
key_table: pa.Table | pa.RecordBatchReader | None = None,
|
|
86
97
|
batch_readahead: int | None = None,
|
|
87
|
-
|
|
98
|
+
batch_aligned: bool = False,
|
|
99
|
+
hide_progress_bar: bool = False,
|
|
88
100
|
) -> pa.RecordBatchReader: ...
|
|
89
101
|
def to_unordered_record_batches(
|
|
90
102
|
self,
|
|
91
103
|
*,
|
|
92
104
|
shards: list[Shard] | None = None,
|
|
93
|
-
key_table: pa.Table | pa.
|
|
105
|
+
key_table: pa.Table | pa.RecordBatchReader | None = None,
|
|
94
106
|
batch_readahead: int | None = None,
|
|
95
|
-
|
|
107
|
+
hide_progress_bar: bool = False,
|
|
96
108
|
) -> pa.RecordBatchReader: ...
|
|
97
109
|
def to_shuffled_record_batches(
|
|
98
110
|
self,
|
|
@@ -115,17 +127,6 @@ class Scan:
|
|
|
115
127
|
) -> EvaluatedPlanStream: ...
|
|
116
128
|
def metrics(self) -> dict[str, Any]: ...
|
|
117
129
|
|
|
118
|
-
class KeySpaceState:
|
|
119
|
-
manifest: FragmentManifest
|
|
120
|
-
|
|
121
|
-
def key_schema(self) -> Schema: ...
|
|
122
|
-
|
|
123
|
-
class ColumnGroupState:
|
|
124
|
-
manifest: FragmentManifest
|
|
125
|
-
column_group: ColumnGroup
|
|
126
|
-
|
|
127
|
-
def schema(self) -> Schema: ...
|
|
128
|
-
|
|
129
130
|
class Transaction:
|
|
130
131
|
status: str
|
|
131
132
|
|
|
@@ -137,6 +138,7 @@ class Transaction:
|
|
|
137
138
|
def ops(self) -> list[Operation]: ...
|
|
138
139
|
def take(self) -> list[Operation]: ...
|
|
139
140
|
def include(self, ops: list[Operation]): ...
|
|
140
|
-
def commit(self
|
|
141
|
+
def commit(self): ...
|
|
141
142
|
def abort(self): ...
|
|
142
143
|
def is_empty(self) -> bool: ...
|
|
144
|
+
def snapshot(self) -> Snapshot: ...
|
spiral/debug/manifests.py
CHANGED
|
@@ -14,9 +14,9 @@ def display_scan_manifests(scan: Scan):
|
|
|
14
14
|
if len(scan.table_ids()) != 1:
|
|
15
15
|
raise NotImplementedError("Multiple table scans are not supported.")
|
|
16
16
|
table_id = scan.table_ids()[0]
|
|
17
|
-
key_space_manifest = scan.
|
|
17
|
+
key_space_manifest = scan.key_space_manifest(table_id)
|
|
18
18
|
column_group_manifests = [
|
|
19
|
-
(column_group, scan.
|
|
19
|
+
(column_group, scan.column_group_manifest(column_group)) for column_group in scan.column_groups()
|
|
20
20
|
]
|
|
21
21
|
|
|
22
22
|
display_manifests(key_space_manifest, column_group_manifests, scan.key_schema(), None)
|
|
@@ -61,8 +61,10 @@ def _table_of_fragments(manifest: FragmentManifest, title: str, key_schema: Sche
|
|
|
61
61
|
# Create rich table
|
|
62
62
|
table = Table(title=None, show_header=True, header_style="bold")
|
|
63
63
|
table.add_column("ID", style="cyan", no_wrap=True)
|
|
64
|
-
table.add_column("
|
|
64
|
+
table.add_column("Data", justify="right")
|
|
65
|
+
table.add_column("Metadata", justify="right")
|
|
65
66
|
table.add_column("Format", justify="center")
|
|
67
|
+
table.add_column("Key Space", justify="center")
|
|
66
68
|
table.add_column("Key Span", justify="center")
|
|
67
69
|
table.add_column("Key Range", justify="center")
|
|
68
70
|
table.add_column("Level", justify="center")
|
|
@@ -74,12 +76,20 @@ def _table_of_fragments(manifest: FragmentManifest, title: str, key_schema: Sche
|
|
|
74
76
|
if max_rows is not None and i >= max_rows:
|
|
75
77
|
break
|
|
76
78
|
|
|
77
|
-
committed_str =
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
79
|
+
committed_str = (
|
|
80
|
+
datetime_.from_timestamp_micros(fragment.committed_at).strftime("%Y-%m-%d %H:%M:%S")
|
|
81
|
+
if fragment.committed_at
|
|
82
|
+
else "N/A"
|
|
83
|
+
)
|
|
84
|
+
compacted_str = (
|
|
85
|
+
datetime_.from_timestamp_micros(fragment.compacted_at).strftime("%Y-%m-%d %H:%M:%S")
|
|
86
|
+
if fragment.compacted_at
|
|
87
|
+
else "N/A"
|
|
82
88
|
)
|
|
89
|
+
|
|
90
|
+
data_size = _format_bytes(fragment.size_bytes)
|
|
91
|
+
metadata_size = _format_bytes(len(fragment.format_metadata or b""))
|
|
92
|
+
key_space = fragment.ks_id
|
|
83
93
|
key_span = f"{fragment.key_span.begin}..{fragment.key_span.end}"
|
|
84
94
|
min_key = pretty_key(bytes(fragment.key_extent.min), key_schema)
|
|
85
95
|
max_key = pretty_key(bytes(fragment.key_extent.max), key_schema)
|
|
@@ -91,8 +101,10 @@ def _table_of_fragments(manifest: FragmentManifest, title: str, key_schema: Sche
|
|
|
91
101
|
|
|
92
102
|
table.add_row(
|
|
93
103
|
fragment.id,
|
|
94
|
-
|
|
104
|
+
data_size,
|
|
105
|
+
metadata_size,
|
|
95
106
|
str(fragment.format),
|
|
107
|
+
key_space,
|
|
96
108
|
key_span,
|
|
97
109
|
key_range,
|
|
98
110
|
str(fragment.level),
|
spiral/debug/scan.py
CHANGED
|
@@ -15,18 +15,16 @@ def show_scan(scan: Scan):
|
|
|
15
15
|
column_groups = scan.column_groups()
|
|
16
16
|
|
|
17
17
|
splits = [s.key_range for s in scan.shards()]
|
|
18
|
-
|
|
18
|
+
key_space_manifest = scan.key_space_manifest(table_id)
|
|
19
19
|
|
|
20
20
|
# Collect all key bounds from all manifests. This makes sure all visualizations are aligned.
|
|
21
21
|
key_points = set()
|
|
22
|
-
key_space_manifest = key_space_state.manifest
|
|
23
22
|
for i in range(len(key_space_manifest)):
|
|
24
23
|
fragment_file = key_space_manifest[i]
|
|
25
24
|
key_points.add(fragment_file.key_extent.min)
|
|
26
25
|
key_points.add(fragment_file.key_extent.max)
|
|
27
26
|
for cg in column_groups:
|
|
28
|
-
|
|
29
|
-
cg_manifest = cg_scan.manifest
|
|
27
|
+
cg_manifest = scan.column_group_manifest(cg)
|
|
30
28
|
for i in range(len(cg_manifest)):
|
|
31
29
|
fragment_file = cg_manifest[i]
|
|
32
30
|
key_points.add(fragment_file.key_extent.min)
|
|
@@ -39,9 +37,9 @@ def show_scan(scan: Scan):
|
|
|
39
37
|
|
|
40
38
|
show_manifest(key_space_manifest, scope="Key space", key_points=key_points, splits=splits)
|
|
41
39
|
for cg in scan.column_groups():
|
|
42
|
-
|
|
40
|
+
cg_manifest = scan.column_group_manifest(cg)
|
|
43
41
|
# Skip table id from the start of the column group.
|
|
44
|
-
show_manifest(
|
|
42
|
+
show_manifest(cg_manifest, scope=".".join(cg.path[1:]), key_points=key_points, splits=splits)
|
|
45
43
|
|
|
46
44
|
|
|
47
45
|
def show_manifest(manifest: FragmentManifest, scope: str = None, key_points: list[Key] = None, splits: list = None):
|
spiral/demo.py
CHANGED
|
@@ -1,16 +1,69 @@
|
|
|
1
1
|
"""Demo data to play with SpiralDB"""
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
|
+
import hashlib
|
|
5
|
+
import os
|
|
4
6
|
import time
|
|
7
|
+
from pathlib import Path
|
|
5
8
|
|
|
6
9
|
import duckdb
|
|
10
|
+
import numpy as np
|
|
7
11
|
import pandas as pd
|
|
8
12
|
import pyarrow as pa
|
|
13
|
+
import pyarrow.parquet as pq
|
|
9
14
|
from datasets import load_dataset
|
|
10
15
|
|
|
11
16
|
from spiral import Project, Spiral, Table
|
|
12
17
|
|
|
13
18
|
|
|
19
|
+
# Cache configuration
|
|
20
|
+
def _get_cache_dir() -> Path | None:
|
|
21
|
+
"""Get cache directory from environment variable, or None if caching is disabled."""
|
|
22
|
+
cache_dir = os.environ.get("SPIRAL_DEMO_CACHE_DIR")
|
|
23
|
+
if cache_dir:
|
|
24
|
+
path = Path(cache_dir)
|
|
25
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
26
|
+
return path
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _cache_key(*parts: str) -> str:
|
|
31
|
+
"""Generate a cache key from components."""
|
|
32
|
+
return "-".join(str(p).replace("-", "_") for p in parts)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_cached_table(cache_key: str) -> pa.Table | None:
|
|
36
|
+
"""Load Arrow table from cache if available."""
|
|
37
|
+
cache_dir = _get_cache_dir()
|
|
38
|
+
if not cache_dir:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
cache_file = cache_dir / f"{cache_key}.parquet"
|
|
42
|
+
if not cache_file.exists():
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
return pq.read_table(cache_file)
|
|
47
|
+
except Exception as e:
|
|
48
|
+
# On any error (corruption, etc.), return None to trigger re-download
|
|
49
|
+
print(f"Warning: Failed to load cache {cache_file}: {e}")
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _save_to_cache(cache_key: str, table: pa.Table) -> None:
|
|
54
|
+
"""Save Arrow table to cache."""
|
|
55
|
+
cache_dir = _get_cache_dir()
|
|
56
|
+
if not cache_dir:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
cache_file = cache_dir / f"{cache_key}.parquet"
|
|
60
|
+
try:
|
|
61
|
+
pq.write_table(table, cache_file, compression="zstd")
|
|
62
|
+
print(f"Cached data to {cache_file}")
|
|
63
|
+
except Exception as e:
|
|
64
|
+
print(f"Warning: Failed to save cache {cache_file}: {e}")
|
|
65
|
+
|
|
66
|
+
|
|
14
67
|
def _install_duckdb_extension(name: str, max_retries: int = 3) -> None:
|
|
15
68
|
"""Install and load a DuckDB extension with retry logic for flaky CI environments."""
|
|
16
69
|
for attempt in range(max_retries):
|
|
@@ -30,22 +83,37 @@ def demo_project(sp: Spiral) -> Project:
|
|
|
30
83
|
|
|
31
84
|
|
|
32
85
|
@functools.lru_cache(maxsize=1)
|
|
33
|
-
def images(sp: Spiral) -> Table:
|
|
86
|
+
def images(sp: Spiral, limit=10) -> Table:
|
|
34
87
|
table = demo_project(sp).create_table(
|
|
35
88
|
"openimages.images-v1", key_schema=pa.schema([("idx", pa.int64())]), exist_ok=False
|
|
36
89
|
)
|
|
37
90
|
|
|
38
|
-
#
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
91
|
+
# Try to load from cache first
|
|
92
|
+
# Use a hash of the URL to create a stable cache key
|
|
93
|
+
url = "https://storage.googleapis.com/cvdf-datasets/oid/open-images-dataset-validation.tsv"
|
|
94
|
+
url_hash = hashlib.md5(url.encode()).hexdigest()[:8]
|
|
95
|
+
cache_key = _cache_key("images", "v1", f"url-{url_hash}", f"limit-{limit}")
|
|
96
|
+
df = _get_cached_table(cache_key)
|
|
97
|
+
|
|
98
|
+
if df is None:
|
|
99
|
+
# Cache miss - download from Google Cloud Storage
|
|
100
|
+
print(f"Cache miss for {cache_key}, downloading from GCS...")
|
|
101
|
+
# Load URLs from a TSV file
|
|
102
|
+
df_pandas = pd.read_csv(
|
|
103
|
+
url,
|
|
104
|
+
names=["url", "size", "etag"],
|
|
105
|
+
skiprows=1,
|
|
106
|
+
sep="\t",
|
|
107
|
+
header=None,
|
|
108
|
+
)
|
|
109
|
+
# For this example, we load just a few rows, but Spiral can handle many more.
|
|
110
|
+
df = pa.Table.from_pandas(df_pandas[:limit])
|
|
111
|
+
df = df.append_column("idx", pa.array(range(len(df))))
|
|
112
|
+
|
|
113
|
+
# Save to cache for future runs
|
|
114
|
+
_save_to_cache(cache_key, df)
|
|
115
|
+
else:
|
|
116
|
+
print(f"Cache hit for {cache_key}")
|
|
49
117
|
|
|
50
118
|
# Write just the metadata - lightweight and fast
|
|
51
119
|
table.write(df)
|
|
@@ -57,30 +125,44 @@ def gharchive(sp: Spiral, limit=100, period=None) -> Table:
|
|
|
57
125
|
if period is None:
|
|
58
126
|
period = pd.Period("2023-01-01T00:00:00Z", freq="h")
|
|
59
127
|
|
|
60
|
-
|
|
128
|
+
# Try to load from cache first
|
|
129
|
+
period_str = f"{period.strftime('%Y-%m-%d')}-{str(period.hour)}"
|
|
130
|
+
cache_key = _cache_key("gharchive", "v1", f"period-{period_str}", f"limit-{limit}")
|
|
131
|
+
cached_events = _get_cached_table(cache_key)
|
|
132
|
+
|
|
133
|
+
if cached_events is None:
|
|
134
|
+
# Cache miss - download from gharchive
|
|
135
|
+
print(f"Cache miss for {cache_key}, downloading from gharchive.org...")
|
|
136
|
+
_install_duckdb_extension("httpfs")
|
|
137
|
+
|
|
138
|
+
json_gz_url = f"https://data.gharchive.org/{period_str}.json.gz"
|
|
139
|
+
arrow_table = (
|
|
140
|
+
duckdb.read_json(json_gz_url, union_by_name=True)
|
|
141
|
+
.limit(limit)
|
|
142
|
+
.select("""
|
|
143
|
+
* REPLACE (
|
|
144
|
+
cast(created_at AS TIMESTAMP_MS) AS created_at,
|
|
145
|
+
)
|
|
146
|
+
""")
|
|
147
|
+
.to_arrow_table()
|
|
148
|
+
)
|
|
61
149
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
150
|
+
events = duckdb.from_arrow(arrow_table).order("created_at, id").distinct().to_arrow_table()
|
|
151
|
+
events = (
|
|
152
|
+
events.drop_columns("id")
|
|
153
|
+
.add_column(0, "id", events["id"].cast(pa.large_string()))
|
|
154
|
+
.drop_columns("created_at")
|
|
155
|
+
.add_column(0, "created_at", events["created_at"].cast(pa.timestamp("ms")))
|
|
156
|
+
.drop_columns("org")
|
|
69
157
|
)
|
|
70
|
-
""")
|
|
71
|
-
.to_arrow_table()
|
|
72
|
-
)
|
|
73
158
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
.add_column(0, "created_at", events["created_at"].cast(pa.timestamp("ms")))
|
|
80
|
-
.drop_columns("org")
|
|
81
|
-
)
|
|
159
|
+
# Save to cache for future runs
|
|
160
|
+
_save_to_cache(cache_key, events)
|
|
161
|
+
else:
|
|
162
|
+
print(f"Cache hit for {cache_key}")
|
|
163
|
+
events = cached_events
|
|
82
164
|
|
|
83
|
-
key_schema = pa.schema([("created_at", pa.timestamp("ms")), ("id", pa.
|
|
165
|
+
key_schema = pa.schema([("created_at", pa.timestamp("ms")), ("id", pa.string())])
|
|
84
166
|
table = demo_project(sp).create_table("gharchive.events", key_schema=key_schema, exist_ok=False)
|
|
85
167
|
table.write(events, push_down_nulls=True)
|
|
86
168
|
return table
|
|
@@ -88,13 +170,38 @@ def gharchive(sp: Spiral, limit=100, period=None) -> Table:
|
|
|
88
170
|
|
|
89
171
|
@functools.lru_cache(maxsize=1)
|
|
90
172
|
def fineweb(sp: Spiral, limit=100) -> Table:
|
|
91
|
-
table = demo_project(sp).create_table(
|
|
92
|
-
|
|
93
|
-
|
|
173
|
+
table = demo_project(sp).create_table("fineweb.v1", key_schema=pa.schema([("id", pa.string())]), exist_ok=False)
|
|
174
|
+
|
|
175
|
+
# Try to load from cache first
|
|
176
|
+
cache_key = _cache_key("fineweb", "v1", f"limit-{limit}")
|
|
177
|
+
arrow_table = _get_cached_table(cache_key)
|
|
178
|
+
|
|
179
|
+
if arrow_table is None:
|
|
180
|
+
# Cache miss - download from HuggingFace
|
|
181
|
+
print(f"Cache miss for {cache_key}, downloading from HuggingFace...")
|
|
182
|
+
ds = load_dataset("HuggingFaceFW/fineweb", "sample-10BT", streaming=True)
|
|
183
|
+
data = ds["train"].take(limit)
|
|
184
|
+
arrow_table = pa.Table.from_pylist(data.to_list())
|
|
94
185
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
186
|
+
# Save to cache for future runs
|
|
187
|
+
_save_to_cache(cache_key, arrow_table)
|
|
188
|
+
else:
|
|
189
|
+
print(f"Cache hit for {cache_key}")
|
|
98
190
|
|
|
99
191
|
table.write(arrow_table, push_down_nulls=True)
|
|
100
192
|
return table
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@functools.lru_cache(maxsize=1)
|
|
196
|
+
def abc(sp: Spiral, limit=100) -> Table:
|
|
197
|
+
table = demo_project(sp).create_table("abc", key_schema=pa.schema([("a", pa.int64())]), exist_ok=False)
|
|
198
|
+
|
|
199
|
+
table.write(
|
|
200
|
+
{
|
|
201
|
+
"a": pa.array(np.arange(limit)),
|
|
202
|
+
"b": pa.array(np.arange(100, 100 + limit)),
|
|
203
|
+
"c": pa.array(np.repeat(99, limit)),
|
|
204
|
+
}
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return table
|
spiral/enrichment.py
CHANGED
|
@@ -2,18 +2,18 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import logging
|
|
5
|
-
from functools import partial
|
|
5
|
+
from functools import partial, reduce
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
8
|
from spiral.core.client import Shard
|
|
9
9
|
from spiral.core.table import KeyRange
|
|
10
|
-
from spiral.core.table.spec import Key
|
|
10
|
+
from spiral.core.table.spec import Key
|
|
11
11
|
from spiral.expressions import Expr
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
14
|
import dask.distributed
|
|
15
15
|
|
|
16
|
-
from spiral import Scan, Table
|
|
16
|
+
from spiral import Scan, Table, TransactionOps
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
@@ -52,12 +52,13 @@ class Enrichment:
|
|
|
52
52
|
"""The filter expression."""
|
|
53
53
|
return self._where
|
|
54
54
|
|
|
55
|
-
def _scan(self) -> Scan:
|
|
56
|
-
return self._table.spiral.scan(self._projection, where=self._where)
|
|
55
|
+
def _scan(self, shard: Shard | None = None) -> Scan:
|
|
56
|
+
return self._table.spiral.scan(self._projection, where=self._where, shard=shard)
|
|
57
57
|
|
|
58
58
|
def apply(
|
|
59
59
|
self,
|
|
60
60
|
*,
|
|
61
|
+
shards: list[Shard] | None = None,
|
|
61
62
|
txn_dump: str | None = None,
|
|
62
63
|
) -> None:
|
|
63
64
|
"""Apply the enrichment onto the table in a streaming fashion.
|
|
@@ -65,12 +66,17 @@ class Enrichment:
|
|
|
65
66
|
For large tables, consider using `apply_dask` for distributed execution.
|
|
66
67
|
|
|
67
68
|
Args:
|
|
69
|
+
shards: Optional list of shards to process.
|
|
68
70
|
txn_dump: Optional path to dump the transaction JSON for debugging.
|
|
69
71
|
"""
|
|
72
|
+
# Combine multiple shards into one covering the full key range.
|
|
73
|
+
encompassing_shard: Shard | None = None
|
|
74
|
+
if shards:
|
|
75
|
+
encompassing_shard = reduce(lambda a, b: a | b, shards)
|
|
70
76
|
|
|
71
77
|
txn = self._table.txn()
|
|
72
78
|
|
|
73
|
-
txn.writeback(self._scan())
|
|
79
|
+
txn.writeback(self._scan(encompassing_shard), shards=shards)
|
|
74
80
|
|
|
75
81
|
if txn.is_empty():
|
|
76
82
|
logger.warning("Transaction not committed. No rows were read for enrichment.")
|
|
@@ -150,7 +156,7 @@ class Enrichment:
|
|
|
150
156
|
_compute = partial(
|
|
151
157
|
_enrichment_task,
|
|
152
158
|
config_json=self._table.spiral.config.to_json(),
|
|
153
|
-
|
|
159
|
+
state_bytes=plan_scan.core.plan_context().to_bytes_compressed(),
|
|
154
160
|
output_table_id=self._table.table_id,
|
|
155
161
|
incremental=checkpoint_dump is not None,
|
|
156
162
|
)
|
|
@@ -210,8 +216,7 @@ class Enrichment:
|
|
|
210
216
|
logger.warning("Transaction not committed. No rows were read for enrichment.")
|
|
211
217
|
return
|
|
212
218
|
|
|
213
|
-
|
|
214
|
-
tx.commit(compact=True, txn_dump=txn_dump)
|
|
219
|
+
tx.commit(txn_dump=txn_dump)
|
|
215
220
|
|
|
216
221
|
|
|
217
222
|
def _checkpoint_load_key_ranges(checkpoint_dump: str) -> list[KeyRange] | None:
|
|
@@ -243,26 +248,16 @@ def _checkpoint_dump_key_ranges(checkpoint_dump: str, ranges: list[KeyRange]):
|
|
|
243
248
|
|
|
244
249
|
@dataclasses.dataclass
|
|
245
250
|
class EnrichmentTaskResult:
|
|
246
|
-
ops:
|
|
251
|
+
ops: TransactionOps | None = None
|
|
247
252
|
error: str | None = None
|
|
248
253
|
|
|
249
|
-
def __getstate__(self):
|
|
250
|
-
return {
|
|
251
|
-
"ops": [op.to_json() for op in self.ops],
|
|
252
|
-
"error": self.error,
|
|
253
|
-
}
|
|
254
|
-
|
|
255
|
-
def __setstate__(self, state):
|
|
256
|
-
self.ops = [Operation.from_json(op_json) for op_json in state["ops"]]
|
|
257
|
-
self.error = state["error"]
|
|
258
|
-
|
|
259
254
|
|
|
260
255
|
# NOTE(marko): This function must be picklable!
|
|
261
256
|
def _enrichment_task(
|
|
262
257
|
shard: Shard,
|
|
263
258
|
*,
|
|
264
259
|
config_json: str,
|
|
265
|
-
|
|
260
|
+
state_bytes: bytes,
|
|
266
261
|
output_table_id,
|
|
267
262
|
incremental: bool,
|
|
268
263
|
) -> EnrichmentTaskResult:
|
|
@@ -272,7 +267,7 @@ def _enrichment_task(
|
|
|
272
267
|
|
|
273
268
|
config = ClientSettings.from_json(config_json)
|
|
274
269
|
sp = Spiral(config=config)
|
|
275
|
-
task_scan = sp.resume_scan(
|
|
270
|
+
task_scan = sp.resume_scan(state_bytes)
|
|
276
271
|
|
|
277
272
|
table = sp.table(output_table_id)
|
|
278
273
|
task_tx = table.txn()
|
|
@@ -284,7 +279,7 @@ def _enrichment_task(
|
|
|
284
279
|
task_tx.abort()
|
|
285
280
|
|
|
286
281
|
if incremental:
|
|
287
|
-
return EnrichmentTaskResult(
|
|
282
|
+
return EnrichmentTaskResult(error=str(e))
|
|
288
283
|
|
|
289
284
|
logger.error(f"Enrichment task failed for shard {shard}: {e}")
|
|
290
285
|
raise e
|