norm_toolkit 1.1.0__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: norm_toolkit
3
- Version: 1.1.0
3
+ Version: 1.3.0
4
4
  Summary: Toolkit to normalize text to UMLS / ontologies
5
5
  Author: Haydn Jones
6
6
  Author-email: Haydn Jones <haydnjonest@gmail.com>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "norm_toolkit"
3
- version = "1.1.0"
3
+ version = "1.3.0"
4
4
  description = "Toolkit to normalize text to UMLS / ontologies"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "Haydn Jones", email = "haydnjonest@gmail.com" }]
@@ -38,6 +38,9 @@ HIT_STRUCT_TYPE = pl.Struct(
38
38
  "score": pl.Int64,
39
39
  "total_score": pl.Int64,
40
40
  "match_type": pl.Utf8,
41
+ "pref_name": pl.Utf8,
42
+ "description": pl.Utf8,
43
+ "synonyms": pl.List(pl.Utf8),
41
44
  }
42
45
  )
43
46
 
@@ -324,6 +324,7 @@ LEFT JOIN agg ON agg.Q = aq.Q;
324
324
  def normalize(
325
325
  self,
326
326
  strings: Sequence[str],
327
+ synonyms: Mapping[str, Sequence[str]] | None = None,
327
328
  top_k: int = 25,
328
329
  prefer_ttys: list[str] | None = None,
329
330
  filter_sources: list[str] | None = None,
@@ -338,6 +339,10 @@ LEFT JOIN agg ON agg.Q = aq.Q;
338
339
 
339
340
  Args:
340
341
  strings: Input strings to normalize
342
+ synonyms: Optional mapping of input strings to their synonyms.
343
+ Synonyms are normalized and used alongside the main string
344
+ to improve matching. Results are still keyed by the original
345
+ input string.
341
346
  top_k: Maximum number of results per query
342
347
  prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
343
348
  filter_sources: Restrict to these sources (include only)
@@ -348,7 +353,8 @@ LEFT JOIN agg ON agg.Q = aq.Q;
348
353
  coverage_weight: Weight for coverage in scoring
349
354
 
350
355
  Returns:
351
- DataFrame with columns: input_string, hits (list of match structs)
356
+ DataFrame with columns: input_string, hits (list of match structs),
357
+ and synonyms (list of strings) if synonyms were provided.
352
358
  """
353
359
  # Apply defaults
354
360
  if prefer_ttys is None:
@@ -358,9 +364,14 @@ LEFT JOIN agg ON agg.Q = aq.Q;
358
364
  q_to_nstrs: dict[str, list[str]] = {}
359
365
  for s in strings:
360
366
  nstrs = list(lvg_normalize(s) or [])
367
+ # Add normalized forms of synonyms
368
+ if synonyms and s in synonyms:
369
+ for syn in synonyms[s]:
370
+ syn_nstrs = list(lvg_normalize(syn) or [])
371
+ nstrs.extend(syn_nstrs)
361
372
  q_to_nstrs[s] = nstrs
362
373
 
363
- return self._lookup(
374
+ result = self._lookup(
364
375
  q_to_nstrs=q_to_nstrs,
365
376
  all_queries=list(strings),
366
377
  prefer_ttys=prefer_ttys,
@@ -373,6 +384,13 @@ LEFT JOIN agg ON agg.Q = aq.Q;
373
384
  coverage_weight=coverage_weight,
374
385
  )
375
386
 
387
+ # Add synonyms column if synonyms were provided
388
+ if synonyms:
389
+ syn_list = [list(synonyms.get(s, [])) for s in strings]
390
+ result = result.with_columns(pl.Series("synonyms", syn_list))
391
+
392
+ return result
393
+
376
394
  def concept_info(
377
395
  self,
378
396
  concept_ids: Sequence[str],
@@ -110,6 +110,7 @@ class PostgresNormalizer:
110
110
  async def normalize(
111
111
  self,
112
112
  strings: Sequence[str],
113
+ synonyms: Mapping[str, Sequence[str]] | None = None,
113
114
  top_k: int = 25,
114
115
  prefer_ttys: list[str] | None = None,
115
116
  filter_sources: list[str] | None = None,
@@ -124,6 +125,10 @@ class PostgresNormalizer:
124
125
 
125
126
  Args:
126
127
  strings: Input strings to normalize
128
+ synonyms: Optional mapping of input strings to their synonyms.
129
+ Synonyms are normalized and used alongside the main string
130
+ to improve matching. Results are still keyed by the original
131
+ input string.
127
132
  top_k: Maximum number of results per query
128
133
  prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
129
134
  filter_sources: Restrict to these sources (include only)
@@ -134,7 +139,8 @@ class PostgresNormalizer:
134
139
  coverage_weight: Weight for coverage in scoring
135
140
 
136
141
  Returns:
137
- DataFrame with columns: input_string, hits (list of match structs)
142
+ DataFrame with columns: input_string, hits (list of match structs),
143
+ and synonyms (list of strings) if synonyms were provided.
138
144
  """
139
145
  await self._ensure_initialized()
140
146
 
@@ -145,9 +151,14 @@ class PostgresNormalizer:
145
151
  q_to_nstrs: dict[str, list[str]] = {}
146
152
  for s in strings:
147
153
  nstrs = list(lvg_normalize(s) or [])
154
+ # Add normalized forms of synonyms
155
+ if synonyms and s in synonyms:
156
+ for syn in synonyms[s]:
157
+ syn_nstrs = list(lvg_normalize(syn) or [])
158
+ nstrs.extend(syn_nstrs)
148
159
  q_to_nstrs[s] = nstrs
149
160
 
150
- return await self._lookup(
161
+ result = await self._lookup(
151
162
  q_to_nstrs=q_to_nstrs,
152
163
  all_queries=list(strings),
153
164
  prefer_ttys=prefer_ttys,
@@ -160,6 +171,16 @@ class PostgresNormalizer:
160
171
  coverage_weight=coverage_weight,
161
172
  )
162
173
 
174
+ # Enrich hits with concept info (pref_name, description, synonyms)
175
+ result = await self._enrich_hits_with_concept_info(result, prefer_ttys)
176
+
177
+ # Add synonyms column if synonyms were provided
178
+ if synonyms:
179
+ syn_list = [list(synonyms.get(s, [])) for s in strings]
180
+ result = result.with_columns(pl.Series("input_synonyms", syn_list))
181
+
182
+ return result
183
+
163
184
  async def _lookup(
164
185
  self,
165
186
  q_to_nstrs: Mapping[str, Sequence[str]],
@@ -458,6 +479,58 @@ LEFT JOIN agg ON agg.Q = aq.Q;
458
479
 
459
480
  return pl.DataFrame(data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
460
481
 
482
+ async def _enrich_hits_with_concept_info(
483
+ self,
484
+ result: pl.DataFrame,
485
+ prefer_ttys: list[str] | None,
486
+ ) -> pl.DataFrame:
487
+ """Enrich hits with pref_name, description, and synonyms from concept_info."""
488
+ # Collect all unique concept_ids from hits
489
+ all_concept_ids: set[str] = set()
490
+ for hits in result["hits"].to_list():
491
+ if hits:
492
+ for hit in hits:
493
+ if hit and "global_identifier" in hit:
494
+ all_concept_ids.add(hit["global_identifier"])
495
+
496
+ if not all_concept_ids:
497
+ # No concepts to enrich, just add empty fields
498
+ enriched_data = []
499
+ for row in result.iter_rows(named=True):
500
+ enriched_hits = []
501
+ for hit in row["hits"] or []:
502
+ enriched_hit = dict(hit)
503
+ enriched_hit["pref_name"] = None
504
+ enriched_hit["description"] = None
505
+ enriched_hit["synonyms"] = []
506
+ enriched_hits.append(enriched_hit)
507
+ enriched_data.append({"input_string": row["input_string"], "hits": enriched_hits})
508
+ return pl.DataFrame(enriched_data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
509
+
510
+ # Get concept info for all concepts
511
+ concept_infos = await self.concept_info(list(all_concept_ids), prefer_ttys=prefer_ttys)
512
+
513
+ # Enrich each hit
514
+ enriched_data = []
515
+ for row in result.iter_rows(named=True):
516
+ enriched_hits = []
517
+ for hit in row["hits"] or []:
518
+ enriched_hit = dict(hit)
519
+ cid = hit.get("global_identifier")
520
+ if cid and cid in concept_infos:
521
+ info = concept_infos[cid]
522
+ enriched_hit["pref_name"] = info.preferred_name
523
+ enriched_hit["description"] = info.description
524
+ enriched_hit["synonyms"] = info.synonyms or []
525
+ else:
526
+ enriched_hit["pref_name"] = None
527
+ enriched_hit["description"] = None
528
+ enriched_hit["synonyms"] = []
529
+ enriched_hits.append(enriched_hit)
530
+ enriched_data.append({"input_string": row["input_string"], "hits": enriched_hits})
531
+
532
+ return pl.DataFrame(enriched_data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
533
+
461
534
  async def concept_info(
462
535
  self,
463
536
  concept_ids: Sequence[str],
@@ -804,11 +877,13 @@ ORDER BY t.concept_id, t.type_tree, t.type_id;
804
877
 
805
878
  # PostgreSQL recursive CTE with named parameters
806
879
  # Use CAST() instead of :: to avoid conflicts with SQLAlchemy named params
880
+ # UNION (not UNION ALL) deduplicates on (concept_id, depth) during recursion
881
+ # DISTINCT in output needed since same concept can be reached at different depths
807
882
  query = f"""
808
883
  WITH RECURSIVE walk(concept_id, depth) AS (
809
884
  SELECT CAST(:concept_id AS VARCHAR), 0
810
885
 
811
- UNION ALL
886
+ UNION
812
887
 
813
888
  SELECT e.child_id, w.depth + 1
814
889
  FROM walk w
File without changes