groundworkers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groundworkers/__init__.py +3 -0
- groundworkers/adapters/__init__.py +1 -0
- groundworkers/adapters/omop_emb.py +251 -0
- groundworkers/adapters/omop_graph.py +721 -0
- groundworkers/adapters/omop_vocab.py +582 -0
- groundworkers/base/__init__.py +17 -0
- groundworkers/base/errors.py +19 -0
- groundworkers/base/results.py +38 -0
- groundworkers/base/server.py +52 -0
- groundworkers/base/sql.py +109 -0
- groundworkers/config.py +139 -0
- groundworkers/server.py +127 -0
- groundworkers/tools/__init__.py +1 -0
- groundworkers/tools/concept_tools.py +237 -0
- groundworkers/tools/embedding_tools.py +83 -0
- groundworkers/tools/resolver_tools.py +90 -0
- groundworkers/tools/search_tools.py +163 -0
- groundworkers/tools/system_tools.py +67 -0
- groundworkers-0.1.0.dist-info/METADATA +116 -0
- groundworkers-0.1.0.dist-info/RECORD +23 -0
- groundworkers-0.1.0.dist-info/WHEEL +5 -0
- groundworkers-0.1.0.dist-info/entry_points.txt +2 -0
- groundworkers-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,721 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import deque
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from datetime import date
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from omop_graph.extensions.omop_alchemy import PredicateKind
|
|
9
|
+
from omop_graph.graph.constraints import SearchConstraintConcept
|
|
10
|
+
from omop_graph.graph.kg import KnowledgeGraph
|
|
11
|
+
from omop_graph.graph.paths import find_shortest_paths_batch
|
|
12
|
+
from omop_graph.graph.traverse import traverse
|
|
13
|
+
from omop_graph.reasoning.grounding import GroundingConstraints, ground_term
|
|
14
|
+
from omop_graph.reasoning.resolvers import ResolverPipeline
|
|
15
|
+
from omop_graph.reasoning.resolvers.resolvers import (
|
|
16
|
+
EmbeddingResolver,
|
|
17
|
+
ExactLabelResolver,
|
|
18
|
+
ExactSynonymResolver,
|
|
19
|
+
FullTextResolver,
|
|
20
|
+
FullTextSynonymResolver,
|
|
21
|
+
PartialLabelResolver,
|
|
22
|
+
PartialSynonymResolver,
|
|
23
|
+
)
|
|
24
|
+
from omop_alchemy.cdm.model.vocabulary import (
|
|
25
|
+
Concept,
|
|
26
|
+
Concept_Ancestor,
|
|
27
|
+
Concept_Class,
|
|
28
|
+
Domain,
|
|
29
|
+
Vocabulary,
|
|
30
|
+
)
|
|
31
|
+
from sqlalchemy import func, select, text
|
|
32
|
+
from sqlalchemy.engine import Engine
|
|
33
|
+
from sqlalchemy.exc import NoResultFound
|
|
34
|
+
|
|
35
|
+
from groundworkers.base.errors import GroundworkersError
|
|
36
|
+
|
|
37
|
+
# TODO: some of this adapter logic really should be pushed back into
|
|
38
|
+
# the core omop-graph library, but waiting for the use-cases and paths
|
|
39
|
+
# to stabilise first.
|
|
40
|
+
|
|
41
|
+
class OmopGraphAdapter:
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
engine: Engine,
|
|
45
|
+
*,
|
|
46
|
+
vocab_schema: str = "omop_vocab",
|
|
47
|
+
emb_model_name: str | None = None,
|
|
48
|
+
) -> None:
|
|
49
|
+
self.engine = engine
|
|
50
|
+
self.vocab_schema = vocab_schema
|
|
51
|
+
self.emb_model_name = emb_model_name
|
|
52
|
+
self._kg: KnowledgeGraph | None = None
|
|
53
|
+
|
|
54
|
+
def is_available(self) -> bool:
|
|
55
|
+
try:
|
|
56
|
+
self._get_kg()
|
|
57
|
+
return True
|
|
58
|
+
except GroundworkersError:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def close(self) -> None:
|
|
62
|
+
self.engine.dispose()
|
|
63
|
+
self._kg = None
|
|
64
|
+
|
|
65
|
+
def get_concept(self, concept_id: int) -> dict[str, Any] | None:
|
|
66
|
+
try:
|
|
67
|
+
concept_view = self._get_kg().concept_view(concept_id)
|
|
68
|
+
except Exception as exc:
|
|
69
|
+
if self._is_not_found(exc):
|
|
70
|
+
return None
|
|
71
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
72
|
+
return self._serialise_concept_view(concept_view)
|
|
73
|
+
|
|
74
|
+
def get_concept_by_code(self, vocabulary_id: str, code: str) -> list[dict[str, Any]]:
|
|
75
|
+
try:
|
|
76
|
+
concept_id = self._get_kg().concept_id_by_code(vocabulary_id, code)
|
|
77
|
+
concept_view = self._get_kg().concept_view(concept_id)
|
|
78
|
+
except Exception as exc:
|
|
79
|
+
if self._is_not_found(exc):
|
|
80
|
+
return []
|
|
81
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
82
|
+
return [self._serialise_concept_view(concept_view)]
|
|
83
|
+
|
|
84
|
+
def get_ancestors(self, concept_id: int, max_depth: int) -> list[dict[str, Any]]:
|
|
85
|
+
kg = self._get_kg()
|
|
86
|
+
if self.get_concept(concept_id) is None:
|
|
87
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
|
|
88
|
+
|
|
89
|
+
queue: deque[tuple[int, int]] = deque((parent_id, 1) for parent_id in kg.parents(concept_id))
|
|
90
|
+
return self._walk_hierarchy(queue=queue, neighbour_getter=kg.parents, max_depth=max_depth)
|
|
91
|
+
|
|
92
|
+
def ground(
|
|
93
|
+
self,
|
|
94
|
+
query: str,
|
|
95
|
+
limit: int,
|
|
96
|
+
domain: str | None,
|
|
97
|
+
vocabulary_id: str | None,
|
|
98
|
+
parent_ids: tuple[int, ...] | None = None,
|
|
99
|
+
) -> dict[str, Any]:
|
|
100
|
+
"""Ground free text to ranked standard OMOP concepts.
|
|
101
|
+
|
|
102
|
+
Returns a dict with keys:
|
|
103
|
+
results — ranked list of grounded concepts with scoring fields
|
|
104
|
+
grounding_explanation — summary of which tier matched and what constraints ran
|
|
105
|
+
"""
|
|
106
|
+
kg = self._get_kg()
|
|
107
|
+
|
|
108
|
+
# Normalise domain to its canonical OMOP casing (e.g. "condition" → "Condition").
|
|
109
|
+
# OMOP domain_id values are title-cased; a case-insensitive match against the
|
|
110
|
+
# known root codes table handles the common mistake of passing lowercase names.
|
|
111
|
+
if domain is not None:
|
|
112
|
+
_domain_lower = domain.lower()
|
|
113
|
+
domain = next(
|
|
114
|
+
(k for k in self._DOMAIN_ROOT_CODES if k.lower() == _domain_lower),
|
|
115
|
+
domain, # unknown domain: pass through unchanged
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
search_constraint = None
|
|
119
|
+
if domain or vocabulary_id:
|
|
120
|
+
search_constraint = SearchConstraintConcept(
|
|
121
|
+
domains=(domain,) if domain else None,
|
|
122
|
+
vocabularies=(vocabulary_id,) if vocabulary_id else None,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if parent_ids is not None:
|
|
126
|
+
resolved_parent_ids: tuple[int, ...] = parent_ids
|
|
127
|
+
parent_ids_source = "explicit"
|
|
128
|
+
elif domain is not None:
|
|
129
|
+
resolved_parent_ids = self._get_domain_root_ids(domain)
|
|
130
|
+
parent_ids_source = "domain_root"
|
|
131
|
+
else:
|
|
132
|
+
# No domain filter: collect roots across all known domains so hierarchy
|
|
133
|
+
# anchoring doesn't silently drop every candidate.
|
|
134
|
+
all_roots: list[int] = []
|
|
135
|
+
for d in self._DOMAIN_ROOT_CODES:
|
|
136
|
+
all_roots.extend(self._get_domain_root_ids(d))
|
|
137
|
+
resolved_parent_ids = tuple(all_roots)
|
|
138
|
+
parent_ids_source = "all_domain_roots"
|
|
139
|
+
|
|
140
|
+
if not resolved_parent_ids:
|
|
141
|
+
raise GroundworkersError(
|
|
142
|
+
"QUERY_ERROR",
|
|
143
|
+
"No hierarchy anchors found — ensure the OMOP vocabulary is bootstrapped "
|
|
144
|
+
"(concept and concept_ancestor tables must be populated).",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
constraints = GroundingConstraints(parent_ids=resolved_parent_ids, search_constraint=search_constraint)
|
|
148
|
+
|
|
149
|
+
# Tiered pipeline — short-circuit on the first tier that returns results,
|
|
150
|
+
# avoiding lower-quality resolvers when a better match exists.
|
|
151
|
+
# Each tier pairs the label resolver with its synonym counterpart so that
|
|
152
|
+
# abbreviations, trade names, and alternate spellings are matched at the
|
|
153
|
+
# same confidence level as the primary concept name.
|
|
154
|
+
# FullTextSynonymResolver degrades gracefully (returns nothing) when the
|
|
155
|
+
# tsvector sidecar columns have not been installed.
|
|
156
|
+
tiers: list[tuple[Any, ...]] = [
|
|
157
|
+
(ExactLabelResolver(), ExactSynonymResolver()),
|
|
158
|
+
(FullTextResolver(), FullTextSynonymResolver()),
|
|
159
|
+
]
|
|
160
|
+
if self.emb_model_name:
|
|
161
|
+
tiers.append((EmbeddingResolver(),))
|
|
162
|
+
tiers.append((PartialLabelResolver(), PartialSynonymResolver()))
|
|
163
|
+
|
|
164
|
+
results: list[Any] = []
|
|
165
|
+
for tier in tiers:
|
|
166
|
+
pipeline = ResolverPipeline(resolvers=tier)
|
|
167
|
+
try:
|
|
168
|
+
results = ground_term(
|
|
169
|
+
pipeline, kg, query,
|
|
170
|
+
query_embedding=None,
|
|
171
|
+
constraints=constraints,
|
|
172
|
+
max_candidates=limit,
|
|
173
|
+
)
|
|
174
|
+
except Exception as exc:
|
|
175
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
176
|
+
if results:
|
|
177
|
+
break
|
|
178
|
+
|
|
179
|
+
concept_ids = tuple(dict.fromkeys(r.concept_id for r in results))
|
|
180
|
+
try:
|
|
181
|
+
views = {v.concept_id: v for v in kg.concept_views(concept_ids, sort=False)} if concept_ids else {}
|
|
182
|
+
except Exception:
|
|
183
|
+
views = {}
|
|
184
|
+
|
|
185
|
+
matched_tier = self._label_match_kind_name(results[0].match_kind) if results else None
|
|
186
|
+
used_embedding = any(getattr(r, "embedding_score", None) is not None for r in results)
|
|
187
|
+
|
|
188
|
+
return {
|
|
189
|
+
"results": [self._serialise_ground_result(r, views) for r in results],
|
|
190
|
+
"grounding_explanation": {
|
|
191
|
+
"matched_tier": matched_tier,
|
|
192
|
+
"used_embedding": used_embedding,
|
|
193
|
+
"effective_parent_ids": list(resolved_parent_ids),
|
|
194
|
+
"parent_ids_source": parent_ids_source,
|
|
195
|
+
},
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
# Valid predicate kind names accepted by get_neighbors (case-insensitive).
|
|
199
|
+
_PREDICATE_KIND_NAMES: dict[str, PredicateKind] = {pk.name.upper(): pk for pk in PredicateKind}
|
|
200
|
+
|
|
201
|
+
def get_neighbors(
|
|
202
|
+
self,
|
|
203
|
+
concept_id: int,
|
|
204
|
+
max_depth: int,
|
|
205
|
+
predicate_kinds: list[str] | None,
|
|
206
|
+
max_nodes: int,
|
|
207
|
+
include_edges: bool,
|
|
208
|
+
) -> dict[str, Any]:
|
|
209
|
+
"""Bounded multi-hop neighborhood exploration via BFS.
|
|
210
|
+
|
|
211
|
+
Follows outgoing relationship edges from the seed concept up to
|
|
212
|
+
max_depth hops, collecting all reachable concepts and (optionally)
|
|
213
|
+
the edges that connect them.
|
|
214
|
+
"""
|
|
215
|
+
kg = self._get_kg()
|
|
216
|
+
if self.get_concept(concept_id) is None:
|
|
217
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
|
|
218
|
+
|
|
219
|
+
pk_set: set[PredicateKind] | None = None
|
|
220
|
+
if predicate_kinds is not None:
|
|
221
|
+
pk_set = set()
|
|
222
|
+
for pk_name in predicate_kinds:
|
|
223
|
+
key = pk_name.upper()
|
|
224
|
+
if key not in self._PREDICATE_KIND_NAMES:
|
|
225
|
+
valid = sorted(self._PREDICATE_KIND_NAMES)
|
|
226
|
+
raise GroundworkersError(
|
|
227
|
+
"INVALID_INPUT",
|
|
228
|
+
f"Unknown predicate_kind {pk_name!r}. Valid values: {valid}",
|
|
229
|
+
)
|
|
230
|
+
pk_set.add(self._PREDICATE_KIND_NAMES[key])
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
subgraph, graph_trace = traverse(
|
|
234
|
+
kg=kg,
|
|
235
|
+
seeds=(concept_id,),
|
|
236
|
+
predicate_kinds=pk_set,
|
|
237
|
+
max_depth=max_depth,
|
|
238
|
+
on=None,
|
|
239
|
+
max_nodes=max_nodes,
|
|
240
|
+
trace=True, # always trace so we can report terminated_reason
|
|
241
|
+
)
|
|
242
|
+
except Exception as exc:
|
|
243
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
244
|
+
|
|
245
|
+
neighbor_ids = tuple(n for n in sorted(subgraph.nodes) if n != concept_id)
|
|
246
|
+
try:
|
|
247
|
+
views = {v.concept_id: v for v in kg.concept_views(neighbor_ids, sort=False)} if neighbor_ids else {}
|
|
248
|
+
except Exception:
|
|
249
|
+
views = {}
|
|
250
|
+
|
|
251
|
+
neighbors: list[dict[str, Any]] = []
|
|
252
|
+
for nid in neighbor_ids:
|
|
253
|
+
view = views.get(nid)
|
|
254
|
+
if view:
|
|
255
|
+
neighbors.append({
|
|
256
|
+
"concept_id": int(view.concept_id),
|
|
257
|
+
"concept_name": view.concept_name,
|
|
258
|
+
"vocabulary_id": view.vocabulary_id,
|
|
259
|
+
"domain_id": view.domain_id,
|
|
260
|
+
"concept_class_id": view.concept_class_id,
|
|
261
|
+
"standard_concept": bool(view.standard_concept),
|
|
262
|
+
})
|
|
263
|
+
|
|
264
|
+
edges: list[dict[str, Any]] = []
|
|
265
|
+
if include_edges:
|
|
266
|
+
for edge in subgraph.edges:
|
|
267
|
+
edges.append({
|
|
268
|
+
"subject_id": int(edge.subject_id),
|
|
269
|
+
"predicate_id": edge.predicate_id,
|
|
270
|
+
"predicate_kind": edge.predicate_kind.name,
|
|
271
|
+
"object_id": int(edge.object_id),
|
|
272
|
+
})
|
|
273
|
+
|
|
274
|
+
terminated_reason = graph_trace.terminated_reason if graph_trace else None
|
|
275
|
+
return {
|
|
276
|
+
"concept_id": concept_id,
|
|
277
|
+
"neighbor_count": len(neighbors),
|
|
278
|
+
"edge_count": len(subgraph.edges),
|
|
279
|
+
"neighbors": neighbors,
|
|
280
|
+
"edges": edges,
|
|
281
|
+
"terminated_early": terminated_reason is not None,
|
|
282
|
+
"terminated_reason": terminated_reason,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
def get_edges(self, concept_id: int) -> dict[str, Any]:
|
|
286
|
+
kg = self._get_kg()
|
|
287
|
+
if self.get_concept(concept_id) is None:
|
|
288
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
|
|
289
|
+
try:
|
|
290
|
+
outbound = kg.edges(concept_id, direction="out", active_only=False)
|
|
291
|
+
inbound = kg.edges(concept_id, direction="in", active_only=False)
|
|
292
|
+
except Exception as exc:
|
|
293
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
294
|
+
|
|
295
|
+
other_ids = tuple(dict.fromkeys([e.object_id for e in outbound] + [e.subject_id for e in inbound]))
|
|
296
|
+
try:
|
|
297
|
+
views = {v.concept_id: v for v in kg.concept_views(other_ids, sort=False)} if other_ids else {}
|
|
298
|
+
except Exception:
|
|
299
|
+
views = {}
|
|
300
|
+
|
|
301
|
+
return {
|
|
302
|
+
"outbound": [self._serialise_edge_out(e, views) for e in outbound],
|
|
303
|
+
"inbound": [self._serialise_edge_in(e, views) for e in inbound],
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
def find_path(
|
|
307
|
+
self,
|
|
308
|
+
source_id: int,
|
|
309
|
+
target_id: int,
|
|
310
|
+
max_depth: int,
|
|
311
|
+
predicate_kinds: frozenset | None = None,
|
|
312
|
+
within_domain: bool = True,
|
|
313
|
+
) -> dict[str, Any]:
|
|
314
|
+
kg = self._get_kg()
|
|
315
|
+
if self.get_concept(source_id) is None:
|
|
316
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {source_id} was not found")
|
|
317
|
+
if source_id == target_id:
|
|
318
|
+
return {"found": True, "paths": [{"length": 0, "steps": []}]}
|
|
319
|
+
if self.get_concept(target_id) is None:
|
|
320
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {target_id} was not found")
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
paths = find_shortest_paths_batch(
|
|
324
|
+
kg,
|
|
325
|
+
source_id,
|
|
326
|
+
target_id,
|
|
327
|
+
max_depth=max_depth,
|
|
328
|
+
predicate_kinds=predicate_kinds,
|
|
329
|
+
within_domain=within_domain,
|
|
330
|
+
)
|
|
331
|
+
except Exception as exc:
|
|
332
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
333
|
+
|
|
334
|
+
if not paths:
|
|
335
|
+
return {"found": False, "paths": []}
|
|
336
|
+
|
|
337
|
+
all_concept_ids: set[int] = set()
|
|
338
|
+
for path in paths:
|
|
339
|
+
for step in path.steps:
|
|
340
|
+
all_concept_ids.add(step.subject.concept_id)
|
|
341
|
+
all_concept_ids.add(step.object.concept_id)
|
|
342
|
+
try:
|
|
343
|
+
views = {v.concept_id: v for v in kg.concept_views(tuple(all_concept_ids), sort=False)} if all_concept_ids else {}
|
|
344
|
+
except Exception:
|
|
345
|
+
views = {}
|
|
346
|
+
|
|
347
|
+
serialised: list[dict[str, Any]] = []
|
|
348
|
+
for path in sorted(paths, key=lambda p: len(p.steps)):
|
|
349
|
+
steps = []
|
|
350
|
+
for step in path.steps:
|
|
351
|
+
try:
|
|
352
|
+
pred_kind = kg.predicate_kind(step.predicate).name
|
|
353
|
+
except Exception:
|
|
354
|
+
pred_kind = "UNKNOWN"
|
|
355
|
+
subj_view = views.get(step.subject.concept_id)
|
|
356
|
+
obj_view = views.get(step.object.concept_id)
|
|
357
|
+
steps.append({
|
|
358
|
+
"subject_id": int(step.subject.concept_id),
|
|
359
|
+
"subject_name": subj_view.concept_name if subj_view else None,
|
|
360
|
+
"predicate": step.predicate,
|
|
361
|
+
"predicate_kind": pred_kind,
|
|
362
|
+
"object_id": int(step.object.concept_id),
|
|
363
|
+
"object_name": obj_view.concept_name if obj_view else None,
|
|
364
|
+
})
|
|
365
|
+
serialised.append({"length": len(steps), "steps": steps})
|
|
366
|
+
|
|
367
|
+
return {"found": True, "paths": serialised}
|
|
368
|
+
|
|
369
|
+
# Predicate-kind presets for equivalency path tools.
|
|
370
|
+
_IDENTITY_KINDS: frozenset = frozenset({PredicateKind.IDENTITY})
|
|
371
|
+
_IDENTITY_AND_HIERARCHY_KINDS: frozenset = frozenset({PredicateKind.IDENTITY, PredicateKind.HIERARCHY})
|
|
372
|
+
|
|
373
|
+
def find_equivalency_path(
|
|
374
|
+
self,
|
|
375
|
+
source_id: int,
|
|
376
|
+
target_id: int,
|
|
377
|
+
max_depth: int,
|
|
378
|
+
allow_hierarchical_traversal: bool = False,
|
|
379
|
+
) -> dict[str, Any]:
|
|
380
|
+
"""Find paths restricted to identity (and optionally hierarchy) edges.
|
|
381
|
+
|
|
382
|
+
When allow_hierarchical_traversal=False only IDENTITY predicates are
|
|
383
|
+
traversed (Maps to, Concept same_as, Concept poss_eq, etc.) — the
|
|
384
|
+
result represents a direct cross-vocabulary equivalence with no loss
|
|
385
|
+
of specificity.
|
|
386
|
+
|
|
387
|
+
When allow_hierarchical_traversal=True HIERARCHY predicates (Is a /
|
|
388
|
+
Subsumes) are also allowed. A path may then step up or down the
|
|
389
|
+
ancestry chain to find a connection, meaning the target may be an
|
|
390
|
+
ancestor of the source — equivalence at a broader level.
|
|
391
|
+
|
|
392
|
+
within_domain is always False for equivalency paths: identity
|
|
393
|
+
relationships are designed to cross vocabulary/domain boundaries.
|
|
394
|
+
"""
|
|
395
|
+
kinds = self._IDENTITY_AND_HIERARCHY_KINDS if allow_hierarchical_traversal else self._IDENTITY_KINDS
|
|
396
|
+
return self.find_path(
|
|
397
|
+
source_id=source_id,
|
|
398
|
+
target_id=target_id,
|
|
399
|
+
max_depth=max_depth,
|
|
400
|
+
predicate_kinds=kinds,
|
|
401
|
+
within_domain=False,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def map_to_standard(self, vocabulary_id: str, code: str) -> dict[str, Any]:
|
|
405
|
+
source_list = self.get_concept_by_code(vocabulary_id, code)
|
|
406
|
+
if not source_list:
|
|
407
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {vocabulary_id}:{code} was not found")
|
|
408
|
+
source = source_list[0]
|
|
409
|
+
|
|
410
|
+
if source["standard_concept"]:
|
|
411
|
+
return {"source": source, "standard_concepts": [source]}
|
|
412
|
+
|
|
413
|
+
kg = self._get_kg()
|
|
414
|
+
try:
|
|
415
|
+
edges = kg.edges(
|
|
416
|
+
source["concept_id"],
|
|
417
|
+
direction="out",
|
|
418
|
+
predicate_kinds=frozenset({PredicateKind.IDENTITY}),
|
|
419
|
+
active_only=True,
|
|
420
|
+
)
|
|
421
|
+
except Exception as exc:
|
|
422
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
423
|
+
|
|
424
|
+
standard_concepts = []
|
|
425
|
+
for edge in edges:
|
|
426
|
+
target = self.get_concept(int(edge.object_id))
|
|
427
|
+
if target and target["standard_concept"]:
|
|
428
|
+
standard_concepts.append(target)
|
|
429
|
+
|
|
430
|
+
return {"source": source, "standard_concepts": standard_concepts}
|
|
431
|
+
|
|
432
|
+
def get_vocabulary_catalogue(self) -> dict[str, Any]:
|
|
433
|
+
vocab_stmt = (
|
|
434
|
+
select(
|
|
435
|
+
Vocabulary.vocabulary_id,
|
|
436
|
+
Vocabulary.vocabulary_name,
|
|
437
|
+
func.count(Concept.concept_id).label("concept_count"),
|
|
438
|
+
)
|
|
439
|
+
.outerjoin(Concept, Concept.vocabulary_id == Vocabulary.vocabulary_id)
|
|
440
|
+
.group_by(Vocabulary.vocabulary_id, Vocabulary.vocabulary_name)
|
|
441
|
+
.order_by(Vocabulary.vocabulary_id)
|
|
442
|
+
)
|
|
443
|
+
domain_stmt = (
|
|
444
|
+
select(
|
|
445
|
+
Domain.domain_id,
|
|
446
|
+
Domain.domain_name,
|
|
447
|
+
func.count(Concept.concept_id).label("concept_count"),
|
|
448
|
+
)
|
|
449
|
+
.outerjoin(Concept, Concept.domain_id == Domain.domain_id)
|
|
450
|
+
.group_by(Domain.domain_id, Domain.domain_name)
|
|
451
|
+
.order_by(Domain.domain_id)
|
|
452
|
+
)
|
|
453
|
+
class_stmt = (
|
|
454
|
+
select(Concept_Class.concept_class_id, Concept_Class.concept_class_name)
|
|
455
|
+
.order_by(Concept_Class.concept_class_id)
|
|
456
|
+
)
|
|
457
|
+
kg = self._get_kg()
|
|
458
|
+
try:
|
|
459
|
+
with kg.session_factory() as session:
|
|
460
|
+
vocab_rows = session.execute(vocab_stmt).all()
|
|
461
|
+
domain_rows = session.execute(domain_stmt).all()
|
|
462
|
+
class_rows = session.execute(class_stmt).all()
|
|
463
|
+
except Exception as exc:
|
|
464
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
465
|
+
|
|
466
|
+
return {
|
|
467
|
+
"vocabularies": [
|
|
468
|
+
{"vocabulary_id": r[0], "vocabulary_name": r[1], "concept_count": int(r[2])}
|
|
469
|
+
for r in vocab_rows
|
|
470
|
+
],
|
|
471
|
+
"domains": [
|
|
472
|
+
{"domain_id": r[0], "domain_name": r[1], "concept_count": int(r[2])}
|
|
473
|
+
for r in domain_rows
|
|
474
|
+
],
|
|
475
|
+
"concept_classes": [
|
|
476
|
+
{"concept_class_id": r[0], "concept_class_name": r[1]}
|
|
477
|
+
for r in class_rows
|
|
478
|
+
],
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
def get_descendants(self, concept_id: int, max_depth: int) -> list[dict[str, Any]]:
|
|
482
|
+
kg = self._get_kg()
|
|
483
|
+
if self.get_concept(concept_id) is None:
|
|
484
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
|
|
485
|
+
|
|
486
|
+
queue: deque[tuple[int, int]] = deque((child_id, 1) for child_id in kg.children(concept_id))
|
|
487
|
+
return self._walk_hierarchy(queue=queue, neighbour_getter=kg.children, max_depth=max_depth)
|
|
488
|
+
|
|
489
|
+
def _serialise_ground_result(self, result: object, views: dict) -> dict[str, Any]:
|
|
490
|
+
view = views.get(getattr(result, "concept_id", None))
|
|
491
|
+
concept_id = int(result.concept_id) # type: ignore[attr-defined]
|
|
492
|
+
original_id = getattr(result, "original_id", None)
|
|
493
|
+
standardized_from = None
|
|
494
|
+
if original_id is not None and int(original_id) != concept_id:
|
|
495
|
+
standardized_from = {
|
|
496
|
+
"concept_id": int(original_id),
|
|
497
|
+
"concept_name": getattr(result, "original_name", None),
|
|
498
|
+
}
|
|
499
|
+
emb_score = getattr(result, "embedding_score", None)
|
|
500
|
+
return {
|
|
501
|
+
"concept_id": concept_id,
|
|
502
|
+
"concept_name": result.concept_name, # type: ignore[attr-defined]
|
|
503
|
+
"vocabulary_id": view.vocabulary_id if view else None,
|
|
504
|
+
"domain_id": view.domain_id if view else None,
|
|
505
|
+
"concept_class_id": view.concept_class_id if view else None,
|
|
506
|
+
"standard_concept": True,
|
|
507
|
+
"match_kind": self._label_match_kind_name(result.match_kind), # type: ignore[attr-defined]
|
|
508
|
+
"matched_label": getattr(result, "matched_concept_label", None),
|
|
509
|
+
"total_score": round(float(result.total_score), 4), # type: ignore[attr-defined]
|
|
510
|
+
"relevance": round(float(getattr(result, "relevance", 0.0)), 4),
|
|
511
|
+
"parsimony_penalty": round(float(getattr(result, "parsimony_penalty", 0.0)), 4),
|
|
512
|
+
"broadness_bonus": round(float(getattr(result, "broadness_bonus", 0.0)), 4),
|
|
513
|
+
"embedding_score": round(float(emb_score), 4) if emb_score is not None else None,
|
|
514
|
+
"separation": int(getattr(result, "separation", 0)),
|
|
515
|
+
"standardized_from": standardized_from,
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
def _serialise_edge_out(self, edge: object, views: dict) -> dict[str, Any]:
|
|
519
|
+
view = views.get(int(edge.object_id)) # type: ignore[attr-defined]
|
|
520
|
+
return {
|
|
521
|
+
"relationship_id": edge.predicate_id, # type: ignore[attr-defined]
|
|
522
|
+
"predicate_kind": edge.predicate_kind.name, # type: ignore[attr-defined]
|
|
523
|
+
"target_concept_id": int(edge.object_id), # type: ignore[attr-defined]
|
|
524
|
+
"target_concept_name": view.concept_name if view else None,
|
|
525
|
+
"valid": edge.invalid_reason is None, # type: ignore[attr-defined]
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
def _serialise_edge_in(self, edge: object, views: dict) -> dict[str, Any]:
|
|
529
|
+
view = views.get(int(edge.subject_id)) # type: ignore[attr-defined]
|
|
530
|
+
return {
|
|
531
|
+
"relationship_id": edge.predicate_id, # type: ignore[attr-defined]
|
|
532
|
+
"predicate_kind": edge.predicate_kind.name, # type: ignore[attr-defined]
|
|
533
|
+
"source_concept_id": int(edge.subject_id), # type: ignore[attr-defined]
|
|
534
|
+
"source_concept_name": view.concept_name if view else None,
|
|
535
|
+
"valid": edge.invalid_reason is None, # type: ignore[attr-defined]
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
@staticmethod
|
|
539
|
+
def _label_match_kind_name(match_kind: object) -> str:
|
|
540
|
+
_MAP = {0: "EXACT", 1: "FULLTEXT", 2: "PARTIAL", 3: "EMBEDDING_NEAREST"}
|
|
541
|
+
val = getattr(match_kind, "value", None)
|
|
542
|
+
if isinstance(val, int):
|
|
543
|
+
return _MAP.get(val, str(match_kind))
|
|
544
|
+
return str(match_kind)
|
|
545
|
+
|
|
546
|
+
def _walk_hierarchy(self, *, queue: deque[tuple[int, int]], neighbour_getter: Callable[[int], Any], max_depth: int) -> list[dict[str, Any]]:
|
|
547
|
+
results: list[dict[str, Any]] = []
|
|
548
|
+
visited: set[int] = set()
|
|
549
|
+
kg = self._get_kg()
|
|
550
|
+
|
|
551
|
+
while queue:
|
|
552
|
+
current_id, depth = queue.popleft()
|
|
553
|
+
if current_id in visited or depth > max_depth:
|
|
554
|
+
continue
|
|
555
|
+
visited.add(current_id)
|
|
556
|
+
|
|
557
|
+
try:
|
|
558
|
+
concept_view = kg.concept_view(current_id)
|
|
559
|
+
except Exception as exc:
|
|
560
|
+
if self._is_not_found(exc):
|
|
561
|
+
continue
|
|
562
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
563
|
+
|
|
564
|
+
results.append(self._serialise_hierarchy_view(concept_view, depth))
|
|
565
|
+
|
|
566
|
+
if depth < max_depth:
|
|
567
|
+
try:
|
|
568
|
+
next_ids = neighbour_getter(current_id)
|
|
569
|
+
except Exception as exc:
|
|
570
|
+
raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
|
|
571
|
+
for next_id in next_ids:
|
|
572
|
+
if next_id not in visited:
|
|
573
|
+
queue.append((int(next_id), depth + 1))
|
|
574
|
+
|
|
575
|
+
results.sort(key=lambda item: (item["depth"], item["concept_id"]))
|
|
576
|
+
return results
|
|
577
|
+
|
|
578
|
+
def _get_kg(self) -> KnowledgeGraph:
|
|
579
|
+
if self._kg is not None:
|
|
580
|
+
return self._kg
|
|
581
|
+
|
|
582
|
+
# Fail fast with a clear message if the database is unreachable before
|
|
583
|
+
# KnowledgeGraph has a chance to raise something opaque.
|
|
584
|
+
try:
|
|
585
|
+
with self.engine.connect() as conn:
|
|
586
|
+
conn.execute(text("SELECT 1"))
|
|
587
|
+
except Exception as exc:
|
|
588
|
+
raise GroundworkersError("DB_UNAVAILABLE", f"Cannot connect to database: {exc}") from exc
|
|
589
|
+
|
|
590
|
+
try:
|
|
591
|
+
self._kg = KnowledgeGraph(cdm_engine=self.engine)
|
|
592
|
+
except Exception as exc:
|
|
593
|
+
raise self._wrap_graph_error(exc, default_code="BACKEND_UNAVAIL")
|
|
594
|
+
return self._kg
|
|
595
|
+
|
|
596
|
+
# Stable SNOMED concept codes for the top-level concept in each standard OMOP domain.
|
|
597
|
+
# These are consistent across all Athena vocabulary releases (concept_ids may differ
|
|
598
|
+
# between instances, but concept_codes are stable).
|
|
599
|
+
_DOMAIN_ROOT_CODES: dict[str, tuple[str, str]] = {
|
|
600
|
+
"condition": ("SNOMED", "404684003"), # Clinical finding
|
|
601
|
+
"procedure": ("SNOMED", "71388002"), # Procedure
|
|
602
|
+
"drug": ("SNOMED", "373873005"), # Pharmaceutical / biologic product
|
|
603
|
+
"measurement": ("SNOMED", "363787002"), # Observable entity
|
|
604
|
+
"device": ("SNOMED", "260787004"), # Physical object
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
def _get_domain_root_ids(self, domain: str | None) -> tuple[int, ...]:
|
|
608
|
+
"""Return 1–3 top-level concept IDs to use as hierarchy anchors for grounding.
|
|
609
|
+
|
|
610
|
+
Fast path: look up a known SNOMED root by concept_code (single-row lookup).
|
|
611
|
+
Fallback for unknown domains: find the most-connected ancestor via GROUP BY
|
|
612
|
+
(one query, uses the ancestor_concept_id index).
|
|
613
|
+
Results are cached on the adapter instance.
|
|
614
|
+
"""
|
|
615
|
+
if not hasattr(self, "_root_ids_cache"):
|
|
616
|
+
self._root_ids_cache: dict[str, tuple[int, ...]] = {}
|
|
617
|
+
cache_key = domain or ""
|
|
618
|
+
if cache_key in self._root_ids_cache:
|
|
619
|
+
return self._root_ids_cache[cache_key]
|
|
620
|
+
|
|
621
|
+
result: tuple[int, ...] = ()
|
|
622
|
+
kg = self._get_kg()
|
|
623
|
+
|
|
624
|
+
if domain and domain.lower() in self._DOMAIN_ROOT_CODES:
|
|
625
|
+
# Fast path: single-row lookup by the stable SNOMED root concept_code.
|
|
626
|
+
vocab_id, code = self._DOMAIN_ROOT_CODES[domain.lower()]
|
|
627
|
+
stmt = (
|
|
628
|
+
select(Concept.concept_id)
|
|
629
|
+
.where(
|
|
630
|
+
Concept.concept_code == code,
|
|
631
|
+
Concept.vocabulary_id == vocab_id,
|
|
632
|
+
Concept.standard_concept == "S",
|
|
633
|
+
)
|
|
634
|
+
.limit(1)
|
|
635
|
+
)
|
|
636
|
+
try:
|
|
637
|
+
with kg.session_factory() as session:
|
|
638
|
+
rows = session.execute(stmt).all()
|
|
639
|
+
result = tuple(int(r[0]) for r in rows)
|
|
640
|
+
except Exception as exc:
|
|
641
|
+
raise GroundworkersError(
|
|
642
|
+
"QUERY_ERROR",
|
|
643
|
+
f"Failed to resolve hierarchy anchors for domain {domain!r}: {exc}",
|
|
644
|
+
) from exc
|
|
645
|
+
|
|
646
|
+
if not result and domain:
|
|
647
|
+
# Fallback for unknown domains, or when the known-code lookup missed.
|
|
648
|
+
# Find the ancestor with the most descendants in this domain — the true
|
|
649
|
+
# root of the hierarchy has the highest descendant count.
|
|
650
|
+
stmt = (
|
|
651
|
+
select(Concept_Ancestor.ancestor_concept_id)
|
|
652
|
+
.join(Concept, Concept.concept_id == Concept_Ancestor.ancestor_concept_id)
|
|
653
|
+
.where(
|
|
654
|
+
func.lower(Concept.domain_id) == domain.lower(),
|
|
655
|
+
Concept.standard_concept == "S",
|
|
656
|
+
Concept_Ancestor.min_levels_of_separation > 0,
|
|
657
|
+
)
|
|
658
|
+
.group_by(Concept_Ancestor.ancestor_concept_id)
|
|
659
|
+
.order_by(func.count().desc())
|
|
660
|
+
.limit(3)
|
|
661
|
+
)
|
|
662
|
+
try:
|
|
663
|
+
with kg.session_factory() as session:
|
|
664
|
+
rows = session.execute(stmt).all()
|
|
665
|
+
result = tuple(int(r[0]) for r in rows)
|
|
666
|
+
except Exception as exc:
|
|
667
|
+
raise GroundworkersError(
|
|
668
|
+
"QUERY_ERROR",
|
|
669
|
+
f"Failed to resolve hierarchy anchors for domain {domain!r}: {exc}",
|
|
670
|
+
) from exc
|
|
671
|
+
|
|
672
|
+
self._root_ids_cache[cache_key] = result
|
|
673
|
+
return result
|
|
674
|
+
|
|
675
|
+
def _serialise_concept_view(self, concept_view: object) -> dict[str, Any]:
|
|
676
|
+
return {
|
|
677
|
+
"concept_id": int(concept_view.concept_id), # type: ignore[attr-defined]
|
|
678
|
+
"concept_name": concept_view.concept_name, # type: ignore[attr-defined]
|
|
679
|
+
"concept_code": concept_view.concept_code, # type: ignore[attr-defined]
|
|
680
|
+
"vocabulary_id": concept_view.vocabulary_id, # type: ignore[attr-defined]
|
|
681
|
+
"domain_id": concept_view.domain_id, # type: ignore[attr-defined]
|
|
682
|
+
"concept_class_id": concept_view.concept_class_id, # type: ignore[attr-defined]
|
|
683
|
+
"standard_concept": bool(concept_view.standard_concept), # type: ignore[attr-defined]
|
|
684
|
+
"valid_start_date": self._date_to_iso(concept_view.valid_start_date), # type: ignore[attr-defined]
|
|
685
|
+
"valid_end_date": self._date_to_iso(concept_view.valid_end_date), # type: ignore[attr-defined]
|
|
686
|
+
"invalid_reason": concept_view.invalid_reason, # type: ignore[attr-defined]
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
def _serialise_hierarchy_view(self, concept_view: object, depth: int) -> dict[str, Any]:
|
|
690
|
+
return {
|
|
691
|
+
"concept_id": int(concept_view.concept_id), # type: ignore[attr-defined]
|
|
692
|
+
"concept_name": concept_view.concept_name, # type: ignore[attr-defined]
|
|
693
|
+
"vocabulary_id": concept_view.vocabulary_id, # type: ignore[attr-defined]
|
|
694
|
+
"domain_id": concept_view.domain_id, # type: ignore[attr-defined]
|
|
695
|
+
"standard_concept": bool(concept_view.standard_concept), # type: ignore[attr-defined]
|
|
696
|
+
"depth": depth,
|
|
697
|
+
}
|
|
698
|
+
|
|
699
|
+
@staticmethod
|
|
700
|
+
def _date_to_iso(value: date | str) -> str:
|
|
701
|
+
if isinstance(value, date):
|
|
702
|
+
return value.isoformat()
|
|
703
|
+
return value
|
|
704
|
+
|
|
705
|
+
@staticmethod
|
|
706
|
+
def _is_not_found(exc: Exception) -> bool:
|
|
707
|
+
if isinstance(exc, NoResultFound):
|
|
708
|
+
return True
|
|
709
|
+
return any(cls.__name__ in {"NotFoundError", "ConceptNotFoundError"} for cls in type(exc).__mro__)
|
|
710
|
+
|
|
711
|
+
@staticmethod
|
|
712
|
+
def _wrap_graph_error(exc: Exception, *, default_code: str) -> GroundworkersError:
|
|
713
|
+
if isinstance(exc, GroundworkersError):
|
|
714
|
+
return exc
|
|
715
|
+
msg = str(exc)
|
|
716
|
+
if "relationship classification" in msg or "relationship_mapping" in msg:
|
|
717
|
+
return GroundworkersError(
|
|
718
|
+
"BACKEND_UNAVAIL",
|
|
719
|
+
"omop-graph setup incomplete — run: omop-graph relationship-classification",
|
|
720
|
+
)
|
|
721
|
+
return GroundworkersError(default_code, msg or repr(exc))
|