aiagents4pharma 1.14.0__py3-none-any.whl → 1.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. aiagents4pharma/configs/config.yaml +2 -1
  2. aiagents4pharma/configs/talk2biomodels/__init__.py +1 -0
  3. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +2 -3
  4. aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
  5. aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
  6. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +21 -7
  7. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
  8. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +67 -69
  9. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
  10. aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
  11. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
  12. aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
  13. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
  14. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
  15. aiagents4pharma/talk2biomodels/tools/ask_question.py +29 -8
  16. aiagents4pharma/talk2biomodels/tools/get_annotation.py +24 -9
  17. aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
  18. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +91 -96
  19. aiagents4pharma/talk2biomodels/tools/simulate_model.py +14 -81
  20. aiagents4pharma/talk2biomodels/tools/steady_state.py +48 -89
  21. {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/METADATA +1 -1
  22. {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/RECORD +25 -16
  23. aiagents4pharma/talk2biomodels/tests/test_langgraph.py +0 -384
  24. {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/LICENSE +0 -0
  25. {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/WHEEL +0 -0
  26. {aiagents4pharma-1.14.0.dist-info → aiagents4pharma-1.14.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,90 @@
1
+ '''
2
+ Test cases for Talk2Biomodels steady state tool.
3
+ '''
4
+
5
+ from langchain_core.messages import HumanMessage, ToolMessage
6
+ from ..agents.t2b_agent import get_app
7
+
8
+ def test_steady_state_tool():
9
+ '''
10
+ Test the steady_state tool.
11
+ '''
12
+ unique_id = 123
13
+ app = get_app(unique_id)
14
+ config = {"configurable": {"thread_id": unique_id}}
15
+ app.update_state(config, {"llm_model": "gpt-4o-mini"})
16
+ #########################################################
17
+ # In this case, we will test if the tool returns an error
18
+ # when the model does not achieve a steady state. The tool
19
+ # status should be "error".
20
+ prompt = """Run a steady state analysis of model 537."""
21
+ # Invoke the agent
22
+ app.invoke(
23
+ {"messages": [HumanMessage(content=prompt)]},
24
+ config=config
25
+ )
26
+ current_state = app.get_state(config)
27
+ reversed_messages = current_state.values["messages"][::-1]
28
+ tool_msg_status = None
29
+ for msg in reversed_messages:
30
+ # Assert that the status of the
31
+ # ToolMessage is "error"
32
+ if isinstance(msg, ToolMessage):
33
+ # print (msg)
34
+ tool_msg_status = msg.status
35
+ break
36
+ assert tool_msg_status == "error"
37
+ #########################################################
38
+ # In this case, we will test if the tool is indeed invoked
39
+ # successfully
40
+ prompt = """Run a steady state analysis of model 64.
41
+ Set the initial concentration of `Pyruvate` to 0.2. The
42
+ concentration of `NAD` resets to 100 every 2 time units."""
43
+ # Invoke the agent
44
+ app.invoke(
45
+ {"messages": [HumanMessage(content=prompt)]},
46
+ config=config
47
+ )
48
+ # Loop through the reversed messages until a
49
+ # ToolMessage is found.
50
+ current_state = app.get_state(config)
51
+ reversed_messages = current_state.values["messages"][::-1]
52
+ steady_state_invoked = False
53
+ for msg in reversed_messages:
54
+ # Assert that the message is a ToolMessage
55
+ # and its status is "error"
56
+ if isinstance(msg, ToolMessage):
57
+ print (msg)
58
+ if msg.name == "steady_state" and msg.status != "error":
59
+ steady_state_invoked = True
60
+ break
61
+ assert steady_state_invoked
62
+ #########################################################
63
+ # In this case, we will test if the `ask_question` tool is
64
+ # invoked upon asking a question about the already generated
65
+ # steady state results
66
+ prompt = """What is the Phosphoenolpyruvate concentration
67
+ at the steady state? Show only the concentration, rounded
68
+ to 2 decimal places. For example, if the concentration is
69
+ 0.123456, your response should be `0.12`. Do not return
70
+ any other information."""
71
+ # Invoke the agent
72
+ response = app.invoke(
73
+ {"messages": [HumanMessage(content=prompt)]},
74
+ config=config
75
+ )
76
+ assistant_msg = response["messages"][-1].content
77
+ current_state = app.get_state(config)
78
+ reversed_messages = current_state.values["messages"][::-1]
79
+ # Loop through the reversed messages until a
80
+ # ToolMessage is found.
81
+ ask_questool_invoked = False
82
+ for msg in reversed_messages:
83
+ # Assert that the message is a ToolMessage
84
+ # and its status is "error"
85
+ if isinstance(msg, ToolMessage):
86
+ if msg.name == "ask_question":
87
+ ask_questool_invoked = True
88
+ break
89
+ assert ask_questool_invoked
90
+ assert "0.06" in assistant_msg
@@ -6,10 +6,11 @@ Tool for asking a question about the simulation results.
6
6
 
7
7
  import logging
8
8
  from typing import Type, Annotated, Literal
9
+ import hydra
10
+ import basico
9
11
  import pandas as pd
10
12
  from pydantic import BaseModel, Field
11
13
  from langchain_core.tools.base import BaseTool
12
- from langchain.agents.agent_types import AgentType
13
14
  from langchain_experimental.agents import create_pandas_dataframe_agent
14
15
  from langchain_openai import ChatOpenAI
15
16
  from langgraph.prebuilt import InjectedState
@@ -64,31 +65,51 @@ class AskQuestionTool(BaseTool):
64
65
  question,
65
66
  question_context,
66
67
  experiment_name)
67
- # print (f'Calling ask_question tool {question}, {question_context}, {experiment_name}')
68
+ # Load hydra configuration
69
+ with hydra.initialize(version_base=None, config_path="../../configs"):
70
+ cfg = hydra.compose(config_name='config',
71
+ overrides=['talk2biomodels/tools/ask_question=default'])
72
+ cfg = cfg.talk2biomodels.tools.ask_question
73
+ # Get the context of the question
74
+ # and based on the context, get the data
75
+ # and prompt content to ask the question
68
76
  if question_context == "steady_state":
69
77
  dic_context = state["dic_steady_state_data"]
78
+ prompt_content = cfg.steady_state_prompt
70
79
  else:
71
80
  dic_context = state["dic_simulated_data"]
81
+ prompt_content = cfg.simulation_prompt
82
+ # Extract the
72
83
  dic_data = {}
73
84
  for data in dic_context:
74
85
  for key in data:
75
86
  if key not in dic_data:
76
87
  dic_data[key] = []
77
88
  dic_data[key] += [data[key]]
78
- # print (dic_data)
89
+ # Create a pandas dataframe of the data
79
90
  df_data = pd.DataFrame.from_dict(dic_data)
91
+ # Extract the data for the experiment
92
+ # matching the experiment name
80
93
  df = pd.DataFrame(
81
94
  df_data[df_data['name'] == experiment_name]['data'].iloc[0]
82
95
  )
83
- prompt_content = None
84
- # if run_manager and 'prompt' in run_manager.metadata:
85
- # prompt_content = run_manager.metadata['prompt']
86
- # Create a pandas dataframe agent with OpenAI
96
+ logger.log(logging.INFO, "Shape of the dataframe: %s", df.shape)
97
+ # # Extract the model units
98
+ # model_units = basico.model_info.get_model_units()
99
+ # Update the prompt content with the model units
100
+ prompt_content += "Following are the model units:\n"
101
+ prompt_content += f"{basico.model_info.get_model_units()}\n\n"
102
+ # Create a pandas dataframe agent
87
103
  df_agent = create_pandas_dataframe_agent(
88
104
  ChatOpenAI(model=state['llm_model']),
89
105
  allow_dangerous_code=True,
90
- agent_type=AgentType.OPENAI_FUNCTIONS,
106
+ agent_type='tool-calling',
91
107
  df=df,
108
+ max_iterations=5,
109
+ include_df_in_prompt=True,
110
+ number_of_head_rows=df.shape[0],
111
+ verbose=True,
92
112
  prefix=prompt_content)
113
+ # Invoke the agent with the question
93
114
  llm_result = df_agent.invoke(question)
94
115
  return llm_result["output"]
@@ -7,6 +7,7 @@ based on the provided model and species names.
7
7
  import math
8
8
  from typing import List, Annotated, Type
9
9
  import logging
10
+ from dataclasses import dataclass
10
11
  from pydantic import BaseModel, Field
11
12
  import basico
12
13
  import pandas as pd
@@ -42,18 +43,29 @@ def prepare_content_msg(species_not_found: List[str],
42
43
  {", ".join(species_without_description)}.'''
43
44
  return content
44
45
 
45
- class GetAnnotationInput(BaseModel):
46
+ @dataclass
47
+ class ArgumentData:
46
48
  """
47
- Input schema for annotation tool.
49
+ Dataclass for storing the argument data.
48
50
  """
49
- sys_bio_model: ModelData = Field(description="model data")
50
- tool_call_id: Annotated[str, InjectedToolCallId]
51
+ experiment_name: Annotated[str, "An AI assigned _ separated name of"
52
+ " the experiment based on human query"
53
+ " and the context of the experiment."
54
+ " This must be set before the experiment is run."]
51
55
  list_species_names: List[str] = Field(
52
- default=[],
56
+ default=None,
53
57
  description='''List of species names to fetch annotations for.
54
58
  If not provided, annotations for all
55
59
  species in the model will be fetched.'''
56
60
  )
61
+
62
+ class GetAnnotationInput(BaseModel):
63
+ """
64
+ Input schema for annotation tool.
65
+ """
66
+ arg_data: ArgumentData = Field(description="argument data")
67
+ sys_bio_model: ModelData = Field(description="model data")
68
+ tool_call_id: Annotated[str, InjectedToolCallId]
57
69
  state: Annotated[dict, InjectedState]
58
70
 
59
71
  class GetAnnotationTool(BaseTool):
@@ -70,14 +82,16 @@ class GetAnnotationTool(BaseTool):
70
82
  return_direct: bool = False
71
83
 
72
84
  def _run(self,
85
+ arg_data: ArgumentData,
73
86
  tool_call_id: Annotated[str, InjectedToolCallId],
74
87
  state: Annotated[dict, InjectedState],
75
- list_species_names: List[str] = None,
76
88
  sys_bio_model: ModelData = None) -> str:
77
89
  """
78
90
  Run the tool.
79
91
  """
80
- logger.info("Running the GetAnnotationTool tool for species %s", list_species_names)
92
+ logger.info("Running the GetAnnotationTool tool for species %s, %s",
93
+ arg_data.list_species_names,
94
+ arg_data.experiment_name)
81
95
 
82
96
  # Prepare the model object
83
97
  sbml_file_path = state['sbml_file_path'][-1] if state['sbml_file_path'] else None
@@ -90,11 +104,11 @@ class GetAnnotationTool(BaseTool):
90
104
  # for example this may happen with model 20
91
105
  raise ValueError("Unable to extract species from the model.")
92
106
  # Fetch annotations for the species names
93
- list_species_names = list_species_names or df_species.index.tolist()
107
+ arg_data.list_species_names = arg_data.list_species_names or df_species.index.tolist()
94
108
 
95
109
  (annotations_df,
96
110
  species_not_found,
97
- species_without_description) = self._fetch_annotations(list_species_names)
111
+ species_without_description) = self._fetch_annotations(arg_data.list_species_names)
98
112
 
99
113
  # Check if annotations are empty
100
114
  # If empty, return a message
@@ -107,6 +121,7 @@ class GetAnnotationTool(BaseTool):
107
121
 
108
122
  # Prepare the simulated data
109
123
  dic_annotations_data = {
124
+ 'name': arg_data.experiment_name,
110
125
  'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
111
126
  'tool_call_id': tool_call_id,
112
127
  'data': annotations_df.to_dict()
@@ -0,0 +1,114 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ A utility module for defining the dataclasses
5
+ for the arguments to set up initial settings
6
+ before the experiment is run.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from typing import Union, List, Optional, Annotated
12
+ from pydantic import Field
13
+ import basico
14
+
15
+ # Initialize logger
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ @dataclass
20
+ class TimeData:
21
+ """
22
+ Dataclass for storing the time data.
23
+ """
24
+ duration: Union[int, float] = Field(
25
+ description="Duration of the simulation",
26
+ default=100)
27
+ interval: Union[int, float] = Field(
28
+ description="The interval is the time step or"
29
+ " the step size of the simulation. It is unrelated"
30
+ " to the step size of species concentration and parameter values.",
31
+ default=100)
32
+
33
+ @dataclass
34
+ class SpeciesInitialData:
35
+ """
36
+ Dataclass for storing the species initial data.
37
+ """
38
+ species_name: List[str] = Field(
39
+ description="List of species whose initial concentration is to be set."
40
+ " This does not include species that reoccur or the species whose"
41
+ " concentration is to be determined/observed at the end of the experiment."
42
+ " Do not hallucinate the species name.",
43
+ default=[])
44
+ species_concentration: List[Union[int, float]] = Field(
45
+ description="List of initial concentrations of species."
46
+ " This does not include species that reoccur or the species whose"
47
+ " concentration is to be determined/observed at the end of the experiment."
48
+ " Do not hallucinate the species concentration.",
49
+ default=[])
50
+
51
+ @dataclass
52
+ class TimeSpeciesNameConcentration:
53
+ """
54
+ Dataclass for storing the time,
55
+ species name, and concentration data.
56
+ """
57
+ time: Union[int, float] = Field(description="time point where the event occurs")
58
+ species_name: str = Field(description="species name")
59
+ species_concentration: Union[int, float] = Field(
60
+ description="species concentration at the time point")
61
+
62
+ @dataclass
63
+ class ReocurringData:
64
+ """
65
+ Dataclass for species that reoccur. In other words,
66
+ the concentration of the species resets to a certain
67
+ value after a certain time interval.
68
+ """
69
+ data: List[TimeSpeciesNameConcentration] = Field(
70
+ description="List of time, name, and concentration data"
71
+ " of species or parameters that reoccur",
72
+ default=[])
73
+
74
+ @dataclass
75
+ class ArgumentData:
76
+ """
77
+ Dataclass for storing the argument data.
78
+ """
79
+ experiment_name: Annotated[str, "An AI assigned _ separated name of"
80
+ " the experiment based on human query"
81
+ " and the context of the experiment."
82
+ " This must be set before the experiment is run."]
83
+ time_data: Optional[TimeData] = Field(
84
+ description="time data",
85
+ default=None)
86
+ species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field(
87
+ description="Data of species whose initial concentration"
88
+ " is to be set before the experiment. This does not include"
89
+ " species that reoccur or the species whose concentration"
90
+ " is to be determined at the end of the experiment.",
91
+ default=None)
92
+ reocurring_data: Optional[ReocurringData] = Field(
93
+ description="List of concentration and time data of species that"
94
+ " reoccur. For example, a species whose concentration resets"
95
+ " to a certain value after a certain time interval.",
96
+ default=None)
97
+
98
+ def add_rec_events(model_object, reocurring_data):
99
+ """
100
+ Add reocurring events to the model.
101
+
102
+ Args:
103
+ model_object: The model object.
104
+ reocurring_data: The reocurring data.
105
+
106
+ Returns:
107
+ None
108
+ """
109
+ for row in reocurring_data.data:
110
+ tp, sn, sc = row.time, row.species_name, row.species_concentration
111
+ basico.add_event(f'{sn}_{tp}',
112
+ f'Time > {tp}',
113
+ [[sn, str(sc)]],
114
+ model=model_object.copasi_model)
@@ -6,7 +6,7 @@ Tool for parameter scan.
6
6
 
7
7
  import logging
8
8
  from dataclasses import dataclass
9
- from typing import Type, Union, List, Annotated
9
+ from typing import Type, Union, List, Annotated, Optional
10
10
  import pandas as pd
11
11
  import basico
12
12
  from pydantic import BaseModel, Field
@@ -16,61 +16,37 @@ from langchain_core.tools import BaseTool
16
16
  from langchain_core.messages import ToolMessage
17
17
  from langchain_core.tools.base import InjectedToolCallId
18
18
  from .load_biomodel import ModelData, load_biomodel
19
+ from .load_arguments import TimeData, SpeciesInitialData
19
20
 
20
21
  # Initialize logger
21
22
  logging.basicConfig(level=logging.INFO)
22
23
  logger = logging.getLogger(__name__)
23
24
 
24
- @dataclass
25
- class TimeData:
26
- """
27
- Dataclass for storing the time data.
28
- """
29
- duration: Union[int, float] = 100
30
- interval: Union[int, float] = 10
31
-
32
- @dataclass
33
- class SpeciesData:
34
- """
35
- Dataclass for storing the species data.
36
- """
37
- species_name: List[str] = Field(description="species name", default=[])
38
- species_concentration: List[Union[int, float]] = Field(
39
- description="initial species concentration",
40
- default=[])
41
-
42
- @dataclass
43
- class TimeSpeciesNameConcentration:
44
- """
45
- Dataclass for storing the time, species name, and concentration data.
46
- """
47
- time: Union[int, float] = Field(description="time point where the event occurs")
48
- species_name: str = Field(description="species name")
49
- species_concentration: Union[int, float] = Field(
50
- description="species concentration at the time point")
51
-
52
- @dataclass
53
- class ReocurringData:
54
- """
55
- Dataclass for species that reoccur. In other words, the concentration
56
- of the species resets to a certain value after a certain time interval.
57
- """
58
- data: List[TimeSpeciesNameConcentration] = Field(
59
- description="time, name, and concentration data of species that reoccur",
60
- default=[])
61
-
62
25
  @dataclass
63
26
  class ParameterScanData(BaseModel):
64
27
  """
65
28
  Dataclass for storing the parameter scan data.
66
29
  """
67
- species_names: List[str] = Field(description="species names to scan",
68
- default=[])
69
- parameter_name: str = Field(description="Parameter name to scan",
70
- default_factory=None)
71
- parameter_values: List[Union[int, float]] = Field(
72
- description="Parameter values to scan",
73
- default_factory=None)
30
+ species_names: List[str] = Field(
31
+ description="species to be observed after each scan."
32
+ " These are the species whose concentration"
33
+ " will be observed after the parameter scan."
34
+ " Do not make up this data.",
35
+ default=[])
36
+ species_parameter_name: str = Field(
37
+ description="Species or parameter name to be scanned."
38
+ " This is the species or parameter whose value will be scanned"
39
+ " over a range of values. This does not include the species"
40
+ " that are to be observed after the scan."
41
+ "Do not make up this data.",
42
+ default=None)
43
+ species_parameter_values: List[Union[int, float]] = Field(
44
+ description="Species or parameter values to be scanned."
45
+ " These are the values of the species or parameters that will be"
46
+ " scanned over a range of values. This does not include the "
47
+ "species that are to be observed after the scan."
48
+ "Do not make up this data.",
49
+ default=None)
74
50
 
75
51
  @dataclass
76
52
  class ArgumentData:
@@ -78,30 +54,20 @@ class ArgumentData:
78
54
  Dataclass for storing the argument data.
79
55
  """
80
56
  time_data: TimeData = Field(description="time data", default=None)
81
- species_data: SpeciesData = Field(
82
- description="species name and initial concentration data",
83
- default=None)
84
- reocurring_data: ReocurringData = Field(
85
- description="""Concentration and time data of species that reoccur
86
- For example, a species whose concentration resets to a certain value
87
- after a certain time interval""")
57
+ species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field(
58
+ description=" This is the initial condition of the model."
59
+ " This does not include species that reoccur or the species"
60
+ " whose concentration is to be determined/observed at the end"
61
+ " of the experiment. This also does not include the species"
62
+ " or the parameter that is to be scanned. Do not make up this data.",
63
+ default=None)
88
64
  parameter_scan_data: ParameterScanData = Field(
89
- description="parameter scan data",
90
- default=None)
91
- scan_name: str = Field(
92
- description="""An AI assigned `_` separated name of
93
- the parameter scan experiment based on human query""")
94
-
95
- def add_rec_events(model_object, reocurring_data):
96
- """
97
- Add reocurring events to the model.
98
- """
99
- for row in reocurring_data.data:
100
- tp, sn, sc = row.time, row.species_name, row.species_concentration
101
- basico.add_event(f'{sn}_{tp}',
102
- f'Time > {tp}',
103
- [[sn, str(sc)]],
104
- model=model_object.copasi_model)
65
+ description="parameter scan data",
66
+ default=None)
67
+ experiment_name: str = Field(
68
+ description="An AI assigned `_` separated unique name of"
69
+ " the parameter scan experiment based on human query."
70
+ " This must be unique for each experiment.")
105
71
 
106
72
  def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_call_id):
107
73
  """
@@ -125,13 +91,18 @@ def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_cal
125
91
  # Prepare the list dictionary of scanned data
126
92
  # that will be passed to the state of the graph
127
93
  list_dic_scanned_data.append({
128
- 'name': arg_data.scan_name+':'+species_name,
94
+ 'name': arg_data.experiment_name+':'+species_name,
129
95
  'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
130
96
  'tool_call_id': tool_call_id,
131
97
  'data': df_param_scan.to_dict()
132
98
  })
133
99
  return list_dic_scanned_data
134
- def run_parameter_scan(model_object, arg_data, dic_species_data, duration, interval) -> dict:
100
+
101
+ def run_parameter_scan(model_object,
102
+ arg_data,
103
+ dic_species_data,
104
+ duration,
105
+ interval) -> dict:
135
106
  """
136
107
  Run parameter scan on the model.
137
108
 
@@ -146,44 +117,61 @@ def run_parameter_scan(model_object, arg_data, dic_species_data, duration, inter
146
117
  dict: Dictionary of parameter scan results. Each key is a species name
147
118
  and each value is a DataFrame containing the results of the parameter scan.
148
119
  """
149
- # Extract all parameter names from the model and verify if the given parameter name is valid
120
+ # Extract all parameter names from the model
150
121
  df_all_parameters = basico.model_info.get_parameters(model=model_object.copasi_model)
151
122
  all_parameters = df_all_parameters.index.tolist()
152
- if arg_data.parameter_scan_data.parameter_name not in all_parameters:
153
- logger.error(
154
- "Invalid parameter name: %s", arg_data.parameter_scan_data.parameter_name)
155
- raise ValueError(
156
- f"Invalid parameter name: {arg_data.parameter_scan_data.parameter_name}")
157
- # Extract all species name from the model and verify if the given species name is valid
123
+
124
+ # Extract all species name from the model
158
125
  df_all_species = basico.model_info.get_species(model=model_object.copasi_model)
159
126
  all_species = df_all_species['display_name'].tolist()
127
+
128
+ # Verify if the given species or parameter names to be scanned are valid
129
+ if arg_data.parameter_scan_data.species_parameter_name not in all_parameters + all_species:
130
+ logger.error(
131
+ "Invalid species or parameter name: %s",
132
+ arg_data.parameter_scan_data.species_parameter_name)
133
+ raise ValueError(
134
+ "Invalid species or parameter name: "
135
+ f"{arg_data.parameter_scan_data.species_parameter_name}.")
136
+
160
137
  # Dictionary to store the parameter scan results
161
138
  dic_param_scan_results = {}
139
+
140
+ # Loop through the species names that are to be observed
162
141
  for species_name in arg_data.parameter_scan_data.species_names:
142
+ # Verify if the given species name to be observed is valid
163
143
  if species_name not in all_species:
164
144
  logger.error("Invalid species name: %s", species_name)
165
- raise ValueError(f"Invalid species name: {species_name}")
145
+ raise ValueError(f"Invalid species name: {species_name}.")
146
+
147
+ # Copy the model object to avoid modifying the original model
148
+ model_object_copy = model_object.model_copy()
149
+
166
150
  # Update the fixed model species and parameters
167
151
  # These are the initial conditions of the model
168
152
  # set by the user
169
- model_object.update_parameters(dic_species_data)
153
+ model_object_copy.update_parameters(dic_species_data)
154
+
170
155
  # Initialize empty DataFrame to store results
171
156
  # of the parameter scan
172
157
  df_param_scan = pd.DataFrame()
173
- for param_value in arg_data.parameter_scan_data.parameter_values:
158
+
159
+ # Loop through the parameter that are to be scanned
160
+ for param_value in arg_data.parameter_scan_data.species_parameter_values:
174
161
  # Update the parameter value in the model
175
- model_object.update_parameters(
176
- {arg_data.parameter_scan_data.parameter_name: param_value})
162
+ model_object_copy.update_parameters(
163
+ {arg_data.parameter_scan_data.species_parameter_name: param_value})
177
164
  # Simulate the model
178
- model_object.simulate(duration=duration, interval=interval)
165
+ model_object_copy.simulate(duration=duration, interval=interval)
179
166
  # If the column name 'Time' is not present in the results DataFrame
180
167
  if 'Time' not in df_param_scan.columns:
181
- df_param_scan['Time'] = model_object.simulation_results['Time']
168
+ df_param_scan['Time'] = model_object_copy.simulation_results['Time']
182
169
  # Add the simulation results to the results DataFrame
183
- col_name = f"{arg_data.parameter_scan_data.parameter_name}_{param_value}"
184
- df_param_scan[col_name] = model_object.simulation_results[species_name]
170
+ col_name = f"{arg_data.parameter_scan_data.species_parameter_name}_{param_value}"
171
+ df_param_scan[col_name] = model_object_copy.simulation_results[species_name]
185
172
 
186
173
  logger.log(logging.INFO, "Parameter scan results with shape %s", df_param_scan.shape)
174
+
187
175
  # Add the results of the parameter scan to the dictionary
188
176
  dic_param_scan_results[species_name] = df_param_scan
189
177
  # return df_param_scan
@@ -210,8 +198,9 @@ class ParameterScanTool(BaseTool):
210
198
  Tool for parameter scan.
211
199
  """
212
200
  name: str = "parameter_scan"
213
- description: str = """A tool to perform parameter scan
214
- of a list of parameter values for a given species."""
201
+ description: str = """A tool to perform scanning of a given
202
+ parameter over a range of values and observe the effect on
203
+ the concentration of a given species"""
215
204
  args_schema: Type[BaseModel] = ParameterScanInput
216
205
 
217
206
  def _run(self,
@@ -245,12 +234,18 @@ class ParameterScanTool(BaseTool):
245
234
  dic_species_data = {}
246
235
  if arg_data:
247
236
  # Prepare the dictionary of species data
248
- if arg_data.species_data is not None:
249
- dic_species_data = dict(zip(arg_data.species_data.species_name,
250
- arg_data.species_data.species_concentration))
251
- # Add reocurring events (if any) to the model
252
- if arg_data.reocurring_data is not None:
253
- add_rec_events(model_object, arg_data.reocurring_data)
237
+ if arg_data.species_to_be_analyzed_before_experiment is not None:
238
+ dic_species_data = dict(
239
+ zip(
240
+ arg_data.species_to_be_analyzed_before_experiment.species_name,
241
+ arg_data.species_to_be_analyzed_before_experiment.species_concentration
242
+ )
243
+ )
244
+
245
+ # # Add reocurring events (if any) to the model
246
+ # if arg_data.reocurring_data is not None:
247
+ # add_rec_events(model_object, arg_data.reocurring_data)
248
+
254
249
  # Set the duration and interval
255
250
  if arg_data.time_data is not None:
256
251
  duration = arg_data.time_data.duration
@@ -284,7 +279,7 @@ class ParameterScanTool(BaseTool):
284
279
  # update the message history
285
280
  "messages": [
286
281
  ToolMessage(
287
- content=f"Parameter scan results of {arg_data.scan_name}",
282
+ content=f"Parameter scan results of {arg_data.experiment_name}",
288
283
  tool_call_id=tool_call_id
289
284
  )
290
285
  ],