aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
  2. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
  3. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
  4. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
  5. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
  6. aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
  7. aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
  8. aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
  9. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
  10. aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
  11. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
  12. aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
  13. aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
  14. aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
  15. aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
  16. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
  17. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
  18. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
  19. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  20. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  21. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
  22. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  23. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
  24. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
  25. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
  26. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  27. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  28. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  29. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  30. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  31. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  32. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
  33. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
  34. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
  35. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
  36. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
  37. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
  38. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
  39. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
  40. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
  41. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
  42. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
  43. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
  44. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  45. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
  46. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
  47. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
  48. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  49. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
  50. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
  51. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
  52. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
  53. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
  54. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
  55. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
  56. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ import logging
8
8
  from typing import Annotated
9
9
  import hydra
10
10
  from langchain_openai import ChatOpenAI
11
+ from langchain_core.language_models.chat_models import BaseChatModel
11
12
  from langgraph.checkpoint.memory import MemorySaver
12
13
  from langgraph.graph import START, StateGraph
13
14
  from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
@@ -26,7 +27,8 @@ from ..states.state_talk2biomodels import Talk2Biomodels
26
27
  logging.basicConfig(level=logging.INFO)
27
28
  logger = logging.getLogger(__name__)
28
29
 
29
- def get_app(uniq_id, llm_model='gpt-4o-mini'):
30
+ def get_app(uniq_id,
31
+ llm_model: BaseChatModel = ChatOpenAI(model='gpt-4o-mini', temperature=0)):
30
32
  '''
31
33
  This function returns the langraph app.
32
34
  '''
@@ -51,8 +53,6 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
51
53
  QueryArticle()
52
54
  ])
53
55
 
54
- # Define the model
55
- llm = ChatOpenAI(model=llm_model, temperature=0)
56
56
  # Load hydra configuration
57
57
  logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.")
58
58
  with hydra.initialize(version_base=None, config_path="../configs"):
@@ -62,7 +62,7 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
62
62
  logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier)
63
63
  # Create the agent
