deepresearch-flow 0.5.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. deepresearch_flow/paper/cli.py +63 -0
  2. deepresearch_flow/paper/config.py +87 -12
  3. deepresearch_flow/paper/db.py +1041 -34
  4. deepresearch_flow/paper/db_ops.py +124 -19
  5. deepresearch_flow/paper/extract.py +1546 -152
  6. deepresearch_flow/paper/prompt_templates/deep_read_phi_system.j2 +2 -0
  7. deepresearch_flow/paper/prompt_templates/deep_read_phi_user.j2 +5 -0
  8. deepresearch_flow/paper/prompt_templates/deep_read_system.j2 +2 -0
  9. deepresearch_flow/paper/prompt_templates/deep_read_user.j2 +272 -40
  10. deepresearch_flow/paper/prompt_templates/eight_questions_phi_system.j2 +1 -0
  11. deepresearch_flow/paper/prompt_templates/eight_questions_phi_user.j2 +2 -0
  12. deepresearch_flow/paper/prompt_templates/eight_questions_system.j2 +2 -0
  13. deepresearch_flow/paper/prompt_templates/eight_questions_user.j2 +4 -0
  14. deepresearch_flow/paper/prompt_templates/simple_phi_system.j2 +2 -0
  15. deepresearch_flow/paper/prompt_templates/simple_system.j2 +2 -0
  16. deepresearch_flow/paper/prompt_templates/simple_user.j2 +2 -0
  17. deepresearch_flow/paper/providers/azure_openai.py +45 -3
  18. deepresearch_flow/paper/providers/openai_compatible.py +45 -3
  19. deepresearch_flow/paper/schemas/deep_read_phi_schema.json +1 -0
  20. deepresearch_flow/paper/schemas/deep_read_schema.json +1 -0
  21. deepresearch_flow/paper/schemas/default_paper_schema.json +6 -0
  22. deepresearch_flow/paper/schemas/eight_questions_schema.json +1 -0
  23. deepresearch_flow/paper/snapshot/__init__.py +4 -0
  24. deepresearch_flow/paper/snapshot/api.py +941 -0
  25. deepresearch_flow/paper/snapshot/builder.py +965 -0
  26. deepresearch_flow/paper/snapshot/identity.py +239 -0
  27. deepresearch_flow/paper/snapshot/schema.py +245 -0
  28. deepresearch_flow/paper/snapshot/tests/__init__.py +2 -0
  29. deepresearch_flow/paper/snapshot/tests/test_identity.py +123 -0
  30. deepresearch_flow/paper/snapshot/text.py +154 -0
  31. deepresearch_flow/paper/template_registry.py +1 -0
  32. deepresearch_flow/paper/templates/deep_read.md.j2 +4 -0
  33. deepresearch_flow/paper/templates/deep_read_phi.md.j2 +4 -0
  34. deepresearch_flow/paper/templates/default_paper.md.j2 +4 -0
  35. deepresearch_flow/paper/templates/eight_questions.md.j2 +4 -0
  36. deepresearch_flow/paper/web/app.py +10 -3
  37. deepresearch_flow/recognize/cli.py +380 -103
  38. deepresearch_flow/recognize/markdown.py +31 -7
  39. deepresearch_flow/recognize/math.py +47 -12
  40. deepresearch_flow/recognize/mermaid.py +320 -10
  41. deepresearch_flow/recognize/organize.py +29 -7
  42. deepresearch_flow/translator/cli.py +71 -20
  43. deepresearch_flow/translator/engine.py +220 -81
  44. deepresearch_flow/translator/prompts.py +19 -2
  45. deepresearch_flow/translator/protector.py +15 -3
  46. {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/METADATA +407 -33
  47. {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/RECORD +51 -43
  48. {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/WHEEL +1 -1
  49. {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/entry_points.txt +0 -0
  50. {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/licenses/LICENSE +0 -0
  51. {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.0.dist-info}/top_level.txt +0 -0
@@ -3,13 +3,19 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import asyncio
6
+ import hashlib
6
7
  import json
8
+ import math
9
+ from collections import deque
7
10
  from dataclasses import dataclass, field
8
- from datetime import datetime
11
+ from datetime import datetime, timezone
9
12
  from pathlib import Path
10
13
  from typing import Any, Iterable
14
+ import importlib.resources as resources
15
+ import contextlib
11
16
  import logging
12
17
  import re
18
+ import signal
13
19
  import time
14
20
 
15
21
  import coloredlogs
@@ -19,12 +25,20 @@ from jsonschema import Draft7Validator
19
25
  from rich.console import Console
20
26
  from rich.table import Table
21
27
  from tqdm import tqdm
22
- from deepresearch_flow.paper.config import PaperConfig, ProviderConfig, resolve_api_keys
28
+ from deepresearch_flow.paper.config import (
29
+ ApiKeyConfig,
30
+ PaperConfig,
31
+ ProviderConfig,
32
+ resolve_api_key_configs,
33
+ resolve_api_keys,
34
+ )
23
35
  from deepresearch_flow.paper.llm import backoff_delay, call_provider
24
36
  from deepresearch_flow.paper.prompts import DEFAULT_SYSTEM_PROMPT, DEFAULT_USER_PROMPT
25
37
  from deepresearch_flow.paper.render import render_papers, resolve_render_template
26
38
  from deepresearch_flow.paper.schema import schema_to_prompt, validate_schema
27
39
  from deepresearch_flow.paper.template_registry import (
40
+ StageDefinition,
41
+ get_template_bundle,
28
42
  get_stage_definitions,
29
43
  load_custom_prompt_templates,
30
44
  load_prompt_templates,
@@ -53,22 +67,257 @@ class ExtractionError:
53
67
  stage_name: str | None = None
54
68
 
55
69
 
70
+ def _mask_key(value: str, *, keep: int = 4) -> str:
71
+ if not value:
72
+ return "<empty>"
73
+ if len(value) <= keep:
74
+ return value
75
+ return f"...{value[-keep:]}"
76
+
77
+
56
78
  class KeyRotator:
57
- def __init__(self, keys: list[str]) -> None:
79
+ def __init__(self, keys: list[ApiKeyConfig], *, cooldown_seconds: float, verbose: bool) -> None:
58
80
  self._keys = keys
59
81
  self._idx = 0
60
82
  self._lock = asyncio.Lock()
83
+ self._cooldown_seconds = max(cooldown_seconds, 0.0)
84
+ self._verbose = verbose
85
+ self._cooldowns: dict[str, float] = {key.key: 0.0 for key in keys}
86
+ self._quota_until: dict[str, float] = {key.key: 0.0 for key in keys}
87
+ self._key_meta: dict[str, ApiKeyConfig] = {key.key: key for key in keys}
88
+ self._error_counts: dict[str, int] = {key.key: 0 for key in keys}
89
+ self._last_pause_until: float = 0.0
90
+ self._last_key_quota_until: dict[str, float] = {key.key: 0.0 for key in keys}
61
91
 
62
92
  async def next_key(self) -> str | None:
63
93
  if not self._keys:
64
94
  return None
95
+ while True:
96
+ wait_for: float | None = None
97
+ wait_until_epoch: float | None = None
98
+ pause_reason: str | None = None
99
+ should_log_pause = False
100
+ async with self._lock:
101
+ now = time.monotonic()
102
+ now_epoch = time.time()
103
+ total = len(self._keys)
104
+ for offset in range(total):
105
+ idx = (self._idx + offset) % total
106
+ key = self._keys[idx].key
107
+ if (
108
+ self._cooldowns.get(key, 0.0) <= now
109
+ and self._quota_until.get(key, 0.0) <= now_epoch
110
+ ):
111
+ self._idx = idx + 1
112
+ return key
113
+ waits: list[float] = []
114
+ has_cooldown_wait = False
115
+ has_quota_wait = False
116
+ for meta in self._keys:
117
+ key = meta.key
118
+ cooldown_wait = max(self._cooldowns.get(key, 0.0) - now, 0.0)
119
+ quota_wait = max(self._quota_until.get(key, 0.0) - now_epoch, 0.0)
120
+ if cooldown_wait > 0:
121
+ has_cooldown_wait = True
122
+ if quota_wait > 0:
123
+ has_quota_wait = True
124
+ waits.append(max(cooldown_wait, quota_wait))
125
+ wait_for = min(waits) if waits else None
126
+ if wait_for is not None:
127
+ wait_until_epoch = now_epoch + wait_for
128
+ if wait_until_epoch > self._last_pause_until + 0.5:
129
+ self._last_pause_until = wait_until_epoch
130
+ if has_quota_wait and has_cooldown_wait:
131
+ pause_reason = "quota/cooldown"
132
+ elif has_quota_wait:
133
+ pause_reason = "quota"
134
+ elif has_cooldown_wait:
135
+ pause_reason = "cooldown"
136
+ else:
137
+ pause_reason = "unknown"
138
+ should_log_pause = True
139
+ if wait_for is None:
140
+ return None
141
+ wait_for = max(wait_for, 0.01)
142
+ if should_log_pause and wait_until_epoch is not None:
143
+ reset_dt = datetime.fromtimestamp(wait_until_epoch).astimezone().isoformat()
144
+ logger.warning(
145
+ "All API keys unavailable (%s); pausing %.2fs until %s",
146
+ pause_reason,
147
+ wait_for,
148
+ reset_dt,
149
+ )
150
+ elif self._verbose:
151
+ logger.debug("All API keys cooling down; waiting %.2fs", wait_for)
152
+ await asyncio.sleep(wait_for)
153
+
154
+ async def key_pool_wait(self) -> tuple[float | None, str | None, float | None]:
155
+ if not self._keys:
156
+ return None, None, None
65
157
  async with self._lock:
66
- key = self._keys[self._idx % len(self._keys)]
67
- self._idx += 1
68
- return key
158
+ now = time.monotonic()
159
+ now_epoch = time.time()
160
+ for meta in self._keys:
161
+ key = meta.key
162
+ if (
163
+ self._cooldowns.get(key, 0.0) <= now
164
+ and self._quota_until.get(key, 0.0) <= now_epoch
165
+ ):
166
+ return 0.0, None, None
167
+
168
+ waits: list[float] = []
169
+ has_cooldown_wait = False
170
+ has_quota_wait = False
171
+ for meta in self._keys:
172
+ key = meta.key
173
+ cooldown_wait = max(self._cooldowns.get(key, 0.0) - now, 0.0)
174
+ quota_wait = max(self._quota_until.get(key, 0.0) - now_epoch, 0.0)
175
+ if cooldown_wait > 0:
176
+ has_cooldown_wait = True
177
+ if quota_wait > 0:
178
+ has_quota_wait = True
179
+ waits.append(max(cooldown_wait, quota_wait))
180
+
181
+ wait_for = min(waits) if waits else None
182
+ if wait_for is None:
183
+ return None, None, None
184
+ if has_quota_wait and has_cooldown_wait:
185
+ reason = "quota/cooldown"
186
+ elif has_quota_wait:
187
+ reason = "quota"
188
+ elif has_cooldown_wait:
189
+ reason = "cooldown"
190
+ else:
191
+ reason = "unknown"
192
+ wait_until_epoch = now_epoch + wait_for
193
+ return wait_for, reason, wait_until_epoch
194
+
195
+ async def mark_error(self, key: str) -> None:
196
+ if key not in self._cooldowns:
197
+ return
198
+ async with self._lock:
199
+ now = time.monotonic()
200
+ self._error_counts[key] = self._error_counts.get(key, 0) + 1
201
+ cooldown_until = now + self._cooldown_seconds
202
+ current = self._cooldowns.get(key, 0.0)
203
+ self._cooldowns[key] = max(current, cooldown_until)
204
+ if cooldown_until > current:
205
+ logger.warning(
206
+ "API key %s cooling down for %.2fs (errors=%d)",
207
+ _mask_key(key),
208
+ self._cooldown_seconds,
209
+ self._error_counts[key],
210
+ )
211
+ elif self._verbose:
212
+ logger.debug(
213
+ "API key cooldown applied (%.2fs, errors=%d)",
214
+ self._cooldown_seconds,
215
+ self._error_counts[key],
216
+ )
217
+
218
+ async def mark_quota_exceeded(self, key: str, message: str, status_code: int | None) -> bool:
219
+ if key not in self._key_meta:
220
+ return False
221
+ meta = self._key_meta[key]
222
+ tokens = meta.quota_error_tokens
223
+ if not tokens:
224
+ return False
225
+ candidate = message
226
+ try:
227
+ data = json.loads(message)
228
+ except (TypeError, json.JSONDecodeError):
229
+ data = None
230
+ if isinstance(data, dict):
231
+ collected: list[str] = [message]
232
+ error = data.get("error")
233
+ if isinstance(error, dict):
234
+ for key_name in ("code", "type", "message"):
235
+ value = error.get(key_name)
236
+ if isinstance(value, str):
237
+ collected.append(value)
238
+ for key_name in ("code", "type", "message"):
239
+ value = data.get(key_name)
240
+ if isinstance(value, str):
241
+ collected.append(value)
242
+ candidate = " ".join(collected)
243
+ lower_msg = candidate.lower()
244
+ tokens_match = any(token.lower() in lower_msg for token in tokens)
245
+ if not tokens_match:
246
+ return False
247
+ matched_tokens = [token for token in tokens if token.lower() in lower_msg]
248
+ reset_epoch = _compute_next_reset_epoch(meta)
249
+ if reset_epoch is None:
250
+ logger.warning(
251
+ "API key %s hit quota trigger but no reset_time/quota_duration configured "
252
+ "(matched=%s, status_code=%s)",
253
+ _mask_key(key),
254
+ ",".join(matched_tokens) or "<none>",
255
+ status_code if status_code is not None else "unknown",
256
+ )
257
+ return False
258
+ async with self._lock:
259
+ current = self._quota_until.get(key, 0.0)
260
+ self._quota_until[key] = max(current, reset_epoch)
261
+ if reset_epoch > self._last_key_quota_until.get(key, 0.0):
262
+ self._last_key_quota_until[key] = reset_epoch
263
+ wait_for = max(reset_epoch - time.time(), 0.0)
264
+ reset_dt = datetime.fromtimestamp(reset_epoch).astimezone().isoformat()
265
+ logger.warning(
266
+ "API key %s quota exhausted; pausing %.2fs until %s (matched=%s, status_code=%s)",
267
+ _mask_key(key),
268
+ wait_for,
269
+ reset_dt,
270
+ ",".join(matched_tokens) or "<none>",
271
+ status_code if status_code is not None else "unknown",
272
+ )
273
+ elif self._verbose:
274
+ wait_for = max(reset_epoch - time.time(), 0.0)
275
+ reset_dt = datetime.fromtimestamp(reset_epoch).astimezone().isoformat()
276
+ logger.debug(
277
+ "API key %s quota exhausted; cooldown %.2fs until %s",
278
+ _mask_key(key),
279
+ wait_for,
280
+ reset_dt,
281
+ )
282
+ return True
69
283
 
70
284
 
71
285
  logger = logging.getLogger(__name__)
286
+ _console = Console()
287
+
288
+
289
+ def log_extraction_failure(
290
+ path: str,
291
+ error_type: str,
292
+ error_message: str,
293
+ *,
294
+ status_code: int | None = None,
295
+ ) -> None:
296
+ message = error_message.strip()
297
+ if not message:
298
+ message = "no error message"
299
+ if status_code is not None and f"status_code={status_code}" not in message:
300
+ message = f"{message} (status_code={status_code})"
301
+ console_message = message
302
+ if len(console_message) > 500:
303
+ console_message = f"{console_message[:500]}..."
304
+ logger.warning(
305
+ "Extraction failed for %s (%s): %s",
306
+ path,
307
+ error_type,
308
+ message,
309
+ )
310
+ _console.print(
311
+ f"[bold red]Extraction failed[/] [dim]{path}[/]\n"
312
+ f"[bold yellow]{error_type}[/]: {console_message}"
313
+ )
314
+
315
+
316
+ def _summarize_error_message(message: str, limit: int = 300) -> str:
317
+ text = (message or "").strip() or "no error message"
318
+ if len(text) > limit:
319
+ return f"{text[:limit]}..."
320
+ return text
72
321
 
73
322
 
74
323
  def configure_logging(verbose: bool) -> None:
@@ -76,10 +325,180 @@ def configure_logging(verbose: bool) -> None:
76
325
  coloredlogs.install(level=level, fmt="%(asctime)s %(levelname)s %(message)s")
77
326
 
78
327
 
328
+ @dataclass(frozen=True)
329
+ class DocTask:
330
+ path: Path
331
+ stage_index: int
332
+ stage_name: str
333
+ stage_fields: list[str]
334
+
335
+
336
+ @dataclass
337
+ class DocState:
338
+ total_stages: int
339
+ next_index: int = 0
340
+ failed: bool = False
341
+ lock: asyncio.Lock = field(default_factory=asyncio.Lock)
342
+ event: asyncio.Event = field(default_factory=asyncio.Event)
343
+
344
+
345
+ @dataclass
346
+ class DocDagState:
347
+ total_stages: int
348
+ remaining: int
349
+ completed: set[str] = field(default_factory=set)
350
+ in_flight: set[str] = field(default_factory=set)
351
+ failed: bool = False
352
+ finalized: bool = False
353
+ lock: asyncio.Lock = field(default_factory=asyncio.Lock)
354
+
355
+
356
+ @dataclass
357
+ class DocContext:
358
+ path: Path
359
+ source_path: str
360
+ content: str
361
+ truncated_content: str
362
+ truncation: dict[str, Any] | None
363
+ source_hash: str
364
+ stage_path: Path | None
365
+ stage_state: dict[str, Any] | None
366
+ stages: dict[str, dict[str, Any]]
367
+ stage_meta: dict[str, dict[str, Any]]
368
+
369
+
79
370
  def _count_prompt_chars(messages: list[dict[str, str]]) -> int:
80
371
  return sum(len(message.get("content") or "") for message in messages)
81
372
 
82
373
 
374
+ def _load_prompt_template_sources(name: str) -> tuple[str, str]:
375
+ bundle = get_template_bundle(name)
376
+ system_path = resources.files("deepresearch_flow.paper.prompt_templates").joinpath(
377
+ bundle.prompt_system
378
+ )
379
+ user_path = resources.files("deepresearch_flow.paper.prompt_templates").joinpath(
380
+ bundle.prompt_user
381
+ )
382
+ return (
383
+ system_path.read_text(encoding="utf-8"),
384
+ user_path.read_text(encoding="utf-8"),
385
+ )
386
+
387
+
388
+ def _compute_prompt_hash(
389
+ *,
390
+ prompt_template: str,
391
+ output_language: str,
392
+ stage_name: str | None,
393
+ stage_fields: list[str],
394
+ custom_prompt: bool,
395
+ prompt_system_path: Path | None,
396
+ prompt_user_path: Path | None,
397
+ ) -> str:
398
+ if custom_prompt and prompt_system_path and prompt_user_path:
399
+ system_text = prompt_system_path.read_text(encoding="utf-8")
400
+ user_text = prompt_user_path.read_text(encoding="utf-8")
401
+ else:
402
+ system_text, user_text = _load_prompt_template_sources(prompt_template)
403
+ payload = {
404
+ "prompt_template": prompt_template,
405
+ "output_language": output_language,
406
+ "stage_name": stage_name or "",
407
+ "stage_fields": stage_fields,
408
+ "system_template": system_text,
409
+ "user_template": user_text,
410
+ }
411
+ raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
412
+ return hashlib.sha256(raw.encode("utf-8")).hexdigest()
413
+
414
+
415
+ def _resolve_stage_dependencies(
416
+ stage_definitions: list[StageDefinition],
417
+ ) -> dict[str, list[str]]:
418
+ deps: dict[str, list[str]] = {}
419
+ for idx, stage_def in enumerate(stage_definitions):
420
+ if stage_def.depends_on is None:
421
+ if idx == 0:
422
+ deps[stage_def.name] = []
423
+ else:
424
+ deps[stage_def.name] = [stage_definitions[idx - 1].name]
425
+ else:
426
+ deps[stage_def.name] = list(stage_def.depends_on)
427
+ return deps
428
+
429
+
430
+ def _build_dependency_graph(
431
+ stage_definitions: list[StageDefinition],
432
+ deps: dict[str, list[str]],
433
+ ) -> dict[str, list[str]]:
434
+ stage_names = {stage_def.name for stage_def in stage_definitions}
435
+ for stage_name, dependencies in deps.items():
436
+ for dependency in dependencies:
437
+ if dependency not in stage_names:
438
+ raise ValueError(
439
+ f"Stage '{stage_name}' depends on unknown stage '{dependency}'"
440
+ )
441
+
442
+ dependents: dict[str, list[str]] = {name: [] for name in stage_names}
443
+ indegree: dict[str, int] = {name: 0 for name in stage_names}
444
+ for stage_name, dependencies in deps.items():
445
+ indegree[stage_name] = len(dependencies)
446
+ for dependency in dependencies:
447
+ dependents[dependency].append(stage_name)
448
+
449
+ queue = deque(name for name, degree in indegree.items() if degree == 0)
450
+ visited = 0
451
+ while queue:
452
+ node = queue.popleft()
453
+ visited += 1
454
+ for child in dependents[node]:
455
+ indegree[child] -= 1
456
+ if indegree[child] == 0:
457
+ queue.append(child)
458
+
459
+ if visited != len(stage_names):
460
+ raise ValueError("Stage dependency cycle detected in template definition")
461
+
462
+ return dependents
463
+
464
+
465
+ def _parse_reset_time(reset_time: str) -> datetime | None:
466
+ candidate = reset_time.strip()
467
+ match = re.search(
468
+ r"(\d{4}-\d{2}-\d{2})[T ](\d{2}:\d{2}:\d{2}(?:\.\d+)?)(?:\s*(Z|[+-]\d{2}:?\d{2}))?",
469
+ candidate,
470
+ )
471
+ if not match:
472
+ return None
473
+ date_part, time_part, tz_part = match.group(1), match.group(2), match.group(3)
474
+ if tz_part:
475
+ tz_part = tz_part.replace("Z", "+00:00")
476
+ tz_part = re.sub(r"([+-]\d{2})(\d{2})$", r"\1:\2", tz_part)
477
+ iso_str = f"{date_part}T{time_part}{tz_part or ''}"
478
+ try:
479
+ return datetime.fromisoformat(iso_str)
480
+ except ValueError:
481
+ return None
482
+
483
+
484
+ def _compute_next_reset_epoch(meta: ApiKeyConfig) -> float | None:
485
+ if not meta.reset_time or not meta.quota_duration:
486
+ return None
487
+ base = _parse_reset_time(meta.reset_time)
488
+ if not base:
489
+ return None
490
+ duration = meta.quota_duration
491
+ if duration <= 0:
492
+ return None
493
+ now = datetime.now(timezone.utc)
494
+ base_utc = base.astimezone(timezone.utc)
495
+ if now <= base_utc:
496
+ return base_utc.timestamp()
497
+ elapsed = (now - base_utc).total_seconds()
498
+ cycles = math.floor(elapsed / duration) + 1
499
+ return base_utc.timestamp() + cycles * duration
500
+
501
+
83
502
  def _estimate_tokens_for_chars(char_count: int) -> int:
84
503
  if char_count <= 0:
85
504
  return 0
@@ -304,6 +723,22 @@ def load_errors(path: Path) -> list[dict[str, Any]]:
304
723
  return []
305
724
 
306
725
 
726
+ def load_retry_list(path: Path) -> list[dict[str, Any]]:
727
+ if not path.exists():
728
+ raise click.ClickException(f"Retry list JSON not found: {path}")
729
+ try:
730
+ data = json.loads(path.read_text(encoding="utf-8"))
731
+ except json.JSONDecodeError as exc:
732
+ raise click.ClickException(f"Invalid retry list JSON: {exc}") from exc
733
+ if isinstance(data, dict):
734
+ items = data.get("items")
735
+ if isinstance(items, list):
736
+ return [item for item in items if isinstance(item, dict)]
737
+ if isinstance(data, list):
738
+ return [item for item in data if isinstance(item, dict)]
739
+ raise click.ClickException("Retry list JSON must be a list or contain an 'items' list")
740
+
741
+
307
742
  def write_json(path: Path, data: Any) -> None:
308
743
  path.parent.mkdir(parents=True, exist_ok=True)
309
744
  path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
@@ -376,6 +811,7 @@ async def call_with_retries(
376
811
  backoff_max_seconds: float,
377
812
  client: httpx.AsyncClient,
378
813
  validator: Draft7Validator,
814
+ key_rotator: KeyRotator | None = None,
379
815
  throttle: RequestThrottle | None = None,
380
816
  stats: ExtractionStats | None = None,
381
817
  ) -> dict[str, Any]:
@@ -384,6 +820,8 @@ async def call_with_retries(
384
820
  prompt_chars = _count_prompt_chars(messages)
385
821
  while attempt < max_retries:
386
822
  attempt += 1
823
+ if key_rotator:
824
+ api_key = await key_rotator.next_key()
387
825
  if throttle:
388
826
  await throttle.tick()
389
827
  if stats:
@@ -402,9 +840,29 @@ async def call_with_retries(
402
840
  if stats:
403
841
  await stats.add_output_chars(len(response_text))
404
842
  except ProviderError as exc:
843
+ quota_hit = False
844
+ if api_key and key_rotator:
845
+ quota_hit = await key_rotator.mark_quota_exceeded(
846
+ api_key,
847
+ str(exc),
848
+ exc.status_code,
849
+ )
405
850
  if exc.structured_error and use_structured != "none":
851
+ logger.warning(
852
+ "Structured response failed; retrying without structured output "
853
+ "(provider=%s, model=%s, status_code=%s): %s",
854
+ provider.name,
855
+ model,
856
+ exc.status_code if exc.status_code is not None else "unknown",
857
+ _summarize_error_message(str(exc)),
858
+ )
406
859
  use_structured = "none"
407
860
  continue
861
+ if quota_hit:
862
+ attempt -= 1
863
+ continue
864
+ if api_key and key_rotator and not quota_hit and should_retry_error(exc):
865
+ await key_rotator.mark_error(api_key)
408
866
  if should_retry_error(exc) and attempt < max_retries:
409
867
  await asyncio.sleep(backoff_delay(backoff_base_seconds, attempt, backoff_max_seconds))
410
868
  continue
@@ -448,9 +906,16 @@ async def extract_documents(
448
906
  split: bool,
449
907
  split_dir: Path | None,
450
908
  force: bool,
909
+ force_stages: list[str],
451
910
  retry_failed: bool,
911
+ retry_failed_stages: bool,
912
+ retry_list_path: Path | None,
913
+ stage_dag: bool,
914
+ start_idx: int,
915
+ end_idx: int,
452
916
  dry_run: bool,
453
917
  max_concurrency_override: int | None,
918
+ timeout_seconds: float,
454
919
  prompt_template: str,
455
920
  output_language: str,
456
921
  custom_prompt: bool,
@@ -469,19 +934,86 @@ async def extract_documents(
469
934
  markdown_files = discover_markdown(inputs, glob_pattern, recursive=True)
470
935
  template_tag = prompt_template if not custom_prompt else "custom"
471
936
 
937
+ total_files = len(markdown_files)
938
+ if start_idx != 0 or end_idx != -1:
939
+ slice_end = end_idx if end_idx != -1 else None
940
+ markdown_files = markdown_files[start_idx:slice_end]
941
+ logger.info(
942
+ "Applied range filter [%d:%d]. Files: %d -> %d",
943
+ start_idx,
944
+ end_idx,
945
+ total_files,
946
+ len(markdown_files),
947
+ )
948
+ if not markdown_files:
949
+ logger.warning(
950
+ "Range filter yielded 0 files (range=%d:%d, total=%d)",
951
+ start_idx,
952
+ end_idx,
953
+ total_files,
954
+ )
955
+
956
+ retry_list_entries = load_retry_list(retry_list_path) if retry_list_path else []
957
+ error_entries = load_errors(errors_path) if retry_failed or retry_failed_stages else []
958
+ retry_stage_map: dict[str, set[str]] = {}
959
+ retry_full_paths: set[str] = set()
960
+
472
961
  if retry_failed:
473
- error_entries = load_errors(errors_path)
474
- retry_paths = {Path(entry.get("source_path", "")).resolve() for entry in error_entries}
475
- markdown_files = [path for path in markdown_files if path in retry_paths]
962
+ for entry in error_entries:
963
+ source_path = entry.get("source_path")
964
+ if not source_path:
965
+ continue
966
+ retry_full_paths.add(str(Path(source_path).resolve()))
967
+
968
+ if retry_failed_stages:
969
+ for entry in error_entries:
970
+ source_path = entry.get("source_path")
971
+ stage_name = entry.get("stage_name")
972
+ if not source_path:
973
+ continue
974
+ resolved = str(Path(source_path).resolve())
975
+ if not stage_name:
976
+ retry_full_paths.add(resolved)
977
+ continue
978
+ retry_stage_map.setdefault(resolved, set()).add(stage_name)
979
+
980
+ if retry_list_entries:
981
+ for entry in retry_list_entries:
982
+ source_path = entry.get("source_path")
983
+ if not source_path:
984
+ continue
985
+ resolved = str(Path(source_path).resolve())
986
+ retry_stages = entry.get("retry_stages")
987
+ if isinstance(retry_stages, list) and retry_stages:
988
+ stage_set = {
989
+ stage for stage in retry_stages if isinstance(stage, str) and stage.strip()
990
+ }
991
+ if stage_set:
992
+ retry_stage_map.setdefault(resolved, set()).update(stage_set)
993
+ continue
994
+ retry_full_paths.add(resolved)
995
+
996
+ retry_mode = retry_failed or retry_failed_stages or bool(retry_list_entries)
997
+ retry_stages_mode = retry_failed_stages or bool(retry_stage_map)
998
+
999
+ if retry_mode:
1000
+ retry_paths = set(retry_full_paths) | set(retry_stage_map.keys())
1001
+ markdown_files = [
1002
+ path for path in markdown_files if str(path.resolve()) in retry_paths
1003
+ ]
476
1004
  logger.debug("Retrying %d markdown files", len(markdown_files))
1005
+ if not markdown_files:
1006
+ logger.warning("Retry list produced 0 files to process")
477
1007
  else:
478
1008
  logger.debug("Discovered %d markdown files", len(markdown_files))
479
1009
 
1010
+ stage_definitions = get_stage_definitions(prompt_template) if not custom_prompt else []
1011
+ multi_stage = bool(stage_definitions)
1012
+ stage_dag_enabled = stage_dag and multi_stage
1013
+
480
1014
  if dry_run:
481
1015
  input_chars = 0
482
1016
  prompt_chars = 0
483
- stage_definitions = get_stage_definitions(prompt_template) if not custom_prompt else []
484
- multi_stage = bool(stage_definitions)
485
1017
  metadata_fields = [
486
1018
  "paper_title",
487
1019
  "paper_authors",
@@ -525,6 +1057,21 @@ async def extract_documents(
525
1057
  )
526
1058
  prompt_chars += _count_prompt_chars(messages)
527
1059
 
1060
+ if stage_dag_enabled and stage_definitions:
1061
+ stage_dependencies = _resolve_stage_dependencies(stage_definitions)
1062
+ _build_dependency_graph(stage_definitions, stage_dependencies)
1063
+ plan_table = Table(
1064
+ title="stage DAG plan (dry-run)",
1065
+ header_style="bold cyan",
1066
+ title_style="bold magenta",
1067
+ )
1068
+ plan_table.add_column("Stage", style="cyan", no_wrap=True)
1069
+ plan_table.add_column("Depends on", style="white")
1070
+ for stage_def in stage_definitions:
1071
+ deps = stage_dependencies.get(stage_def.name, [])
1072
+ plan_table.add_row(stage_def.name, ", ".join(deps) if deps else "none")
1073
+ Console().print(plan_table)
1074
+
528
1075
  duration = time.monotonic() - start_time
529
1076
  prompt_tokens = _estimate_tokens_for_chars(prompt_chars)
530
1077
  completion_tokens = 0
@@ -562,17 +1109,25 @@ async def extract_documents(
562
1109
  if isinstance(entry, dict) and entry.get("source_path")
563
1110
  }
564
1111
 
565
- rotator = KeyRotator(resolve_api_keys(provider.api_keys))
1112
+ cooldown_seconds = max(1.0, float(config.extract.backoff_base_seconds))
1113
+ resolved_keys = resolve_api_key_configs(provider.api_keys)
1114
+ rotator = KeyRotator(
1115
+ resolved_keys,
1116
+ cooldown_seconds=cooldown_seconds,
1117
+ verbose=verbose,
1118
+ )
566
1119
  max_concurrency = max_concurrency_override or config.extract.max_concurrency
567
1120
  semaphore = asyncio.Semaphore(max_concurrency)
568
1121
 
569
1122
  errors: list[ExtractionError] = []
570
1123
  results: dict[str, dict[str, Any]] = {}
571
- stage_definitions = get_stage_definitions(prompt_template) if not custom_prompt else []
572
- multi_stage = bool(stage_definitions)
573
1124
  stage_output_dir = Path("paper_stage_outputs")
574
1125
  if multi_stage:
575
1126
  stage_output_dir.mkdir(parents=True, exist_ok=True)
1127
+ if stage_dag_enabled:
1128
+ logger.info("Multi-stage scheduler: DAG")
1129
+ else:
1130
+ logger.info("Multi-stage scheduler: sequential")
576
1131
 
577
1132
  throttle = None
578
1133
  if sleep_every is not None or sleep_time is not None:
@@ -594,125 +1149,214 @@ async def extract_documents(
594
1149
  )
595
1150
  stats = ExtractionStats(doc_bar=doc_bar)
596
1151
 
597
- async def process_one(path: Path, client: httpx.AsyncClient) -> None:
598
- source_path = str(path.resolve())
599
- current_stage: str | None = None
600
- try:
601
- if verbose:
602
- logger.debug("Processing %s", source_path)
603
- content = read_text(path)
604
- await stats.add_input_chars(len(content))
605
- source_hash = compute_source_hash(content)
606
- stage_state: dict[str, Any] | None = None
607
- stage_path: Path | None = None
1152
+ results_lock = asyncio.Lock()
1153
+ logger.info("Request timeout set to %.1fs", timeout_seconds)
608
1154
 
609
- if not force and not retry_failed:
610
- existing_entry = existing_by_path.get(source_path)
611
- if existing_entry and existing_entry.get("source_hash") == source_hash:
612
- results[source_path] = existing_entry
613
- if stage_bar:
614
- stage_bar.update(len(stage_definitions))
615
- return
1155
+ pause_threshold_seconds = max(0.0, float(config.extract.pause_threshold_seconds))
1156
+ pause_watchdog_seconds = max(60.0, pause_threshold_seconds * 6.0)
1157
+ pause_gate: asyncio.Event | None = None
1158
+ pause_task: asyncio.Task[None] | None = None
616
1159
 
617
- truncated_content, truncation = truncate_content(
618
- content, config.extract.truncate_max_chars, config.extract.truncate_strategy
619
- )
1160
+ shutdown_event = asyncio.Event()
1161
+ shutdown_reason: str | None = None
620
1162
 
621
- api_key = await rotator.next_key()
1163
+ def request_shutdown(reason: str) -> None:
1164
+ nonlocal shutdown_reason
1165
+ if shutdown_event.is_set():
1166
+ return
1167
+ shutdown_reason = reason
1168
+ shutdown_event.set()
1169
+ if pause_gate:
1170
+ pause_gate.set()
1171
+ logger.warning("Graceful shutdown requested (%s); draining in-flight tasks", reason)
1172
+
1173
+ loop = asyncio.get_running_loop()
1174
+ for sig in (signal.SIGINT, signal.SIGTERM):
1175
+ try:
1176
+ loop.add_signal_handler(sig, request_shutdown, sig.name)
1177
+ except (NotImplementedError, RuntimeError, ValueError):
1178
+ signal.signal(
1179
+ sig,
1180
+ lambda *_args, _sig=sig: loop.call_soon_threadsafe(
1181
+ request_shutdown, _sig.name
1182
+ ),
1183
+ )
622
1184
 
623
- if multi_stage:
624
- stage_path = stage_output_dir / f"{stable_hash(source_path)}.json"
625
- stage_state = load_stage_state(stage_path) if not force else None
626
- if stage_state and stage_state.get("source_hash") != source_hash:
627
- stage_state = None
628
- if stage_state is None:
629
- stage_state = {
630
- "source_path": source_path,
631
- "source_hash": source_hash,
632
- "prompt_template": prompt_template,
633
- "output_language": output_language,
634
- "stages": {},
635
- }
636
-
637
- if multi_stage and stage_state is not None and stage_path is not None:
638
- stages: dict[str, dict[str, Any]] = stage_state.get("stages", {})
639
- metadata_fields = [
640
- "paper_title",
641
- "paper_authors",
642
- "publication_date",
643
- "publication_venue",
644
- ]
645
- for stage_def in stage_definitions:
646
- stage_name = stage_def.name
647
- stage_fields = stage_def.fields
648
- current_stage = stage_name
649
- if stage_name in stages and not force:
650
- if stage_bar:
651
- stage_bar.update(1)
1185
+ if resolved_keys:
1186
+ pause_gate = asyncio.Event()
1187
+ pause_gate.set()
1188
+
1189
+ async def pause_watcher() -> None:
1190
+ paused = False
1191
+ try:
1192
+ while True:
1193
+ wait_for, reason, wait_until_epoch = await rotator.key_pool_wait()
1194
+ if wait_for is None or wait_for <= 0:
1195
+ if paused:
1196
+ paused = False
1197
+ pause_gate.set()
1198
+ logger.info("Queue resumed; key pool available")
1199
+ await asyncio.sleep(0.5)
652
1200
  continue
653
- try:
654
- required_fields = metadata_fields + stage_fields
655
- stage_schema = build_stage_schema(schema, required_fields)
656
- stage_validator = validate_schema(stage_schema)
657
- previous_outputs = json.dumps(stages, ensure_ascii=False)
658
- messages = build_messages(
659
- truncated_content,
660
- stage_schema,
661
- provider,
662
- prompt_template,
663
- output_language,
664
- custom_prompt=False,
665
- prompt_system_path=None,
666
- prompt_user_path=None,
667
- stage_name=stage_name,
668
- stage_fields=required_fields,
669
- previous_outputs=previous_outputs,
670
- )
671
- async with semaphore:
672
- data = await call_with_retries(
673
- provider,
674
- model,
675
- messages,
676
- stage_schema,
677
- api_key,
678
- timeout=60.0,
679
- structured_mode=provider.structured_mode,
680
- max_retries=config.extract.max_retries,
681
- backoff_base_seconds=config.extract.backoff_base_seconds,
682
- backoff_max_seconds=config.extract.backoff_max_seconds,
683
- client=client,
684
- validator=stage_validator,
685
- throttle=throttle,
686
- stats=stats,
1201
+ if wait_for <= pause_threshold_seconds:
1202
+ if paused:
1203
+ paused = False
1204
+ pause_gate.set()
1205
+ logger.info(
1206
+ "Queue resumed; key pool wait %.2fs below threshold %.2fs",
1207
+ wait_for,
1208
+ pause_threshold_seconds,
687
1209
  )
688
- stages[stage_name] = data
689
- stage_state["stages"] = stages
690
- write_json_atomic(stage_path, stage_state)
691
- finally:
692
- if stage_bar:
693
- stage_bar.update(1)
1210
+ await asyncio.sleep(min(wait_for, pause_threshold_seconds))
1211
+ continue
1212
+ if not paused:
1213
+ paused = True
1214
+ pause_gate.clear()
1215
+ reset_dt = (
1216
+ datetime.fromtimestamp(wait_until_epoch).astimezone().isoformat()
1217
+ if wait_until_epoch
1218
+ else "unknown"
1219
+ )
1220
+ logger.warning(
1221
+ "Queue paused (keys unavailable: %s); waiting %.2fs until %s",
1222
+ reason or "unknown",
1223
+ wait_for,
1224
+ reset_dt,
1225
+ )
1226
+ await asyncio.sleep(min(wait_for, max(pause_threshold_seconds, 1.0)))
1227
+ except asyncio.CancelledError:
1228
+ raise
1229
+ except Exception:
1230
+ logger.exception("Queue pause watcher failed; releasing pause gate")
1231
+ pause_gate.set()
694
1232
 
695
- merged: dict[str, Any] = {}
696
- for stage_def in stage_definitions:
697
- merged.update(stages.get(stage_def.name, {}))
698
- errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
699
- if errors_in_doc:
700
- raise ProviderError(
701
- f"Schema validation failed: {errors_in_doc[0].message}",
702
- error_type="validation_error",
703
- )
704
- data = append_metadata(
705
- merged,
706
- source_path=source_path,
707
- source_hash=source_hash,
708
- provider=provider.name,
709
- model=model,
710
- truncation=truncation,
711
- prompt_template=prompt_template,
712
- output_language=output_language,
1233
+ pause_task = asyncio.create_task(pause_watcher())
1234
+
1235
+ async def await_key_pool_ready() -> None:
1236
+ if shutdown_event.is_set():
1237
+ return
1238
+ if not pause_gate or pause_gate.is_set():
1239
+ return
1240
+ try:
1241
+ pause_wait = asyncio.create_task(pause_gate.wait())
1242
+ shutdown_wait = asyncio.create_task(shutdown_event.wait())
1243
+ try:
1244
+ await asyncio.wait_for(
1245
+ asyncio.wait(
1246
+ [pause_wait, shutdown_wait],
1247
+ return_when=asyncio.FIRST_COMPLETED,
1248
+ ),
1249
+ timeout=pause_watchdog_seconds,
1250
+ )
1251
+ finally:
1252
+ for task in (pause_wait, shutdown_wait):
1253
+ task.cancel()
1254
+ except asyncio.TimeoutError:
1255
+ logger.warning(
1256
+ "Queue pause watchdog timeout; rechecking key pool availability"
1257
+ )
1258
+ wait_for, _, _ = await rotator.key_pool_wait()
1259
+ if wait_for is None or wait_for <= 0:
1260
+ pause_gate.set()
1261
+
1262
+ async def drain_queue(queue: asyncio.Queue[Any], drain_lock: asyncio.Lock) -> int:
1263
+ drained = 0
1264
+ async with drain_lock:
1265
+ while True:
1266
+ try:
1267
+ queue.get_nowait()
1268
+ except asyncio.QueueEmpty:
1269
+ break
1270
+ queue.task_done()
1271
+ drained += 1
1272
+ return drained
1273
+
1274
+ async def wait_for_queue(queue: asyncio.Queue[Any], drain_lock: asyncio.Lock) -> None:
1275
+ if shutdown_event.is_set():
1276
+ await drain_queue(queue, drain_lock)
1277
+ await queue.join()
1278
+ return
1279
+ join_task = asyncio.create_task(queue.join())
1280
+ shutdown_task = asyncio.create_task(shutdown_event.wait())
1281
+ done, _pending = await asyncio.wait(
1282
+ [join_task, shutdown_task],
1283
+ return_when=asyncio.FIRST_COMPLETED,
1284
+ )
1285
+ if shutdown_task in done:
1286
+ await drain_queue(queue, drain_lock)
1287
+ await queue.join()
1288
+ for task in (join_task, shutdown_task):
1289
+ task.cancel()
1290
+
1291
+ def build_output_payload() -> dict[str, Any]:
1292
+ final_results: list[dict[str, Any]] = []
1293
+ seen = set()
1294
+ for entry in existing:
1295
+ path = entry.get("source_path") if isinstance(entry, dict) else None
1296
+ if path and path in results:
1297
+ final_results.append(results[path])
1298
+ seen.add(path)
1299
+ elif path:
1300
+ final_results.append(entry)
1301
+ seen.add(path)
1302
+
1303
+ for path, entry in results.items():
1304
+ if path not in seen:
1305
+ final_results.append(entry)
1306
+ return {"template_tag": template_tag, "papers": final_results}
1307
+
1308
+ async def persist_output_snapshot() -> None:
1309
+ async with results_lock:
1310
+ payload = build_output_payload()
1311
+ await asyncio.to_thread(write_json, output_path, payload)
1312
+
1313
+ def build_merged(ctx: DocContext) -> dict[str, Any]:
1314
+ merged: dict[str, Any] = {}
1315
+ for stage_def in stage_definitions:
1316
+ merged.update(ctx.stages.get(stage_def.name, {}))
1317
+ return merged
1318
+
1319
+ async def update_results(ctx: DocContext) -> None:
1320
+ merged = build_merged(ctx)
1321
+ data = append_metadata(
1322
+ merged,
1323
+ source_path=ctx.source_path,
1324
+ source_hash=ctx.source_hash,
1325
+ provider=provider.name,
1326
+ model=model,
1327
+ truncation=ctx.truncation,
1328
+ prompt_template=prompt_template,
1329
+ output_language=output_language,
1330
+ )
1331
+ async with results_lock:
1332
+ results[ctx.source_path] = data
1333
+ await persist_output_snapshot()
1334
+
1335
+ async def run_single_stage(client: httpx.AsyncClient) -> None:
1336
+ doc_queue: asyncio.Queue[Path] = asyncio.Queue()
1337
+ drain_lock = asyncio.Lock()
1338
+
1339
+ async def process_one(path: Path) -> None:
1340
+ source_path = str(path.resolve())
1341
+ current_stage: str | None = None
1342
+ try:
1343
+ if shutdown_event.is_set():
1344
+ return
1345
+ if verbose:
1346
+ logger.debug("Processing %s", source_path)
1347
+ content = read_text(path)
1348
+ await stats.add_input_chars(len(content))
1349
+ source_hash = compute_source_hash(content)
1350
+
1351
+ if not force and not retry_mode:
1352
+ existing_entry = existing_by_path.get(source_path)
1353
+ if existing_entry and existing_entry.get("source_hash") == source_hash:
1354
+ results[source_path] = existing_entry
1355
+ return
1356
+
1357
+ truncated_content, truncation = truncate_content(
1358
+ content, config.extract.truncate_max_chars, config.extract.truncate_strategy
713
1359
  )
714
- results[source_path] = data
715
- else:
716
1360
  messages = build_messages(
717
1361
  truncated_content,
718
1362
  schema,
@@ -723,20 +1367,24 @@ async def extract_documents(
723
1367
  prompt_system_path=prompt_system_path,
724
1368
  prompt_user_path=prompt_user_path,
725
1369
  )
1370
+ if shutdown_event.is_set():
1371
+ return
1372
+ await await_key_pool_ready()
726
1373
  async with semaphore:
727
1374
  data = await call_with_retries(
728
1375
  provider,
729
1376
  model,
730
1377
  messages,
731
1378
  schema,
732
- api_key,
733
- timeout=60.0,
1379
+ None,
1380
+ timeout=timeout_seconds,
734
1381
  structured_mode=provider.structured_mode,
735
1382
  max_retries=config.extract.max_retries,
736
1383
  backoff_base_seconds=config.extract.backoff_base_seconds,
737
1384
  backoff_max_seconds=config.extract.backoff_max_seconds,
738
1385
  client=client,
739
1386
  validator=validator,
1387
+ key_rotator=rotator,
740
1388
  throttle=throttle,
741
1389
  stats=stats,
742
1390
  )
@@ -752,38 +1400,770 @@ async def extract_documents(
752
1400
  output_language=output_language,
753
1401
  )
754
1402
  results[source_path] = data
755
- except ProviderError as exc:
756
- logger.warning("Extraction failed for %s: %s", source_path, exc)
757
- errors.append(
758
- ExtractionError(
759
- path=path,
760
- provider=provider.name,
761
- model=model,
762
- error_type=exc.error_type,
763
- error_message=str(exc),
764
- stage_name=current_stage if multi_stage else None,
1403
+ except ProviderError as exc:
1404
+ log_extraction_failure(
1405
+ source_path,
1406
+ exc.error_type,
1407
+ str(exc),
1408
+ status_code=exc.status_code,
765
1409
  )
1410
+ errors.append(
1411
+ ExtractionError(
1412
+ path=path,
1413
+ provider=provider.name,
1414
+ model=model,
1415
+ error_type=exc.error_type,
1416
+ error_message=str(exc),
1417
+ stage_name=current_stage if multi_stage else None,
1418
+ )
1419
+ )
1420
+ except Exception as exc: # pragma: no cover - safety net
1421
+ logger.exception("Unexpected error while processing %s", source_path)
1422
+ errors.append(
1423
+ ExtractionError(
1424
+ path=path,
1425
+ provider=provider.name,
1426
+ model=model,
1427
+ error_type="unexpected_error",
1428
+ error_message=str(exc),
1429
+ stage_name=current_stage if multi_stage else None,
1430
+ )
1431
+ )
1432
+ finally:
1433
+ if doc_bar:
1434
+ doc_bar.update(1)
1435
+
1436
+ for path in markdown_files:
1437
+ if shutdown_event.is_set():
1438
+ break
1439
+ doc_queue.put_nowait(path)
1440
+
1441
+ async def worker() -> None:
1442
+ while True:
1443
+ if shutdown_event.is_set():
1444
+ await drain_queue(doc_queue, drain_lock)
1445
+ return
1446
+ try:
1447
+ path = doc_queue.get_nowait()
1448
+ except asyncio.QueueEmpty:
1449
+ return
1450
+ await process_one(path)
1451
+ doc_queue.task_done()
1452
+
1453
+ workers = [asyncio.create_task(worker()) for _ in range(max_concurrency)]
1454
+ if markdown_files:
1455
+ await wait_for_queue(doc_queue, drain_lock)
1456
+ for w in workers:
1457
+ w.cancel()
1458
+
1459
+ async def run_multi_stage(client: httpx.AsyncClient) -> None:
1460
+ metadata_fields = [
1461
+ "paper_title",
1462
+ "paper_authors",
1463
+ "publication_date",
1464
+ "publication_venue",
1465
+ ]
1466
+ force_stage_set = set(force_stages or [])
1467
+ doc_contexts: dict[Path, DocContext] = {}
1468
+ doc_states: dict[Path, DocState] = {}
1469
+ task_queue: asyncio.Queue[DocTask] = asyncio.Queue()
1470
+ drain_lock = asyncio.Lock()
1471
+
1472
+ for path in markdown_files:
1473
+ if shutdown_event.is_set():
1474
+ break
1475
+ source_path = str(path.resolve())
1476
+ if verbose:
1477
+ logger.debug("Preparing %s", source_path)
1478
+ content = read_text(path)
1479
+ await stats.add_input_chars(len(content))
1480
+ source_hash = compute_source_hash(content)
1481
+ truncated_content, truncation = truncate_content(
1482
+ content, config.extract.truncate_max_chars, config.extract.truncate_strategy
766
1483
  )
767
- except Exception as exc: # pragma: no cover - safety net
768
- logger.exception("Unexpected error while processing %s", source_path)
769
- errors.append(
770
- ExtractionError(
771
- path=path,
772
- provider=provider.name,
773
- model=model,
774
- error_type="unexpected_error",
775
- error_message=str(exc),
776
- stage_name=current_stage if multi_stage else None,
1484
+ stage_path = stage_output_dir / f"{stable_hash(source_path)}.json"
1485
+ stage_state = load_stage_state(stage_path) if not force else None
1486
+ if stage_state and stage_state.get("source_hash") != source_hash:
1487
+ stage_state = None
1488
+ if stage_state is None:
1489
+ stage_state = {
1490
+ "source_path": source_path,
1491
+ "source_hash": source_hash,
1492
+ "prompt_template": prompt_template,
1493
+ "output_language": output_language,
1494
+ "stages": {},
1495
+ "stage_meta": {},
1496
+ }
1497
+ stages: dict[str, dict[str, Any]] = stage_state.get("stages", {})
1498
+ stage_meta: dict[str, dict[str, Any]] = stage_state.get("stage_meta", {})
1499
+ doc_contexts[path] = DocContext(
1500
+ path=path,
1501
+ source_path=source_path,
1502
+ content=content,
1503
+ truncated_content=truncated_content,
1504
+ truncation=truncation,
1505
+ source_hash=source_hash,
1506
+ stage_path=stage_path,
1507
+ stage_state=stage_state,
1508
+ stages=stages,
1509
+ stage_meta=stage_meta,
1510
+ )
1511
+ doc_states[path] = DocState(total_stages=len(stage_definitions))
1512
+ for idx, stage_def in enumerate(stage_definitions):
1513
+ required_fields = metadata_fields + stage_def.fields
1514
+ task_queue.put_nowait(
1515
+ DocTask(
1516
+ path=path,
1517
+ stage_index=idx,
1518
+ stage_name=stage_def.name,
1519
+ stage_fields=required_fields,
1520
+ )
1521
+ )
1522
+
1523
+ async def run_task(task: DocTask) -> None:
1524
+ ctx = doc_contexts[task.path]
1525
+ state = doc_states[task.path]
1526
+
1527
+ while True:
1528
+ if shutdown_event.is_set():
1529
+ async with state.lock:
1530
+ state.failed = True
1531
+ state.event.set()
1532
+ if stage_bar:
1533
+ stage_bar.update(1)
1534
+ return
1535
+ async with state.lock:
1536
+ if state.failed:
1537
+ if stage_bar:
1538
+ stage_bar.update(1)
1539
+ return
1540
+ if task.stage_index == state.next_index:
1541
+ break
1542
+ wait_event = state.event
1543
+ await wait_event.wait()
1544
+
1545
+ current_stage = task.stage_name
1546
+ is_retry_full = ctx.source_path in retry_full_paths
1547
+ retry_stages = (
1548
+ retry_stage_map.get(ctx.source_path)
1549
+ if retry_stages_mode and not is_retry_full
1550
+ else None
1551
+ )
1552
+ if shutdown_event.is_set():
1553
+ async with state.lock:
1554
+ state.failed = True
1555
+ state.event.set()
1556
+ if stage_bar:
1557
+ stage_bar.update(1)
1558
+ return
1559
+ stage_record = ctx.stages.get(current_stage)
1560
+ if (
1561
+ retry_stages_mode
1562
+ and not is_retry_full
1563
+ and retry_stages is not None
1564
+ and current_stage not in retry_stages
1565
+ and stage_record is not None
1566
+ ):
1567
+ if stage_bar:
1568
+ stage_bar.update(1)
1569
+ await update_results(ctx)
1570
+ final_validation_error: str | None = None
1571
+ if not state.failed and task.stage_index == state.total_stages - 1:
1572
+ merged = build_merged(ctx)
1573
+ errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
1574
+ if errors_in_doc:
1575
+ final_validation_error = errors_in_doc[0].message
1576
+
1577
+ async with state.lock:
1578
+ if final_validation_error:
1579
+ errors.append(
1580
+ ExtractionError(
1581
+ path=task.path,
1582
+ provider=provider.name,
1583
+ model=model,
1584
+ error_type="validation_error",
1585
+ error_message=f"Schema validation failed: {final_validation_error}",
1586
+ stage_name=current_stage,
1587
+ )
1588
+ )
1589
+ state.failed = True
1590
+ if not state.failed:
1591
+ state.next_index += 1
1592
+ if state.next_index >= state.total_stages or state.failed:
1593
+ if doc_bar and state.total_stages:
1594
+ doc_bar.update(1)
1595
+ state.event.set()
1596
+ state.event = asyncio.Event()
1597
+ return
1598
+ prompt_hash = prompt_hash_map[current_stage]
1599
+ stage_schema = stage_schema_map[current_stage]
1600
+ stage_validator = stage_validator_map[current_stage]
1601
+ stage_meta = ctx.stage_meta.get(current_stage, {})
1602
+ needs_run = force or current_stage in force_stage_set
1603
+ if stage_record is None:
1604
+ needs_run = True
1605
+ if is_retry_full:
1606
+ needs_run = True
1607
+ if retry_stages is not None and current_stage in retry_stages:
1608
+ needs_run = True
1609
+ if stage_meta.get("prompt_hash") != prompt_hash:
1610
+ needs_run = True
1611
+ if stage_record is not None and not needs_run:
1612
+ errors_in_stage = sorted(stage_validator.iter_errors(stage_record), key=lambda e: e.path)
1613
+ if errors_in_stage:
1614
+ needs_run = True
1615
+
1616
+ if not needs_run:
1617
+ if stage_bar:
1618
+ stage_bar.update(1)
1619
+ await update_results(ctx)
1620
+ else:
1621
+ try:
1622
+ previous_outputs = json.dumps(ctx.stages, ensure_ascii=False)
1623
+ messages = build_messages(
1624
+ ctx.truncated_content,
1625
+ stage_schema,
1626
+ provider,
1627
+ prompt_template,
1628
+ output_language,
1629
+ custom_prompt=False,
1630
+ prompt_system_path=None,
1631
+ prompt_user_path=None,
1632
+ stage_name=current_stage,
1633
+ stage_fields=task.stage_fields,
1634
+ previous_outputs=previous_outputs,
1635
+ )
1636
+ async with semaphore:
1637
+ data = await call_with_retries(
1638
+ provider,
1639
+ model,
1640
+ messages,
1641
+ stage_schema,
1642
+ None,
1643
+ timeout=timeout_seconds,
1644
+ structured_mode=provider.structured_mode,
1645
+ max_retries=config.extract.max_retries,
1646
+ backoff_base_seconds=config.extract.backoff_base_seconds,
1647
+ backoff_max_seconds=config.extract.backoff_max_seconds,
1648
+ client=client,
1649
+ validator=stage_validator,
1650
+ key_rotator=rotator,
1651
+ throttle=throttle,
1652
+ stats=stats,
1653
+ )
1654
+ ctx.stages[current_stage] = data
1655
+ ctx.stage_meta[current_stage] = {"prompt_hash": prompt_hash}
1656
+ ctx.stage_state["stages"] = ctx.stages
1657
+ ctx.stage_state["stage_meta"] = ctx.stage_meta
1658
+ write_json_atomic(ctx.stage_path, ctx.stage_state)
1659
+ if stage_bar:
1660
+ stage_bar.update(1)
1661
+ await update_results(ctx)
1662
+ except ProviderError as exc:
1663
+ log_extraction_failure(
1664
+ ctx.source_path,
1665
+ exc.error_type,
1666
+ str(exc),
1667
+ status_code=exc.status_code,
1668
+ )
1669
+ errors.append(
1670
+ ExtractionError(
1671
+ path=task.path,
1672
+ provider=provider.name,
1673
+ model=model,
1674
+ error_type=exc.error_type,
1675
+ error_message=str(exc),
1676
+ stage_name=current_stage,
1677
+ )
1678
+ )
1679
+ async with state.lock:
1680
+ state.failed = True
1681
+ if stage_bar:
1682
+ stage_bar.update(1)
1683
+ except Exception as exc: # pragma: no cover - safety net
1684
+ logger.exception("Unexpected error while processing %s", ctx.source_path)
1685
+ errors.append(
1686
+ ExtractionError(
1687
+ path=task.path,
1688
+ provider=provider.name,
1689
+ model=model,
1690
+ error_type="unexpected_error",
1691
+ error_message=str(exc),
1692
+ stage_name=current_stage,
1693
+ )
1694
+ )
1695
+ async with state.lock:
1696
+ state.failed = True
1697
+ if stage_bar:
1698
+ stage_bar.update(1)
1699
+
1700
+ final_validation_error: str | None = None
1701
+ if not state.failed and task.stage_index == state.total_stages - 1:
1702
+ merged = build_merged(ctx)
1703
+ errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
1704
+ if errors_in_doc:
1705
+ final_validation_error = errors_in_doc[0].message
1706
+
1707
+ async with state.lock:
1708
+ if final_validation_error:
1709
+ errors.append(
1710
+ ExtractionError(
1711
+ path=task.path,
1712
+ provider=provider.name,
1713
+ model=model,
1714
+ error_type="validation_error",
1715
+ error_message=f"Schema validation failed: {final_validation_error}",
1716
+ stage_name=current_stage,
1717
+ )
1718
+ )
1719
+ state.failed = True
1720
+ if not state.failed:
1721
+ state.next_index += 1
1722
+ if state.next_index >= state.total_stages or state.failed:
1723
+ if doc_bar and state.total_stages:
1724
+ doc_bar.update(1)
1725
+ state.event.set()
1726
+ state.event = asyncio.Event()
1727
+
1728
+ async def worker() -> None:
1729
+ while True:
1730
+ if shutdown_event.is_set():
1731
+ await drain_queue(task_queue, drain_lock)
1732
+ return
1733
+ await await_key_pool_ready()
1734
+ try:
1735
+ task = task_queue.get_nowait()
1736
+ except asyncio.QueueEmpty:
1737
+ return
1738
+ await run_task(task)
1739
+ task_queue.task_done()
1740
+
1741
+ workers = [asyncio.create_task(worker()) for _ in range(max_concurrency)]
1742
+ await wait_for_queue(task_queue, drain_lock)
1743
+ for w in workers:
1744
+ w.cancel()
1745
+
1746
+ async def run_multi_stage_dag(client: httpx.AsyncClient) -> None:
1747
+ metadata_fields = [
1748
+ "paper_title",
1749
+ "paper_authors",
1750
+ "publication_date",
1751
+ "publication_venue",
1752
+ ]
1753
+ force_stage_set = set(force_stages or [])
1754
+ stage_dependencies = _resolve_stage_dependencies(stage_definitions)
1755
+ dependents_map = _build_dependency_graph(stage_definitions, stage_dependencies)
1756
+ stage_index_map = {stage_def.name: idx for idx, stage_def in enumerate(stage_definitions)}
1757
+ stage_fields_map = {
1758
+ stage_def.name: metadata_fields + stage_def.fields for stage_def in stage_definitions
1759
+ }
1760
+ stage_schema_map: dict[str, dict[str, Any]] = {}
1761
+ stage_validator_map: dict[str, Draft7Validator] = {}
1762
+ prompt_hash_map: dict[str, str] = {}
1763
+ for stage_def in stage_definitions:
1764
+ stage_name = stage_def.name
1765
+ stage_schema = build_stage_schema(schema, stage_fields_map[stage_name])
1766
+ stage_schema_map[stage_name] = stage_schema
1767
+ stage_validator_map[stage_name] = validate_schema(stage_schema)
1768
+ prompt_hash_map[stage_name] = _compute_prompt_hash(
1769
+ prompt_template=prompt_template,
1770
+ output_language=output_language,
1771
+ stage_name=stage_name,
1772
+ stage_fields=stage_fields_map[stage_name],
1773
+ custom_prompt=custom_prompt,
1774
+ prompt_system_path=prompt_system_path,
1775
+ prompt_user_path=prompt_user_path,
1776
+ )
1777
+
1778
+ total_reused = 0
1779
+ fully_completed = 0
1780
+
1781
+ doc_contexts: dict[Path, DocContext] = {}
1782
+ doc_states: dict[Path, DocDagState] = {}
1783
+ task_queue: asyncio.Queue[DocTask] = asyncio.Queue()
1784
+ drain_lock = asyncio.Lock()
1785
+
1786
+ async def enqueue_ready(path: Path, stage_names: Iterable[str]) -> None:
1787
+ if shutdown_event.is_set():
1788
+ return
1789
+ state = doc_states[path]
1790
+ for stage_name in stage_names:
1791
+ async with state.lock:
1792
+ if state.failed:
1793
+ continue
1794
+ if stage_name in state.completed or stage_name in state.in_flight:
1795
+ continue
1796
+ dependencies = stage_dependencies.get(stage_name, [])
1797
+ if any(dep not in state.completed for dep in dependencies):
1798
+ continue
1799
+ state.in_flight.add(stage_name)
1800
+ task_queue.put_nowait(
1801
+ DocTask(
1802
+ path=path,
1803
+ stage_index=stage_index_map[stage_name],
1804
+ stage_name=stage_name,
1805
+ stage_fields=stage_fields_map[stage_name],
1806
+ )
777
1807
  )
1808
+
1809
+ for path in markdown_files:
1810
+ if shutdown_event.is_set():
1811
+ break
1812
+ source_path = str(path.resolve())
1813
+ if verbose:
1814
+ logger.debug("Preparing %s", source_path)
1815
+ content = read_text(path)
1816
+ await stats.add_input_chars(len(content))
1817
+ source_hash = compute_source_hash(content)
1818
+ truncated_content, truncation = truncate_content(
1819
+ content, config.extract.truncate_max_chars, config.extract.truncate_strategy
1820
+ )
1821
+ stage_path = stage_output_dir / f"{stable_hash(source_path)}.json"
1822
+ stage_state = load_stage_state(stage_path) if not force else None
1823
+ if stage_state and stage_state.get("source_hash") != source_hash:
1824
+ stage_state = None
1825
+ if stage_state is None:
1826
+ stage_state = {
1827
+ "source_path": source_path,
1828
+ "source_hash": source_hash,
1829
+ "prompt_template": prompt_template,
1830
+ "output_language": output_language,
1831
+ "stages": {},
1832
+ "stage_meta": {},
1833
+ }
1834
+ stages: dict[str, dict[str, Any]] = stage_state.get("stages", {})
1835
+ stage_meta: dict[str, dict[str, Any]] = stage_state.get("stage_meta", {})
1836
+ doc_contexts[path] = DocContext(
1837
+ path=path,
1838
+ source_path=source_path,
1839
+ content=content,
1840
+ truncated_content=truncated_content,
1841
+ truncation=truncation,
1842
+ source_hash=source_hash,
1843
+ stage_path=stage_path,
1844
+ stage_state=stage_state,
1845
+ stages=stages,
1846
+ stage_meta=stage_meta,
1847
+ )
1848
+ doc_states[path] = DocDagState(
1849
+ total_stages=len(stage_definitions),
1850
+ remaining=len(stage_definitions),
1851
+ )
1852
+ state = doc_states[path]
1853
+ is_retry_full = source_path in retry_full_paths
1854
+ retry_stages = (
1855
+ retry_stage_map.get(source_path)
1856
+ if retry_stages_mode and not is_retry_full
1857
+ else None
1858
+ )
1859
+ reused_count = 0
1860
+ for stage_def in stage_definitions:
1861
+ stage_name = stage_def.name
1862
+ stage_record = stages.get(stage_name)
1863
+ if (
1864
+ retry_stages_mode
1865
+ and not is_retry_full
1866
+ and retry_stages is not None
1867
+ and stage_name not in retry_stages
1868
+ and stage_record is not None
1869
+ ):
1870
+ state.completed.add(stage_name)
1871
+ state.remaining -= 1
1872
+ reused_count += 1
1873
+ continue
1874
+
1875
+ stage_meta_entry = stage_meta.get(stage_name, {})
1876
+ needs_run = force or stage_name in force_stage_set
1877
+ if stage_record is None:
1878
+ needs_run = True
1879
+ if is_retry_full:
1880
+ needs_run = True
1881
+ if retry_stages is not None and stage_name in retry_stages:
1882
+ needs_run = True
1883
+ if stage_meta_entry.get("prompt_hash") != prompt_hash_map[stage_name]:
1884
+ needs_run = True
1885
+ if stage_record is not None and not needs_run:
1886
+ errors_in_stage = sorted(
1887
+ stage_validator_map[stage_name].iter_errors(stage_record),
1888
+ key=lambda e: e.path,
1889
+ )
1890
+ if errors_in_stage:
1891
+ needs_run = True
1892
+
1893
+ if not needs_run:
1894
+ state.completed.add(stage_name)
1895
+ state.remaining -= 1
1896
+ reused_count += 1
1897
+
1898
+ if reused_count:
1899
+ total_reused += reused_count
1900
+ if stage_bar:
1901
+ stage_bar.update(reused_count)
1902
+
1903
+ if state.remaining == 0:
1904
+ state.finalized = True
1905
+ if doc_bar:
1906
+ doc_bar.update(1)
1907
+ await update_results(doc_contexts[path])
1908
+ merged = build_merged(doc_contexts[path])
1909
+ errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
1910
+ if errors_in_doc:
1911
+ errors.append(
1912
+ ExtractionError(
1913
+ path=path,
1914
+ provider=provider.name,
1915
+ model=model,
1916
+ error_type="validation_error",
1917
+ error_message=f"Schema validation failed: {errors_in_doc[0].message}",
1918
+ stage_name=stage_definitions[-1].name if stage_definitions else None,
1919
+ )
1920
+ )
1921
+ state.failed = True
1922
+ fully_completed += 1
1923
+ continue
1924
+
1925
+ await enqueue_ready(path, [stage_def.name for stage_def in stage_definitions])
1926
+
1927
+ if total_reused:
1928
+ logger.info(
1929
+ "DAG precheck reused %d/%d stages",
1930
+ total_reused,
1931
+ len(markdown_files) * len(stage_definitions),
1932
+ )
1933
+ if fully_completed:
1934
+ logger.info("DAG precheck fully satisfied %d docs (no stages queued)", fully_completed)
1935
+
1936
+ async def finalize_stage(
1937
+ ctx: DocContext,
1938
+ stage_name: str,
1939
+ *,
1940
+ failed: bool,
1941
+ ) -> None:
1942
+ state = doc_states[ctx.path]
1943
+ skip_count = 0
1944
+ doc_done = False
1945
+ is_failed = False
1946
+ async with state.lock:
1947
+ state.in_flight.discard(stage_name)
1948
+ if stage_name not in state.completed:
1949
+ state.completed.add(stage_name)
1950
+ state.remaining -= 1
1951
+ if failed and not state.failed:
1952
+ state.failed = True
1953
+ skip_count = state.remaining - len(state.in_flight)
1954
+ if skip_count < 0:
1955
+ skip_count = 0
1956
+ state.remaining -= skip_count
1957
+ if state.remaining == 0 and not state.finalized:
1958
+ state.finalized = True
1959
+ doc_done = True
1960
+ is_failed = state.failed
1961
+
1962
+ if stage_bar:
1963
+ stage_bar.update(1)
1964
+ if skip_count:
1965
+ stage_bar.update(skip_count)
1966
+
1967
+ if not is_failed:
1968
+ await enqueue_ready(ctx.path, dependents_map.get(stage_name, []))
1969
+
1970
+ if doc_done:
1971
+ if doc_bar:
1972
+ doc_bar.update(1)
1973
+ if not is_failed:
1974
+ merged = build_merged(ctx)
1975
+ errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
1976
+ if errors_in_doc:
1977
+ errors.append(
1978
+ ExtractionError(
1979
+ path=ctx.path,
1980
+ provider=provider.name,
1981
+ model=model,
1982
+ error_type="validation_error",
1983
+ error_message=f"Schema validation failed: {errors_in_doc[0].message}",
1984
+ stage_name=stage_name,
1985
+ )
1986
+ )
1987
+ async with state.lock:
1988
+ state.failed = True
1989
+
1990
+ async def run_task(task: DocTask) -> None:
1991
+ ctx = doc_contexts[task.path]
1992
+ state = doc_states[task.path]
1993
+ current_stage = task.stage_name
1994
+
1995
+ if shutdown_event.is_set():
1996
+ await finalize_stage(ctx, current_stage, failed=False)
1997
+ return
1998
+ async with state.lock:
1999
+ if state.failed:
2000
+ state.in_flight.discard(current_stage)
2001
+ if state.failed:
2002
+ await finalize_stage(ctx, current_stage, failed=False)
2003
+ return
2004
+
2005
+ is_retry_full = ctx.source_path in retry_full_paths
2006
+ retry_stages = (
2007
+ retry_stage_map.get(ctx.source_path)
2008
+ if retry_stages_mode and not is_retry_full
2009
+ else None
2010
+ )
2011
+ stage_record = ctx.stages.get(current_stage)
2012
+ if (
2013
+ retry_stages_mode
2014
+ and not is_retry_full
2015
+ and retry_stages is not None
2016
+ and current_stage not in retry_stages
2017
+ and stage_record is not None
2018
+ ):
2019
+ await update_results(ctx)
2020
+ await finalize_stage(ctx, current_stage, failed=False)
2021
+ return
2022
+
2023
+ prompt_hash = _compute_prompt_hash(
2024
+ prompt_template=prompt_template,
2025
+ output_language=output_language,
2026
+ stage_name=current_stage,
2027
+ stage_fields=task.stage_fields,
2028
+ custom_prompt=custom_prompt,
2029
+ prompt_system_path=prompt_system_path,
2030
+ prompt_user_path=prompt_user_path,
778
2031
  )
779
- finally:
780
- if doc_bar:
781
- doc_bar.update(1)
2032
+ stage_schema = build_stage_schema(schema, task.stage_fields)
2033
+ stage_validator = validate_schema(stage_schema)
2034
+ stage_meta = ctx.stage_meta.get(current_stage, {})
2035
+ needs_run = force or current_stage in force_stage_set
2036
+ if stage_record is None:
2037
+ needs_run = True
2038
+ if is_retry_full:
2039
+ needs_run = True
2040
+ if retry_stages is not None and current_stage in retry_stages:
2041
+ needs_run = True
2042
+ if stage_meta.get("prompt_hash") != prompt_hash:
2043
+ needs_run = True
2044
+ if stage_record is not None and not needs_run:
2045
+ errors_in_stage = sorted(
2046
+ stage_validator.iter_errors(stage_record), key=lambda e: e.path
2047
+ )
2048
+ if errors_in_stage:
2049
+ needs_run = True
2050
+
2051
+ if not needs_run:
2052
+ await update_results(ctx)
2053
+ await finalize_stage(ctx, current_stage, failed=False)
2054
+ return
2055
+
2056
+ try:
2057
+ dependencies = stage_dependencies.get(current_stage, [])
2058
+ previous_payload = {
2059
+ dep: ctx.stages.get(dep) for dep in dependencies if dep in ctx.stages
2060
+ }
2061
+ previous_outputs = (
2062
+ json.dumps(previous_payload, ensure_ascii=False) if previous_payload else ""
2063
+ )
2064
+ messages = build_messages(
2065
+ ctx.truncated_content,
2066
+ stage_schema,
2067
+ provider,
2068
+ prompt_template,
2069
+ output_language,
2070
+ custom_prompt=False,
2071
+ prompt_system_path=None,
2072
+ prompt_user_path=None,
2073
+ stage_name=current_stage,
2074
+ stage_fields=task.stage_fields,
2075
+ previous_outputs=previous_outputs,
2076
+ )
2077
+ async with semaphore:
2078
+ data = await call_with_retries(
2079
+ provider,
2080
+ model,
2081
+ messages,
2082
+ stage_schema,
2083
+ None,
2084
+ timeout=timeout_seconds,
2085
+ structured_mode=provider.structured_mode,
2086
+ max_retries=config.extract.max_retries,
2087
+ backoff_base_seconds=config.extract.backoff_base_seconds,
2088
+ backoff_max_seconds=config.extract.backoff_max_seconds,
2089
+ client=client,
2090
+ validator=stage_validator,
2091
+ key_rotator=rotator,
2092
+ throttle=throttle,
2093
+ stats=stats,
2094
+ )
2095
+ async with state.lock:
2096
+ ctx.stages[current_stage] = data
2097
+ ctx.stage_meta[current_stage] = {"prompt_hash": prompt_hash}
2098
+ ctx.stage_state["stages"] = ctx.stages
2099
+ ctx.stage_state["stage_meta"] = ctx.stage_meta
2100
+ write_json_atomic(ctx.stage_path, ctx.stage_state)
2101
+ await update_results(ctx)
2102
+ await finalize_stage(ctx, current_stage, failed=False)
2103
+ except ProviderError as exc:
2104
+ log_extraction_failure(
2105
+ ctx.source_path,
2106
+ exc.error_type,
2107
+ str(exc),
2108
+ status_code=exc.status_code,
2109
+ )
2110
+ errors.append(
2111
+ ExtractionError(
2112
+ path=task.path,
2113
+ provider=provider.name,
2114
+ model=model,
2115
+ error_type=exc.error_type,
2116
+ error_message=str(exc),
2117
+ stage_name=current_stage,
2118
+ )
2119
+ )
2120
+ await finalize_stage(ctx, current_stage, failed=True)
2121
+ except Exception as exc: # pragma: no cover - safety net
2122
+ logger.exception("Unexpected error while processing %s", ctx.source_path)
2123
+ errors.append(
2124
+ ExtractionError(
2125
+ path=task.path,
2126
+ provider=provider.name,
2127
+ model=model,
2128
+ error_type="unexpected_error",
2129
+ error_message=str(exc),
2130
+ stage_name=current_stage,
2131
+ )
2132
+ )
2133
+ await finalize_stage(ctx, current_stage, failed=True)
2134
+
2135
+ async def worker() -> None:
2136
+ while True:
2137
+ if shutdown_event.is_set():
2138
+ await drain_queue(task_queue, drain_lock)
2139
+ return
2140
+ await await_key_pool_ready()
2141
+ try:
2142
+ task = task_queue.get_nowait()
2143
+ except asyncio.QueueEmpty:
2144
+ return
2145
+ await run_task(task)
2146
+ task_queue.task_done()
2147
+
2148
+ workers = [asyncio.create_task(worker()) for _ in range(max_concurrency)]
2149
+ await wait_for_queue(task_queue, drain_lock)
2150
+ for w in workers:
2151
+ w.cancel()
782
2152
 
783
2153
  try:
784
2154
  async with httpx.AsyncClient() as client:
785
- await asyncio.gather(*(process_one(path, client) for path in markdown_files))
2155
+ if multi_stage:
2156
+ if stage_dag_enabled:
2157
+ await run_multi_stage_dag(client)
2158
+ else:
2159
+ await run_multi_stage(client)
2160
+ else:
2161
+ await run_single_stage(client)
786
2162
  finally:
2163
+ if pause_task:
2164
+ pause_task.cancel()
2165
+ with contextlib.suppress(asyncio.CancelledError):
2166
+ await pause_task
787
2167
  if doc_bar:
788
2168
  doc_bar.close()
789
2169
  if stage_bar:
@@ -796,7 +2176,7 @@ async def extract_documents(
796
2176
  if path and path in results:
797
2177
  final_results.append(results[path])
798
2178
  seen.add(path)
799
- elif path and not retry_failed:
2179
+ elif path:
800
2180
  final_results.append(entry)
801
2181
  seen.add(path)
802
2182
 
@@ -820,6 +2200,12 @@ async def extract_documents(
820
2200
  ]
821
2201
  write_json(errors_path, error_payload)
822
2202
 
2203
+ if shutdown_event.is_set():
2204
+ logger.warning(
2205
+ "Graceful shutdown completed (%s)",
2206
+ shutdown_reason or "signal",
2207
+ )
2208
+
823
2209
  if split:
824
2210
  target_dir = split_dir or output_path.parent
825
2211
  target_dir.mkdir(parents=True, exist_ok=True)
@@ -861,7 +2247,15 @@ async def extract_documents(
861
2247
  table.add_column("Value", style="white", overflow="fold")
862
2248
  table.add_row("Documents", f"{doc_count} total")
863
2249
  table.add_row("Successful", str(doc_count - len(errors)))
2250
+ failed_stage_count = sum(1 for err in errors if err.stage_name)
2251
+ retried_stage_count = 0
2252
+ if retry_stages_mode:
2253
+ retried_stage_count = sum(
2254
+ len(retry_stage_map.get(str(path.resolve()), set())) for path in markdown_files
2255
+ )
864
2256
  table.add_row("Errors", str(len(errors)))
2257
+ table.add_row("Failed stages", str(failed_stage_count))
2258
+ table.add_row("Retried stages", str(retried_stage_count))
865
2259
  table.add_row("Output JSON", str(output_path))
866
2260
  table.add_row("Errors JSON", str(errors_path))
867
2261
  table.add_row("Duration", _format_duration(duration))