aiagents4pharma 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ state_modifier: >
3
3
  You are Talk2BioModels agent.
4
4
  If the user asks for the uploaded model,
5
5
  then pass the use_uploaded_model argument
6
- as True. If the user asks for simulation,
7
- then suggest a value for the `simulation_name`
6
+ as True. If the user asks for simulation
7
+ or steady state, suggest a value for the
8
+ `simulation_name` or `steadystate_name`
8
9
  argument.
@@ -17,6 +17,7 @@ from ..tools.simulate_model import SimulateModelTool
17
17
  from ..tools.custom_plotter import CustomPlotterTool
18
18
  from ..tools.ask_question import AskQuestionTool
19
19
  from ..tools.parameter_scan import ParameterScanTool
20
+ from ..tools.steady_state import SteadyStateTool
20
21
  from ..states.state_talk2biomodels import Talk2Biomodels
21
22
 
22
23
  # Initialize logger
@@ -42,6 +43,7 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
42
43
  CustomPlotterTool(),
43
44
  SearchModelsTool(),
44
45
  GetModelInfoTool(),
46
+ SteadyStateTool(),
45
47
  ParameterScanTool()
46
48
  ])
47
49
 
@@ -22,3 +22,4 @@ class Talk2Biomodels(AgentState):
22
22
  sbml_file_path: Annotated[list, operator.add]
23
23
  dic_simulated_data: Annotated[list[dict], operator.add]
24
24
  dic_scanned_data: Annotated[list[dict], operator.add]
25
+ dic_steady_state_data: Annotated[list[dict], operator.add]
@@ -98,7 +98,7 @@ def test_simulate_model_tool():
98
98
  # Upload a model to the state
99
99
  app.update_state(config,
100
100
  {"sbml_file_path": ["aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml"]})
101
- prompt = "Simulate models 64 and the uploaded model"
101
+ prompt = "Simulate model 64 and the uploaded model"
102
102
  # Invoke the agent
