groundworkers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,721 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import deque
4
+ from collections.abc import Callable
5
+ from datetime import date
6
+ from typing import Any
7
+
8
+ from omop_graph.extensions.omop_alchemy import PredicateKind
9
+ from omop_graph.graph.constraints import SearchConstraintConcept
10
+ from omop_graph.graph.kg import KnowledgeGraph
11
+ from omop_graph.graph.paths import find_shortest_paths_batch
12
+ from omop_graph.graph.traverse import traverse
13
+ from omop_graph.reasoning.grounding import GroundingConstraints, ground_term
14
+ from omop_graph.reasoning.resolvers import ResolverPipeline
15
+ from omop_graph.reasoning.resolvers.resolvers import (
16
+ EmbeddingResolver,
17
+ ExactLabelResolver,
18
+ ExactSynonymResolver,
19
+ FullTextResolver,
20
+ FullTextSynonymResolver,
21
+ PartialLabelResolver,
22
+ PartialSynonymResolver,
23
+ )
24
+ from omop_alchemy.cdm.model.vocabulary import (
25
+ Concept,
26
+ Concept_Ancestor,
27
+ Concept_Class,
28
+ Domain,
29
+ Vocabulary,
30
+ )
31
+ from sqlalchemy import func, select, text
32
+ from sqlalchemy.engine import Engine
33
+ from sqlalchemy.exc import NoResultFound
34
+
35
+ from groundworkers.base.errors import GroundworkersError
36
+
37
+ # TODO: some of this adapter logic really should be pushed back into
38
+ # the core omop-graph library, but waiting for the use-cases and paths
39
+ # to stabilise first.
40
+
41
+ class OmopGraphAdapter:
42
+ def __init__(
43
+ self,
44
+ engine: Engine,
45
+ *,
46
+ vocab_schema: str = "omop_vocab",
47
+ emb_model_name: str | None = None,
48
+ ) -> None:
49
+ self.engine = engine
50
+ self.vocab_schema = vocab_schema
51
+ self.emb_model_name = emb_model_name
52
+ self._kg: KnowledgeGraph | None = None
53
+
54
+ def is_available(self) -> bool:
55
+ try:
56
+ self._get_kg()
57
+ return True
58
+ except GroundworkersError:
59
+ return False
60
+
61
+ def close(self) -> None:
62
+ self.engine.dispose()
63
+ self._kg = None
64
+
65
+ def get_concept(self, concept_id: int) -> dict[str, Any] | None:
66
+ try:
67
+ concept_view = self._get_kg().concept_view(concept_id)
68
+ except Exception as exc:
69
+ if self._is_not_found(exc):
70
+ return None
71
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
72
+ return self._serialise_concept_view(concept_view)
73
+
74
+ def get_concept_by_code(self, vocabulary_id: str, code: str) -> list[dict[str, Any]]:
75
+ try:
76
+ concept_id = self._get_kg().concept_id_by_code(vocabulary_id, code)
77
+ concept_view = self._get_kg().concept_view(concept_id)
78
+ except Exception as exc:
79
+ if self._is_not_found(exc):
80
+ return []
81
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
82
+ return [self._serialise_concept_view(concept_view)]
83
+
84
+ def get_ancestors(self, concept_id: int, max_depth: int) -> list[dict[str, Any]]:
85
+ kg = self._get_kg()
86
+ if self.get_concept(concept_id) is None:
87
+ raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
88
+
89
+ queue: deque[tuple[int, int]] = deque((parent_id, 1) for parent_id in kg.parents(concept_id))
90
+ return self._walk_hierarchy(queue=queue, neighbour_getter=kg.parents, max_depth=max_depth)
91
+
92
+ def ground(
93
+ self,
94
+ query: str,
95
+ limit: int,
96
+ domain: str | None,
97
+ vocabulary_id: str | None,
98
+ parent_ids: tuple[int, ...] | None = None,
99
+ ) -> dict[str, Any]:
100
+ """Ground free text to ranked standard OMOP concepts.
101
+
102
+ Returns a dict with keys:
103
+ results — ranked list of grounded concepts with scoring fields
104
+ grounding_explanation — summary of which tier matched and what constraints ran
105
+ """
106
+ kg = self._get_kg()
107
+
108
+ # Normalise domain to its canonical OMOP casing (e.g. "condition" → "Condition").
109
+ # OMOP domain_id values are title-cased; a case-insensitive match against the
110
+ # known root codes table handles the common mistake of passing lowercase names.
111
+ if domain is not None:
112
+ _domain_lower = domain.lower()
113
+ domain = next(
114
+ (k for k in self._DOMAIN_ROOT_CODES if k.lower() == _domain_lower),
115
+ domain, # unknown domain: pass through unchanged
116
+ )
117
+
118
+ search_constraint = None
119
+ if domain or vocabulary_id:
120
+ search_constraint = SearchConstraintConcept(
121
+ domains=(domain,) if domain else None,
122
+ vocabularies=(vocabulary_id,) if vocabulary_id else None,
123
+ )
124
+
125
+ if parent_ids is not None:
126
+ resolved_parent_ids: tuple[int, ...] = parent_ids
127
+ parent_ids_source = "explicit"
128
+ elif domain is not None:
129
+ resolved_parent_ids = self._get_domain_root_ids(domain)
130
+ parent_ids_source = "domain_root"
131
+ else:
132
+ # No domain filter: collect roots across all known domains so hierarchy
133
+ # anchoring doesn't silently drop every candidate.
134
+ all_roots: list[int] = []
135
+ for d in self._DOMAIN_ROOT_CODES:
136
+ all_roots.extend(self._get_domain_root_ids(d))
137
+ resolved_parent_ids = tuple(all_roots)
138
+ parent_ids_source = "all_domain_roots"
139
+
140
+ if not resolved_parent_ids:
141
+ raise GroundworkersError(
142
+ "QUERY_ERROR",
143
+ "No hierarchy anchors found — ensure the OMOP vocabulary is bootstrapped "
144
+ "(concept and concept_ancestor tables must be populated).",
145
+ )
146
+
147
+ constraints = GroundingConstraints(parent_ids=resolved_parent_ids, search_constraint=search_constraint)
148
+
149
+ # Tiered pipeline — short-circuit on the first tier that returns results,
150
+ # avoiding lower-quality resolvers when a better match exists.
151
+ # Each tier pairs the label resolver with its synonym counterpart so that
152
+ # abbreviations, trade names, and alternate spellings are matched at the
153
+ # same confidence level as the primary concept name.
154
+ # FullTextSynonymResolver degrades gracefully (returns nothing) when the
155
+ # tsvector sidecar columns have not been installed.
156
+ tiers: list[tuple[Any, ...]] = [
157
+ (ExactLabelResolver(), ExactSynonymResolver()),
158
+ (FullTextResolver(), FullTextSynonymResolver()),
159
+ ]
160
+ if self.emb_model_name:
161
+ tiers.append((EmbeddingResolver(),))
162
+ tiers.append((PartialLabelResolver(), PartialSynonymResolver()))
163
+
164
+ results: list[Any] = []
165
+ for tier in tiers:
166
+ pipeline = ResolverPipeline(resolvers=tier)
167
+ try:
168
+ results = ground_term(
169
+ pipeline, kg, query,
170
+ query_embedding=None,
171
+ constraints=constraints,
172
+ max_candidates=limit,
173
+ )
174
+ except Exception as exc:
175
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
176
+ if results:
177
+ break
178
+
179
+ concept_ids = tuple(dict.fromkeys(r.concept_id for r in results))
180
+ try:
181
+ views = {v.concept_id: v for v in kg.concept_views(concept_ids, sort=False)} if concept_ids else {}
182
+ except Exception:
183
+ views = {}
184
+
185
+ matched_tier = self._label_match_kind_name(results[0].match_kind) if results else None
186
+ used_embedding = any(getattr(r, "embedding_score", None) is not None for r in results)
187
+
188
+ return {
189
+ "results": [self._serialise_ground_result(r, views) for r in results],
190
+ "grounding_explanation": {
191
+ "matched_tier": matched_tier,
192
+ "used_embedding": used_embedding,
193
+ "effective_parent_ids": list(resolved_parent_ids),
194
+ "parent_ids_source": parent_ids_source,
195
+ },
196
+ }
197
+
198
+ # Valid predicate kind names accepted by get_neighbors (case-insensitive).
199
+ _PREDICATE_KIND_NAMES: dict[str, PredicateKind] = {pk.name.upper(): pk for pk in PredicateKind}
200
+
201
+ def get_neighbors(
202
+ self,
203
+ concept_id: int,
204
+ max_depth: int,
205
+ predicate_kinds: list[str] | None,
206
+ max_nodes: int,
207
+ include_edges: bool,
208
+ ) -> dict[str, Any]:
209
+ """Bounded multi-hop neighborhood exploration via BFS.
210
+
211
+ Follows outgoing relationship edges from the seed concept up to
212
+ max_depth hops, collecting all reachable concepts and (optionally)
213
+ the edges that connect them.
214
+ """
215
+ kg = self._get_kg()
216
+ if self.get_concept(concept_id) is None:
217
+ raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
218
+
219
+ pk_set: set[PredicateKind] | None = None
220
+ if predicate_kinds is not None:
221
+ pk_set = set()
222
+ for pk_name in predicate_kinds:
223
+ key = pk_name.upper()
224
+ if key not in self._PREDICATE_KIND_NAMES:
225
+ valid = sorted(self._PREDICATE_KIND_NAMES)
226
+ raise GroundworkersError(
227
+ "INVALID_INPUT",
228
+ f"Unknown predicate_kind {pk_name!r}. Valid values: {valid}",
229
+ )
230
+ pk_set.add(self._PREDICATE_KIND_NAMES[key])
231
+
232
+ try:
233
+ subgraph, graph_trace = traverse(
234
+ kg=kg,
235
+ seeds=(concept_id,),
236
+ predicate_kinds=pk_set,
237
+ max_depth=max_depth,
238
+ on=None,
239
+ max_nodes=max_nodes,
240
+ trace=True, # always trace so we can report terminated_reason
241
+ )
242
+ except Exception as exc:
243
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
244
+
245
+ neighbor_ids = tuple(n for n in sorted(subgraph.nodes) if n != concept_id)
246
+ try:
247
+ views = {v.concept_id: v for v in kg.concept_views(neighbor_ids, sort=False)} if neighbor_ids else {}
248
+ except Exception:
249
+ views = {}
250
+
251
+ neighbors: list[dict[str, Any]] = []
252
+ for nid in neighbor_ids:
253
+ view = views.get(nid)
254
+ if view:
255
+ neighbors.append({
256
+ "concept_id": int(view.concept_id),
257
+ "concept_name": view.concept_name,
258
+ "vocabulary_id": view.vocabulary_id,
259
+ "domain_id": view.domain_id,
260
+ "concept_class_id": view.concept_class_id,
261
+ "standard_concept": bool(view.standard_concept),
262
+ })
263
+
264
+ edges: list[dict[str, Any]] = []
265
+ if include_edges:
266
+ for edge in subgraph.edges:
267
+ edges.append({
268
+ "subject_id": int(edge.subject_id),
269
+ "predicate_id": edge.predicate_id,
270
+ "predicate_kind": edge.predicate_kind.name,
271
+ "object_id": int(edge.object_id),
272
+ })
273
+
274
+ terminated_reason = graph_trace.terminated_reason if graph_trace else None
275
+ return {
276
+ "concept_id": concept_id,
277
+ "neighbor_count": len(neighbors),
278
+ "edge_count": len(subgraph.edges),
279
+ "neighbors": neighbors,
280
+ "edges": edges,
281
+ "terminated_early": terminated_reason is not None,
282
+ "terminated_reason": terminated_reason,
283
+ }
284
+
285
+ def get_edges(self, concept_id: int) -> dict[str, Any]:
286
+ kg = self._get_kg()
287
+ if self.get_concept(concept_id) is None:
288
+ raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
289
+ try:
290
+ outbound = kg.edges(concept_id, direction="out", active_only=False)
291
+ inbound = kg.edges(concept_id, direction="in", active_only=False)
292
+ except Exception as exc:
293
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
294
+
295
+ other_ids = tuple(dict.fromkeys([e.object_id for e in outbound] + [e.subject_id for e in inbound]))
296
+ try:
297
+ views = {v.concept_id: v for v in kg.concept_views(other_ids, sort=False)} if other_ids else {}
298
+ except Exception:
299
+ views = {}
300
+
301
+ return {
302
+ "outbound": [self._serialise_edge_out(e, views) for e in outbound],
303
+ "inbound": [self._serialise_edge_in(e, views) for e in inbound],
304
+ }
305
+
306
+ def find_path(
307
+ self,
308
+ source_id: int,
309
+ target_id: int,
310
+ max_depth: int,
311
+ predicate_kinds: frozenset | None = None,
312
+ within_domain: bool = True,
313
+ ) -> dict[str, Any]:
314
+ kg = self._get_kg()
315
+ if self.get_concept(source_id) is None:
316
+ raise GroundworkersError("NOT_FOUND", f"Concept {source_id} was not found")
317
+ if source_id == target_id:
318
+ return {"found": True, "paths": [{"length": 0, "steps": []}]}
319
+ if self.get_concept(target_id) is None:
320
+ raise GroundworkersError("NOT_FOUND", f"Concept {target_id} was not found")
321
+
322
+ try:
323
+ paths = find_shortest_paths_batch(
324
+ kg,
325
+ source_id,
326
+ target_id,
327
+ max_depth=max_depth,
328
+ predicate_kinds=predicate_kinds,
329
+ within_domain=within_domain,
330
+ )
331
+ except Exception as exc:
332
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
333
+
334
+ if not paths:
335
+ return {"found": False, "paths": []}
336
+
337
+ all_concept_ids: set[int] = set()
338
+ for path in paths:
339
+ for step in path.steps:
340
+ all_concept_ids.add(step.subject.concept_id)
341
+ all_concept_ids.add(step.object.concept_id)
342
+ try:
343
+ views = {v.concept_id: v for v in kg.concept_views(tuple(all_concept_ids), sort=False)} if all_concept_ids else {}
344
+ except Exception:
345
+ views = {}
346
+
347
+ serialised: list[dict[str, Any]] = []
348
+ for path in sorted(paths, key=lambda p: len(p.steps)):
349
+ steps = []
350
+ for step in path.steps:
351
+ try:
352
+ pred_kind = kg.predicate_kind(step.predicate).name
353
+ except Exception:
354
+ pred_kind = "UNKNOWN"
355
+ subj_view = views.get(step.subject.concept_id)
356
+ obj_view = views.get(step.object.concept_id)
357
+ steps.append({
358
+ "subject_id": int(step.subject.concept_id),
359
+ "subject_name": subj_view.concept_name if subj_view else None,
360
+ "predicate": step.predicate,
361
+ "predicate_kind": pred_kind,
362
+ "object_id": int(step.object.concept_id),
363
+ "object_name": obj_view.concept_name if obj_view else None,
364
+ })
365
+ serialised.append({"length": len(steps), "steps": steps})
366
+
367
+ return {"found": True, "paths": serialised}
368
+
369
+ # Predicate-kind presets for equivalency path tools.
370
+ _IDENTITY_KINDS: frozenset = frozenset({PredicateKind.IDENTITY})
371
+ _IDENTITY_AND_HIERARCHY_KINDS: frozenset = frozenset({PredicateKind.IDENTITY, PredicateKind.HIERARCHY})
372
+
373
+ def find_equivalency_path(
374
+ self,
375
+ source_id: int,
376
+ target_id: int,
377
+ max_depth: int,
378
+ allow_hierarchical_traversal: bool = False,
379
+ ) -> dict[str, Any]:
380
+ """Find paths restricted to identity (and optionally hierarchy) edges.
381
+
382
+ When allow_hierarchical_traversal=False only IDENTITY predicates are
383
+ traversed (Maps to, Concept same_as, Concept poss_eq, etc.) — the
384
+ result represents a direct cross-vocabulary equivalence with no loss
385
+ of specificity.
386
+
387
+ When allow_hierarchical_traversal=True HIERARCHY predicates (Is a /
388
+ Subsumes) are also allowed. A path may then step up or down the
389
+ ancestry chain to find a connection, meaning the target may be an
390
+ ancestor of the source — equivalence at a broader level.
391
+
392
+ within_domain is always False for equivalency paths: identity
393
+ relationships are designed to cross vocabulary/domain boundaries.
394
+ """
395
+ kinds = self._IDENTITY_AND_HIERARCHY_KINDS if allow_hierarchical_traversal else self._IDENTITY_KINDS
396
+ return self.find_path(
397
+ source_id=source_id,
398
+ target_id=target_id,
399
+ max_depth=max_depth,
400
+ predicate_kinds=kinds,
401
+ within_domain=False,
402
+ )
403
+
404
+ def map_to_standard(self, vocabulary_id: str, code: str) -> dict[str, Any]:
405
+ source_list = self.get_concept_by_code(vocabulary_id, code)
406
+ if not source_list:
407
+ raise GroundworkersError("NOT_FOUND", f"Concept {vocabulary_id}:{code} was not found")
408
+ source = source_list[0]
409
+
410
+ if source["standard_concept"]:
411
+ return {"source": source, "standard_concepts": [source]}
412
+
413
+ kg = self._get_kg()
414
+ try:
415
+ edges = kg.edges(
416
+ source["concept_id"],
417
+ direction="out",
418
+ predicate_kinds=frozenset({PredicateKind.IDENTITY}),
419
+ active_only=True,
420
+ )
421
+ except Exception as exc:
422
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
423
+
424
+ standard_concepts = []
425
+ for edge in edges:
426
+ target = self.get_concept(int(edge.object_id))
427
+ if target and target["standard_concept"]:
428
+ standard_concepts.append(target)
429
+
430
+ return {"source": source, "standard_concepts": standard_concepts}
431
+
432
+ def get_vocabulary_catalogue(self) -> dict[str, Any]:
433
+ vocab_stmt = (
434
+ select(
435
+ Vocabulary.vocabulary_id,
436
+ Vocabulary.vocabulary_name,
437
+ func.count(Concept.concept_id).label("concept_count"),
438
+ )
439
+ .outerjoin(Concept, Concept.vocabulary_id == Vocabulary.vocabulary_id)
440
+ .group_by(Vocabulary.vocabulary_id, Vocabulary.vocabulary_name)
441
+ .order_by(Vocabulary.vocabulary_id)
442
+ )
443
+ domain_stmt = (
444
+ select(
445
+ Domain.domain_id,
446
+ Domain.domain_name,
447
+ func.count(Concept.concept_id).label("concept_count"),
448
+ )
449
+ .outerjoin(Concept, Concept.domain_id == Domain.domain_id)
450
+ .group_by(Domain.domain_id, Domain.domain_name)
451
+ .order_by(Domain.domain_id)
452
+ )
453
+ class_stmt = (
454
+ select(Concept_Class.concept_class_id, Concept_Class.concept_class_name)
455
+ .order_by(Concept_Class.concept_class_id)
456
+ )
457
+ kg = self._get_kg()
458
+ try:
459
+ with kg.session_factory() as session:
460
+ vocab_rows = session.execute(vocab_stmt).all()
461
+ domain_rows = session.execute(domain_stmt).all()
462
+ class_rows = session.execute(class_stmt).all()
463
+ except Exception as exc:
464
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
465
+
466
+ return {
467
+ "vocabularies": [
468
+ {"vocabulary_id": r[0], "vocabulary_name": r[1], "concept_count": int(r[2])}
469
+ for r in vocab_rows
470
+ ],
471
+ "domains": [
472
+ {"domain_id": r[0], "domain_name": r[1], "concept_count": int(r[2])}
473
+ for r in domain_rows
474
+ ],
475
+ "concept_classes": [
476
+ {"concept_class_id": r[0], "concept_class_name": r[1]}
477
+ for r in class_rows
478
+ ],
479
+ }
480
+
481
+ def get_descendants(self, concept_id: int, max_depth: int) -> list[dict[str, Any]]:
482
+ kg = self._get_kg()
483
+ if self.get_concept(concept_id) is None:
484
+ raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} was not found")
485
+
486
+ queue: deque[tuple[int, int]] = deque((child_id, 1) for child_id in kg.children(concept_id))
487
+ return self._walk_hierarchy(queue=queue, neighbour_getter=kg.children, max_depth=max_depth)
488
+
489
+ def _serialise_ground_result(self, result: object, views: dict) -> dict[str, Any]:
490
+ view = views.get(getattr(result, "concept_id", None))
491
+ concept_id = int(result.concept_id) # type: ignore[attr-defined]
492
+ original_id = getattr(result, "original_id", None)
493
+ standardized_from = None
494
+ if original_id is not None and int(original_id) != concept_id:
495
+ standardized_from = {
496
+ "concept_id": int(original_id),
497
+ "concept_name": getattr(result, "original_name", None),
498
+ }
499
+ emb_score = getattr(result, "embedding_score", None)
500
+ return {
501
+ "concept_id": concept_id,
502
+ "concept_name": result.concept_name, # type: ignore[attr-defined]
503
+ "vocabulary_id": view.vocabulary_id if view else None,
504
+ "domain_id": view.domain_id if view else None,
505
+ "concept_class_id": view.concept_class_id if view else None,
506
+ "standard_concept": True,
507
+ "match_kind": self._label_match_kind_name(result.match_kind), # type: ignore[attr-defined]
508
+ "matched_label": getattr(result, "matched_concept_label", None),
509
+ "total_score": round(float(result.total_score), 4), # type: ignore[attr-defined]
510
+ "relevance": round(float(getattr(result, "relevance", 0.0)), 4),
511
+ "parsimony_penalty": round(float(getattr(result, "parsimony_penalty", 0.0)), 4),
512
+ "broadness_bonus": round(float(getattr(result, "broadness_bonus", 0.0)), 4),
513
+ "embedding_score": round(float(emb_score), 4) if emb_score is not None else None,
514
+ "separation": int(getattr(result, "separation", 0)),
515
+ "standardized_from": standardized_from,
516
+ }
517
+
518
+ def _serialise_edge_out(self, edge: object, views: dict) -> dict[str, Any]:
519
+ view = views.get(int(edge.object_id)) # type: ignore[attr-defined]
520
+ return {
521
+ "relationship_id": edge.predicate_id, # type: ignore[attr-defined]
522
+ "predicate_kind": edge.predicate_kind.name, # type: ignore[attr-defined]
523
+ "target_concept_id": int(edge.object_id), # type: ignore[attr-defined]
524
+ "target_concept_name": view.concept_name if view else None,
525
+ "valid": edge.invalid_reason is None, # type: ignore[attr-defined]
526
+ }
527
+
528
+ def _serialise_edge_in(self, edge: object, views: dict) -> dict[str, Any]:
529
+ view = views.get(int(edge.subject_id)) # type: ignore[attr-defined]
530
+ return {
531
+ "relationship_id": edge.predicate_id, # type: ignore[attr-defined]
532
+ "predicate_kind": edge.predicate_kind.name, # type: ignore[attr-defined]
533
+ "source_concept_id": int(edge.subject_id), # type: ignore[attr-defined]
534
+ "source_concept_name": view.concept_name if view else None,
535
+ "valid": edge.invalid_reason is None, # type: ignore[attr-defined]
536
+ }
537
+
538
+ @staticmethod
539
+ def _label_match_kind_name(match_kind: object) -> str:
540
+ _MAP = {0: "EXACT", 1: "FULLTEXT", 2: "PARTIAL", 3: "EMBEDDING_NEAREST"}
541
+ val = getattr(match_kind, "value", None)
542
+ if isinstance(val, int):
543
+ return _MAP.get(val, str(match_kind))
544
+ return str(match_kind)
545
+
546
+ def _walk_hierarchy(self, *, queue: deque[tuple[int, int]], neighbour_getter: Callable[[int], Any], max_depth: int) -> list[dict[str, Any]]:
547
+ results: list[dict[str, Any]] = []
548
+ visited: set[int] = set()
549
+ kg = self._get_kg()
550
+
551
+ while queue:
552
+ current_id, depth = queue.popleft()
553
+ if current_id in visited or depth > max_depth:
554
+ continue
555
+ visited.add(current_id)
556
+
557
+ try:
558
+ concept_view = kg.concept_view(current_id)
559
+ except Exception as exc:
560
+ if self._is_not_found(exc):
561
+ continue
562
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
563
+
564
+ results.append(self._serialise_hierarchy_view(concept_view, depth))
565
+
566
+ if depth < max_depth:
567
+ try:
568
+ next_ids = neighbour_getter(current_id)
569
+ except Exception as exc:
570
+ raise self._wrap_graph_error(exc, default_code="QUERY_ERROR")
571
+ for next_id in next_ids:
572
+ if next_id not in visited:
573
+ queue.append((int(next_id), depth + 1))
574
+
575
+ results.sort(key=lambda item: (item["depth"], item["concept_id"]))
576
+ return results
577
+
578
+ def _get_kg(self) -> KnowledgeGraph:
579
+ if self._kg is not None:
580
+ return self._kg
581
+
582
+ # Fail fast with a clear message if the database is unreachable before
583
+ # KnowledgeGraph has a chance to raise something opaque.
584
+ try:
585
+ with self.engine.connect() as conn:
586
+ conn.execute(text("SELECT 1"))
587
+ except Exception as exc:
588
+ raise GroundworkersError("DB_UNAVAILABLE", f"Cannot connect to database: {exc}") from exc
589
+
590
+ try:
591
+ self._kg = KnowledgeGraph(cdm_engine=self.engine)
592
+ except Exception as exc:
593
+ raise self._wrap_graph_error(exc, default_code="BACKEND_UNAVAIL")
594
+ return self._kg
595
+
596
+ # Stable SNOMED concept codes for the top-level concept in each standard OMOP domain.
597
+ # These are consistent across all Athena vocabulary releases (concept_ids may differ
598
+ # between instances, but concept_codes are stable).
599
+ _DOMAIN_ROOT_CODES: dict[str, tuple[str, str]] = {
600
+ "condition": ("SNOMED", "404684003"), # Clinical finding
601
+ "procedure": ("SNOMED", "71388002"), # Procedure
602
+ "drug": ("SNOMED", "373873005"), # Pharmaceutical / biologic product
603
+ "measurement": ("SNOMED", "363787002"), # Observable entity
604
+ "device": ("SNOMED", "260787004"), # Physical object
605
+ }
606
+
607
+ def _get_domain_root_ids(self, domain: str | None) -> tuple[int, ...]:
608
+ """Return 1–3 top-level concept IDs to use as hierarchy anchors for grounding.
609
+
610
+ Fast path: look up a known SNOMED root by concept_code (single-row lookup).
611
+ Fallback for unknown domains: find the most-connected ancestor via GROUP BY
612
+ (one query, uses the ancestor_concept_id index).
613
+ Results are cached on the adapter instance.
614
+ """
615
+ if not hasattr(self, "_root_ids_cache"):
616
+ self._root_ids_cache: dict[str, tuple[int, ...]] = {}
617
+ cache_key = domain or ""
618
+ if cache_key in self._root_ids_cache:
619
+ return self._root_ids_cache[cache_key]
620
+
621
+ result: tuple[int, ...] = ()
622
+ kg = self._get_kg()
623
+
624
+ if domain and domain.lower() in self._DOMAIN_ROOT_CODES:
625
+ # Fast path: single-row lookup by the stable SNOMED root concept_code.
626
+ vocab_id, code = self._DOMAIN_ROOT_CODES[domain.lower()]
627
+ stmt = (
628
+ select(Concept.concept_id)
629
+ .where(
630
+ Concept.concept_code == code,
631
+ Concept.vocabulary_id == vocab_id,
632
+ Concept.standard_concept == "S",
633
+ )
634
+ .limit(1)
635
+ )
636
+ try:
637
+ with kg.session_factory() as session:
638
+ rows = session.execute(stmt).all()
639
+ result = tuple(int(r[0]) for r in rows)
640
+ except Exception as exc:
641
+ raise GroundworkersError(
642
+ "QUERY_ERROR",
643
+ f"Failed to resolve hierarchy anchors for domain {domain!r}: {exc}",
644
+ ) from exc
645
+
646
+ if not result and domain:
647
+ # Fallback for unknown domains, or when the known-code lookup missed.
648
+ # Find the ancestor with the most descendants in this domain — the true
649
+ # root of the hierarchy has the highest descendant count.
650
+ stmt = (
651
+ select(Concept_Ancestor.ancestor_concept_id)
652
+ .join(Concept, Concept.concept_id == Concept_Ancestor.ancestor_concept_id)
653
+ .where(
654
+ func.lower(Concept.domain_id) == domain.lower(),
655
+ Concept.standard_concept == "S",
656
+ Concept_Ancestor.min_levels_of_separation > 0,
657
+ )
658
+ .group_by(Concept_Ancestor.ancestor_concept_id)
659
+ .order_by(func.count().desc())
660
+ .limit(3)
661
+ )
662
+ try:
663
+ with kg.session_factory() as session:
664
+ rows = session.execute(stmt).all()
665
+ result = tuple(int(r[0]) for r in rows)
666
+ except Exception as exc:
667
+ raise GroundworkersError(
668
+ "QUERY_ERROR",
669
+ f"Failed to resolve hierarchy anchors for domain {domain!r}: {exc}",
670
+ ) from exc
671
+
672
+ self._root_ids_cache[cache_key] = result
673
+ return result
674
+
675
+ def _serialise_concept_view(self, concept_view: object) -> dict[str, Any]:
676
+ return {
677
+ "concept_id": int(concept_view.concept_id), # type: ignore[attr-defined]
678
+ "concept_name": concept_view.concept_name, # type: ignore[attr-defined]
679
+ "concept_code": concept_view.concept_code, # type: ignore[attr-defined]
680
+ "vocabulary_id": concept_view.vocabulary_id, # type: ignore[attr-defined]
681
+ "domain_id": concept_view.domain_id, # type: ignore[attr-defined]
682
+ "concept_class_id": concept_view.concept_class_id, # type: ignore[attr-defined]
683
+ "standard_concept": bool(concept_view.standard_concept), # type: ignore[attr-defined]
684
+ "valid_start_date": self._date_to_iso(concept_view.valid_start_date), # type: ignore[attr-defined]
685
+ "valid_end_date": self._date_to_iso(concept_view.valid_end_date), # type: ignore[attr-defined]
686
+ "invalid_reason": concept_view.invalid_reason, # type: ignore[attr-defined]
687
+ }
688
+
689
+ def _serialise_hierarchy_view(self, concept_view: object, depth: int) -> dict[str, Any]:
690
+ return {
691
+ "concept_id": int(concept_view.concept_id), # type: ignore[attr-defined]
692
+ "concept_name": concept_view.concept_name, # type: ignore[attr-defined]
693
+ "vocabulary_id": concept_view.vocabulary_id, # type: ignore[attr-defined]
694
+ "domain_id": concept_view.domain_id, # type: ignore[attr-defined]
695
+ "standard_concept": bool(concept_view.standard_concept), # type: ignore[attr-defined]
696
+ "depth": depth,
697
+ }
698
+
699
+ @staticmethod
700
+ def _date_to_iso(value: date | str) -> str:
701
+ if isinstance(value, date):
702
+ return value.isoformat()
703
+ return value
704
+
705
+ @staticmethod
706
+ def _is_not_found(exc: Exception) -> bool:
707
+ if isinstance(exc, NoResultFound):
708
+ return True
709
+ return any(cls.__name__ in {"NotFoundError", "ConceptNotFoundError"} for cls in type(exc).__mro__)
710
+
711
+ @staticmethod
712
+ def _wrap_graph_error(exc: Exception, *, default_code: str) -> GroundworkersError:
713
+ if isinstance(exc, GroundworkersError):
714
+ return exc
715
+ msg = str(exc)
716
+ if "relationship classification" in msg or "relationship_mapping" in msg:
717
+ return GroundworkersError(
718
+ "BACKEND_UNAVAIL",
719
+ "omop-graph setup incomplete — run: omop-graph relationship-classification",
720
+ )
721
+ return GroundworkersError(default_code, msg or repr(exc))