aiagents4pharma 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +3 -1
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +42 -10
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,21 @@ import pytest
|
|
6
6
|
import ollama
|
7
7
|
from ..utils.enrichments.ollama import EnrichmentWithOllama
|
8
8
|
|
9
|
+
|
9
10
|
@pytest.fixture(name="ollama_config")
|
10
11
|
def fixture_ollama_config():
|
11
12
|
"""Return a dictionary with Ollama configuration."""
|
12
13
|
return {
|
13
|
-
"model_name": "
|
14
|
+
"model_name": "llama3.2:1b",
|
14
15
|
"prompt_enrichment": """
|
15
|
-
Given the input as a list of strings, please return the list of addditional information
|
16
|
-
each input terms using your prior knowledge.
|
16
|
+
Given the input as a list of strings, please return the list of addditional information
|
17
|
+
of each input terms using your prior knowledge.
|
17
18
|
|
18
19
|
Example:
|
19
20
|
Input: ['acetaminophen', 'aspirin']
|
20
|
-
Ouput: ['acetaminophen is a medication used to treat pain and fever',
|
21
|
+
Ouput: ['acetaminophen is a medication used to treat pain and fever',
|
21
22
|
'aspirin is a medication used to treat pain, fever, and inflammation']
|
22
|
-
|
23
|
+
|
23
24
|
Do not include any pretext as the output, only the list of strings enriched.
|
24
25
|
|
25
26
|
Input: {input}
|
@@ -28,10 +29,11 @@ def fixture_ollama_config():
|
|
28
29
|
"streaming": False,
|
29
30
|
}
|
30
31
|
|
32
|
+
|
31
33
|
def test_no_model_ollama(ollama_config):
|
32
34
|
"""Test the case when the Ollama model is not available."""
|
33
35
|
cfg = ollama_config
|
34
|
-
cfg_model = "smollm2:135m"
|
36
|
+
cfg_model = "smollm2:135m" # Choose a small model
|
35
37
|
|
36
38
|
# Delete the Ollama model
|
37
39
|
try:
|
@@ -41,7 +43,8 @@ def test_no_model_ollama(ollama_config):
|
|
41
43
|
|
42
44
|
# Check if the model is available
|
43
45
|
with pytest.raises(
|
44
|
-
ValueError,
|
46
|
+
ValueError,
|
47
|
+
match=f"Error: Pulled {cfg_model} model and restarted Ollama server.",
|
45
48
|
):
|
46
49
|
EnrichmentWithOllama(
|
47
50
|
model_name=cfg_model,
|
@@ -51,7 +54,8 @@ def test_no_model_ollama(ollama_config):
|
|
51
54
|
)
|
52
55
|
ollama.delete(cfg_model)
|
53
56
|
|
54
|
-
|
57
|
+
|
58
|
+
def test_enrich_ollama(ollama_config):
|
55
59
|
"""Test the Ollama textual enrichment class for node enrichment."""
|
56
60
|
# Prepare enrichment model
|
57
61
|
cfg = ollama_config
|
@@ -63,37 +67,11 @@ def test_enrich_nodes_ollama(ollama_config):
|
|
63
67
|
)
|
64
68
|
|
65
69
|
# Perform enrichment for nodes
|
66
|
-
nodes = ["
|
70
|
+
nodes = ["acetaminophen"]
|
67
71
|
enriched_nodes = enr_model.enrich_documents(nodes)
|
68
72
|
# Check the enriched nodes
|
69
|
-
assert len(enriched_nodes) ==
|
70
|
-
assert all(
|
71
|
-
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
def test_enrich_relations_ollama(ollama_config):
|
76
|
-
"""Test the Ollama textual enrichment class for relation enrichment."""
|
77
|
-
# Prepare enrichment model
|
78
|
-
cfg = ollama_config
|
79
|
-
enr_model = EnrichmentWithOllama(
|
80
|
-
model_name=cfg["model_name"],
|
81
|
-
prompt_enrichment=cfg["prompt_enrichment"],
|
82
|
-
temperature=cfg["temperature"],
|
83
|
-
streaming=cfg["streaming"],
|
84
|
-
)
|
85
|
-
# Perform enrichment for relations
|
86
|
-
relations = [
|
87
|
-
"IL23R-gene causation disease-inflammatory bowel diseases",
|
88
|
-
"NOD2-gene causation disease-inflammatory bowel diseases",
|
89
|
-
]
|
90
|
-
enriched_relations = enr_model.enrich_documents(relations)
|
91
|
-
# Check the enriched relations
|
92
|
-
assert len(enriched_relations) == 2
|
93
|
-
assert all(
|
94
|
-
enriched_relations[i] != relations[i]
|
95
|
-
for i in range(len(relations))
|
96
|
-
)
|
73
|
+
assert len(enriched_nodes) == 1
|
74
|
+
assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
|
97
75
|
|
98
76
|
|
99
77
|
def test_enrich_ollama_rag(ollama_config):
|
@@ -107,11 +85,9 @@ def test_enrich_ollama_rag(ollama_config):
|
|
107
85
|
streaming=cfg["streaming"],
|
108
86
|
)
|
109
87
|
# Perform enrichment for nodes
|
110
|
-
nodes = ["
|
88
|
+
nodes = ["acetaminophen"]
|
111
89
|
docs = [r"\path\to\doc1", r"\path\to\doc2"]
|
112
90
|
enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
|
113
91
|
# Check the enriched nodes
|
114
|
-
assert len(enriched_nodes) ==
|
115
|
-
assert all(
|
116
|
-
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
117
|
-
)
|
92
|
+
assert len(enriched_nodes) == 1
|
93
|
+
assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/kg_utils.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import networkx as nx
|
7
|
+
import pandas as pd
|
8
|
+
from ..utils import kg_utils
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture(name="sample_graph")
|
12
|
+
def make_sample_graph():
|
13
|
+
"""Return a sample graph"""
|
14
|
+
sg = nx.Graph()
|
15
|
+
sg.add_node(1, node_id=1, feature_id="A", feature_value="ValueA")
|
16
|
+
sg.add_node(2, node_id=2, feature_id="B", feature_value="ValueB")
|
17
|
+
sg.add_edge(1, 2, edge_id=1, feature_id="E", feature_value="EdgeValue")
|
18
|
+
return sg
|
19
|
+
|
20
|
+
|
21
|
+
def test_kg_to_df_pandas(sample_graph):
|
22
|
+
"""Test the kg_to_df_pandas function"""
|
23
|
+
df_nodes, df_edges = kg_utils.kg_to_df_pandas(sample_graph)
|
24
|
+
print(df_nodes)
|
25
|
+
expected_nodes_data = {
|
26
|
+
"node_id": [1, 2],
|
27
|
+
"feature_id": ["A", "B"],
|
28
|
+
"feature_value": ["ValueA", "ValueB"],
|
29
|
+
}
|
30
|
+
expected_nodes_df = pd.DataFrame(expected_nodes_data, index=[1, 2])
|
31
|
+
print(expected_nodes_df)
|
32
|
+
expected_edges_data = {
|
33
|
+
"node_source": [1],
|
34
|
+
"node_target": [2],
|
35
|
+
"edge_id": [1],
|
36
|
+
"feature_id": ["E"],
|
37
|
+
"feature_value": ["EdgeValue"],
|
38
|
+
}
|
39
|
+
expected_edges_df = pd.DataFrame(expected_edges_data)
|
40
|
+
|
41
|
+
# Assert that the dataframes are equal but the order of columns may be different
|
42
|
+
# Ignore the index of the dataframes
|
43
|
+
pd.testing.assert_frame_equal(df_nodes, expected_nodes_df, check_like=True)
|
44
|
+
pd.testing.assert_frame_equal(df_edges, expected_edges_df, check_like=True)
|
45
|
+
|
46
|
+
|
47
|
+
def test_df_pandas_to_kg():
|
48
|
+
"""Test the df_pandas_to_kg function"""
|
49
|
+
nodes_data = {
|
50
|
+
"node_id": [1, 2],
|
51
|
+
"feature_id": ["A", "B"],
|
52
|
+
"feature_value": ["ValueA", "ValueB"],
|
53
|
+
}
|
54
|
+
df_nodes_attrs = pd.DataFrame(nodes_data).set_index("node_id")
|
55
|
+
|
56
|
+
edges_data = {
|
57
|
+
"node_source": [1],
|
58
|
+
"node_target": [2],
|
59
|
+
"edge_id": [1],
|
60
|
+
"feature_id": ["E"],
|
61
|
+
"feature_value": ["EdgeValue"],
|
62
|
+
}
|
63
|
+
df_edges = pd.DataFrame(edges_data)
|
64
|
+
|
65
|
+
kg = kg_utils.df_pandas_to_kg(
|
66
|
+
df_edges, df_nodes_attrs, "node_source", "node_target"
|
67
|
+
)
|
68
|
+
|
69
|
+
assert len(kg.nodes) == 2
|
70
|
+
assert len(kg.edges) == 1
|
71
|
+
|
72
|
+
assert kg.nodes[1]["feature_id"] == "A"
|
73
|
+
assert kg.nodes[1]["feature_value"] == "ValueA"
|
74
|
+
assert kg.nodes[2]["feature_id"] == "B"
|
75
|
+
assert kg.nodes[2]["feature_value"] == "ValueB"
|
76
|
+
|
77
|
+
assert kg.edges[1, 2]["feature_id"] == "E"
|
78
|
+
assert kg.edges[1, 2]["feature_value"] == "EdgeValue"
|
79
|
+
assert kg.edges[1, 2]["edge_id"] == 1
|
@@ -0,0 +1,143 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing Graph RAG reasoning.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.messages import ToolMessage
|
10
|
+
from langchain_core.tools.base import InjectedToolCallId
|
11
|
+
from langchain_core.tools import BaseTool
|
12
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16
|
+
from langchain_community.document_loaders import PyPDFLoader
|
17
|
+
from langgraph.types import Command
|
18
|
+
from langgraph.prebuilt import InjectedState
|
19
|
+
import hydra
|
20
|
+
|
21
|
+
# Initialize logger
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class GraphRAGReasoningInput(BaseModel):
|
27
|
+
"""
|
28
|
+
GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
state: Injected state.
|
32
|
+
prompt: Prompt to interact with the backend.
|
33
|
+
extraction_name: Name assigned to the subgraph extraction process
|
34
|
+
"""
|
35
|
+
|
36
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
37
|
+
description="Tool call ID."
|
38
|
+
)
|
39
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
40
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
41
|
+
extraction_name: str = Field(
|
42
|
+
description="""Name assigned to the subgraph extraction process
|
43
|
+
when the subgraph_extraction tool is invoked."""
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class GraphRAGReasoningTool(BaseTool):
|
48
|
+
"""
|
49
|
+
This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
|
50
|
+
over user's request by considering textualized subgraph context and document context.
|
51
|
+
"""
|
52
|
+
|
53
|
+
name: str = "graphrag_reasoning"
|
54
|
+
description: str = """A tool to perform reasoning using a Graph RAG approach
|
55
|
+
by considering textualized subgraph context and document context."""
|
56
|
+
args_schema: Type[BaseModel] = GraphRAGReasoningInput
|
57
|
+
|
58
|
+
def _run(
|
59
|
+
self,
|
60
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
61
|
+
state: Annotated[dict, InjectedState],
|
62
|
+
prompt: str,
|
63
|
+
extraction_name: str,
|
64
|
+
):
|
65
|
+
"""
|
66
|
+
Run the Graph RAG reasoning tool.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
tool_call_id: The tool call ID.
|
70
|
+
state: The injected state.
|
71
|
+
prompt: The prompt to interact with the backend.
|
72
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
73
|
+
"""
|
74
|
+
logger.log(
|
75
|
+
logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
|
76
|
+
)
|
77
|
+
|
78
|
+
# Load Hydra configuration
|
79
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
80
|
+
cfg = hydra.compose(
|
81
|
+
config_name="config", overrides=["tools/graphrag_reasoning=default"]
|
82
|
+
)
|
83
|
+
cfg = cfg.tools.graphrag_reasoning
|
84
|
+
|
85
|
+
# Prepare documents
|
86
|
+
all_docs = []
|
87
|
+
if len(state["uploaded_files"]) != 0:
|
88
|
+
for uploaded_file in state["uploaded_files"]:
|
89
|
+
if uploaded_file["file_type"] == "drug_data":
|
90
|
+
# Load documents
|
91
|
+
raw_documents = PyPDFLoader(
|
92
|
+
file_path=uploaded_file["file_path"]
|
93
|
+
).load()
|
94
|
+
|
95
|
+
# Split documents
|
96
|
+
# May need to find an optimal chunk size and overlap configuration
|
97
|
+
documents = RecursiveCharacterTextSplitter(
|
98
|
+
chunk_size=cfg.splitter_chunk_size,
|
99
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
100
|
+
).split_documents(raw_documents)
|
101
|
+
|
102
|
+
# Add documents to the list
|
103
|
+
all_docs.extend(documents)
|
104
|
+
|
105
|
+
# Load the extracted graph
|
106
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
107
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
108
|
+
|
109
|
+
# Set another prompt template
|
110
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
111
|
+
[("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
|
112
|
+
)
|
113
|
+
|
114
|
+
# Prepare chain with retrieved documents
|
115
|
+
qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
|
116
|
+
rag_chain = create_retrieval_chain(
|
117
|
+
InMemoryVectorStore.from_documents(
|
118
|
+
documents=all_docs, embedding=state["embedding_model"]
|
119
|
+
).as_retriever(
|
120
|
+
search_type=cfg.retriever_search_type,
|
121
|
+
search_kwargs={
|
122
|
+
"k": cfg.retriever_k,
|
123
|
+
"fetch_k": cfg.retriever_fetch_k,
|
124
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
125
|
+
},
|
126
|
+
),
|
127
|
+
qa_chain,
|
128
|
+
)
|
129
|
+
|
130
|
+
# Invoke the chain
|
131
|
+
response = rag_chain.invoke(
|
132
|
+
{
|
133
|
+
"input": prompt,
|
134
|
+
"subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
|
135
|
+
}
|
136
|
+
)
|
137
|
+
|
138
|
+
return Command(
|
139
|
+
update={
|
140
|
+
# update the message history
|
141
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
142
|
+
}
|
143
|
+
)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
A utility module for defining the dataclasses
|
3
|
+
for the arguments to set up initial settings
|
4
|
+
"""
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import Annotated
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class ArgumentData:
|
12
|
+
"""
|
13
|
+
Dataclass for storing the argument data.
|
14
|
+
"""
|
15
|
+
|
16
|
+
extraction_name: Annotated[
|
17
|
+
str,
|
18
|
+
"""An AI assigned _ separated name of the subgraph extraction
|
19
|
+
based on human query and the context of the graph reasoning
|
20
|
+
experiment.
|
21
|
+
This must be set before the subgraph extraction is invoked.""",
|
22
|
+
]
|
@@ -0,0 +1,305 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph extraction.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Type, Annotated
|
6
|
+
import logging
|
7
|
+
import pickle
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import hydra
|
11
|
+
import networkx as nx
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain_core.prompts import ChatPromptTemplate
|
16
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
17
|
+
from langchain_core.tools import BaseTool
|
18
|
+
from langchain_core.messages import ToolMessage
|
19
|
+
from langchain_core.tools.base import InjectedToolCallId
|
20
|
+
from langchain_community.document_loaders import PyPDFLoader
|
21
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
22
|
+
from langgraph.types import Command
|
23
|
+
from langgraph.prebuilt import InjectedState
|
24
|
+
import torch
|
25
|
+
from torch_geometric.data import Data
|
26
|
+
from ..utils.extractions.pcst import PCSTPruning
|
27
|
+
from ..utils.embeddings.ollama import EmbeddingWithOllama
|
28
|
+
from .load_arguments import ArgumentData
|
29
|
+
|
30
|
+
# Initialize logger
|
31
|
+
logging.basicConfig(level=logging.INFO)
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class SubgraphExtractionInput(BaseModel):
|
36
|
+
"""
|
37
|
+
SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
prompt: Prompt to interact with the backend.
|
41
|
+
tool_call_id: Tool call ID.
|
42
|
+
state: Injected state.
|
43
|
+
arg_data: Argument for analytical process over graph data.
|
44
|
+
"""
|
45
|
+
|
46
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
47
|
+
description="Tool call ID."
|
48
|
+
)
|
49
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
50
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
51
|
+
arg_data: ArgumentData = Field(
|
52
|
+
description="Experiment over graph data.", default=None
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class SubgraphExtractionTool(BaseTool):
|
57
|
+
"""
|
58
|
+
This tool performs subgraph extraction based on user's prompt by taking into account
|
59
|
+
the top-k nodes and edges.
|
60
|
+
"""
|
61
|
+
|
62
|
+
name: str = "subgraph_extraction"
|
63
|
+
description: str = "A tool for subgraph extraction based on user's prompt."
|
64
|
+
args_schema: Type[BaseModel] = SubgraphExtractionInput
|
65
|
+
|
66
|
+
def perform_endotype_filtering(
|
67
|
+
self,
|
68
|
+
prompt: str,
|
69
|
+
state: Annotated[dict, InjectedState],
|
70
|
+
cfg: hydra.core.config_store.ConfigStore,
|
71
|
+
) -> str:
|
72
|
+
"""
|
73
|
+
Perform endotype filtering based on the uploaded files and prepare the prompt.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
prompt: The prompt to interact with the backend.
|
77
|
+
state: Injected state for the tool.
|
78
|
+
cfg: Hydra configuration object.
|
79
|
+
"""
|
80
|
+
# Loop through the uploaded files
|
81
|
+
all_genes = []
|
82
|
+
for uploaded_file in state["uploaded_files"]:
|
83
|
+
if uploaded_file["file_type"] == "endotype":
|
84
|
+
# Load the PDF file
|
85
|
+
docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
|
86
|
+
|
87
|
+
# Split the text into chunks
|
88
|
+
splits = RecursiveCharacterTextSplitter(
|
89
|
+
chunk_size=cfg.splitter_chunk_size,
|
90
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
91
|
+
).split_documents(docs)
|
92
|
+
|
93
|
+
# Create a chat prompt template
|
94
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
95
|
+
[
|
96
|
+
("system", cfg.prompt_endotype_filtering),
|
97
|
+
("human", "{input}"),
|
98
|
+
]
|
99
|
+
)
|
100
|
+
|
101
|
+
qa_chain = create_stuff_documents_chain(
|
102
|
+
state["llm_model"], prompt_template
|
103
|
+
)
|
104
|
+
rag_chain = create_retrieval_chain(
|
105
|
+
InMemoryVectorStore.from_documents(
|
106
|
+
documents=splits, embedding=state["embedding_model"]
|
107
|
+
).as_retriever(
|
108
|
+
search_type=cfg.retriever_search_type,
|
109
|
+
search_kwargs={
|
110
|
+
"k": cfg.retriever_k,
|
111
|
+
"fetch_k": cfg.retriever_fetch_k,
|
112
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
113
|
+
},
|
114
|
+
),
|
115
|
+
qa_chain,
|
116
|
+
)
|
117
|
+
results = rag_chain.invoke({"input": prompt})
|
118
|
+
all_genes.append(results["answer"])
|
119
|
+
|
120
|
+
# Prepare the prompt
|
121
|
+
if len(all_genes) > 0:
|
122
|
+
prompt = " ".join(
|
123
|
+
[prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
|
124
|
+
)
|
125
|
+
|
126
|
+
return prompt
|
127
|
+
|
128
|
+
def prepare_final_subgraph(self,
|
129
|
+
subgraph: dict,
|
130
|
+
pyg_graph: Data,
|
131
|
+
textualized_graph: pd.DataFrame) -> dict:
|
132
|
+
"""
|
133
|
+
Prepare the subgraph based on the extracted subgraph.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
subgraph: The extracted subgraph.
|
137
|
+
pyg_graph: The PyTorch Geometric graph.
|
138
|
+
textualized_graph: The textualized graph.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
142
|
+
"""
|
143
|
+
# print(subgraph)
|
144
|
+
# Prepare the PyTorch Geometric graph
|
145
|
+
mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
|
146
|
+
pyg_graph = Data(
|
147
|
+
# Node features
|
148
|
+
x=pyg_graph.x[subgraph["nodes"]],
|
149
|
+
node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
150
|
+
node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
151
|
+
enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
|
152
|
+
num_nodes=len(subgraph["nodes"]),
|
153
|
+
# Edge features
|
154
|
+
edge_index=torch.LongTensor(
|
155
|
+
[
|
156
|
+
[
|
157
|
+
mapping[i]
|
158
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
|
159
|
+
],
|
160
|
+
[
|
161
|
+
mapping[i]
|
162
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
|
163
|
+
],
|
164
|
+
]
|
165
|
+
),
|
166
|
+
edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
|
167
|
+
edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
168
|
+
relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
169
|
+
label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
170
|
+
enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
|
171
|
+
)
|
172
|
+
|
173
|
+
# Networkx DiGraph construction to be visualized in the frontend
|
174
|
+
nx_graph = nx.DiGraph()
|
175
|
+
for n in pyg_graph.node_name:
|
176
|
+
nx_graph.add_node(n)
|
177
|
+
for i, e in enumerate(
|
178
|
+
[
|
179
|
+
[pyg_graph.node_name[i], pyg_graph.node_name[j]]
|
180
|
+
for (i, j) in pyg_graph.edge_index.transpose(1, 0)
|
181
|
+
]
|
182
|
+
):
|
183
|
+
nx_graph.add_edge(
|
184
|
+
e[0],
|
185
|
+
e[1],
|
186
|
+
relation=pyg_graph.edge_type[i],
|
187
|
+
label=pyg_graph.edge_type[i],
|
188
|
+
)
|
189
|
+
|
190
|
+
# Prepare the textualized subgraph
|
191
|
+
textualized_graph = (
|
192
|
+
textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
|
193
|
+
+ "\n"
|
194
|
+
+ textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
|
195
|
+
)
|
196
|
+
|
197
|
+
return {
|
198
|
+
"graph_pyg": pyg_graph,
|
199
|
+
"graph_nx": nx_graph,
|
200
|
+
"graph_text": textualized_graph,
|
201
|
+
}
|
202
|
+
|
203
|
+
def _run(
|
204
|
+
self,
|
205
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
206
|
+
state: Annotated[dict, InjectedState],
|
207
|
+
prompt: str,
|
208
|
+
arg_data: ArgumentData = None,
|
209
|
+
) -> Command:
|
210
|
+
"""
|
211
|
+
Run the subgraph extraction tool.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
tool_call_id: The tool call ID for the tool.
|
215
|
+
state: Injected state for the tool.
|
216
|
+
prompt: The prompt to interact with the backend.
|
217
|
+
arg_data (ArgumentData): The argument data.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
Command: The command to be executed.
|
221
|
+
"""
|
222
|
+
logger.log(logging.INFO, "Invoking subgraph_extraction tool")
|
223
|
+
|
224
|
+
# Load hydra configuration
|
225
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
226
|
+
cfg = hydra.compose(
|
227
|
+
config_name="config", overrides=["tools/subgraph_extraction=default"]
|
228
|
+
)
|
229
|
+
cfg = cfg.tools.subgraph_extraction
|
230
|
+
|
231
|
+
# Retrieve source graph from the state
|
232
|
+
initial_graph = {}
|
233
|
+
initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
|
234
|
+
# logger.log(logging.INFO, "Source graph: %s", source_graph)
|
235
|
+
|
236
|
+
# Load the knowledge graph
|
237
|
+
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
|
238
|
+
initial_graph["pyg"] = pickle.load(f)
|
239
|
+
with open(initial_graph["source"]["kg_text_path"], "rb") as f:
|
240
|
+
initial_graph["text"] = pickle.load(f)
|
241
|
+
|
242
|
+
# Prepare prompt construction along with a list of endotypes
|
243
|
+
if len(state["uploaded_files"]) != 0 and "endotype" in [
|
244
|
+
f["file_type"] for f in state["uploaded_files"]
|
245
|
+
]:
|
246
|
+
prompt = self.perform_endotype_filtering(prompt, state, cfg)
|
247
|
+
|
248
|
+
# Prepare embedding model and embed the user prompt as query
|
249
|
+
query_emb = torch.tensor(
|
250
|
+
EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
|
251
|
+
).float()
|
252
|
+
|
253
|
+
# Prepare the PCSTPruning object and extract the subgraph
|
254
|
+
# Parameters were set in the configuration file obtained from Hydra
|
255
|
+
subgraph = PCSTPruning(
|
256
|
+
state["topk_nodes"],
|
257
|
+
state["topk_edges"],
|
258
|
+
cfg.cost_e,
|
259
|
+
cfg.c_const,
|
260
|
+
cfg.root,
|
261
|
+
cfg.num_clusters,
|
262
|
+
cfg.pruning,
|
263
|
+
cfg.verbosity_level,
|
264
|
+
).extract_subgraph(initial_graph["pyg"], query_emb)
|
265
|
+
|
266
|
+
# Prepare subgraph as a NetworkX graph and textualized graph
|
267
|
+
final_subgraph = self.prepare_final_subgraph(
|
268
|
+
subgraph, initial_graph["pyg"], initial_graph["text"]
|
269
|
+
)
|
270
|
+
|
271
|
+
# Prepare the dictionary of extracted graph
|
272
|
+
dic_extracted_graph = {
|
273
|
+
"name": arg_data.extraction_name,
|
274
|
+
"tool_call_id": tool_call_id,
|
275
|
+
"graph_source": initial_graph["source"]["name"],
|
276
|
+
"topk_nodes": state["topk_nodes"],
|
277
|
+
"topk_edges": state["topk_edges"],
|
278
|
+
"graph_dict": {
|
279
|
+
"nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
|
280
|
+
"edges": list(final_subgraph["graph_nx"].edges(data=True)),
|
281
|
+
},
|
282
|
+
"graph_text": final_subgraph["graph_text"],
|
283
|
+
"graph_summary": None,
|
284
|
+
}
|
285
|
+
|
286
|
+
# Prepare the dictionary of updated state
|
287
|
+
dic_updated_state_for_model = {}
|
288
|
+
for key, value in {
|
289
|
+
"dic_extracted_graph": [dic_extracted_graph],
|
290
|
+
}.items():
|
291
|
+
if value:
|
292
|
+
dic_updated_state_for_model[key] = value
|
293
|
+
|
294
|
+
# Return the updated state of the tool
|
295
|
+
return Command(
|
296
|
+
update=dic_updated_state_for_model | {
|
297
|
+
# update the message history
|
298
|
+
"messages": [
|
299
|
+
ToolMessage(
|
300
|
+
content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
|
301
|
+
tool_call_id=tool_call_id,
|
302
|
+
)
|
303
|
+
],
|
304
|
+
}
|
305
|
+
)
|