aiagents4pharma 1.36.0__py3-none-any.whl → 1.37.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,12 +20,20 @@ def input_dict_fixture():
20
20
  input_dict = {
21
21
  "topk_nodes": 3,
22
22
  "topk_edges": 3,
23
+ "selections": {
24
+ "gene/protein": [],
25
+ "molecular_function": [],
26
+ "cellular_component": [],
27
+ "biological_process": [],
28
+ "drug": [],
29
+ "disease": []
30
+ },
23
31
  "uploaded_files": [],
24
32
  "dic_source_graph": [
25
33
  {
26
- "name": "PrimeKG",
27
- "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
28
- "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
34
+ "name": "BioBridge",
35
+ "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
36
+ "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
29
37
  }
30
38
  ],
31
39
  "dic_extracted_graph": []
@@ -70,7 +78,7 @@ def test_main_agent_invokes_t2kg(input_dict):
70
78
  current_state = app.get_state(config)
71
79
  dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
72
80
  assert isinstance(dic_extracted_graph, dict)
73
- assert dic_extracted_graph["graph_source"] == "PrimeKG"
81
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
74
82
  assert dic_extracted_graph["topk_nodes"] == 3
75
83
  assert dic_extracted_graph["topk_edges"] == 3
76
84
  assert isinstance(dic_extracted_graph["graph_dict"], dict)
@@ -9,7 +9,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
9
9
  from langgraph.checkpoint.memory import MemorySaver
10
10
  from langgraph.graph import START, StateGraph
11
11
  from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
12
- from ..tools.subgraph_extraction import SubgraphExtractionTool
12
+ from ..tools.multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
13
13
  from ..tools.subgraph_summarization import SubgraphSummarizationTool
14
14
  from ..tools.graphrag_reasoning import GraphRAGReasoningTool
15
15
  from ..states.state_talk2knowledgegraphs import Talk2KnowledgeGraphs
@@ -39,7 +39,7 @@ def get_app(uniq_id, llm_model: BaseChatModel):
39
39
  cfg = cfg.agents.t2kg_agent
40
40
 
41
41
  # Define the tools
42
- subgraph_extraction = SubgraphExtractionTool()
42
+ subgraph_extraction = MultimodalSubgraphExtractionTool()
43
43
  subgraph_summarization = SubgraphSummarizationTool()
44
44
  graphrag_reasoning = GraphRAGReasoningTool()
45
45
  tools = ToolNode([
@@ -2,12 +2,13 @@ _target_: app.frontend.streamlit_app_talk2knowledgegraphs
2
2
  default_user: "talk2kg_user"
3
3
  data_package_allowed_file_types:
4
4
  - "pdf"
5
- endotype_allowed_file_types:
6
- - "pdf"
5
+ multimodal_allowed_file_types:
6
+ - "xls"
7
+ - "xlsx"
7
8
  upload_data_dir: "../files"
8
9
  kg_name: "PrimeKG"
9
- kg_pyg_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/primekg_ibd_pyg_graph.pkl"
10
- kg_text_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/primekg_ibd_text_graph.pkl"
10
+ kg_pyg_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_pyg_graph.pkl"
11
+ kg_text_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_text_graph.pkl"
11
12
  openai_api_key: ${oc.env:OPENAI_API_KEY}
12
13
  openai_llms:
13
14
  - "gpt-4o-mini"
@@ -23,9 +24,9 @@ ollama_embeddings:
23
24
  - "nomic-embed-text"
24
25
  temperature: 0.1
25
26
  streaming: False
26
- reasoning_subgraph_topk_nodes: 10
27
+ reasoning_subgraph_topk_nodes: 5
27
28
  reasoning_subgraph_topk_nodes_min: 1
28
29
  reasoning_subgraph_topk_nodes_max: 100
29
- reasoning_subgraph_topk_edges: 10
30
+ reasoning_subgraph_topk_edges: 5
30
31
  reasoning_subgraph_topk_edges_min: 1
31
32
  reasoning_subgraph_topk_edges_max: 100
@@ -2,6 +2,7 @@ defaults:
2
2
  - _self_
3
3
  - agents/t2kg_agent: default
4
4
  - tools/subgraph_extraction: default
5
+ - tools/multimodal_subgraph_extraction: default
5
6
  - tools/subgraph_summarization: default
6
7
  - tools/graphrag_reasoning: default
7
8
  - utils/pubchem_utils: default
@@ -31,6 +31,7 @@ class Talk2KnowledgeGraphs(AgentState):
31
31
 
32
32
  llm_model: BaseChatModel
33
33
  embedding_model: Embeddings
34
+ selections: dict
34
35
  uploaded_files: list
35
36
  topk_nodes: int
36
37
  topk_edges: int
@@ -19,6 +19,14 @@ def input_dict_fixture():
19
19
  input_dict = {
20
20
  "llm_model": None, # TBA for each test case
21
21
  "embedding_model": None, # TBA for each test case
22
+ "selections": {
23
+ "gene/protein": [],
24
+ "molecular_function": [],
25
+ "cellular_component": [],
26
+ "biological_process": [],
27
+ "drug": [],
28
+ "disease": []
29
+ },
22
30
  "uploaded_files": [
23
31
  {
24
32
  "file_name": "adalimumab.pdf",
@@ -27,21 +35,14 @@ def input_dict_fixture():
27
35
  "uploaded_by": "VPEUser",
28
36
  "uploaded_timestamp": "2024-11-05 00:00:00",
29
37
  },
30
- {
31
- "file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
32
- "file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
33
- "file_type": "endotype",
34
- "uploaded_by": "VPEUser",
35
- "uploaded_timestamp": "2024-11-05 00:00:00",
36
- },
37
38
  ],
38
39
  "topk_nodes": 3,
39
40
  "topk_edges": 3,
40
41
  "dic_source_graph": [
41
42
  {
42
- "name": "PrimeKG",
43
- "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
44
- "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
43
+ "name": "BioBridge",
44
+ "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
45
+ "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
45
46
  }
46
47
  ],
47
48
  "dic_extracted_graph": []
@@ -96,7 +97,7 @@ def test_t2kg_agent_openai(input_dict):
96
97
  dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
97
98
  assert isinstance(dic_extracted_graph, dict)
98
99
  assert dic_extracted_graph["name"] == "subkg_12345"
99
- assert dic_extracted_graph["graph_source"] == "PrimeKG"
100
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
100
101
  assert dic_extracted_graph["topk_nodes"] == 3
101
102
  assert dic_extracted_graph["topk_edges"] == 3
102
103
  assert isinstance(dic_extracted_graph["graph_dict"], dict)
@@ -0,0 +1,152 @@
1
+ """
2
+ Test cases for tools/subgraph_extraction.py
3
+ """
4
+
5
+ import pytest
6
+ # from langchain_openai import ChatOpenAI, OpenAIEmbeddings
7
+ from ..tools.multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
8
+
9
+ # Define the data path
10
+ DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
11
+
12
+
13
+ @pytest.fixture(name="agent_state")
14
+ def agent_state_fixture():
15
+ """
16
+ Agent state fixture.
17
+ """
18
+ agent_state = {
19
+ # "llm_model": ChatOpenAI(model="gpt-4o-mini", temperature=0.0),
20
+ # "embedding_model": OpenAIEmbeddings(model="text-embedding-3-small"),
21
+ "selections": {
22
+ "gene/protein": [],
23
+ "molecular_function": [],
24
+ "cellular_component": [],
25
+ "biological_process": [],
26
+ "drug": [],
27
+ "disease": []
28
+ },
29
+ "uploaded_files": [],
30
+ "topk_nodes": 3,
31
+ "topk_edges": 3,
32
+ "dic_source_graph": [
33
+ {
34
+ "name": "BioBridge",
35
+ "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
36
+ "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
37
+ }
38
+ ],
39
+ }
40
+
41
+ return agent_state
42
+
43
+
44
+ def test_extract_multimodal_subgraph_wo_doc(agent_state):
45
+ """
46
+ Test the multimodal subgraph extraction tool for only text as modality.
47
+
48
+ Args:
49
+ agent_state: Agent state in the form of a dictionary.
50
+ """
51
+ prompt = """
52
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
53
+ (IBD) that existed in the knowledge graph.
54
+ Please set the extraction name for this process as `subkg_12345`.
55
+ """
56
+
57
+ # Instantiate the MultimodalSubgraphExtractionTool
58
+ subgraph_extraction_tool = MultimodalSubgraphExtractionTool()
59
+
60
+ # Invoking the subgraph_extraction_tool
61
+ response = subgraph_extraction_tool.invoke(
62
+ input={"prompt": prompt,
63
+ "tool_call_id": "subgraph_extraction_tool",
64
+ "state": agent_state,
65
+ "arg_data": {"extraction_name": "subkg_12345"}})
66
+
67
+ # Check tool message
68
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
69
+
70
+ # Check extracted subgraph dictionary
71
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
72
+ assert isinstance(dic_extracted_graph, dict)
73
+ assert dic_extracted_graph["name"] == "subkg_12345"
74
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
75
+ assert dic_extracted_graph["topk_nodes"] == 3
76
+ assert dic_extracted_graph["topk_edges"] == 3
77
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
78
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
79
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
80
+ assert isinstance(dic_extracted_graph["graph_text"], str)
81
+ # Check if the nodes are in the graph_text
82
+ assert all(
83
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
84
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
85
+ )
86
+ # Check if the edges are in the graph_text
87
+ assert all(
88
+ ",".join([e[0], str(tuple(e[2]["relation"])), e[1]])
89
+ in dic_extracted_graph["graph_text"].replace('"', '')
90
+ for e in dic_extracted_graph["graph_dict"]["edges"]
91
+ )
92
+
93
+
94
+ def test_extract_multimodal_subgraph_w_doc(agent_state):
95
+ """
96
+ Test the multimodal subgraph extraction tool for text as modality, plus genes.
97
+
98
+ Args:
99
+ agent_state: Agent state in the form of a dictionary.
100
+ """
101
+ # Update state
102
+ agent_state["uploaded_files"] = [
103
+ {
104
+ "file_name": "multimodal-analysis.xlsx",
105
+ "file_path": f"{DATA_PATH}/multimodal-analysis.xlsx",
106
+ "file_type": "multimodal",
107
+ "uploaded_by": "VPEUser",
108
+ "uploaded_timestamp": "2025-05-12 00:00:00",
109
+ }
110
+ ]
111
+
112
+ prompt = """
113
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
114
+ (IBD) that existed in the knowledge graph.
115
+ Please set the extraction name for this process as `subkg_12345`.
116
+ """
117
+
118
+ # Instantiate the SubgraphExtractionTool
119
+ subgraph_extraction_tool = MultimodalSubgraphExtractionTool()
120
+
121
+ # Invoking the subgraph_extraction_tool
122
+ response = subgraph_extraction_tool.invoke(
123
+ input={"prompt": prompt,
124
+ "tool_call_id": "subgraph_extraction_tool",
125
+ "state": agent_state,
126
+ "arg_data": {"extraction_name": "subkg_12345"}})
127
+
128
+ # Check tool message
129
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
130
+
131
+ # Check extracted subgraph dictionary
132
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
133
+ assert isinstance(dic_extracted_graph, dict)
134
+ assert dic_extracted_graph["name"] == "subkg_12345"
135
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
136
+ assert dic_extracted_graph["topk_nodes"] == 3
137
+ assert dic_extracted_graph["topk_edges"] == 3
138
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
139
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
140
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
141
+ assert isinstance(dic_extracted_graph["graph_text"], str)
142
+ # Check if the nodes are in the graph_text
143
+ assert all(
144
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
145
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
146
+ )
147
+ # Check if the edges are in the graph_text
148
+ assert all(
149
+ ",".join([e[0], str(tuple(e[2]["relation"])), e[1]])
150
+ in dic_extracted_graph["graph_text"].replace('"', '')
151
+ for e in dic_extracted_graph["graph_dict"]["edges"]
152
+ )
@@ -3,22 +3,21 @@ Test cases for tools/subgraph_extraction.py
3
3
  """
4
4
 
5
5
  import pytest
6
- from langchain_core.messages import HumanMessage
7
6
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
- from ..agents.t2kg_agent import get_app
7
+ from ..tools.subgraph_extraction import SubgraphExtractionTool
9
8
 
10
9
  # Define the data path
11
10
  DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
12
11
 
13
12
 
14
- @pytest.fixture(name="input_dict")
15
- def input_dict_fixture():
13
+ @pytest.fixture(name="agent_state")
14
+ def agent_state_fixture():
16
15
  """
17
- Input dictionary fixture.
16
+ Agent state fixture.
18
17
  """
19
- input_dict = {
20
- "llm_model": None, # TBA for each test case
21
- "embedding_model": None, # TBA for each test case
18
+ agent_state = {
19
+ "llm_model": ChatOpenAI(model="gpt-4o-mini", temperature=0.0),
20
+ "embedding_model": OpenAIEmbeddings(model="text-embedding-3-small"),
22
21
  "uploaded_files": [],
23
22
  "topk_nodes": 3,
24
23
  "topk_edges": 3,
@@ -31,52 +30,37 @@ def input_dict_fixture():
31
30
  ],
32
31
  }
33
32
 
34
- return input_dict
33
+ return agent_state
35
34
 
36
35
 
37
- def test_extract_subgraph_wo_docs(input_dict):
36
+ def test_extract_subgraph_wo_docs(agent_state):
38
37
  """
39
38
  Test the subgraph extraction tool without any documents using OpenAI model.
40
39
 
41
40
  Args:
42
- input_dict: Input dictionary.
41
+ agent_state: Agent state in the form of a dictionary.
43
42
  """
44
- # Prepare LLM and embedding model
45
- input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
46
- input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
47
-
48
- # Setup the app
49
- unique_id = 12345
50
- app = get_app(unique_id, llm_model=input_dict["llm_model"])
51
- config = {"configurable": {"thread_id": unique_id}}
52
- # Update state
53
- app.update_state(
54
- config,
55
- input_dict,
56
- )
57
43
  prompt = """
58
- Please directly invoke `subgraph_extraction` tool without calling any other tools
59
- to respond to the following prompt:
60
-
61
44
  Extract all relevant information related to nodes of genes related to inflammatory bowel disease
62
45
  (IBD) that existed in the knowledge graph.
63
46
  Please set the extraction name for this process as `subkg_12345`.
64
47
  """
65
48
 
66
- # Test the tool subgraph_extraction
67
- response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
49
+ # Instantiate the SubgraphExtractionTool
50
+ subgraph_extraction_tool = SubgraphExtractionTool()
68
51
 
69
- # Check assistant message
70
- assistant_msg = response["messages"][-1].content
71
- assert isinstance(assistant_msg, str)
52
+ # Invoking the subgraph_extraction_tool
53
+ response = subgraph_extraction_tool.invoke(
54
+ input={"prompt": prompt,
55
+ "tool_call_id": "subgraph_extraction_tool",
56
+ "state": agent_state,
57
+ "arg_data": {"extraction_name": "subkg_12345"}})
72
58
 
73
59
  # Check tool message
74
- tool_msg = response["messages"][-2]
75
- assert tool_msg.name == "subgraph_extraction"
60
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
76
61
 
77
62
  # Check extracted subgraph dictionary
78
- current_state = app.get_state(config)
79
- dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
63
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
80
64
  assert isinstance(dic_extracted_graph, dict)
81
65
  assert dic_extracted_graph["name"] == "subkg_12345"
82
66
  assert dic_extracted_graph["graph_source"] == "PrimeKG"
@@ -99,24 +83,16 @@ def test_extract_subgraph_wo_docs(input_dict):
99
83
  )
100
84
 
101
85
 
102
- def test_extract_subgraph_w_docs(input_dict):
86
+ def test_extract_subgraph_w_docs(agent_state):
103
87
  """
104
- Test the subgraph extraction tool with a document as reference (i.e., endotype document)
105
- using OpenAI model.
88
+ As a knowledge graph agent, I would like you to call a tool called `subgraph_extraction`.
89
+ After calling the tool, restrain yourself to call any other tool.
106
90
 
107
91
  Args:
108
- input_dict: Input dictionary.
92
+ agent_state: Agent state in the form of a dictionary.
109
93
  """
110
- # Prepare LLM and embedding model
111
- input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
112
- input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
113
-
114
- # Setup the app
115
- unique_id = 12345
116
- app = get_app(unique_id, llm_model=input_dict["llm_model"])
117
- config = {"configurable": {"thread_id": unique_id}}
118
94
  # Update state
119
- input_dict["uploaded_files"] = [
95
+ agent_state["uploaded_files"] = [
120
96
  {
121
97
  "file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
122
98
  "file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
@@ -125,33 +101,28 @@ def test_extract_subgraph_w_docs(input_dict):
125
101
  "uploaded_timestamp": "2024-11-05 00:00:00",
126
102
  }
127
103
  ]
128
- app.update_state(
129
- config,
130
- input_dict,
131
- )
132
- prompt = """
133
- Please ONLY invoke `subgraph_extraction` tool without calling any other tools
134
- to respond to the following prompt:
135
104
 
105
+ prompt = """
136
106
  Extract all relevant information related to nodes of genes related to inflammatory bowel disease
137
107
  (IBD) that existed in the knowledge graph.
138
108
  Please set the extraction name for this process as `subkg_12345`.
139
109
  """
140
110
 
141
- # Test the tool subgraph_extraction
142
- response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
111
+ # Instantiate the SubgraphExtractionTool
112
+ subgraph_extraction_tool = SubgraphExtractionTool()
143
113
 
144
- # Check assistant message
145
- assistant_msg = response["messages"][-1].content
146
- assert isinstance(assistant_msg, str)
114
+ # Invoking the subgraph_extraction_tool
115
+ response = subgraph_extraction_tool.invoke(
116
+ input={"prompt": prompt,
117
+ "tool_call_id": "subgraph_extraction_tool",
118
+ "state": agent_state,
119
+ "arg_data": {"extraction_name": "subkg_12345"}})
147
120
 
148
121
  # Check tool message
149
- tool_msg = response["messages"][-2]
150
- assert tool_msg.name == "subgraph_extraction"
122
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
151
123
 
152
124
  # Check extracted subgraph dictionary
153
- current_state = app.get_state(config)
154
- dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
125
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
155
126
  assert isinstance(dic_extracted_graph, dict)
156
127
  assert dic_extracted_graph["name"] == "subkg_12345"
157
128
  assert dic_extracted_graph["graph_source"] == "PrimeKG"
@@ -2,5 +2,6 @@
2
2
  This file is used to import all the models in the package.
3
3
  '''
4
4
  from . import subgraph_extraction
5
+ from . import multimodal_subgraph_extraction
5
6
  from . import subgraph_summarization
6
7
  from . import graphrag_reasoning
@@ -0,0 +1,374 @@
1
+ """
2
+ Tool for performing multimodal subgraph extraction.
3
+ """
4
+
5
+ from typing import Type, Annotated
6
+ import logging
7
+ import pickle
8
+ import numpy as np
9
+ import pandas as pd
10
+ import hydra
11
+ import networkx as nx
12
+ from pydantic import BaseModel, Field
13
+ from langchain_core.tools import BaseTool
14
+ from langchain_core.messages import ToolMessage
15
+ from langchain_core.tools.base import InjectedToolCallId
16
+ from langgraph.types import Command
17
+ from langgraph.prebuilt import InjectedState
18
+ import torch
19
+ from torch_geometric.data import Data
20
+ from ..utils.extractions.multimodal_pcst import MultimodalPCSTPruning
21
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
22
+ from .load_arguments import ArgumentData
23
+
24
+ # Initialize logger
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class MultimodalSubgraphExtractionInput(BaseModel):
30
+ """
31
+ MultimodalSubgraphExtractionInput is a Pydantic model representing an input
32
+ for extracting a subgraph.
33
+
34
+ Args:
35
+ prompt: Prompt to interact with the backend.
36
+ tool_call_id: Tool call ID.
37
+ state: Injected state.
38
+ arg_data: Argument for analytical process over graph data.
39
+ """
40
+
41
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
42
+ description="Tool call ID."
43
+ )
44
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
45
+ prompt: str = Field(description="Prompt to interact with the backend.")
46
+ arg_data: ArgumentData = Field(
47
+ description="Experiment over graph data.", default=None
48
+ )
49
+
50
+
51
+ class MultimodalSubgraphExtractionTool(BaseTool):
52
+ """
53
+ This tool performs subgraph extraction based on user's prompt by taking into account
54
+ the top-k nodes and edges.
55
+ """
56
+
57
+ name: str = "subgraph_extraction"
58
+ description: str = "A tool for subgraph extraction based on user's prompt."
59
+ args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
60
+
61
+ def _prepare_query_modalities(self,
62
+ prompt_emb: list,
63
+ state: Annotated[dict, InjectedState],
64
+ pyg_graph: Data) -> pd.DataFrame:
65
+ """
66
+ Prepare the modality-specific query for subgraph extraction.
67
+
68
+ Args:
69
+ prompt_emb: The embedding of the user prompt in a list.
70
+ state: The injected state for the tool.
71
+ pyg_graph: The PyTorch Geometric graph Data.
72
+
73
+ Returns:
74
+ A DataFrame containing the query embeddings and modalities.
75
+ """
76
+ # Initialize dataframes
77
+ multimodal_df = pd.DataFrame({"name": []})
78
+ query_df = pd.DataFrame({"node_id": [],
79
+ "node_type": [],
80
+ "x": [],
81
+ "desc_x": [],
82
+ "use_description": []})
83
+
84
+ # Loop over the uploaded files and find multimodal files
85
+ for i in range(len(state["uploaded_files"])):
86
+ # Check if multimodal file is uploaded
87
+ if state["uploaded_files"][i]["file_type"] == "multimodal":
88
+ # Read the Excel file
89
+ multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
90
+ sheet_name=None)
91
+
92
+ # Check if the multimodal_df is empty
93
+ if len(multimodal_df) > 0:
94
+ # Merge all obtained dataframes into a single dataframe
95
+ multimodal_df = pd.concat(multimodal_df).reset_index()
96
+ multimodal_df.drop(columns=["level_1"], inplace=True)
97
+ multimodal_df.rename(columns={"level_0": "q_node_type",
98
+ "name": "q_node_name"}, inplace=True)
99
+ # Since an excel sheet name could not contain a `/`,
100
+ # but the node type can be 'gene/protein' as exists in the PrimeKG
101
+ multimodal_df["q_node_type"] = multimodal_df.q_node_type.apply(
102
+ lambda x: x.replace('-', '/')
103
+ )
104
+
105
+ # Convert PyG graph to a DataFrame for easier filtering
106
+ graph_df = pd.DataFrame({
107
+ "node_id": pyg_graph.node_id,
108
+ "node_name": pyg_graph.node_name,
109
+ "node_type": pyg_graph.node_type,
110
+ "x": pyg_graph.x,
111
+ "desc_x": pyg_graph.desc_x.tolist(),
112
+ })
113
+
114
+ # Make a query dataframe by merging the graph_df and multimodal_df
115
+ query_df = graph_df.merge(multimodal_df, how='cross')
116
+ query_df = query_df[
117
+ query_df.apply(
118
+ lambda x:
119
+ (x['q_node_name'].lower() in x['node_name'].lower()) & # node name
120
+ (x['node_type'] == x['q_node_type']), # node type
121
+ axis=1
122
+ )
123
+ ]
124
+ query_df = query_df[['node_id', 'node_type', 'x', 'desc_x']].reset_index(drop=True)
125
+ query_df['use_description'] = False # set to False for modal-specific embeddings
126
+
127
+ # Update the state by adding the the selected node IDs
128
+ state["selections"] = query_df.groupby("node_type")["node_id"].apply(list).to_dict()
129
+
130
+ # Append a user prompt to the query dataframe
131
+ query_df = pd.concat([
132
+ query_df,
133
+ pd.DataFrame({
134
+ 'node_id': 'user_prompt',
135
+ 'node_type': 'prompt',
136
+ 'x': prompt_emb,
137
+ 'desc_x': prompt_emb,
138
+ 'use_description': True # set to True for user prompt embedding
139
+ })
140
+ ]).reset_index(drop=True)
141
+
142
+ return query_df
143
+
144
+ def _perform_subgraph_extraction(self,
145
+ state: Annotated[dict, InjectedState],
146
+ cfg: dict,
147
+ pyg_graph: Data,
148
+ query_df: pd.DataFrame) -> dict:
149
+ """
150
+ Perform multimodal subgraph extraction based on modal-specific embeddings.
151
+
152
+ Args:
153
+ state: The injected state for the tool.
154
+ cfg: The configuration dictionary.
155
+ pyg_graph: The PyTorch Geometric graph Data.
156
+ query_df: The DataFrame containing the query embeddings and modalities.
157
+
158
+ Returns:
159
+ A dictionary containing the extracted subgraph with nodes and edges.
160
+ """
161
+ # Initialize the subgraph dictionary
162
+ subgraphs = {}
163
+ subgraphs["nodes"] = []
164
+ subgraphs["edges"] = []
165
+
166
+ # Loop over query embeddings and modalities
167
+ for q in query_df.iterrows():
168
+ # Prepare the PCSTPruning object and extract the subgraph
169
+ # Parameters were set in the configuration file obtained from Hydra
170
+ subgraph = MultimodalPCSTPruning(
171
+ topk=state["topk_nodes"],
172
+ topk_e=state["topk_edges"],
173
+ cost_e=cfg.cost_e,
174
+ c_const=cfg.c_const,
175
+ root=cfg.root,
176
+ num_clusters=cfg.num_clusters,
177
+ pruning=cfg.pruning,
178
+ verbosity_level=cfg.verbosity_level,
179
+ use_description=q[1]['use_description'],
180
+ ).extract_subgraph(pyg_graph,
181
+ torch.tensor(q[1]['desc_x']), # description embedding
182
+ torch.tensor(q[1]['x']), # modal-specific embedding
183
+ q[1]['node_type'])
184
+
185
+ # Append the extracted subgraph to the dictionary
186
+ subgraphs["nodes"].append(subgraph["nodes"].tolist())
187
+ subgraphs["edges"].append(subgraph["edges"].tolist())
188
+
189
+ # Concatenate and get unique node and edge indices
190
+ subgraphs["nodes"] = np.unique(
191
+ np.concatenate([np.array(list_) for list_ in subgraphs["nodes"]])
192
+ )
193
+ subgraphs["edges"] = np.unique(
194
+ np.concatenate([np.array(list_) for list_ in subgraphs["edges"]])
195
+ )
196
+
197
+ return subgraphs
198
+
199
+ def _prepare_final_subgraph(self,
200
+ state:Annotated[dict, InjectedState],
201
+ subgraph: dict,
202
+ graph: dict,
203
+ cfg) -> dict:
204
+ """
205
+ Prepare the subgraph based on the extracted subgraph.
206
+
207
+ Args:
208
+ state: The injected state for the tool.
209
+ subgraph: The extracted subgraph.
210
+ graph: The initial graph containing PyG and textualized graph.
211
+ cfg: The configuration dictionary.
212
+
213
+ Returns:
214
+ A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
215
+ """
216
+ # print(subgraph)
217
+ # Prepare the PyTorch Geometric graph
218
+ mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
219
+ pyg_graph = Data(
220
+ # Node features
221
+ # x=pyg_graph.x[subgraph["nodes"]],
222
+ x=[graph["pyg"].x[i] for i in subgraph["nodes"]],
223
+ node_id=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
224
+ node_name=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
225
+ enriched_node=np.array(graph["pyg"].enriched_node)[subgraph["nodes"]].tolist(),
226
+ num_nodes=len(subgraph["nodes"]),
227
+ # Edge features
228
+ edge_index=torch.LongTensor(
229
+ [
230
+ [
231
+ mapping[i]
232
+ for i in graph["pyg"].edge_index[:, subgraph["edges"]][0].tolist()
233
+ ],
234
+ [
235
+ mapping[i]
236
+ for i in graph["pyg"].edge_index[:, subgraph["edges"]][1].tolist()
237
+ ],
238
+ ]
239
+ ),
240
+ edge_attr=graph["pyg"].edge_attr[subgraph["edges"]],
241
+ edge_type=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
242
+ relation=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
243
+ label=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
244
+ enriched_edge=np.array(graph["pyg"].enriched_edge)[subgraph["edges"]].tolist(),
245
+ )
246
+
247
+ # Networkx DiGraph construction to be visualized in the frontend
248
+ nx_graph = nx.DiGraph()
249
+ # Add nodes with attributes
250
+ node_colors = {n: cfg.node_colors_dict[k]
251
+ for k, v in state["selections"].items() for n in v}
252
+ for n in pyg_graph.node_name:
253
+ nx_graph.add_node(n, color=node_colors.get(n, None))
254
+
255
+ # Add edges with attributes
256
+ edges = zip(
257
+ pyg_graph.edge_index[0].tolist(),
258
+ pyg_graph.edge_index[1].tolist(),
259
+ pyg_graph.edge_type
260
+ )
261
+ for src, dst, edge_type in edges:
262
+ nx_graph.add_edge(
263
+ pyg_graph.node_name[src],
264
+ pyg_graph.node_name[dst],
265
+ relation=edge_type,
266
+ label=edge_type,
267
+ )
268
+
269
+ # Prepare the textualized subgraph
270
+ textualized_graph = (
271
+ graph["text"]["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
272
+ + "\n"
273
+ + graph["text"]["edges"].iloc[subgraph["edges"]].to_csv(index=False)
274
+ )
275
+
276
+ return {
277
+ "graph_pyg": pyg_graph,
278
+ "graph_nx": nx_graph,
279
+ "graph_text": textualized_graph,
280
+ }
281
+
282
+ def _run(
283
+ self,
284
+ tool_call_id: Annotated[str, InjectedToolCallId],
285
+ state: Annotated[dict, InjectedState],
286
+ prompt: str,
287
+ arg_data: ArgumentData = None,
288
+ ) -> Command:
289
+ """
290
+ Run the subgraph extraction tool.
291
+
292
+ Args:
293
+ tool_call_id: The tool call ID for the tool.
294
+ state: Injected state for the tool.
295
+ prompt: The prompt to interact with the backend.
296
+ arg_data (ArgumentData): The argument data.
297
+
298
+ Returns:
299
+ Command: The command to be executed.
300
+ """
301
+ logger.log(logging.INFO, "Invoking subgraph_extraction tool")
302
+
303
+ # Load hydra configuration
304
+ with hydra.initialize(version_base=None, config_path="../configs"):
305
+ cfg = hydra.compose(
306
+ config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
307
+ )
308
+ cfg = cfg.tools.multimodal_subgraph_extraction
309
+
310
+ # Retrieve source graph from the state
311
+ initial_graph = {}
312
+ initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
313
+ # logger.log(logging.INFO, "Source graph: %s", source_graph)
314
+
315
+ # Load the knowledge graph
316
+ with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
317
+ initial_graph["pyg"] = pickle.load(f)
318
+ with open(initial_graph["source"]["kg_text_path"], "rb") as f:
319
+ initial_graph["text"] = pickle.load(f)
320
+
321
+ # Prepare the query embeddings and modalities
322
+ query_df = self._prepare_query_modalities(
323
+ [EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)],
324
+ state,
325
+ initial_graph["pyg"]
326
+ )
327
+
328
+ # Perform subgraph extraction
329
+ subgraphs = self._perform_subgraph_extraction(state,
330
+ cfg,
331
+ initial_graph["pyg"],
332
+ query_df)
333
+
334
+ # Prepare subgraph as a NetworkX graph and textualized graph
335
+ final_subgraph = self._prepare_final_subgraph(state,
336
+ subgraphs,
337
+ initial_graph,
338
+ cfg)
339
+
340
+ # Prepare the dictionary of extracted graph
341
+ dic_extracted_graph = {
342
+ "name": arg_data.extraction_name,
343
+ "tool_call_id": tool_call_id,
344
+ "graph_source": initial_graph["source"]["name"],
345
+ "topk_nodes": state["topk_nodes"],
346
+ "topk_edges": state["topk_edges"],
347
+ "graph_dict": {
348
+ "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
349
+ "edges": list(final_subgraph["graph_nx"].edges(data=True)),
350
+ },
351
+ "graph_text": final_subgraph["graph_text"],
352
+ "graph_summary": None,
353
+ }
354
+
355
+ # Prepare the dictionary of updated state
356
+ dic_updated_state_for_model = {}
357
+ for key, value in {
358
+ "dic_extracted_graph": [dic_extracted_graph],
359
+ }.items():
360
+ if value:
361
+ dic_updated_state_for_model[key] = value
362
+
363
+ # Return the updated state of the tool
364
+ return Command(
365
+ update=dic_updated_state_for_model | {
366
+ # update the message history
367
+ "messages": [
368
+ ToolMessage(
369
+ content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
370
+ tool_call_id=tool_call_id,
371
+ )
372
+ ],
373
+ }
374
+ )
@@ -2,3 +2,4 @@
2
2
  This file is used to import all the models in the package.
3
3
  '''
4
4
  from . import pcst
5
+ from . import multimodal_pcst
@@ -0,0 +1,292 @@
1
+ """
2
+ Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
+ """
4
+
5
+ from typing import Tuple, NamedTuple
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import pcst_fast
10
+ from torch_geometric.data.data import Data
11
+
12
+ class MultimodalPCSTPruning(NamedTuple):
13
+ """
14
+ Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
15
+ (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
16
+ Question Answering', NeurIPS 2024) paper.
17
+ https://arxiv.org/abs/2402.07630
18
+ https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
19
+
20
+ Args:
21
+ topk: The number of top nodes to consider.
22
+ topk_e: The number of top edges to consider.
23
+ cost_e: The cost of the edges.
24
+ c_const: The constant value for the cost of the edges computation.
25
+ root: The root node of the subgraph, -1 for unrooted.
26
+ num_clusters: The number of clusters.
27
+ pruning: The pruning strategy to use.
28
+ verbosity_level: The verbosity level.
29
+ """
30
+ topk: int = 3
31
+ topk_e: int = 3
32
+ cost_e: float = 0.5
33
+ c_const: float = 0.01
34
+ root: int = -1
35
+ num_clusters: int = 1
36
+ pruning: str = "gw"
37
+ verbosity_level: int = 0
38
+ use_description: bool = False
39
+
40
+ def _compute_node_prizes(self,
41
+ graph: Data,
42
+ query_emb: torch.Tensor,
43
+ modality: str) :
44
+ """
45
+ Compute the node prizes based on the cosine similarity between the query and nodes.
46
+
47
+ Args:
48
+ graph: The knowledge graph in PyTorch Geometric Data format.
49
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
50
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
51
+ modality: The modality to use for the subgraph extraction based on the node type.
52
+
53
+ Returns:
54
+ The prizes of the nodes.
55
+ """
56
+ # Convert PyG graph to a DataFrame
57
+ graph_df = pd.DataFrame({
58
+ "node_type": graph.node_type,
59
+ "desc_x": [x.tolist() for x in graph.desc_x],
60
+ "x": [list(x) for x in graph.x],
61
+ "score": [0.0 for _ in range(len(graph.node_id))],
62
+ })
63
+
64
+ # Calculate cosine similarity for text features and update the score
65
+ if self.use_description:
66
+ graph_df.loc[:, "score"] = torch.nn.CosineSimilarity(dim=-1)(
67
+ query_emb,
68
+ torch.tensor(list(graph_df.desc_x.values)) # Using textual description features
69
+ ).tolist()
70
+ else:
71
+ graph_df.loc[graph_df["node_type"] == modality,
72
+ "score"] = torch.nn.CosineSimilarity(dim=-1)(
73
+ query_emb,
74
+ torch.tensor(list(graph_df[graph_df["node_type"]== modality].x.values))
75
+ ).tolist()
76
+
77
+ # Set the prizes for nodes based on the similarity scores
78
+ n_prizes = torch.tensor(graph_df.score.values, dtype=torch.float32)
79
+ # n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
80
+ topk = min(self.topk, graph.num_nodes)
81
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
82
+ n_prizes = torch.zeros_like(n_prizes)
83
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
84
+
85
+ return n_prizes
86
+
87
+ def _compute_edge_prizes(self,
88
+ graph: Data,
89
+ text_emb: torch.Tensor) :
90
+ """
91
+ Compute the node prizes based on the cosine similarity between the query and nodes.
92
+
93
+ Args:
94
+ graph: The knowledge graph in PyTorch Geometric Data format.
95
+ text_emb: The textual description embedding in PyTorch Tensor format.
96
+
97
+ Returns:
98
+ The prizes of the nodes.
99
+ """
100
+ # Note that as of now, the edge features are based on textual features
101
+ # Compute prizes for edges
102
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(text_emb, graph.edge_attr)
103
+ unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
104
+ topk_e = min(self.topk_e, unique_prizes.size(0))
105
+ topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
106
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
107
+ last_topk_e_value = topk_e
108
+ for k in range(topk_e):
109
+ indices = inverse_indices == (
110
+ unique_prizes == topk_e_values[k]
111
+ ).nonzero(as_tuple=True)[0]
112
+ value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
113
+ e_prizes[indices] = value
114
+ last_topk_e_value = value * (1 - self.c_const)
115
+
116
+ return e_prizes
117
+
118
+ def compute_prizes(self,
119
+ graph: Data,
120
+ text_emb: torch.Tensor,
121
+ query_emb: torch.Tensor,
122
+ modality: str):
123
+ """
124
+ Compute the node prizes based on the cosine similarity between the query and nodes,
125
+ as well as the edge prizes based on the cosine similarity between the query and edges.
126
+ Note that the node and edge embeddings shall use the same embedding model and dimensions
127
+ with the query.
128
+
129
+ Args:
130
+ graph: The knowledge graph in PyTorch Geometric Data format.
131
+ text_emb: The textual description embedding in PyTorch Tensor format.
132
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
133
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
134
+ modality: The modality to use for the subgraph extraction based on node type.
135
+
136
+ Returns:
137
+ The prizes of the nodes and edges.
138
+ """
139
+ # Compute prizes for nodes
140
+ n_prizes = self._compute_node_prizes(graph, query_emb, modality)
141
+
142
+ # Compute prizes for edges
143
+ e_prizes = self._compute_edge_prizes(graph, text_emb)
144
+
145
+ return {"nodes": n_prizes, "edges": e_prizes}
146
+
147
+ def compute_subgraph_costs(self,
148
+ graph: Data,
149
+ prizes: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
150
+ """
151
+ Compute the costs in constructing the subgraph proposed by G-Retriever paper.
152
+
153
+ Args:
154
+ graph: The knowledge graph in PyTorch Geometric Data format.
155
+ prizes: The prizes of the nodes and the edges.
156
+
157
+ Returns:
158
+ edges: The edges of the subgraph, consisting of edges and number of edges without
159
+ virtual edges.
160
+ prizes: The prizes of the subgraph.
161
+ costs: The costs of the subgraph.
162
+ """
163
+ # Logic to reduce the cost of the edges such that at least one edge is selected
164
+ updated_cost_e = min(
165
+ self.cost_e,
166
+ prizes["edges"].max().item() * (1 - self.c_const / 2),
167
+ )
168
+
169
+ # Initialize variables
170
+ edges = []
171
+ costs = []
172
+ virtual = {
173
+ "n_prizes": [],
174
+ "edges": [],
175
+ "costs": [],
176
+ }
177
+ mapping = {"nodes": {}, "edges": {}}
178
+
179
+ # Compute the costs, edges, and virtual variables based on the prizes
180
+ for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
181
+ prize_e = prizes["edges"][i]
182
+ if prize_e <= updated_cost_e:
183
+ mapping["edges"][len(edges)] = i
184
+ edges.append((src, dst))
185
+ costs.append(updated_cost_e - prize_e)
186
+ else:
187
+ virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
188
+ mapping["nodes"][virtual_node_id] = i
189
+ virtual["edges"].append((src, virtual_node_id))
190
+ virtual["edges"].append((virtual_node_id, dst))
191
+ virtual["costs"].append(0)
192
+ virtual["costs"].append(0)
193
+ virtual["n_prizes"].append(prize_e - updated_cost_e)
194
+ prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
195
+ edges_dict = {}
196
+ edges_dict["edges"] = edges
197
+ edges_dict["num_prior_edges"] = len(edges)
198
+ # Final computation of the costs and edges based on the virtual costs and virtual edges
199
+ if len(virtual["costs"]) > 0:
200
+ costs = np.array(costs + virtual["costs"])
201
+ edges = np.array(edges + virtual["edges"])
202
+ edges_dict["edges"] = edges
203
+
204
+ return edges_dict, prizes, costs, mapping
205
+
206
+ def get_subgraph_nodes_edges(
207
+ self, graph: Data, vertices: np.ndarray, edges_dict: dict, mapping: dict,
208
+ ) -> dict:
209
+ """
210
+ Get the selected nodes and edges of the subgraph based on the vertices and edges computed
211
+ by the PCST algorithm.
212
+
213
+ Args:
214
+ graph: The knowledge graph in PyTorch Geometric Data format.
215
+ vertices: The vertices of the subgraph computed by the PCST algorithm.
216
+ edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
217
+ and the number of prior edges (without virtual edges).
218
+ mapping: The mapping dictionary of the nodes and edges.
219
+ num_prior_edges: The number of edges before adding virtual edges.
220
+
221
+ Returns:
222
+ The selected nodes and edges of the extracted subgraph.
223
+ """
224
+ # Get edges information
225
+ edges = edges_dict["edges"]
226
+ num_prior_edges = edges_dict["num_prior_edges"]
227
+ # Retrieve the selected nodes and edges based on the given vertices and edges
228
+ subgraph_nodes = vertices[vertices < graph.num_nodes]
229
+ subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
230
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
231
+ if len(virtual_vertices) > 0:
232
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
233
+ virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
234
+ subgraph_edges = np.array(subgraph_edges + virtual_edges)
235
+ edge_index = graph.edge_index[:, subgraph_edges]
236
+ subgraph_nodes = np.unique(
237
+ np.concatenate(
238
+ [subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()]
239
+ )
240
+ )
241
+
242
+ return {"nodes": subgraph_nodes, "edges": subgraph_edges}
243
+
244
+ def extract_subgraph(self,
245
+ graph: Data,
246
+ text_emb: torch.Tensor,
247
+ query_emb: torch.Tensor,
248
+ modality: str) -> dict:
249
+ """
250
+ Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
251
+
252
+ Args:
253
+ graph: The knowledge graph in PyTorch Geometric Data format.
254
+ text_emb: The textual description embedding in PyTorch Tensor format.
255
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
256
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
257
+ modality: The modality to use for the subgraph extraction
258
+ (e.g., "text", "sequence", "smiles").
259
+
260
+ Returns:
261
+ The selected nodes and edges of the subgraph.
262
+ """
263
+ # Assert the topk and topk_e values for subgraph retrieval
264
+ assert self.topk > 0, "topk must be greater than or equal to 0"
265
+ assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
266
+
267
+ # Retrieve the top-k nodes and edges based on the query embedding
268
+ prizes = self.compute_prizes(graph, text_emb, query_emb, modality)
269
+
270
+ # Compute costs in constructing the subgraph
271
+ edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
272
+ graph, prizes
273
+ )
274
+
275
+ # Retrieve the subgraph using the PCST algorithm
276
+ result_vertices, result_edges = pcst_fast.pcst_fast(
277
+ edges_dict["edges"],
278
+ prizes,
279
+ costs,
280
+ self.root,
281
+ self.num_clusters,
282
+ self.pruning,
283
+ self.verbosity_level,
284
+ )
285
+
286
+ subgraph = self.get_subgraph_nodes_edges(
287
+ graph,
288
+ result_vertices,
289
+ {"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
290
+ mapping)
291
+
292
+ return subgraph
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aiagents4pharma
3
- Version: 1.36.0
3
+ Version: 1.37.0
4
4
  Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D.
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -62,6 +62,7 @@ Requires-Dist: umap-learn==0.5.7
62
62
  Requires-Dist: plotly-express==0.4.1
63
63
  Requires-Dist: seaborn==0.13.2
64
64
  Requires-Dist: scanpy==1.11.0
65
+ Requires-Dist: openpyxl==3.1.5
65
66
  Dynamic: license-file
66
67
 
67
68
  [![Talk2BioModels](https://github.com/VirtualPatientEngine/AIAgents4Pharma/actions/workflows/tests_talk2biomodels.yml/badge.svg)](https://github.com/VirtualPatientEngine/AIAgents4Pharma/actions/workflows/tests_talk2biomodels.yml)
@@ -9,7 +9,7 @@ aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml,sha2
9
9
  aiagents4pharma/talk2aiagents4pharma/states/__init__.py,sha256=3wSvCpM29oqvVjhbhabm7FNm9Zt0rHO5tEn63YW6doc,108
10
10
  aiagents4pharma/talk2aiagents4pharma/states/state_talk2aiagents4pharma.py,sha256=NxujEBDKubvpV9UG7ERTDRB6psr0XnObCNHyztLAhgo,485
11
11
  aiagents4pharma/talk2aiagents4pharma/tests/__init__.py,sha256=Jbw5tJxSrjGoaK5IX3pJWDCNzhrVQ10lkYq2oQ_KQD8,45
12
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py,sha256=ndHQT4ycWy-uKRAND7JmX_SuNg4g9hJw4UCW0CbKSp0,4165
12
+ aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py,sha256=_zUm8i8vrBbcDgpwExa1sVGr1A9FgZFuwoLS395RnhU,4418
13
13
  aiagents4pharma/talk2biomodels/__init__.py,sha256=1cq1HX2xoi_a0nDPuXYoSTrnL26OHQBW3zXNwwwjFO0,181
14
14
  aiagents4pharma/talk2biomodels/agents/__init__.py,sha256=sn5-fREjMdEvb-OUan3iOqrgYGjplNx3J8hYOaW0Po8,128
15
15
  aiagents4pharma/talk2biomodels/agents/t2b_agent.py,sha256=g0DIW5P-dtJoVyG4weFdDgTrJPL_Dx1MMbTWextJDZ4,3455
@@ -73,17 +73,18 @@ aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py,sha256=6q59gh_NQai
73
73
  aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py,sha256=MLe-twtFnOu-P8P9diYq7jvHBHbWFRRCZLcfpUzqPMg,2806
74
74
  aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=Z0Eo7LTiKk0STsr8VI7wkCLq7PHrK1vYlH4I1hSNLiA,165
75
75
  aiagents4pharma/talk2knowledgegraphs/agents/__init__.py,sha256=iOAzuy_8A03tQDFtSBhC9dldUo62z5gfxcVtXAdLOJs,92
76
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py,sha256=IcXSZ2qQA1m-gS-o0Pj_g1oar8uPdhsbaovloUFka3Q,3058
76
+ aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py,sha256=w4wSSQ9gw_fzpcHZ2Bnqok17YDkFeQ3d72JenvQm6Oc,3089
77
77
  aiagents4pharma/talk2knowledgegraphs/configs/__init__.py,sha256=4_DVdpahaJ55yPl0aZotlFA_MYWLFF2cubWyKtBVI_Q,126
78
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml,sha256=X91262b-wkygiH4HrEr0bIzHxHDuDWwuxLQAmdUe-E4,367
78
+ aiagents4pharma/talk2knowledgegraphs/configs/config.yaml,sha256=-AJXKnR2z5ig0SK_3vLL9JFjNRri7q7blHYFWxoTDl0,417
79
79
  aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
80
80
  aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml,sha256=ENCGROwYFpR6g4QD518h73sshdn3vPVpotBMk1QJcpU,4830
81
81
  aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py,sha256=fKfc3FR7g5KjY9b6jzrU6cwKTVVpkoVZQS3dvUowu34,69
82
82
  aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
83
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml,sha256=4azC4cH-_-zt-bRVgNjkFM24mjNke6Rgn9pNl7XWrPQ,912
83
+ aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml,sha256=WJgd2ZU7_WQ1qlcTfkFlM8u23sH6eU2KgAm0E4kqqfs,941
84
84
  aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py,sha256=C1yyRZW8hqWw46p_bh1vAJp2z9aVvn4HpKjKkjlWIqY,150
85
85
  aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
86
86
  aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml,sha256=Ua99yECXiwp4ZCUDgsDskYbKzcJrv7roQuLj31Zky4c,1037
87
+ aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
87
88
  aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
88
89
  aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml,sha256=U8HvMsYbaOwDwQPATj7EFvLtTy7XZEplE5WMoNjgYYc,1469
89
90
  aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
@@ -94,15 +95,16 @@ aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py,sha256=-LaPLse8BkALqwFe
94
95
  aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py,sha256=KBMhCJ7yjMWqQJJctFYdpjYAlwv48Jl6i1dddXP4f08,7599
95
96
  aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py,sha256=Y-6-nORsnBJlU6rH0skyfr9S9J4PfTWK-af_p5UuknQ,7483
96
97
  aiagents4pharma/talk2knowledgegraphs/states/__init__.py,sha256=XaqorSvx634dWRRlXUdzlisHtYMyqgJ2q7TanzsKlhw,108
97
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py,sha256=6HqGo-awqoyNJG0igm5so5A4Tq8RPkCsjPg8Go38csE,1066
98
+ aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py,sha256=y5bp6yObN-AQtTq-m8ml7UnZaeKYUiPV_yjskAzBJaI,1087
98
99
  aiagents4pharma/talk2knowledgegraphs/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
99
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py,sha256=CCN6cyhEaiXSvIC-4y3ueDSzjDCBYDsmSmOor-DMeF4,3928
100
+ aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py,sha256=PPfHKnfqMbUOBKU7q4VbQvHQymX1M_zTYdysQgVxKCs,3851
100
101
  aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py,sha256=crH0eFA3P8P6IYzi1UWNa4YvRVrtlBzoScf9NaE1lDk,9827
101
102
  aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py,sha256=NFUlsZvhfIrkF4YenWfahrLK93Xhm5UYEGG_uYN2LVM,566
102
103
  aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py,sha256=Pvu0r93CpnhjkfMxc-EiVLpAJ04FdW9iTamCnetu654,2272
103
104
  aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py,sha256=TuIsqcN1Mww3DTqGk6ebgJBWzUWdMWEq2yRQuYSFqvA,4416
104
105
  aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py,sha256=aOKHTber2Cg3mjNjfIa6RZU7XdFj5C2ps1YEUXw76CI,10650
105
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py,sha256=zRi2j9Dm3VFywhhrPjVoJ7z_zJpAEM74MJRXapnhwVE,6246
106
+ aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py,sha256=Da-hXcu41_5Ge4DPlOoY6OqBwYnXPc58Q89wuywqVJM,5806
107
+ aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py,sha256=C2HzmAG1XCeV1hwZzz3-9_2dm_84-i1BvTNWA1pqUwM,5393
106
108
  aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py,sha256=oBqfspXXOxH04OQuPb8BCW0liIQTGKXtaPNSrPpQtFc,7597
107
109
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py,sha256=uYFoE_6zeU10_1mLLAHUr5c4S2XZMSc0Q_860o-KWEw,1517
108
110
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py,sha256=hzX84pheZdEsTtikF2KtBFiH44_xPjYXxLA6p4Ax1CY,1623
@@ -117,9 +119,10 @@ aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py,sh
117
119
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py,sha256=G13Diw7cA5TGINUNO1CDnN4rM6KbepxRXNjuzY578DI,1611
118
120
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py,sha256=pal76wi7WgQWUNk56BrzfFV8jKpbDaHHdbwtgx_gXLI,2410
119
121
  aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py,sha256=K1Y6QM0MDP1IrAdcWkigl8R-O-i-lsL4NCyOrWewhdM,1246
120
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py,sha256=zpD4h7EYtyq0QNOqLd6bkxrPlPb2XN64ceI9ncgESrA,171
122
+ aiagents4pharma/talk2knowledgegraphs/tools/__init__.py,sha256=uleTEbhgvlYw4fOqV4NmoFvxGTon2Oim7jTQ5qPmYoU,216
121
123
  aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py,sha256=OEuOFncDRdb7TQEGq4rkT5On-jI-R7Nt8K5EBzaND8w,5338
122
124
  aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py,sha256=zhmsRp-8vjB5rRekqTA07d3yb-42HWqng9dDMkvK6hM,623
125
+ aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py,sha256=Qjl8hXG8Gv5jQ4pBX8me0pGGakqRZmcDfTGgdEHD9pc,15394
123
126
  aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py,sha256=te06QMFQfgJWrjaGrqpcOYeaV38jwm0KY_rXVSMHkeI,11468
124
127
  aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py,sha256=mDSBOxopDfNhEJeU8fVI8b5lXTYrRzcc97aLbFgYSy4,4413
125
128
  aiagents4pharma/talk2knowledgegraphs/utils/__init__.py,sha256=cZqb3LZLmBnmyAtWFv2Z-4uJvQmx0M4zKsfiWrlM3Pk,195
@@ -138,7 +141,8 @@ aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py,sha256=xSPP-
138
141
  aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py,sha256=CQEGQ6Qsex2T91Vw7zTrclJBbSGGhxeWaVJb8tnURAQ,1691
139
142
  aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py,sha256=I0cD0Fk2Uk27_4jEaIhpoGhoMh_RphY1VtkMnk4dkPg,2011
140
143
  aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py,sha256=z0Jb3tt8VzRjzqI9oVcUvRlPPg6BUdmslfKDIEFE_h8,3013
141
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py,sha256=7gwwtfzKhB8GuOBD47XRi0NprwEXkOzwNl5eeu-hDTI,86
144
+ aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py,sha256=5bt3H6gGSAwN2K-IG7AHwG2lC4yQeMd2_jbhu2z5XKg,116
145
+ aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py,sha256=Irh5JXEhaLZ6Rxv3h5Anif_rGNItyLOGDWg1RACmoDA,12628
142
146
  aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py,sha256=m5p0yoJb7I19ua5yeQfXPf7c4r6S1XPwttsrM7Qoy94,9336
143
147
  aiagents4pharma/talk2scholars/__init__.py,sha256=NOZxTklAH1j1ggu97Ib8Xn9LCKudEWt-8dx8w7yxVD8,180
144
148
  aiagents4pharma/talk2scholars/agents/__init__.py,sha256=c_0Pk85bt-RfK5RMyALM3MXo3qXVMoYS7BOqM9wuFME,317
@@ -225,8 +229,8 @@ aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py,sha256=lyrfpx8NH
225
229
  aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py,sha256=IPD1V9yrBYaDnRe7sR6PrpwR82OBJbA2P_Tc6RbxAbM,2748
226
230
  aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py,sha256=ALwLecy1QVebbsmXJiDj1GhGmyhq2R2tZlAyEl1vfhw,7410
227
231
  aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py,sha256=oIrfbOySgts50ksHKyjcWjRkPRIS88g3Lc0v9mBkU8w,6375
228
- aiagents4pharma-1.36.0.dist-info/licenses/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
229
- aiagents4pharma-1.36.0.dist-info/METADATA,sha256=4S4eCTvL7mAxQUmSDp4SIyn2WAYuOivLpcCdL-j5dGQ,16757
230
- aiagents4pharma-1.36.0.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
231
- aiagents4pharma-1.36.0.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
232
- aiagents4pharma-1.36.0.dist-info/RECORD,,
232
+ aiagents4pharma-1.37.0.dist-info/licenses/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
233
+ aiagents4pharma-1.37.0.dist-info/METADATA,sha256=F-uncJSmjQ9bOlTHKuLJMa311nbF90UL7aJXwn2zVe0,16788
234
+ aiagents4pharma-1.37.0.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
235
+ aiagents4pharma-1.37.0.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
236
+ aiagents4pharma-1.37.0.dist-info/RECORD,,