aiagents4pharma 1.36.0__py3-none-any.whl → 1.38.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +12 -4
  2. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +2 -2
  3. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +7 -6
  4. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
  5. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
  6. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +1 -0
  7. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +12 -11
  8. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +152 -0
  9. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +36 -65
  10. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
  11. aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
  12. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
  13. aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +292 -0
  14. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -0
  15. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +33 -7
  16. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +59 -3
  17. aiagents4pharma/talk2scholars/tests/test_read_helper_utils.py +110 -0
  18. aiagents4pharma/talk2scholars/tests/test_s2_display.py +20 -1
  19. aiagents4pharma/talk2scholars/tests/test_s2_query.py +17 -0
  20. aiagents4pharma/talk2scholars/tests/test_state.py +25 -1
  21. aiagents4pharma/talk2scholars/tests/test_zotero_pdf_downloader_utils.py +46 -0
  22. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +35 -40
  23. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +62 -40
  24. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +6 -2
  25. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +2 -1
  26. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +7 -3
  27. aiagents4pharma/talk2scholars/tools/s2/search.py +2 -1
  28. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +2 -1
  29. aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +79 -136
  30. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +147 -0
  31. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +42 -9
  32. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/METADATA +2 -1
  33. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/RECORD +36 -29
  34. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/WHEEL +1 -1
  35. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/licenses/LICENSE +0 -0
  36. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,374 @@
