groundworkers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- groundworkers/__init__.py +3 -0
- groundworkers/adapters/__init__.py +1 -0
- groundworkers/adapters/omop_emb.py +251 -0
- groundworkers/adapters/omop_graph.py +721 -0
- groundworkers/adapters/omop_vocab.py +582 -0
- groundworkers/base/__init__.py +17 -0
- groundworkers/base/errors.py +19 -0
- groundworkers/base/results.py +38 -0
- groundworkers/base/server.py +52 -0
- groundworkers/base/sql.py +109 -0
- groundworkers/config.py +139 -0
- groundworkers/server.py +127 -0
- groundworkers/tools/__init__.py +1 -0
- groundworkers/tools/concept_tools.py +237 -0
- groundworkers/tools/embedding_tools.py +83 -0
- groundworkers/tools/resolver_tools.py +90 -0
- groundworkers/tools/search_tools.py +163 -0
- groundworkers/tools/system_tools.py +67 -0
- groundworkers-0.1.0.dist-info/METADATA +116 -0
- groundworkers-0.1.0.dist-info/RECORD +23 -0
- groundworkers-0.1.0.dist-info/WHEEL +5 -0
- groundworkers-0.1.0.dist-info/entry_points.txt +2 -0
- groundworkers-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from omop_emb import EmbeddingConceptFilter, EmbeddingReaderInterface
|
|
8
|
+
from omop_emb.config import MetricType
|
|
9
|
+
from omop_emb.embeddings.embedding_client import EmbeddingRole
|
|
10
|
+
from omop_emb.interface import list_registered_models
|
|
11
|
+
|
|
12
|
+
from groundworkers.base.errors import GroundworkersError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OmopEmbAdapter:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*,
|
|
19
|
+
backend_factory: Callable[[], object],
|
|
20
|
+
backend_type: str | None,
|
|
21
|
+
default_model_name: str | None = None,
|
|
22
|
+
client_factory: Callable[[str], object] | None = None,
|
|
23
|
+
cdm_engine: object | None = None,
|
|
24
|
+
faiss_cache_dir: str | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
self._backend_factory = backend_factory
|
|
27
|
+
self._backend_type = backend_type
|
|
28
|
+
self._default_model_name = default_model_name
|
|
29
|
+
self._client_factory = client_factory
|
|
30
|
+
self._cdm_engine = cdm_engine
|
|
31
|
+
self._faiss_cache_dir = faiss_cache_dir
|
|
32
|
+
self._backend: object | None = None
|
|
33
|
+
self._clients: dict[str, object] = {}
|
|
34
|
+
|
|
35
|
+
def is_available(self) -> bool:
|
|
36
|
+
return self.index_status()["available"]
|
|
37
|
+
|
|
38
|
+
def has_client(self) -> bool:
|
|
39
|
+
return self._client_factory is not None
|
|
40
|
+
|
|
41
|
+
def close(self) -> None:
|
|
42
|
+
backend = self._backend
|
|
43
|
+
if backend is not None and hasattr(backend, "close"):
|
|
44
|
+
backend.close()
|
|
45
|
+
self._backend = None
|
|
46
|
+
self._clients.clear()
|
|
47
|
+
|
|
48
|
+
def index_status(self) -> dict[str, Any]:
|
|
49
|
+
try:
|
|
50
|
+
backend = self._get_backend()
|
|
51
|
+
records = list_registered_models(backend=backend)
|
|
52
|
+
models: list[dict[str, Any]] = []
|
|
53
|
+
for record in records:
|
|
54
|
+
metric_type = record.metric_type or MetricType.COSINE
|
|
55
|
+
concept_count = backend.get_embedding_count(
|
|
56
|
+
model_name=record.model_name,
|
|
57
|
+
metric_type=metric_type,
|
|
58
|
+
)
|
|
59
|
+
models.append(
|
|
60
|
+
{
|
|
61
|
+
"model_name": record.model_name,
|
|
62
|
+
"provider": self._enum_value(record.provider_type),
|
|
63
|
+
"dimensions": int(record.dimensions),
|
|
64
|
+
"index_type": self._enum_value(record.index_type),
|
|
65
|
+
"concept_count": int(concept_count),
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
return {
|
|
69
|
+
"available": bool(models),
|
|
70
|
+
"backend_type": self._backend_type or self._backend_type_from_backend(backend),
|
|
71
|
+
"models": models,
|
|
72
|
+
}
|
|
73
|
+
except Exception:
|
|
74
|
+
return {
|
|
75
|
+
"available": False,
|
|
76
|
+
"backend_type": self._backend_type,
|
|
77
|
+
"models": [],
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def get_neighbours(
|
|
81
|
+
self,
|
|
82
|
+
concept_id: int,
|
|
83
|
+
limit: int,
|
|
84
|
+
model_name: str | None,
|
|
85
|
+
) -> dict[str, Any]:
|
|
86
|
+
record = self._resolve_model_record(model_name)
|
|
87
|
+
reader = self._build_reader(record)
|
|
88
|
+
vectors = reader.get_embeddings_by_concept_ids((concept_id,))
|
|
89
|
+
if concept_id not in vectors:
|
|
90
|
+
raise GroundworkersError("NOT_FOUND", f"Concept {concept_id} is not present in the embedding index")
|
|
91
|
+
|
|
92
|
+
vector = np.asarray(vectors[concept_id], dtype=float).reshape(1, -1)
|
|
93
|
+
# Request limit+1 so that self-exclusion below still yields `limit` results.
|
|
94
|
+
concept_filter = self._build_concept_filter(limit=limit + 1)
|
|
95
|
+
raw = reader.get_nearest_concepts(
|
|
96
|
+
query_embedding=vector,
|
|
97
|
+
concept_filter=concept_filter,
|
|
98
|
+
k=limit + 1,
|
|
99
|
+
)
|
|
100
|
+
matches = raw[0] if raw else ()
|
|
101
|
+
results = [
|
|
102
|
+
self._serialise_nearest_match(match)
|
|
103
|
+
for match in matches
|
|
104
|
+
if getattr(match, "concept_id", None) != concept_id
|
|
105
|
+
][:limit]
|
|
106
|
+
return {
|
|
107
|
+
"query_concept_id": concept_id,
|
|
108
|
+
"model_name": record.model_name,
|
|
109
|
+
"results": results,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
def search(
|
|
113
|
+
self,
|
|
114
|
+
query: str,
|
|
115
|
+
limit: int,
|
|
116
|
+
domain: str | None,
|
|
117
|
+
vocabulary: str | None,
|
|
118
|
+
standard_only: bool,
|
|
119
|
+
active_only: bool,
|
|
120
|
+
model_name: str | None,
|
|
121
|
+
) -> dict[str, Any]:
|
|
122
|
+
if not self.has_client():
|
|
123
|
+
raise GroundworkersError(
|
|
124
|
+
"BACKEND_UNAVAIL",
|
|
125
|
+
"on-the-fly embedding requires a configured model client",
|
|
126
|
+
)
|
|
127
|
+
record = self._resolve_model_record(model_name)
|
|
128
|
+
reader = self._build_reader(record)
|
|
129
|
+
client = self._get_client(record.model_name)
|
|
130
|
+
concept_filter = self._build_concept_filter(
|
|
131
|
+
limit=limit,
|
|
132
|
+
domain=domain,
|
|
133
|
+
vocabulary=vocabulary,
|
|
134
|
+
standard_only=standard_only,
|
|
135
|
+
active_only=active_only,
|
|
136
|
+
)
|
|
137
|
+
raw = reader.get_nearest_concepts_from_query_texts(
|
|
138
|
+
query_texts=(query,),
|
|
139
|
+
embedding_client=client,
|
|
140
|
+
concept_filter=concept_filter,
|
|
141
|
+
k=limit,
|
|
142
|
+
)
|
|
143
|
+
matches = raw[0] if raw else ()
|
|
144
|
+
return {
|
|
145
|
+
"query_text": query,
|
|
146
|
+
"model_name": record.model_name,
|
|
147
|
+
"results": [self._serialise_nearest_match(match) for match in matches],
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
def encode(self, text: str, model_name: str | None) -> dict[str, Any]:
|
|
151
|
+
if not self.has_client():
|
|
152
|
+
raise GroundworkersError("BACKEND_UNAVAIL", "embedding client is not configured")
|
|
153
|
+
record = self._resolve_model_record(model_name)
|
|
154
|
+
client = self._get_client(record.model_name)
|
|
155
|
+
vector = client.embeddings(text, embedding_role=EmbeddingRole.QUERY)
|
|
156
|
+
array = np.asarray(vector, dtype=float)
|
|
157
|
+
if array.ndim != 2 or array.shape[0] != 1:
|
|
158
|
+
raise GroundworkersError("QUERY_ERROR", f"Expected one embedding vector, got shape {array.shape}")
|
|
159
|
+
row = array[0]
|
|
160
|
+
return {
|
|
161
|
+
"text": text,
|
|
162
|
+
"model_name": record.model_name,
|
|
163
|
+
"dimensions": int(row.shape[0]),
|
|
164
|
+
"vector": row.tolist(),
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
def _get_backend(self) -> object:
|
|
168
|
+
if self._backend is None:
|
|
169
|
+
try:
|
|
170
|
+
self._backend = self._backend_factory()
|
|
171
|
+
except Exception as exc:
|
|
172
|
+
raise GroundworkersError("BACKEND_UNAVAIL", f"Embedding backend is unavailable: {exc}") from exc
|
|
173
|
+
return self._backend
|
|
174
|
+
|
|
175
|
+
def _resolve_model_record(self, model_name: str | None) -> object:
|
|
176
|
+
backend = self._get_backend()
|
|
177
|
+
requested_name = model_name or self._default_model_name
|
|
178
|
+
records = list_registered_models(backend=backend, model_name=requested_name)
|
|
179
|
+
if requested_name is not None:
|
|
180
|
+
if not records:
|
|
181
|
+
raise GroundworkersError("NOT_FOUND", f"Embedding model {requested_name!r} is not registered")
|
|
182
|
+
return records[0]
|
|
183
|
+
# No specific model requested — records contains all registered models.
|
|
184
|
+
if len(records) == 1:
|
|
185
|
+
return records[0]
|
|
186
|
+
if not records:
|
|
187
|
+
raise GroundworkersError("BACKEND_UNAVAIL", "No embedding models are registered in the backend")
|
|
188
|
+
raise GroundworkersError(
|
|
189
|
+
"BACKEND_UNAVAIL",
|
|
190
|
+
"No default embedding model is configured and multiple registered models are available",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def _build_reader(self, record: object) -> EmbeddingReaderInterface:
|
|
194
|
+
return EmbeddingReaderInterface(
|
|
195
|
+
model=record.model_name,
|
|
196
|
+
backend=self._get_backend(),
|
|
197
|
+
metric_type=record.metric_type or MetricType.COSINE,
|
|
198
|
+
omop_cdm_engine=self._cdm_engine,
|
|
199
|
+
provider_name_or_type=record.provider_type,
|
|
200
|
+
faiss_cache_dir=self._faiss_cache_dir,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def _get_client(self, model_name: str) -> object:
|
|
204
|
+
if self._client_factory is None:
|
|
205
|
+
raise GroundworkersError("BACKEND_UNAVAIL", "embedding client is not configured")
|
|
206
|
+
client = self._clients.get(model_name)
|
|
207
|
+
if client is None:
|
|
208
|
+
try:
|
|
209
|
+
client = self._client_factory(model_name)
|
|
210
|
+
except Exception as exc:
|
|
211
|
+
raise GroundworkersError("BACKEND_UNAVAIL", f"Embedding client is unavailable: {exc}") from exc
|
|
212
|
+
self._clients[model_name] = client
|
|
213
|
+
return client
|
|
214
|
+
|
|
215
|
+
def _build_concept_filter(
|
|
216
|
+
self,
|
|
217
|
+
*,
|
|
218
|
+
limit: int,
|
|
219
|
+
domain: str | None = None,
|
|
220
|
+
vocabulary: str | None = None,
|
|
221
|
+
standard_only: bool = False,
|
|
222
|
+
active_only: bool = False,
|
|
223
|
+
) -> EmbeddingConceptFilter:
|
|
224
|
+
domains = (domain,) if domain else None
|
|
225
|
+
vocabularies = (vocabulary,) if vocabulary else None
|
|
226
|
+
return EmbeddingConceptFilter(
|
|
227
|
+
domains=domains,
|
|
228
|
+
vocabularies=vocabularies,
|
|
229
|
+
require_standard=standard_only,
|
|
230
|
+
require_active=active_only,
|
|
231
|
+
limit=limit,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def _backend_type_from_backend(self, backend: object) -> str | None:
|
|
235
|
+
backend_type = getattr(backend, "backend_type", None)
|
|
236
|
+
return self._enum_value(backend_type)
|
|
237
|
+
|
|
238
|
+
def _serialise_nearest_match(self, match: object) -> dict[str, Any]:
|
|
239
|
+
return {
|
|
240
|
+
"concept_id": int(getattr(match, "concept_id")),
|
|
241
|
+
"concept_name": getattr(match, "concept_name", None),
|
|
242
|
+
"similarity": round(float(getattr(match, "similarity")), 6),
|
|
243
|
+
"is_standard": getattr(match, "is_standard", None),
|
|
244
|
+
"is_active": getattr(match, "is_active", None),
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _enum_value(value: object) -> str | None:
|
|
249
|
+
if value is None:
|
|
250
|
+
return None
|
|
251
|
+
return getattr(value, "value", str(value))
|