debase 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2169 @@
1
+ """enzyme_lineage_extractor.py
2
+
3
+ Single-file, maintainable CLI tool that pulls an enzyme "family tree" and
4
+ associated sequences from one or two PDFs (manuscript + SI) using Google
5
+ Gemini (or compatible) LLM.
6
+
7
+ Navigate by searching for the numbered section headers:
8
+
9
+ # === 1. CONFIG & CONSTANTS ===
10
+ # === 2. DOMAIN MODELS ===
11
+ # === 3. LOGGING HELPERS ===
12
+ # === 4. PDF HELPERS ===
13
+ # === 5. LLM (GEMINI) HELPERS ===
14
+ # === 6. LINEAGE EXTRACTION ===
15
+ # === 7. SEQUENCE EXTRACTION ===
16
+ # === 8. VALIDATION & MERGE ===
17
+ # === 9. PIPELINE ORCHESTRATOR ===
18
+ # === 10. CLI ENTRYPOINT ===
19
+ """
20
+
21
+ # === 1. CONFIG & CONSTANTS ===
22
+ from __future__ import annotations
23
+ import pandas as pd
24
+ import networkx as nx # light dependency, used only for generation inference
25
+
26
+ import os
27
+ import re
28
+ import json
29
+ import time
30
+ import logging
31
+ from pathlib import Path
32
+ from dataclasses import dataclass, field
33
+ from typing import List, Optional, Union, Tuple
34
+
35
+ MODEL_NAME: str = "gemini-2.5-flash"
36
+ MAX_CHARS: int = 150_000 # Max characters sent to LLM
37
+ SEQ_CHUNK: int = 10 # Batch size when prompting for sequences
38
+ MAX_RETRIES: int = 4 # LLM retry loop
39
+ CACHE_DIR: Path = Path.home() / ".cache" / "enzyme_extractor"
40
+
41
+ # Ensure cache directory exists
42
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
43
+
44
+ # === 2. DOMAIN MODELS ===
45
+ @dataclass
46
+ class Campaign:
47
+ """Representation of a directed evolution campaign."""
48
+ campaign_id: str
49
+ campaign_name: str
50
+ description: str
51
+ model_substrate: Optional[str] = None
52
+ model_product: Optional[str] = None
53
+ substrate_id: Optional[str] = None
54
+ product_id: Optional[str] = None
55
+ data_locations: List[str] = field(default_factory=list)
56
+ reaction_conditions: dict = field(default_factory=dict)
57
+ notes: str = ""
58
+
59
+ @dataclass
60
+ class Variant:
61
+ """Representation of a variant in the evolutionary lineage."""
62
+ variant_id: str
63
+ parent_id: Optional[str]
64
+ mutations: List[str]
65
+ generation: int
66
+ campaign_id: Optional[str] = None # Links variant to campaign
67
+ notes: str = ""
68
+
69
+ @dataclass
70
+ class SequenceBlock:
71
+ """Protein and/or DNA sequence associated with a variant."""
72
+ variant_id: str
73
+ aa_seq: Optional[str] = None
74
+ dna_seq: Optional[str] = None
75
+ confidence: Optional[float] = None
76
+ truncated: bool = False
77
+ metadata: dict = field(default_factory=dict)
78
+
79
+ # === 3. LOGGING HELPERS ===
80
+
81
+ # --- Debug dump helper ----------------------------------------------------
82
+ def _dump(text: str | bytes, path: Path | str) -> None:
83
+ """Write `text` / `bytes` to `path`, creating parent dirs as needed."""
84
+ p = Path(path)
85
+ p.parent.mkdir(parents=True, exist_ok=True)
86
+ mode = "wb" if isinstance(text, (bytes, bytearray)) else "w"
87
+ with p.open(mode) as fh:
88
+ fh.write(text)
89
+
90
+ def get_logger(name: str = __name__) -> logging.Logger:
91
+ logger = logging.getLogger(name)
92
+ if not logger.handlers:
93
+ handler = logging.StreamHandler()
94
+ fmt = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
95
+ handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S"))
96
+ logger.addHandler(handler)
97
+ logger.setLevel(logging.INFO)
98
+ return logger
99
+
100
+ log = get_logger(__name__)
101
+
102
+ # === 4. PDF HELPERS (incl. caption scraper & figure extraction) ===
103
+ try:
104
+ import fitz # PyMuPDF
105
+ except ImportError as exc: # pragma: no cover
106
+ raise ImportError(
107
+ "PyMuPDF is required for PDF parsing. Install with `pip install pymupdf`."
108
+ ) from exc
109
+
110
+ _DOI_REGEX = re.compile(r"10\.[0-9]{4,9}/[-._;()/:A-Z0-9]+", re.I)
111
+
112
+ # PDB ID regex - matches 4-character PDB codes
113
+ _PDB_REGEX = re.compile(r"\b[1-9][A-Z0-9]{3}\b")
114
+
115
+ # Improved caption prefix regex - captures most journal variants
116
+ _CAPTION_PREFIX_RE = re.compile(
117
+ r"""
118
+ ^\s*
119
+ (?:Fig(?:ure)?|Extended\s+Data\s+Fig|ED\s+Fig|Scheme|Chart|
120
+ Table|Supp(?:lementary|l|\.?)\s+(?:Fig(?:ure)?|Table)) # label part
121
+ \s*(?:S?\d+[A-Za-z]?|[IVX]+) # figure number
122
+ [.:]?\s* # trailing punctuation/space
123
+ """,
124
+ re.I | re.X,
125
+ )
126
+
127
+
128
+ def _open_doc(pdf_path: str | Path | bytes):
129
+ if isinstance(pdf_path, (str, Path)):
130
+ return fitz.open(pdf_path) # type: ignore[arg-type]
131
+ return fitz.open(stream=pdf_path, filetype="pdf") # type: ignore[arg-type]
132
+
133
+
134
+ def extract_text(pdf_path: str | Path | bytes) -> str:
135
+ """Extract raw text from a PDF file (all blocks)."""
136
+
137
+ doc = _open_doc(pdf_path)
138
+ try:
139
+ return "\n".join(page.get_text() for page in doc)
140
+ finally:
141
+ doc.close()
142
+
143
+
144
+ def extract_captions(pdf_path: str | Path | bytes, max_chars: int = MAX_CHARS) -> str:
145
+ """Extract figure/table captions using the improved regex.
146
+
147
+ The function scans every text line on every page and keeps lines whose first
148
+ token matches `_CAPTION_PREFIX_RE`. This covers labels such as:
149
+ * Fig. 1, Figure 2A, Extended Data Fig 3
150
+ * Table S1, Table 4, Scheme 2, Chart 1B
151
+ * Supplementary Fig. S5, Supp Table 2
152
+ """
153
+
154
+ doc = _open_doc(pdf_path)
155
+ captions: list[str] = []
156
+ try:
157
+ for page in doc:
158
+ page_dict = page.get_text("dict")
159
+ for block in page_dict.get("blocks", []):
160
+ # Get all lines in this block
161
+ block_lines = []
162
+ for line in block.get("lines", []):
163
+ text_line = "".join(span["text"] for span in line.get("spans", []))
164
+ block_lines.append(text_line.strip())
165
+
166
+ # Check if any line starts with a caption prefix
167
+ for i, line in enumerate(block_lines):
168
+ if _CAPTION_PREFIX_RE.match(line):
169
+ # Found a caption start - collect this line and subsequent lines
170
+ # until we hit an empty line or the end of the block
171
+ caption_parts = [line]
172
+ for j in range(i + 1, len(block_lines)):
173
+ next_line = block_lines[j]
174
+ if not next_line: # Empty line signals end of caption
175
+ break
176
+ # Check if next line is a new caption
177
+ if _CAPTION_PREFIX_RE.match(next_line):
178
+ break
179
+ caption_parts.append(next_line)
180
+
181
+ # Join the parts with spaces
182
+ full_caption = " ".join(caption_parts)
183
+ captions.append(full_caption)
184
+ finally:
185
+ doc.close()
186
+
187
+ joined = "\n".join(captions)
188
+ return joined[:max_chars]
189
+
190
+
191
+ def extract_doi(pdf_path: str | Path | bytes) -> Optional[str]:
192
+ """Attempt to locate a DOI inside the PDF."""
193
+
194
+ m = _DOI_REGEX.search(extract_text(pdf_path))
195
+ return m.group(0) if m else None
196
+
197
+
198
+ def extract_pdb_ids(pdf_path: str | Path | bytes) -> List[str]:
199
+ """Extract all PDB IDs from the PDF."""
200
+ text = extract_text(pdf_path)
201
+
202
+ # Find all potential PDB IDs
203
+ pdb_ids = []
204
+ for match in _PDB_REGEX.finditer(text):
205
+ pdb_id = match.group(0).upper()
206
+ # Additional validation - check context for "PDB" mention
207
+ start = max(0, match.start() - 50)
208
+ end = min(len(text), match.end() + 50)
209
+ context = text[start:end].upper()
210
+
211
+ # Only include if "PDB" appears in context or it's a known pattern
212
+ if "PDB" in context or "PROTEIN DATA BANK" in context:
213
+ if pdb_id not in pdb_ids:
214
+ pdb_ids.append(pdb_id)
215
+ log.info(f"Found PDB ID: {pdb_id}")
216
+
217
+ return pdb_ids
218
+
219
+
220
+ def limited_concat(*pdf_paths: str | Path, max_chars: int = MAX_CHARS) -> str:
221
+ """Concatenate **all text** from PDFs, trimmed to `max_chars`."""
222
+
223
+ total = 0
224
+ chunks: list[str] = []
225
+ for p in pdf_paths:
226
+ t = extract_text(p)
227
+ if total + len(t) > max_chars:
228
+ t = t[: max_chars - total]
229
+ chunks.append(t)
230
+ total += len(t)
231
+ if total >= max_chars:
232
+ break
233
+ return "\n".join(chunks)
234
+
235
+
236
+ def limited_caption_concat(*pdf_paths: str | Path, max_chars: int = MAX_CHARS) -> str:
237
+ """Concatenate only caption text from PDFs, trimmed to `max_chars`."""
238
+
239
+ total = 0
240
+ chunks: list[str] = []
241
+ for p in pdf_paths:
242
+ t = extract_captions(p)
243
+ if total + len(t) > max_chars:
244
+ t = t[: max_chars - total]
245
+ chunks.append(t)
246
+ total += len(t)
247
+ if total >= max_chars:
248
+ break
249
+ return "\n".join(chunks)
250
+
251
+
252
+ def extract_figure(pdf_path: Union[str, Path], figure_id: str, debug_dir: Optional[Union[str, Path]] = None) -> Optional[bytes]:
253
+ """Extract a specific figure from a PDF by finding its caption.
254
+
255
+ Returns the figure as PNG bytes if found, None otherwise.
256
+ """
257
+ doc = _open_doc(pdf_path)
258
+ figure_bytes = None
259
+
260
+ try:
261
+ # Search for the exact figure caption text
262
+ search_text = figure_id.strip()
263
+
264
+ for page_num, page in enumerate(doc):
265
+ # Search for the caption text on this page
266
+ text_instances = page.search_for(search_text)
267
+
268
+ if text_instances:
269
+ log.info(f"Found caption '{figure_id}' on page {page_num + 1}")
270
+
271
+ # Get the position of the first instance
272
+ caption_rect = text_instances[0]
273
+
274
+ # Get all images on this page
275
+ image_list = page.get_images()
276
+
277
+ if image_list:
278
+ # Find the image closest to and above the caption
279
+ best_img = None
280
+ best_distance = float('inf')
281
+
282
+ for img_index, img in enumerate(image_list):
283
+ # Get image position
284
+ xref = img[0]
285
+ img_rects = page.get_image_rects(xref)
286
+
287
+ if img_rects:
288
+ img_rect = img_rects[0]
289
+
290
+ # Check if image is above the caption and calculate distance
291
+ if img_rect.y1 <= caption_rect.y0: # Image bottom is above caption top
292
+ distance = caption_rect.y0 - img_rect.y1
293
+ if distance < best_distance and distance < 100: # Within reasonable distance
294
+ best_distance = distance
295
+ best_img = xref
296
+
297
+ if best_img is not None:
298
+ # Extract the identified image
299
+ pix = fitz.Pixmap(doc, best_img)
300
+
301
+ if pix.n - pix.alpha < 4: # GRAY or RGB
302
+ figure_bytes = pix.tobytes("png")
303
+ else: # Convert CMYK to RGB
304
+ pix2 = fitz.Pixmap(fitz.csRGB, pix)
305
+ figure_bytes = pix2.tobytes("png")
306
+ pix2 = None
307
+ pix = None
308
+
309
+ # Save to debug directory if provided
310
+ if debug_dir and figure_bytes:
311
+ debug_path = Path(debug_dir)
312
+ debug_path.mkdir(parents=True, exist_ok=True)
313
+ fig_file = debug_path / f"figure_{figure_id.replace(' ', '_').replace('.', '')}_{int(time.time())}.png"
314
+ with open(fig_file, 'wb') as f:
315
+ f.write(figure_bytes)
316
+ log.info(f"Saved figure to: {fig_file}")
317
+
318
+ break
319
+
320
+ finally:
321
+ doc.close()
322
+
323
+ return figure_bytes
324
+
325
+
326
+ def is_figure_reference(location: str) -> bool:
327
+ """Check if a location string refers to a figure."""
328
+ # Check for common figure patterns
329
+ figure_patterns = [
330
+ r'Fig(?:ure)?\.?\s+', # Fig. 2B, Figure 3
331
+ r'Extended\s+Data\s+Fig', # Extended Data Fig
332
+ r'ED\s+Fig', # ED Fig
333
+ r'Scheme\s+', # Scheme 1
334
+ r'Chart\s+', # Chart 2
335
+ ]
336
+
337
+ location_str = str(location).strip()
338
+ for pattern in figure_patterns:
339
+ if re.search(pattern, location_str, re.I):
340
+ return True
341
+ return False
342
+
343
+ # === 5. LLM (Gemini) HELPERS === ---------------------------------------------
344
+ from typing import Tuple, Any
345
+
346
+ _BACKOFF_BASE = 2.0 # exponential back-off base (seconds)
347
+
348
+ # -- 5.1 Import whichever SDK is installed -----------------------------------
349
+
350
+ def _import_gemini_sdk() -> Tuple[str, Any]:
351
+ """Return (flavor, module) where flavor in {"new", "legacy"}."""
352
+ try:
353
+ import google.generativeai as genai # official SDK >= 1.0
354
+ return "new", genai
355
+ except ImportError:
356
+ try:
357
+ import google_generativeai as genai # legacy prerelease name
358
+ return "legacy", genai
359
+ except ImportError as exc:
360
+ raise ImportError(
361
+ "Neither 'google-generativeai' (>=1.0) nor 'google_generativeai'\n"
362
+ "is installed. Run: pip install --upgrade google-generativeai"
363
+ ) from exc
364
+
365
+ _SDK_FLAVOR, _genai = _import_gemini_sdk()
366
+
367
+ # -- 5.2 Model factory --------------------------------------------------------
368
+
369
+ def get_model():
370
+ """Configure API key and return a `GenerativeModel` instance."""
371
+ api_key = os.getenv("GEMINI_API_KEY")
372
+ if not api_key:
373
+ raise EnvironmentError("Set the GEMINI_API_KEY environment variable.")
374
+ _genai.configure(api_key=api_key)
375
+ # Positional constructor arg works for both SDK flavors
376
+ return _genai.GenerativeModel(MODEL_NAME)
377
+
378
+ # === 5.3 Unified call helper ----------------------------------------------
379
+
380
+ def _extract_text(resp) -> str:
381
+ """
382
+ Pull the *first* textual part out of a GenerativeAI response, handling both
383
+ the old prerelease SDK and the >=1.0 SDK.
384
+
385
+ Returns an empty string if no textual content is found.
386
+ """
387
+ # 1) Legacy SDK (<= 0.4) - still has nice `.text`
388
+ if getattr(resp, "text", None):
389
+ return resp.text
390
+
391
+ # 2) >= 1.0 SDK
392
+ if getattr(resp, "candidates", None):
393
+ cand = resp.candidates[0]
394
+
395
+ # 2a) Some beta builds still expose `.text`
396
+ if getattr(cand, "text", None):
397
+ return cand.text
398
+
399
+ # 2b) Official path: candidate.content.parts[*].text
400
+ if getattr(cand, "content", None):
401
+ parts = [
402
+ part.text # Part objects have .text
403
+ for part in cand.content.parts
404
+ if getattr(part, "text", None)
405
+ ]
406
+ if parts:
407
+ return "".join(parts)
408
+
409
+ # 3) As a last resort fall back to str()
410
+ return str(resp)
411
+
412
+
413
+ def generate_json_with_retry(
414
+ model,
415
+ prompt: str,
416
+ schema_hint: str | None = None,
417
+ *,
418
+ max_retries: int = MAX_RETRIES,
419
+ debug_dir:str | Path | None = None,
420
+ tag: str = 'gemini',
421
+ ):
422
+ """
423
+ Call Gemini with retries & exponential back-off, returning parsed JSON.
424
+
425
+ Also strips Markdown fences that the model may wrap around its JSON.
426
+ """
427
+ # Log prompt details
428
+ log.info("=== GEMINI API CALL: %s ===", tag.upper())
429
+ log.info("Prompt length: %d characters", len(prompt))
430
+ log.info("First 500 chars of prompt:\n%s\n...(truncated)", prompt[:500])
431
+
432
+ # Save full prompt to debug directory if provided
433
+ if debug_dir:
434
+ debug_path = Path(debug_dir)
435
+ debug_path.mkdir(parents=True, exist_ok=True)
436
+ prompt_file = debug_path / f"{tag}_prompt_{int(time.time())}.txt"
437
+ with open(prompt_file, 'w') as f:
438
+ f.write(f"=== PROMPT FOR {tag.upper()} ===\n")
439
+ f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
440
+ f.write(f"Length: {len(prompt)} characters\n")
441
+ f.write("="*80 + "\n\n")
442
+ f.write(prompt)
443
+ log.info("Full prompt saved to: %s", prompt_file)
444
+
445
+ fence_re = re.compile(r"```json|```", re.I)
446
+ for attempt in range(1, max_retries + 1):
447
+ try:
448
+ log.info("Calling Gemini API (attempt %d/%d)...", attempt, max_retries)
449
+ resp = model.generate_content(prompt)
450
+ raw = _extract_text(resp).strip()
451
+
452
+ # Log response
453
+ log.info("Gemini response length: %d characters", len(raw))
454
+ log.info("First 500 chars of response:\n%s\n...(truncated)", raw[:500])
455
+
456
+ # Save full response to debug directory
457
+ if debug_dir:
458
+ response_file = debug_path / f"{tag}_response_{int(time.time())}.txt"
459
+ with open(response_file, 'w') as f:
460
+ f.write(f"=== RESPONSE FOR {tag.upper()} ===\n")
461
+ f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
462
+ f.write(f"Length: {len(raw)} characters\n")
463
+ f.write("="*80 + "\n\n")
464
+ f.write(raw)
465
+ log.info("Full response saved to: %s", response_file)
466
+
467
+ # Remove common Markdown fences
468
+ if raw.startswith("```"):
469
+ raw = fence_re.sub("", raw).strip()
470
+
471
+ # Try to find JSON in the response
472
+ # First, try to parse as-is
473
+ try:
474
+ parsed = json.loads(raw)
475
+ except json.JSONDecodeError:
476
+ # If that fails, look for JSON array or object
477
+ # Find the first '[' or '{' and the matching closing bracket
478
+ json_start = -1
479
+ json_end = -1
480
+ bracket_stack = []
481
+ in_string = False
482
+ escape_next = False
483
+
484
+ for i, char in enumerate(raw):
485
+ if escape_next:
486
+ escape_next = False
487
+ continue
488
+
489
+ if char == '\\':
490
+ escape_next = True
491
+ continue
492
+
493
+ if char == '"' and not escape_next:
494
+ in_string = not in_string
495
+ continue
496
+
497
+ if in_string:
498
+ continue
499
+
500
+ if char in '[{':
501
+ if json_start == -1:
502
+ json_start = i
503
+ bracket_stack.append(char)
504
+ elif char in ']}':
505
+ if bracket_stack:
506
+ opening = bracket_stack.pop()
507
+ if (opening == '[' and char == ']') or (opening == '{' and char == '}'):
508
+ if not bracket_stack: # Found complete JSON
509
+ json_end = i + 1
510
+ break
511
+
512
+ if json_start >= 0 and json_end > json_start:
513
+ # Extract the JSON portion
514
+ json_str = raw[json_start:json_end]
515
+ parsed = json.loads(json_str)
516
+ else:
517
+ # Look for simple [] in the response
518
+ if '[]' in raw:
519
+ parsed = []
520
+ else:
521
+ # No JSON structure found, re-raise the original error
522
+ raise json.JSONDecodeError("No JSON structure found in response", raw, 0)
523
+ log.info("Successfully parsed JSON response")
524
+ return parsed
525
+ except Exception as exc: # broad except OK here
526
+ log.warning(
527
+ "Gemini call failed (attempt %d/%d): %s",
528
+ attempt, max_retries, exc,
529
+ )
530
+ if attempt == max_retries:
531
+ raise
532
+ time.sleep(_BACKOFF_BASE ** attempt)
533
+ # -------------------------------------------------------------------- end 5 ---
534
+
535
+
536
+ # === 6. LINEAGE EXTRACTION (WITH CAMPAIGN SUPPORT) ===
537
+ """
538
+ Variant lineage extractor with campaign identification.
539
+
540
+ Asks Gemini to produce a JSON representation of the evolutionary lineage of
541
+ enzyme variants described in a manuscript. The heavy lifting is done by the
542
+ LLM; this section crafts robust prompts, parses the reply, validates it, and
543
+ exposes a convenient high-level `get_lineage()` helper for the pipeline.
544
+
545
+ June 2025: updated for Google Generative AI SDK >= 1.0 and added rich debug
546
+ hooks (`--debug-dir` dumps prompts, replies, and raw captions).
547
+
548
+ December 2025: Added campaign identification to support multiple directed
549
+ evolution campaigns within a single paper.
550
+ """
551
+
552
+ from pathlib import Path
553
+ from typing import List, Dict, Any
554
+
555
+ # ---- 6.0 Campaign identification prompts -----------------------------------
556
+
557
+ _CAMPAIGN_IDENTIFICATION_PROMPT = """
558
+ You are an expert reader of protein engineering manuscripts.
559
+ Analyze the following manuscript text to identify ALL distinct directed evolution campaigns.
560
+
561
+ Each campaign represents a separate evolutionary lineage targeting different:
562
+ - Model reactions (e.g., different chemical transformations)
563
+ - Substrate scopes
564
+ - Activities (e.g., different enzymatic reactions)
565
+
566
+ Look for:
567
+ 1. Different model substrates/products mentioned (e.g., different substrate/product pairs)
568
+ 2. Distinct enzyme lineage names (e.g., different variant naming patterns)
569
+ 3. Separate evolution trees or lineage tables
570
+ 4. Different reaction schemes or transformations
571
+
572
+ Return a JSON array of campaigns:
573
+ [
574
+ {{
575
+ "campaign_id": "unique_id",
576
+ "campaign_name": "descriptive name",
577
+ "description": "what this campaign evolved for",
578
+ "model_substrate": "substrate name/id",
579
+ "model_product": "product name/id",
580
+ "substrate_id": "id from paper (e.g., 1a)",
581
+ "product_id": "id from paper (e.g., 2a)",
582
+ "data_locations": ["Table S1", "Figure 1"],
583
+ "lineage_hint": "enzyme name pattern",
584
+ "notes": "additional context"
585
+ }}
586
+ ]
587
+
588
+ TEXT:
589
+ {text}
590
+ """.strip()
591
+
592
+ _CAMPAIGN_MAPPING_PROMPT = """
593
+ Given these identified campaigns and the lineage data location, determine which campaign this data belongs to:
594
+
595
+ Campaigns:
596
+ {campaigns}
597
+
598
+ Data location: {location}
599
+ Caption/context: {context}
600
+
601
+ Based on the caption, enzyme names, or reaction details, which campaign does this data belong to?
602
+ Return ONLY the campaign_id as a string.
603
+ """.strip()
604
+
605
+ # ---- 6.1 Prompt templates -------------------------------------------------
606
+
607
+ _LINEAGE_LOC_PROMPT = """
608
+ You are an expert reader of protein engineering manuscripts.
609
+ Given the following article text, list up to {max_results} *locations* (page
610
+ numbers, figure/table IDs, or section headings) that you would review first to
611
+ find the COMPLETE evolutionary lineage of enzyme variants (i.e. which variant
612
+ came from which parent and what mutations were introduced).
613
+
614
+ Respond with a JSON array of objects, each containing:
615
+ - "location": the identifier (e.g. "Table S1", "Figure 2B", "p. 6")
616
+ - "type": one of "table", "figure", "text", "section"
617
+ - "confidence": your confidence score (0-100) that this location contains lineage data
618
+ - "reason": brief explanation of why this location likely contains lineage
619
+
620
+ Order by confidence score (highest first). Tables showing complete variant lineages or
621
+ mutation lists should be ranked higher than figure showing complete variant lineages.
622
+ Text sections is used when no suitable tables/figurews exist.
623
+
624
+ Don't include oligonucleotide results or result from only one round.
625
+
626
+ Example output:
627
+ [
628
+ {{"location": "Table S1", "type": "table", "confidence": 95, "reason": "Variant lineage table"}},
629
+ {{"location": "Figure 2B", "type": "figure", "confidence": 70, "reason": "Phylogenetic tree diagram"}},
630
+ {{"location": "Section 3.2", "type": "section", "confidence": 60, "reason": "Evolution description"}}
631
+ ]
632
+ """.strip()
633
+
634
+ _LINEAGE_SCHEMA_HINT = """
635
+ {
636
+ "variants": [
637
+ {
638
+ "variant_id": "string",
639
+ "parent_id": "string | null",
640
+ "mutations": ["string"],
641
+ "generation": "int",
642
+ "campaign_id": "string (optional)",
643
+ "notes": "string (optional)"
644
+ }
645
+ ]
646
+ }
647
+ """.strip()
648
+
649
+ _LINEAGE_EXTRACT_PROMPT = """
650
+ Below is the (optionally truncated) text of a protein-engineering manuscript.
651
+ Your task is to output the **complete evolutionary lineage** as JSON conforming
652
+ exactly to the schema provided below.
653
+
654
+ {campaign_context}
655
+
656
+ Schema:
657
+ ```json
658
+ {schema}
659
+ ```
660
+
661
+ Guidelines:
662
+ * Include every named variant that appears in the lineage (WT, libraries,
663
+ hits, final variant, etc.).
664
+ * If a variant appears multiple times, keep the earliest generation.
665
+ * `mutations` must be a list of human-readable point mutations *relative to
666
+ its immediate parent* (e.g. ["L34V", "S152G"]). If no mutations are listed,
667
+ use an empty list.
668
+ * Generation = 0 for the starting template (WT or first variant supplied by
669
+ the authors). Increment by 1 for each subsequent round.
670
+ * If you are uncertain about any field, add an explanatory string to `notes`.
671
+ * IMPORTANT: Only include variants that belong to the campaign context provided above.
672
+
673
+ Return **ONLY** minified JSON, no markdown fences, no commentary.
674
+
675
+ TEXT:
676
+ ```
677
+ {text}
678
+ ```
679
+ """.strip()
680
+
681
+ _LINEAGE_FIGURE_PROMPT = """
682
+ You are looking at a figure from a protein-engineering manuscript that shows
683
+ the evolutionary lineage of enzyme variants.
684
+
685
+ {campaign_context}
686
+
687
+ Your task is to output the **complete evolutionary lineage** as JSON conforming
688
+ exactly to the schema provided below.
689
+
690
+ Schema:
691
+ ```json
692
+ {schema}
693
+ ```
694
+
695
+ Guidelines:
696
+ * Include every named variant that appears in the lineage diagram/tree
697
+ * Extract parent-child relationships from the visual connections (arrows, lines, etc.)
698
+ * `mutations` must be a list of human-readable point mutations *relative to
699
+ its immediate parent* (e.g. ["L34V", "S152G"]) if shown
700
+ * Generation = 0 for the starting template (WT or first variant). Increment by 1
701
+ for each subsequent round/generation shown in the figure
702
+ * If you are uncertain about any field, add an explanatory string to `notes`
703
+ * IMPORTANT: Only include variants that belong to the campaign context provided above.
704
+
705
+ Return **ONLY** minified JSON, no markdown fences, no commentary.
706
+ """.strip()
707
+
708
+ # ---- 6.2 Helper functions -------------------------------------------------
709
+
710
+ def identify_campaigns(
711
+ text: str,
712
+ model,
713
+ *,
714
+ debug_dir: str | Path | None = None,
715
+ ) -> List[Campaign]:
716
+ """Identify distinct directed evolution campaigns in the manuscript."""
717
+ prompt = _CAMPAIGN_IDENTIFICATION_PROMPT.format(text=text[:30_000])
718
+ campaigns_data: List[dict] = []
719
+ try:
720
+ campaigns_data = generate_json_with_retry(
721
+ model,
722
+ prompt,
723
+ debug_dir=debug_dir,
724
+ tag="campaigns",
725
+ )
726
+ except Exception as exc:
727
+ log.warning("identify_campaigns(): %s", exc)
728
+
729
+ # Convert to Campaign objects
730
+ campaigns = []
731
+ for data in campaigns_data:
732
+ try:
733
+ campaign = Campaign(
734
+ campaign_id=data.get("campaign_id", ""),
735
+ campaign_name=data.get("campaign_name", ""),
736
+ description=data.get("description", ""),
737
+ model_substrate=data.get("model_substrate"),
738
+ model_product=data.get("model_product"),
739
+ substrate_id=data.get("substrate_id"),
740
+ product_id=data.get("product_id"),
741
+ data_locations=data.get("data_locations", []),
742
+ reaction_conditions=data.get("reaction_conditions", {}),
743
+ notes=data.get("notes", "")
744
+ )
745
+ campaigns.append(campaign)
746
+ log.info(f"Identified campaign: {campaign.campaign_name} ({campaign.campaign_id})")
747
+ except Exception as exc:
748
+ log.warning(f"Failed to parse campaign data: {exc}")
749
+
750
+ return campaigns
751
+
752
+ def identify_evolution_locations(
753
+ text: str,
754
+ model,
755
+ *,
756
+ max_results: int = 5,
757
+ debug_dir: str | Path | None = None,
758
+ campaigns: Optional[List[Campaign]] = None,
759
+ ) -> List[dict]:
760
+ """Ask Gemini where in the paper the lineage is probably described."""
761
+ prompt = _LINEAGE_LOC_PROMPT.format(max_results=max_results) + "\n\nTEXT:\n" + text[:15_000]
762
+ locs: List[dict] = []
763
+ try:
764
+ locs = generate_json_with_retry(
765
+ model,
766
+ prompt,
767
+ debug_dir=debug_dir,
768
+ tag="locate",
769
+ )
770
+ except Exception as exc: # pragma: no cover
771
+ log.warning("identify_evolution_locations(): %s", exc)
772
+
773
+ # If we have campaigns, try to map locations to campaigns
774
+ if campaigns and locs:
775
+ for loc in locs:
776
+ # Extract more context around the location
777
+ location_str = loc.get('location', '')
778
+ context = loc.get('reason', '')
779
+
780
+ # Ask Gemini to map this location to a campaign
781
+ if campaigns:
782
+ try:
783
+ campaigns_json = json.dumps([{
784
+ "campaign_id": c.campaign_id,
785
+ "campaign_name": c.campaign_name,
786
+ "lineage_hint": c.notes
787
+ } for c in campaigns])
788
+
789
+ mapping_prompt = _CAMPAIGN_MAPPING_PROMPT.format(
790
+ campaigns=campaigns_json,
791
+ location=location_str,
792
+ context=context
793
+ )
794
+
795
+ # Save mapping prompt to debug if provided
796
+ if debug_dir:
797
+ debug_path = Path(debug_dir)
798
+ debug_path.mkdir(parents=True, exist_ok=True)
799
+ mapping_file = debug_path / f"campaign_mapping_{location_str.replace(' ', '_')}_{int(time.time())}.txt"
800
+ _dump(f"=== CAMPAIGN MAPPING PROMPT ===\nLocation: {location_str}\n{'='*80}\n\n{mapping_prompt}", mapping_file)
801
+
802
+ response = model.generate_content(mapping_prompt)
803
+ campaign_id = _extract_text(response).strip().strip('"')
804
+
805
+ # Save mapping response to debug if provided
806
+ if debug_dir:
807
+ response_file = debug_path / f"campaign_mapping_response_{location_str.replace(' ', '_')}_{int(time.time())}.txt"
808
+ _dump(f"=== CAMPAIGN MAPPING RESPONSE ===\nLocation: {location_str}\nMapped to: {campaign_id}\n{'='*80}\n\n{_extract_text(response)}", response_file)
809
+
810
+ # Add campaign_id to location
811
+ loc['campaign_id'] = campaign_id
812
+ log.info(f"Mapped {location_str} to campaign: {campaign_id}")
813
+ except Exception as exc:
814
+ log.warning(f"Failed to map location to campaign: {exc}")
815
+
816
+ return locs if isinstance(locs, list) else []
817
+
818
+
819
+
820
+ def _parse_variants(data: Dict[str, Any], campaign_id: Optional[str] = None) -> List[Variant]:
821
+ """Convert raw JSON to a list[Variant] with basic validation."""
822
+ variants_json = data.get("variants", []) if isinstance(data, dict) else []
823
+ parsed: List[Variant] = []
824
+ for item in variants_json:
825
+ try:
826
+ variant_id = str(item["variant_id"]).strip()
827
+ parent_id = item.get("parent_id")
828
+ parent_id = str(parent_id).strip() if parent_id else None
829
+ mutations = [str(m).strip() for m in item.get("mutations", [])]
830
+ generation = int(item.get("generation", 0))
831
+ notes = str(item.get("notes", "")).strip()
832
+
833
+ # Use campaign_id from item if present, otherwise use the passed campaign_id,
834
+ # otherwise default to "default"
835
+ variant_campaign_id = item.get("campaign_id") or campaign_id or "default"
836
+
837
+ parsed.append(
838
+ Variant(
839
+ variant_id=variant_id,
840
+ parent_id=parent_id,
841
+ mutations=mutations,
842
+ generation=generation,
843
+ campaign_id=variant_campaign_id,
844
+ notes=notes,
845
+ )
846
+ )
847
+ except Exception as exc: # pragma: no cover
848
+ log.debug("Skipping malformed variant entry %s: %s", item, exc)
849
+ return parsed
850
+
851
+
852
+
853
+ def extract_complete_lineage(
854
+ text: str,
855
+ model,
856
+ *,
857
+ debug_dir: str | Path | None = None,
858
+ campaign_id: Optional[str] = None,
859
+ campaign_info: Optional[Campaign] = None,
860
+ ) -> List[Variant]:
861
+ """Prompt Gemini for the full lineage and return a list[Variant]."""
862
+ # Build campaign context
863
+ campaign_context = ""
864
+ if campaign_info:
865
+ campaign_context = f"""
866
+ CAMPAIGN CONTEXT:
867
+ You are extracting the lineage for the following campaign:
868
+ - Campaign ID: {campaign_info.campaign_id}
869
+ - Campaign: {campaign_info.campaign_name}
870
+ - Description: {campaign_info.description}
871
+ - Model reaction: {campaign_info.substrate_id} → {campaign_info.product_id}
872
+ - Lineage hint: Variants containing "{campaign_info.notes}" belong to this campaign
873
+
874
+ IMPORTANT:
875
+ 1. Only extract variants that belong to this specific campaign.
876
+ 2. Include "campaign_id": "{campaign_info.campaign_id}" for each variant in your response.
877
+ 3. Use the lineage hint pattern above to identify which variants belong to this campaign.
878
+ 4. Include parent variants only if they are direct ancestors in this campaign's lineage.
879
+ """
880
+
881
+ prompt = _LINEAGE_EXTRACT_PROMPT.format(
882
+ campaign_context=campaign_context,
883
+ schema=_LINEAGE_SCHEMA_HINT,
884
+ text=text[:MAX_CHARS],
885
+ )
886
+ raw = generate_json_with_retry(
887
+ model,
888
+ prompt,
889
+ schema_hint=_LINEAGE_SCHEMA_HINT,
890
+ debug_dir=debug_dir,
891
+ tag="lineage",
892
+ )
893
+ variants = _parse_variants(raw, campaign_id=campaign_id)
894
+ log.info("Extracted %d lineage entries", len(variants))
895
+ return variants
896
+
897
+
898
+ def extract_lineage_from_figure(
899
+ figure_bytes: bytes,
900
+ model,
901
+ *,
902
+ debug_dir: str | Path | None = None,
903
+ campaign_id: Optional[str] = None,
904
+ campaign_info: Optional[Campaign] = None,
905
+ ) -> List[Variant]:
906
+ """Extract lineage from a figure image using Gemini's vision capabilities."""
907
+ # Build campaign context
908
+ campaign_context = ""
909
+ if campaign_info:
910
+ campaign_context = f"""
911
+ CAMPAIGN CONTEXT:
912
+ You are extracting the lineage for the following campaign:
913
+ - Campaign: {campaign_info.campaign_name}
914
+ - Description: {campaign_info.description}
915
+ - Model reaction: {campaign_info.substrate_id} → {campaign_info.product_id}
916
+ - Lineage hint: Variants containing "{campaign_info.notes}" belong to this campaign
917
+
918
+ IMPORTANT: Only extract variants that belong to this specific campaign.
919
+ """
920
+
921
+ prompt = _LINEAGE_FIGURE_PROMPT.format(
922
+ campaign_context=campaign_context,
923
+ schema=_LINEAGE_SCHEMA_HINT
924
+ )
925
+
926
+ # Log prompt details
927
+ log.info("=== GEMINI VISION API CALL: FIGURE_LINEAGE ===")
928
+ log.info("Prompt length: %d characters", len(prompt))
929
+ log.info("Image size: %d bytes", len(figure_bytes))
930
+ log.info("First 500 chars of prompt:\n%s\n...(truncated)", prompt[:500])
931
+
932
+ # Save prompt and image to debug directory if provided
933
+ if debug_dir:
934
+ debug_path = Path(debug_dir)
935
+ debug_path.mkdir(parents=True, exist_ok=True)
936
+
937
+ # Save prompt
938
+ prompt_file = debug_path / f"figure_lineage_prompt_{int(time.time())}.txt"
939
+ _dump(f"=== PROMPT FOR FIGURE_LINEAGE ===\nTimestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\nLength: {len(prompt)} characters\nImage size: {len(figure_bytes)} bytes\n{'='*80}\n\n{prompt}",
940
+ prompt_file)
941
+ log.info("Full prompt saved to: %s", prompt_file)
942
+
943
+ # Save image
944
+ image_file = debug_path / f"figure_lineage_image_{int(time.time())}.png"
945
+ _dump(figure_bytes, image_file)
946
+ log.info("Figure image saved to: %s", image_file)
947
+
948
+ # For Gemini vision API, we need to pass the image differently
949
+ # This will depend on the specific SDK version being used
950
+ try:
951
+ # Create a multimodal prompt with the image
952
+ import PIL.Image
953
+ import io
954
+
955
+ # Convert bytes to PIL Image
956
+ image = PIL.Image.open(io.BytesIO(figure_bytes))
957
+
958
+ log.info("Calling Gemini Vision API...")
959
+ # Generate content with image
960
+ response = model.generate_content([prompt, image])
961
+ raw_text = _extract_text(response).strip()
962
+
963
+ # Log response
964
+ log.info("Gemini figure analysis response length: %d characters", len(raw_text))
965
+ log.info("First 500 chars of response:\n%s\n...(truncated)", raw_text[:500])
966
+
967
+ # Save response to debug directory if provided
968
+ if debug_dir:
969
+ debug_path = Path(debug_dir)
970
+ debug_path.mkdir(parents=True, exist_ok=True)
971
+ response_file = debug_path / f"figure_lineage_response_{int(time.time())}.txt"
972
+ with open(response_file, 'w') as f:
973
+ f.write(f"=== RESPONSE FOR FIGURE LINEAGE ===\n")
974
+ f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
975
+ f.write(f"Length: {len(raw_text)} characters\n")
976
+ f.write("="*80 + "\n\n")
977
+ f.write(raw_text)
978
+ log.info("Full response saved to: %s", response_file)
979
+
980
+ # Parse JSON from response
981
+ fence_re = re.compile(r"```json|```", re.I)
982
+ if raw_text.startswith("```"):
983
+ raw_text = fence_re.sub("", raw_text).strip()
984
+
985
+ raw = json.loads(raw_text)
986
+
987
+ # Handle both array and object formats
988
+ if isinstance(raw, list):
989
+ # Direct array format - convert to expected format
990
+ variants_data = {"variants": raw}
991
+ else:
992
+ # Already in object format
993
+ variants_data = raw
994
+
995
+ variants = _parse_variants(variants_data, campaign_id=campaign_id)
996
+ log.info("Extracted %d lineage entries from figure", len(variants))
997
+ return variants
998
+
999
+ except Exception as exc:
1000
+ log.warning("Failed to extract lineage from figure: %s", exc)
1001
+ return []
1002
+
1003
+
1004
+ # ---- 6.3 Helper for location-based extraction -----------------------------
1005
+
1006
+ def _extract_text_at_locations(text: str, locations: List[Union[str, dict]], context_chars: int = 5000, validate_sequences: bool = False) -> str:
1007
+ """Extract text around identified locations."""
1008
+ if not locations:
1009
+ return text
1010
+
1011
+ extracted_sections = []
1012
+ text_lower = text.lower()
1013
+
1014
+ log.info("Extracting text around %d locations with %d chars context",
1015
+ len(locations), context_chars)
1016
+
1017
+ for location in locations:
1018
+ # Handle both string locations and dict formats
1019
+ if isinstance(location, dict):
1020
+ # New format: {location, type, confidence, reason}
1021
+ location_str = location.get('location', location.get('section', location.get('text', '')))
1022
+ location_type = location.get('type', '')
1023
+
1024
+ # Extract page number if present in location string (e.g., "Table S1" might be on a specific page)
1025
+ page_num = location.get('page', '')
1026
+ search_hint = location.get('search_hint', '')
1027
+
1028
+ # Build search patterns in priority order
1029
+ page_patterns = []
1030
+
1031
+ # 1. Try the exact location string first
1032
+ if location_str:
1033
+ page_patterns.append(location_str.lower())
1034
+
1035
+ # 2. Try page markers if we have a page number
1036
+ if page_num:
1037
+ # Clean page number (remove "page" prefix if present)
1038
+ clean_page = str(page_num).replace('page', '').strip()
1039
+ if clean_page.startswith('S') or clean_page.startswith('s'):
1040
+ page_patterns.extend([f"\n{clean_page}\n", f"\n{clean_page} \n"])
1041
+ else:
1042
+ page_patterns.extend([f"\ns{clean_page}\n", f"\nS{clean_page}\n", f"\n{clean_page}\n"])
1043
+
1044
+ # 3. Try the search hint if provided
1045
+ if search_hint:
1046
+ page_patterns.append(search_hint.lower())
1047
+
1048
+ # 4. Try partial matches for section headers if location looks like a section
1049
+ if location_str and '.' in location_str:
1050
+ text_parts = location_str.split('.')
1051
+ if len(text_parts) > 1:
1052
+ page_patterns.append(text_parts[0].lower() + '.')
1053
+ page_patterns.append(location_str.lower())
1054
+
1055
+ else:
1056
+ # Backward compatibility for string locations
1057
+ page_patterns = [str(location).lower()]
1058
+ location_str = str(location)
1059
+
1060
+ # Try each pattern
1061
+ pos = -1
1062
+ used_pattern = None
1063
+ for pattern in page_patterns:
1064
+ temp_pos = text_lower.find(pattern.lower())
1065
+ if temp_pos != -1:
1066
+ pos = temp_pos
1067
+ used_pattern = pattern
1068
+ log.debug("Found pattern '%s' at position %d", pattern, pos)
1069
+ break
1070
+
1071
+ if pos != -1:
1072
+ if validate_sequences:
1073
+ # For sequence extraction, find ALL occurrences and test each one
1074
+ all_positions = []
1075
+ search_pos = pos
1076
+
1077
+ # Find all occurrences of this pattern
1078
+ while search_pos != -1:
1079
+ all_positions.append(search_pos)
1080
+ search_pos = text_lower.find(used_pattern.lower(), search_pos + 1)
1081
+ if len(all_positions) >= 10: # Limit to 10 occurrences
1082
+ break
1083
+
1084
+ log.info("Found %d occurrences of pattern '%s' for location '%s'",
1085
+ len(all_positions), used_pattern, location_str)
1086
+
1087
+ # Test each position for sequences
1088
+ best_position = -1
1089
+ best_score = 0
1090
+ test_window = 1000 # Test 1000 chars from each position
1091
+
1092
+ for test_pos in all_positions:
1093
+ test_end = min(len(text), test_pos + test_window)
1094
+ test_text = text[test_pos:test_end]
1095
+
1096
+ # Count sequences in this window
1097
+ clean_text = re.sub(r'\s+', '', test_text.upper())
1098
+ aa_matches = len(re.findall(f"[{''.join(_VALID_AA)}]{{50,}}", clean_text))
1099
+ dna_matches = len(re.findall(f"[{''.join(_VALID_DNA)}]{{50,}}", clean_text))
1100
+ score = aa_matches + dna_matches
1101
+
1102
+ if score > 0:
1103
+ log.info("Position %d: found %d AA and %d DNA sequences (score: %d)",
1104
+ test_pos, aa_matches, dna_matches, score)
1105
+
1106
+ if score > best_score:
1107
+ best_score = score
1108
+ best_position = test_pos
1109
+
1110
+ if best_position != -1:
1111
+ # Extract from the best position
1112
+ end = min(len(text), best_position + context_chars)
1113
+ section_text = text[best_position:end]
1114
+ extracted_sections.append(section_text)
1115
+ log.info("Selected position %d with %d sequences for '%s', extracted %d chars",
1116
+ best_position, best_score, location_str, len(section_text))
1117
+ else:
1118
+ log.warning("No sequences found in any of %d occurrences of '%s'",
1119
+ len(all_positions), location_str)
1120
+ else:
1121
+ # For lineage extraction, use the original logic
1122
+ start = max(0, pos - context_chars)
1123
+ end = min(len(text), pos + len(used_pattern) + context_chars)
1124
+ section_text = text[start:end]
1125
+ extracted_sections.append(section_text)
1126
+ log.info("Found '%s' using pattern '%s' at position %d, extracted %d chars",
1127
+ location_str, used_pattern, pos, len(section_text))
1128
+ else:
1129
+ log.warning("Location '%s' not found in text (tried %d patterns)", location_str, len(page_patterns))
1130
+
1131
+ combined = "\n\n[...]\n\n".join(extracted_sections) if extracted_sections else text
1132
+ log.info("Combined %d sections into %d total chars",
1133
+ len(extracted_sections), len(combined))
1134
+ return combined
1135
+
1136
+
1137
+ # ---- 6.4 Public API -------------------------------------------------------
1138
+
1139
+ def get_lineage(
1140
+ caption_text: str,
1141
+ full_text: str,
1142
+ model,
1143
+ *,
1144
+ pdf_paths: Optional[List[Path]] = None,
1145
+ debug_dir: str | Path | None = None,
1146
+ ) -> Tuple[List[Variant], List[Campaign]]:
1147
+ """
1148
+ High-level wrapper used by the pipeline.
1149
+
1150
+ 1. Identify distinct campaigns in the manuscript.
1151
+ 2. Use captions to ask Gemini where the lineage is likely described (fast & focused).
1152
+ 3. Map locations to campaigns.
1153
+ 4. Extract lineage for each campaign separately.
1154
+ 5. Return both variants and campaigns.
1155
+ """
1156
+ # First, identify campaigns in the manuscript
1157
+ campaigns = identify_campaigns(full_text[:50_000], model, debug_dir=debug_dir)
1158
+
1159
+ if campaigns:
1160
+ log.info(f"Identified {len(campaigns)} distinct campaigns")
1161
+ for camp in campaigns:
1162
+ log.info(f" - {camp.campaign_name}: {camp.description}")
1163
+
1164
+ # Use captions for identification - they're concise and focused
1165
+ locations = identify_evolution_locations(caption_text, model, debug_dir=debug_dir, campaigns=campaigns)
1166
+
1167
+ all_variants = []
1168
+
1169
+ if locations:
1170
+ # Log location information
1171
+ location_summary = []
1172
+ for loc in locations[:5]:
1173
+ if isinstance(loc, dict):
1174
+ campaign_info = f", campaign: {loc.get('campaign_id', 'unknown')}" if 'campaign_id' in loc else ""
1175
+ location_summary.append(f"{loc.get('location', 'Unknown')} ({loc.get('type', 'unknown')}, confidence: {loc.get('confidence', 0)}{campaign_info})")
1176
+ else:
1177
+ location_summary.append(str(loc))
1178
+ log.info("Gemini identified %d potential lineage locations: %s",
1179
+ len(locations), ", ".join(location_summary))
1180
+
1181
+ # Group locations by campaign
1182
+ locations_by_campaign = {}
1183
+ for loc in locations:
1184
+ campaign_id = loc.get('campaign_id', 'default') if isinstance(loc, dict) else 'default'
1185
+ if campaign_id not in locations_by_campaign:
1186
+ locations_by_campaign[campaign_id] = []
1187
+ locations_by_campaign[campaign_id].append(loc)
1188
+
1189
+ # Process each campaign's locations
1190
+ for campaign_id, campaign_locations in locations_by_campaign.items():
1191
+ log.info(f"Processing campaign: {campaign_id}")
1192
+
1193
+ # Sort locations by confidence to get the highest confidence one
1194
+ sorted_locations = sorted(campaign_locations,
1195
+ key=lambda x: x.get('confidence', 0) if isinstance(x, dict) else 0,
1196
+ reverse=True)
1197
+
1198
+ # Use only the highest confidence location to avoid duplicates
1199
+ primary_location = sorted_locations[0] if sorted_locations else None
1200
+
1201
+ # Track if we successfully extracted from figure
1202
+ extracted_from_figure = False
1203
+
1204
+ if isinstance(primary_location, dict):
1205
+ location_str = primary_location.get('location', '')
1206
+ location_type = primary_location.get('type', '')
1207
+ confidence = primary_location.get('confidence', 0)
1208
+ reason = primary_location.get('reason', '')
1209
+
1210
+ # Only try figure extraction for high-confidence figures
1211
+ if location_type == 'figure' and confidence >= 80 and pdf_paths:
1212
+ log.info("Primary lineage source is a high-confidence figure: %s (confidence: %d, reason: %s)",
1213
+ location_str, confidence, reason)
1214
+
1215
+ # Try to extract the figure from available PDFs
1216
+ figure_bytes = None
1217
+ for pdf_path in pdf_paths:
1218
+ figure_bytes = extract_figure(pdf_path, location_str, debug_dir=debug_dir)
1219
+ if figure_bytes:
1220
+ log.info("Successfully extracted figure from %s", pdf_path.name)
1221
+ break
1222
+
1223
+ if figure_bytes:
1224
+ # Save figure to debug directory if provided
1225
+ if debug_dir:
1226
+ debug_path = Path(debug_dir)
1227
+ debug_path.mkdir(parents=True, exist_ok=True)
1228
+ figure_file = debug_path / f"lineage_figure_{location_str.replace(' ', '_')}_{int(time.time())}.png"
1229
+ _dump(figure_bytes, figure_file)
1230
+ log.info("Saved lineage figure to: %s", figure_file)
1231
+
1232
+ # Extract lineage from the figure
1233
+ campaign_obj = next((c for c in campaigns if c.campaign_id == campaign_id), None)
1234
+ variants = extract_lineage_from_figure(
1235
+ figure_bytes, model,
1236
+ debug_dir=debug_dir,
1237
+ campaign_id=campaign_id,
1238
+ campaign_info=campaign_obj
1239
+ )
1240
+ if variants:
1241
+ all_variants.extend(variants)
1242
+ extracted_from_figure = True
1243
+ else:
1244
+ log.warning("Failed to extract lineage from figure, falling back to text extraction")
1245
+ else:
1246
+ log.warning("Could not extract figure '%s', falling back to text extraction", location_str)
1247
+ elif location_type == 'table':
1248
+ log.info("Primary lineage source is a table: %s (confidence: %d, reason: %s)",
1249
+ location_str, confidence, reason)
1250
+
1251
+ # Skip text extraction if we already got variants from figure
1252
+ if extracted_from_figure:
1253
+ continue
1254
+
1255
+ # Use text-based extraction (works for tables and text sections)
1256
+ # Extract from full text, not caption text - use only primary location
1257
+ focused_text = _extract_text_at_locations(full_text, [primary_location])
1258
+ log.info("Reduced text from %d to %d chars using primary location %s for campaign %s",
1259
+ len(full_text), len(focused_text),
1260
+ primary_location.get('location', 'Unknown') if isinstance(primary_location, dict) else 'Unknown',
1261
+ campaign_id)
1262
+
1263
+ # Find the campaign object
1264
+ campaign_obj = next((c for c in campaigns if c.campaign_id == campaign_id), None)
1265
+ campaign_variants = extract_complete_lineage(
1266
+ focused_text, model,
1267
+ debug_dir=debug_dir,
1268
+ campaign_id=campaign_id,
1269
+ campaign_info=campaign_obj
1270
+ )
1271
+ all_variants.extend(campaign_variants)
1272
+
1273
+ return all_variants, campaigns
1274
+ else:
1275
+ log.info("Gemini did not identify specific lineage locations")
1276
+ variants = extract_complete_lineage(full_text, model, debug_dir=debug_dir)
1277
+ return variants, campaigns
1278
+
1279
+ # === 7. SEQUENCE EXTRACTION === ----------------------------------------------
1280
+ # Pull every protein and/or DNA sequence for each variant.
1281
+ # 1. Ask Gemini where sequences live (cheap, quick prompt).
1282
+ # 2. Ask Gemini to return the sequences in strict JSON.
1283
+ # 3. Validate and convert to `SequenceBlock` objects.
1284
+
1285
+ # --- 7.0 JSON schema hint ----------------------------------------------------
1286
+ _SEQUENCE_SCHEMA_HINT = """
1287
+ [
1288
+ {
1289
+ "variant_id": "string", // e.g. "IV-G2", "Round4-10"
1290
+ "aa_seq": "string|null", // uppercase amino acids or null
1291
+ "dna_seq": "string|null" // uppercase A/C/G/T or null
1292
+ }
1293
+ ]
1294
+ """.strip()
1295
+
1296
+ # --- 7.1 Quick scan: where are the sequences? --------------------------------
1297
+ _SEQ_LOC_PROMPT = """
1298
+ Find where FULL-LENGTH protein or DNA sequences are located in this document.
1299
+
1300
+ Look for table of contents entries or section listings that mention sequences.
1301
+ Return a JSON array where each element has:
1302
+ - "section": the section heading or description
1303
+ - "page": the page number shown in the table of contents for this section, to your best judgement.
1304
+
1305
+ Focus on:
1306
+ - Table of contents or entries about "Sequence Information" or "Nucleotide and amino acid sequences"
1307
+ - Return the EXACT notation as shown.
1308
+
1309
+ Return [] if no sequence sections are found.
1310
+ Absolutely don't include nucleotides or primer sequences, it is better to return nothing then incomplete sequence, use your best judgement.
1311
+
1312
+ TEXT (truncated):
1313
+ ```
1314
+ {chunk}
1315
+ ```
1316
+ """.strip()
1317
+
1318
+ def identify_sequence_locations(text: str, model, *, debug_dir: str | Path | None = None) -> list[dict]:
1319
+ """Ask Gemini for promising places to look for sequences."""
1320
+ prompt = _SEQ_LOC_PROMPT.format(chunk=text[:15_000])
1321
+ try:
1322
+ locs = generate_json_with_retry(model, prompt, debug_dir=debug_dir, tag="seq_locations")
1323
+ return locs if isinstance(locs, list) else []
1324
+ except Exception as exc: # pylint: disable=broad-except
1325
+ log.warning("identify_sequence_locations(): %s", exc)
1326
+ return []
1327
+
1328
+ # --- 7.2 Page-based extraction helper ---------------------------------------
1329
+ def _extract_text_from_page(pdf_paths: List[Path], page_num: Union[str, int]) -> str:
1330
+ """Extract text from a specific page number in the PDFs."""
1331
+ # Convert page number to int and handle S-prefix
1332
+ page_str = str(page_num).strip().upper()
1333
+ if page_str.startswith('S'):
1334
+ # Supplementary page - look in the SI PDF (second PDF)
1335
+ actual_page = int(page_str[1:]) - 1 # 0-indexed
1336
+ pdf_index = 1 if len(pdf_paths) > 1 else 0
1337
+ else:
1338
+ # Regular page - look in the main PDF
1339
+ actual_page = int(page_str) - 1 # 0-indexed
1340
+ pdf_index = 0
1341
+
1342
+ if pdf_index >= len(pdf_paths):
1343
+ log.warning("Page %s requested but not enough PDFs provided", page_str)
1344
+ return ""
1345
+
1346
+ try:
1347
+ doc = fitz.open(pdf_paths[pdf_index])
1348
+ if 0 <= actual_page < len(doc):
1349
+ page = doc[actual_page]
1350
+ page_text = page.get_text()
1351
+ doc.close()
1352
+ log.info("Extracted %d chars from page %s of %s",
1353
+ len(page_text), page_str, pdf_paths[pdf_index].name)
1354
+ return page_text
1355
+ else:
1356
+ log.warning("Page %s (index %d) out of range for %s (has %d pages)",
1357
+ page_str, actual_page, pdf_paths[pdf_index].name, len(doc))
1358
+ doc.close()
1359
+ return ""
1360
+ except Exception as e:
1361
+ log.error("Failed to extract page %s: %s", page_str, e)
1362
+ return ""
1363
+
1364
+ # --- 7.3 Location validation with samples -----------------------------------
1365
+ _LOC_VALIDATION_PROMPT = """
1366
+ Which sample contains ACTUAL protein/DNA sequences (long strings of ACDEFGHIKLMNPQRSTVWY or ACGT)?
1367
+ Not mutation lists, but actual sequences.
1368
+
1369
+ {samples}
1370
+
1371
+ Reply with ONLY a number: the location_id of the best sample (or -1 if none have sequences).
1372
+ """.strip()
1373
+
1374
+ def validate_sequence_locations(text: str, locations: list, model, *, pdf_paths: List[Path] = None, debug_dir: str | Path | None = None) -> dict:
1375
+ """Extract samples from each location and ask Gemini to pick the best one."""
1376
+ if not locations:
1377
+ return {"best_location_id": -1, "reason": "No locations provided"}
1378
+
1379
+ # Extract 500 char samples from each location
1380
+ samples = []
1381
+ for i, location in enumerate(locations[:5]): # Limit to 5 locations
1382
+ sample_text = ""
1383
+
1384
+ # If we have PDFs and location has a page number, use page extraction
1385
+ if pdf_paths and isinstance(location, dict) and 'page' in location:
1386
+ page_num = location['page']
1387
+ page_text = _extract_text_from_page(pdf_paths, page_num)
1388
+
1389
+ # Also try to extract from the next page
1390
+ next_page_text = ""
1391
+ try:
1392
+ page_str = str(page_num).strip().upper()
1393
+ if page_str.startswith('S'):
1394
+ next_page = f"S{int(page_str[1:]) + 1}"
1395
+ else:
1396
+ next_page = str(int(page_str) + 1)
1397
+ next_page_text = _extract_text_from_page(pdf_paths, next_page)
1398
+ except:
1399
+ pass
1400
+
1401
+ # Combine both pages
1402
+ combined_text = page_text + "\n" + next_page_text if next_page_text else page_text
1403
+
1404
+ if combined_text:
1405
+ # Find the section within the combined pages if possible
1406
+ section = location.get('section', location.get('text', ''))
1407
+ if section:
1408
+ # Try to find section in pages
1409
+ section_lower = section.lower()
1410
+ combined_lower = combined_text.lower()
1411
+ pos = combined_lower.find(section_lower)
1412
+ if pos >= 0:
1413
+ # Extract from section start
1414
+ sample_text = combined_text[pos:pos+5000]
1415
+ else:
1416
+ # Section not found, take from beginning
1417
+ sample_text = combined_text[:10000]
1418
+ else:
1419
+ # No section, take from beginning
1420
+ sample_text = combined_text[:10000]
1421
+
1422
+ # Fallback to text search if page extraction didn't work
1423
+ if not sample_text:
1424
+ sample_text = _extract_text_at_locations(
1425
+ text, [location], context_chars=2000, validate_sequences=False
1426
+ )
1427
+
1428
+ samples.append({
1429
+ "location_id": i,
1430
+ "location": str(location),
1431
+ "sample": sample_text[:5000] if sample_text else ""
1432
+ })
1433
+
1434
+ # Ask Gemini to analyze samples
1435
+ prompt = _LOC_VALIDATION_PROMPT.format(samples=json.dumps(samples, indent=2))
1436
+
1437
+ # Save prompt for debugging
1438
+ if debug_dir:
1439
+ _dump(f"=== PROMPT FOR LOCATION_VALIDATION ===\nTimestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\nLength: {len(prompt)} characters\n{'='*80}\n\n{prompt}",
1440
+ Path(debug_dir) / f"location_validation_prompt_{int(time.time())}.txt")
1441
+
1442
+ try:
1443
+ # Get simple numeric response from Gemini
1444
+ response = model.generate_content(prompt)
1445
+ response_text = response.text.strip()
1446
+
1447
+ # Save response for debugging
1448
+ if debug_dir:
1449
+ _dump(f"=== RESPONSE FOR LOCATION_VALIDATION ===\nTimestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\nLength: {len(response_text)} characters\n{'='*80}\n\n{response_text}",
1450
+ Path(debug_dir) / f"location_validation_response_{int(time.time())}.txt")
1451
+
1452
+ # Try to extract the number from response
1453
+ match = re.search(r'-?\d+', response_text)
1454
+ if match:
1455
+ best_id = int(match.group())
1456
+ return {"best_location_id": best_id, "reason": "Selected by Gemini"}
1457
+ else:
1458
+ log.warning("Could not parse location ID from response: %s", response_text)
1459
+ return {"best_location_id": -1, "reason": "Could not parse response"}
1460
+
1461
+ except Exception as exc:
1462
+ log.warning("validate_sequence_locations(): %s", exc)
1463
+ return {"best_location_id": -1, "reason": str(exc)}
1464
+
1465
+ # --- 7.3 Main extraction prompt ---------------------------------------------
1466
+ _SEQ_EXTRACTION_PROMPT = """
1467
+ Extract EVERY distinct enzyme-variant sequence you can find in the text.
1468
+ For each variant return:
1469
+ * variant_id - the label used in the paper (e.g. "R4-10")
1470
+ * aa_seq - amino-acid sequence (uppercase), or null
1471
+ * dna_seq - DNA sequence (A/C/G/T), or null
1472
+
1473
+ Respond ONLY with **minified JSON** that matches the schema below.
1474
+ NO markdown, no code fences, no commentary.
1475
+
1476
+ Schema:
1477
+ ```json
1478
+ {schema}
1479
+ ```
1480
+
1481
+ TEXT (may be truncated):
1482
+ ```
1483
+ {text}
1484
+ ```
1485
+ """.strip()
1486
+
1487
+ def extract_sequences(text: str, model, *, debug_dir: str | Path | None = None) -> list[SequenceBlock]:
1488
+ """Prompt Gemini and convert its JSON reply into SequenceBlock objects."""
1489
+ prompt = _SEQ_EXTRACTION_PROMPT.format(
1490
+ schema=_SEQUENCE_SCHEMA_HINT, text=text[:MAX_CHARS]
1491
+ )
1492
+ data = generate_json_with_retry(model, prompt, _SEQUENCE_SCHEMA_HINT, debug_dir=debug_dir, tag="sequences")
1493
+ return _parse_sequences(data)
1494
+
1495
+ # --- 7.4 JSON -> dataclass helpers -------------------------------------------
1496
+ _VALID_AA = set("ACDEFGHIKLMNPQRSTVWY")
1497
+ _VALID_DNA = set("ACGT")
1498
+
1499
+ def _contains_sequence(text: str, min_length: int = 50) -> bool:
1500
+ """Check if text contains likely protein or DNA sequences."""
1501
+ # Remove whitespace for checking
1502
+ clean_text = re.sub(r'\s+', '', text.upper())
1503
+
1504
+ # Check for continuous stretches of valid amino acids or DNA
1505
+ # Look for at least min_length consecutive valid characters
1506
+ aa_pattern = f"[{''.join(_VALID_AA)}]{{{min_length},}}"
1507
+ dna_pattern = f"[{''.join(_VALID_DNA)}]{{{min_length},}}"
1508
+
1509
+ return bool(re.search(aa_pattern, clean_text) or re.search(dna_pattern, clean_text))
1510
+
1511
+ def _clean_seq(seq: str | None, alphabet: set[str]) -> str | None:
1512
+ if not seq:
1513
+ return None
1514
+ seq = re.sub(r"\s+", "", seq).upper()
1515
+ return seq if seq and all(ch in alphabet for ch in seq) else None
1516
+
1517
+ def _parse_sequences(raw: list[dict]) -> list[SequenceBlock]:
1518
+ """Validate and convert raw JSON into SequenceBlock instances."""
1519
+ blocks: list[SequenceBlock] = []
1520
+ for entry in raw:
1521
+ vid = (entry.get("variant_id") or entry.get("id") or "").strip()
1522
+ if not vid:
1523
+ continue
1524
+ aa = _clean_seq(entry.get("aa_seq"), _VALID_AA)
1525
+ dna = _clean_seq(entry.get("dna_seq"), _VALID_DNA)
1526
+
1527
+ conf: float | None = None
1528
+ if aa:
1529
+ conf = sum(c in _VALID_AA for c in aa) / len(aa)
1530
+ elif dna:
1531
+ conf = sum(c in _VALID_DNA for c in dna) / len(dna)
1532
+
1533
+ blocks.append(
1534
+ SequenceBlock(
1535
+ variant_id=vid,
1536
+ aa_seq=aa,
1537
+ dna_seq=dna,
1538
+ confidence=conf,
1539
+ truncated=False,
1540
+ )
1541
+ )
1542
+ return blocks
1543
+
1544
+ # --- 7.5 Convenience wrapper -------------------------------------------------
1545
+ def get_sequences(text: str, model, *, pdf_paths: List[Path] = None, debug_dir: str | Path | None = None) -> list[SequenceBlock]:
1546
+ # Phase 1: Identify where sequences might be located
1547
+ locations = identify_sequence_locations(text, model, debug_dir=debug_dir)
1548
+
1549
+ if locations:
1550
+ # Format location info for logging
1551
+ loc_strs = []
1552
+ for loc in locations[:5]:
1553
+ if isinstance(loc, dict):
1554
+ section = loc.get('section', loc.get('text', ''))
1555
+ page = loc.get('page', '')
1556
+ loc_strs.append(f"{section} (page {page})")
1557
+ else:
1558
+ loc_strs.append(str(loc))
1559
+ log.info("Gemini identified %d potential sequence locations: %s",
1560
+ len(locations), ", ".join(loc_strs))
1561
+
1562
+ # Phase 2: Validate locations with sample extraction
1563
+ validation = validate_sequence_locations(text, locations, model, pdf_paths=pdf_paths, debug_dir=debug_dir)
1564
+ best_loc_id = validation.get("best_location_id", -1)
1565
+
1566
+ if best_loc_id >= 0 and best_loc_id < len(locations):
1567
+ # Use the validated best location
1568
+ best_location = locations[best_loc_id]
1569
+ log.info("Using validated best location: %s (reason: %s)",
1570
+ loc_strs[best_loc_id] if best_loc_id < len(loc_strs) else str(best_location),
1571
+ validation.get("reason", ""))
1572
+
1573
+ # Extract with suggested strategy
1574
+ strategy = validation.get("extraction_strategy", {})
1575
+ start_offset = strategy.get("start_offset", 0)
1576
+ min_length = strategy.get("min_length", 30000)
1577
+
1578
+ # Try page-based extraction first if we have page info
1579
+ focused_text = ""
1580
+ if pdf_paths and isinstance(best_location, dict) and 'page' in best_location:
1581
+ page_num = best_location['page']
1582
+ # Extract current page plus next 15 pages
1583
+ all_pages = []
1584
+ for i in range(16): # Current + next 15
1585
+ if isinstance(page_num, str) and page_num.upper().startswith('S'):
1586
+ next_page = f"S{int(page_num[1:]) + i}"
1587
+ else:
1588
+ next_page = str(int(page_num) + i)
1589
+ page_text = _extract_text_from_page(pdf_paths, next_page)
1590
+ if page_text:
1591
+ all_pages.append(page_text)
1592
+ else:
1593
+ break
1594
+ if all_pages:
1595
+ focused_text = "\n".join(all_pages)
1596
+ log.info("Extracted %d chars from pages %s through %d more pages",
1597
+ len(focused_text), page_num, len(all_pages) - 1)
1598
+
1599
+ # Fallback to text search if page extraction didn't work
1600
+ if not focused_text:
1601
+ focused_text = _extract_text_at_locations(
1602
+ text, [best_location],
1603
+ context_chars=max(min_length, 30000),
1604
+ validate_sequences=True
1605
+ )
1606
+
1607
+ if focused_text and len(focused_text) < len(text):
1608
+ log.info("Reduced text from %d to %d chars using validated location",
1609
+ len(text), len(focused_text))
1610
+ return extract_sequences(focused_text, model, debug_dir=debug_dir)
1611
+ else:
1612
+ log.warning("Location validation failed or returned invalid location: %s",
1613
+ validation.get("reason", "Unknown"))
1614
+
1615
+ # Fallback to full text
1616
+ log.info("Using full text for sequence extraction")
1617
+ return extract_sequences(text, model, debug_dir=debug_dir)
1618
+
1619
+ # === 7.6 PDB SEQUENCE EXTRACTION === -----------------------------------------
1620
+ """When no sequences are found in the paper, attempt to fetch them from PDB."""
1621
+
1622
+ def fetch_pdb_sequences(pdb_id: str) -> Dict[str, str]:
1623
+ """Fetch protein sequences from PDB using RCSB API.
1624
+
1625
+ Returns dict mapping chain IDs to sequences.
1626
+ """
1627
+ # Use the GraphQL API which is more reliable
1628
+ url = "https://data.rcsb.org/graphql"
1629
+
1630
+ query = """
1631
+ query getSequences($pdb_id: String!) {
1632
+ entry(entry_id: $pdb_id) {
1633
+ polymer_entities {
1634
+ entity_poly {
1635
+ pdbx_seq_one_letter_code_can
1636
+ }
1637
+ rcsb_polymer_entity_container_identifiers {
1638
+ auth_asym_ids
1639
+ }
1640
+ }
1641
+ }
1642
+ }
1643
+ """
1644
+
1645
+ try:
1646
+ import requests
1647
+ response = requests.post(
1648
+ url,
1649
+ json={"query": query, "variables": {"pdb_id": pdb_id.upper()}},
1650
+ timeout=10
1651
+ )
1652
+ response.raise_for_status()
1653
+ data = response.json()
1654
+
1655
+ sequences = {}
1656
+ entry_data = data.get('data', {}).get('entry', {})
1657
+
1658
+ if entry_data:
1659
+ for entity in entry_data.get('polymer_entities', []):
1660
+ # Get sequence
1661
+ seq_data = entity.get('entity_poly', {})
1662
+ sequence = seq_data.get('pdbx_seq_one_letter_code_can', '')
1663
+
1664
+ # Get chain IDs
1665
+ chain_data = entity.get('rcsb_polymer_entity_container_identifiers', {})
1666
+ chain_ids = chain_data.get('auth_asym_ids', [])
1667
+
1668
+ if sequence and chain_ids:
1669
+ # Clean sequence - remove newlines and spaces
1670
+ clean_seq = sequence.replace('\n', '').replace(' ', '').upper()
1671
+
1672
+ # Add sequence for each chain
1673
+ for chain_id in chain_ids:
1674
+ sequences[chain_id] = clean_seq
1675
+ log.info(f"PDB {pdb_id} chain {chain_id}: {len(clean_seq)} residues")
1676
+
1677
+ return sequences
1678
+
1679
+ except Exception as e:
1680
+ log.warning(f"Failed to fetch PDB {pdb_id}: {e}")
1681
+ return {}
1682
+
1683
+ def match_pdb_to_variants(
1684
+ pdb_sequences: Dict[str, str],
1685
+ variants: List[Variant],
1686
+ lineage_text: str,
1687
+ model,
1688
+ pdb_id: str = None,
1689
+ ) -> Dict[str, str]:
1690
+ """Match PDB chains to variant IDs using LLM analysis of mutations.
1691
+
1692
+ Returns a mapping where each variant maps to at most one PDB chain.
1693
+ Since all chains from a single PDB typically have the same sequence,
1694
+ we match the PDB to a single variant based on context.
1695
+ """
1696
+
1697
+ if not pdb_sequences or not variants:
1698
+ return {}
1699
+
1700
+ # Extract context around PDB ID mentions if possible
1701
+ context_text = ""
1702
+ if pdb_id and lineage_text:
1703
+ # Search for PDB ID mentions in the text
1704
+ pdb_mentions = []
1705
+ text_lower = lineage_text.lower()
1706
+ pdb_lower = pdb_id.lower()
1707
+
1708
+ # Find all occurrences of the PDB ID
1709
+ start = 0
1710
+ while True:
1711
+ pos = text_lower.find(pdb_lower, start)
1712
+ if pos == -1:
1713
+ break
1714
+
1715
+ # Extract context around the mention (300 chars before, 300 after)
1716
+ context_start = max(0, pos - 300)
1717
+ context_end = min(len(lineage_text), pos + len(pdb_id) + 300)
1718
+ context = lineage_text[context_start:context_end]
1719
+
1720
+ # Add ellipsis if truncated
1721
+ if context_start > 0:
1722
+ context = "..." + context
1723
+ if context_end < len(lineage_text):
1724
+ context = context + "..."
1725
+
1726
+ pdb_mentions.append(context)
1727
+ start = pos + 1
1728
+
1729
+ if pdb_mentions:
1730
+ context_text = "\n\n---\n\n".join(pdb_mentions[:3]) # Use up to 3 mentions
1731
+ log.info(f"Found {len(pdb_mentions)} mentions of PDB {pdb_id}")
1732
+ else:
1733
+ # Fallback to general context if no specific mentions found
1734
+ context_text = lineage_text[:2000]
1735
+ else:
1736
+ # Fallback to general context
1737
+ context_text = lineage_text[:2000] if lineage_text else ""
1738
+
1739
+ # Get the first chain's sequence as representative (usually all chains have same sequence)
1740
+ first_chain = list(pdb_sequences.keys())[0]
1741
+ seq_preview = pdb_sequences[first_chain]
1742
+ seq_preview = f"{seq_preview[:50]}...{seq_preview[-20:]}" if len(seq_preview) > 70 else seq_preview
1743
+
1744
+ # Build a prompt for Gemini to match ONE variant to this PDB
1745
+ prompt = f"""Given a PDB structure and enzyme variant information, identify which variant corresponds to this PDB structure.
1746
+
1747
+ PDB ID: {pdb_id or "Unknown"}
1748
+ PDB Sequence (from chain {first_chain}):
1749
+ {seq_preview}
1750
+
1751
+ Variant Information:
1752
+ {json.dumps([{"id": v.variant_id, "mutations": v.mutations, "parent": v.parent_id, "generation": v.generation} for v in variants], indent=2)}
1753
+
1754
+ Context from paper mentioning the PDB:
1755
+ {context_text}
1756
+
1757
+ Based on the context, identify which ONE variant this PDB structure represents.
1758
+ Return ONLY the variant_id as a JSON string, e.g.: "ApePgb GLVRSQL"
1759
+ """
1760
+
1761
+ try:
1762
+ response = model.generate_content(prompt)
1763
+ text = _extract_text(response).strip()
1764
+
1765
+ # Parse JSON response (expecting a single string)
1766
+ if text.startswith("```"):
1767
+ text = text.split("```")[1].strip()
1768
+ if text.startswith("json"):
1769
+ text = text[4:].strip()
1770
+
1771
+ # Remove quotes if present
1772
+ text = text.strip('"\'')
1773
+
1774
+ matched_variant = text
1775
+ log.info(f"PDB {pdb_id} matched to variant: {matched_variant}")
1776
+
1777
+ # Return mapping with all chains pointing to the same variant
1778
+ mapping = {}
1779
+ if matched_variant and any(v.variant_id == matched_variant for v in variants):
1780
+ for chain_id in pdb_sequences:
1781
+ mapping[matched_variant] = chain_id
1782
+ break # Only use the first chain
1783
+
1784
+ return mapping
1785
+
1786
+ except Exception as e:
1787
+ log.warning(f"Failed to match PDB to variant: {e}")
1788
+ # No fallback - return empty if we can't match
1789
+ return {}
1790
+
1791
+ # === 8. MERGE, VALIDATE & SCORE === ------------------------------------------
1792
+ """Glue logic to combine lineage records with sequence blocks and produce a
1793
+ single tidy pandas DataFrame that downstream code (pipeline / CLI) can write
1794
+ as CSV or further analyse.
1795
+
1796
+ Responsibilities
1797
+ ----------------
1798
+ 1. Merge: outer-join on `variant_id`, preserving every lineage row even if a
1799
+ sequence is missing.
1800
+ 2. Generation sanity-check: ensure generation numbers are integers >=0; if
1801
+ missing, infer by walking the lineage graph.
1802
+ 3. Confidence: propagate `SequenceBlock.confidence` or compute a simple score
1803
+ if only raw sequences are present.
1804
+ 4. DOI column: attach the article DOI to every row so the CSV is self-contained.
1805
+ """
1806
+
1807
+
1808
+ # --- 8.1 Generation inference -------------------------------------------------
1809
+
1810
+ def _infer_generations(variants: List[Variant]) -> None:
1811
+ """Fill in missing `generation` fields by walking parent -> child edges.
1812
+
1813
+ We build a directed graph of variant relationships and assign generation
1814
+ numbers by distance from the root(s). If cycles exist (shouldn't!), they
1815
+ are broken arbitrarily and a warning is emitted.
1816
+ """
1817
+ graph = nx.DiGraph()
1818
+ for var in variants:
1819
+ graph.add_node(var.variant_id, obj=var)
1820
+ if var.parent_id:
1821
+ graph.add_edge(var.parent_id, var.variant_id)
1822
+
1823
+ # Detect cycles just in case
1824
+ try:
1825
+ roots = [n for n, d in graph.in_degree() if d == 0]
1826
+ for root in roots:
1827
+ for node, depth in nx.single_source_shortest_path_length(graph, root).items():
1828
+ var: Variant = graph.nodes[node]["obj"] # type: ignore[assignment]
1829
+ var.generation = depth if var.generation is None else var.generation
1830
+ except nx.NetworkXUnfeasible:
1831
+ log.warning("Cycle detected in lineage, generation inference skipped")
1832
+
1833
+ # --- 8.2 Merge helpers --------------------------------------------------------
1834
+
1835
+
1836
+ def _merge_lineage_and_sequences(
1837
+ lineage: List[Variant], seqs: List[SequenceBlock], doi: Optional[str]
1838
+ ) -> pd.DataFrame:
1839
+ """Return a tidy DataFrame with one row per variant."""
1840
+
1841
+ # 1. Make DataFrames
1842
+ df_lin = pd.DataFrame([
1843
+ {
1844
+ "variant_id": v.variant_id,
1845
+ "parent_id": v.parent_id,
1846
+ "generation": v.generation,
1847
+ "mutations": ";".join(v.mutations) if v.mutations else None,
1848
+ "campaign_id": v.campaign_id,
1849
+ "notes": v.notes,
1850
+ }
1851
+ for v in lineage
1852
+ ])
1853
+
1854
+ df_seq = pd.DataFrame([
1855
+ {
1856
+ "variant_id": s.variant_id,
1857
+ "aa_seq": s.aa_seq,
1858
+ "dna_seq": s.dna_seq,
1859
+ "seq_confidence": s.confidence,
1860
+ "truncated": s.truncated,
1861
+ }
1862
+ for s in seqs
1863
+ ])
1864
+
1865
+ # 2. Outer merge keeps every lineage entry and adds sequence cols when present
1866
+ df = pd.merge(df_lin, df_seq, on="variant_id", how="left")
1867
+
1868
+ # 3. If generation missing after user input, try inference
1869
+ if df["generation"].isna().any():
1870
+ _infer_generations(lineage) # mutates in place
1871
+ df = df.drop(columns=["generation"]).merge(
1872
+ pd.DataFrame(
1873
+ {"variant_id": [v.variant_id for v in lineage], "generation": [v.generation for v in lineage]}
1874
+ ),
1875
+ on="variant_id",
1876
+ how="left",
1877
+ )
1878
+
1879
+ # 4. Attach DOI column for provenance
1880
+ df["doi"] = doi
1881
+
1882
+ # 5. Sort rows: primary by generation, then by variant_id
1883
+ df = df.sort_values(["generation", "variant_id"], kind="mergesort")
1884
+
1885
+ return df
1886
+
1887
+ # --- 8.3 Public API -----------------------------------------------------------
1888
+
1889
+ def merge_and_score(
1890
+ lineage: List[Variant],
1891
+ seqs: List[SequenceBlock],
1892
+ doi: Optional[str] = None,
1893
+ ) -> pd.DataFrame:
1894
+ """User-facing helper imported by the pipeline orchestrator.
1895
+
1896
+ * Ensures lineage + sequence lists are non-empty.
1897
+ * Performs a shallow validation.
1898
+ * Returns a ready-to-export pandas DataFrame.
1899
+ """
1900
+
1901
+ if not lineage:
1902
+ raise ValueError("merge_and_score(): `lineage` list is empty; nothing to merge")
1903
+
1904
+ # If no sequences found, still build a DataFrame so caller can decide what to do.
1905
+ df = _merge_lineage_and_sequences(lineage, seqs, doi)
1906
+
1907
+ # Basic sanity: warn if many missing sequences
1908
+ missing_rate = df["aa_seq"].isna().mean() if "aa_seq" in df else 1.0
1909
+ if missing_rate > 0.5:
1910
+ log.warning(">50%% of variants lack sequences (%d / %d)", df["aa_seq"].isna().sum(), len(df))
1911
+
1912
+ return df
1913
+
1914
+ # -------------------------------------------------------------------- end 8 ---
1915
+
1916
+ # === 9. PIPELINE ORCHESTRATOR === --------------------------------------------
1917
+ """High-level function that ties together PDF parsing, LLM calls, merging, and
1918
+ CSV export. This is what both the CLI (Section 10) and other Python callers
1919
+ should invoke.
1920
+
1921
+ **New behaviour (June 2025)** - The lineage table is now written to disk *before*
1922
+ sequence extraction begins so that users keep partial results even if the
1923
+ LLM stalls on the longer sequence prompt. The same `--output` path is used;
1924
+ we first save the lineage-only CSV, then overwrite it later with the merged
1925
+ (final) DataFrame.
1926
+ """
1927
+
1928
+ import time
1929
+ from pathlib import Path
1930
+ from typing import Union
1931
+ import pandas as pd
1932
+
1933
+
1934
+ def _lineage_to_dataframe(lineage: list[Variant]) -> pd.DataFrame:
1935
+ """Convert a list[Variant] to a tidy DataFrame (helper for early dump)."""
1936
+ return pd.DataFrame(
1937
+ {
1938
+ "variant_id": [v.variant_id for v in lineage],
1939
+ "parent_id": [v.parent_id for v in lineage],
1940
+ "generation": [v.generation for v in lineage],
1941
+ "mutations": [";".join(v.mutations) if v.mutations else None for v in lineage],
1942
+ "campaign_id": [v.campaign_id for v in lineage],
1943
+ "notes": [v.notes for v in lineage],
1944
+ }
1945
+ )
1946
+
1947
+
1948
+ def run_pipeline(
1949
+ manuscript: Union[str, Path],
1950
+ si: Optional[Union[str, Path]] = None,
1951
+ output_csv: Optional[Union[str, Path]] = None,
1952
+ *,
1953
+ debug_dir: str | Path | None = None,
1954
+ ) -> pd.DataFrame:
1955
+ """Execute the end-to-end extraction pipeline.
1956
+
1957
+ Parameters
1958
+ ----------
1959
+ manuscript : str | Path
1960
+ Path to the main PDF file.
1961
+ si : str | Path | None, optional
1962
+ Path to the Supplementary Information PDF, if available.
1963
+ output_csv : str | Path | None, optional
1964
+ If provided, **both** the early lineage table *and* the final merged
1965
+ table will be written to this location (the final write overwrites
1966
+ the first).
1967
+
1968
+ Returns
1969
+ -------
1970
+ pandas.DataFrame
1971
+ One row per variant with lineage, sequences, and provenance.
1972
+ """
1973
+
1974
+ t0 = time.perf_counter()
1975
+ manuscript = Path(manuscript)
1976
+ si_path = Path(si) if si else None
1977
+
1978
+ # 1. Prepare raw text ------------------------------------------------------
1979
+ # Always load both caption text (for identification) and full text (for extraction)
1980
+ pdf_paths = [p for p in (si_path, manuscript) if p]
1981
+ caption_text = limited_caption_concat(*pdf_paths)
1982
+ full_text = limited_concat(*pdf_paths)
1983
+
1984
+ log.info("Loaded %d chars of captions for identification and %d chars of full text for extraction",
1985
+ len(caption_text), len(full_text))
1986
+
1987
+ # 2. Connect to Gemini -----------------------------------------------------
1988
+ model = get_model()
1989
+
1990
+ # 3. Extract lineage (Section 6) ------------------------------------------
1991
+ lineage, campaigns = get_lineage(caption_text, full_text, model, pdf_paths=pdf_paths, debug_dir=debug_dir)
1992
+
1993
+ if not lineage:
1994
+ raise RuntimeError("Pipeline aborted: failed to extract any lineage data")
1995
+
1996
+ # Save campaigns info if debug_dir provided
1997
+ if debug_dir and campaigns:
1998
+ campaigns_file = Path(debug_dir) / "campaigns.json"
1999
+ campaigns_data = [
2000
+ {
2001
+ "campaign_id": c.campaign_id,
2002
+ "campaign_name": c.campaign_name,
2003
+ "description": c.description,
2004
+ "model_substrate": c.model_substrate,
2005
+ "model_product": c.model_product,
2006
+ "substrate_id": c.substrate_id,
2007
+ "product_id": c.product_id,
2008
+ "data_locations": c.data_locations,
2009
+ "notes": c.notes
2010
+ }
2011
+ for c in campaigns
2012
+ ]
2013
+ _dump(json.dumps(campaigns_data, indent=2), campaigns_file)
2014
+ log.info(f"Saved {len(campaigns)} campaigns to {campaigns_file}")
2015
+
2016
+ # 3a. EARLY SAVE -------------------------------------------------------------
2017
+ if output_csv:
2018
+ early_df = _lineage_to_dataframe(lineage)
2019
+ output_csv_path = Path(output_csv)
2020
+ # Save lineage-only data with specific filename
2021
+ lineage_path = output_csv_path.parent / "enzyme_lineage_data.csv"
2022
+ early_df.to_csv(lineage_path, index=False)
2023
+ log.info(
2024
+ "Saved lineage-only CSV -> %s",
2025
+ lineage_path,
2026
+ )
2027
+
2028
+ # 4. Extract sequences (Section 7) ----------------------------------------
2029
+ sequences = get_sequences(full_text, model, pdf_paths=pdf_paths, debug_dir=debug_dir)
2030
+
2031
+ # 4a. Try PDB extraction if no sequences found -----------------------------
2032
+ if not sequences or all(s.aa_seq is None for s in sequences):
2033
+ log.info("No sequences found in paper, attempting PDB extraction...")
2034
+
2035
+ # Extract PDB IDs from all PDFs
2036
+ pdb_ids = []
2037
+ for pdf_path in pdf_paths:
2038
+ pdb_ids.extend(extract_pdb_ids(pdf_path))
2039
+
2040
+ if pdb_ids:
2041
+ log.info(f"Found PDB IDs: {pdb_ids}")
2042
+
2043
+ # Try each PDB ID until we get sequences
2044
+ for pdb_id in pdb_ids:
2045
+ pdb_sequences = fetch_pdb_sequences(pdb_id)
2046
+
2047
+ if pdb_sequences:
2048
+ # Match PDB chains to variants
2049
+ variant_to_chain = match_pdb_to_variants(
2050
+ pdb_sequences, lineage, full_text, model, pdb_id
2051
+ )
2052
+
2053
+ # Convert to SequenceBlock objects
2054
+ pdb_seq_blocks = []
2055
+ for variant in lineage:
2056
+ if variant.variant_id in variant_to_chain:
2057
+ chain_id = variant_to_chain[variant.variant_id]
2058
+ if chain_id in pdb_sequences:
2059
+ seq_block = SequenceBlock(
2060
+ variant_id=variant.variant_id,
2061
+ aa_seq=pdb_sequences[chain_id],
2062
+ dna_seq=None,
2063
+ confidence=1.0, # High confidence for PDB sequences
2064
+ truncated=False,
2065
+ metadata={"source": "PDB", "pdb_id": pdb_id, "chain": chain_id}
2066
+ )
2067
+ pdb_seq_blocks.append(seq_block)
2068
+ log.info(f"Added PDB sequence for {variant.variant_id} from {pdb_id}:{chain_id}")
2069
+
2070
+ if pdb_seq_blocks:
2071
+ sequences = pdb_seq_blocks
2072
+ log.info(f"Successfully extracted {len(pdb_seq_blocks)} sequences from PDB {pdb_id}")
2073
+ break
2074
+ else:
2075
+ log.warning(f"No sequences found in PDB {pdb_id}")
2076
+ else:
2077
+ log.warning("No PDB IDs found in paper")
2078
+
2079
+ # 5. Merge & score (Section 8) --------------------------------------------
2080
+ doi = extract_doi(manuscript)
2081
+ df_final = merge_and_score(lineage, sequences, doi)
2082
+
2083
+ # 6. Write FINAL CSV -------------------------------------------------------
2084
+ if output_csv:
2085
+ output_csv_path = Path(output_csv)
2086
+ # Save final data with sequences using same filename (overwrites lineage-only)
2087
+ sequence_path = output_csv_path.parent / "enzyme_lineage_data.csv"
2088
+ df_final.to_csv(sequence_path, index=False)
2089
+ log.info(
2090
+ "Overwrote with final results -> %s (%.1f kB)",
2091
+ sequence_path,
2092
+ sequence_path.stat().st_size / 1024,
2093
+ )
2094
+
2095
+ log.info(
2096
+ "Pipeline finished in %.2f s (variants: %d)",
2097
+ time.perf_counter() - t0,
2098
+ len(df_final),
2099
+ )
2100
+ return df_final
2101
+
2102
+ # -------------------------------------------------------------------- end 9 ---
2103
+
2104
+ # === 10. CLI ENTRYPOINT === ----------------------------------------------
2105
+ """Simple argparse wrapper so the script can be run from the command line
2106
+
2107
+ Example:
2108
+
2109
+ python enzyme_lineage_extractor.py \
2110
+ --manuscript paper.pdf \
2111
+ --si supp.pdf \
2112
+ --output lineage.csv \
2113
+ --captions-only -v
2114
+ """
2115
+
2116
+ import argparse
2117
+ import logging
2118
+ from typing import List, Optional
2119
+
2120
+
2121
+ # -- 10.1 Argument parser ----------------------------------------------------
2122
+
2123
+ def _build_arg_parser() -> argparse.ArgumentParser:
2124
+ p = argparse.ArgumentParser(
2125
+ prog="enzyme_lineage_extractor",
2126
+ description="Extract enzyme variant lineage and sequences from PDFs using Google Gemini",
2127
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
2128
+ )
2129
+ p.add_argument("--manuscript", required=True, help="Path to main manuscript PDF")
2130
+ p.add_argument("--si", help="Path to Supplementary Information PDF")
2131
+ p.add_argument("-o", "--output", help="CSV file for extracted data")
2132
+ p.add_argument(
2133
+ "-v",
2134
+ "--verbose",
2135
+ action="count",
2136
+ default=0,
2137
+ help="Increase verbosity; repeat (-vv) for DEBUG logging",
2138
+ )
2139
+ p.add_argument(
2140
+ "--debug-dir",
2141
+ metavar="DIR",
2142
+ help="Write ALL intermediate artefacts (captions, prompts, raw Gemini replies) to DIR",
2143
+ )
2144
+ return p
2145
+
2146
+
2147
+ # -- 10.2 main() -------------------------------------------------------------
2148
+
2149
+ def main(argv: Optional[List[str]] = None) -> None:
2150
+ parser = _build_arg_parser()
2151
+ args = parser.parse_args(argv)
2152
+
2153
+ # Configure logging early so everything respects the chosen level.
2154
+ level = logging.DEBUG if args.verbose >= 2 else logging.INFO if args.verbose else logging.WARNING
2155
+ logging.basicConfig(level=level, format="%(levelname)s: %(message)s")
2156
+
2157
+ run_pipeline(
2158
+ manuscript=args.manuscript,
2159
+ si=args.si,
2160
+ output_csv=args.output,
2161
+ debug_dir=args.debug_dir,
2162
+ )
2163
+
2164
+
2165
+ if __name__ == "__main__":
2166
+ main()
2167
+
2168
+ # -------------------------------------------------------------------- end 10 ---
2169
+