103
103
  app.invoke(
104
104
  {"messages": [HumanMessage(content=prompt)]},
@@ -145,10 +145,9 @@ def test_param_scan_tool():
145
145
  app.update_state(config, {"llm_model": "gpt-4o-mini"})
146
146
  prompt = """How will the value of Ab in model 537 change
147
147
  if the param kIL6Rbind is varied from 1 to 100 in steps of 10?
148
- Set the initial `DoseQ2W` concentration to 300.
149
- Reset the IL6{serum} concentration to 100 every 500 hours.
150
- Assume that the model is simulated for 2016 hours with
151
- an interval of 2016."""
148
+ Set the initial `DoseQ2W` concentration to 300. Also, reset
149
+ the IL6{serum} concentration to 100 every 500 hours and assume
150
+ that the model is simulated for 2016 hours with an interval of 2016."""
152
151
  # Invoke the agent
153
152
  app.invoke(
154
153
  {"messages": [HumanMessage(content=prompt)]},
@@ -181,11 +180,95 @@ def test_param_scan_tool():
181
180
  assert any((df["status"] == "success") &
182
181
  (df["name"] == "get_modelinfo"))
183
182
 
183
+ def test_steady_state_tool():
184
+ '''
185
+ Test the steady_state tool.
186
+ '''
187
+ unique_id = 123
188
+ app = get_app(unique_id)
189
+ config = {"configurable": {"thread_id": unique_id}}
190
+ app.update_state(config, {"llm_model": "gpt-4o-mini"})
191
+ #########################################################
192
+ # In this case, we will test if the tool returns an error
193
+ # when the model does not achieve a steady state. The tool
194
+ # status should be "error".
195
+ prompt = """Run a steady state analysis of model 537."""
196
+ # Invoke the agent
197
+ app.invoke(
198
+ {"messages": [HumanMessage(content=prompt)]},
199
+ config=config
200
+ )
201
+ current_state = app.get_state(config)
202
+ reversed_messages = current_state.values["messages"][::-1]
203
+ tool_msg_status = None
204
+ for msg in reversed_messages:
205
+ # Assert that the status of the
206
+ # ToolMessage is "error"
207
+ if isinstance(msg, ToolMessage):
208
+ # print (msg)
209
+ tool_msg_status = msg.status
210
+ break
211
+ assert tool_msg_status == "error"
212
+ #########################################################
213
+ # In this case, we will test if the tool is indeed invoked
214
+ # successfully
215
+ prompt = """Run a steady state analysis of model 64.
216
+ Set the initial concentration of `Pyruvate` to 0.2. The
217
+ concentration of `NAD` resets to 100 every 2 time units."""
218
+ # Invoke the agent
219
+ app.invoke(
220
+ {"messages": [HumanMessage(content=prompt)]},
221
+ config=config
222
+ )
223
+ # Loop through the reversed messages until a
224
+ # ToolMessage is found.
225
+ current_state = app.get_state(config)
226
+ reversed_messages = current_state.values["messages"][::-1]
227
+ steady_state_invoked = False
228
+ for msg in reversed_messages:
229
+ # Assert that the message is a ToolMessage
230
+ # and its status is "error"
231
+ if isinstance(msg, ToolMessage):
232
+ print (msg)
233
+ if msg.name == "steady_state" and msg.status != "error":
234
+ steady_state_invoked = True
235
+ break
236
+ assert steady_state_invoked
237
+ #########################################################
238
+ # In this case, we will test if the `ask_question` tool is
239
+ # invoked upon asking a question about the already generated
240
+ # steady state results
241
+ prompt = """What is the Phosphoenolpyruvate concentration
242
+ at the steady state? Show onlyconcentration, rounded to
243
+ 2 decimal places. For example, if the concentration is
244
+ 0.123456, your response should be `0.12`. Do not return
245
+ any other information."""
246
+ # Invoke the agent
247
+ response = app.invoke(
248
+ {"messages": [HumanMessage(content=prompt)]},
249
+ config=config
250
+ )
251
+ assistant_msg = response["messages"][-1].content
252
+ current_state = app.get_state(config)
253
+ reversed_messages = current_state.values["messages"][::-1]
254
+ # Loop through the reversed messages until a
255
+ # ToolMessage is found.
256
+ ask_questool_invoked = False
257
+ for msg in reversed_messages:
258
+ # Assert that the message is a ToolMessage
259
+ # and its status is "error"
260
+ if isinstance(msg, ToolMessage):
261
+ if msg.name == "ask_question":
262
+ ask_questool_invoked = True
263
+ break
264
+ assert ask_questool_invoked
265
+ assert "0.06" in assistant_msg
266
+
184
267
  def test_integration():
185
268
  '''
186
269
  Test the integration of the tools.
187
270
  '''
188
- unique_id = 123
271
+ unique_id = 1234567
189
272
  app = get_app(unique_id)
190
273
  config = {"configurable": {"thread_id": unique_id}}
191
274
  app.update_state(config, {"llm_model": "gpt-4o-mini"})
@@ -211,10 +294,8 @@ def test_integration():
211
294
  ##########################################
212
295
  # Update state
213
296
  app.update_state(config, {"llm_model": "gpt-4o-mini"})
214
- prompt = "What is the concentration of CRP in serum at 1000 hours? "
215
- # prompt += "Show only the concentration, rounded to 1 decimal place."
216
- # prompt += "For example, if the concentration is 0.123456, "
217
- # prompt += "your response should be `0.1`. Do not return any other information."
297
+ prompt = """What is the concentration of CRP
298
+ in serum after 1000 time points?"""
218
299
  # Test the tool get_modelinfo
219
300
  response = app.invoke(
220
301
  {"messages": [HumanMessage(content=prompt)]},
@@ -271,8 +352,9 @@ def test_integration():
271
352
  # simulation results are available but
272
353
  # the species is not available
273
354
  ##########################################
274
- prompt = "Plot the species `TP53`."
275
-
355
+ prompt = """Make a custom plot showing the
356
+ concentration of the species `TP53` over
357
+ time. Do not show any other species."""
276
358
  # Update state
277
359
  app.update_state(config, {"llm_model": "gpt-4o-mini"}
278
360
  )
@@ -5,7 +5,7 @@ Tool for asking a question about the simulation results.
5
5
  """
6
6
 
7
7
  import logging
8
- from typing import Type, Annotated
8
+ from typing import Type, Annotated, Literal
9
9
  import pandas as pd
10
10
  from pydantic import BaseModel, Field
11
11
  from langchain_core.tools.base import BaseTool
@@ -22,9 +22,12 @@ class AskQuestionInput(BaseModel):
22
22
  """
23
23
  Input schema for the AskQuestion tool.
24
24
  """
25
- question: str = Field(description="question about the simulation results")
26
- simulation_name: str = Field(description="""Name assigned to the simulation
27
- when the tool simulate_model was invoked.""")
25
+ question: str = Field(description="question about the simulation and steady state results")
26
+ experiment_name: str = Field(description="""Name assigned to the simulation
27
+ or steady state analysis when the tool
28
+ simulate_model or steady_state is invoked.""")
29
+ question_context: Literal["simulation", "steady_state"] = Field(
30
+ description="Context of the question")
28
31
  state: Annotated[dict, InjectedState]
29
32
 
30
33
  # Note: It's important that every field has type hints.
@@ -32,41 +35,51 @@ class AskQuestionInput(BaseModel):
32
35
  # can lead to unexpected behavior.
33
36
  class AskQuestionTool(BaseTool):
34
37
  """
35
- Tool for calculating the product of two numbers.
38
+ Tool for asking a question about the simulation or steady state results.
36
39
  """
37
40
  name: str = "ask_question"
38
- description: str = "A tool to ask question about the simulation results."
41
+ description: str = """A tool to ask question about the
42
+ simulation or steady state results."""
39
43
  args_schema: Type[BaseModel] = AskQuestionInput
40
44
  return_direct: bool = False
41
45
 
42
46
  def _run(self,
43
47
  question: str,
44
- simulation_name: str,
48
+ experiment_name: str,
49
+ question_context: Literal["simulation", "steady_state"],
45
50
  state: Annotated[dict, InjectedState]) -> str:
46
51
  """
47
52
  Run the tool.
48
53
 
49
54
  Args:
50
- question (str): The question to ask about the simulation results.
55
+ question (str): The question to ask about the simulation or steady state results.
51
56
  state (dict): The state of the graph.
52
- simulation_name (str): The name assigned to the simulation.
57
+ experiment_name (str): The name assigned to the simulation or steady state analysis.
53
58
 
54
59
  Returns:
55
60
  str: The answer to the question.
56
61
  """
57
62
  logger.log(logging.INFO,
58
- "Calling ask_question tool %s, %s", question, simulation_name)
59
- dic_simulated_data = {}
60
- for data in state["dic_simulated_data"]:
63
+ "Calling ask_question tool %s, %s, %s",
64
+ question,
65
+ question_context,
66
+ experiment_name)
67
+ # print (f'Calling ask_question tool {question}, {question_context}, {experiment_name}')
68
+ if question_context == "steady_state":
69
+ dic_context = state["dic_steady_state_data"]
70
+ else:
71
+ dic_context = state["dic_simulated_data"]
72
+ dic_data = {}
73
+ for data in dic_context:
61
74
  for key in data:
