aatm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aatm/__init__.py +3 -0
- aatm/aio/__init__.py +0 -0
- aatm/aio/selectors.py +173 -0
- aatm/aio/translators.py +79 -0
- aatm/api/__init__.py +0 -0
- aatm/api/config.py +68 -0
- aatm/api/data_models.py +85 -0
- aatm/api/main.py +160 -0
- aatm/data_models.py +921 -0
- aatm/debug.py +91 -0
- aatm/embedding_functions.py +290 -0
- aatm/extractors.py +199 -0
- aatm/local_database_utils.py +334 -0
- aatm/logs.py +111 -0
- aatm/main.py +658 -0
- aatm/omop/__init__.py +4 -0
- aatm/omop/condition_occurrence.py +227 -0
- aatm/omop/device_exposure.py +149 -0
- aatm/omop/drug_exposure.py +249 -0
- aatm/omop/registry.py +23 -0
- aatm/pipeline.py +78 -0
- aatm/prompt_helpers.py +48 -0
- aatm/registries/__init__.py +0 -0
- aatm/registries/rerankers.py +114 -0
- aatm/registries/retrievers.py +175 -0
- aatm/registries/selectors.py +157 -0
- aatm/registries/translators.py +93 -0
- aatm/rerankers.py +368 -0
- aatm/retrievers.py +200 -0
- aatm/search_ui.py +129 -0
- aatm/selectors.py +448 -0
- aatm/sql_commands.yaml +78 -0
- aatm/terminology_mapper.py +594 -0
- aatm/time.py +18 -0
- aatm/translators.py +294 -0
- aatm-0.1.0.dist-info/METADATA +241 -0
- aatm-0.1.0.dist-info/RECORD +40 -0
- aatm-0.1.0.dist-info/WHEEL +4 -0
- aatm-0.1.0.dist-info/entry_points.txt +3 -0
- aatm-0.1.0.dist-info/licenses/LICENSE +21 -0
aatm/__init__.py
ADDED
aatm/aio/__init__.py
ADDED
|
File without changes
|
aatm/aio/selectors.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import List
|
|
3
|
+
import dotenv
|
|
4
|
+
from openai import AsyncOpenAI
|
|
5
|
+
from google.genai import types
|
|
6
|
+
|
|
7
|
+
from aatm.data_models import (
|
|
8
|
+
EmptySelectionMetadata,
|
|
9
|
+
RetrieverResults,
|
|
10
|
+
SelectedExpressionMetadata,
|
|
11
|
+
SelectedResult,
|
|
12
|
+
SelectorResults,
|
|
13
|
+
)
|
|
14
|
+
from aatm.prompt_helpers import format_prompt
|
|
15
|
+
from aatm.selectors import OpenAILLMSelector, GeminiLLMSelector
|
|
16
|
+
from aatm.debug import DebugMode, get_debug_mode
|
|
17
|
+
from aatm.logs import get_logger
|
|
18
|
+
|
|
19
|
+
# Load environment variables
|
|
20
|
+
dotenv.load_dotenv()
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
debug_mode = get_debug_mode()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AsyncOpenAILLMSelector(OpenAILLMSelector):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*args,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
super().__init__(*args, **kwargs)
|
|
33
|
+
self.client = AsyncOpenAI()
|
|
34
|
+
|
|
35
|
+
async def select(self, results: RetrieverResults) -> SelectorResults:
|
|
36
|
+
selector_results = SelectorResults(results=[], queries=results.queries)
|
|
37
|
+
async with asyncio.TaskGroup() as tg:
|
|
38
|
+
tasks = []
|
|
39
|
+
for query_id, query in enumerate(results.queries):
|
|
40
|
+
prompt = format_prompt(
|
|
41
|
+
self.prompt_template,
|
|
42
|
+
{
|
|
43
|
+
"json_format": SelectedResult.model_json_schema(),
|
|
44
|
+
"query": query.capitalize(), # avoid case sensitivity for some queries like 'cough' and 'Cough'
|
|
45
|
+
"expressions": [
|
|
46
|
+
r.to_prompt_object() for r in results.results[query_id]
|
|
47
|
+
],
|
|
48
|
+
},
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
tasks.append(
|
|
52
|
+
tg.create_task(
|
|
53
|
+
self.client.responses.parse(
|
|
54
|
+
model=self.model_id,
|
|
55
|
+
input=prompt,
|
|
56
|
+
text_format=SelectedResult,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
responses = [t.result() for t in tasks]
|
|
62
|
+
|
|
63
|
+
for response, (query_id, query) in zip(responses, enumerate(results.queries)):
|
|
64
|
+
selected_result: SelectedResult = response.output_parsed
|
|
65
|
+
assert isinstance(selected_result, SelectedResult), (
|
|
66
|
+
f"Expected SelectedResult object from OpenAI, but got {type(selected_result)}"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# case where no expression is selected
|
|
70
|
+
if selected_result.expression_id is None:
|
|
71
|
+
selector_results.results.append(EmptySelectionMetadata())
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
# case where expression is selected but it is not in the results
|
|
75
|
+
results_expression_ids = set(
|
|
76
|
+
[r.expression_id for r in results.results[query_id]]
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if selected_result.expression_id not in results_expression_ids:
|
|
80
|
+
selector_results.results.append(EmptySelectionMetadata())
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
# case where expression is selected and it is in the results
|
|
84
|
+
for result_idx, r in enumerate(results.results[query_id]):
|
|
85
|
+
if r.expression_id == selected_result.expression_id:
|
|
86
|
+
selector_results.results.append(
|
|
87
|
+
SelectedExpressionMetadata(
|
|
88
|
+
**r.model_dump(), result_list_index=result_idx
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
break
|
|
92
|
+
|
|
93
|
+
return selector_results
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class AsyncGeminiLLMSelector(GeminiLLMSelector):
|
|
97
|
+
async def select(self, results: RetrieverResults) -> SelectorResults:
|
|
98
|
+
selector_results = SelectorResults(results=[], queries=results.queries)
|
|
99
|
+
async with asyncio.TaskGroup() as tg:
|
|
100
|
+
tasks = []
|
|
101
|
+
for query_id, query in enumerate(results.queries):
|
|
102
|
+
prompt = format_prompt(
|
|
103
|
+
self.prompt_template,
|
|
104
|
+
{
|
|
105
|
+
"json_format": SelectedResult.model_json_schema(),
|
|
106
|
+
"query": query.capitalize(), # avoid case sensitivity for some queries like 'cough' and 'Cough'
|
|
107
|
+
"expressions": [
|
|
108
|
+
r.to_prompt_object() for r in results.results[query_id]
|
|
109
|
+
],
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
gemini_formatted_prompt = []
|
|
114
|
+
for msg in prompt:
|
|
115
|
+
gemini_formatted_prompt.append(types.UserContent(msg["content"]))
|
|
116
|
+
|
|
117
|
+
if debug_mode == DebugMode.GEMINI_LLM_SELECTOR:
|
|
118
|
+
logger.debug(prompt)
|
|
119
|
+
logger.debug(gemini_formatted_prompt)
|
|
120
|
+
|
|
121
|
+
tasks.append(
|
|
122
|
+
tg.create_task(
|
|
123
|
+
self.client.aio.models.generate_content(
|
|
124
|
+
model=self.model_id,
|
|
125
|
+
contents=gemini_formatted_prompt,
|
|
126
|
+
config=types.GenerateContentConfig(
|
|
127
|
+
response_mime_type="application/json",
|
|
128
|
+
response_schema=SelectedResult,
|
|
129
|
+
),
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
responses: List[types.GenerateContentResponse] = [t.result() for t in tasks]
|
|
135
|
+
for response, (query_id, query) in zip(responses, enumerate(results.queries)):
|
|
136
|
+
selected_result: SelectedResult = SelectedResult.model_validate_json(
|
|
137
|
+
response.text
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if debug_mode == DebugMode.GEMINI_LLM_SELECTOR:
|
|
141
|
+
logger.debug(response)
|
|
142
|
+
logger.debug(response.text)
|
|
143
|
+
logger.debug(SelectedResult.model_validate_json(response.text))
|
|
144
|
+
|
|
145
|
+
assert isinstance(selected_result, SelectedResult), (
|
|
146
|
+
f"Expected SelectedResult object from Gemini, but got {type(selected_result)}"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# case where no expression is selected
|
|
150
|
+
if selected_result.expression_id is None:
|
|
151
|
+
selector_results.results.append(EmptySelectionMetadata())
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
# case where expression is selected but it is not in the results
|
|
155
|
+
results_expression_ids = set(
|
|
156
|
+
[r.expression_id for r in results.results[query_id]]
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if selected_result.expression_id not in results_expression_ids:
|
|
160
|
+
selector_results.results.append(EmptySelectionMetadata())
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
# case where expression is selected and it is in the results
|
|
164
|
+
for result_idx, r in enumerate(results.results[query_id]):
|
|
165
|
+
if r.expression_id == selected_result.expression_id:
|
|
166
|
+
selector_results.results.append(
|
|
167
|
+
SelectedExpressionMetadata(
|
|
168
|
+
**r.model_dump(), result_list_index=result_idx
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
break
|
|
172
|
+
|
|
173
|
+
return selector_results
|
aatm/aio/translators.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from google.genai import types
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
# Custom modules
|
|
9
|
+
from aatm.data_models import Translation
|
|
10
|
+
from aatm.prompt_helpers import format_prompt
|
|
11
|
+
from aatm.translators import BaseTranslator, GeminiTranslator, OpenAITranslator
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EmptyTranslator(BaseTranslator):
|
|
15
|
+
async def translate(self, texts: List[str]) -> List[Translation]:
|
|
16
|
+
return [Translation(text=t) for t in texts]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AsyncGeminiTranslator(GeminiTranslator):
|
|
20
|
+
async def translate(self, texts: List[str]) -> List[Translation]:
|
|
21
|
+
async with asyncio.TaskGroup() as tg:
|
|
22
|
+
tasks = [
|
|
23
|
+
tg.create_task(
|
|
24
|
+
self.client.aio.models.generate_content(
|
|
25
|
+
model=self.model,
|
|
26
|
+
contents=self.prompt_template.format(text=t),
|
|
27
|
+
config=types.GenerateContentConfig(
|
|
28
|
+
response_mime_type="application/json",
|
|
29
|
+
response_schema=Translation,
|
|
30
|
+
),
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
for t in texts
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
results = [t.result() for t in tasks]
|
|
37
|
+
|
|
38
|
+
processed_results = []
|
|
39
|
+
for result, t in zip(results, texts):
|
|
40
|
+
try:
|
|
41
|
+
processed_results.append(Translation(**json.loads(result.text)))
|
|
42
|
+
except Exception as e:
|
|
43
|
+
print(
|
|
44
|
+
f"Error while processing text '{t}' and response '{result}': {e}. Original text was maintained."
|
|
45
|
+
)
|
|
46
|
+
processed_results.append(Translation(text=t))
|
|
47
|
+
return processed_results
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AsyncOpenAITranslator(OpenAITranslator):
|
|
51
|
+
def __init__(self, *args, **kwargs):
|
|
52
|
+
super().__init__(*args, **kwargs)
|
|
53
|
+
self.client = AsyncOpenAI()
|
|
54
|
+
|
|
55
|
+
async def translate(self, texts):
|
|
56
|
+
async with asyncio.TaskGroup() as tg:
|
|
57
|
+
tasks = [
|
|
58
|
+
tg.create_task(
|
|
59
|
+
self.client.responses.parse(
|
|
60
|
+
model=self.model_id,
|
|
61
|
+
input=format_prompt(self.prompt_template, {"text": t}),
|
|
62
|
+
text_format=Translation,
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
for t in texts
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
results = [t.result() for t in tasks]
|
|
69
|
+
|
|
70
|
+
processed_results = []
|
|
71
|
+
for result, t in zip(results, texts):
|
|
72
|
+
try:
|
|
73
|
+
processed_results.append(result.output_parsed)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
print(
|
|
76
|
+
f"Error while processing text '{t}' and response '{result}': {e}. Original text was maintained."
|
|
77
|
+
)
|
|
78
|
+
processed_results.append(Translation(text=t))
|
|
79
|
+
return processed_results
|
aatm/api/__init__.py
ADDED
|
File without changes
|
aatm/api/config.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Configuration model and persistence utilities for the AATM API.
|
|
2
|
+
|
|
3
|
+
This module defines the `APIConfig` model, which stores runtime configuration
|
|
4
|
+
for the API server and provides helper methods to save the configuration to
|
|
5
|
+
disk and load it back from a YAML file.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, ConfigDict
|
|
12
|
+
import yaml
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class APIConfig(BaseModel):
|
|
16
|
+
"""Configuration model for the AATM API server.
|
|
17
|
+
|
|
18
|
+
This model stores the runtime settings used to serve the API, including host,
|
|
19
|
+
port, batching behavior, optional rate limiting, and worker configuration. It
|
|
20
|
+
also supports persistence to and from a YAML file.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
DEFAULT_PATH: Default filesystem path used to save and load the API
|
|
24
|
+
configuration.
|
|
25
|
+
host: Host interface on which the API server listens.
|
|
26
|
+
port: Port on which the API server listens.
|
|
27
|
+
batch_size: Batch size used by the API processing pipeline.
|
|
28
|
+
rate_limit: Optional maximum number of documents allowed per minute.
|
|
29
|
+
workers: Optional number of worker processes.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
DEFAULT_PATH: Path = Path(".aatm/api_config.yaml")
|
|
33
|
+
host: str
|
|
34
|
+
port: str
|
|
35
|
+
batch_size: int
|
|
36
|
+
rate_limit: Optional[int] = None
|
|
37
|
+
workers: Optional[int] = None
|
|
38
|
+
|
|
39
|
+
model_config = ConfigDict(extra="allow")
|
|
40
|
+
|
|
41
|
+
def save_to_disk(self, path: str | Path = DEFAULT_PATH) -> None:
|
|
42
|
+
"""Save the API configuration to a YAML file on disk.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
path: Destination path where the configuration should be written. If not
|
|
46
|
+
provided, the default configuration path is used.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
None.
|
|
50
|
+
"""
|
|
51
|
+
if isinstance(path, str):
|
|
52
|
+
path = Path(path)
|
|
53
|
+
path.write_text(yaml.safe_dump(self.model_dump(mode="json")))
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def load_from_disk(cls, path: str | Path = DEFAULT_PATH) -> "APIConfig":
|
|
57
|
+
"""Load the API configuration from a YAML file on disk.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
path: Path to the YAML configuration file. If not provided, the default
|
|
61
|
+
configuration path is used.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
An `APIConfig` instance initialized from the contents of the file.
|
|
65
|
+
"""
|
|
66
|
+
if isinstance(path, str):
|
|
67
|
+
path = Path(path)
|
|
68
|
+
return cls(**yaml.safe_load(path.read_text()))
|
aatm/api/data_models.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Request models for the terminology mapping API.
|
|
2
|
+
|
|
3
|
+
This module defines Pydantic models used to validate incoming API requests for
|
|
4
|
+
terminology mapping and retrieval. It includes request schemas for mapping
|
|
5
|
+
source concepts through the terminology pipeline and for performing retriever-
|
|
6
|
+
based searches.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, List, Optional
|
|
10
|
+
from fastapi import HTTPException
|
|
11
|
+
from pydantic import BaseModel, Field, field_validator
|
|
12
|
+
|
|
13
|
+
from aatm.data_models import SourceConcept
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TerminologyMappingRequest(BaseModel):
|
|
17
|
+
"""Request model for terminology mapping operations.
|
|
18
|
+
|
|
19
|
+
This model encapsulates the list of source concepts to be mapped together with
|
|
20
|
+
the optional identifiers of the pipeline components used during the mapping
|
|
21
|
+
workflow.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
source_concepts: Source concepts to map to a target terminology.
|
|
25
|
+
translator_id: Optional identifier of the translator component.
|
|
26
|
+
retriever_id: Optional identifier of the retriever component.
|
|
27
|
+
selector_id: Optional identifier of the selector component.
|
|
28
|
+
reranker_id: Optional identifier of the reranker component.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
source_concepts: List[SourceConcept]
|
|
32
|
+
translator_id: Optional[str] = Field(None, examples=["gemini-2.5-flash"])
|
|
33
|
+
retriever_id: Optional[str] = Field(None, examples=["embeddinggemma-300M"])
|
|
34
|
+
selector_id: Optional[str] = Field(None, examples=["first-result-selector"])
|
|
35
|
+
reranker_id: Optional[str] = Field(None, examples=[None])
|
|
36
|
+
|
|
37
|
+
@field_validator("source_concepts", mode="after")
|
|
38
|
+
def validate_source_concepts(cls, v: List[SourceConcept]):
|
|
39
|
+
"""Validate the list of source concepts provided in the request.
|
|
40
|
+
|
|
41
|
+
This validator ensures that at least one source concept is provided and that
|
|
42
|
+
all source concepts include a `source_code_description` value.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
v: List of source concepts to validate.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
The validated list of source concepts.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
HTTPException: If no source concepts are provided.
|
|
52
|
+
HTTPException: If one or more source concepts are missing the
|
|
53
|
+
`source_code_description` field.
|
|
54
|
+
"""
|
|
55
|
+
if len(v) == 0:
|
|
56
|
+
raise HTTPException(status_code=400, detail="No source concepts provided")
|
|
57
|
+
|
|
58
|
+
incomplete_source_concepts = [
|
|
59
|
+
source_concept
|
|
60
|
+
for source_concept in v
|
|
61
|
+
if source_concept.source_code_description is None
|
|
62
|
+
]
|
|
63
|
+
if len(incomplete_source_concepts) > 0:
|
|
64
|
+
raise HTTPException(
|
|
65
|
+
status_code=400,
|
|
66
|
+
detail=f"The field source_code_description is required for all source concepts. A total of {len(incomplete_source_concepts)} source concepts are missing this field.",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return v
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class SearchRequest(BaseModel):
|
|
73
|
+
"""Request model for retriever-based search operations.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
queries: List of query strings to search for.
|
|
77
|
+
retriever_id: Identifier of the retriever to use.
|
|
78
|
+
top_k: Maximum number of results to return per query.
|
|
79
|
+
where: Optional metadata filter applied during retrieval.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
queries: List[str] = Field(..., examples=[["Cardiovascular disease"]])
|
|
83
|
+
retriever_id: str = Field(..., examples=["embeddinggemma-300M"])
|
|
84
|
+
top_k: int = Field(10, examples=[10])
|
|
85
|
+
where: dict[str, Any] | None = Field(None, examples=[None])
|
aatm/api/main.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""FastAPI application for terminology mapping and retrieval workflows.
|
|
2
|
+
|
|
3
|
+
This module defines API endpoints for terminology mapping and retrieval, along
|
|
4
|
+
with an in-memory least-recently-used registry for caching pipeline components.
|
|
5
|
+
The registry avoids repeatedly instantiating expensive retrievers and
|
|
6
|
+
terminology mappers across requests.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
from typing import List
|
|
11
|
+
from fastapi import FastAPI
|
|
12
|
+
from aatm.api.config import APIConfig
|
|
13
|
+
from aatm.api.data_models import SearchRequest, TerminologyMappingRequest
|
|
14
|
+
from aatm.data_models import MappedSourceConcept, RetrieverResults
|
|
15
|
+
from aatm.pipeline import PipelineBaseClass
|
|
16
|
+
from aatm.registries.retrievers import load_retriever
|
|
17
|
+
from aatm.retrievers import ChromaDBRetriever
|
|
18
|
+
from aatm.terminology_mapper import TerminologyMapper
|
|
19
|
+
|
|
20
|
+
app = FastAPI()
|
|
21
|
+
|
|
22
|
+
api_config = APIConfig.load_from_disk()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ComponentRegistry:
|
|
26
|
+
"""Least-recently-used in-memory registry for pipeline components.
|
|
27
|
+
|
|
28
|
+
This registry stores instantiated terminology mappers and other pipeline
|
|
29
|
+
components keyed by configuration tuples. When the registry reaches its maximum
|
|
30
|
+
capacity, the least recently used item is evicted.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
max_size: Maximum number of cached components to retain.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
max_size: Maximum number of cached components.
|
|
37
|
+
_store: Ordered mapping from cache keys to instantiated pipeline
|
|
38
|
+
components.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, max_size: int = 10):
|
|
42
|
+
"""Initialize the component registry.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
max_size: Maximum number of cached components to retain before evicting
|
|
46
|
+
the least recently used entry.
|
|
47
|
+
"""
|
|
48
|
+
self.max_size = max_size
|
|
49
|
+
self._store: OrderedDict[tuple, TerminologyMapper | PipelineBaseClass] = (
|
|
50
|
+
OrderedDict()
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def get(self, key: tuple) -> TerminologyMapper | PipelineBaseClass | None:
|
|
54
|
+
"""Retrieve a cached component by key.
|
|
55
|
+
|
|
56
|
+
If the key is present, the corresponding component is marked as recently used
|
|
57
|
+
before being returned.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
key: Cache key identifying the component.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The cached `TerminologyMapper` or `PipelineBaseClass` instance associated
|
|
64
|
+
with the key, or `None` if the key is not present.
|
|
65
|
+
"""
|
|
66
|
+
if key not in self._store:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
# Mark as recently used
|
|
70
|
+
self._store.move_to_end(key)
|
|
71
|
+
return self._store[key]
|
|
72
|
+
|
|
73
|
+
def set(self, key: tuple, value: TerminologyMapper | PipelineBaseClass) -> None:
|
|
74
|
+
"""Store a component in the registry.
|
|
75
|
+
|
|
76
|
+
If the key already exists, the existing entry is updated and marked as recently
|
|
77
|
+
used. If the registry is full, the least recently used entry is removed before
|
|
78
|
+
adding the new component.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
key: Cache key identifying the component.
|
|
82
|
+
value: Instantiated terminology mapper or pipeline component to cache.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
None.
|
|
86
|
+
"""
|
|
87
|
+
if key in self._store:
|
|
88
|
+
self._store.move_to_end(key)
|
|
89
|
+
self._store[key] = value
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if len(self._store) >= self.max_size:
|
|
93
|
+
# Remove least recently used item
|
|
94
|
+
self._store.popitem(last=False)
|
|
95
|
+
|
|
96
|
+
self._store[key] = value
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
PIPELINE_COMPONENTS_REGISTRY = ComponentRegistry(max_size=20)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@app.post("/map", response_model=List[MappedSourceConcept])
|
|
103
|
+
def map(request: TerminologyMappingRequest):
|
|
104
|
+
"""Map source concepts to target terminology concepts.
|
|
105
|
+
|
|
106
|
+
This endpoint retrieves or creates a `TerminologyMapper` instance based on the
|
|
107
|
+
requested pipeline component identifiers, then runs the mapping workflow over
|
|
108
|
+
the provided source concepts.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
request: Terminology mapping request containing source concepts and the
|
|
112
|
+
identifiers of the translator, retriever, selector, and reranker
|
|
113
|
+
components.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
A list of mapped source concepts produced by the terminology mapping
|
|
117
|
+
pipeline.
|
|
118
|
+
"""
|
|
119
|
+
tm_key = (
|
|
120
|
+
request.translator_id,
|
|
121
|
+
request.retriever_id,
|
|
122
|
+
request.selector_id,
|
|
123
|
+
request.reranker_id,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
tm = PIPELINE_COMPONENTS_REGISTRY.get(tm_key)
|
|
127
|
+
if tm is None:
|
|
128
|
+
tm = TerminologyMapper.from_task_request(request, api_config)
|
|
129
|
+
PIPELINE_COMPONENTS_REGISTRY.set(tm_key, tm)
|
|
130
|
+
|
|
131
|
+
mapped_concepts = tm.map(
|
|
132
|
+
request.source_concepts,
|
|
133
|
+
save_to_disk=False,
|
|
134
|
+
return_as="mapped_source_concepts",
|
|
135
|
+
)
|
|
136
|
+
return mapped_concepts
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@app.post("/search", response_model=RetrieverResults)
|
|
140
|
+
def search(request: SearchRequest) -> RetrieverResults:
|
|
141
|
+
"""Search terminology candidates using a configured retriever.
|
|
142
|
+
|
|
143
|
+
This endpoint retrieves or creates a retriever instance based on the requested
|
|
144
|
+
retriever identifier, then executes the retrieval operation with the request
|
|
145
|
+
parameters.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
request: Search request containing the retriever identifier and query
|
|
149
|
+
parameters.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
A `RetrieverResults` object containing the retrieval results.
|
|
153
|
+
"""
|
|
154
|
+
retriever_key = f"retriever-{request.retriever_id}"
|
|
155
|
+
retriever = PIPELINE_COMPONENTS_REGISTRY.get(retriever_key)
|
|
156
|
+
if retriever is None:
|
|
157
|
+
retriever: ChromaDBRetriever = load_retriever(request.retriever_id)
|
|
158
|
+
PIPELINE_COMPONENTS_REGISTRY.set(retriever_key, retriever)
|
|
159
|
+
|
|
160
|
+
return retriever(**request.model_dump())
|