aiagents4pharma 1.8.2__tar.gz → 1.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/PKG-INFO +1 -1
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/__init__.py +1 -0
- aiagents4pharma-1.9.0/aiagents4pharma/configs/__init__.py +5 -0
- aiagents4pharma-1.9.0/aiagents4pharma/configs/config.yaml +3 -0
- aiagents4pharma-1.9.0/aiagents4pharma/configs/talk2biomodels/__init__.py +5 -0
- aiagents4pharma-1.9.0/aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma-1.9.0/aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
- aiagents4pharma-1.9.0/aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +8 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/__init__.py +1 -1
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/agents/t2b_agent.py +3 -3
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +1 -1
- aiagents4pharma-1.9.0/aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2biomodels/tests/test_basico_model.py +55 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2biomodels/tests/test_langgraph.py +240 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +57 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/ask_question.py +16 -7
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +6 -6
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/simulate_model.py +26 -12
- aiagents4pharma-1.9.0/aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
- aiagents4pharma-1.9.0/aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma.egg-info/PKG-INFO +1 -1
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma.egg-info/SOURCES.txt +20 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/pyproject.toml +9 -19
- aiagents4pharma-1.9.0/release_version.txt +1 -0
- aiagents4pharma-1.8.2/release_version.txt +0 -1
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/LICENSE +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/README.md +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/agents/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/models/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/models/basico_model.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/models/sys_bio_model.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/states/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/load_biomodel.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/search_models.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/agents/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/agents/scp_agent.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/states/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/states/state_talk2cells.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/tools/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +0 -0
- {aiagents4pharma-1.8.2/aiagents4pharma/talk2knowledgegraphs/utils → aiagents4pharma-1.9.0/aiagents4pharma/talk2competitors}/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma.egg-info/dependency_links.txt +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma.egg-info/requires.txt +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma.egg-info/top_level.txt +0 -0
- {aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: aiagents4pharma
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.9.0
|
4
4
|
Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -0,0 +1,8 @@
|
|
1
|
+
_target_: talk2biomodels.agents.t2b_agent.get_app
|
2
|
+
state_modifier: >
|
3
|
+
You are Talk2BioModels agent.
|
4
|
+
If the user asks for the uploaded model,
|
5
|
+
then pass the use_uploaded_model argument
|
6
|
+
as True. If the user asks for simulation,
|
7
|
+
then suggest a value for the `simulation_name`
|
8
|
+
argument.
|
{aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/agents/t2b_agent.py
RENAMED
@@ -52,10 +52,10 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
|
|
52
52
|
llm = ChatOpenAI(model=llm_model, temperature=0)
|
53
53
|
# Load hydra configuration
|
54
54
|
logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.")
|
55
|
-
with hydra.initialize(version_base=None, config_path="
|
55
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
56
56
|
cfg = hydra.compose(config_name='config',
|
57
|
-
overrides=['
|
58
|
-
cfg = cfg.
|
57
|
+
overrides=['talk2biomodels/agents/t2b_agent=default'])
|
58
|
+
cfg = cfg.talk2biomodels.agents.t2b_agent
|
59
59
|
logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier)
|
60
60
|
# Create the agent
|
61
61
|
model = create_react_agent(
|
@@ -20,5 +20,5 @@ class Talk2Biomodels(AgentState):
|
|
20
20
|
# the operator for the sbml_file_path field.
|
21
21
|
# https://langchain-ai.github.io/langgraph/troubleshooting/errors/INVALID_CONCURRENT_GRAPH_UPDATE/
|
22
22
|
sbml_file_path: Annotated[list, operator.add]
|
23
|
-
dic_simulated_data: dict
|
23
|
+
dic_simulated_data: Annotated[list[dict], operator.add]
|
24
24
|
llm_model: str
|
@@ -0,0 +1,55 @@
|
|
1
|
+
'''
|
2
|
+
A test BasicoModel class for pytest unit testing.
|
3
|
+
'''
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
import pytest
|
7
|
+
import basico
|
8
|
+
from ..models.basico_model import BasicoModel
|
9
|
+
|
10
|
+
@pytest.fixture(name="model")
|
11
|
+
def model_fixture():
|
12
|
+
"""
|
13
|
+
A fixture for the BasicoModel class.
|
14
|
+
"""
|
15
|
+
return BasicoModel(biomodel_id=64, species={"Pyruvate": 100}, duration=2, interval=2)
|
16
|
+
|
17
|
+
def test_with_biomodel_id(model):
|
18
|
+
"""
|
19
|
+
Test initialization of BasicoModel with biomodel_id.
|
20
|
+
"""
|
21
|
+
assert model.biomodel_id == 64
|
22
|
+
# check if the simulation results are a pandas DataFrame object
|
23
|
+
assert isinstance(model.simulate(parameters={'Pyruvate': 0.5, 'KmPFKF6P': 1.5},
|
24
|
+
duration=2,
|
25
|
+
interval=2),
|
26
|
+
pd.DataFrame)
|
27
|
+
assert isinstance(model.simulate(parameters={None: None}, duration=2, interval=2),
|
28
|
+
pd.DataFrame)
|
29
|
+
assert model.description == basico.biomodels.get_model_info(model.biomodel_id)["description"]
|
30
|
+
|
31
|
+
def test_with_sbml_file():
|
32
|
+
"""
|
33
|
+
Test initialization of BasicoModel with sbml_file_path.
|
34
|
+
"""
|
35
|
+
model_object = BasicoModel(sbml_file_path="./BIOMD0000000064_url.xml")
|
36
|
+
assert model_object.sbml_file_path == "./BIOMD0000000064_url.xml"
|
37
|
+
assert isinstance(model_object.simulate(duration=2, interval=2), pd.DataFrame)
|
38
|
+
assert isinstance(model_object.simulate(parameters={'NADH': 0.5}, duration=2, interval=2),
|
39
|
+
pd.DataFrame)
|
40
|
+
|
41
|
+
def test_check_biomodel_id_or_sbml_file_path():
|
42
|
+
'''
|
43
|
+
Test the check_biomodel_id_or_sbml_file_path method of the BioModel class.
|
44
|
+
'''
|
45
|
+
with pytest.raises(ValueError):
|
46
|
+
BasicoModel(species={"Pyruvate": 100}, duration=2, interval=2)
|
47
|
+
|
48
|
+
def test_get_model_metadata():
|
49
|
+
"""
|
50
|
+
Test the get_model_metadata method of the BasicoModel class.
|
51
|
+
"""
|
52
|
+
model = BasicoModel(biomodel_id=64)
|
53
|
+
metadata = model.get_model_metadata()
|
54
|
+
assert metadata["Model Type"] == "SBML Model (COPASI)"
|
55
|
+
assert metadata["Parameter Count"] == len(basico.get_parameters())
|
@@ -0,0 +1,240 @@
|
|
1
|
+
'''
|
2
|
+
Test cases for Talk2Biomodels.
|
3
|
+
'''
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
from langchain_core.messages import HumanMessage, ToolMessage
|
7
|
+
from ..agents.t2b_agent import get_app
|
8
|
+
|
9
|
+
def test_get_modelinfo_tool():
|
10
|
+
'''
|
11
|
+
Test the get_modelinfo tool.
|
12
|
+
'''
|
13
|
+
unique_id = 12345
|
14
|
+
app = get_app(unique_id)
|
15
|
+
config = {"configurable": {"thread_id": unique_id}}
|
16
|
+
# Update state
|
17
|
+
app.update_state(config,
|
18
|
+
{"sbml_file_path": ["aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml"]})
|
19
|
+
prompt = "Extract all relevant information from the uploaded model."
|
20
|
+
# Test the tool get_modelinfo
|
21
|
+
response = app.invoke(
|
22
|
+
{"messages": [HumanMessage(content=prompt)]},
|
23
|
+
config=config
|
24
|
+
)
|
25
|
+
assistant_msg = response["messages"][-1].content
|
26
|
+
# Check if the assistant message is a string
|
27
|
+
assert isinstance(assistant_msg, str)
|
28
|
+
|
29
|
+
def test_search_models_tool():
|
30
|
+
'''
|
31
|
+
Test the search_models tool.
|
32
|
+
'''
|
33
|
+
unique_id = 12345
|
34
|
+
app = get_app(unique_id)
|
35
|
+
config = {"configurable": {"thread_id": unique_id}}
|
36
|
+
# Update state
|
37
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
38
|
+
prompt = "Search for models on Crohn's disease."
|
39
|
+
# Test the tool get_modelinfo
|
40
|
+
response = app.invoke(
|
41
|
+
{"messages": [HumanMessage(content=prompt)]},
|
42
|
+
config=config
|
43
|
+
)
|
44
|
+
assistant_msg = response["messages"][-1].content
|
45
|
+
# Check if the assistant message is a string
|
46
|
+
assert isinstance(assistant_msg, str)
|
47
|
+
# Check if the assistant message contains the
|
48
|
+
# biomodel id BIO0000000537
|
49
|
+
assert "BIOMD0000000537" in assistant_msg
|
50
|
+
|
51
|
+
def test_ask_question_tool():
|
52
|
+
'''
|
53
|
+
Test the ask_question tool without the simulation results.
|
54
|
+
'''
|
55
|
+
unique_id = 12345
|
56
|
+
app = get_app(unique_id, llm_model='gpt-4o-mini')
|
57
|
+
config = {"configurable": {"thread_id": unique_id}}
|
58
|
+
|
59
|
+
##########################################
|
60
|
+
# Test ask_question tool when simulation
|
61
|
+
# results are not available i.e. the
|
62
|
+
# simulation has not been run. In this
|
63
|
+
# case, the tool should return an error
|
64
|
+
##########################################
|
65
|
+
# Update state
|
66
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
67
|
+
# Define the prompt
|
68
|
+
prompt = "Call the ask_question tool to answer the "
|
69
|
+
prompt += "question: What is the concentration of CRP "
|
70
|
+
prompt += "in serum at 1000 hours? The simulation name "
|
71
|
+
prompt += "is `simulation_name`."
|
72
|
+
# Invoke the tool
|
73
|
+
app.invoke(
|
74
|
+
{"messages": [HumanMessage(content=prompt)]},
|
75
|
+
config=config
|
76
|
+
)
|
77
|
+
# Get the messages from the current state
|
78
|
+
# and reverse the order
|
79
|
+
current_state = app.get_state(config)
|
80
|
+
reversed_messages = current_state.values["messages"][::-1]
|
81
|
+
# Loop through the reversed messages until a
|
82
|
+
# ToolMessage is found.
|
83
|
+
for msg in reversed_messages:
|
84
|
+
# Assert that the message is a ToolMessage
|
85
|
+
# and its status is "error"
|
86
|
+
if isinstance(msg, ToolMessage):
|
87
|
+
assert msg.status == "error"
|
88
|
+
|
89
|
+
def test_simulate_model_tool():
|
90
|
+
'''
|
91
|
+
Test the simulate_model tool when simulating
|
92
|
+
multiple models.
|
93
|
+
'''
|
94
|
+
unique_id = 123
|
95
|
+
app = get_app(unique_id)
|
96
|
+
config = {"configurable": {"thread_id": unique_id}}
|
97
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
98
|
+
# Upload a model to the state
|
99
|
+
app.update_state(config,
|
100
|
+
{"sbml_file_path": ["aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml"]})
|
101
|
+
prompt = "Simulate models 64 and the uploaded model"
|
102
|
+
# Invoke the agent
|
103
|
+
app.invoke(
|
104
|
+
{"messages": [HumanMessage(content=prompt)]},
|
105
|
+
config=config
|
106
|
+
)
|
107
|
+
current_state = app.get_state(config)
|
108
|
+
dic_simulated_data = current_state.values["dic_simulated_data"]
|
109
|
+
# Check if the dic_simulated_data is a list
|
110
|
+
assert isinstance(dic_simulated_data, list)
|
111
|
+
# Check if the length of the dic_simulated_data is 2
|
112
|
+
assert len(dic_simulated_data) == 2
|
113
|
+
# Check if the source of the first model is 64
|
114
|
+
assert dic_simulated_data[0]['source'] == 64
|
115
|
+
# Check if the source of the second model is upload
|
116
|
+
assert dic_simulated_data[1]['source'] == "upload"
|
117
|
+
# Check if the data of the first model contains
|
118
|
+
assert '1,3-bisphosphoglycerate' in dic_simulated_data[0]['data']
|
119
|
+
# Check if the data of the second model contains
|
120
|
+
assert 'mTORC2' in dic_simulated_data[1]['data']
|
121
|
+
|
122
|
+
def test_integration():
|
123
|
+
'''
|
124
|
+
Test the integration of the tools.
|
125
|
+
'''
|
126
|
+
unique_id = 123
|
127
|
+
app = get_app(unique_id)
|
128
|
+
config = {"configurable": {"thread_id": unique_id}}
|
129
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
130
|
+
# ##########################################
|
131
|
+
# ## Test simulate_model tool
|
132
|
+
# ##########################################
|
133
|
+
prompt = "Simulate the model 537 for 2016 hours and intervals"
|
134
|
+
prompt += " 2016 with an initial concentration of `DoseQ2W` "
|
135
|
+
prompt += "set to 300 and `Dose` set to 0. Reset the concentration"
|
136
|
+
prompt += " of `NAD` to 100 every 500 hours."
|
137
|
+
# Test the tool get_modelinfo
|
138
|
+
response = app.invoke(
|
139
|
+
{"messages": [HumanMessage(content=prompt)]},
|
140
|
+
config=config
|
141
|
+
)
|
142
|
+
assistant_msg = response["messages"][-1].content
|
143
|
+
print (assistant_msg)
|
144
|
+
# Check if the assistant message is a string
|
145
|
+
assert isinstance(assistant_msg, str)
|
146
|
+
##########################################
|
147
|
+
# Test ask_question tool when simulation
|
148
|
+
# results are available
|
149
|
+
##########################################
|
150
|
+
# Update state
|
151
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
152
|
+
prompt = "What is the concentration of CRP in serum at 1000 hours? "
|
153
|
+
# prompt += "Show only the concentration, rounded to 1 decimal place."
|
154
|
+
# prompt += "For example, if the concentration is 0.123456, "
|
155
|
+
# prompt += "your response should be `0.1`. Do not return any other information."
|
156
|
+
# Test the tool get_modelinfo
|
157
|
+
response = app.invoke(
|
158
|
+
{"messages": [HumanMessage(content=prompt)]},
|
159
|
+
config=config
|
160
|
+
)
|
161
|
+
assistant_msg = response["messages"][-1].content
|
162
|
+
# print (assistant_msg)
|
163
|
+
# Check if the assistant message is a string
|
164
|
+
assert "1.7" in assistant_msg
|
165
|
+
|
166
|
+
##########################################
|
167
|
+
# Test custom_plotter tool when the
|
168
|
+
# simulation results are available
|
169
|
+
##########################################
|
170
|
+
prompt = "Plot only CRP related species."
|
171
|
+
|
172
|
+
# Update state
|
173
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"}
|
174
|
+
)
|
175
|
+
# Test the tool get_modelinfo
|
176
|
+
response = app.invoke(
|
177
|
+
{"messages": [HumanMessage(content=prompt)]},
|
178
|
+
config=config
|
179
|
+
)
|
180
|
+
assistant_msg = response["messages"][-1].content
|
181
|
+
current_state = app.get_state(config)
|
182
|
+
# Get the messages from the current state
|
183
|
+
# and reverse the order
|
184
|
+
reversed_messages = current_state.values["messages"][::-1]
|
185
|
+
# Loop through the reversed messages
|
186
|
+
# until a ToolMessage is found.
|
187
|
+
expected_header = ['Time', 'CRP[serum]', 'CRPExtracellular']
|
188
|
+
expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
|
189
|
+
expected_header += ['CRP[liver]']
|
190
|
+
predicted_artifact = []
|
191
|
+
for msg in reversed_messages:
|
192
|
+
if isinstance(msg, ToolMessage):
|
193
|
+
# Work on the message if it is a ToolMessage
|
194
|
+
# These may contain additional visuals that
|
195
|
+
# need to be displayed to the user.
|
196
|
+
if msg.name == "custom_plotter":
|
197
|
+
predicted_artifact = msg.artifact
|
198
|
+
break
|
199
|
+
# Convert the artifact into a pandas dataframe
|
200
|
+
# for easy comparison
|
201
|
+
df = pd.DataFrame(predicted_artifact)
|
202
|
+
# Extract the headers from the dataframe
|
203
|
+
predicted_header = df.columns.tolist()
|
204
|
+
# Check if the header is in the expected_header
|
205
|
+
# assert expected_header in predicted_artifact
|
206
|
+
assert set(expected_header).issubset(set(predicted_header))
|
207
|
+
##########################################
|
208
|
+
# Test custom_plotter tool when the
|
209
|
+
# simulation results are available but
|
210
|
+
# the species is not available
|
211
|
+
##########################################
|
212
|
+
prompt = "Plot the species `TP53`."
|
213
|
+
|
214
|
+
# Update state
|
215
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"}
|
216
|
+
)
|
217
|
+
# Test the tool get_modelinfo
|
218
|
+
response = app.invoke(
|
219
|
+
{"messages": [HumanMessage(content=prompt)]},
|
220
|
+
config=config
|
221
|
+
)
|
222
|
+
assistant_msg = response["messages"][-1].content
|
223
|
+
# print (response["messages"])
|
224
|
+
current_state = app.get_state(config)
|
225
|
+
# Get the messages from the current state
|
226
|
+
# and reverse the order
|
227
|
+
reversed_messages = current_state.values["messages"][::-1]
|
228
|
+
# Loop through the reversed messages until a
|
229
|
+
# ToolMessage is found.
|
230
|
+
predicted_artifact = []
|
231
|
+
for msg in reversed_messages:
|
232
|
+
if isinstance(msg, ToolMessage):
|
233
|
+
# Work on the message if it is a ToolMessage
|
234
|
+
# These may contain additional visuals that
|
235
|
+
# need to be displayed to the user.
|
236
|
+
if msg.name == "custom_plotter":
|
237
|
+
predicted_artifact = msg.artifact
|
238
|
+
break
|
239
|
+
# Check if the the predicted artifact is `None`
|
240
|
+
assert predicted_artifact is None
|
@@ -0,0 +1,57 @@
|
|
1
|
+
'''
|
2
|
+
This file contains the unit tests for the BioModel class.
|
3
|
+
'''
|
4
|
+
|
5
|
+
from typing import List, Dict, Union, Optional
|
6
|
+
from pydantic import Field
|
7
|
+
import pytest
|
8
|
+
from ..models.sys_bio_model import SysBioModel
|
9
|
+
|
10
|
+
class TestBioModel(SysBioModel):
|
11
|
+
'''
|
12
|
+
A test BioModel class for unit testing.
|
13
|
+
'''
|
14
|
+
|
15
|
+
biomodel_id: Optional[int] = Field(None, description="BioModel ID of the model")
|
16
|
+
sbml_file_path: Optional[str] = Field(None, description="Path to an SBML file")
|
17
|
+
name: Optional[str] = Field(..., description="Name of the model")
|
18
|
+
description: Optional[str] = Field("", description="Description of the model")
|
19
|
+
|
20
|
+
def get_model_metadata(self) -> Dict[str, Union[str, int]]:
|
21
|
+
'''
|
22
|
+
Get the metadata of the model.
|
23
|
+
'''
|
24
|
+
return self.biomodel_id
|
25
|
+
|
26
|
+
def simulate(self,
|
27
|
+
parameters: Dict[str, Union[float, int]],
|
28
|
+
duration: Union[int, float]) -> List[float]:
|
29
|
+
'''
|
30
|
+
Simulate the model.
|
31
|
+
'''
|
32
|
+
param1 = parameters.get('param1', 0.0)
|
33
|
+
param2 = parameters.get('param2', 0.0)
|
34
|
+
return [param1 + param2 * t for t in range(int(duration))]
|
35
|
+
|
36
|
+
def test_get_model_metadata():
|
37
|
+
'''
|
38
|
+
Test the get_model_metadata method of the BioModel class.
|
39
|
+
'''
|
40
|
+
model = TestBioModel(biomodel_id=123, name="Test Model", description="A test model")
|
41
|
+
metadata = model.get_model_metadata()
|
42
|
+
assert metadata == 123
|
43
|
+
|
44
|
+
def test_check_biomodel_id_or_sbml_file_path():
|
45
|
+
'''
|
46
|
+
Test the check_biomodel_id_or_sbml_file_path method of the BioModel class.
|
47
|
+
'''
|
48
|
+
with pytest.raises(ValueError):
|
49
|
+
TestBioModel(name="Test Model", description="A test model")
|
50
|
+
|
51
|
+
def test_simulate():
|
52
|
+
'''
|
53
|
+
Test the simulate method of the BioModel class.
|
54
|
+
'''
|
55
|
+
model = TestBioModel(biomodel_id=123, name="Test Model", description="A test model")
|
56
|
+
results = model.simulate(parameters={'param1': 1.0, 'param2': 2.0}, duration=4.0)
|
57
|
+
assert results == [1.0, 3.0, 5.0, 7.0]
|
{aiagents4pharma-1.8.2 → aiagents4pharma-1.9.0}/aiagents4pharma/talk2biomodels/tools/ask_question.py
RENAMED
@@ -23,6 +23,8 @@ class AskQuestionInput(BaseModel):
|
|
23
23
|
Input schema for the AskQuestion tool.
|
24
24
|
"""
|
25
25
|
question: str = Field(description="question about the simulation results")
|
26
|
+
simulation_name: str = Field(description="""Name assigned to the simulation
|
27
|
+
when the tool simulate_model was invoked.""")
|
26
28
|
state: Annotated[dict, InjectedState]
|
27
29
|
|
28
30
|
# Note: It's important that every field has type hints.
|
@@ -39,6 +41,7 @@ class AskQuestionTool(BaseTool):
|
|
39
41
|
|
40
42
|
def _run(self,
|
41
43
|
question: str,
|
44
|
+
simulation_name: str,
|
42
45
|
state: Annotated[dict, InjectedState]) -> str:
|
43
46
|
"""
|
44
47
|
Run the tool.
|
@@ -46,18 +49,24 @@ class AskQuestionTool(BaseTool):
|
|
46
49
|
Args:
|
47
50
|
question (str): The question to ask about the simulation results.
|
48
51
|
state (dict): The state of the graph.
|
49
|
-
|
52
|
+
simulation_name (str): The name assigned to the simulation.
|
50
53
|
|
51
54
|
Returns:
|
52
55
|
str: The answer to the question.
|
53
56
|
"""
|
54
57
|
logger.log(logging.INFO,
|
55
|
-
"Calling ask_question tool %s", question)
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
58
|
+
"Calling ask_question tool %s, %s", question, simulation_name)
|
59
|
+
dic_simulated_data = {}
|
60
|
+
for data in state["dic_simulated_data"]:
|
61
|
+
for key in data:
|
62
|
+
if key not in dic_simulated_data:
|
63
|
+
dic_simulated_data[key] = []
|
64
|
+
dic_simulated_data[key] += [data[key]]
|
65
|
+
# print (dic_simulated_data)
|
66
|
+
df_simulated_data = pd.DataFrame.from_dict(dic_simulated_data)
|
67
|
+
df = pd.DataFrame(
|
68
|
+
df_simulated_data[df_simulated_data['name'] == simulation_name]['data'].iloc[0]
|
69
|
+
)
|
61
70
|
prompt_content = None
|
62
71
|
# if run_manager and 'prompt' in run_manager.metadata:
|
63
72
|
# prompt_content = run_manager.metadata['prompt']
|
@@ -6,14 +6,11 @@ Tool for plotting a custom figure.
|
|
6
6
|
|
7
7
|
import logging
|
8
8
|
from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
|
9
|
-
from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
|
10
9
|
from pydantic import BaseModel, Field
|
11
10
|
import pandas as pd
|
12
|
-
import pandas as pd
|
13
11
|
from langchain_openai import ChatOpenAI
|
14
12
|
from langchain_core.tools import BaseTool
|
15
13
|
from langgraph.prebuilt import InjectedState
|
16
|
-
from langgraph.prebuilt import InjectedState
|
17
14
|
|
18
15
|
# Initialize logger
|
19
16
|
logging.basicConfig(level=logging.INFO)
|
@@ -24,7 +21,7 @@ class CustomPlotterInput(BaseModel):
|
|
24
21
|
Input schema for the PlotImage tool.
|
25
22
|
"""
|
26
23
|
question: str = Field(description="Description of the plot")
|
27
|
-
|
24
|
+
simulation_name: str = Field(description="Name assigned to the simulation")
|
28
25
|
state: Annotated[dict, InjectedState]
|
29
26
|
|
30
27
|
# Note: It's important that every field has type hints.
|
@@ -41,10 +38,10 @@ class CustomPlotterTool(BaseTool):
|
|
41
38
|
description: str = "A tool to make custom plots of the simulation results"
|
42
39
|
args_schema: Type[BaseModel] = CustomPlotterInput
|
43
40
|
response_format: str = "content_and_artifact"
|
44
|
-
response_format: str = "content_and_artifact"
|
45
41
|
|
46
42
|
def _run(self,
|
47
43
|
question: str,
|
44
|
+
simulation_name: str,
|
48
45
|
state: Annotated[dict, InjectedState]
|
49
46
|
) -> Tuple[str, Union[None, List[str]]]:
|
50
47
|
"""
|
@@ -53,17 +50,24 @@ class CustomPlotterTool(BaseTool):
|
|
53
50
|
Args:
|
54
51
|
question (str): The question about the custom plot.
|
55
52
|
state (dict): The state of the graph.
|
56
|
-
question (str): The question about the custom plot.
|
57
|
-
state (dict): The state of the graph.
|
58
53
|
|
59
54
|
Returns:
|
60
55
|
str: The answer to the question
|
61
56
|
"""
|
62
57
|
logger.log(logging.INFO, "Calling custom_plotter tool %s", question)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
58
|
+
dic_simulated_data = {}
|
59
|
+
for data in state["dic_simulated_data"]:
|
60
|
+
for key in data:
|
61
|
+
if key not in dic_simulated_data:
|
62
|
+
dic_simulated_data[key] = []
|
63
|
+
dic_simulated_data[key] += [data[key]]
|
64
|
+
# Create a pandas dataframe from the dictionary
|
65
|
+
df = pd.DataFrame.from_dict(dic_simulated_data)
|
66
|
+
# Get the simulated data for the current tool call
|
67
|
+
df = pd.DataFrame(
|
68
|
+
df[df['name'] == simulation_name]['data'].iloc[0]
|
69
|
+
)
|
70
|
+
# df = pd.DataFrame.from_dict(state['dic_simulated_data'])
|
67
71
|
species_names = df.columns.tolist()
|
68
72
|
# Exclude the time column
|
69
73
|
species_names.remove('Time')
|
@@ -76,7 +80,8 @@ class CustomPlotterTool(BaseTool):
|
|
76
80
|
A list of species based on user question.
|
77
81
|
"""
|
78
82
|
relevant_species: Union[None, List[Literal[*species_names]]] = Field(
|
79
|
-
description="List of species based on user question.
|
83
|
+
description="""List of species based on user question.
|
84
|
+
If no relevant species are found, it will be None.""")
|
80
85
|
# Create an instance of the LLM model
|
81
86
|
llm = ChatOpenAI(model=state['llm_model'], temperature=0)
|
82
87
|
llm_with_structured_output = llm.with_structured_output(CustomHeader)
|
@@ -90,5 +95,6 @@ class CustomPlotterTool(BaseTool):
|
|
90
95
|
logger.info("Extracted species: %s", extracted_species)
|
91
96
|
if len(extracted_species) == 0:
|
92
97
|
return "No species found in the simulation results that matches the user prompt.", None
|
93
|
-
|
94
|
-
|
98
|
+
# Include the time column
|
99
|
+
extracted_species.insert(0, 'Time')
|
100
|
+
return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')
|
@@ -25,12 +25,12 @@ class RequestedModelInfo:
|
|
25
25
|
"""
|
26
26
|
Dataclass for storing the requested model information.
|
27
27
|
"""
|
28
|
-
species: bool = Field(description="Get species from the model.")
|
29
|
-
parameters: bool = Field(description="Get parameters from the model.")
|
30
|
-
compartments: bool = Field(description="Get compartments from the model.")
|
31
|
-
units: bool = Field(description="Get units from the model.")
|
32
|
-
description: bool = Field(description="Get description from the model.")
|
33
|
-
name: bool = Field(description="Get name from the model.")
|
28
|
+
species: bool = Field(description="Get species from the model.", default=False)
|
29
|
+
parameters: bool = Field(description="Get parameters from the model.", default=False)
|
30
|
+
compartments: bool = Field(description="Get compartments from the model.", default=False)
|
31
|
+
units: bool = Field(description="Get units from the model.", default=False)
|
32
|
+
description: bool = Field(description="Get description from the model.", default=False)
|
33
|
+
name: bool = Field(description="Get name from the model.", default=False)
|
34
34
|
|
35
35
|
class GetModelInfoInput(BaseModel):
|
36
36
|
"""
|