norm_toolkit 1.6.0__tar.gz → 1.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/PKG-INFO +1 -1
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/pyproject.toml +3 -1
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/normalizer.py +16 -11
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/normalizer_cache.py +119 -63
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/normalizer_postgres.py +142 -45
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/normalizer_utils.py +61 -4
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/README.md +0 -0
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/__init__.py +0 -0
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/build_merged.py +0 -0
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/build_ontology.py +0 -0
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/build_umls.py +0 -0
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/constants.py +0 -0
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/models.py +0 -0
- {norm_toolkit-1.6.0 → norm_toolkit-1.8.0}/src/norm_toolkit/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "norm_toolkit"
|
|
3
|
-
version = "1.
|
|
3
|
+
version = "1.8.0"
|
|
4
4
|
description = "Toolkit to normalize text to UMLS / ontologies"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [{ name = "Haydn Jones", email = "haydnjonest@gmail.com" }]
|
|
@@ -24,6 +24,8 @@ dev = [
|
|
|
24
24
|
"pytest>=8.3",
|
|
25
25
|
"rdkit>=2025.9.3",
|
|
26
26
|
"ruff>=0.6.9",
|
|
27
|
+
"fire>=0.7.1",
|
|
28
|
+
"joblib>=1.5.3",
|
|
27
29
|
]
|
|
28
30
|
|
|
29
31
|
[build-system]
|
|
@@ -33,7 +33,7 @@ from norm_toolkit.normalizer_utils import (
|
|
|
33
33
|
build_definitions_sql,
|
|
34
34
|
build_hits_agg_expr,
|
|
35
35
|
build_lookup_sql,
|
|
36
|
-
|
|
36
|
+
build_normalized_query_map,
|
|
37
37
|
build_ontology_filter_clauses,
|
|
38
38
|
build_pref_join,
|
|
39
39
|
build_query_rows,
|
|
@@ -206,7 +206,7 @@ class DuckDBNormalizer:
|
|
|
206
206
|
def normalize(
|
|
207
207
|
self,
|
|
208
208
|
strings: Sequence[str],
|
|
209
|
-
synonyms:
|
|
209
|
+
synonyms: Sequence[Sequence[str] | None] | None = None,
|
|
210
210
|
top_k: int | None = 25,
|
|
211
211
|
ont_top_k: int | None = None,
|
|
212
212
|
prefer_ttys: list[str] | None = None,
|
|
@@ -222,10 +222,10 @@ class DuckDBNormalizer:
|
|
|
222
222
|
|
|
223
223
|
Args:
|
|
224
224
|
strings: Input strings to normalize
|
|
225
|
-
synonyms: Optional
|
|
226
|
-
Synonyms are normalized and used
|
|
227
|
-
to improve matching. Results are
|
|
228
|
-
input string.
|
|
225
|
+
synonyms: Optional list of synonym lists aligned with `strings`
|
|
226
|
+
(same length required). Synonyms are normalized and used
|
|
227
|
+
alongside the main string to improve matching. Results are
|
|
228
|
+
still keyed by the original input string.
|
|
229
229
|
top_k: Maximum number of results per query (mutually exclusive with ont_top_k)
|
|
230
230
|
ont_top_k: Maximum number of results per ontology (mutually exclusive with top_k)
|
|
231
231
|
prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
|
|
@@ -244,12 +244,15 @@ class DuckDBNormalizer:
|
|
|
244
244
|
if prefer_ttys is None:
|
|
245
245
|
prefer_ttys = DEFAULT_PREFER_TTYS
|
|
246
246
|
|
|
247
|
-
|
|
248
|
-
|
|
247
|
+
strings_list = list(strings)
|
|
248
|
+
query_keys = [f"q{i}" for i in range(len(strings_list))] if synonyms is not None else strings_list
|
|
249
|
+
|
|
250
|
+
# Build normalized string map with per-entry keys
|
|
251
|
+
q_to_nstrs, syn_list = build_normalized_query_map(strings_list, synonyms, query_keys=query_keys)
|
|
249
252
|
|
|
250
253
|
result = self._lookup(
|
|
251
254
|
q_to_nstrs=q_to_nstrs,
|
|
252
|
-
all_queries=
|
|
255
|
+
all_queries=query_keys,
|
|
253
256
|
prefer_ttys=prefer_ttys,
|
|
254
257
|
filter_ontologies=filter_ontologies,
|
|
255
258
|
exclude_ontologies=exclude_ontologies,
|
|
@@ -261,9 +264,11 @@ class DuckDBNormalizer:
|
|
|
261
264
|
coverage_weight=coverage_weight,
|
|
262
265
|
)
|
|
263
266
|
|
|
267
|
+
result = result.with_columns(pl.Series("input_string", strings_list))
|
|
268
|
+
|
|
264
269
|
# Add synonyms column if synonyms were provided
|
|
265
|
-
if synonyms:
|
|
266
|
-
syn_list =
|
|
270
|
+
if synonyms is not None:
|
|
271
|
+
syn_list = syn_list if syn_list is not None else [[] for _ in strings_list]
|
|
267
272
|
result = result.with_columns(pl.Series("synonyms", syn_list))
|
|
268
273
|
|
|
269
274
|
return result
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
LRU
|
|
2
|
+
LRU caches for normalized string lookups and entity expansion results.
|
|
3
3
|
|
|
4
4
|
Caches at the normalized string level to avoid repeated DB round trips
|
|
5
5
|
for the same normalized forms.
|
|
@@ -10,7 +10,10 @@ from __future__ import annotations
|
|
|
10
10
|
import hashlib
|
|
11
11
|
from collections import OrderedDict
|
|
12
12
|
from dataclasses import dataclass
|
|
13
|
-
from typing import Any
|
|
13
|
+
from typing import Any, Generic, TypeVar
|
|
14
|
+
|
|
15
|
+
K = TypeVar("K")
|
|
16
|
+
V = TypeVar("V")
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
@dataclass(frozen=True)
|
|
@@ -29,12 +32,21 @@ class CacheKey:
|
|
|
29
32
|
coverage_weight: int
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
|
|
35
|
+
@dataclass(frozen=True)
|
|
36
|
+
class ExpansionCacheKey:
|
|
37
|
+
"""Immutable cache key for entity expansion results."""
|
|
38
|
+
|
|
39
|
+
concept_id: str
|
|
40
|
+
max_depth: int | None
|
|
41
|
+
filter_ontologies: tuple[str, ...] | None
|
|
42
|
+
max_ids: int | None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LRUCache(Generic[K, V]):
|
|
33
46
|
"""
|
|
34
|
-
LRU cache
|
|
47
|
+
LRU cache with basic hit/miss statistics.
|
|
35
48
|
|
|
36
|
-
|
|
37
|
-
and query parameters. Uses an OrderedDict for O(1) LRU eviction.
|
|
49
|
+
Uses an OrderedDict for O(1) LRU eviction.
|
|
38
50
|
"""
|
|
39
51
|
|
|
40
52
|
def __init__(self, maxsize: int = 10000) -> None:
|
|
@@ -45,11 +57,86 @@ class NormalizerCache:
|
|
|
45
57
|
maxsize: Maximum number of entries to cache. When exceeded,
|
|
46
58
|
the least recently used entries are evicted.
|
|
47
59
|
"""
|
|
48
|
-
self._cache: OrderedDict[
|
|
60
|
+
self._cache: OrderedDict[K, V] = OrderedDict()
|
|
49
61
|
self._maxsize = maxsize
|
|
50
62
|
self._hits = 0
|
|
51
63
|
self._misses = 0
|
|
52
64
|
|
|
65
|
+
def get(self, key: K) -> V | None:
|
|
66
|
+
"""
|
|
67
|
+
Get cached value for a key.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
key: Cache key to look up
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Cached value if found, None if not in cache
|
|
74
|
+
"""
|
|
75
|
+
if key in self._cache:
|
|
76
|
+
# Move to end (most recently used)
|
|
77
|
+
self._cache.move_to_end(key)
|
|
78
|
+
self._hits += 1
|
|
79
|
+
return self._cache[key]
|
|
80
|
+
self._misses += 1
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
def set(self, key: K, value: V) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Store a value in the cache.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
key: Cache key
|
|
89
|
+
value: Value to cache
|
|
90
|
+
"""
|
|
91
|
+
if key in self._cache:
|
|
92
|
+
self._cache.move_to_end(key)
|
|
93
|
+
else:
|
|
94
|
+
if len(self._cache) >= self._maxsize:
|
|
95
|
+
# Remove oldest item (LRU eviction)
|
|
96
|
+
self._cache.popitem(last=False)
|
|
97
|
+
self._cache[key] = value
|
|
98
|
+
|
|
99
|
+
def clear(self) -> None:
|
|
100
|
+
"""Clear all cached entries."""
|
|
101
|
+
self._cache.clear()
|
|
102
|
+
self._hits = 0
|
|
103
|
+
self._misses = 0
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def size(self) -> int:
|
|
107
|
+
"""Current number of cached entries."""
|
|
108
|
+
return len(self._cache)
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def hit_rate(self) -> float:
|
|
112
|
+
"""Cache hit rate (0.0 to 1.0)."""
|
|
113
|
+
total = self._hits + self._misses
|
|
114
|
+
return self._hits / total if total > 0 else 0.0
|
|
115
|
+
|
|
116
|
+
def stats(self) -> dict[str, Any]:
|
|
117
|
+
"""
|
|
118
|
+
Get cache statistics.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Dict with size, maxsize, hits, misses, and hit_rate
|
|
122
|
+
"""
|
|
123
|
+
return {
|
|
124
|
+
"size": self.size,
|
|
125
|
+
"maxsize": self._maxsize,
|
|
126
|
+
"hits": self._hits,
|
|
127
|
+
"misses": self._misses,
|
|
128
|
+
"hit_rate": self.hit_rate,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class NormalizerCache(LRUCache[CacheKey, list[dict[str, Any]]]):
|
|
133
|
+
"""
|
|
134
|
+
LRU cache for normalized string lookup results.
|
|
135
|
+
|
|
136
|
+
Caches the fully enriched hits for a given tuple of normalized strings
|
|
137
|
+
and query parameters.
|
|
138
|
+
"""
|
|
139
|
+
|
|
53
140
|
@staticmethod
|
|
54
141
|
def make_key(
|
|
55
142
|
nstrs: tuple[str, ...],
|
|
@@ -100,68 +187,37 @@ class NormalizerCache:
|
|
|
100
187
|
coverage_weight=coverage_weight,
|
|
101
188
|
)
|
|
102
189
|
|
|
103
|
-
def get(self, key: CacheKey) -> list[dict[str, Any]] | None:
|
|
104
|
-
"""
|
|
105
|
-
Get cached hits for a key.
|
|
106
190
|
|
|
107
|
-
|
|
108
|
-
|
|
191
|
+
class ExpansionCache(LRUCache[ExpansionCacheKey, list[str]]):
|
|
192
|
+
"""
|
|
193
|
+
LRU cache for entity expansion results.
|
|
109
194
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
"""
|
|
113
|
-
if key in self._cache:
|
|
114
|
-
# Move to end (most recently used)
|
|
115
|
-
self._cache.move_to_end(key)
|
|
116
|
-
self._hits += 1
|
|
117
|
-
return self._cache[key]
|
|
118
|
-
self._misses += 1
|
|
119
|
-
return None
|
|
195
|
+
Caches expanded concept IDs for a given concept and traversal parameters.
|
|
196
|
+
"""
|
|
120
197
|
|
|
121
|
-
|
|
198
|
+
@staticmethod
|
|
199
|
+
def make_key(
|
|
200
|
+
concept_id: str,
|
|
201
|
+
*,
|
|
202
|
+
max_depth: int | None,
|
|
203
|
+
filter_ontologies: list[str] | None,
|
|
204
|
+
max_ids: int | None,
|
|
205
|
+
) -> ExpansionCacheKey:
|
|
122
206
|
"""
|
|
123
|
-
|
|
207
|
+
Create a cache key from entity expansion parameters.
|
|
124
208
|
|
|
125
209
|
Args:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
self._cache.move_to_end(key)
|
|
131
|
-
else:
|
|
132
|
-
if len(self._cache) >= self._maxsize:
|
|
133
|
-
# Remove oldest item (LRU eviction)
|
|
134
|
-
self._cache.popitem(last=False)
|
|
135
|
-
self._cache[key] = hits
|
|
136
|
-
|
|
137
|
-
def clear(self) -> None:
|
|
138
|
-
"""Clear all cached entries."""
|
|
139
|
-
self._cache.clear()
|
|
140
|
-
self._hits = 0
|
|
141
|
-
self._misses = 0
|
|
142
|
-
|
|
143
|
-
@property
|
|
144
|
-
def size(self) -> int:
|
|
145
|
-
"""Current number of cached entries."""
|
|
146
|
-
return len(self._cache)
|
|
147
|
-
|
|
148
|
-
@property
|
|
149
|
-
def hit_rate(self) -> float:
|
|
150
|
-
"""Cache hit rate (0.0 to 1.0)."""
|
|
151
|
-
total = self._hits + self._misses
|
|
152
|
-
return self._hits / total if total > 0 else 0.0
|
|
153
|
-
|
|
154
|
-
def stats(self) -> dict[str, Any]:
|
|
155
|
-
"""
|
|
156
|
-
Get cache statistics.
|
|
210
|
+
concept_id: Starting concept ID
|
|
211
|
+
max_depth: Maximum depth to traverse
|
|
212
|
+
filter_ontologies: Ontologies to include
|
|
213
|
+
max_ids: Maximum number of IDs to return
|
|
157
214
|
|
|
158
215
|
Returns:
|
|
159
|
-
|
|
216
|
+
Immutable ExpansionCacheKey instance
|
|
160
217
|
"""
|
|
161
|
-
return
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
}
|
|
218
|
+
return ExpansionCacheKey(
|
|
219
|
+
concept_id=concept_id,
|
|
220
|
+
max_depth=max_depth,
|
|
221
|
+
filter_ontologies=tuple(filter_ontologies) if filter_ontologies else None,
|
|
222
|
+
max_ids=max_ids,
|
|
223
|
+
)
|
|
@@ -25,7 +25,7 @@ from norm_toolkit.constants import (
|
|
|
25
25
|
TYPES_TABLE,
|
|
26
26
|
)
|
|
27
27
|
from norm_toolkit.models import ConceptInfo
|
|
28
|
-
from norm_toolkit.normalizer_cache import NormalizerCache
|
|
28
|
+
from norm_toolkit.normalizer_cache import ExpansionCache, NormalizerCache
|
|
29
29
|
from norm_toolkit.normalizer_utils import (
|
|
30
30
|
apply_concept_name_rows,
|
|
31
31
|
apply_definition_rows,
|
|
@@ -34,7 +34,7 @@ from norm_toolkit.normalizer_utils import (
|
|
|
34
34
|
build_definitions_sql,
|
|
35
35
|
build_hits_agg_expr,
|
|
36
36
|
build_lookup_sql,
|
|
37
|
-
|
|
37
|
+
build_normalized_query_map,
|
|
38
38
|
build_ontology_filter_clauses,
|
|
39
39
|
build_pref_join,
|
|
40
40
|
build_query_rows,
|
|
@@ -94,8 +94,8 @@ class PostgresNormalizer:
|
|
|
94
94
|
schema: PostgreSQL schema where tables are located (default: "public")
|
|
95
95
|
owned_resource: Optional resource with async close() method to clean up
|
|
96
96
|
when this normalizer is closed (e.g., AlloyDB AsyncConnector)
|
|
97
|
-
cache_maxsize: Maximum number of entries in
|
|
98
|
-
enable_cache: Whether to enable caching
|
|
97
|
+
cache_maxsize: Maximum number of entries in each cache
|
|
98
|
+
enable_cache: Whether to enable caching for normalization and expansion
|
|
99
99
|
|
|
100
100
|
Note:
|
|
101
101
|
After creating the normalizer, call `await normalizer.initialize()`
|
|
@@ -110,8 +110,9 @@ class PostgresNormalizer:
|
|
|
110
110
|
self._has_stt = False
|
|
111
111
|
self._initialized = False
|
|
112
112
|
|
|
113
|
-
# Initialize
|
|
113
|
+
# Initialize caches
|
|
114
114
|
self._cache: NormalizerCache | None = NormalizerCache(maxsize=cache_maxsize) if enable_cache else None
|
|
115
|
+
self._expansion_cache: ExpansionCache | None = ExpansionCache(maxsize=cache_maxsize) if enable_cache else None
|
|
115
116
|
|
|
116
117
|
# Build qualified table names
|
|
117
118
|
prefix = f"{schema}." if schema else ""
|
|
@@ -158,7 +159,7 @@ class PostgresNormalizer:
|
|
|
158
159
|
async def normalize(
|
|
159
160
|
self,
|
|
160
161
|
strings: Sequence[str],
|
|
161
|
-
synonyms:
|
|
162
|
+
synonyms: Sequence[Sequence[str] | None] | None = None,
|
|
162
163
|
top_k: int | None = 25,
|
|
163
164
|
ont_top_k: int | None = None,
|
|
164
165
|
prefer_ttys: list[str] | None = None,
|
|
@@ -174,10 +175,10 @@ class PostgresNormalizer:
|
|
|
174
175
|
|
|
175
176
|
Args:
|
|
176
177
|
strings: Input strings to normalize
|
|
177
|
-
synonyms: Optional
|
|
178
|
-
Synonyms are normalized and used
|
|
179
|
-
to improve matching. Results are
|
|
180
|
-
input string.
|
|
178
|
+
synonyms: Optional list of synonym lists aligned with `strings`
|
|
179
|
+
(same length required). Synonyms are normalized and used
|
|
180
|
+
alongside the main string to improve matching. Results are
|
|
181
|
+
still keyed by the original input string.
|
|
181
182
|
top_k: Maximum number of results per query (mutually exclusive with ont_top_k)
|
|
182
183
|
ont_top_k: Maximum number of results per ontology (mutually exclusive with top_k)
|
|
183
184
|
prefer_ttys: Term types to prefer (e.g., ["PT", "MH"])
|
|
@@ -204,6 +205,9 @@ class PostgresNormalizer:
|
|
|
204
205
|
if ont_top_k is not None:
|
|
205
206
|
ont_top_k = max(1, int(ont_top_k))
|
|
206
207
|
|
|
208
|
+
strings_list = list(strings)
|
|
209
|
+
query_keys = [f"q{i}" for i in range(len(strings_list))] if synonyms is not None else strings_list
|
|
210
|
+
|
|
207
211
|
def make_cache_key(nstrs: tuple[str, ...]) -> Any:
|
|
208
212
|
return NormalizerCache.make_key(
|
|
209
213
|
nstrs,
|
|
@@ -218,8 +222,8 @@ class PostgresNormalizer:
|
|
|
218
222
|
coverage_weight=coverage_weight,
|
|
219
223
|
)
|
|
220
224
|
|
|
221
|
-
# Build normalized string map (
|
|
222
|
-
q_to_nstrs =
|
|
225
|
+
# Build normalized string map with per-entry keys (tuples for cache keys)
|
|
226
|
+
q_to_nstrs, syn_list = build_normalized_query_map(strings_list, synonyms, query_keys=query_keys)
|
|
223
227
|
|
|
224
228
|
# Check cache for each input
|
|
225
229
|
cached_hits: dict[str, list[dict[str, Any]]] = {}
|
|
@@ -271,12 +275,15 @@ class PostgresNormalizer:
|
|
|
271
275
|
self._cache.set(cache_key, hits)
|
|
272
276
|
|
|
273
277
|
# Build final result in original order
|
|
274
|
-
result_data = [
|
|
278
|
+
result_data = [
|
|
279
|
+
{"input_string": strings_list[i], "hits": cached_hits.get(query_keys[i], [])}
|
|
280
|
+
for i in range(len(strings_list))
|
|
281
|
+
]
|
|
275
282
|
result = pl.DataFrame(result_data).cast({"hits": pl.List(HIT_STRUCT_TYPE)})
|
|
276
283
|
|
|
277
284
|
# Add synonyms column if synonyms were provided
|
|
278
|
-
if synonyms:
|
|
279
|
-
syn_list =
|
|
285
|
+
if synonyms is not None:
|
|
286
|
+
syn_list = syn_list if syn_list is not None else [[] for _ in strings_list]
|
|
280
287
|
result = result.with_columns(pl.Series("synonyms", syn_list))
|
|
281
288
|
|
|
282
289
|
return result
|
|
@@ -319,8 +326,7 @@ class PostgresNormalizer:
|
|
|
319
326
|
qwords_values = sql_params.add_rows(qword_rows) if qword_rows else ""
|
|
320
327
|
|
|
321
328
|
allq_values = ", ".join(
|
|
322
|
-
f"({sql_params.add(q)}, {sql_params.add_cast(i, 'INTEGER')})"
|
|
323
|
-
for i, q in enumerate(all_queries)
|
|
329
|
+
f"({sql_params.add(q)}, {sql_params.add_cast(i, 'INTEGER')})" for i, q in enumerate(all_queries)
|
|
324
330
|
)
|
|
325
331
|
|
|
326
332
|
# Build preference clauses (parameterized to prevent SQL injection)
|
|
@@ -596,61 +602,138 @@ class PostgresNormalizer:
|
|
|
596
602
|
List of descendant concept IDs ordered by depth (shallowest first),
|
|
597
603
|
excludes the starting concept
|
|
598
604
|
"""
|
|
605
|
+
results = await self.get_narrower_concepts_many(
|
|
606
|
+
[concept_id],
|
|
607
|
+
max_depth=max_depth,
|
|
608
|
+
filter_ontologies=filter_ontologies,
|
|
609
|
+
max_ids=max_ids,
|
|
610
|
+
)
|
|
611
|
+
return results.get(concept_id, [])
|
|
612
|
+
|
|
613
|
+
async def get_narrower_concepts_many(
|
|
614
|
+
self,
|
|
615
|
+
concept_ids: Sequence[str],
|
|
616
|
+
max_depth: int | None = 10,
|
|
617
|
+
filter_ontologies: list[str] | None = None,
|
|
618
|
+
max_ids: int | None = None,
|
|
619
|
+
) -> dict[str, list[str]]:
|
|
620
|
+
"""
|
|
621
|
+
Get narrower (descendant) concept IDs for many roots in one query.
|
|
622
|
+
|
|
623
|
+
Uses the hierarchy edges to walk down the tree/DAG from each root concept.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
concept_ids: Starting concept IDs (broader terms)
|
|
627
|
+
max_depth: Maximum depth to traverse (1 = direct children only, None = all descendants)
|
|
628
|
+
filter_ontologies: Only follow edges from these ontologies (e.g., ["UMLS", "CHEBI"])
|
|
629
|
+
max_ids: Maximum number of concept IDs to return (None = no limit)
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
Dict mapping each concept ID to descendant IDs ordered by depth
|
|
633
|
+
(shallowest first), excluding the starting concept.
|
|
634
|
+
"""
|
|
599
635
|
await self._ensure_initialized()
|
|
600
636
|
|
|
601
|
-
if not self._has_edges:
|
|
602
|
-
return []
|
|
637
|
+
if not self._has_edges or not concept_ids:
|
|
638
|
+
return {cid: [] for cid in concept_ids}
|
|
603
639
|
|
|
604
|
-
|
|
640
|
+
id_list = list(dict.fromkeys(concept_ids))
|
|
641
|
+
|
|
642
|
+
res: dict[str, list[str]] = {}
|
|
643
|
+
missing: list[str] = []
|
|
644
|
+
cache_keys: dict[str, Any] = {}
|
|
645
|
+
|
|
646
|
+
if self._expansion_cache is not None:
|
|
647
|
+
for cid in id_list:
|
|
648
|
+
cache_key = ExpansionCache.make_key(
|
|
649
|
+
cid,
|
|
650
|
+
max_depth=max_depth,
|
|
651
|
+
filter_ontologies=filter_ontologies,
|
|
652
|
+
max_ids=max_ids,
|
|
653
|
+
)
|
|
654
|
+
cache_keys[cid] = cache_key
|
|
655
|
+
cached = self._expansion_cache.get(cache_key)
|
|
656
|
+
if cached is not None:
|
|
657
|
+
res[cid] = cached
|
|
658
|
+
else:
|
|
659
|
+
res[cid] = []
|
|
660
|
+
missing.append(cid)
|
|
661
|
+
else:
|
|
662
|
+
for cid in id_list:
|
|
663
|
+
res[cid] = []
|
|
664
|
+
missing = id_list
|
|
665
|
+
|
|
666
|
+
if not missing:
|
|
667
|
+
return res
|
|
668
|
+
|
|
669
|
+
sql_params = _SqlParams()
|
|
670
|
+
idmap_values = sql_params.add_single_column_values(missing)
|
|
671
|
+
params = sql_params.params
|
|
672
|
+
params["max_depth"] = max_depth
|
|
605
673
|
|
|
606
|
-
# Build ontology filter clause
|
|
607
674
|
ontology_filter = ""
|
|
608
675
|
if filter_ontologies:
|
|
609
|
-
|
|
610
|
-
for i, ont in enumerate(filter_ontologies):
|
|
611
|
-
key = f"ont{i}"
|
|
612
|
-
params[key] = ont
|
|
613
|
-
ont_placeholders.append(f":{key}")
|
|
614
|
-
ontologies_sql = ", ".join(ont_placeholders)
|
|
676
|
+
ontologies_sql = sql_params.add_values(filter_ontologies)
|
|
615
677
|
ontology_filter = f" AND e.ontology IN ({ontologies_sql})"
|
|
616
678
|
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
679
|
+
if max_ids is None:
|
|
680
|
+
select_sql = """
|
|
681
|
+
SELECT root_id, concept_id, MIN(depth) AS min_depth
|
|
682
|
+
FROM walk
|
|
683
|
+
WHERE concept_id != root_id
|
|
684
|
+
GROUP BY root_id, concept_id
|
|
685
|
+
ORDER BY root_id, min_depth, concept_id
|
|
686
|
+
"""
|
|
687
|
+
else:
|
|
620
688
|
params["max_ids"] = max_ids
|
|
621
|
-
|
|
689
|
+
select_sql = """
|
|
690
|
+
SELECT root_id, concept_id, min_depth
|
|
691
|
+
FROM (
|
|
692
|
+
SELECT root_id, concept_id, min_depth,
|
|
693
|
+
ROW_NUMBER() OVER (PARTITION BY root_id ORDER BY min_depth, concept_id) AS rn
|
|
694
|
+
FROM (
|
|
695
|
+
SELECT root_id, concept_id, MIN(depth) AS min_depth
|
|
696
|
+
FROM walk
|
|
697
|
+
WHERE concept_id != root_id
|
|
698
|
+
GROUP BY root_id, concept_id
|
|
699
|
+
) base
|
|
700
|
+
) ranked
|
|
701
|
+
WHERE rn <= :max_ids
|
|
702
|
+
ORDER BY root_id, min_depth, concept_id
|
|
703
|
+
"""
|
|
622
704
|
|
|
623
|
-
# PostgreSQL recursive CTE with named parameters
|
|
624
|
-
# Use CAST() instead of :: to avoid conflicts with SQLAlchemy named params
|
|
625
|
-
# UNION (not UNION ALL) deduplicates on (concept_id, depth) during recursion
|
|
626
|
-
# GROUP BY with MIN(depth) gets shortest path depth for each concept
|
|
627
705
|
query = dedent(
|
|
628
706
|
f"""
|
|
629
|
-
WITH RECURSIVE
|
|
630
|
-
|
|
707
|
+
WITH RECURSIVE idmap(root_id) AS (VALUES {idmap_values}),
|
|
708
|
+
walk(root_id, concept_id, depth) AS (
|
|
709
|
+
SELECT root_id, root_id, 0
|
|
710
|
+
FROM idmap
|
|
631
711
|
|
|
632
712
|
UNION
|
|
633
713
|
|
|
634
|
-
SELECT e.child_id, w.depth + 1
|
|
714
|
+
SELECT w.root_id, e.child_id, w.depth + 1
|
|
635
715
|
FROM walk w
|
|
636
716
|
JOIN {self._edges_table} e ON e.parent_id = w.concept_id
|
|
637
717
|
WHERE (CAST(:max_depth AS INTEGER) IS NULL OR w.depth < :max_depth){ontology_filter}
|
|
638
718
|
)
|
|
639
|
-
|
|
640
|
-
FROM walk
|
|
641
|
-
WHERE concept_id != :concept_id
|
|
642
|
-
GROUP BY concept_id
|
|
643
|
-
ORDER BY min_depth, concept_id{limit_clause}
|
|
719
|
+
{select_sql}
|
|
644
720
|
"""
|
|
645
721
|
)
|
|
646
722
|
|
|
647
723
|
rows = await self._fetch_rows(query, params)
|
|
648
724
|
|
|
649
|
-
|
|
725
|
+
for row in rows:
|
|
726
|
+
res[row["root_id"]].append(row["concept_id"])
|
|
727
|
+
|
|
728
|
+
if self._expansion_cache is not None:
|
|
729
|
+
for cid in missing:
|
|
730
|
+
self._expansion_cache.set(cache_keys[cid], res[cid])
|
|
731
|
+
|
|
732
|
+
return res
|
|
650
733
|
|
|
651
734
|
def cache_stats(self) -> dict[str, Any] | None:
|
|
652
735
|
"""
|
|
653
|
-
Get cache statistics.
|
|
736
|
+
Get normalization cache statistics.
|
|
654
737
|
|
|
655
738
|
Returns:
|
|
656
739
|
Dict with size, maxsize, hits, misses, and hit_rate,
|
|
@@ -660,10 +743,24 @@ class PostgresNormalizer:
|
|
|
660
743
|
return None
|
|
661
744
|
return self._cache.stats()
|
|
662
745
|
|
|
746
|
+
def expansion_cache_stats(self) -> dict[str, Any] | None:
|
|
747
|
+
"""
|
|
748
|
+
Get entity expansion cache statistics.
|
|
749
|
+
|
|
750
|
+
Returns:
|
|
751
|
+
Dict with size, maxsize, hits, misses, and hit_rate,
|
|
752
|
+
or None if caching is disabled.
|
|
753
|
+
"""
|
|
754
|
+
if self._expansion_cache is None:
|
|
755
|
+
return None
|
|
756
|
+
return self._expansion_cache.stats()
|
|
757
|
+
|
|
663
758
|
def clear_cache(self) -> None:
|
|
664
759
|
"""Clear all cached entries."""
|
|
665
760
|
if self._cache is not None:
|
|
666
761
|
self._cache.clear()
|
|
762
|
+
if self._expansion_cache is not None:
|
|
763
|
+
self._expansion_cache.clear()
|
|
667
764
|
|
|
668
765
|
async def close(self) -> None:
|
|
669
766
|
"""
|
|
@@ -21,24 +21,81 @@ from norm_toolkit.constants import (
|
|
|
21
21
|
from norm_toolkit.models import ConceptInfo, SemanticType
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def _coerce_synonyms_list(
|
|
25
|
+
strings: Sequence[str],
|
|
26
|
+
synonyms: Sequence[Sequence[str] | None] | None,
|
|
27
|
+
) -> list[list[str]] | None:
|
|
28
|
+
if synonyms is None:
|
|
29
|
+
return None
|
|
30
|
+
if not isinstance(synonyms, Sequence) or isinstance(synonyms, (str, bytes)):
|
|
31
|
+
raise TypeError("synonyms must be a sequence of sequences aligned with strings")
|
|
32
|
+
if len(synonyms) != len(strings):
|
|
33
|
+
raise ValueError("synonyms must have the same length as strings")
|
|
34
|
+
out: list[list[str]] = []
|
|
35
|
+
for i, syns in enumerate(synonyms):
|
|
36
|
+
if syns is None:
|
|
37
|
+
out.append([])
|
|
38
|
+
continue
|
|
39
|
+
if not isinstance(syns, Sequence) or isinstance(syns, (str, bytes)):
|
|
40
|
+
raise ValueError(f"synonyms[{i}] must be a sequence of strings")
|
|
41
|
+
out.append(list(syns))
|
|
42
|
+
return out
|
|
43
|
+
|
|
44
|
+
|
|
24
45
|
def build_normalized_string_map(
|
|
25
46
|
strings: Sequence[str],
|
|
26
|
-
synonyms:
|
|
47
|
+
synonyms: Sequence[Sequence[str] | None] | None = None,
|
|
27
48
|
) -> dict[str, tuple[str, ...]]:
|
|
28
49
|
"""
|
|
29
50
|
Build a mapping of input string -> normalized string variants.
|
|
30
51
|
|
|
31
52
|
Normalized variants are deduplicated while preserving order.
|
|
53
|
+
Duplicate input strings will collapse to the last entry.
|
|
54
|
+
Synonyms must be aligned with `strings` when provided.
|
|
32
55
|
"""
|
|
56
|
+
synonyms_list = _coerce_synonyms_list(strings, synonyms)
|
|
57
|
+
syns_iter = synonyms_list if synonyms_list is not None else [None] * len(strings)
|
|
58
|
+
|
|
33
59
|
q_to_nstrs: dict[str, tuple[str, ...]] = {}
|
|
34
|
-
for s in strings:
|
|
60
|
+
for s, syns in zip(strings, syns_iter):
|
|
35
61
|
nstrs = list(lvg_normalize(s) or [])
|
|
36
|
-
|
|
37
|
-
|
|
62
|
+
if syns:
|
|
63
|
+
for syn in syns:
|
|
64
|
+
nstrs.extend(lvg_normalize(syn) or [])
|
|
38
65
|
q_to_nstrs[s] = tuple(dict.fromkeys(nstrs))
|
|
39
66
|
return q_to_nstrs
|
|
40
67
|
|
|
41
68
|
|
|
69
|
+
def build_normalized_query_map(
|
|
70
|
+
strings: Sequence[str],
|
|
71
|
+
synonyms: Sequence[Sequence[str] | None] | None = None,
|
|
72
|
+
*,
|
|
73
|
+
query_keys: Sequence[str] | None = None,
|
|
74
|
+
) -> tuple[dict[str, tuple[str, ...]], list[list[str]] | None]:
|
|
75
|
+
"""
|
|
76
|
+
Build a mapping of query key -> normalized string variants.
|
|
77
|
+
|
|
78
|
+
Normalized variants are deduplicated while preserving order.
|
|
79
|
+
Synonyms must be aligned with `strings` when provided.
|
|
80
|
+
"""
|
|
81
|
+
if query_keys is None:
|
|
82
|
+
query_keys = list(strings)
|
|
83
|
+
if len(query_keys) != len(strings):
|
|
84
|
+
raise ValueError("query_keys must have the same length as strings")
|
|
85
|
+
|
|
86
|
+
synonyms_list = _coerce_synonyms_list(strings, synonyms)
|
|
87
|
+
syns_iter = synonyms_list if synonyms_list is not None else [None] * len(strings)
|
|
88
|
+
|
|
89
|
+
q_to_nstrs: dict[str, tuple[str, ...]] = {}
|
|
90
|
+
for key, s, syns in zip(query_keys, strings, syns_iter):
|
|
91
|
+
nstrs = list(lvg_normalize(s) or [])
|
|
92
|
+
if syns:
|
|
93
|
+
for syn in syns:
|
|
94
|
+
nstrs.extend(lvg_normalize(syn) or [])
|
|
95
|
+
q_to_nstrs[key] = tuple(dict.fromkeys(nstrs))
|
|
96
|
+
return q_to_nstrs, synonyms_list
|
|
97
|
+
|
|
98
|
+
|
|
42
99
|
def build_query_rows(
|
|
43
100
|
q_to_nstrs: Mapping[str, Sequence[str]],
|
|
44
101
|
*,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|