norm_toolkit 1.6.0__tar.gz → 1.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: norm_toolkit
3
- Version: 1.6.0
3
+ Version: 1.8.0
4
4
  Summary: Toolkit to normalize text to UMLS / ontologies
5
5
  Author: Haydn Jones
6
6
  Author-email: Haydn Jones <haydnjonest@gmail.com>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "norm_toolkit"
3
- version = "1.6.0"
3
+ version = "1.8.0"
4
4
  description = "Toolkit to normalize text to UMLS / ontologies"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "Haydn Jones", email = "haydnjonest@gmail.com" }]
@@ -24,6 +24,8 @@ dev = [
24
24
  "pytest>=8.3",
25
25
  "rdkit>=2025.9.3",
26
26
  "ruff>=0.6.9",
27
+ "fire>=0.7.1",
28
+ "joblib>=1.5.3",
27
29
  ]
28
30
 
29
31
  [build-system]
@@ -33,7 +33,7 @@ from norm_toolkit.normalizer_utils import (
33
33
  build_definitions_sql,
34
34
  build_hits_agg_expr,
35
35
  build_lookup_sql,
36
- build_normalized_string_map,
36
+ build_normalized_query_map,
37
37
  build_ontology_filter_clauses,
38
38
  build_pref_join,
39
39
  build_query_rows,
@@ -206,7 +206,7 @@ class DuckDBNormalizer:
206
206
  def normalize(
207
207
  self,
208
208
  strings: Sequence[str],
209
- synonyms: Mapping[str, Sequence[str]] | None = None,
209
+ synonyms: Sequence[Sequence[str] | None] | None = None,
210
210
  top_k: int | None = 25,
211
211
  ont_top_k: int | None = None,
212
212
  prefer_ttys: list[str] | None = None,
@@ -222,10 +222,10 @@ class DuckDBNormalizer:
222
222
 
223
223
  Args:
224
224
  strings: Input strings to normalize
225
- synonyms: Optional mapping of input strings to their synonyms.
226
- Synonyms are normalized and used alongside the main string
227
- to improve matching. Results are still keyed by the original
228
- input string.
225
+ synonyms: Optional list of synonym lists aligned with `strings`
226
+ (same length required). Synonyms are normalized and used
227
+ alongside the main string to improve matching. Results are
228
+ still keyed by the original input string.
229
229
  top_k: Maximum number of results per query (mutually exclusive with ont_top_k)
230
230
  ont_top_k: Maximum number of results per ontology (mutually exclusive with top_k)
231
231
  prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
@@ -244,12 +244,15 @@ class DuckDBNormalizer:
244
244
  if prefer_ttys is None:
245
245
  prefer_ttys = DEFAULT_PREFER_TTYS
246
246
 
247
- # Build normalized string map
248
- q_to_nstrs = build_normalized_string_map(strings, synonyms)
247
+ strings_list = list(strings)
248
+ query_keys = [f"q{i}" for i in range(len(strings_list))] if synonyms is not None else strings_list
249
+
250
+ # Build normalized string map with per-entry keys
251
+ q_to_nstrs, syn_list = build_normalized_query_map(strings_list, synonyms, query_keys=query_keys)
249
252
 
250
253
  result = self._lookup(
251
254
  q_to_nstrs=q_to_nstrs,
252
- all_queries=list(strings),
255
+ all_queries=query_keys,
253
256
  prefer_ttys=prefer_ttys,
254
257
  filter_ontologies=filter_ontologies,
255
258
  exclude_ontologies=exclude_ontologies,
@@ -261,9 +264,11 @@ class DuckDBNormalizer:
261
264
  coverage_weight=coverage_weight,
262
265
  )
263
266
 
267
+ result = result.with_columns(pl.Series("input_string", strings_list))
268
+
264
269
  # Add synonyms column if synonyms were provided
265
- if synonyms:
266
- syn_list = [list(synonyms.get(s, [])) for s in strings]
270
+ if synonyms is not None:
271
+ syn_list = syn_list if syn_list is not None else [[] for _ in strings_list]
267
272
  result = result.with_columns(pl.Series("synonyms", syn_list))
268
273
 
269
274
  return result
@@ -1,5 +1,5 @@
1
1
  """
2
- LRU cache for normalized string lookup results.
2
+ LRU caches for normalized string lookups and entity expansion results.
3
3
 
4
4
  Caches at the normalized string level to avoid repeated DB round trips
5
5
  for the same normalized forms.
@@ -10,7 +10,10 @@ from __future__ import annotations
10
10
  import hashlib
11
11
  from collections import OrderedDict
12
12
  from dataclasses import dataclass
13
- from typing import Any
13
+ from typing import Any, Generic, TypeVar
14
+
15
+ K = TypeVar("K")
16
+ V = TypeVar("V")
14
17
 
15
18
 
16
19
  @dataclass(frozen=True)
@@ -29,12 +32,21 @@ class CacheKey:
29
32
  coverage_weight: int
30
33
 
31
34
 
32
- class NormalizerCache:
35
+ @dataclass(frozen=True)
36
+ class ExpansionCacheKey:
37
+ """Immutable cache key for entity expansion results."""
38
+
39
+ concept_id: str
40
+ max_depth: int | None
41
+ filter_ontologies: tuple[str, ...] | None
42
+ max_ids: int | None
43
+
44
+
45
+ class LRUCache(Generic[K, V]):
33
46
  """
34
- LRU cache for normalized string lookup results.
47
+ LRU cache with basic hit/miss statistics.
35
48
 
36
- Caches the fully enriched hits for a given tuple of normalized strings
37
- and query parameters. Uses an OrderedDict for O(1) LRU eviction.
49
+ Uses an OrderedDict for O(1) LRU eviction.
38
50
  """
39
51
 
40
52
  def __init__(self, maxsize: int = 10000) -> None:
@@ -45,11 +57,86 @@ class NormalizerCache:
45
57
  maxsize: Maximum number of entries to cache. When exceeded,
46
58
  the least recently used entries are evicted.
47
59
  """
48
- self._cache: OrderedDict[CacheKey, list[dict[str, Any]]] = OrderedDict()
60
+ self._cache: OrderedDict[K, V] = OrderedDict()
49
61
  self._maxsize = maxsize
50
62
  self._hits = 0
51
63
  self._misses = 0
52
64
 
65
+ def get(self, key: K) -> V | None:
66
+ """
67
+ Get cached value for a key.
68
+
69
+ Args:
70
+ key: Cache key to look up
71
+
72
+ Returns:
73
+ Cached value if found, None if not in cache
74
+ """
75
+ if key in self._cache:
76
+ # Move to end (most recently used)
77
+ self._cache.move_to_end(key)
78
+ self._hits += 1
79
+ return self._cache[key]
80
+ self._misses += 1
81
+ return None
82
+
83
+ def set(self, key: K, value: V) -> None:
84
+ """
85
+ Store a value in the cache.
86
+
87
+ Args:
88
+ key: Cache key
89
+ value: Value to cache
90
+ """
91
+ if key in self._cache:
92
+ self._cache.move_to_end(key)
93
+ else:
94
+ if len(self._cache) >= self._maxsize:
95
+ # Remove oldest item (LRU eviction)
96
+ self._cache.popitem(last=False)
97
+ self._cache[key] = value
98
+
99
+ def clear(self) -> None:
100
+ """Clear all cached entries."""
101
+ self._cache.clear()
102
+ self._hits = 0
103
+ self._misses = 0
104
+
105
+ @property
106
+ def size(self) -> int:
107
+ """Current number of cached entries."""
108
+ return len(self._cache)
109
+
110
+ @property
111
+ def hit_rate(self) -> float:
112
+ """Cache hit rate (0.0 to 1.0)."""
113
+ total = self._hits + self._misses
114
+ return self._hits / total if total > 0 else 0.0
115
+
116
+ def stats(self) -> dict[str, Any]:
117
+ """
118
+ Get cache statistics.
119
+
120
+ Returns:
121
+ Dict with size, maxsize, hits, misses, and hit_rate
122
+ """
123
+ return {
124
+ "size": self.size,
125
+ "maxsize": self._maxsize,
126
+ "hits": self._hits,
127
+ "misses": self._misses,
128
+ "hit_rate": self.hit_rate,
129
+ }
130
+
131
+
132
+ class NormalizerCache(LRUCache[CacheKey, list[dict[str, Any]]]):
133
+ """
134
+ LRU cache for normalized string lookup results.
135
+
136
+ Caches the fully enriched hits for a given tuple of normalized strings
137
+ and query parameters.
138
+ """
139
+
53
140
  @staticmethod
54
141
  def make_key(
55
142
  nstrs: tuple[str, ...],
@@ -100,68 +187,37 @@ class NormalizerCache:
100
187
  coverage_weight=coverage_weight,
101
188
  )
102
189
 
103
- def get(self, key: CacheKey) -> list[dict[str, Any]] | None:
104
- """
105
- Get cached hits for a key.
106
190
 
107
- Args:
108
- key: Cache key to look up
191
+ class ExpansionCache(LRUCache[ExpansionCacheKey, list[str]]):
192
+ """
193
+ LRU cache for entity expansion results.
109
194
 
110
- Returns:
111
- Cached hits list if found, None if not in cache
112
- """
113
- if key in self._cache:
114
- # Move to end (most recently used)
115
- self._cache.move_to_end(key)
116
- self._hits += 1
117
- return self._cache[key]
118
- self._misses += 1
119
- return None
195
+ Caches expanded concept IDs for a given concept and traversal parameters.
196
+ """
120
197
 
121
- def set(self, key: CacheKey, hits: list[dict[str, Any]]) -> None:
198
+ @staticmethod
199
+ def make_key(
200
+ concept_id: str,
201
+ *,
202
+ max_depth: int | None,
203
+ filter_ontologies: list[str] | None,
204
+ max_ids: int | None,
205
+ ) -> ExpansionCacheKey:
122
206
  """
123
- Store hits in the cache.
207
+ Create a cache key from entity expansion parameters.
124
208
 
125
209
  Args:
126
- key: Cache key
127
- hits: List of hit dictionaries to cache
128
- """
129
- if key in self._cache:
130
- self._cache.move_to_end(key)
131
- else:
132
- if len(self._cache) >= self._maxsize:
133
- # Remove oldest item (LRU eviction)
134
- self._cache.popitem(last=False)
135
- self._cache[key] = hits
136
-
137
- def clear(self) -> None:
138
- """Clear all cached entries."""
139
- self._cache.clear()
140
- self._hits = 0
141
- self._misses = 0
142
-
143
- @property
144
- def size(self) -> int:
145
- """Current number of cached entries."""
146
- return len(self._cache)
147
-
148
- @property
149
- def hit_rate(self) -> float:
150
- """Cache hit rate (0.0 to 1.0)."""
151
- total = self._hits + self._misses
152
- return self._hits / total if total > 0 else 0.0
153
-
154
- def stats(self) -> dict[str, Any]:
155
- """
156
- Get cache statistics.
210
+ concept_id: Starting concept ID
211
+ max_depth: Maximum depth to traverse
212
+ filter_ontologies: Ontologies to include
213
+ max_ids: Maximum number of IDs to return
157
214
 
158
215
  Returns:
159
- Dict with size, maxsize, hits, misses, and hit_rate
216
+ Immutable ExpansionCacheKey instance
160
217
  """
161
- return {
162
- "size": self.size,
163
- "maxsize": self._maxsize,
164
- "hits": self._hits,
165
- "misses": self._misses,
166
- "hit_rate": self.hit_rate,
167
- }
218
+ return ExpansionCacheKey(
219
+ concept_id=concept_id,
220
+ max_depth=max_depth,
221
+ filter_ontologies=tuple(filter_ontologies) if filter_ontologies else None,
222
+ max_ids=max_ids,
223
+ )
@@ -25,7 +25,7 @@ from norm_toolkit.constants import (
25
25
  TYPES_TABLE,
26
26
  )
27
27
  from norm_toolkit.models import ConceptInfo
28
- from norm_toolkit.normalizer_cache import NormalizerCache
28
+ from norm_toolkit.normalizer_cache import ExpansionCache, NormalizerCache
29
29
  from norm_toolkit.normalizer_utils import (
30
30
  apply_concept_name_rows,
31
31
  apply_definition_rows,
@@ -34,7 +34,7 @@ from norm_toolkit.normalizer_utils import (
34
34
  build_definitions_sql,
35
35
  build_hits_agg_expr,
36
36
  build_lookup_sql,
37
- build_normalized_string_map,
37
+ build_normalized_query_map,
38
38
  build_ontology_filter_clauses,
39
39
  build_pref_join,
40
40
  build_query_rows,
@@ -94,8 +94,8 @@ class PostgresNormalizer:
94
94
  schema: PostgreSQL schema where tables are located (default: "public")
95
95
  owned_resource: Optional resource with async close() method to clean up
96
96
  when this normalizer is closed (e.g., AlloyDB AsyncConnector)
97
- cache_maxsize: Maximum number of entries in the normalized string cache
98
- enable_cache: Whether to enable caching of normalized string lookups
97
+ cache_maxsize: Maximum number of entries in each cache
98
+ enable_cache: Whether to enable caching for normalization and expansion
99
99
 
100
100
  Note:
101
101
  After creating the normalizer, call `await normalizer.initialize()`
@@ -110,8 +110,9 @@ class PostgresNormalizer:
110
110
  self._has_stt = False
111
111
  self._initialized = False
112
112
 
113
- # Initialize cache
113
+ # Initialize caches
114
114
  self._cache: NormalizerCache | None = NormalizerCache(maxsize=cache_maxsize) if enable_cache else None
115
+ self._expansion_cache: ExpansionCache | None = ExpansionCache(maxsize=cache_maxsize) if enable_cache else None
115
116
 
116
117
  # Build qualified table names
117
118
  prefix = f"{schema}." if schema else ""
@@ -158,7 +159,7 @@ class PostgresNormalizer:
158
159
  async def normalize(
159
160
  self,
160
161
  strings: Sequence[str],
161
- synonyms: Mapping[str, Sequence[str]] | None = None,
162
+ synonyms: Sequence[Sequence[str] | None] | None = None,
162
163
  top_k: int | None = 25,
163
164
  ont_top_k: int | None = None,
164
165
  prefer_ttys: list[str] | None = None,
@@ -174,10 +175,10 @@ class PostgresNormalizer:
174
175
 
175
176
  Args:
176
177
  strings: Input strings to normalize
177
- synonyms: Optional mapping of input strings to their synonyms.
178
- Synonyms are normalized and used alongside the main string
179
- to improve matching. Results are still keyed by the original
180
- input string.
178
+ synonyms: Optional list of synonym lists aligned with `strings`
179
+ (same length required). Synonyms are normalized and used
180
+ alongside the main string to improve matching. Results are
181
+ still keyed by the original input string.
181
182
  top_k: Maximum number of results per query (mutually exclusive with ont_top_k)
182
183
  ont_top_k: Maximum number of results per ontology (mutually exclusive with top_k)
183
184
  prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
@@ -204,6 +205,9 @@ class PostgresNormalizer:
204
205
  if ont_top_k is not None:
205
206
  ont_top_k = max(1, int(ont_top_k))
206
207
 
208
+ strings_list = list(strings)
209
+ query_keys = [f"q{i}" for i in range(len(strings_list))] if synonyms is not None else strings_list
210
+
207
211
  def make_cache_key(nstrs: tuple[str, ...]) -> Any:
208
212
  return NormalizerCache.make_key(
209
213
  nstrs,
@@ -218,8 +222,8 @@ class PostgresNormalizer:
218
222
  coverage_weight=coverage_weight,
219
223
  )
220
224
 
221
- # Build normalized string map (use tuple for hashable cache keys)
222
- q_to_nstrs = build_normalized_string_map(strings, synonyms)
225
+ # Build normalized string map with per-entry keys (tuples for cache keys)
226
+ q_to_nstrs, syn_list = build_normalized_query_map(strings_list, synonyms, query_keys=query_keys)
223
227
 
224
228
  # Check cache for each input
225
229
  cached_hits: dict[str, list[dict[str, Any]]] = {}
@@ -271,12 +275,15 @@ class PostgresNormalizer:
271
275
  self._cache.set(cache_key, hits)
272
276
 
273
277
  # Build final result in original order
274
- result_data = [{"input_string": s, "hits": cached_hits.get(s, [])} for s in strings]
278
+ result_data = [
279
+ {"input_string": strings_list[i], "hits": cached_hits.get(query_keys[i], [])}
280
+ for i in range(len(strings_list))
281
+ ]
275
282
  result = pl.DataFrame(result_data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
276
283
 
277
284
  # Add synonyms column if synonyms were provided
278
- if synonyms:
279
- syn_list = [list(synonyms.get(s, [])) for s in strings]
285
+ if synonyms is not None:
286
+ syn_list = syn_list if syn_list is not None else [[] for _ in strings_list]
280
287
  result = result.with_columns(pl.Series("synonyms", syn_list))
281
288
 
282
289
  return result
@@ -319,8 +326,7 @@ class PostgresNormalizer:
319
326
  qwords_values = sql_params.add_rows(qword_rows) if qword_rows else ""
320
327
 
321
328
  allq_values = ", ".join(
322
- f"({sql_params.add(q)}, {sql_params.add_cast(i, 'INTEGER')})"
323
- for i, q in enumerate(all_queries)
329
+ f"({sql_params.add(q)}, {sql_params.add_cast(i, 'INTEGER')})" for i, q in enumerate(all_queries)
324
330
  )
325
331
 
326
332
  # Build preference clauses (parameterized to prevent SQL injection)
@@ -596,61 +602,138 @@ class PostgresNormalizer:
596
602
  List of descendant concept IDs ordered by depth (shallowest first),
597
603
  excludes the starting concept
598
604
  """
605
+ results = await self.get_narrower_concepts_many(
606
+ [concept_id],
607
+ max_depth=max_depth,
608
+ filter_ontologies=filter_ontologies,
609
+ max_ids=max_ids,
610
+ )
611
+ return results.get(concept_id, [])
612
+
613
+ async def get_narrower_concepts_many(
614
+ self,
615
+ concept_ids: Sequence[str],
616
+ max_depth: int | None = 10,
617
+ filter_ontologies: list[str] | None = None,
618
+ max_ids: int | None = None,
619
+ ) -> dict[str, list[str]]:
620
+ """
621
+ Get narrower (descendant) concept IDs for many roots in one query.
622
+
623
+ Uses the hierarchy edges to walk down the tree/DAG from each root concept.
624
+
625
+ Args:
626
+ concept_ids: Starting concept IDs (broader terms)
627
+ max_depth: Maximum depth to traverse (1 = direct children only, None = all descendants)
628
+ filter_ontologies: Only follow edges from these ontologies (e.g., ["UMLS", "CHEBI"])
629
+ max_ids: Maximum number of concept IDs to return (None = no limit)
630
+
631
+ Returns:
632
+ Dict mapping each concept ID to descendant IDs ordered by depth
633
+ (shallowest first), excluding the starting concept.
634
+ """
599
635
  await self._ensure_initialized()
600
636
 
601
- if not self._has_edges:
602
- return []
637
+ if not self._has_edges or not concept_ids:
638
+ return {cid: [] for cid in concept_ids}
603
639
 
604
- params: dict[str, Any] = {"concept_id": concept_id, "max_depth": max_depth}
640
+ id_list = list(dict.fromkeys(concept_ids))
641
+
642
+ res: dict[str, list[str]] = {}
643
+ missing: list[str] = []
644
+ cache_keys: dict[str, Any] = {}
645
+
646
+ if self._expansion_cache is not None:
647
+ for cid in id_list:
648
+ cache_key = ExpansionCache.make_key(
649
+ cid,
650
+ max_depth=max_depth,
651
+ filter_ontologies=filter_ontologies,
652
+ max_ids=max_ids,
653
+ )
654
+ cache_keys[cid] = cache_key
655
+ cached = self._expansion_cache.get(cache_key)
656
+ if cached is not None:
657
+ res[cid] = cached
658
+ else:
659
+ res[cid] = []
660
+ missing.append(cid)
661
+ else:
662
+ for cid in id_list:
663
+ res[cid] = []
664
+ missing = id_list
665
+
666
+ if not missing:
667
+ return res
668
+
669
+ sql_params = _SqlParams()
670
+ idmap_values = sql_params.add_single_column_values(missing)
671
+ params = sql_params.params
672
+ params["max_depth"] = max_depth
605
673
 
606
- # Build ontology filter clause
607
674
  ontology_filter = ""
608
675
  if filter_ontologies:
609
- ont_placeholders = []
610
- for i, ont in enumerate(filter_ontologies):
611
- key = f"ont{i}"
612
- params[key] = ont
613
- ont_placeholders.append(f":{key}")
614
- ontologies_sql = ", ".join(ont_placeholders)
676
+ ontologies_sql = sql_params.add_values(filter_ontologies)
615
677
  ontology_filter = f" AND e.ontology IN ({ontologies_sql})"
616
678
 
617
- # Build optional LIMIT clause
618
- limit_clause = ""
619
- if max_ids is not None:
679
+ if max_ids is None:
680
+ select_sql = """
681
+ SELECT root_id, concept_id, MIN(depth) AS min_depth
682
+ FROM walk
683
+ WHERE concept_id != root_id
684
+ GROUP BY root_id, concept_id
685
+ ORDER BY root_id, min_depth, concept_id
686
+ """
687
+ else:
620
688
  params["max_ids"] = max_ids
621
- limit_clause = "\nLIMIT :max_ids"
689
+ select_sql = """
690
+ SELECT root_id, concept_id, min_depth
691
+ FROM (
692
+ SELECT root_id, concept_id, min_depth,
693
+ ROW_NUMBER() OVER (PARTITION BY root_id ORDER BY min_depth, concept_id) AS rn
694
+ FROM (
695
+ SELECT root_id, concept_id, MIN(depth) AS min_depth
696
+ FROM walk
697
+ WHERE concept_id != root_id
698
+ GROUP BY root_id, concept_id
699
+ ) base
700
+ ) ranked
701
+ WHERE rn <= :max_ids
702
+ ORDER BY root_id, min_depth, concept_id
703
+ """
622
704
 
623
- # PostgreSQL recursive CTE with named parameters
624
- # Use CAST() instead of :: to avoid conflicts with SQLAlchemy named params
625
- # UNION (not UNION ALL) deduplicates on (concept_id, depth) during recursion
626
- # GROUP BY with MIN(depth) gets shortest path depth for each concept
627
705
  query = dedent(
628
706
  f"""
629
- WITH RECURSIVE walk(concept_id, depth) AS (
630
- SELECT CAST(:concept_id AS VARCHAR), 0
707
+ WITH RECURSIVE idmap(root_id) AS (VALUES {idmap_values}),
708
+ walk(root_id, concept_id, depth) AS (
709
+ SELECT root_id, root_id, 0
710
+ FROM idmap
631
711
 
632
712
  UNION
633
713
 
634
- SELECT e.child_id, w.depth + 1
714
+ SELECT w.root_id, e.child_id, w.depth + 1
635
715
  FROM walk w
636
716
  JOIN {self._edges_table} e ON e.parent_id = w.concept_id
637
717
  WHERE (CAST(:max_depth AS INTEGER) IS NULL OR w.depth < :max_depth){ontology_filter}
638
718
  )
639
- SELECT concept_id, MIN(depth) AS min_depth
640
- FROM walk
641
- WHERE concept_id != :concept_id
642
- GROUP BY concept_id
643
- ORDER BY min_depth, concept_id{limit_clause}
719
+ {select_sql}
644
720
  """
645
721
  )
646
722
 
647
723
  rows = await self._fetch_rows(query, params)
648
724
 
649
- return [r["concept_id"] for r in rows]
725
+ for row in rows:
726
+ res[row["root_id"]].append(row["concept_id"])
727
+
728
+ if self._expansion_cache is not None:
729
+ for cid in missing:
730
+ self._expansion_cache.set(cache_keys[cid], res[cid])
731
+
732
+ return res
650
733
 
651
734
  def cache_stats(self) -> dict[str, Any] | None:
652
735
  """
653
- Get cache statistics.
736
+ Get normalization cache statistics.
654
737
 
655
738
  Returns:
656
739
  Dict with size, maxsize, hits, misses, and hit_rate,
@@ -660,10 +743,24 @@ class PostgresNormalizer:
660
743
  return None
661
744
  return self._cache.stats()
662
745
 
746
+ def expansion_cache_stats(self) -> dict[str, Any] | None:
747
+ """
748
+ Get entity expansion cache statistics.
749
+
750
+ Returns:
751
+ Dict with size, maxsize, hits, misses, and hit_rate,
752
+ or None if caching is disabled.
753
+ """
754
+ if self._expansion_cache is None:
755
+ return None
756
+ return self._expansion_cache.stats()
757
+
663
758
  def clear_cache(self) -> None:
664
759
  """Clear all cached entries."""
665
760
  if self._cache is not None:
666
761
  self._cache.clear()
762
+ if self._expansion_cache is not None:
763
+ self._expansion_cache.clear()
667
764
 
668
765
  async def close(self) -> None:
669
766
  """
@@ -21,24 +21,81 @@ from norm_toolkit.constants import (
21
21
  from norm_toolkit.models import ConceptInfo, SemanticType
22
22
 
23
23
 
24
+ def _coerce_synonyms_list(
25
+ strings: Sequence[str],
26
+ synonyms: Sequence[Sequence[str] | None] | None,
27
+ ) -> list[list[str]] | None:
28
+ if synonyms is None:
29
+ return None
30
+ if not isinstance(synonyms, Sequence) or isinstance(synonyms, (str, bytes)):
31
+ raise TypeError("synonyms must be a sequence of sequences aligned with strings")
32
+ if len(synonyms) != len(strings):
33
+ raise ValueError("synonyms must have the same length as strings")
34
+ out: list[list[str]] = []
35
+ for i, syns in enumerate(synonyms):
36
+ if syns is None:
37
+ out.append([])
38
+ continue
39
+ if not isinstance(syns, Sequence) or isinstance(syns, (str, bytes)):
40
+ raise ValueError(f"synonyms[{i}] must be a sequence of strings")
41
+ out.append(list(syns))
42
+ return out
43
+
44
+
24
45
  def build_normalized_string_map(
25
46
  strings: Sequence[str],
26
- synonyms: Mapping[str, Sequence[str]] | None = None,
47
+ synonyms: Sequence[Sequence[str] | None] | None = None,
27
48
  ) -> dict[str, tuple[str, ...]]:
28
49
  """
29
50
  Build a mapping of input string -> normalized string variants.
30
51
 
31
52
  Normalized variants are deduplicated while preserving order.
53
+ Duplicate input strings will collapse to the last entry.
54
+ Synonyms must be aligned with `strings` when provided.
32
55
  """
56
+ synonyms_list = _coerce_synonyms_list(strings, synonyms)
57
+ syns_iter = synonyms_list if synonyms_list is not None else [None] * len(strings)
58
+
33
59
  q_to_nstrs: dict[str, tuple[str, ...]] = {}
34
- for s in strings:
60
+ for s, syns in zip(strings, syns_iter):
35
61
  nstrs = list(lvg_normalize(s) or [])
36
- for syn in (synonyms or {}).get(s, []):
37
- nstrs.extend(lvg_normalize(syn) or [])
62
+ if syns:
63
+ for syn in syns:
64
+ nstrs.extend(lvg_normalize(syn) or [])
38
65
  q_to_nstrs[s] = tuple(dict.fromkeys(nstrs))
39
66
  return q_to_nstrs
40
67
 
41
68
 
69
+ def build_normalized_query_map(
70
+ strings: Sequence[str],
71
+ synonyms: Sequence[Sequence[str] | None] | None = None,
72
+ *,
73
+ query_keys: Sequence[str] | None = None,
74
+ ) -> tuple[dict[str, tuple[str, ...]], list[list[str]] | None]:
75
+ """
76
+ Build a mapping of query key -> normalized string variants.
77
+
78
+ Normalized variants are deduplicated while preserving order.
79
+ Synonyms must be aligned with `strings` when provided.
80
+ """
81
+ if query_keys is None:
82
+ query_keys = list(strings)
83
+ if len(query_keys) != len(strings):
84
+ raise ValueError("query_keys must have the same length as strings")
85
+
86
+ synonyms_list = _coerce_synonyms_list(strings, synonyms)
87
+ syns_iter = synonyms_list if synonyms_list is not None else [None] * len(strings)
88
+
89
+ q_to_nstrs: dict[str, tuple[str, ...]] = {}
90
+ for key, s, syns in zip(query_keys, strings, syns_iter):
91
+ nstrs = list(lvg_normalize(s) or [])
92
+ if syns:
93
+ for syn in syns:
94
+ nstrs.extend(lvg_normalize(syn) or [])
95
+ q_to_nstrs[key] = tuple(dict.fromkeys(nstrs))
96
+ return q_to_nstrs, synonyms_list
97
+
98
+
42
99
  def build_query_rows(
43
100
  q_to_nstrs: Mapping[str, Sequence[str]],
44
101
  *,
File without changes