aiagents4pharma 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +3 -1
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +42 -10
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,21 @@ import pytest
|
|
6
6
|
import ollama
|
7
7
|
from ..utils.enrichments.ollama import EnrichmentWithOllama
|
8
8
|
|
9
|
+
|
9
10
|
@pytest.fixture(name="ollama_config")
|
10
11
|
def fixture_ollama_config():
|
11
12
|
"""Return a dictionary with Ollama configuration."""
|
12
13
|
return {
|
13
|
-
"model_name": "
|
14
|
+
"model_name": "llama3.2:1b",
|
14
15
|
"prompt_enrichment": """
|
15
|
-
Given the input as a list of strings, please return the list of addditional information
|
16
|
-
each input terms using your prior knowledge.
|
16
|
+
Given the input as a list of strings, please return the list of addditional information
|
17
|
+
of each input terms using your prior knowledge.
|
17
18
|
|
18
19
|
Example:
|
19
20
|
Input: ['acetaminophen', 'aspirin']
|
20
|
-
Ouput: ['acetaminophen is a medication used to treat pain and fever',
|
21
|
+
Ouput: ['acetaminophen is a medication used to treat pain and fever',
|
21
22
|
'aspirin is a medication used to treat pain, fever, and inflammation']
|
22
|
-
|
23
|
+
|
23
24
|
Do not include any pretext as the output, only the list of strings enriched.
|
24
25
|
|
25
26
|
Input: {input}
|
@@ -28,10 +29,11 @@ def fixture_ollama_config():
|
|
28
29
|
"streaming": False,
|
29
30
|
}
|
30
31
|
|
32
|
+
|
31
33
|
def test_no_model_ollama(ollama_config):
|
32
34
|
"""Test the case when the Ollama model is not available."""
|
33
35
|
cfg = ollama_config
|
34
|
-
cfg_model = "smollm2:135m"
|
36
|
+
cfg_model = "smollm2:135m" # Choose a small model
|
35
37
|
|
36
38
|
# Delete the Ollama model
|
37
39
|
try:
|
@@ -41,7 +43,8 @@ def test_no_model_ollama(ollama_config):
|
|
41
43
|
|
42
44
|
# Check if the model is available
|
43
45
|
with pytest.raises(
|
44
|
-
ValueError,
|
46
|
+
ValueError,
|
47
|
+
match=f"Error: Pulled {cfg_model} model and restarted Ollama server.",
|
45
48
|
):
|
46
49
|
EnrichmentWithOllama(
|
47
50
|
model_name=cfg_model,
|
@@ -51,7 +54,8 @@ def test_no_model_ollama(ollama_config):
|
|
51
54
|
)
|
52
55
|
ollama.delete(cfg_model)
|
53
56
|
|
54
|
-
|
57
|
+
|
58
|
+
def test_enrich_ollama(ollama_config):
|
55
59
|
"""Test the Ollama textual enrichment class for node enrichment."""
|
56
60
|
# Prepare enrichment model
|
57
61
|
cfg = ollama_config
|
@@ -63,37 +67,11 @@ def test_enrich_nodes_ollama(ollama_config):
|
|
63
67
|
)
|
64
68
|
|
65
69
|
# Perform enrichment for nodes
|
66
|
-
nodes = ["
|
70
|
+
nodes = ["acetaminophen"]
|
67
71
|
enriched_nodes = enr_model.enrich_documents(nodes)
|
68
72
|
# Check the enriched nodes
|
69
|
-
assert len(enriched_nodes) ==
|
70
|
-
assert all(
|
71
|
-
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
def test_enrich_relations_ollama(ollama_config):
|
76
|
-
"""Test the Ollama textual enrichment class for relation enrichment."""
|
77
|
-
# Prepare enrichment model
|
78
|
-
cfg = ollama_config
|
79
|
-
enr_model = EnrichmentWithOllama(
|
80
|
-
model_name=cfg["model_name"],
|
81
|
-
prompt_enrichment=cfg["prompt_enrichment"],
|
82
|
-
temperature=cfg["temperature"],
|
83
|
-
streaming=cfg["streaming"],
|
84
|
-
)
|
85
|
-
# Perform enrichment for relations
|
86
|
-
relations = [
|
87
|
-
"IL23R-gene causation disease-inflammatory bowel diseases",
|
88
|
-
"NOD2-gene causation disease-inflammatory bowel diseases",
|
89
|
-
]
|
90
|
-
enriched_relations = enr_model.enrich_documents(relations)
|
91
|
-
# Check the enriched relations
|
92
|
-
assert len(enriched_relations) == 2
|
93
|
-
assert all(
|
94
|
-
enriched_relations[i] != relations[i]
|
95
|
-
for i in range(len(relations))
|
96
|
-
)
|
73
|
+
assert len(enriched_nodes) == 1
|
74
|
+
assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
|
97
75
|
|
98
76
|
|
99
77
|
def test_enrich_ollama_rag(ollama_config):
|
@@ -107,11 +85,9 @@ def test_enrich_ollama_rag(ollama_config):
|
|
107
85
|
streaming=cfg["streaming"],
|
108
86
|
)
|
109
87
|
# Perform enrichment for nodes
|
110
|
-
nodes = ["
|
88
|
+
nodes = ["acetaminophen"]
|
111
89
|
docs = [r"\path\to\doc1", r"\path\to\doc2"]
|
112
90
|
enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
|
113
91
|
# Check the enriched nodes
|
114
|
-
assert len(enriched_nodes) ==
|
115
|
-
assert all(
|
116
|
-
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
117
|
-
)
|
92
|
+
assert len(enriched_nodes) == 1
|
93
|
+
assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/kg_utils.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import networkx as nx
|
7
|
+
import pandas as pd
|
8
|
+
from ..utils import kg_utils
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture(name="sample_graph")
|
12
|
+
def make_sample_graph():
|
13
|
+
"""Return a sample graph"""
|
14
|
+
sg = nx.Graph()
|
15
|
+
sg.add_node(1, node_id=1, feature_id="A", feature_value="ValueA")
|
16
|
+
sg.add_node(2, node_id=2, feature_id="B", feature_value="ValueB")
|
17
|
+
sg.add_edge(1, 2, edge_id=1, feature_id="E", feature_value="EdgeValue")
|
18
|
+
return sg
|
19
|
+
|
20
|
+
|
21
|
+
def test_kg_to_df_pandas(sample_graph):
|
22
|
+
"""Test the kg_to_df_pandas function"""
|
23
|
+
df_nodes, df_edges = kg_utils.kg_to_df_pandas(sample_graph)
|
24
|
+
print(df_nodes)
|
25
|
+
expected_nodes_data = {
|
26
|
+
"node_id": [1, 2],
|
27
|
+
"feature_id": ["A", "B"],
|
28
|
+
"feature_value": ["ValueA", "ValueB"],
|
29
|
+
}
|
30
|
+
expected_nodes_df = pd.DataFrame(expected_nodes_data, index=[1, 2])
|
31
|
+
print(expected_nodes_df)
|
32
|
+
expected_edges_data = {
|
33
|
+
"node_source": [1],
|
34
|
+
"node_target": [2],
|
35
|
+
"edge_id": [1],
|
36
|
+
"feature_id": ["E"],
|
37
|
+
"feature_value": ["EdgeValue"],
|
38
|
+
}
|
39
|
+
expected_edges_df = pd.DataFrame(expected_edges_data)
|
40
|
+
|
41
|
+
# Assert that the dataframes are equal but the order of columns may be different
|
42
|
+
# Ignore the index of the dataframes
|
43
|
+
pd.testing.assert_frame_equal(df_nodes, expected_nodes_df, check_like=True)
|
44
|
+
pd.testing.assert_frame_equal(df_edges, expected_edges_df, check_like=True)
|
45
|
+
|
46
|
+
|
47
|
+
def test_df_pandas_to_kg():
|
48
|
+
"""Test the df_pandas_to_kg function"""
|
49
|
+
nodes_data = {
|
50
|
+
"node_id": [1, 2],
|
51
|
+
"feature_id": ["A", "B"],
|
52
|
+
"feature_value": ["ValueA", "ValueB"],
|
53
|
+
}
|
54
|
+
df_nodes_attrs = pd.DataFrame(nodes_data).set_index("node_id")
|
55
|
+
|
56
|
+
edges_data = {
|
57
|
+
"node_source": [1],
|
58
|
+
"node_target": [2],
|
59
|
+
"edge_id": [1],
|
60
|
+
"feature_id": ["E"],
|
61
|
+
"feature_value": ["EdgeValue"],
|
62
|
+
}
|
63
|
+
df_edges = pd.DataFrame(edges_data)
|
64
|
+
|
65
|
+
kg = kg_utils.df_pandas_to_kg(
|
66
|
+
df_edges, df_nodes_attrs, "node_source", "node_target"
|
67
|
+
)
|
68
|
+
|
69
|
+
assert len(kg.nodes) == 2
|
70
|
+
assert len(kg.edges) == 1
|
71
|
+
|
72
|
+
assert kg.nodes[1]["feature_id"] == "A"
|
73
|
+
assert kg.nodes[1]["feature_value"] == "ValueA"
|
74
|
+
assert kg.nodes[2]["feature_id"] == "B"
|
75
|
+
assert kg.nodes[2]["feature_value"] == "ValueB"
|
76
|
+
|
77
|
+
assert kg.edges[1, 2]["feature_id"] == "E"
|
78
|
+
assert kg.edges[1, 2]["feature_value"] == "EdgeValue"
|
79
|
+
assert kg.edges[1, 2]["edge_id"] == 1
|
@@ -0,0 +1,143 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing Graph RAG reasoning.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.messages import ToolMessage
|
10
|
+
from langchain_core.tools.base import InjectedToolCallId
|
11
|
+
from langchain_core.tools import BaseTool
|
12
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16
|
+
from langchain_community.document_loaders import PyPDFLoader
|
17
|
+
from langgraph.types import Command
|
18
|
+
from langgraph.prebuilt import InjectedState
|
19
|
+
import hydra
|
20
|
+
|
21
|
+
# Initialize logger
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class GraphRAGReasoningInput(BaseModel):
|
27
|
+
"""
|
28
|
+
GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
state: Injected state.
|
32
|
+
prompt: Prompt to interact with the backend.
|
33
|
+
extraction_name: Name assigned to the subgraph extraction process
|
34
|
+
"""
|
35
|
+
|
36
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
37
|
+
description="Tool call ID."
|
38
|
+
)
|
39
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
40
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
41
|
+
extraction_name: str = Field(
|
42
|
+
description="""Name assigned to the subgraph extraction process
|
43
|
+
when the subgraph_extraction tool is invoked."""
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class GraphRAGReasoningTool(BaseTool):
|
48
|
+
"""
|
49
|
+
This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
|
50
|
+
over user's request by considering textualized subgraph context and document context.
|
51
|
+
"""
|
52
|
+
|
53
|
+
name: str = "graphrag_reasoning"
|
54
|
+
description: str = """A tool to perform reasoning using a Graph RAG approach
|
55
|
+
by considering textualized subgraph context and document context."""
|
56
|
+
args_schema: Type[BaseModel] = GraphRAGReasoningInput
|
57
|
+
|
58
|
+
def _run(
|
59
|
+
self,
|
60
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
61
|
+
state: Annotated[dict, InjectedState],
|
62
|
+
prompt: str,
|
63
|
+
extraction_name: str,
|
64
|
+
):
|
65
|
+
"""
|
66
|
+
Run the Graph RAG reasoning tool.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
tool_call_id: The tool call ID.
|
70
|
+
state: The injected state.
|
71
|
+
prompt: The prompt to interact with the backend.
|
72
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
73
|
+
"""
|
74
|
+
logger.log(
|
75
|
+
logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
|
76
|
+
)
|
77
|
+
|
78
|
+
# Load Hydra configuration
|
79
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
80
|
+
cfg = hydra.compose(
|
81
|
+
config_name="config", overrides=["tools/graphrag_reasoning=default"]
|
82
|
+
)
|
83
|
+
cfg = cfg.tools.graphrag_reasoning
|
84
|
+
|
85
|
+
# Prepare documents
|
86
|
+
all_docs = []
|
87
|
+
if len(state["uploaded_files"]) != 0:
|
88
|
+
for uploaded_file in state["uploaded_files"]:
|
89
|
+
if uploaded_file["file_type"] == "drug_data":
|
90
|
+
# Load documents
|
91
|
+
raw_documents = PyPDFLoader(
|
92
|
+
file_path=uploaded_file["file_path"]
|
93
|
+
).load()
|
94
|
+
|
95
|
+
# Split documents
|
96
|
+
# May need to find an optimal chunk size and overlap configuration
|
97
|
+
documents = RecursiveCharacterTextSplitter(
|
98
|
+
chunk_size=cfg.splitter_chunk_size,
|
99
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
100
|
+
).split_documents(raw_documents)
|
101
|
+
|
102
|
+
# Add documents to the list
|
103
|
+
all_docs.extend(documents)
|
104
|
+
|
105
|
+
# Load the extracted graph
|
106
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
107
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
108
|
+
|
109
|
+
# Set another prompt template
|
110
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
111
|
+
[("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
|
112
|
+
)
|
113
|
+
|
114
|
+
# Prepare chain with retrieved documents
|
115
|
+
qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
|
116
|
+
rag_chain = create_retrieval_chain(
|
117
|
+
InMemoryVectorStore.from_documents(
|
118
|
+
documents=all_docs, embedding=state["embedding_model"]
|
119
|
+
).as_retriever(
|
120
|
+
search_type=cfg.retriever_search_type,
|
121
|
+
search_kwargs={
|
122
|
+
"k": cfg.retriever_k,
|
123
|
+
"fetch_k": cfg.retriever_fetch_k,
|
124
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
125
|
+
},
|
126
|
+
),
|
127
|
+
qa_chain,
|
128
|
+
)
|
129
|
+
|
130
|
+
# Invoke the chain
|
131
|
+
response = rag_chain.invoke(
|
132
|
+
{
|
133
|
+
"input": prompt,
|
134
|
+
"subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
|
135
|
+
}
|
136
|
+
)
|
137
|
+
|
138
|
+
return Command(
|
139
|
+
update={
|
140
|
+
# update the message history
|
141
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
142
|
+
}
|
143
|
+
)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
A utility module for defining the dataclasses
|
3
|
+
for the arguments to set up initial settings
|
4
|
+
"""
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import Annotated
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class ArgumentData:
|
12
|
+
"""
|
13
|
+
Dataclass for storing the argument data.
|
14
|
+
"""
|
15
|
+
|
16
|
+
extraction_name: Annotated[
|
17
|
+
str,
|
18
|
+
"""An AI assigned _ separated name of the subgraph extraction
|
19
|
+
based on human query and the context of the graph reasoning
|
20
|
+
experiment.
|
21
|
+
This must be set before the subgraph extraction is invoked.""",
|
22
|
+
]
|
@@ -0,0 +1,305 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph extraction.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Type, Annotated
|
6
|
+
import logging
|
7
|
+
import pickle
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import hydra
|
11
|
+
import networkx as nx
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain_core.prompts import ChatPromptTemplate
|
16
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
17
|
+
from langchain_core.tools import BaseTool
|
18
|
+
from langchain_core.messages import ToolMessage
|
19
|
+
from langchain_core.tools.base import InjectedToolCallId
|
20
|
+
from langchain_community.document_loaders import PyPDFLoader
|
21
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
22
|
+
from langgraph.types import Command
|
23
|
+
from langgraph.prebuilt import InjectedState
|
24
|
+
import torch
|
25
|
+
from torch_geometric.data import Data
|
26
|
+
from ..utils.extractions.pcst import PCSTPruning
|
27
|
+
from ..utils.embeddings.ollama import EmbeddingWithOllama
|
28
|
+
from .load_arguments import ArgumentData
|
29
|
+
|
30
|
+
# Initialize logger
|
31
|
+
logging.basicConfig(level=logging.INFO)
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class SubgraphExtractionInput(BaseModel):
|
36
|
+
"""
|
37
|
+
SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
prompt: Prompt to interact with the backend.
|
41
|
+
tool_call_id: Tool call ID.
|
42
|
+
state: Injected state.
|
43
|
+
arg_data: Argument for analytical process over graph data.
|
44
|
+
"""
|
45
|
+
|
46
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
47
|
+
description="Tool call ID."
|
48
|
+
)
|
49
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
50
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
51
|
+
arg_data: ArgumentData = Field(
|
52
|
+
description="Experiment over graph data.", default=None
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class SubgraphExtractionTool(BaseTool):
|
57
|
+
"""
|
58
|
+
This tool performs subgraph extraction based on user's prompt by taking into account
|
59
|
+
the top-k nodes and edges.
|
60
|
+
"""
|
61
|
+
|
62
|
+
name: str = "subgraph_extraction"
|
63
|
+
description: str = "A tool for subgraph extraction based on user's prompt."
|
64
|
+
args_schema: Type[BaseModel] = SubgraphExtractionInput
|
65
|
+
|
66
|
+
def perform_endotype_filtering(
|
67
|
+
self,
|
68
|
+
prompt: str,
|
69
|
+
state: Annotated[dict, InjectedState],
|
70
|
+
cfg: hydra.core.config_store.ConfigStore,
|
71
|
+
) -> str:
|
72
|
+
"""
|
73
|
+
Perform endotype filtering based on the uploaded files and prepare the prompt.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
prompt: The prompt to interact with the backend.
|
77
|
+
state: Injected state for the tool.
|
78
|
+
cfg: Hydra configuration object.
|
79
|
+
"""
|
80
|
+
# Loop through the uploaded files
|
81
|
+
all_genes = []
|
82
|
+
for uploaded_file in state["uploaded_files"]:
|
83
|
+
if uploaded_file["file_type"] == "endotype":
|
84
|
+
# Load the PDF file
|
85
|
+
docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
|
86
|
+
|
87
|
+
# Split the text into chunks
|
88
|
+
splits = RecursiveCharacterTextSplitter(
|
89
|
+
chunk_size=cfg.splitter_chunk_size,
|
90
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
91
|
+
).split_documents(docs)
|
92
|
+
|
93
|
+
# Create a chat prompt template
|
94
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
95
|
+
[
|
96
|
+
("system", cfg.prompt_endotype_filtering),
|
97
|
+
("human", "{input}"),
|
98
|
+
]
|
99
|
+
)
|
100
|
+
|
101
|
+
qa_chain = create_stuff_documents_chain(
|
102
|
+
state["llm_model"], prompt_template
|
103
|
+
)
|
104
|
+
rag_chain = create_retrieval_chain(
|
105
|
+
InMemoryVectorStore.from_documents(
|
106
|
+
documents=splits, embedding=state["embedding_model"]
|
107
|
+
).as_retriever(
|
108
|
+
search_type=cfg.retriever_search_type,
|
109
|
+
search_kwargs={
|
110
|
+
"k": cfg.retriever_k,
|
111
|
+
"fetch_k": cfg.retriever_fetch_k,
|
112
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
113
|
+
},
|
114
|
+
),
|
115
|
+
qa_chain,
|
116
|
+
)
|
117
|
+
results = rag_chain.invoke({"input": prompt})
|
118
|
+
all_genes.append(results["answer"])
|
119
|
+
|
120
|
+
# Prepare the prompt
|
121
|
+
if len(all_genes) > 0:
|
122
|
+
prompt = " ".join(
|
123
|
+
[prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
|
124
|
+
)
|
125
|
+
|
126
|
+
return prompt
|
127
|
+
|
128
|
+
def prepare_final_subgraph(self,
|
129
|
+
subgraph: dict,
|
130
|
+
pyg_graph: Data,
|
131
|
+
textualized_graph: pd.DataFrame) -> dict:
|
132
|
+
"""
|
133
|
+
Prepare the subgraph based on the extracted subgraph.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
subgraph: The extracted subgraph.
|
137
|
+
pyg_graph: The PyTorch Geometric graph.
|
138
|
+
textualized_graph: The textualized graph.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
142
|
+
"""
|
143
|
+
# print(subgraph)
|
144
|
+
# Prepare the PyTorch Geometric graph
|
145
|
+
mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
|
146
|
+
pyg_graph = Data(
|
147
|
+
# Node features
|
148
|
+
x=pyg_graph.x[subgraph["nodes"]],
|
149
|
+
node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
150
|
+
node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
151
|
+
enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
|
152
|
+
num_nodes=len(subgraph["nodes"]),
|
153
|
+
# Edge features
|
154
|
+
edge_index=torch.LongTensor(
|
155
|
+
[
|
156
|
+
[
|
157
|
+
mapping[i]
|
158
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
|
159
|
+
],
|
160
|
+
[
|
161
|
+
mapping[i]
|
162
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
|
163
|
+
],
|
164
|
+
]
|
165
|
+
),
|
166
|
+
edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
|
167
|
+
edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
168
|
+
relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
169
|
+
label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
170
|
+
enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
|
171
|
+
)
|
172
|
+
|
173
|
+
# Networkx DiGraph construction to be visualized in the frontend
|
174
|
+
nx_graph = nx.DiGraph()
|
175
|
+
for n in pyg_graph.node_name:
|
176
|
+
nx_graph.add_node(n)
|
177
|
+
for i, e in enumerate(
|
178
|
+
[
|
179
|
+
[pyg_graph.node_name[i], pyg_graph.node_name[j]]
|
180
|
+
for (i, j) in pyg_graph.edge_index.transpose(1, 0)
|
181
|
+
]
|
182
|
+
):
|
183
|
+
nx_graph.add_edge(
|
184
|
+
e[0],
|
185
|
+
e[1],
|
186
|
+
relation=pyg_graph.edge_type[i],
|
187
|
+
label=pyg_graph.edge_type[i],
|
188
|
+
)
|
189
|
+
|
190
|
+
# Prepare the textualized subgraph
|
191
|
+
textualized_graph = (
|
192
|
+
textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
|
193
|
+
+ "\n"
|
194
|
+
+ textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
|
195
|
+
)
|
196
|
+
|
197
|
+
return {
|
198
|
+
"graph_pyg": pyg_graph,
|
199
|
+
"graph_nx": nx_graph,
|
200
|
+
"graph_text": textualized_graph,
|
201
|
+
}
|
202
|
+
|
203
|
+
def _run(
|
204
|
+
self,
|
205
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
206
|
+
state: Annotated[dict, InjectedState],
|
207
|
+
prompt: str,
|
208
|
+
arg_data: ArgumentData = None,
|
209
|
+
) -> Command:
|
210
|
+
"""
|
211
|
+
Run the subgraph extraction tool.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
tool_call_id: The tool call ID for the tool.
|
215
|
+
state: Injected state for the tool.
|
216
|
+
prompt: The prompt to interact with the backend.
|
217
|
+
arg_data (ArgumentData): The argument data.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
Command: The command to be executed.
|
221
|
+
"""
|
222
|
+
logger.log(logging.INFO, "Invoking subgraph_extraction tool")
|
223
|
+
|
224
|
+
# Load hydra configuration
|
225
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
226
|
+
cfg = hydra.compose(
|
227
|
+
config_name="config", overrides=["tools/subgraph_extraction=default"]
|
228
|
+
)
|
229
|
+
cfg = cfg.tools.subgraph_extraction
|
230
|
+
|
231
|
+
# Retrieve source graph from the state
|
232
|
+
initial_graph = {}
|
233
|
+
initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
|
234
|
+
# logger.log(logging.INFO, "Source graph: %s", source_graph)
|
235
|
+
|
236
|
+
# Load the knowledge graph
|
237
|
+
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
|
238
|
+
initial_graph["pyg"] = pickle.load(f)
|
239
|
+
with open(initial_graph["source"]["kg_text_path"], "rb") as f:
|
240
|
+
initial_graph["text"] = pickle.load(f)
|
241
|
+
|
242
|
+
# Prepare prompt construction along with a list of endotypes
|
243
|
+
if len(state["uploaded_files"]) != 0 and "endotype" in [
|
244
|
+
f["file_type"] for f in state["uploaded_files"]
|
245
|
+
]:
|
246
|
+
prompt = self.perform_endotype_filtering(prompt, state, cfg)
|
247
|
+
|
248
|
+
# Prepare embedding model and embed the user prompt as query
|
249
|
+
query_emb = torch.tensor(
|
250
|
+
EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
|
251
|
+
).float()
|
252
|
+
|
253
|
+
# Prepare the PCSTPruning object and extract the subgraph
|
254
|
+
# Parameters were set in the configuration file obtained from Hydra
|
255
|
+
subgraph = PCSTPruning(
|
256
|
+
state["topk_nodes"],
|
257
|
+
state["topk_edges"],
|
258
|
+
cfg.cost_e,
|
259
|
+
cfg.c_const,
|
260
|
+
cfg.root,
|
261
|
+
cfg.num_clusters,
|
262
|
+
cfg.pruning,
|
263
|
+
cfg.verbosity_level,
|
264
|
+
).extract_subgraph(initial_graph["pyg"], query_emb)
|
265
|
+
|
266
|
+
# Prepare subgraph as a NetworkX graph and textualized graph
|
267
|
+
final_subgraph = self.prepare_final_subgraph(
|
268
|
+
subgraph, initial_graph["pyg"], initial_graph["text"]
|
269
|
+
)
|
270
|
+
|
271
|
+
# Prepare the dictionary of extracted graph
|
272
|
+
dic_extracted_graph = {
|
273
|
+
"name": arg_data.extraction_name,
|
274
|
+
"tool_call_id": tool_call_id,
|
275
|
+
"graph_source": initial_graph["source"]["name"],
|
276
|
+
"topk_nodes": state["topk_nodes"],
|
277
|
+
"topk_edges": state["topk_edges"],
|
278
|
+
"graph_dict": {
|
279
|
+
"nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
|
280
|
+
"edges": list(final_subgraph["graph_nx"].edges(data=True)),
|
281
|
+
},
|
282
|
+
"graph_text": final_subgraph["graph_text"],
|
283
|
+
"graph_summary": None,
|
284
|
+
}
|
285
|
+
|
286
|
+
# Prepare the dictionary of updated state
|
287
|
+
dic_updated_state_for_model = {}
|
288
|
+
for key, value in {
|
289
|
+
"dic_extracted_graph": [dic_extracted_graph],
|
290
|
+
}.items():
|
291
|
+
if value:
|
292
|
+
dic_updated_state_for_model[key] = value
|
293
|
+
|
294
|
+
# Return the updated state of the tool
|
295
|
+
return Command(
|
296
|
+
update=dic_updated_state_for_model | {
|
297
|
+
# update the message history
|
298
|
+
"messages": [
|
299
|
+
ToolMessage(
|
300
|
+
content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
|
301
|
+
tool_call_id=tool_call_id,
|
302
|
+
)
|
303
|
+
],
|
304
|
+
}
|
305
|
+
)
|