1
+ """
2
+ Tool for performing multimodal subgraph extraction.
3
+ """
4
+
5
+ from typing import Type, Annotated
6
+ import logging
7
+ import pickle
8
+ import numpy as np
9
+ import pandas as pd
10
+ import hydra
11
+ import networkx as nx
12
+ from pydantic import BaseModel, Field
13
+ from langchain_core.tools import BaseTool
14
+ from langchain_core.messages import ToolMessage
15
+ from langchain_core.tools.base import InjectedToolCallId
16
+ from langgraph.types import Command
17
+ from langgraph.prebuilt import InjectedState
18
+ import torch
19
+ from torch_geometric.data import Data
20
+ from ..utils.extractions.multimodal_pcst import MultimodalPCSTPruning
21
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
22
+ from .load_arguments import ArgumentData
23
+
24
+ # Initialize logger
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class MultimodalSubgraphExtractionInput(BaseModel):
30
+ """
31
+ MultimodalSubgraphExtractionInput is a Pydantic model representing an input
32
+ for extracting a subgraph.
33
+
34
+ Args:
35
+ prompt: Prompt to interact with the backend.
36
+ tool_call_id: Tool call ID.
37
+ state: Injected state.
38
+ arg_data: Argument for analytical process over graph data.
39
+ """
40
+
41
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
42
+ description="Tool call ID."
43
+ )
44
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
45
+ prompt: str = Field(description="Prompt to interact with the backend.")
46
+ arg_data: ArgumentData = Field(
47
+ description="Experiment over graph data.", default=None
48
+ )
49
+
50
+
51
+ class MultimodalSubgraphExtractionTool(BaseTool):
52
+ """
53
+ This tool performs subgraph extraction based on user's prompt by taking into account
54
+ the top-k nodes and edges.
55
+ """
56
+
57
+ name: str = "subgraph_extraction"
58
+ description: str = "A tool for subgraph extraction based on user's prompt."
59
+ args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
60
+
61
+ def _prepare_query_modalities(self,
62
+ prompt_emb: list,
63
+ state: Annotated[dict, InjectedState],
64
+ pyg_graph: Data) -> pd.DataFrame:
65
+ """
66
+ Prepare the modality-specific query for subgraph extraction.
67
+
68
+ Args:
69
+ prompt_emb: The embedding of the user prompt in a list.
70
+ state: The injected state for the tool.
71
+ pyg_graph: The PyTorch Geometric graph Data.
72
+
73
+ Returns:
74
+ A DataFrame containing the query embeddings and modalities.
75
+ """
76
+ # Initialize dataframes
77
+ multimodal_df = pd.DataFrame({"name": []})
78
+ query_df = pd.DataFrame({"node_id": [],
79
+ "node_type": [],
80
+ "x": [],
81
+ "desc_x": [],
82
+ "use_description": []})
83
+
84
+ # Loop over the uploaded files and find multimodal files
85
+ for i in range(len(state["uploaded_files"])):
86
+ # Check if multimodal file is uploaded
87
+ if state["uploaded_files"][i]["file_type"] == "multimodal":
88
+ # Read the Excel file
89
+ multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
90
+ sheet_name=None)
91
+
92
+ # Check if the multimodal_df is empty
93
+ if len(multimodal_df) > 0:
94
+ # Merge all obtained dataframes into a single dataframe
95
+ multimodal_df = pd.concat(multimodal_df).reset_index()
96
+ multimodal_df.drop(columns=["level_1"], inplace=True)
97
+ multimodal_df.rename(columns={"level_0": "q_node_type",
98
+ "name": "q_node_name"}, inplace=True)
99
+ # Since an excel sheet name could not contain a `/`,
100
+ # but the node type can be 'gene/protein' as exists in the PrimeKG
101
+ multimodal_df["q_node_type"] = multimodal_df.q_node_type.apply(
102
+ lambda x: x.replace('-', '/')
103
+ )
104
+
105
+ # Convert PyG graph to a DataFrame for easier filtering
106
+ graph_df = pd.DataFrame({
107
+ "node_id": pyg_graph.node_id,
108
+ "node_name": pyg_graph.node_name,
109
+ "node_type": pyg_graph.node_type,
110
+ "x": pyg_graph.x,
111
+ "desc_x": pyg_graph.desc_x.tolist(),
112
+ })
113
+
114
+ # Make a query dataframe by merging the graph_df and multimodal_df
115
+ query_df = graph_df.merge(multimodal_df, how='cross')
116
+ query_df = query_df[
117
+ query_df.apply(
118
+ lambda x:
119
+ (x['q_node_name'].lower() in x['node_name'].lower()) & # node name
120
+ (x['node_type'] == x['q_node_type']), # node type
121
+ axis=1
122
+ )
123
+ ]
124
+ query_df = query_df[['node_id', 'node_type', 'x', 'desc_x']].reset_index(drop=True)
125
+ query_df['use_description'] = False # set to False for modal-specific embeddings
126
+
127
+ # Update the state by adding the the selected node IDs
128
+ state["selections"] = query_df.groupby("node_type")["node_id"].apply(list).to_dict()
129
+
130
+ # Append a user prompt to the query dataframe
131
+ query_df = pd.concat([
132
+ query_df,
133
+ pd.DataFrame({
134
+ 'node_id': 'user_prompt',
135
+ 'node_type': 'prompt',
136
+ 'x': prompt_emb,
137
+ 'desc_x': prompt_emb,
138
+ 'use_description': True # set to True for user prompt embedding
139
+ })
140
+ ]).reset_index(drop=True)
141
+
142
+ return query_df
143
+
144
+ def _perform_subgraph_extraction(self,
145
+ state: Annotated[dict, InjectedState],
146
+ cfg: dict,
147
+ pyg_graph: Data,
148
+ query_df: pd.DataFrame) -> dict:
149
+ """
150
+ Perform multimodal subgraph extraction based on modal-specific embeddings.
151
+
152
+ Args:
153
+ state: The injected state for the tool.
154
+ cfg: The configuration dictionary.
155
+ pyg_graph: The PyTorch Geometric graph Data.
156
+ query_df: The DataFrame containing the query embeddings and modalities.
157
+
158
+ Returns:
159
+ A dictionary containing the extracted subgraph with nodes and edges.
160
+ """
161
+ # Initialize the subgraph dictionary
162
+ subgraphs = {}
163
+ subgraphs["nodes"] = []
164
+ subgraphs["edges"] = []
165
+
166
+ # Loop over query embeddings and modalities
167
+ for q in query_df.iterrows():
168
+ # Prepare the PCSTPruning object and extract the subgraph
169
+ # Parameters were set in the configuration file obtained from Hydra
170
+ subgraph = MultimodalPCSTPruning(
171
+ topk=state["topk_nodes"],
172
+ topk_e=state["topk_edges"],
173
+ cost_e=cfg.cost_e,
174
+ c_const=cfg.c_const,
175
+ root=cfg.root,
176
+ num_clusters=cfg.num_clusters,
177
+ pruning=cfg.pruning,
178
+ verbosity_level=cfg.verbosity_level,
179
+ use_description=q[1]['use_description'],
180
+ ).extract_subgraph(pyg_graph,
181
+ torch.tensor(q[1]['desc_x']), # description embedding
182
+ torch.tensor(q[1]['x']), # modal-specific embedding
183
+ q[1]['node_type'])
184
+
185
+ # Append the extracted subgraph to the dictionary
186
+ subgraphs["nodes"].append(subgraph["nodes"].tolist())
187
+ subgraphs["edges"].append(subgraph["edges"].tolist())
188
+
189
+ # Concatenate and get unique node and edge indices
190
+ subgraphs["nodes"] = np.unique(
191
+ np.concatenate([np.array(list_) for list_ in subgraphs["nodes"]])
192
+ )
193
+ subgraphs["edges"] = np.unique(
194
+ np.concatenate([np.array(list_) for list_ in subgraphs["edges"]])
195
+ )
196
+
197
+ return subgraphs
198
+
199
+ def _prepare_final_subgraph(self,
200
+ state:Annotated[dict, InjectedState],
201
+ subgraph: dict,
202
+ graph: dict,
203
+ cfg) -> dict:
204
+ """
205
+ Prepare the subgraph based on the extracted subgraph.
206
+
207
+ Args:
208
+ state: The injected state for the tool.
209
+ subgraph: The extracted subgraph.
210
+ graph: The initial graph containing PyG and textualized graph.
211
+ cfg: The configuration dictionary.
212
+
213
+ Returns:
214
+ A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
215
+ """
216
+ # print(subgraph)
217
+ # Prepare the PyTorch Geometric graph
218
+ mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
219
+ pyg_graph = Data(
220
+ # Node features
221
+ # x=pyg_graph.x[subgraph["nodes"]],
222
+ x=[graph["pyg"].x[i] for i in subgraph["nodes"]],
223
+ node_id=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
224
+ node_name=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
225
+ enriched_node=np.array(graph["pyg"].enriched_node)[subgraph["nodes"]].tolist(),
226
+ num_nodes=len(subgraph["nodes"]),
227
+ # Edge features
228
+ edge_index=torch.LongTensor(
229
+ [
230
+ [
231
+ mapping[i]
232
+ for i in graph["pyg"].edge_index[:, subgraph["edges"]][0].tolist()
233
+ ],
234
+ [
235
+ mapping[i]
236
+ for i in graph["pyg"].edge_index[:, subgraph["edges"]][1].tolist()
237
+ ],
238
+ ]
239
+ ),
240
+ edge_attr=graph["pyg"].edge_attr[subgraph["edges"]],
241
+ edge_type=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
242
+ relation=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
243
+ label=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
244
+ enriched_edge=np.array(graph["pyg"].enriched_edge)[subgraph["edges"]].tolist(),
245
+ )
246
+
247
+ # Networkx DiGraph construction to be visualized in the frontend
248
+ nx_graph = nx.DiGraph()
249
+ # Add nodes with attributes
250
+ node_colors = {n: cfg.node_colors_dict[k]
251
+ for k, v in state["selections"].items() for n in v}
252
+ for n in pyg_graph.node_name:
253
+ nx_graph.add_node(n, color=node_colors.get(n, None))
254
+
255
+ # Add edges with attributes
256
+ edges = zip(
257
+ pyg_graph.edge_index[0].tolist(),
258
+ pyg_graph.edge_index[1].tolist(),
259
+ pyg_graph.edge_type
260
+ )
261
+ for src, dst, edge_type in edges:
262
+ nx_graph.add_edge(
263
+ pyg_graph.node_name[src],
264
+ pyg_graph.node_name[dst],
265
+ relation=edge_type,
266
+ label=edge_type,
267
+ )
268
+
269
+ # Prepare the textualized subgraph
270
+ textualized_graph = (
271
+ graph["text"]["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
272
+ + "\n"
273
+ + graph["text"]["edges"].iloc[subgraph["edges"]].to_csv(index=False)
274
+ )
275
+
276
+ return {
277
+ "graph_pyg": pyg_graph,
278
+ "graph_nx": nx_graph,
279
+ "graph_text": textualized_graph,
280
+ }
281
+
282
+ def _run(
283
+ self,
284
+ tool_call_id: Annotated[str, InjectedToolCallId],
285
+ state: Annotated[dict, InjectedState],
286
+ prompt: str,
287
+ arg_data: ArgumentData = None,
288
+ ) -> Command:
289
+ """
290
+ Run the subgraph extraction tool.
291
+
292
+ Args:
293
+ tool_call_id: The tool call ID for the tool.
294
+ state: Injected state for the tool.
295
+ prompt: The prompt to interact with the backend.
296
+ arg_data (ArgumentData): The argument data.
297
+
298
+ Returns:
299
+ Command: The command to be executed.
300
+ """
301
+ logger.log(logging.INFO, "Invoking subgraph_extraction tool")
302
+
303
+ # Load hydra configuration
304
+ with hydra.initialize(version_base=None, config_path="../configs"):
305
+ cfg = hydra.compose(
306
+ config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
307
+ )
308
+ cfg = cfg.tools.multimodal_subgraph_extraction
309
+
310
+ # Retrieve source graph from the state
311
+ initial_graph = {}
312
+ initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
313
+ # logger.log(logging.INFO, "Source graph: %s", source_graph)
314
+
315
+ # Load the knowledge graph
316
+ with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
317
+ initial_graph["pyg"] = pickle.load(f)
318
+ with open(initial_graph["source"]["kg_text_path"], "rb") as f:
319
+ initial_graph["text"] = pickle.load(f)
320
+
321
+ # Prepare the query embeddings and modalities
322
+ query_df = self._prepare_query_modalities(
323
+ [EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)],
324
+ state,
325
+ initial_graph["pyg"]
326
+ )
327
+
328
+ # Perform subgraph extraction
329
+ subgraphs = self._perform_subgraph_extraction(state,
330
+ cfg,
331
+ initial_graph["pyg"],
332
+ query_df)
333
+
334
+ # Prepare subgraph as a NetworkX graph and textualized graph
335
+ final_subgraph = self._prepare_final_subgraph(state,
336
+ subgraphs,
337
+ initial_graph,
338
+ cfg)
339
+
340
+ # Prepare the dictionary of extracted graph
341
+ dic_extracted_graph = {
342
+ "name": arg_data.extraction_name,
343
+ "tool_call_id": tool_call_id,
344
+ "graph_source": initial_graph["source"]["name"],
345
+ "topk_nodes": state["topk_nodes"],
346
+ "topk_edges": state["topk_edges"],
347
+ "graph_dict": {
348
+ "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
349
+ "edges": list(final_subgraph["graph_nx"].edges(data=True)),
350
+ },
351
+ "graph_text": final_subgraph["graph_text"],
352
+ "graph_summary": None,
353
+ }
354
+
355
+ # Prepare the dictionary of updated state
356
+ dic_updated_state_for_model = {}
357
+ for key, value in {
358
+ "dic_extracted_graph": [dic_extracted_graph],
359
+ }.items():
360
+ if value:
361
+ dic_updated_state_for_model[key] = value
362
+
363
+ # Return the updated state of the tool
364
+ return Command(
365
+ update=dic_updated_state_for_model | {
366
+ # update the message history
367
+ "messages": [
368
+ ToolMessage(
369
+ content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
370
+ tool_call_id=tool_call_id,
371
+ )
372
+ ],
373
+ }
374
+ )
@@ -2,3 +2,4 @@
2
2
  This file is used to import all the models in the package.
