norm_toolkit 1.2.0__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: norm_toolkit
3
- Version: 1.2.0
3
+ Version: 1.4.0
4
4
  Summary: Toolkit to normalize text to UMLS / ontologies
5
5
  Author: Haydn Jones
6
6
  Author-email: Haydn Jones <haydnjonest@gmail.com>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "norm_toolkit"
3
- version = "1.2.0"
3
+ version = "1.4.0"
4
4
  description = "Toolkit to normalize text to UMLS / ontologies"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "Haydn Jones", email = "haydnjonest@gmail.com" }]
@@ -38,6 +38,9 @@ HIT_STRUCT_TYPE = pl.Struct(
38
38
  "score": pl.Int64,
39
39
  "total_score": pl.Int64,
40
40
  "match_type": pl.Utf8,
41
+ "pref_name": pl.Utf8,
42
+ "description": pl.Utf8,
43
+ "synonyms": pl.List(pl.Utf8),
41
44
  }
42
45
  )
43
46
 
@@ -0,0 +1,163 @@
1
+ """
2
+ LRU cache for normalized string lookup results.
3
+
4
+ Caches at the normalized string level to avoid repeated DB round trips
5
+ for the same normalized forms.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import hashlib
11
+ from collections import OrderedDict
12
+ from dataclasses import dataclass
13
+ from typing import Any
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class CacheKey:
18
+ """Immutable cache key for normalized string lookup results."""
19
+
20
+ nstrs_hash: str # Hash of sorted normalized strings
21
+ top_k: int
22
+ prefer_ttys: tuple[str, ...] | None
23
+ filter_sources: tuple[str, ...] | None
24
+ exclude_sources: tuple[str, ...] | None
25
+ allow_partial: bool
26
+ min_coverage: float
27
+ min_word_hits: int | None
28
+ coverage_weight: int
29
+
30
+
31
+ class NormalizerCache:
32
+ """
33
+ LRU cache for normalized string lookup results.
34
+
35
+ Caches the fully enriched hits for a given tuple of normalized strings
36
+ and query parameters. Uses an OrderedDict for O(1) LRU eviction.
37
+ """
38
+
39
+ def __init__(self, maxsize: int = 10000) -> None:
40
+ """
41
+ Initialize the cache.
42
+
43
+ Args:
44
+ maxsize: Maximum number of entries to cache. When exceeded,
45
+ the least recently used entries are evicted.
46
+ """
47
+ self._cache: OrderedDict[CacheKey, list[dict[str, Any]]] = OrderedDict()
48
+ self._maxsize = maxsize
49
+ self._hits = 0
50
+ self._misses = 0
51
+
52
+ @staticmethod
53
+ def make_key(
54
+ nstrs: tuple[str, ...],
55
+ *,
56
+ top_k: int,
57
+ prefer_ttys: list[str] | None,
58
+ filter_sources: list[str] | None,
59
+ exclude_sources: list[str] | None,
60
+ allow_partial: bool,
61
+ min_coverage: float,
62
+ min_word_hits: int | None,
63
+ coverage_weight: int,
64
+ ) -> CacheKey:
65
+ """
66
+ Create a cache key from normalized strings and query parameters.
67
+
68
+ Args:
69
+ nstrs: Tuple of normalized strings for the query
70
+ top_k: Maximum number of results
71
+ prefer_ttys: Preferred term types
72
+ filter_sources: Include only these sources
73
+ exclude_sources: Exclude these sources
74
+ allow_partial: Whether partial matching is enabled
75
+ min_coverage: Minimum coverage threshold
76
+ min_word_hits: Minimum word hits required
77
+ coverage_weight: Weight for coverage in scoring
78
+
79
+ Returns:
80
+ Immutable CacheKey instance
81
+ """
82
+ # Hash the normalized strings tuple for compact storage
83
+ # Sort to ensure consistent hashing regardless of order
84
+ nstrs_str = "\0".join(sorted(nstrs))
85
+ nstrs_hash = hashlib.md5(nstrs_str.encode(), usedforsecurity=False).hexdigest()
86
+
87
+ return CacheKey(
88
+ nstrs_hash=nstrs_hash,
89
+ top_k=top_k,
90
+ prefer_ttys=tuple(prefer_ttys) if prefer_ttys else None,
91
+ filter_sources=tuple(filter_sources) if filter_sources else None,
92
+ exclude_sources=tuple(exclude_sources) if exclude_sources else None,
93
+ allow_partial=allow_partial,
94
+ min_coverage=min_coverage,
95
+ min_word_hits=min_word_hits,
96
+ coverage_weight=coverage_weight,
97
+ )
98
+
99
+ def get(self, key: CacheKey) -> list[dict[str, Any]] | None:
100
+ """
101
+ Get cached hits for a key.
102
+
103
+ Args:
104
+ key: Cache key to look up
105
+
106
+ Returns:
107
+ Cached hits list if found, None if not in cache
108
+ """
109
+ if key in self._cache:
110
+ # Move to end (most recently used)
111
+ self._cache.move_to_end(key)
112
+ self._hits += 1
113
+ return self._cache[key]
114
+ self._misses += 1
115
+ return None
116
+
117
+ def set(self, key: CacheKey, hits: list[dict[str, Any]]) -> None:
118
+ """
119
+ Store hits in the cache.
120
+
121
+ Args:
122
+ key: Cache key
123
+ hits: List of hit dictionaries to cache
124
+ """
125
+ if key in self._cache:
126
+ self._cache.move_to_end(key)
127
+ else:
128
+ if len(self._cache) >= self._maxsize:
129
+ # Remove oldest item (LRU eviction)
130
+ self._cache.popitem(last=False)
131
+ self._cache[key] = hits
132
+
133
+ def clear(self) -> None:
134
+ """Clear all cached entries."""
135
+ self._cache.clear()
136
+ self._hits = 0
137
+ self._misses = 0
138
+
139
+ @property
140
+ def size(self) -> int:
141
+ """Current number of cached entries."""
142
+ return len(self._cache)
143
+
144
+ @property
145
+ def hit_rate(self) -> float:
146
+ """Cache hit rate (0.0 to 1.0)."""
147
+ total = self._hits + self._misses
148
+ return self._hits / total if total > 0 else 0.0
149
+
150
+ def stats(self) -> dict[str, Any]:
151
+ """
152
+ Get cache statistics.
153
+
154
+ Returns:
155
+ Dict with size, maxsize, hits, misses, and hit_rate
156
+ """
157
+ return {
158
+ "size": self.size,
159
+ "maxsize": self._maxsize,
160
+ "hits": self._hits,
161
+ "misses": self._misses,
162
+ "hit_rate": self.hit_rate,
163
+ }
@@ -32,6 +32,7 @@ from norm_toolkit.constants import (
32
32
  TYPES_TABLE,
33
33
  )
34
34
  from norm_toolkit.models import ConceptInfo, SemanticType
35
+ from norm_toolkit.normalizer_cache import NormalizerCache
35
36
 
36
37
 
37
38
  class PostgresNormalizer:
@@ -47,6 +48,8 @@ class PostgresNormalizer:
47
48
  engine: AsyncEngine,
48
49
  schema: str = "public",
49
50
  owned_resource: Any | None = None,
51
+ cache_maxsize: int = 10000,
52
+ enable_cache: bool = True,
50
53
  ) -> None:
51
54
  """
