aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
  2. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
  3. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
  4. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
  5. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
  6. aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
  7. aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
  8. aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
  9. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
  10. aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
  11. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
  12. aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
  13. aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
  14. aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
  15. aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
  16. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
  17. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
  18. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
  19. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  20. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  21. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
  22. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  23. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
  24. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
  25. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
  26. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  27. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  28. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  29. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  30. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  31. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  32. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
  33. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
  34. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
  35. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
  36. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
  37. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
  38. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
  39. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
  40. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
  41. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
  42. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
  43. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
  44. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  45. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
  46. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
  47. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
  48. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  49. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
  50. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
  51. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
  52. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
  53. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
  54. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
  55. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
  56. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
1
+ """
2
+ Tool for performing Graph RAG reasoning.
3
+ """
4
+
5
+ import logging
6
+ from typing import Type, Annotated
7
+ from pydantic import BaseModel, Field
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.messages import ToolMessage
10
+ from langchain_core.tools.base import InjectedToolCallId
11
+ from langchain_core.tools import BaseTool
12
+ from langchain_core.vectorstores import InMemoryVectorStore
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.document_loaders import PyPDFLoader
17
+ from langgraph.types import Command
18
+ from langgraph.prebuilt import InjectedState
19
+ import hydra
20
+
21
+ # Initialize logger
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class GraphRAGReasoningInput(BaseModel):
27
+ """
28
+ GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
29
+
30
+ Args:
31
+ state: Injected state.
32
+ prompt: Prompt to interact with the backend.
33
+ extraction_name: Name assigned to the subgraph extraction process
34
+ """
35
+
36
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
37
+ description="Tool call ID."
38
+ )
39
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
40
+ prompt: str = Field(description="Prompt to interact with the backend.")
41
+ extraction_name: str = Field(
42
+ description="""Name assigned to the subgraph extraction process
43
+ when the subgraph_extraction tool is invoked."""
44
+ )
45
+
46
+
47
+ class GraphRAGReasoningTool(BaseTool):
48
+ """
49
+ This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
50
+ over user's request by considering textualized subgraph context and document context.
51
+ """
52
+
53
+ name: str = "graphrag_reasoning"
54
+ description: str = """A tool to perform reasoning using a Graph RAG approach
55
+ by considering textualized subgraph context and document context."""
56
+ args_schema: Type[BaseModel] = GraphRAGReasoningInput
57
+
58
+ def _run(
59
+ self,
60
+ tool_call_id: Annotated[str, InjectedToolCallId],
61
+ state: Annotated[dict, InjectedState],
62
+ prompt: str,
63
+ extraction_name: str,
64
+ ):
65
+ """
66
+ Run the Graph RAG reasoning tool.
67
+
68
+ Args:
69
+ tool_call_id: The tool call ID.
70
+ state: The injected state.
71
+ prompt: The prompt to interact with the backend.
72
+ extraction_name: The name assigned to the subgraph extraction process.
73
+ """
74
+ logger.log(
75
+ logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
76
+ )
77
+
78
+ # Load Hydra configuration
79
+ with hydra.initialize(version_base=None, config_path="../configs"):
80
+ cfg = hydra.compose(
81
+ config_name="config", overrides=["tools/graphrag_reasoning=default"]
82
+ )
83
+ cfg = cfg.tools.graphrag_reasoning
84
+
85
+ # Prepare documents
86
+ all_docs = []
87
+ if len(state["uploaded_files"]) != 0:
88
+ for uploaded_file in state["uploaded_files"]:
89
+ if uploaded_file["file_type"] == "drug_data":
90
+ # Load documents
91
+ raw_documents = PyPDFLoader(
92
+ file_path=uploaded_file["file_path"]
93
+ ).load()
94
+
95
+ # Split documents
96
+ # May need to find an optimal chunk size and overlap configuration
97
+ documents = RecursiveCharacterTextSplitter(
98
+ chunk_size=cfg.splitter_chunk_size,
99
+ chunk_overlap=cfg.splitter_chunk_overlap,
100
+ ).split_documents(raw_documents)
101
+
102
+ # Add documents to the list
103
+ all_docs.extend(documents)
104
+
105
+ # Load the extracted graph
106
+ extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
107
+ # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
108
+
109
+ # Set another prompt template
110
+ prompt_template = ChatPromptTemplate.from_messages(
111
+ [("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
112
+ )
113
+
114
+ # Prepare chain with retrieved documents
115
+ qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
116
+ rag_chain = create_retrieval_chain(
117
+ InMemoryVectorStore.from_documents(
118
+ documents=all_docs, embedding=state["embedding_model"]
119
+ ).as_retriever(
120
+ search_type=cfg.retriever_search_type,
121
+ search_kwargs={
122
+ "k": cfg.retriever_k,
123
+ "fetch_k": cfg.retriever_fetch_k,
124
+ "lambda_mult": cfg.retriever_lambda_mult,
125
+ },
126
+ ),
127
+ qa_chain,
128
+ )
129
+
130
+ # Invoke the chain
131
+ response = rag_chain.invoke(
132
+ {
133
+ "input": prompt,
134
+ "subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
135
+ }
136
+ )
137
+
138
+ return Command(
139
+ update={
140
+ # update the message history
141
+ "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
142
+ }
143
+ )
@@ -0,0 +1,22 @@
1
+ """
2
+ A utility module for defining the dataclasses
3
+ for the arguments to set up initial settings
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Annotated
8
+
9
+
10
+ @dataclass
11
+ class ArgumentData:
12
+ """
13
+ Dataclass for storing the argument data.
14
+ """
15
+
16
+ extraction_name: Annotated[
17
+ str,
18
+ """An AI assigned _ separated name of the subgraph extraction
19
+ based on human query and the context of the graph reasoning
20
+ experiment.
21
+ This must be set before the subgraph extraction is invoked.""",
22
+ ]
@@ -0,0 +1,305 @@
1
+ """
2
+ Tool for performing subgraph extraction.
3
+ """
4
+
5
+ from typing import Type, Annotated
6
+ import logging
7
+ import pickle
8
+ import numpy as np
9
+ import pandas as pd
10
+ import hydra
11
+ import networkx as nx
12
+ from pydantic import BaseModel, Field
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_core.vectorstores import InMemoryVectorStore
17
+ from langchain_core.tools import BaseTool
18
+ from langchain_core.messages import ToolMessage
19
+ from langchain_core.tools.base import InjectedToolCallId
20
+ from langchain_community.document_loaders import PyPDFLoader
21
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
22
+ from langgraph.types import Command
23
+ from langgraph.prebuilt import InjectedState
24
+ import torch
25
+ from torch_geometric.data import Data
26
+ from ..utils.extractions.pcst import PCSTPruning
27
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
28
+ from .load_arguments import ArgumentData
29
+
30
+ # Initialize logger
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SubgraphExtractionInput(BaseModel):
36
+ """
37
+ SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
38
+
39
+ Args:
40
+ prompt: Prompt to interact with the backend.
41
+ tool_call_id: Tool call ID.
42
+ state: Injected state.
43
+ arg_data: Argument for analytical process over graph data.
44
+ """
45
+
46
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
47
+ description="Tool call ID."
48
+ )
49
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
50
+ prompt: str = Field(description="Prompt to interact with the backend.")
51
+ arg_data: ArgumentData = Field(
52
+ description="Experiment over graph data.", default=None
53
+ )
54
+
55
+
56
+ class SubgraphExtractionTool(BaseTool):
57
+ """
58
+ This tool performs subgraph extraction based on user's prompt by taking into account
59
+ the top-k nodes and edges.
60
+ """
61
+
62
+ name: str = "subgraph_extraction"
63
+ description: str = "A tool for subgraph extraction based on user's prompt."
64
+ args_schema: Type[BaseModel] = SubgraphExtractionInput
65
+
66
+ def perform_endotype_filtering(
67
+ self,
68
+ prompt: str,
69
+ state: Annotated[dict, InjectedState],
70
+ cfg: hydra.core.config_store.ConfigStore,
71
+ ) -> str:
72
+ """
73
+ Perform endotype filtering based on the uploaded files and prepare the prompt.
74
+
75
+ Args:
76
+ prompt: The prompt to interact with the backend.
77
+ state: Injected state for the tool.
78
+ cfg: Hydra configuration object.
79
+ """
80
+ # Loop through the uploaded files
81
+ all_genes = []
82
+ for uploaded_file in state["uploaded_files"]:
83
+ if uploaded_file["file_type"] == "endotype":
84
+ # Load the PDF file
85
+ docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
86
+
87
+ # Split the text into chunks
88
+ splits = RecursiveCharacterTextSplitter(
89
+ chunk_size=cfg.splitter_chunk_size,
90
+ chunk_overlap=cfg.splitter_chunk_overlap,
91
+ ).split_documents(docs)
92
+
93
+ # Create a chat prompt template
94
+ prompt_template = ChatPromptTemplate.from_messages(
95
+ [
96
+ ("system", cfg.prompt_endotype_filtering),
97
+ ("human", "{input}"),
98
+ ]
99
+ )
100
+
101
+ qa_chain = create_stuff_documents_chain(
102
+ state["llm_model"], prompt_template
103
+ )
104
+ rag_chain = create_retrieval_chain(
105
+ InMemoryVectorStore.from_documents(
106
+ documents=splits, embedding=state["embedding_model"]
107
+ ).as_retriever(
108
+ search_type=cfg.retriever_search_type,
109
+ search_kwargs={
110
+ "k": cfg.retriever_k,
111
+ "fetch_k": cfg.retriever_fetch_k,
112
+ "lambda_mult": cfg.retriever_lambda_mult,
113
+ },
114
+ ),
115
+ qa_chain,
116
+ )
117
+ results = rag_chain.invoke({"input": prompt})
118
+ all_genes.append(results["answer"])
119
+
120
+ # Prepare the prompt
121
+ if len(all_genes) > 0:
122
+ prompt = " ".join(
123
+ [prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
124
+ )
125
+
126
+ return prompt
127
+
128
+ def prepare_final_subgraph(self,
129
+ subgraph: dict,
130
+ pyg_graph: Data,
131
+ textualized_graph: pd.DataFrame) -> dict:
132
+ """
133
+ Prepare the subgraph based on the extracted subgraph.
134
+
135
+ Args:
136
+ subgraph: The extracted subgraph.
137
+ pyg_graph: The PyTorch Geometric graph.
138
+ textualized_graph: The textualized graph.
139
+
140
+ Returns:
141
+ A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
142
+ """
143
+ # print(subgraph)
144
+ # Prepare the PyTorch Geometric graph
145
+ mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
146
+ pyg_graph = Data(
147
+ # Node features
148
+ x=pyg_graph.x[subgraph["nodes"]],
149
+ node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
150
+ node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
151
+ enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
152
+ num_nodes=len(subgraph["nodes"]),
153
+ # Edge features
154
+ edge_index=torch.LongTensor(
155
+ [
156
+ [
157
+ mapping[i]
158
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
159
+ ],
160
+ [
161
+ mapping[i]
162
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
163
+ ],
164
+ ]
165
+ ),
166
+ edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
167
+ edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
168
+ relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
169
+ label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
170
+ enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
171
+ )
172
+
173
+ # Networkx DiGraph construction to be visualized in the frontend
174
+ nx_graph = nx.DiGraph()
175
+ for n in pyg_graph.node_name:
176
+ nx_graph.add_node(n)
177
+ for i, e in enumerate(
178
+ [
179
+ [pyg_graph.node_name[i], pyg_graph.node_name[j]]
180
+ for (i, j) in pyg_graph.edge_index.transpose(1, 0)
181
+ ]
182
+ ):
183
+ nx_graph.add_edge(
184
+ e[0],
185
+ e[1],
186
+ relation=pyg_graph.edge_type[i],
187
+ label=pyg_graph.edge_type[i],
188
+ )
189
+
190
+ # Prepare the textualized subgraph
191
+ textualized_graph = (
192
+ textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
193
+ + "\n"
194
+ + textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
195
+ )
196
+
197
+ return {
198
+ "graph_pyg": pyg_graph,
199
+ "graph_nx": nx_graph,
200
+ "graph_text": textualized_graph,
201
+ }
202
+
203
+ def _run(
204
+ self,
205
+ tool_call_id: Annotated[str, InjectedToolCallId],
206
+ state: Annotated[dict, InjectedState],
207
+ prompt: str,
208
+ arg_data: ArgumentData = None,
209
+ ) -> Command:
210
+ """
211
+ Run the subgraph extraction tool.
212
+
213
+ Args:
214
+ tool_call_id: The tool call ID for the tool.
215
+ state: Injected state for the tool.
216
+ prompt: The prompt to interact with the backend.
217
+ arg_data (ArgumentData): The argument data.
218
+
219
+ Returns:
220
+ Command: The command to be executed.
221
+ """
222
+ logger.log(logging.INFO, "Invoking subgraph_extraction tool")
223
+
224
+ # Load hydra configuration
225
+ with hydra.initialize(version_base=None, config_path="../configs"):
226
+ cfg = hydra.compose(
227
+ config_name="config", overrides=["tools/subgraph_extraction=default"]
228
+ )
229
+ cfg = cfg.tools.subgraph_extraction
230
+
231
+ # Retrieve source graph from the state
232
+ initial_graph = {}
233
+ initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
234
+ # logger.log(logging.INFO, "Source graph: %s", source_graph)
235
+
236
+ # Load the knowledge graph
237
+ with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
238
+ initial_graph["pyg"] = pickle.load(f)
239
+ with open(initial_graph["source"]["kg_text_path"], "rb") as f:
240
+ initial_graph["text"] = pickle.load(f)
241
+
242
+ # Prepare prompt construction along with a list of endotypes
243
+ if len(state["uploaded_files"]) != 0 and "endotype" in [
244
+ f["file_type"] for f in state["uploaded_files"]
245
+ ]:
246
+ prompt = self.perform_endotype_filtering(prompt, state, cfg)
247
+
248
+ # Prepare embedding model and embed the user prompt as query
249
+ query_emb = torch.tensor(
250
+ EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
251
+ ).float()
252
+
253
+ # Prepare the PCSTPruning object and extract the subgraph
254
+ # Parameters were set in the configuration file obtained from Hydra
255
+ subgraph = PCSTPruning(
256
+ state["topk_nodes"],
257
+ state["topk_edges"],
258
+ cfg.cost_e,
259
+ cfg.c_const,
260
+ cfg.root,
261
+ cfg.num_clusters,
262
+ cfg.pruning,
263
+ cfg.verbosity_level,
264
+ ).extract_subgraph(initial_graph["pyg"], query_emb)
265
+
266
+ # Prepare subgraph as a NetworkX graph and textualized graph
267
+ final_subgraph = self.prepare_final_subgraph(
268
+ subgraph, initial_graph["pyg"], initial_graph["text"]
269
+ )
270
+
271
+ # Prepare the dictionary of extracted graph
272
+ dic_extracted_graph = {
273
+ "name": arg_data.extraction_name,
274
+ "tool_call_id": tool_call_id,
275
+ "graph_source": initial_graph["source"]["name"],
276
+ "topk_nodes": state["topk_nodes"],
277
+ "topk_edges": state["topk_edges"],
278
+ "graph_dict": {
279
+ "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
280
+ "edges": list(final_subgraph["graph_nx"].edges(data=True)),
281
+ },
282
+ "graph_text": final_subgraph["graph_text"],
283
+ "graph_summary": None,
284
+ }
285
+
286
+ # Prepare the dictionary of updated state
287
+ dic_updated_state_for_model = {}
288
+ for key, value in {
289
+ "dic_extracted_graph": [dic_extracted_graph],
290
+ }.items():
291
+ if value:
292
+ dic_updated_state_for_model[key] = value
293
+
294
+ # Return the updated state of the tool
295
+ return Command(
296
+ update=dic_updated_state_for_model | {
297
+ # update the message history
298
+ "messages": [
299
+ ToolMessage(
300
+ content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
301
+ tool_call_id=tool_call_id,
302
+ )
303
+ ],
304
+ }
305
+ )
@@ -0,0 +1,126 @@
1
+ """
2
+ Tool for performing subgraph summarization.
3
+ """
4
+
5
+ import logging
6
+ from typing import Type, Annotated
7
+ from pydantic import BaseModel, Field
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.messages import ToolMessage
11
+ from langchain_core.tools.base import InjectedToolCallId
12
+ from langchain_core.tools import BaseTool
13
+ from langgraph.types import Command
14
+ from langgraph.prebuilt import InjectedState
15
+ import hydra
16
+
17
+ # Initialize logger
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SubgraphSummarizationInput(BaseModel):
23
+ """
24
+ SubgraphSummarizationInput is a Pydantic model representing an input for
25
+ summarizing a given textualized subgraph.
26
+
27
+ Args:
28
+ tool_call_id: Tool call ID.
29
+ state: Injected state.
30
+ prompt: Prompt to interact with the backend.
31
+ extraction_name: Name assigned to the subgraph extraction process
32
+ """
33
+
34
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
35
+ description="Tool call ID."
36
+ )
37
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
38
+ prompt: str = Field(description="Prompt to interact with the backend.")
39
+ extraction_name: str = Field(
40
+ description="""Name assigned to the subgraph extraction process
41
+ when the subgraph_extraction tool is invoked."""
42
+ )
43
+
44
+
45
+ class SubgraphSummarizationTool(BaseTool):
46
+ """
47
+ This tool performs subgraph summarization over textualized graph to highlight the most
48
+ important information in responding to user's prompt.
49
+ """
50
+
51
+ name: str = "subgraph_summarization"
52
+ description: str = """A tool to perform subgraph summarization over textualized graph
53
+ for responding to user's follow-up prompt(s)."""
54
+ args_schema: Type[BaseModel] = SubgraphSummarizationInput
55
+
56
+ def _run(
57
+ self,
58
+ tool_call_id: Annotated[str, InjectedToolCallId],
59
+ state: Annotated[dict, InjectedState],
60
+ prompt: str,
61
+ extraction_name: str,
62
+ ):
63
+ """
64
+ Run the subgraph summarization tool.
65
+
66
+ Args:
67
+ tool_call_id: The tool call ID.
68
+ state: The injected state.
69
+ prompt: The prompt to interact with the backend.
70
+ extraction_name: The name assigned to the subgraph extraction process.
71
+ """
72
+ logger.log(
73
+ logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
74
+ )
75
+
76
+ # Load hydra configuration
77
+ with hydra.initialize(version_base=None, config_path="../configs"):
78
+ cfg = hydra.compose(
79
+ config_name="config", overrides=["tools/subgraph_summarization=default"]
80
+ )
81
+ cfg = cfg.tools.subgraph_summarization
82
+
83
+ # Load the extracted graph
84
+ extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
85
+ # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
86
+
87
+ # Prepare prompt template
88
+ prompt_template = ChatPromptTemplate.from_messages(
89
+ [
90
+ ("system", cfg.prompt_subgraph_summarization),
91
+ ("human", "{input}"),
92
+ ]
93
+ )
94
+
95
+ # Prepare chain
96
+ chain = prompt_template | state["llm_model"] | StrOutputParser()
97
+
98
+ # Return the subgraph and textualized graph as JSON response
99
+ response = chain.invoke(
100
+ {
101
+ "input": prompt,
102
+ "textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
103
+ }
104
+ )
105
+
106
+ # Store the response as graph_summary in the extracted graph
107
+ for key, value in extracted_graph.items():
108
+ if key == extraction_name:
109
+ value["graph_summary"] = response
110
+
111
+ # Prepare the dictionary of updated state
112
+ dic_updated_state_for_model = {}
113
+ for key, value in {
114
+ "dic_extracted_graph": list(extracted_graph.values()),
115
+ }.items():
116
+ if value:
117
+ dic_updated_state_for_model[key] = value
118
+
119
+ # Return the updated state of the tool
120
+ return Command(
121
+ update=dic_updated_state_for_model
122
+ | {
123
+ # update the message history
124
+ "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
125
+ }
126
+ )
@@ -1,5 +1,7 @@
1
1
  '''
2
- This file is used to import utlities.
2
+ This file is used to import all the models in the package.
3
3
  '''
4
- from . import enrichments
5
4
  from . import embeddings
5
+ from . import enrichments
6
+ from . import extractions
7
+ from . import kg_utils
@@ -4,3 +4,4 @@ This file is used to import all the models in the package.
4
4
  from . import embeddings
5
5
  from . import sentence_transformer
6
6
  from . import huggingface
7
+ from . import ollama
@@ -0,0 +1,81 @@
1
+ """
2
+ Embedding class using Ollama model based on LangChain Embeddings class.
3
+ """
4
+
5
+ import time
6
+ from typing import List
7
+ import subprocess
8
+ import ollama
9
+ from langchain_ollama import OllamaEmbeddings
10
+ from .embeddings import Embeddings
11
+
12
+ class EmbeddingWithOllama(Embeddings):
13
+ """
14
+ Embedding class using Ollama model based on LangChain Embeddings class.
15
+ """
16
+ def __init__(self, model_name: str):
17
+ """
18
+ Initialize the EmbeddingWithOllama class.
19
+
20
+ Args:
21
+ model_name: The name of the Ollama model to be used.
22
+ """
23
+ # Setup the Ollama server
24
+ self.__setup(model_name)
25
+
26
+ # Set parameters
27
+ self.model_name = model_name
28
+
29
+ # Prepare model
30
+ self.model = OllamaEmbeddings(model=self.model_name)
31
+
32
+ def __setup(self, model_name: str) -> None:
33
+ """
34
+ Check if the Ollama model is available and run the Ollama server if needed.
35
+
36
+ Args:
37
+ model_name: The name of the Ollama model to be used.
38
+ """
39
+ try:
40
+ models_list = ollama.list()["models"]
41
+ if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
42
+ ollama.pull(model_name)
43
+ time.sleep(30)
44
+ raise ValueError(f"Pulled {model_name} model")
45
+ except Exception as e:
46
+ with subprocess.Popen(
47
+ "ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
48
+ ):
49
+ time.sleep(10)
50
+ raise ValueError(f"Error: {e} and restarted Ollama server.") from e
51
+
52
+ def embed_documents(self, texts: List[str]) -> List[float]:
53
+ """
54
+ Generate embedding for a list of input texts using Ollama model.
55
+
56
+ Args:
57
+ texts: The list of texts to be embedded.
58
+
59
+ Returns:
60
+ The list of embeddings for the given texts.
61
+ """
62
+
63
+ # Generate the embedding
64
+ embeddings = self.model.embed_documents(texts)
65
+
66
+ return embeddings
67
+
68
+ def embed_query(self, text: str) -> List[float]:
69
+ """
70
+ Generate embeddings for an input text using Ollama model.
71
+
72
+ Args:
73
+ text: A query to be embedded.
74
+ Returns:
75
+ The embeddings for the given query.
76
+ """
77
+
78
+ # Generate the embedding
79
+ embeddings = self.model.embed_query(text)
80
+
81
+ return embeddings