64
64
  model = create_react_agent(
65
- llm,
65
+ llm_model,
66
66
  tools=tools,
67
67
  state_schema=Talk2Biomodels,
68
68
  state_modifier=cfg.state_modifier,
@@ -10,22 +10,14 @@ steady_state_prompt: >
10
10
 
11
11
  Here are some instructions to help you answer questions:
12
12
 
13
- 1. Before you answer any question, follow the plan and solve
14
- technique. Start by understanding the question, then plan your
15
- approach to solve the question, and finally solve the question
16
- by following the plan. Always give a brief explanation of your
17
- answer to the user.
13
+ 1. If the user wants to know the time taken by the model to reach
14
+ steady state, you should look at the `steady_state_transition_time`
15
+ column of the data for the model species.
16
+
17
+ 2. The highest value in the column `steady_state_transition_time`
18
+ is the time taken by the model to reach steady state.
18
19
 
19
- 2. If the user wants to know the time taken by the model to reach
20
- steady state, you should look at the steady_state_transition_time
21
- column of the data for the model species. The highest value in
22
- this column is the time taken by the model to reach steady state.
23
-
24
- 3. To get accurate results, trim the data to the relevant columns
25
- before performing any calculations. This will help you avoid
26
- errors in your calculations, and ignore irrelevant data.
27
-
28
- 4. Please use the units provided below to answer the questions.
20
+ 3. Please use the units provided below to answer the questions.
29
21
  simulation_prompt: >
30
22
  Following is the information about the data frame:
31
23
  1. First column is the time column, and the rest of the columns
@@ -7,6 +7,8 @@ This is the state file for the Talk2BioModels agent.
7
7
  from typing import Annotated
8
8
  import operator
9
9
  from langgraph.prebuilt.chat_agent_executor import AgentState
10
+ from langchain_core.language_models import BaseChatModel
11
+ from langchain_core.embeddings import Embeddings
10
12
 
11
13
  def add_data(data1: dict, data2: dict) -> dict:
12
14
  """
@@ -26,7 +28,8 @@ class Talk2Biomodels(AgentState):
26
28
  """
27
29
  The state for the Talk2BioModels agent.
28
30
  """
29
- llm_model: str
31
+ llm_model: BaseChatModel
32
+ text_embedding_model: Embeddings
30
33
  pdf_file_name: str
31
34
  # A StateGraph may receive a concurrent updates
32
35
  # which is not supported by the StateGraph. Hence,
@@ -3,6 +3,7 @@ Test cases for Talk2Biomodels.
3
3
  '''
4
4
 
5
5
  from langchain_core.messages import HumanMessage, ToolMessage
6
+ from langchain_openai import ChatOpenAI
6
7
  from ..agents.t2b_agent import get_app
7
8
 
8
9
  def test_ask_question_tool():
@@ -10,7 +11,7 @@ def test_ask_question_tool():
10
11
  Test the ask_question tool without the simulation results.
11
12
  '''
12
13
  unique_id = 12345
13
- app = get_app(unique_id, llm_model='gpt-4o-mini')
14
+ app = get_app(unique_id)
14
15
  config = {"configurable": {"thread_id": unique_id}}
15
16
 
16
17
  ##########################################
@@ -20,7 +21,8 @@ def test_ask_question_tool():
20
21
  # case, the tool should return an error
21
22
  ##########################################
22
23
  # Update state
23
- app.update_state(config, {"llm_model": "gpt-4o-mini"})
24
+ app.update_state(config,
25
+ {"llm_model": ChatOpenAI(model='gpt-4o-mini', temperature=0)})
24
26
  # Define the prompt
25
27
  prompt = "Call the ask_question tool to answer the "
26
28
  prompt += "question: What is the concentration of CRP "
@@ -5,6 +5,7 @@ Test cases for Talk2Biomodels get_annotation tool.
5
5
  import random
6
6
  import pytest
7
7
  from langchain_core.messages import HumanMessage, ToolMessage
8
+ from langchain_openai import ChatOpenAI
8
9
  from ..agents.t2b_agent import get_app
9
10
  from ..tools.get_annotation import prepare_content_msg
10
11
 
@@ -16,7 +17,9 @@ def make_graph_fixture():
16
17
  unique_id = random.randint(1000, 9999)
17
18
  graph = get_app(unique_id)
18
19
  config = {"configurable": {"thread_id": unique_id}}
19
- graph.update_state(config, {"llm_model": "gpt-4o-mini"})
20
+ graph.update_state(config, {"llm_model": ChatOpenAI(model='gpt-4o-mini',
21
+ temperature=0)
22
+ })
20
23
  return graph, config
21
24
 
22
25
  def test_no_model_provided(make_graph):
@@ -85,7 +88,6 @@ def test_invalid_species_provided(make_graph):
85
88
  # (likely due to an invalid species).
86
89
  test_condition = True
87
90
  break
88
- # assert test_condition
89
91
  assert test_condition
90
92
 
91
93
  def test_invalid_and_valid_species_provided(make_graph):
@@ -4,8 +4,11 @@ Test cases for Talk2Biomodels.
4
4
 
5
5
  import pandas as pd
6
6
  from langchain_core.messages import HumanMessage, ToolMessage
7
+ from langchain_openai import ChatOpenAI
7
8
  from ..agents.t2b_agent import get_app
8
9
 
10
+ LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0)
11
+
9
12
  def test_integration():
10
13
  '''
11
14
  Test the integration of the tools.
@@ -13,7 +16,7 @@ def test_integration():
13
16
  unique_id = 1234567
14
17
  app = get_app(unique_id)
15
18
  config = {"configurable": {"thread_id": unique_id}}
16
- app.update_state(config, {"llm_model": "gpt-4o-mini"})
19
+ app.update_state(config, {"llm_model": LLM_MODEL})
17
20
  # ##########################################
18
21
  # ## Test simulate_model tool
19
22
  # ##########################################
@@ -34,7 +37,7 @@ def test_integration():
34
37
  # results are available
35
38
  ##########################################
36
39
  # Update state
37
- app.update_state(config, {"llm_model": "gpt-4o-mini"})
40
+ app.update_state(config, {"llm_model": LLM_MODEL})
38
41
  prompt = """What is the concentration of CRP in serum after 100 hours?
39
42
  Round off the value to 2 decimal places."""
40
43
  # Test the tool get_modelinfo
@@ -49,12 +52,15 @@ def test_integration():
49
52
 
50
53
  ##########################################
51
54
  # Test custom_plotter tool when the
52
- # simulation results are available
55
+ # simulation results are available but
56
+ # the species is not available
53
57
  ##########################################
54
- prompt = "Plot only CRP related species."
55
-
58
+ prompt = """Call the custom_plotter tool to make a plot
59
+ showing only species `TP53` and `Pyruvate`. Let me
60
+ know if these species were not found. Do not
61
+ invoke any other tool."""
56
62
  # Update state
57
- app.update_state(config, {"llm_model": "gpt-4o-mini"}
63
+ app.update_state(config, {"llm_model": LLM_MODEL}
58
64
  )
59
65
  # Test the tool get_modelinfo
60
66
  response = app.invoke(
@@ -66,11 +72,8 @@ def test_integration():
66
72
  # Get the messages from the current state
67
73
  # and reverse the order
68
74
  reversed_messages = current_state.values["messages"][::-1]
69
- # Loop through the reversed messages
70
- # until a ToolMessage is found.
71
- expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular']
72
- expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
73
- expected_header += ['CRP{liver}']
75
+ # Loop through the reversed messages until a
76
+ # ToolMessage is found.
74
77
  predicted_artifact = []
75
78
  for msg in reversed_messages:
76
79
  if isinstance(msg, ToolMessage):
@@ -80,24 +83,17 @@ def test_integration():
80
83
  if msg.name == "custom_plotter":
81
84
  predicted_artifact = msg.artifact
82
85
  break
83
- # Convert the artifact into a pandas dataframe
84
- # for easy comparison
85
- df = pd.DataFrame(predicted_artifact)
86
- # Extract the headers from the dataframe
87
- predicted_header = df.columns.tolist()
88
- # Check if the header is in the expected_header
89
- # assert expected_header in predicted_artifact
90
- assert set(expected_header).issubset(set(predicted_header))
86
+ # Check if the the predicted artifact is `None`
87
+ assert predicted_artifact is None
88
+
91
89
  ##########################################
92
90
  # Test custom_plotter tool when the
93
- # simulation results are available but
94
- # the species is not available
91
+ # simulation results are available
95
92
  ##########################################
96
- prompt = """Make a custom plot showing the
97
- concentration of the species `TP53` over
98
- time. Do not show any other species."""
93
+ prompt = "Plot only CRP related species."
94
+
99
95
  # Update state
100
- app.update_state(config, {"llm_model": "gpt-4o-mini"}
96
+ app.update_state(config, {"llm_model": LLM_MODEL}
101
97
  )
102
98
  # Test the tool get_modelinfo
103
99
  response = app.invoke(
@@ -105,13 +101,15 @@ def test_integration():
105
101
  config=config
106
102
  )
107
103
  assistant_msg = response["messages"][-1].content
108
- # print (response["messages"])
109
104
  current_state = app.get_state(config)
110
105
  # Get the messages from the current state
111
106
  # and reverse the order
112
107
  reversed_messages = current_state.values["messages"][::-1]
113
- # Loop through the reversed messages until a
114
- # ToolMessage is found.
108
+ # Loop through the reversed messages
109
+ # until a ToolMessage is found.
110
+ expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular']
111
+ expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
112
+ expected_header += ['CRP{liver}']
115
113
  predicted_artifact = []
116
114
  for msg in reversed_messages:
117
115
  if isinstance(msg, ToolMessage):
@@ -121,5 +119,11 @@ def test_integration():
121
119
  if msg.name == "custom_plotter":
122
120
  predicted_artifact = msg.artifact
123
121
  break
124
- # Check if the the predicted artifact is `None`
125
- assert predicted_artifact is None
122
+ # Convert the artifact into a pandas dataframe
123
+ # for easy comparison
124
+ df = pd.DataFrame(predicted_artifact)
125
+ # Extract the headers from the dataframe
126
+ predicted_header = df.columns.tolist()
127
+ # Check if the header is in the expected_header
128
+ # assert expected_header in predicted_artifact
129
+ assert set(expected_header).issubset(set(predicted_header))
@@ -5,6 +5,7 @@ Test cases for Talk2Biomodels query_article tool.
5
5
  from pydantic import BaseModel, Field
6
6
  from langchain_core.messages import HumanMessage, ToolMessage
7
7
  from langchain_openai import ChatOpenAI
8
+ from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
8
9
  from ..agents.t2b_agent import get_app
9
10
 
10
11
  class Article(BaseModel):
@@ -21,8 +22,10 @@ def test_query_article_with_an_article():
21
22
  app = get_app(unique_id)
22
23
  config = {"configurable": {"thread_id": unique_id}}
23
24
  # Update state by providing the pdf file name
25
+ # and the text embedding model
24
26
  app.update_state(config,
25
- {"pdf_file_name": "aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf"})
27
+ {"pdf_file_name": "aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf",
28
+ "text_embedding_model": NVIDIAEmbeddings(model='nvidia/llama-3.2-nv-embedqa-1b-v2')})
26
29
  prompt = "What is the title of the article?"
27
30
  # Test the tool query_article
28
31
  response = app.invoke(
@@ -55,6 +58,9 @@ def test_query_article_without_an_article():
55
58
  app = get_app(unique_id)
56
59
  config = {"configurable": {"thread_id": unique_id}}
57
60
  prompt = "What is the title of the uploaded article?"
61
+ # Update state by providing the text embedding model
62
+ app.update_state(config,
63
+ {"text_embedding_model": NVIDIAEmbeddings(model='nvidia/llama-3.2-nv-embedqa-1b-v2')})
58
64
  # Test the tool query_article
59
65
  app.invoke(
60
66
  {"messages": [HumanMessage(content=prompt)]},
@@ -3,6 +3,7 @@ Test cases for Talk2Biomodels search models tool.
3
3
  '''
4
4
 
5
5
  from langchain_core.messages import HumanMessage
6
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
6
7
  from ..agents.t2b_agent import get_app
7
8
 
8
9
  def test_search_models_tool():
@@ -13,7 +14,8 @@ def test_search_models_tool():
13
14
  app = get_app(unique_id)
14
15
  config = {"configurable": {"thread_id": unique_id}}
15
16
  # Update state
16
- app.update_state(config, {"llm_model": "gpt-4o-mini"})
17
+ app.update_state(config,
18
+ {"llm_model": ChatNVIDIA(model="meta/llama-3.3-70b-instruct")})
17
19
  prompt = "Search for models on Crohn's disease."
18
20
  # Test the tool get_modelinfo
19
21
  response = app.invoke(
@@ -3,8 +3,11 @@ Test cases for Talk2Biomodels steady state tool.
3
3
  '''
4
4
 
5
5
  from langchain_core.messages import HumanMessage, ToolMessage
6
+ from langchain_openai import ChatOpenAI
6
7
  from ..agents.t2b_agent import get_app
7
8
 
9
+ LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0)
10
+
8
11
  def test_steady_state_tool():
9
12
  '''
10
13
  Test the steady_state tool.
@@ -12,7 +15,7 @@ def test_steady_state_tool():
12
15
  unique_id = 123
13
16
  app = get_app(unique_id)
14
17
  config = {"configurable": {"thread_id": unique_id}}
15
- app.update_state(config, {"llm_model": "gpt-4o-mini"})
18
+ app.update_state(config, {"llm_model": LLM_MODEL})
16
19
  #########################################################
17
20
  # In this case, we will test if the tool returns an error
18
21
  # when the model does not achieve a steady state. The tool
@@ -37,8 +40,8 @@ def test_steady_state_tool():
37
40
  #########################################################
38
41
  # In this case, we will test if the tool is indeed invoked
39
42
  # successfully
40
- prompt = """Run a steady state analysis of model 64.
41
- Set the initial concentration of `Pyruvate` to 0.2. The
43
+ prompt = """Bring model 64 to a steady state. Set the
44
+ initial concentration of `Pyruvate` to 0.2. The
42
45
  concentration of `NAD` resets to 100 every 2 time units."""
43
46
  # Invoke the agent
44
47
  app.invoke(
@@ -12,7 +12,6 @@ import pandas as pd
12
12
  from pydantic import BaseModel, Field
13
13
  from langchain_core.tools.base import BaseTool
14
14
  from langchain_experimental.agents import create_pandas_dataframe_agent
15
- from langchain_openai import ChatOpenAI
16
15
  from langgraph.prebuilt import InjectedState
17
16
 
18
17
  # Initialize logger
@@ -101,7 +100,7 @@ class AskQuestionTool(BaseTool):
101
100
  prompt_content += f"{basico.model_info.get_model_units()}\n\n"
102
101
  # Create a pandas dataframe agent
103
102
  df_agent = create_pandas_dataframe_agent(
104
- ChatOpenAI(model=state['llm_model']),
103
+ state['llm_model'],
105
104
  allow_dangerous_code=True,
106
105
  agent_type='tool-calling',
107
106
  df=df,
@@ -5,10 +5,9 @@ Tool for plotting a custom figure.
5
5
  """
6
6
 
7
7
  import logging
8
- from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
8
+ from typing import Type, Annotated, List, Tuple, Union, Literal
9
9
  from pydantic import BaseModel, Field
10
10
  import pandas as pd
11
- from langchain_openai import ChatOpenAI
12
11
  from langchain_core.tools import BaseTool
13
12
  from langgraph.prebuilt import InjectedState
14
13
 
@@ -71,30 +70,44 @@ class CustomPlotterTool(BaseTool):
71
70
  species_names = df.columns.tolist()
72
71
  # Exclude the time column
73
72
  species_names.remove('Time')
73
+ logging.log(logging.INFO, "Species names: %s", species_names)
74
74
  # In the following code, we extract the species
75
75
  # from the user question. We use Literal to restrict
76
76
  # the species names to the ones available in the
77
77
  # simulation results.
78
- class CustomHeader(TypedDict):
78
+ class CustomHeader(BaseModel):
79
79
  """
80
80
  A list of species based on user question.
81
+
82
+ This is a Pydantic model that restricts the species
83
+ names to the ones available in the simulation results.
84
+
85
+ If no species is relevant, set the attribute
86
+ `relevant_species` to None.
81
87
  """
82
88
  relevant_species: Union[None, List[Literal[*species_names]]] = Field(
83
- description="""List of species based on user question.
84
- If no relevant species are found, it will be None.""")
89
+ description="This is a list of species based on the user question."
90
+ "It is restricted to the species available in the simulation results."
91
+ "If no species is relevant, set this attribute to None."
92
+ "If the user asks for very specific species (for example, using the"
93
+ "keyword `only` in the question), set this attribute to correspond "
94
+ "to the species available in the simulation results, otherwise set it to None."
95
+ )
85
96
  # Create an instance of the LLM model
86
- llm = ChatOpenAI(model=state['llm_model'], temperature=0)
97
+ logging.log(logging.INFO, "LLM model: %s", state['llm_model'])
98
+ llm = state['llm_model']
87
99
  llm_with_structured_output = llm.with_structured_output(CustomHeader)
88
100
  results = llm_with_structured_output.invoke(question)
101
+ if results.relevant_species is None:
102
+ raise ValueError("No species found in the simulation results \
103
+ that matches the user prompt.")
89
104
  extracted_species = []
90
105
  # Extract the species from the results
91
106
  # that are available in the simulation results
92
- for species in results['relevant_species']:
107
+ for species in results.relevant_species:
93
108
  if species in species_names:
94
109
  extracted_species.append(species)
95
- logger.info("Extracted species: %s", extracted_species)
96
- if len(extracted_species) == 0:
97
- return "No species found in the simulation results that matches the user prompt.", None
110
+ logging.info("Extracted species: %s", extracted_species)
98
111
  # Include the time column
99
112
  extracted_species.insert(0, 'Time')
100
113
  return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')
@@ -5,7 +5,7 @@ This module contains the `GetAnnotationTool` for fetching species annotations
5
5
  based on the provided model and species names.
6
6
  """
7
7
  import math
8
- from typing import List, Annotated, Type, TypedDict, Union, Literal
8
+ from typing import List, Annotated, Type, Union, Literal
9
9
  import logging
10
10
  from dataclasses import dataclass
11
11
  import hydra
@@ -17,7 +17,7 @@ from langgraph.prebuilt import InjectedState
17
17
  from langchain_core.tools.base import BaseTool
18
18
  from langchain_core.tools.base import InjectedToolCallId
19
19
  from langchain_core.messages import ToolMessage
20
- from langchain_openai import ChatOpenAI
20
+ # from langchain_openai import ChatOpenAI
21
21
  from .load_biomodel import ModelData, load_biomodel
22
22
  from ..api.uniprot import search_uniprot_labels
23
23
  from ..api.ols import search_ols_labels
@@ -49,7 +49,7 @@ def extract_relevant_species_names(model_object, arg_data, state):
49
49
  all_species_names = df_species.index.tolist()
50
50
 
51
51
  # Define a structured output for the LLM model
52
- class CustomHeader(TypedDict):
52
+ class CustomHeader(BaseModel):
53
53
  """
54
54
  A list of species based on user question.
55
55
  """
@@ -58,17 +58,21 @@ def extract_relevant_species_names(model_object, arg_data, state):
58
58
  If no relevant species are found, it must be None.""")
59
59
 
60
60
  # Create an instance of the LLM model
61
- llm = ChatOpenAI(model=state['llm_model'], temperature=0)
61
+ llm = state['llm_model']
62
62
  # Get the structured output from the LLM model
63
63
  llm_with_structured_output = llm.with_structured_output(CustomHeader)
64
64
  # Define the question for the LLM model using the prompt
65
65
  question = cfg.prompt
66
66
  question += f'Here is the user question: {arg_data.user_question}'
67
67
  # Invoke the LLM model with the user question
68
- dic = llm_with_structured_output.invoke(question)
68
+ results = llm_with_structured_output.invoke(question)
69
+ logging.info("Results from the LLM model: %s", results)
70
+ # Check if the returned species names are empty
71
+ if not results.relevant_species:
72
+ raise ValueError("Model does not contain the requested species.")
69
73
  extracted_species = []
70
74
  # Extract all the species names from the model
71
- for species in dic['relevant_species']:
75
+ for species in results.relevant_species:
72
76
  if species in all_species_names:
73
77
  extracted_species.append(species)
74
78
  logger.info("Extracted species: %s", extracted_species)
@@ -136,10 +140,7 @@ class GetAnnotationTool(BaseTool):
136
140
 
137
141
  # Extract relevant species names based on the user question
138
142
  list_species_names = extract_relevant_species_names(model_object, arg_data, state)
139
-
140
- # Check if the returned species names are empty
141
- if not list_species_names:
142
- raise ValueError("Model does not contain the requested species.")
143
+ print (list_species_names)
143
144
 
144
145
  (annotations_df,
145
146
  species_without_description) = self._fetch_annotations(list_species_names)
@@ -9,7 +9,6 @@ from typing import Type, Annotated
9
9
  from pydantic import BaseModel, Field
10
10
  from langchain_core.tools import BaseTool
11
11
  from langchain_core.vectorstores import InMemoryVectorStore
12
- from langchain_openai.embeddings import OpenAIEmbeddings
13
12
  from langchain_community.document_loaders import PyPDFLoader
14
13
  from langgraph.prebuilt import InjectedState
15
14
 
@@ -51,8 +50,13 @@ class QueryArticle(BaseTool):
51
50
  pages = []
52
51
  for page in loader.lazy_load():
53
52
  pages.append(page)
53
+ # Set up text embedding model
54
+ text_embedding_model = state['text_embedding_model']
55
+ logging.info("Loaded text embedding model %s", text_embedding_model)
54
56
  # Create a vector store from the pages
55
- vector_store = InMemoryVectorStore.from_documents(pages, OpenAIEmbeddings())
57
+ vector_store = InMemoryVectorStore.from_documents(
58
+ pages,
59
+ text_embedding_model)
56
60
  # Search the article with the question
57
61
  docs = vector_store.similarity_search(question)
58
62
  # Return the content of the pages
@@ -5,14 +5,18 @@ Tool for searching models based on search query.
5
5
  """
6
6
 
7
7
  from typing import Type, Annotated
8
+ import logging
8
9
  from pydantic import BaseModel, Field
9
10
  from basico import biomodels
10
11
  from langchain_core.tools import BaseTool
11
12
  from langchain_core.output_parsers import StrOutputParser
12
13
  from langchain_core.prompts import ChatPromptTemplate
13
- from langchain_openai import ChatOpenAI
14
14
  from langgraph.prebuilt import InjectedState
15
15
 
16
+ # Initialize logger
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
16
20
  class SearchModelsInput(BaseModel):
17
21
  """
18
22
  Input schema for the search models tool.
@@ -43,8 +47,10 @@ class SearchModelsTool(BaseTool):
43
47
  Returns:
44
48
  dict: The answer to the question in the form of a dictionary.
45
49
  """
50
+ logger.log(logging.INFO, "Searching models with the query and model: %s, %s",
51
+ query, state['llm_model'])
46
52
  search_results = biomodels.search_for_model(query)
47
- llm = ChatOpenAI(model=state['llm_model'])
53
+ llm = state['llm_model']
48
54
  # Check if run_manager's metadata has the key 'prompt_content'
49
55
  prompt_content = f'''
50
56
  Convert the input into a table.
@@ -1,5 +1,8 @@
1
1
  '''
2
2
  This file is used to import the datasets and utils.
3
3
  '''
4
+ from . import agents
4
5
  from . import datasets
6
+ from . import states
7
+ from . import tools
5
8
  from . import utils
@@ -0,0 +1,4 @@
1
+ '''
2
+ This file is used to import all the models in the package.
3
+ '''
4
+ from . import t2kg_agent
@@ -0,0 +1,85 @@
1
+ '''
2
+ This is the agent file for the Talk2KnowledgeGraphs agent.
3
+ '''
4
+
5
+ import logging
6
+ from typing import Annotated
7
+ import hydra
8
+ from langchain_ollama import ChatOllama
9
+ from langchain_core.language_models.chat_models import BaseChatModel
10
+ from langgraph.checkpoint.memory import MemorySaver
11
+ from langgraph.graph import START, StateGraph
12
+ from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
13
+ from ..tools.subgraph_extraction import SubgraphExtractionTool
14
+ from ..tools.subgraph_summarization import SubgraphSummarizationTool
15
+ from ..tools.graphrag_reasoning import GraphRAGReasoningTool
16
+ from ..states.state_talk2knowledgegraphs import Talk2KnowledgeGraphs
17
+
18
+ # Initialize logger
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def get_app(uniq_id, llm_model: BaseChatModel=ChatOllama(model='llama3.2:1b', temperature=0.0)):
23
+ '''
24
+ This function returns the langraph app.
25
+ '''
26
+ def agent_t2kg_node(state: Annotated[dict, InjectedState]):
27
+ '''
28
+ This function calls the model.
29
+ '''
30
+ logger.log(logging.INFO, "Calling t2kg_agent node with thread_id %s", uniq_id)
31
+ response = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
32
+
33
+ return response
34
+
35
+ # Load hydra configuration
36
+ logger.log(logging.INFO, "Load Hydra configuration for Talk2KnowledgeGraphs agent.")
37
+ with hydra.initialize(version_base=None, config_path="../configs"):
38
+ cfg = hydra.compose(config_name='config',
39
+ overrides=['agents/t2kg_agent=default'])
40
+ cfg = cfg.agents.t2kg_agent
41
+
42
+ # Define the tools
43
+ subgraph_extraction = SubgraphExtractionTool()
44
+ subgraph_summarization = SubgraphSummarizationTool()
45
+ graphrag_reasoning = GraphRAGReasoningTool()
46
+ tools = ToolNode([
47
+ subgraph_extraction,
48
+ subgraph_summarization,
49
+ graphrag_reasoning,
50
+ ])
51
+
52
+ # Create the agent
53
+ model = create_react_agent(
54
+ llm_model,
55
+ tools=tools,
56
+ state_schema=Talk2KnowledgeGraphs,
57
+ state_modifier=cfg.state_modifier,
58
+ checkpointer=MemorySaver()
59
+ )
60
+
61
+ # Define a new graph
62
+ workflow = StateGraph(Talk2KnowledgeGraphs)
63
+
64
+ # Define the two nodes we will cycle between
65
+ workflow.add_node("agent_t2kg", agent_t2kg_node)
66
+
67
+ # Set the entrypoint as the first node
68
+ # This means that this node is the first one called
69
+ workflow.add_edge(START, "agent_t2kg")
70
+
71
+ # Initialize memory to persist state between graph runs
72
+ checkpointer = MemorySaver()
73
+
74
+ # Finally, we compile it!
75
+ # This compiles it into a LangChain Runnable,
76
+ # meaning you can use it as you would any other runnable.
77
+ # Note that we're (optionally) passing the memory
78
+ # when compiling the graph
79
+ app = workflow.compile(checkpointer=checkpointer)
80
+ logger.log(logging.INFO,
81
+ "Compiled the graph with thread_id %s and llm_model %s",
82
+ uniq_id,
83
+ llm_model)
84
+
85
+ return app
@@ -0,0 +1,7 @@
1
+ '''
2
+ Import all the modules in the package
3
+ '''
4
+
5
+ from . import agents
6
+ from . import tools
7
+ from . import app