aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
  2. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
  3. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
  4. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
  5. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
  6. aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
  7. aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
  8. aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
  9. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
  10. aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
  11. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
  12. aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
  13. aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
  14. aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
  15. aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
  16. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
  17. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
  18. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
  19. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  20. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  21. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
  22. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  23. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
  24. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
  25. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
  26. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  27. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  28. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  29. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  30. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  31. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  32. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
  33. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
  34. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
  35. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
  36. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
  37. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
  38. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
  39. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
  40. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
  41. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
  42. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
  43. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
  44. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  45. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
  46. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
  47. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
  48. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  49. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
  50. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
  51. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
  52. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
  53. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
  54. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
  55. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
  56. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
1
+ """
2
+ Tool for performing Graph RAG reasoning.
3
+ """
4
+
5
+ import logging
6
+ from typing import Type, Annotated
7
+ from pydantic import BaseModel, Field
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.messages import ToolMessage
10
+ from langchain_core.tools.base import InjectedToolCallId
11
+ from langchain_core.tools import BaseTool
12
+ from langchain_core.vectorstores import InMemoryVectorStore
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.document_loaders import PyPDFLoader
17
+ from langgraph.types import Command
18
+ from langgraph.prebuilt import InjectedState
19
+ import hydra
20
+
21
+ # Initialize logger
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class GraphRAGReasoningInput(BaseModel):
27
+ """
28
+ GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
29
+
30
+ Args:
31
+ state: Injected state.
32
+ prompt: Prompt to interact with the backend.
33
+ extraction_name: Name assigned to the subgraph extraction process
34
+ """
35
+
36
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
37
+ description="Tool call ID."
38
+ )
39
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
40
+ prompt: str = Field(description="Prompt to interact with the backend.")
41
+ extraction_name: str = Field(
42
+ description="""Name assigned to the subgraph extraction process
43
+ when the subgraph_extraction tool is invoked."""
44
+ )
45
+
46
+
47
+ class GraphRAGReasoningTool(BaseTool):
48
+ """
49
+ This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
50
+ over user's request by considering textualized subgraph context and document context.
51
+ """
52
+
53
+ name: str = "graphrag_reasoning"
54
+ description: str = """A tool to perform reasoning using a Graph RAG approach
55
+ by considering textualized subgraph context and document context."""
56
+ args_schema: Type[BaseModel] = GraphRAGReasoningInput
57
+
58
+ def _run(
59
+ self,
60
+ tool_call_id: Annotated[str, InjectedToolCallId],
61
+ state: Annotated[dict, InjectedState],
62
+ prompt: str,
63
+ extraction_name: str,
64
+ ):
65
+ """
66
+ Run the Graph RAG reasoning tool.
67
+
68
+ Args:
69
+ tool_call_id: The tool call ID.
70
+ state: The injected state.
71
+ prompt: The prompt to interact with the backend.
72
+ extraction_name: The name assigned to the subgraph extraction process.
73
+ """
74
+ logger.log(
75
+ logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
76
+ )
77
+
78
+ # Load Hydra configuration
79
+ with hydra.initialize(version_base=None, config_path="../configs"):
80
+ cfg = hydra.compose(
81
+ config_name="config", overrides=["tools/graphrag_reasoning=default"]
82
+ )
83
+ cfg = cfg.tools.graphrag_reasoning
84
+
85
+ # Prepare documents
86
+ all_docs = []
87
+ if len(state["uploaded_files"]) != 0:
88
+ for uploaded_file in state["uploaded_files"]:
89
+ if uploaded_file["file_type"] == "drug_data":
90
+ # Load documents
91
+ raw_documents = PyPDFLoader(
92
+ file_path=uploaded_file["file_path"]
93
+ ).load()
94
+
95
+ # Split documents
96
+ # May need to find an optimal chunk size and overlap configuration
97
+ documents = RecursiveCharacterTextSplitter(
98
+ chunk_size=cfg.splitter_chunk_size,
99
+ chunk_overlap=cfg.splitter_chunk_overlap,
100
+ ).split_documents(raw_documents)
101
+
102
+ # Add documents to the list
103
+ all_docs.extend(documents)
104
+
105
+ # Load the extracted graph
106
+ extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
107
+ # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
108
+
109
+ # Set another prompt template
110
+ prompt_template = ChatPromptTemplate.from_messages(
111
+ [("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
112
+ )
113
+
114
+ # Prepare chain with retrieved documents
115
+ qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
116
+ rag_chain = create_retrieval_chain(
117
+ InMemoryVectorStore.from_documents(
118
+ documents=all_docs, embedding=state["embedding_model"]
119
+ ).as_retriever(
120
+ search_type=cfg.retriever_search_type,
121
+ search_kwargs={
122
+ "k": cfg.retriever_k,
123
+ "fetch_k": cfg.retriever_fetch_k,
124
+ "lambda_mult": cfg.retriever_lambda_mult,
125
+ },
126
+ ),
127
+ qa_chain,
128
+ )
129
+
130
+ # Invoke the chain
131
+ response = rag_chain.invoke(
132
+ {
133
+ "input": prompt,
134
+ "subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
135
+ }
136
+ )
137
+
138
+ return Command(
139
+ update={
140
+ # update the message history
141
+ "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
142
+ }
143
+ )
@@ -0,0 +1,22 @@
1
+ """
2
+ A utility module for defining the dataclasses
3
+ for the arguments to set up initial settings
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Annotated
8
+
9
+
10
+ @dataclass
11
+ class ArgumentData:
12
+ """
13
+ Dataclass for storing the argument data.
14
+ """
15
+
16
+ extraction_name: Annotated[
17
+ str,
18
+ """An AI assigned _ separated name of the subgraph extraction
19
+ based on human query and the context of the graph reasoning
20
+ experiment.
21
+ This must be set before the subgraph extraction is invoked.""",
22
+ ]
@@ -0,0 +1,305 @@
1
+ """
2
+ Tool for performing subgraph extraction.
3
+ """
4
+
5
+ from typing import Type, Annotated
6
+ import logging
7
+ import pickle
8
+ import numpy as np
9
+ import pandas as pd
10
+ import hydra
11
+ import networkx as nx
12
+ from pydantic import BaseModel, Field
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_core.vectorstores import InMemoryVectorStore
17
+ from langchain_core.tools import BaseTool
18
+ from langchain_core.messages import ToolMessage
19
+ from langchain_core.tools.base import InjectedToolCallId
20
+ from langchain_community.document_loaders import PyPDFLoader
21
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
22
+ from langgraph.types import Command
23
+ from langgraph.prebuilt import InjectedState
24
+ import torch
25
+ from torch_geometric.data import Data
26
+ from ..utils.extractions.pcst import PCSTPruning
27
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
28
+ from .load_arguments import ArgumentData
29
+
30
+ # Initialize logger
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SubgraphExtractionInput(BaseModel):
36
+ """
37
+ SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
38
+
39
+ Args:
40
+ prompt: Prompt to interact with the backend.
41
+ tool_call_id: Tool call ID.
42
+ state: Injected state.
43
+ arg_data: Argument for analytical process over graph data.
44
+ """
45
+
46
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
47
+ description="Tool call ID."
48
+ )
49
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
50
+ prompt: str = Field(description="Prompt to interact with the backend.")
51
+ arg_data: ArgumentData = Field(
52
+ description="Experiment over graph data.", default=None
53
+ )
54
+
55
+
56
+ class SubgraphExtractionTool(BaseTool):
57
+ """
58
+ This tool performs subgraph extraction based on user's prompt by taking into account
59
+ the top-k nodes and edges.
60
+ """
61
+
62
+ name: str = "subgraph_extraction"
63
+ description: str = "A tool for subgraph extraction based on user's prompt."
64
+ args_schema: Type[BaseModel] = SubgraphExtractionInput
65
+
66
+ def perform_endotype_filtering(
67
+ self,
68
+ prompt: str,
69
+ state: Annotated[dict, InjectedState],
70
+ cfg: hydra.core.config_store.ConfigStore,
71
+ ) -> str:
72
+ """
73
+ Perform endotype filtering based on the uploaded files and prepare the prompt.
74
+
75
+ Args:
76
+ prompt: The prompt to interact with the backend.
77
+ state: Injected state for the tool.
78
+ cfg: Hydra configuration object.
79
+ """
80
+ # Loop through the uploaded files
81
+ all_genes = []
82
+ for uploaded_file in state["uploaded_files"]:
83
+ if uploaded_file["file_type"] == "endotype":
84
+ # Load the PDF file
85
+ docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
86
+
87
+ # Split the text into chunks
88
+ splits = RecursiveCharacterTextSplitter(
89
+ chunk_size=cfg.splitter_chunk_size,
90
+ chunk_overlap=cfg.splitter_chunk_overlap,
91
+ ).split_documents(docs)
92
+
93
+ # Create a chat prompt template
94
+ prompt_template = ChatPromptTemplate.from_messages(
95
+ [
96
+ ("system", cfg.prompt_endotype_filtering),
97
+ ("human", "{input}"),
98
+ ]
99
+ )
100
+
101
+ qa_chain = create_stuff_documents_chain(
102
+ state["llm_model"], prompt_template
103
+ )
104
+ rag_chain = create_retrieval_chain(
105
+ InMemoryVectorStore.from_documents(
106
+ documents=splits, embedding=state["embedding_model"]
107
+ ).as_retriever(
108
+ search_type=cfg.retriever_search_type,
109
+ search_kwargs={
110
+ "k": cfg.retriever_k,
111
+ "fetch_k": cfg.retriever_fetch_k,
112
+ "lambda_mult": cfg.retriever_lambda_mult,
113
+ },
114
+ ),
115
+ qa_chain,
116
+ )
117
+ results = rag_chain.invoke({"input": prompt})
118
+ all_genes.append(results["answer"])
119
+
120
+ # Prepare the prompt
121
+ if len(all_genes) > 0:
122
+ prompt = " ".join(
123
+ [prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
124
+ )
125
+
126
+ return prompt
127
+
128
+ def prepare_final_subgraph(self,
129
+ subgraph: dict,
130
+ pyg_graph: Data,
131
+ textualized_graph: pd.DataFrame) -> dict:
132
+ """
133
+ Prepare the subgraph based on the extracted subgraph.
134
+
135
+ Args:
136
+ subgraph: The extracted subgraph.
137
+ pyg_graph: The PyTorch Geometric graph.
138
+ textualized_graph: The textualized graph.
139
+
140
+ Returns:
141
+ A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
142
+ """
143
+ # print(subgraph)
144
+ # Prepare the PyTorch Geometric graph
145
+ mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
146
+ pyg_graph = Data(
147
+ # Node features
148
+ x=pyg_graph.x[subgraph["nodes"]],
149
+ node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
150
+ node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
151
+ enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
152
+ num_nodes=len(subgraph["nodes"]),
153
+ # Edge features
154
+ edge_index=torch.LongTensor(
155
+ [
156
+ [
157
+ mapping[i]
158
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
159
+ ],
160
+ [
161
+ mapping[i]
162
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
163
+ ],
164
+ ]
165
+ ),
166
+ edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
167
+ edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
168
+ relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
169
+ label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
170
+ enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
171
+ )
172
+
173
+ # Networkx DiGraph construction to be visualized in the frontend
174
+ nx_graph = nx.DiGraph()
175
+ for n in pyg_graph.node_name:
176
+ nx_graph.add_node(n)
177
+ for i, e in enumerate(
178
+ [
179
+ [pyg_graph.node_name[i], pyg_graph.node_name[j]]
180
+ for (i, j) in pyg_graph.edge_index.transpose(1, 0)
181
+ ]
182
+ ):
183
+ nx_graph.add_edge(
184
+ e[0],
185
+ e[1],
186
+ relation=pyg_graph.edge_type[i],
187
+ label=pyg_graph.edge_type[i],
188
+ )
189
+
190
+ # Prepare the textualized subgraph
191
+ textualized_graph = (
192
+ textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
193
+ + "\n"
194
+ + textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
195
+ )
196
+
197
+ return {
198
+ "graph_pyg": pyg_graph,
199
+ "graph_nx": nx_graph,
200
+ "graph_text": textualized_graph,
201
+ }
202
+
203
+ def _run(
204
+ self,
205
+ tool_call_id: Annotated[str, InjectedToolCallId],
206
+ state: Annotated[dict, InjectedState],
207
+ prompt: str,
208
+ arg_data: ArgumentData = None,
209
+ ) -> Command:
210
+ """
211
+ Run the subgraph extraction tool.
212
+
213
+ Args:
214
+ tool_call_id: The tool call ID for the tool.
215
+ state: Injected state for the tool.
216
+ prompt: The prompt to interact with the backend.
217
+ arg_data (ArgumentData): The argument data.
218
+
219
+ Returns:
220
+ Command: The command to be executed.
221
+ """
222
+ logger.log(logging.INFO, "Invoking subgraph_extraction tool")
223
+
224
+ # Load hydra configuration
225
+ with hydra.initialize(version_base=None, config_path="../configs"):
226
+ cfg = hydra.compose(
227
+ config_name="config", overrides=["tools/subgraph_extraction=default"]
228
+ )
229
+ cfg = cfg.tools.subgraph_extraction
230
+
231
+ # Retrieve source graph from the state
232
+ initial_graph = {}
233
+ initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
234
+ # logger.log(logging.INFO, "Source graph: %s", source_graph)
235
+
236
+ # Load the knowledge graph
237
+ with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
238
+ initial_graph["pyg"] = pickle.load(f)
239
+ with open(initial_graph["source"]["kg_text_path"], "rb") as f:
240
+ initial_graph["text"] = pickle.load(f)
241
+
242
+ # Prepare prompt construction along with a list of endotypes
243
+ if len(state["uploaded_files"]) != 0 and "endotype" in [
244
+ f["file_type"] for f in state["uploaded_files"]
245
+ ]:
246
+ prompt = self.perform_endotype_filtering(prompt, state, cfg)
247
+
248
+ # Prepare embedding model and embed the user prompt as query
249
+ query_emb = torch.tensor(
250
+ EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
251
+ ).float()
252
+
253
+ # Prepare the PCSTPruning object and extract the subgraph
254
+ # Parameters were set in the configuration file obtained from Hydra
255
+ subgraph = PCSTPruning(
256
+ state["topk_nodes"],
257
+ state["topk_edges"],
258
+ cfg.cost_e,
259
+ cfg.c_const,
260
+ cfg.root,
261
+ cfg.num_clusters,
262
+ cfg.pruning,
263
+ cfg.verbosity_level,
264
+ ).extract_subgraph(initial_graph["pyg"], query_emb)
265
+
266
+ # Prepare subgraph as a NetworkX graph and textualized graph
267
+ final_subgraph = self.prepare_final_subgraph(
268
+ subgraph, initial_graph["pyg"], initial_graph["text"]
269
+ )
270
+
271
+ # Prepare the dictionary of extracted graph
272
+ dic_extracted_graph = {
273
+ "name": arg_data.extraction_name,
274
+ "tool_call_id": tool_call_id,
275
+ "graph_source": initial_graph["source"]["name"],
276
+ "topk_nodes": state["topk_nodes"],
277
+ "topk_edges": state["topk_edges"],
278
+ "graph_dict": {
279
+ "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
280
+ "edges": list(final_subgraph["graph_nx"].edges(data=True)),
281
+ },
282
+ "graph_text": final_subgraph["graph_text"],
283
+ "graph_summary": None,
284
+ }
285
+
286
+ # Prepare the dictionary of updated state
287
+ dic_updated_state_for_model = {}
288
+ for key, value in {
289
+ "dic_extracted_graph": [dic_extracted_graph],
290
+ }.items():
291
+ if value:
292
+ dic_updated_state_for_model[key] = value
293
+
294
+ # Return the updated state of the tool
295
+ return Command(
296
+ update=dic_updated_state_for_model | {
297
+ # update the message history
298
+ "messages": [
299
+ ToolMessage(
300
+ content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
301
+ tool_call_id=tool_call_id,
302
+ )
303
+ ],
304
+ }
305
+ )
@@ -0,0 +1,126 @@
1
+ """
2
+ Tool for performing subgraph summarization.
3
+ """
4
+
5
+ import logging
6
+ from typing import Type, Annotated
7
+ from pydantic import BaseModel, Field
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.messages import ToolMessage
11
+ from langchain_core.tools.base import InjectedToolCallId
12
+ from langchain_core.tools import BaseTool
13
+ from langgraph.types import Command
14
+ from langgraph.prebuilt import InjectedState
15
+ import hydra
16
+
17
+ # Initialize logger
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SubgraphSummarizationInput(BaseModel):
23
+ """
24
+ SubgraphSummarizationInput is a Pydantic model representing an input for
25
+ summarizing a given textualized subgraph.
26
+
27
+ Args:
28
+ tool_call_id: Tool call ID.
29
+ state: Injected state.
30
+ prompt: Prompt to interact with the backend.
31
+ extraction_name: Name assigned to the subgraph extraction process
32
+ """
33
+
34
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
35
+ description="Tool call ID."
36
+ )
37
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
38
+ prompt: str = Field(description="Prompt to interact with the backend.")
39
+ extraction_name: str = Field(
40
+ description="""Name assigned to the subgraph extraction process
41
+ when the subgraph_extraction tool is invoked."""
42
+ )
43
+
44
+
45
+ class SubgraphSummarizationTool(BaseTool):
46
+ """
47
+ This tool performs subgraph summarization over textualized graph to highlight the most
48
+ important information in responding to user's prompt.
49
+ """
50
+
51
+ name: str = "subgraph_summarization"
52
+ description: str = """A tool to perform subgraph summarization over textualized graph
53
+ for responding to user's follow-up prompt(s)."""
54
+ args_schema: Type[BaseModel] = SubgraphSummarizationInput
55
+
56
+ def _run(
57
+ self,
58
+ tool_call_id: Annotated[str, InjectedToolCallId],
59
+ state: Annotated[dict, InjectedState],
60
+ prompt: str,
61
+ extraction_name: str,
62
+ ):
63
+ """
64
+ Run the subgraph summarization tool.
65
+
66
+ Args:
67
+ tool_call_id: The tool call ID.
68
+ state: The injected state.
69
+ prompt: The prompt to interact with the backend.
70
+ extraction_name: The name assigned to the subgraph extraction process.
71
+ """
72
+ logger.log(
73
+ logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
74
+ )
75
+
76
+ # Load hydra configuration
77
+ with hydra.initialize(version_base=None, config_path="../configs"):
78
+ cfg = hydra.compose(
79
+ config_name="config", overrides=["tools/subgraph_summarization=default"]
80
+ )
81
+ cfg = cfg.tools.subgraph_summarization
82
+
83
+ # Load the extracted graph
84
+ extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
85
+ # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
86
+
87
+ # Prepare prompt template
88
+ prompt_template = ChatPromptTemplate.from_messages(
89
+ [
90
+ ("system", cfg.prompt_subgraph_summarization),
91
+ ("human", "{input}"),
92
+ ]
93
+ )
94
+
95
+ # Prepare chain
96
+ chain = prompt_template | state["llm_model"] | StrOutputParser()
97
+
98
+ # Return the subgraph and textualized graph as JSON response
99
+ response = chain.invoke(
100
+ {
101
+ "input": prompt,
102
+ "textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
103
+ }
104
+ )
105
+
106
+ # Store the response as graph_summary in the extracted graph
107
+ for key, value in extracted_graph.items():
108
+ if key == extraction_name:
109
+ value["graph_summary"] = response
110
+
111
+ # Prepare the dictionary of updated state
112
+ dic_updated_state_for_model = {}
113
+ for key, value in {
114
+ "dic_extracted_graph": list(extracted_graph.values()),
115
+ }.items():
116
+ if value:
117
+ dic_updated_state_for_model[key] = value
118
+
119
+ # Return the updated state of the tool
120
+ return Command(
121
+ update=dic_updated_state_for_model
122
+ | {
123
+ # update the message history
124
+ "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
125
+ }
126
+ )
@@ -1,5 +1,7 @@
1
1
  '''
2
- This file is used to import utlities.
2
+ This file is used to import all the models in the package.
3
3
  '''
4
- from . import enrichments
5
4
  from . import embeddings
5
+ from . import enrichments
6
+ from . import extractions
7
+ from . import kg_utils
@@ -4,3 +4,4 @@ This file is used to import all the models in the package.
4
4
  from . import embeddings
5
5
  from . import sentence_transformer
6
6
  from . import huggingface
7
+ from . import ollama
@@ -0,0 +1,81 @@
1
+ """
2
+ Embedding class using Ollama model based on LangChain Embeddings class.
3
+ """
4
+
5
+ import time
6
+ from typing import List
7
+ import subprocess
8
+ import ollama
9
+ from langchain_ollama import OllamaEmbeddings
10
+ from .embeddings import Embeddings
11
+
12
+ class EmbeddingWithOllama(Embeddings):
13
+ """
14
+ Embedding class using Ollama model based on LangChain Embeddings class.
15
+ """
16
+ def __init__(self, model_name: str):
17
+ """
18
+ Initialize the EmbeddingWithOllama class.
19
+
20
+ Args:
21
+ model_name: The name of the Ollama model to be used.
22
+ """
23
+ # Setup the Ollama server
24
+ self.__setup(model_name)
25
+
26
+ # Set parameters
27
+ self.model_name = model_name
28
+
29
+ # Prepare model
30
+ self.model = OllamaEmbeddings(model=self.model_name)
31
+
32
+ def __setup(self, model_name: str) -> None:
33
+ """
34
+ Check if the Ollama model is available and run the Ollama server if needed.
35
+
36
+ Args:
37
+ model_name: The name of the Ollama model to be used.
38
+ """
39
+ try:
40
+ models_list = ollama.list()["models"]
41
+ if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
42
+ ollama.pull(model_name)
43
+ time.sleep(30)
44
+ raise ValueError(f"Pulled {model_name} model")
45
+ except Exception as e:
46
+ with subprocess.Popen(
47
+ "ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
48
+ ):
49
+ time.sleep(10)
50
+ raise ValueError(f"Error: {e} and restarted Ollama server.") from e
51
+
52
+ def embed_documents(self, texts: List[str]) -> List[float]:
53
+ """
54
+ Generate embedding for a list of input texts using Ollama model.
55
+
56
+ Args:
57
+ texts: The list of texts to be embedded.
58
+
59
+ Returns:
60
+ The list of embeddings for the given texts.
61
+ """
62
+
63
+ # Generate the embedding
64
+ embeddings = self.model.embed_documents(texts)
65
+
66
+ return embeddings
67
+
68
+ def embed_query(self, text: str) -> List[float]:
69
+ """
70
+ Generate embeddings for an input text using Ollama model.
71
+
72
+ Args:
73
+ text: A query to be embedded.
74
+ Returns:
75
+ The embeddings for the given query.
76
+ """
77
+
78
+ # Generate the embedding
79
+ embeddings = self.model.embed_query(text)
80
+
81
+ return embeddings