aiagents4pharma 1.14.0__py3-none-any.whl → 1.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/configs/config.yaml +2 -1
- aiagents4pharma/configs/talk2biomodels/__init__.py +1 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +2 -3
- aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
- aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +21 -7
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +67 -69
- aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
- aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
- aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
- aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
- aiagents4pharma/talk2biomodels/tools/ask_question.py +29 -8
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +24 -9
- aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +91 -96
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +14 -81
- aiagents4pharma/talk2biomodels/tools/steady_state.py +48 -89
- {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/METADATA +1 -1
- {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/RECORD +25 -16
- aiagents4pharma/talk2biomodels/tests/test_langgraph.py +0 -384
- {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
'''
|
2
|
+
Test cases for Talk2Biomodels steady state tool.
|
3
|
+
'''
|
4
|
+
|
5
|
+
from langchain_core.messages import HumanMessage, ToolMessage
|
6
|
+
from ..agents.t2b_agent import get_app
|
7
|
+
|
8
|
+
def test_steady_state_tool():
|
9
|
+
'''
|
10
|
+
Test the steady_state tool.
|
11
|
+
'''
|
12
|
+
unique_id = 123
|
13
|
+
app = get_app(unique_id)
|
14
|
+
config = {"configurable": {"thread_id": unique_id}}
|
15
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
16
|
+
#########################################################
|
17
|
+
# In this case, we will test if the tool returns an error
|
18
|
+
# when the model does not achieve a steady state. The tool
|
19
|
+
# status should be "error".
|
20
|
+
prompt = """Run a steady state analysis of model 537."""
|
21
|
+
# Invoke the agent
|
22
|
+
app.invoke(
|
23
|
+
{"messages": [HumanMessage(content=prompt)]},
|
24
|
+
config=config
|
25
|
+
)
|
26
|
+
current_state = app.get_state(config)
|
27
|
+
reversed_messages = current_state.values["messages"][::-1]
|
28
|
+
tool_msg_status = None
|
29
|
+
for msg in reversed_messages:
|
30
|
+
# Assert that the status of the
|
31
|
+
# ToolMessage is "error"
|
32
|
+
if isinstance(msg, ToolMessage):
|
33
|
+
# print (msg)
|
34
|
+
tool_msg_status = msg.status
|
35
|
+
break
|
36
|
+
assert tool_msg_status == "error"
|
37
|
+
#########################################################
|
38
|
+
# In this case, we will test if the tool is indeed invoked
|
39
|
+
# successfully
|
40
|
+
prompt = """Run a steady state analysis of model 64.
|
41
|
+
Set the initial concentration of `Pyruvate` to 0.2. The
|
42
|
+
concentration of `NAD` resets to 100 every 2 time units."""
|
43
|
+
# Invoke the agent
|
44
|
+
app.invoke(
|
45
|
+
{"messages": [HumanMessage(content=prompt)]},
|
46
|
+
config=config
|
47
|
+
)
|
48
|
+
# Loop through the reversed messages until a
|
49
|
+
# ToolMessage is found.
|
50
|
+
current_state = app.get_state(config)
|
51
|
+
reversed_messages = current_state.values["messages"][::-1]
|
52
|
+
steady_state_invoked = False
|
53
|
+
for msg in reversed_messages:
|
54
|
+
# Assert that the message is a ToolMessage
|
55
|
+
# and its status is "error"
|
56
|
+
if isinstance(msg, ToolMessage):
|
57
|
+
print (msg)
|
58
|
+
if msg.name == "steady_state" and msg.status != "error":
|
59
|
+
steady_state_invoked = True
|
60
|
+
break
|
61
|
+
assert steady_state_invoked
|
62
|
+
#########################################################
|
63
|
+
# In this case, we will test if the `ask_question` tool is
|
64
|
+
# invoked upon asking a question about the already generated
|
65
|
+
# steady state results
|
66
|
+
prompt = """What is the Phosphoenolpyruvate concentration
|
67
|
+
at the steady state? Show only the concentration, rounded
|
68
|
+
to 2 decimal places. For example, if the concentration is
|
69
|
+
0.123456, your response should be `0.12`. Do not return
|
70
|
+
any other information."""
|
71
|
+
# Invoke the agent
|
72
|
+
response = app.invoke(
|
73
|
+
{"messages": [HumanMessage(content=prompt)]},
|
74
|
+
config=config
|
75
|
+
)
|
76
|
+
assistant_msg = response["messages"][-1].content
|
77
|
+
current_state = app.get_state(config)
|
78
|
+
reversed_messages = current_state.values["messages"][::-1]
|
79
|
+
# Loop through the reversed messages until a
|
80
|
+
# ToolMessage is found.
|
81
|
+
ask_questool_invoked = False
|
82
|
+
for msg in reversed_messages:
|
83
|
+
# Assert that the message is a ToolMessage
|
84
|
+
# and its status is "error"
|
85
|
+
if isinstance(msg, ToolMessage):
|
86
|
+
if msg.name == "ask_question":
|
87
|
+
ask_questool_invoked = True
|
88
|
+
break
|
89
|
+
assert ask_questool_invoked
|
90
|
+
assert "0.06" in assistant_msg
|
@@ -6,10 +6,11 @@ Tool for asking a question about the simulation results.
|
|
6
6
|
|
7
7
|
import logging
|
8
8
|
from typing import Type, Annotated, Literal
|
9
|
+
import hydra
|
10
|
+
import basico
|
9
11
|
import pandas as pd
|
10
12
|
from pydantic import BaseModel, Field
|
11
13
|
from langchain_core.tools.base import BaseTool
|
12
|
-
from langchain.agents.agent_types import AgentType
|
13
14
|
from langchain_experimental.agents import create_pandas_dataframe_agent
|
14
15
|
from langchain_openai import ChatOpenAI
|
15
16
|
from langgraph.prebuilt import InjectedState
|
@@ -64,31 +65,51 @@ class AskQuestionTool(BaseTool):
|
|
64
65
|
question,
|
65
66
|
question_context,
|
66
67
|
experiment_name)
|
67
|
-
#
|
68
|
+
# Load hydra configuration
|
69
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
70
|
+
cfg = hydra.compose(config_name='config',
|
71
|
+
overrides=['talk2biomodels/tools/ask_question=default'])
|
72
|
+
cfg = cfg.talk2biomodels.tools.ask_question
|
73
|
+
# Get the context of the question
|
74
|
+
# and based on the context, get the data
|
75
|
+
# and prompt content to ask the question
|
68
76
|
if question_context == "steady_state":
|
69
77
|
dic_context = state["dic_steady_state_data"]
|
78
|
+
prompt_content = cfg.steady_state_prompt
|
70
79
|
else:
|
71
80
|
dic_context = state["dic_simulated_data"]
|
81
|
+
prompt_content = cfg.simulation_prompt
|
82
|
+
# Extract the
|
72
83
|
dic_data = {}
|
73
84
|
for data in dic_context:
|
74
85
|
for key in data:
|
75
86
|
if key not in dic_data:
|
76
87
|
dic_data[key] = []
|
77
88
|
dic_data[key] += [data[key]]
|
78
|
-
#
|
89
|
+
# Create a pandas dataframe of the data
|
79
90
|
df_data = pd.DataFrame.from_dict(dic_data)
|
91
|
+
# Extract the data for the experiment
|
92
|
+
# matching the experiment name
|
80
93
|
df = pd.DataFrame(
|
81
94
|
df_data[df_data['name'] == experiment_name]['data'].iloc[0]
|
82
95
|
)
|
83
|
-
|
84
|
-
#
|
85
|
-
#
|
86
|
-
#
|
96
|
+
logger.log(logging.INFO, "Shape of the dataframe: %s", df.shape)
|
97
|
+
# # Extract the model units
|
98
|
+
# model_units = basico.model_info.get_model_units()
|
99
|
+
# Update the prompt content with the model units
|
100
|
+
prompt_content += "Following are the model units:\n"
|
101
|
+
prompt_content += f"{basico.model_info.get_model_units()}\n\n"
|
102
|
+
# Create a pandas dataframe agent
|
87
103
|
df_agent = create_pandas_dataframe_agent(
|
88
104
|
ChatOpenAI(model=state['llm_model']),
|
89
105
|
allow_dangerous_code=True,
|
90
|
-
agent_type=
|
106
|
+
agent_type='tool-calling',
|
91
107
|
df=df,
|
108
|
+
max_iterations=5,
|
109
|
+
include_df_in_prompt=True,
|
110
|
+
number_of_head_rows=df.shape[0],
|
111
|
+
verbose=True,
|
92
112
|
prefix=prompt_content)
|
113
|
+
# Invoke the agent with the question
|
93
114
|
llm_result = df_agent.invoke(question)
|
94
115
|
return llm_result["output"]
|
@@ -7,6 +7,7 @@ based on the provided model and species names.
|
|
7
7
|
import math
|
8
8
|
from typing import List, Annotated, Type
|
9
9
|
import logging
|
10
|
+
from dataclasses import dataclass
|
10
11
|
from pydantic import BaseModel, Field
|
11
12
|
import basico
|
12
13
|
import pandas as pd
|
@@ -42,18 +43,29 @@ def prepare_content_msg(species_not_found: List[str],
|
|
42
43
|
{", ".join(species_without_description)}.'''
|
43
44
|
return content
|
44
45
|
|
45
|
-
|
46
|
+
@dataclass
|
47
|
+
class ArgumentData:
|
46
48
|
"""
|
47
|
-
|
49
|
+
Dataclass for storing the argument data.
|
48
50
|
"""
|
49
|
-
|
50
|
-
|
51
|
+
experiment_name: Annotated[str, "An AI assigned _ separated name of"
|
52
|
+
" the experiment based on human query"
|
53
|
+
" and the context of the experiment."
|
54
|
+
" This must be set before the experiment is run."]
|
51
55
|
list_species_names: List[str] = Field(
|
52
|
-
default=
|
56
|
+
default=None,
|
53
57
|
description='''List of species names to fetch annotations for.
|
54
58
|
If not provided, annotations for all
|
55
59
|
species in the model will be fetched.'''
|
56
60
|
)
|
61
|
+
|
62
|
+
class GetAnnotationInput(BaseModel):
|
63
|
+
"""
|
64
|
+
Input schema for annotation tool.
|
65
|
+
"""
|
66
|
+
arg_data: ArgumentData = Field(description="argument data")
|
67
|
+
sys_bio_model: ModelData = Field(description="model data")
|
68
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
57
69
|
state: Annotated[dict, InjectedState]
|
58
70
|
|
59
71
|
class GetAnnotationTool(BaseTool):
|
@@ -70,14 +82,16 @@ class GetAnnotationTool(BaseTool):
|
|
70
82
|
return_direct: bool = False
|
71
83
|
|
72
84
|
def _run(self,
|
85
|
+
arg_data: ArgumentData,
|
73
86
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
74
87
|
state: Annotated[dict, InjectedState],
|
75
|
-
list_species_names: List[str] = None,
|
76
88
|
sys_bio_model: ModelData = None) -> str:
|
77
89
|
"""
|
78
90
|
Run the tool.
|
79
91
|
"""
|
80
|
-
logger.info("Running the GetAnnotationTool tool for species %s",
|
92
|
+
logger.info("Running the GetAnnotationTool tool for species %s, %s",
|
93
|
+
arg_data.list_species_names,
|
94
|
+
arg_data.experiment_name)
|
81
95
|
|
82
96
|
# Prepare the model object
|
83
97
|
sbml_file_path = state['sbml_file_path'][-1] if state['sbml_file_path'] else None
|
@@ -90,11 +104,11 @@ class GetAnnotationTool(BaseTool):
|
|
90
104
|
# for example this may happen with model 20
|
91
105
|
raise ValueError("Unable to extract species from the model.")
|
92
106
|
# Fetch annotations for the species names
|
93
|
-
list_species_names = list_species_names or df_species.index.tolist()
|
107
|
+
arg_data.list_species_names = arg_data.list_species_names or df_species.index.tolist()
|
94
108
|
|
95
109
|
(annotations_df,
|
96
110
|
species_not_found,
|
97
|
-
species_without_description) = self._fetch_annotations(list_species_names)
|
111
|
+
species_without_description) = self._fetch_annotations(arg_data.list_species_names)
|
98
112
|
|
99
113
|
# Check if annotations are empty
|
100
114
|
# If empty, return a message
|
@@ -107,6 +121,7 @@ class GetAnnotationTool(BaseTool):
|
|
107
121
|
|
108
122
|
# Prepare the simulated data
|
109
123
|
dic_annotations_data = {
|
124
|
+
'name': arg_data.experiment_name,
|
110
125
|
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
|
111
126
|
'tool_call_id': tool_call_id,
|
112
127
|
'data': annotations_df.to_dict()
|
@@ -0,0 +1,114 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
A utility module for defining the dataclasses
|
5
|
+
for the arguments to set up initial settings
|
6
|
+
before the experiment is run.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from typing import Union, List, Optional, Annotated
|
12
|
+
from pydantic import Field
|
13
|
+
import basico
|
14
|
+
|
15
|
+
# Initialize logger
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class TimeData:
|
21
|
+
"""
|
22
|
+
Dataclass for storing the time data.
|
23
|
+
"""
|
24
|
+
duration: Union[int, float] = Field(
|
25
|
+
description="Duration of the simulation",
|
26
|
+
default=100)
|
27
|
+
interval: Union[int, float] = Field(
|
28
|
+
description="The interval is the time step or"
|
29
|
+
" the step size of the simulation. It is unrelated"
|
30
|
+
" to the step size of species concentration and parameter values.",
|
31
|
+
default=100)
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class SpeciesInitialData:
|
35
|
+
"""
|
36
|
+
Dataclass for storing the species initial data.
|
37
|
+
"""
|
38
|
+
species_name: List[str] = Field(
|
39
|
+
description="List of species whose initial concentration is to be set."
|
40
|
+
" This does not include species that reoccur or the species whose"
|
41
|
+
" concentration is to be determined/observed at the end of the experiment."
|
42
|
+
" Do not hallucinate the species name.",
|
43
|
+
default=[])
|
44
|
+
species_concentration: List[Union[int, float]] = Field(
|
45
|
+
description="List of initial concentrations of species."
|
46
|
+
" This does not include species that reoccur or the species whose"
|
47
|
+
" concentration is to be determined/observed at the end of the experiment."
|
48
|
+
" Do not hallucinate the species concentration.",
|
49
|
+
default=[])
|
50
|
+
|
51
|
+
@dataclass
|
52
|
+
class TimeSpeciesNameConcentration:
|
53
|
+
"""
|
54
|
+
Dataclass for storing the time,
|
55
|
+
species name, and concentration data.
|
56
|
+
"""
|
57
|
+
time: Union[int, float] = Field(description="time point where the event occurs")
|
58
|
+
species_name: str = Field(description="species name")
|
59
|
+
species_concentration: Union[int, float] = Field(
|
60
|
+
description="species concentration at the time point")
|
61
|
+
|
62
|
+
@dataclass
|
63
|
+
class ReocurringData:
|
64
|
+
"""
|
65
|
+
Dataclass for species that reoccur. In other words,
|
66
|
+
the concentration of the species resets to a certain
|
67
|
+
value after a certain time interval.
|
68
|
+
"""
|
69
|
+
data: List[TimeSpeciesNameConcentration] = Field(
|
70
|
+
description="List of time, name, and concentration data"
|
71
|
+
" of species or parameters that reoccur",
|
72
|
+
default=[])
|
73
|
+
|
74
|
+
@dataclass
|
75
|
+
class ArgumentData:
|
76
|
+
"""
|
77
|
+
Dataclass for storing the argument data.
|
78
|
+
"""
|
79
|
+
experiment_name: Annotated[str, "An AI assigned _ separated name of"
|
80
|
+
" the experiment based on human query"
|
81
|
+
" and the context of the experiment."
|
82
|
+
" This must be set before the experiment is run."]
|
83
|
+
time_data: Optional[TimeData] = Field(
|
84
|
+
description="time data",
|
85
|
+
default=None)
|
86
|
+
species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field(
|
87
|
+
description="Data of species whose initial concentration"
|
88
|
+
" is to be set before the experiment. This does not include"
|
89
|
+
" species that reoccur or the species whose concentration"
|
90
|
+
" is to be determined at the end of the experiment.",
|
91
|
+
default=None)
|
92
|
+
reocurring_data: Optional[ReocurringData] = Field(
|
93
|
+
description="List of concentration and time data of species that"
|
94
|
+
" reoccur. For example, a species whose concentration resets"
|
95
|
+
" to a certain value after a certain time interval.",
|
96
|
+
default=None)
|
97
|
+
|
98
|
+
def add_rec_events(model_object, reocurring_data):
|
99
|
+
"""
|
100
|
+
Add reocurring events to the model.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
model_object: The model object.
|
104
|
+
reocurring_data: The reocurring data.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
None
|
108
|
+
"""
|
109
|
+
for row in reocurring_data.data:
|
110
|
+
tp, sn, sc = row.time, row.species_name, row.species_concentration
|
111
|
+
basico.add_event(f'{sn}_{tp}',
|
112
|
+
f'Time > {tp}',
|
113
|
+
[[sn, str(sc)]],
|
114
|
+
model=model_object.copasi_model)
|
@@ -6,7 +6,7 @@ Tool for parameter scan.
|
|
6
6
|
|
7
7
|
import logging
|
8
8
|
from dataclasses import dataclass
|
9
|
-
from typing import Type, Union, List, Annotated
|
9
|
+
from typing import Type, Union, List, Annotated, Optional
|
10
10
|
import pandas as pd
|
11
11
|
import basico
|
12
12
|
from pydantic import BaseModel, Field
|
@@ -16,61 +16,37 @@ from langchain_core.tools import BaseTool
|
|
16
16
|
from langchain_core.messages import ToolMessage
|
17
17
|
from langchain_core.tools.base import InjectedToolCallId
|
18
18
|
from .load_biomodel import ModelData, load_biomodel
|
19
|
+
from .load_arguments import TimeData, SpeciesInitialData
|
19
20
|
|
20
21
|
# Initialize logger
|
21
22
|
logging.basicConfig(level=logging.INFO)
|
22
23
|
logger = logging.getLogger(__name__)
|
23
24
|
|
24
|
-
@dataclass
|
25
|
-
class TimeData:
|
26
|
-
"""
|
27
|
-
Dataclass for storing the time data.
|
28
|
-
"""
|
29
|
-
duration: Union[int, float] = 100
|
30
|
-
interval: Union[int, float] = 10
|
31
|
-
|
32
|
-
@dataclass
|
33
|
-
class SpeciesData:
|
34
|
-
"""
|
35
|
-
Dataclass for storing the species data.
|
36
|
-
"""
|
37
|
-
species_name: List[str] = Field(description="species name", default=[])
|
38
|
-
species_concentration: List[Union[int, float]] = Field(
|
39
|
-
description="initial species concentration",
|
40
|
-
default=[])
|
41
|
-
|
42
|
-
@dataclass
|
43
|
-
class TimeSpeciesNameConcentration:
|
44
|
-
"""
|
45
|
-
Dataclass for storing the time, species name, and concentration data.
|
46
|
-
"""
|
47
|
-
time: Union[int, float] = Field(description="time point where the event occurs")
|
48
|
-
species_name: str = Field(description="species name")
|
49
|
-
species_concentration: Union[int, float] = Field(
|
50
|
-
description="species concentration at the time point")
|
51
|
-
|
52
|
-
@dataclass
|
53
|
-
class ReocurringData:
|
54
|
-
"""
|
55
|
-
Dataclass for species that reoccur. In other words, the concentration
|
56
|
-
of the species resets to a certain value after a certain time interval.
|
57
|
-
"""
|
58
|
-
data: List[TimeSpeciesNameConcentration] = Field(
|
59
|
-
description="time, name, and concentration data of species that reoccur",
|
60
|
-
default=[])
|
61
|
-
|
62
25
|
@dataclass
|
63
26
|
class ParameterScanData(BaseModel):
|
64
27
|
"""
|
65
28
|
Dataclass for storing the parameter scan data.
|
66
29
|
"""
|
67
|
-
species_names: List[str] = Field(
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
30
|
+
species_names: List[str] = Field(
|
31
|
+
description="species to be observed after each scan."
|
32
|
+
" These are the species whose concentration"
|
33
|
+
" will be observed after the parameter scan."
|
34
|
+
" Do not make up this data.",
|
35
|
+
default=[])
|
36
|
+
species_parameter_name: str = Field(
|
37
|
+
description="Species or parameter name to be scanned."
|
38
|
+
" This is the species or parameter whose value will be scanned"
|
39
|
+
" over a range of values. This does not include the species"
|
40
|
+
" that are to be observed after the scan."
|
41
|
+
"Do not make up this data.",
|
42
|
+
default=None)
|
43
|
+
species_parameter_values: List[Union[int, float]] = Field(
|
44
|
+
description="Species or parameter values to be scanned."
|
45
|
+
" These are the values of the species or parameters that will be"
|
46
|
+
" scanned over a range of values. This does not include the "
|
47
|
+
"species that are to be observed after the scan."
|
48
|
+
"Do not make up this data.",
|
49
|
+
default=None)
|
74
50
|
|
75
51
|
@dataclass
|
76
52
|
class ArgumentData:
|
@@ -78,30 +54,20 @@ class ArgumentData:
|
|
78
54
|
Dataclass for storing the argument data.
|
79
55
|
"""
|
80
56
|
time_data: TimeData = Field(description="time data", default=None)
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
57
|
+
species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field(
|
58
|
+
description=" This is the initial condition of the model."
|
59
|
+
" This does not include species that reoccur or the species"
|
60
|
+
" whose concentration is to be determined/observed at the end"
|
61
|
+
" of the experiment. This also does not include the species"
|
62
|
+
" or the parameter that is to be scanned. Do not make up this data.",
|
63
|
+
default=None)
|
88
64
|
parameter_scan_data: ParameterScanData = Field(
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
def add_rec_events(model_object, reocurring_data):
|
96
|
-
"""
|
97
|
-
Add reocurring events to the model.
|
98
|
-
"""
|
99
|
-
for row in reocurring_data.data:
|
100
|
-
tp, sn, sc = row.time, row.species_name, row.species_concentration
|
101
|
-
basico.add_event(f'{sn}_{tp}',
|
102
|
-
f'Time > {tp}',
|
103
|
-
[[sn, str(sc)]],
|
104
|
-
model=model_object.copasi_model)
|
65
|
+
description="parameter scan data",
|
66
|
+
default=None)
|
67
|
+
experiment_name: str = Field(
|
68
|
+
description="An AI assigned `_` separated unique name of"
|
69
|
+
" the parameter scan experiment based on human query."
|
70
|
+
" This must be unique for each experiment.")
|
105
71
|
|
106
72
|
def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_call_id):
|
107
73
|
"""
|
@@ -125,13 +91,18 @@ def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_cal
|
|
125
91
|
# Prepare the list dictionary of scanned data
|
126
92
|
# that will be passed to the state of the graph
|
127
93
|
list_dic_scanned_data.append({
|
128
|
-
'name': arg_data.
|
94
|
+
'name': arg_data.experiment_name+':'+species_name,
|
129
95
|
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
|
130
96
|
'tool_call_id': tool_call_id,
|
131
97
|
'data': df_param_scan.to_dict()
|
132
98
|
})
|
133
99
|
return list_dic_scanned_data
|
134
|
-
|
100
|
+
|
101
|
+
def run_parameter_scan(model_object,
|
102
|
+
arg_data,
|
103
|
+
dic_species_data,
|
104
|
+
duration,
|
105
|
+
interval) -> dict:
|
135
106
|
"""
|
136
107
|
Run parameter scan on the model.
|
137
108
|
|
@@ -146,44 +117,61 @@ def run_parameter_scan(model_object, arg_data, dic_species_data, duration, inter
|
|
146
117
|
dict: Dictionary of parameter scan results. Each key is a species name
|
147
118
|
and each value is a DataFrame containing the results of the parameter scan.
|
148
119
|
"""
|
149
|
-
# Extract all parameter names from the model
|
120
|
+
# Extract all parameter names from the model
|
150
121
|
df_all_parameters = basico.model_info.get_parameters(model=model_object.copasi_model)
|
151
122
|
all_parameters = df_all_parameters.index.tolist()
|
152
|
-
|
153
|
-
|
154
|
-
"Invalid parameter name: %s", arg_data.parameter_scan_data.parameter_name)
|
155
|
-
raise ValueError(
|
156
|
-
f"Invalid parameter name: {arg_data.parameter_scan_data.parameter_name}")
|
157
|
-
# Extract all species name from the model and verify if the given species name is valid
|
123
|
+
|
124
|
+
# Extract all species name from the model
|
158
125
|
df_all_species = basico.model_info.get_species(model=model_object.copasi_model)
|
159
126
|
all_species = df_all_species['display_name'].tolist()
|
127
|
+
|
128
|
+
# Verify if the given species or parameter names to be scanned are valid
|
129
|
+
if arg_data.parameter_scan_data.species_parameter_name not in all_parameters + all_species:
|
130
|
+
logger.error(
|
131
|
+
"Invalid species or parameter name: %s",
|
132
|
+
arg_data.parameter_scan_data.species_parameter_name)
|
133
|
+
raise ValueError(
|
134
|
+
"Invalid species or parameter name: "
|
135
|
+
f"{arg_data.parameter_scan_data.species_parameter_name}.")
|
136
|
+
|
160
137
|
# Dictionary to store the parameter scan results
|
161
138
|
dic_param_scan_results = {}
|
139
|
+
|
140
|
+
# Loop through the species names that are to be observed
|
162
141
|
for species_name in arg_data.parameter_scan_data.species_names:
|
142
|
+
# Verify if the given species name to be observed is valid
|
163
143
|
if species_name not in all_species:
|
164
144
|
logger.error("Invalid species name: %s", species_name)
|
165
|
-
raise ValueError(f"Invalid species name: {species_name}")
|
145
|
+
raise ValueError(f"Invalid species name: {species_name}.")
|
146
|
+
|
147
|
+
# Copy the model object to avoid modifying the original model
|
148
|
+
model_object_copy = model_object.model_copy()
|
149
|
+
|
166
150
|
# Update the fixed model species and parameters
|
167
151
|
# These are the initial conditions of the model
|
168
152
|
# set by the user
|
169
|
-
|
153
|
+
model_object_copy.update_parameters(dic_species_data)
|
154
|
+
|
170
155
|
# Initialize empty DataFrame to store results
|
171
156
|
# of the parameter scan
|
172
157
|
df_param_scan = pd.DataFrame()
|
173
|
-
|
158
|
+
|
159
|
+
# Loop through the parameter that are to be scanned
|
160
|
+
for param_value in arg_data.parameter_scan_data.species_parameter_values:
|
174
161
|
# Update the parameter value in the model
|
175
|
-
|
176
|
-
{arg_data.parameter_scan_data.
|
162
|
+
model_object_copy.update_parameters(
|
163
|
+
{arg_data.parameter_scan_data.species_parameter_name: param_value})
|
177
164
|
# Simulate the model
|
178
|
-
|
165
|
+
model_object_copy.simulate(duration=duration, interval=interval)
|
179
166
|
# If the column name 'Time' is not present in the results DataFrame
|
180
167
|
if 'Time' not in df_param_scan.columns:
|
181
|
-
df_param_scan['Time'] =
|
168
|
+
df_param_scan['Time'] = model_object_copy.simulation_results['Time']
|
182
169
|
# Add the simulation results to the results DataFrame
|
183
|
-
col_name = f"{arg_data.parameter_scan_data.
|
184
|
-
df_param_scan[col_name] =
|
170
|
+
col_name = f"{arg_data.parameter_scan_data.species_parameter_name}_{param_value}"
|
171
|
+
df_param_scan[col_name] = model_object_copy.simulation_results[species_name]
|
185
172
|
|
186
173
|
logger.log(logging.INFO, "Parameter scan results with shape %s", df_param_scan.shape)
|
174
|
+
|
187
175
|
# Add the results of the parameter scan to the dictionary
|
188
176
|
dic_param_scan_results[species_name] = df_param_scan
|
189
177
|
# return df_param_scan
|
@@ -210,8 +198,9 @@ class ParameterScanTool(BaseTool):
|
|
210
198
|
Tool for parameter scan.
|
211
199
|
"""
|
212
200
|
name: str = "parameter_scan"
|
213
|
-
description: str = """A tool to perform
|
214
|
-
|
201
|
+
description: str = """A tool to perform scanning of a given
|
202
|
+
parameter over a range of values and observe the effect on
|
203
|
+
the concentration of a given species"""
|
215
204
|
args_schema: Type[BaseModel] = ParameterScanInput
|
216
205
|
|
217
206
|
def _run(self,
|
@@ -245,12 +234,18 @@ class ParameterScanTool(BaseTool):
|
|
245
234
|
dic_species_data = {}
|
246
235
|
if arg_data:
|
247
236
|
# Prepare the dictionary of species data
|
248
|
-
if arg_data.
|
249
|
-
dic_species_data = dict(
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
237
|
+
if arg_data.species_to_be_analyzed_before_experiment is not None:
|
238
|
+
dic_species_data = dict(
|
239
|
+
zip(
|
240
|
+
arg_data.species_to_be_analyzed_before_experiment.species_name,
|
241
|
+
arg_data.species_to_be_analyzed_before_experiment.species_concentration
|
242
|
+
)
|
243
|
+
)
|
244
|
+
|
245
|
+
# # Add reocurring events (if any) to the model
|
246
|
+
# if arg_data.reocurring_data is not None:
|
247
|
+
# add_rec_events(model_object, arg_data.reocurring_data)
|
248
|
+
|
254
249
|
# Set the duration and interval
|
255
250
|
if arg_data.time_data is not None:
|
256
251
|
duration = arg_data.time_data.duration
|
@@ -284,7 +279,7 @@ class ParameterScanTool(BaseTool):
|
|
284
279
|
# update the message history
|
285
280
|
"messages": [
|
286
281
|
ToolMessage(
|
287
|
-
content=f"Parameter scan results of {arg_data.
|
282
|
+
content=f"Parameter scan results of {arg_data.experiment_name}",
|
288
283
|
tool_call_id=tool_call_id
|
289
284
|
)
|
290
285
|
],
|