aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
- aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
- aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
- aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for tools/subgraph_extraction.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
from langchain_core.messages import HumanMessage
|
7
|
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
8
|
+
from ..agents.t2kg_agent import get_app
|
9
|
+
|
10
|
+
# Define the data path
|
11
|
+
DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.fixture(name="input_dict")
|
15
|
+
def input_dict_fixture():
|
16
|
+
"""
|
17
|
+
Input dictionary fixture.
|
18
|
+
"""
|
19
|
+
input_dict = {
|
20
|
+
"llm_model": None, # TBA for each test case
|
21
|
+
"embedding_model": None, # TBA for each test case
|
22
|
+
"uploaded_files": [],
|
23
|
+
"topk_nodes": 3,
|
24
|
+
"topk_edges": 3,
|
25
|
+
"dic_source_graph": [
|
26
|
+
{
|
27
|
+
"name": "PrimeKG",
|
28
|
+
"kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
|
29
|
+
"kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
|
30
|
+
}
|
31
|
+
],
|
32
|
+
}
|
33
|
+
|
34
|
+
return input_dict
|
35
|
+
|
36
|
+
|
37
|
+
def test_extract_subgraph_wo_docs(input_dict):
|
38
|
+
"""
|
39
|
+
Test the subgraph extraction tool without any documents using OpenAI model.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
input_dict: Input dictionary.
|
43
|
+
"""
|
44
|
+
# Prepare LLM and embedding model
|
45
|
+
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
|
46
|
+
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
|
47
|
+
|
48
|
+
# Setup the app
|
49
|
+
unique_id = 12345
|
50
|
+
app = get_app(unique_id, llm_model=input_dict["llm_model"])
|
51
|
+
config = {"configurable": {"thread_id": unique_id}}
|
52
|
+
# Update state
|
53
|
+
app.update_state(
|
54
|
+
config,
|
55
|
+
input_dict,
|
56
|
+
)
|
57
|
+
prompt = """
|
58
|
+
Please directly invoke `subgraph_extraction` tool without calling any other tools
|
59
|
+
to respond to the following prompt:
|
60
|
+
|
61
|
+
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
|
62
|
+
(IBD) that existed in the knowledge graph.
|
63
|
+
Please set the extraction name for this process as `subkg_12345`.
|
64
|
+
"""
|
65
|
+
|
66
|
+
# Test the tool subgraph_extraction
|
67
|
+
response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
|
68
|
+
|
69
|
+
# Check assistant message
|
70
|
+
assistant_msg = response["messages"][-1].content
|
71
|
+
assert isinstance(assistant_msg, str)
|
72
|
+
|
73
|
+
# Check tool message
|
74
|
+
tool_msg = response["messages"][-2]
|
75
|
+
assert tool_msg.name == "subgraph_extraction"
|
76
|
+
|
77
|
+
# Check extracted subgraph dictionary
|
78
|
+
current_state = app.get_state(config)
|
79
|
+
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
80
|
+
assert isinstance(dic_extracted_graph, dict)
|
81
|
+
assert dic_extracted_graph["name"] == "subkg_12345"
|
82
|
+
assert dic_extracted_graph["graph_source"] == "PrimeKG"
|
83
|
+
assert dic_extracted_graph["topk_nodes"] == 3
|
84
|
+
assert dic_extracted_graph["topk_edges"] == 3
|
85
|
+
assert isinstance(dic_extracted_graph["graph_dict"], dict)
|
86
|
+
assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
|
87
|
+
assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
|
88
|
+
assert isinstance(dic_extracted_graph["graph_text"], str)
|
89
|
+
# Check if the nodes are in the graph_text
|
90
|
+
assert all(
|
91
|
+
n[0] in dic_extracted_graph["graph_text"]
|
92
|
+
for n in dic_extracted_graph["graph_dict"]["nodes"]
|
93
|
+
)
|
94
|
+
# Check if the edges are in the graph_text
|
95
|
+
assert all(
|
96
|
+
",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
|
97
|
+
in dic_extracted_graph["graph_text"]
|
98
|
+
for e in dic_extracted_graph["graph_dict"]["edges"]
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
def test_extract_subgraph_w_docs(input_dict):
|
103
|
+
"""
|
104
|
+
Test the subgraph extraction tool with a document as reference (i.e., endotype document)
|
105
|
+
using OpenAI model.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
input_dict: Input dictionary.
|
109
|
+
"""
|
110
|
+
# Prepare LLM and embedding model
|
111
|
+
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
|
112
|
+
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
|
113
|
+
|
114
|
+
# Setup the app
|
115
|
+
unique_id = 12345
|
116
|
+
app = get_app(unique_id, llm_model=input_dict["llm_model"])
|
117
|
+
config = {"configurable": {"thread_id": unique_id}}
|
118
|
+
# Update state
|
119
|
+
input_dict["uploaded_files"] = [
|
120
|
+
{
|
121
|
+
"file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
|
122
|
+
"file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
|
123
|
+
"file_type": "endotype",
|
124
|
+
"uploaded_by": "VPEUser",
|
125
|
+
"uploaded_timestamp": "2024-11-05 00:00:00",
|
126
|
+
}
|
127
|
+
]
|
128
|
+
app.update_state(
|
129
|
+
config,
|
130
|
+
input_dict,
|
131
|
+
)
|
132
|
+
prompt = """
|
133
|
+
Please ONLY invoke `subgraph_extraction` tool without calling any other tools
|
134
|
+
to respond to the following prompt:
|
135
|
+
|
136
|
+
Extract all relevant information related to nodes of genes related to inflammatory bowel disease
|
137
|
+
(IBD) that existed in the knowledge graph.
|
138
|
+
Please set the extraction name for this process as `subkg_12345`.
|
139
|
+
"""
|
140
|
+
|
141
|
+
# Test the tool subgraph_extraction
|
142
|
+
response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
|
143
|
+
|
144
|
+
# Check assistant message
|
145
|
+
assistant_msg = response["messages"][-1].content
|
146
|
+
assert isinstance(assistant_msg, str)
|
147
|
+
|
148
|
+
# Check tool message
|
149
|
+
tool_msg = response["messages"][-2]
|
150
|
+
assert tool_msg.name == "subgraph_extraction"
|
151
|
+
|
152
|
+
# Check extracted subgraph dictionary
|
153
|
+
current_state = app.get_state(config)
|
154
|
+
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
155
|
+
assert isinstance(dic_extracted_graph, dict)
|
156
|
+
assert dic_extracted_graph["name"] == "subkg_12345"
|
157
|
+
assert dic_extracted_graph["graph_source"] == "PrimeKG"
|
158
|
+
assert dic_extracted_graph["topk_nodes"] == 3
|
159
|
+
assert dic_extracted_graph["topk_edges"] == 3
|
160
|
+
assert isinstance(dic_extracted_graph["graph_dict"], dict)
|
161
|
+
assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
|
162
|
+
assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
|
163
|
+
assert isinstance(dic_extracted_graph["graph_text"], str)
|
164
|
+
# Check if the nodes are in the graph_text
|
165
|
+
assert all(
|
166
|
+
n[0] in dic_extracted_graph["graph_text"]
|
167
|
+
for n in dic_extracted_graph["graph_dict"]["nodes"]
|
168
|
+
)
|
169
|
+
# Check if the edges are in the graph_text
|
170
|
+
assert all(
|
171
|
+
",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
|
172
|
+
in dic_extracted_graph["graph_text"]
|
173
|
+
for e in dic_extracted_graph["graph_dict"]["edges"]
|
174
|
+
)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for tools/subgraph_summarization.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
from langchain_core.messages import HumanMessage
|
7
|
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
8
|
+
from ..agents.t2kg_agent import get_app
|
9
|
+
|
10
|
+
# Define the data path
|
11
|
+
DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.fixture(name="input_dict")
|
15
|
+
def input_dict_fixture():
|
16
|
+
"""
|
17
|
+
Input dictionary fixture.
|
18
|
+
"""
|
19
|
+
input_dict = {
|
20
|
+
"llm_model": None, # TBA for each test case
|
21
|
+
"embedding_model": None, # TBA for each test case
|
22
|
+
"uploaded_files": [],
|
23
|
+
"topk_nodes": 3,
|
24
|
+
"topk_edges": 3,
|
25
|
+
"dic_source_graph": [
|
26
|
+
{
|
27
|
+
"name": "PrimeKG",
|
28
|
+
"kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
|
29
|
+
"kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
|
30
|
+
}
|
31
|
+
],
|
32
|
+
"dic_extracted_graph": [
|
33
|
+
{
|
34
|
+
"name": "subkg_12345",
|
35
|
+
"tool_call_id": "tool_12345",
|
36
|
+
"graph_source": "PrimeKG",
|
37
|
+
"topk_nodes": 3,
|
38
|
+
"topk_edges": 3,
|
39
|
+
"graph_dict": {
|
40
|
+
'nodes': [('IFNG_(3495)', {}),
|
41
|
+
('IKBKG_(3672)', {}),
|
42
|
+
('ATG16L1_(6661)', {}),
|
43
|
+
('inflammatory bowel disease_(28158)', {}),
|
44
|
+
('Crohn ileitis and jejunitis_(35814)', {}),
|
45
|
+
("Crohn's colitis_(83770)", {})],
|
46
|
+
'edges': [('IFNG_(3495)', 'inflammatory bowel disease_(28158)',
|
47
|
+
{'relation': ['gene/protein', 'associated with', 'disease'],
|
48
|
+
'label': ['gene/protein', 'associated with', 'disease']}),
|
49
|
+
('IFNG_(3495)', "Crohn's colitis_(83770)",
|
50
|
+
{'relation': ['gene/protein', 'associated with', 'disease'],
|
51
|
+
'label': ['gene/protein', 'associated with', 'disease']}),
|
52
|
+
('IFNG_(3495)', 'Crohn ileitis and jejunitis_(35814)',
|
53
|
+
{'relation': ['gene/protein', 'associated with', 'disease'],
|
54
|
+
'label': ['gene/protein', 'associated with', 'disease']}),
|
55
|
+
('ATG16L1_(6661)', 'IKBKG_(3672)',
|
56
|
+
{'relation': ['gene/protein', 'ppi', 'gene/protein'],
|
57
|
+
'label': ['gene/protein', 'ppi', 'gene/protein']}),
|
58
|
+
("Crohn's colitis_(83770)", 'ATG16L1_(6661)',
|
59
|
+
{'relation': ['disease', 'associated with', 'gene/protein'],
|
60
|
+
'label': ['disease', 'associated with', 'gene/protein']})]},
|
61
|
+
"graph_text": """
|
62
|
+
node_id,node_attr
|
63
|
+
IFNG_(3495),"IFNG belongs to gene/protein category.
|
64
|
+
This gene encodes a soluble cytokine that is a member of the type II interferon class.
|
65
|
+
The encoded protein is secreted by cells of both the innate and adaptive immune systems.
|
66
|
+
The active protein is a homodimer that binds to the interferon gamma receptor
|
67
|
+
which triggers a cellular response to viral and microbial infections.
|
68
|
+
Mutations in this gene are associated with an increased susceptibility to viral,
|
69
|
+
bacterial and parasitic infections and to several autoimmune diseases.
|
70
|
+
[provided by RefSeq, Dec 2015]."
|
71
|
+
IKBKG_(3672),"IKBKG belongs to gene/protein category. This gene encodes the regulatory
|
72
|
+
subunit of the inhibitor of kappaB kinase (IKK) complex, which activates NF-kappaB
|
73
|
+
resulting in activation of genes involved in inflammation, immunity, cell survival,
|
74
|
+
and other pathways. Mutations in this gene result in incontinentia pigmenti,
|
75
|
+
hypohidrotic ectodermal dysplasia, and several other types of immunodeficiencies.
|
76
|
+
A pseudogene highly similar to this locus is located in an adjacent region of the
|
77
|
+
X chromosome. [provided by RefSeq, Mar 2016]."
|
78
|
+
ATG16L1_(6661),"ATG16L1 belongs to gene/protein category. The protein encoded
|
79
|
+
by this gene is part of a large protein complex that is necessary for autophagy,
|
80
|
+
the major process by which intracellular components are targeted to lysosomes
|
81
|
+
for degradation. Defects in this gene are a cause of susceptibility to inflammatory
|
82
|
+
bowel disease type 10 (IBD10). Several transcript variants encoding different
|
83
|
+
isoforms have been found for this gene.[provided by RefSeq, Jun 2010]."
|
84
|
+
inflammatory bowel disease_(28158),inflammatory bowel disease belongs to disease
|
85
|
+
category. Any inflammatory bowel disease in which the cause of the disease
|
86
|
+
is a mutation in the NOD2 gene.
|
87
|
+
Crohn ileitis and jejunitis_(35814),Crohn ileitis and jejunitis belongs to
|
88
|
+
disease category. An Crohn disease involving a pathogenic inflammatory
|
89
|
+
response in the ileum.
|
90
|
+
Crohn's colitis_(83770),Crohn's colitis belongs to disease category.
|
91
|
+
Crohn's disease affecting the colon.
|
92
|
+
|
93
|
+
head_id,edge_type,tail_id
|
94
|
+
Crohn's colitis_(83770),"('disease', 'associated with', 'gene/protein')",
|
95
|
+
ATG16L1_(6661)
|
96
|
+
ATG16L1_(6661),"('gene/protein', 'ppi', 'gene/protein')",IKBKG_(3672)
|
97
|
+
IFNG_(3495),"('gene/protein', 'associated with', 'disease')",
|
98
|
+
inflammatory bowel disease_(28158)
|
99
|
+
IFNG_(3495),"('gene/protein', 'associated with', 'disease')",Crohn's colitis_(83770)
|
100
|
+
IFNG_(3495),"('gene/protein', 'associated with', 'disease')",
|
101
|
+
Crohn ileitis and jejunitis_(35814)
|
102
|
+
""",
|
103
|
+
"graph_summary": None,
|
104
|
+
}
|
105
|
+
],
|
106
|
+
}
|
107
|
+
|
108
|
+
return input_dict
|
109
|
+
|
110
|
+
|
111
|
+
def test_summarize_subgraph(input_dict):
|
112
|
+
"""
|
113
|
+
Test the subgraph summarization tool without any documents using Ollama model.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
input_dict: Input dictionary fixture.
|
117
|
+
"""
|
118
|
+
# Prepare LLM and embedding model
|
119
|
+
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
|
120
|
+
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
|
121
|
+
|
122
|
+
# Setup the app
|
123
|
+
unique_id = 12345
|
124
|
+
app = get_app(unique_id, llm_model=input_dict["llm_model"])
|
125
|
+
config = {"configurable": {"thread_id": unique_id}}
|
126
|
+
# Update state
|
127
|
+
app.update_state(
|
128
|
+
config,
|
129
|
+
input_dict,
|
130
|
+
)
|
131
|
+
prompt = """
|
132
|
+
Please directly invoke `subgraph_summarization` tool without calling any other tools
|
133
|
+
to respond to the following prompt:
|
134
|
+
|
135
|
+
You are given a subgraph in the forms of textualized subgraph representing
|
136
|
+
nodes and edges (triples) obtained from extraction_name `subkg_12345`.
|
137
|
+
Summarize the given subgraph and higlight the importance nodes and edges.
|
138
|
+
"""
|
139
|
+
|
140
|
+
# Test the tool subgraph_summarization
|
141
|
+
response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
|
142
|
+
|
143
|
+
# Check assistant message
|
144
|
+
assistant_msg = response["messages"][-1].content
|
145
|
+
assert isinstance(assistant_msg, str)
|
146
|
+
|
147
|
+
# Check tool message
|
148
|
+
tool_msg = response["messages"][-2]
|
149
|
+
assert tool_msg.name == "subgraph_summarization"
|
150
|
+
|
151
|
+
# Check summarized subgraph
|
152
|
+
current_state = app.get_state(config)
|
153
|
+
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
154
|
+
assert isinstance(dic_extracted_graph["graph_summary"], str)
|
@@ -31,7 +31,6 @@ def test_embedding_with_huggingface_embed_query(embedding_model):
|
|
31
31
|
# Check the result
|
32
32
|
assert len(result) == 768
|
33
33
|
|
34
|
-
|
35
34
|
def test_embedding_with_huggingface_failed():
|
36
35
|
"""Test embedding documents using the EmbeddingWithHuggingFace class."""
|
37
36
|
# Check if the model is available on HuggingFace Hub
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/embeddings/ollama.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import ollama
|
7
|
+
from ..utils.embeddings.ollama import EmbeddingWithOllama
|
8
|
+
|
9
|
+
@pytest.fixture(name="ollama_config")
|
10
|
+
def fixture_ollama_config():
|
11
|
+
"""Return a dictionary with Ollama configuration."""
|
12
|
+
return {
|
13
|
+
"model_name": "all-minilm", # Choose a small model
|
14
|
+
}
|
15
|
+
|
16
|
+
def test_no_model_ollama(ollama_config):
|
17
|
+
"""Test the case when the Ollama model is not available."""
|
18
|
+
cfg = ollama_config
|
19
|
+
|
20
|
+
# Delete the Ollama model
|
21
|
+
try:
|
22
|
+
ollama.delete(cfg["model_name"])
|
23
|
+
except ollama.ResponseError:
|
24
|
+
pass
|
25
|
+
|
26
|
+
# Check if the model is available
|
27
|
+
with pytest.raises(
|
28
|
+
ValueError, match=f"Error: Pulled {cfg["model_name"]} model and restarted Ollama server."
|
29
|
+
):
|
30
|
+
EmbeddingWithOllama(model_name=cfg["model_name"])
|
31
|
+
|
32
|
+
@pytest.fixture(name="embedding_model")
|
33
|
+
def embedding_model_fixture(ollama_config):
|
34
|
+
"""Return the configuration object for the Ollama embedding model and model object"""
|
35
|
+
cfg = ollama_config
|
36
|
+
return EmbeddingWithOllama(model_name=cfg["model_name"])
|
37
|
+
|
38
|
+
def test_embedding_with_ollama_embed_documents(embedding_model):
|
39
|
+
"""Test embedding documents using the EmbeddingWithOllama class."""
|
40
|
+
# Perform embedding
|
41
|
+
texts = ["Adalimumab", "Infliximab", "Vedolizumab"]
|
42
|
+
result = embedding_model.embed_documents(texts)
|
43
|
+
# Check the result
|
44
|
+
assert len(result) == 3
|
45
|
+
assert len(result[0]) == 384
|
46
|
+
|
47
|
+
def test_embedding_with_ollama_embed_query(embedding_model):
|
48
|
+
"""Test embedding a query using the EmbeddingWithOllama class."""
|
49
|
+
# Perform embedding
|
50
|
+
text = "Adalimumab"
|
51
|
+
result = embedding_model.embed_query(text)
|
52
|
+
# Check the result
|
53
|
+
assert len(result) == 384
|
54
|
+
|
55
|
+
# Delete the Ollama model so that it will not be cached afterward
|
56
|
+
ollama.delete(embedding_model.model_name)
|
@@ -6,20 +6,21 @@ import pytest
|
|
6
6
|
import ollama
|
7
7
|
from ..utils.enrichments.ollama import EnrichmentWithOllama
|
8
8
|
|
9
|
+
|
9
10
|
@pytest.fixture(name="ollama_config")
|
10
11
|
def fixture_ollama_config():
|
11
12
|
"""Return a dictionary with Ollama configuration."""
|
12
13
|
return {
|
13
|
-
"model_name": "
|
14
|
+
"model_name": "llama3.2:1b",
|
14
15
|
"prompt_enrichment": """
|
15
|
-
Given the input as a list of strings, please return the list of addditional information
|
16
|
-
each input terms using your prior knowledge.
|
16
|
+
Given the input as a list of strings, please return the list of addditional information
|
17
|
+
of each input terms using your prior knowledge.
|
17
18
|
|
18
19
|
Example:
|
19
20
|
Input: ['acetaminophen', 'aspirin']
|
20
|
-
Ouput: ['acetaminophen is a medication used to treat pain and fever',
|
21
|
+
Ouput: ['acetaminophen is a medication used to treat pain and fever',
|
21
22
|
'aspirin is a medication used to treat pain, fever, and inflammation']
|
22
|
-
|
23
|
+
|
23
24
|
Do not include any pretext as the output, only the list of strings enriched.
|
24
25
|
|
25
26
|
Input: {input}
|
@@ -28,10 +29,11 @@ def fixture_ollama_config():
|
|
28
29
|
"streaming": False,
|
29
30
|
}
|
30
31
|
|
32
|
+
|
31
33
|
def test_no_model_ollama(ollama_config):
|
32
34
|
"""Test the case when the Ollama model is not available."""
|
33
35
|
cfg = ollama_config
|
34
|
-
cfg_model = "smollm2:135m"
|
36
|
+
cfg_model = "smollm2:135m" # Choose a small model
|
35
37
|
|
36
38
|
# Delete the Ollama model
|
37
39
|
try:
|
@@ -41,7 +43,8 @@ def test_no_model_ollama(ollama_config):
|
|
41
43
|
|
42
44
|
# Check if the model is available
|
43
45
|
with pytest.raises(
|
44
|
-
ValueError,
|
46
|
+
ValueError,
|
47
|
+
match=f"Error: Pulled {cfg_model} model and restarted Ollama server.",
|
45
48
|
):
|
46
49
|
EnrichmentWithOllama(
|
47
50
|
model_name=cfg_model,
|
@@ -51,7 +54,8 @@ def test_no_model_ollama(ollama_config):
|
|
51
54
|
)
|
52
55
|
ollama.delete(cfg_model)
|
53
56
|
|
54
|
-
|
57
|
+
|
58
|
+
def test_enrich_ollama(ollama_config):
|
55
59
|
"""Test the Ollama textual enrichment class for node enrichment."""
|
56
60
|
# Prepare enrichment model
|
57
61
|
cfg = ollama_config
|
@@ -63,37 +67,11 @@ def test_enrich_nodes_ollama(ollama_config):
|
|
63
67
|
)
|
64
68
|
|
65
69
|
# Perform enrichment for nodes
|
66
|
-
nodes = ["
|
70
|
+
nodes = ["acetaminophen"]
|
67
71
|
enriched_nodes = enr_model.enrich_documents(nodes)
|
68
72
|
# Check the enriched nodes
|
69
|
-
assert len(enriched_nodes) ==
|
70
|
-
assert all(
|
71
|
-
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
def test_enrich_relations_ollama(ollama_config):
|
76
|
-
"""Test the Ollama textual enrichment class for relation enrichment."""
|
77
|
-
# Prepare enrichment model
|
78
|
-
cfg = ollama_config
|
79
|
-
enr_model = EnrichmentWithOllama(
|
80
|
-
model_name=cfg["model_name"],
|
81
|
-
prompt_enrichment=cfg["prompt_enrichment"],
|
82
|
-
temperature=cfg["temperature"],
|
83
|
-
streaming=cfg["streaming"],
|
84
|
-
)
|
85
|
-
# Perform enrichment for relations
|
86
|
-
relations = [
|
87
|
-
"IL23R-gene causation disease-inflammatory bowel diseases",
|
88
|
-
"NOD2-gene causation disease-inflammatory bowel diseases",
|
89
|
-
]
|
90
|
-
enriched_relations = enr_model.enrich_documents(relations)
|
91
|
-
# Check the enriched relations
|
92
|
-
assert len(enriched_relations) == 2
|
93
|
-
assert all(
|
94
|
-
enriched_relations[i] != relations[i]
|
95
|
-
for i in range(len(relations))
|
96
|
-
)
|
73
|
+
assert len(enriched_nodes) == 1
|
74
|
+
assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
|
97
75
|
|
98
76
|
|
99
77
|
def test_enrich_ollama_rag(ollama_config):
|
@@ -107,11 +85,9 @@ def test_enrich_ollama_rag(ollama_config):
|
|
107
85
|
streaming=cfg["streaming"],
|
108
86
|
)
|
109
87
|
# Perform enrichment for nodes
|
110
|
-
nodes = ["
|
88
|
+
nodes = ["acetaminophen"]
|
111
89
|
docs = [r"\path\to\doc1", r"\path\to\doc2"]
|
112
90
|
enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
|
113
91
|
# Check the enriched nodes
|
114
|
-
assert len(enriched_nodes) ==
|
115
|
-
assert all(
|
116
|
-
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
117
|
-
)
|
92
|
+
assert len(enriched_nodes) == 1
|
93
|
+
assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/kg_utils.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import networkx as nx
|
7
|
+
import pandas as pd
|
8
|
+
from ..utils import kg_utils
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture(name="sample_graph")
|
12
|
+
def make_sample_graph():
|
13
|
+
"""Return a sample graph"""
|
14
|
+
sg = nx.Graph()
|
15
|
+
sg.add_node(1, node_id=1, feature_id="A", feature_value="ValueA")
|
16
|
+
sg.add_node(2, node_id=2, feature_id="B", feature_value="ValueB")
|
17
|
+
sg.add_edge(1, 2, edge_id=1, feature_id="E", feature_value="EdgeValue")
|
18
|
+
return sg
|
19
|
+
|
20
|
+
|
21
|
+
def test_kg_to_df_pandas(sample_graph):
|
22
|
+
"""Test the kg_to_df_pandas function"""
|
23
|
+
df_nodes, df_edges = kg_utils.kg_to_df_pandas(sample_graph)
|
24
|
+
print(df_nodes)
|
25
|
+
expected_nodes_data = {
|
26
|
+
"node_id": [1, 2],
|
27
|
+
"feature_id": ["A", "B"],
|
28
|
+
"feature_value": ["ValueA", "ValueB"],
|
29
|
+
}
|
30
|
+
expected_nodes_df = pd.DataFrame(expected_nodes_data, index=[1, 2])
|
31
|
+
print(expected_nodes_df)
|
32
|
+
expected_edges_data = {
|
33
|
+
"node_source": [1],
|
34
|
+
"node_target": [2],
|
35
|
+
"edge_id": [1],
|
36
|
+
"feature_id": ["E"],
|
37
|
+
"feature_value": ["EdgeValue"],
|
38
|
+
}
|
39
|
+
expected_edges_df = pd.DataFrame(expected_edges_data)
|
40
|
+
|
41
|
+
# Assert that the dataframes are equal but the order of columns may be different
|
42
|
+
# Ignore the index of the dataframes
|
43
|
+
pd.testing.assert_frame_equal(df_nodes, expected_nodes_df, check_like=True)
|
44
|
+
pd.testing.assert_frame_equal(df_edges, expected_edges_df, check_like=True)
|
45
|
+
|
46
|
+
|
47
|
+
def test_df_pandas_to_kg():
|
48
|
+
"""Test the df_pandas_to_kg function"""
|
49
|
+
nodes_data = {
|
50
|
+
"node_id": [1, 2],
|
51
|
+
"feature_id": ["A", "B"],
|
52
|
+
"feature_value": ["ValueA", "ValueB"],
|
53
|
+
}
|
54
|
+
df_nodes_attrs = pd.DataFrame(nodes_data).set_index("node_id")
|
55
|
+
|
56
|
+
edges_data = {
|
57
|
+
"node_source": [1],
|
58
|
+
"node_target": [2],
|
59
|
+
"edge_id": [1],
|
60
|
+
"feature_id": ["E"],
|
61
|
+
"feature_value": ["EdgeValue"],
|
62
|
+
}
|
63
|
+
df_edges = pd.DataFrame(edges_data)
|
64
|
+
|
65
|
+
kg = kg_utils.df_pandas_to_kg(
|
66
|
+
df_edges, df_nodes_attrs, "node_source", "node_target"
|
67
|
+
)
|
68
|
+
|
69
|
+
assert len(kg.nodes) == 2
|
70
|
+
assert len(kg.edges) == 1
|
71
|
+
|
72
|
+
assert kg.nodes[1]["feature_id"] == "A"
|
73
|
+
assert kg.nodes[1]["feature_value"] == "ValueA"
|
74
|
+
assert kg.nodes[2]["feature_id"] == "B"
|
75
|
+
assert kg.nodes[2]["feature_value"] == "ValueB"
|
76
|
+
|
77
|
+
assert kg.edges[1, 2]["feature_id"] == "E"
|
78
|
+
assert kg.edges[1, 2]["feature_value"] == "EdgeValue"
|
79
|
+
assert kg.edges[1, 2]["edge_id"] == 1
|