aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
- aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
- aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
- aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ import logging
|
|
8
8
|
from typing import Annotated
|
9
9
|
import hydra
|
10
10
|
from langchain_openai import ChatOpenAI
|
11
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
11
12
|
from langgraph.checkpoint.memory import MemorySaver
|
12
13
|
from langgraph.graph import START, StateGraph
|
13
14
|
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
|
@@ -26,7 +27,8 @@ from ..states.state_talk2biomodels import Talk2Biomodels
|
|
26
27
|
logging.basicConfig(level=logging.INFO)
|
27
28
|
logger = logging.getLogger(__name__)
|
28
29
|
|
29
|
-
def get_app(uniq_id,
|
30
|
+
def get_app(uniq_id,
|
31
|
+
llm_model: BaseChatModel = ChatOpenAI(model='gpt-4o-mini', temperature=0)):
|
30
32
|
'''
|
31
33
|
This function returns the langraph app.
|
32
34
|
'''
|
@@ -51,8 +53,6 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
|
|
51
53
|
QueryArticle()
|
52
54
|
])
|
53
55
|
|
54
|
-
# Define the model
|
55
|
-
llm = ChatOpenAI(model=llm_model, temperature=0)
|
56
56
|
# Load hydra configuration
|
57
57
|
logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.")
|
58
58
|
with hydra.initialize(version_base=None, config_path="../configs"):
|
@@ -62,7 +62,7 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
|
|
62
62
|
logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier)
|
63
63
|
# Create the agent
|
64
64
|
model = create_react_agent(
|
65
|
-
|
65
|
+
llm_model,
|
66
66
|
tools=tools,
|
67
67
|
state_schema=Talk2Biomodels,
|
68
68
|
state_modifier=cfg.state_modifier,
|
@@ -10,22 +10,14 @@ steady_state_prompt: >
|
|
10
10
|
|
11
11
|
Here are some instructions to help you answer questions:
|
12
12
|
|
13
|
-
1.
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
13
|
+
1. If the user wants to know the time taken by the model to reach
|
14
|
+
steady state, you should look at the `steady_state_transition_time`
|
15
|
+
column of the data for the model species.
|
16
|
+
|
17
|
+
2. The highest value in the column `steady_state_transition_time`
|
18
|
+
is the time taken by the model to reach steady state.
|
18
19
|
|
19
|
-
|
20
|
-
steady state, you should look at the steady_state_transition_time
|
21
|
-
column of the data for the model species. The highest value in
|
22
|
-
this column is the time taken by the model to reach steady state.
|
23
|
-
|
24
|
-
3. To get accurate results, trim the data to the relevant columns
|
25
|
-
before performing any calculations. This will help you avoid
|
26
|
-
errors in your calculations, and ignore irrelevant data.
|
27
|
-
|
28
|
-
4. Please use the units provided below to answer the questions.
|
20
|
+
3. Please use the units provided below to answer the questions.
|
29
21
|
simulation_prompt: >
|
30
22
|
Following is the information about the data frame:
|
31
23
|
1. First column is the time column, and the rest of the columns
|
@@ -7,6 +7,8 @@ This is the state file for the Talk2BioModels agent.
|
|
7
7
|
from typing import Annotated
|
8
8
|
import operator
|
9
9
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
10
|
+
from langchain_core.language_models import BaseChatModel
|
11
|
+
from langchain_core.embeddings import Embeddings
|
10
12
|
|
11
13
|
def add_data(data1: dict, data2: dict) -> dict:
|
12
14
|
"""
|
@@ -26,7 +28,8 @@ class Talk2Biomodels(AgentState):
|
|
26
28
|
"""
|
27
29
|
The state for the Talk2BioModels agent.
|
28
30
|
"""
|
29
|
-
llm_model:
|
31
|
+
llm_model: BaseChatModel
|
32
|
+
text_embedding_model: Embeddings
|
30
33
|
pdf_file_name: str
|
31
34
|
# A StateGraph may receive a concurrent updates
|
32
35
|
# which is not supported by the StateGraph. Hence,
|
@@ -3,6 +3,7 @@ Test cases for Talk2Biomodels.
|
|
3
3
|
'''
|
4
4
|
|
5
5
|
from langchain_core.messages import HumanMessage, ToolMessage
|
6
|
+
from langchain_openai import ChatOpenAI
|
6
7
|
from ..agents.t2b_agent import get_app
|
7
8
|
|
8
9
|
def test_ask_question_tool():
|
@@ -10,7 +11,7 @@ def test_ask_question_tool():
|
|
10
11
|
Test the ask_question tool without the simulation results.
|
11
12
|
'''
|
12
13
|
unique_id = 12345
|
13
|
-
app = get_app(unique_id
|
14
|
+
app = get_app(unique_id)
|
14
15
|
config = {"configurable": {"thread_id": unique_id}}
|
15
16
|
|
16
17
|
##########################################
|
@@ -20,7 +21,8 @@ def test_ask_question_tool():
|
|
20
21
|
# case, the tool should return an error
|
21
22
|
##########################################
|
22
23
|
# Update state
|
23
|
-
app.update_state(config,
|
24
|
+
app.update_state(config,
|
25
|
+
{"llm_model": ChatOpenAI(model='gpt-4o-mini', temperature=0)})
|
24
26
|
# Define the prompt
|
25
27
|
prompt = "Call the ask_question tool to answer the "
|
26
28
|
prompt += "question: What is the concentration of CRP "
|
@@ -5,6 +5,7 @@ Test cases for Talk2Biomodels get_annotation tool.
|
|
5
5
|
import random
|
6
6
|
import pytest
|
7
7
|
from langchain_core.messages import HumanMessage, ToolMessage
|
8
|
+
from langchain_openai import ChatOpenAI
|
8
9
|
from ..agents.t2b_agent import get_app
|
9
10
|
from ..tools.get_annotation import prepare_content_msg
|
10
11
|
|
@@ -16,7 +17,9 @@ def make_graph_fixture():
|
|
16
17
|
unique_id = random.randint(1000, 9999)
|
17
18
|
graph = get_app(unique_id)
|
18
19
|
config = {"configurable": {"thread_id": unique_id}}
|
19
|
-
graph.update_state(config, {"llm_model":
|
20
|
+
graph.update_state(config, {"llm_model": ChatOpenAI(model='gpt-4o-mini',
|
21
|
+
temperature=0)
|
22
|
+
})
|
20
23
|
return graph, config
|
21
24
|
|
22
25
|
def test_no_model_provided(make_graph):
|
@@ -85,7 +88,6 @@ def test_invalid_species_provided(make_graph):
|
|
85
88
|
# (likely due to an invalid species).
|
86
89
|
test_condition = True
|
87
90
|
break
|
88
|
-
# assert test_condition
|
89
91
|
assert test_condition
|
90
92
|
|
91
93
|
def test_invalid_and_valid_species_provided(make_graph):
|
@@ -4,8 +4,11 @@ Test cases for Talk2Biomodels.
|
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
from langchain_core.messages import HumanMessage, ToolMessage
|
7
|
+
from langchain_openai import ChatOpenAI
|
7
8
|
from ..agents.t2b_agent import get_app
|
8
9
|
|
10
|
+
LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0)
|
11
|
+
|
9
12
|
def test_integration():
|
10
13
|
'''
|
11
14
|
Test the integration of the tools.
|
@@ -13,7 +16,7 @@ def test_integration():
|
|
13
16
|
unique_id = 1234567
|
14
17
|
app = get_app(unique_id)
|
15
18
|
config = {"configurable": {"thread_id": unique_id}}
|
16
|
-
app.update_state(config, {"llm_model":
|
19
|
+
app.update_state(config, {"llm_model": LLM_MODEL})
|
17
20
|
# ##########################################
|
18
21
|
# ## Test simulate_model tool
|
19
22
|
# ##########################################
|
@@ -34,7 +37,7 @@ def test_integration():
|
|
34
37
|
# results are available
|
35
38
|
##########################################
|
36
39
|
# Update state
|
37
|
-
app.update_state(config, {"llm_model":
|
40
|
+
app.update_state(config, {"llm_model": LLM_MODEL})
|
38
41
|
prompt = """What is the concentration of CRP in serum after 100 hours?
|
39
42
|
Round off the value to 2 decimal places."""
|
40
43
|
# Test the tool get_modelinfo
|
@@ -49,12 +52,15 @@ def test_integration():
|
|
49
52
|
|
50
53
|
##########################################
|
51
54
|
# Test custom_plotter tool when the
|
52
|
-
# simulation results are available
|
55
|
+
# simulation results are available but
|
56
|
+
# the species is not available
|
53
57
|
##########################################
|
54
|
-
prompt = "
|
55
|
-
|
58
|
+
prompt = """Call the custom_plotter tool to make a plot
|
59
|
+
showing only species `TP53` and `Pyruvate`. Let me
|
60
|
+
know if these species were not found. Do not
|
61
|
+
invoke any other tool."""
|
56
62
|
# Update state
|
57
|
-
app.update_state(config, {"llm_model":
|
63
|
+
app.update_state(config, {"llm_model": LLM_MODEL}
|
58
64
|
)
|
59
65
|
# Test the tool get_modelinfo
|
60
66
|
response = app.invoke(
|
@@ -66,11 +72,8 @@ def test_integration():
|
|
66
72
|
# Get the messages from the current state
|
67
73
|
# and reverse the order
|
68
74
|
reversed_messages = current_state.values["messages"][::-1]
|
69
|
-
# Loop through the reversed messages
|
70
|
-
#
|
71
|
-
expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular']
|
72
|
-
expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
|
73
|
-
expected_header += ['CRP{liver}']
|
75
|
+
# Loop through the reversed messages until a
|
76
|
+
# ToolMessage is found.
|
74
77
|
predicted_artifact = []
|
75
78
|
for msg in reversed_messages:
|
76
79
|
if isinstance(msg, ToolMessage):
|
@@ -80,24 +83,17 @@ def test_integration():
|
|
80
83
|
if msg.name == "custom_plotter":
|
81
84
|
predicted_artifact = msg.artifact
|
82
85
|
break
|
83
|
-
#
|
84
|
-
|
85
|
-
|
86
|
-
# Extract the headers from the dataframe
|
87
|
-
predicted_header = df.columns.tolist()
|
88
|
-
# Check if the header is in the expected_header
|
89
|
-
# assert expected_header in predicted_artifact
|
90
|
-
assert set(expected_header).issubset(set(predicted_header))
|
86
|
+
# Check if the the predicted artifact is `None`
|
87
|
+
assert predicted_artifact is None
|
88
|
+
|
91
89
|
##########################################
|
92
90
|
# Test custom_plotter tool when the
|
93
|
-
# simulation results are available
|
94
|
-
# the species is not available
|
91
|
+
# simulation results are available
|
95
92
|
##########################################
|
96
|
-
prompt = "
|
97
|
-
|
98
|
-
time. Do not show any other species."""
|
93
|
+
prompt = "Plot only CRP related species."
|
94
|
+
|
99
95
|
# Update state
|
100
|
-
app.update_state(config, {"llm_model":
|
96
|
+
app.update_state(config, {"llm_model": LLM_MODEL}
|
101
97
|
)
|
102
98
|
# Test the tool get_modelinfo
|
103
99
|
response = app.invoke(
|
@@ -105,13 +101,15 @@ def test_integration():
|
|
105
101
|
config=config
|
106
102
|
)
|
107
103
|
assistant_msg = response["messages"][-1].content
|
108
|
-
# print (response["messages"])
|
109
104
|
current_state = app.get_state(config)
|
110
105
|
# Get the messages from the current state
|
111
106
|
# and reverse the order
|
112
107
|
reversed_messages = current_state.values["messages"][::-1]
|
113
|
-
# Loop through the reversed messages
|
114
|
-
# ToolMessage is found.
|
108
|
+
# Loop through the reversed messages
|
109
|
+
# until a ToolMessage is found.
|
110
|
+
expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular']
|
111
|
+
expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
|
112
|
+
expected_header += ['CRP{liver}']
|
115
113
|
predicted_artifact = []
|
116
114
|
for msg in reversed_messages:
|
117
115
|
if isinstance(msg, ToolMessage):
|
@@ -121,5 +119,11 @@ def test_integration():
|
|
121
119
|
if msg.name == "custom_plotter":
|
122
120
|
predicted_artifact = msg.artifact
|
123
121
|
break
|
124
|
-
#
|
125
|
-
|
122
|
+
# Convert the artifact into a pandas dataframe
|
123
|
+
# for easy comparison
|
124
|
+
df = pd.DataFrame(predicted_artifact)
|
125
|
+
# Extract the headers from the dataframe
|
126
|
+
predicted_header = df.columns.tolist()
|
127
|
+
# Check if the header is in the expected_header
|
128
|
+
# assert expected_header in predicted_artifact
|
129
|
+
assert set(expected_header).issubset(set(predicted_header))
|
@@ -5,6 +5,7 @@ Test cases for Talk2Biomodels query_article tool.
|
|
5
5
|
from pydantic import BaseModel, Field
|
6
6
|
from langchain_core.messages import HumanMessage, ToolMessage
|
7
7
|
from langchain_openai import ChatOpenAI
|
8
|
+
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
8
9
|
from ..agents.t2b_agent import get_app
|
9
10
|
|
10
11
|
class Article(BaseModel):
|
@@ -21,8 +22,10 @@ def test_query_article_with_an_article():
|
|
21
22
|
app = get_app(unique_id)
|
22
23
|
config = {"configurable": {"thread_id": unique_id}}
|
23
24
|
# Update state by providing the pdf file name
|
25
|
+
# and the text embedding model
|
24
26
|
app.update_state(config,
|
25
|
-
{"pdf_file_name": "aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf"
|
27
|
+
{"pdf_file_name": "aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf",
|
28
|
+
"text_embedding_model": NVIDIAEmbeddings(model='nvidia/llama-3.2-nv-embedqa-1b-v2')})
|
26
29
|
prompt = "What is the title of the article?"
|
27
30
|
# Test the tool query_article
|
28
31
|
response = app.invoke(
|
@@ -55,6 +58,9 @@ def test_query_article_without_an_article():
|
|
55
58
|
app = get_app(unique_id)
|
56
59
|
config = {"configurable": {"thread_id": unique_id}}
|
57
60
|
prompt = "What is the title of the uploaded article?"
|
61
|
+
# Update state by providing the text embedding model
|
62
|
+
app.update_state(config,
|
63
|
+
{"text_embedding_model": NVIDIAEmbeddings(model='nvidia/llama-3.2-nv-embedqa-1b-v2')})
|
58
64
|
# Test the tool query_article
|
59
65
|
app.invoke(
|
60
66
|
{"messages": [HumanMessage(content=prompt)]},
|
@@ -3,6 +3,7 @@ Test cases for Talk2Biomodels search models tool.
|
|
3
3
|
'''
|
4
4
|
|
5
5
|
from langchain_core.messages import HumanMessage
|
6
|
+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
6
7
|
from ..agents.t2b_agent import get_app
|
7
8
|
|
8
9
|
def test_search_models_tool():
|
@@ -13,7 +14,8 @@ def test_search_models_tool():
|
|
13
14
|
app = get_app(unique_id)
|
14
15
|
config = {"configurable": {"thread_id": unique_id}}
|
15
16
|
# Update state
|
16
|
-
app.update_state(config,
|
17
|
+
app.update_state(config,
|
18
|
+
{"llm_model": ChatNVIDIA(model="meta/llama-3.3-70b-instruct")})
|
17
19
|
prompt = "Search for models on Crohn's disease."
|
18
20
|
# Test the tool get_modelinfo
|
19
21
|
response = app.invoke(
|
@@ -3,8 +3,11 @@ Test cases for Talk2Biomodels steady state tool.
|
|
3
3
|
'''
|
4
4
|
|
5
5
|
from langchain_core.messages import HumanMessage, ToolMessage
|
6
|
+
from langchain_openai import ChatOpenAI
|
6
7
|
from ..agents.t2b_agent import get_app
|
7
8
|
|
9
|
+
LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0)
|
10
|
+
|
8
11
|
def test_steady_state_tool():
|
9
12
|
'''
|
10
13
|
Test the steady_state tool.
|
@@ -12,7 +15,7 @@ def test_steady_state_tool():
|
|
12
15
|
unique_id = 123
|
13
16
|
app = get_app(unique_id)
|
14
17
|
config = {"configurable": {"thread_id": unique_id}}
|
15
|
-
app.update_state(config, {"llm_model":
|
18
|
+
app.update_state(config, {"llm_model": LLM_MODEL})
|
16
19
|
#########################################################
|
17
20
|
# In this case, we will test if the tool returns an error
|
18
21
|
# when the model does not achieve a steady state. The tool
|
@@ -37,8 +40,8 @@ def test_steady_state_tool():
|
|
37
40
|
#########################################################
|
38
41
|
# In this case, we will test if the tool is indeed invoked
|
39
42
|
# successfully
|
40
|
-
prompt = """
|
41
|
-
|
43
|
+
prompt = """Bring model 64 to a steady state. Set the
|
44
|
+
initial concentration of `Pyruvate` to 0.2. The
|
42
45
|
concentration of `NAD` resets to 100 every 2 time units."""
|
43
46
|
# Invoke the agent
|
44
47
|
app.invoke(
|
@@ -12,7 +12,6 @@ import pandas as pd
|
|
12
12
|
from pydantic import BaseModel, Field
|
13
13
|
from langchain_core.tools.base import BaseTool
|
14
14
|
from langchain_experimental.agents import create_pandas_dataframe_agent
|
15
|
-
from langchain_openai import ChatOpenAI
|
16
15
|
from langgraph.prebuilt import InjectedState
|
17
16
|
|
18
17
|
# Initialize logger
|
@@ -101,7 +100,7 @@ class AskQuestionTool(BaseTool):
|
|
101
100
|
prompt_content += f"{basico.model_info.get_model_units()}\n\n"
|
102
101
|
# Create a pandas dataframe agent
|
103
102
|
df_agent = create_pandas_dataframe_agent(
|
104
|
-
|
103
|
+
state['llm_model'],
|
105
104
|
allow_dangerous_code=True,
|
106
105
|
agent_type='tool-calling',
|
107
106
|
df=df,
|
@@ -5,10 +5,9 @@ Tool for plotting a custom figure.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import logging
|
8
|
-
from typing import Type,
|
8
|
+
from typing import Type, Annotated, List, Tuple, Union, Literal
|
9
9
|
from pydantic import BaseModel, Field
|
10
10
|
import pandas as pd
|
11
|
-
from langchain_openai import ChatOpenAI
|
12
11
|
from langchain_core.tools import BaseTool
|
13
12
|
from langgraph.prebuilt import InjectedState
|
14
13
|
|
@@ -71,30 +70,44 @@ class CustomPlotterTool(BaseTool):
|
|
71
70
|
species_names = df.columns.tolist()
|
72
71
|
# Exclude the time column
|
73
72
|
species_names.remove('Time')
|
73
|
+
logging.log(logging.INFO, "Species names: %s", species_names)
|
74
74
|
# In the following code, we extract the species
|
75
75
|
# from the user question. We use Literal to restrict
|
76
76
|
# the species names to the ones available in the
|
77
77
|
# simulation results.
|
78
|
-
class CustomHeader(
|
78
|
+
class CustomHeader(BaseModel):
|
79
79
|
"""
|
80
80
|
A list of species based on user question.
|
81
|
+
|
82
|
+
This is a Pydantic model that restricts the species
|
83
|
+
names to the ones available in the simulation results.
|
84
|
+
|
85
|
+
If no species is relevant, set the attribute
|
86
|
+
`relevant_species` to None.
|
81
87
|
"""
|
82
88
|
relevant_species: Union[None, List[Literal[*species_names]]] = Field(
|
83
|
-
description="
|
84
|
-
|
89
|
+
description="This is a list of species based on the user question."
|
90
|
+
"It is restricted to the species available in the simulation results."
|
91
|
+
"If no species is relevant, set this attribute to None."
|
92
|
+
"If the user asks for very specific species (for example, using the"
|
93
|
+
"keyword `only` in the question), set this attribute to correspond "
|
94
|
+
"to the species available in the simulation results, otherwise set it to None."
|
95
|
+
)
|
85
96
|
# Create an instance of the LLM model
|
86
|
-
|
97
|
+
logging.log(logging.INFO, "LLM model: %s", state['llm_model'])
|
98
|
+
llm = state['llm_model']
|
87
99
|
llm_with_structured_output = llm.with_structured_output(CustomHeader)
|
88
100
|
results = llm_with_structured_output.invoke(question)
|
101
|
+
if results.relevant_species is None:
|
102
|
+
raise ValueError("No species found in the simulation results \
|
103
|
+
that matches the user prompt.")
|
89
104
|
extracted_species = []
|
90
105
|
# Extract the species from the results
|
91
106
|
# that are available in the simulation results
|
92
|
-
for species in results
|
107
|
+
for species in results.relevant_species:
|
93
108
|
if species in species_names:
|
94
109
|
extracted_species.append(species)
|
95
|
-
|
96
|
-
if len(extracted_species) == 0:
|
97
|
-
return "No species found in the simulation results that matches the user prompt.", None
|
110
|
+
logging.info("Extracted species: %s", extracted_species)
|
98
111
|
# Include the time column
|
99
112
|
extracted_species.insert(0, 'Time')
|
100
113
|
return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')
|
@@ -5,7 +5,7 @@ This module contains the `GetAnnotationTool` for fetching species annotations
|
|
5
5
|
based on the provided model and species names.
|
6
6
|
"""
|
7
7
|
import math
|
8
|
-
from typing import List, Annotated, Type,
|
8
|
+
from typing import List, Annotated, Type, Union, Literal
|
9
9
|
import logging
|
10
10
|
from dataclasses import dataclass
|
11
11
|
import hydra
|
@@ -17,7 +17,7 @@ from langgraph.prebuilt import InjectedState
|
|
17
17
|
from langchain_core.tools.base import BaseTool
|
18
18
|
from langchain_core.tools.base import InjectedToolCallId
|
19
19
|
from langchain_core.messages import ToolMessage
|
20
|
-
from langchain_openai import ChatOpenAI
|
20
|
+
# from langchain_openai import ChatOpenAI
|
21
21
|
from .load_biomodel import ModelData, load_biomodel
|
22
22
|
from ..api.uniprot import search_uniprot_labels
|
23
23
|
from ..api.ols import search_ols_labels
|
@@ -49,7 +49,7 @@ def extract_relevant_species_names(model_object, arg_data, state):
|
|
49
49
|
all_species_names = df_species.index.tolist()
|
50
50
|
|
51
51
|
# Define a structured output for the LLM model
|
52
|
-
class CustomHeader(
|
52
|
+
class CustomHeader(BaseModel):
|
53
53
|
"""
|
54
54
|
A list of species based on user question.
|
55
55
|
"""
|
@@ -58,17 +58,21 @@ def extract_relevant_species_names(model_object, arg_data, state):
|
|
58
58
|
If no relevant species are found, it must be None.""")
|
59
59
|
|
60
60
|
# Create an instance of the LLM model
|
61
|
-
llm =
|
61
|
+
llm = state['llm_model']
|
62
62
|
# Get the structured output from the LLM model
|
63
63
|
llm_with_structured_output = llm.with_structured_output(CustomHeader)
|
64
64
|
# Define the question for the LLM model using the prompt
|
65
65
|
question = cfg.prompt
|
66
66
|
question += f'Here is the user question: {arg_data.user_question}'
|
67
67
|
# Invoke the LLM model with the user question
|
68
|
-
|
68
|
+
results = llm_with_structured_output.invoke(question)
|
69
|
+
logging.info("Results from the LLM model: %s", results)
|
70
|
+
# Check if the returned species names are empty
|
71
|
+
if not results.relevant_species:
|
72
|
+
raise ValueError("Model does not contain the requested species.")
|
69
73
|
extracted_species = []
|
70
74
|
# Extract all the species names from the model
|
71
|
-
for species in
|
75
|
+
for species in results.relevant_species:
|
72
76
|
if species in all_species_names:
|
73
77
|
extracted_species.append(species)
|
74
78
|
logger.info("Extracted species: %s", extracted_species)
|
@@ -136,10 +140,7 @@ class GetAnnotationTool(BaseTool):
|
|
136
140
|
|
137
141
|
# Extract relevant species names based on the user question
|
138
142
|
list_species_names = extract_relevant_species_names(model_object, arg_data, state)
|
139
|
-
|
140
|
-
# Check if the returned species names are empty
|
141
|
-
if not list_species_names:
|
142
|
-
raise ValueError("Model does not contain the requested species.")
|
143
|
+
print (list_species_names)
|
143
144
|
|
144
145
|
(annotations_df,
|
145
146
|
species_without_description) = self._fetch_annotations(list_species_names)
|
@@ -9,7 +9,6 @@ from typing import Type, Annotated
|
|
9
9
|
from pydantic import BaseModel, Field
|
10
10
|
from langchain_core.tools import BaseTool
|
11
11
|
from langchain_core.vectorstores import InMemoryVectorStore
|
12
|
-
from langchain_openai.embeddings import OpenAIEmbeddings
|
13
12
|
from langchain_community.document_loaders import PyPDFLoader
|
14
13
|
from langgraph.prebuilt import InjectedState
|
15
14
|
|
@@ -51,8 +50,13 @@ class QueryArticle(BaseTool):
|
|
51
50
|
pages = []
|
52
51
|
for page in loader.lazy_load():
|
53
52
|
pages.append(page)
|
53
|
+
# Set up text embedding model
|
54
|
+
text_embedding_model = state['text_embedding_model']
|
55
|
+
logging.info("Loaded text embedding model %s", text_embedding_model)
|
54
56
|
# Create a vector store from the pages
|
55
|
-
vector_store = InMemoryVectorStore.from_documents(
|
57
|
+
vector_store = InMemoryVectorStore.from_documents(
|
58
|
+
pages,
|
59
|
+
text_embedding_model)
|
56
60
|
# Search the article with the question
|
57
61
|
docs = vector_store.similarity_search(question)
|
58
62
|
# Return the content of the pages
|
@@ -5,14 +5,18 @@ Tool for searching models based on search query.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
from typing import Type, Annotated
|
8
|
+
import logging
|
8
9
|
from pydantic import BaseModel, Field
|
9
10
|
from basico import biomodels
|
10
11
|
from langchain_core.tools import BaseTool
|
11
12
|
from langchain_core.output_parsers import StrOutputParser
|
12
13
|
from langchain_core.prompts import ChatPromptTemplate
|
13
|
-
from langchain_openai import ChatOpenAI
|
14
14
|
from langgraph.prebuilt import InjectedState
|
15
15
|
|
16
|
+
# Initialize logger
|
17
|
+
logging.basicConfig(level=logging.INFO)
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
16
20
|
class SearchModelsInput(BaseModel):
|
17
21
|
"""
|
18
22
|
Input schema for the search models tool.
|
@@ -43,8 +47,10 @@ class SearchModelsTool(BaseTool):
|
|
43
47
|
Returns:
|
44
48
|
dict: The answer to the question in the form of a dictionary.
|
45
49
|
"""
|
50
|
+
logger.log(logging.INFO, "Searching models with the query and model: %s, %s",
|
51
|
+
query, state['llm_model'])
|
46
52
|
search_results = biomodels.search_for_model(query)
|
47
|
-
llm =
|
53
|
+
llm = state['llm_model']
|
48
54
|
# Check if run_manager's metadata has the key 'prompt_content'
|
49
55
|
prompt_content = f'''
|
50
56
|
Convert the input into a table.
|
@@ -0,0 +1,85 @@
|
|
1
|
+
'''
|
2
|
+
This is the agent file for the Talk2KnowledgeGraphs agent.
|
3
|
+
'''
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Annotated
|
7
|
+
import hydra
|
8
|
+
from langchain_ollama import ChatOllama
|
9
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
10
|
+
from langgraph.checkpoint.memory import MemorySaver
|
11
|
+
from langgraph.graph import START, StateGraph
|
12
|
+
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
|
13
|
+
from ..tools.subgraph_extraction import SubgraphExtractionTool
|
14
|
+
from ..tools.subgraph_summarization import SubgraphSummarizationTool
|
15
|
+
from ..tools.graphrag_reasoning import GraphRAGReasoningTool
|
16
|
+
from ..states.state_talk2knowledgegraphs import Talk2KnowledgeGraphs
|
17
|
+
|
18
|
+
# Initialize logger
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
def get_app(uniq_id, llm_model: BaseChatModel=ChatOllama(model='llama3.2:1b', temperature=0.0)):
|
23
|
+
'''
|
24
|
+
This function returns the langraph app.
|
25
|
+
'''
|
26
|
+
def agent_t2kg_node(state: Annotated[dict, InjectedState]):
|
27
|
+
'''
|
28
|
+
This function calls the model.
|
29
|
+
'''
|
30
|
+
logger.log(logging.INFO, "Calling t2kg_agent node with thread_id %s", uniq_id)
|
31
|
+
response = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
|
32
|
+
|
33
|
+
return response
|
34
|
+
|
35
|
+
# Load hydra configuration
|
36
|
+
logger.log(logging.INFO, "Load Hydra configuration for Talk2KnowledgeGraphs agent.")
|
37
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
38
|
+
cfg = hydra.compose(config_name='config',
|
39
|
+
overrides=['agents/t2kg_agent=default'])
|
40
|
+
cfg = cfg.agents.t2kg_agent
|
41
|
+
|
42
|
+
# Define the tools
|
43
|
+
subgraph_extraction = SubgraphExtractionTool()
|
44
|
+
subgraph_summarization = SubgraphSummarizationTool()
|
45
|
+
graphrag_reasoning = GraphRAGReasoningTool()
|
46
|
+
tools = ToolNode([
|
47
|
+
subgraph_extraction,
|
48
|
+
subgraph_summarization,
|
49
|
+
graphrag_reasoning,
|
50
|
+
])
|
51
|
+
|
52
|
+
# Create the agent
|
53
|
+
model = create_react_agent(
|
54
|
+
llm_model,
|
55
|
+
tools=tools,
|
56
|
+
state_schema=Talk2KnowledgeGraphs,
|
57
|
+
state_modifier=cfg.state_modifier,
|
58
|
+
checkpointer=MemorySaver()
|
59
|
+
)
|
60
|
+
|
61
|
+
# Define a new graph
|
62
|
+
workflow = StateGraph(Talk2KnowledgeGraphs)
|
63
|
+
|
64
|
+
# Define the two nodes we will cycle between
|
65
|
+
workflow.add_node("agent_t2kg", agent_t2kg_node)
|
66
|
+
|
67
|
+
# Set the entrypoint as the first node
|
68
|
+
# This means that this node is the first one called
|
69
|
+
workflow.add_edge(START, "agent_t2kg")
|
70
|
+
|
71
|
+
# Initialize memory to persist state between graph runs
|
72
|
+
checkpointer = MemorySaver()
|
73
|
+
|
74
|
+
# Finally, we compile it!
|
75
|
+
# This compiles it into a LangChain Runnable,
|
76
|
+
# meaning you can use it as you would any other runnable.
|
77
|
+
# Note that we're (optionally) passing the memory
|
78
|
+
# when compiling the graph
|
79
|
+
app = workflow.compile(checkpointer=checkpointer)
|
80
|
+
logger.log(logging.INFO,
|
81
|
+
"Compiled the graph with thread_id %s and llm_model %s",
|
82
|
+
uniq_id,
|
83
|
+
llm_model)
|
84
|
+
|
85
|
+
return app
|