aiagents4pharma 1.8.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +9 -5
- aiagents4pharma/configs/__init__.py +5 -0
- aiagents4pharma/configs/config.yaml +4 -0
- aiagents4pharma/configs/talk2biomodels/__init__.py +6 -0
- aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +14 -0
- aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
- aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +96 -0
- aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
- aiagents4pharma/talk2biomodels/api/ols.py +72 -0
- aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
- aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
- aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
- aiagents4pharma/talk2biomodels/states/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +41 -0
- aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +54 -0
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +171 -0
- aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
- aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
- aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +76 -0
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
- aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +63 -0
- aiagents4pharma/talk2biomodels/tools/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/tools/ask_question.py +61 -18
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +304 -0
- aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +11 -9
- aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
- aiagents4pharma/talk2biomodels/tools/load_biomodel.py +0 -1
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +287 -0
- aiagents4pharma/talk2biomodels/tools/query_article.py +59 -0
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +35 -90
- aiagents4pharma/talk2biomodels/tools/steady_state.py +167 -0
- aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
- aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
- aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +25 -0
- aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +79 -0
- aiagents4pharma/talk2competitors/__init__.py +5 -0
- aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
- aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
- aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
- aiagents4pharma/talk2competitors/config/__init__.py +5 -0
- aiagents4pharma/talk2competitors/config/config.py +110 -0
- aiagents4pharma/talk2competitors/state/__init__.py +5 -0
- aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
- aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
- aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
- aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
- aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
- aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
- aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
- aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
- aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
- aiagents4pharma/talk2knowledgegraphs/__init__.py +2 -1
- aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +39 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +117 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +36 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +123 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/METADATA +44 -25
- aiagents4pharma-1.15.0.dist-info/RECORD +102 -0
- aiagents4pharma-1.8.0.dist-info/RECORD +0 -35
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,63 @@
|
|
1
|
+
'''
|
2
|
+
This file contains the unit tests for the BioModel class.
|
3
|
+
'''
|
4
|
+
|
5
|
+
from typing import List, Dict, Union, Optional
|
6
|
+
from pydantic import Field
|
7
|
+
import pytest
|
8
|
+
from ..models.sys_bio_model import SysBioModel
|
9
|
+
|
10
|
+
class TestBioModel(SysBioModel):
|
11
|
+
'''
|
12
|
+
A test BioModel class for unit testing.
|
13
|
+
'''
|
14
|
+
|
15
|
+
biomodel_id: Optional[int] = Field(None, description="BioModel ID of the model")
|
16
|
+
sbml_file_path: Optional[str] = Field(None, description="Path to an SBML file")
|
17
|
+
name: Optional[str] = Field(..., description="Name of the model")
|
18
|
+
description: Optional[str] = Field("", description="Description of the model")
|
19
|
+
param1: Optional[float] = Field(0.0, description="Parameter 1")
|
20
|
+
param2: Optional[float] = Field(0.0, description="Parameter 2")
|
21
|
+
|
22
|
+
def get_model_metadata(self) -> Dict[str, Union[str, int]]:
|
23
|
+
'''
|
24
|
+
Get the metadata of the model.
|
25
|
+
'''
|
26
|
+
return self.biomodel_id
|
27
|
+
|
28
|
+
def update_parameters(self, parameters):
|
29
|
+
'''
|
30
|
+
Update the model parameters.
|
31
|
+
'''
|
32
|
+
self.param1 = parameters.get('param1', 0.0)
|
33
|
+
self.param2 = parameters.get('param2', 0.0)
|
34
|
+
|
35
|
+
def simulate(self, duration: Union[int, float]) -> List[float]:
|
36
|
+
'''
|
37
|
+
Simulate the model.
|
38
|
+
'''
|
39
|
+
return [self.param1 + self.param2 * t for t in range(int(duration))]
|
40
|
+
|
41
|
+
def test_get_model_metadata():
|
42
|
+
'''
|
43
|
+
Test the get_model_metadata method of the BioModel class.
|
44
|
+
'''
|
45
|
+
model = TestBioModel(biomodel_id=123, name="Test Model", description="A test model")
|
46
|
+
metadata = model.get_model_metadata()
|
47
|
+
assert metadata == 123
|
48
|
+
|
49
|
+
def test_check_biomodel_id_or_sbml_file_path():
|
50
|
+
'''
|
51
|
+
Test the check_biomodel_id_or_sbml_file_path method of the BioModel class.
|
52
|
+
'''
|
53
|
+
with pytest.raises(ValueError):
|
54
|
+
TestBioModel(name="Test Model", description="A test model")
|
55
|
+
|
56
|
+
def test_simulate():
|
57
|
+
'''
|
58
|
+
Test the simulate method of the BioModel class.
|
59
|
+
'''
|
60
|
+
model = TestBioModel(biomodel_id=123, name="Test Model", description="A test model")
|
61
|
+
model.update_parameters({'param1': 1.0, 'param2': 2.0})
|
62
|
+
results = model.simulate(duration=4.0)
|
63
|
+
assert results == [1.0, 3.0, 5.0, 7.0]
|
@@ -6,3 +6,8 @@ from . import simulate_model
|
|
6
6
|
from . import ask_question
|
7
7
|
from . import custom_plotter
|
8
8
|
from . import get_modelinfo
|
9
|
+
from . import parameter_scan
|
10
|
+
from . import steady_state
|
11
|
+
from . import load_biomodel
|
12
|
+
from . import get_annotation
|
13
|
+
from . import query_article
|
@@ -5,11 +5,12 @@ Tool for asking a question about the simulation results.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import logging
|
8
|
-
from typing import Type, Annotated
|
8
|
+
from typing import Type, Annotated, Literal
|
9
|
+
import hydra
|
10
|
+
import basico
|
9
11
|
import pandas as pd
|
10
12
|
from pydantic import BaseModel, Field
|
11
13
|
from langchain_core.tools.base import BaseTool
|
12
|
-
from langchain.agents.agent_types import AgentType
|
13
14
|
from langchain_experimental.agents import create_pandas_dataframe_agent
|
14
15
|
from langchain_openai import ChatOpenAI
|
15
16
|
from langgraph.prebuilt import InjectedState
|
@@ -22,7 +23,12 @@ class AskQuestionInput(BaseModel):
|
|
22
23
|
"""
|
23
24
|
Input schema for the AskQuestion tool.
|
24
25
|
"""
|
25
|
-
question: str = Field(description="question about the simulation results")
|
26
|
+
question: str = Field(description="question about the simulation and steady state results")
|
27
|
+
experiment_name: str = Field(description="""Name assigned to the simulation
|
28
|
+
or steady state analysis when the tool
|
29
|
+
simulate_model or steady_state is invoked.""")
|
30
|
+
question_context: Literal["simulation", "steady_state"] = Field(
|
31
|
+
description="Context of the question")
|
26
32
|
state: Annotated[dict, InjectedState]
|
27
33
|
|
28
34
|
# Note: It's important that every field has type hints.
|
@@ -30,43 +36,80 @@ class AskQuestionInput(BaseModel):
|
|
30
36
|
# can lead to unexpected behavior.
|
31
37
|
class AskQuestionTool(BaseTool):
|
32
38
|
"""
|
33
|
-
Tool for
|
39
|
+
Tool for asking a question about the simulation or steady state results.
|
34
40
|
"""
|
35
41
|
name: str = "ask_question"
|
36
|
-
description: str = "A tool to ask question about the
|
42
|
+
description: str = """A tool to ask question about the
|
43
|
+
simulation or steady state results."""
|
37
44
|
args_schema: Type[BaseModel] = AskQuestionInput
|
38
45
|
return_direct: bool = False
|
39
46
|
|
40
47
|
def _run(self,
|
41
48
|
question: str,
|
49
|
+
experiment_name: str,
|
50
|
+
question_context: Literal["simulation", "steady_state"],
|
42
51
|
state: Annotated[dict, InjectedState]) -> str:
|
43
52
|
"""
|
44
53
|
Run the tool.
|
45
54
|
|
46
55
|
Args:
|
47
|
-
question (str): The question to ask about the simulation results.
|
56
|
+
question (str): The question to ask about the simulation or steady state results.
|
48
57
|
state (dict): The state of the graph.
|
49
|
-
|
58
|
+
experiment_name (str): The name assigned to the simulation or steady state analysis.
|
50
59
|
|
51
60
|
Returns:
|
52
61
|
str: The answer to the question.
|
53
62
|
"""
|
54
63
|
logger.log(logging.INFO,
|
55
|
-
"Calling ask_question tool %s",
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
#
|
64
|
+
"Calling ask_question tool %s, %s, %s",
|
65
|
+
question,
|
66
|
+
question_context,
|
67
|
+
experiment_name)
|
68
|
+
# Load hydra configuration
|
69
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
70
|
+
cfg = hydra.compose(config_name='config',
|
71
|
+
overrides=['talk2biomodels/tools/ask_question=default'])
|
72
|
+
cfg = cfg.talk2biomodels.tools.ask_question
|
73
|
+
# Get the context of the question
|
74
|
+
# and based on the context, get the data
|
75
|
+
# and prompt content to ask the question
|
76
|
+
if question_context == "steady_state":
|
77
|
+
dic_context = state["dic_steady_state_data"]
|
78
|
+
prompt_content = cfg.steady_state_prompt
|
79
|
+
else:
|
80
|
+
dic_context = state["dic_simulated_data"]
|
81
|
+
prompt_content = cfg.simulation_prompt
|
82
|
+
# Extract the
|
83
|
+
dic_data = {}
|
84
|
+
for data in dic_context:
|
85
|
+
for key in data:
|
86
|
+
if key not in dic_data:
|
87
|
+
dic_data[key] = []
|
88
|
+
dic_data[key] += [data[key]]
|
89
|
+
# Create a pandas dataframe of the data
|
90
|
+
df_data = pd.DataFrame.from_dict(dic_data)
|
91
|
+
# Extract the data for the experiment
|
92
|
+
# matching the experiment name
|
93
|
+
df = pd.DataFrame(
|
94
|
+
df_data[df_data['name'] == experiment_name]['data'].iloc[0]
|
95
|
+
)
|
96
|
+
logger.log(logging.INFO, "Shape of the dataframe: %s", df.shape)
|
97
|
+
# # Extract the model units
|
98
|
+
# model_units = basico.model_info.get_model_units()
|
99
|
+
# Update the prompt content with the model units
|
100
|
+
prompt_content += "Following are the model units:\n"
|
101
|
+
prompt_content += f"{basico.model_info.get_model_units()}\n\n"
|
102
|
+
# Create a pandas dataframe agent
|
65
103
|
df_agent = create_pandas_dataframe_agent(
|
66
104
|
ChatOpenAI(model=state['llm_model']),
|
67
105
|
allow_dangerous_code=True,
|
68
|
-
agent_type=
|
106
|
+
agent_type='tool-calling',
|
69
107
|
df=df,
|
108
|
+
max_iterations=5,
|
109
|
+
include_df_in_prompt=True,
|
110
|
+
number_of_head_rows=df.shape[0],
|
111
|
+
verbose=True,
|
70
112
|
prefix=prompt_content)
|
113
|
+
# Invoke the agent with the question
|
71
114
|
llm_result = df_agent.invoke(question)
|
72
115
|
return llm_result["output"]
|
@@ -6,14 +6,11 @@ Tool for plotting a custom figure.
|
|
6
6
|
|
7
7
|
import logging
|
8
8
|
from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
|
9
|
-
from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
|
10
9
|
from pydantic import BaseModel, Field
|
11
10
|
import pandas as pd
|
12
|
-
import pandas as pd
|
13
11
|
from langchain_openai import ChatOpenAI
|
14
12
|
from langchain_core.tools import BaseTool
|
15
13
|
from langgraph.prebuilt import InjectedState
|
16
|
-
from langgraph.prebuilt import InjectedState
|
17
14
|
|
18
15
|
# Initialize logger
|
19
16
|
logging.basicConfig(level=logging.INFO)
|
@@ -24,7 +21,7 @@ class CustomPlotterInput(BaseModel):
|
|
24
21
|
Input schema for the PlotImage tool.
|
25
22
|
"""
|
26
23
|
question: str = Field(description="Description of the plot")
|
27
|
-
|
24
|
+
simulation_name: str = Field(description="Name assigned to the simulation")
|
28
25
|
state: Annotated[dict, InjectedState]
|
29
26
|
|
30
27
|
# Note: It's important that every field has type hints.
|
@@ -41,10 +38,10 @@ class CustomPlotterTool(BaseTool):
|
|
41
38
|
description: str = "A tool to make custom plots of the simulation results"
|
42
39
|
args_schema: Type[BaseModel] = CustomPlotterInput
|
43
40
|
response_format: str = "content_and_artifact"
|
44
|
-
response_format: str = "content_and_artifact"
|
45
41
|
|
46
42
|
def _run(self,
|
47
43
|
question: str,
|
44
|
+
simulation_name: str,
|
48
45
|
state: Annotated[dict, InjectedState]
|
49
46
|
) -> Tuple[str, Union[None, List[str]]]:
|
50
47
|
"""
|
@@ -53,17 +50,24 @@ class CustomPlotterTool(BaseTool):
|
|
53
50
|
Args:
|
54
51
|
question (str): The question about the custom plot.
|
55
52
|
state (dict): The state of the graph.
|
56
|
-
question (str): The question about the custom plot.
|
57
|
-
state (dict): The state of the graph.
|
58
53
|
|
59
54
|
Returns:
|
60
55
|
str: The answer to the question
|
61
56
|
"""
|
62
57
|
logger.log(logging.INFO, "Calling custom_plotter tool %s", question)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
58
|
+
dic_simulated_data = {}
|
59
|
+
for data in state["dic_simulated_data"]:
|
60
|
+
for key in data:
|
61
|
+
if key not in dic_simulated_data:
|
62
|
+
dic_simulated_data[key] = []
|
63
|
+
dic_simulated_data[key] += [data[key]]
|
64
|
+
# Create a pandas dataframe from the dictionary
|
65
|
+
df = pd.DataFrame.from_dict(dic_simulated_data)
|
66
|
+
# Get the simulated data for the current tool call
|
67
|
+
df = pd.DataFrame(
|
68
|
+
df[df['name'] == simulation_name]['data'].iloc[0]
|
69
|
+
)
|
70
|
+
# df = pd.DataFrame.from_dict(state['dic_simulated_data'])
|
67
71
|
species_names = df.columns.tolist()
|
68
72
|
# Exclude the time column
|
69
73
|
species_names.remove('Time')
|
@@ -76,7 +80,8 @@ class CustomPlotterTool(BaseTool):
|
|
76
80
|
A list of species based on user question.
|
77
81
|
"""
|
78
82
|
relevant_species: Union[None, List[Literal[*species_names]]] = Field(
|
79
|
-
description="List of species based on user question.
|
83
|
+
description="""List of species based on user question.
|
84
|
+
If no relevant species are found, it will be None.""")
|
80
85
|
# Create an instance of the LLM model
|
81
86
|
llm = ChatOpenAI(model=state['llm_model'], temperature=0)
|
82
87
|
llm_with_structured_output = llm.with_structured_output(CustomHeader)
|
@@ -90,5 +95,6 @@ class CustomPlotterTool(BaseTool):
|
|
90
95
|
logger.info("Extracted species: %s", extracted_species)
|
91
96
|
if len(extracted_species) == 0:
|
92
97
|
return "No species found in the simulation results that matches the user prompt.", None
|
93
|
-
|
94
|
-
|
98
|
+
# Include the time column
|
99
|
+
extracted_species.insert(0, 'Time')
|
100
|
+
return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')
|
@@ -0,0 +1,304 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This module contains the `GetAnnotationTool` for fetching species annotations
|
5
|
+
based on the provided model and species names.
|
6
|
+
"""
|
7
|
+
import math
|
8
|
+
from typing import List, Annotated, Type
|
9
|
+
import logging
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from pydantic import BaseModel, Field
|
12
|
+
import basico
|
13
|
+
import pandas as pd
|
14
|
+
from langgraph.types import Command
|
15
|
+
from langgraph.prebuilt import InjectedState
|
16
|
+
from langchain_core.tools.base import BaseTool
|
17
|
+
from langchain_core.tools.base import InjectedToolCallId
|
18
|
+
from langchain_core.messages import ToolMessage
|
19
|
+
from .load_biomodel import ModelData, load_biomodel
|
20
|
+
from ..api.uniprot import search_uniprot_labels
|
21
|
+
from ..api.ols import search_ols_labels
|
22
|
+
from ..api.kegg import fetch_kegg_annotations
|
23
|
+
|
24
|
+
# Initialize logger
|
25
|
+
logging.basicConfig(level=logging.INFO)
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
ols_ontology_abbreviations = {'pato', 'chebi', 'sbo', 'fma', 'pr','go'}
|
29
|
+
|
30
|
+
def prepare_content_msg(species_not_found: List[str],
|
31
|
+
species_without_description: List[str]):
|
32
|
+
"""
|
33
|
+
Prepare the content message.
|
34
|
+
"""
|
35
|
+
content = 'Successfully extracted annotations for the species.'
|
36
|
+
if species_not_found:
|
37
|
+
content += f'''The following species do not exist, and
|
38
|
+
hence their annotations were not extracted:
|
39
|
+
{', '.join(species_not_found)}.'''
|
40
|
+
if species_without_description:
|
41
|
+
content += f'''The descriptions for the following species
|
42
|
+
were not found:
|
43
|
+
{", ".join(species_without_description)}.'''
|
44
|
+
return content
|
45
|
+
|
46
|
+
@dataclass
|
47
|
+
class ArgumentData:
|
48
|
+
"""
|
49
|
+
Dataclass for storing the argument data.
|
50
|
+
"""
|
51
|
+
experiment_name: Annotated[str, "An AI assigned _ separated name of"
|
52
|
+
" the experiment based on human query"
|
53
|
+
" and the context of the experiment."
|
54
|
+
" This must be set before the experiment is run."]
|
55
|
+
list_species_names: List[str] = Field(
|
56
|
+
default=None,
|
57
|
+
description='''List of species names to fetch annotations for.
|
58
|
+
If not provided, annotations for all
|
59
|
+
species in the model will be fetched.'''
|
60
|
+
)
|
61
|
+
|
62
|
+
class GetAnnotationInput(BaseModel):
|
63
|
+
"""
|
64
|
+
Input schema for annotation tool.
|
65
|
+
"""
|
66
|
+
arg_data: ArgumentData = Field(description="argument data")
|
67
|
+
sys_bio_model: ModelData = Field(description="model data")
|
68
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
69
|
+
state: Annotated[dict, InjectedState]
|
70
|
+
|
71
|
+
class GetAnnotationTool(BaseTool):
|
72
|
+
"""
|
73
|
+
Tool for fetching species annotations based on the provided model and species names.
|
74
|
+
"""
|
75
|
+
name: str = "get_annotation"
|
76
|
+
description: str = '''A tool to extract annotations for a list of species names
|
77
|
+
based on the provided model. Annotations include
|
78
|
+
the species name, description, database, ID, link,
|
79
|
+
and qualifier. The tool can handle multiple species
|
80
|
+
in a single invoke.'''
|
81
|
+
args_schema: Type[BaseModel] = GetAnnotationInput
|
82
|
+
return_direct: bool = False
|
83
|
+
|
84
|
+
def _run(self,
|
85
|
+
arg_data: ArgumentData,
|
86
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
87
|
+
state: Annotated[dict, InjectedState],
|
88
|
+
sys_bio_model: ModelData = None) -> str:
|
89
|
+
"""
|
90
|
+
Run the tool.
|
91
|
+
"""
|
92
|
+
logger.info("Running the GetAnnotationTool tool for species %s, %s",
|
93
|
+
arg_data.list_species_names,
|
94
|
+
arg_data.experiment_name)
|
95
|
+
|
96
|
+
# Prepare the model object
|
97
|
+
sbml_file_path = state['sbml_file_path'][-1] if state['sbml_file_path'] else None
|
98
|
+
model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path)
|
99
|
+
|
100
|
+
# Extract all the species names from the model
|
101
|
+
df_species = basico.model_info.get_species(model=model_object.copasi_model)
|
102
|
+
|
103
|
+
if df_species is None:
|
104
|
+
# for example this may happen with model 20
|
105
|
+
raise ValueError("Unable to extract species from the model.")
|
106
|
+
# Fetch annotations for the species names
|
107
|
+
arg_data.list_species_names = arg_data.list_species_names or df_species.index.tolist()
|
108
|
+
|
109
|
+
(annotations_df,
|
110
|
+
species_not_found,
|
111
|
+
species_without_description) = self._fetch_annotations(arg_data.list_species_names)
|
112
|
+
|
113
|
+
# Check if annotations are empty
|
114
|
+
# If empty, return a message
|
115
|
+
if annotations_df.empty:
|
116
|
+
logger.warning("The annotations dataframe is empty.")
|
117
|
+
return prepare_content_msg(species_not_found, species_without_description)
|
118
|
+
|
119
|
+
# Process annotations
|
120
|
+
annotations_df = self._process_annotations(annotations_df)
|
121
|
+
|
122
|
+
# Prepare the simulated data
|
123
|
+
dic_annotations_data = {
|
124
|
+
'name': arg_data.experiment_name,
|
125
|
+
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
|
126
|
+
'tool_call_id': tool_call_id,
|
127
|
+
'data': annotations_df.to_dict()
|
128
|
+
}
|
129
|
+
|
130
|
+
# Update the state with the annotations data
|
131
|
+
dic_updated_state_for_model = {}
|
132
|
+
for key, value in {
|
133
|
+
"model_id": [sys_bio_model.biomodel_id],
|
134
|
+
"sbml_file_path": [sbml_file_path],
|
135
|
+
"dic_annotations_data": [dic_annotations_data]
|
136
|
+
}.items():
|
137
|
+
if value:
|
138
|
+
dic_updated_state_for_model[key] = value
|
139
|
+
|
140
|
+
return Command(
|
141
|
+
update=dic_updated_state_for_model | {
|
142
|
+
"messages": [
|
143
|
+
ToolMessage(
|
144
|
+
content=prepare_content_msg(species_not_found,
|
145
|
+
species_without_description),
|
146
|
+
artifact=True,
|
147
|
+
tool_call_id=tool_call_id
|
148
|
+
)
|
149
|
+
]
|
150
|
+
}
|
151
|
+
)
|
152
|
+
|
153
|
+
def _fetch_annotations(self, list_species_names: List[str]) -> tuple:
|
154
|
+
"""
|
155
|
+
Fetch annotations for the given species names from the model.
|
156
|
+
In this method, we fetch the MIRIAM annotations for the species names.
|
157
|
+
If the annotation is not found, we add the species to the list of
|
158
|
+
species not found. If the annotation is found, we extract the descriptions
|
159
|
+
from the annotation and add them to the data list.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
list_species_names (List[str]): List of species names to fetch annotations for.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
tuple: A tuple containing the annotations dataframe, species not found list,
|
166
|
+
and description not found list.
|
167
|
+
"""
|
168
|
+
species_not_found = []
|
169
|
+
description_not_found = []
|
170
|
+
data = []
|
171
|
+
|
172
|
+
# Loop through the species names
|
173
|
+
for species in list_species_names:
|
174
|
+
# Get the MIRIAM annotation for the species
|
175
|
+
annotation = basico.get_miriam_annotation(name=species)
|
176
|
+
# If the annotation is not found, add the species to the list
|
177
|
+
if annotation is None:
|
178
|
+
species_not_found.append(species)
|
179
|
+
continue
|
180
|
+
|
181
|
+
# Extract the descriptions from the annotation
|
182
|
+
descriptions = annotation.get("descriptions", [])
|
183
|
+
|
184
|
+
if descriptions == []:
|
185
|
+
description_not_found.append(species)
|
186
|
+
continue
|
187
|
+
|
188
|
+
# Loop through the descriptions and add them to the data list
|
189
|
+
for desc in descriptions:
|
190
|
+
data.append({
|
191
|
+
"Species Name": species,
|
192
|
+
"Link": desc["id"],
|
193
|
+
"Qualifier": desc["qualifier"]
|
194
|
+
})
|
195
|
+
|
196
|
+
# Create a dataframe from the data list
|
197
|
+
annotations_df = pd.DataFrame(data)
|
198
|
+
|
199
|
+
# Return the annotations dataframe and the species not found list
|
200
|
+
return annotations_df, species_not_found, description_not_found
|
201
|
+
|
202
|
+
def _process_annotations(self, annotations_df: pd.DataFrame) -> pd.DataFrame:
|
203
|
+
"""
|
204
|
+
Process annotations dataframe to add additional information.
|
205
|
+
In this method, we add a new column for the ID, a new column for the database,
|
206
|
+
and a new column for the description. We then reorder the columns and process
|
207
|
+
the link to format it correctly.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
annotations_df (pd.DataFrame): Annotations dataframe to process.
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
pd.DataFrame: Processed annotations dataframe
|
214
|
+
"""
|
215
|
+
logger.info("Processing annotations.")
|
216
|
+
# Add a new column for the ID
|
217
|
+
# Get the ID from the link key
|
218
|
+
annotations_df['Id'] = annotations_df['Link'].str.split('/').str[-1]
|
219
|
+
|
220
|
+
# Add a new column for the database
|
221
|
+
# Get the database from the link key
|
222
|
+
annotations_df['Database'] = annotations_df['Link'].str.split('/').str[-2]
|
223
|
+
|
224
|
+
# Fetch descriptions for the IDs based on the database type
|
225
|
+
# by qyerying the respective APIs
|
226
|
+
identifiers = annotations_df[['Id', 'Database']].to_dict(orient='records')
|
227
|
+
descriptions = self._fetch_descriptions(identifiers)
|
228
|
+
|
229
|
+
# Add a new column for the description
|
230
|
+
# Get the description from the descriptions dictionary
|
231
|
+
# based on the ID. If the description is not found, use '-'
|
232
|
+
annotations_df['Description'] = annotations_df['Id'].apply(lambda x:
|
233
|
+
descriptions.get(x, '-'))
|
234
|
+
# annotations_df.index = annotations_df.index + 1
|
235
|
+
|
236
|
+
# Reorder the columns
|
237
|
+
annotations_df = annotations_df[
|
238
|
+
["Species Name", "Description", "Database", "Id", "Link", "Qualifier"]
|
239
|
+
]
|
240
|
+
|
241
|
+
# Process the link to format it correctly
|
242
|
+
annotations_df["Link"] = annotations_df["Link"].apply(self._process_link)
|
243
|
+
|
244
|
+
# Return the processed annotations dataframe
|
245
|
+
return annotations_df
|
246
|
+
|
247
|
+
def _process_link(self, link: str) -> str:
|
248
|
+
"""
|
249
|
+
Process link to format it correctly.
|
250
|
+
"""
|
251
|
+
for ols_ontology_abbreviation in ols_ontology_abbreviations:
|
252
|
+
if ols_ontology_abbreviation +'/' in link:
|
253
|
+
link = link.replace(f"{ols_ontology_abbreviation}/", "")
|
254
|
+
if "kegg.compound" in link:
|
255
|
+
link = link.replace("kegg.compound/", "kegg.compound:")
|
256
|
+
return link
|
257
|
+
|
258
|
+
def _fetch_descriptions(self, data: List[dict[str, str]]) -> dict[str, str]:
|
259
|
+
"""
|
260
|
+
Fetch protein names or labels based on the database type.
|
261
|
+
"""
|
262
|
+
logger.info("Fetching descriptions for the IDs.")
|
263
|
+
results = {}
|
264
|
+
grouped_data = {}
|
265
|
+
|
266
|
+
# In the following loop, we create a dictionary with database as the key
|
267
|
+
# and a list of identifiers as the value. If either the database or the
|
268
|
+
# identifier is NaN, we set it to None.
|
269
|
+
for entry in data:
|
270
|
+
identifier = entry.get('Id')
|
271
|
+
database = entry.get('Database')
|
272
|
+
# Check if database is NaN
|
273
|
+
if isinstance(database, float):
|
274
|
+
if math.isnan(database):
|
275
|
+
database = None
|
276
|
+
results[identifier or "unknown"] = "-"
|
277
|
+
else:
|
278
|
+
database = database.lower()
|
279
|
+
grouped_data.setdefault(database, []).append(identifier)
|
280
|
+
|
281
|
+
# In the following loop, we fetch the descriptions for the identifiers
|
282
|
+
# based on the database type.
|
283
|
+
# Constants
|
284
|
+
|
285
|
+
for database, identifiers in grouped_data.items():
|
286
|
+
if database == 'uniprot':
|
287
|
+
results.update(search_uniprot_labels(identifiers))
|
288
|
+
elif database in ols_ontology_abbreviations:
|
289
|
+
annotations = search_ols_labels([
|
290
|
+
{"Id": id_, "Database": database}
|
291
|
+
for id_ in identifiers
|
292
|
+
])
|
293
|
+
for identifier in identifiers:
|
294
|
+
results[identifier] = annotations.get(database, {}).get(identifier, "-")
|
295
|
+
elif database == 'kegg.compound':
|
296
|
+
data = [{"Id": identifier, "Database": "kegg.compound"}
|
297
|
+
for identifier in identifiers]
|
298
|
+
annotations = fetch_kegg_annotations(data)
|
299
|
+
for identifier in identifiers:
|
300
|
+
results[identifier] = annotations.get(database, {}).get(identifier, "-")
|
301
|
+
else:
|
302
|
+
for identifier in identifiers:
|
303
|
+
results[identifier] = "-"
|
304
|
+
return results
|
@@ -25,12 +25,12 @@ class RequestedModelInfo:
|
|
25
25
|
"""
|
26
26
|
Dataclass for storing the requested model information.
|
27
27
|
"""
|
28
|
-
species: bool = Field(description="Get species from the model.")
|
29
|
-
parameters: bool = Field(description="Get parameters from the model.")
|
30
|
-
compartments: bool = Field(description="Get compartments from the model.")
|
31
|
-
units: bool = Field(description="Get units from the model.")
|
32
|
-
description: bool = Field(description="Get description from the model.")
|
33
|
-
name: bool = Field(description="Get name from the model.")
|
28
|
+
species: bool = Field(description="Get species from the model.", default=False)
|
29
|
+
parameters: bool = Field(description="Get parameters from the model.", default=False)
|
30
|
+
compartments: bool = Field(description="Get compartments from the model.", default=False)
|
31
|
+
units: bool = Field(description="Get units from the model.", default=False)
|
32
|
+
description: bool = Field(description="Get description from the model.", default=False)
|
33
|
+
name: bool = Field(description="Get name from the model.", default=False)
|
34
34
|
|
35
35
|
class GetModelInfoInput(BaseModel):
|
36
36
|
"""
|
@@ -47,8 +47,10 @@ class GetModelInfoTool(BaseTool):
|
|
47
47
|
"""
|
48
48
|
This tool ise used extract model information.
|
49
49
|
"""
|
50
|
-
name: str = "
|
51
|
-
description: str = "A tool for extracting
|
50
|
+
name: str = "get_modelinfo"
|
51
|
+
description: str = """A tool for extracting name,
|
52
|
+
description, species, parameters,
|
53
|
+
compartments, and units from a model."""
|
52
54
|
args_schema: Type[BaseModel] = GetModelInfoInput
|
53
55
|
|
54
56
|
def _run(self,
|
@@ -81,7 +83,7 @@ class GetModelInfoTool(BaseTool):
|
|
81
83
|
# Extract species from the model
|
82
84
|
if requested_model_info.species:
|
83
85
|
df_species = basico.model_info.get_species(model=model_obj.copasi_model)
|
84
|
-
dic_results['Species'] = df_species.
|
86
|
+
dic_results['Species'] = df_species['display_name'].tolist()
|
85
87
|
dic_results['Species'] = ','.join(dic_results['Species'])
|
86
88
|
|
87
89
|
# Extract parameters from the model
|