norm_toolkit 1.0.2__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: norm_toolkit
3
- Version: 1.0.2
3
+ Version: 1.2.0
4
4
  Summary: Toolkit to normalize text to UMLS / ontologies
5
5
  Author: Haydn Jones
6
6
  Author-email: Haydn Jones <haydnjonest@gmail.com>
@@ -10,6 +10,7 @@ Requires-Dist: lvg-norm>=1.1.0
10
10
  Requires-Dist: polars[rt64]>=1.36.1
11
11
  Requires-Dist: pyarrow>=20.0.0
12
12
  Requires-Dist: pydantic>=2.12.5
13
+ Requires-Dist: sqlalchemy>=2.0.0
13
14
  Requires-Dist: tqdm>=4.67.1
14
15
  Requires-Python: >=3.12
15
16
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "norm_toolkit"
3
- version = "1.0.2"
3
+ version = "1.2.0"
4
4
  description = "Toolkit to normalize text to UMLS / ontologies"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "Haydn Jones", email = "haydnjonest@gmail.com" }]
@@ -12,6 +12,7 @@ dependencies = [
12
12
  "polars[rt64]>=1.36.1",
13
13
  "pyarrow>=20.0.0",
14
14
  "pydantic>=2.12.5",
15
+ "sqlalchemy>=2.0.0",
15
16
  "tqdm>=4.67.1",
16
17
  ]
17
18
 
