codebatch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codebatch/runner.py ADDED
@@ -0,0 +1,495 @@
1
+ """Shard runner with state machine and atomic index commit.
2
+
3
+ Shard execution follows a monotonic state machine:
4
+ ready -> running -> done|failed
5
+
6
+ State transitions are atomic and never corrupt prior results.
7
+ """
8
+
9
+ import json
10
+ import os
11
+ from datetime import datetime, timezone
12
+ from pathlib import Path
13
+ from typing import Callable, Iterable, Iterator, Optional, Union
14
+
15
+ from .batch import BatchManager
16
+ from .cas import ObjectStore
17
+ from .common import SCHEMA_VERSION, PRODUCER, utc_now_z, object_shard_prefix
18
+ from .snapshot import SnapshotBuilder
19
+
20
+
21
+ class _CountingIterator:
22
+ """Iterator wrapper that counts items as they're yielded."""
23
+
24
+ def __init__(self, iterable: Iterable):
25
+ self._iterator = iter(iterable)
26
+ self.count = 0
27
+
28
+ def __iter__(self):
29
+ return self
30
+
31
+ def __next__(self):
32
+ item = next(self._iterator)
33
+ self.count += 1
34
+ return item
35
+
36
+
37
+ class ShardRunner:
38
+ """Runs individual shards with state management and atomic output commits."""
39
+
40
+ def __init__(self, store_root: Path):
41
+ """Initialize the shard runner.
42
+
43
+ Args:
44
+ store_root: Root directory of the CodeBatch store.
45
+ """
46
+ self.store_root = Path(store_root)
47
+ self.batch_manager = BatchManager(store_root)
48
+ self.snapshot_builder = SnapshotBuilder(store_root)
49
+ self.object_store = ObjectStore(store_root)
50
+
51
+ def _shard_dir(self, batch_id: str, task_id: str, shard_id: str) -> Path:
52
+ """Get the shard directory path."""
53
+ return self.store_root / "batches" / batch_id / "tasks" / task_id / "shards" / shard_id
54
+
55
+ def _load_state(self, batch_id: str, task_id: str, shard_id: str) -> dict:
56
+ """Load shard state."""
57
+ state_path = self._shard_dir(batch_id, task_id, shard_id) / "state.json"
58
+ with open(state_path, "r", encoding="utf-8") as f:
59
+ return json.load(f)
60
+
61
+ def _save_state(self, batch_id: str, task_id: str, shard_id: str, state: dict) -> None:
62
+ """Save shard state atomically."""
63
+ shard_dir = self._shard_dir(batch_id, task_id, shard_id)
64
+ state_path = shard_dir / "state.json"
65
+
66
+ # Atomic write via temp file with PID for race safety
67
+ temp_path = state_path.with_suffix(f".tmp.{os.getpid()}")
68
+ with open(temp_path, "w", encoding="utf-8") as f:
69
+ json.dump(state, f, indent=2)
70
+ # On Windows, rename fails if target exists; use replace instead
71
+ temp_path.replace(state_path)
72
+
73
+ def _append_event(
74
+ self,
75
+ events_path: Path,
76
+ event: str,
77
+ batch_id: str,
78
+ task_id: str = None,
79
+ shard_id: str = None,
80
+ attempt: int = None,
81
+ duration_ms: int = None,
82
+ error: dict = None,
83
+ stats: dict = None,
84
+ ) -> None:
85
+ """Append an event record to events.jsonl."""
86
+ record = {
87
+ "schema_version": SCHEMA_VERSION,
88
+ "ts": utc_now_z(),
89
+ "event": event,
90
+ "batch_id": batch_id,
91
+ }
92
+ if task_id:
93
+ record["task_id"] = task_id
94
+ if shard_id:
95
+ record["shard_id"] = shard_id
96
+ if attempt is not None:
97
+ record["attempt"] = attempt
98
+ if duration_ms is not None:
99
+ record["duration_ms"] = duration_ms
100
+ if error:
101
+ record["error"] = error
102
+ if stats:
103
+ record["stats"] = stats
104
+
105
+ with open(events_path, "a", encoding="utf-8") as f:
106
+ f.write(json.dumps(record, separators=(",", ":")))
107
+ f.write("\n")
108
+
109
+ def _iter_shard_files(
110
+ self, snapshot_id: str, shard_id: str
111
+ ) -> Iterator[dict]:
112
+ """Stream files assigned to a shard based on hash prefix.
113
+
114
+ Files are assigned to shards based on the first two hex chars of their object hash.
115
+ Uses streaming to avoid loading entire index into memory.
116
+
117
+ Args:
118
+ snapshot_id: Snapshot ID.
119
+ shard_id: Shard ID (two hex chars, e.g., 'ab').
120
+
121
+ Yields:
122
+ File index records assigned to this shard.
123
+ """
124
+ for record in self.snapshot_builder.iter_file_index(snapshot_id):
125
+ # Extract shard prefix from object ref (handles both sha256:<hex> and bare hex)
126
+ obj_shard = object_shard_prefix(record["object"])
127
+ if obj_shard == shard_id:
128
+ yield record
129
+
130
+ def _get_shard_files(
131
+ self, snapshot_id: str, shard_id: str
132
+ ) -> list[dict]:
133
+ """Get files assigned to a shard based on hash prefix.
134
+
135
+ Files are assigned to shards based on the first two hex chars of their object hash.
136
+
137
+ Args:
138
+ snapshot_id: Snapshot ID.
139
+ shard_id: Shard ID (two hex chars, e.g., 'ab').
140
+
141
+ Returns:
142
+ List of file index records assigned to this shard.
143
+ """
144
+ return list(self._iter_shard_files(snapshot_id, shard_id))
145
+
146
+ def run_shard(
147
+ self,
148
+ batch_id: str,
149
+ task_id: str,
150
+ shard_id: str,
151
+ executor: Callable[[dict, Iterable[dict], "ShardRunner"], Iterable[dict]],
152
+ ) -> dict:
153
+ """Run a shard with the given executor.
154
+
155
+ The executor receives:
156
+ - task config dict
157
+ - iterable of file records for this shard (may be iterator or list)
158
+ - this runner instance (for CAS access)
159
+
160
+ The executor should return an iterable of output records.
161
+ Executors that need random access can call list() on the input.
162
+
163
+ Args:
164
+ batch_id: Batch ID.
165
+ task_id: Task ID.
166
+ shard_id: Shard ID.
167
+ executor: Function that processes files and returns output records.
168
+
169
+ Returns:
170
+ Final shard state.
171
+ """
172
+ shard_dir = self._shard_dir(batch_id, task_id, shard_id)
173
+ task_events_path = shard_dir.parent.parent / "events.jsonl"
174
+ batch_events_path = shard_dir.parent.parent.parent.parent / "events.jsonl"
175
+
176
+ # Load current state
177
+ state = self._load_state(batch_id, task_id, shard_id)
178
+
179
+ # Check if already done
180
+ if state["status"] == "done":
181
+ return state
182
+
183
+ # Enforce dependency completion (Phase 2 requirement)
184
+ deps_ok, incomplete = self.check_deps_complete(batch_id, task_id, shard_id)
185
+ if not deps_ok:
186
+ raise ValueError(
187
+ f"Cannot run task '{task_id}' shard '{shard_id}': "
188
+ f"dependencies not complete: {incomplete}"
189
+ )
190
+
191
+ # Increment attempt counter
192
+ state["attempt"] = state.get("attempt", 0) + 1
193
+ attempt = state["attempt"]
194
+
195
+ # Transition to running
196
+ state["status"] = "running"
197
+ state["started_at"] = utc_now_z()
198
+ self._save_state(batch_id, task_id, shard_id, state)
199
+
200
+ # Log shard_started event to both task and batch
201
+ for events_path in [task_events_path, batch_events_path]:
202
+ self._append_event(
203
+ events_path,
204
+ "shard_started",
205
+ batch_id,
206
+ task_id=task_id,
207
+ shard_id=shard_id,
208
+ attempt=attempt,
209
+ )
210
+
211
+ start_time = datetime.now(timezone.utc)
212
+
213
+ try:
214
+ # Load task config
215
+ task = self.batch_manager.load_task(batch_id, task_id)
216
+ batch = self.batch_manager.load_batch(batch_id)
217
+ snapshot_id = batch["snapshot_id"]
218
+
219
+ # Get files for this shard as streaming iterator with counting
220
+ # Executors that need random access should materialize with list()
221
+ shard_files = _CountingIterator(self._iter_shard_files(snapshot_id, shard_id))
222
+
223
+ # Enrich config with execution context for tasks that need it
224
+ # (e.g., symbols task needs batch_id/shard_id for iter_prior_outputs)
225
+ exec_config = dict(task["config"])
226
+ exec_config["_batch_id"] = batch_id
227
+ exec_config["_task_id"] = task_id
228
+ exec_config["_shard_id"] = shard_id
229
+ exec_config["_snapshot_id"] = snapshot_id
230
+
231
+ # Execute - output_records may be iterator or list
232
+ output_records = executor(exec_config, shard_files, self)
233
+
234
+ # Write outputs atomically, counting as we go
235
+ outputs_path = shard_dir / "outputs.index.jsonl"
236
+ temp_outputs_path = outputs_path.with_suffix(f".tmp.{os.getpid()}")
237
+ outputs_written = 0
238
+
239
+ with open(temp_outputs_path, "w", encoding="utf-8") as f:
240
+ for record in output_records:
241
+ # Ensure required fields
242
+ record.setdefault("schema_version", SCHEMA_VERSION)
243
+ record.setdefault("snapshot_id", snapshot_id)
244
+ record.setdefault("batch_id", batch_id)
245
+ record.setdefault("task_id", task_id)
246
+ record.setdefault("shard_id", shard_id)
247
+ record.setdefault("ts", utc_now_z())
248
+ f.write(json.dumps(record, separators=(",", ":")))
249
+ f.write("\n")
250
+ outputs_written += 1
251
+
252
+ # Atomic rename (use replace for Windows compatibility)
253
+ temp_outputs_path.replace(outputs_path)
254
+
255
+ # Calculate duration
256
+ end_time = datetime.now(timezone.utc)
257
+ duration_ms = int((end_time - start_time).total_seconds() * 1000)
258
+
259
+ # Transition to done
260
+ state["status"] = "done"
261
+ state["completed_at"] = utc_now_z()
262
+ state["stats"] = {
263
+ "files_processed": shard_files.count,
264
+ "outputs_written": outputs_written,
265
+ }
266
+ self._save_state(batch_id, task_id, shard_id, state)
267
+
268
+ # Log shard_completed event to both task and batch
269
+ for events_path in [task_events_path, batch_events_path]:
270
+ self._append_event(
271
+ events_path,
272
+ "shard_completed",
273
+ batch_id,
274
+ task_id=task_id,
275
+ shard_id=shard_id,
276
+ attempt=attempt,
277
+ duration_ms=duration_ms,
278
+ stats=state["stats"],
279
+ )
280
+
281
+ except Exception as e:
282
+ # Calculate duration
283
+ end_time = datetime.now(timezone.utc)
284
+ duration_ms = int((end_time - start_time).total_seconds() * 1000)
285
+
286
+ # Transition to failed
287
+ error_info = {
288
+ "code": type(e).__name__,
289
+ "message": str(e),
290
+ }
291
+ state["status"] = "failed"
292
+ state["completed_at"] = utc_now_z()
293
+ state["error"] = error_info
294
+ self._save_state(batch_id, task_id, shard_id, state)
295
+
296
+ # Log shard_failed event to both task and batch
297
+ for events_path in [task_events_path, batch_events_path]:
298
+ self._append_event(
299
+ events_path,
300
+ "shard_failed",
301
+ batch_id,
302
+ task_id=task_id,
303
+ shard_id=shard_id,
304
+ attempt=attempt,
305
+ duration_ms=duration_ms,
306
+ error=error_info,
307
+ )
308
+
309
+ return state
310
+
311
+ def reset_shard(self, batch_id: str, task_id: str, shard_id: str) -> dict:
312
+ """Reset a shard to ready state for retry.
313
+
314
+ Only failed shards can be reset.
315
+
316
+ Args:
317
+ batch_id: Batch ID.
318
+ task_id: Task ID.
319
+ shard_id: Shard ID.
320
+
321
+ Returns:
322
+ New shard state.
323
+
324
+ Raises:
325
+ ValueError: If shard is not in failed state.
326
+ """
327
+ state = self._load_state(batch_id, task_id, shard_id)
328
+
329
+ if state["status"] != "failed":
330
+ raise ValueError(f"Can only reset failed shards, current status: {state['status']}")
331
+
332
+ # Keep attempt counter for tracking
333
+ attempt = state.get("attempt", 0)
334
+
335
+ # Reset to ready
336
+ new_state = {
337
+ "schema_name": "codebatch.shard_state",
338
+ "schema_version": SCHEMA_VERSION,
339
+ "producer": PRODUCER,
340
+ "shard_id": shard_id,
341
+ "task_id": task_id,
342
+ "batch_id": batch_id,
343
+ "status": "ready",
344
+ "attempt": attempt, # Preserve attempt count
345
+ }
346
+ self._save_state(batch_id, task_id, shard_id, new_state)
347
+
348
+ # Log retry event
349
+ task_events_path = self._shard_dir(batch_id, task_id, shard_id).parent.parent / "events.jsonl"
350
+ self._append_event(
351
+ task_events_path,
352
+ "shard_retrying",
353
+ batch_id,
354
+ task_id=task_id,
355
+ shard_id=shard_id,
356
+ attempt=attempt + 1,
357
+ )
358
+
359
+ return new_state
360
+
361
+ def get_shard_outputs(self, batch_id: str, task_id: str, shard_id: str) -> list[dict]:
362
+ """Get output records for a shard.
363
+
364
+ Args:
365
+ batch_id: Batch ID.
366
+ task_id: Task ID.
367
+ shard_id: Shard ID.
368
+
369
+ Returns:
370
+ List of output records.
371
+ """
372
+ outputs_path = self._shard_dir(batch_id, task_id, shard_id) / "outputs.index.jsonl"
373
+ if not outputs_path.exists():
374
+ return []
375
+
376
+ records = []
377
+ with open(outputs_path, "r", encoding="utf-8") as f:
378
+ for line in f:
379
+ line = line.strip()
380
+ if line:
381
+ records.append(json.loads(line))
382
+ return records
383
+
384
+ def iter_prior_outputs(
385
+ self,
386
+ batch_id: str,
387
+ task_id: str,
388
+ shard_id: str,
389
+ kind: Optional[str] = None,
390
+ ) -> Iterator[dict]:
391
+ """Stream output records from a prior task in the same shard.
392
+
393
+ This is the approved mechanism for tasks to consume outputs from
394
+ their dependencies. Tasks may only read from their own shard.
395
+
396
+ Args:
397
+ batch_id: Batch ID.
398
+ task_id: Task ID of the dependency (e.g., "01_parse").
399
+ shard_id: Shard ID (must match current shard).
400
+ kind: Optional filter by output kind (e.g., "ast", "diagnostic").
401
+
402
+ Yields:
403
+ Output records from the prior task, optionally filtered by kind.
404
+
405
+ Raises:
406
+ FileNotFoundError: If the dependency task shard doesn't exist.
407
+ """
408
+ outputs_path = self._shard_dir(batch_id, task_id, shard_id) / "outputs.index.jsonl"
409
+ if not outputs_path.exists():
410
+ return
411
+
412
+ with open(outputs_path, "r", encoding="utf-8") as f:
413
+ for line in f:
414
+ line = line.strip()
415
+ if not line:
416
+ continue
417
+ record = json.loads(line)
418
+ if kind is None or record.get("kind") == kind:
419
+ yield record
420
+
421
+ def write_shard_outputs(
422
+ self,
423
+ batch_id: str,
424
+ task_id: str,
425
+ shard_id: str,
426
+ records: Iterable[dict],
427
+ snapshot_id: str,
428
+ ) -> int:
429
+ """Write shard outputs atomically with per-shard replacement.
430
+
431
+ This is the ONLY approved mechanism for writing outputs.index.jsonl.
432
+ Tasks should use this helper rather than writing directly.
433
+
434
+ The write is atomic: temp file -> replace. This enforces the
435
+ per-shard replacement policy (no appending).
436
+
437
+ Args:
438
+ batch_id: Batch ID.
439
+ task_id: Task ID.
440
+ shard_id: Shard ID.
441
+ records: Iterable of output records.
442
+ snapshot_id: Snapshot ID for record enrichment.
443
+
444
+ Returns:
445
+ Number of records written.
446
+ """
447
+ shard_dir = self._shard_dir(batch_id, task_id, shard_id)
448
+ outputs_path = shard_dir / "outputs.index.jsonl"
449
+ temp_outputs_path = outputs_path.with_suffix(f".tmp.{os.getpid()}")
450
+
451
+ outputs_written = 0
452
+ with open(temp_outputs_path, "w", encoding="utf-8") as f:
453
+ for record in records:
454
+ # Ensure required fields
455
+ record.setdefault("schema_version", SCHEMA_VERSION)
456
+ record.setdefault("snapshot_id", snapshot_id)
457
+ record.setdefault("batch_id", batch_id)
458
+ record.setdefault("task_id", task_id)
459
+ record.setdefault("shard_id", shard_id)
460
+ record.setdefault("ts", utc_now_z())
461
+ f.write(json.dumps(record, separators=(",", ":")))
462
+ f.write("\n")
463
+ outputs_written += 1
464
+
465
+ # Atomic replace (not append!)
466
+ temp_outputs_path.replace(outputs_path)
467
+ return outputs_written
468
+
469
+ def check_deps_complete(self, batch_id: str, task_id: str, shard_id: str) -> tuple[bool, list[str]]:
470
+ """Check if all dependencies for a task are complete in this shard.
471
+
472
+ Args:
473
+ batch_id: Batch ID.
474
+ task_id: Task ID to check.
475
+ shard_id: Shard ID.
476
+
477
+ Returns:
478
+ Tuple of (all_complete, incomplete_task_ids).
479
+ """
480
+ task = self.batch_manager.load_task(batch_id, task_id)
481
+ deps = task.get("inputs", {}).get("tasks", [])
482
+
483
+ if not deps:
484
+ return True, []
485
+
486
+ incomplete = []
487
+ for dep_task_id in deps:
488
+ try:
489
+ state = self._load_state(batch_id, dep_task_id, shard_id)
490
+ if state.get("status") != "done":
491
+ incomplete.append(dep_task_id)
492
+ except FileNotFoundError:
493
+ incomplete.append(dep_task_id)
494
+
495
+ return len(incomplete) == 0, incomplete