flowapy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flowa/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """Flowa - Variant literature assessment pipeline."""
2
+
3
+ __version__ = '0.1.0'
flowa/aggregate.py ADDED
@@ -0,0 +1,372 @@
1
+ """Aggregate evidence across all papers for a variant."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import re
7
+ import time
8
+ from typing import Any
9
+
10
+ import logfire
11
+ import typer
12
+ from pydantic import BaseModel
13
+ from pydantic_ai import Agent, ModelRetry, NativeOutput, RunContext
14
+
15
+ from flowa.clinvar import format_clinvar_for_prompt, query_clinvar
16
+ from flowa.models import create_model, get_model_settings
17
+ from flowa.prompts import load_prompt_and_schema
18
+ from flowa.resolve import CitationQuery, resolve_citations
19
+ from flowa.schema import AGGREGATION_SCHEMA_VERSION, with_schema_version
20
+ from flowa.settings import ModelConfig, Settings
21
+ from flowa.storage import (
22
+ assessment_url,
23
+ encode_doi,
24
+ exists,
25
+ paper_url,
26
+ read_bytes,
27
+ read_json,
28
+ read_text,
29
+ write_bytes,
30
+ write_json,
31
+ )
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+ # Cap for thinking + structured-output combined; matches Sonnet 4.6's max output.
36
+ _AGGREGATE_MAX_TOKENS = 64_000
37
+
38
+
39
+ # Paper ID generation ({LastName}{Year} format), ported from palit.
40
+
41
+
42
+ def _extract_first_author_last_name(authors: str) -> str:
43
+ """Extract first author's last name from authors string.
44
+
45
+ Authors are semicolon-separated in "Last, First" format:
46
+ "Smith, John A; Doe, Jane B; ..." -> "Smith"
47
+ "van der Berg, Anna; Doe, Jane; ..." -> "VanDerBerg"
48
+ """
49
+ if not authors:
50
+ return 'Unknown'
51
+ first_author = authors.split(';')[0].strip()
52
+ # Take everything before the comma (the last name portion)
53
+ last_name = first_author.split(',')[0].strip()
54
+ if not last_name:
55
+ return 'Unknown'
56
+ parts = last_name.split()
57
+ # Join multi-word last names, capitalize each part, remove non-alpha
58
+ return ''.join(re.sub(r'[^A-Za-z]', '', p).capitalize() for p in parts)
59
+
60
+
61
+ def generate_paper_ids(
62
+ evidence_list: list[dict[str, Any]],
63
+ ) -> tuple[dict[str, str], dict[str, str]]:
64
+ """Generate {LastName}{Year} paper IDs from evidence list.
65
+
66
+ Each item must have 'doi', 'authors', and 'date' keys.
67
+
68
+ Returns:
69
+ Tuple of (paper_id_to_doi, doi_to_paper_id) mappings.
70
+ Collisions are disambiguated with letter suffixes (a, b, c).
71
+ """
72
+ base_id_to_dois: dict[str, list[str]] = {}
73
+
74
+ for evidence in evidence_list:
75
+ doi = evidence['doi']
76
+ authors = evidence.get('authors', '')
77
+ date = evidence.get('date', '')
78
+ last_name = _extract_first_author_last_name(authors)
79
+ year = date[:4] if date and len(date) >= 4 else 'Unknown'
80
+ base_id = f'{last_name}{year}'
81
+ base_id_to_dois.setdefault(base_id, []).append(doi)
82
+
83
+ paper_id_to_doi: dict[str, str] = {}
84
+ doi_to_paper_id: dict[str, str] = {}
85
+ for base_id, dois in base_id_to_dois.items():
86
+ if len(dois) == 1:
87
+ paper_id_to_doi[base_id] = dois[0]
88
+ doi_to_paper_id[dois[0]] = base_id
89
+ else:
90
+ for i, doi in enumerate(sorted(dois)):
91
+ suffixed_id = f'{base_id}{chr(ord("a") + i)}'
92
+ paper_id_to_doi[suffixed_id] = doi
93
+ doi_to_paper_id[doi] = suffixed_id
94
+
95
+ return paper_id_to_doi, doi_to_paper_id
96
+
97
+
98
+ # One counter per rule label — lets us see how often aggregation's LLM
99
+ # produces shape-invalid output and which rule it trips on. Logfire is
100
+ # configured in cli.py; when it isn't (tests), the counter is a no-op.
101
+ _aggregate_retry_counter = logfire.metric_counter(
102
+ 'flowa_aggregate_validation_errors_total',
103
+ description='Shape-validation rule violations found by the aggregate output_validator',
104
+ )
105
+
106
+
107
+ def create_aggregate_agent(
108
+ model: ModelConfig,
109
+ paper_id_to_doi: dict[str, str],
110
+ output_type: type[BaseModel],
111
+ ) -> Agent[None, BaseModel]:
112
+ """Create a Pydantic AI agent with paper_id validation."""
113
+ agent: Agent[None, BaseModel] = Agent(
114
+ create_model(model),
115
+ output_type=NativeOutput(output_type),
116
+ retries=3,
117
+ model_settings=get_model_settings(model, effort='medium', max_tokens=_AGGREGATE_MAX_TOKENS),
118
+ )
119
+
120
+ @agent.output_validator
121
+ def validate_shape(ctx: RunContext[None], result: BaseModel) -> BaseModel:
122
+ """Shape-only validation: paper_id membership + group integrity.
123
+
124
+ Semantic cross-field checks (citation fidelity, grouping order) live in
125
+ chat-service — every artifact mutation flows through it, so the rules
126
+ stay authored in one place.
127
+ """
128
+ errors: list[str] = []
129
+
130
+ for cat_result in result.results: # type: ignore[attr-defined]
131
+ code = getattr(cat_result, 'code', '<unknown>')
132
+ paper_ids_in_papers = [p.paper_id for p in cat_result.papers]
133
+ paper_ids_set = set(paper_ids_in_papers)
134
+
135
+ if len(paper_ids_in_papers) != len(paper_ids_set):
136
+ duplicates = [pid for pid in paper_ids_set if paper_ids_in_papers.count(pid) > 1]
137
+ errors.append(f'code={code}: papers[] has duplicate paper_ids: {sorted(duplicates)}')
138
+ _aggregate_retry_counter.add(1, {'rule': 'paper_id_duplicate'})
139
+
140
+ for paper in cat_result.papers:
141
+ if paper.paper_id not in paper_id_to_doi:
142
+ errors.append(f'code={code}: papers[] has unknown paper_id={paper.paper_id}')
143
+ _aggregate_retry_counter.add(1, {'rule': 'paper_id_unknown'})
144
+
145
+ for claim in cat_result.claims:
146
+ if claim.paper_id not in paper_ids_set:
147
+ errors.append(f'code={code}: claim cites paper_id={claim.paper_id} not present in papers[]')
148
+ _aggregate_retry_counter.add(1, {'rule': 'claim_paper_missing'})
149
+
150
+ if errors:
151
+ raise ModelRetry('Invalid aggregate output: ' + '; '.join(errors))
152
+
153
+ return result
154
+
155
+ return agent
156
+
157
+
158
+ def resolve_aggregate_citations(
159
+ aggregate_dict: dict[str, Any],
160
+ paper_id_to_doi: dict[str, str],
161
+ pdf_bytes_cache: dict[str, bytes],
162
+ markdown_cache: dict[str, str],
163
+ metadata_cache: dict[str, dict[str, Any]],
164
+ ) -> None:
165
+ """Post-process aggregate output: resolve quotes to bboxes on claim citations.
166
+
167
+ Delegates the actual alignment to `flowa.resolve.resolve_citations` with
168
+ cache-backed loaders, then attaches the resulting bboxes onto each claim's
169
+ citations in place. Claims are grouped by paper_id so every (paper_id,
170
+ quote) pair resolves to exactly one paper.
171
+ """
172
+ # Collect all (doi, quote) pairs, grouped by DOI.
173
+ doi_quotes: dict[str, list[str]] = {}
174
+ for cat_result in aggregate_dict['results']:
175
+ for claim in cat_result['claims']:
176
+ doi = paper_id_to_doi[claim['paper_id']]
177
+ for citation in claim['citations']:
178
+ doi_quotes.setdefault(doi, []).append(citation['quote'])
179
+
180
+ citations_input = [CitationQuery(doi=doi, quotes=quotes) for doi, quotes in doi_quotes.items()]
181
+ result = resolve_citations(
182
+ citations_input,
183
+ pdf_loader=pdf_bytes_cache.get,
184
+ markdown_loader=markdown_cache.get,
185
+ )
186
+
187
+ # Attach resolved bboxes onto each claim's citations.
188
+ for cat_result in aggregate_dict['results']:
189
+ for claim in cat_result['claims']:
190
+ doi = paper_id_to_doi[claim['paper_id']]
191
+ for citation in claim['citations']:
192
+ quote = citation['quote']
193
+ bboxes = result.resolved.get(doi, {}).get(quote, [])
194
+ citation['bboxes'] = [b.model_dump() for b in bboxes]
195
+ if not bboxes:
196
+ log.warning('No bboxes resolved for %s quote: %.80s...', doi, quote)
197
+
198
+ # Add paper_id_mapping: {AuthorYear -> {doi, pmid}} for cross-referencing
199
+ # prose citations with papers. Consumers build the reverse index on read.
200
+ aggregate_dict['paper_id_mapping'] = {
201
+ pid: {'doi': doi, 'pmid': metadata_cache[doi].get('pmid')} for pid, doi in paper_id_to_doi.items()
202
+ }
203
+
204
+
205
+ async def aggregate_evidence_async(
206
+ base: str,
207
+ variant_id: str,
208
+ model: ModelConfig,
209
+ ncbi_api_key: str | None = None,
210
+ prompt_set: str = 'generic',
211
+ dry_run: bool = False,
212
+ ) -> None:
213
+ """Aggregate evidence across all papers for a variant."""
214
+ aggregation_url = assessment_url(base, variant_id, 'aggregation.json')
215
+ aggregation_raw_url = assessment_url(base, variant_id, 'aggregation_raw.json')
216
+
217
+ # Load variant details and query data (stored by query command)
218
+ variant_details = json.dumps(read_json(assessment_url(base, variant_id, 'variant_details.json')))
219
+ query_data = read_json(assessment_url(base, variant_id, 'query.json'))
220
+ dois = query_data['dois']
221
+
222
+ # Reconstruct the colon-glued HGVS c. expression for ClinVar exact-phrase
223
+ # search. The variant_spec persisted here carries `transcript` and `hgvs_c`
224
+ # as separate fields; downstream callers join them at the use site.
225
+ spec_item = query_data['variant_spec']['variants'][0]
226
+ hgvs_c_full = f'{spec_item["transcript"]}:{spec_item["hgvs_c"]}'
227
+
228
+ # Fetch ClinVar evidence
229
+ clinvar_data = query_clinvar(hgvs_c_full, ncbi_api_key)
230
+ clinvar_text = format_clinvar_for_prompt(clinvar_data)
231
+
232
+ # Load extractions and metadata for each paper. PDF bytes and Markdown are
233
+ # cached for post-LLM citation resolution: PdfIndex takes both (markdown
234
+ # denoises the indexed PDF chars). Both files are produced together by the
235
+ # pipeline; a missing markdown.md at this point is a storage corruption and
236
+ # surfaces as FileNotFoundError below.
237
+ evidence_extractions: list[dict[str, Any]] = []
238
+ pdf_bytes_cache: dict[str, bytes] = {}
239
+ markdown_cache: dict[str, str] = {}
240
+ metadata_cache: dict[str, dict[str, Any]] = {}
241
+
242
+ for doi in dois:
243
+ extraction_url = assessment_url(base, variant_id, 'extractions', f'{encode_doi(doi)}.json')
244
+
245
+ if not exists(extraction_url):
246
+ log.info('Skipping %s: no extraction', doi)
247
+ continue
248
+
249
+ extraction_data = read_json(extraction_url)
250
+
251
+ if not extraction_data.get('variant_discussed'):
252
+ log.info('Skipping %s: variant not discussed', doi)
253
+ continue
254
+
255
+ pdf_bytes_cache[doi] = read_bytes(paper_url(base, doi, 'source.pdf'))
256
+ markdown_cache[doi] = read_text(paper_url(base, doi, 'markdown.md'))
257
+ metadata = read_json(paper_url(base, doi, 'metadata.json'))
258
+ metadata_cache[doi] = metadata
259
+
260
+ # Support both new-shape ('claims') and legacy ('evidence') extractions for
261
+ # backfill convenience — the schema renamed EvidenceFinding -> Claim and
262
+ # dropped commentary, but we still want to consume older extraction JSON.
263
+ claims = extraction_data.get('claims')
264
+ if claims is None:
265
+ legacy = extraction_data.get('evidence', [])
266
+ claims = [
267
+ {
268
+ 'text': item.get('finding', ''),
269
+ 'citations': [{'quote': c['quote']} for c in item.get('citations', [])],
270
+ }
271
+ for item in legacy
272
+ ]
273
+
274
+ entry: dict[str, Any] = {
275
+ 'doi': doi,
276
+ 'title': metadata['title'],
277
+ 'authors': metadata['authors'],
278
+ 'date': metadata['date'],
279
+ 'claims': claims,
280
+ }
281
+ if metadata.get('pmid'):
282
+ entry['pmid'] = metadata['pmid']
283
+ evidence_extractions.append(entry)
284
+
285
+ # Generate paper_ids and replace DOIs with human-readable IDs for the LLM
286
+ paper_id_to_doi: dict[str, str] = {}
287
+ if evidence_extractions:
288
+ paper_id_to_doi, doi_to_paper_id = generate_paper_ids(evidence_extractions)
289
+ for entry in evidence_extractions:
290
+ entry['paper_id'] = doi_to_paper_id[entry.pop('doi')]
291
+ evidence_extractions.sort(key=lambda x: x['date'], reverse=True)
292
+
293
+ log.info(
294
+ 'Aggregating evidence from %d papers + ClinVar (model: %s)',
295
+ len(evidence_extractions),
296
+ model.name,
297
+ )
298
+
299
+ # Load prompt and schema from prompt set
300
+ prompt_template, output_type = load_prompt_and_schema('aggregation', prompt_set)
301
+
302
+ evidence_text = (
303
+ json.dumps(evidence_extractions, indent=2)
304
+ if evidence_extractions
305
+ else 'No papers discussing this variant were found.'
306
+ )
307
+ prompt = prompt_template.render(
308
+ variant_details=variant_details,
309
+ clinvar_data=clinvar_text,
310
+ evidence_extractions=evidence_text,
311
+ )
312
+
313
+ if dry_run:
314
+ print('=== PROMPT ===')
315
+ print(prompt)
316
+ print('\n=== PAPER ID MAPPING ===')
317
+ for pid, doi in paper_id_to_doi.items():
318
+ print(f' {pid} -> {doi}')
319
+ return
320
+
321
+ agent = create_aggregate_agent(model, paper_id_to_doi, output_type)
322
+
323
+ log.info('Calling LLM for aggregate assessment')
324
+ t0 = time.monotonic()
325
+ # Stream so bytes flow during extended thinking; otherwise the connection
326
+ # goes silent for many minutes and trips our Bedrock read_timeout.
327
+ async with agent.run_stream(prompt) as stream_result:
328
+ output = await stream_result.get_output()
329
+ raw_messages_json = stream_result.all_messages_json()
330
+ elapsed = time.monotonic() - t0
331
+
332
+ # Post-LLM: resolve quotes to bboxes, replace paper_id with DOI
333
+ aggregate_dict = output.model_dump()
334
+ with logfire.span('flowa.resolve_citations', paper_count=len(paper_id_to_doi)):
335
+ resolve_aggregate_citations(aggregate_dict, paper_id_to_doi, pdf_bytes_cache, markdown_cache, metadata_cache)
336
+
337
+ # Store structured aggregation result
338
+ write_json(aggregation_url, with_schema_version(aggregate_dict, AGGREGATION_SCHEMA_VERSION))
339
+
340
+ # Store raw LLM conversation for debugging
341
+ write_bytes(aggregation_raw_url, raw_messages_json)
342
+
343
+ results_list = output.results # type: ignore[attr-defined]
344
+ total_claims = sum(len(cat_result.claims) for cat_result in results_list)
345
+ total_papers = sum(len(cat_result.papers) for cat_result in results_list)
346
+ log.info(
347
+ 'Aggregated variant %s: %d categories, %d claims across %d papers in %.1fs',
348
+ variant_id,
349
+ len(results_list),
350
+ total_claims,
351
+ total_papers,
352
+ elapsed,
353
+ )
354
+
355
+
356
+ def aggregate_evidence(
357
+ variant_id: str = typer.Option(..., '--variant-id', help='Variant identifier'),
358
+ dry_run: bool = typer.Option(False, '--dry-run', help='Dump prompt and exit without calling LLM'),
359
+ ) -> None:
360
+ """Aggregate evidence across all papers for a variant.
361
+
362
+ Reads extraction results from assessments/{variant_id}/extractions/,
363
+ variant details from variant_details.json, and paper metadata from
364
+ papers/{encoded_doi}/metadata.json. Calls LLM for aggregate assessment and
365
+ stores result to assessments/{variant_id}/aggregation.json.
366
+ """
367
+ s = Settings() # type: ignore[call-arg]
368
+ asyncio.run(
369
+ aggregate_evidence_async(
370
+ s.flowa_storage_base, variant_id, s.flowa_extraction_model, s.ncbi_api_key, s.flowa_prompt_set, dry_run
371
+ )
372
+ )
flowa/cli.py ADDED
@@ -0,0 +1,81 @@
1
+ """Main CLI entry point for Flowa."""
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ import logfire
8
+ import typer
9
+ from pydantic_ai import Agent
10
+ from pydantic_ai.models.instrumented import InstrumentationSettings
11
+
12
+ from flowa import __version__, aggregate, convert, download, extract, query, resolve, run
13
+
14
+ app = typer.Typer(
15
+ name='flowa',
16
+ help='Variant literature assessment pipeline with AI extraction',
17
+ add_completion=False,
18
+ )
19
+
20
+ # Configure root logger to write to stderr (stdout reserved for structured output)
21
+ # Done after imports so force=True overrides any library-configured logging
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s %(levelname)s %(name)s: %(message)s',
25
+ datefmt='%Y-%m-%d %H:%M:%S',
26
+ stream=sys.stderr,
27
+ force=True, # Override any existing configuration
28
+ )
29
+
30
+ # Set up OpenTelemetry via logfire. send_to_logfire=False means we only export
31
+ # via OTLP (to OTEL_EXPORTER_OTLP_ENDPOINT). Gracefully degrades to no-op when
32
+ # the env var is unset (local dev, tests).
33
+ logfire.configure(
34
+ send_to_logfire=False,
35
+ service_name='flowa-worker',
36
+ )
37
+ # Patches Agent class globally — captures every agent run (extraction,
38
+ # aggregation, transcription) with zero changes in those modules.
39
+ Agent.instrument_all(
40
+ InstrumentationSettings(
41
+ include_content=False,
42
+ version=3,
43
+ )
44
+ )
45
+
46
+
47
+ @app.callback()
48
+ def main(
49
+ log_level: str = typer.Option(
50
+ os.environ.get('FLOWA_LOG_LEVEL', 'INFO'),
51
+ '--log-level',
52
+ '-l',
53
+ help='Logging level (DEBUG, INFO, WARNING, ERROR)',
54
+ ),
55
+ ) -> None:
56
+ """Flowa - Variant literature assessment pipeline."""
57
+ level = getattr(logging, log_level.upper(), logging.INFO)
58
+ logging.getLogger().setLevel(level)
59
+ # Suppress noisy third-party loggers even at DEBUG
60
+ for name in ('pdfminer', 'pdfplumber', 'pypdfium2'):
61
+ logging.getLogger(name).setLevel(logging.WARNING)
62
+
63
+
64
+ # Register commands
65
+ app.command(name='run')(run.run)
66
+ app.command(name='query')(query.query_dois)
67
+ app.command(name='download')(download.download_paper)
68
+ app.command(name='convert')(convert.convert_paper)
69
+ app.command(name='extract')(extract.extract_paper)
70
+ app.command(name='aggregate')(aggregate.aggregate_evidence)
71
+ app.command(name='resolve')(resolve.resolve)
72
+
73
+
74
+ @app.command()
75
+ def version() -> None:
76
+ """Show version information."""
77
+ print(f'Flowa version {__version__}')
78
+
79
+
80
+ if __name__ == '__main__':
81
+ app()