62
- if key not in dic_simulated_data:
63
- dic_simulated_data[key] = []
64
- dic_simulated_data[key] += [data[key]]
65
- # print (dic_simulated_data)
66
- df_simulated_data = pd.DataFrame.from_dict(dic_simulated_data)
75
+ if key not in dic_data:
76
+ dic_data[key] = []
77
+ dic_data[key] += [data[key]]
78
+ # print (dic_data)
79
+ df_data = pd.DataFrame.from_dict(dic_data)
67
80
  df = pd.DataFrame(
68
- df_simulated_data[df_simulated_data['name'] == simulation_name]['data'].iloc[0]
69
- )
81
+ df_data[df_data['name'] == experiment_name]['data'].iloc[0]
82
+ )
70
83
  prompt_content = None
71
84
  # if run_manager and 'prompt' in run_manager.metadata:
72
85
  # prompt_content = run_manager.metadata['prompt']
@@ -0,0 +1,208 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Tool for parameter scan.
5
+ """
6
+
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from typing import Type, Union, List, Annotated
10
+ import basico
11
+ from pydantic import BaseModel, Field
12
+ from langgraph.types import Command
13
+ from langgraph.prebuilt import InjectedState
14
+ from langchain_core.tools import BaseTool
15
+ from langchain_core.messages import ToolMessage
16
+ from langchain_core.tools.base import InjectedToolCallId
17
+ from .load_biomodel import ModelData, load_biomodel
18
+
19
+ # Initialize logger
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ @dataclass
24
+ class TimeData:
25
+ """
26
+ Dataclass for storing the time data.
27
+ """
28
+ duration: Union[int, float] = 100
29
+ interval: Union[int, float] = 10
30
+
31
+ @dataclass
32
+ class SpeciesData:
33
+ """
34
+ Dataclass for storing the species data.
35
+ """
36
+ species_name: List[str] = Field(description="species name", default=[])
37
+ species_concentration: List[Union[int, float]] = Field(
38
+ description="initial species concentration",
39
+ default=[])
40
+
41
+ @dataclass
42
+ class TimeSpeciesNameConcentration:
43
+ """
44
+ Dataclass for storing the time, species name, and concentration data.
45
+ """
46
+ time: Union[int, float] = Field(description="time point where the event occurs")
47
+ species_name: str = Field(description="species name")
48
+ species_concentration: Union[int, float] = Field(
49
+ description="species concentration at the time point")
50
+
51
+ @dataclass
52
+ class ReocurringData:
53
+ """
54
+ Dataclass for species that reoccur. In other words, the concentration
55
+ of the species resets to a certain value after a certain time interval.
56
+ """
57
+ data: List[TimeSpeciesNameConcentration] = Field(
58
+ description="time, name, and concentration data of species that reoccur",
59
+ default=[])
60
+
61
+ @dataclass
62
+ class ArgumentData:
63
+ """
64
+ Dataclass for storing the argument data.
65
+ """
66
+ time_data: TimeData = Field(description="time data", default=None)
67
+ species_data: SpeciesData = Field(
68
+ description="species name and initial concentration data")
69
+ reocurring_data: ReocurringData = Field(
70
+ description="""Concentration and time data of species that reoccur
71
+ For example, a species whose concentration resets to a certain value
72
+ after a certain time interval""")
73
+ steadystate_name: str = Field(
74
+ description="""An AI assigned `_` separated name of
75
+ the steady state experiment based on human query""")
76
+
77
+ def add_rec_events(model_object, reocurring_data):
78
+ """
79
+ Add reocurring events to the model.
80
+ """
81
+ for row in reocurring_data.data:
82
+ tp, sn, sc = row.time, row.species_name, row.species_concentration
83
+ basico.add_event(f'{sn}_{tp}',
84
+ f'Time > {tp}',
85
+ [[sn, str(sc)]],
86
+ model=model_object.copasi_model)
87
+
88
+ def run_steady_state(model_object, dic_species_data):
89
+ """
90
+ Run the steady state analysis.
91
+
92
+ Args:
93
+ model_object: The model object.
94
+ dic_species_data: Dictionary of species data.
95
+
96
+ Returns:
97
+ DataFrame: The results of the steady state analysis.
98
+ """
99
+ # Update the fixed model species and parameters
100
+ # These are the initial conditions of the model
101
+ # set by the user
102
+ model_object.update_parameters(dic_species_data)
103
+ logger.log(logging.INFO, "Running steady state analysis")
104
+ # Run the steady state analysis
105
+ output = basico.task_steadystate.run_steadystate(model=model_object.copasi_model)
106
+ if output == 0:
107
+ logger.error("Steady state analysis failed")
108
+ raise ValueError("A steady state was not found")
109
+ logger.log(logging.INFO, "Steady state analysis successful")
110
+ # Store the steady state results in a DataFrame
111
+ df_steady_state = basico.model_info.get_species(model=model_object.copasi_model)
112
+ logger.log(logging.INFO, "Steady state results with shape %s", df_steady_state.shape)
113
+ return df_steady_state
114
+
115
+ class SteadyStateInput(BaseModel):
116
+ """
117
+ Input schema for the steady state tool.
118
+ """
119
+ sys_bio_model: ModelData = Field(description="model data",
120
+ default=None)
121
+ arg_data: ArgumentData = Field(description=
122
+ """time, species, and reocurring data
123
+ as well as the steady state data""",
124
+ default=None)
125
+ tool_call_id: Annotated[str, InjectedToolCallId]
126
+ state: Annotated[dict, InjectedState]
127
+
128
+ # Note: It's important that every field has type hints. BaseTool is a
129
+ # Pydantic class and not having type hints can lead to unexpected behavior.
130
+ class SteadyStateTool(BaseTool):
131
+ """
132
+ Tool for steady state analysis.
133
+ """
134
+ name: str = "steady_state"
135
+ description: str = """A tool to simulate a model and perform
136
+ steady state analysisto answer questions
137
+ about the steady state of species."""
138
+ args_schema: Type[BaseModel] = SteadyStateInput
139
+
140
+ def _run(self,
141
+ tool_call_id: Annotated[str, InjectedToolCallId],
142
+ state: Annotated[dict, InjectedState],
143
+ sys_bio_model: ModelData = None,
144
+ arg_data: ArgumentData = None
145
+ ) -> Command:
146
+ """
147
+ Run the tool.
148
+
149
+ Args:
150
+ tool_call_id (str): The tool call ID. This is injected by the system.
151
+ state (dict): The state of the tool.
152
+ sys_bio_model (ModelData): The model data.
153
+ arg_data (ArgumentData): The argument data.
154
+
155
+ Returns:
156
+ Command: The updated state of the tool.
157
+ """
158
+ logger.log(logging.INFO, "Calling steady_state tool %s, %s",
159
+ sys_bio_model, arg_data)
160
+ # print (f'Calling steady_state tool {sys_bio_model}, {arg_data}, {tool_call_id}')
161
+ sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
162
+ model_object = load_biomodel(sys_bio_model,
163
+ sbml_file_path=sbml_file_path)
164
+ # Prepare the dictionary of species data
165
+ # that will be passed to the simulate method
166
+ # of the BasicoModel class
167
+ dic_species_data = {}
168
+ if arg_data:
169
+ # Prepare the dictionary of species data
170
+ if arg_data.species_data is not None:
171
+ dic_species_data = dict(zip(arg_data.species_data.species_name,
172
+ arg_data.species_data.species_concentration))
173
+ # Add reocurring events (if any) to the model
174
+ if arg_data.reocurring_data is not None:
175
+ add_rec_events(model_object, arg_data.reocurring_data)
176
+ # Run the parameter scan
177
+ df_steady_state = run_steady_state(model_object, dic_species_data)
178
+ # Prepare the dictionary of scanned data
179
+ # that will be passed to the state of the graph
180
+ dic_steady_state_data = {
181
+ 'name': arg_data.steadystate_name,
182
+ 'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
183
+ 'tool_call_id': tool_call_id,
184
+ 'data': df_steady_state.to_dict(orient='records')
185
+ }
186
+ # Prepare the dictionary of updated state for the model
187
+ dic_updated_state_for_model = {}
188
+ for key, value in {
189
+ "model_id": [sys_bio_model.biomodel_id],
190
+ "sbml_file_path": [sbml_file_path],
191
+ "dic_steady_state_data": [dic_steady_state_data]
192
+ }.items():
193
+ if value:
194
+ dic_updated_state_for_model[key] = value
195
+ # Return the updated state
196
+ return Command(
197
+ update=dic_updated_state_for_model|{
198
+ # Update the message history
199
+ "messages": [
200
+ ToolMessage(
201
+ content=f'''Steady state analysis of
202
+ {arg_data.steadystate_name}
203
+ are ready''',
204
+ tool_call_id=tool_call_id
205
+ )
206
+ ],
207
+ }
208
+ )
@@ -1,4 +1,5 @@
1
1
  '''
2
- This file is used to import the datasets, utils, and tools.
2
+ This file is used to import the datasets and utils.
3
3
  '''
4
4
  from . import datasets
5
+ from . import utils
@@ -0,0 +1,39 @@
1
+ """
2
+ Test cases for utils/enrichments/enrichments.py
3
+ """
4
+
5
+ from ..utils.enrichments.enrichments import Enrichments
6
+
7
+ class TestEnrichments(Enrichments):
8
+ """Test implementation of the Enrichments interface for testing purposes."""
9
+
10
+ def enrich_documents(self, texts: list[str]) -> list[list[float]]:
11
+ return [
12
+ f"Additional text description of {text} as the input." for text in texts
13
+ ]
14
+
15
+ def enrich_documents_with_rag(self, texts, docs):
16
+ # Currently we don't have a RAG model to test this method.
17
+ # Thus, we will just call the enrich_documents method instead.
18
+ return self.enrich_documents(texts)
19
+
20
+ def test_enrich_documents():
21
+ """Test enriching documents using the Enrichments interface."""
22
+ enrichments = TestEnrichments()
23
+ texts = ["text1", "text2"]
24
+ result = enrichments.enrich_documents(texts)
25
+ assert result == [
26
+ "Additional text description of text1 as the input.",
27
+ "Additional text description of text2 as the input.",
28
+ ]
29
+
30
+ def test_enrich_documents_with_rag():
31
+ """Test enriching documents with RAG using the Enrichments interface."""
32
+ enrichments = TestEnrichments()
33
+ texts = ["text1", "text2"]
34
+ docs = ["doc1", "doc2"]
35
+ result = enrichments.enrich_documents_with_rag(texts, docs)
36
+ assert result == [
37
+ "Additional text description of text1 as the input.",
38
+ "Additional text description of text2 as the input.",
39
+ ]
@@ -0,0 +1,117 @@
1
+ """
2
+ Test cases for utils/enrichments/ollama.py
3
+ """
4
+
5
+ import pytest
6
+ import ollama
7
+ from ..utils.enrichments.ollama import EnrichmentWithOllama
8
+
9
+ @pytest.fixture(name="ollama_config")
10
+ def fixture_ollama_config():
11
+ """Return a dictionary with Ollama configuration."""
12
+ return {
13
+ "model_name": "smollm2:360m",
14
+ "prompt_enrichment": """
15
+ Given the input as a list of strings, please return the list of addditional information of
16
+ each input terms using your prior knowledge.
17
+
18
+ Example:
19
+ Input: ['acetaminophen', 'aspirin']
20
+ Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
+ 'aspirin is a medication used to treat pain, fever, and inflammation']
22
+
23
+ Do not include any pretext as the output, only the list of strings enriched.
24
+
25
+ Input: {input}
26
+ """,
27
+ "temperature": 0.0,
28
+ "streaming": False,
29
+ }
30
+
31
+ def test_no_model_ollama(ollama_config):
32
+ """Test the case when the Ollama model is not available."""
33
+ cfg = ollama_config
34
+ cfg_model = "smollm2:135m" # Choose a small model
35
+
36
+ # Delete the Ollama model
37
+ try:
38
+ ollama.delete(cfg_model)
39
+ except ollama.ResponseError:
40
+ pass
41
+
42
+ # Check if the model is available
43
+ with pytest.raises(
44
+ ValueError, match=f"Error: Pulled {cfg_model} model and restarted Ollama server."
45
+ ):
46
+ EnrichmentWithOllama(
47
+ model_name=cfg_model,
48
+ prompt_enrichment=cfg["prompt_enrichment"],
49
+ temperature=cfg["temperature"],
50
+ streaming=cfg["streaming"],
51
+ )
52
+ ollama.delete(cfg_model)
53
+
54
+ def test_enrich_nodes_ollama(ollama_config):
55
+ """Test the Ollama textual enrichment class for node enrichment."""
56
+ # Prepare enrichment model
57
+ cfg = ollama_config
58
+ enr_model = EnrichmentWithOllama(
59
+ model_name=cfg["model_name"],
60
+ prompt_enrichment=cfg["prompt_enrichment"],
61
+ temperature=cfg["temperature"],
62
+ streaming=cfg["streaming"],
63
+ )
64
+
65
+ # Perform enrichment for nodes
66
+ nodes = ["Adalimumab", "Infliximab"]
67
+ enriched_nodes = enr_model.enrich_documents(nodes)
68
+ # Check the enriched nodes
69
+ assert len(enriched_nodes) == 2
70
+ assert all(
71
+ enriched_nodes[i] != nodes[i] for i in range(len(nodes))
72
+ )
73
+
74
+
75
+ def test_enrich_relations_ollama(ollama_config):
76
+ """Test the Ollama textual enrichment class for relation enrichment."""
77
+ # Prepare enrichment model
78
+ cfg = ollama_config
79
+ enr_model = EnrichmentWithOllama(
80
+ model_name=cfg["model_name"],
81
+ prompt_enrichment=cfg["prompt_enrichment"],
82
+ temperature=cfg["temperature"],
83
+ streaming=cfg["streaming"],
84
+ )
85
+ # Perform enrichment for relations
86
+ relations = [
87
+ "IL23R-gene causation disease-inflammatory bowel diseases",
88
+ "NOD2-gene causation disease-inflammatory bowel diseases",
89
+ ]
90
+ enriched_relations = enr_model.enrich_documents(relations)
91
+ # Check the enriched relations
92
+ assert len(enriched_relations) == 2
93
+ assert all(
94
+ enriched_relations[i] != relations[i]
95
+ for i in range(len(relations))
96
+ )
97
+
98
+
99
+ def test_enrich_ollama_rag(ollama_config):
100
+ """Test the Ollama textual enrichment class for enrichment with RAG (not implemented)."""
101
+ # Prepare enrichment model
102
+ cfg = ollama_config
103
+ enr_model = EnrichmentWithOllama(
104
+ model_name=cfg["model_name"],
105
+ prompt_enrichment=cfg["prompt_enrichment"],
106
+ temperature=cfg["temperature"],
107
+ streaming=cfg["streaming"],
108
+ )
109
+ # Perform enrichment for nodes
110
+ nodes = ["Adalimumab", "Infliximab"]
111
+ docs = [r"\path\to\doc1", r"\path\to\doc2"]
112
+ enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
113
+ # Check the enriched nodes
114
+ assert len(enriched_nodes) == 2
115
+ assert all(
116
+ enriched_nodes[i] != nodes[i] for i in range(len(nodes))
117
+ )
@@ -0,0 +1,5 @@
1
+ '''
2
+ This file is used to import utlities.
3
+ '''
4
+ from . import enrichments
5
+ from . import embeddings
@@ -0,0 +1,5 @@
1
+ """
2
+ This package contains modules to use the enrichment model
3
+ """
4
+ from . import enrichments
5
+ from . import ollama
@@ -0,0 +1,36 @@
1
+ """
2
+ Enrichments interface
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ class Enrichments(ABC):
8
+ """Interface for enrichment models.
9
+
10
+ This is an interface meant for implementing text enrichment models.
11
+
12
+ Enrichment models are used to enrich node or relation features in a given knowledge graph.
13
+ """
14
+
15
+ @abstractmethod
16
+ def enrich_documents(self, texts: list[str]) -> list[list[str]]:
17
+ """Enrich documents.
18
+
19
+ Args:
20
+ texts: List of documents to enrich.
21
+
22
+ Returns:
23
+ List of enriched documents.
24
+ """
25
+
26
+ @abstractmethod
27
+ def enrich_documents_with_rag(self, texts: list[str], docs: list[str]) -> list[str]:
28
+ """Enrich documents with RAG.
29
+
30
+ Args:
31
+ texts: List of documents to enrich.
32
+ docs: List of reference documents to enrich the input texts.
33
+
34
+ Returns:
35
+ List of enriched documents with RAG.
36
+ """
@@ -0,0 +1,123 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Enrichment class using Ollama model based on LangChain Enrichment class.
5
+ """
6
+
7
+ import time
8
+ from typing import List
9
+ import subprocess
10
+ import ast
11
+ import ollama
12
+ from langchain_ollama import ChatOllama
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from .enrichments import Enrichments
16
+
17
+ class EnrichmentWithOllama(Enrichments):
18
+ """
19
+ Enrichment class using Ollama model based on the Enrichment abstract class.
20
+ """
21
+ def __init__(
22
+ self,
23
+ model_name: str,
24
+ prompt_enrichment: str,
25
+ temperature: float,
26
+ streaming: bool,
27
+ ):
28
+ """
29
+ Initialize the EnrichmentWithOllama class.
30
+
31
+ Args:
32
+ model_name: The name of the Ollama model to be used.
33
+ prompt_enrichment: The prompt enrichment template.
34
+ temperature: The temperature for the Ollama model.
35
+ streaming: The streaming flag for the Ollama model.
36
+ """
37
+ # Setup the Ollama server
38
+ self.__setup(model_name)
39
+
40
+ # Set parameters
41
+ self.model_name = model_name
42
+ self.prompt_enrichment = prompt_enrichment
43
+ self.temperature = temperature
44
+ self.streaming = streaming
45
+
46
+ # Prepare prompt template
47
+ self.prompt_template = ChatPromptTemplate.from_messages(
48
+ [
49
+ ("system", self.prompt_enrichment),
50
+ ("human", "{input}"),
51
+ ]
52
+ )
53
+
54
+ # Prepare model
55
+ self.model = ChatOllama(
56
+ model=self.model_name,
57
+ temperature=self.temperature,
58
+ streaming=self.streaming,
59
+ )
60
+
61
+ def __setup(self, model_name: str) -> None:
62
+ """
63
+ Check if the Ollama model is available and run the Ollama server if needed.
64
+
65
+ Args:
66
+ model_name: The name of the Ollama model to be used.
67
+ """
68
+ try:
69
+ models_list = ollama.list()["models"]
70
+ if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
71
+ ollama.pull(model_name)
72
+ time.sleep(30)
73
+ raise ValueError(f"Pulled {model_name} model")
74
+ except Exception as e:
75
+ with subprocess.Popen(
76
+ "ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
77
+ ):
78
+ time.sleep(10)
79
+ raise ValueError(f"Error: {e} and restarted Ollama server.") from e
80
+
81
+ def enrich_documents(self, texts: List[str]) -> List[str]:
82
+ """
83
+ Enrich a list of input texts with additional textual features using OLLAMA model.
84
+ Important: Make sure the input is a list of texts based on the defined prompt template
85
+ with 'input' as the variable name.
86
+
87
+ Args:
88
+ texts: The list of texts to be enriched.
89
+
90
+ Returns:
91
+ The list of enriched texts.
92
+ """
93
+
94
+ # Perform enrichment
95
+ chain = self.prompt_template | self.model | StrOutputParser()
96
+
97
+ # Generate the enriched node
98
+ # Important: Make sure the input is a list of texts based on the defined prompt template
99
+ # with 'input' as the variable name
100
+ enriched_texts = chain.invoke({"input": "[" + ", ".join(texts) + "]"})
101
+
102
+ # Convert the enriched nodes to a list of dictionary
103
+ enriched_texts = ast.literal_eval(enriched_texts.replace("```", ""))
104
+
105
+ # Final check for the enriched texts
106
+ assert len(enriched_texts) == len(texts)
107
+
108
+ return enriched_texts
109
+
110
+ def enrich_documents_with_rag(self, texts, docs):
111
+ """
112
+ Enrich a list of input texts with additional textual features using OLLAMA model with RAG.
113
+ As of now, we don't have a RAG model to test this method yet.
114
+ Thus, we will just call the enrich_documents method instead.
115
+
116
+ Args:
117
+ texts: The list of texts to be enriched.
118
+ docs: The list of reference documents to enrich the input texts.
119
+
120
+ Returns:
121
+ The list of enriched texts
122
+ """
123
+ return self.enrich_documents(texts)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: aiagents4pharma
3
- Version: 1.11.0
3
+ Version: 1.13.0
4
4
  Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -20,9 +20,11 @@ Requires-Dist: langchain-community==0.3.5
20
20
  Requires-Dist: langchain-core==0.3.31
21
21
  Requires-Dist: langchain-experimental==0.3.3
22
22
  Requires-Dist: langchain-openai==0.2.5
23
+ Requires-Dist: langchain_ollama==0.2.2
23
24
  Requires-Dist: langgraph==0.2.66
24
25
  Requires-Dist: matplotlib==3.9.2
25
26
  Requires-Dist: openai==1.59.4
27
+ Requires-Dist: ollama==0.4.6
26
28
  Requires-Dist: pandas==2.2.3
27
29
  Requires-Dist: plotly==5.24.1
28
30
  Requires-Dist: pydantic==2.9.2
@@ -4,27 +4,28 @@ aiagents4pharma/configs/config.yaml,sha256=8y8uG6Dzx4-9jyb6hZ8r4lOJz5gA_sQhCiSCg
4
4
  aiagents4pharma/configs/talk2biomodels/__init__.py,sha256=5ah__-8XyRblwT0U1ByRigNjt_GyCheu7zce4aM-eZE,68
5
5
  aiagents4pharma/configs/talk2biomodels/agents/__init__.py,sha256=_ZoG8snICK2bidWtc2KOGs738LWg9_r66V9mOMnEb-E,71
6
6
  aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
7
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml,sha256=yD7qZCneaM-JE5PdZjDmDoTRUdsFrzeCKZsBx1b-f20,293
7
+ aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml,sha256=Oi89_BbxfQc6SGW1pC-hyZMqOIkiAMOlNwpCa4VCXk0,327
8
8
  aiagents4pharma/talk2biomodels/__init__.py,sha256=qUw3qXrENqSCLIKSLy_qtNPwPDTb1wdZ8fZispcHb3g,141
9
9
  aiagents4pharma/talk2biomodels/agents/__init__.py,sha256=sn5-fREjMdEvb-OUan3iOqrgYGjplNx3J8hYOaW0Po8,128
10
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py,sha256=6Im4YFcdykN7wpEM8y9qi_x4lTg02WJpb0SEWh8TPLo,3188
10
+ aiagents4pharma/talk2biomodels/agents/t2b_agent.py,sha256=8_2D3uknPjYbzKj7IDhC8xnz_HEvXck0b4RCJxoAxRs,3276
11
11
  aiagents4pharma/talk2biomodels/models/__init__.py,sha256=5fTHHm3PVloYPNKXbgNlcPgv3-u28ZquxGydFYDfhJA,122
12
12
  aiagents4pharma/talk2biomodels/models/basico_model.py,sha256=PH25FTOuUjsmw_UUxoRb-4kptOYpicEn4GqS0phS3nk,4807
13
13
  aiagents4pharma/talk2biomodels/models/sys_bio_model.py,sha256=JeoiGQAvQABHnG0wKR2XBmmxqQdtgO6kxaLDUTUmr1s,2001
14
14
  aiagents4pharma/talk2biomodels/states/__init__.py,sha256=YLg1-N0D9qyRRLRqwqfLCLAqZYDtMVZTfI8Y0b_4tbA,139
15
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py,sha256=Dlsnh9dW1mCXTBXmlDAlOox7f4azFbLBG_2k3YPielM,824
15
+ aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py,sha256=XcrCsBbQgc4LCZ7ihpY5tcLQ82Mq17XhRXD1DW8Ep7U,887
16
16
  aiagents4pharma/talk2biomodels/tests/__init__.py,sha256=Jbw5tJxSrjGoaK5IX3pJWDCNzhrVQ10lkYq2oQ_KQD8,45
17
17
  aiagents4pharma/talk2biomodels/tests/test_basico_model.py,sha256=y82fpTJMPHwtXxlle1cGQ_2Bewwpxi0aJSVrVAYLhN0,2060
18
- aiagents4pharma/talk2biomodels/tests/test_langgraph.py,sha256=_71UZS1zucn6Nus4oCH7tINQVRvJEFnL0UIZ6-sUd3I,11967
18
+ aiagents4pharma/talk2biomodels/tests/test_langgraph.py,sha256=QLAL4nmHrioTD-w-9OE0wQi5JdWJJ59PejNbDzCSvw4,15170
19
19
  aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py,sha256=HSmBBViMi0jYf4gWX21IbppAfDzG0nr_S3KtKS9fZVQ,2165
20
20
  aiagents4pharma/talk2biomodels/tools/__init__.py,sha256=SMTMlGHxTuJI7gjwGaTMv0XmilJ73-r5dp568hD3Fw0,266
21
- aiagents4pharma/talk2biomodels/tools/ask_question.py,sha256=uxCQ4ON8--D0ACPvT14t6x_aqm9LP6woBA4GM7bPXc4,3061
21
+ aiagents4pharma/talk2biomodels/tools/ask_question.py,sha256=qpltsgyLFFwLYQeapQHASFRDCNiWsJkmTH_sUrfJ_Fg,3708
22
22
  aiagents4pharma/talk2biomodels/tools/custom_plotter.py,sha256=HWwKTX3o4dk0GcRVTO2hPrFSu98mtJ4TKC_hbHXOe1c,4018
23
23
  aiagents4pharma/talk2biomodels/tools/get_modelinfo.py,sha256=qA-4FOI-O728Nmn7s8JJ8HKwxvA9MZbst7NkPKTAMV4,5391
24
24
  aiagents4pharma/talk2biomodels/tools/load_biomodel.py,sha256=pyVzLQoMnuJYEwsjeOlqcUrbU1F1Z-pNlgkhFaoKpy0,689
25
25
  aiagents4pharma/talk2biomodels/tools/parameter_scan.py,sha256=aIyL_m46s3Q74ieJOZjZBM34VCjBKSMpEtckhdZofbE,12139
26
26
  aiagents4pharma/talk2biomodels/tools/search_models.py,sha256=Iq2ddofOOfZYtAurCISq3bAq5rbwB3l_rL1lgEFyFCI,2653
27
27
  aiagents4pharma/talk2biomodels/tools/simulate_model.py,sha256=sWmFVnVvJbdXXTqn_7gQl5UW0tv4FyU5yLXWLweLs_M,7059
28
+ aiagents4pharma/talk2biomodels/tools/steady_state.py,sha256=itUXFTYO525D695LQlGb1ObwSuvHk0A5mXtgdCproqo,8102
28
29
  aiagents4pharma/talk2cells/__init__.py,sha256=zmOP5RAhabgKIQP-W4P4qKME2tG3fhAXM3MeO5_H8kE,120
29
30
  aiagents4pharma/talk2cells/agents/__init__.py,sha256=38nK2a_lEFRjO3qD6Fo9a3983ZCYat6hmJKWY61y2Mo,128
30
31
  aiagents4pharma/talk2cells/agents/scp_agent.py,sha256=gDMfhUNWHa_XWOqm1Ql6yLAdI_7bnIk5sRYn43H2sYk,3090
@@ -51,7 +52,7 @@ aiagents4pharma/talk2competitors/tools/s2/display_results.py,sha256=B8JJGohi1Eyx
51
52
  aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py,sha256=FYLt47DAk6WOKfEk1Gj9zVvJGNyxA283PCp8IKW9U5M,4262
52
53
  aiagents4pharma/talk2competitors/tools/s2/search.py,sha256=pppjrQv5-8ep4fnqgTSBNgnbSnQsVIcNrRrH0p2TP1o,4025
53
54
  aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py,sha256=dAfUQxI7T5eu0eDxK8VAl7-JH0Wnw24CVkOQqwj-hXc,4810
54
- aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=SW7Ys2A4eXyFtizNPdSw91SHOPVUBGBsrCQ7TqwSUL0,91
55
+ aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=4smVQoSMM6rflVnNkABqlDAAlSn4bYsq7rMVWjRGvis,103
55
56
  aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py,sha256=L3gPuHskSegmtXskVrLIYr7FXe_ibKgJ2GGr1_Wok6k,173
56
57
  aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py,sha256=QlzDXmXREoa9MA6-GwzqRjdzndQeGBAF11Td6NFk_9Y,23426
57
58
  aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py,sha256=-LaPLse8BkALqwFetNK7wch2dt9Dz6QKGKZKBKM6bIk,409
@@ -65,14 +66,19 @@ aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py,sha2
65
66
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py,sha256=uYFoE_6zeU10_1mLLAHUr5c4S2XZMSc0Q_860o-KWEw,1517
66
67
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py,sha256=EINWyXg_3AMHF3WzFLhIUiFDuaEhTVHBvVAJr8VtMDg,1624
67
68
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py,sha256=Qxo6WeIDRy8aLh1tNKw0kSlzmUj3MtTak63oW2YwB24,1327
68
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
+ aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py,sha256=N6HRr4lWHXY7bTHe2uXJe4D_EG9WqZPibZne6qLl9_k,1447
70
+ aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py,sha256=kMuB_vci6hKr2qJgXBmcje7yxeJ2nY2ImXw-NSJpts0,3912
71
+ aiagents4pharma/talk2knowledgegraphs/utils/__init__.py,sha256=X8kSpmDTMkeE0fA5D0CWMsQ52YKoM5rRSXrjnali3IM,97
69
72
  aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py,sha256=6vQnPkeOWae_8jePjhma3sJuMTngy0I0tqzdFt6OqKg,2507
70
73
  aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py,sha256=xRb0x7SoAb0nSVZYgjrqxWvENOMDuqIdL43NMjoOaCs,153
71
74
  aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py,sha256=1nGznrAj-xT0xuSMBGz2dOujJ7M_IwSR84njxtxsy9A,2523
72
75
  aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py,sha256=2vi_elf6EgzfagFAO5QnL3a_aXZyN7B1EBziu44MTfM,3806
73
76
  aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py,sha256=36iKlisOpMtGR5xfTAlSHXWvPqVC_Jbezod8kbBBMVg,2136
74
- aiagents4pharma-1.11.0.dist-info/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
75
- aiagents4pharma-1.11.0.dist-info/METADATA,sha256=GsvCuC24bJ_wwQP8aevOB4Eai8WcUIcr9kmN3EwYP_Y,8541
76
- aiagents4pharma-1.11.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
77
- aiagents4pharma-1.11.0.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
78
- aiagents4pharma-1.11.0.dist-info/RECORD,,
77
+ aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py,sha256=tW426knki2DBIHcWyF_K04iMMdbpIn_e_TpPmTgz2dI,113
78
+ aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py,sha256=Bx8x6zzk5614ApWB90N_iv4_Y_Uq0-KwUeBwYSdQMU4,924
79
+ aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py,sha256=8eoxR-VHo0G7ReQIwje7xEhE-SJlHdef7_wJRpnvFIc,4116
80
+ aiagents4pharma-1.13.0.dist-info/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
81
+ aiagents4pharma-1.13.0.dist-info/METADATA,sha256=gr1_xufeepNgQ_LbpMynm6gGpUyGNCcRd5RecsVAfiM,8609
82
+ aiagents4pharma-1.13.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
83
+ aiagents4pharma-1.13.0.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
84
+ aiagents4pharma-1.13.0.dist-info/RECORD,,