aiagents4pharma 1.13.1__py3-none-any.whl → 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2biomodels/__init__.py +1 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -2
- aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
- aiagents4pharma/talk2biomodels/api/ols.py +72 -0
- aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +1 -0
- aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +173 -0
- aiagents4pharma/talk2biomodels/tools/__init__.py +1 -0
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +289 -0
- {aiagents4pharma-1.13.1.dist-info → aiagents4pharma-1.14.0.dist-info}/METADATA +1 -1
- {aiagents4pharma-1.13.1.dist-info → aiagents4pharma-1.14.0.dist-info}/RECORD +16 -9
- {aiagents4pharma-1.13.1.dist-info → aiagents4pharma-1.14.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.13.1.dist-info → aiagents4pharma-1.14.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.13.1.dist-info → aiagents4pharma-1.14.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ from ..tools.search_models import SearchModelsTool
|
|
15
15
|
from ..tools.get_modelinfo import GetModelInfoTool
|
16
16
|
from ..tools.simulate_model import SimulateModelTool
|
17
17
|
from ..tools.custom_plotter import CustomPlotterTool
|
18
|
+
from ..tools.get_annotation import GetAnnotationTool
|
18
19
|
from ..tools.ask_question import AskQuestionTool
|
19
20
|
from ..tools.parameter_scan import ParameterScanTool
|
20
21
|
from ..tools.steady_state import SteadyStateTool
|
@@ -44,8 +45,9 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
|
|
44
45
|
SearchModelsTool(),
|
45
46
|
GetModelInfoTool(),
|
46
47
|
SteadyStateTool(),
|
47
|
-
ParameterScanTool()
|
48
|
-
|
48
|
+
ParameterScanTool(),
|
49
|
+
GetAnnotationTool()
|
50
|
+
])
|
49
51
|
|
50
52
|
# Define the model
|
51
53
|
llm = ChatOpenAI(model=llm_model, temperature=0)
|
@@ -0,0 +1,83 @@
|
|
1
|
+
"""
|
2
|
+
This module contains the API for fetching Kegg database
|
3
|
+
"""
|
4
|
+
import re
|
5
|
+
from typing import List, Dict
|
6
|
+
import requests
|
7
|
+
|
8
|
+
def fetch_from_api(base_url: str, query: str) -> str:
|
9
|
+
"""Fetch data from the given API endpoint."""
|
10
|
+
try:
|
11
|
+
response = requests.get(base_url + query, timeout=10)
|
12
|
+
response.raise_for_status()
|
13
|
+
return response.text
|
14
|
+
except requests.exceptions.RequestException as e:
|
15
|
+
print(f"Error fetching data for query {query}: {e}")
|
16
|
+
return ""
|
17
|
+
|
18
|
+
def fetch_kegg_names(ids: List[str], batch_size: int = 10) -> Dict[str, str]:
|
19
|
+
"""
|
20
|
+
Fetch the names of multiple KEGG entries using the KEGG REST API in batches.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
ids (List[str]): List of KEGG IDs.
|
24
|
+
batch_size (int): Maximum number of IDs to include in a single request.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
Dict[str, str]: A mapping of KEGG IDs to their names.
|
28
|
+
"""
|
29
|
+
if not ids:
|
30
|
+
return {}
|
31
|
+
|
32
|
+
base_url = "https://rest.kegg.jp/get/"
|
33
|
+
entry_name_map = {}
|
34
|
+
|
35
|
+
# Process IDs in batches
|
36
|
+
for i in range(0, len(ids), batch_size):
|
37
|
+
batch = ids[i:i + batch_size]
|
38
|
+
query = "+".join(batch)
|
39
|
+
entry_data = fetch_from_api(base_url, query)
|
40
|
+
|
41
|
+
# if not entry_data:
|
42
|
+
# continue
|
43
|
+
entries = entry_data.split("///")
|
44
|
+
for entry in entries:
|
45
|
+
if not entry.strip():
|
46
|
+
continue
|
47
|
+
lines = entry.strip().split("\n")
|
48
|
+
entry_line = next((line for line in lines
|
49
|
+
if line.startswith("ENTRY")), None)
|
50
|
+
name_line = next((line for line in lines
|
51
|
+
if line.startswith("NAME")), None)
|
52
|
+
|
53
|
+
# if not entry_line and not name_line:
|
54
|
+
# continue
|
55
|
+
entry_id = entry_line.split()[1]
|
56
|
+
# Split multiple names in the NAME field and clean them
|
57
|
+
names = [
|
58
|
+
re.sub(r'[^a-zA-Z0-9\s]', '', name).strip()
|
59
|
+
for name in name_line.replace("NAME", "").strip().split(";")
|
60
|
+
]
|
61
|
+
# Join cleaned names into a single string
|
62
|
+
entry_name_map[entry_id] = " ".join(names).strip()
|
63
|
+
|
64
|
+
return entry_name_map
|
65
|
+
|
66
|
+
def fetch_kegg_annotations(data: List[Dict[str, str]],
|
67
|
+
batch_size: int = 10) -> Dict[str, Dict[str, str]]:
|
68
|
+
"""Fetch KEGG entry descriptions grouped by database type."""
|
69
|
+
grouped_data = {}
|
70
|
+
for entry in data:
|
71
|
+
db_type = entry["Database"].lower()
|
72
|
+
grouped_data.setdefault(db_type, []).append(entry["Id"])
|
73
|
+
|
74
|
+
results = {}
|
75
|
+
for db_type, ids in grouped_data.items():
|
76
|
+
results[db_type] = fetch_kegg_names(ids, batch_size=batch_size)
|
77
|
+
|
78
|
+
return results
|
79
|
+
|
80
|
+
# def get_protein_name_or_label(data: List[Dict[str, str]],
|
81
|
+
# batch_size: int = 10) -> Dict[str, Dict[str, str]]:
|
82
|
+
# """Fetch descriptions for KEGG-related identifiers."""
|
83
|
+
# return fetch_kegg_annotations(data, batch_size=batch_size)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
"""
|
2
|
+
This module contains the API for fetching ols database
|
3
|
+
"""
|
4
|
+
from typing import List, Dict
|
5
|
+
import requests
|
6
|
+
|
7
|
+
def fetch_from_ols(term: str) -> str:
|
8
|
+
"""
|
9
|
+
Fetch the label for a single term from OLS.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
term (str): The term in the format "ONTOLOGY:TERM_ID".
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
str: The label for the term or an error message.
|
16
|
+
"""
|
17
|
+
try:
|
18
|
+
ontology, _ = term.split(":")
|
19
|
+
base_url = f"https://www.ebi.ac.uk/ols4/api/ontologies/{ontology.lower()}/terms"
|
20
|
+
params = {"obo_id": term}
|
21
|
+
response = requests.get(
|
22
|
+
base_url,
|
23
|
+
params=params,
|
24
|
+
headers={"Accept": "application/json"},
|
25
|
+
timeout=10
|
26
|
+
)
|
27
|
+
response.raise_for_status()
|
28
|
+
data = response.json()
|
29
|
+
label = '-'
|
30
|
+
# Extract and return the label
|
31
|
+
if "_embedded" in data and "terms" in data["_embedded"] \
|
32
|
+
and len(data["_embedded"]["terms"]) > 0:
|
33
|
+
label = data["_embedded"]["terms"][0].get("label", "Label not found")
|
34
|
+
return label
|
35
|
+
except (requests.exceptions.RequestException, KeyError, IndexError) as e:
|
36
|
+
return f"Error: {str(e)}"
|
37
|
+
|
38
|
+
def fetch_ols_labels(terms: List[str]) -> Dict[str, str]:
|
39
|
+
"""
|
40
|
+
Fetch labels for multiple terms from OLS.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
terms (List[str]): A list of terms in the format "ONTOLOGY:TERM_ID".
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Dict[str, str]: A mapping of term IDs to their labels or error messages.
|
47
|
+
"""
|
48
|
+
results = {}
|
49
|
+
for term in terms:
|
50
|
+
results[term] = fetch_from_ols(term)
|
51
|
+
return results
|
52
|
+
|
53
|
+
def search_ols_labels(data: List[Dict[str, str]]) -> Dict[str, Dict[str, str]]:
|
54
|
+
"""
|
55
|
+
Fetch OLS annotations grouped by ontology type.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
data (List[Dict[str, str]]): A list of dictionaries containing 'Id' and 'Database'.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Dict[str, Dict[str, str]]: A mapping of ontology type to term labels.
|
62
|
+
"""
|
63
|
+
grouped_data = {}
|
64
|
+
for entry in data:
|
65
|
+
ontology = entry["Database"].lower()
|
66
|
+
grouped_data.setdefault(ontology, []).append(entry["Id"])
|
67
|
+
|
68
|
+
results = {}
|
69
|
+
for ontology, terms in grouped_data.items():
|
70
|
+
results[ontology] = fetch_ols_labels(terms)
|
71
|
+
|
72
|
+
return results
|
@@ -0,0 +1,35 @@
|
|
1
|
+
"""
|
2
|
+
This module contains the API for fetching uniprot database
|
3
|
+
"""
|
4
|
+
from typing import List, Dict
|
5
|
+
import requests
|
6
|
+
|
7
|
+
def search_uniprot_labels(identifiers: List[str]) -> Dict[str, str]:
|
8
|
+
"""
|
9
|
+
Fetch protein names or labels for a list of UniProt identifiers by making sequential requests.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
identifiers (List[str]): A list of UniProt identifiers.
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
Dict[str, str]: A mapping of UniProt identifiers to their protein names or error messages.
|
16
|
+
"""
|
17
|
+
results = {}
|
18
|
+
base_url = "https://www.uniprot.org/uniprot/"
|
19
|
+
|
20
|
+
for identifier in identifiers:
|
21
|
+
url = f"{base_url}{identifier}.json"
|
22
|
+
try:
|
23
|
+
response = requests.get(url, timeout=10)
|
24
|
+
response.raise_for_status()
|
25
|
+
data = response.json()
|
26
|
+
protein_name = (
|
27
|
+
data.get('proteinDescription', {})
|
28
|
+
.get('recommendedName', {})
|
29
|
+
.get('fullName', {})
|
30
|
+
.get('value', 'Name not found')
|
31
|
+
)
|
32
|
+
results[identifier] = protein_name
|
33
|
+
except requests.exceptions.RequestException as e:
|
34
|
+
results[identifier] = f"Error: {str(e)}"
|
35
|
+
return results
|
@@ -23,3 +23,4 @@ class Talk2Biomodels(AgentState):
|
|
23
23
|
dic_simulated_data: Annotated[list[dict], operator.add]
|
24
24
|
dic_scanned_data: Annotated[list[dict], operator.add]
|
25
25
|
dic_steady_state_data: Annotated[list[dict], operator.add]
|
26
|
+
dic_annotations_data : Annotated[list[dict], operator.add]
|
@@ -0,0 +1,57 @@
|
|
1
|
+
'''
|
2
|
+
Test cases for Talk2Biomodels.
|
3
|
+
'''
|
4
|
+
|
5
|
+
from ..api.uniprot import search_uniprot_labels
|
6
|
+
from ..api.ols import fetch_from_ols
|
7
|
+
from ..api.kegg import fetch_kegg_names, fetch_from_api
|
8
|
+
|
9
|
+
def test_search_uniprot_labels():
|
10
|
+
'''
|
11
|
+
Test the search_uniprot_labels function.
|
12
|
+
'''
|
13
|
+
# "P61764" = Positive result, "P0000Q" = negative result
|
14
|
+
identifiers = ["P61764", "P0000Q"]
|
15
|
+
results = search_uniprot_labels(identifiers)
|
16
|
+
assert results["P61764"] == "Syntaxin-binding protein 1"
|
17
|
+
assert results["P0000Q"].startswith("Error: 400")
|
18
|
+
|
19
|
+
def test_fetch_from_ols():
|
20
|
+
'''
|
21
|
+
Test the fetch_from_ols function.
|
22
|
+
'''
|
23
|
+
term_1 = "GO:0005886" #Positive result
|
24
|
+
term_2 = "GO:ABC123" #Negative result
|
25
|
+
label_1 = fetch_from_ols(term_1)
|
26
|
+
label_2 = fetch_from_ols(term_2)
|
27
|
+
assert isinstance(label_1, str), f"Expected string, got {type(label_1)}"
|
28
|
+
assert isinstance(label_2, str), f"Expected string, got {type(label_2)}"
|
29
|
+
assert label_1 == "plasma membrane"
|
30
|
+
assert label_2.startswith("Error: 404")
|
31
|
+
|
32
|
+
def test_fetch_kegg_names():
|
33
|
+
'''
|
34
|
+
Test the fetch_kegg_names function.
|
35
|
+
'''
|
36
|
+
ids = ["C00001", "C00002"]
|
37
|
+
results = fetch_kegg_names(ids)
|
38
|
+
assert results["C00001"] == "H2O"
|
39
|
+
assert results["C00002"] == "ATP"
|
40
|
+
|
41
|
+
# Try with an empty list
|
42
|
+
results = fetch_kegg_names([])
|
43
|
+
assert not results
|
44
|
+
|
45
|
+
def test_fetch_from_api():
|
46
|
+
'''
|
47
|
+
Test the fetch_from_api function.
|
48
|
+
'''
|
49
|
+
base_url = "https://rest.kegg.jp/get/"
|
50
|
+
query = "C00001"
|
51
|
+
entry_data = fetch_from_api(base_url, query)
|
52
|
+
assert entry_data.startswith("ENTRY C00001")
|
53
|
+
|
54
|
+
# Try with an invalid query
|
55
|
+
query = "C0000Q"
|
56
|
+
entry_data = fetch_from_api(base_url, query)
|
57
|
+
assert not entry_data
|
@@ -0,0 +1,173 @@
|
|
1
|
+
'''
|
2
|
+
Test cases for Talk2Biomodels.
|
3
|
+
'''
|
4
|
+
import random
|
5
|
+
import pytest
|
6
|
+
from langchain_core.messages import HumanMessage, ToolMessage
|
7
|
+
from ..agents.t2b_agent import get_app
|
8
|
+
from ..tools.get_annotation import prepare_content_msg
|
9
|
+
|
10
|
+
@pytest.fixture(name="make_graph")
|
11
|
+
def make_graph_fixture():
|
12
|
+
'''
|
13
|
+
Create an instance of the talk2biomodels agent.
|
14
|
+
'''
|
15
|
+
unique_id = random.randint(1000, 9999)
|
16
|
+
graph = get_app(unique_id)
|
17
|
+
config = {"configurable": {"thread_id": unique_id}}
|
18
|
+
return graph, config
|
19
|
+
|
20
|
+
def test_species_list(make_graph):
|
21
|
+
'''
|
22
|
+
Test the tool by passing species names.
|
23
|
+
'''
|
24
|
+
# Test with a valid species name
|
25
|
+
app, config = make_graph
|
26
|
+
prompt = "Extract annotations of species IL6 in model 537."
|
27
|
+
app.invoke(
|
28
|
+
{"messages": [HumanMessage(content=prompt)]},
|
29
|
+
config=config
|
30
|
+
)
|
31
|
+
current_state = app.get_state(config)
|
32
|
+
# print (current_state.values["dic_annotations_data"])
|
33
|
+
dic_annotations_data = current_state.values["dic_annotations_data"]
|
34
|
+
|
35
|
+
# The assert statement checks if IL6 is present in the returned annotations.
|
36
|
+
assert dic_annotations_data[0]['data']["Species Name"][0] == "IL6"
|
37
|
+
|
38
|
+
# Test with an invalid species name
|
39
|
+
app, config = make_graph
|
40
|
+
prompt = "Extract annotations of species NADH in model 537."
|
41
|
+
app.invoke(
|
42
|
+
{"messages": [HumanMessage(content=prompt)]},
|
43
|
+
config=config
|
44
|
+
)
|
45
|
+
current_state = app.get_state(config)
|
46
|
+
reversed_messages = current_state.values["messages"][::-1]
|
47
|
+
# Loop through the reversed messages until a
|
48
|
+
# ToolMessage is found.
|
49
|
+
|
50
|
+
test_condition = False
|
51
|
+
for msg in reversed_messages:
|
52
|
+
# Assert that the one of the messages is a ToolMessage
|
53
|
+
# and its artifact is None.
|
54
|
+
if isinstance(msg, ToolMessage) and msg.name == "get_annotation":
|
55
|
+
#If a ToolMessage exists and artifact is None (meaning no valid annotation was found)
|
56
|
+
#and the rejected species (NADH) is mentioned, the test passes.
|
57
|
+
if msg.artifact is None and 'NADH' in msg.content:
|
58
|
+
#If artifact is None, it means no annotation was found
|
59
|
+
# (likely due to an invalid species).
|
60
|
+
#If artifact contains data, the tool successfully retrieved annotations.
|
61
|
+
test_condition = True
|
62
|
+
break
|
63
|
+
# assert test_condition
|
64
|
+
assert test_condition, "Expected rejection message for NADH but did not find it."
|
65
|
+
|
66
|
+
# Test with an invalid species name and a valid species name
|
67
|
+
app, config = make_graph
|
68
|
+
prompt = "Extract annotations of species NADH, NAD, and IL7 in model 64."
|
69
|
+
app.invoke(
|
70
|
+
{"messages": [HumanMessage(content=prompt)]},
|
71
|
+
config=config
|
72
|
+
)
|
73
|
+
current_state = app.get_state(config)
|
74
|
+
# dic_annotations_data = current_state.values["dic_annotations_data"]
|
75
|
+
reversed_messages = current_state.values["messages"][::-1]
|
76
|
+
# Loop through the reversed messages until a
|
77
|
+
# ToolMessage is found.
|
78
|
+
artifact_was_none = False
|
79
|
+
for msg in reversed_messages:
|
80
|
+
# Assert that the one of the messages is a ToolMessage
|
81
|
+
# and its artifact is None.
|
82
|
+
if isinstance(msg, ToolMessage) and msg.name == "get_annotation":
|
83
|
+
# print (msg.artifact, msg.content)
|
84
|
+
|
85
|
+
if msg.artifact is True and 'IL7' in msg.content:
|
86
|
+
artifact_was_none = True
|
87
|
+
break
|
88
|
+
assert artifact_was_none
|
89
|
+
|
90
|
+
def test_all_species(make_graph):
|
91
|
+
'''
|
92
|
+
Test the tool by asking for annotations of all species is specific models.
|
93
|
+
|
94
|
+
model 12 contains species with no URL.
|
95
|
+
model 20 contains species without description.
|
96
|
+
model 56 contains species with database outside of KEGG, UniProt, and OLS.
|
97
|
+
'''
|
98
|
+
# Test valid models
|
99
|
+
for model_id in [12, 20, 56]:
|
100
|
+
app, config = make_graph
|
101
|
+
prompt = f"Extract annotations of all species model {model_id}."
|
102
|
+
# Test the tool get_modelinfo
|
103
|
+
app.invoke({"messages": [HumanMessage(content=prompt)]},
|
104
|
+
config=config
|
105
|
+
)
|
106
|
+
#print(response["messages"])
|
107
|
+
# assistant_msg = response["messages"][-1].content
|
108
|
+
|
109
|
+
current_state = app.get_state(config)
|
110
|
+
|
111
|
+
reversed_messages = current_state.values["messages"][::-1]
|
112
|
+
# Coveres all of the use cases for the expecetd sting on all the species
|
113
|
+
test_condition = False
|
114
|
+
for msg in reversed_messages:
|
115
|
+
if isinstance(msg, ToolMessage) and msg.name == "get_annotation":
|
116
|
+
if model_id == 12:
|
117
|
+
# For model 12:
|
118
|
+
# Expect a successful extraction (artifact is True) and that the content
|
119
|
+
# matches what is returned by prepare_content_msg for species ['lac'].
|
120
|
+
if (msg.artifact is True and msg.content == prepare_content_msg(['lac'],[])
|
121
|
+
and msg.status=="success"):
|
122
|
+
test_condition = True
|
123
|
+
break
|
124
|
+
|
125
|
+
if model_id == 20:
|
126
|
+
# For model 20:
|
127
|
+
# Expect an error message containing a note that species extraction failed.
|
128
|
+
if ("Unable to extract species from the model"
|
129
|
+
in msg.content and msg.status == "error"):
|
130
|
+
test_condition = True
|
131
|
+
break
|
132
|
+
|
133
|
+
if model_id == 56:
|
134
|
+
# For model 56:
|
135
|
+
# Expect a successful extraction (artifact is True) and that the content
|
136
|
+
# matches for for missing description ['ORI'].
|
137
|
+
if (msg.artifact is True and
|
138
|
+
msg.content == prepare_content_msg([],['ORI'])
|
139
|
+
and msg.status == "success"):
|
140
|
+
test_condition = True
|
141
|
+
break
|
142
|
+
|
143
|
+
# Retrieve the dictionary that holds all the annotation data from the app's state
|
144
|
+
dic_annotations_data = current_state.values["dic_annotations_data"]
|
145
|
+
|
146
|
+
assert isinstance(dic_annotations_data, list),\
|
147
|
+
f"Expected a list for model {model_id}, got {type(dic_annotations_data)}"
|
148
|
+
assert len(dic_annotations_data) > 0,\
|
149
|
+
f"Expected species data for model {model_id}, but got empty list"
|
150
|
+
assert test_condition # Expected output is validated
|
151
|
+
|
152
|
+
# Test case where no model is specified
|
153
|
+
app, config = make_graph
|
154
|
+
prompt = "Extract annotations of all species."
|
155
|
+
app.invoke({"messages": [HumanMessage(content=prompt)]},
|
156
|
+
config=config
|
157
|
+
)
|
158
|
+
current_state = app.get_state(config)
|
159
|
+
# dic_annotations_data = current_state.values["dic_annotations_data"]
|
160
|
+
reversed_messages = current_state.values["messages"][::-1]
|
161
|
+
print(reversed_messages)
|
162
|
+
|
163
|
+
test_condition = False
|
164
|
+
for msg in reversed_messages:
|
165
|
+
# Assert that the one of the messages is a ToolMessage
|
166
|
+
if isinstance(msg, ToolMessage) and msg.name == "get_annotation":
|
167
|
+
if "Error:" in msg.content and msg.status == "error":
|
168
|
+
test_condition = True
|
169
|
+
break
|
170
|
+
# Loop through the reversed messages until a
|
171
|
+
# ToolMessage is found.
|
172
|
+
# Ensure the system correctly informs the user to specify a model
|
173
|
+
assert test_condition, "Expected error message when no model is specified was not found."
|
@@ -0,0 +1,289 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This module contains the `GetAnnotationTool` for fetching species annotations
|
5
|
+
based on the provided model and species names.
|
6
|
+
"""
|
7
|
+
import math
|
8
|
+
from typing import List, Annotated, Type
|
9
|
+
import logging
|
10
|
+
from pydantic import BaseModel, Field
|
11
|
+
import basico
|
12
|
+
import pandas as pd
|
13
|
+
from langgraph.types import Command
|
14
|
+
from langgraph.prebuilt import InjectedState
|
15
|
+
from langchain_core.tools.base import BaseTool
|
16
|
+
from langchain_core.tools.base import InjectedToolCallId
|
17
|
+
from langchain_core.messages import ToolMessage
|
18
|
+
from .load_biomodel import ModelData, load_biomodel
|
19
|
+
from ..api.uniprot import search_uniprot_labels
|
20
|
+
from ..api.ols import search_ols_labels
|
21
|
+
from ..api.kegg import fetch_kegg_annotations
|
22
|
+
|
23
|
+
# Initialize logger
|
24
|
+
logging.basicConfig(level=logging.INFO)
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
ols_ontology_abbreviations = {'pato', 'chebi', 'sbo', 'fma', 'pr','go'}
|
28
|
+
|
29
|
+
def prepare_content_msg(species_not_found: List[str],
|
30
|
+
species_without_description: List[str]):
|
31
|
+
"""
|
32
|
+
Prepare the content message.
|
33
|
+
"""
|
34
|
+
content = 'Successfully extracted annotations for the species.'
|
35
|
+
if species_not_found:
|
36
|
+
content += f'''The following species do not exist, and
|
37
|
+
hence their annotations were not extracted:
|
38
|
+
{', '.join(species_not_found)}.'''
|
39
|
+
if species_without_description:
|
40
|
+
content += f'''The descriptions for the following species
|
41
|
+
were not found:
|
42
|
+
{", ".join(species_without_description)}.'''
|
43
|
+
return content
|
44
|
+
|
45
|
+
class GetAnnotationInput(BaseModel):
|
46
|
+
"""
|
47
|
+
Input schema for annotation tool.
|
48
|
+
"""
|
49
|
+
sys_bio_model: ModelData = Field(description="model data")
|
50
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
51
|
+
list_species_names: List[str] = Field(
|
52
|
+
default=[],
|
53
|
+
description='''List of species names to fetch annotations for.
|
54
|
+
If not provided, annotations for all
|
55
|
+
species in the model will be fetched.'''
|
56
|
+
)
|
57
|
+
state: Annotated[dict, InjectedState]
|
58
|
+
|
59
|
+
class GetAnnotationTool(BaseTool):
|
60
|
+
"""
|
61
|
+
Tool for fetching species annotations based on the provided model and species names.
|
62
|
+
"""
|
63
|
+
name: str = "get_annotation"
|
64
|
+
description: str = '''A tool to extract annotations for a list of species names
|
65
|
+
based on the provided model. Annotations include
|
66
|
+
the species name, description, database, ID, link,
|
67
|
+
and qualifier. The tool can handle multiple species
|
68
|
+
in a single invoke.'''
|
69
|
+
args_schema: Type[BaseModel] = GetAnnotationInput
|
70
|
+
return_direct: bool = False
|
71
|
+
|
72
|
+
def _run(self,
|
73
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
74
|
+
state: Annotated[dict, InjectedState],
|
75
|
+
list_species_names: List[str] = None,
|
76
|
+
sys_bio_model: ModelData = None) -> str:
|
77
|
+
"""
|
78
|
+
Run the tool.
|
79
|
+
"""
|
80
|
+
logger.info("Running the GetAnnotationTool tool for species %s", list_species_names)
|
81
|
+
|
82
|
+
# Prepare the model object
|
83
|
+
sbml_file_path = state['sbml_file_path'][-1] if state['sbml_file_path'] else None
|
84
|
+
model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path)
|
85
|
+
|
86
|
+
# Extract all the species names from the model
|
87
|
+
df_species = basico.model_info.get_species(model=model_object.copasi_model)
|
88
|
+
|
89
|
+
if df_species is None:
|
90
|
+
# for example this may happen with model 20
|
91
|
+
raise ValueError("Unable to extract species from the model.")
|
92
|
+
# Fetch annotations for the species names
|
93
|
+
list_species_names = list_species_names or df_species.index.tolist()
|
94
|
+
|
95
|
+
(annotations_df,
|
96
|
+
species_not_found,
|
97
|
+
species_without_description) = self._fetch_annotations(list_species_names)
|
98
|
+
|
99
|
+
# Check if annotations are empty
|
100
|
+
# If empty, return a message
|
101
|
+
if annotations_df.empty:
|
102
|
+
logger.warning("The annotations dataframe is empty.")
|
103
|
+
return prepare_content_msg(species_not_found, species_without_description)
|
104
|
+
|
105
|
+
# Process annotations
|
106
|
+
annotations_df = self._process_annotations(annotations_df)
|
107
|
+
|
108
|
+
# Prepare the simulated data
|
109
|
+
dic_annotations_data = {
|
110
|
+
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
|
111
|
+
'tool_call_id': tool_call_id,
|
112
|
+
'data': annotations_df.to_dict()
|
113
|
+
}
|
114
|
+
|
115
|
+
# Update the state with the annotations data
|
116
|
+
dic_updated_state_for_model = {}
|
117
|
+
for key, value in {
|
118
|
+
"model_id": [sys_bio_model.biomodel_id],
|
119
|
+
"sbml_file_path": [sbml_file_path],
|
120
|
+
"dic_annotations_data": [dic_annotations_data]
|
121
|
+
}.items():
|
122
|
+
if value:
|
123
|
+
dic_updated_state_for_model[key] = value
|
124
|
+
|
125
|
+
return Command(
|
126
|
+
update=dic_updated_state_for_model | {
|
127
|
+
"messages": [
|
128
|
+
ToolMessage(
|
129
|
+
content=prepare_content_msg(species_not_found,
|
130
|
+
species_without_description),
|
131
|
+
artifact=True,
|
132
|
+
tool_call_id=tool_call_id
|
133
|
+
)
|
134
|
+
]
|
135
|
+
}
|
136
|
+
)
|
137
|
+
|
138
|
+
def _fetch_annotations(self, list_species_names: List[str]) -> tuple:
|
139
|
+
"""
|
140
|
+
Fetch annotations for the given species names from the model.
|
141
|
+
In this method, we fetch the MIRIAM annotations for the species names.
|
142
|
+
If the annotation is not found, we add the species to the list of
|
143
|
+
species not found. If the annotation is found, we extract the descriptions
|
144
|
+
from the annotation and add them to the data list.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
list_species_names (List[str]): List of species names to fetch annotations for.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
tuple: A tuple containing the annotations dataframe, species not found list,
|
151
|
+
and description not found list.
|
152
|
+
"""
|
153
|
+
species_not_found = []
|
154
|
+
description_not_found = []
|
155
|
+
data = []
|
156
|
+
|
157
|
+
# Loop through the species names
|
158
|
+
for species in list_species_names:
|
159
|
+
# Get the MIRIAM annotation for the species
|
160
|
+
annotation = basico.get_miriam_annotation(name=species)
|
161
|
+
# If the annotation is not found, add the species to the list
|
162
|
+
if annotation is None:
|
163
|
+
species_not_found.append(species)
|
164
|
+
continue
|
165
|
+
|
166
|
+
# Extract the descriptions from the annotation
|
167
|
+
descriptions = annotation.get("descriptions", [])
|
168
|
+
|
169
|
+
if descriptions == []:
|
170
|
+
description_not_found.append(species)
|
171
|
+
continue
|
172
|
+
|
173
|
+
# Loop through the descriptions and add them to the data list
|
174
|
+
for desc in descriptions:
|
175
|
+
data.append({
|
176
|
+
"Species Name": species,
|
177
|
+
"Link": desc["id"],
|
178
|
+
"Qualifier": desc["qualifier"]
|
179
|
+
})
|
180
|
+
|
181
|
+
# Create a dataframe from the data list
|
182
|
+
annotations_df = pd.DataFrame(data)
|
183
|
+
|
184
|
+
# Return the annotations dataframe and the species not found list
|
185
|
+
return annotations_df, species_not_found, description_not_found
|
186
|
+
|
187
|
+
def _process_annotations(self, annotations_df: pd.DataFrame) -> pd.DataFrame:
|
188
|
+
"""
|
189
|
+
Process annotations dataframe to add additional information.
|
190
|
+
In this method, we add a new column for the ID, a new column for the database,
|
191
|
+
and a new column for the description. We then reorder the columns and process
|
192
|
+
the link to format it correctly.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
annotations_df (pd.DataFrame): Annotations dataframe to process.
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
pd.DataFrame: Processed annotations dataframe
|
199
|
+
"""
|
200
|
+
logger.info("Processing annotations.")
|
201
|
+
# Add a new column for the ID
|
202
|
+
# Get the ID from the link key
|
203
|
+
annotations_df['Id'] = annotations_df['Link'].str.split('/').str[-1]
|
204
|
+
|
205
|
+
# Add a new column for the database
|
206
|
+
# Get the database from the link key
|
207
|
+
annotations_df['Database'] = annotations_df['Link'].str.split('/').str[-2]
|
208
|
+
|
209
|
+
# Fetch descriptions for the IDs based on the database type
|
210
|
+
# by qyerying the respective APIs
|
211
|
+
identifiers = annotations_df[['Id', 'Database']].to_dict(orient='records')
|
212
|
+
descriptions = self._fetch_descriptions(identifiers)
|
213
|
+
|
214
|
+
# Add a new column for the description
|
215
|
+
# Get the description from the descriptions dictionary
|
216
|
+
# based on the ID. If the description is not found, use '-'
|
217
|
+
annotations_df['Description'] = annotations_df['Id'].apply(lambda x:
|
218
|
+
descriptions.get(x, '-'))
|
219
|
+
# annotations_df.index = annotations_df.index + 1
|
220
|
+
|
221
|
+
# Reorder the columns
|
222
|
+
annotations_df = annotations_df[
|
223
|
+
["Species Name", "Description", "Database", "Id", "Link", "Qualifier"]
|
224
|
+
]
|
225
|
+
|
226
|
+
# Process the link to format it correctly
|
227
|
+
annotations_df["Link"] = annotations_df["Link"].apply(self._process_link)
|
228
|
+
|
229
|
+
# Return the processed annotations dataframe
|
230
|
+
return annotations_df
|
231
|
+
|
232
|
+
def _process_link(self, link: str) -> str:
|
233
|
+
"""
|
234
|
+
Process link to format it correctly.
|
235
|
+
"""
|
236
|
+
for ols_ontology_abbreviation in ols_ontology_abbreviations:
|
237
|
+
if ols_ontology_abbreviation +'/' in link:
|
238
|
+
link = link.replace(f"{ols_ontology_abbreviation}/", "")
|
239
|
+
if "kegg.compound" in link:
|
240
|
+
link = link.replace("kegg.compound/", "kegg.compound:")
|
241
|
+
return link
|
242
|
+
|
243
|
+
def _fetch_descriptions(self, data: List[dict[str, str]]) -> dict[str, str]:
|
244
|
+
"""
|
245
|
+
Fetch protein names or labels based on the database type.
|
246
|
+
"""
|
247
|
+
logger.info("Fetching descriptions for the IDs.")
|
248
|
+
results = {}
|
249
|
+
grouped_data = {}
|
250
|
+
|
251
|
+
# In the following loop, we create a dictionary with database as the key
|
252
|
+
# and a list of identifiers as the value. If either the database or the
|
253
|
+
# identifier is NaN, we set it to None.
|
254
|
+
for entry in data:
|
255
|
+
identifier = entry.get('Id')
|
256
|
+
database = entry.get('Database')
|
257
|
+
# Check if database is NaN
|
258
|
+
if isinstance(database, float):
|
259
|
+
if math.isnan(database):
|
260
|
+
database = None
|
261
|
+
results[identifier or "unknown"] = "-"
|
262
|
+
else:
|
263
|
+
database = database.lower()
|
264
|
+
grouped_data.setdefault(database, []).append(identifier)
|
265
|
+
|
266
|
+
# In the following loop, we fetch the descriptions for the identifiers
|
267
|
+
# based on the database type.
|
268
|
+
# Constants
|
269
|
+
|
270
|
+
for database, identifiers in grouped_data.items():
|
271
|
+
if database == 'uniprot':
|
272
|
+
results.update(search_uniprot_labels(identifiers))
|
273
|
+
elif database in ols_ontology_abbreviations:
|
274
|
+
annotations = search_ols_labels([
|
275
|
+
{"Id": id_, "Database": database}
|
276
|
+
for id_ in identifiers
|
277
|
+
])
|
278
|
+
for identifier in identifiers:
|
279
|
+
results[identifier] = annotations.get(database, {}).get(identifier, "-")
|
280
|
+
elif database == 'kegg.compound':
|
281
|
+
data = [{"Id": identifier, "Database": "kegg.compound"}
|
282
|
+
for identifier in identifiers]
|
283
|
+
annotations = fetch_kegg_annotations(data)
|
284
|
+
for identifier in identifiers:
|
285
|
+
results[identifier] = annotations.get(database, {}).get(identifier, "-")
|
286
|
+
else:
|
287
|
+
for identifier in identifiers:
|
288
|
+
results[identifier] = "-"
|
289
|
+
return results
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: aiagents4pharma
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.14.0
|
4
4
|
Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -5,21 +5,28 @@ aiagents4pharma/configs/talk2biomodels/__init__.py,sha256=5ah__-8XyRblwT0U1ByRig
|
|
5
5
|
aiagents4pharma/configs/talk2biomodels/agents/__init__.py,sha256=_ZoG8snICK2bidWtc2KOGs738LWg9_r66V9mOMnEb-E,71
|
6
6
|
aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
7
7
|
aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml,sha256=Oi89_BbxfQc6SGW1pC-hyZMqOIkiAMOlNwpCa4VCXk0,327
|
8
|
-
aiagents4pharma/talk2biomodels/__init__.py,sha256=
|
8
|
+
aiagents4pharma/talk2biomodels/__init__.py,sha256=2ICwVh1u07SZv31Jd2DKHobauOxWNWY29_Gqq3kOnNQ,159
|
9
9
|
aiagents4pharma/talk2biomodels/agents/__init__.py,sha256=sn5-fREjMdEvb-OUan3iOqrgYGjplNx3J8hYOaW0Po8,128
|
10
|
-
aiagents4pharma/talk2biomodels/agents/t2b_agent.py,sha256=
|
10
|
+
aiagents4pharma/talk2biomodels/agents/t2b_agent.py,sha256=13aSlBZBWtjXOLq7c99u33c923fi2Ab0VW--eX5gF-o,3366
|
11
|
+
aiagents4pharma/talk2biomodels/api/__init__.py,sha256=_GmDQqDLYpsUPUeE1nBNlT5AI9oTXIcqgOfNfvmonqA,123
|
12
|
+
aiagents4pharma/talk2biomodels/api/kegg.py,sha256=QzYDAfJ16E7tbHGxP8ZNWRizMkMRS_HJuucueXEC1Gg,2943
|
13
|
+
aiagents4pharma/talk2biomodels/api/ols.py,sha256=qq0Qy-gJDxanQW-HfCChDsTQsY1M41ua8hMlTnfuzrA,2202
|
14
|
+
aiagents4pharma/talk2biomodels/api/uniprot.py,sha256=aPUAVBR7UYXDuuhDpKezAK2aTMzo-NxFYFq6C0W5u6U,1175
|
11
15
|
aiagents4pharma/talk2biomodels/models/__init__.py,sha256=5fTHHm3PVloYPNKXbgNlcPgv3-u28ZquxGydFYDfhJA,122
|
12
16
|
aiagents4pharma/talk2biomodels/models/basico_model.py,sha256=PH25FTOuUjsmw_UUxoRb-4kptOYpicEn4GqS0phS3nk,4807
|
13
17
|
aiagents4pharma/talk2biomodels/models/sys_bio_model.py,sha256=JeoiGQAvQABHnG0wKR2XBmmxqQdtgO6kxaLDUTUmr1s,2001
|
14
18
|
aiagents4pharma/talk2biomodels/states/__init__.py,sha256=YLg1-N0D9qyRRLRqwqfLCLAqZYDtMVZTfI8Y0b_4tbA,139
|
15
|
-
aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py,sha256=
|
19
|
+
aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py,sha256=44vRJwc2NlVLQQr1Smipr02YzXHOyqUVSNlg7rFjli0,950
|
16
20
|
aiagents4pharma/talk2biomodels/tests/__init__.py,sha256=Jbw5tJxSrjGoaK5IX3pJWDCNzhrVQ10lkYq2oQ_KQD8,45
|
21
|
+
aiagents4pharma/talk2biomodels/tests/test_api.py,sha256=7Kz2r5F5tjmn3F0LoM33oP-21W633936YHiyf5toGg0,1716
|
17
22
|
aiagents4pharma/talk2biomodels/tests/test_basico_model.py,sha256=y82fpTJMPHwtXxlle1cGQ_2Bewwpxi0aJSVrVAYLhN0,2060
|
23
|
+
aiagents4pharma/talk2biomodels/tests/test_get_annotation.py,sha256=lgdWNl6g1hiTcUbcmgn2bUk5_-8EUpSNa0MMpIMGeDA,7301
|
18
24
|
aiagents4pharma/talk2biomodels/tests/test_langgraph.py,sha256=QLAL4nmHrioTD-w-9OE0wQi5JdWJJ59PejNbDzCSvw4,15170
|
19
25
|
aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py,sha256=HSmBBViMi0jYf4gWX21IbppAfDzG0nr_S3KtKS9fZVQ,2165
|
20
|
-
aiagents4pharma/talk2biomodels/tools/__init__.py,sha256
|
26
|
+
aiagents4pharma/talk2biomodels/tools/__init__.py,sha256=ZiOdSFaeHW6y3hdtBfsKf0vSb3MuCLuy9MDyjARggb4,322
|
21
27
|
aiagents4pharma/talk2biomodels/tools/ask_question.py,sha256=qpltsgyLFFwLYQeapQHASFRDCNiWsJkmTH_sUrfJ_Fg,3708
|
22
28
|
aiagents4pharma/talk2biomodels/tools/custom_plotter.py,sha256=HWwKTX3o4dk0GcRVTO2hPrFSu98mtJ4TKC_hbHXOe1c,4018
|
29
|
+
aiagents4pharma/talk2biomodels/tools/get_annotation.py,sha256=Ifbbz08YFI1ifAy3t0tYkr45-k7inO8lZePvCSe5ZfA,11835
|
23
30
|
aiagents4pharma/talk2biomodels/tools/get_modelinfo.py,sha256=qA-4FOI-O728Nmn7s8JJ8HKwxvA9MZbst7NkPKTAMV4,5391
|
24
31
|
aiagents4pharma/talk2biomodels/tools/load_biomodel.py,sha256=pyVzLQoMnuJYEwsjeOlqcUrbU1F1Z-pNlgkhFaoKpy0,689
|
25
32
|
aiagents4pharma/talk2biomodels/tools/parameter_scan.py,sha256=aIyL_m46s3Q74ieJOZjZBM34VCjBKSMpEtckhdZofbE,12139
|
@@ -77,8 +84,8 @@ aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py,sh
|
|
77
84
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py,sha256=tW426knki2DBIHcWyF_K04iMMdbpIn_e_TpPmTgz2dI,113
|
78
85
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py,sha256=Bx8x6zzk5614ApWB90N_iv4_Y_Uq0-KwUeBwYSdQMU4,924
|
79
86
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py,sha256=8eoxR-VHo0G7ReQIwje7xEhE-SJlHdef7_wJRpnvFIc,4116
|
80
|
-
aiagents4pharma-1.
|
81
|
-
aiagents4pharma-1.
|
82
|
-
aiagents4pharma-1.
|
83
|
-
aiagents4pharma-1.
|
84
|
-
aiagents4pharma-1.
|
87
|
+
aiagents4pharma-1.14.0.dist-info/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
|
88
|
+
aiagents4pharma-1.14.0.dist-info/METADATA,sha256=yGCN0sRc2dixfRh-UAPk2uLZLBjb5WUaSd0VwROU7qY,8609
|
89
|
+
aiagents4pharma-1.14.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
90
|
+
aiagents4pharma-1.14.0.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
|
91
|
+
aiagents4pharma-1.14.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|