aiagents4pharma 1.35.0__py3-none-any.whl → 1.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +12 -4
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +2 -2
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +7 -6
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +12 -11
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +152 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +36 -65
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +9 -2
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +20 -13
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +292 -0
- {aiagents4pharma-1.35.0.dist-info → aiagents4pharma-1.37.0.dist-info}/METADATA +24 -8
- {aiagents4pharma-1.35.0.dist-info → aiagents4pharma-1.37.0.dist-info}/RECORD +20 -16
- {aiagents4pharma-1.35.0.dist-info → aiagents4pharma-1.37.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.35.0.dist-info → aiagents4pharma-1.37.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.35.0.dist-info → aiagents4pharma-1.37.0.dist-info}/top_level.txt +0 -0
@@ -20,12 +20,20 @@ def input_dict_fixture():
|
|
20
20
|
input_dict = {
|
21
21
|
"topk_nodes": 3,
|
22
22
|
"topk_edges": 3,
|
23
|
+
"selections": {
|
24
|
+
"gene/protein": [],
|
25
|
+
"molecular_function": [],
|
26
|
+
"cellular_component": [],
|
27
|
+
"biological_process": [],
|
28
|
+
"drug": [],
|
29
|
+
"disease": []
|
30
|
+
},
|
23
31
|
"uploaded_files": [],
|
24
32
|
"dic_source_graph": [
|
25
33
|
{
|
26
|
-
"name": "
|
27
|
-
"kg_pyg_path": f"{DATA_PATH}/
|
28
|
-
"kg_text_path": f"{DATA_PATH}/
|
34
|
+
"name": "BioBridge",
|
35
|
+
"kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
|
36
|
+
"kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
|
29
37
|
}
|
30
38
|
],
|
31
39
|
"dic_extracted_graph": []
|
@@ -70,7 +78,7 @@ def test_main_agent_invokes_t2kg(input_dict):
|
|
70
78
|
current_state = app.get_state(config)
|
71
79
|
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
72
80
|
assert isinstance(dic_extracted_graph, dict)
|
73
|
-
assert dic_extracted_graph["graph_source"] == "
|
81
|
+
assert dic_extracted_graph["graph_source"] == "BioBridge"
|
74
82
|
assert dic_extracted_graph["topk_nodes"] == 3
|
75
83
|
assert dic_extracted_graph["topk_edges"] == 3
|
76
84
|
assert isinstance(dic_extracted_graph["graph_dict"], dict)
|
@@ -9,7 +9,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
|
|
9
9
|
from langgraph.checkpoint.memory import MemorySaver
|
10
10
|
from langgraph.graph import START, StateGraph
|
11
11
|
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
|
12
|
-
from ..tools.
|
12
|
+
from ..tools.multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
|
13
13
|
from ..tools.subgraph_summarization import SubgraphSummarizationTool
|
14
14
|
from ..tools.graphrag_reasoning import GraphRAGReasoningTool
|
15
15
|
from ..states.state_talk2knowledgegraphs import Talk2KnowledgeGraphs
|
@@ -39,7 +39,7 @@ def get_app(uniq_id, llm_model: BaseChatModel):
|
|
39
39
|
cfg = cfg.agents.t2kg_agent
|
40
40
|
|
41
41
|
# Define the tools
|
42
|
-
subgraph_extraction =
|
42
|
+
subgraph_extraction = MultimodalSubgraphExtractionTool()
|
43
43
|
subgraph_summarization = SubgraphSummarizationTool()
|
44
44
|
graphrag_reasoning = GraphRAGReasoningTool()
|
45
45
|
tools = ToolNode([
|
@@ -2,12 +2,13 @@ _target_: app.frontend.streamlit_app_talk2knowledgegraphs
|
|
2
2
|
default_user: "talk2kg_user"
|
3
3
|
data_package_allowed_file_types:
|
4
4
|
- "pdf"
|
5
|
-
|
6
|
-
- "
|
5
|
+
multimodal_allowed_file_types:
|
6
|
+
- "xls"
|
7
|
+
- "xlsx"
|
7
8
|
upload_data_dir: "../files"
|
8
9
|
kg_name: "PrimeKG"
|
9
|
-
kg_pyg_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/
|
10
|
-
kg_text_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/
|
10
|
+
kg_pyg_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_pyg_graph.pkl"
|
11
|
+
kg_text_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_text_graph.pkl"
|
11
12
|
openai_api_key: ${oc.env:OPENAI_API_KEY}
|
12
13
|
openai_llms:
|
13
14
|
- "gpt-4o-mini"
|
@@ -23,9 +24,9 @@ ollama_embeddings:
|
|
23
24
|
- "nomic-embed-text"
|
24
25
|
temperature: 0.1
|
25
26
|
streaming: False
|
26
|
-
reasoning_subgraph_topk_nodes:
|
27
|
+
reasoning_subgraph_topk_nodes: 5
|
27
28
|
reasoning_subgraph_topk_nodes_min: 1
|
28
29
|
reasoning_subgraph_topk_nodes_max: 100
|
29
|
-
reasoning_subgraph_topk_edges:
|
30
|
+
reasoning_subgraph_topk_edges: 5
|
30
31
|
reasoning_subgraph_topk_edges_min: 1
|
31
32
|
reasoning_subgraph_topk_edges_max: 100
|
File without changes
|
@@ -19,6 +19,14 @@ def input_dict_fixture():
|
|
19
19
|
input_dict = {
|
20
20
|
"llm_model": None, # TBA for each test case
|
21
21
|
"embedding_model": None, # TBA for each test case
|
22
|
+
"selections": {
|
23
|
+
"gene/protein": [],
|
24
|
+
"molecular_function": [],
|
25
|
+
"cellular_component": [],
|
26
|
+
"biological_process": [],
|
27
|
+
"drug": [],
|
28
|
+
"disease": []
|
29
|
+
},
|
22
30
|
"uploaded_files": [
|
23
31
|
{
|
24
32
|
"file_name": "adalimumab.pdf",
|
@@ -27,21 +35,14 @@ def input_dict_fixture():
|
|
27
35
|
"uploaded_by": "VPEUser",
|
28
36
|
"uploaded_timestamp": "2024-11-05 00:00:00",
|
29
37
|
},
|
30
|
-
{
|
31
|
-
"file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
|
32
|
-
"file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
|
33
|
-
"file_type": "endotype",
|
34
|
-
"uploaded_by": "VPEUser",
|
35
|
-
"uploaded_timestamp": "2024-11-05 00:00:00",
|
36
|
-
},
|
37
38
|
],
|
38
39
|
"topk_nodes": 3,
|
39
40
|
"topk_edges": 3,
|
40
41
|
"dic_source_graph": [
|
41
42
|
{
|
42
|
-
"name": "
|
43
|
-
"kg_pyg_path": f"{DATA_PATH}/
|
44
|
-
"kg_text_path": f"{DATA_PATH}/
|
43
|
+
"name": "BioBridge",
|
44
|
+
"kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
|
45
|
+
"kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
|
45
46
|
}
|
46
47
|
],
|
47
48
|
"dic_extracted_graph": []
|
@@ -96,7 +97,7 @@ def test_t2kg_agent_openai(input_dict):
|
|
96
97
|
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
97
98
|
assert isinstance(dic_extracted_graph, dict)
|
98
99
|
assert dic_extracted_graph["name"] == "subkg_12345"
|
99
|
-
assert dic_extracted_graph["graph_source"] == "
|
100
|
+
assert dic_extracted_graph["graph_source"] == "BioBridge"
|
100
101
|
assert dic_extracted_graph["topk_nodes"] == 3
|
101
102
|
assert dic_extracted_graph["topk_edges"] == 3
|
102
103
|
assert isinstance(dic_extracted_graph["graph_dict"], dict)
|
@@ -0,0 +1,152 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for tools/subgraph_extraction.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
# from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
7
|
+
from ..tools.multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
|
8
|
+
|
9
|
+
# Define the data path
|
10
|
+
DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
|
11
|
+
|
12
|
+
|
13
|
+
@pytest.fixture(name="agent_state")
|
14
|
+
def agent_state_fixture():
|
15
|
+
"""
|
16
|
+
Agent state fixture.
|
17
|
+
"""
|
18
|
+
agent_state = {
|
19
|
+
# "llm_model": ChatOpenAI(model="gpt-4o-mini", temperature=0.0),
|
20
|
+
# "embedding_model": OpenAIEmbeddings(model="text-embedding-3-small"),
|
21
|
+
"selections": {
|
22
|
+
"gene/protein": [],
|
23
|
+
"molecular_function": [],
|
24
|
+
"cellular_component": [],
|
25
|
+
"biological_process": [],
|
26
|
+
"drug": [],
|
27
|
+
"disease": []
|
28
|
+
},
|
29
|
+
"uploaded_files": [],
|
30
|
+
"topk_nodes": 3,
|
31
|
+
"topk_edges": 3,
|
32
|
+
"dic_source_graph": [
|
33
|
+
{
|
34
|
+
"name": "BioBridge",
|
35
|
+
"kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
|
36
|
+
"kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
|
37
|
+
}
|
38
|
+
],
|
39
|
+
}
|
40
|
+
|
41
|
+
return agent_state
|
42
|
+
|
43
|
+
|
44
|
+
def test_extract_multimodal_subgraph_wo_doc(agent_state):
|
45
|
+
"""
|
46
|
+
Test the multimodal subgraph extraction tool for only text as modality.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
agent_state: Agent state in the form of a dictionary.
|
50
|
+
"""
|
51
|
+
prompt = """
|
52
|
+
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
|
53
|
+
(IBD) that existed in the knowledge graph.
|
54
|
+
Please set the extraction name for this process as `subkg_12345`.
|
55
|
+
"""
|
56
|
+
|
57
|
+
# Instantiate the MultimodalSubgraphExtractionTool
|
58
|
+
subgraph_extraction_tool = MultimodalSubgraphExtractionTool()
|
59
|
+
|
60
|
+
# Invoking the subgraph_extraction_tool
|
61
|
+
response = subgraph_extraction_tool.invoke(
|
62
|
+
input={"prompt": prompt,
|
63
|
+
"tool_call_id": "subgraph_extraction_tool",
|
64
|
+
"state": agent_state,
|
65
|
+
"arg_data": {"extraction_name": "subkg_12345"}})
|
66
|
+
|
67
|
+
# Check tool message
|
68
|
+
assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
|
69
|
+
|
70
|
+
# Check extracted subgraph dictionary
|
71
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
72
|
+
assert isinstance(dic_extracted_graph, dict)
|
73
|
+
assert dic_extracted_graph["name"] == "subkg_12345"
|
74
|
+
assert dic_extracted_graph["graph_source"] == "BioBridge"
|
75
|
+
assert dic_extracted_graph["topk_nodes"] == 3
|
76
|
+
assert dic_extracted_graph["topk_edges"] == 3
|
77
|
+
assert isinstance(dic_extracted_graph["graph_dict"], dict)
|
78
|
+
assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
|
79
|
+
assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
|
80
|
+
assert isinstance(dic_extracted_graph["graph_text"], str)
|
81
|
+
# Check if the nodes are in the graph_text
|
82
|
+
assert all(
|
83
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
84
|
+
for n in dic_extracted_graph["graph_dict"]["nodes"]
|
85
|
+
)
|
86
|
+
# Check if the edges are in the graph_text
|
87
|
+
assert all(
|
88
|
+
",".join([e[0], str(tuple(e[2]["relation"])), e[1]])
|
89
|
+
in dic_extracted_graph["graph_text"].replace('"', '')
|
90
|
+
for e in dic_extracted_graph["graph_dict"]["edges"]
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
def test_extract_multimodal_subgraph_w_doc(agent_state):
|
95
|
+
"""
|
96
|
+
Test the multimodal subgraph extraction tool for text as modality, plus genes.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
agent_state: Agent state in the form of a dictionary.
|
100
|
+
"""
|
101
|
+
# Update state
|
102
|
+
agent_state["uploaded_files"] = [
|
103
|
+
{
|
104
|
+
"file_name": "multimodal-analysis.xlsx",
|
105
|
+
"file_path": f"{DATA_PATH}/multimodal-analysis.xlsx",
|
106
|
+
"file_type": "multimodal",
|
107
|
+
"uploaded_by": "VPEUser",
|
108
|
+
"uploaded_timestamp": "2025-05-12 00:00:00",
|
109
|
+
}
|
110
|
+
]
|
111
|
+
|
112
|
+
prompt = """
|
113
|
+
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
|
114
|
+
(IBD) that existed in the knowledge graph.
|
115
|
+
Please set the extraction name for this process as `subkg_12345`.
|
116
|
+
"""
|
117
|
+
|
118
|
+
# Instantiate the SubgraphExtractionTool
|
119
|
+
subgraph_extraction_tool = MultimodalSubgraphExtractionTool()
|
120
|
+
|
121
|
+
# Invoking the subgraph_extraction_tool
|
122
|
+
response = subgraph_extraction_tool.invoke(
|
123
|
+
input={"prompt": prompt,
|
124
|
+
"tool_call_id": "subgraph_extraction_tool",
|
125
|
+
"state": agent_state,
|
126
|
+
"arg_data": {"extraction_name": "subkg_12345"}})
|
127
|
+
|
128
|
+
# Check tool message
|
129
|
+
assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
|
130
|
+
|
131
|
+
# Check extracted subgraph dictionary
|
132
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
133
|
+
assert isinstance(dic_extracted_graph, dict)
|
134
|
+
assert dic_extracted_graph["name"] == "subkg_12345"
|
135
|
+
assert dic_extracted_graph["graph_source"] == "BioBridge"
|
136
|
+
assert dic_extracted_graph["topk_nodes"] == 3
|
137
|
+
assert dic_extracted_graph["topk_edges"] == 3
|
138
|
+
assert isinstance(dic_extracted_graph["graph_dict"], dict)
|
139
|
+
assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
|
140
|
+
assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
|
141
|
+
assert isinstance(dic_extracted_graph["graph_text"], str)
|
142
|
+
# Check if the nodes are in the graph_text
|
143
|
+
assert all(
|
144
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
145
|
+
for n in dic_extracted_graph["graph_dict"]["nodes"]
|
146
|
+
)
|
147
|
+
# Check if the edges are in the graph_text
|
148
|
+
assert all(
|
149
|
+
",".join([e[0], str(tuple(e[2]["relation"])), e[1]])
|
150
|
+
in dic_extracted_graph["graph_text"].replace('"', '')
|
151
|
+
for e in dic_extracted_graph["graph_dict"]["edges"]
|
152
|
+
)
|
@@ -3,22 +3,21 @@ Test cases for tools/subgraph_extraction.py
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import pytest
|
6
|
-
from langchain_core.messages import HumanMessage
|
7
6
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
8
|
-
from ..
|
7
|
+
from ..tools.subgraph_extraction import SubgraphExtractionTool
|
9
8
|
|
10
9
|
# Define the data path
|
11
10
|
DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
|
12
11
|
|
13
12
|
|
14
|
-
@pytest.fixture(name="
|
15
|
-
def
|
13
|
+
@pytest.fixture(name="agent_state")
|
14
|
+
def agent_state_fixture():
|
16
15
|
"""
|
17
|
-
|
16
|
+
Agent state fixture.
|
18
17
|
"""
|
19
|
-
|
20
|
-
"llm_model":
|
21
|
-
"embedding_model":
|
18
|
+
agent_state = {
|
19
|
+
"llm_model": ChatOpenAI(model="gpt-4o-mini", temperature=0.0),
|
20
|
+
"embedding_model": OpenAIEmbeddings(model="text-embedding-3-small"),
|
22
21
|
"uploaded_files": [],
|
23
22
|
"topk_nodes": 3,
|
24
23
|
"topk_edges": 3,
|
@@ -31,52 +30,37 @@ def input_dict_fixture():
|
|
31
30
|
],
|
32
31
|
}
|
33
32
|
|
34
|
-
return
|
33
|
+
return agent_state
|
35
34
|
|
36
35
|
|
37
|
-
def test_extract_subgraph_wo_docs(
|
36
|
+
def test_extract_subgraph_wo_docs(agent_state):
|
38
37
|
"""
|
39
38
|
Test the subgraph extraction tool without any documents using OpenAI model.
|
40
39
|
|
41
40
|
Args:
|
42
|
-
|
41
|
+
agent_state: Agent state in the form of a dictionary.
|
43
42
|
"""
|
44
|
-
# Prepare LLM and embedding model
|
45
|
-
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
|
46
|
-
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
|
47
|
-
|
48
|
-
# Setup the app
|
49
|
-
unique_id = 12345
|
50
|
-
app = get_app(unique_id, llm_model=input_dict["llm_model"])
|
51
|
-
config = {"configurable": {"thread_id": unique_id}}
|
52
|
-
# Update state
|
53
|
-
app.update_state(
|
54
|
-
config,
|
55
|
-
input_dict,
|
56
|
-
)
|
57
43
|
prompt = """
|
58
|
-
Please directly invoke `subgraph_extraction` tool without calling any other tools
|
59
|
-
to respond to the following prompt:
|
60
|
-
|
61
44
|
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
|
62
45
|
(IBD) that existed in the knowledge graph.
|
63
46
|
Please set the extraction name for this process as `subkg_12345`.
|
64
47
|
"""
|
65
48
|
|
66
|
-
#
|
67
|
-
|
49
|
+
# Instantiate the SubgraphExtractionTool
|
50
|
+
subgraph_extraction_tool = SubgraphExtractionTool()
|
68
51
|
|
69
|
-
#
|
70
|
-
|
71
|
-
|
52
|
+
# Invoking the subgraph_extraction_tool
|
53
|
+
response = subgraph_extraction_tool.invoke(
|
54
|
+
input={"prompt": prompt,
|
55
|
+
"tool_call_id": "subgraph_extraction_tool",
|
56
|
+
"state": agent_state,
|
57
|
+
"arg_data": {"extraction_name": "subkg_12345"}})
|
72
58
|
|
73
59
|
# Check tool message
|
74
|
-
|
75
|
-
assert tool_msg.name == "subgraph_extraction"
|
60
|
+
assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
|
76
61
|
|
77
62
|
# Check extracted subgraph dictionary
|
78
|
-
|
79
|
-
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
63
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
80
64
|
assert isinstance(dic_extracted_graph, dict)
|
81
65
|
assert dic_extracted_graph["name"] == "subkg_12345"
|
82
66
|
assert dic_extracted_graph["graph_source"] == "PrimeKG"
|
@@ -99,24 +83,16 @@ def test_extract_subgraph_wo_docs(input_dict):
|
|
99
83
|
)
|
100
84
|
|
101
85
|
|
102
|
-
def test_extract_subgraph_w_docs(
|
86
|
+
def test_extract_subgraph_w_docs(agent_state):
|
103
87
|
"""
|
104
|
-
|
105
|
-
|
88
|
+
As a knowledge graph agent, I would like you to call a tool called `subgraph_extraction`.
|
89
|
+
After calling the tool, restrain yourself to call any other tool.
|
106
90
|
|
107
91
|
Args:
|
108
|
-
|
92
|
+
agent_state: Agent state in the form of a dictionary.
|
109
93
|
"""
|
110
|
-
# Prepare LLM and embedding model
|
111
|
-
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
|
112
|
-
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
|
113
|
-
|
114
|
-
# Setup the app
|
115
|
-
unique_id = 12345
|
116
|
-
app = get_app(unique_id, llm_model=input_dict["llm_model"])
|
117
|
-
config = {"configurable": {"thread_id": unique_id}}
|
118
94
|
# Update state
|
119
|
-
|
95
|
+
agent_state["uploaded_files"] = [
|
120
96
|
{
|
121
97
|
"file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
|
122
98
|
"file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
|
@@ -125,33 +101,28 @@ def test_extract_subgraph_w_docs(input_dict):
|
|
125
101
|
"uploaded_timestamp": "2024-11-05 00:00:00",
|
126
102
|
}
|
127
103
|
]
|
128
|
-
app.update_state(
|
129
|
-
config,
|
130
|
-
input_dict,
|
131
|
-
)
|
132
|
-
prompt = """
|
133
|
-
Please ONLY invoke `subgraph_extraction` tool without calling any other tools
|
134
|
-
to respond to the following prompt:
|
135
104
|
|
105
|
+
prompt = """
|
136
106
|
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
|
137
107
|
(IBD) that existed in the knowledge graph.
|
138
108
|
Please set the extraction name for this process as `subkg_12345`.
|
139
109
|
"""
|
140
110
|
|
141
|
-
#
|
142
|
-
|
111
|
+
# Instantiate the SubgraphExtractionTool
|
112
|
+
subgraph_extraction_tool = SubgraphExtractionTool()
|
143
113
|
|
144
|
-
#
|
145
|
-
|
146
|
-
|
114
|
+
# Invoking the subgraph_extraction_tool
|
115
|
+
response = subgraph_extraction_tool.invoke(
|
116
|
+
input={"prompt": prompt,
|
117
|
+
"tool_call_id": "subgraph_extraction_tool",
|
118
|
+
"state": agent_state,
|
119
|
+
"arg_data": {"extraction_name": "subkg_12345"}})
|
147
120
|
|
148
121
|
# Check tool message
|
149
|
-
|
150
|
-
assert tool_msg.name == "subgraph_extraction"
|
122
|
+
assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
|
151
123
|
|
152
124
|
# Check extracted subgraph dictionary
|
153
|
-
|
154
|
-
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
125
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
155
126
|
assert isinstance(dic_extracted_graph, dict)
|
156
127
|
assert dic_extracted_graph["name"] == "subkg_12345"
|
157
128
|
assert dic_extracted_graph["graph_source"] == "PrimeKG"
|
@@ -2,17 +2,23 @@
|
|
2
2
|
Test cases for utils/embeddings/sentence_transformer.py
|
3
3
|
"""
|
4
4
|
|
5
|
-
import pytest
|
6
5
|
import numpy as np
|
6
|
+
import pytest
|
7
|
+
|
7
8
|
from ..utils.embeddings.sentence_transformer import EmbeddingWithSentenceTransformer
|
8
9
|
|
10
|
+
|
9
11
|
@pytest.fixture(name="embedding_model")
|
10
12
|
def embedding_model_fixture():
|
11
13
|
"""
|
12
14
|
Fixture for creating an instance of EmbeddingWithSentenceTransformer.
|
13
15
|
"""
|
14
16
|
model_name = "sentence-transformers/all-MiniLM-L6-v1" # Small model for testing
|
15
|
-
|
17
|
+
embedding_model = EmbeddingWithSentenceTransformer(model_name=model_name)
|
18
|
+
# Move underlying model to CPU for testing
|
19
|
+
embedding_model.model.to("cpu")
|
20
|
+
return embedding_model
|
21
|
+
|
16
22
|
|
17
23
|
def test_embed_documents(embedding_model):
|
18
24
|
"""
|
@@ -27,6 +33,7 @@ def test_embed_documents(embedding_model):
|
|
27
33
|
assert len(embeddings[0]) == 384
|
28
34
|
assert embeddings.dtype == np.float32
|
29
35
|
|
36
|
+
|
30
37
|
def test_embed_query(embedding_model):
|
31
38
|
"""
|
32
39
|
Test the embed_query method of EmbeddingWithSentenceTransformer class.
|
@@ -19,24 +19,28 @@ from ..utils.enrichments.ols_terms import EnrichmentWithOLS
|
|
19
19
|
CL_DESC = "CD4-positive, alpha-beta T cell"
|
20
20
|
GO_DESC = "Any process that activates or increases the frequency, rate or extent"
|
21
21
|
UBERON_DESC = "The olfactory organ of vertebrates, consisting of nares"
|
22
|
-
HP_DESC = "
|
22
|
+
HP_DESC = "Developmental hypoplasia of the antihelix"
|
23
23
|
MONDO_DESC = "A gastrointestinal disorder characterized by chronic inflammation"
|
24
24
|
|
25
25
|
# The expected description for the non-existing term is None
|
26
26
|
|
27
|
+
|
27
28
|
@pytest.fixture(name="enrich_obj")
|
28
29
|
def fixture_uniprot_config():
|
29
30
|
"""Return a dictionary with the configuration for OLS enrichment."""
|
30
31
|
return EnrichmentWithOLS()
|
31
32
|
|
33
|
+
|
32
34
|
def test_enrich_documents(enrich_obj):
|
33
35
|
"""Test the enrich_documents method."""
|
34
|
-
ols_terms = [
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
36
|
+
ols_terms = [
|
37
|
+
"CL_0000899",
|
38
|
+
"GO_0046427",
|
39
|
+
"UBERON_0000004",
|
40
|
+
"HP_0009739",
|
41
|
+
"MONDO_0005011",
|
42
|
+
"XYZ_0000000",
|
43
|
+
]
|
40
44
|
descriptions = enrich_obj.enrich_documents(ols_terms)
|
41
45
|
assert descriptions[0].startswith(CL_DESC)
|
42
46
|
assert descriptions[1].startswith(GO_DESC)
|
@@ -45,14 +49,17 @@ def test_enrich_documents(enrich_obj):
|
|
45
49
|
assert descriptions[4].startswith(MONDO_DESC)
|
46
50
|
assert descriptions[5] is None
|
47
51
|
|
52
|
+
|
48
53
|
def test_enrich_documents_with_rag(enrich_obj):
|
49
54
|
"""Test the enrich_documents_with_rag method."""
|
50
|
-
ols_terms = [
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
55
|
+
ols_terms = [
|
56
|
+
"CL_0000899",
|
57
|
+
"GO_0046427",
|
58
|
+
"UBERON_0000004",
|
59
|
+
"HP_0009739",
|
60
|
+
"MONDO_0005011",
|
61
|
+
"XYZ_0000000",
|
62
|
+
]
|
56
63
|
descriptions = enrich_obj.enrich_documents_with_rag(ols_terms, None)
|
57
64
|
assert descriptions[0].startswith(CL_DESC)
|
58
65
|
assert descriptions[1].startswith(GO_DESC)
|