3
3
  '''
4
4
  from . import pcst
5
+ from . import multimodal_pcst
@@ -0,0 +1,292 @@
1
+ """
2
+ Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
+ """
4
+
5
+ from typing import Tuple, NamedTuple
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import pcst_fast
10
+ from torch_geometric.data.data import Data
11
+
12
+ class MultimodalPCSTPruning(NamedTuple):
13
+ """
14
+ Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
15
+ (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
16
+ Question Answering', NeurIPS 2024) paper.
17
+ https://arxiv.org/abs/2402.07630
18
+ https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
19
+
20
+ Args:
21
+ topk: The number of top nodes to consider.
22
+ topk_e: The number of top edges to consider.
23
+ cost_e: The cost of the edges.
24
+ c_const: The constant value for the cost of the edges computation.
25
+ root: The root node of the subgraph, -1 for unrooted.
26
+ num_clusters: The number of clusters.
27
+ pruning: The pruning strategy to use.
28
+ verbosity_level: The verbosity level.
29
+ """
30
+ topk: int = 3
31
+ topk_e: int = 3
32
+ cost_e: float = 0.5
33
+ c_const: float = 0.01
34
+ root: int = -1
35
+ num_clusters: int = 1
36
+ pruning: str = "gw"
37
+ verbosity_level: int = 0
38
+ use_description: bool = False
39
+
40
+ def _compute_node_prizes(self,
41
+ graph: Data,
42
+ query_emb: torch.Tensor,
43
+ modality: str) :
44
+ """
45
+ Compute the node prizes based on the cosine similarity between the query and nodes.
46
+
47
+ Args:
48
+ graph: The knowledge graph in PyTorch Geometric Data format.
49
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
50
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
51
+ modality: The modality to use for the subgraph extraction based on the node type.
52
+
53
+ Returns:
54
+ The prizes of the nodes.
55
+ """
56
+ # Convert PyG graph to a DataFrame
57
+ graph_df = pd.DataFrame({
58
+ "node_type": graph.node_type,
59
+ "desc_x": [x.tolist() for x in graph.desc_x],
60
+ "x": [list(x) for x in graph.x],
61
+ "score": [0.0 for _ in range(len(graph.node_id))],
62
+ })
63
+
64
+ # Calculate cosine similarity for text features and update the score
65
+ if self.use_description:
66
+ graph_df.loc[:, "score"] = torch.nn.CosineSimilarity(dim=-1)(
67
+ query_emb,
68
+ torch.tensor(list(graph_df.desc_x.values)) # Using textual description features
69
+ ).tolist()
70
+ else:
71
+ graph_df.loc[graph_df["node_type"] == modality,
72
+ "score"] = torch.nn.CosineSimilarity(dim=-1)(
73
+ query_emb,
74
+ torch.tensor(list(graph_df[graph_df["node_type"]== modality].x.values))
75
+ ).tolist()
76
+
77
+ # Set the prizes for nodes based on the similarity scores
78
+ n_prizes = torch.tensor(graph_df.score.values, dtype=torch.float32)
79
+ # n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
80
+ topk = min(self.topk, graph.num_nodes)
81
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
82
+ n_prizes = torch.zeros_like(n_prizes)
83
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
84
+
85
+ return n_prizes
86
+
87
+ def _compute_edge_prizes(self,
88
+ graph: Data,
89
+ text_emb: torch.Tensor) :
90
+ """
91
+ Compute the node prizes based on the cosine similarity between the query and nodes.
92
+
93
+ Args:
94
+ graph: The knowledge graph in PyTorch Geometric Data format.
95
+ text_emb: The textual description embedding in PyTorch Tensor format.
96
+
97
+ Returns:
98
+ The prizes of the nodes.
99
+ """
100
+ # Note that as of now, the edge features are based on textual features
101
+ # Compute prizes for edges
102
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(text_emb, graph.edge_attr)
103
+ unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
104
+ topk_e = min(self.topk_e, unique_prizes.size(0))
105
+ topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
106
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
107
+ last_topk_e_value = topk_e
108
+ for k in range(topk_e):
109
+ indices = inverse_indices == (
110
+ unique_prizes == topk_e_values[k]
111
+ ).nonzero(as_tuple=True)[0]
112
+ value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
113
+ e_prizes[indices] = value
114
+ last_topk_e_value = value * (1 - self.c_const)
115
+
116
+ return e_prizes
117
+
118
+ def compute_prizes(self,
119
+ graph: Data,
120
+ text_emb: torch.Tensor,
121
+ query_emb: torch.Tensor,
122
+ modality: str):
123
+ """
124
+ Compute the node prizes based on the cosine similarity between the query and nodes,
125
+ as well as the edge prizes based on the cosine similarity between the query and edges.
126
+ Note that the node and edge embeddings shall use the same embedding model and dimensions
127
+ with the query.
128
+
129
+ Args:
130
+ graph: The knowledge graph in PyTorch Geometric Data format.
131
+ text_emb: The textual description embedding in PyTorch Tensor format.
132
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
133
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
134
+ modality: The modality to use for the subgraph extraction based on node type.
135
+
136
+ Returns:
137
+ The prizes of the nodes and edges.
138
+ """
139
+ # Compute prizes for nodes
140
+ n_prizes = self._compute_node_prizes(graph, query_emb, modality)
141
+
142
+ # Compute prizes for edges
143
+ e_prizes = self._compute_edge_prizes(graph, text_emb)
144
+
145
+ return {"nodes": n_prizes, "edges": e_prizes}
146
+
147
+ def compute_subgraph_costs(self,
148
+ graph: Data,
149
+ prizes: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
150
+ """
151
+ Compute the costs in constructing the subgraph proposed by G-Retriever paper.
152
+
153
+ Args:
154
+ graph: The knowledge graph in PyTorch Geometric Data format.
155
+ prizes: The prizes of the nodes and the edges.
156
+
157
+ Returns:
158
+ edges: The edges of the subgraph, consisting of edges and number of edges without
159
+ virtual edges.
160
+ prizes: The prizes of the subgraph.
161
+ costs: The costs of the subgraph.
162
+ """
163
+ # Logic to reduce the cost of the edges such that at least one edge is selected
164
+ updated_cost_e = min(
165
+ self.cost_e,
166
+ prizes["edges"].max().item() * (1 - self.c_const / 2),
167
+ )
168
+
169
+ # Initialize variables
170
+ edges = []
171
+ costs = []
172
+ virtual = {
173
+ "n_prizes": [],
174
+ "edges": [],
175
+ "costs": [],
176
+ }
177
+ mapping = {"nodes": {}, "edges": {}}
178
+
179
+ # Compute the costs, edges, and virtual variables based on the prizes
180
+ for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
181
+ prize_e = prizes["edges"][i]
182
+ if prize_e <= updated_cost_e:
183
+ mapping["edges"][len(edges)] = i
184
+ edges.append((src, dst))
185
+ costs.append(updated_cost_e - prize_e)
186
+ else:
187
+ virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
188
+ mapping["nodes"][virtual_node_id] = i
189
+ virtual["edges"].append((src, virtual_node_id))
190
+ virtual["edges"].append((virtual_node_id, dst))
191
+ virtual["costs"].append(0)
192
+ virtual["costs"].append(0)
193
+ virtual["n_prizes"].append(prize_e - updated_cost_e)
194
+ prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
195
+ edges_dict = {}
196
+ edges_dict["edges"] = edges
197
+ edges_dict["num_prior_edges"] = len(edges)
198
+ # Final computation of the costs and edges based on the virtual costs and virtual edges
199
+ if len(virtual["costs"]) > 0:
200
+ costs = np.array(costs + virtual["costs"])
201
+ edges = np.array(edges + virtual["edges"])
202
+ edges_dict["edges"] = edges
203
+
204
+ return edges_dict, prizes, costs, mapping
205
+
206
+ def get_subgraph_nodes_edges(
207
+ self, graph: Data, vertices: np.ndarray, edges_dict: dict, mapping: dict,
208
+ ) -> dict:
209
+ """
210
+ Get the selected nodes and edges of the subgraph based on the vertices and edges computed
211
+ by the PCST algorithm.
212
+
213
+ Args:
214
+ graph: The knowledge graph in PyTorch Geometric Data format.
215
+ vertices: The vertices of the subgraph computed by the PCST algorithm.
216
+ edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
217
+ and the number of prior edges (without virtual edges).
218
+ mapping: The mapping dictionary of the nodes and edges.
219
+ num_prior_edges: The number of edges before adding virtual edges.
220
+
221
+ Returns:
222
+ The selected nodes and edges of the extracted subgraph.
223
+ """
224
+ # Get edges information
225
+ edges = edges_dict["edges"]
226
+ num_prior_edges = edges_dict["num_prior_edges"]
227
+ # Retrieve the selected nodes and edges based on the given vertices and edges
228
+ subgraph_nodes = vertices[vertices < graph.num_nodes]
229
+ subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
230
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
231
+ if len(virtual_vertices) > 0:
232
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
233
+ virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
234
+ subgraph_edges = np.array(subgraph_edges + virtual_edges)
235
+ edge_index = graph.edge_index[:, subgraph_edges]
236
+ subgraph_nodes = np.unique(
237
+ np.concatenate(
238
+ [subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()]
239
+ )
240
+ )
241
+
242
+ return {"nodes": subgraph_nodes, "edges": subgraph_edges}
243
+
244
+ def extract_subgraph(self,
245
+ graph: Data,
246
+ text_emb: torch.Tensor,
247
+ query_emb: torch.Tensor,
248
+ modality: str) -> dict:
249
+ """
250
+ Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
251
+
252
+ Args:
253
+ graph: The knowledge graph in PyTorch Geometric Data format.
254
+ text_emb: The textual description embedding in PyTorch Tensor format.
255
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
256
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
257
+ modality: The modality to use for the subgraph extraction
258
+ (e.g., "text", "sequence", "smiles").
259
+
260
+ Returns:
261
+ The selected nodes and edges of the subgraph.
262
+ """
263
+ # Assert the topk and topk_e values for subgraph retrieval
264
+ assert self.topk > 0, "topk must be greater than or equal to 0"
265
+ assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
266
+
267
+ # Retrieve the top-k nodes and edges based on the query embedding
268
+ prizes = self.compute_prizes(graph, text_emb, query_emb, modality)
269
+
270
+ # Compute costs in constructing the subgraph
271
+ edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
272
+ graph, prizes
273
+ )
274
+
275
+ # Retrieve the subgraph using the PCST algorithm
276
+ result_vertices, result_edges = pcst_fast.pcst_fast(
277
+ edges_dict["edges"],
278
+ prizes,
279
+ costs,
280
+ self.root,
281
+ self.num_clusters,
282
+ self.pruning,
283
+ self.verbosity_level,
284
+ )
285
+
286
+ subgraph = self.get_subgraph_nodes_edges(
287
+ graph,
288
+ result_vertices,
289
+ {"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
290
+ mapping)
291
+
292
+ return subgraph
@@ -2,6 +2,7 @@
2
2
  library_type: "user" # Type of library ('user' or 'group')
3
3
  default_limit: 2
4
4
  request_timeout: 10
5
+ chunk_size: 16384 # Size (in bytes) for streaming PDF download chunks
5
6
  user_id: ${oc.env:ZOTERO_USER_ID} # Load from environment variable
6
7
  api_key: ${oc.env:ZOTERO_API_KEY} # Load from environment variable
7
8