norm_toolkit 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,840 @@
1
+ """
2
+ Async PostgreSQL normalizer for biomedical concept normalization.
3
+
4
+ Works with PostgreSQL databases using the same schema as DuckDB databases
5
+ built by build_umls_duckdb, build_ontology_duckdb, or build_merged_duckdb.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import json
12
+ from collections.abc import Mapping, Sequence
13
+
14
+ import asyncpg
15
+ import polars as pl
16
+ from lvg_norm import lvg_normalize
17
+
18
+ from norm_toolkit.constants import (
19
+ ATOMS_TABLE,
20
+ CONCEPTS_TABLE,
21
+ DEFAULT_PREFER_TTYS,
22
+ DEFS_TABLE,
23
+ EDGES_TABLE,
24
+ EXACT_BUMP,
25
+ HIT_STRUCT_TYPE,
26
+ ISPREF_WEIGHT,
27
+ NS_TABLE,
28
+ NW_TABLE,
29
+ RANK_MULTIPLIER,
30
+ STT_WEIGHT,
31
+ TTY_WEIGHT,
32
+ TYPES_TABLE,
33
+ )
34
+ from norm_toolkit.models import ConceptInfo, SemanticType
35
+
36
+
37
+ class PostgresNormalizer:
38
+ """
39
+ Async normalizer using PostgreSQL via asyncpg.
40
+
41
+ Optimized for small batch processing (1-5 strings at a time).
42
+ Uses VALUES clauses instead of temp tables for efficiency with small batches.
43
+ """
44
+
45
+ def __init__(self, pool: asyncpg.Pool, schema: str = "public") -> None:
46
+ """
47
+ Initialize the normalizer with an external connection pool.
48
+
49
+ Args:
50
+ pool: asyncpg connection pool (caller manages lifecycle)
51
+ schema: PostgreSQL schema where tables are located (default: "public")
52
+
53
+ Note:
54
+ After creating the normalizer, call `await normalizer.initialize()`
55
+ to detect database capabilities before using other methods.
56
+ """
57
+ self._pool = pool
58
+ self._schema = schema
59
+ self._loop: asyncio.AbstractEventLoop | None = None
60
+ self._has_types = False
61
+ self._has_defs = False
62
+ self._has_edges = False
63
+ self._has_stt = False
64
+ self._initialized = False
65
+
66
+ # Build qualified table names
67
+ prefix = f"{schema}." if schema else ""
68
+ self._ns_table = f"{prefix}{NS_TABLE}"
69
+ self._nw_table = f"{prefix}{NW_TABLE}"
70
+ self._atoms_table = f"{prefix}{ATOMS_TABLE}"
71
+ self._concepts_table = f"{prefix}{CONCEPTS_TABLE}"
72
+ self._types_table = f"{prefix}{TYPES_TABLE}"
73
+ self._defs_table = f"{prefix}{DEFS_TABLE}"
74
+ self._edges_table = f"{prefix}{EDGES_TABLE}"
75
+
76
+ @classmethod
77
+ def create_sync(cls, dsn: str, schema: str = "public", min_size: int = 1, max_size: int = 10) -> PostgresNormalizer:
78
+ """
79
+ Create a normalizer synchronously with its own event loop.
80
+
81
+ Use this factory method for sync-only usage. The normalizer will manage
82
+ its own event loop and pool, allowing you to call normalize_sync().
83
+
84
+ Args:
85
+ dsn: PostgreSQL connection string (e.g., "postgresql://user:pass@host:5432/db")
86
+ schema: PostgreSQL schema where tables are located (default: "public")
87
+ min_size: Minimum pool connections
88
+ max_size: Maximum pool connections
89
+
90
+ Example:
91
+ >>> normalizer = PostgresNormalizer.create_sync("postgresql://...")
92
+ >>> result = normalizer.normalize_sync(["diabetes"])
93
+ >>> normalizer.close_sync()
94
+ """
95
+ loop = asyncio.new_event_loop()
96
+
97
+ async def _create():
98
+ pool = await asyncpg.create_pool(dsn, min_size=min_size, max_size=max_size)
99
+ return pool
100
+
101
+ pool = loop.run_until_complete(_create())
102
+ instance = cls(pool, schema=schema)
103
+ instance._loop = loop
104
+ loop.run_until_complete(instance.initialize())
105
+ return instance
106
+
107
+ async def initialize(self) -> None:
108
+ """
109
+ Detect database capabilities.
110
+
111
+ Must be called after __init__ before using normalize/concept_info methods.
112
+ """
113
+ self._has_types = await self._table_has_rows(self._types_table)
114
+ self._has_defs = await self._table_has_rows(self._defs_table)
115
+ self._has_edges = await self._table_has_rows(self._edges_table)
116
+ self._has_stt = await self._column_has_values(self._atoms_table, "stt")
117
+ self._initialized = True
118
+
119
+ async def _table_has_rows(self, table: str) -> bool:
120
+ """Check if a table exists and has rows."""
121
+ try:
122
+ async with self._pool.acquire() as con:
123
+ result = await con.fetchval(f"SELECT 1 FROM {table} LIMIT 1")
124
+ return result is not None
125
+ except Exception:
126
+ return False
127
+
128
+ async def _column_has_values(self, table: str, column: str) -> bool:
129
+ """Check if a column has any non-null values."""
130
+ try:
131
+ async with self._pool.acquire() as con:
132
+ result = await con.fetchval(f"SELECT 1 FROM {table} WHERE {column} IS NOT NULL LIMIT 1")
133
+ return result is not None
134
+ except Exception:
135
+ return False
136
+
137
+ async def normalize(
138
+ self,
139
+ strings: Sequence[str],
140
+ top_k: int = 25,
141
+ prefer_ttys: list[str] | None = None,
142
+ filter_sources: list[str] | None = None,
143
+ exclude_sources: list[str] | None = None,
144
+ allow_partial: bool = True,
145
+ min_coverage: float = 0.6,
146
+ min_word_hits: int | None = None,
147
+ coverage_weight: int = 25,
148
+ ) -> pl.DataFrame:
149
+ """
150
+ Normalize input strings to ranked concepts.
151
+
152
+ Args:
153
+ strings: Input strings to normalize
154
+ top_k: Maximum number of results per query
155
+ prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
156
+ filter_sources: Restrict to these sources (include only)
157
+ exclude_sources: Exclude these sources
158
+ allow_partial: Enable word-overlap partial matching
159
+ min_coverage: Minimum fraction of query words that must match
160
+ min_word_hits: Minimum absolute word hits required
161
+ coverage_weight: Weight for coverage in scoring
162
+
163
+ Returns:
164
+ DataFrame with columns: input_string, hits (list of match structs)
165
+ """
166
+ if prefer_ttys is None:
167
+ prefer_ttys = DEFAULT_PREFER_TTYS
168
+
169
+ # Build normalized string map
170
+ q_to_nstrs: dict[str, list[str]] = {}
171
+ for s in strings:
172
+ nstrs = list(lvg_normalize(s) or [])
173
+ q_to_nstrs[s] = nstrs
174
+
175
+ return await self._lookup(
176
+ q_to_nstrs=q_to_nstrs,
177
+ all_queries=list(strings),
178
+ prefer_ttys=prefer_ttys,
179
+ filter_sources=filter_sources,
180
+ exclude_sources=exclude_sources,
181
+ top_k=top_k,
182
+ allow_partial=allow_partial,
183
+ min_coverage=min_coverage,
184
+ min_word_hits=min_word_hits,
185
+ coverage_weight=coverage_weight,
186
+ )
187
+
188
+ async def _lookup(
189
+ self,
190
+ q_to_nstrs: Mapping[str, Sequence[str]],
191
+ all_queries: Sequence[str],
192
+ prefer_ttys: list[str] | None,
193
+ filter_sources: list[str] | None,
194
+ exclude_sources: list[str] | None,
195
+ *,
196
+ top_k: int = 25,
197
+ allow_partial: bool = True,
198
+ min_coverage: float = 0.6,
199
+ min_word_hits: int | None = None,
200
+ coverage_weight: int = 25,
201
+ ) -> pl.DataFrame:
202
+ """Core lookup via exact + partial match paths."""
203
+ top_k = max(1, int(top_k))
204
+
205
+ # Flatten q_to_nstrs to rows
206
+ qmap_rows: list[tuple[str, str]] = []
207
+ for q, nstrs in q_to_nstrs.items():
208
+ for nstr in dict.fromkeys(nstrs):
209
+ if nstr:
210
+ qmap_rows.append((q, nstr))
211
+
212
+ if not qmap_rows:
213
+ return pl.DataFrame({"input_string": all_queries, "hits": [[] for _ in all_queries]}).cast(
214
+ {"hits": pl.List(HIT_STRUCT_TYPE)}
215
+ )
216
+
217
+ # Build parameters and VALUES clauses
218
+ params: list[str] = []
219
+
220
+ # qmap VALUES clause
221
+ qmap_placeholders = []
222
+ for q, nstr in qmap_rows:
223
+ idx = len(params)
224
+ params.extend([q, nstr])
225
+ qmap_placeholders.append(f"(${idx + 1}, ${idx + 2})")
226
+ qmap_values = ", ".join(qmap_placeholders)
227
+
228
+ # qwords VALUES clause (for partial path)
229
+ qwords_values = ""
230
+ if allow_partial:
231
+ qwords_rows = [(q, n, w) for q, n in qmap_rows for w in dict.fromkeys(n.split()) if w]
232
+ qwords_placeholders = []
233
+ for q, nstr, nwd in qwords_rows:
234
+ idx = len(params)
235
+ params.extend([q, nstr, nwd])
236
+ qwords_placeholders.append(f"(${idx + 1}, ${idx + 2}, ${idx + 3})")
237
+ qwords_values = ", ".join(qwords_placeholders)
238
+
239
+ # allq VALUES clause (preserve order)
240
+ allq_placeholders = []
241
+ for q in all_queries:
242
+ idx = len(params)
243
+ params.append(q)
244
+ allq_placeholders.append(f"(${idx + 1})")
245
+ allq_values = ", ".join(allq_placeholders)
246
+
247
+ # Build preference clauses
248
+ tty_join = ""
249
+ tty_bump_expr = "0"
250
+ if prefer_ttys:
251
+ tty_vals = ", ".join(f"('{t}')" for t in prefer_ttys)
252
+ tty_join = f"LEFT JOIN (VALUES {tty_vals}) AS pt(tty) ON a.name_type = pt.tty"
253
+ tty_bump_expr = "CASE WHEN pt.tty IS NULL THEN 0 ELSE 1 END"
254
+
255
+ # Source filtering
256
+ source_filter_exprs = []
257
+ nw_filter_clauses = []
258
+ if filter_sources:
259
+ filt_vals = ", ".join(f"'{src}'" for src in filter_sources)
260
+ source_filter_exprs.append(f"a.source IN ({filt_vals})")
261
+ nw_filter_clauses.append(f"nw.source IN ({filt_vals})")
262
+ if exclude_sources:
263
+ excl_vals = ", ".join(f"'{src}'" for src in exclude_sources)
264
+ source_filter_exprs.append(f"a.source NOT IN ({excl_vals})")
265
+ nw_filter_clauses.append(f"nw.source NOT IN ({excl_vals})")
266
+ nw_filter_clause = (" AND " + " AND ".join(nw_filter_clauses)) if nw_filter_clauses else ""
267
+ combined_where = f"WHERE {' AND '.join(source_filter_exprs)}" if source_filter_exprs else ""
268
+
269
+ # STT bump
270
+ stt_bump_expr = "CASE WHEN a.stt='PF' THEN 1 ELSE 0 END" if self._has_stt else "0"
271
+
272
+ # Scoring constants
273
+ min_hits_sql = str(min_word_hits) if min_word_hits is not None else "0"
274
+ cov_sql = f"{min_coverage:.6f}"
275
+
276
+ # Build exact match CTE
277
+ exact_cte = f"""
278
+ cand_exact AS (
279
+ SELECT
280
+ q.Q, q.NSTR,
281
+ a.concept_id,
282
+ a.identifier,
283
+ a.str,
284
+ a.source,
285
+ a.name_type,
286
+ a.ispref,
287
+ a.rank,
288
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
289
+ {stt_bump_expr} AS stt_bump,
290
+ {tty_bump_expr} AS tty_bump,
291
+ 1.0 AS coverage
292
+ FROM qmap q
293
+ JOIN {self._ns_table} ns ON ns.nstr = q.NSTR
294
+ JOIN {self._atoms_table} a
295
+ ON a.concept_id = ns.concept_id
296
+ AND a.name_id = ns.name_id
297
+ {tty_join}
298
+ {combined_where}
299
+ ),
300
+ dedup_exact AS (
301
+ SELECT *,
302
+ ROW_NUMBER() OVER (
303
+ PARTITION BY Q, concept_id
304
+ ORDER BY rank DESC, ispref_bump DESC, stt_bump DESC, tty_bump DESC, concept_id
305
+ ) AS rnc
306
+ FROM cand_exact
307
+ ),
308
+ scored_exact AS (
309
+ SELECT
310
+ Q, NSTR, concept_id, identifier, str, source, name_type, ispref, rank,
311
+ (rank*{RANK_MULTIPLIER} + ispref_bump*{ISPREF_WEIGHT} + stt_bump*{STT_WEIGHT}
312
+ + tty_bump*{TTY_WEIGHT} + {EXACT_BUMP} + ROUND(coverage * {coverage_weight}))::INTEGER AS total_score,
313
+ TRUE AS is_exact
314
+ FROM dedup_exact
315
+ WHERE rnc = 1
316
+ )
317
+ """
318
+
319
+ # Build partial match CTE (if enabled)
320
+ partial_cte = ""
321
+ union_partial = ""
322
+ if allow_partial and qwords_values:
323
+ partial_cte = f"""
324
+ ,
325
+ qn AS (
326
+ SELECT Q, NSTR, COUNT(DISTINCT NWD) AS need
327
+ FROM qwords
328
+ GROUP BY Q, NSTR
329
+ ),
330
+ hits AS (
331
+ SELECT qw.Q, qw.NSTR, nw.string_id, nw.concept_id,
332
+ COUNT(DISTINCT qw.NWD) AS hits
333
+ FROM qwords qw
334
+ JOIN {self._nw_table} nw ON nw.nwd = qw.NWD{nw_filter_clause}
335
+ GROUP BY qw.Q, qw.NSTR, nw.string_id, nw.concept_id
336
+ ),
337
+ good AS (
338
+ SELECT h.Q, h.NSTR, h.string_id, h.concept_id, h.hits, qn.need,
339
+ CAST(h.hits AS DOUBLE PRECISION)/NULLIF(qn.need,0) AS coverage
340
+ FROM hits h
341
+ JOIN qn ON qn.Q = h.Q AND qn.NSTR = h.NSTR
342
+ WHERE h.hits >= GREATEST({min_hits_sql}, CAST(CEIL(qn.need * {cov_sql}) AS INTEGER))
343
+ ),
344
+ cand_partial AS (
345
+ SELECT
346
+ g.Q, g.NSTR,
347
+ a.concept_id,
348
+ a.identifier,
349
+ a.str,
350
+ a.source,
351
+ a.name_type,
352
+ a.ispref,
353
+ a.rank,
354
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
355
+ {stt_bump_expr} AS stt_bump,
356
+ {tty_bump_expr} AS tty_bump,
357
+ COALESCE(g.coverage, 0.0) AS coverage
358
+ FROM good g
359
+ JOIN {self._atoms_table} a ON a.string_id = g.string_id
360
+ {tty_join}
361
+ {combined_where}
362
+ ),
363
+ dedup_partial AS (
364
+ SELECT *,
365
+ ROW_NUMBER() OVER (
366
+ PARTITION BY Q, concept_id
367
+ ORDER BY rank DESC, ispref_bump DESC, stt_bump DESC, tty_bump DESC, concept_id
368
+ ) AS rnc
369
+ FROM cand_partial
370
+ ),
371
+ scored_partial AS (
372
+ SELECT
373
+ Q, NSTR, concept_id, identifier, str, source, name_type, ispref, rank,
374
+ (rank*{RANK_MULTIPLIER} + ispref_bump*{ISPREF_WEIGHT} + stt_bump*{STT_WEIGHT}
375
+ + tty_bump*{TTY_WEIGHT} + ROUND(coverage * {coverage_weight}))::INTEGER AS total_score,
376
+ FALSE AS is_exact
377
+ FROM dedup_partial
378
+ WHERE rnc = 1
379
+ )
380
+ """
381
+ union_partial = "UNION ALL SELECT * FROM scored_partial"
382
+
383
+ # qwords CTE (only if partial enabled)
384
+ qwords_cte = ""
385
+ if allow_partial and qwords_values:
386
+ qwords_cte = f"qwords(Q, NSTR, NWD) AS (VALUES {qwords_values}),"
387
+
388
+ # Final aggregation SQL with JSON_AGG
389
+ sql = f"""
390
+ WITH
391
+ qmap(Q, NSTR) AS (VALUES {qmap_values}),
392
+ {qwords_cte}
393
+ allq(Q) AS (VALUES {allq_values}),
394
+ {exact_cte}
395
+ {partial_cte}
396
+ ,
397
+ scored AS (
398
+ SELECT * FROM scored_exact
399
+ {union_partial}
400
+ ),
401
+ dedup_concept AS (
402
+ SELECT *,
403
+ ROW_NUMBER() OVER (PARTITION BY Q, concept_id ORDER BY total_score DESC) AS rcid
404
+ FROM scored
405
+ ),
406
+ best AS (
407
+ SELECT *,
408
+ ROW_NUMBER() OVER (PARTITION BY Q ORDER BY total_score DESC, concept_id) AS rn
409
+ FROM dedup_concept
410
+ WHERE rcid = 1
411
+ ),
412
+ topk AS (
413
+ SELECT * FROM best WHERE rn <= {top_k}
414
+ ),
415
+ agg AS (
416
+ SELECT
417
+ Q,
418
+ JSON_AGG(
419
+ JSON_BUILD_OBJECT(
420
+ 'global_identifier', concept_id,
421
+ 'identifier', identifier,
422
+ 'nstr', NSTR,
423
+ 'name', str,
424
+ 'source', source,
425
+ 'name_type', name_type,
426
+ 'score', rank,
427
+ 'total_score', total_score,
428
+ 'match_type', CASE WHEN is_exact THEN 'exact' ELSE 'partial' END
429
+ ) ORDER BY total_score DESC, concept_id
430
+ ) AS hits
431
+ FROM topk
432
+ GROUP BY Q
433
+ )
434
+ SELECT
435
+ aq.Q AS input_string,
436
+ agg.hits
437
+ FROM allq aq
438
+ LEFT JOIN agg ON agg.Q = aq.Q;
439
+ """
440
+
441
+ async with self._pool.acquire() as con:
442
+ rows = await con.fetch(sql, *params)
443
+
444
+ # Parse JSON results into Polars DataFrame
445
+ data = []
446
+ for row in rows:
447
+ input_string = row["input_string"]
448
+ hits_json = row["hits"]
449
+ hits = json.loads(hits_json) if hits_json else []
450
+ data.append({"input_string": input_string, "hits": hits})
451
+
452
+ return pl.DataFrame(data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
453
+
454
+ async def concept_info(
455
+ self,
456
+ concept_ids: Sequence[str],
457
+ prefer_ttys: list[str] | None = None,
458
+ prefer_def_sources: list[str] | None = None,
459
+ ) -> dict[str, ConceptInfo]:
460
+ """
461
+ Get detailed information for concepts.
462
+
463
+ Args:
464
+ concept_ids: List of concept IDs
465
+ prefer_ttys: Preferred term types
466
+ prefer_def_sources: Preferred sources for definitions
467
+
468
+ Returns:
469
+ Dict mapping concept_id to ConceptInfo
470
+ """
471
+ if not concept_ids:
472
+ return {}
473
+
474
+ if prefer_ttys is None:
475
+ prefer_ttys = DEFAULT_PREFER_TTYS
476
+
477
+ id_list = list(dict.fromkeys(concept_ids))
478
+
479
+ # Initialize results with defaults
480
+ res: dict[str, ConceptInfo] = {}
481
+ for cid in id_list:
482
+ res[cid] = ConceptInfo(
483
+ concept_id=cid,
484
+ identifier=None,
485
+ source=None,
486
+ preferred_name=None,
487
+ name_type=None,
488
+ description=None,
489
+ def_source=None,
490
+ synonyms=[],
491
+ semantic_types=[],
492
+ )
493
+
494
+ # Build idmap VALUES clause
495
+ params: list[str] = []
496
+ idmap_placeholders = []
497
+ for cid in id_list:
498
+ idx = len(params)
499
+ params.append(cid)
500
+ idmap_placeholders.append(f"(${idx + 1})")
501
+ idmap_values = ", ".join(idmap_placeholders)
502
+
503
+ # Build preference clauses
504
+ tty_join = ""
505
+ tty_bump = "0"
506
+ if prefer_ttys:
507
+ tty_vals = ", ".join(f"('{t}')" for t in prefer_ttys)
508
+ tty_join = f"LEFT JOIN (VALUES {tty_vals}) AS pt(tty) ON a.name_type = pt.tty"
509
+ tty_bump = "CASE WHEN pt.tty IS NULL THEN 0 ELSE 1 END"
510
+
511
+ stt_bump = "CASE WHEN a.stt='PF' THEN 1 ELSE 0 END" if self._has_stt else "0"
512
+
513
+ # Main query for names
514
+ sql = f"""
515
+ WITH
516
+ idmap(concept_id) AS (VALUES {idmap_values}),
517
+ name_cand AS (
518
+ SELECT
519
+ c.concept_id, a.str, a.source AS sab,
520
+ a.name_type AS tty, a.ispref, a.stt, a.rank,
521
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
522
+ {stt_bump} AS stt_bump,
523
+ {tty_bump} AS tty_bump,
524
+ a.identifier
525
+ FROM idmap c
526
+ JOIN {self._atoms_table} a ON a.concept_id = c.concept_id
527
+ {tty_join}
528
+ ),
529
+ name_best AS (
530
+ SELECT *,
531
+ ROW_NUMBER() OVER (
532
+ PARTITION BY concept_id
533
+ ORDER BY tty_bump DESC, ispref_bump DESC, stt_bump DESC, rank DESC, str
534
+ ) AS rn
535
+ FROM name_cand
536
+ ),
537
+ chosen AS (
538
+ SELECT concept_id, str AS preferred_name, sab AS name_sab, tty AS name_tty, identifier
539
+ FROM name_best WHERE rn=1
540
+ ),
541
+ syn_cand AS (
542
+ SELECT
543
+ c.concept_id, a.str, a.source AS sab,
544
+ a.name_type AS tty, a.ispref, a.stt, a.rank,
545
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
546
+ {stt_bump} AS stt_bump,
547
+ {tty_bump} AS tty_bump
548
+ FROM idmap c
549
+ JOIN {self._atoms_table} a ON a.concept_id = c.concept_id
550
+ {tty_join}
551
+ ),
552
+ syn_rank AS (
553
+ SELECT sc.*,
554
+ ROW_NUMBER() OVER (
555
+ PARTITION BY sc.concept_id, LOWER(sc.str)
556
+ ORDER BY sc.tty_bump DESC, sc.ispref_bump DESC,
557
+ sc.stt_bump DESC, sc.rank DESC, sc.str
558
+ ) AS rstr
559
+ FROM syn_cand sc
560
+ ),
561
+ syn_best_uniq AS (
562
+ SELECT s.concept_id, s.str, s.tty_bump, s.ispref_bump, s.stt_bump, s.rank
563
+ FROM syn_rank s
564
+ LEFT JOIN chosen ch ON ch.concept_id = s.concept_id
565
+ WHERE s.rstr = 1 AND NOT (s.str = ch.preferred_name)
566
+ ),
567
+ syn_agg AS (
568
+ SELECT concept_id,
569
+ ARRAY_AGG(
570
+ str ORDER BY
571
+ tty_bump DESC, ispref_bump DESC, stt_bump DESC, rank DESC, str
572
+ ) AS synonyms
573
+ FROM syn_best_uniq
574
+ GROUP BY concept_id
575
+ )
576
+ SELECT c.concept_id,
577
+ ch.preferred_name, ch.name_sab, ch.name_tty, ch.identifier,
578
+ sa.synonyms
579
+ FROM idmap c
580
+ LEFT JOIN chosen ch ON ch.concept_id = c.concept_id
581
+ LEFT JOIN syn_agg sa ON sa.concept_id = c.concept_id
582
+ ORDER BY c.concept_id;
583
+ """
584
+
585
+ async with self._pool.acquire() as con:
586
+ rows = await con.fetch(sql, *params)
587
+
588
+ for row in rows:
589
+ cid = row["concept_id"]
590
+ ent = res[cid]
591
+
592
+ if row["preferred_name"] is not None:
593
+ ent.preferred_name = row["preferred_name"]
594
+ ent.source = row["name_sab"]
595
+ ent.name_type = row["name_tty"]
596
+ ent.identifier = row["identifier"]
597
+
598
+ synonyms = row.get("synonyms")
599
+ if isinstance(synonyms, list):
600
+ ent.synonyms = list(dict.fromkeys(synonyms))
601
+
602
+ # Definitions (if available)
603
+ if self._has_defs:
604
+ await self._populate_definitions(res, id_list, prefer_def_sources)
605
+
606
+ # Semantic types (if available)
607
+ if self._has_types:
608
+ await self._populate_semantic_types(res, id_list)
609
+
610
+ return res
611
+
612
+ async def _populate_definitions(
613
+ self,
614
+ res: dict[str, ConceptInfo],
615
+ id_list: list[str],
616
+ prefer_def_sources: list[str] | None,
617
+ ) -> None:
618
+ """Populate definitions for concepts."""
619
+ params: list[str] = []
620
+ idmap_placeholders = []
621
+ for cid in id_list:
622
+ idx = len(params)
623
+ params.append(cid)
624
+ idmap_placeholders.append(f"(${idx + 1})")
625
+ idmap_values = ", ".join(idmap_placeholders)
626
+
627
+ def_pref_join = ""
628
+ def_pref_bump = "0"
629
+ if prefer_def_sources:
630
+ def_vals = ", ".join(f"('{src}')" for src in prefer_def_sources)
631
+ def_pref_join = f"LEFT JOIN (VALUES {def_vals}) AS pds(sab) ON d.source = pds.sab"
632
+ def_pref_bump = "CASE WHEN pds.sab IS NULL THEN 0 ELSE 1 END"
633
+
634
+ sql = f"""
635
+ WITH
636
+ idmap(concept_id) AS (VALUES {idmap_values}),
637
+ def_cand AS (
638
+ SELECT
639
+ d.concept_id, d.source AS sab, d.def_text,
640
+ {def_pref_bump} AS def_pref_bump,
641
+ length(d.def_text) AS def_len
642
+ FROM {self._defs_table} d
643
+ JOIN idmap c ON c.concept_id = d.concept_id
644
+ {def_pref_join}
645
+ ),
646
+ def_best AS (
647
+ SELECT *,
648
+ ROW_NUMBER() OVER (
649
+ PARTITION BY concept_id
650
+ ORDER BY def_pref_bump DESC, def_len DESC
651
+ ) AS drn
652
+ FROM def_cand
653
+ )
654
+ SELECT concept_id, def_text, sab AS def_sab
655
+ FROM def_best
656
+ WHERE drn = 1;
657
+ """
658
+
659
+ async with self._pool.acquire() as con:
660
+ rows = await con.fetch(sql, *params)
661
+
662
+ for row in rows:
663
+ cid = row["concept_id"]
664
+ if cid in res and row["def_text"]:
665
+ res[cid].description = row["def_text"]
666
+ res[cid].def_source = row["def_sab"]
667
+
668
+ async def _populate_semantic_types(
669
+ self,
670
+ res: dict[str, ConceptInfo],
671
+ id_list: list[str],
672
+ ) -> None:
673
+ """Populate semantic types for concepts."""
674
+ params: list[str] = []
675
+ idmap_placeholders = []
676
+ for cid in id_list:
677
+ idx = len(params)
678
+ params.append(cid)
679
+ idmap_placeholders.append(f"(${idx + 1})")
680
+ idmap_values = ", ".join(idmap_placeholders)
681
+
682
+ sql = f"""
683
+ WITH idmap(concept_id) AS (VALUES {idmap_values})
684
+ SELECT DISTINCT t.concept_id, t.type_id, t.type_name, t.type_tree
685
+ FROM {self._types_table} t
686
+ JOIN idmap c ON c.concept_id = t.concept_id
687
+ ORDER BY t.concept_id, t.type_tree, t.type_id;
688
+ """
689
+
690
+ async with self._pool.acquire() as con:
691
+ rows = await con.fetch(sql, *params)
692
+
693
+ for row in rows:
694
+ cid = row["concept_id"]
695
+ if cid in res and row["type_id"] and row["type_name"]:
696
+ res[cid].semantic_types.append(SemanticType(type_id=row["type_id"], type_name=row["type_name"]))
697
+
698
+ async def concept_semantic_types(
699
+ self,
700
+ concept_ids: Sequence[str],
701
+ ) -> dict[str, list[dict[str, str]]]:
702
+ """
703
+ Get semantic types for concepts.
704
+
705
+ Returns dict mapping concept_id to list of {"tui": ..., "sty": ...}
706
+ """
707
+ if not self._has_types or not concept_ids:
708
+ return {cid: [] for cid in concept_ids}
709
+
710
+ id_list = list(dict.fromkeys(concept_ids))
711
+
712
+ params: list[str] = []
713
+ idmap_placeholders = []
714
+ for cid in id_list:
715
+ idx = len(params)
716
+ params.append(cid)
717
+ idmap_placeholders.append(f"(${idx + 1})")
718
+ idmap_values = ", ".join(idmap_placeholders)
719
+
720
+ sql = f"""
721
+ WITH idmap(concept_id) AS (VALUES {idmap_values})
722
+ SELECT DISTINCT t.concept_id, t.type_id AS tui, t.type_name AS sty, t.type_tree
723
+ FROM {self._types_table} t
724
+ JOIN idmap c ON c.concept_id = t.concept_id
725
+ ORDER BY t.concept_id, t.type_tree, t.type_id;
726
+ """
727
+
728
+ async with self._pool.acquire() as con:
729
+ rows = await con.fetch(sql, *params)
730
+
731
+ res: dict[str, list[dict[str, str]]] = {cid: [] for cid in id_list}
732
+ for row in rows:
733
+ res[row["concept_id"]].append({"tui": row["tui"], "sty": row["sty"]})
734
+
735
+ return res
736
+
737
+ async def get_narrower_concepts(
738
+ self,
739
+ concept_id: str,
740
+ max_depth: int | None = 10,
741
+ filter_sources: list[str] | None = None,
742
+ ) -> list[str]:
743
+ """
744
+ Get all narrower (descendant) concept IDs using recursive traversal.
745
+
746
+ Uses the hierarchy edges to walk down the tree/DAG from the given concept.
747
+
748
+ Args:
749
+ concept_id: Starting concept ID (broader term)
750
+ max_depth: Maximum depth to traverse (1 = direct children only, None = all descendants)
751
+ filter_sources: Only follow edges from these sources (e.g., ["SNOMEDCT_US"])
752
+
753
+ Returns:
754
+ List of descendant concept IDs (excludes the starting concept)
755
+ """
756
+ if not self._has_edges:
757
+ return []
758
+
759
+ # Build source filter clause
760
+ source_filter = ""
761
+ if filter_sources:
762
+ sources_sql = ", ".join(f"'{src}'" for src in filter_sources)
763
+ source_filter = f" AND e.source IN ({sources_sql})"
764
+
765
+ # PostgreSQL recursive CTE
766
+ query = f"""
767
+ WITH RECURSIVE walk(concept_id, depth) AS (
768
+ SELECT $1::VARCHAR, 0
769
+
770
+ UNION ALL
771
+
772
+ SELECT e.child_id, w.depth + 1
773
+ FROM walk w
774
+ JOIN {self._edges_table} e ON e.parent_id = w.concept_id
775
+ WHERE ($2::INTEGER IS NULL OR w.depth < $2){source_filter}
776
+ )
777
+ SELECT DISTINCT concept_id
778
+ FROM walk
779
+ WHERE concept_id != $1
780
+ """
781
+
782
+ async with self._pool.acquire() as con:
783
+ rows = await con.fetch(query, concept_id, max_depth)
784
+
785
+ return [r["concept_id"] for r in rows]
786
+
787
+ def normalize_sync(
788
+ self,
789
+ strings: Sequence[str],
790
+ top_k: int = 25,
791
+ prefer_ttys: list[str] | None = None,
792
+ filter_sources: list[str] | None = None,
793
+ exclude_sources: list[str] | None = None,
794
+ allow_partial: bool = True,
795
+ min_coverage: float = 0.6,
796
+ min_word_hits: int | None = None,
797
+ coverage_weight: int = 25,
798
+ ) -> pl.DataFrame:
799
+ """
800
+ Synchronous wrapper around normalize().
801
+
802
+ Requires the normalizer to be created with create_sync() factory method.
803
+ """
804
+ if self._loop is None:
805
+ raise RuntimeError("normalize_sync() requires normalizer created with create_sync()")
806
+
807
+ return self._loop.run_until_complete(
808
+ self.normalize(
809
+ strings=strings,
810
+ top_k=top_k,
811
+ prefer_ttys=prefer_ttys,
812
+ filter_sources=filter_sources,
813
+ exclude_sources=exclude_sources,
814
+ allow_partial=allow_partial,
815
+ min_coverage=min_coverage,
816
+ min_word_hits=min_word_hits,
817
+ coverage_weight=coverage_weight,
818
+ )
819
+ )
820
+
821
+ async def close(self) -> None:
822
+ """
823
+ Close the connection pool.
824
+
825
+ Note: Only call this if you want to close the pool. If the pool
826
+ is managed externally, the caller should close it instead.
827
+ """
828
+ await self._pool.close()
829
+
830
+ def close_sync(self) -> None:
831
+ """
832
+ Synchronously close the connection pool and event loop.
833
+
834
+ Use this when the normalizer was created with create_sync().
835
+ """
836
+ if self._loop is None:
837
+ raise RuntimeError("close_sync() requires normalizer created with create_sync()")
838
+
839
+ self._loop.run_until_complete(self._pool.close())
840
+ self._loop.close()