pyspiral 0.7.18__cp312-abi3-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. pyspiral-0.7.18.dist-info/METADATA +52 -0
  2. pyspiral-0.7.18.dist-info/RECORD +110 -0
  3. pyspiral-0.7.18.dist-info/WHEEL +4 -0
  4. pyspiral-0.7.18.dist-info/entry_points.txt +3 -0
  5. spiral/__init__.py +55 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +411 -0
  8. spiral/api/__init__.py +78 -0
  9. spiral/api/admin.py +15 -0
  10. spiral/api/client.py +164 -0
  11. spiral/api/filesystems.py +134 -0
  12. spiral/api/key_space_indexes.py +23 -0
  13. spiral/api/organizations.py +77 -0
  14. spiral/api/projects.py +219 -0
  15. spiral/api/telemetry.py +19 -0
  16. spiral/api/text_indexes.py +56 -0
  17. spiral/api/types.py +23 -0
  18. spiral/api/workers.py +40 -0
  19. spiral/api/workloads.py +52 -0
  20. spiral/arrow_.py +216 -0
  21. spiral/cli/__init__.py +88 -0
  22. spiral/cli/__main__.py +4 -0
  23. spiral/cli/admin.py +14 -0
  24. spiral/cli/app.py +108 -0
  25. spiral/cli/console.py +95 -0
  26. spiral/cli/fs.py +76 -0
  27. spiral/cli/iceberg.py +97 -0
  28. spiral/cli/key_spaces.py +103 -0
  29. spiral/cli/login.py +25 -0
  30. spiral/cli/orgs.py +90 -0
  31. spiral/cli/printer.py +53 -0
  32. spiral/cli/projects.py +147 -0
  33. spiral/cli/state.py +7 -0
  34. spiral/cli/tables.py +197 -0
  35. spiral/cli/telemetry.py +17 -0
  36. spiral/cli/text.py +115 -0
  37. spiral/cli/types.py +50 -0
  38. spiral/cli/workloads.py +58 -0
  39. spiral/client.py +256 -0
  40. spiral/core/__init__.pyi +0 -0
  41. spiral/core/_tools/__init__.pyi +5 -0
  42. spiral/core/authn/__init__.pyi +21 -0
  43. spiral/core/client/__init__.pyi +285 -0
  44. spiral/core/config/__init__.pyi +35 -0
  45. spiral/core/expr/__init__.pyi +15 -0
  46. spiral/core/expr/images/__init__.pyi +3 -0
  47. spiral/core/expr/list_/__init__.pyi +4 -0
  48. spiral/core/expr/refs/__init__.pyi +4 -0
  49. spiral/core/expr/str_/__init__.pyi +3 -0
  50. spiral/core/expr/struct_/__init__.pyi +6 -0
  51. spiral/core/expr/text/__init__.pyi +5 -0
  52. spiral/core/expr/udf/__init__.pyi +14 -0
  53. spiral/core/expr/video/__init__.pyi +3 -0
  54. spiral/core/table/__init__.pyi +141 -0
  55. spiral/core/table/manifests/__init__.pyi +35 -0
  56. spiral/core/table/metastore/__init__.pyi +58 -0
  57. spiral/core/table/spec/__init__.pyi +215 -0
  58. spiral/dataloader.py +299 -0
  59. spiral/dataset.py +264 -0
  60. spiral/datetime_.py +27 -0
  61. spiral/debug/__init__.py +0 -0
  62. spiral/debug/manifests.py +87 -0
  63. spiral/debug/metrics.py +56 -0
  64. spiral/debug/scan.py +266 -0
  65. spiral/enrichment.py +306 -0
  66. spiral/expressions/__init__.py +274 -0
  67. spiral/expressions/base.py +167 -0
  68. spiral/expressions/file.py +17 -0
  69. spiral/expressions/http.py +17 -0
  70. spiral/expressions/list_.py +68 -0
  71. spiral/expressions/s3.py +16 -0
  72. spiral/expressions/str_.py +39 -0
  73. spiral/expressions/struct.py +59 -0
  74. spiral/expressions/text.py +62 -0
  75. spiral/expressions/tiff.py +222 -0
  76. spiral/expressions/udf.py +60 -0
  77. spiral/grpc_.py +32 -0
  78. spiral/iceberg.py +31 -0
  79. spiral/iterable_dataset.py +106 -0
  80. spiral/key_space_index.py +44 -0
  81. spiral/project.py +227 -0
  82. spiral/protogen/_/__init__.py +0 -0
  83. spiral/protogen/_/arrow/__init__.py +0 -0
  84. spiral/protogen/_/arrow/flight/__init__.py +0 -0
  85. spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
  86. spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
  87. spiral/protogen/_/google/__init__.py +0 -0
  88. spiral/protogen/_/google/protobuf/__init__.py +2310 -0
  89. spiral/protogen/_/message_pool.py +3 -0
  90. spiral/protogen/_/py.typed +0 -0
  91. spiral/protogen/_/scandal/__init__.py +190 -0
  92. spiral/protogen/_/spfs/__init__.py +72 -0
  93. spiral/protogen/_/spql/__init__.py +61 -0
  94. spiral/protogen/_/substrait/__init__.py +6196 -0
  95. spiral/protogen/_/substrait/extensions/__init__.py +169 -0
  96. spiral/protogen/__init__.py +0 -0
  97. spiral/protogen/util.py +41 -0
  98. spiral/py.typed +0 -0
  99. spiral/scan.py +363 -0
  100. spiral/server.py +17 -0
  101. spiral/settings.py +36 -0
  102. spiral/snapshot.py +56 -0
  103. spiral/streaming_/__init__.py +3 -0
  104. spiral/streaming_/reader.py +133 -0
  105. spiral/streaming_/stream.py +157 -0
  106. spiral/substrait_.py +274 -0
  107. spiral/table.py +224 -0
  108. spiral/text_index.py +17 -0
  109. spiral/transaction.py +155 -0
  110. spiral/types_.py +6 -0
