norm_toolkit 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,679 @@
1
+ """
2
+ Unified normalizer for biomedical concept normalization.
3
+
4
+ Works with DuckDB databases built by build_umls_duckdb, build_ontology_duckdb,
5
+ or build_merged_duckdb. All use a standardized schema.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import contextlib
11
+ from collections.abc import Mapping, Sequence
12
+
13
+ import duckdb
14
+ import polars as pl
15
+ from lvg_norm import lvg_normalize
16
+
17
+ from norm_toolkit.constants import (
18
+ ATOMS_TABLE,
19
+ DEFAULT_PREFER_TTYS,
20
+ DEFS_TABLE,
21
+ EDGES_TABLE,
22
+ EXACT_BUMP,
23
+ HIT_STRUCT_TYPE,
24
+ ISPREF_WEIGHT,
25
+ NS_TABLE,
26
+ NW_TABLE,
27
+ RANK_MULTIPLIER,
28
+ STT_WEIGHT,
29
+ TTY_WEIGHT,
30
+ TYPES_TABLE,
31
+ )
32
+ from norm_toolkit.models import ConceptInfo, SemanticType
33
+
34
+
35
+ class DuckDBNormalizer:
36
+ """
37
+ High-throughput normalizer using DuckDB.
38
+
39
+ Works with databases built by any of the build functions. Uses exact match
40
+ via normalized string index and optional partial match via word-level index.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ db_path: str,
46
+ threads: int = 8,
47
+ ) -> None:
48
+ """
49
+ Initialize the normalizer.
50
+
51
+ Args:
52
+ db_path: Path to DuckDB database file
53
+ threads: Number of DuckDB threads to use
54
+ """
55
+ self.db_path = db_path
56
+ self.con = duckdb.connect(db_path, read_only=True)
57
+ self.con.execute(f"PRAGMA threads={threads}")
58
+
59
+ # Detect database capabilities
60
+ self._has_types = self._table_has_rows(TYPES_TABLE)
61
+ self._has_defs = self._table_has_rows(DEFS_TABLE)
62
+ self._has_edges = self._table_has_rows(EDGES_TABLE)
63
+ self._has_stt = self._column_has_values(ATOMS_TABLE, "stt")
64
+
65
+ def _table_has_rows(self, table: str) -> bool:
66
+ """Check if a table exists and has rows."""
67
+ try:
68
+ result = self.con.execute(f"SELECT 1 FROM {table} LIMIT 1").fetchone()
69
+ return result is not None
70
+ except Exception:
71
+ return False
72
+
73
+ def _column_has_values(self, table: str, column: str) -> bool:
74
+ """Check if a column has any non-null values."""
75
+ try:
76
+ result = self.con.execute(f"SELECT 1 FROM {table} WHERE {column} IS NOT NULL LIMIT 1").fetchone()
77
+ return result is not None
78
+ except Exception:
79
+ return False
80
+
81
+ def _lookup(
82
+ self,
83
+ q_to_nstrs: Mapping[str, Sequence[str]],
84
+ all_queries: Sequence[str],
85
+ prefer_ttys: list[str] | None,
86
+ filter_sources: list[str] | None,
87
+ exclude_sources: list[str] | None,
88
+ *,
89
+ top_k: int = 25,
90
+ allow_partial: bool = True,
91
+ min_coverage: float = 0.6,
92
+ min_word_hits: int | None = None,
93
+ coverage_weight: int = 25,
94
+ ) -> pl.DataFrame:
95
+ """
96
+ Core lookup via exact + partial match paths.
97
+
98
+ Returns DataFrame with columns: input_string, hits (list of structs)
99
+ """
100
+ top_k = max(1, int(top_k))
101
+
102
+ # Flatten q_to_nstrs to rows
103
+ rows: list[tuple[str, str]] = []
104
+ for q, nstrs in q_to_nstrs.items():
105
+ for nstr in dict.fromkeys(nstrs):
106
+ if nstr:
107
+ rows.append((q, nstr))
108
+
109
+ if not rows:
110
+ return pl.DataFrame({"input_string": all_queries, "hits": [[] for _ in all_queries]}).cast(
111
+ {"hits": pl.List(HIT_STRUCT_TYPE)}
112
+ )
113
+
114
+ qmap_df = pl.DataFrame(rows, schema=["Q", "NSTR"], orient="row")
115
+ self.con.register("qmap", qmap_df.to_arrow())
116
+
117
+ # Word-level table for partial path
118
+ if allow_partial:
119
+ word_rows = [(q, n, w) for q, n in rows for w in dict.fromkeys(n.split()) if w]
120
+ qwords_df = pl.DataFrame(word_rows, schema=["Q", "NSTR", "NWD"], orient="row")
121
+ self.con.register("qwords", qwords_df.to_arrow())
122
+
123
+ # Build preference clauses
124
+ tty_join = ""
125
+ tty_bump_expr = "0"
126
+ if prefer_ttys:
127
+ tty_vals = ", ".join(f"('{t}')" for t in prefer_ttys)
128
+ tty_join = f"LEFT JOIN (VALUES {tty_vals}) AS pt(tty) ON a.name_type = pt.tty"
129
+ tty_bump_expr = "CASE WHEN pt.tty IS NULL THEN 0 ELSE 1 END"
130
+
131
+ # Source filtering (include and exclude)
132
+ source_filter_exprs = []
133
+ nw_filter_clauses = []
134
+ if filter_sources:
135
+ filt_vals = ", ".join(f"'{src}'" for src in filter_sources)
136
+ source_filter_exprs.append(f"a.source IN ({filt_vals})")
137
+ nw_filter_clauses.append(f"nw.source IN ({filt_vals})")
138
+ if exclude_sources:
139
+ excl_vals = ", ".join(f"'{src}'" for src in exclude_sources)
140
+ source_filter_exprs.append(f"a.source NOT IN ({excl_vals})")
141
+ nw_filter_clauses.append(f"nw.source NOT IN ({excl_vals})")
142
+ nw_filter_clause = (" AND " + " AND ".join(nw_filter_clauses)) if nw_filter_clauses else ""
143
+
144
+ # Build WHERE clause from source filters
145
+ combined_where = f"WHERE {' AND '.join(source_filter_exprs)}" if source_filter_exprs else ""
146
+
147
+ # STT bump
148
+ stt_bump_expr = "CASE WHEN a.stt='PF' THEN 1 ELSE 0 END" if self._has_stt else "0"
149
+
150
+ # Scoring constants
151
+ min_hits_sql = str(min_word_hits) if min_word_hits is not None else "0"
152
+ cov_sql = f"{min_coverage:.6f}"
153
+
154
+ # Build exact match CTE
155
+ exact_cte = f"""
156
+ cand_exact AS (
157
+ SELECT
158
+ q.Q, q.NSTR,
159
+ a.concept_id,
160
+ a.identifier,
161
+ a.str,
162
+ a.source,
163
+ a.name_type,
164
+ a.ispref,
165
+ a.rank,
166
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
167
+ {stt_bump_expr} AS stt_bump,
168
+ {tty_bump_expr} AS tty_bump,
169
+ 1.0 AS coverage
170
+ FROM qmap q
171
+ JOIN {NS_TABLE} ns ON ns.nstr = q.NSTR
172
+ JOIN {ATOMS_TABLE} a
173
+ ON a.concept_id = ns.concept_id
174
+ AND a.name_id = ns.name_id
175
+ {tty_join}
176
+ {combined_where}
177
+ ),
178
+ dedup_exact AS (
179
+ SELECT *,
180
+ ROW_NUMBER() OVER (
181
+ PARTITION BY Q, concept_id
182
+ ORDER BY rank DESC, ispref_bump DESC, stt_bump DESC, tty_bump DESC, concept_id
183
+ ) AS rnc
184
+ FROM cand_exact
185
+ ),
186
+ scored_exact AS (
187
+ SELECT
188
+ Q, NSTR, concept_id, identifier, str, source, name_type, ispref, rank,
189
+ (rank*{RANK_MULTIPLIER} + ispref_bump*{ISPREF_WEIGHT} + stt_bump*{STT_WEIGHT}
190
+ + tty_bump*{TTY_WEIGHT} + {EXACT_BUMP} + ROUND(coverage * {coverage_weight}))::INTEGER AS total_score,
191
+ TRUE AS is_exact
192
+ FROM dedup_exact
193
+ WHERE rnc = 1
194
+ )
195
+ """
196
+
197
+ # Build partial match CTE (if enabled)
198
+ partial_cte = ""
199
+ union_partial = ""
200
+ if allow_partial:
201
+ partial_cte = f"""
202
+ ,
203
+ qn AS (
204
+ SELECT Q, NSTR, COUNT(DISTINCT NWD) AS need
205
+ FROM qwords
206
+ GROUP BY Q, NSTR
207
+ ),
208
+ hits AS (
209
+ SELECT qw.Q, qw.NSTR, nw.string_id, nw.concept_id,
210
+ COUNT(DISTINCT qw.NWD) AS hits
211
+ FROM qwords qw
212
+ JOIN {NW_TABLE} nw ON nw.nwd = qw.NWD{nw_filter_clause}
213
+ GROUP BY qw.Q, qw.NSTR, nw.string_id, nw.concept_id
214
+ ),
215
+ good AS (
216
+ SELECT h.Q, h.NSTR, h.string_id, h.concept_id, h.hits, qn.need,
217
+ CAST(h.hits AS DOUBLE)/NULLIF(qn.need,0) AS coverage
218
+ FROM hits h
219
+ JOIN qn ON qn.Q = h.Q AND qn.NSTR = h.NSTR
220
+ WHERE h.hits >= GREATEST({min_hits_sql}, CAST(CEIL(qn.need * {cov_sql}) AS INTEGER))
221
+ ),
222
+ cand_partial AS (
223
+ SELECT
224
+ g.Q, g.NSTR,
225
+ a.concept_id,
226
+ a.identifier,
227
+ a.str,
228
+ a.source,
229
+ a.name_type,
230
+ a.ispref,
231
+ a.rank,
232
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
233
+ {stt_bump_expr} AS stt_bump,
234
+ {tty_bump_expr} AS tty_bump,
235
+ COALESCE(g.coverage, 0.0) AS coverage
236
+ FROM good g
237
+ JOIN {ATOMS_TABLE} a ON a.string_id = g.string_id
238
+ {tty_join}
239
+ {combined_where}
240
+ ),
241
+ dedup_partial AS (
242
+ SELECT *,
243
+ ROW_NUMBER() OVER (
244
+ PARTITION BY Q, concept_id
245
+ ORDER BY rank DESC, ispref_bump DESC, stt_bump DESC, tty_bump DESC, concept_id
246
+ ) AS rnc
247
+ FROM cand_partial
248
+ ),
249
+ scored_partial AS (
250
+ SELECT
251
+ Q, NSTR, concept_id, identifier, str, source, name_type, ispref, rank,
252
+ (rank*{RANK_MULTIPLIER} + ispref_bump*{ISPREF_WEIGHT} + stt_bump*{STT_WEIGHT}
253
+ + tty_bump*{TTY_WEIGHT} + ROUND(coverage * {coverage_weight}))::INTEGER AS total_score,
254
+ FALSE AS is_exact
255
+ FROM dedup_partial
256
+ WHERE rnc = 1
257
+ )
258
+ """
259
+ union_partial = "UNION ALL SELECT * FROM scored_partial"
260
+
261
+ # Final aggregation SQL
262
+ sql = f"""
263
+ WITH
264
+ {exact_cte}
265
+ {partial_cte}
266
+ ,
267
+ scored AS (
268
+ SELECT * FROM scored_exact
269
+ {union_partial}
270
+ ),
271
+ dedup_concept AS (
272
+ SELECT *,
273
+ ROW_NUMBER() OVER (PARTITION BY Q, concept_id ORDER BY total_score DESC) AS rcid
274
+ FROM scored
275
+ ),
276
+ best AS (
277
+ SELECT *,
278
+ ROW_NUMBER() OVER (PARTITION BY Q ORDER BY total_score DESC, concept_id) AS rn
279
+ FROM dedup_concept
280
+ WHERE rcid = 1
281
+ ),
282
+ topk AS (
283
+ SELECT * FROM best WHERE rn <= {top_k}
284
+ ),
285
+ agg AS (
286
+ SELECT
287
+ Q,
288
+ LIST({{
289
+ 'global_identifier': concept_id,
290
+ 'identifier': identifier,
291
+ 'nstr': NSTR,
292
+ 'name': str,
293
+ 'source': source,
294
+ 'name_type': name_type,
295
+ 'score': rank::BIGINT,
296
+ 'total_score': total_score::BIGINT,
297
+ 'match_type': CASE WHEN is_exact THEN 'exact' ELSE 'partial' END
298
+ }} ORDER BY total_score DESC, concept_id) AS hits
299
+ FROM topk
300
+ GROUP BY Q
301
+ )
302
+ SELECT
303
+ aq.Q AS input_string,
304
+ agg.hits
305
+ FROM allq aq
306
+ LEFT JOIN agg ON agg.Q = aq.Q;
307
+ """
308
+
309
+ # Register all queries for preserving order
310
+ allq_df = pl.DataFrame({"Q": all_queries})
311
+ self.con.register("allq", allq_df.to_arrow())
312
+
313
+ out = self.con.execute(sql).pl()
314
+ out = out.with_columns(pl.col("hits").fill_null([]).cast(pl.List(HIT_STRUCT_TYPE)))
315
+
316
+ with contextlib.suppress(Exception):
317
+ self.con.unregister("qmap")
318
+ self.con.unregister("allq")
319
+ if allow_partial:
320
+ self.con.unregister("qwords")
321
+
322
+ return out
323
+
324
+ def normalize(
325
+ self,
326
+ strings: Sequence[str],
327
+ top_k: int = 25,
328
+ prefer_ttys: list[str] | None = None,
329
+ filter_sources: list[str] | None = None,
330
+ exclude_sources: list[str] | None = None,
331
+ allow_partial: bool = True,
332
+ min_coverage: float = 0.6,
333
+ min_word_hits: int | None = None,
334
+ coverage_weight: int = 25,
335
+ ) -> pl.DataFrame:
336
+ """
337
+ Normalize input strings to ranked concepts.
338
+
339
+ Args:
340
+ strings: Input strings to normalize
341
+ top_k: Maximum number of results per query
342
+ prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
343
+ filter_sources: Restrict to these sources (include only)
344
+ exclude_sources: Exclude these sources
345
+ allow_partial: Enable word-overlap partial matching
346
+ min_coverage: Minimum fraction of query words that must match
347
+ min_word_hits: Minimum absolute word hits required
348
+ coverage_weight: Weight for coverage in scoring
349
+
350
+ Returns:
351
+ DataFrame with columns: input_string, hits (list of match structs)
352
+ """
353
+ # Apply defaults
354
+ if prefer_ttys is None:
355
+ prefer_ttys = DEFAULT_PREFER_TTYS
356
+
357
+ # Build normalized string map
358
+ q_to_nstrs: dict[str, list[str]] = {}
359
+ for s in strings:
360
+ nstrs = list(lvg_normalize(s) or [])
361
+ q_to_nstrs[s] = nstrs
362
+
363
+ return self._lookup(
364
+ q_to_nstrs=q_to_nstrs,
365
+ all_queries=list(strings),
366
+ prefer_ttys=prefer_ttys,
367
+ filter_sources=filter_sources,
368
+ exclude_sources=exclude_sources,
369
+ top_k=top_k,
370
+ allow_partial=allow_partial,
371
+ min_coverage=min_coverage,
372
+ min_word_hits=min_word_hits,
373
+ coverage_weight=coverage_weight,
374
+ )
375
+
376
+ def concept_info(
377
+ self,
378
+ concept_ids: Sequence[str],
379
+ prefer_ttys: list[str] | None = None,
380
+ prefer_def_sources: list[str] | None = None,
381
+ ) -> dict[str, ConceptInfo]:
382
+ """
383
+ Get detailed information for concepts.
384
+
385
+ Args:
386
+ concept_ids: List of concept IDs
387
+ prefer_ttys: Preferred term types
388
+ prefer_def_sources: Preferred sources for definitions
389
+
390
+ Returns:
391
+ Dict mapping concept_id to ConceptInfo
392
+ """
393
+ if not concept_ids:
394
+ return {}
395
+
396
+ id_list = list(dict.fromkeys(concept_ids))
397
+ id_df = pl.DataFrame({"concept_id": id_list})
398
+ self.con.register("idmap", id_df.to_arrow())
399
+
400
+ # Initialize results with defaults
401
+ res: dict[str, ConceptInfo] = {}
402
+ for cid in id_list:
403
+ res[cid] = ConceptInfo(
404
+ concept_id=cid,
405
+ identifier=None,
406
+ source=None,
407
+ preferred_name=None,
408
+ name_type=None,
409
+ description=None,
410
+ def_source=None,
411
+ synonyms=[],
412
+ semantic_types=[],
413
+ )
414
+
415
+ self._populate_concept_info(res, prefer_ttys, prefer_def_sources)
416
+
417
+ with contextlib.suppress(Exception):
418
+ self.con.unregister("idmap")
419
+
420
+ return res
421
+
422
+ def _populate_concept_info(
423
+ self,
424
+ res: dict[str, ConceptInfo],
425
+ prefer_ttys: list[str] | None,
426
+ prefer_def_sources: list[str] | None,
427
+ ) -> None:
428
+ """Populate ConceptInfo for all concepts."""
429
+ if prefer_ttys is None:
430
+ prefer_ttys = DEFAULT_PREFER_TTYS
431
+
432
+ # Build preference clauses
433
+ tty_join = def_pref_join = ""
434
+ tty_bump = def_pref_bump = "0"
435
+
436
+ if prefer_ttys:
437
+ tty_vals = ", ".join(f"('{t}')" for t in prefer_ttys)
438
+ tty_join = f"LEFT JOIN (VALUES {tty_vals}) AS pt(tty) ON a.name_type = pt.tty"
439
+ tty_bump = "CASE WHEN pt.tty IS NULL THEN 0 ELSE 1 END"
440
+
441
+ if prefer_def_sources:
442
+ def_vals = ", ".join(f"('{src}')" for src in prefer_def_sources)
443
+ def_pref_join = f"LEFT JOIN (VALUES {def_vals}) AS pds(sab) ON d.source = pds.sab"
444
+ def_pref_bump = "CASE WHEN pds.sab IS NULL THEN 0 ELSE 1 END"
445
+
446
+ stt_bump = "CASE WHEN a.stt='PF' THEN 1 ELSE 0 END" if self._has_stt else "0"
447
+
448
+ # Main query for names
449
+ sql = f"""
450
+ WITH
451
+ name_cand AS (
452
+ SELECT
453
+ c.concept_id, a.str, a.source AS sab,
454
+ a.name_type AS tty, a.ispref, a.stt, a.rank,
455
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
456
+ {stt_bump} AS stt_bump,
457
+ {tty_bump} AS tty_bump,
458
+ a.identifier
459
+ FROM idmap c
460
+ JOIN {ATOMS_TABLE} a ON a.concept_id = c.concept_id
461
+ {tty_join}
462
+ ),
463
+ name_best AS (
464
+ SELECT *,
465
+ ROW_NUMBER() OVER (
466
+ PARTITION BY concept_id
467
+ ORDER BY tty_bump DESC, ispref_bump DESC, stt_bump DESC, rank DESC, str
468
+ ) AS rn
469
+ FROM name_cand
470
+ ),
471
+ chosen AS (
472
+ SELECT concept_id, str AS preferred_name, sab AS name_sab, tty AS name_tty, identifier
473
+ FROM name_best WHERE rn=1
474
+ ),
475
+
476
+ syn_cand AS (
477
+ SELECT
478
+ c.concept_id, a.str, a.source AS sab,
479
+ a.name_type AS tty, a.ispref, a.stt, a.rank,
480
+ CASE WHEN a.ispref='Y' THEN 1 ELSE 0 END AS ispref_bump,
481
+ {stt_bump} AS stt_bump,
482
+ {tty_bump} AS tty_bump
483
+ FROM idmap c
484
+ JOIN {ATOMS_TABLE} a ON a.concept_id = c.concept_id
485
+ {tty_join}
486
+ ),
487
+ syn_rank AS (
488
+ SELECT sc.*,
489
+ ROW_NUMBER() OVER (
490
+ PARTITION BY sc.concept_id, LOWER(sc.str)
491
+ ORDER BY sc.tty_bump DESC, sc.ispref_bump DESC,
492
+ sc.stt_bump DESC, sc.rank DESC, sc.str
493
+ ) AS rstr
494
+ FROM syn_cand sc
495
+ ),
496
+ syn_best_uniq AS (
497
+ SELECT s.concept_id, s.str, s.tty_bump, s.ispref_bump, s.stt_bump, s.rank
498
+ FROM syn_rank s
499
+ LEFT JOIN chosen ch ON ch.concept_id = s.concept_id
500
+ WHERE s.rstr = 1 AND NOT (s.str = ch.preferred_name)
501
+ ),
502
+ syn_agg AS (
503
+ SELECT concept_id,
504
+ ARRAY_AGG(
505
+ str ORDER BY
506
+ tty_bump DESC, ispref_bump DESC, stt_bump DESC, rank DESC, str
507
+ ) AS synonyms
508
+ FROM syn_best_uniq
509
+ GROUP BY concept_id
510
+ )
511
+
512
+ SELECT c.concept_id,
513
+ ch.preferred_name, ch.name_sab, ch.name_tty, ch.identifier,
514
+ sa.synonyms
515
+ FROM idmap c
516
+ LEFT JOIN chosen ch ON ch.concept_id = c.concept_id
517
+ LEFT JOIN syn_agg sa ON sa.concept_id = c.concept_id
518
+ ORDER BY c.concept_id;
519
+ """
520
+
521
+ out = self.con.execute(sql).pl()
522
+
523
+ for row in out.iter_rows(named=True):
524
+ cid = row["concept_id"]
525
+ ent = res[cid]
526
+
527
+ if row["preferred_name"] is not None:
528
+ ent.preferred_name = row["preferred_name"]
529
+ ent.source = row["name_sab"]
530
+ ent.name_type = row["name_tty"]
531
+ ent.identifier = row["identifier"]
532
+
533
+ synonyms = row.get("synonyms")
534
+ if isinstance(synonyms, list):
535
+ ent.synonyms = list(dict.fromkeys(synonyms))
536
+
537
+ # Definitions (if available)
538
+ if self._has_defs:
539
+ self._populate_definitions(res, def_pref_join, def_pref_bump)
540
+
541
+ # Semantic types (if available)
542
+ if self._has_types:
543
+ self._populate_semantic_types(res)
544
+
545
+ def _populate_definitions(
546
+ self,
547
+ res: dict[str, ConceptInfo],
548
+ def_pref_join: str,
549
+ def_pref_bump: str,
550
+ ) -> None:
551
+ """Populate definitions for concepts."""
552
+ sql = f"""
553
+ WITH
554
+ def_cand AS (
555
+ SELECT
556
+ d.concept_id, d.source AS sab, d.def_text,
557
+ {def_pref_bump} AS def_pref_bump,
558
+ length(d.def_text) AS def_len
559
+ FROM {DEFS_TABLE} d
560
+ JOIN idmap c ON c.concept_id = d.concept_id
561
+ {def_pref_join}
562
+ ),
563
+ def_best AS (
564
+ SELECT *,
565
+ ROW_NUMBER() OVER (
566
+ PARTITION BY concept_id
567
+ ORDER BY def_pref_bump DESC, def_len DESC
568
+ ) AS drn
569
+ FROM def_cand
570
+ )
571
+ SELECT concept_id, def_text, sab AS def_sab
572
+ FROM def_best
573
+ WHERE drn = 1;
574
+ """
575
+
576
+ out = self.con.execute(sql).pl()
577
+ for row in out.iter_rows(named=True):
578
+ cid = row["concept_id"]
579
+ if cid in res and row["def_text"]:
580
+ res[cid].description = row["def_text"]
581
+ res[cid].def_source = row["def_sab"]
582
+
583
+ def _populate_semantic_types(self, res: dict[str, ConceptInfo]) -> None:
584
+ """Populate semantic types for concepts."""
585
+ sql = f"""
586
+ SELECT DISTINCT t.concept_id, t.type_id, t.type_name, t.type_tree
587
+ FROM {TYPES_TABLE} t
588
+ JOIN idmap c ON c.concept_id = t.concept_id
589
+ ORDER BY t.concept_id, t.type_tree, t.type_id;
590
+ """
591
+
592
+ out = self.con.execute(sql).pl()
593
+ for row in out.iter_rows(named=True):
594
+ cid = row["concept_id"]
595
+ if cid in res and row["type_id"] and row["type_name"]:
596
+ res[cid].semantic_types.append(SemanticType(type_id=row["type_id"], type_name=row["type_name"]))
597
+
598
+ def concept_semantic_types(self, concept_ids: Sequence[str]) -> dict[str, list[dict[str, str]]]:
599
+ """
600
+ Get semantic types for concepts.
601
+
602
+ Returns dict mapping concept_id to list of {"tui": ..., "sty": ...}
603
+ """
604
+ if not self._has_types or not concept_ids:
605
+ return {cid: [] for cid in concept_ids}
606
+
607
+ id_list = list(dict.fromkeys(concept_ids))
608
+ id_df = pl.DataFrame({"concept_id": id_list})
609
+ self.con.register("idmap", id_df.to_arrow())
610
+
611
+ sql = f"""
612
+ SELECT DISTINCT t.concept_id, t.type_id AS tui, t.type_name AS sty, t.type_tree
613
+ FROM {TYPES_TABLE} t
614
+ JOIN idmap c ON c.concept_id = t.concept_id
615
+ ORDER BY t.concept_id, t.type_tree, t.type_id;
616
+ """
617
+
618
+ out = self.con.execute(sql).pl()
619
+
620
+ with contextlib.suppress(Exception):
621
+ self.con.unregister("idmap")
622
+
623
+ res: dict[str, list[dict[str, str]]] = {cid: [] for cid in id_list}
624
+ for row in out.iter_rows(named=True):
625
+ res[row["concept_id"]].append({"tui": row["tui"], "sty": row["sty"]})
626
+
627
+ return res
628
+
629
+ def get_narrower_concepts(
630
+ self,
631
+ concept_id: str,
632
+ max_depth: int | None = 10,
633
+ filter_sources: list[str] | None = None,
634
+ ) -> list[str]:
635
+ """
636
+ Get all narrower (descendant) concept IDs using recursive traversal.
637
+
638
+ Uses the hierarchy edges to walk down the tree/DAG from the given concept.
639
+
640
+ Args:
641
+ concept_id: Starting concept ID (broader term)
642
+ max_depth: Maximum depth to traverse (1 = direct children only, None = all descendants)
643
+ filter_sources: Only follow edges from these sources (e.g., ["SNOMEDCT_US"])
644
+
645
+ Returns:
646
+ List of descendant concept IDs (excludes the starting concept)
647
+ """
648
+ if not self._has_edges:
649
+ return []
650
+
651
+ # Build source filter clause
652
+ source_filter = ""
653
+ if filter_sources:
654
+ sources_sql = ", ".join(f"'{src}'" for src in filter_sources)
655
+ source_filter = f" AND e.source IN ({sources_sql})"
656
+
657
+ # DuckDB recursive CTE
658
+ query = f"""
659
+ WITH RECURSIVE walk(concept_id, depth) AS (
660
+ SELECT $1::VARCHAR, 0
661
+
662
+ UNION ALL
663
+
664
+ SELECT e.child_id, w.depth + 1
665
+ FROM walk w
666
+ JOIN {EDGES_TABLE} e ON e.parent_id = w.concept_id
667
+ WHERE ($2 IS NULL OR w.depth < $2){source_filter}
668
+ )
669
+ SELECT DISTINCT concept_id
670
+ FROM walk
671
+ WHERE concept_id != $1
672
+ """
673
+
674
+ result = self.con.execute(query, [concept_id, max_depth]).fetchall()
675
+ return [r[0] for r in result]
676
+
677
+ def close(self) -> None:
678
+ """Close the database connection."""
679
+ self.con.close()