pyspiral 0.8.9__cp311-abi3-macosx_11_0_arm64.whl → 0.9.9__cp311-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,8 @@ import pyarrow as pa
5
5
  from spiral.api.types import DatasetName, IndexName, ProjectId, RootUri, TableId, TableName
6
6
  from spiral.core.authn import Authn
7
7
  from spiral.core.config import ClientSettings
8
- from spiral.core.table import ColumnGroupState, KeyRange, KeySpaceState, Scan, ScanState, Snapshot, Table, Transaction
8
+ from spiral.core.table import KeyRange, Scan, ScanContext, Snapshot, Table, Transaction
9
+ from spiral.core.table.manifests import FragmentManifest
9
10
  from spiral.core.table.spec import ColumnGroup, Schema
10
11
  from spiral.expressions import Expr
11
12
 
@@ -36,16 +37,23 @@ class Spiral:
36
37
  asof: int | None = None,
37
38
  shard: Shard | None = None,
38
39
  key_columns: KeyColumns | None = None,
40
+ progress: bool = True,
39
41
  ) -> Scan:
40
42
  """Construct a table scan."""
41
43
  ...
42
44
 
43
- def load_scan(self, plan_state: ScanState) -> Scan:
44
- """Load a scan from a serialized scan state."""
45
+ def load_scan(self, context: ScanContext) -> Scan:
46
+ """Load a scan from a serialized scan context."""
45
47
  ...
46
48
 
47
- def transaction(self, table: Table, *, partition_max_bytes: int | None = None) -> Transaction:
48
- """Being a table transaction."""
49
+ def transaction(
50
+ self,
51
+ table: Table,
52
+ *,
53
+ partition_max_bytes: int | None = None,
54
+ compact_threshold: int | None = None,
55
+ ) -> Transaction:
56
+ """Begin a table transaction."""
49
57
  ...
50
58
 
