aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
- aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
- aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
- aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
- aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing Graph RAG reasoning.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.messages import ToolMessage
|
10
|
+
from langchain_core.tools.base import InjectedToolCallId
|
11
|
+
from langchain_core.tools import BaseTool
|
12
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16
|
+
from langchain_community.document_loaders import PyPDFLoader
|
17
|
+
from langgraph.types import Command
|
18
|
+
from langgraph.prebuilt import InjectedState
|
19
|
+
import hydra
|
20
|
+
|
21
|
+
# Initialize logger
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class GraphRAGReasoningInput(BaseModel):
|
27
|
+
"""
|
28
|
+
GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
state: Injected state.
|
32
|
+
prompt: Prompt to interact with the backend.
|
33
|
+
extraction_name: Name assigned to the subgraph extraction process
|
34
|
+
"""
|
35
|
+
|
36
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
37
|
+
description="Tool call ID."
|
38
|
+
)
|
39
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
40
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
41
|
+
extraction_name: str = Field(
|
42
|
+
description="""Name assigned to the subgraph extraction process
|
43
|
+
when the subgraph_extraction tool is invoked."""
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class GraphRAGReasoningTool(BaseTool):
|
48
|
+
"""
|
49
|
+
This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
|
50
|
+
over user's request by considering textualized subgraph context and document context.
|
51
|
+
"""
|
52
|
+
|
53
|
+
name: str = "graphrag_reasoning"
|
54
|
+
description: str = """A tool to perform reasoning using a Graph RAG approach
|
55
|
+
by considering textualized subgraph context and document context."""
|
56
|
+
args_schema: Type[BaseModel] = GraphRAGReasoningInput
|
57
|
+
|
58
|
+
def _run(
|
59
|
+
self,
|
60
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
61
|
+
state: Annotated[dict, InjectedState],
|
62
|
+
prompt: str,
|
63
|
+
extraction_name: str,
|
64
|
+
):
|
65
|
+
"""
|
66
|
+
Run the Graph RAG reasoning tool.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
tool_call_id: The tool call ID.
|
70
|
+
state: The injected state.
|
71
|
+
prompt: The prompt to interact with the backend.
|
72
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
73
|
+
"""
|
74
|
+
logger.log(
|
75
|
+
logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
|
76
|
+
)
|
77
|
+
|
78
|
+
# Load Hydra configuration
|
79
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
80
|
+
cfg = hydra.compose(
|
81
|
+
config_name="config", overrides=["tools/graphrag_reasoning=default"]
|
82
|
+
)
|
83
|
+
cfg = cfg.tools.graphrag_reasoning
|
84
|
+
|
85
|
+
# Prepare documents
|
86
|
+
all_docs = []
|
87
|
+
if len(state["uploaded_files"]) != 0:
|
88
|
+
for uploaded_file in state["uploaded_files"]:
|
89
|
+
if uploaded_file["file_type"] == "drug_data":
|
90
|
+
# Load documents
|
91
|
+
raw_documents = PyPDFLoader(
|
92
|
+
file_path=uploaded_file["file_path"]
|
93
|
+
).load()
|
94
|
+
|
95
|
+
# Split documents
|
96
|
+
# May need to find an optimal chunk size and overlap configuration
|
97
|
+
documents = RecursiveCharacterTextSplitter(
|
98
|
+
chunk_size=cfg.splitter_chunk_size,
|
99
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
100
|
+
).split_documents(raw_documents)
|
101
|
+
|
102
|
+
# Add documents to the list
|
103
|
+
all_docs.extend(documents)
|
104
|
+
|
105
|
+
# Load the extracted graph
|
106
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
107
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
108
|
+
|
109
|
+
# Set another prompt template
|
110
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
111
|
+
[("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
|
112
|
+
)
|
113
|
+
|
114
|
+
# Prepare chain with retrieved documents
|
115
|
+
qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
|
116
|
+
rag_chain = create_retrieval_chain(
|
117
|
+
InMemoryVectorStore.from_documents(
|
118
|
+
documents=all_docs, embedding=state["embedding_model"]
|
119
|
+
).as_retriever(
|
120
|
+
search_type=cfg.retriever_search_type,
|
121
|
+
search_kwargs={
|
122
|
+
"k": cfg.retriever_k,
|
123
|
+
"fetch_k": cfg.retriever_fetch_k,
|
124
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
125
|
+
},
|
126
|
+
),
|
127
|
+
qa_chain,
|
128
|
+
)
|
129
|
+
|
130
|
+
# Invoke the chain
|
131
|
+
response = rag_chain.invoke(
|
132
|
+
{
|
133
|
+
"input": prompt,
|
134
|
+
"subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
|
135
|
+
}
|
136
|
+
)
|
137
|
+
|
138
|
+
return Command(
|
139
|
+
update={
|
140
|
+
# update the message history
|
141
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
142
|
+
}
|
143
|
+
)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
A utility module for defining the dataclasses
|
3
|
+
for the arguments to set up initial settings
|
4
|
+
"""
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import Annotated
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class ArgumentData:
|
12
|
+
"""
|
13
|
+
Dataclass for storing the argument data.
|
14
|
+
"""
|
15
|
+
|
16
|
+
extraction_name: Annotated[
|
17
|
+
str,
|
18
|
+
"""An AI assigned _ separated name of the subgraph extraction
|
19
|
+
based on human query and the context of the graph reasoning
|
20
|
+
experiment.
|
21
|
+
This must be set before the subgraph extraction is invoked.""",
|
22
|
+
]
|
@@ -0,0 +1,305 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph extraction.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Type, Annotated
|
6
|
+
import logging
|
7
|
+
import pickle
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import hydra
|
11
|
+
import networkx as nx
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
from langchain.chains.retrieval import create_retrieval_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain_core.prompts import ChatPromptTemplate
|
16
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
17
|
+
from langchain_core.tools import BaseTool
|
18
|
+
from langchain_core.messages import ToolMessage
|
19
|
+
from langchain_core.tools.base import InjectedToolCallId
|
20
|
+
from langchain_community.document_loaders import PyPDFLoader
|
21
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
22
|
+
from langgraph.types import Command
|
23
|
+
from langgraph.prebuilt import InjectedState
|
24
|
+
import torch
|
25
|
+
from torch_geometric.data import Data
|
26
|
+
from ..utils.extractions.pcst import PCSTPruning
|
27
|
+
from ..utils.embeddings.ollama import EmbeddingWithOllama
|
28
|
+
from .load_arguments import ArgumentData
|
29
|
+
|
30
|
+
# Initialize logger
|
31
|
+
logging.basicConfig(level=logging.INFO)
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class SubgraphExtractionInput(BaseModel):
|
36
|
+
"""
|
37
|
+
SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
prompt: Prompt to interact with the backend.
|
41
|
+
tool_call_id: Tool call ID.
|
42
|
+
state: Injected state.
|
43
|
+
arg_data: Argument for analytical process over graph data.
|
44
|
+
"""
|
45
|
+
|
46
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
47
|
+
description="Tool call ID."
|
48
|
+
)
|
49
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
50
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
51
|
+
arg_data: ArgumentData = Field(
|
52
|
+
description="Experiment over graph data.", default=None
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class SubgraphExtractionTool(BaseTool):
|
57
|
+
"""
|
58
|
+
This tool performs subgraph extraction based on user's prompt by taking into account
|
59
|
+
the top-k nodes and edges.
|
60
|
+
"""
|
61
|
+
|
62
|
+
name: str = "subgraph_extraction"
|
63
|
+
description: str = "A tool for subgraph extraction based on user's prompt."
|
64
|
+
args_schema: Type[BaseModel] = SubgraphExtractionInput
|
65
|
+
|
66
|
+
def perform_endotype_filtering(
|
67
|
+
self,
|
68
|
+
prompt: str,
|
69
|
+
state: Annotated[dict, InjectedState],
|
70
|
+
cfg: hydra.core.config_store.ConfigStore,
|
71
|
+
) -> str:
|
72
|
+
"""
|
73
|
+
Perform endotype filtering based on the uploaded files and prepare the prompt.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
prompt: The prompt to interact with the backend.
|
77
|
+
state: Injected state for the tool.
|
78
|
+
cfg: Hydra configuration object.
|
79
|
+
"""
|
80
|
+
# Loop through the uploaded files
|
81
|
+
all_genes = []
|
82
|
+
for uploaded_file in state["uploaded_files"]:
|
83
|
+
if uploaded_file["file_type"] == "endotype":
|
84
|
+
# Load the PDF file
|
85
|
+
docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
|
86
|
+
|
87
|
+
# Split the text into chunks
|
88
|
+
splits = RecursiveCharacterTextSplitter(
|
89
|
+
chunk_size=cfg.splitter_chunk_size,
|
90
|
+
chunk_overlap=cfg.splitter_chunk_overlap,
|
91
|
+
).split_documents(docs)
|
92
|
+
|
93
|
+
# Create a chat prompt template
|
94
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
95
|
+
[
|
96
|
+
("system", cfg.prompt_endotype_filtering),
|
97
|
+
("human", "{input}"),
|
98
|
+
]
|
99
|
+
)
|
100
|
+
|
101
|
+
qa_chain = create_stuff_documents_chain(
|
102
|
+
state["llm_model"], prompt_template
|
103
|
+
)
|
104
|
+
rag_chain = create_retrieval_chain(
|
105
|
+
InMemoryVectorStore.from_documents(
|
106
|
+
documents=splits, embedding=state["embedding_model"]
|
107
|
+
).as_retriever(
|
108
|
+
search_type=cfg.retriever_search_type,
|
109
|
+
search_kwargs={
|
110
|
+
"k": cfg.retriever_k,
|
111
|
+
"fetch_k": cfg.retriever_fetch_k,
|
112
|
+
"lambda_mult": cfg.retriever_lambda_mult,
|
113
|
+
},
|
114
|
+
),
|
115
|
+
qa_chain,
|
116
|
+
)
|
117
|
+
results = rag_chain.invoke({"input": prompt})
|
118
|
+
all_genes.append(results["answer"])
|
119
|
+
|
120
|
+
# Prepare the prompt
|
121
|
+
if len(all_genes) > 0:
|
122
|
+
prompt = " ".join(
|
123
|
+
[prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
|
124
|
+
)
|
125
|
+
|
126
|
+
return prompt
|
127
|
+
|
128
|
+
def prepare_final_subgraph(self,
|
129
|
+
subgraph: dict,
|
130
|
+
pyg_graph: Data,
|
131
|
+
textualized_graph: pd.DataFrame) -> dict:
|
132
|
+
"""
|
133
|
+
Prepare the subgraph based on the extracted subgraph.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
subgraph: The extracted subgraph.
|
137
|
+
pyg_graph: The PyTorch Geometric graph.
|
138
|
+
textualized_graph: The textualized graph.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
142
|
+
"""
|
143
|
+
# print(subgraph)
|
144
|
+
# Prepare the PyTorch Geometric graph
|
145
|
+
mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
|
146
|
+
pyg_graph = Data(
|
147
|
+
# Node features
|
148
|
+
x=pyg_graph.x[subgraph["nodes"]],
|
149
|
+
node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
150
|
+
node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
|
151
|
+
enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
|
152
|
+
num_nodes=len(subgraph["nodes"]),
|
153
|
+
# Edge features
|
154
|
+
edge_index=torch.LongTensor(
|
155
|
+
[
|
156
|
+
[
|
157
|
+
mapping[i]
|
158
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
|
159
|
+
],
|
160
|
+
[
|
161
|
+
mapping[i]
|
162
|
+
for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
|
163
|
+
],
|
164
|
+
]
|
165
|
+
),
|
166
|
+
edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
|
167
|
+
edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
168
|
+
relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
169
|
+
label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
|
170
|
+
enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
|
171
|
+
)
|
172
|
+
|
173
|
+
# Networkx DiGraph construction to be visualized in the frontend
|
174
|
+
nx_graph = nx.DiGraph()
|
175
|
+
for n in pyg_graph.node_name:
|
176
|
+
nx_graph.add_node(n)
|
177
|
+
for i, e in enumerate(
|
178
|
+
[
|
179
|
+
[pyg_graph.node_name[i], pyg_graph.node_name[j]]
|
180
|
+
for (i, j) in pyg_graph.edge_index.transpose(1, 0)
|
181
|
+
]
|
182
|
+
):
|
183
|
+
nx_graph.add_edge(
|
184
|
+
e[0],
|
185
|
+
e[1],
|
186
|
+
relation=pyg_graph.edge_type[i],
|
187
|
+
label=pyg_graph.edge_type[i],
|
188
|
+
)
|
189
|
+
|
190
|
+
# Prepare the textualized subgraph
|
191
|
+
textualized_graph = (
|
192
|
+
textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
|
193
|
+
+ "\n"
|
194
|
+
+ textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
|
195
|
+
)
|
196
|
+
|
197
|
+
return {
|
198
|
+
"graph_pyg": pyg_graph,
|
199
|
+
"graph_nx": nx_graph,
|
200
|
+
"graph_text": textualized_graph,
|
201
|
+
}
|
202
|
+
|
203
|
+
def _run(
|
204
|
+
self,
|
205
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
206
|
+
state: Annotated[dict, InjectedState],
|
207
|
+
prompt: str,
|
208
|
+
arg_data: ArgumentData = None,
|
209
|
+
) -> Command:
|
210
|
+
"""
|
211
|
+
Run the subgraph extraction tool.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
tool_call_id: The tool call ID for the tool.
|
215
|
+
state: Injected state for the tool.
|
216
|
+
prompt: The prompt to interact with the backend.
|
217
|
+
arg_data (ArgumentData): The argument data.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
Command: The command to be executed.
|
221
|
+
"""
|
222
|
+
logger.log(logging.INFO, "Invoking subgraph_extraction tool")
|
223
|
+
|
224
|
+
# Load hydra configuration
|
225
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
226
|
+
cfg = hydra.compose(
|
227
|
+
config_name="config", overrides=["tools/subgraph_extraction=default"]
|
228
|
+
)
|
229
|
+
cfg = cfg.tools.subgraph_extraction
|
230
|
+
|
231
|
+
# Retrieve source graph from the state
|
232
|
+
initial_graph = {}
|
233
|
+
initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
|
234
|
+
# logger.log(logging.INFO, "Source graph: %s", source_graph)
|
235
|
+
|
236
|
+
# Load the knowledge graph
|
237
|
+
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
|
238
|
+
initial_graph["pyg"] = pickle.load(f)
|
239
|
+
with open(initial_graph["source"]["kg_text_path"], "rb") as f:
|
240
|
+
initial_graph["text"] = pickle.load(f)
|
241
|
+
|
242
|
+
# Prepare prompt construction along with a list of endotypes
|
243
|
+
if len(state["uploaded_files"]) != 0 and "endotype" in [
|
244
|
+
f["file_type"] for f in state["uploaded_files"]
|
245
|
+
]:
|
246
|
+
prompt = self.perform_endotype_filtering(prompt, state, cfg)
|
247
|
+
|
248
|
+
# Prepare embedding model and embed the user prompt as query
|
249
|
+
query_emb = torch.tensor(
|
250
|
+
EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
|
251
|
+
).float()
|
252
|
+
|
253
|
+
# Prepare the PCSTPruning object and extract the subgraph
|
254
|
+
# Parameters were set in the configuration file obtained from Hydra
|
255
|
+
subgraph = PCSTPruning(
|
256
|
+
state["topk_nodes"],
|
257
|
+
state["topk_edges"],
|
258
|
+
cfg.cost_e,
|
259
|
+
cfg.c_const,
|
260
|
+
cfg.root,
|
261
|
+
cfg.num_clusters,
|
262
|
+
cfg.pruning,
|
263
|
+
cfg.verbosity_level,
|
264
|
+
).extract_subgraph(initial_graph["pyg"], query_emb)
|
265
|
+
|
266
|
+
# Prepare subgraph as a NetworkX graph and textualized graph
|
267
|
+
final_subgraph = self.prepare_final_subgraph(
|
268
|
+
subgraph, initial_graph["pyg"], initial_graph["text"]
|
269
|
+
)
|
270
|
+
|
271
|
+
# Prepare the dictionary of extracted graph
|
272
|
+
dic_extracted_graph = {
|
273
|
+
"name": arg_data.extraction_name,
|
274
|
+
"tool_call_id": tool_call_id,
|
275
|
+
"graph_source": initial_graph["source"]["name"],
|
276
|
+
"topk_nodes": state["topk_nodes"],
|
277
|
+
"topk_edges": state["topk_edges"],
|
278
|
+
"graph_dict": {
|
279
|
+
"nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
|
280
|
+
"edges": list(final_subgraph["graph_nx"].edges(data=True)),
|
281
|
+
},
|
282
|
+
"graph_text": final_subgraph["graph_text"],
|
283
|
+
"graph_summary": None,
|
284
|
+
}
|
285
|
+
|
286
|
+
# Prepare the dictionary of updated state
|
287
|
+
dic_updated_state_for_model = {}
|
288
|
+
for key, value in {
|
289
|
+
"dic_extracted_graph": [dic_extracted_graph],
|
290
|
+
}.items():
|
291
|
+
if value:
|
292
|
+
dic_updated_state_for_model[key] = value
|
293
|
+
|
294
|
+
# Return the updated state of the tool
|
295
|
+
return Command(
|
296
|
+
update=dic_updated_state_for_model | {
|
297
|
+
# update the message history
|
298
|
+
"messages": [
|
299
|
+
ToolMessage(
|
300
|
+
content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
|
301
|
+
tool_call_id=tool_call_id,
|
302
|
+
)
|
303
|
+
],
|
304
|
+
}
|
305
|
+
)
|
@@ -0,0 +1,126 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph summarization.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.output_parsers import StrOutputParser
|
9
|
+
from langchain_core.prompts import ChatPromptTemplate
|
10
|
+
from langchain_core.messages import ToolMessage
|
11
|
+
from langchain_core.tools.base import InjectedToolCallId
|
12
|
+
from langchain_core.tools import BaseTool
|
13
|
+
from langgraph.types import Command
|
14
|
+
from langgraph.prebuilt import InjectedState
|
15
|
+
import hydra
|
16
|
+
|
17
|
+
# Initialize logger
|
18
|
+
logging.basicConfig(level=logging.INFO)
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class SubgraphSummarizationInput(BaseModel):
|
23
|
+
"""
|
24
|
+
SubgraphSummarizationInput is a Pydantic model representing an input for
|
25
|
+
summarizing a given textualized subgraph.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
tool_call_id: Tool call ID.
|
29
|
+
state: Injected state.
|
30
|
+
prompt: Prompt to interact with the backend.
|
31
|
+
extraction_name: Name assigned to the subgraph extraction process
|
32
|
+
"""
|
33
|
+
|
34
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
35
|
+
description="Tool call ID."
|
36
|
+
)
|
37
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
38
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
39
|
+
extraction_name: str = Field(
|
40
|
+
description="""Name assigned to the subgraph extraction process
|
41
|
+
when the subgraph_extraction tool is invoked."""
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class SubgraphSummarizationTool(BaseTool):
|
46
|
+
"""
|
47
|
+
This tool performs subgraph summarization over textualized graph to highlight the most
|
48
|
+
important information in responding to user's prompt.
|
49
|
+
"""
|
50
|
+
|
51
|
+
name: str = "subgraph_summarization"
|
52
|
+
description: str = """A tool to perform subgraph summarization over textualized graph
|
53
|
+
for responding to user's follow-up prompt(s)."""
|
54
|
+
args_schema: Type[BaseModel] = SubgraphSummarizationInput
|
55
|
+
|
56
|
+
def _run(
|
57
|
+
self,
|
58
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
59
|
+
state: Annotated[dict, InjectedState],
|
60
|
+
prompt: str,
|
61
|
+
extraction_name: str,
|
62
|
+
):
|
63
|
+
"""
|
64
|
+
Run the subgraph summarization tool.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
tool_call_id: The tool call ID.
|
68
|
+
state: The injected state.
|
69
|
+
prompt: The prompt to interact with the backend.
|
70
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
71
|
+
"""
|
72
|
+
logger.log(
|
73
|
+
logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
|
74
|
+
)
|
75
|
+
|
76
|
+
# Load hydra configuration
|
77
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
78
|
+
cfg = hydra.compose(
|
79
|
+
config_name="config", overrides=["tools/subgraph_summarization=default"]
|
80
|
+
)
|
81
|
+
cfg = cfg.tools.subgraph_summarization
|
82
|
+
|
83
|
+
# Load the extracted graph
|
84
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
85
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
86
|
+
|
87
|
+
# Prepare prompt template
|
88
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
89
|
+
[
|
90
|
+
("system", cfg.prompt_subgraph_summarization),
|
91
|
+
("human", "{input}"),
|
92
|
+
]
|
93
|
+
)
|
94
|
+
|
95
|
+
# Prepare chain
|
96
|
+
chain = prompt_template | state["llm_model"] | StrOutputParser()
|
97
|
+
|
98
|
+
# Return the subgraph and textualized graph as JSON response
|
99
|
+
response = chain.invoke(
|
100
|
+
{
|
101
|
+
"input": prompt,
|
102
|
+
"textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
|
103
|
+
}
|
104
|
+
)
|
105
|
+
|
106
|
+
# Store the response as graph_summary in the extracted graph
|
107
|
+
for key, value in extracted_graph.items():
|
108
|
+
if key == extraction_name:
|
109
|
+
value["graph_summary"] = response
|
110
|
+
|
111
|
+
# Prepare the dictionary of updated state
|
112
|
+
dic_updated_state_for_model = {}
|
113
|
+
for key, value in {
|
114
|
+
"dic_extracted_graph": list(extracted_graph.values()),
|
115
|
+
}.items():
|
116
|
+
if value:
|
117
|
+
dic_updated_state_for_model[key] = value
|
118
|
+
|
119
|
+
# Return the updated state of the tool
|
120
|
+
return Command(
|
121
|
+
update=dic_updated_state_for_model
|
122
|
+
| {
|
123
|
+
# update the message history
|
124
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
125
|
+
}
|
126
|
+
)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
"""
|
2
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import time
|
6
|
+
from typing import List
|
7
|
+
import subprocess
|
8
|
+
import ollama
|
9
|
+
from langchain_ollama import OllamaEmbeddings
|
10
|
+
from .embeddings import Embeddings
|
11
|
+
|
12
|
+
class EmbeddingWithOllama(Embeddings):
|
13
|
+
"""
|
14
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
15
|
+
"""
|
16
|
+
def __init__(self, model_name: str):
|
17
|
+
"""
|
18
|
+
Initialize the EmbeddingWithOllama class.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
model_name: The name of the Ollama model to be used.
|
22
|
+
"""
|
23
|
+
# Setup the Ollama server
|
24
|
+
self.__setup(model_name)
|
25
|
+
|
26
|
+
# Set parameters
|
27
|
+
self.model_name = model_name
|
28
|
+
|
29
|
+
# Prepare model
|
30
|
+
self.model = OllamaEmbeddings(model=self.model_name)
|
31
|
+
|
32
|
+
def __setup(self, model_name: str) -> None:
|
33
|
+
"""
|
34
|
+
Check if the Ollama model is available and run the Ollama server if needed.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
model_name: The name of the Ollama model to be used.
|
38
|
+
"""
|
39
|
+
try:
|
40
|
+
models_list = ollama.list()["models"]
|
41
|
+
if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
|
42
|
+
ollama.pull(model_name)
|
43
|
+
time.sleep(30)
|
44
|
+
raise ValueError(f"Pulled {model_name} model")
|
45
|
+
except Exception as e:
|
46
|
+
with subprocess.Popen(
|
47
|
+
"ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
48
|
+
):
|
49
|
+
time.sleep(10)
|
50
|
+
raise ValueError(f"Error: {e} and restarted Ollama server.") from e
|
51
|
+
|
52
|
+
def embed_documents(self, texts: List[str]) -> List[float]:
|
53
|
+
"""
|
54
|
+
Generate embedding for a list of input texts using Ollama model.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
texts: The list of texts to be embedded.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The list of embeddings for the given texts.
|
61
|
+
"""
|
62
|
+
|
63
|
+
# Generate the embedding
|
64
|
+
embeddings = self.model.embed_documents(texts)
|
65
|
+
|
66
|
+
return embeddings
|
67
|
+
|
68
|
+
def embed_query(self, text: str) -> List[float]:
|
69
|
+
"""
|
70
|
+
Generate embeddings for an input text using Ollama model.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
text: A query to be embedded.
|
74
|
+
Returns:
|
75
|
+
The embeddings for the given query.
|
76
|
+
"""
|
77
|
+
|
78
|
+
# Generate the embedding
|
79
|
+
embeddings = self.model.embed_query(text)
|
80
|
+
|
81
|
+
return embeddings
|