@@ -0,0 +1,56 @@
1
+ from typing import Any
2
+
3
+
4
+ def display_metrics(metrics: dict[str, Any]) -> None:
5
+ """Display metrics in a formatted table."""
6
+ print(
7
+ f"{'Metric':<40} {'Type':<10} {'Count':<8} {'Avg':<12} {'Min':<12} "
8
+ f"{'Max':<12} {'P95':<12} {'P99':<12} {'StdDev':<12}"
9
+ )
10
+ print("=" * 140)
11
+
12
+ for metric_name, data in sorted(metrics.items()):
13
+ metric_type = data["type"]
14
+ count = data["count"]
15
+ avg = _format_value(data["avg"], metric_type, metric_name)
16
+ min_val = _format_value(data["min"], metric_type, metric_name)
17
+ max_val = _format_value(data["max"], metric_type, metric_name)
18
+ p95 = _format_value(data["p95"], metric_type, metric_name)
19
+ p99 = _format_value(data["p99"], metric_type, metric_name)
20
+ stddev = _format_value(data["stddev"], metric_type, metric_name)
21
+
22
+ print(
23
+ f"{metric_name:<40} {metric_type:<10} {count:<8} {avg:<12} {min_val:<12} "
24
+ f"{max_val:<12} {p95:<12} {p99:<12} {stddev:<12}"
25
+ )
26
+
27
+
28
+ def _format_duration(nanoseconds: float) -> str:
29
+ """Convert nanoseconds to human-readable duration."""
30
+ if nanoseconds >= 1_000_000_000:
31
+ return f"{nanoseconds / 1_000_000_000:.2f}s"
32
+ elif nanoseconds >= 1_000_000:
33
+ return f"{nanoseconds / 1_000_000:.2f}ms"
34
+ elif nanoseconds >= 1_000:
35
+ return f"{nanoseconds / 1_000:.2f}μs"
36
+ else:
37
+ return f"{nanoseconds:.0f}ns"
38
+
39
+
40
+ def _format_bytes(bytes_value: float) -> str:
41
+ """Convert bytes to human-readable size."""
42
+ for unit in ["B", "KB", "MB", "GB"]:
43
+ if bytes_value < 1024:
44
+ return f"{bytes_value:.1f}{unit}"
45
+ bytes_value /= 1024
46
+ return f"{bytes_value:.1f}TB"
47
+
48
+
49
+ def _format_value(value: float, metric_type: str, metric_name: str) -> str:
50
+ """Format a value based on metric type and name."""
51
+ if metric_type == "timer" or "duration" in metric_name:
52
+ return _format_duration(value)
53
+ elif "bytes" in metric_name:
54
+ return _format_bytes(value)
55
+ else:
56
+ return f"{value:,.0f}"
spiral/debug/scan.py ADDED
@@ -0,0 +1,266 @@
1
+ from datetime import datetime
2
+
3
+ from spiral.core.table import Scan
4
+ from spiral.core.table.manifests import FragmentFile, FragmentManifest
5
+ from spiral.core.table.spec import Key
6
+ from spiral.types_ import Timestamp
7
+
8
+
9
+ def show_scan(scan: Scan):
10
+ """Displays a scan in a way that is useful for debugging."""
11
+ table_ids = scan.table_ids()
12
+ if len(table_ids) > 1:
13
+ raise NotImplementedError("Multiple table scan is not supported.")
14
+ table_id = table_ids[0]
15
+ column_groups = scan.column_groups()
16
+
17
+ splits = scan.splits()
18
+ key_space_state = scan.key_space_state(table_id)
19
+
20
+ # Collect all key bounds from all manifests. This makes sure all visualizations are aligned.
21
+ key_points = set()
22
+ key_space_manifest = key_space_state.manifest
23
+ for i in range(len(key_space_manifest)):
24
+ fragment_file = key_space_manifest[i]
25
+ key_points.add(fragment_file.key_extent.min)
26
+ key_points.add(fragment_file.key_extent.max)
27
+ for cg in column_groups:
28
+ cg_scan = scan.column_group_state(cg)
29
+ cg_manifest = cg_scan.manifest
30
+ for i in range(len(cg_manifest)):
31
+ fragment_file = cg_manifest[i]
32
+ key_points.add(fragment_file.key_extent.min)
33
+ key_points.add(fragment_file.key_extent.max)
34
+
35
+ # Make sure split points exist in all key points.
36
+ for s in splits[:-1]: # Don't take the last end.
37
+ key_points.add(s.end)
38
+ key_points = list(sorted(key_points))
39
+
40
+ show_manifest(key_space_manifest, scope="Key space", key_points=key_points, splits=splits)
41
+ for cg in scan.column_groups():
42
+ cg_scan = scan.column_group_state(cg)
43
+ # Skip table id from the start of the column group.
44
+ show_manifest(cg_scan.manifest, scope=".".join(cg.path[1:]), key_points=key_points, splits=splits)
45
+
46
+
47
+ def show_manifest(manifest: FragmentManifest, scope: str = None, key_points: list[Key] = None, splits: list = None):
48
+ try:
49
+ import matplotlib.patches as patches
50
+ import matplotlib.pyplot as plt
51
+ except ImportError:
52
+ raise ImportError("matplotlib is required for debug")
53
+
54
+ total_fragments = len(manifest)
55
+
56
+ size_points = set()
57
+ for i in range(total_fragments):
58
+ manifest_file: FragmentFile = manifest[i]
59
+ size_points.add(manifest_file.size_bytes)
60
+ size_points = list(sorted(size_points))
61
+
62
+ if key_points is None:
63
+ key_points = set()
64
+
65
+ for i in range(total_fragments):
66
+ manifest_file: FragmentFile = manifest[i]
67
+
68
+ key_points.add(manifest_file.key_extent.min)
69
+ key_points.add(manifest_file.key_extent.max)
70
+
71
+ if splits is not None:
72
+ for split in splits[:-1]:
73
+ key_points.add(split.end)
74
+
75
+ key_points = list(sorted(key_points))
76
+
77
+ # Create figure and axis with specified size
78
+ fig, ax = plt.subplots(figsize=(12, 8))
79
+
80
+ # Plot each rectangle
81
+ for i in range(total_fragments):
82
+ manifest_file: FragmentFile = manifest[i]
83
+
84
+ left = key_points.index(manifest_file.key_extent.min)
85
+ right = key_points.index(manifest_file.key_extent.max)
86
+ height = size_points.index(manifest_file.size_bytes) + 1
87
+
88
+ color = _get_fragment_color(manifest_file, i, total_fragments)
89
+
90
+ # Create rectangle patch
91
+ rect = patches.Rectangle(
92
+ (left, 0), # (x, y)
93
+ right - left, # width
94
+ height, # height
95
+ facecolor=color, # fill color
96
+ edgecolor="black", # border color
97
+ alpha=0.5, # transparency
98
+ linewidth=1, # border width
99
+ label=manifest_file.id, # label for legend
100
+ )
101
+
102
+ ax.add_patch(rect)
103
+
104
+ # Set axis limits with some padding
105
+ ax.set_xlim(-0.5, len(key_points) - 1 + 0.5)
106
+ ax.set_ylim(-0.5, len(size_points) + 0.5)
107
+
108
+ # Create split markers on x-axis
109
+ if splits is not None:
110
+ split_positions = [key_points.index(split.end) for split in splits[:-1]]
111
+
112
+ # Add split markers at the bottom
113
+ for pos in split_positions:
114
+ ax.annotate("▲", xy=(pos, 0), ha="center", va="top", color="red", annotation_clip=False)
115
+
116
+ # Add grid
117
+ ax.grid(True, linestyle="--", alpha=0.7, zorder=0)
118
+
119
+ # Add labels and title
120
+ ax.set_title("Fragment Distribution" if scope is None else f"{scope} Fragment Distribution")
121
+ ax.set_xlabel("Key Index")
122
+ ax.set_ylabel("Size Index")
123
+
124
+ # Add legend
125
+ ax.legend(bbox_to_anchor=(1, 1), loc="upper left", fontsize="small")
126
+
127
+ # Adjust layout to prevent label cutoff
128
+ plt.tight_layout()
129
+
130
+ plot = FragmentManifestPlot(fig, ax, manifest)
131
+ fig.canvas.mpl_connect("motion_notify_event", plot.hover)
132
+
133
+ plt.show()
134
+
135
+
136
+ def _get_fragment_color(manifest_file: FragmentFile, color_index, total_colors):
137
+ import matplotlib.cm as cm
138
+
139
+ if manifest_file.compacted_at is not None:
140
+ # Use a shade of gray for compacted fragments
141
+ # Vary the shade based on the index to distinguish different compacted fragments
142
+ gray_value = 0.3 + (0.5 * (color_index / total_colors))
143
+ return (gray_value, gray_value, gray_value)
144
+ else:
145
+ # Use viridis colormap for non-compacted fragments
146
+ return cm.viridis(color_index / total_colors)
147
+
148
+
149
+ def _get_human_size(size_bytes: int) -> str:
150
+ # Convert bytes to a human-readable format
151
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
152
+ if size_bytes < 1024:
153
+ return f"{size_bytes:.2f} {unit}"
154
+ size_bytes /= 1024
155
+ return f"{size_bytes:.2f} PB"
156
+
157
+
158
+ def _maybe_truncate(text, max_length: int = 30) -> str:
159
+ text = str(text)
160
+ if len(text) <= max_length:
161
+ return text
162
+
163
+ half_length = (max_length - 3) // 2
164
+ return text[:half_length] + "..." + text[-half_length:]
165
+
166
+
167
+ def _get_fragment_legend(manifest_file: FragmentFile):
168
+ return "\n".join(
169
+ [
170
+ f"id: {manifest_file.id}",
171
+ f"size: {_get_human_size(manifest_file.size_bytes)} ({manifest_file.size_bytes} bytes)",
172
+ f"key_span: {manifest_file.key_span}",
173
+ f"key_min: {_maybe_truncate(manifest_file.key_extent.min)}",
174
+ f"key_max: {_maybe_truncate(manifest_file.key_extent.max)}",
175
+ f"format: {manifest_file.format}",
176
+ f"level: {manifest_file.level}",
177
+ f"committed_at: {_format_timestamp(manifest_file.committed_at)}",
178
+ f"compacted_at: {_format_timestamp(manifest_file.compacted_at)}",
179
+ f"ks_id: {manifest_file.ks_id}",
180
+ ]
181
+ )
182
+
183
+
184
+ def _format_timestamp(ts: Timestamp | None) -> str:
185
+ # Format timestamp or show None
186
+ if ts is None:
187
+ return "None"
188
+ try:
189
+ return datetime.fromtimestamp(ts / 1e6).strftime("%Y-%m-%d %H:%M:%S")
190
+ except ValueError:
191
+ return str(ts)
192
+
193
+
194
+ class FragmentManifestPlot:
195
+ def __init__(self, fig, ax, manifest: FragmentManifest):
196
+ self.fig = fig
197
+ self.ax = ax
198
+ self.manifest = manifest
199
+
200
+ # Position the annotation in the bottom right corner
201
+ self.annotation = ax.annotate(
202
+ "",
203
+ xy=(0.98, 0.02), # Position in axes coordinates
204
+ xycoords="axes fraction",
205
+ bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="gray", alpha=0.8),
206
+ ha="right", # Right-align text
207
+ va="bottom", # Bottom-align text
208
+ visible=False,
209
+ )
210
+ self.highlighted_rect = None
211
+ self.highlighted_legend = None
212
+
213
+ def hover(self, event):
214
+ if event.inaxes != self.ax:
215
+ # Check if we're hovering over the legend
216
+ legend = self.ax.get_legend()
217
+ if legend and legend.contains(event)[0]:
218
+ # Find which legend item we're hovering over
219
+ for i, legend_text in enumerate(legend.get_texts()):
220
+ if legend_text.contains(event)[0]:
221
+ manifest_file = self.manifest[i]
222
+ self._show_legend(manifest_file, i, legend_text)
223
+ return
224
+ self._hide_legend()
225
+ return
226
+
227
+ # Check rectangles in the main plot
228
+ for i, rect in enumerate(self.ax.patches):
229
+ if rect.contains(event)[0]:
230
+ manifest_file = self.manifest[i]
231
+ self._show_legend(manifest_file, i, rect)
232
+ return
233
+
234
+ self._hide_legend()
235
+
236
+ def _show_legend(self, manifest_file, index, highlight_obj):
237
+ import matplotlib.patches as patches
238
+
239
+ # Update tooltip text
240
+ self.annotation.set_text(_get_fragment_legend(manifest_file))
241
+ self.annotation.set_visible(True)
242
+
243
+ # Handle highlighting
244
+ if isinstance(highlight_obj, patches.Rectangle):
245
+ # Highlighting rectangle in main plot
246
+ if self.highlighted_rect and self.highlighted_rect != highlight_obj:
247
+ self.highlighted_rect.set_alpha(0.5)
248
+ highlight_obj.set_alpha(0.8)
249
+ self.highlighted_rect = highlight_obj
250
+ else:
251
+ # Highlighting legend text
252
+ if self.highlighted_rect:
253
+ self.highlighted_rect.set_alpha(0.5)
254
+ # Find and highlight corresponding rectangle
255
+ rect = self.ax.patches[index]
256
+ rect.set_alpha(0.8)
257
+ self.highlighted_rect = rect
258
+
259
+ self.fig.canvas.draw_idle()
260
+
261
+ def _hide_legend(self):
262
+ if self.annotation.get_visible():
263
+ self.annotation.set_visible(False)
264
+ if self.highlighted_rect:
265
+ self.highlighted_rect.set_alpha(0.5)
266
+ self.fig.canvas.draw_idle()
spiral/enrichment.py ADDED
@@ -0,0 +1,306 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import logging
5
+ from functools import partial
6
+ from typing import TYPE_CHECKING
7
+
8
+ from spiral.core.client import KeyColumns, Shard
9
+ from spiral.core.table import KeyRange
10
+ from spiral.core.table.spec import Key, Operation
11
+ from spiral.expressions import Expr
12
+
13
+ if TYPE_CHECKING:
14
+ import dask.distributed
15
+
16
+ from spiral import Scan, Table
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Enrichment:
22
+ """
23
+ An enrichment is used to derive new columns from the existing once, such as fetching data from object storage
24
+ with `se.s3.get` or compute embeddings. With column groups design supporting 100s of thousands of columns,
25
+ horizontally expanding tables are a powerful primitive.
26
+
27
+ NOTE: Spiral aims to optimize enrichments where source and destination table are the same.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ table: Table,
33
+ projection: Expr,
34
+ where: Expr | None,
35
+ ):
36
+ self._table = table
37
+ self._projection = projection
38
+ self._where = where
39
+
40
+ @property
41
+ def table(self) -> Table:
42
+ """The table to write back into."""
43
+ return self._table
44
+
45
+ @property
46
+ def projection(self) -> Expr:
47
+ """The projection expression."""
48
+ return self._projection
49
+
50
+ @property
51
+ def where(self) -> Expr | None:
52
+ """The filter expression."""
53
+ return self._where
54
+
55
+ def _scan(self) -> Scan:
56
+ return self._table.spiral.scan(self._projection, where=self._where, _key_columns=KeyColumns.Included)
57
+
58
+ def apply(
59
+ self,
60
+ *,
61
+ batch_readahead: int | None = None,
62
+ partition_size_bytes: int | None = None,
63
+ txn_dump: str | None = None,
64
+ ) -> None:
65
+ """Apply the enrichment onto the table in a streaming fashion.
66
+
67
+ For large tables, consider using `apply_dask` for distributed execution.
68
+
69
+ Args:
70
+ index: Optional key space index to use for sharding the enrichment.
71
+ If not provided, the table's default sharding will be used.
72
+ partition_size_bytes: The maximum partition size in bytes.
73
+ If not provided, the default partition size is used.
74
+ txn_dump: Optional path to dump the transaction JSON for debugging.
75
+ """
76
+
77
+ txn = self._table.txn()
78
+
79
+ txn.writeback(
80
+ self._scan(),
81
+ partition_size_bytes=partition_size_bytes,
82
+ batch_readahead=batch_readahead,
83
+ )
84
+
85
+ if txn.is_empty():
86
+ logger.warning("Transaction not committed. No rows were read for enrichment.")
87
+ return
88
+
89
+ txn.commit(txn_dump=txn_dump)
90
+
91
+ def apply_dask(
92
+ self,
93
+ *,
94
+ partition_size_bytes: int | None = None,
95
+ max_task_size: int | None = None,
96
+ checkpoint_dump: str | None = None,
97
+ shards: list[Shard] | None = None,
98
+ txn_dump: str | None = None,
99
+ client: dask.distributed.Client | None = None,
100
+ **kwargs,
101
+ ) -> None:
102
+ """Use distributed Dask to apply the enrichment. Requires `dask[distributed]` to be installed.
103
+
104
+ If "address" of an existing Dask cluster is not provided in `kwargs`, a local cluster will be created.
105
+
106
+ Dask execution has some limitations, e.g. UDFs are not currently supported. These limitations
107
+ usually manifest as serialization errors when Dask workers attempt to serialize the state. If you are
108
+ encountering such issues, consider splitting the enrichment into UDF-only derivation that will be
109
+ executed in a streaming fashion, followed by a Dask enrichment for the rest of the computation.
110
+ If that is not possible, please reach out to the support for assistance.
111
+
112
+ How shards are determined:
113
+ - If `shards` is provided, those will be used directly.
114
+ - Else, if `checkpoint_dump` is provided, shards will be loaded from the checkpoint.
115
+ - Else, if `max_task_size` is provided, shards will be created based on the task size.
116
+ - Else, the scan's default sharding will be used.
117
+
118
+ Args:
119
+ partition_size_bytes: The maximum partition size in bytes.
120
+ If not provided, the default partition size is used.
121
+ max_task_size: Optional size task limit, in number of rows. Used for sharding.
122
+ If provided and checkpoint is present, the checkpoint shards will be used instead.
123
+ If not provided, the scan's default sharding will be used.
124
+ checkpoint_dump: Optional path to dump intermediate checkpoints for incremental progress.
125
+ shards: Optional list of shards to process.
126
+ If provided, `max_task_size` and `checkpoint_dump` are ignored.
127
+ txn_dump: Optional path to dump the transaction JSON for debugging.
128
+ client: Optional Dask distributed client. If not provided, a new client will be created
129
+ **kwargs: Additional keyword arguments to pass to `dask.distributed.Client`
130
+ such as `address` to connect to an existing cluster.
131
+ """
132
+ if client is None:
133
+ try:
134
+ from dask.distributed import Client
135
+ except ImportError:
136
+ raise ImportError("dask is not installed, please install dask[distributed] to use this feature.")
137
+
138
+ # Connect before doing any work.
139
+ client = Client(**kwargs)
140
+
141
+ # Start a transaction BEFORE the planning scan.
142
+ tx = self._table.txn()
143
+ plan_scan = self._scan()
144
+
145
+ # Determine the "tasks". Start from provided shards.
146
+ task_shards = shards
147
+ # If shards are not provided, try loading from checkpoint.
148
+ if task_shards is None and checkpoint_dump is not None:
149
+ checkpoint: list[KeyRange] | None = _checkpoint_load_key_ranges(checkpoint_dump)
150
+ if checkpoint is None:
151
+ logger.info(f"No existing checkpoint found at {checkpoint_dump}. Starting from scratch.")
152
+ else:
153
+ logger.info(f"Resuming enrichment from checkpoint at {checkpoint_dump} with {len(checkpoint)} ranges.")
154
+ task_shards = [Shard(kr, None) for kr in checkpoint]
155
+ # If still no shards, try creating from max task size.
156
+ if task_shards is None and max_task_size is not None:
157
+ task_shards = self._table.spiral.compute_shards(max_task_size, self.projection, self.where)
158
+ # Fallback to default sharding in the scan.
159
+ if task_shards is None:
160
+ task_shards = plan_scan.shards()
161
+
162
+ # Partially bind the enrichment function.
163
+ _compute = partial(
164
+ _enrichment_task,
165
+ settings_json=self._table.spiral.config.to_json(),
166
+ state_json=plan_scan.core.plan_state().to_json(),
167
+ output_table_id=self._table.table_id,
168
+ partition_size_bytes=partition_size_bytes,
169
+ incremental=checkpoint_dump is not None,
170
+ )
171
+ enrichments = client.map(_compute, task_shards)
172
+
173
+ logger.info(f"Applying enrichment with {len(task_shards)} shards. Follow progress at {client.dashboard_link}")
174
+
175
+ failed_ranges = []
176
+ try:
177
+ for result, shard in zip(client.gather(enrichments), task_shards):
178
+ result: EnrichmentTaskResult
179
+
180
+ if result.error is not None:
181
+ logger.error(f"Enrichment task failed for range {shard.key_range}: {result.error}")
182
+ failed_ranges.append(shard.key_range)
183
+ continue
184
+
185
+ tx.include(result.ops)
186
+ except Exception as e:
187
+ # If not incremental, re-raise the exception.
188
+ if checkpoint_dump is None:
189
+ raise e
190
+
191
+ # Handle worker failures (e.g., KilledWorker from Dask)
192
+ from dask.distributed import KilledWorker
193
+
194
+ if isinstance(e, KilledWorker):
195
+ logger.error(f"Dask worker was killed during enrichment: {e}")
196
+
197
+ # Try to gather partial results and mark remaining tasks as failed
198
+ for future, shard in zip(enrichments, task_shards):
199
+ if future.done() and not future.exception():
200
+ try:
201
+ result = future.result()
202
+
203
+ if result.error is not None:
204
+ logger.error(f"Enrichment task failed for range {shard.key_range}: {result.error}")
205
+ failed_ranges.append(shard.key_range)
206
+ continue
207
+
208
+ tx.include(result.ops)
209
+ except Exception:
210
+ # Task failed or incomplete, add to failed ranges
211
+ failed_ranges.append(shard.key_range)
212
+ else:
213
+ # Task didn't complete, add to failed ranges
214
+ failed_ranges.append(shard.key_range)
215
+
216
+ # Dump checkpoint of failed ranges, if any.
217
+ if checkpoint_dump is not None:
218
+ logger.info(
219
+ f"Dumping checkpoint with failed {len(failed_ranges)}/{len(task_shards)} ranges to {checkpoint_dump}."
220
+ )
221
+ _checkpoint_dump_key_ranges(checkpoint_dump, failed_ranges)
222
+
223
+ if tx.is_empty():
224
+ logger.warning("Transaction not committed. No rows were read for enrichment.")
225
+ return
226
+
227
+ # Always compact in distributed enrichment.
228
+ tx.commit(compact=True, txn_dump=txn_dump)
229
+
230
+
231
+ def _checkpoint_load_key_ranges(checkpoint_dump: str) -> list[KeyRange] | None:
232
+ import json
233
+ import os
234
+
235
+ if not os.path.exists(checkpoint_dump):
236
+ return None
237
+
238
+ with open(checkpoint_dump) as f:
239
+ data = json.load(f)
240
+ return [
241
+ KeyRange(begin=Key(bytes.fromhex(r["begin"])), end=Key(bytes.fromhex(r["end"])))
242
+ for r in data.get("key_ranges", [])
243
+ ]
244
+
245
+
246
+ def _checkpoint_dump_key_ranges(checkpoint_dump: str, ranges: list[KeyRange]):
247
+ import json
248
+ import os
249
+
250
+ os.makedirs(os.path.dirname(checkpoint_dump), exist_ok=True)
251
+ with open(checkpoint_dump, "w") as f:
252
+ json.dump(
253
+ {"key_ranges": [{"begin": bytes(r.begin).hex(), "end": bytes(r.end).hex()} for r in ranges]},
254
+ f,
255
+ )
256
+
257
+
258
+ @dataclasses.dataclass
259
+ class EnrichmentTaskResult:
260
+ ops: list[Operation]
261
+ error: str | None = None
262
+
263
+ def __getstate__(self):
264
+ return {
265
+ "ops": [op.to_json() for op in self.ops],
266
+ "error": self.error,
267
+ }
268
+
269
+ def __setstate__(self, state):
270
+ self.ops = [Operation.from_json(op_json) for op_json in state["ops"]]
271
+ self.error = state["error"]
272
+
273
+
274
+ # NOTE(marko): This function must be picklable!
275
+ def _enrichment_task(
276
+ shard: Shard,
277
+ *,
278
+ settings_json: str,
279
+ state_json: str,
280
+ output_table_id,
281
+ partition_size_bytes: int | None,
282
+ incremental: bool,
283
+ ) -> EnrichmentTaskResult:
284
+ # Returns operations that can be included in a transaction.
285
+ from spiral import Scan, Spiral
286
+ from spiral.core.table import ScanState
287
+ from spiral.settings import ClientSettings
288
+
289
+ settings = ClientSettings.from_json(settings_json)
290
+ sp = Spiral(config=settings)
291
+ state = ScanState.from_json(state_json)
292
+ task_scan = Scan(sp, sp.core.load_scan(state))
293
+ table = sp.table(output_table_id)
294
+ task_tx = table.txn()
295
+
296
+ try:
297
+ task_tx.writeback(task_scan, key_range=shard.key_range, partition_size_bytes=partition_size_bytes)
298
+ return EnrichmentTaskResult(ops=task_tx.take())
299
+ except Exception as e:
300
+ task_tx.abort()
301
+
302
+ if incremental:
303
+ return EnrichmentTaskResult(ops=[], error=str(e))
304
+
305
+ logger.error(f"Enrichment task failed for shard {shard}: {e}")
306
+ raise e