aiagents4pharma 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +3 -1
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +42 -10
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph summarization.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.output_parsers import StrOutputParser
|
9
|
+
from langchain_core.prompts import ChatPromptTemplate
|
10
|
+
from langchain_core.messages import ToolMessage
|
11
|
+
from langchain_core.tools.base import InjectedToolCallId
|
12
|
+
from langchain_core.tools import BaseTool
|
13
|
+
from langgraph.types import Command
|
14
|
+
from langgraph.prebuilt import InjectedState
|
15
|
+
import hydra
|
16
|
+
|
17
|
+
# Initialize logger
|
18
|
+
logging.basicConfig(level=logging.INFO)
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class SubgraphSummarizationInput(BaseModel):
|
23
|
+
"""
|
24
|
+
SubgraphSummarizationInput is a Pydantic model representing an input for
|
25
|
+
summarizing a given textualized subgraph.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
tool_call_id: Tool call ID.
|
29
|
+
state: Injected state.
|
30
|
+
prompt: Prompt to interact with the backend.
|
31
|
+
extraction_name: Name assigned to the subgraph extraction process
|
32
|
+
"""
|
33
|
+
|
34
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
35
|
+
description="Tool call ID."
|
36
|
+
)
|
37
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
38
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
39
|
+
extraction_name: str = Field(
|
40
|
+
description="""Name assigned to the subgraph extraction process
|
41
|
+
when the subgraph_extraction tool is invoked."""
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class SubgraphSummarizationTool(BaseTool):
|
46
|
+
"""
|
47
|
+
This tool performs subgraph summarization over textualized graph to highlight the most
|
48
|
+
important information in responding to user's prompt.
|
49
|
+
"""
|
50
|
+
|
51
|
+
name: str = "subgraph_summarization"
|
52
|
+
description: str = """A tool to perform subgraph summarization over textualized graph
|
53
|
+
for responding to user's follow-up prompt(s)."""
|
54
|
+
args_schema: Type[BaseModel] = SubgraphSummarizationInput
|
55
|
+
|
56
|
+
def _run(
|
57
|
+
self,
|
58
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
59
|
+
state: Annotated[dict, InjectedState],
|
60
|
+
prompt: str,
|
61
|
+
extraction_name: str,
|
62
|
+
):
|
63
|
+
"""
|
64
|
+
Run the subgraph summarization tool.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
tool_call_id: The tool call ID.
|
68
|
+
state: The injected state.
|
69
|
+
prompt: The prompt to interact with the backend.
|
70
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
71
|
+
"""
|
72
|
+
logger.log(
|
73
|
+
logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
|
74
|
+
)
|
75
|
+
|
76
|
+
# Load hydra configuration
|
77
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
78
|
+
cfg = hydra.compose(
|
79
|
+
config_name="config", overrides=["tools/subgraph_summarization=default"]
|
80
|
+
)
|
81
|
+
cfg = cfg.tools.subgraph_summarization
|
82
|
+
|
83
|
+
# Load the extracted graph
|
84
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
85
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
86
|
+
|
87
|
+
# Prepare prompt template
|
88
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
89
|
+
[
|
90
|
+
("system", cfg.prompt_subgraph_summarization),
|
91
|
+
("human", "{input}"),
|
92
|
+
]
|
93
|
+
)
|
94
|
+
|
95
|
+
# Prepare chain
|
96
|
+
chain = prompt_template | state["llm_model"] | StrOutputParser()
|
97
|
+
|
98
|
+
# Return the subgraph and textualized graph as JSON response
|
99
|
+
response = chain.invoke(
|
100
|
+
{
|
101
|
+
"input": prompt,
|
102
|
+
"textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
|
103
|
+
}
|
104
|
+
)
|
105
|
+
|
106
|
+
# Store the response as graph_summary in the extracted graph
|
107
|
+
for key, value in extracted_graph.items():
|
108
|
+
if key == extraction_name:
|
109
|
+
value["graph_summary"] = response
|
110
|
+
|
111
|
+
# Prepare the dictionary of updated state
|
112
|
+
dic_updated_state_for_model = {}
|
113
|
+
for key, value in {
|
114
|
+
"dic_extracted_graph": list(extracted_graph.values()),
|
115
|
+
}.items():
|
116
|
+
if value:
|
117
|
+
dic_updated_state_for_model[key] = value
|
118
|
+
|
119
|
+
# Return the updated state of the tool
|
120
|
+
return Command(
|
121
|
+
update=dic_updated_state_for_model
|
122
|
+
| {
|
123
|
+
# update the message history
|
124
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
125
|
+
}
|
126
|
+
)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
"""
|
2
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import time
|
6
|
+
from typing import List
|
7
|
+
import subprocess
|
8
|
+
import ollama
|
9
|
+
from langchain_ollama import OllamaEmbeddings
|
10
|
+
from .embeddings import Embeddings
|
11
|
+
|
12
|
+
class EmbeddingWithOllama(Embeddings):
|
13
|
+
"""
|
14
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
15
|
+
"""
|
16
|
+
def __init__(self, model_name: str):
|
17
|
+
"""
|
18
|
+
Initialize the EmbeddingWithOllama class.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
model_name: The name of the Ollama model to be used.
|
22
|
+
"""
|
23
|
+
# Setup the Ollama server
|
24
|
+
self.__setup(model_name)
|
25
|
+
|
26
|
+
# Set parameters
|
27
|
+
self.model_name = model_name
|
28
|
+
|
29
|
+
# Prepare model
|
30
|
+
self.model = OllamaEmbeddings(model=self.model_name)
|
31
|
+
|
32
|
+
def __setup(self, model_name: str) -> None:
|
33
|
+
"""
|
34
|
+
Check if the Ollama model is available and run the Ollama server if needed.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
model_name: The name of the Ollama model to be used.
|
38
|
+
"""
|
39
|
+
try:
|
40
|
+
models_list = ollama.list()["models"]
|
41
|
+
if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
|
42
|
+
ollama.pull(model_name)
|
43
|
+
time.sleep(30)
|
44
|
+
raise ValueError(f"Pulled {model_name} model")
|
45
|
+
except Exception as e:
|
46
|
+
with subprocess.Popen(
|
47
|
+
"ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
48
|
+
):
|
49
|
+
time.sleep(10)
|
50
|
+
raise ValueError(f"Error: {e} and restarted Ollama server.") from e
|
51
|
+
|
52
|
+
def embed_documents(self, texts: List[str]) -> List[float]:
|
53
|
+
"""
|
54
|
+
Generate embedding for a list of input texts using Ollama model.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
texts: The list of texts to be embedded.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The list of embeddings for the given texts.
|
61
|
+
"""
|
62
|
+
|
63
|
+
# Generate the embedding
|
64
|
+
embeddings = self.model.embed_documents(texts)
|
65
|
+
|
66
|
+
return embeddings
|
67
|
+
|
68
|
+
def embed_query(self, text: str) -> List[float]:
|
69
|
+
"""
|
70
|
+
Generate embeddings for an input text using Ollama model.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
text: A query to be embedded.
|
74
|
+
Returns:
|
75
|
+
The embeddings for the given query.
|
76
|
+
"""
|
77
|
+
|
78
|
+
# Generate the embedding
|
79
|
+
embeddings = self.model.embed_query(text)
|
80
|
+
|
81
|
+
return embeddings
|
@@ -0,0 +1,225 @@
|
|
1
|
+
"""
|
2
|
+
Exctraction of subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Tuple, NamedTuple
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import pcst_fast
|
9
|
+
from torch_geometric.data.data import Data
|
10
|
+
|
11
|
+
class PCSTPruning(NamedTuple):
|
12
|
+
"""
|
13
|
+
Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
|
14
|
+
(He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
|
15
|
+
Question Answering', NeurIPS 2024) paper.
|
16
|
+
https://arxiv.org/abs/2402.07630
|
17
|
+
https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
|
18
|
+
|
19
|
+
Args:
|
20
|
+
topk: The number of top nodes to consider.
|
21
|
+
topk_e: The number of top edges to consider.
|
22
|
+
cost_e: The cost of the edges.
|
23
|
+
c_const: The constant value for the cost of the edges computation.
|
24
|
+
root: The root node of the subgraph, -1 for unrooted.
|
25
|
+
num_clusters: The number of clusters.
|
26
|
+
pruning: The pruning strategy to use.
|
27
|
+
verbosity_level: The verbosity level.
|
28
|
+
"""
|
29
|
+
topk: int = 3
|
30
|
+
topk_e: int = 3
|
31
|
+
cost_e: float = 0.5
|
32
|
+
c_const: float = 0.01
|
33
|
+
root: int = -1
|
34
|
+
num_clusters: int = 1
|
35
|
+
pruning: str = "gw"
|
36
|
+
verbosity_level: int = 0
|
37
|
+
|
38
|
+
def compute_prizes(self, graph: Data, query_emb: torch.Tensor) -> np.ndarray:
|
39
|
+
"""
|
40
|
+
Compute the node prizes based on the cosine similarity between the query and nodes,
|
41
|
+
as well as the edge prizes based on the cosine similarity between the query and edges.
|
42
|
+
Note that the node and edge embeddings shall use the same embedding model and dimensions
|
43
|
+
with the query.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
47
|
+
query_emb: The query embedding in PyTorch Tensor format.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The prizes of the nodes and edges.
|
51
|
+
"""
|
52
|
+
# Compute prizes for nodes
|
53
|
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
|
54
|
+
topk = min(self.topk, graph.num_nodes)
|
55
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
56
|
+
n_prizes = torch.zeros_like(n_prizes)
|
57
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
58
|
+
|
59
|
+
# Compute prizes for edges
|
60
|
+
# e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
|
61
|
+
# topk_e = min(self.topk_e, e_prizes.unique().size(0))
|
62
|
+
# topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
|
63
|
+
# e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
64
|
+
# last_topk_e_value = topk_e
|
65
|
+
# for k in range(topk_e):
|
66
|
+
# indices = e_prizes == topk_e_values[k]
|
67
|
+
# value = min((topk_e - k) / sum(indices), last_topk_e_value)
|
68
|
+
# e_prizes[indices] = value
|
69
|
+
# last_topk_e_value = value * (1 - self.c_const)
|
70
|
+
|
71
|
+
# Optimized version of the above code
|
72
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
|
73
|
+
unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
|
74
|
+
topk_e = min(self.topk_e, unique_prizes.size(0))
|
75
|
+
topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
|
76
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
77
|
+
last_topk_e_value = topk_e
|
78
|
+
for k in range(topk_e):
|
79
|
+
indices = inverse_indices == (
|
80
|
+
unique_prizes == topk_e_values[k]
|
81
|
+
).nonzero(as_tuple=True)[0]
|
82
|
+
value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
|
83
|
+
e_prizes[indices] = value
|
84
|
+
last_topk_e_value = value * (1 - self.c_const)
|
85
|
+
|
86
|
+
return {"nodes": n_prizes, "edges": e_prizes}
|
87
|
+
|
88
|
+
def compute_subgraph_costs(
|
89
|
+
self, graph: Data, prizes: dict
|
90
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
91
|
+
"""
|
92
|
+
Compute the costs in constructing the subgraph proposed by G-Retriever paper.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
96
|
+
prizes: The prizes of the nodes and the edges.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
edges: The edges of the subgraph, consisting of edges and number of edges without
|
100
|
+
virtual edges.
|
101
|
+
prizes: The prizes of the subgraph.
|
102
|
+
costs: The costs of the subgraph.
|
103
|
+
"""
|
104
|
+
# Logic to reduce the cost of the edges such that at least one edge is selected
|
105
|
+
updated_cost_e = min(
|
106
|
+
self.cost_e,
|
107
|
+
prizes["edges"].max().item() * (1 - self.c_const / 2),
|
108
|
+
)
|
109
|
+
|
110
|
+
# Initialize variables
|
111
|
+
edges = []
|
112
|
+
costs = []
|
113
|
+
virtual = {
|
114
|
+
"n_prizes": [],
|
115
|
+
"edges": [],
|
116
|
+
"costs": [],
|
117
|
+
}
|
118
|
+
mapping = {"nodes": {}, "edges": {}}
|
119
|
+
|
120
|
+
# Compute the costs, edges, and virtual variables based on the prizes
|
121
|
+
for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
|
122
|
+
prize_e = prizes["edges"][i]
|
123
|
+
if prize_e <= updated_cost_e:
|
124
|
+
mapping["edges"][len(edges)] = i
|
125
|
+
edges.append((src, dst))
|
126
|
+
costs.append(updated_cost_e - prize_e)
|
127
|
+
else:
|
128
|
+
virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
|
129
|
+
mapping["nodes"][virtual_node_id] = i
|
130
|
+
virtual["edges"].append((src, virtual_node_id))
|
131
|
+
virtual["edges"].append((virtual_node_id, dst))
|
132
|
+
virtual["costs"].append(0)
|
133
|
+
virtual["costs"].append(0)
|
134
|
+
virtual["n_prizes"].append(prize_e - updated_cost_e)
|
135
|
+
prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
|
136
|
+
edges_dict = {}
|
137
|
+
edges_dict["edges"] = edges
|
138
|
+
edges_dict["num_prior_edges"] = len(edges)
|
139
|
+
# Final computation of the costs and edges based on the virtual costs and virtual edges
|
140
|
+
if len(virtual["costs"]) > 0:
|
141
|
+
costs = np.array(costs + virtual["costs"])
|
142
|
+
edges = np.array(edges + virtual["edges"])
|
143
|
+
edges_dict["edges"] = edges
|
144
|
+
|
145
|
+
return edges_dict, prizes, costs, mapping
|
146
|
+
|
147
|
+
def get_subgraph_nodes_edges(
|
148
|
+
self, graph: Data, vertices: np.ndarray, edges_dict: dict, mapping: dict,
|
149
|
+
) -> dict:
|
150
|
+
"""
|
151
|
+
Get the selected nodes and edges of the subgraph based on the vertices and edges computed
|
152
|
+
by the PCST algorithm.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
156
|
+
vertices: The vertices of the subgraph computed by the PCST algorithm.
|
157
|
+
edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
|
158
|
+
and the number of prior edges (without virtual edges).
|
159
|
+
mapping: The mapping dictionary of the nodes and edges.
|
160
|
+
num_prior_edges: The number of edges before adding virtual edges.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
The selected nodes and edges of the extracted subgraph.
|
164
|
+
"""
|
165
|
+
# Get edges information
|
166
|
+
edges = edges_dict["edges"]
|
167
|
+
num_prior_edges = edges_dict["num_prior_edges"]
|
168
|
+
# Retrieve the selected nodes and edges based on the given vertices and edges
|
169
|
+
subgraph_nodes = vertices[vertices < graph.num_nodes]
|
170
|
+
subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
|
171
|
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
172
|
+
if len(virtual_vertices) > 0:
|
173
|
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
174
|
+
virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
|
175
|
+
subgraph_edges = np.array(subgraph_edges + virtual_edges)
|
176
|
+
edge_index = graph.edge_index[:, subgraph_edges]
|
177
|
+
subgraph_nodes = np.unique(
|
178
|
+
np.concatenate(
|
179
|
+
[subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()]
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
return {"nodes": subgraph_nodes, "edges": subgraph_edges}
|
184
|
+
|
185
|
+
def extract_subgraph(self, graph: Data, query_emb: torch.Tensor) -> dict:
|
186
|
+
"""
|
187
|
+
Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
191
|
+
query_emb: The query embedding.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
The selected nodes and edges of the subgraph.
|
195
|
+
"""
|
196
|
+
# Assert the topk and topk_e values for subgraph retrieval
|
197
|
+
assert self.topk > 0, "topk must be greater than or equal to 0"
|
198
|
+
assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
|
199
|
+
|
200
|
+
# Retrieve the top-k nodes and edges based on the query embedding
|
201
|
+
prizes = self.compute_prizes(graph, query_emb)
|
202
|
+
|
203
|
+
# Compute costs in constructing the subgraph
|
204
|
+
edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
|
205
|
+
graph, prizes
|
206
|
+
)
|
207
|
+
|
208
|
+
# Retrieve the subgraph using the PCST algorithm
|
209
|
+
result_vertices, result_edges = pcst_fast.pcst_fast(
|
210
|
+
edges_dict["edges"],
|
211
|
+
prizes,
|
212
|
+
costs,
|
213
|
+
self.root,
|
214
|
+
self.num_clusters,
|
215
|
+
self.pruning,
|
216
|
+
self.verbosity_level,
|
217
|
+
)
|
218
|
+
|
219
|
+
subgraph = self.get_subgraph_nodes_edges(
|
220
|
+
graph,
|
221
|
+
result_vertices,
|
222
|
+
{"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
|
223
|
+
mapping)
|
224
|
+
|
225
|
+
return subgraph
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: aiagents4pharma
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.19.0
|
4
4
|
Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D.
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -12,6 +12,7 @@ Requires-Dist: copasi_basico==0.78
|
|
12
12
|
Requires-Dist: coverage==7.6.4
|
13
13
|
Requires-Dist: einops==0.8.0
|
14
14
|
Requires-Dist: gdown==5.2.0
|
15
|
+
Requires-Dist: gravis==0.1.0
|
15
16
|
Requires-Dist: huggingface_hub==0.26.5
|
16
17
|
Requires-Dist: hydra-core==1.3.2
|
17
18
|
Requires-Dist: joblib==1.4.2
|
@@ -27,6 +28,7 @@ Requires-Dist: matplotlib==3.9.2
|
|
27
28
|
Requires-Dist: openai==1.59.4
|
28
29
|
Requires-Dist: ollama==0.4.6
|
29
30
|
Requires-Dist: pandas==2.2.3
|
31
|
+
Requires-Dist: pcst_fast==1.0.10
|
30
32
|
Requires-Dist: plotly==5.24.1
|
31
33
|
Requires-Dist: pydantic==2.9.2
|
32
34
|
Requires-Dist: pylint==3.3.1
|
@@ -55,31 +55,63 @@ aiagents4pharma/talk2cells/tools/__init__.py,sha256=38nK2a_lEFRjO3qD6Fo9a3983ZCY
|
|
55
55
|
aiagents4pharma/talk2cells/tools/scp_agent/__init__.py,sha256=s7g0lyH1lMD9pcWHLPtwRJRvzmTh2II7DrxyLulpjmQ,163
|
56
56
|
aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py,sha256=6q59gh_NQaiOU2rn55A3sIIFKlXi4SK3iKgySvUDrtQ,600
|
57
57
|
aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py,sha256=MLe-twtFnOu-P8P9diYq7jvHBHbWFRRCZLcfpUzqPMg,2806
|
58
|
-
aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=
|
58
|
+
aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=Z0Eo7LTiKk0STsr8VI7wkCLq7PHrK1vYlH4I1hSNLiA,165
|
59
|
+
aiagents4pharma/talk2knowledgegraphs/agents/__init__.py,sha256=iOAzuy_8A03tQDFtSBhC9dldUo62z5gfxcVtXAdLOJs,92
|
60
|
+
aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py,sha256=j6MA1LB28mqpb6ZEmNLGcvDZvOnlGbJB9r7VXyEGask,3079
|
61
|
+
aiagents4pharma/talk2knowledgegraphs/configs/__init__.py,sha256=Y49ucO22v9oe9EwFiXN6MU2wvyB3_ZBpmHwHbeh-ZVQ,106
|
62
|
+
aiagents4pharma/talk2knowledgegraphs/configs/config.yaml,sha256=rwUIZ2t5j5hlFyre7VnV8zMsP0qpPTwvAFExgvQD6q0,196
|
63
|
+
aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
64
|
+
aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml,sha256=ENCGROwYFpR6g4QD518h73sshdn3vPVpotBMk1QJcpU,4830
|
65
|
+
aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py,sha256=fKfc3FR7g5KjY9b6jzrU6cwKTVVpkoVZQS3dvUowu34,69
|
66
|
+
aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
67
|
+
aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml,sha256=4azC4cH-_-zt-bRVgNjkFM24mjNke6Rgn9pNl7XWrPQ,912
|
68
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py,sha256=C1yyRZW8hqWw46p_bh1vAJp2z9aVvn4HpKjKkjlWIqY,150
|
69
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
70
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml,sha256=Ua99yECXiwp4ZCUDgsDskYbKzcJrv7roQuLj31Zky4c,1037
|
71
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
72
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml,sha256=U8HvMsYbaOwDwQPATj7EFvLtTy7XZEplE5WMoNjgYYc,1469
|
73
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
74
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml,sha256=OOSlPpJVwJK4_lu4lhA2E48yhFFbEYpyHsoi9Orgm00,561
|
59
75
|
aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py,sha256=L3gPuHskSegmtXskVrLIYr7FXe_ibKgJ2GGr1_Wok6k,173
|
60
76
|
aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py,sha256=QlzDXmXREoa9MA6-GwzqRjdzndQeGBAF11Td6NFk_9Y,23426
|
61
77
|
aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py,sha256=-LaPLse8BkALqwFetNK7wch2dt9Dz6QKGKZKBKM6bIk,409
|
62
78
|
aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py,sha256=KBMhCJ7yjMWqQJJctFYdpjYAlwv48Jl6i1dddXP4f08,7599
|
63
79
|
aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py,sha256=Y-6-nORsnBJlU6rH0skyfr9S9J4PfTWK-af_p5UuknQ,7483
|
80
|
+
aiagents4pharma/talk2knowledgegraphs/states/__init__.py,sha256=XaqorSvx634dWRRlXUdzlisHtYMyqgJ2q7TanzsKlhw,108
|
81
|
+
aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py,sha256=6HqGo-awqoyNJG0igm5so5A4Tq8RPkCsjPg8Go38csE,1066
|
64
82
|
aiagents4pharma/talk2knowledgegraphs/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
83
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py,sha256=CCN6cyhEaiXSvIC-4y3ueDSzjDCBYDsmSmOor-DMeF4,3928
|
65
84
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py,sha256=crH0eFA3P8P6IYzi1UWNa4YvRVrtlBzoScf9NaE1lDk,9827
|
66
85
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py,sha256=NFUlsZvhfIrkF4YenWfahrLK93Xhm5UYEGG_uYN2LVM,566
|
67
86
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py,sha256=Pvu0r93CpnhjkfMxc-EiVLpAJ04FdW9iTamCnetu654,2272
|
68
87
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py,sha256=TuIsqcN1Mww3DTqGk6ebgJBWzUWdMWEq2yRQuYSFqvA,4416
|
88
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py,sha256=aOKHTber2Cg3mjNjfIa6RZU7XdFj5C2ps1YEUXw76CI,10650
|
89
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py,sha256=zRi2j9Dm3VFywhhrPjVoJ7z_zJpAEM74MJRXapnhwVE,6246
|
90
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py,sha256=oBqfspXXOxH04OQuPb8BCW0liIQTGKXtaPNSrPpQtFc,7597
|
69
91
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py,sha256=uYFoE_6zeU10_1mLLAHUr5c4S2XZMSc0Q_860o-KWEw,1517
|
70
|
-
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py,sha256=
|
92
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py,sha256=hzX84pheZdEsTtikF2KtBFiH44_xPjYXxLA6p4Ax1CY,1623
|
93
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py,sha256=jn-TrPwF0aR9kVoerwkbMZa3U6Hc6HjV6Zoau4qSH4g,1834
|
71
94
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py,sha256=Qxo6WeIDRy8aLh1tNKw0kSlzmUj3MtTak63oW2YwB24,1327
|
72
95
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py,sha256=N6HRr4lWHXY7bTHe2uXJe4D_EG9WqZPibZne6qLl9_k,1447
|
73
|
-
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py,sha256=
|
74
|
-
aiagents4pharma/talk2knowledgegraphs/
|
96
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py,sha256=JhY7axvVULLywDJ2ctA-gob5YPeaJYWsaMNjHT6L9CU,3021
|
97
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py,sha256=pal76wi7WgQWUNk56BrzfFV8jKpbDaHHdbwtgx_gXLI,2410
|
98
|
+
aiagents4pharma/talk2knowledgegraphs/tools/__init__.py,sha256=zpD4h7EYtyq0QNOqLd6bkxrPlPb2XN64ceI9ncgESrA,171
|
99
|
+
aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py,sha256=OEuOFncDRdb7TQEGq4rkT5On-jI-R7Nt8K5EBzaND8w,5338
|
100
|
+
aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py,sha256=zhmsRp-8vjB5rRekqTA07d3yb-42HWqng9dDMkvK6hM,623
|
101
|
+
aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py,sha256=te06QMFQfgJWrjaGrqpcOYeaV38jwm0KY_rXVSMHkeI,11468
|
102
|
+
aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py,sha256=mDSBOxopDfNhEJeU8fVI8b5lXTYrRzcc97aLbFgYSy4,4413
|
103
|
+
aiagents4pharma/talk2knowledgegraphs/utils/__init__.py,sha256=Q9mzcSmkmhdnOn13fxGh1fNECYoUR5Y5CCuEJTIxwAI,167
|
75
104
|
aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py,sha256=6vQnPkeOWae_8jePjhma3sJuMTngy0I0tqzdFt6OqKg,2507
|
76
|
-
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py,sha256=
|
105
|
+
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py,sha256=4TGK0XIVkkfGOyrSVwFQ-Lp-rzH9CCl-fWcqkFJKRLc,174
|
77
106
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py,sha256=1nGznrAj-xT0xuSMBGz2dOujJ7M_IwSR84njxtxsy9A,2523
|
78
107
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py,sha256=2vi_elf6EgzfagFAO5QnL3a_aXZyN7B1EBziu44MTfM,3806
|
108
|
+
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py,sha256=8w0sjt3Ex5YJ_XvpKl9UbhdTiiaoMIarbPUxLBU-1Uw,2378
|
79
109
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py,sha256=36iKlisOpMtGR5xfTAlSHXWvPqVC_Jbezod8kbBBMVg,2136
|
80
110
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py,sha256=tW426knki2DBIHcWyF_K04iMMdbpIn_e_TpPmTgz2dI,113
|
81
111
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py,sha256=Bx8x6zzk5614ApWB90N_iv4_Y_Uq0-KwUeBwYSdQMU4,924
|
82
112
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py,sha256=8eoxR-VHo0G7ReQIwje7xEhE-SJlHdef7_wJRpnvFIc,4116
|
113
|
+
aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py,sha256=7gwwtfzKhB8GuOBD47XRi0NprwEXkOzwNl5eeu-hDTI,86
|
114
|
+
aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py,sha256=m5p0yoJb7I19ua5yeQfXPf7c4r6S1XPwttsrM7Qoy94,9336
|
83
115
|
aiagents4pharma/talk2scholars/__init__.py,sha256=gphERyVKZHvOnMQsml7TIHlaIshHJ75R1J3FKExkfuY,120
|
84
116
|
aiagents4pharma/talk2scholars/agents/__init__.py,sha256=ykszlVGxz3egLHZAttlNoTPxIrnQJZYva_ssR8fwIFk,117
|
85
117
|
aiagents4pharma/talk2scholars/agents/main_agent.py,sha256=etPQUCjHtD-in-kD7Wg_UD6jRtCHj-mj41y03PYbAQM,4616
|
@@ -112,8 +144,8 @@ aiagents4pharma/talk2scholars/tools/s2/display_results.py,sha256=B8JJGohi1Eyx8C3
|
|
112
144
|
aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py,sha256=0Y3q8TkF_Phng9L7g1kk9Fhyit9UNitWurp03H0GZv8,4455
|
113
145
|
aiagents4pharma/talk2scholars/tools/s2/search.py,sha256=CcgFN7YuuQ9Vl1DJcldnnvPrswABKjNxeauK1rABps8,4176
|
114
146
|
aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py,sha256=irS-igdG8BZbVb0Z4VlIjzsyBlUfREd0v0_RlUM-0_U,4994
|
115
|
-
aiagents4pharma-1.
|
116
|
-
aiagents4pharma-1.
|
117
|
-
aiagents4pharma-1.
|
118
|
-
aiagents4pharma-1.
|
119
|
-
aiagents4pharma-1.
|
147
|
+
aiagents4pharma-1.19.0.dist-info/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
|
148
|
+
aiagents4pharma-1.19.0.dist-info/METADATA,sha256=jMpgcCw7eRa0gUr_cCMCX38iit_D0LQE2VdzHliAk_M,7053
|
149
|
+
aiagents4pharma-1.19.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
150
|
+
aiagents4pharma-1.19.0.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
|
151
|
+
aiagents4pharma-1.19.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|