aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
- aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
- aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
- aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing Graph RAG reasoning.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.messages import ToolMessage
|
10
|
+
from langchain_core.tools.base import InjectedToolCallId
|
11
|
+
from langchain_core.tools import BaseTool
|
12
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16
|
+
from langchain_community.document_loaders import PyPDFLoader
|
17
|
+
from langgraph.types import Command
|
18
|
+
from langgraph.prebuilt import InjectedState
|
19
|
+
import hydra
|
20
|
+
|
21
|
+
# Initialize logger
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class GraphRAGReasoningInput(BaseModel):
|
27
|
+
"""
|
28
|
+
GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
state: Injected state.
|
32
|
+
prompt: Prompt to interact with the backend.
|
33
|
+
extraction_name: Name assigned to the subgraph extraction process
|
34
|
+
"""
|
35
|
+
|
36
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
37
|
+
description="Tool call ID."
|
38
|
+
)
|
39
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
40
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
41
|
+
extraction_name: str = Field(
|
42
|
+
description="""Name assigned to the subgraph extraction process
|
43
|
+
when the subgraph_extraction tool is invoked."""
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class GraphRAGReasoningTool(BaseTool):
|
48
|
+
"""
|
49
|
+
This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
|
50
|
+
over user's request by considering textualized subgraph context and document context.
|
51
|
+
"""
|
52
|
+
|
53
|
+
name: str = "graphrag_reasoning"
|
54
|
+
description: str = """A tool to perform reasoning using a Graph RAG approach
|
55
|
+
by considering textualized subgraph context and document context."""
|
56
|
+
args_schema: Type[BaseModel] = GraphRAGReasoningInput
|
57
|
+
|
58
|
+
def _run(
|
59
|
+
self,
|
60
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
61
|
+
state: Annotated[dict, InjectedState],
|
62
|
+
prompt: str,
|
63
|
+
extraction_name: str,
|
64
|
+
):
|
65
|
+
"""
|
66
|
+
Run the Graph RAG reasoning tool.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
tool_call_id: The tool call ID.
|
70
|
+
state: The injected state.
|
71
|
+
prompt: The prompt to interact with the backend.
|
72
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
73
|
+
"""
|
74
|
+
logger.log(
|
75
|
+
logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
|
76
|
+
)
|
77
|
+
|
78
|
+
# Load Hydra configuration
|
79
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
80
|
+
cfg = hydra.compose(
|
81
|
+
config_name="config", overrides=["tools/graphrag_reasoning=default"]
|
82
|
+
)
|
83
|
+
cfg = cfg.tools.graphrag_reasoning
|
84
|
+
|
85
|
+
# Prepare documents
|
86
|
+
all_docs = []
|
87
|
+
if len(state["uploaded_files"]) != 0:
|
88
|
+
for uploaded_file in state["uploaded_files"]:
|
89
|
+
if uploaded_file["file_type"] == "drug_data":
|
90
|
+
# Load documents
|
91
|
+
raw_documents = PyPDFLoader(
|
92
|
+
file_path=uploaded_file["file_path"]
|
93
|
+
).load()
|
94
|
+
|
95
|
+
# Split documents
|
96
|
+
# May need to find an optimal chunk size and overlap configuration
|
97
|
+
documents = RecursiveCharacterTextSplitter(
|
98
|
+
chunk_size=cfg.splitter_chunk_size,
|
99
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
100
|
+
).split_documents(raw_documents)
|
101
|
+
|
102
|
+
# Add documents to the list
|
103
|
+
all_docs.extend(documents)
|
104
|
+
|
105
|
+
# Load the extracted graph
|
106
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
107
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
108
|
+
|
109
|
+
# Set another prompt template
|
110
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
111
|
+
[("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
|
112
|
+
)
|
113
|
+
|
114
|
+
# Prepare chain with retrieved documents
|
115
|
+
qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
|
116
|
+
rag_chain = create_retrieval_chain(
|
117
|
+
InMemoryVectorStore.from_documents(
|
118
|
+
documents=all_docs, embedding=state["embedding_model"]
|
119
|
+
).as_retriever(
|
120
|
+
search_type=cfg.retriever_search_type,
|
121
|
+
search_kwargs={
|
122
|
+
"k": cfg.retriever_k,
|
123
|
+
"fetch_k": cfg.retriever_fetch_k,
|
124
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
125
|
+
},
|
126
|
+
),
|
127
|
+
qa_chain,
|
128
|
+
)
|
129
|
+
|
130
|
+
# Invoke the chain
|
131
|
+
response = rag_chain.invoke(
|
132
|
+
{
|
133
|
+
"input": prompt,
|
134
|
+
"subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
|
135
|
+
}
|
136
|
+
)
|
137
|
+
|
138
|
+
return Command(
|
139
|
+
update={
|
140
|
+
# update the message history
|
141
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
142
|
+
}
|
143
|
+
)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
A utility module for defining the dataclasses
|
3
|
+
for the arguments to set up initial settings
|
4
|
+
"""
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import Annotated
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class ArgumentData:
|
12
|
+
"""
|
13
|
+
Dataclass for storing the argument data.
|
14
|
+
"""
|
15
|
+
|
16
|
+
extraction_name: Annotated[
|
17
|
+
str,
|
18
|
+
"""An AI assigned _ separated name of the subgraph extraction
|
19
|
+
based on human query and the context of the graph reasoning
|
20
|
+
experiment.
|
21
|
+
This must be set before the subgraph extraction is invoked.""",
|
22
|
+
]
|
@@ -0,0 +1,305 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph extraction.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Type, Annotated
|
6
|
+
import logging
|
7
|
+
import pickle
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import hydra
|
11
|
+
import networkx as nx
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain_core.prompts import ChatPromptTemplate
|
16
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
17
|
+
from langchain_core.tools import BaseTool
|
18
|
+
from langchain_core.messages import ToolMessage
|
19
|
+
from langchain_core.tools.base import InjectedToolCallId
|
20
|
+
from langchain_community.document_loaders import PyPDFLoader
|
21
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
22
|
+
from langgraph.types import Command
|
23
|
+
from langgraph.prebuilt import InjectedState
|
24
|
+
import torch
|
25
|
+
from torch_geometric.data import Data
|
26
|
+
from ..utils.extractions.pcst import PCSTPruning
|
27
|
+
from ..utils.embeddings.ollama import EmbeddingWithOllama
|
28
|
+
from .load_arguments import ArgumentData
|
29
|
+
|
30
|
+
# Initialize logger
|
31
|
+
logging.basicConfig(level=logging.INFO)
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class SubgraphExtractionInput(BaseModel):
|
36
|
+
"""
|
37
|
+
SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
prompt: Prompt to interact with the backend.
|
41
|
+
tool_call_id: Tool call ID.
|
42
|
+
state: Injected state.
|
43
|
+
arg_data: Argument for analytical process over graph data.
|
44
|
+
"""
|
45
|
+
|
46
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
47
|
+
description="Tool call ID."
|
48
|
+
)
|
49
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
50
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
51
|
+
arg_data: ArgumentData = Field(
|
52
|
+
description="Experiment over graph data.", default=None
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class SubgraphExtractionTool(BaseTool):
|
57
|
+
"""
|
58
|
+
This tool performs subgraph extraction based on user's prompt by taking into account
|
59
|
+
the top-k nodes and edges.
|
60
|
+
"""
|
61
|
+
|
62
|
+
name: str = "subgraph_extraction"
|
63
|
+
description: str = "A tool for subgraph extraction based on user's prompt."
|
64
|
+
args_schema: Type[BaseModel] = SubgraphExtractionInput
|
65
|
+
|
66
|
+
def perform_endotype_filtering(
|
67
|
+
self,
|
68
|
+
prompt: str,
|
69
|
+
state: Annotated[dict, InjectedState],
|
70
|
+
cfg: hydra.core.config_store.ConfigStore,
|
71
|
+
) -> str:
|
72
|
+
"""
|
73
|
+
Perform endotype filtering based on the uploaded files and prepare the prompt.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
prompt: The prompt to interact with the backend.
|
77
|
+
state: Injected state for the tool.
|
78
|
+
cfg: Hydra configuration object.
|
79
|
+
"""
|
80
|
+
# Loop through the uploaded files
|
81
|
+
all_genes = []
|
82
|
+
for uploaded_file in state["uploaded_files"]:
|
83
|
+
if uploaded_file["file_type"] == "endotype":
|
84
|
+
# Load the PDF file
|
85
|
+
docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
|
86
|
+
|
87
|
+
# Split the text into chunks
|
88
|
+
splits = RecursiveCharacterTextSplitter(
|
89
|
+
chunk_size=cfg.splitter_chunk_size,
|
90
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
91
|
+
).split_documents(docs)
|
92
|
+
|
93
|
+
# Create a chat prompt template
|
94
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
95
|
+
[
|
96
|
+
("system", cfg.prompt_endotype_filtering),
|
97
|
+
("human", "{input}"),
|
98
|
+
]
|
99
|
+
)
|
100
|
+
|
101
|
+
qa_chain = create_stuff_documents_chain(
|
102
|
+
state["llm_model"], prompt_template
|
103
|
+
)
|
104
|
+
rag_chain = create_retrieval_chain(
|
105
|
+
InMemoryVectorStore.from_documents(
|
106
|
+
documents=splits, embedding=state["embedding_model"]
|
107
|
+
).as_retriever(
|
108
|
+
search_type=cfg.retriever_search_type,
|
109
|
+
search_kwargs={
|
110
|
+
"k": cfg.retriever_k,
|
111
|
+
"fetch_k": cfg.retriever_fetch_k,
|
112
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
113
|
+
},
|
114
|
+
),
|
115
|
+
qa_chain,
|
116
|
+
)
|
117
|
+
results = rag_chain.invoke({"input": prompt})
|
118
|
+
all_genes.append(results["answer"])
|
119
|
+
|
120
|
+
# Prepare the prompt
|
121
|
+
if len(all_genes) > 0:
|
122
|
+
prompt = " ".join(
|
123
|
+
[prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
|
124
|
+
)
|
125
|
+
|
126
|
+
return prompt
|
127
|
+
|
128
|
+
def prepare_final_subgraph(self,
|
129
|
+
subgraph: dict,
|
130
|
+
pyg_graph: Data,
|
131
|
+
textualized_graph: pd.DataFrame) -> dict:
|
132
|
+
"""
|
133
|
+
Prepare the subgraph based on the extracted subgraph.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
subgraph: The extracted subgraph.
|
137
|
+
pyg_graph: The PyTorch Geometric graph.
|
138
|
+
textualized_graph: The textualized graph.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
142
|
+
"""
|
143
|
+
# print(subgraph)
|
144
|
+
# Prepare the PyTorch Geometric graph
|
145
|
+
mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
|
146
|
+
pyg_graph = Data(
|
147
|
+
# Node features
|
148
|
+
x=pyg_graph.x[subgraph["nodes"]],
|
149
|
+
node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
150
|
+
node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
151
|
+
enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
|
152
|
+
num_nodes=len(subgraph["nodes"]),
|
153
|
+
# Edge features
|
154
|
+
edge_index=torch.LongTensor(
|
155
|
+
[
|
156
|
+
[
|
157
|
+
mapping[i]
|
158
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
|
159
|
+
],
|
160
|
+
[
|
161
|
+
mapping[i]
|
162
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
|
163
|
+
],
|
164
|
+
]
|
165
|
+
),
|
166
|
+
edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
|
167
|
+
edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
168
|
+
relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
169
|
+
label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
170
|
+
enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
|
171
|
+
)
|
172
|
+
|
173
|
+
# Networkx DiGraph construction to be visualized in the frontend
|
174
|
+
nx_graph = nx.DiGraph()
|
175
|
+
for n in pyg_graph.node_name:
|
176
|
+
nx_graph.add_node(n)
|
177
|
+
for i, e in enumerate(
|
178
|
+
[
|
179
|
+
[pyg_graph.node_name[i], pyg_graph.node_name[j]]
|
180
|
+
for (i, j) in pyg_graph.edge_index.transpose(1, 0)
|
181
|
+
]
|
182
|
+
):
|
183
|
+
nx_graph.add_edge(
|
184
|
+
e[0],
|
185
|
+
e[1],
|
186
|
+
relation=pyg_graph.edge_type[i],
|
187
|
+
label=pyg_graph.edge_type[i],
|
188
|
+
)
|
189
|
+
|
190
|
+
# Prepare the textualized subgraph
|
191
|
+
textualized_graph = (
|
192
|
+
textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
|
193
|
+
+ "\n"
|
194
|
+
+ textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
|
195
|
+
)
|
196
|
+
|
197
|
+
return {
|
198
|
+
"graph_pyg": pyg_graph,
|
199
|
+
"graph_nx": nx_graph,
|
200
|
+
"graph_text": textualized_graph,
|
201
|
+
}
|
202
|
+
|
203
|
+
def _run(
|
204
|
+
self,
|
205
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
206
|
+
state: Annotated[dict, InjectedState],
|
207
|
+
prompt: str,
|
208
|
+
arg_data: ArgumentData = None,
|
209
|
+
) -> Command:
|
210
|
+
"""
|
211
|
+
Run the subgraph extraction tool.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
tool_call_id: The tool call ID for the tool.
|
215
|
+
state: Injected state for the tool.
|
216
|
+
prompt: The prompt to interact with the backend.
|
217
|
+
arg_data (ArgumentData): The argument data.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
Command: The command to be executed.
|
221
|
+
"""
|
222
|
+
logger.log(logging.INFO, "Invoking subgraph_extraction tool")
|
223
|
+
|
224
|
+
# Load hydra configuration
|
225
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
226
|
+
cfg = hydra.compose(
|
227
|
+
config_name="config", overrides=["tools/subgraph_extraction=default"]
|
228
|
+
)
|
229
|
+
cfg = cfg.tools.subgraph_extraction
|
230
|
+
|
231
|
+
# Retrieve source graph from the state
|
232
|
+
initial_graph = {}
|
233
|
+
initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
|
234
|
+
# logger.log(logging.INFO, "Source graph: %s", source_graph)
|
235
|
+
|
236
|
+
# Load the knowledge graph
|
237
|
+
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
|
238
|
+
initial_graph["pyg"] = pickle.load(f)
|
239
|
+
with open(initial_graph["source"]["kg_text_path"], "rb") as f:
|
240
|
+
initial_graph["text"] = pickle.load(f)
|
241
|
+
|
242
|
+
# Prepare prompt construction along with a list of endotypes
|
243
|
+
if len(state["uploaded_files"]) != 0 and "endotype" in [
|
244
|
+
f["file_type"] for f in state["uploaded_files"]
|
245
|
+
]:
|
246
|
+
prompt = self.perform_endotype_filtering(prompt, state, cfg)
|
247
|
+
|
248
|
+
# Prepare embedding model and embed the user prompt as query
|
249
|
+
query_emb = torch.tensor(
|
250
|
+
EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
|
251
|
+
).float()
|
252
|
+
|
253
|
+
# Prepare the PCSTPruning object and extract the subgraph
|
254
|
+
# Parameters were set in the configuration file obtained from Hydra
|
255
|
+
subgraph = PCSTPruning(
|
256
|
+
state["topk_nodes"],
|
257
|
+
state["topk_edges"],
|
258
|
+
cfg.cost_e,
|
259
|
+
cfg.c_const,
|
260
|
+
cfg.root,
|
261
|
+
cfg.num_clusters,
|
262
|
+
cfg.pruning,
|
263
|
+
cfg.verbosity_level,
|
264
|
+
).extract_subgraph(initial_graph["pyg"], query_emb)
|
265
|
+
|
266
|
+
# Prepare subgraph as a NetworkX graph and textualized graph
|
267
|
+
final_subgraph = self.prepare_final_subgraph(
|
268
|
+
subgraph, initial_graph["pyg"], initial_graph["text"]
|
269
|
+
)
|
270
|
+
|
271
|
+
# Prepare the dictionary of extracted graph
|
272
|
+
dic_extracted_graph = {
|
273
|
+
"name": arg_data.extraction_name,
|
274
|
+
"tool_call_id": tool_call_id,
|
275
|
+
"graph_source": initial_graph["source"]["name"],
|
276
|
+
"topk_nodes": state["topk_nodes"],
|
277
|
+
"topk_edges": state["topk_edges"],
|
278
|
+
"graph_dict": {
|
279
|
+
"nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
|
280
|
+
"edges": list(final_subgraph["graph_nx"].edges(data=True)),
|
281
|
+
},
|
282
|
+
"graph_text": final_subgraph["graph_text"],
|
283
|
+
"graph_summary": None,
|
284
|
+
}
|
285
|
+
|
286
|
+
# Prepare the dictionary of updated state
|
287
|
+
dic_updated_state_for_model = {}
|
288
|
+
for key, value in {
|
289
|
+
"dic_extracted_graph": [dic_extracted_graph],
|
290
|
+
}.items():
|
291
|
+
if value:
|
292
|
+
dic_updated_state_for_model[key] = value
|
293
|
+
|
294
|
+
# Return the updated state of the tool
|
295
|
+
return Command(
|
296
|
+
update=dic_updated_state_for_model | {
|
297
|
+
# update the message history
|
298
|
+
"messages": [
|
299
|
+
ToolMessage(
|
300
|
+
content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
|
301
|
+
tool_call_id=tool_call_id,
|
302
|
+
)
|
303
|
+
],
|
304
|
+
}
|
305
|
+
)
|
@@ -0,0 +1,126 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph summarization.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.output_parsers import StrOutputParser
|
9
|
+
from langchain_core.prompts import ChatPromptTemplate
|
10
|
+
from langchain_core.messages import ToolMessage
|
11
|
+
from langchain_core.tools.base import InjectedToolCallId
|
12
|
+
from langchain_core.tools import BaseTool
|
13
|
+
from langgraph.types import Command
|
14
|
+
from langgraph.prebuilt import InjectedState
|
15
|
+
import hydra
|
16
|
+
|
17
|
+
# Initialize logger
|
18
|
+
logging.basicConfig(level=logging.INFO)
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class SubgraphSummarizationInput(BaseModel):
|
23
|
+
"""
|
24
|
+
SubgraphSummarizationInput is a Pydantic model representing an input for
|
25
|
+
summarizing a given textualized subgraph.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
tool_call_id: Tool call ID.
|
29
|
+
state: Injected state.
|
30
|
+
prompt: Prompt to interact with the backend.
|
31
|
+
extraction_name: Name assigned to the subgraph extraction process
|
32
|
+
"""
|
33
|
+
|
34
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
35
|
+
description="Tool call ID."
|
36
|
+
)
|
37
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
38
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
39
|
+
extraction_name: str = Field(
|
40
|
+
description="""Name assigned to the subgraph extraction process
|
41
|
+
when the subgraph_extraction tool is invoked."""
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class SubgraphSummarizationTool(BaseTool):
|
46
|
+
"""
|
47
|
+
This tool performs subgraph summarization over textualized graph to highlight the most
|
48
|
+
important information in responding to user's prompt.
|
49
|
+
"""
|
50
|
+
|
51
|
+
name: str = "subgraph_summarization"
|
52
|
+
description: str = """A tool to perform subgraph summarization over textualized graph
|
53
|
+
for responding to user's follow-up prompt(s)."""
|
54
|
+
args_schema: Type[BaseModel] = SubgraphSummarizationInput
|
55
|
+
|
56
|
+
def _run(
|
57
|
+
self,
|
58
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
59
|
+
state: Annotated[dict, InjectedState],
|
60
|
+
prompt: str,
|
61
|
+
extraction_name: str,
|
62
|
+
):
|
63
|
+
"""
|
64
|
+
Run the subgraph summarization tool.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
tool_call_id: The tool call ID.
|
68
|
+
state: The injected state.
|
69
|
+
prompt: The prompt to interact with the backend.
|
70
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
71
|
+
"""
|
72
|
+
logger.log(
|
73
|
+
logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
|
74
|
+
)
|
75
|
+
|
76
|
+
# Load hydra configuration
|
77
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
78
|
+
cfg = hydra.compose(
|
79
|
+
config_name="config", overrides=["tools/subgraph_summarization=default"]
|
80
|
+
)
|
81
|
+
cfg = cfg.tools.subgraph_summarization
|
82
|
+
|
83
|
+
# Load the extracted graph
|
84
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
85
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
86
|
+
|
87
|
+
# Prepare prompt template
|
88
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
89
|
+
[
|
90
|
+
("system", cfg.prompt_subgraph_summarization),
|
91
|
+
("human", "{input}"),
|
92
|
+
]
|
93
|
+
)
|
94
|
+
|
95
|
+
# Prepare chain
|
96
|
+
chain = prompt_template | state["llm_model"] | StrOutputParser()
|
97
|
+
|
98
|
+
# Return the subgraph and textualized graph as JSON response
|
99
|
+
response = chain.invoke(
|
100
|
+
{
|
101
|
+
"input": prompt,
|
102
|
+
"textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
|
103
|
+
}
|
104
|
+
)
|
105
|
+
|
106
|
+
# Store the response as graph_summary in the extracted graph
|
107
|
+
for key, value in extracted_graph.items():
|
108
|
+
if key == extraction_name:
|
109
|
+
value["graph_summary"] = response
|
110
|
+
|
111
|
+
# Prepare the dictionary of updated state
|
112
|
+
dic_updated_state_for_model = {}
|
113
|
+
for key, value in {
|
114
|
+
"dic_extracted_graph": list(extracted_graph.values()),
|
115
|
+
}.items():
|
116
|
+
if value:
|
117
|
+
dic_updated_state_for_model[key] = value
|
118
|
+
|
119
|
+
# Return the updated state of the tool
|
120
|
+
return Command(
|
121
|
+
update=dic_updated_state_for_model
|
122
|
+
| {
|
123
|
+
# update the message history
|
124
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
125
|
+
}
|
126
|
+
)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
"""
|
2
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import time
|
6
|
+
from typing import List
|
7
|
+
import subprocess
|
8
|
+
import ollama
|
9
|
+
from langchain_ollama import OllamaEmbeddings
|
10
|
+
from .embeddings import Embeddings
|
11
|
+
|
12
|
+
class EmbeddingWithOllama(Embeddings):
|
13
|
+
"""
|
14
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
15
|
+
"""
|
16
|
+
def __init__(self, model_name: str):
|
17
|
+
"""
|
18
|
+
Initialize the EmbeddingWithOllama class.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
model_name: The name of the Ollama model to be used.
|
22
|
+
"""
|
23
|
+
# Setup the Ollama server
|
24
|
+
self.__setup(model_name)
|
25
|
+
|
26
|
+
# Set parameters
|
27
|
+
self.model_name = model_name
|
28
|
+
|
29
|
+
# Prepare model
|
30
|
+
self.model = OllamaEmbeddings(model=self.model_name)
|
31
|
+
|
32
|
+
def __setup(self, model_name: str) -> None:
|
33
|
+
"""
|
34
|
+
Check if the Ollama model is available and run the Ollama server if needed.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
model_name: The name of the Ollama model to be used.
|
38
|
+
"""
|
39
|
+
try:
|
40
|
+
models_list = ollama.list()["models"]
|
41
|
+
if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
|
42
|
+
ollama.pull(model_name)
|
43
|
+
time.sleep(30)
|
44
|
+
raise ValueError(f"Pulled {model_name} model")
|
45
|
+
except Exception as e:
|
46
|
+
with subprocess.Popen(
|
47
|
+
"ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
48
|
+
):
|
49
|
+
time.sleep(10)
|
50
|
+
raise ValueError(f"Error: {e} and restarted Ollama server.") from e
|
51
|
+
|
52
|
+
def embed_documents(self, texts: List[str]) -> List[float]:
|
53
|
+
"""
|
54
|
+
Generate embedding for a list of input texts using Ollama model.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
texts: The list of texts to be embedded.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The list of embeddings for the given texts.
|
61
|
+
"""
|
62
|
+
|
63
|
+
# Generate the embedding
|
64
|
+
embeddings = self.model.embed_documents(texts)
|
65
|
+
|
66
|
+
return embeddings
|
67
|
+
|
68
|
+
def embed_query(self, text: str) -> List[float]:
|
69
|
+
"""
|
70
|
+
Generate embeddings for an input text using Ollama model.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
text: A query to be embedded.
|
74
|
+
Returns:
|
75
|
+
The embeddings for the given query.
|
76
|
+
"""
|
77
|
+
|
78
|
+
# Generate the embedding
|
79
|
+
embeddings = self.model.embed_query(text)
|
80
|
+
|
81
|
+
return embeddings
|