pyspiral 0.7.18__cp312-abi3-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyspiral-0.7.18.dist-info/METADATA +52 -0
- pyspiral-0.7.18.dist-info/RECORD +110 -0
- pyspiral-0.7.18.dist-info/WHEEL +4 -0
- pyspiral-0.7.18.dist-info/entry_points.txt +3 -0
- spiral/__init__.py +55 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +411 -0
- spiral/api/__init__.py +78 -0
- spiral/api/admin.py +15 -0
- spiral/api/client.py +164 -0
- spiral/api/filesystems.py +134 -0
- spiral/api/key_space_indexes.py +23 -0
- spiral/api/organizations.py +77 -0
- spiral/api/projects.py +219 -0
- spiral/api/telemetry.py +19 -0
- spiral/api/text_indexes.py +56 -0
- spiral/api/types.py +23 -0
- spiral/api/workers.py +40 -0
- spiral/api/workloads.py +52 -0
- spiral/arrow_.py +216 -0
- spiral/cli/__init__.py +88 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +14 -0
- spiral/cli/app.py +108 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +76 -0
- spiral/cli/iceberg.py +97 -0
- spiral/cli/key_spaces.py +103 -0
- spiral/cli/login.py +25 -0
- spiral/cli/orgs.py +90 -0
- spiral/cli/printer.py +53 -0
- spiral/cli/projects.py +147 -0
- spiral/cli/state.py +7 -0
- spiral/cli/tables.py +197 -0
- spiral/cli/telemetry.py +17 -0
- spiral/cli/text.py +115 -0
- spiral/cli/types.py +50 -0
- spiral/cli/workloads.py +58 -0
- spiral/client.py +256 -0
- spiral/core/__init__.pyi +0 -0
- spiral/core/_tools/__init__.pyi +5 -0
- spiral/core/authn/__init__.pyi +21 -0
- spiral/core/client/__init__.pyi +285 -0
- spiral/core/config/__init__.pyi +35 -0
- spiral/core/expr/__init__.pyi +15 -0
- spiral/core/expr/images/__init__.pyi +3 -0
- spiral/core/expr/list_/__init__.pyi +4 -0
- spiral/core/expr/refs/__init__.pyi +4 -0
- spiral/core/expr/str_/__init__.pyi +3 -0
- spiral/core/expr/struct_/__init__.pyi +6 -0
- spiral/core/expr/text/__init__.pyi +5 -0
- spiral/core/expr/udf/__init__.pyi +14 -0
- spiral/core/expr/video/__init__.pyi +3 -0
- spiral/core/table/__init__.pyi +141 -0
- spiral/core/table/manifests/__init__.pyi +35 -0
- spiral/core/table/metastore/__init__.pyi +58 -0
- spiral/core/table/spec/__init__.pyi +215 -0
- spiral/dataloader.py +299 -0
- spiral/dataset.py +264 -0
- spiral/datetime_.py +27 -0
- spiral/debug/__init__.py +0 -0
- spiral/debug/manifests.py +87 -0
- spiral/debug/metrics.py +56 -0
- spiral/debug/scan.py +266 -0
- spiral/enrichment.py +306 -0
- spiral/expressions/__init__.py +274 -0
- spiral/expressions/base.py +167 -0
- spiral/expressions/file.py +17 -0
- spiral/expressions/http.py +17 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/s3.py +16 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +59 -0
- spiral/expressions/text.py +62 -0
- spiral/expressions/tiff.py +222 -0
- spiral/expressions/udf.py +60 -0
- spiral/grpc_.py +32 -0
- spiral/iceberg.py +31 -0
- spiral/iterable_dataset.py +106 -0
- spiral/key_space_index.py +44 -0
- spiral/project.py +227 -0
- spiral/protogen/_/__init__.py +0 -0
- spiral/protogen/_/arrow/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
- spiral/protogen/_/google/__init__.py +0 -0
- spiral/protogen/_/google/protobuf/__init__.py +2310 -0
- spiral/protogen/_/message_pool.py +3 -0
- spiral/protogen/_/py.typed +0 -0
- spiral/protogen/_/scandal/__init__.py +190 -0
- spiral/protogen/_/spfs/__init__.py +72 -0
- spiral/protogen/_/spql/__init__.py +61 -0
- spiral/protogen/_/substrait/__init__.py +6196 -0
- spiral/protogen/_/substrait/extensions/__init__.py +169 -0
- spiral/protogen/__init__.py +0 -0
- spiral/protogen/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan.py +363 -0
- spiral/server.py +17 -0
- spiral/settings.py +36 -0
- spiral/snapshot.py +56 -0
- spiral/streaming_/__init__.py +3 -0
- spiral/streaming_/reader.py +133 -0
- spiral/streaming_/stream.py +157 -0
- spiral/substrait_.py +274 -0
- spiral/table.py +224 -0
- spiral/text_index.py +17 -0
- spiral/transaction.py +155 -0
- spiral/types_.py +6 -0
spiral/debug/metrics.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def display_metrics(metrics: dict[str, Any]) -> None:
|
|
5
|
+
"""Display metrics in a formatted table."""
|
|
6
|
+
print(
|
|
7
|
+
f"{'Metric':<40} {'Type':<10} {'Count':<8} {'Avg':<12} {'Min':<12} "
|
|
8
|
+
f"{'Max':<12} {'P95':<12} {'P99':<12} {'StdDev':<12}"
|
|
9
|
+
)
|
|
10
|
+
print("=" * 140)
|
|
11
|
+
|
|
12
|
+
for metric_name, data in sorted(metrics.items()):
|
|
13
|
+
metric_type = data["type"]
|
|
14
|
+
count = data["count"]
|
|
15
|
+
avg = _format_value(data["avg"], metric_type, metric_name)
|
|
16
|
+
min_val = _format_value(data["min"], metric_type, metric_name)
|
|
17
|
+
max_val = _format_value(data["max"], metric_type, metric_name)
|
|
18
|
+
p95 = _format_value(data["p95"], metric_type, metric_name)
|
|
19
|
+
p99 = _format_value(data["p99"], metric_type, metric_name)
|
|
20
|
+
stddev = _format_value(data["stddev"], metric_type, metric_name)
|
|
21
|
+
|
|
22
|
+
print(
|
|
23
|
+
f"{metric_name:<40} {metric_type:<10} {count:<8} {avg:<12} {min_val:<12} "
|
|
24
|
+
f"{max_val:<12} {p95:<12} {p99:<12} {stddev:<12}"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _format_duration(nanoseconds: float) -> str:
|
|
29
|
+
"""Convert nanoseconds to human-readable duration."""
|
|
30
|
+
if nanoseconds >= 1_000_000_000:
|
|
31
|
+
return f"{nanoseconds / 1_000_000_000:.2f}s"
|
|
32
|
+
elif nanoseconds >= 1_000_000:
|
|
33
|
+
return f"{nanoseconds / 1_000_000:.2f}ms"
|
|
34
|
+
elif nanoseconds >= 1_000:
|
|
35
|
+
return f"{nanoseconds / 1_000:.2f}μs"
|
|
36
|
+
else:
|
|
37
|
+
return f"{nanoseconds:.0f}ns"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _format_bytes(bytes_value: float) -> str:
|
|
41
|
+
"""Convert bytes to human-readable size."""
|
|
42
|
+
for unit in ["B", "KB", "MB", "GB"]:
|
|
43
|
+
if bytes_value < 1024:
|
|
44
|
+
return f"{bytes_value:.1f}{unit}"
|
|
45
|
+
bytes_value /= 1024
|
|
46
|
+
return f"{bytes_value:.1f}TB"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _format_value(value: float, metric_type: str, metric_name: str) -> str:
|
|
50
|
+
"""Format a value based on metric type and name."""
|
|
51
|
+
if metric_type == "timer" or "duration" in metric_name:
|
|
52
|
+
return _format_duration(value)
|
|
53
|
+
elif "bytes" in metric_name:
|
|
54
|
+
return _format_bytes(value)
|
|
55
|
+
else:
|
|
56
|
+
return f"{value:,.0f}"
|
spiral/debug/scan.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
from spiral.core.table import Scan
|
|
4
|
+
from spiral.core.table.manifests import FragmentFile, FragmentManifest
|
|
5
|
+
from spiral.core.table.spec import Key
|
|
6
|
+
from spiral.types_ import Timestamp
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def show_scan(scan: Scan):
|
|
10
|
+
"""Displays a scan in a way that is useful for debugging."""
|
|
11
|
+
table_ids = scan.table_ids()
|
|
12
|
+
if len(table_ids) > 1:
|
|
13
|
+
raise NotImplementedError("Multiple table scan is not supported.")
|
|
14
|
+
table_id = table_ids[0]
|
|
15
|
+
column_groups = scan.column_groups()
|
|
16
|
+
|
|
17
|
+
splits = scan.splits()
|
|
18
|
+
key_space_state = scan.key_space_state(table_id)
|
|
19
|
+
|
|
20
|
+
# Collect all key bounds from all manifests. This makes sure all visualizations are aligned.
|
|
21
|
+
key_points = set()
|
|
22
|
+
key_space_manifest = key_space_state.manifest
|
|
23
|
+
for i in range(len(key_space_manifest)):
|
|
24
|
+
fragment_file = key_space_manifest[i]
|
|
25
|
+
key_points.add(fragment_file.key_extent.min)
|
|
26
|
+
key_points.add(fragment_file.key_extent.max)
|
|
27
|
+
for cg in column_groups:
|
|
28
|
+
cg_scan = scan.column_group_state(cg)
|
|
29
|
+
cg_manifest = cg_scan.manifest
|
|
30
|
+
for i in range(len(cg_manifest)):
|
|
31
|
+
fragment_file = cg_manifest[i]
|
|
32
|
+
key_points.add(fragment_file.key_extent.min)
|
|
33
|
+
key_points.add(fragment_file.key_extent.max)
|
|
34
|
+
|
|
35
|
+
# Make sure split points exist in all key points.
|
|
36
|
+
for s in splits[:-1]: # Don't take the last end.
|
|
37
|
+
key_points.add(s.end)
|
|
38
|
+
key_points = list(sorted(key_points))
|
|
39
|
+
|
|
40
|
+
show_manifest(key_space_manifest, scope="Key space", key_points=key_points, splits=splits)
|
|
41
|
+
for cg in scan.column_groups():
|
|
42
|
+
cg_scan = scan.column_group_state(cg)
|
|
43
|
+
# Skip table id from the start of the column group.
|
|
44
|
+
show_manifest(cg_scan.manifest, scope=".".join(cg.path[1:]), key_points=key_points, splits=splits)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def show_manifest(manifest: FragmentManifest, scope: str = None, key_points: list[Key] = None, splits: list = None):
|
|
48
|
+
try:
|
|
49
|
+
import matplotlib.patches as patches
|
|
50
|
+
import matplotlib.pyplot as plt
|
|
51
|
+
except ImportError:
|
|
52
|
+
raise ImportError("matplotlib is required for debug")
|
|
53
|
+
|
|
54
|
+
total_fragments = len(manifest)
|
|
55
|
+
|
|
56
|
+
size_points = set()
|
|
57
|
+
for i in range(total_fragments):
|
|
58
|
+
manifest_file: FragmentFile = manifest[i]
|
|
59
|
+
size_points.add(manifest_file.size_bytes)
|
|
60
|
+
size_points = list(sorted(size_points))
|
|
61
|
+
|
|
62
|
+
if key_points is None:
|
|
63
|
+
key_points = set()
|
|
64
|
+
|
|
65
|
+
for i in range(total_fragments):
|
|
66
|
+
manifest_file: FragmentFile = manifest[i]
|
|
67
|
+
|
|
68
|
+
key_points.add(manifest_file.key_extent.min)
|
|
69
|
+
key_points.add(manifest_file.key_extent.max)
|
|
70
|
+
|
|
71
|
+
if splits is not None:
|
|
72
|
+
for split in splits[:-1]:
|
|
73
|
+
key_points.add(split.end)
|
|
74
|
+
|
|
75
|
+
key_points = list(sorted(key_points))
|
|
76
|
+
|
|
77
|
+
# Create figure and axis with specified size
|
|
78
|
+
fig, ax = plt.subplots(figsize=(12, 8))
|
|
79
|
+
|
|
80
|
+
# Plot each rectangle
|
|
81
|
+
for i in range(total_fragments):
|
|
82
|
+
manifest_file: FragmentFile = manifest[i]
|
|
83
|
+
|
|
84
|
+
left = key_points.index(manifest_file.key_extent.min)
|
|
85
|
+
right = key_points.index(manifest_file.key_extent.max)
|
|
86
|
+
height = size_points.index(manifest_file.size_bytes) + 1
|
|
87
|
+
|
|
88
|
+
color = _get_fragment_color(manifest_file, i, total_fragments)
|
|
89
|
+
|
|
90
|
+
# Create rectangle patch
|
|
91
|
+
rect = patches.Rectangle(
|
|
92
|
+
(left, 0), # (x, y)
|
|
93
|
+
right - left, # width
|
|
94
|
+
height, # height
|
|
95
|
+
facecolor=color, # fill color
|
|
96
|
+
edgecolor="black", # border color
|
|
97
|
+
alpha=0.5, # transparency
|
|
98
|
+
linewidth=1, # border width
|
|
99
|
+
label=manifest_file.id, # label for legend
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
ax.add_patch(rect)
|
|
103
|
+
|
|
104
|
+
# Set axis limits with some padding
|
|
105
|
+
ax.set_xlim(-0.5, len(key_points) - 1 + 0.5)
|
|
106
|
+
ax.set_ylim(-0.5, len(size_points) + 0.5)
|
|
107
|
+
|
|
108
|
+
# Create split markers on x-axis
|
|
109
|
+
if splits is not None:
|
|
110
|
+
split_positions = [key_points.index(split.end) for split in splits[:-1]]
|
|
111
|
+
|
|
112
|
+
# Add split markers at the bottom
|
|
113
|
+
for pos in split_positions:
|
|
114
|
+
ax.annotate("▲", xy=(pos, 0), ha="center", va="top", color="red", annotation_clip=False)
|
|
115
|
+
|
|
116
|
+
# Add grid
|
|
117
|
+
ax.grid(True, linestyle="--", alpha=0.7, zorder=0)
|
|
118
|
+
|
|
119
|
+
# Add labels and title
|
|
120
|
+
ax.set_title("Fragment Distribution" if scope is None else f"{scope} Fragment Distribution")
|
|
121
|
+
ax.set_xlabel("Key Index")
|
|
122
|
+
ax.set_ylabel("Size Index")
|
|
123
|
+
|
|
124
|
+
# Add legend
|
|
125
|
+
ax.legend(bbox_to_anchor=(1, 1), loc="upper left", fontsize="small")
|
|
126
|
+
|
|
127
|
+
# Adjust layout to prevent label cutoff
|
|
128
|
+
plt.tight_layout()
|
|
129
|
+
|
|
130
|
+
plot = FragmentManifestPlot(fig, ax, manifest)
|
|
131
|
+
fig.canvas.mpl_connect("motion_notify_event", plot.hover)
|
|
132
|
+
|
|
133
|
+
plt.show()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _get_fragment_color(manifest_file: FragmentFile, color_index, total_colors):
|
|
137
|
+
import matplotlib.cm as cm
|
|
138
|
+
|
|
139
|
+
if manifest_file.compacted_at is not None:
|
|
140
|
+
# Use a shade of gray for compacted fragments
|
|
141
|
+
# Vary the shade based on the index to distinguish different compacted fragments
|
|
142
|
+
gray_value = 0.3 + (0.5 * (color_index / total_colors))
|
|
143
|
+
return (gray_value, gray_value, gray_value)
|
|
144
|
+
else:
|
|
145
|
+
# Use viridis colormap for non-compacted fragments
|
|
146
|
+
return cm.viridis(color_index / total_colors)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _get_human_size(size_bytes: int) -> str:
|
|
150
|
+
# Convert bytes to a human-readable format
|
|
151
|
+
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
|
152
|
+
if size_bytes < 1024:
|
|
153
|
+
return f"{size_bytes:.2f} {unit}"
|
|
154
|
+
size_bytes /= 1024
|
|
155
|
+
return f"{size_bytes:.2f} PB"
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _maybe_truncate(text, max_length: int = 30) -> str:
|
|
159
|
+
text = str(text)
|
|
160
|
+
if len(text) <= max_length:
|
|
161
|
+
return text
|
|
162
|
+
|
|
163
|
+
half_length = (max_length - 3) // 2
|
|
164
|
+
return text[:half_length] + "..." + text[-half_length:]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _get_fragment_legend(manifest_file: FragmentFile):
|
|
168
|
+
return "\n".join(
|
|
169
|
+
[
|
|
170
|
+
f"id: {manifest_file.id}",
|
|
171
|
+
f"size: {_get_human_size(manifest_file.size_bytes)} ({manifest_file.size_bytes} bytes)",
|
|
172
|
+
f"key_span: {manifest_file.key_span}",
|
|
173
|
+
f"key_min: {_maybe_truncate(manifest_file.key_extent.min)}",
|
|
174
|
+
f"key_max: {_maybe_truncate(manifest_file.key_extent.max)}",
|
|
175
|
+
f"format: {manifest_file.format}",
|
|
176
|
+
f"level: {manifest_file.level}",
|
|
177
|
+
f"committed_at: {_format_timestamp(manifest_file.committed_at)}",
|
|
178
|
+
f"compacted_at: {_format_timestamp(manifest_file.compacted_at)}",
|
|
179
|
+
f"ks_id: {manifest_file.ks_id}",
|
|
180
|
+
]
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _format_timestamp(ts: Timestamp | None) -> str:
|
|
185
|
+
# Format timestamp or show None
|
|
186
|
+
if ts is None:
|
|
187
|
+
return "None"
|
|
188
|
+
try:
|
|
189
|
+
return datetime.fromtimestamp(ts / 1e6).strftime("%Y-%m-%d %H:%M:%S")
|
|
190
|
+
except ValueError:
|
|
191
|
+
return str(ts)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class FragmentManifestPlot:
|
|
195
|
+
def __init__(self, fig, ax, manifest: FragmentManifest):
|
|
196
|
+
self.fig = fig
|
|
197
|
+
self.ax = ax
|
|
198
|
+
self.manifest = manifest
|
|
199
|
+
|
|
200
|
+
# Position the annotation in the bottom right corner
|
|
201
|
+
self.annotation = ax.annotate(
|
|
202
|
+
"",
|
|
203
|
+
xy=(0.98, 0.02), # Position in axes coordinates
|
|
204
|
+
xycoords="axes fraction",
|
|
205
|
+
bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="gray", alpha=0.8),
|
|
206
|
+
ha="right", # Right-align text
|
|
207
|
+
va="bottom", # Bottom-align text
|
|
208
|
+
visible=False,
|
|
209
|
+
)
|
|
210
|
+
self.highlighted_rect = None
|
|
211
|
+
self.highlighted_legend = None
|
|
212
|
+
|
|
213
|
+
def hover(self, event):
|
|
214
|
+
if event.inaxes != self.ax:
|
|
215
|
+
# Check if we're hovering over the legend
|
|
216
|
+
legend = self.ax.get_legend()
|
|
217
|
+
if legend and legend.contains(event)[0]:
|
|
218
|
+
# Find which legend item we're hovering over
|
|
219
|
+
for i, legend_text in enumerate(legend.get_texts()):
|
|
220
|
+
if legend_text.contains(event)[0]:
|
|
221
|
+
manifest_file = self.manifest[i]
|
|
222
|
+
self._show_legend(manifest_file, i, legend_text)
|
|
223
|
+
return
|
|
224
|
+
self._hide_legend()
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
# Check rectangles in the main plot
|
|
228
|
+
for i, rect in enumerate(self.ax.patches):
|
|
229
|
+
if rect.contains(event)[0]:
|
|
230
|
+
manifest_file = self.manifest[i]
|
|
231
|
+
self._show_legend(manifest_file, i, rect)
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
self._hide_legend()
|
|
235
|
+
|
|
236
|
+
def _show_legend(self, manifest_file, index, highlight_obj):
|
|
237
|
+
import matplotlib.patches as patches
|
|
238
|
+
|
|
239
|
+
# Update tooltip text
|
|
240
|
+
self.annotation.set_text(_get_fragment_legend(manifest_file))
|
|
241
|
+
self.annotation.set_visible(True)
|
|
242
|
+
|
|
243
|
+
# Handle highlighting
|
|
244
|
+
if isinstance(highlight_obj, patches.Rectangle):
|
|
245
|
+
# Highlighting rectangle in main plot
|
|
246
|
+
if self.highlighted_rect and self.highlighted_rect != highlight_obj:
|
|
247
|
+
self.highlighted_rect.set_alpha(0.5)
|
|
248
|
+
highlight_obj.set_alpha(0.8)
|
|
249
|
+
self.highlighted_rect = highlight_obj
|
|
250
|
+
else:
|
|
251
|
+
# Highlighting legend text
|
|
252
|
+
if self.highlighted_rect:
|
|
253
|
+
self.highlighted_rect.set_alpha(0.5)
|
|
254
|
+
# Find and highlight corresponding rectangle
|
|
255
|
+
rect = self.ax.patches[index]
|
|
256
|
+
rect.set_alpha(0.8)
|
|
257
|
+
self.highlighted_rect = rect
|
|
258
|
+
|
|
259
|
+
self.fig.canvas.draw_idle()
|
|
260
|
+
|
|
261
|
+
def _hide_legend(self):
|
|
262
|
+
if self.annotation.get_visible():
|
|
263
|
+
self.annotation.set_visible(False)
|
|
264
|
+
if self.highlighted_rect:
|
|
265
|
+
self.highlighted_rect.set_alpha(0.5)
|
|
266
|
+
self.fig.canvas.draw_idle()
|
spiral/enrichment.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import logging
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from spiral.core.client import KeyColumns, Shard
|
|
9
|
+
from spiral.core.table import KeyRange
|
|
10
|
+
from spiral.core.table.spec import Key, Operation
|
|
11
|
+
from spiral.expressions import Expr
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import dask.distributed
|
|
15
|
+
|
|
16
|
+
from spiral import Scan, Table
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Enrichment:
|
|
22
|
+
"""
|
|
23
|
+
An enrichment is used to derive new columns from the existing once, such as fetching data from object storage
|
|
24
|
+
with `se.s3.get` or compute embeddings. With column groups design supporting 100s of thousands of columns,
|
|
25
|
+
horizontally expanding tables are a powerful primitive.
|
|
26
|
+
|
|
27
|
+
NOTE: Spiral aims to optimize enrichments where source and destination table are the same.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
table: Table,
|
|
33
|
+
projection: Expr,
|
|
34
|
+
where: Expr | None,
|
|
35
|
+
):
|
|
36
|
+
self._table = table
|
|
37
|
+
self._projection = projection
|
|
38
|
+
self._where = where
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def table(self) -> Table:
|
|
42
|
+
"""The table to write back into."""
|
|
43
|
+
return self._table
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def projection(self) -> Expr:
|
|
47
|
+
"""The projection expression."""
|
|
48
|
+
return self._projection
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def where(self) -> Expr | None:
|
|
52
|
+
"""The filter expression."""
|
|
53
|
+
return self._where
|
|
54
|
+
|
|
55
|
+
def _scan(self) -> Scan:
|
|
56
|
+
return self._table.spiral.scan(self._projection, where=self._where, _key_columns=KeyColumns.Included)
|
|
57
|
+
|
|
58
|
+
def apply(
|
|
59
|
+
self,
|
|
60
|
+
*,
|
|
61
|
+
batch_readahead: int | None = None,
|
|
62
|
+
partition_size_bytes: int | None = None,
|
|
63
|
+
txn_dump: str | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
"""Apply the enrichment onto the table in a streaming fashion.
|
|
66
|
+
|
|
67
|
+
For large tables, consider using `apply_dask` for distributed execution.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
index: Optional key space index to use for sharding the enrichment.
|
|
71
|
+
If not provided, the table's default sharding will be used.
|
|
72
|
+
partition_size_bytes: The maximum partition size in bytes.
|
|
73
|
+
If not provided, the default partition size is used.
|
|
74
|
+
txn_dump: Optional path to dump the transaction JSON for debugging.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
txn = self._table.txn()
|
|
78
|
+
|
|
79
|
+
txn.writeback(
|
|
80
|
+
self._scan(),
|
|
81
|
+
partition_size_bytes=partition_size_bytes,
|
|
82
|
+
batch_readahead=batch_readahead,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if txn.is_empty():
|
|
86
|
+
logger.warning("Transaction not committed. No rows were read for enrichment.")
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
txn.commit(txn_dump=txn_dump)
|
|
90
|
+
|
|
91
|
+
def apply_dask(
|
|
92
|
+
self,
|
|
93
|
+
*,
|
|
94
|
+
partition_size_bytes: int | None = None,
|
|
95
|
+
max_task_size: int | None = None,
|
|
96
|
+
checkpoint_dump: str | None = None,
|
|
97
|
+
shards: list[Shard] | None = None,
|
|
98
|
+
txn_dump: str | None = None,
|
|
99
|
+
client: dask.distributed.Client | None = None,
|
|
100
|
+
**kwargs,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Use distributed Dask to apply the enrichment. Requires `dask[distributed]` to be installed.
|
|
103
|
+
|
|
104
|
+
If "address" of an existing Dask cluster is not provided in `kwargs`, a local cluster will be created.
|
|
105
|
+
|
|
106
|
+
Dask execution has some limitations, e.g. UDFs are not currently supported. These limitations
|
|
107
|
+
usually manifest as serialization errors when Dask workers attempt to serialize the state. If you are
|
|
108
|
+
encountering such issues, consider splitting the enrichment into UDF-only derivation that will be
|
|
109
|
+
executed in a streaming fashion, followed by a Dask enrichment for the rest of the computation.
|
|
110
|
+
If that is not possible, please reach out to the support for assistance.
|
|
111
|
+
|
|
112
|
+
How shards are determined:
|
|
113
|
+
- If `shards` is provided, those will be used directly.
|
|
114
|
+
- Else, if `checkpoint_dump` is provided, shards will be loaded from the checkpoint.
|
|
115
|
+
- Else, if `max_task_size` is provided, shards will be created based on the task size.
|
|
116
|
+
- Else, the scan's default sharding will be used.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
partition_size_bytes: The maximum partition size in bytes.
|
|
120
|
+
If not provided, the default partition size is used.
|
|
121
|
+
max_task_size: Optional size task limit, in number of rows. Used for sharding.
|
|
122
|
+
If provided and checkpoint is present, the checkpoint shards will be used instead.
|
|
123
|
+
If not provided, the scan's default sharding will be used.
|
|
124
|
+
checkpoint_dump: Optional path to dump intermediate checkpoints for incremental progress.
|
|
125
|
+
shards: Optional list of shards to process.
|
|
126
|
+
If provided, `max_task_size` and `checkpoint_dump` are ignored.
|
|
127
|
+
txn_dump: Optional path to dump the transaction JSON for debugging.
|
|
128
|
+
client: Optional Dask distributed client. If not provided, a new client will be created
|
|
129
|
+
**kwargs: Additional keyword arguments to pass to `dask.distributed.Client`
|
|
130
|
+
such as `address` to connect to an existing cluster.
|
|
131
|
+
"""
|
|
132
|
+
if client is None:
|
|
133
|
+
try:
|
|
134
|
+
from dask.distributed import Client
|
|
135
|
+
except ImportError:
|
|
136
|
+
raise ImportError("dask is not installed, please install dask[distributed] to use this feature.")
|
|
137
|
+
|
|
138
|
+
# Connect before doing any work.
|
|
139
|
+
client = Client(**kwargs)
|
|
140
|
+
|
|
141
|
+
# Start a transaction BEFORE the planning scan.
|
|
142
|
+
tx = self._table.txn()
|
|
143
|
+
plan_scan = self._scan()
|
|
144
|
+
|
|
145
|
+
# Determine the "tasks". Start from provided shards.
|
|
146
|
+
task_shards = shards
|
|
147
|
+
# If shards are not provided, try loading from checkpoint.
|
|
148
|
+
if task_shards is None and checkpoint_dump is not None:
|
|
149
|
+
checkpoint: list[KeyRange] | None = _checkpoint_load_key_ranges(checkpoint_dump)
|
|
150
|
+
if checkpoint is None:
|
|
151
|
+
logger.info(f"No existing checkpoint found at {checkpoint_dump}. Starting from scratch.")
|
|
152
|
+
else:
|
|
153
|
+
logger.info(f"Resuming enrichment from checkpoint at {checkpoint_dump} with {len(checkpoint)} ranges.")
|
|
154
|
+
task_shards = [Shard(kr, None) for kr in checkpoint]
|
|
155
|
+
# If still no shards, try creating from max task size.
|
|
156
|
+
if task_shards is None and max_task_size is not None:
|
|
157
|
+
task_shards = self._table.spiral.compute_shards(max_task_size, self.projection, self.where)
|
|
158
|
+
# Fallback to default sharding in the scan.
|
|
159
|
+
if task_shards is None:
|
|
160
|
+
task_shards = plan_scan.shards()
|
|
161
|
+
|
|
162
|
+
# Partially bind the enrichment function.
|
|
163
|
+
_compute = partial(
|
|
164
|
+
_enrichment_task,
|
|
165
|
+
settings_json=self._table.spiral.config.to_json(),
|
|
166
|
+
state_json=plan_scan.core.plan_state().to_json(),
|
|
167
|
+
output_table_id=self._table.table_id,
|
|
168
|
+
partition_size_bytes=partition_size_bytes,
|
|
169
|
+
incremental=checkpoint_dump is not None,
|
|
170
|
+
)
|
|
171
|
+
enrichments = client.map(_compute, task_shards)
|
|
172
|
+
|
|
173
|
+
logger.info(f"Applying enrichment with {len(task_shards)} shards. Follow progress at {client.dashboard_link}")
|
|
174
|
+
|
|
175
|
+
failed_ranges = []
|
|
176
|
+
try:
|
|
177
|
+
for result, shard in zip(client.gather(enrichments), task_shards):
|
|
178
|
+
result: EnrichmentTaskResult
|
|
179
|
+
|
|
180
|
+
if result.error is not None:
|
|
181
|
+
logger.error(f"Enrichment task failed for range {shard.key_range}: {result.error}")
|
|
182
|
+
failed_ranges.append(shard.key_range)
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
tx.include(result.ops)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
# If not incremental, re-raise the exception.
|
|
188
|
+
if checkpoint_dump is None:
|
|
189
|
+
raise e
|
|
190
|
+
|
|
191
|
+
# Handle worker failures (e.g., KilledWorker from Dask)
|
|
192
|
+
from dask.distributed import KilledWorker
|
|
193
|
+
|
|
194
|
+
if isinstance(e, KilledWorker):
|
|
195
|
+
logger.error(f"Dask worker was killed during enrichment: {e}")
|
|
196
|
+
|
|
197
|
+
# Try to gather partial results and mark remaining tasks as failed
|
|
198
|
+
for future, shard in zip(enrichments, task_shards):
|
|
199
|
+
if future.done() and not future.exception():
|
|
200
|
+
try:
|
|
201
|
+
result = future.result()
|
|
202
|
+
|
|
203
|
+
if result.error is not None:
|
|
204
|
+
logger.error(f"Enrichment task failed for range {shard.key_range}: {result.error}")
|
|
205
|
+
failed_ranges.append(shard.key_range)
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
tx.include(result.ops)
|
|
209
|
+
except Exception:
|
|
210
|
+
# Task failed or incomplete, add to failed ranges
|
|
211
|
+
failed_ranges.append(shard.key_range)
|
|
212
|
+
else:
|
|
213
|
+
# Task didn't complete, add to failed ranges
|
|
214
|
+
failed_ranges.append(shard.key_range)
|
|
215
|
+
|
|
216
|
+
# Dump checkpoint of failed ranges, if any.
|
|
217
|
+
if checkpoint_dump is not None:
|
|
218
|
+
logger.info(
|
|
219
|
+
f"Dumping checkpoint with failed {len(failed_ranges)}/{len(task_shards)} ranges to {checkpoint_dump}."
|
|
220
|
+
)
|
|
221
|
+
_checkpoint_dump_key_ranges(checkpoint_dump, failed_ranges)
|
|
222
|
+
|
|
223
|
+
if tx.is_empty():
|
|
224
|
+
logger.warning("Transaction not committed. No rows were read for enrichment.")
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
# Always compact in distributed enrichment.
|
|
228
|
+
tx.commit(compact=True, txn_dump=txn_dump)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _checkpoint_load_key_ranges(checkpoint_dump: str) -> list[KeyRange] | None:
|
|
232
|
+
import json
|
|
233
|
+
import os
|
|
234
|
+
|
|
235
|
+
if not os.path.exists(checkpoint_dump):
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
with open(checkpoint_dump) as f:
|
|
239
|
+
data = json.load(f)
|
|
240
|
+
return [
|
|
241
|
+
KeyRange(begin=Key(bytes.fromhex(r["begin"])), end=Key(bytes.fromhex(r["end"])))
|
|
242
|
+
for r in data.get("key_ranges", [])
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _checkpoint_dump_key_ranges(checkpoint_dump: str, ranges: list[KeyRange]):
|
|
247
|
+
import json
|
|
248
|
+
import os
|
|
249
|
+
|
|
250
|
+
os.makedirs(os.path.dirname(checkpoint_dump), exist_ok=True)
|
|
251
|
+
with open(checkpoint_dump, "w") as f:
|
|
252
|
+
json.dump(
|
|
253
|
+
{"key_ranges": [{"begin": bytes(r.begin).hex(), "end": bytes(r.end).hex()} for r in ranges]},
|
|
254
|
+
f,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@dataclasses.dataclass
|
|
259
|
+
class EnrichmentTaskResult:
|
|
260
|
+
ops: list[Operation]
|
|
261
|
+
error: str | None = None
|
|
262
|
+
|
|
263
|
+
def __getstate__(self):
|
|
264
|
+
return {
|
|
265
|
+
"ops": [op.to_json() for op in self.ops],
|
|
266
|
+
"error": self.error,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
def __setstate__(self, state):
|
|
270
|
+
self.ops = [Operation.from_json(op_json) for op_json in state["ops"]]
|
|
271
|
+
self.error = state["error"]
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# NOTE(marko): This function must be picklable!
|
|
275
|
+
def _enrichment_task(
|
|
276
|
+
shard: Shard,
|
|
277
|
+
*,
|
|
278
|
+
settings_json: str,
|
|
279
|
+
state_json: str,
|
|
280
|
+
output_table_id,
|
|
281
|
+
partition_size_bytes: int | None,
|
|
282
|
+
incremental: bool,
|
|
283
|
+
) -> EnrichmentTaskResult:
|
|
284
|
+
# Returns operations that can be included in a transaction.
|
|
285
|
+
from spiral import Scan, Spiral
|
|
286
|
+
from spiral.core.table import ScanState
|
|
287
|
+
from spiral.settings import ClientSettings
|
|
288
|
+
|
|
289
|
+
settings = ClientSettings.from_json(settings_json)
|
|
290
|
+
sp = Spiral(config=settings)
|
|
291
|
+
state = ScanState.from_json(state_json)
|
|
292
|
+
task_scan = Scan(sp, sp.core.load_scan(state))
|
|
293
|
+
table = sp.table(output_table_id)
|
|
294
|
+
task_tx = table.txn()
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
task_tx.writeback(task_scan, key_range=shard.key_range, partition_size_bytes=partition_size_bytes)
|
|
298
|
+
return EnrichmentTaskResult(ops=task_tx.take())
|
|
299
|
+
except Exception as e:
|
|
300
|
+
task_tx.abort()
|
|
301
|
+
|
|
302
|
+
if incremental:
|
|
303
|
+
return EnrichmentTaskResult(ops=[], error=str(e))
|
|
304
|
+
|
|
305
|
+
logger.error(f"Enrichment task failed for shard {shard}: {e}")
|
|
306
|
+
raise e
|