@@ -324,6 +324,7 @@ LEFT JOIN agg ON agg.Q = aq.Q;
324
324
  def normalize(
325
325
  self,
326
326
  strings: Sequence[str],
327
+ synonyms: Mapping[str, Sequence[str]] | None = None,
327
328
  top_k: int = 25,
328
329
  prefer_ttys: list[str] | None = None,
329
330
  filter_sources: list[str] | None = None,
@@ -338,6 +339,10 @@ LEFT JOIN agg ON agg.Q = aq.Q;
338
339
 
339
340
  Args:
340
341
  strings: Input strings to normalize
342
+ synonyms: Optional mapping of input strings to their synonyms.
343
+ Synonyms are normalized and used alongside the main string
344
+ to improve matching. Results are still keyed by the original
345
+ input string.
341
346
  top_k: Maximum number of results per query
342
347
  prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
343
348
  filter_sources: Restrict to these sources (include only)
@@ -348,7 +353,8 @@ LEFT JOIN agg ON agg.Q = aq.Q;
348
353
  coverage_weight: Weight for coverage in scoring
349
354
 
350
355
  Returns:
351
- DataFrame with columns: input_string, hits (list of match structs)
356
+ DataFrame with columns: input_string, hits (list of match structs),
357
+ and synonyms (list of strings) if synonyms were provided.
352
358
  """
353
359
  # Apply defaults
354
360
  if prefer_ttys is None:
@@ -358,9 +364,14 @@ LEFT JOIN agg ON agg.Q = aq.Q;
358
364
  q_to_nstrs: dict[str, list[str]] = {}
359
365
  for s in strings:
360
366
  nstrs = list(lvg_normalize(s) or [])
367
+ # Add normalized forms of synonyms
368
+ if synonyms and s in synonyms:
369
+ for syn in synonyms[s]:
370
+ syn_nstrs = list(lvg_normalize(syn) or [])
371
+ nstrs.extend(syn_nstrs)
361
372
  q_to_nstrs[s] = nstrs
362
373
 
363
- return self._lookup(
374
+ result = self._lookup(
364
375
  q_to_nstrs=q_to_nstrs,
365
376
  all_queries=list(strings),
366
377
  prefer_ttys=prefer_ttys,
@@ -373,6 +384,13 @@ LEFT JOIN agg ON agg.Q = aq.Q;
373
384
  coverage_weight=coverage_weight,
374
385
  )
375
386
 
387
+ # Add synonyms column if synonyms were provided
388
+ if synonyms:
389
+ syn_list = [list(synonyms.get(s, [])) for s in strings]
390
+ result = result.with_columns(pl.Series("synonyms", syn_list))
391
+
392
+ return result
393
+
376
394
  def concept_info(
377
395
  self,
378
396
  concept_ids: Sequence[str],
@@ -7,18 +7,17 @@ built by build_umls_duckdb, build_ontology_duckdb, or build_merged_duckdb.
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- import asyncio
11
10
  import json
12
11
  from collections.abc import Mapping, Sequence
13
12
  from typing import Any
14
13
 
15
- import asyncpg
16
14
  import polars as pl
17
15
  from lvg_norm import lvg_normalize
16
+ from sqlalchemy import text
17
+ from sqlalchemy.ext.asyncio import AsyncEngine
18
18
 
19
19
  from norm_toolkit.constants import (
20
20
  ATOMS_TABLE,
21
- CONCEPTS_TABLE,
22
21
  DEFAULT_PREFER_TTYS,
23
22
  DEFS_TABLE,
24
23
  EDGES_TABLE,
@@ -37,7 +36,7 @@ from norm_toolkit.models import ConceptInfo, SemanticType
37
36
 
38
37
  class PostgresNormalizer:
39
38
  """
40
- Async normalizer using PostgreSQL via asyncpg.
39
+ Async normalizer using PostgreSQL via SQLAlchemy.
41
40
 
42
41
  Optimized for small batch processing (1-5 strings at a time).
43
42
  Uses VALUES clauses instead of temp tables for efficiency with small batches.
@@ -45,15 +44,15 @@ class PostgresNormalizer:
45
44
 
46
45
  def __init__(
47
46
  self,
48
- pool: asyncpg.Pool,
47
+ engine: AsyncEngine,
49
48
  schema: str = "public",
50
49
  owned_resource: Any | None = None,
51
50
  ) -> None:
52
51
  """
53
- Initialize the normalizer with an external connection pool.
52
+ Initialize the normalizer with an SQLAlchemy AsyncEngine.
54
53
 
55
54
  Args:
56
- pool: asyncpg connection pool (caller manages lifecycle)
55
+ engine: SQLAlchemy AsyncEngine (caller manages lifecycle)
57
56
  schema: PostgreSQL schema where tables are located (default: "public")
58
57
  owned_resource: Optional resource with async close() method to clean up
59
58
  when this normalizer is closed (e.g., AlloyDB AsyncConnector)
@@ -62,9 +61,8 @@ class PostgresNormalizer:
62
61
  After creating the normalizer, call `await normalizer.initialize()`
63
62
  to detect database capabilities before using other methods.
64
63
  """
65
- self._pool = pool
64
+ self._engine = engine
66
65
  self._schema = schema
67
- self._loop: asyncio.AbstractEventLoop | None = None
68
66
  self._owned_resource = owned_resource
69
67
  self._has_types = False
70
68
  self._has_defs = False
@@ -77,48 +75,14 @@ class PostgresNormalizer:
77
75
  self._ns_table = f"{prefix}{NS_TABLE}"
78
76
  self._nw_table = f"{prefix}{NW_TABLE}"
79
77
  self._atoms_table = f"{prefix}{ATOMS_TABLE}"
80
- self._concepts_table = f"{prefix}{CONCEPTS_TABLE}"
81
78
  self._types_table = f"{prefix}{TYPES_TABLE}"
82
79
  self._defs_table = f"{prefix}{DEFS_TABLE}"
83
80
  self._edges_table = f"{prefix}{EDGES_TABLE}"
84
81
 
85
- @classmethod
86
- def create_sync(cls, dsn: str, schema: str = "public", min_size: int = 1, max_size: int = 10) -> PostgresNormalizer:
87
- """
88
- Create a normalizer synchronously with its own event loop.
89
-
90
- Use this factory method for sync-only usage. The normalizer will manage
91
- its own event loop and pool, allowing you to call normalize_sync().
92
-
93
- Args:
94
- dsn: PostgreSQL connection string (e.g., "postgresql://user:pass@host:5432/db")
95
- schema: PostgreSQL schema where tables are located (default: "public")
96
- min_size: Minimum pool connections
97
- max_size: Maximum pool connections
98
-
99
- Example:
100
- >>> normalizer = PostgresNormalizer.create_sync("postgresql://...")
101
- >>> result = normalizer.normalize_sync(["diabetes"])
102
- >>> normalizer.close_sync()
103
- """
104
- loop = asyncio.new_event_loop()
105
-
106
- async def _create():
107
- pool = await asyncpg.create_pool(dsn, min_size=min_size, max_size=max_size)
108
- return pool
109
-
110
- pool = loop.run_until_complete(_create())
111
- instance = cls(pool, schema=schema)
112
- instance._loop = loop
113
- loop.run_until_complete(instance.initialize())
114
- return instance
115
-
116
- async def initialize(self) -> None:
117
- """
118
- Detect database capabilities.
119
-
120
- Must be called after __init__ before using normalize/concept_info methods.
121
- """
82
+ async def _ensure_initialized(self) -> None:
83
+ """Lazily initialize on first use."""
84
+ if self._initialized:
85
+ return
122
86
  self._has_types = await self._table_has_rows(self._types_table)
123
87
  self._has_defs = await self._table_has_rows(self._defs_table)
124
88
  self._has_edges = await self._table_has_rows(self._edges_table)
@@ -128,24 +92,25 @@ class PostgresNormalizer:
128
92
  async def _table_has_rows(self, table: str) -> bool:
129
93
  """Check if a table exists and has rows."""
130
94
  try:
131
- async with self._pool.acquire() as con:
132
- result = await con.fetchval(f"SELECT 1 FROM {table} LIMIT 1")
133
- return result is not None
95
+ async with self._engine.connect() as conn:
96
+ result = await conn.execute(text(f"SELECT 1 FROM {table} LIMIT 1"))
97
+ return result.scalar() is not None
134
98
  except Exception:
135
99
  return False
136
100
 
137
101
  async def _column_has_values(self, table: str, column: str) -> bool:
138
102
  """Check if a column has any non-null values."""
139
103
  try:
140
- async with self._pool.acquire() as con:
141
- result = await con.fetchval(f"SELECT 1 FROM {table} WHERE {column} IS NOT NULL LIMIT 1")
142
- return result is not None
104
+ async with self._engine.connect() as conn:
105
+ result = await conn.execute(text(f"SELECT 1 FROM {table} WHERE {column} IS NOT NULL LIMIT 1"))
106
+ return result.scalar() is not None
143
107
  except Exception:
144
108
  return False
145
109
 
146
110
  async def normalize(
147
111
  self,
148
112
  strings: Sequence[str],
113
+ synonyms: Mapping[str, Sequence[str]] | None = None,
149
114
  top_k: int = 25,
150
115
  prefer_ttys: list[str] | None = None,
151
116
  filter_sources: list[str] | None = None,
@@ -160,6 +125,10 @@ class PostgresNormalizer:
160
125
 
161
126
  Args:
162
127
  strings: Input strings to normalize
128
+ synonyms: Optional mapping of input strings to their synonyms.
129
+ Synonyms are normalized and used alongside the main string
130
+ to improve matching. Results are still keyed by the original
131
+ input string.
163
132
  top_k: Maximum number of results per query
164
133
  prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
165
134
  filter_sources: Restrict to these sources (include only)
@@ -170,8 +139,11 @@ class PostgresNormalizer:
170
139
  coverage_weight: Weight for coverage in scoring
171
140
 
172
141
  Returns:
173
- DataFrame with columns: input_string, hits (list of match structs)
142
+ DataFrame with columns: input_string, hits (list of match structs),
143
+ and synonyms (list of strings) if synonyms were provided.
174
144
  """
145
+ await self._ensure_initialized()
146
+
175
147
  if prefer_ttys is None:
176
148
  prefer_ttys = DEFAULT_PREFER_TTYS
177
149
 
@@ -179,9 +151,14 @@ class PostgresNormalizer:
179
151
  q_to_nstrs: dict[str, list[str]] = {}
180
152
  for s in strings:
181
153
  nstrs = list(lvg_normalize(s) or [])
154
+ # Add normalized forms of synonyms
155
+ if synonyms and s in synonyms:
156
+ for syn in synonyms[s]:
157
+ syn_nstrs = list(lvg_normalize(syn) or [])
158
+ nstrs.extend(syn_nstrs)
182
159
  q_to_nstrs[s] = nstrs
183
160
 
184
- return await self._lookup(
161
+ result = await self._lookup(
185
162
  q_to_nstrs=q_to_nstrs,
186
163
  all_queries=list(strings),
187
164
  prefer_ttys=prefer_ttys,
@@ -194,6 +171,13 @@ class PostgresNormalizer:
194
171
  coverage_weight=coverage_weight,
195
172
  )
196
173
 
174
+ # Add synonyms column if synonyms were provided
175
+ if synonyms:
176
+ syn_list = [list(synonyms.get(s, [])) for s in strings]
177
+ result = result.with_columns(pl.Series("synonyms", syn_list))
178
+
179
+ return result
180
+
197
181
  async def _lookup(
198
182
  self,
199
183
  q_to_nstrs: Mapping[str, Sequence[str]],
@@ -223,15 +207,18 @@ class PostgresNormalizer:
223
207
  {"hits": pl.List(HIT_STRUCT_TYPE)}
224
208
  )
225
209
 
226
- # Build parameters and VALUES clauses
227
- params: list[str] = []
210
+ # Build parameters and VALUES clauses using named parameters
211
+ params: dict[str, Any] = {}
212
+ param_idx = 0
228
213
 
229
214
  # qmap VALUES clause
230
215
  qmap_placeholders = []
231
216
  for q, nstr in qmap_rows:
232
- idx = len(params)
233
- params.extend([q, nstr])
234
- qmap_placeholders.append(f"(${idx + 1}, ${idx + 2})")
217
+ q_key, nstr_key = f"p{param_idx}", f"p{param_idx + 1}"
218
+ params[q_key] = q
219
+ params[nstr_key] = nstr
220
+ qmap_placeholders.append(f"(:{q_key}, :{nstr_key})")
221
+ param_idx += 2
235
222
  qmap_values = ", ".join(qmap_placeholders)
236
223
 
237
224
  # qwords VALUES clause (for partial path)
@@ -240,36 +227,58 @@ class PostgresNormalizer:
240
227
  qwords_rows = [(q, n, w) for q, n in qmap_rows for w in dict.fromkeys(n.split()) if w]
241
228
  qwords_placeholders = []
242
229
  for q, nstr, nwd in qwords_rows:
243
- idx = len(params)
244
- params.extend([q, nstr, nwd])
245
- qwords_placeholders.append(f"(${idx + 1}, ${idx + 2}, ${idx + 3})")
230
+ q_key, nstr_key, nwd_key = f"p{param_idx}", f"p{param_idx + 1}", f"p{param_idx + 2}"
231
+ params[q_key] = q
232
+ params[nstr_key] = nstr
233
+ params[nwd_key] = nwd
234
+ qwords_placeholders.append(f"(:{q_key}, :{nstr_key}, :{nwd_key})")
235
+ param_idx += 3
246
236
  qwords_values = ", ".join(qwords_placeholders)
247
237
 
248
238
  # allq VALUES clause (preserve order)
249
239
  allq_placeholders = []
250
240
  for q in all_queries:
251
- idx = len(params)
252
- params.append(q)
253
- allq_placeholders.append(f"(${idx + 1})")
241
+ q_key = f"p{param_idx}"
242
+ params[q_key] = q
243
+ allq_placeholders.append(f"(:{q_key})")
244
+ param_idx += 1
254
245
  allq_values = ", ".join(allq_placeholders)
255
246
 
256
- # Build preference clauses
247
+ # Build preference clauses (parameterized to prevent SQL injection)
257
248
  tty_join = ""
258
249
  tty_bump_expr = "0"
259
250
  if prefer_ttys:
260
- tty_vals = ", ".join(f"('{t}')" for t in prefer_ttys)
251
+ tty_placeholders = []
252
+ for tty in prefer_ttys:
253
+ key = f"p{param_idx}"
254
+ params[key] = tty
255
+ tty_placeholders.append(f"(:{key})")
256
+ param_idx += 1
257
+ tty_vals = ", ".join(tty_placeholders)
261
258
  tty_join = f"LEFT JOIN (VALUES {tty_vals}) AS pt(tty) ON a.name_type = pt.tty"
262
259
  tty_bump_expr = "CASE WHEN pt.tty IS NULL THEN 0 ELSE 1 END"
263
260
 
264
- # Source filtering
261
+ # Source filtering (parameterized to prevent SQL injection)
265
262
  source_filter_exprs = []
266
263
  nw_filter_clauses = []
267
264
  if filter_sources:
268
- filt_vals = ", ".join(f"'{src}'" for src in filter_sources)
265
+ filt_placeholders = []
266
+ for src in filter_sources:
267
+ key = f"p{param_idx}"
268
+ params[key] = src
269
+ filt_placeholders.append(f":{key}")
270
+ param_idx += 1
271
+ filt_vals = ", ".join(filt_placeholders)
269
272
  source_filter_exprs.append(f"a.source IN ({filt_vals})")
270
273
  nw_filter_clauses.append(f"nw.source IN ({filt_vals})")
271
274
  if exclude_sources:
272
- excl_vals = ", ".join(f"'{src}'" for src in exclude_sources)
275
+ excl_placeholders = []
276
+ for src in exclude_sources:
277
+ key = f"p{param_idx}"
278
+ params[key] = src
279
+ excl_placeholders.append(f":{key}")
280
+ param_idx += 1
281
+ excl_vals = ", ".join(excl_placeholders)
273
282
  source_filter_exprs.append(f"a.source NOT IN ({excl_vals})")
274
283
  nw_filter_clauses.append(f"nw.source NOT IN ({excl_vals})")
275
284
  nw_filter_clause = (" AND " + " AND ".join(nw_filter_clauses)) if nw_filter_clauses else ""
@@ -447,15 +456,22 @@ FROM allq aq
447
456
  LEFT JOIN agg ON agg.Q = aq.Q;
448
457
  """
449
458
 
450
- async with self._pool.acquire() as con:
451
- rows = await con.fetch(sql, *params)
459
+ async with self._engine.connect() as conn:
460
+ result = await conn.execute(text(sql), params)
461
+ rows = result.mappings().all()
452
462
 
453
- # Parse JSON results into Polars DataFrame
463
+ # Parse results into Polars DataFrame
464
+ # Note: asyncpg auto-deserializes JSON, so hits may already be a list
454
465
  data = []
455
466
  for row in rows:
456
467
  input_string = row["input_string"]
457
- hits_json = row["hits"]
458
- hits = json.loads(hits_json) if hits_json else []
468
+ hits_raw = row["hits"]
469
+ if hits_raw is None:
470
+ hits = []
471
+ elif isinstance(hits_raw, list):
472
+ hits = hits_raw # Already deserialized by asyncpg
473
+ else:
474
+ hits = json.loads(hits_raw) # String, needs parsing
459
475
  data.append({"input_string": input_string, "hits": hits})
460
476
 
461
477
  return pl.DataFrame(data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
@@ -477,6 +493,8 @@ LEFT JOIN agg ON agg.Q = aq.Q;
477
493
  Returns:
478
494
  Dict mapping concept_id to ConceptInfo
479
495
  """
496
+ await self._ensure_initialized()
497
+
480
498
  if not concept_ids:
481
499
  return {}
482
500
 
@@ -500,20 +518,28 @@ LEFT JOIN agg ON agg.Q = aq.Q;
500
518
  semantic_types=[],
501
519
  )
502
520
 
503
- # Build idmap VALUES clause
504
- params: list[str] = []
521
+ # Build idmap VALUES clause using named parameters
522
+ params: dict[str, Any] = {}
523
+ param_idx = 0
505
524
  idmap_placeholders = []
506
525
  for cid in id_list:
507
- idx = len(params)
508
- params.append(cid)
509
- idmap_placeholders.append(f"(${idx + 1})")
526
+ key = f"p{param_idx}"
527
+ params[key] = cid
528
+ idmap_placeholders.append(f"(:{key})")
529
+ param_idx += 1
510
530
  idmap_values = ", ".join(idmap_placeholders)
511
531
 
512
532
  # Build preference clauses
513
533
  tty_join = ""
514
534
  tty_bump = "0"
515
535
  if prefer_ttys:
516
- tty_vals = ", ".join(f"('{t}')" for t in prefer_ttys)
536
+ tty_placeholders = []
537
+ for tty in prefer_ttys:
538
+ key = f"p{param_idx}"
539
+ params[key] = tty
540
+ tty_placeholders.append(f"(:{key})")
541
+ param_idx += 1
542
+ tty_vals = ", ".join(tty_placeholders)
517
543
  tty_join = f"LEFT JOIN (VALUES {tty_vals}) AS pt(tty) ON a.name_type = pt.tty"
518
544
  tty_bump = "CASE WHEN pt.tty IS NULL THEN 0 ELSE 1 END"
519
545
 
@@ -591,8 +617,9 @@ LEFT JOIN syn_agg sa ON sa.concept_id = c.concept_id
591
617
  ORDER BY c.concept_id;
592
618
  """
593
619
 
594
- async with self._pool.acquire() as con:
595
- rows = await con.fetch(sql, *params)
620
+ async with self._engine.connect() as conn:
621
+ result = await conn.execute(text(sql), params)
622
+ rows = result.mappings().all()
596
623
 
597
624
  for row in rows:
598
625
  cid = row["concept_id"]
@@ -625,18 +652,26 @@ ORDER BY c.concept_id;
625
652
  prefer_def_sources: list[str] | None,
626
653
  ) -> None:
627
654
  """Populate definitions for concepts."""
628
- params: list[str] = []
655
+ params: dict[str, Any] = {}
656
+ param_idx = 0
629
657
  idmap_placeholders = []
630
658
  for cid in id_list:
631
- idx = len(params)
632
- params.append(cid)
633
- idmap_placeholders.append(f"(${idx + 1})")
659
+ key = f"p{param_idx}"
660
+ params[key] = cid
661
+ idmap_placeholders.append(f"(:{key})")
662
+ param_idx += 1
634
663
  idmap_values = ", ".join(idmap_placeholders)
635
664
 
636
665
  def_pref_join = ""
637
666
  def_pref_bump = "0"
638
667
  if prefer_def_sources:
639
- def_vals = ", ".join(f"('{src}')" for src in prefer_def_sources)
668
+ def_placeholders = []
669
+ for src in prefer_def_sources:
670
+ key = f"p{param_idx}"
671
+ params[key] = src
672
+ def_placeholders.append(f"(:{key})")
673
+ param_idx += 1
674
+ def_vals = ", ".join(def_placeholders)
640
675
  def_pref_join = f"LEFT JOIN (VALUES {def_vals}) AS pds(sab) ON d.source = pds.sab"
641
676
  def_pref_bump = "CASE WHEN pds.sab IS NULL THEN 0 ELSE 1 END"
642
677
 
@@ -665,8 +700,9 @@ FROM def_best
665
700
  WHERE drn = 1;
666
701
  """
667
702
 
668
- async with self._pool.acquire() as con:
669
- rows = await con.fetch(sql, *params)
703
+ async with self._engine.connect() as conn:
704
+ result = await conn.execute(text(sql), params)
705
+ rows = result.mappings().all()
670
706
 
671
707
  for row in rows:
672
708
  cid = row["concept_id"]
@@ -680,12 +716,12 @@ WHERE drn = 1;
680
716
  id_list: list[str],
681
717
  ) -> None:
682
718
  """Populate semantic types for concepts."""
683
- params: list[str] = []
719
+ params: dict[str, Any] = {}
684
720
  idmap_placeholders = []
685
- for cid in id_list:
686
- idx = len(params)
687
- params.append(cid)
688
- idmap_placeholders.append(f"(${idx + 1})")
721
+ for i, cid in enumerate(id_list):
722
+ key = f"p{i}"
723
+ params[key] = cid
724
+ idmap_placeholders.append(f"(:{key})")
689
725
  idmap_values = ", ".join(idmap_placeholders)
690
726
 
691
727
  sql = f"""
@@ -696,8 +732,9 @@ JOIN idmap c ON c.concept_id = t.concept_id
696
732
  ORDER BY t.concept_id, t.type_tree, t.type_id;
697
733
  """
698
734
 
699
- async with self._pool.acquire() as con:
700
- rows = await con.fetch(sql, *params)
735
+ async with self._engine.connect() as conn:
736
+ result = await conn.execute(text(sql), params)
737
+ rows = result.mappings().all()
701
738
 
702
739
  for row in rows:
703
740
  cid = row["concept_id"]
@@ -713,17 +750,19 @@ ORDER BY t.concept_id, t.type_tree, t.type_id;
713
750
 
714
751
  Returns dict mapping concept_id to list of {"tui": ..., "sty": ...}
715
752
  """
753
+ await self._ensure_initialized()
754
+
716
755
  if not self._has_types or not concept_ids:
717
756
  return {cid: [] for cid in concept_ids}
718
757
 
719
758
  id_list = list(dict.fromkeys(concept_ids))
720
759
 
721
- params: list[str] = []
760
+ params: dict[str, Any] = {}
722
761
  idmap_placeholders = []
723
- for cid in id_list:
724
- idx = len(params)
725
- params.append(cid)
726
- idmap_placeholders.append(f"(${idx + 1})")
762
+ for i, cid in enumerate(id_list):
763
+ key = f"p{i}"
764
+ params[key] = cid
765
+ idmap_placeholders.append(f"(:{key})")
727
766
  idmap_values = ", ".join(idmap_placeholders)
728
767
 
729
768
  sql = f"""
@@ -734,8 +773,9 @@ JOIN idmap c ON c.concept_id = t.concept_id
734
773
  ORDER BY t.concept_id, t.type_tree, t.type_id;
735
774
  """
736
775
 
737
- async with self._pool.acquire() as con:
738
- rows = await con.fetch(sql, *params)
776
+ async with self._engine.connect() as conn:
777
+ result = await conn.execute(text(sql), params)
778
+ rows = result.mappings().all()
739
779
 
740
780
  res: dict[str, list[dict[str, str]]] = {cid: [] for cid in id_list}
741
781
  for row in rows:
@@ -762,90 +802,57 @@ ORDER BY t.concept_id, t.type_tree, t.type_id;
762
802
  Returns:
763
803
  List of descendant concept IDs (excludes the starting concept)
764
804
  """
805
+ await self._ensure_initialized()
806
+
765
807
  if not self._has_edges:
766
808
  return []
767
809
 
810
+ params: dict[str, Any] = {"concept_id": concept_id, "max_depth": max_depth}
811
+
768
812
  # Build source filter clause
769
813
  source_filter = ""
770
814
  if filter_sources:
771
- sources_sql = ", ".join(f"'{src}'" for src in filter_sources)
815
+ src_placeholders = []
816
+ for i, src in enumerate(filter_sources):
817
+ key = f"src{i}"
818
+ params[key] = src
819
+ src_placeholders.append(f":{key}")
820
+ sources_sql = ", ".join(src_placeholders)
772
821
  source_filter = f" AND e.source IN ({sources_sql})"
773
822
 
774
- # PostgreSQL recursive CTE
823
+ # PostgreSQL recursive CTE with named parameters
824
+ # Use CAST() instead of :: to avoid conflicts with SQLAlchemy named params
825
+ # UNION (not UNION ALL) deduplicates on (concept_id, depth) during recursion
826
+ # DISTINCT in output needed since same concept can be reached at different depths
775
827
  query = f"""
776
828
  WITH RECURSIVE walk(concept_id, depth) AS (
777
- SELECT $1::VARCHAR, 0
829
+ SELECT CAST(:concept_id AS VARCHAR), 0
778
830
 
779
- UNION ALL
831
+ UNION
780
832
 
781
833
  SELECT e.child_id, w.depth + 1
782
834
  FROM walk w
783
835
  JOIN {self._edges_table} e ON e.parent_id = w.concept_id
784
- WHERE ($2::INTEGER IS NULL OR w.depth < $2){source_filter}
836
+ WHERE (CAST(:max_depth AS INTEGER) IS NULL OR w.depth < :max_depth){source_filter}
785
837
  )
786
838
  SELECT DISTINCT concept_id
787
839
  FROM walk
788
- WHERE concept_id != $1
840
+ WHERE concept_id != :concept_id
789
841
  """
790
842
 
791
- async with self._pool.acquire() as con:
792
- rows = await con.fetch(query, concept_id, max_depth)
843
+ async with self._engine.connect() as conn:
844
+ result = await conn.execute(text(query), params)
845
+ rows = result.mappings().all()
793
846
 
794
847
  return [r["concept_id"] for r in rows]
795
848
 
796
- def normalize_sync(
797
- self,
798
- strings: Sequence[str],
799
- top_k: int = 25,
800
- prefer_ttys: list[str] | None = None,
801
- filter_sources: list[str] | None = None,
802
- exclude_sources: list[str] | None = None,
803
- allow_partial: bool = True,
804
- min_coverage: float = 0.6,
805
- min_word_hits: int | None = None,
806
- coverage_weight: int = 25,
807
- ) -> pl.DataFrame:
808
- """
809
- Synchronous wrapper around normalize().
810
-
811
- Requires the normalizer to be created with create_sync() factory method.
812
- """
813
- if self._loop is None:
814
- raise RuntimeError("normalize_sync() requires normalizer created with create_sync()")
815
-
816
- return self._loop.run_until_complete(
817
- self.normalize(
818
- strings=strings,
819
- top_k=top_k,
820
- prefer_ttys=prefer_ttys,
821
- filter_sources=filter_sources,
822
- exclude_sources=exclude_sources,
823
- allow_partial=allow_partial,
824
- min_coverage=min_coverage,
825
- min_word_hits=min_word_hits,
826
- coverage_weight=coverage_weight,
827
- )
828
- )
829
-
830
849
  async def close(self) -> None:
831
850
  """
832
- Close the connection pool and any owned resources.
851
+ Close the engine and any owned resources.
833
852
 
834
- Note: Only call this if you want to close the pool. If the pool
853
+ Note: Only call this if you want to close the engine. If the engine
835
854
  is managed externally, the caller should close it instead.
836
855
  """
837
- await self._pool.close()
856
+ await self._engine.dispose()
838
857
  if self._owned_resource is not None:
839
858
  await self._owned_resource.close()
840
-
841
- def close_sync(self) -> None:
842
- """
843
- Synchronously close the connection pool and event loop.
844
-
845
- Use this when the normalizer was created with create_sync().
846
- """
847
- if self._loop is None:
848
- raise RuntimeError("close_sync() requires normalizer created with create_sync()")
849
-
850
- self._loop.run_until_complete(self._pool.close())
851
- self._loop.close()
File without changes