52
55
  Initialize the normalizer with an SQLAlchemy AsyncEngine.
@@ -56,6 +59,8 @@ class PostgresNormalizer:
56
59
  schema: PostgreSQL schema where tables are located (default: "public")
57
60
  owned_resource: Optional resource with async close() method to clean up
58
61
  when this normalizer is closed (e.g., AlloyDB AsyncConnector)
62
+ cache_maxsize: Maximum number of entries in the normalized string cache
63
+ enable_cache: Whether to enable caching of normalized string lookups
59
64
 
60
65
  Note:
61
66
  After creating the normalizer, call `await normalizer.initialize()`
@@ -70,6 +75,11 @@ class PostgresNormalizer:
70
75
  self._has_stt = False
71
76
  self._initialized = False
72
77
 
78
+ # Initialize cache
79
+ self._cache: NormalizerCache | None = (
80
+ NormalizerCache(maxsize=cache_maxsize) if enable_cache else None
81
+ )
82
+
73
83
  # Build qualified table names
74
84
  prefix = f"{schema}." if schema else ""
75
85
  self._ns_table = f"{prefix}{NS_TABLE}"
@@ -147,8 +157,8 @@ class PostgresNormalizer:
147
157
  if prefer_ttys is None:
148
158
  prefer_ttys = DEFAULT_PREFER_TTYS
