deepresearch-flow 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepresearch_flow/paper/cli.py +63 -0
- deepresearch_flow/paper/config.py +87 -12
- deepresearch_flow/paper/db.py +1154 -35
- deepresearch_flow/paper/db_ops.py +124 -19
- deepresearch_flow/paper/extract.py +1546 -152
- deepresearch_flow/paper/prompt_templates/deep_read_phi_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/deep_read_phi_user.j2 +5 -0
- deepresearch_flow/paper/prompt_templates/deep_read_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/deep_read_user.j2 +272 -40
- deepresearch_flow/paper/prompt_templates/eight_questions_phi_system.j2 +1 -0
- deepresearch_flow/paper/prompt_templates/eight_questions_phi_user.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/eight_questions_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/eight_questions_user.j2 +4 -0
- deepresearch_flow/paper/prompt_templates/simple_phi_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/simple_system.j2 +2 -0
- deepresearch_flow/paper/prompt_templates/simple_user.j2 +2 -0
- deepresearch_flow/paper/providers/azure_openai.py +45 -3
- deepresearch_flow/paper/providers/openai_compatible.py +45 -3
- deepresearch_flow/paper/schemas/deep_read_phi_schema.json +1 -0
- deepresearch_flow/paper/schemas/deep_read_schema.json +1 -0
- deepresearch_flow/paper/schemas/default_paper_schema.json +6 -0
- deepresearch_flow/paper/schemas/eight_questions_schema.json +1 -0
- deepresearch_flow/paper/snapshot/__init__.py +4 -0
- deepresearch_flow/paper/snapshot/api.py +941 -0
- deepresearch_flow/paper/snapshot/builder.py +965 -0
- deepresearch_flow/paper/snapshot/identity.py +239 -0
- deepresearch_flow/paper/snapshot/schema.py +245 -0
- deepresearch_flow/paper/snapshot/tests/__init__.py +2 -0
- deepresearch_flow/paper/snapshot/tests/test_identity.py +123 -0
- deepresearch_flow/paper/snapshot/text.py +154 -0
- deepresearch_flow/paper/template_registry.py +1 -0
- deepresearch_flow/paper/templates/deep_read.md.j2 +4 -0
- deepresearch_flow/paper/templates/deep_read_phi.md.j2 +4 -0
- deepresearch_flow/paper/templates/default_paper.md.j2 +4 -0
- deepresearch_flow/paper/templates/eight_questions.md.j2 +4 -0
- deepresearch_flow/paper/web/app.py +10 -3
- deepresearch_flow/recognize/cli.py +380 -103
- deepresearch_flow/recognize/markdown.py +31 -7
- deepresearch_flow/recognize/math.py +47 -12
- deepresearch_flow/recognize/mermaid.py +320 -10
- deepresearch_flow/recognize/organize.py +29 -7
- deepresearch_flow/translator/cli.py +71 -20
- deepresearch_flow/translator/engine.py +220 -81
- deepresearch_flow/translator/prompts.py +19 -2
- deepresearch_flow/translator/protector.py +15 -3
- deepresearch_flow-0.6.1.dist-info/METADATA +849 -0
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.1.dist-info}/RECORD +51 -43
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.1.dist-info}/WHEEL +1 -1
- deepresearch_flow-0.5.1.dist-info/METADATA +0 -440
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.1.dist-info}/entry_points.txt +0 -0
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {deepresearch_flow-0.5.1.dist-info → deepresearch_flow-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -3,13 +3,19 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
+
import hashlib
|
|
6
7
|
import json
|
|
8
|
+
import math
|
|
9
|
+
from collections import deque
|
|
7
10
|
from dataclasses import dataclass, field
|
|
8
|
-
from datetime import datetime
|
|
11
|
+
from datetime import datetime, timezone
|
|
9
12
|
from pathlib import Path
|
|
10
13
|
from typing import Any, Iterable
|
|
14
|
+
import importlib.resources as resources
|
|
15
|
+
import contextlib
|
|
11
16
|
import logging
|
|
12
17
|
import re
|
|
18
|
+
import signal
|
|
13
19
|
import time
|
|
14
20
|
|
|
15
21
|
import coloredlogs
|
|
@@ -19,12 +25,20 @@ from jsonschema import Draft7Validator
|
|
|
19
25
|
from rich.console import Console
|
|
20
26
|
from rich.table import Table
|
|
21
27
|
from tqdm import tqdm
|
|
22
|
-
from deepresearch_flow.paper.config import
|
|
28
|
+
from deepresearch_flow.paper.config import (
|
|
29
|
+
ApiKeyConfig,
|
|
30
|
+
PaperConfig,
|
|
31
|
+
ProviderConfig,
|
|
32
|
+
resolve_api_key_configs,
|
|
33
|
+
resolve_api_keys,
|
|
34
|
+
)
|
|
23
35
|
from deepresearch_flow.paper.llm import backoff_delay, call_provider
|
|
24
36
|
from deepresearch_flow.paper.prompts import DEFAULT_SYSTEM_PROMPT, DEFAULT_USER_PROMPT
|
|
25
37
|
from deepresearch_flow.paper.render import render_papers, resolve_render_template
|
|
26
38
|
from deepresearch_flow.paper.schema import schema_to_prompt, validate_schema
|
|
27
39
|
from deepresearch_flow.paper.template_registry import (
|
|
40
|
+
StageDefinition,
|
|
41
|
+
get_template_bundle,
|
|
28
42
|
get_stage_definitions,
|
|
29
43
|
load_custom_prompt_templates,
|
|
30
44
|
load_prompt_templates,
|
|
@@ -53,22 +67,257 @@ class ExtractionError:
|
|
|
53
67
|
stage_name: str | None = None
|
|
54
68
|
|
|
55
69
|
|
|
70
|
+
def _mask_key(value: str, *, keep: int = 4) -> str:
|
|
71
|
+
if not value:
|
|
72
|
+
return "<empty>"
|
|
73
|
+
if len(value) <= keep:
|
|
74
|
+
return value
|
|
75
|
+
return f"...{value[-keep:]}"
|
|
76
|
+
|
|
77
|
+
|
|
56
78
|
class KeyRotator:
|
|
57
|
-
def __init__(self, keys: list[
|
|
79
|
+
def __init__(self, keys: list[ApiKeyConfig], *, cooldown_seconds: float, verbose: bool) -> None:
|
|
58
80
|
self._keys = keys
|
|
59
81
|
self._idx = 0
|
|
60
82
|
self._lock = asyncio.Lock()
|
|
83
|
+
self._cooldown_seconds = max(cooldown_seconds, 0.0)
|
|
84
|
+
self._verbose = verbose
|
|
85
|
+
self._cooldowns: dict[str, float] = {key.key: 0.0 for key in keys}
|
|
86
|
+
self._quota_until: dict[str, float] = {key.key: 0.0 for key in keys}
|
|
87
|
+
self._key_meta: dict[str, ApiKeyConfig] = {key.key: key for key in keys}
|
|
88
|
+
self._error_counts: dict[str, int] = {key.key: 0 for key in keys}
|
|
89
|
+
self._last_pause_until: float = 0.0
|
|
90
|
+
self._last_key_quota_until: dict[str, float] = {key.key: 0.0 for key in keys}
|
|
61
91
|
|
|
62
92
|
async def next_key(self) -> str | None:
|
|
63
93
|
if not self._keys:
|
|
64
94
|
return None
|
|
95
|
+
while True:
|
|
96
|
+
wait_for: float | None = None
|
|
97
|
+
wait_until_epoch: float | None = None
|
|
98
|
+
pause_reason: str | None = None
|
|
99
|
+
should_log_pause = False
|
|
100
|
+
async with self._lock:
|
|
101
|
+
now = time.monotonic()
|
|
102
|
+
now_epoch = time.time()
|
|
103
|
+
total = len(self._keys)
|
|
104
|
+
for offset in range(total):
|
|
105
|
+
idx = (self._idx + offset) % total
|
|
106
|
+
key = self._keys[idx].key
|
|
107
|
+
if (
|
|
108
|
+
self._cooldowns.get(key, 0.0) <= now
|
|
109
|
+
and self._quota_until.get(key, 0.0) <= now_epoch
|
|
110
|
+
):
|
|
111
|
+
self._idx = idx + 1
|
|
112
|
+
return key
|
|
113
|
+
waits: list[float] = []
|
|
114
|
+
has_cooldown_wait = False
|
|
115
|
+
has_quota_wait = False
|
|
116
|
+
for meta in self._keys:
|
|
117
|
+
key = meta.key
|
|
118
|
+
cooldown_wait = max(self._cooldowns.get(key, 0.0) - now, 0.0)
|
|
119
|
+
quota_wait = max(self._quota_until.get(key, 0.0) - now_epoch, 0.0)
|
|
120
|
+
if cooldown_wait > 0:
|
|
121
|
+
has_cooldown_wait = True
|
|
122
|
+
if quota_wait > 0:
|
|
123
|
+
has_quota_wait = True
|
|
124
|
+
waits.append(max(cooldown_wait, quota_wait))
|
|
125
|
+
wait_for = min(waits) if waits else None
|
|
126
|
+
if wait_for is not None:
|
|
127
|
+
wait_until_epoch = now_epoch + wait_for
|
|
128
|
+
if wait_until_epoch > self._last_pause_until + 0.5:
|
|
129
|
+
self._last_pause_until = wait_until_epoch
|
|
130
|
+
if has_quota_wait and has_cooldown_wait:
|
|
131
|
+
pause_reason = "quota/cooldown"
|
|
132
|
+
elif has_quota_wait:
|
|
133
|
+
pause_reason = "quota"
|
|
134
|
+
elif has_cooldown_wait:
|
|
135
|
+
pause_reason = "cooldown"
|
|
136
|
+
else:
|
|
137
|
+
pause_reason = "unknown"
|
|
138
|
+
should_log_pause = True
|
|
139
|
+
if wait_for is None:
|
|
140
|
+
return None
|
|
141
|
+
wait_for = max(wait_for, 0.01)
|
|
142
|
+
if should_log_pause and wait_until_epoch is not None:
|
|
143
|
+
reset_dt = datetime.fromtimestamp(wait_until_epoch).astimezone().isoformat()
|
|
144
|
+
logger.warning(
|
|
145
|
+
"All API keys unavailable (%s); pausing %.2fs until %s",
|
|
146
|
+
pause_reason,
|
|
147
|
+
wait_for,
|
|
148
|
+
reset_dt,
|
|
149
|
+
)
|
|
150
|
+
elif self._verbose:
|
|
151
|
+
logger.debug("All API keys cooling down; waiting %.2fs", wait_for)
|
|
152
|
+
await asyncio.sleep(wait_for)
|
|
153
|
+
|
|
154
|
+
async def key_pool_wait(self) -> tuple[float | None, str | None, float | None]:
|
|
155
|
+
if not self._keys:
|
|
156
|
+
return None, None, None
|
|
65
157
|
async with self._lock:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
158
|
+
now = time.monotonic()
|
|
159
|
+
now_epoch = time.time()
|
|
160
|
+
for meta in self._keys:
|
|
161
|
+
key = meta.key
|
|
162
|
+
if (
|
|
163
|
+
self._cooldowns.get(key, 0.0) <= now
|
|
164
|
+
and self._quota_until.get(key, 0.0) <= now_epoch
|
|
165
|
+
):
|
|
166
|
+
return 0.0, None, None
|
|
167
|
+
|
|
168
|
+
waits: list[float] = []
|
|
169
|
+
has_cooldown_wait = False
|
|
170
|
+
has_quota_wait = False
|
|
171
|
+
for meta in self._keys:
|
|
172
|
+
key = meta.key
|
|
173
|
+
cooldown_wait = max(self._cooldowns.get(key, 0.0) - now, 0.0)
|
|
174
|
+
quota_wait = max(self._quota_until.get(key, 0.0) - now_epoch, 0.0)
|
|
175
|
+
if cooldown_wait > 0:
|
|
176
|
+
has_cooldown_wait = True
|
|
177
|
+
if quota_wait > 0:
|
|
178
|
+
has_quota_wait = True
|
|
179
|
+
waits.append(max(cooldown_wait, quota_wait))
|
|
180
|
+
|
|
181
|
+
wait_for = min(waits) if waits else None
|
|
182
|
+
if wait_for is None:
|
|
183
|
+
return None, None, None
|
|
184
|
+
if has_quota_wait and has_cooldown_wait:
|
|
185
|
+
reason = "quota/cooldown"
|
|
186
|
+
elif has_quota_wait:
|
|
187
|
+
reason = "quota"
|
|
188
|
+
elif has_cooldown_wait:
|
|
189
|
+
reason = "cooldown"
|
|
190
|
+
else:
|
|
191
|
+
reason = "unknown"
|
|
192
|
+
wait_until_epoch = now_epoch + wait_for
|
|
193
|
+
return wait_for, reason, wait_until_epoch
|
|
194
|
+
|
|
195
|
+
async def mark_error(self, key: str) -> None:
|
|
196
|
+
if key not in self._cooldowns:
|
|
197
|
+
return
|
|
198
|
+
async with self._lock:
|
|
199
|
+
now = time.monotonic()
|
|
200
|
+
self._error_counts[key] = self._error_counts.get(key, 0) + 1
|
|
201
|
+
cooldown_until = now + self._cooldown_seconds
|
|
202
|
+
current = self._cooldowns.get(key, 0.0)
|
|
203
|
+
self._cooldowns[key] = max(current, cooldown_until)
|
|
204
|
+
if cooldown_until > current:
|
|
205
|
+
logger.warning(
|
|
206
|
+
"API key %s cooling down for %.2fs (errors=%d)",
|
|
207
|
+
_mask_key(key),
|
|
208
|
+
self._cooldown_seconds,
|
|
209
|
+
self._error_counts[key],
|
|
210
|
+
)
|
|
211
|
+
elif self._verbose:
|
|
212
|
+
logger.debug(
|
|
213
|
+
"API key cooldown applied (%.2fs, errors=%d)",
|
|
214
|
+
self._cooldown_seconds,
|
|
215
|
+
self._error_counts[key],
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
async def mark_quota_exceeded(self, key: str, message: str, status_code: int | None) -> bool:
|
|
219
|
+
if key not in self._key_meta:
|
|
220
|
+
return False
|
|
221
|
+
meta = self._key_meta[key]
|
|
222
|
+
tokens = meta.quota_error_tokens
|
|
223
|
+
if not tokens:
|
|
224
|
+
return False
|
|
225
|
+
candidate = message
|
|
226
|
+
try:
|
|
227
|
+
data = json.loads(message)
|
|
228
|
+
except (TypeError, json.JSONDecodeError):
|
|
229
|
+
data = None
|
|
230
|
+
if isinstance(data, dict):
|
|
231
|
+
collected: list[str] = [message]
|
|
232
|
+
error = data.get("error")
|
|
233
|
+
if isinstance(error, dict):
|
|
234
|
+
for key_name in ("code", "type", "message"):
|
|
235
|
+
value = error.get(key_name)
|
|
236
|
+
if isinstance(value, str):
|
|
237
|
+
collected.append(value)
|
|
238
|
+
for key_name in ("code", "type", "message"):
|
|
239
|
+
value = data.get(key_name)
|
|
240
|
+
if isinstance(value, str):
|
|
241
|
+
collected.append(value)
|
|
242
|
+
candidate = " ".join(collected)
|
|
243
|
+
lower_msg = candidate.lower()
|
|
244
|
+
tokens_match = any(token.lower() in lower_msg for token in tokens)
|
|
245
|
+
if not tokens_match:
|
|
246
|
+
return False
|
|
247
|
+
matched_tokens = [token for token in tokens if token.lower() in lower_msg]
|
|
248
|
+
reset_epoch = _compute_next_reset_epoch(meta)
|
|
249
|
+
if reset_epoch is None:
|
|
250
|
+
logger.warning(
|
|
251
|
+
"API key %s hit quota trigger but no reset_time/quota_duration configured "
|
|
252
|
+
"(matched=%s, status_code=%s)",
|
|
253
|
+
_mask_key(key),
|
|
254
|
+
",".join(matched_tokens) or "<none>",
|
|
255
|
+
status_code if status_code is not None else "unknown",
|
|
256
|
+
)
|
|
257
|
+
return False
|
|
258
|
+
async with self._lock:
|
|
259
|
+
current = self._quota_until.get(key, 0.0)
|
|
260
|
+
self._quota_until[key] = max(current, reset_epoch)
|
|
261
|
+
if reset_epoch > self._last_key_quota_until.get(key, 0.0):
|
|
262
|
+
self._last_key_quota_until[key] = reset_epoch
|
|
263
|
+
wait_for = max(reset_epoch - time.time(), 0.0)
|
|
264
|
+
reset_dt = datetime.fromtimestamp(reset_epoch).astimezone().isoformat()
|
|
265
|
+
logger.warning(
|
|
266
|
+
"API key %s quota exhausted; pausing %.2fs until %s (matched=%s, status_code=%s)",
|
|
267
|
+
_mask_key(key),
|
|
268
|
+
wait_for,
|
|
269
|
+
reset_dt,
|
|
270
|
+
",".join(matched_tokens) or "<none>",
|
|
271
|
+
status_code if status_code is not None else "unknown",
|
|
272
|
+
)
|
|
273
|
+
elif self._verbose:
|
|
274
|
+
wait_for = max(reset_epoch - time.time(), 0.0)
|
|
275
|
+
reset_dt = datetime.fromtimestamp(reset_epoch).astimezone().isoformat()
|
|
276
|
+
logger.debug(
|
|
277
|
+
"API key %s quota exhausted; cooldown %.2fs until %s",
|
|
278
|
+
_mask_key(key),
|
|
279
|
+
wait_for,
|
|
280
|
+
reset_dt,
|
|
281
|
+
)
|
|
282
|
+
return True
|
|
69
283
|
|
|
70
284
|
|
|
71
285
|
logger = logging.getLogger(__name__)
|
|
286
|
+
_console = Console()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def log_extraction_failure(
|
|
290
|
+
path: str,
|
|
291
|
+
error_type: str,
|
|
292
|
+
error_message: str,
|
|
293
|
+
*,
|
|
294
|
+
status_code: int | None = None,
|
|
295
|
+
) -> None:
|
|
296
|
+
message = error_message.strip()
|
|
297
|
+
if not message:
|
|
298
|
+
message = "no error message"
|
|
299
|
+
if status_code is not None and f"status_code={status_code}" not in message:
|
|
300
|
+
message = f"{message} (status_code={status_code})"
|
|
301
|
+
console_message = message
|
|
302
|
+
if len(console_message) > 500:
|
|
303
|
+
console_message = f"{console_message[:500]}..."
|
|
304
|
+
logger.warning(
|
|
305
|
+
"Extraction failed for %s (%s): %s",
|
|
306
|
+
path,
|
|
307
|
+
error_type,
|
|
308
|
+
message,
|
|
309
|
+
)
|
|
310
|
+
_console.print(
|
|
311
|
+
f"[bold red]Extraction failed[/] [dim]{path}[/]\n"
|
|
312
|
+
f"[bold yellow]{error_type}[/]: {console_message}"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _summarize_error_message(message: str, limit: int = 300) -> str:
|
|
317
|
+
text = (message or "").strip() or "no error message"
|
|
318
|
+
if len(text) > limit:
|
|
319
|
+
return f"{text[:limit]}..."
|
|
320
|
+
return text
|
|
72
321
|
|
|
73
322
|
|
|
74
323
|
def configure_logging(verbose: bool) -> None:
|
|
@@ -76,10 +325,180 @@ def configure_logging(verbose: bool) -> None:
|
|
|
76
325
|
coloredlogs.install(level=level, fmt="%(asctime)s %(levelname)s %(message)s")
|
|
77
326
|
|
|
78
327
|
|
|
328
|
+
@dataclass(frozen=True)
|
|
329
|
+
class DocTask:
|
|
330
|
+
path: Path
|
|
331
|
+
stage_index: int
|
|
332
|
+
stage_name: str
|
|
333
|
+
stage_fields: list[str]
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@dataclass
|
|
337
|
+
class DocState:
|
|
338
|
+
total_stages: int
|
|
339
|
+
next_index: int = 0
|
|
340
|
+
failed: bool = False
|
|
341
|
+
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
342
|
+
event: asyncio.Event = field(default_factory=asyncio.Event)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@dataclass
|
|
346
|
+
class DocDagState:
|
|
347
|
+
total_stages: int
|
|
348
|
+
remaining: int
|
|
349
|
+
completed: set[str] = field(default_factory=set)
|
|
350
|
+
in_flight: set[str] = field(default_factory=set)
|
|
351
|
+
failed: bool = False
|
|
352
|
+
finalized: bool = False
|
|
353
|
+
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@dataclass
|
|
357
|
+
class DocContext:
|
|
358
|
+
path: Path
|
|
359
|
+
source_path: str
|
|
360
|
+
content: str
|
|
361
|
+
truncated_content: str
|
|
362
|
+
truncation: dict[str, Any] | None
|
|
363
|
+
source_hash: str
|
|
364
|
+
stage_path: Path | None
|
|
365
|
+
stage_state: dict[str, Any] | None
|
|
366
|
+
stages: dict[str, dict[str, Any]]
|
|
367
|
+
stage_meta: dict[str, dict[str, Any]]
|
|
368
|
+
|
|
369
|
+
|
|
79
370
|
def _count_prompt_chars(messages: list[dict[str, str]]) -> int:
|
|
80
371
|
return sum(len(message.get("content") or "") for message in messages)
|
|
81
372
|
|
|
82
373
|
|
|
374
|
+
def _load_prompt_template_sources(name: str) -> tuple[str, str]:
|
|
375
|
+
bundle = get_template_bundle(name)
|
|
376
|
+
system_path = resources.files("deepresearch_flow.paper.prompt_templates").joinpath(
|
|
377
|
+
bundle.prompt_system
|
|
378
|
+
)
|
|
379
|
+
user_path = resources.files("deepresearch_flow.paper.prompt_templates").joinpath(
|
|
380
|
+
bundle.prompt_user
|
|
381
|
+
)
|
|
382
|
+
return (
|
|
383
|
+
system_path.read_text(encoding="utf-8"),
|
|
384
|
+
user_path.read_text(encoding="utf-8"),
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _compute_prompt_hash(
|
|
389
|
+
*,
|
|
390
|
+
prompt_template: str,
|
|
391
|
+
output_language: str,
|
|
392
|
+
stage_name: str | None,
|
|
393
|
+
stage_fields: list[str],
|
|
394
|
+
custom_prompt: bool,
|
|
395
|
+
prompt_system_path: Path | None,
|
|
396
|
+
prompt_user_path: Path | None,
|
|
397
|
+
) -> str:
|
|
398
|
+
if custom_prompt and prompt_system_path and prompt_user_path:
|
|
399
|
+
system_text = prompt_system_path.read_text(encoding="utf-8")
|
|
400
|
+
user_text = prompt_user_path.read_text(encoding="utf-8")
|
|
401
|
+
else:
|
|
402
|
+
system_text, user_text = _load_prompt_template_sources(prompt_template)
|
|
403
|
+
payload = {
|
|
404
|
+
"prompt_template": prompt_template,
|
|
405
|
+
"output_language": output_language,
|
|
406
|
+
"stage_name": stage_name or "",
|
|
407
|
+
"stage_fields": stage_fields,
|
|
408
|
+
"system_template": system_text,
|
|
409
|
+
"user_template": user_text,
|
|
410
|
+
}
|
|
411
|
+
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
|
412
|
+
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _resolve_stage_dependencies(
|
|
416
|
+
stage_definitions: list[StageDefinition],
|
|
417
|
+
) -> dict[str, list[str]]:
|
|
418
|
+
deps: dict[str, list[str]] = {}
|
|
419
|
+
for idx, stage_def in enumerate(stage_definitions):
|
|
420
|
+
if stage_def.depends_on is None:
|
|
421
|
+
if idx == 0:
|
|
422
|
+
deps[stage_def.name] = []
|
|
423
|
+
else:
|
|
424
|
+
deps[stage_def.name] = [stage_definitions[idx - 1].name]
|
|
425
|
+
else:
|
|
426
|
+
deps[stage_def.name] = list(stage_def.depends_on)
|
|
427
|
+
return deps
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _build_dependency_graph(
|
|
431
|
+
stage_definitions: list[StageDefinition],
|
|
432
|
+
deps: dict[str, list[str]],
|
|
433
|
+
) -> dict[str, list[str]]:
|
|
434
|
+
stage_names = {stage_def.name for stage_def in stage_definitions}
|
|
435
|
+
for stage_name, dependencies in deps.items():
|
|
436
|
+
for dependency in dependencies:
|
|
437
|
+
if dependency not in stage_names:
|
|
438
|
+
raise ValueError(
|
|
439
|
+
f"Stage '{stage_name}' depends on unknown stage '{dependency}'"
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
dependents: dict[str, list[str]] = {name: [] for name in stage_names}
|
|
443
|
+
indegree: dict[str, int] = {name: 0 for name in stage_names}
|
|
444
|
+
for stage_name, dependencies in deps.items():
|
|
445
|
+
indegree[stage_name] = len(dependencies)
|
|
446
|
+
for dependency in dependencies:
|
|
447
|
+
dependents[dependency].append(stage_name)
|
|
448
|
+
|
|
449
|
+
queue = deque(name for name, degree in indegree.items() if degree == 0)
|
|
450
|
+
visited = 0
|
|
451
|
+
while queue:
|
|
452
|
+
node = queue.popleft()
|
|
453
|
+
visited += 1
|
|
454
|
+
for child in dependents[node]:
|
|
455
|
+
indegree[child] -= 1
|
|
456
|
+
if indegree[child] == 0:
|
|
457
|
+
queue.append(child)
|
|
458
|
+
|
|
459
|
+
if visited != len(stage_names):
|
|
460
|
+
raise ValueError("Stage dependency cycle detected in template definition")
|
|
461
|
+
|
|
462
|
+
return dependents
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def _parse_reset_time(reset_time: str) -> datetime | None:
|
|
466
|
+
candidate = reset_time.strip()
|
|
467
|
+
match = re.search(
|
|
468
|
+
r"(\d{4}-\d{2}-\d{2})[T ](\d{2}:\d{2}:\d{2}(?:\.\d+)?)(?:\s*(Z|[+-]\d{2}:?\d{2}))?",
|
|
469
|
+
candidate,
|
|
470
|
+
)
|
|
471
|
+
if not match:
|
|
472
|
+
return None
|
|
473
|
+
date_part, time_part, tz_part = match.group(1), match.group(2), match.group(3)
|
|
474
|
+
if tz_part:
|
|
475
|
+
tz_part = tz_part.replace("Z", "+00:00")
|
|
476
|
+
tz_part = re.sub(r"([+-]\d{2})(\d{2})$", r"\1:\2", tz_part)
|
|
477
|
+
iso_str = f"{date_part}T{time_part}{tz_part or ''}"
|
|
478
|
+
try:
|
|
479
|
+
return datetime.fromisoformat(iso_str)
|
|
480
|
+
except ValueError:
|
|
481
|
+
return None
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _compute_next_reset_epoch(meta: ApiKeyConfig) -> float | None:
|
|
485
|
+
if not meta.reset_time or not meta.quota_duration:
|
|
486
|
+
return None
|
|
487
|
+
base = _parse_reset_time(meta.reset_time)
|
|
488
|
+
if not base:
|
|
489
|
+
return None
|
|
490
|
+
duration = meta.quota_duration
|
|
491
|
+
if duration <= 0:
|
|
492
|
+
return None
|
|
493
|
+
now = datetime.now(timezone.utc)
|
|
494
|
+
base_utc = base.astimezone(timezone.utc)
|
|
495
|
+
if now <= base_utc:
|
|
496
|
+
return base_utc.timestamp()
|
|
497
|
+
elapsed = (now - base_utc).total_seconds()
|
|
498
|
+
cycles = math.floor(elapsed / duration) + 1
|
|
499
|
+
return base_utc.timestamp() + cycles * duration
|
|
500
|
+
|
|
501
|
+
|
|
83
502
|
def _estimate_tokens_for_chars(char_count: int) -> int:
|
|
84
503
|
if char_count <= 0:
|
|
85
504
|
return 0
|
|
@@ -304,6 +723,22 @@ def load_errors(path: Path) -> list[dict[str, Any]]:
|
|
|
304
723
|
return []
|
|
305
724
|
|
|
306
725
|
|
|
726
|
+
def load_retry_list(path: Path) -> list[dict[str, Any]]:
|
|
727
|
+
if not path.exists():
|
|
728
|
+
raise click.ClickException(f"Retry list JSON not found: {path}")
|
|
729
|
+
try:
|
|
730
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
731
|
+
except json.JSONDecodeError as exc:
|
|
732
|
+
raise click.ClickException(f"Invalid retry list JSON: {exc}") from exc
|
|
733
|
+
if isinstance(data, dict):
|
|
734
|
+
items = data.get("items")
|
|
735
|
+
if isinstance(items, list):
|
|
736
|
+
return [item for item in items if isinstance(item, dict)]
|
|
737
|
+
if isinstance(data, list):
|
|
738
|
+
return [item for item in data if isinstance(item, dict)]
|
|
739
|
+
raise click.ClickException("Retry list JSON must be a list or contain an 'items' list")
|
|
740
|
+
|
|
741
|
+
|
|
307
742
|
def write_json(path: Path, data: Any) -> None:
|
|
308
743
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
309
744
|
path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
@@ -376,6 +811,7 @@ async def call_with_retries(
|
|
|
376
811
|
backoff_max_seconds: float,
|
|
377
812
|
client: httpx.AsyncClient,
|
|
378
813
|
validator: Draft7Validator,
|
|
814
|
+
key_rotator: KeyRotator | None = None,
|
|
379
815
|
throttle: RequestThrottle | None = None,
|
|
380
816
|
stats: ExtractionStats | None = None,
|
|
381
817
|
) -> dict[str, Any]:
|
|
@@ -384,6 +820,8 @@ async def call_with_retries(
|
|
|
384
820
|
prompt_chars = _count_prompt_chars(messages)
|
|
385
821
|
while attempt < max_retries:
|
|
386
822
|
attempt += 1
|
|
823
|
+
if key_rotator:
|
|
824
|
+
api_key = await key_rotator.next_key()
|
|
387
825
|
if throttle:
|
|
388
826
|
await throttle.tick()
|
|
389
827
|
if stats:
|
|
@@ -402,9 +840,29 @@ async def call_with_retries(
|
|
|
402
840
|
if stats:
|
|
403
841
|
await stats.add_output_chars(len(response_text))
|
|
404
842
|
except ProviderError as exc:
|
|
843
|
+
quota_hit = False
|
|
844
|
+
if api_key and key_rotator:
|
|
845
|
+
quota_hit = await key_rotator.mark_quota_exceeded(
|
|
846
|
+
api_key,
|
|
847
|
+
str(exc),
|
|
848
|
+
exc.status_code,
|
|
849
|
+
)
|
|
405
850
|
if exc.structured_error and use_structured != "none":
|
|
851
|
+
logger.warning(
|
|
852
|
+
"Structured response failed; retrying without structured output "
|
|
853
|
+
"(provider=%s, model=%s, status_code=%s): %s",
|
|
854
|
+
provider.name,
|
|
855
|
+
model,
|
|
856
|
+
exc.status_code if exc.status_code is not None else "unknown",
|
|
857
|
+
_summarize_error_message(str(exc)),
|
|
858
|
+
)
|
|
406
859
|
use_structured = "none"
|
|
407
860
|
continue
|
|
861
|
+
if quota_hit:
|
|
862
|
+
attempt -= 1
|
|
863
|
+
continue
|
|
864
|
+
if api_key and key_rotator and not quota_hit and should_retry_error(exc):
|
|
865
|
+
await key_rotator.mark_error(api_key)
|
|
408
866
|
if should_retry_error(exc) and attempt < max_retries:
|
|
409
867
|
await asyncio.sleep(backoff_delay(backoff_base_seconds, attempt, backoff_max_seconds))
|
|
410
868
|
continue
|
|
@@ -448,9 +906,16 @@ async def extract_documents(
|
|
|
448
906
|
split: bool,
|
|
449
907
|
split_dir: Path | None,
|
|
450
908
|
force: bool,
|
|
909
|
+
force_stages: list[str],
|
|
451
910
|
retry_failed: bool,
|
|
911
|
+
retry_failed_stages: bool,
|
|
912
|
+
retry_list_path: Path | None,
|
|
913
|
+
stage_dag: bool,
|
|
914
|
+
start_idx: int,
|
|
915
|
+
end_idx: int,
|
|
452
916
|
dry_run: bool,
|
|
453
917
|
max_concurrency_override: int | None,
|
|
918
|
+
timeout_seconds: float,
|
|
454
919
|
prompt_template: str,
|
|
455
920
|
output_language: str,
|
|
456
921
|
custom_prompt: bool,
|
|
@@ -469,19 +934,86 @@ async def extract_documents(
|
|
|
469
934
|
markdown_files = discover_markdown(inputs, glob_pattern, recursive=True)
|
|
470
935
|
template_tag = prompt_template if not custom_prompt else "custom"
|
|
471
936
|
|
|
937
|
+
total_files = len(markdown_files)
|
|
938
|
+
if start_idx != 0 or end_idx != -1:
|
|
939
|
+
slice_end = end_idx if end_idx != -1 else None
|
|
940
|
+
markdown_files = markdown_files[start_idx:slice_end]
|
|
941
|
+
logger.info(
|
|
942
|
+
"Applied range filter [%d:%d]. Files: %d -> %d",
|
|
943
|
+
start_idx,
|
|
944
|
+
end_idx,
|
|
945
|
+
total_files,
|
|
946
|
+
len(markdown_files),
|
|
947
|
+
)
|
|
948
|
+
if not markdown_files:
|
|
949
|
+
logger.warning(
|
|
950
|
+
"Range filter yielded 0 files (range=%d:%d, total=%d)",
|
|
951
|
+
start_idx,
|
|
952
|
+
end_idx,
|
|
953
|
+
total_files,
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
retry_list_entries = load_retry_list(retry_list_path) if retry_list_path else []
|
|
957
|
+
error_entries = load_errors(errors_path) if retry_failed or retry_failed_stages else []
|
|
958
|
+
retry_stage_map: dict[str, set[str]] = {}
|
|
959
|
+
retry_full_paths: set[str] = set()
|
|
960
|
+
|
|
472
961
|
if retry_failed:
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
962
|
+
for entry in error_entries:
|
|
963
|
+
source_path = entry.get("source_path")
|
|
964
|
+
if not source_path:
|
|
965
|
+
continue
|
|
966
|
+
retry_full_paths.add(str(Path(source_path).resolve()))
|
|
967
|
+
|
|
968
|
+
if retry_failed_stages:
|
|
969
|
+
for entry in error_entries:
|
|
970
|
+
source_path = entry.get("source_path")
|
|
971
|
+
stage_name = entry.get("stage_name")
|
|
972
|
+
if not source_path:
|
|
973
|
+
continue
|
|
974
|
+
resolved = str(Path(source_path).resolve())
|
|
975
|
+
if not stage_name:
|
|
976
|
+
retry_full_paths.add(resolved)
|
|
977
|
+
continue
|
|
978
|
+
retry_stage_map.setdefault(resolved, set()).add(stage_name)
|
|
979
|
+
|
|
980
|
+
if retry_list_entries:
|
|
981
|
+
for entry in retry_list_entries:
|
|
982
|
+
source_path = entry.get("source_path")
|
|
983
|
+
if not source_path:
|
|
984
|
+
continue
|
|
985
|
+
resolved = str(Path(source_path).resolve())
|
|
986
|
+
retry_stages = entry.get("retry_stages")
|
|
987
|
+
if isinstance(retry_stages, list) and retry_stages:
|
|
988
|
+
stage_set = {
|
|
989
|
+
stage for stage in retry_stages if isinstance(stage, str) and stage.strip()
|
|
990
|
+
}
|
|
991
|
+
if stage_set:
|
|
992
|
+
retry_stage_map.setdefault(resolved, set()).update(stage_set)
|
|
993
|
+
continue
|
|
994
|
+
retry_full_paths.add(resolved)
|
|
995
|
+
|
|
996
|
+
retry_mode = retry_failed or retry_failed_stages or bool(retry_list_entries)
|
|
997
|
+
retry_stages_mode = retry_failed_stages or bool(retry_stage_map)
|
|
998
|
+
|
|
999
|
+
if retry_mode:
|
|
1000
|
+
retry_paths = set(retry_full_paths) | set(retry_stage_map.keys())
|
|
1001
|
+
markdown_files = [
|
|
1002
|
+
path for path in markdown_files if str(path.resolve()) in retry_paths
|
|
1003
|
+
]
|
|
476
1004
|
logger.debug("Retrying %d markdown files", len(markdown_files))
|
|
1005
|
+
if not markdown_files:
|
|
1006
|
+
logger.warning("Retry list produced 0 files to process")
|
|
477
1007
|
else:
|
|
478
1008
|
logger.debug("Discovered %d markdown files", len(markdown_files))
|
|
479
1009
|
|
|
1010
|
+
stage_definitions = get_stage_definitions(prompt_template) if not custom_prompt else []
|
|
1011
|
+
multi_stage = bool(stage_definitions)
|
|
1012
|
+
stage_dag_enabled = stage_dag and multi_stage
|
|
1013
|
+
|
|
480
1014
|
if dry_run:
|
|
481
1015
|
input_chars = 0
|
|
482
1016
|
prompt_chars = 0
|
|
483
|
-
stage_definitions = get_stage_definitions(prompt_template) if not custom_prompt else []
|
|
484
|
-
multi_stage = bool(stage_definitions)
|
|
485
1017
|
metadata_fields = [
|
|
486
1018
|
"paper_title",
|
|
487
1019
|
"paper_authors",
|
|
@@ -525,6 +1057,21 @@ async def extract_documents(
|
|
|
525
1057
|
)
|
|
526
1058
|
prompt_chars += _count_prompt_chars(messages)
|
|
527
1059
|
|
|
1060
|
+
if stage_dag_enabled and stage_definitions:
|
|
1061
|
+
stage_dependencies = _resolve_stage_dependencies(stage_definitions)
|
|
1062
|
+
_build_dependency_graph(stage_definitions, stage_dependencies)
|
|
1063
|
+
plan_table = Table(
|
|
1064
|
+
title="stage DAG plan (dry-run)",
|
|
1065
|
+
header_style="bold cyan",
|
|
1066
|
+
title_style="bold magenta",
|
|
1067
|
+
)
|
|
1068
|
+
plan_table.add_column("Stage", style="cyan", no_wrap=True)
|
|
1069
|
+
plan_table.add_column("Depends on", style="white")
|
|
1070
|
+
for stage_def in stage_definitions:
|
|
1071
|
+
deps = stage_dependencies.get(stage_def.name, [])
|
|
1072
|
+
plan_table.add_row(stage_def.name, ", ".join(deps) if deps else "none")
|
|
1073
|
+
Console().print(plan_table)
|
|
1074
|
+
|
|
528
1075
|
duration = time.monotonic() - start_time
|
|
529
1076
|
prompt_tokens = _estimate_tokens_for_chars(prompt_chars)
|
|
530
1077
|
completion_tokens = 0
|
|
@@ -562,17 +1109,25 @@ async def extract_documents(
|
|
|
562
1109
|
if isinstance(entry, dict) and entry.get("source_path")
|
|
563
1110
|
}
|
|
564
1111
|
|
|
565
|
-
|
|
1112
|
+
cooldown_seconds = max(1.0, float(config.extract.backoff_base_seconds))
|
|
1113
|
+
resolved_keys = resolve_api_key_configs(provider.api_keys)
|
|
1114
|
+
rotator = KeyRotator(
|
|
1115
|
+
resolved_keys,
|
|
1116
|
+
cooldown_seconds=cooldown_seconds,
|
|
1117
|
+
verbose=verbose,
|
|
1118
|
+
)
|
|
566
1119
|
max_concurrency = max_concurrency_override or config.extract.max_concurrency
|
|
567
1120
|
semaphore = asyncio.Semaphore(max_concurrency)
|
|
568
1121
|
|
|
569
1122
|
errors: list[ExtractionError] = []
|
|
570
1123
|
results: dict[str, dict[str, Any]] = {}
|
|
571
|
-
stage_definitions = get_stage_definitions(prompt_template) if not custom_prompt else []
|
|
572
|
-
multi_stage = bool(stage_definitions)
|
|
573
1124
|
stage_output_dir = Path("paper_stage_outputs")
|
|
574
1125
|
if multi_stage:
|
|
575
1126
|
stage_output_dir.mkdir(parents=True, exist_ok=True)
|
|
1127
|
+
if stage_dag_enabled:
|
|
1128
|
+
logger.info("Multi-stage scheduler: DAG")
|
|
1129
|
+
else:
|
|
1130
|
+
logger.info("Multi-stage scheduler: sequential")
|
|
576
1131
|
|
|
577
1132
|
throttle = None
|
|
578
1133
|
if sleep_every is not None or sleep_time is not None:
|
|
@@ -594,125 +1149,214 @@ async def extract_documents(
|
|
|
594
1149
|
)
|
|
595
1150
|
stats = ExtractionStats(doc_bar=doc_bar)
|
|
596
1151
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
current_stage: str | None = None
|
|
600
|
-
try:
|
|
601
|
-
if verbose:
|
|
602
|
-
logger.debug("Processing %s", source_path)
|
|
603
|
-
content = read_text(path)
|
|
604
|
-
await stats.add_input_chars(len(content))
|
|
605
|
-
source_hash = compute_source_hash(content)
|
|
606
|
-
stage_state: dict[str, Any] | None = None
|
|
607
|
-
stage_path: Path | None = None
|
|
1152
|
+
results_lock = asyncio.Lock()
|
|
1153
|
+
logger.info("Request timeout set to %.1fs", timeout_seconds)
|
|
608
1154
|
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
if stage_bar:
|
|
614
|
-
stage_bar.update(len(stage_definitions))
|
|
615
|
-
return
|
|
1155
|
+
pause_threshold_seconds = max(0.0, float(config.extract.pause_threshold_seconds))
|
|
1156
|
+
pause_watchdog_seconds = max(60.0, pause_threshold_seconds * 6.0)
|
|
1157
|
+
pause_gate: asyncio.Event | None = None
|
|
1158
|
+
pause_task: asyncio.Task[None] | None = None
|
|
616
1159
|
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
)
|
|
1160
|
+
shutdown_event = asyncio.Event()
|
|
1161
|
+
shutdown_reason: str | None = None
|
|
620
1162
|
|
|
621
|
-
|
|
1163
|
+
def request_shutdown(reason: str) -> None:
|
|
1164
|
+
nonlocal shutdown_reason
|
|
1165
|
+
if shutdown_event.is_set():
|
|
1166
|
+
return
|
|
1167
|
+
shutdown_reason = reason
|
|
1168
|
+
shutdown_event.set()
|
|
1169
|
+
if pause_gate:
|
|
1170
|
+
pause_gate.set()
|
|
1171
|
+
logger.warning("Graceful shutdown requested (%s); draining in-flight tasks", reason)
|
|
1172
|
+
|
|
1173
|
+
loop = asyncio.get_running_loop()
|
|
1174
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
1175
|
+
try:
|
|
1176
|
+
loop.add_signal_handler(sig, request_shutdown, sig.name)
|
|
1177
|
+
except (NotImplementedError, RuntimeError, ValueError):
|
|
1178
|
+
signal.signal(
|
|
1179
|
+
sig,
|
|
1180
|
+
lambda *_args, _sig=sig: loop.call_soon_threadsafe(
|
|
1181
|
+
request_shutdown, _sig.name
|
|
1182
|
+
),
|
|
1183
|
+
)
|
|
622
1184
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
stages: dict[str, dict[str, Any]] = stage_state.get("stages", {})
|
|
639
|
-
metadata_fields = [
|
|
640
|
-
"paper_title",
|
|
641
|
-
"paper_authors",
|
|
642
|
-
"publication_date",
|
|
643
|
-
"publication_venue",
|
|
644
|
-
]
|
|
645
|
-
for stage_def in stage_definitions:
|
|
646
|
-
stage_name = stage_def.name
|
|
647
|
-
stage_fields = stage_def.fields
|
|
648
|
-
current_stage = stage_name
|
|
649
|
-
if stage_name in stages and not force:
|
|
650
|
-
if stage_bar:
|
|
651
|
-
stage_bar.update(1)
|
|
1185
|
+
if resolved_keys:
|
|
1186
|
+
pause_gate = asyncio.Event()
|
|
1187
|
+
pause_gate.set()
|
|
1188
|
+
|
|
1189
|
+
async def pause_watcher() -> None:
|
|
1190
|
+
paused = False
|
|
1191
|
+
try:
|
|
1192
|
+
while True:
|
|
1193
|
+
wait_for, reason, wait_until_epoch = await rotator.key_pool_wait()
|
|
1194
|
+
if wait_for is None or wait_for <= 0:
|
|
1195
|
+
if paused:
|
|
1196
|
+
paused = False
|
|
1197
|
+
pause_gate.set()
|
|
1198
|
+
logger.info("Queue resumed; key pool available")
|
|
1199
|
+
await asyncio.sleep(0.5)
|
|
652
1200
|
continue
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
provider,
|
|
662
|
-
prompt_template,
|
|
663
|
-
output_language,
|
|
664
|
-
custom_prompt=False,
|
|
665
|
-
prompt_system_path=None,
|
|
666
|
-
prompt_user_path=None,
|
|
667
|
-
stage_name=stage_name,
|
|
668
|
-
stage_fields=required_fields,
|
|
669
|
-
previous_outputs=previous_outputs,
|
|
670
|
-
)
|
|
671
|
-
async with semaphore:
|
|
672
|
-
data = await call_with_retries(
|
|
673
|
-
provider,
|
|
674
|
-
model,
|
|
675
|
-
messages,
|
|
676
|
-
stage_schema,
|
|
677
|
-
api_key,
|
|
678
|
-
timeout=60.0,
|
|
679
|
-
structured_mode=provider.structured_mode,
|
|
680
|
-
max_retries=config.extract.max_retries,
|
|
681
|
-
backoff_base_seconds=config.extract.backoff_base_seconds,
|
|
682
|
-
backoff_max_seconds=config.extract.backoff_max_seconds,
|
|
683
|
-
client=client,
|
|
684
|
-
validator=stage_validator,
|
|
685
|
-
throttle=throttle,
|
|
686
|
-
stats=stats,
|
|
1201
|
+
if wait_for <= pause_threshold_seconds:
|
|
1202
|
+
if paused:
|
|
1203
|
+
paused = False
|
|
1204
|
+
pause_gate.set()
|
|
1205
|
+
logger.info(
|
|
1206
|
+
"Queue resumed; key pool wait %.2fs below threshold %.2fs",
|
|
1207
|
+
wait_for,
|
|
1208
|
+
pause_threshold_seconds,
|
|
687
1209
|
)
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
1210
|
+
await asyncio.sleep(min(wait_for, pause_threshold_seconds))
|
|
1211
|
+
continue
|
|
1212
|
+
if not paused:
|
|
1213
|
+
paused = True
|
|
1214
|
+
pause_gate.clear()
|
|
1215
|
+
reset_dt = (
|
|
1216
|
+
datetime.fromtimestamp(wait_until_epoch).astimezone().isoformat()
|
|
1217
|
+
if wait_until_epoch
|
|
1218
|
+
else "unknown"
|
|
1219
|
+
)
|
|
1220
|
+
logger.warning(
|
|
1221
|
+
"Queue paused (keys unavailable: %s); waiting %.2fs until %s",
|
|
1222
|
+
reason or "unknown",
|
|
1223
|
+
wait_for,
|
|
1224
|
+
reset_dt,
|
|
1225
|
+
)
|
|
1226
|
+
await asyncio.sleep(min(wait_for, max(pause_threshold_seconds, 1.0)))
|
|
1227
|
+
except asyncio.CancelledError:
|
|
1228
|
+
raise
|
|
1229
|
+
except Exception:
|
|
1230
|
+
logger.exception("Queue pause watcher failed; releasing pause gate")
|
|
1231
|
+
pause_gate.set()
|
|
694
1232
|
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
1233
|
+
pause_task = asyncio.create_task(pause_watcher())
|
|
1234
|
+
|
|
1235
|
+
async def await_key_pool_ready() -> None:
|
|
1236
|
+
if shutdown_event.is_set():
|
|
1237
|
+
return
|
|
1238
|
+
if not pause_gate or pause_gate.is_set():
|
|
1239
|
+
return
|
|
1240
|
+
try:
|
|
1241
|
+
pause_wait = asyncio.create_task(pause_gate.wait())
|
|
1242
|
+
shutdown_wait = asyncio.create_task(shutdown_event.wait())
|
|
1243
|
+
try:
|
|
1244
|
+
await asyncio.wait_for(
|
|
1245
|
+
asyncio.wait(
|
|
1246
|
+
[pause_wait, shutdown_wait],
|
|
1247
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
1248
|
+
),
|
|
1249
|
+
timeout=pause_watchdog_seconds,
|
|
1250
|
+
)
|
|
1251
|
+
finally:
|
|
1252
|
+
for task in (pause_wait, shutdown_wait):
|
|
1253
|
+
task.cancel()
|
|
1254
|
+
except asyncio.TimeoutError:
|
|
1255
|
+
logger.warning(
|
|
1256
|
+
"Queue pause watchdog timeout; rechecking key pool availability"
|
|
1257
|
+
)
|
|
1258
|
+
wait_for, _, _ = await rotator.key_pool_wait()
|
|
1259
|
+
if wait_for is None or wait_for <= 0:
|
|
1260
|
+
pause_gate.set()
|
|
1261
|
+
|
|
1262
|
+
async def drain_queue(queue: asyncio.Queue[Any], drain_lock: asyncio.Lock) -> int:
|
|
1263
|
+
drained = 0
|
|
1264
|
+
async with drain_lock:
|
|
1265
|
+
while True:
|
|
1266
|
+
try:
|
|
1267
|
+
queue.get_nowait()
|
|
1268
|
+
except asyncio.QueueEmpty:
|
|
1269
|
+
break
|
|
1270
|
+
queue.task_done()
|
|
1271
|
+
drained += 1
|
|
1272
|
+
return drained
|
|
1273
|
+
|
|
1274
|
+
async def wait_for_queue(queue: asyncio.Queue[Any], drain_lock: asyncio.Lock) -> None:
|
|
1275
|
+
if shutdown_event.is_set():
|
|
1276
|
+
await drain_queue(queue, drain_lock)
|
|
1277
|
+
await queue.join()
|
|
1278
|
+
return
|
|
1279
|
+
join_task = asyncio.create_task(queue.join())
|
|
1280
|
+
shutdown_task = asyncio.create_task(shutdown_event.wait())
|
|
1281
|
+
done, _pending = await asyncio.wait(
|
|
1282
|
+
[join_task, shutdown_task],
|
|
1283
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
1284
|
+
)
|
|
1285
|
+
if shutdown_task in done:
|
|
1286
|
+
await drain_queue(queue, drain_lock)
|
|
1287
|
+
await queue.join()
|
|
1288
|
+
for task in (join_task, shutdown_task):
|
|
1289
|
+
task.cancel()
|
|
1290
|
+
|
|
1291
|
+
def build_output_payload() -> dict[str, Any]:
|
|
1292
|
+
final_results: list[dict[str, Any]] = []
|
|
1293
|
+
seen = set()
|
|
1294
|
+
for entry in existing:
|
|
1295
|
+
path = entry.get("source_path") if isinstance(entry, dict) else None
|
|
1296
|
+
if path and path in results:
|
|
1297
|
+
final_results.append(results[path])
|
|
1298
|
+
seen.add(path)
|
|
1299
|
+
elif path:
|
|
1300
|
+
final_results.append(entry)
|
|
1301
|
+
seen.add(path)
|
|
1302
|
+
|
|
1303
|
+
for path, entry in results.items():
|
|
1304
|
+
if path not in seen:
|
|
1305
|
+
final_results.append(entry)
|
|
1306
|
+
return {"template_tag": template_tag, "papers": final_results}
|
|
1307
|
+
|
|
1308
|
+
async def persist_output_snapshot() -> None:
|
|
1309
|
+
async with results_lock:
|
|
1310
|
+
payload = build_output_payload()
|
|
1311
|
+
await asyncio.to_thread(write_json, output_path, payload)
|
|
1312
|
+
|
|
1313
|
+
def build_merged(ctx: DocContext) -> dict[str, Any]:
|
|
1314
|
+
merged: dict[str, Any] = {}
|
|
1315
|
+
for stage_def in stage_definitions:
|
|
1316
|
+
merged.update(ctx.stages.get(stage_def.name, {}))
|
|
1317
|
+
return merged
|
|
1318
|
+
|
|
1319
|
+
async def update_results(ctx: DocContext) -> None:
|
|
1320
|
+
merged = build_merged(ctx)
|
|
1321
|
+
data = append_metadata(
|
|
1322
|
+
merged,
|
|
1323
|
+
source_path=ctx.source_path,
|
|
1324
|
+
source_hash=ctx.source_hash,
|
|
1325
|
+
provider=provider.name,
|
|
1326
|
+
model=model,
|
|
1327
|
+
truncation=ctx.truncation,
|
|
1328
|
+
prompt_template=prompt_template,
|
|
1329
|
+
output_language=output_language,
|
|
1330
|
+
)
|
|
1331
|
+
async with results_lock:
|
|
1332
|
+
results[ctx.source_path] = data
|
|
1333
|
+
await persist_output_snapshot()
|
|
1334
|
+
|
|
1335
|
+
async def run_single_stage(client: httpx.AsyncClient) -> None:
|
|
1336
|
+
doc_queue: asyncio.Queue[Path] = asyncio.Queue()
|
|
1337
|
+
drain_lock = asyncio.Lock()
|
|
1338
|
+
|
|
1339
|
+
async def process_one(path: Path) -> None:
|
|
1340
|
+
source_path = str(path.resolve())
|
|
1341
|
+
current_stage: str | None = None
|
|
1342
|
+
try:
|
|
1343
|
+
if shutdown_event.is_set():
|
|
1344
|
+
return
|
|
1345
|
+
if verbose:
|
|
1346
|
+
logger.debug("Processing %s", source_path)
|
|
1347
|
+
content = read_text(path)
|
|
1348
|
+
await stats.add_input_chars(len(content))
|
|
1349
|
+
source_hash = compute_source_hash(content)
|
|
1350
|
+
|
|
1351
|
+
if not force and not retry_mode:
|
|
1352
|
+
existing_entry = existing_by_path.get(source_path)
|
|
1353
|
+
if existing_entry and existing_entry.get("source_hash") == source_hash:
|
|
1354
|
+
results[source_path] = existing_entry
|
|
1355
|
+
return
|
|
1356
|
+
|
|
1357
|
+
truncated_content, truncation = truncate_content(
|
|
1358
|
+
content, config.extract.truncate_max_chars, config.extract.truncate_strategy
|
|
713
1359
|
)
|
|
714
|
-
results[source_path] = data
|
|
715
|
-
else:
|
|
716
1360
|
messages = build_messages(
|
|
717
1361
|
truncated_content,
|
|
718
1362
|
schema,
|
|
@@ -723,20 +1367,24 @@ async def extract_documents(
|
|
|
723
1367
|
prompt_system_path=prompt_system_path,
|
|
724
1368
|
prompt_user_path=prompt_user_path,
|
|
725
1369
|
)
|
|
1370
|
+
if shutdown_event.is_set():
|
|
1371
|
+
return
|
|
1372
|
+
await await_key_pool_ready()
|
|
726
1373
|
async with semaphore:
|
|
727
1374
|
data = await call_with_retries(
|
|
728
1375
|
provider,
|
|
729
1376
|
model,
|
|
730
1377
|
messages,
|
|
731
1378
|
schema,
|
|
732
|
-
|
|
733
|
-
timeout=
|
|
1379
|
+
None,
|
|
1380
|
+
timeout=timeout_seconds,
|
|
734
1381
|
structured_mode=provider.structured_mode,
|
|
735
1382
|
max_retries=config.extract.max_retries,
|
|
736
1383
|
backoff_base_seconds=config.extract.backoff_base_seconds,
|
|
737
1384
|
backoff_max_seconds=config.extract.backoff_max_seconds,
|
|
738
1385
|
client=client,
|
|
739
1386
|
validator=validator,
|
|
1387
|
+
key_rotator=rotator,
|
|
740
1388
|
throttle=throttle,
|
|
741
1389
|
stats=stats,
|
|
742
1390
|
)
|
|
@@ -752,38 +1400,770 @@ async def extract_documents(
|
|
|
752
1400
|
output_language=output_language,
|
|
753
1401
|
)
|
|
754
1402
|
results[source_path] = data
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
model=model,
|
|
762
|
-
error_type=exc.error_type,
|
|
763
|
-
error_message=str(exc),
|
|
764
|
-
stage_name=current_stage if multi_stage else None,
|
|
1403
|
+
except ProviderError as exc:
|
|
1404
|
+
log_extraction_failure(
|
|
1405
|
+
source_path,
|
|
1406
|
+
exc.error_type,
|
|
1407
|
+
str(exc),
|
|
1408
|
+
status_code=exc.status_code,
|
|
765
1409
|
)
|
|
1410
|
+
errors.append(
|
|
1411
|
+
ExtractionError(
|
|
1412
|
+
path=path,
|
|
1413
|
+
provider=provider.name,
|
|
1414
|
+
model=model,
|
|
1415
|
+
error_type=exc.error_type,
|
|
1416
|
+
error_message=str(exc),
|
|
1417
|
+
stage_name=current_stage if multi_stage else None,
|
|
1418
|
+
)
|
|
1419
|
+
)
|
|
1420
|
+
except Exception as exc: # pragma: no cover - safety net
|
|
1421
|
+
logger.exception("Unexpected error while processing %s", source_path)
|
|
1422
|
+
errors.append(
|
|
1423
|
+
ExtractionError(
|
|
1424
|
+
path=path,
|
|
1425
|
+
provider=provider.name,
|
|
1426
|
+
model=model,
|
|
1427
|
+
error_type="unexpected_error",
|
|
1428
|
+
error_message=str(exc),
|
|
1429
|
+
stage_name=current_stage if multi_stage else None,
|
|
1430
|
+
)
|
|
1431
|
+
)
|
|
1432
|
+
finally:
|
|
1433
|
+
if doc_bar:
|
|
1434
|
+
doc_bar.update(1)
|
|
1435
|
+
|
|
1436
|
+
for path in markdown_files:
|
|
1437
|
+
if shutdown_event.is_set():
|
|
1438
|
+
break
|
|
1439
|
+
doc_queue.put_nowait(path)
|
|
1440
|
+
|
|
1441
|
+
async def worker() -> None:
|
|
1442
|
+
while True:
|
|
1443
|
+
if shutdown_event.is_set():
|
|
1444
|
+
await drain_queue(doc_queue, drain_lock)
|
|
1445
|
+
return
|
|
1446
|
+
try:
|
|
1447
|
+
path = doc_queue.get_nowait()
|
|
1448
|
+
except asyncio.QueueEmpty:
|
|
1449
|
+
return
|
|
1450
|
+
await process_one(path)
|
|
1451
|
+
doc_queue.task_done()
|
|
1452
|
+
|
|
1453
|
+
workers = [asyncio.create_task(worker()) for _ in range(max_concurrency)]
|
|
1454
|
+
if markdown_files:
|
|
1455
|
+
await wait_for_queue(doc_queue, drain_lock)
|
|
1456
|
+
for w in workers:
|
|
1457
|
+
w.cancel()
|
|
1458
|
+
|
|
1459
|
+
async def run_multi_stage(client: httpx.AsyncClient) -> None:
|
|
1460
|
+
metadata_fields = [
|
|
1461
|
+
"paper_title",
|
|
1462
|
+
"paper_authors",
|
|
1463
|
+
"publication_date",
|
|
1464
|
+
"publication_venue",
|
|
1465
|
+
]
|
|
1466
|
+
force_stage_set = set(force_stages or [])
|
|
1467
|
+
doc_contexts: dict[Path, DocContext] = {}
|
|
1468
|
+
doc_states: dict[Path, DocState] = {}
|
|
1469
|
+
task_queue: asyncio.Queue[DocTask] = asyncio.Queue()
|
|
1470
|
+
drain_lock = asyncio.Lock()
|
|
1471
|
+
|
|
1472
|
+
for path in markdown_files:
|
|
1473
|
+
if shutdown_event.is_set():
|
|
1474
|
+
break
|
|
1475
|
+
source_path = str(path.resolve())
|
|
1476
|
+
if verbose:
|
|
1477
|
+
logger.debug("Preparing %s", source_path)
|
|
1478
|
+
content = read_text(path)
|
|
1479
|
+
await stats.add_input_chars(len(content))
|
|
1480
|
+
source_hash = compute_source_hash(content)
|
|
1481
|
+
truncated_content, truncation = truncate_content(
|
|
1482
|
+
content, config.extract.truncate_max_chars, config.extract.truncate_strategy
|
|
766
1483
|
)
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
1484
|
+
stage_path = stage_output_dir / f"{stable_hash(source_path)}.json"
|
|
1485
|
+
stage_state = load_stage_state(stage_path) if not force else None
|
|
1486
|
+
if stage_state and stage_state.get("source_hash") != source_hash:
|
|
1487
|
+
stage_state = None
|
|
1488
|
+
if stage_state is None:
|
|
1489
|
+
stage_state = {
|
|
1490
|
+
"source_path": source_path,
|
|
1491
|
+
"source_hash": source_hash,
|
|
1492
|
+
"prompt_template": prompt_template,
|
|
1493
|
+
"output_language": output_language,
|
|
1494
|
+
"stages": {},
|
|
1495
|
+
"stage_meta": {},
|
|
1496
|
+
}
|
|
1497
|
+
stages: dict[str, dict[str, Any]] = stage_state.get("stages", {})
|
|
1498
|
+
stage_meta: dict[str, dict[str, Any]] = stage_state.get("stage_meta", {})
|
|
1499
|
+
doc_contexts[path] = DocContext(
|
|
1500
|
+
path=path,
|
|
1501
|
+
source_path=source_path,
|
|
1502
|
+
content=content,
|
|
1503
|
+
truncated_content=truncated_content,
|
|
1504
|
+
truncation=truncation,
|
|
1505
|
+
source_hash=source_hash,
|
|
1506
|
+
stage_path=stage_path,
|
|
1507
|
+
stage_state=stage_state,
|
|
1508
|
+
stages=stages,
|
|
1509
|
+
stage_meta=stage_meta,
|
|
1510
|
+
)
|
|
1511
|
+
doc_states[path] = DocState(total_stages=len(stage_definitions))
|
|
1512
|
+
for idx, stage_def in enumerate(stage_definitions):
|
|
1513
|
+
required_fields = metadata_fields + stage_def.fields
|
|
1514
|
+
task_queue.put_nowait(
|
|
1515
|
+
DocTask(
|
|
1516
|
+
path=path,
|
|
1517
|
+
stage_index=idx,
|
|
1518
|
+
stage_name=stage_def.name,
|
|
1519
|
+
stage_fields=required_fields,
|
|
1520
|
+
)
|
|
1521
|
+
)
|
|
1522
|
+
|
|
1523
|
+
async def run_task(task: DocTask) -> None:
|
|
1524
|
+
ctx = doc_contexts[task.path]
|
|
1525
|
+
state = doc_states[task.path]
|
|
1526
|
+
|
|
1527
|
+
while True:
|
|
1528
|
+
if shutdown_event.is_set():
|
|
1529
|
+
async with state.lock:
|
|
1530
|
+
state.failed = True
|
|
1531
|
+
state.event.set()
|
|
1532
|
+
if stage_bar:
|
|
1533
|
+
stage_bar.update(1)
|
|
1534
|
+
return
|
|
1535
|
+
async with state.lock:
|
|
1536
|
+
if state.failed:
|
|
1537
|
+
if stage_bar:
|
|
1538
|
+
stage_bar.update(1)
|
|
1539
|
+
return
|
|
1540
|
+
if task.stage_index == state.next_index:
|
|
1541
|
+
break
|
|
1542
|
+
wait_event = state.event
|
|
1543
|
+
await wait_event.wait()
|
|
1544
|
+
|
|
1545
|
+
current_stage = task.stage_name
|
|
1546
|
+
is_retry_full = ctx.source_path in retry_full_paths
|
|
1547
|
+
retry_stages = (
|
|
1548
|
+
retry_stage_map.get(ctx.source_path)
|
|
1549
|
+
if retry_stages_mode and not is_retry_full
|
|
1550
|
+
else None
|
|
1551
|
+
)
|
|
1552
|
+
if shutdown_event.is_set():
|
|
1553
|
+
async with state.lock:
|
|
1554
|
+
state.failed = True
|
|
1555
|
+
state.event.set()
|
|
1556
|
+
if stage_bar:
|
|
1557
|
+
stage_bar.update(1)
|
|
1558
|
+
return
|
|
1559
|
+
stage_record = ctx.stages.get(current_stage)
|
|
1560
|
+
if (
|
|
1561
|
+
retry_stages_mode
|
|
1562
|
+
and not is_retry_full
|
|
1563
|
+
and retry_stages is not None
|
|
1564
|
+
and current_stage not in retry_stages
|
|
1565
|
+
and stage_record is not None
|
|
1566
|
+
):
|
|
1567
|
+
if stage_bar:
|
|
1568
|
+
stage_bar.update(1)
|
|
1569
|
+
await update_results(ctx)
|
|
1570
|
+
final_validation_error: str | None = None
|
|
1571
|
+
if not state.failed and task.stage_index == state.total_stages - 1:
|
|
1572
|
+
merged = build_merged(ctx)
|
|
1573
|
+
errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
|
|
1574
|
+
if errors_in_doc:
|
|
1575
|
+
final_validation_error = errors_in_doc[0].message
|
|
1576
|
+
|
|
1577
|
+
async with state.lock:
|
|
1578
|
+
if final_validation_error:
|
|
1579
|
+
errors.append(
|
|
1580
|
+
ExtractionError(
|
|
1581
|
+
path=task.path,
|
|
1582
|
+
provider=provider.name,
|
|
1583
|
+
model=model,
|
|
1584
|
+
error_type="validation_error",
|
|
1585
|
+
error_message=f"Schema validation failed: {final_validation_error}",
|
|
1586
|
+
stage_name=current_stage,
|
|
1587
|
+
)
|
|
1588
|
+
)
|
|
1589
|
+
state.failed = True
|
|
1590
|
+
if not state.failed:
|
|
1591
|
+
state.next_index += 1
|
|
1592
|
+
if state.next_index >= state.total_stages or state.failed:
|
|
1593
|
+
if doc_bar and state.total_stages:
|
|
1594
|
+
doc_bar.update(1)
|
|
1595
|
+
state.event.set()
|
|
1596
|
+
state.event = asyncio.Event()
|
|
1597
|
+
return
|
|
1598
|
+
prompt_hash = prompt_hash_map[current_stage]
|
|
1599
|
+
stage_schema = stage_schema_map[current_stage]
|
|
1600
|
+
stage_validator = stage_validator_map[current_stage]
|
|
1601
|
+
stage_meta = ctx.stage_meta.get(current_stage, {})
|
|
1602
|
+
needs_run = force or current_stage in force_stage_set
|
|
1603
|
+
if stage_record is None:
|
|
1604
|
+
needs_run = True
|
|
1605
|
+
if is_retry_full:
|
|
1606
|
+
needs_run = True
|
|
1607
|
+
if retry_stages is not None and current_stage in retry_stages:
|
|
1608
|
+
needs_run = True
|
|
1609
|
+
if stage_meta.get("prompt_hash") != prompt_hash:
|
|
1610
|
+
needs_run = True
|
|
1611
|
+
if stage_record is not None and not needs_run:
|
|
1612
|
+
errors_in_stage = sorted(stage_validator.iter_errors(stage_record), key=lambda e: e.path)
|
|
1613
|
+
if errors_in_stage:
|
|
1614
|
+
needs_run = True
|
|
1615
|
+
|
|
1616
|
+
if not needs_run:
|
|
1617
|
+
if stage_bar:
|
|
1618
|
+
stage_bar.update(1)
|
|
1619
|
+
await update_results(ctx)
|
|
1620
|
+
else:
|
|
1621
|
+
try:
|
|
1622
|
+
previous_outputs = json.dumps(ctx.stages, ensure_ascii=False)
|
|
1623
|
+
messages = build_messages(
|
|
1624
|
+
ctx.truncated_content,
|
|
1625
|
+
stage_schema,
|
|
1626
|
+
provider,
|
|
1627
|
+
prompt_template,
|
|
1628
|
+
output_language,
|
|
1629
|
+
custom_prompt=False,
|
|
1630
|
+
prompt_system_path=None,
|
|
1631
|
+
prompt_user_path=None,
|
|
1632
|
+
stage_name=current_stage,
|
|
1633
|
+
stage_fields=task.stage_fields,
|
|
1634
|
+
previous_outputs=previous_outputs,
|
|
1635
|
+
)
|
|
1636
|
+
async with semaphore:
|
|
1637
|
+
data = await call_with_retries(
|
|
1638
|
+
provider,
|
|
1639
|
+
model,
|
|
1640
|
+
messages,
|
|
1641
|
+
stage_schema,
|
|
1642
|
+
None,
|
|
1643
|
+
timeout=timeout_seconds,
|
|
1644
|
+
structured_mode=provider.structured_mode,
|
|
1645
|
+
max_retries=config.extract.max_retries,
|
|
1646
|
+
backoff_base_seconds=config.extract.backoff_base_seconds,
|
|
1647
|
+
backoff_max_seconds=config.extract.backoff_max_seconds,
|
|
1648
|
+
client=client,
|
|
1649
|
+
validator=stage_validator,
|
|
1650
|
+
key_rotator=rotator,
|
|
1651
|
+
throttle=throttle,
|
|
1652
|
+
stats=stats,
|
|
1653
|
+
)
|
|
1654
|
+
ctx.stages[current_stage] = data
|
|
1655
|
+
ctx.stage_meta[current_stage] = {"prompt_hash": prompt_hash}
|
|
1656
|
+
ctx.stage_state["stages"] = ctx.stages
|
|
1657
|
+
ctx.stage_state["stage_meta"] = ctx.stage_meta
|
|
1658
|
+
write_json_atomic(ctx.stage_path, ctx.stage_state)
|
|
1659
|
+
if stage_bar:
|
|
1660
|
+
stage_bar.update(1)
|
|
1661
|
+
await update_results(ctx)
|
|
1662
|
+
except ProviderError as exc:
|
|
1663
|
+
log_extraction_failure(
|
|
1664
|
+
ctx.source_path,
|
|
1665
|
+
exc.error_type,
|
|
1666
|
+
str(exc),
|
|
1667
|
+
status_code=exc.status_code,
|
|
1668
|
+
)
|
|
1669
|
+
errors.append(
|
|
1670
|
+
ExtractionError(
|
|
1671
|
+
path=task.path,
|
|
1672
|
+
provider=provider.name,
|
|
1673
|
+
model=model,
|
|
1674
|
+
error_type=exc.error_type,
|
|
1675
|
+
error_message=str(exc),
|
|
1676
|
+
stage_name=current_stage,
|
|
1677
|
+
)
|
|
1678
|
+
)
|
|
1679
|
+
async with state.lock:
|
|
1680
|
+
state.failed = True
|
|
1681
|
+
if stage_bar:
|
|
1682
|
+
stage_bar.update(1)
|
|
1683
|
+
except Exception as exc: # pragma: no cover - safety net
|
|
1684
|
+
logger.exception("Unexpected error while processing %s", ctx.source_path)
|
|
1685
|
+
errors.append(
|
|
1686
|
+
ExtractionError(
|
|
1687
|
+
path=task.path,
|
|
1688
|
+
provider=provider.name,
|
|
1689
|
+
model=model,
|
|
1690
|
+
error_type="unexpected_error",
|
|
1691
|
+
error_message=str(exc),
|
|
1692
|
+
stage_name=current_stage,
|
|
1693
|
+
)
|
|
1694
|
+
)
|
|
1695
|
+
async with state.lock:
|
|
1696
|
+
state.failed = True
|
|
1697
|
+
if stage_bar:
|
|
1698
|
+
stage_bar.update(1)
|
|
1699
|
+
|
|
1700
|
+
final_validation_error: str | None = None
|
|
1701
|
+
if not state.failed and task.stage_index == state.total_stages - 1:
|
|
1702
|
+
merged = build_merged(ctx)
|
|
1703
|
+
errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
|
|
1704
|
+
if errors_in_doc:
|
|
1705
|
+
final_validation_error = errors_in_doc[0].message
|
|
1706
|
+
|
|
1707
|
+
async with state.lock:
|
|
1708
|
+
if final_validation_error:
|
|
1709
|
+
errors.append(
|
|
1710
|
+
ExtractionError(
|
|
1711
|
+
path=task.path,
|
|
1712
|
+
provider=provider.name,
|
|
1713
|
+
model=model,
|
|
1714
|
+
error_type="validation_error",
|
|
1715
|
+
error_message=f"Schema validation failed: {final_validation_error}",
|
|
1716
|
+
stage_name=current_stage,
|
|
1717
|
+
)
|
|
1718
|
+
)
|
|
1719
|
+
state.failed = True
|
|
1720
|
+
if not state.failed:
|
|
1721
|
+
state.next_index += 1
|
|
1722
|
+
if state.next_index >= state.total_stages or state.failed:
|
|
1723
|
+
if doc_bar and state.total_stages:
|
|
1724
|
+
doc_bar.update(1)
|
|
1725
|
+
state.event.set()
|
|
1726
|
+
state.event = asyncio.Event()
|
|
1727
|
+
|
|
1728
|
+
async def worker() -> None:
|
|
1729
|
+
while True:
|
|
1730
|
+
if shutdown_event.is_set():
|
|
1731
|
+
await drain_queue(task_queue, drain_lock)
|
|
1732
|
+
return
|
|
1733
|
+
await await_key_pool_ready()
|
|
1734
|
+
try:
|
|
1735
|
+
task = task_queue.get_nowait()
|
|
1736
|
+
except asyncio.QueueEmpty:
|
|
1737
|
+
return
|
|
1738
|
+
await run_task(task)
|
|
1739
|
+
task_queue.task_done()
|
|
1740
|
+
|
|
1741
|
+
workers = [asyncio.create_task(worker()) for _ in range(max_concurrency)]
|
|
1742
|
+
await wait_for_queue(task_queue, drain_lock)
|
|
1743
|
+
for w in workers:
|
|
1744
|
+
w.cancel()
|
|
1745
|
+
|
|
1746
|
+
async def run_multi_stage_dag(client: httpx.AsyncClient) -> None:
|
|
1747
|
+
metadata_fields = [
|
|
1748
|
+
"paper_title",
|
|
1749
|
+
"paper_authors",
|
|
1750
|
+
"publication_date",
|
|
1751
|
+
"publication_venue",
|
|
1752
|
+
]
|
|
1753
|
+
force_stage_set = set(force_stages or [])
|
|
1754
|
+
stage_dependencies = _resolve_stage_dependencies(stage_definitions)
|
|
1755
|
+
dependents_map = _build_dependency_graph(stage_definitions, stage_dependencies)
|
|
1756
|
+
stage_index_map = {stage_def.name: idx for idx, stage_def in enumerate(stage_definitions)}
|
|
1757
|
+
stage_fields_map = {
|
|
1758
|
+
stage_def.name: metadata_fields + stage_def.fields for stage_def in stage_definitions
|
|
1759
|
+
}
|
|
1760
|
+
stage_schema_map: dict[str, dict[str, Any]] = {}
|
|
1761
|
+
stage_validator_map: dict[str, Draft7Validator] = {}
|
|
1762
|
+
prompt_hash_map: dict[str, str] = {}
|
|
1763
|
+
for stage_def in stage_definitions:
|
|
1764
|
+
stage_name = stage_def.name
|
|
1765
|
+
stage_schema = build_stage_schema(schema, stage_fields_map[stage_name])
|
|
1766
|
+
stage_schema_map[stage_name] = stage_schema
|
|
1767
|
+
stage_validator_map[stage_name] = validate_schema(stage_schema)
|
|
1768
|
+
prompt_hash_map[stage_name] = _compute_prompt_hash(
|
|
1769
|
+
prompt_template=prompt_template,
|
|
1770
|
+
output_language=output_language,
|
|
1771
|
+
stage_name=stage_name,
|
|
1772
|
+
stage_fields=stage_fields_map[stage_name],
|
|
1773
|
+
custom_prompt=custom_prompt,
|
|
1774
|
+
prompt_system_path=prompt_system_path,
|
|
1775
|
+
prompt_user_path=prompt_user_path,
|
|
1776
|
+
)
|
|
1777
|
+
|
|
1778
|
+
total_reused = 0
|
|
1779
|
+
fully_completed = 0
|
|
1780
|
+
|
|
1781
|
+
doc_contexts: dict[Path, DocContext] = {}
|
|
1782
|
+
doc_states: dict[Path, DocDagState] = {}
|
|
1783
|
+
task_queue: asyncio.Queue[DocTask] = asyncio.Queue()
|
|
1784
|
+
drain_lock = asyncio.Lock()
|
|
1785
|
+
|
|
1786
|
+
async def enqueue_ready(path: Path, stage_names: Iterable[str]) -> None:
|
|
1787
|
+
if shutdown_event.is_set():
|
|
1788
|
+
return
|
|
1789
|
+
state = doc_states[path]
|
|
1790
|
+
for stage_name in stage_names:
|
|
1791
|
+
async with state.lock:
|
|
1792
|
+
if state.failed:
|
|
1793
|
+
continue
|
|
1794
|
+
if stage_name in state.completed or stage_name in state.in_flight:
|
|
1795
|
+
continue
|
|
1796
|
+
dependencies = stage_dependencies.get(stage_name, [])
|
|
1797
|
+
if any(dep not in state.completed for dep in dependencies):
|
|
1798
|
+
continue
|
|
1799
|
+
state.in_flight.add(stage_name)
|
|
1800
|
+
task_queue.put_nowait(
|
|
1801
|
+
DocTask(
|
|
1802
|
+
path=path,
|
|
1803
|
+
stage_index=stage_index_map[stage_name],
|
|
1804
|
+
stage_name=stage_name,
|
|
1805
|
+
stage_fields=stage_fields_map[stage_name],
|
|
1806
|
+
)
|
|
777
1807
|
)
|
|
1808
|
+
|
|
1809
|
+
for path in markdown_files:
|
|
1810
|
+
if shutdown_event.is_set():
|
|
1811
|
+
break
|
|
1812
|
+
source_path = str(path.resolve())
|
|
1813
|
+
if verbose:
|
|
1814
|
+
logger.debug("Preparing %s", source_path)
|
|
1815
|
+
content = read_text(path)
|
|
1816
|
+
await stats.add_input_chars(len(content))
|
|
1817
|
+
source_hash = compute_source_hash(content)
|
|
1818
|
+
truncated_content, truncation = truncate_content(
|
|
1819
|
+
content, config.extract.truncate_max_chars, config.extract.truncate_strategy
|
|
1820
|
+
)
|
|
1821
|
+
stage_path = stage_output_dir / f"{stable_hash(source_path)}.json"
|
|
1822
|
+
stage_state = load_stage_state(stage_path) if not force else None
|
|
1823
|
+
if stage_state and stage_state.get("source_hash") != source_hash:
|
|
1824
|
+
stage_state = None
|
|
1825
|
+
if stage_state is None:
|
|
1826
|
+
stage_state = {
|
|
1827
|
+
"source_path": source_path,
|
|
1828
|
+
"source_hash": source_hash,
|
|
1829
|
+
"prompt_template": prompt_template,
|
|
1830
|
+
"output_language": output_language,
|
|
1831
|
+
"stages": {},
|
|
1832
|
+
"stage_meta": {},
|
|
1833
|
+
}
|
|
1834
|
+
stages: dict[str, dict[str, Any]] = stage_state.get("stages", {})
|
|
1835
|
+
stage_meta: dict[str, dict[str, Any]] = stage_state.get("stage_meta", {})
|
|
1836
|
+
doc_contexts[path] = DocContext(
|
|
1837
|
+
path=path,
|
|
1838
|
+
source_path=source_path,
|
|
1839
|
+
content=content,
|
|
1840
|
+
truncated_content=truncated_content,
|
|
1841
|
+
truncation=truncation,
|
|
1842
|
+
source_hash=source_hash,
|
|
1843
|
+
stage_path=stage_path,
|
|
1844
|
+
stage_state=stage_state,
|
|
1845
|
+
stages=stages,
|
|
1846
|
+
stage_meta=stage_meta,
|
|
1847
|
+
)
|
|
1848
|
+
doc_states[path] = DocDagState(
|
|
1849
|
+
total_stages=len(stage_definitions),
|
|
1850
|
+
remaining=len(stage_definitions),
|
|
1851
|
+
)
|
|
1852
|
+
state = doc_states[path]
|
|
1853
|
+
is_retry_full = source_path in retry_full_paths
|
|
1854
|
+
retry_stages = (
|
|
1855
|
+
retry_stage_map.get(source_path)
|
|
1856
|
+
if retry_stages_mode and not is_retry_full
|
|
1857
|
+
else None
|
|
1858
|
+
)
|
|
1859
|
+
reused_count = 0
|
|
1860
|
+
for stage_def in stage_definitions:
|
|
1861
|
+
stage_name = stage_def.name
|
|
1862
|
+
stage_record = stages.get(stage_name)
|
|
1863
|
+
if (
|
|
1864
|
+
retry_stages_mode
|
|
1865
|
+
and not is_retry_full
|
|
1866
|
+
and retry_stages is not None
|
|
1867
|
+
and stage_name not in retry_stages
|
|
1868
|
+
and stage_record is not None
|
|
1869
|
+
):
|
|
1870
|
+
state.completed.add(stage_name)
|
|
1871
|
+
state.remaining -= 1
|
|
1872
|
+
reused_count += 1
|
|
1873
|
+
continue
|
|
1874
|
+
|
|
1875
|
+
stage_meta_entry = stage_meta.get(stage_name, {})
|
|
1876
|
+
needs_run = force or stage_name in force_stage_set
|
|
1877
|
+
if stage_record is None:
|
|
1878
|
+
needs_run = True
|
|
1879
|
+
if is_retry_full:
|
|
1880
|
+
needs_run = True
|
|
1881
|
+
if retry_stages is not None and stage_name in retry_stages:
|
|
1882
|
+
needs_run = True
|
|
1883
|
+
if stage_meta_entry.get("prompt_hash") != prompt_hash_map[stage_name]:
|
|
1884
|
+
needs_run = True
|
|
1885
|
+
if stage_record is not None and not needs_run:
|
|
1886
|
+
errors_in_stage = sorted(
|
|
1887
|
+
stage_validator_map[stage_name].iter_errors(stage_record),
|
|
1888
|
+
key=lambda e: e.path,
|
|
1889
|
+
)
|
|
1890
|
+
if errors_in_stage:
|
|
1891
|
+
needs_run = True
|
|
1892
|
+
|
|
1893
|
+
if not needs_run:
|
|
1894
|
+
state.completed.add(stage_name)
|
|
1895
|
+
state.remaining -= 1
|
|
1896
|
+
reused_count += 1
|
|
1897
|
+
|
|
1898
|
+
if reused_count:
|
|
1899
|
+
total_reused += reused_count
|
|
1900
|
+
if stage_bar:
|
|
1901
|
+
stage_bar.update(reused_count)
|
|
1902
|
+
|
|
1903
|
+
if state.remaining == 0:
|
|
1904
|
+
state.finalized = True
|
|
1905
|
+
if doc_bar:
|
|
1906
|
+
doc_bar.update(1)
|
|
1907
|
+
await update_results(doc_contexts[path])
|
|
1908
|
+
merged = build_merged(doc_contexts[path])
|
|
1909
|
+
errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
|
|
1910
|
+
if errors_in_doc:
|
|
1911
|
+
errors.append(
|
|
1912
|
+
ExtractionError(
|
|
1913
|
+
path=path,
|
|
1914
|
+
provider=provider.name,
|
|
1915
|
+
model=model,
|
|
1916
|
+
error_type="validation_error",
|
|
1917
|
+
error_message=f"Schema validation failed: {errors_in_doc[0].message}",
|
|
1918
|
+
stage_name=stage_definitions[-1].name if stage_definitions else None,
|
|
1919
|
+
)
|
|
1920
|
+
)
|
|
1921
|
+
state.failed = True
|
|
1922
|
+
fully_completed += 1
|
|
1923
|
+
continue
|
|
1924
|
+
|
|
1925
|
+
await enqueue_ready(path, [stage_def.name for stage_def in stage_definitions])
|
|
1926
|
+
|
|
1927
|
+
if total_reused:
|
|
1928
|
+
logger.info(
|
|
1929
|
+
"DAG precheck reused %d/%d stages",
|
|
1930
|
+
total_reused,
|
|
1931
|
+
len(markdown_files) * len(stage_definitions),
|
|
1932
|
+
)
|
|
1933
|
+
if fully_completed:
|
|
1934
|
+
logger.info("DAG precheck fully satisfied %d docs (no stages queued)", fully_completed)
|
|
1935
|
+
|
|
1936
|
+
async def finalize_stage(
|
|
1937
|
+
ctx: DocContext,
|
|
1938
|
+
stage_name: str,
|
|
1939
|
+
*,
|
|
1940
|
+
failed: bool,
|
|
1941
|
+
) -> None:
|
|
1942
|
+
state = doc_states[ctx.path]
|
|
1943
|
+
skip_count = 0
|
|
1944
|
+
doc_done = False
|
|
1945
|
+
is_failed = False
|
|
1946
|
+
async with state.lock:
|
|
1947
|
+
state.in_flight.discard(stage_name)
|
|
1948
|
+
if stage_name not in state.completed:
|
|
1949
|
+
state.completed.add(stage_name)
|
|
1950
|
+
state.remaining -= 1
|
|
1951
|
+
if failed and not state.failed:
|
|
1952
|
+
state.failed = True
|
|
1953
|
+
skip_count = state.remaining - len(state.in_flight)
|
|
1954
|
+
if skip_count < 0:
|
|
1955
|
+
skip_count = 0
|
|
1956
|
+
state.remaining -= skip_count
|
|
1957
|
+
if state.remaining == 0 and not state.finalized:
|
|
1958
|
+
state.finalized = True
|
|
1959
|
+
doc_done = True
|
|
1960
|
+
is_failed = state.failed
|
|
1961
|
+
|
|
1962
|
+
if stage_bar:
|
|
1963
|
+
stage_bar.update(1)
|
|
1964
|
+
if skip_count:
|
|
1965
|
+
stage_bar.update(skip_count)
|
|
1966
|
+
|
|
1967
|
+
if not is_failed:
|
|
1968
|
+
await enqueue_ready(ctx.path, dependents_map.get(stage_name, []))
|
|
1969
|
+
|
|
1970
|
+
if doc_done:
|
|
1971
|
+
if doc_bar:
|
|
1972
|
+
doc_bar.update(1)
|
|
1973
|
+
if not is_failed:
|
|
1974
|
+
merged = build_merged(ctx)
|
|
1975
|
+
errors_in_doc = sorted(validator.iter_errors(merged), key=lambda e: e.path)
|
|
1976
|
+
if errors_in_doc:
|
|
1977
|
+
errors.append(
|
|
1978
|
+
ExtractionError(
|
|
1979
|
+
path=ctx.path,
|
|
1980
|
+
provider=provider.name,
|
|
1981
|
+
model=model,
|
|
1982
|
+
error_type="validation_error",
|
|
1983
|
+
error_message=f"Schema validation failed: {errors_in_doc[0].message}",
|
|
1984
|
+
stage_name=stage_name,
|
|
1985
|
+
)
|
|
1986
|
+
)
|
|
1987
|
+
async with state.lock:
|
|
1988
|
+
state.failed = True
|
|
1989
|
+
|
|
1990
|
+
async def run_task(task: DocTask) -> None:
|
|
1991
|
+
ctx = doc_contexts[task.path]
|
|
1992
|
+
state = doc_states[task.path]
|
|
1993
|
+
current_stage = task.stage_name
|
|
1994
|
+
|
|
1995
|
+
if shutdown_event.is_set():
|
|
1996
|
+
await finalize_stage(ctx, current_stage, failed=False)
|
|
1997
|
+
return
|
|
1998
|
+
async with state.lock:
|
|
1999
|
+
if state.failed:
|
|
2000
|
+
state.in_flight.discard(current_stage)
|
|
2001
|
+
if state.failed:
|
|
2002
|
+
await finalize_stage(ctx, current_stage, failed=False)
|
|
2003
|
+
return
|
|
2004
|
+
|
|
2005
|
+
is_retry_full = ctx.source_path in retry_full_paths
|
|
2006
|
+
retry_stages = (
|
|
2007
|
+
retry_stage_map.get(ctx.source_path)
|
|
2008
|
+
if retry_stages_mode and not is_retry_full
|
|
2009
|
+
else None
|
|
2010
|
+
)
|
|
2011
|
+
stage_record = ctx.stages.get(current_stage)
|
|
2012
|
+
if (
|
|
2013
|
+
retry_stages_mode
|
|
2014
|
+
and not is_retry_full
|
|
2015
|
+
and retry_stages is not None
|
|
2016
|
+
and current_stage not in retry_stages
|
|
2017
|
+
and stage_record is not None
|
|
2018
|
+
):
|
|
2019
|
+
await update_results(ctx)
|
|
2020
|
+
await finalize_stage(ctx, current_stage, failed=False)
|
|
2021
|
+
return
|
|
2022
|
+
|
|
2023
|
+
prompt_hash = _compute_prompt_hash(
|
|
2024
|
+
prompt_template=prompt_template,
|
|
2025
|
+
output_language=output_language,
|
|
2026
|
+
stage_name=current_stage,
|
|
2027
|
+
stage_fields=task.stage_fields,
|
|
2028
|
+
custom_prompt=custom_prompt,
|
|
2029
|
+
prompt_system_path=prompt_system_path,
|
|
2030
|
+
prompt_user_path=prompt_user_path,
|
|
778
2031
|
)
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
2032
|
+
stage_schema = build_stage_schema(schema, task.stage_fields)
|
|
2033
|
+
stage_validator = validate_schema(stage_schema)
|
|
2034
|
+
stage_meta = ctx.stage_meta.get(current_stage, {})
|
|
2035
|
+
needs_run = force or current_stage in force_stage_set
|
|
2036
|
+
if stage_record is None:
|
|
2037
|
+
needs_run = True
|
|
2038
|
+
if is_retry_full:
|
|
2039
|
+
needs_run = True
|
|
2040
|
+
if retry_stages is not None and current_stage in retry_stages:
|
|
2041
|
+
needs_run = True
|
|
2042
|
+
if stage_meta.get("prompt_hash") != prompt_hash:
|
|
2043
|
+
needs_run = True
|
|
2044
|
+
if stage_record is not None and not needs_run:
|
|
2045
|
+
errors_in_stage = sorted(
|
|
2046
|
+
stage_validator.iter_errors(stage_record), key=lambda e: e.path
|
|
2047
|
+
)
|
|
2048
|
+
if errors_in_stage:
|
|
2049
|
+
needs_run = True
|
|
2050
|
+
|
|
2051
|
+
if not needs_run:
|
|
2052
|
+
await update_results(ctx)
|
|
2053
|
+
await finalize_stage(ctx, current_stage, failed=False)
|
|
2054
|
+
return
|
|
2055
|
+
|
|
2056
|
+
try:
|
|
2057
|
+
dependencies = stage_dependencies.get(current_stage, [])
|
|
2058
|
+
previous_payload = {
|
|
2059
|
+
dep: ctx.stages.get(dep) for dep in dependencies if dep in ctx.stages
|
|
2060
|
+
}
|
|
2061
|
+
previous_outputs = (
|
|
2062
|
+
json.dumps(previous_payload, ensure_ascii=False) if previous_payload else ""
|
|
2063
|
+
)
|
|
2064
|
+
messages = build_messages(
|
|
2065
|
+
ctx.truncated_content,
|
|
2066
|
+
stage_schema,
|
|
2067
|
+
provider,
|
|
2068
|
+
prompt_template,
|
|
2069
|
+
output_language,
|
|
2070
|
+
custom_prompt=False,
|
|
2071
|
+
prompt_system_path=None,
|
|
2072
|
+
prompt_user_path=None,
|
|
2073
|
+
stage_name=current_stage,
|
|
2074
|
+
stage_fields=task.stage_fields,
|
|
2075
|
+
previous_outputs=previous_outputs,
|
|
2076
|
+
)
|
|
2077
|
+
async with semaphore:
|
|
2078
|
+
data = await call_with_retries(
|
|
2079
|
+
provider,
|
|
2080
|
+
model,
|
|
2081
|
+
messages,
|
|
2082
|
+
stage_schema,
|
|
2083
|
+
None,
|
|
2084
|
+
timeout=timeout_seconds,
|
|
2085
|
+
structured_mode=provider.structured_mode,
|
|
2086
|
+
max_retries=config.extract.max_retries,
|
|
2087
|
+
backoff_base_seconds=config.extract.backoff_base_seconds,
|
|
2088
|
+
backoff_max_seconds=config.extract.backoff_max_seconds,
|
|
2089
|
+
client=client,
|
|
2090
|
+
validator=stage_validator,
|
|
2091
|
+
key_rotator=rotator,
|
|
2092
|
+
throttle=throttle,
|
|
2093
|
+
stats=stats,
|
|
2094
|
+
)
|
|
2095
|
+
async with state.lock:
|
|
2096
|
+
ctx.stages[current_stage] = data
|
|
2097
|
+
ctx.stage_meta[current_stage] = {"prompt_hash": prompt_hash}
|
|
2098
|
+
ctx.stage_state["stages"] = ctx.stages
|
|
2099
|
+
ctx.stage_state["stage_meta"] = ctx.stage_meta
|
|
2100
|
+
write_json_atomic(ctx.stage_path, ctx.stage_state)
|
|
2101
|
+
await update_results(ctx)
|
|
2102
|
+
await finalize_stage(ctx, current_stage, failed=False)
|
|
2103
|
+
except ProviderError as exc:
|
|
2104
|
+
log_extraction_failure(
|
|
2105
|
+
ctx.source_path,
|
|
2106
|
+
exc.error_type,
|
|
2107
|
+
str(exc),
|
|
2108
|
+
status_code=exc.status_code,
|
|
2109
|
+
)
|
|
2110
|
+
errors.append(
|
|
2111
|
+
ExtractionError(
|
|
2112
|
+
path=task.path,
|
|
2113
|
+
provider=provider.name,
|
|
2114
|
+
model=model,
|
|
2115
|
+
error_type=exc.error_type,
|
|
2116
|
+
error_message=str(exc),
|
|
2117
|
+
stage_name=current_stage,
|
|
2118
|
+
)
|
|
2119
|
+
)
|
|
2120
|
+
await finalize_stage(ctx, current_stage, failed=True)
|
|
2121
|
+
except Exception as exc: # pragma: no cover - safety net
|
|
2122
|
+
logger.exception("Unexpected error while processing %s", ctx.source_path)
|
|
2123
|
+
errors.append(
|
|
2124
|
+
ExtractionError(
|
|
2125
|
+
path=task.path,
|
|
2126
|
+
provider=provider.name,
|
|
2127
|
+
model=model,
|
|
2128
|
+
error_type="unexpected_error",
|
|
2129
|
+
error_message=str(exc),
|
|
2130
|
+
stage_name=current_stage,
|
|
2131
|
+
)
|
|
2132
|
+
)
|
|
2133
|
+
await finalize_stage(ctx, current_stage, failed=True)
|
|
2134
|
+
|
|
2135
|
+
async def worker() -> None:
|
|
2136
|
+
while True:
|
|
2137
|
+
if shutdown_event.is_set():
|
|
2138
|
+
await drain_queue(task_queue, drain_lock)
|
|
2139
|
+
return
|
|
2140
|
+
await await_key_pool_ready()
|
|
2141
|
+
try:
|
|
2142
|
+
task = task_queue.get_nowait()
|
|
2143
|
+
except asyncio.QueueEmpty:
|
|
2144
|
+
return
|
|
2145
|
+
await run_task(task)
|
|
2146
|
+
task_queue.task_done()
|
|
2147
|
+
|
|
2148
|
+
workers = [asyncio.create_task(worker()) for _ in range(max_concurrency)]
|
|
2149
|
+
await wait_for_queue(task_queue, drain_lock)
|
|
2150
|
+
for w in workers:
|
|
2151
|
+
w.cancel()
|
|
782
2152
|
|
|
783
2153
|
try:
|
|
784
2154
|
async with httpx.AsyncClient() as client:
|
|
785
|
-
|
|
2155
|
+
if multi_stage:
|
|
2156
|
+
if stage_dag_enabled:
|
|
2157
|
+
await run_multi_stage_dag(client)
|
|
2158
|
+
else:
|
|
2159
|
+
await run_multi_stage(client)
|
|
2160
|
+
else:
|
|
2161
|
+
await run_single_stage(client)
|
|
786
2162
|
finally:
|
|
2163
|
+
if pause_task:
|
|
2164
|
+
pause_task.cancel()
|
|
2165
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
2166
|
+
await pause_task
|
|
787
2167
|
if doc_bar:
|
|
788
2168
|
doc_bar.close()
|
|
789
2169
|
if stage_bar:
|
|
@@ -796,7 +2176,7 @@ async def extract_documents(
|
|
|
796
2176
|
if path and path in results:
|
|
797
2177
|
final_results.append(results[path])
|
|
798
2178
|
seen.add(path)
|
|
799
|
-
elif path
|
|
2179
|
+
elif path:
|
|
800
2180
|
final_results.append(entry)
|
|
801
2181
|
seen.add(path)
|
|
802
2182
|
|
|
@@ -820,6 +2200,12 @@ async def extract_documents(
|
|
|
820
2200
|
]
|
|
821
2201
|
write_json(errors_path, error_payload)
|
|
822
2202
|
|
|
2203
|
+
if shutdown_event.is_set():
|
|
2204
|
+
logger.warning(
|
|
2205
|
+
"Graceful shutdown completed (%s)",
|
|
2206
|
+
shutdown_reason or "signal",
|
|
2207
|
+
)
|
|
2208
|
+
|
|
823
2209
|
if split:
|
|
824
2210
|
target_dir = split_dir or output_path.parent
|
|
825
2211
|
target_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -861,7 +2247,15 @@ async def extract_documents(
|
|
|
861
2247
|
table.add_column("Value", style="white", overflow="fold")
|
|
862
2248
|
table.add_row("Documents", f"{doc_count} total")
|
|
863
2249
|
table.add_row("Successful", str(doc_count - len(errors)))
|
|
2250
|
+
failed_stage_count = sum(1 for err in errors if err.stage_name)
|
|
2251
|
+
retried_stage_count = 0
|
|
2252
|
+
if retry_stages_mode:
|
|
2253
|
+
retried_stage_count = sum(
|
|
2254
|
+
len(retry_stage_map.get(str(path.resolve()), set())) for path in markdown_files
|
|
2255
|
+
)
|
|
864
2256
|
table.add_row("Errors", str(len(errors)))
|
|
2257
|
+
table.add_row("Failed stages", str(failed_stage_count))
|
|
2258
|
+
table.add_row("Retried stages", str(retried_stage_count))
|
|
865
2259
|
table.add_row("Output JSON", str(output_path))
|
|
866
2260
|
table.add_row("Errors JSON", str(errors_path))
|
|
867
2261
|
table.add_row("Duration", _format_duration(duration))
|