51
59
  def search(
@@ -220,31 +228,32 @@ class Internal:
220
228
  Flush the write-ahead log of the table.
221
229
  """
222
230
  ...
223
- def update_text_index(self, index: TextIndex, snapshot: Snapshot) -> None:
231
+ def truncate_metadata(self, table: Table) -> None:
224
232
  """
225
- Index table changes up to the given snapshot.
233
+ Truncate the column group metadata of the table.
234
+
235
+ This removes compacted fragments from metadata.
236
+ IMPORTANT: The command will break as-of before truncation for the table.
226
237
  """
227
238
  ...
228
- def update_key_space_index(self, index: KeySpaceIndex, snapshot: Snapshot) -> None:
239
+ def update_text_index(self, index: TextIndex, snapshot: Snapshot) -> None:
229
240
  """
230
241
  Index table changes up to the given snapshot.
231
242
  """
232
243
  ...
233
- def key_space_state(self, snapshot: Snapshot) -> KeySpaceState:
244
+ def update_key_space_index(self, index: KeySpaceIndex, snapshot: Snapshot) -> None:
234
245
  """
235
- The key space state for the table.
246
+ Index table changes up to the given snapshot.
236
247
  """
237
248
  ...
238
- def column_group_state(
239
- self, snapshot: Snapshot, key_space_state: KeySpaceState, column_group: ColumnGroup
240
- ) -> ColumnGroupState:
249
+ def key_space_manifest(self, snapshot: Snapshot) -> FragmentManifest:
241
250
  """
242
- The state the column group of the table.
251
+ The manifest of the key space of the table as of the given snapshot.
243
252
  """
244
253
  ...
245
- def column_groups_states(self, snapshot: Snapshot, key_space_state: KeySpaceState) -> list[ColumnGroupState]:
254
+ def column_group_manifest(self, snapshot: Snapshot, column_group: ColumnGroup) -> FragmentManifest:
246
255
  """
247
- The state of each column group of the table.
256
+ The manifest of the given column group of the table as of the given snapshot.
248
257
  """
249
258
  ...
250
259
  def key_space_index_shards(self, index: KeySpaceIndex) -> list[Shard]:
@@ -52,10 +52,12 @@ class Snapshot:
52
52
  table: Table
53
53
  wal: WriteAheadLog
54
54
 
55
- class ScanState:
56
- def to_json(self) -> str: ...
55
+ def column_groups(self) -> list[ColumnGroup]: ...
56
+
57
+ class ScanContext:
58
+ def to_bytes_compressed(self) -> bytes: ...
57
59
  @staticmethod
58
- def from_json(json: str) -> ScanState: ...
60
+ def from_bytes_compressed(compressed: bytes) -> ScanContext: ...
59
61
 
60
62
  class MaterializablePlan:
61
63
  pass
@@ -73,26 +75,36 @@ class Scan:
73
75
  def is_empty(self) -> bool: ...
74
76
  def shards(self) -> list[Shard]: ...
75
77
  def table_ids(self) -> list[str]: ...
78
+ def context(self) -> ScanContext: ...
76
79
  def column_groups(self) -> list[ColumnGroup]: ...
77
- def column_group_state(self, column_group: ColumnGroup) -> ColumnGroupState: ...
78
- def key_space_state(self, table_id: str) -> KeySpaceState: ...
79
- def plan_state(self) -> ScanState: ...
80
+ def key_space_manifest(self, table_id: str) -> FragmentManifest:
81
+ """
82
+ Manifest of the key fragments for the given table id.
83
+ """
84
+ ...
85
+ def column_group_manifest(self, column_group: ColumnGroup) -> FragmentManifest:
86
+ """
87
+ Manifest of the fragments for the given column group.
88
+ """
89
+ ...
90
+ def plan_context(self) -> ScanContext: ...
80
91
  def materializable_plan(self) -> MaterializablePlan: ...
81
92
  def to_record_batches(
82
93
  self,
83
94
  *,
84
95
  shards: list[Shard] | None = None,
85
- key_table: pa.Table | pa.RecordBatch | None = None,
96
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
86
97
  batch_readahead: int | None = None,
87
- progress: bool = True,
98
+ batch_aligned: bool = False,
99
+ hide_progress_bar: bool = False,
88
100
  ) -> pa.RecordBatchReader: ...
89
101
  def to_unordered_record_batches(
90
102
  self,
91
103
  *,
92
104
  shards: list[Shard] | None = None,
93
- key_table: pa.Table | pa.RecordBatch | None = None,
105
+ key_table: pa.Table | pa.RecordBatchReader | None = None,
94
106
  batch_readahead: int | None = None,
95
- progress: bool = True,
107
+ hide_progress_bar: bool = False,
96
108
  ) -> pa.RecordBatchReader: ...
97
109
  def to_shuffled_record_batches(
98
110
  self,
@@ -115,17 +127,6 @@ class Scan:
115
127
  ) -> EvaluatedPlanStream: ...
116
128
  def metrics(self) -> dict[str, Any]: ...
117
129
 
118
- class KeySpaceState:
119
- manifest: FragmentManifest
120
-
121
- def key_schema(self) -> Schema: ...
122
-
123
- class ColumnGroupState:
124
- manifest: FragmentManifest
125
- column_group: ColumnGroup
126
-
127
- def schema(self) -> Schema: ...
128
-
129
130
  class Transaction:
130
131
  status: str
131
132
 
@@ -137,6 +138,7 @@ class Transaction:
137
138
  def ops(self) -> list[Operation]: ...
138
139
  def take(self) -> list[Operation]: ...
139
140
  def include(self, ops: list[Operation]): ...
140
- def commit(self, *, compact: bool = False): ...
141
+ def commit(self): ...
141
142
  def abort(self): ...
142
143
  def is_empty(self) -> bool: ...
144
+ def snapshot(self) -> Snapshot: ...
spiral/debug/manifests.py CHANGED
@@ -14,9 +14,9 @@ def display_scan_manifests(scan: Scan):
14
14
  if len(scan.table_ids()) != 1:
15
15
  raise NotImplementedError("Multiple table scans are not supported.")
16
16
  table_id = scan.table_ids()[0]
17
- key_space_manifest = scan.key_space_state(table_id).manifest
17
+ key_space_manifest = scan.key_space_manifest(table_id)
18
18
  column_group_manifests = [
19
- (column_group, scan.column_group_state(column_group).manifest) for column_group in scan.column_groups()
19
+ (column_group, scan.column_group_manifest(column_group)) for column_group in scan.column_groups()
20
20
  ]
21
21
 
22
22
  display_manifests(key_space_manifest, column_group_manifests, scan.key_schema(), None)
@@ -61,8 +61,10 @@ def _table_of_fragments(manifest: FragmentManifest, title: str, key_schema: Sche
61
61
  # Create rich table
62
62
  table = Table(title=None, show_header=True, header_style="bold")
63
63
  table.add_column("ID", style="cyan", no_wrap=True)
64
- table.add_column("Size (Metadata)", justify="right")
64
+ table.add_column("Data", justify="right")
65
+ table.add_column("Metadata", justify="right")
65
66
  table.add_column("Format", justify="center")
67
+ table.add_column("Key Space", justify="center")
66
68
  table.add_column("Key Span", justify="center")
67
69
  table.add_column("Key Range", justify="center")
68
70
  table.add_column("Level", justify="center")
@@ -74,12 +76,20 @@ def _table_of_fragments(manifest: FragmentManifest, title: str, key_schema: Sche
74
76
  if max_rows is not None and i >= max_rows:
75
77
  break
76
78
 
77
- committed_str = str(datetime_.from_timestamp_micros(fragment.committed_at)) if fragment.committed_at else "N/A"
78
- compacted_str = str(datetime_.from_timestamp_micros(fragment.compacted_at)) if fragment.compacted_at else "N/A"
79
-
80
- size_with_metadata = (
81
- f"{_format_bytes(fragment.size_bytes)} ({_format_bytes(len(fragment.format_metadata or b''))})"
79
+ committed_str = (
80
+ datetime_.from_timestamp_micros(fragment.committed_at).strftime("%Y-%m-%d %H:%M:%S")
81
+ if fragment.committed_at
82
+ else "N/A"
83
+ )
84
+ compacted_str = (
85
+ datetime_.from_timestamp_micros(fragment.compacted_at).strftime("%Y-%m-%d %H:%M:%S")
86
+ if fragment.compacted_at
87
+ else "N/A"
82
88
  )
89
+
90
+ data_size = _format_bytes(fragment.size_bytes)
91
+ metadata_size = _format_bytes(len(fragment.format_metadata or b""))
92
+ key_space = fragment.ks_id
83
93
  key_span = f"{fragment.key_span.begin}..{fragment.key_span.end}"
84
94
  min_key = pretty_key(bytes(fragment.key_extent.min), key_schema)
85
95
  max_key = pretty_key(bytes(fragment.key_extent.max), key_schema)
@@ -91,8 +101,10 @@ def _table_of_fragments(manifest: FragmentManifest, title: str, key_schema: Sche
91
101
 
92
102
  table.add_row(
93
103
  fragment.id,
94
- size_with_metadata,
104
+ data_size,
105
+ metadata_size,
95
106
  str(fragment.format),
107
+ key_space,
96
108
  key_span,
97
109
  key_range,
98
110
  str(fragment.level),
spiral/debug/scan.py CHANGED
@@ -15,18 +15,16 @@ def show_scan(scan: Scan):
15
15
  column_groups = scan.column_groups()
16
16
 
17
17
  splits = [s.key_range for s in scan.shards()]
18
- key_space_state = scan.key_space_state(table_id)
18
+ key_space_manifest = scan.key_space_manifest(table_id)
19
19
 
20
20
  # Collect all key bounds from all manifests. This makes sure all visualizations are aligned.
21
21
  key_points = set()
22
- key_space_manifest = key_space_state.manifest
23
22
  for i in range(len(key_space_manifest)):
24
23
  fragment_file = key_space_manifest[i]
25
24
  key_points.add(fragment_file.key_extent.min)
26
25
  key_points.add(fragment_file.key_extent.max)
27
26
  for cg in column_groups:
28
- cg_scan = scan.column_group_state(cg)
29
- cg_manifest = cg_scan.manifest
27
+ cg_manifest = scan.column_group_manifest(cg)
30
28
  for i in range(len(cg_manifest)):
31
29
  fragment_file = cg_manifest[i]
32
30
  key_points.add(fragment_file.key_extent.min)
@@ -39,9 +37,9 @@ def show_scan(scan: Scan):
39
37
 
40
38
  show_manifest(key_space_manifest, scope="Key space", key_points=key_points, splits=splits)
41
39
  for cg in scan.column_groups():
42
- cg_scan = scan.column_group_state(cg)
40
+ cg_manifest = scan.column_group_manifest(cg)
43
41
  # Skip table id from the start of the column group.
44
- show_manifest(cg_scan.manifest, scope=".".join(cg.path[1:]), key_points=key_points, splits=splits)
42
+ show_manifest(cg_manifest, scope=".".join(cg.path[1:]), key_points=key_points, splits=splits)
45
43
 
46
44
 
47
45
  def show_manifest(manifest: FragmentManifest, scope: str = None, key_points: list[Key] = None, splits: list = None):
spiral/demo.py CHANGED
@@ -1,16 +1,69 @@
1
1
  """Demo data to play with SpiralDB"""
2
2
 
3
3
  import functools
4
+ import hashlib
5
+ import os
4
6
  import time
7
+ from pathlib import Path
5
8
 
6
9
  import duckdb
10
+ import numpy as np
7
11
  import pandas as pd
8
12
  import pyarrow as pa
13
+ import pyarrow.parquet as pq
9
14
  from datasets import load_dataset
10
15
 
11
16
  from spiral import Project, Spiral, Table
12
17
 
13
18
 
19
+ # Cache configuration
20
+ def _get_cache_dir() -> Path | None:
21
+ """Get cache directory from environment variable, or None if caching is disabled."""
22
+ cache_dir = os.environ.get("SPIRAL_DEMO_CACHE_DIR")
23
+ if cache_dir:
24
+ path = Path(cache_dir)
25
+ path.mkdir(parents=True, exist_ok=True)
26
+ return path
27
+ return None
28
+
29
+
30
+ def _cache_key(*parts: str) -> str:
31
+ """Generate a cache key from components."""
32
+ return "-".join(str(p).replace("-", "_") for p in parts)
33
+
34
+
35
+ def _get_cached_table(cache_key: str) -> pa.Table | None:
36
+ """Load Arrow table from cache if available."""
37
+ cache_dir = _get_cache_dir()
38
+ if not cache_dir:
39
+ return None
40
+
41
+ cache_file = cache_dir / f"{cache_key}.parquet"
42
+ if not cache_file.exists():
43
+ return None
44
+
45
+ try:
46
+ return pq.read_table(cache_file)
47
+ except Exception as e:
48
+ # On any error (corruption, etc.), return None to trigger re-download
49
+ print(f"Warning: Failed to load cache {cache_file}: {e}")
50
+ return None
51
+
52
+
53
+ def _save_to_cache(cache_key: str, table: pa.Table) -> None:
54
+ """Save Arrow table to cache."""
55
+ cache_dir = _get_cache_dir()
56
+ if not cache_dir:
57
+ return
58
+
59
+ cache_file = cache_dir / f"{cache_key}.parquet"
60
+ try:
61
+ pq.write_table(table, cache_file, compression="zstd")
62
+ print(f"Cached data to {cache_file}")
63
+ except Exception as e:
64
+ print(f"Warning: Failed to save cache {cache_file}: {e}")
65
+
66
+
14
67
  def _install_duckdb_extension(name: str, max_retries: int = 3) -> None:
15
68
  """Install and load a DuckDB extension with retry logic for flaky CI environments."""
16
69
  for attempt in range(max_retries):
@@ -30,22 +83,37 @@ def demo_project(sp: Spiral) -> Project:
30
83
 
31
84
 
32
85
  @functools.lru_cache(maxsize=1)
33
- def images(sp: Spiral) -> Table:
86
+ def images(sp: Spiral, limit=10) -> Table:
34
87
  table = demo_project(sp).create_table(
35
88
  "openimages.images-v1", key_schema=pa.schema([("idx", pa.int64())]), exist_ok=False
36
89
  )
37
90
 
38
- # Load URLs from a TSV file
39
- df = pd.read_csv(
40
- "https://storage.googleapis.com/cvdf-datasets/oid/open-images-dataset-validation.tsv",
41
- names=["url", "size", "etag"],
42
- skiprows=1,
43
- sep="\t",
44
- header=None,
45
- )
46
- # For this example, we load just a few rows, but Spiral can handle many more.
47
- df = pa.Table.from_pandas(df[:10])
48
- df = df.append_column("idx", pa.array(range(len(df))))
91
+ # Try to load from cache first
92
+ # Use a hash of the URL to create a stable cache key
93
+ url = "https://storage.googleapis.com/cvdf-datasets/oid/open-images-dataset-validation.tsv"
94
+ url_hash = hashlib.md5(url.encode()).hexdigest()[:8]
95
+ cache_key = _cache_key("images", "v1", f"url-{url_hash}", f"limit-{limit}")
96
+ df = _get_cached_table(cache_key)
97
+
98
+ if df is None:
99
+ # Cache miss - download from Google Cloud Storage
100
+ print(f"Cache miss for {cache_key}, downloading from GCS...")
101
+ # Load URLs from a TSV file
102
+ df_pandas = pd.read_csv(
103
+ url,
104
+ names=["url", "size", "etag"],
105
+ skiprows=1,
106
+ sep="\t",
107
+ header=None,
108
+ )
109
+ # For this example, we load just a few rows, but Spiral can handle many more.
110
+ df = pa.Table.from_pandas(df_pandas[:limit])
111
+ df = df.append_column("idx", pa.array(range(len(df))))
112
+
113
+ # Save to cache for future runs
114
+ _save_to_cache(cache_key, df)
115
+ else:
116
+ print(f"Cache hit for {cache_key}")
49
117
 
50
118
  # Write just the metadata - lightweight and fast
51
119
  table.write(df)
@@ -57,30 +125,44 @@ def gharchive(sp: Spiral, limit=100, period=None) -> Table:
57
125
  if period is None:
58
126
  period = pd.Period("2023-01-01T00:00:00Z", freq="h")
59
127
 
60
- _install_duckdb_extension("httpfs")
128
+ # Try to load from cache first
129
+ period_str = f"{period.strftime('%Y-%m-%d')}-{str(period.hour)}"
130
+ cache_key = _cache_key("gharchive", "v1", f"period-{period_str}", f"limit-{limit}")
131
+ cached_events = _get_cached_table(cache_key)
132
+
133
+ if cached_events is None:
134
+ # Cache miss - download from gharchive
135
+ print(f"Cache miss for {cache_key}, downloading from gharchive.org...")
136
+ _install_duckdb_extension("httpfs")
137
+
138
+ json_gz_url = f"https://data.gharchive.org/{period_str}.json.gz"
139
+ arrow_table = (
140
+ duckdb.read_json(json_gz_url, union_by_name=True)
141
+ .limit(limit)
142
+ .select("""
143
+ * REPLACE (
144
+ cast(created_at AS TIMESTAMP_MS) AS created_at,
145
+ )
146
+ """)
147
+ .to_arrow_table()
148
+ )
61
149
 
62
- json_gz_url = f"https://data.gharchive.org/{period.strftime('%Y-%m-%d')}-{str(period.hour)}.json.gz"
63
- arrow_table = (
64
- duckdb.read_json(json_gz_url, union_by_name=True)
65
- .limit(limit)
66
- .select("""
67
- * REPLACE (
68
- cast(created_at AS TIMESTAMP_MS) AS created_at,
150
+ events = duckdb.from_arrow(arrow_table).order("created_at, id").distinct().to_arrow_table()
151
+ events = (
152
+ events.drop_columns("id")
153
+ .add_column(0, "id", events["id"].cast(pa.large_string()))
154
+ .drop_columns("created_at")
155
+ .add_column(0, "created_at", events["created_at"].cast(pa.timestamp("ms")))
156
+ .drop_columns("org")
69
157
  )
70
- """)
71
- .to_arrow_table()
72
- )
73
158
 
74
- events = duckdb.from_arrow(arrow_table).order("created_at, id").distinct().to_arrow_table()
75
- events = (
76
- events.drop_columns("id")
77
- .add_column(0, "id", events["id"].cast(pa.large_string()))
78
- .drop_columns("created_at")
79
- .add_column(0, "created_at", events["created_at"].cast(pa.timestamp("ms")))
80
- .drop_columns("org")
81
- )
159
+ # Save to cache for future runs
160
+ _save_to_cache(cache_key, events)
161
+ else:
162
+ print(f"Cache hit for {cache_key}")
163
+ events = cached_events
82
164
 
83
- key_schema = pa.schema([("created_at", pa.timestamp("ms")), ("id", pa.string_view())])
165
+ key_schema = pa.schema([("created_at", pa.timestamp("ms")), ("id", pa.string())])
84
166
  table = demo_project(sp).create_table("gharchive.events", key_schema=key_schema, exist_ok=False)
85
167
  table.write(events, push_down_nulls=True)
86
168
  return table
@@ -88,13 +170,38 @@ def gharchive(sp: Spiral, limit=100, period=None) -> Table:
88
170
 
89
171
  @functools.lru_cache(maxsize=1)
90
172
  def fineweb(sp: Spiral, limit=100) -> Table:
91
- table = demo_project(sp).create_table(
92
- "fineweb.v1", key_schema=pa.schema([("id", pa.string_view())]), exist_ok=False
93
- )
173
+ table = demo_project(sp).create_table("fineweb.v1", key_schema=pa.schema([("id", pa.string())]), exist_ok=False)
174
+
175
+ # Try to load from cache first
176
+ cache_key = _cache_key("fineweb", "v1", f"limit-{limit}")
177
+ arrow_table = _get_cached_table(cache_key)
178
+
179
+ if arrow_table is None:
180
+ # Cache miss - download from HuggingFace
181
+ print(f"Cache miss for {cache_key}, downloading from HuggingFace...")
182
+ ds = load_dataset("HuggingFaceFW/fineweb", "sample-10BT", streaming=True)
183
+ data = ds["train"].take(limit)
184
+ arrow_table = pa.Table.from_pylist(data.to_list())
94
185
 
95
- ds = load_dataset("HuggingFaceFW/fineweb", "sample-10BT", streaming=True)
96
- data = ds["train"].take(limit)
97
- arrow_table = pa.Table.from_pylist(data.to_list())
186
+ # Save to cache for future runs
187
+ _save_to_cache(cache_key, arrow_table)
188
+ else:
189
+ print(f"Cache hit for {cache_key}")
98
190
 
99
191
  table.write(arrow_table, push_down_nulls=True)
100
192
  return table
193
+
194
+
195
+ @functools.lru_cache(maxsize=1)
196
+ def abc(sp: Spiral, limit=100) -> Table:
197
+ table = demo_project(sp).create_table("abc", key_schema=pa.schema([("a", pa.int64())]), exist_ok=False)
198
+
199
+ table.write(
200
+ {
201
+ "a": pa.array(np.arange(limit)),
202
+ "b": pa.array(np.arange(100, 100 + limit)),
203
+ "c": pa.array(np.repeat(99, limit)),
204
+ }
205
+ )
206
+
207
+ return table
spiral/enrichment.py CHANGED
@@ -2,18 +2,18 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import logging
5
- from functools import partial
5
+ from functools import partial, reduce
6
6
  from typing import TYPE_CHECKING
7
7
 
8
8
  from spiral.core.client import Shard
9
9
  from spiral.core.table import KeyRange
10
- from spiral.core.table.spec import Key, Operation
10
+ from spiral.core.table.spec import Key
11
11
  from spiral.expressions import Expr
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  import dask.distributed
15
15
 
16
- from spiral import Scan, Table
16
+ from spiral import Scan, Table, TransactionOps
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
@@ -52,12 +52,13 @@ class Enrichment:
52
52
  """The filter expression."""
53
53
  return self._where
54
54
 
55
- def _scan(self) -> Scan:
56
- return self._table.spiral.scan(self._projection, where=self._where)
55
+ def _scan(self, shard: Shard | None = None) -> Scan:
56
+ return self._table.spiral.scan(self._projection, where=self._where, shard=shard)
57
57
 
58
58
  def apply(
59
59
  self,
60
60
  *,
61
+ shards: list[Shard] | None = None,
61
62
  txn_dump: str | None = None,
62
63
  ) -> None:
63
64
  """Apply the enrichment onto the table in a streaming fashion.
@@ -65,12 +66,17 @@ class Enrichment:
65
66
  For large tables, consider using `apply_dask` for distributed execution.
66
67
 
67
68
  Args:
69
+ shards: Optional list of shards to process.
68
70
  txn_dump: Optional path to dump the transaction JSON for debugging.
69
71
  """
72
+ # Combine multiple shards into one covering the full key range.
73
+ encompassing_shard: Shard | None = None
74
+ if shards:
75
+ encompassing_shard = reduce(lambda a, b: a | b, shards)
70
76
 
71
77
  txn = self._table.txn()
72
78
 
73
- txn.writeback(self._scan())
79
+ txn.writeback(self._scan(encompassing_shard), shards=shards)
74
80
 
75
81
  if txn.is_empty():
76
82
  logger.warning("Transaction not committed. No rows were read for enrichment.")
@@ -150,7 +156,7 @@ class Enrichment:
150
156
  _compute = partial(
151
157
  _enrichment_task,
152
158
  config_json=self._table.spiral.config.to_json(),
153
- state_json=plan_scan.core.plan_state().to_json(),
159
+ state_bytes=plan_scan.core.plan_context().to_bytes_compressed(),
154
160
  output_table_id=self._table.table_id,
155
161
  incremental=checkpoint_dump is not None,
156
162
  )
@@ -210,8 +216,7 @@ class Enrichment:
210
216
  logger.warning("Transaction not committed. No rows were read for enrichment.")
211
217
  return
212
218
 
213
- # Always compact in distributed enrichment.
214
- tx.commit(compact=True, txn_dump=txn_dump)
219
+ tx.commit(txn_dump=txn_dump)
215
220
 
216
221
 
217
222
  def _checkpoint_load_key_ranges(checkpoint_dump: str) -> list[KeyRange] | None:
@@ -243,26 +248,16 @@ def _checkpoint_dump_key_ranges(checkpoint_dump: str, ranges: list[KeyRange]):
243
248
 
244
249
  @dataclasses.dataclass
245
250
  class EnrichmentTaskResult:
246
- ops: list[Operation]
251
+ ops: TransactionOps | None = None
247
252
  error: str | None = None
248
253
 
249
- def __getstate__(self):
250
- return {
251
- "ops": [op.to_json() for op in self.ops],
252
- "error": self.error,
253
- }
254
-
255
- def __setstate__(self, state):
256
- self.ops = [Operation.from_json(op_json) for op_json in state["ops"]]
257
- self.error = state["error"]
258
-
259
254
 
260
255
  # NOTE(marko): This function must be picklable!
261
256
  def _enrichment_task(
262
257
  shard: Shard,
263
258
  *,
264
259
  config_json: str,
265
- state_json: str,
260
+ state_bytes: bytes,
266
261
  output_table_id,
267
262
  incremental: bool,
268
263
  ) -> EnrichmentTaskResult:
@@ -272,7 +267,7 @@ def _enrichment_task(
272
267
 
273
268
  config = ClientSettings.from_json(config_json)
274
269
  sp = Spiral(config=config)
275
- task_scan = sp.resume_scan(state_json)
270
+ task_scan = sp.resume_scan(state_bytes)
276
271
 
277
272
  table = sp.table(output_table_id)
278
273
  task_tx = table.txn()
@@ -284,7 +279,7 @@ def _enrichment_task(
284
279
  task_tx.abort()
285
280
 
286
281
  if incremental:
287
- return EnrichmentTaskResult(ops=[], error=str(e))
282
+ return EnrichmentTaskResult(error=str(e))
288
283
 
289
284
  logger.error(f"Enrichment task failed for shard {shard}: {e}")
290
285
  raise e