aatm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aatm/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .logs import configure_logging
2
+
3
+ configure_logging()
aatm/aio/__init__.py ADDED
File without changes
aatm/aio/selectors.py ADDED
@@ -0,0 +1,173 @@
1
+ import asyncio
2
+ from typing import List
3
+ import dotenv
4
+ from openai import AsyncOpenAI
5
+ from google.genai import types
6
+
7
+ from aatm.data_models import (
8
+ EmptySelectionMetadata,
9
+ RetrieverResults,
10
+ SelectedExpressionMetadata,
11
+ SelectedResult,
12
+ SelectorResults,
13
+ )
14
+ from aatm.prompt_helpers import format_prompt
15
+ from aatm.selectors import OpenAILLMSelector, GeminiLLMSelector
16
+ from aatm.debug import DebugMode, get_debug_mode
17
+ from aatm.logs import get_logger
18
+
19
+ # Load environment variables
20
+ dotenv.load_dotenv()
21
+
22
+ logger = get_logger(__name__)
23
+ debug_mode = get_debug_mode()
24
+
25
+
26
+ class AsyncOpenAILLMSelector(OpenAILLMSelector):
27
+ def __init__(
28
+ self,
29
+ *args,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(*args, **kwargs)
33
+ self.client = AsyncOpenAI()
34
+
35
+ async def select(self, results: RetrieverResults) -> SelectorResults:
36
+ selector_results = SelectorResults(results=[], queries=results.queries)
37
+ async with asyncio.TaskGroup() as tg:
38
+ tasks = []
39
+ for query_id, query in enumerate(results.queries):
40
+ prompt = format_prompt(
41
+ self.prompt_template,
42
+ {
43
+ "json_format": SelectedResult.model_json_schema(),
44
+ "query": query.capitalize(), # avoid case sensitivity for some queries like 'cough' and 'Cough'
45
+ "expressions": [
46
+ r.to_prompt_object() for r in results.results[query_id]
47
+ ],
48
+ },
49
+ )
50
+
51
+ tasks.append(
52
+ tg.create_task(
53
+ self.client.responses.parse(
54
+ model=self.model_id,
55
+ input=prompt,
56
+ text_format=SelectedResult,
57
+ )
58
+ )
59
+ )
60
+
61
+ responses = [t.result() for t in tasks]
62
+
63
+ for response, (query_id, query) in zip(responses, enumerate(results.queries)):
64
+ selected_result: SelectedResult = response.output_parsed
65
+ assert isinstance(selected_result, SelectedResult), (
66
+ f"Expected SelectedResult object from OpenAI, but got {type(selected_result)}"
67
+ )
68
+
69
+ # case where no expression is selected
70
+ if selected_result.expression_id is None:
71
+ selector_results.results.append(EmptySelectionMetadata())
72
+ continue
73
+
74
+ # case where expression is selected but it is not in the results
75
+ results_expression_ids = set(
76
+ [r.expression_id for r in results.results[query_id]]
77
+ )
78
+
79
+ if selected_result.expression_id not in results_expression_ids:
80
+ selector_results.results.append(EmptySelectionMetadata())
81
+ continue
82
+
83
+ # case where expression is selected and it is in the results
84
+ for result_idx, r in enumerate(results.results[query_id]):
85
+ if r.expression_id == selected_result.expression_id:
86
+ selector_results.results.append(
87
+ SelectedExpressionMetadata(
88
+ **r.model_dump(), result_list_index=result_idx
89
+ )
90
+ )
91
+ break
92
+
93
+ return selector_results
94
+
95
+
96
+ class AsyncGeminiLLMSelector(GeminiLLMSelector):
97
+ async def select(self, results: RetrieverResults) -> SelectorResults:
98
+ selector_results = SelectorResults(results=[], queries=results.queries)
99
+ async with asyncio.TaskGroup() as tg:
100
+ tasks = []
101
+ for query_id, query in enumerate(results.queries):
102
+ prompt = format_prompt(
103
+ self.prompt_template,
104
+ {
105
+ "json_format": SelectedResult.model_json_schema(),
106
+ "query": query.capitalize(), # avoid case sensitivity for some queries like 'cough' and 'Cough'
107
+ "expressions": [
108
+ r.to_prompt_object() for r in results.results[query_id]
109
+ ],
110
+ },
111
+ )
112
+
113
+ gemini_formatted_prompt = []
114
+ for msg in prompt:
115
+ gemini_formatted_prompt.append(types.UserContent(msg["content"]))
116
+
117
+ if debug_mode == DebugMode.GEMINI_LLM_SELECTOR:
118
+ logger.debug(prompt)
119
+ logger.debug(gemini_formatted_prompt)
120
+
121
+ tasks.append(
122
+ tg.create_task(
123
+ self.client.aio.models.generate_content(
124
+ model=self.model_id,
125
+ contents=gemini_formatted_prompt,
126
+ config=types.GenerateContentConfig(
127
+ response_mime_type="application/json",
128
+ response_schema=SelectedResult,
129
+ ),
130
+ )
131
+ )
132
+ )
133
+
134
+ responses: List[types.GenerateContentResponse] = [t.result() for t in tasks]
135
+ for response, (query_id, query) in zip(responses, enumerate(results.queries)):
136
+ selected_result: SelectedResult = SelectedResult.model_validate_json(
137
+ response.text
138
+ )
139
+
140
+ if debug_mode == DebugMode.GEMINI_LLM_SELECTOR:
141
+ logger.debug(response)
142
+ logger.debug(response.text)
143
+ logger.debug(SelectedResult.model_validate_json(response.text))
144
+
145
+ assert isinstance(selected_result, SelectedResult), (
146
+ f"Expected SelectedResult object from Gemini, but got {type(selected_result)}"
147
+ )
148
+
149
+ # case where no expression is selected
150
+ if selected_result.expression_id is None:
151
+ selector_results.results.append(EmptySelectionMetadata())
152
+ continue
153
+
154
+ # case where expression is selected but it is not in the results
155
+ results_expression_ids = set(
156
+ [r.expression_id for r in results.results[query_id]]
157
+ )
158
+
159
+ if selected_result.expression_id not in results_expression_ids:
160
+ selector_results.results.append(EmptySelectionMetadata())
161
+ continue
162
+
163
+ # case where expression is selected and it is in the results
164
+ for result_idx, r in enumerate(results.results[query_id]):
165
+ if r.expression_id == selected_result.expression_id:
166
+ selector_results.results.append(
167
+ SelectedExpressionMetadata(
168
+ **r.model_dump(), result_list_index=result_idx
169
+ )
170
+ )
171
+ break
172
+
173
+ return selector_results
@@ -0,0 +1,79 @@
1
+ import asyncio
2
+ import json
3
+ from google.genai import types
4
+ from typing import List
5
+
6
+ from openai import AsyncOpenAI
7
+
8
+ # Custom modules
9
+ from aatm.data_models import Translation
10
+ from aatm.prompt_helpers import format_prompt
11
+ from aatm.translators import BaseTranslator, GeminiTranslator, OpenAITranslator
12
+
13
+
14
+ class EmptyTranslator(BaseTranslator):
15
+ async def translate(self, texts: List[str]) -> List[Translation]:
16
+ return [Translation(text=t) for t in texts]
17
+
18
+
19
+ class AsyncGeminiTranslator(GeminiTranslator):
20
+ async def translate(self, texts: List[str]) -> List[Translation]:
21
+ async with asyncio.TaskGroup() as tg:
22
+ tasks = [
23
+ tg.create_task(
24
+ self.client.aio.models.generate_content(
25
+ model=self.model,
26
+ contents=self.prompt_template.format(text=t),
27
+ config=types.GenerateContentConfig(
28
+ response_mime_type="application/json",
29
+ response_schema=Translation,
30
+ ),
31
+ )
32
+ )
33
+ for t in texts
34
+ ]
35
+
36
+ results = [t.result() for t in tasks]
37
+
38
+ processed_results = []
39
+ for result, t in zip(results, texts):
40
+ try:
41
+ processed_results.append(Translation(**json.loads(result.text)))
42
+ except Exception as e:
43
+ print(
44
+ f"Error while processing text '{t}' and response '{result}': {e}. Original text was maintained."
45
+ )
46
+ processed_results.append(Translation(text=t))
47
+ return processed_results
48
+
49
+
50
+ class AsyncOpenAITranslator(OpenAITranslator):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+ self.client = AsyncOpenAI()
54
+
55
+ async def translate(self, texts):
56
+ async with asyncio.TaskGroup() as tg:
57
+ tasks = [
58
+ tg.create_task(
59
+ self.client.responses.parse(
60
+ model=self.model_id,
61
+ input=format_prompt(self.prompt_template, {"text": t}),
62
+ text_format=Translation,
63
+ )
64
+ )
65
+ for t in texts
66
+ ]
67
+
68
+ results = [t.result() for t in tasks]
69
+
70
+ processed_results = []
71
+ for result, t in zip(results, texts):
72
+ try:
73
+ processed_results.append(result.output_parsed)
74
+ except Exception as e:
75
+ print(
76
+ f"Error while processing text '{t}' and response '{result}': {e}. Original text was maintained."
77
+ )
78
+ processed_results.append(Translation(text=t))
79
+ return processed_results
aatm/api/__init__.py ADDED
File without changes
aatm/api/config.py ADDED
@@ -0,0 +1,68 @@
1
+ """Configuration model and persistence utilities for the AATM API.
2
+
3
+ This module defines the `APIConfig` model, which stores runtime configuration
4
+ for the API server and provides helper methods to save the configuration to
5
+ disk and load it back from a YAML file.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ from pydantic import BaseModel, ConfigDict
12
+ import yaml
13
+
14
+
15
+ class APIConfig(BaseModel):
16
+ """Configuration model for the AATM API server.
17
+
18
+ This model stores the runtime settings used to serve the API, including host,
19
+ port, batching behavior, optional rate limiting, and worker configuration. It
20
+ also supports persistence to and from a YAML file.
21
+
22
+ Attributes:
23
+ DEFAULT_PATH: Default filesystem path used to save and load the API
24
+ configuration.
25
+ host: Host interface on which the API server listens.
26
+ port: Port on which the API server listens.
27
+ batch_size: Batch size used by the API processing pipeline.
28
+ rate_limit: Optional maximum number of documents allowed per minute.
29
+ workers: Optional number of worker processes.
30
+ """
31
+
32
+ DEFAULT_PATH: Path = Path(".aatm/api_config.yaml")
33
+ host: str
34
+ port: str
35
+ batch_size: int
36
+ rate_limit: Optional[int] = None
37
+ workers: Optional[int] = None
38
+
39
+ model_config = ConfigDict(extra="allow")
40
+
41
+ def save_to_disk(self, path: str | Path = DEFAULT_PATH) -> None:
42
+ """Save the API configuration to a YAML file on disk.
43
+
44
+ Args:
45
+ path: Destination path where the configuration should be written. If not
46
+ provided, the default configuration path is used.
47
+
48
+ Returns:
49
+ None.
50
+ """
51
+ if isinstance(path, str):
52
+ path = Path(path)
53
+ path.write_text(yaml.safe_dump(self.model_dump(mode="json")))
54
+
55
+ @classmethod
56
+ def load_from_disk(cls, path: str | Path = DEFAULT_PATH) -> "APIConfig":
57
+ """Load the API configuration from a YAML file on disk.
58
+
59
+ Args:
60
+ path: Path to the YAML configuration file. If not provided, the default
61
+ configuration path is used.
62
+
63
+ Returns:
64
+ An `APIConfig` instance initialized from the contents of the file.
65
+ """
66
+ if isinstance(path, str):
67
+ path = Path(path)
68
+ return cls(**yaml.safe_load(path.read_text()))
@@ -0,0 +1,85 @@
1
+ """Request models for the terminology mapping API.
2
+
3
+ This module defines Pydantic models used to validate incoming API requests for
4
+ terminology mapping and retrieval. It includes request schemas for mapping
5
+ source concepts through the terminology pipeline and for performing retriever-
6
+ based searches.
7
+ """
8
+
9
+ from typing import Any, List, Optional
10
+ from fastapi import HTTPException
11
+ from pydantic import BaseModel, Field, field_validator
12
+
13
+ from aatm.data_models import SourceConcept
14
+
15
+
16
+ class TerminologyMappingRequest(BaseModel):
17
+ """Request model for terminology mapping operations.
18
+
19
+ This model encapsulates the list of source concepts to be mapped together with
20
+ the optional identifiers of the pipeline components used during the mapping
21
+ workflow.
22
+
23
+ Attributes:
24
+ source_concepts: Source concepts to map to a target terminology.
25
+ translator_id: Optional identifier of the translator component.
26
+ retriever_id: Optional identifier of the retriever component.
27
+ selector_id: Optional identifier of the selector component.
28
+ reranker_id: Optional identifier of the reranker component.
29
+ """
30
+
31
+ source_concepts: List[SourceConcept]
32
+ translator_id: Optional[str] = Field(None, examples=["gemini-2.5-flash"])
33
+ retriever_id: Optional[str] = Field(None, examples=["embeddinggemma-300M"])
34
+ selector_id: Optional[str] = Field(None, examples=["first-result-selector"])
35
+ reranker_id: Optional[str] = Field(None, examples=[None])
36
+
37
+ @field_validator("source_concepts", mode="after")
38
+ def validate_source_concepts(cls, v: List[SourceConcept]):
39
+ """Validate the list of source concepts provided in the request.
40
+
41
+ This validator ensures that at least one source concept is provided and that
42
+ all source concepts include a `source_code_description` value.
43
+
44
+ Args:
45
+ v: List of source concepts to validate.
46
+
47
+ Returns:
48
+ The validated list of source concepts.
49
+
50
+ Raises:
51
+ HTTPException: If no source concepts are provided.
52
+ HTTPException: If one or more source concepts are missing the
53
+ `source_code_description` field.
54
+ """
55
+ if len(v) == 0:
56
+ raise HTTPException(status_code=400, detail="No source concepts provided")
57
+
58
+ incomplete_source_concepts = [
59
+ source_concept
60
+ for source_concept in v
61
+ if source_concept.source_code_description is None
62
+ ]
63
+ if len(incomplete_source_concepts) > 0:
64
+ raise HTTPException(
65
+ status_code=400,
66
+ detail=f"The field source_code_description is required for all source concepts. A total of {len(incomplete_source_concepts)} source concepts are missing this field.",
67
+ )
68
+
69
+ return v
70
+
71
+
72
+ class SearchRequest(BaseModel):
73
+ """Request model for retriever-based search operations.
74
+
75
+ Attributes:
76
+ queries: List of query strings to search for.
77
+ retriever_id: Identifier of the retriever to use.
78
+ top_k: Maximum number of results to return per query.
79
+ where: Optional metadata filter applied during retrieval.
80
+ """
81
+
82
+ queries: List[str] = Field(..., examples=[["Cardiovascular disease"]])
83
+ retriever_id: str = Field(..., examples=["embeddinggemma-300M"])
84
+ top_k: int = Field(10, examples=[10])
85
+ where: dict[str, Any] | None = Field(None, examples=[None])
aatm/api/main.py ADDED
@@ -0,0 +1,160 @@
1
+ """FastAPI application for terminology mapping and retrieval workflows.
2
+
3
+ This module defines API endpoints for terminology mapping and retrieval, along
4
+ with an in-memory least-recently-used registry for caching pipeline components.
5
+ The registry avoids repeatedly instantiating expensive retrievers and
6
+ terminology mappers across requests.
7
+ """
8
+
9
+ from collections import OrderedDict
10
+ from typing import List
11
+ from fastapi import FastAPI
12
+ from aatm.api.config import APIConfig
13
+ from aatm.api.data_models import SearchRequest, TerminologyMappingRequest
14
+ from aatm.data_models import MappedSourceConcept, RetrieverResults
15
+ from aatm.pipeline import PipelineBaseClass
16
+ from aatm.registries.retrievers import load_retriever
17
+ from aatm.retrievers import ChromaDBRetriever
18
+ from aatm.terminology_mapper import TerminologyMapper
19
+
20
+ app = FastAPI()
21
+
22
+ api_config = APIConfig.load_from_disk()
23
+
24
+
25
+ class ComponentRegistry:
26
+ """Least-recently-used in-memory registry for pipeline components.
27
+
28
+ This registry stores instantiated terminology mappers and other pipeline
29
+ components keyed by configuration tuples. When the registry reaches its maximum
30
+ capacity, the least recently used item is evicted.
31
+
32
+ Args:
33
+ max_size: Maximum number of cached components to retain.
34
+
35
+ Attributes:
36
+ max_size: Maximum number of cached components.
37
+ _store: Ordered mapping from cache keys to instantiated pipeline
38
+ components.
39
+ """
40
+
41
+ def __init__(self, max_size: int = 10):
42
+ """Initialize the component registry.
43
+
44
+ Args:
45
+ max_size: Maximum number of cached components to retain before evicting
46
+ the least recently used entry.
47
+ """
48
+ self.max_size = max_size
49
+ self._store: OrderedDict[tuple, TerminologyMapper | PipelineBaseClass] = (
50
+ OrderedDict()
51
+ )
52
+
53
+ def get(self, key: tuple) -> TerminologyMapper | PipelineBaseClass | None:
54
+ """Retrieve a cached component by key.
55
+
56
+ If the key is present, the corresponding component is marked as recently used
57
+ before being returned.
58
+
59
+ Args:
60
+ key: Cache key identifying the component.
61
+
62
+ Returns:
63
+ The cached `TerminologyMapper` or `PipelineBaseClass` instance associated
64
+ with the key, or `None` if the key is not present.
65
+ """
66
+ if key not in self._store:
67
+ return None
68
+
69
+ # Mark as recently used
70
+ self._store.move_to_end(key)
71
+ return self._store[key]
72
+
73
+ def set(self, key: tuple, value: TerminologyMapper | PipelineBaseClass) -> None:
74
+ """Store a component in the registry.
75
+
76
+ If the key already exists, the existing entry is updated and marked as recently
77
+ used. If the registry is full, the least recently used entry is removed before
78
+ adding the new component.
79
+
80
+ Args:
81
+ key: Cache key identifying the component.
82
+ value: Instantiated terminology mapper or pipeline component to cache.
83
+
84
+ Returns:
85
+ None.
86
+ """
87
+ if key in self._store:
88
+ self._store.move_to_end(key)
89
+ self._store[key] = value
90
+ return
91
+
92
+ if len(self._store) >= self.max_size:
93
+ # Remove least recently used item
94
+ self._store.popitem(last=False)
95
+
96
+ self._store[key] = value
97
+
98
+
99
+ PIPELINE_COMPONENTS_REGISTRY = ComponentRegistry(max_size=20)
100
+
101
+
102
+ @app.post("/map", response_model=List[MappedSourceConcept])
103
+ def map(request: TerminologyMappingRequest):
104
+ """Map source concepts to target terminology concepts.
105
+
106
+ This endpoint retrieves or creates a `TerminologyMapper` instance based on the
107
+ requested pipeline component identifiers, then runs the mapping workflow over
108
+ the provided source concepts.
109
+
110
+ Args:
111
+ request: Terminology mapping request containing source concepts and the
112
+ identifiers of the translator, retriever, selector, and reranker
113
+ components.
114
+
115
+ Returns:
116
+ A list of mapped source concepts produced by the terminology mapping
117
+ pipeline.
118
+ """
119
+ tm_key = (
120
+ request.translator_id,
121
+ request.retriever_id,
122
+ request.selector_id,
123
+ request.reranker_id,
124
+ )
125
+
126
+ tm = PIPELINE_COMPONENTS_REGISTRY.get(tm_key)
127
+ if tm is None:
128
+ tm = TerminologyMapper.from_task_request(request, api_config)
129
+ PIPELINE_COMPONENTS_REGISTRY.set(tm_key, tm)
130
+
131
+ mapped_concepts = tm.map(
132
+ request.source_concepts,
133
+ save_to_disk=False,
134
+ return_as="mapped_source_concepts",
135
+ )
136
+ return mapped_concepts
137
+
138
+
139
+ @app.post("/search", response_model=RetrieverResults)
140
+ def search(request: SearchRequest) -> RetrieverResults:
141
+ """Search terminology candidates using a configured retriever.
142
+
143
+ This endpoint retrieves or creates a retriever instance based on the requested
144
+ retriever identifier, then executes the retrieval operation with the request
145
+ parameters.
146
+
147
+ Args:
148
+ request: Search request containing the retriever identifier and query
149
+ parameters.
150
+
151
+ Returns:
152
+ A `RetrieverResults` object containing the retrieval results.
153
+ """
154
+ retriever_key = f"retriever-{request.retriever_id}"
155
+ retriever = PIPELINE_COMPONENTS_REGISTRY.get(retriever_key)
156
+ if retriever is None:
157
+ retriever: ChromaDBRetriever = load_retriever(request.retriever_id)
158
+ PIPELINE_COMPONENTS_REGISTRY.set(retriever_key, retriever)
159
+
160
+ return retriever(**request.model_dump())