149
159
 
150
- # Build normalized string map
151
- q_to_nstrs: dict[str, list[str]] = {}
160
+ # Build normalized string map (use tuple for hashable cache keys)
161
+ q_to_nstrs: dict[str, tuple[str, ...]] = {}
152
162
  for s in strings:
153
163
  nstrs = list(lvg_normalize(s) or [])
154
164
  # Add normalized forms of synonyms
@@ -156,25 +166,92 @@ class PostgresNormalizer:
156
166
  for syn in synonyms[s]:
157
167
  syn_nstrs = list(lvg_normalize(syn) or [])
158
168
  nstrs.extend(syn_nstrs)
159
- q_to_nstrs[s] = nstrs
160
-
161
- result = await self._lookup(
162
- q_to_nstrs=q_to_nstrs,
163
- all_queries=list(strings),
164
- prefer_ttys=prefer_ttys,
165
- filter_sources=filter_sources,
166
- exclude_sources=exclude_sources,
167
- top_k=top_k,
168
- allow_partial=allow_partial,
169
- min_coverage=min_coverage,
170
- min_word_hits=min_word_hits,
171
- coverage_weight=coverage_weight,
172
- )
169
+ # Deduplicate while preserving order, then convert to tuple
170
+ q_to_nstrs[s] = tuple(dict.fromkeys(nstrs))
171
+
172
+ # Check cache for each input
173
+ cached_hits: dict[str, list[dict[str, Any]]] = {}
174
+ uncached_queries: list[str] = []
175
+ uncached_q_to_nstrs: dict[str, tuple[str, ...]] = {}
176
+
177
+ for q, nstrs in q_to_nstrs.items():
178
+ if not nstrs:
179
+ # No normalized strings, empty result
180
+ cached_hits[q] = []
181
+ continue
182
+
183
+ if self._cache is not None:
184
+ cache_key = NormalizerCache.make_key(
185
+ nstrs,
186
+ top_k=top_k,
187
+ prefer_ttys=prefer_ttys,
188
+ filter_sources=filter_sources,
189
+ exclude_sources=exclude_sources,
190
+ allow_partial=allow_partial,
191
+ min_coverage=min_coverage,
192
+ min_word_hits=min_word_hits,
193
+ coverage_weight=coverage_weight,
194
+ )
195
+ cached = self._cache.get(cache_key)
196
+ if cached is not None:
197
+ cached_hits[q] = cached
198
+ continue
199
+
200
+ uncached_queries.append(q)
201
+ uncached_q_to_nstrs[q] = nstrs
202
+
203
+ # Query DB for uncached entries
204
+ if uncached_q_to_nstrs:
205
+ # Convert tuples back to lists for _lookup
206
+ uncached_q_to_nstrs_list: dict[str, list[str]] = {
207
+ q: list(nstrs) for q, nstrs in uncached_q_to_nstrs.items()
208
+ }
209
+
210
+ fresh_result = await self._lookup(
211
+ q_to_nstrs=uncached_q_to_nstrs_list,
212
+ all_queries=uncached_queries,
213
+ prefer_ttys=prefer_ttys,
214
+ filter_sources=filter_sources,
215
+ exclude_sources=exclude_sources,
216
+ top_k=top_k,
217
+ allow_partial=allow_partial,
218
+ min_coverage=min_coverage,
219
+ min_word_hits=min_word_hits,
220
+ coverage_weight=coverage_weight,
221
+ )
222
+
223
+ # Enrich fresh results
224
+ fresh_result = await self._enrich_hits_with_concept_info(fresh_result, prefer_ttys)
225
+
226
+ # Cache fresh results and add to cached_hits
227
+ for row in fresh_result.iter_rows(named=True):
228
+ q = row["input_string"]
229
+ hits = row["hits"] or []
230
+ cached_hits[q] = hits
231
+
232
+ if self._cache is not None:
233
+ nstrs = uncached_q_to_nstrs[q]
234
+ cache_key = NormalizerCache.make_key(
235
+ nstrs,
236
+ top_k=top_k,
237
+ prefer_ttys=prefer_ttys,
238
+ filter_sources=filter_sources,
239
+ exclude_sources=exclude_sources,
240
+ allow_partial=allow_partial,
241
+ min_coverage=min_coverage,
242
+ min_word_hits=min_word_hits,
243
+ coverage_weight=coverage_weight,
244
+ )
245
+ self._cache.set(cache_key, hits)
246
+
247
+ # Build final result in original order
248
+ result_data = [{"input_string": s, "hits": cached_hits.get(s, [])} for s in strings]
249
+ result = pl.DataFrame(result_data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
173
250
 
174
251
  # Add synonyms column if synonyms were provided
175
252
  if synonyms:
176
253
  syn_list = [list(synonyms.get(s, [])) for s in strings]
177
- result = result.with_columns(pl.Series("synonyms", syn_list))
254
+ result = result.with_columns(pl.Series("input_synonyms", syn_list))
178
255
 
179
256
  return result
180
257
 
@@ -476,6 +553,58 @@ LEFT JOIN agg ON agg.Q = aq.Q;
476
553
 
477
554
  return pl.DataFrame(data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
478
555
 
556
+ async def _enrich_hits_with_concept_info(
557
+ self,
558
+ result: pl.DataFrame,
559
+ prefer_ttys: list[str] | None,
560
+ ) -> pl.DataFrame:
561
+ """Enrich hits with pref_name, description, and synonyms from concept_info."""
562
+ # Collect all unique concept_ids from hits
563
+ all_concept_ids: set[str] = set()
564
+ for hits in result["hits"].to_list():
565
+ if hits:
566
+ for hit in hits:
567
+ if hit and "global_identifier" in hit:
568
+ all_concept_ids.add(hit["global_identifier"])
569
+
570
+ if not all_concept_ids:
571
+ # No concepts to enrich, just add empty fields
572
+ enriched_data = []
573
+ for row in result.iter_rows(named=True):
574
+ enriched_hits = []
575
+ for hit in row["hits"] or []:
576
+ enriched_hit = dict(hit)
577
+ enriched_hit["pref_name"] = None
578
+ enriched_hit["description"] = None
579
+ enriched_hit["synonyms"] = []
580
+ enriched_hits.append(enriched_hit)
581
+ enriched_data.append({"input_string": row["input_string"], "hits": enriched_hits})
582
+ return pl.DataFrame(enriched_data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
583
+
584
+ # Get concept info for all concepts
585
+ concept_infos = await self.concept_info(list(all_concept_ids), prefer_ttys=prefer_ttys)
586
+
587
+ # Enrich each hit
588
+ enriched_data = []
589
+ for row in result.iter_rows(named=True):
590
+ enriched_hits = []
591
+ for hit in row["hits"] or []:
592
+ enriched_hit = dict(hit)
593
+ cid = hit.get("global_identifier")
594
+ if cid and cid in concept_infos:
595
+ info = concept_infos[cid]
596
+ enriched_hit["pref_name"] = info.preferred_name
597
+ enriched_hit["description"] = info.description
598
+ enriched_hit["synonyms"] = info.synonyms or []
599
+ else:
600
+ enriched_hit["pref_name"] = None
601
+ enriched_hit["description"] = None
602
+ enriched_hit["synonyms"] = []
603
+ enriched_hits.append(enriched_hit)
604
+ enriched_data.append({"input_string": row["input_string"], "hits": enriched_hits})
605
+
606
+ return pl.DataFrame(enriched_data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
607
+
479
608
  async def concept_info(
480
609
  self,
481
610
  concept_ids: Sequence[str],
@@ -846,6 +975,23 @@ WHERE concept_id != :concept_id
846
975
 
847
976
  return [r["concept_id"] for r in rows]
848
977
 
978
+ def cache_stats(self) -> dict[str, Any] | None:
979
+ """
980
+ Get cache statistics.
981
+
982
+ Returns:
983
+ Dict with size, maxsize, hits, misses, and hit_rate,
984
+ or None if caching is disabled.
985
+ """
986
+ if self._cache is None:
987
+ return None
988
+ return self._cache.stats()
989
+
990
+ def clear_cache(self) -> None:
991
+ """Clear all cached entries."""
992
+ if self._cache is not None:
993
+ self._cache.clear()
994
+
849
995
  async def close(self) -> None:
850
996
  """
851
997
  Close the engine and any owned resources.
File without changes