codebatch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codebatch/__init__.py +3 -0
- codebatch/batch.py +366 -0
- codebatch/cas.py +170 -0
- codebatch/cli.py +432 -0
- codebatch/common.py +104 -0
- codebatch/paths.py +196 -0
- codebatch/query.py +242 -0
- codebatch/runner.py +495 -0
- codebatch/snapshot.py +340 -0
- codebatch/store.py +162 -0
- codebatch/tasks/__init__.py +37 -0
- codebatch/tasks/analyze.py +109 -0
- codebatch/tasks/lint.py +244 -0
- codebatch/tasks/parse.py +304 -0
- codebatch/tasks/symbols.py +223 -0
- codebatch-0.1.0.dist-info/METADATA +66 -0
- codebatch-0.1.0.dist-info/RECORD +19 -0
- codebatch-0.1.0.dist-info/WHEEL +4 -0
- codebatch-0.1.0.dist-info/entry_points.txt +2 -0
codebatch/runner.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""Shard runner with state machine and atomic index commit.
|
|
2
|
+
|
|
3
|
+
Shard execution follows a monotonic state machine:
|
|
4
|
+
ready -> running -> done|failed
|
|
5
|
+
|
|
6
|
+
State transitions are atomic and never corrupt prior results.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Callable, Iterable, Iterator, Optional, Union
|
|
14
|
+
|
|
15
|
+
from .batch import BatchManager
|
|
16
|
+
from .cas import ObjectStore
|
|
17
|
+
from .common import SCHEMA_VERSION, PRODUCER, utc_now_z, object_shard_prefix
|
|
18
|
+
from .snapshot import SnapshotBuilder
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _CountingIterator:
|
|
22
|
+
"""Iterator wrapper that counts items as they're yielded."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, iterable: Iterable):
|
|
25
|
+
self._iterator = iter(iterable)
|
|
26
|
+
self.count = 0
|
|
27
|
+
|
|
28
|
+
def __iter__(self):
|
|
29
|
+
return self
|
|
30
|
+
|
|
31
|
+
def __next__(self):
|
|
32
|
+
item = next(self._iterator)
|
|
33
|
+
self.count += 1
|
|
34
|
+
return item
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ShardRunner:
|
|
38
|
+
"""Runs individual shards with state management and atomic output commits."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, store_root: Path):
|
|
41
|
+
"""Initialize the shard runner.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
store_root: Root directory of the CodeBatch store.
|
|
45
|
+
"""
|
|
46
|
+
self.store_root = Path(store_root)
|
|
47
|
+
self.batch_manager = BatchManager(store_root)
|
|
48
|
+
self.snapshot_builder = SnapshotBuilder(store_root)
|
|
49
|
+
self.object_store = ObjectStore(store_root)
|
|
50
|
+
|
|
51
|
+
def _shard_dir(self, batch_id: str, task_id: str, shard_id: str) -> Path:
|
|
52
|
+
"""Get the shard directory path."""
|
|
53
|
+
return self.store_root / "batches" / batch_id / "tasks" / task_id / "shards" / shard_id
|
|
54
|
+
|
|
55
|
+
def _load_state(self, batch_id: str, task_id: str, shard_id: str) -> dict:
|
|
56
|
+
"""Load shard state."""
|
|
57
|
+
state_path = self._shard_dir(batch_id, task_id, shard_id) / "state.json"
|
|
58
|
+
with open(state_path, "r", encoding="utf-8") as f:
|
|
59
|
+
return json.load(f)
|
|
60
|
+
|
|
61
|
+
def _save_state(self, batch_id: str, task_id: str, shard_id: str, state: dict) -> None:
|
|
62
|
+
"""Save shard state atomically."""
|
|
63
|
+
shard_dir = self._shard_dir(batch_id, task_id, shard_id)
|
|
64
|
+
state_path = shard_dir / "state.json"
|
|
65
|
+
|
|
66
|
+
# Atomic write via temp file with PID for race safety
|
|
67
|
+
temp_path = state_path.with_suffix(f".tmp.{os.getpid()}")
|
|
68
|
+
with open(temp_path, "w", encoding="utf-8") as f:
|
|
69
|
+
json.dump(state, f, indent=2)
|
|
70
|
+
# On Windows, rename fails if target exists; use replace instead
|
|
71
|
+
temp_path.replace(state_path)
|
|
72
|
+
|
|
73
|
+
def _append_event(
|
|
74
|
+
self,
|
|
75
|
+
events_path: Path,
|
|
76
|
+
event: str,
|
|
77
|
+
batch_id: str,
|
|
78
|
+
task_id: str = None,
|
|
79
|
+
shard_id: str = None,
|
|
80
|
+
attempt: int = None,
|
|
81
|
+
duration_ms: int = None,
|
|
82
|
+
error: dict = None,
|
|
83
|
+
stats: dict = None,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Append an event record to events.jsonl."""
|
|
86
|
+
record = {
|
|
87
|
+
"schema_version": SCHEMA_VERSION,
|
|
88
|
+
"ts": utc_now_z(),
|
|
89
|
+
"event": event,
|
|
90
|
+
"batch_id": batch_id,
|
|
91
|
+
}
|
|
92
|
+
if task_id:
|
|
93
|
+
record["task_id"] = task_id
|
|
94
|
+
if shard_id:
|
|
95
|
+
record["shard_id"] = shard_id
|
|
96
|
+
if attempt is not None:
|
|
97
|
+
record["attempt"] = attempt
|
|
98
|
+
if duration_ms is not None:
|
|
99
|
+
record["duration_ms"] = duration_ms
|
|
100
|
+
if error:
|
|
101
|
+
record["error"] = error
|
|
102
|
+
if stats:
|
|
103
|
+
record["stats"] = stats
|
|
104
|
+
|
|
105
|
+
with open(events_path, "a", encoding="utf-8") as f:
|
|
106
|
+
f.write(json.dumps(record, separators=(",", ":")))
|
|
107
|
+
f.write("\n")
|
|
108
|
+
|
|
109
|
+
def _iter_shard_files(
|
|
110
|
+
self, snapshot_id: str, shard_id: str
|
|
111
|
+
) -> Iterator[dict]:
|
|
112
|
+
"""Stream files assigned to a shard based on hash prefix.
|
|
113
|
+
|
|
114
|
+
Files are assigned to shards based on the first two hex chars of their object hash.
|
|
115
|
+
Uses streaming to avoid loading entire index into memory.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
snapshot_id: Snapshot ID.
|
|
119
|
+
shard_id: Shard ID (two hex chars, e.g., 'ab').
|
|
120
|
+
|
|
121
|
+
Yields:
|
|
122
|
+
File index records assigned to this shard.
|
|
123
|
+
"""
|
|
124
|
+
for record in self.snapshot_builder.iter_file_index(snapshot_id):
|
|
125
|
+
# Extract shard prefix from object ref (handles both sha256:<hex> and bare hex)
|
|
126
|
+
obj_shard = object_shard_prefix(record["object"])
|
|
127
|
+
if obj_shard == shard_id:
|
|
128
|
+
yield record
|
|
129
|
+
|
|
130
|
+
def _get_shard_files(
|
|
131
|
+
self, snapshot_id: str, shard_id: str
|
|
132
|
+
) -> list[dict]:
|
|
133
|
+
"""Get files assigned to a shard based on hash prefix.
|
|
134
|
+
|
|
135
|
+
Files are assigned to shards based on the first two hex chars of their object hash.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
snapshot_id: Snapshot ID.
|
|
139
|
+
shard_id: Shard ID (two hex chars, e.g., 'ab').
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
List of file index records assigned to this shard.
|
|
143
|
+
"""
|
|
144
|
+
return list(self._iter_shard_files(snapshot_id, shard_id))
|
|
145
|
+
|
|
146
|
+
def run_shard(
|
|
147
|
+
self,
|
|
148
|
+
batch_id: str,
|
|
149
|
+
task_id: str,
|
|
150
|
+
shard_id: str,
|
|
151
|
+
executor: Callable[[dict, Iterable[dict], "ShardRunner"], Iterable[dict]],
|
|
152
|
+
) -> dict:
|
|
153
|
+
"""Run a shard with the given executor.
|
|
154
|
+
|
|
155
|
+
The executor receives:
|
|
156
|
+
- task config dict
|
|
157
|
+
- iterable of file records for this shard (may be iterator or list)
|
|
158
|
+
- this runner instance (for CAS access)
|
|
159
|
+
|
|
160
|
+
The executor should return an iterable of output records.
|
|
161
|
+
Executors that need random access can call list() on the input.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
batch_id: Batch ID.
|
|
165
|
+
task_id: Task ID.
|
|
166
|
+
shard_id: Shard ID.
|
|
167
|
+
executor: Function that processes files and returns output records.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Final shard state.
|
|
171
|
+
"""
|
|
172
|
+
shard_dir = self._shard_dir(batch_id, task_id, shard_id)
|
|
173
|
+
task_events_path = shard_dir.parent.parent / "events.jsonl"
|
|
174
|
+
batch_events_path = shard_dir.parent.parent.parent.parent / "events.jsonl"
|
|
175
|
+
|
|
176
|
+
# Load current state
|
|
177
|
+
state = self._load_state(batch_id, task_id, shard_id)
|
|
178
|
+
|
|
179
|
+
# Check if already done
|
|
180
|
+
if state["status"] == "done":
|
|
181
|
+
return state
|
|
182
|
+
|
|
183
|
+
# Enforce dependency completion (Phase 2 requirement)
|
|
184
|
+
deps_ok, incomplete = self.check_deps_complete(batch_id, task_id, shard_id)
|
|
185
|
+
if not deps_ok:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Cannot run task '{task_id}' shard '{shard_id}': "
|
|
188
|
+
f"dependencies not complete: {incomplete}"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Increment attempt counter
|
|
192
|
+
state["attempt"] = state.get("attempt", 0) + 1
|
|
193
|
+
attempt = state["attempt"]
|
|
194
|
+
|
|
195
|
+
# Transition to running
|
|
196
|
+
state["status"] = "running"
|
|
197
|
+
state["started_at"] = utc_now_z()
|
|
198
|
+
self._save_state(batch_id, task_id, shard_id, state)
|
|
199
|
+
|
|
200
|
+
# Log shard_started event to both task and batch
|
|
201
|
+
for events_path in [task_events_path, batch_events_path]:
|
|
202
|
+
self._append_event(
|
|
203
|
+
events_path,
|
|
204
|
+
"shard_started",
|
|
205
|
+
batch_id,
|
|
206
|
+
task_id=task_id,
|
|
207
|
+
shard_id=shard_id,
|
|
208
|
+
attempt=attempt,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
start_time = datetime.now(timezone.utc)
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
# Load task config
|
|
215
|
+
task = self.batch_manager.load_task(batch_id, task_id)
|
|
216
|
+
batch = self.batch_manager.load_batch(batch_id)
|
|
217
|
+
snapshot_id = batch["snapshot_id"]
|
|
218
|
+
|
|
219
|
+
# Get files for this shard as streaming iterator with counting
|
|
220
|
+
# Executors that need random access should materialize with list()
|
|
221
|
+
shard_files = _CountingIterator(self._iter_shard_files(snapshot_id, shard_id))
|
|
222
|
+
|
|
223
|
+
# Enrich config with execution context for tasks that need it
|
|
224
|
+
# (e.g., symbols task needs batch_id/shard_id for iter_prior_outputs)
|
|
225
|
+
exec_config = dict(task["config"])
|
|
226
|
+
exec_config["_batch_id"] = batch_id
|
|
227
|
+
exec_config["_task_id"] = task_id
|
|
228
|
+
exec_config["_shard_id"] = shard_id
|
|
229
|
+
exec_config["_snapshot_id"] = snapshot_id
|
|
230
|
+
|
|
231
|
+
# Execute - output_records may be iterator or list
|
|
232
|
+
output_records = executor(exec_config, shard_files, self)
|
|
233
|
+
|
|
234
|
+
# Write outputs atomically, counting as we go
|
|
235
|
+
outputs_path = shard_dir / "outputs.index.jsonl"
|
|
236
|
+
temp_outputs_path = outputs_path.with_suffix(f".tmp.{os.getpid()}")
|
|
237
|
+
outputs_written = 0
|
|
238
|
+
|
|
239
|
+
with open(temp_outputs_path, "w", encoding="utf-8") as f:
|
|
240
|
+
for record in output_records:
|
|
241
|
+
# Ensure required fields
|
|
242
|
+
record.setdefault("schema_version", SCHEMA_VERSION)
|
|
243
|
+
record.setdefault("snapshot_id", snapshot_id)
|
|
244
|
+
record.setdefault("batch_id", batch_id)
|
|
245
|
+
record.setdefault("task_id", task_id)
|
|
246
|
+
record.setdefault("shard_id", shard_id)
|
|
247
|
+
record.setdefault("ts", utc_now_z())
|
|
248
|
+
f.write(json.dumps(record, separators=(",", ":")))
|
|
249
|
+
f.write("\n")
|
|
250
|
+
outputs_written += 1
|
|
251
|
+
|
|
252
|
+
# Atomic rename (use replace for Windows compatibility)
|
|
253
|
+
temp_outputs_path.replace(outputs_path)
|
|
254
|
+
|
|
255
|
+
# Calculate duration
|
|
256
|
+
end_time = datetime.now(timezone.utc)
|
|
257
|
+
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
|
258
|
+
|
|
259
|
+
# Transition to done
|
|
260
|
+
state["status"] = "done"
|
|
261
|
+
state["completed_at"] = utc_now_z()
|
|
262
|
+
state["stats"] = {
|
|
263
|
+
"files_processed": shard_files.count,
|
|
264
|
+
"outputs_written": outputs_written,
|
|
265
|
+
}
|
|
266
|
+
self._save_state(batch_id, task_id, shard_id, state)
|
|
267
|
+
|
|
268
|
+
# Log shard_completed event to both task and batch
|
|
269
|
+
for events_path in [task_events_path, batch_events_path]:
|
|
270
|
+
self._append_event(
|
|
271
|
+
events_path,
|
|
272
|
+
"shard_completed",
|
|
273
|
+
batch_id,
|
|
274
|
+
task_id=task_id,
|
|
275
|
+
shard_id=shard_id,
|
|
276
|
+
attempt=attempt,
|
|
277
|
+
duration_ms=duration_ms,
|
|
278
|
+
stats=state["stats"],
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
except Exception as e:
|
|
282
|
+
# Calculate duration
|
|
283
|
+
end_time = datetime.now(timezone.utc)
|
|
284
|
+
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
|
285
|
+
|
|
286
|
+
# Transition to failed
|
|
287
|
+
error_info = {
|
|
288
|
+
"code": type(e).__name__,
|
|
289
|
+
"message": str(e),
|
|
290
|
+
}
|
|
291
|
+
state["status"] = "failed"
|
|
292
|
+
state["completed_at"] = utc_now_z()
|
|
293
|
+
state["error"] = error_info
|
|
294
|
+
self._save_state(batch_id, task_id, shard_id, state)
|
|
295
|
+
|
|
296
|
+
# Log shard_failed event to both task and batch
|
|
297
|
+
for events_path in [task_events_path, batch_events_path]:
|
|
298
|
+
self._append_event(
|
|
299
|
+
events_path,
|
|
300
|
+
"shard_failed",
|
|
301
|
+
batch_id,
|
|
302
|
+
task_id=task_id,
|
|
303
|
+
shard_id=shard_id,
|
|
304
|
+
attempt=attempt,
|
|
305
|
+
duration_ms=duration_ms,
|
|
306
|
+
error=error_info,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return state
|
|
310
|
+
|
|
311
|
+
def reset_shard(self, batch_id: str, task_id: str, shard_id: str) -> dict:
|
|
312
|
+
"""Reset a shard to ready state for retry.
|
|
313
|
+
|
|
314
|
+
Only failed shards can be reset.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
batch_id: Batch ID.
|
|
318
|
+
task_id: Task ID.
|
|
319
|
+
shard_id: Shard ID.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
New shard state.
|
|
323
|
+
|
|
324
|
+
Raises:
|
|
325
|
+
ValueError: If shard is not in failed state.
|
|
326
|
+
"""
|
|
327
|
+
state = self._load_state(batch_id, task_id, shard_id)
|
|
328
|
+
|
|
329
|
+
if state["status"] != "failed":
|
|
330
|
+
raise ValueError(f"Can only reset failed shards, current status: {state['status']}")
|
|
331
|
+
|
|
332
|
+
# Keep attempt counter for tracking
|
|
333
|
+
attempt = state.get("attempt", 0)
|
|
334
|
+
|
|
335
|
+
# Reset to ready
|
|
336
|
+
new_state = {
|
|
337
|
+
"schema_name": "codebatch.shard_state",
|
|
338
|
+
"schema_version": SCHEMA_VERSION,
|
|
339
|
+
"producer": PRODUCER,
|
|
340
|
+
"shard_id": shard_id,
|
|
341
|
+
"task_id": task_id,
|
|
342
|
+
"batch_id": batch_id,
|
|
343
|
+
"status": "ready",
|
|
344
|
+
"attempt": attempt, # Preserve attempt count
|
|
345
|
+
}
|
|
346
|
+
self._save_state(batch_id, task_id, shard_id, new_state)
|
|
347
|
+
|
|
348
|
+
# Log retry event
|
|
349
|
+
task_events_path = self._shard_dir(batch_id, task_id, shard_id).parent.parent / "events.jsonl"
|
|
350
|
+
self._append_event(
|
|
351
|
+
task_events_path,
|
|
352
|
+
"shard_retrying",
|
|
353
|
+
batch_id,
|
|
354
|
+
task_id=task_id,
|
|
355
|
+
shard_id=shard_id,
|
|
356
|
+
attempt=attempt + 1,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return new_state
|
|
360
|
+
|
|
361
|
+
def get_shard_outputs(self, batch_id: str, task_id: str, shard_id: str) -> list[dict]:
|
|
362
|
+
"""Get output records for a shard.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
batch_id: Batch ID.
|
|
366
|
+
task_id: Task ID.
|
|
367
|
+
shard_id: Shard ID.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
List of output records.
|
|
371
|
+
"""
|
|
372
|
+
outputs_path = self._shard_dir(batch_id, task_id, shard_id) / "outputs.index.jsonl"
|
|
373
|
+
if not outputs_path.exists():
|
|
374
|
+
return []
|
|
375
|
+
|
|
376
|
+
records = []
|
|
377
|
+
with open(outputs_path, "r", encoding="utf-8") as f:
|
|
378
|
+
for line in f:
|
|
379
|
+
line = line.strip()
|
|
380
|
+
if line:
|
|
381
|
+
records.append(json.loads(line))
|
|
382
|
+
return records
|
|
383
|
+
|
|
384
|
+
def iter_prior_outputs(
|
|
385
|
+
self,
|
|
386
|
+
batch_id: str,
|
|
387
|
+
task_id: str,
|
|
388
|
+
shard_id: str,
|
|
389
|
+
kind: Optional[str] = None,
|
|
390
|
+
) -> Iterator[dict]:
|
|
391
|
+
"""Stream output records from a prior task in the same shard.
|
|
392
|
+
|
|
393
|
+
This is the approved mechanism for tasks to consume outputs from
|
|
394
|
+
their dependencies. Tasks may only read from their own shard.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
batch_id: Batch ID.
|
|
398
|
+
task_id: Task ID of the dependency (e.g., "01_parse").
|
|
399
|
+
shard_id: Shard ID (must match current shard).
|
|
400
|
+
kind: Optional filter by output kind (e.g., "ast", "diagnostic").
|
|
401
|
+
|
|
402
|
+
Yields:
|
|
403
|
+
Output records from the prior task, optionally filtered by kind.
|
|
404
|
+
|
|
405
|
+
Raises:
|
|
406
|
+
FileNotFoundError: If the dependency task shard doesn't exist.
|
|
407
|
+
"""
|
|
408
|
+
outputs_path = self._shard_dir(batch_id, task_id, shard_id) / "outputs.index.jsonl"
|
|
409
|
+
if not outputs_path.exists():
|
|
410
|
+
return
|
|
411
|
+
|
|
412
|
+
with open(outputs_path, "r", encoding="utf-8") as f:
|
|
413
|
+
for line in f:
|
|
414
|
+
line = line.strip()
|
|
415
|
+
if not line:
|
|
416
|
+
continue
|
|
417
|
+
record = json.loads(line)
|
|
418
|
+
if kind is None or record.get("kind") == kind:
|
|
419
|
+
yield record
|
|
420
|
+
|
|
421
|
+
def write_shard_outputs(
|
|
422
|
+
self,
|
|
423
|
+
batch_id: str,
|
|
424
|
+
task_id: str,
|
|
425
|
+
shard_id: str,
|
|
426
|
+
records: Iterable[dict],
|
|
427
|
+
snapshot_id: str,
|
|
428
|
+
) -> int:
|
|
429
|
+
"""Write shard outputs atomically with per-shard replacement.
|
|
430
|
+
|
|
431
|
+
This is the ONLY approved mechanism for writing outputs.index.jsonl.
|
|
432
|
+
Tasks should use this helper rather than writing directly.
|
|
433
|
+
|
|
434
|
+
The write is atomic: temp file -> replace. This enforces the
|
|
435
|
+
per-shard replacement policy (no appending).
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
batch_id: Batch ID.
|
|
439
|
+
task_id: Task ID.
|
|
440
|
+
shard_id: Shard ID.
|
|
441
|
+
records: Iterable of output records.
|
|
442
|
+
snapshot_id: Snapshot ID for record enrichment.
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
Number of records written.
|
|
446
|
+
"""
|
|
447
|
+
shard_dir = self._shard_dir(batch_id, task_id, shard_id)
|
|
448
|
+
outputs_path = shard_dir / "outputs.index.jsonl"
|
|
449
|
+
temp_outputs_path = outputs_path.with_suffix(f".tmp.{os.getpid()}")
|
|
450
|
+
|
|
451
|
+
outputs_written = 0
|
|
452
|
+
with open(temp_outputs_path, "w", encoding="utf-8") as f:
|
|
453
|
+
for record in records:
|
|
454
|
+
# Ensure required fields
|
|
455
|
+
record.setdefault("schema_version", SCHEMA_VERSION)
|
|
456
|
+
record.setdefault("snapshot_id", snapshot_id)
|
|
457
|
+
record.setdefault("batch_id", batch_id)
|
|
458
|
+
record.setdefault("task_id", task_id)
|
|
459
|
+
record.setdefault("shard_id", shard_id)
|
|
460
|
+
record.setdefault("ts", utc_now_z())
|
|
461
|
+
f.write(json.dumps(record, separators=(",", ":")))
|
|
462
|
+
f.write("\n")
|
|
463
|
+
outputs_written += 1
|
|
464
|
+
|
|
465
|
+
# Atomic replace (not append!)
|
|
466
|
+
temp_outputs_path.replace(outputs_path)
|
|
467
|
+
return outputs_written
|
|
468
|
+
|
|
469
|
+
def check_deps_complete(self, batch_id: str, task_id: str, shard_id: str) -> tuple[bool, list[str]]:
|
|
470
|
+
"""Check if all dependencies for a task are complete in this shard.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
batch_id: Batch ID.
|
|
474
|
+
task_id: Task ID to check.
|
|
475
|
+
shard_id: Shard ID.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
Tuple of (all_complete, incomplete_task_ids).
|
|
479
|
+
"""
|
|
480
|
+
task = self.batch_manager.load_task(batch_id, task_id)
|
|
481
|
+
deps = task.get("inputs", {}).get("tasks", [])
|
|
482
|
+
|
|
483
|
+
if not deps:
|
|
484
|
+
return True, []
|
|
485
|
+
|
|
486
|
+
incomplete = []
|
|
487
|
+
for dep_task_id in deps:
|
|
488
|
+
try:
|
|
489
|
+
state = self._load_state(batch_id, dep_task_id, shard_id)
|
|
490
|
+
if state.get("status") != "done":
|
|
491
|
+
incomplete.append(dep_task_id)
|
|
492
|
+
except FileNotFoundError:
|
|
493
|
+
incomplete.append(dep_task_id)
|
|
494
|
+
|
|
495
|
+
return len(incomplete) == 0, incomplete
|