aiagents4pharma 1.18.0__py3-none-any.whl → 1.19.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.1.dist-info}/METADATA +3 -1
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.1.dist-info}/RECORD +42 -10
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.1.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.1.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing subgraph summarization.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Type, Annotated
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from langchain_core.output_parsers import StrOutputParser
|
9
|
+
from langchain_core.prompts import ChatPromptTemplate
|
10
|
+
from langchain_core.messages import ToolMessage
|
11
|
+
from langchain_core.tools.base import InjectedToolCallId
|
12
|
+
from langchain_core.tools import BaseTool
|
13
|
+
from langgraph.types import Command
|
14
|
+
from langgraph.prebuilt import InjectedState
|
15
|
+
import hydra
|
16
|
+
|
17
|
+
# Initialize logger
|
18
|
+
logging.basicConfig(level=logging.INFO)
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class SubgraphSummarizationInput(BaseModel):
|
23
|
+
"""
|
24
|
+
SubgraphSummarizationInput is a Pydantic model representing an input for
|
25
|
+
summarizing a given textualized subgraph.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
tool_call_id: Tool call ID.
|
29
|
+
state: Injected state.
|
30
|
+
prompt: Prompt to interact with the backend.
|
31
|
+
extraction_name: Name assigned to the subgraph extraction process
|
32
|
+
"""
|
33
|
+
|
34
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
35
|
+
description="Tool call ID."
|
36
|
+
)
|
37
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
38
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
39
|
+
extraction_name: str = Field(
|
40
|
+
description="""Name assigned to the subgraph extraction process
|
41
|
+
when the subgraph_extraction tool is invoked."""
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class SubgraphSummarizationTool(BaseTool):
|
46
|
+
"""
|
47
|
+
This tool performs subgraph summarization over textualized graph to highlight the most
|
48
|
+
important information in responding to user's prompt.
|
49
|
+
"""
|
50
|
+
|
51
|
+
name: str = "subgraph_summarization"
|
52
|
+
description: str = """A tool to perform subgraph summarization over textualized graph
|
53
|
+
for responding to user's follow-up prompt(s)."""
|
54
|
+
args_schema: Type[BaseModel] = SubgraphSummarizationInput
|
55
|
+
|
56
|
+
def _run(
|
57
|
+
self,
|
58
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
59
|
+
state: Annotated[dict, InjectedState],
|
60
|
+
prompt: str,
|
61
|
+
extraction_name: str,
|
62
|
+
):
|
63
|
+
"""
|
64
|
+
Run the subgraph summarization tool.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
tool_call_id: The tool call ID.
|
68
|
+
state: The injected state.
|
69
|
+
prompt: The prompt to interact with the backend.
|
70
|
+
extraction_name: The name assigned to the subgraph extraction process.
|
71
|
+
"""
|
72
|
+
logger.log(
|
73
|
+
logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
|
74
|
+
)
|
75
|
+
|
76
|
+
# Load hydra configuration
|
77
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
78
|
+
cfg = hydra.compose(
|
79
|
+
config_name="config", overrides=["tools/subgraph_summarization=default"]
|
80
|
+
)
|
81
|
+
cfg = cfg.tools.subgraph_summarization
|
82
|
+
|
83
|
+
# Load the extracted graph
|
84
|
+
extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
|
85
|
+
# logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
|
86
|
+
|
87
|
+
# Prepare prompt template
|
88
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
89
|
+
[
|
90
|
+
("system", cfg.prompt_subgraph_summarization),
|
91
|
+
("human", "{input}"),
|
92
|
+
]
|
93
|
+
)
|
94
|
+
|
95
|
+
# Prepare chain
|
96
|
+
chain = prompt_template | state["llm_model"] | StrOutputParser()
|
97
|
+
|
98
|
+
# Return the subgraph and textualized graph as JSON response
|
99
|
+
response = chain.invoke(
|
100
|
+
{
|
101
|
+
"input": prompt,
|
102
|
+
"textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
|
103
|
+
}
|
104
|
+
)
|
105
|
+
|
106
|
+
# Store the response as graph_summary in the extracted graph
|
107
|
+
for key, value in extracted_graph.items():
|
108
|
+
if key == extraction_name:
|
109
|
+
value["graph_summary"] = response
|
110
|
+
|
111
|
+
# Prepare the dictionary of updated state
|
112
|
+
dic_updated_state_for_model = {}
|
113
|
+
for key, value in {
|
114
|
+
"dic_extracted_graph": list(extracted_graph.values()),
|
115
|
+
}.items():
|
116
|
+
if value:
|
117
|
+
dic_updated_state_for_model[key] = value
|
118
|
+
|
119
|
+
# Return the updated state of the tool
|
120
|
+
return Command(
|
121
|
+
update=dic_updated_state_for_model
|
122
|
+
| {
|
123
|
+
# update the message history
|
124
|
+
"messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
|
125
|
+
}
|
126
|
+
)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
"""
|
2
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import time
|
6
|
+
from typing import List
|
7
|
+
import subprocess
|
8
|
+
import ollama
|
9
|
+
from langchain_ollama import OllamaEmbeddings
|
10
|
+
from .embeddings import Embeddings
|
11
|
+
|
12
|
+
class EmbeddingWithOllama(Embeddings):
|
13
|
+
"""
|
14
|
+
Embedding class using Ollama model based on LangChain Embeddings class.
|
15
|
+
"""
|
16
|
+
def __init__(self, model_name: str):
|
17
|
+
"""
|
18
|
+
Initialize the EmbeddingWithOllama class.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
model_name: The name of the Ollama model to be used.
|
22
|
+
"""
|
23
|
+
# Setup the Ollama server
|
24
|
+
self.__setup(model_name)
|
25
|
+
|
26
|
+
# Set parameters
|
27
|
+
self.model_name = model_name
|
28
|
+
|
29
|
+
# Prepare model
|
30
|
+
self.model = OllamaEmbeddings(model=self.model_name)
|
31
|
+
|
32
|
+
def __setup(self, model_name: str) -> None:
|
33
|
+
"""
|
34
|
+
Check if the Ollama model is available and run the Ollama server if needed.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
model_name: The name of the Ollama model to be used.
|
38
|
+
"""
|
39
|
+
try:
|
40
|
+
models_list = ollama.list()["models"]
|
41
|
+
if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
|
42
|
+
ollama.pull(model_name)
|
43
|
+
time.sleep(30)
|
44
|
+
raise ValueError(f"Pulled {model_name} model")
|
45
|
+
except Exception as e:
|
46
|
+
with subprocess.Popen(
|
47
|
+
"ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
48
|
+
):
|
49
|
+
time.sleep(10)
|
50
|
+
raise ValueError(f"Error: {e} and restarted Ollama server.") from e
|
51
|
+
|
52
|
+
def embed_documents(self, texts: List[str]) -> List[float]:
|
53
|
+
"""
|
54
|
+
Generate embedding for a list of input texts using Ollama model.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
texts: The list of texts to be embedded.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The list of embeddings for the given texts.
|
61
|
+
"""
|
62
|
+
|
63
|
+
# Generate the embedding
|
64
|
+
embeddings = self.model.embed_documents(texts)
|
65
|
+
|
66
|
+
return embeddings
|
67
|
+
|
68
|
+
def embed_query(self, text: str) -> List[float]:
|
69
|
+
"""
|
70
|
+
Generate embeddings for an input text using Ollama model.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
text: A query to be embedded.
|
74
|
+
Returns:
|
75
|
+
The embeddings for the given query.
|
76
|
+
"""
|
77
|
+
|
78
|
+
# Generate the embedding
|
79
|
+
embeddings = self.model.embed_query(text)
|
80
|
+
|
81
|
+
return embeddings
|
@@ -0,0 +1,225 @@
|
|
1
|
+
"""
|
2
|
+
Exctraction of subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Tuple, NamedTuple
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import pcst_fast
|
9
|
+
from torch_geometric.data.data import Data
|
10
|
+
|
11
|
+
class PCSTPruning(NamedTuple):
|
12
|
+
"""
|
13
|
+
Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
|
14
|
+
(He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
|
15
|
+
Question Answering', NeurIPS 2024) paper.
|
16
|
+
https://arxiv.org/abs/2402.07630
|
17
|
+
https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
|
18
|
+
|
19
|
+
Args:
|
20
|
+
topk: The number of top nodes to consider.
|
21
|
+
topk_e: The number of top edges to consider.
|
22
|
+
cost_e: The cost of the edges.
|
23
|
+
c_const: The constant value for the cost of the edges computation.
|
24
|
+
root: The root node of the subgraph, -1 for unrooted.
|
25
|
+
num_clusters: The number of clusters.
|
26
|
+
pruning: The pruning strategy to use.
|
27
|
+
verbosity_level: The verbosity level.
|
28
|
+
"""
|
29
|
+
topk: int = 3
|
30
|
+
topk_e: int = 3
|
31
|
+
cost_e: float = 0.5
|
32
|
+
c_const: float = 0.01
|
33
|
+
root: int = -1
|
34
|
+
num_clusters: int = 1
|
35
|
+
pruning: str = "gw"
|
36
|
+
verbosity_level: int = 0
|
37
|
+
|
38
|
+
def compute_prizes(self, graph: Data, query_emb: torch.Tensor) -> np.ndarray:
|
39
|
+
"""
|
40
|
+
Compute the node prizes based on the cosine similarity between the query and nodes,
|
41
|
+
as well as the edge prizes based on the cosine similarity between the query and edges.
|
42
|
+
Note that the node and edge embeddings shall use the same embedding model and dimensions
|
43
|
+
with the query.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
47
|
+
query_emb: The query embedding in PyTorch Tensor format.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The prizes of the nodes and edges.
|
51
|
+
"""
|
52
|
+
# Compute prizes for nodes
|
53
|
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
|
54
|
+
topk = min(self.topk, graph.num_nodes)
|
55
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
56
|
+
n_prizes = torch.zeros_like(n_prizes)
|
57
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
58
|
+
|
59
|
+
# Compute prizes for edges
|
60
|
+
# e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
|
61
|
+
# topk_e = min(self.topk_e, e_prizes.unique().size(0))
|
62
|
+
# topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
|
63
|
+
# e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
64
|
+
# last_topk_e_value = topk_e
|
65
|
+
# for k in range(topk_e):
|
66
|
+
# indices = e_prizes == topk_e_values[k]
|
67
|
+
# value = min((topk_e - k) / sum(indices), last_topk_e_value)
|
68
|
+
# e_prizes[indices] = value
|
69
|
+
# last_topk_e_value = value * (1 - self.c_const)
|
70
|
+
|
71
|
+
# Optimized version of the above code
|
72
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
|
73
|
+
unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
|
74
|
+
topk_e = min(self.topk_e, unique_prizes.size(0))
|
75
|
+
topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
|
76
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
77
|
+
last_topk_e_value = topk_e
|
78
|
+
for k in range(topk_e):
|
79
|
+
indices = inverse_indices == (
|
80
|
+
unique_prizes == topk_e_values[k]
|
81
|
+
).nonzero(as_tuple=True)[0]
|
82
|
+
value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
|
83
|
+
e_prizes[indices] = value
|
84
|
+
last_topk_e_value = value * (1 - self.c_const)
|
85
|
+
|
86
|
+
return {"nodes": n_prizes, "edges": e_prizes}
|
87
|
+
|
88
|
+
def compute_subgraph_costs(
|
89
|
+
self, graph: Data, prizes: dict
|
90
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
91
|
+
"""
|
92
|
+
Compute the costs in constructing the subgraph proposed by G-Retriever paper.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
96
|
+
prizes: The prizes of the nodes and the edges.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
edges: The edges of the subgraph, consisting of edges and number of edges without
|
100
|
+
virtual edges.
|
101
|
+
prizes: The prizes of the subgraph.
|
102
|
+
costs: The costs of the subgraph.
|
103
|
+
"""
|
104
|
+
# Logic to reduce the cost of the edges such that at least one edge is selected
|
105
|
+
updated_cost_e = min(
|
106
|
+
self.cost_e,
|
107
|
+
prizes["edges"].max().item() * (1 - self.c_const / 2),
|
108
|
+
)
|
109
|
+
|
110
|
+
# Initialize variables
|
111
|
+
edges = []
|
112
|
+
costs = []
|
113
|
+
virtual = {
|
114
|
+
"n_prizes": [],
|
115
|
+
"edges": [],
|
116
|
+
"costs": [],
|
117
|
+
}
|
118
|
+
mapping = {"nodes": {}, "edges": {}}
|
119
|
+
|
120
|
+
# Compute the costs, edges, and virtual variables based on the prizes
|
121
|
+
for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
|
122
|
+
prize_e = prizes["edges"][i]
|
123
|
+
if prize_e <= updated_cost_e:
|
124
|
+
mapping["edges"][len(edges)] = i
|
125
|
+
edges.append((src, dst))
|
126
|
+
costs.append(updated_cost_e - prize_e)
|
127
|
+
else:
|
128
|
+
virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
|
129
|
+
mapping["nodes"][virtual_node_id] = i
|
130
|
+
virtual["edges"].append((src, virtual_node_id))
|
131
|
+
virtual["edges"].append((virtual_node_id, dst))
|
132
|
+
virtual["costs"].append(0)
|
133
|
+
virtual["costs"].append(0)
|
134
|
+
virtual["n_prizes"].append(prize_e - updated_cost_e)
|
135
|
+
prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
|
136
|
+
edges_dict = {}
|
137
|
+
edges_dict["edges"] = edges
|
138
|
+
edges_dict["num_prior_edges"] = len(edges)
|
139
|
+
# Final computation of the costs and edges based on the virtual costs and virtual edges
|
140
|
+
if len(virtual["costs"]) > 0:
|
141
|
+
costs = np.array(costs + virtual["costs"])
|
142
|
+
edges = np.array(edges + virtual["edges"])
|
143
|
+
edges_dict["edges"] = edges
|
144
|
+
|
145
|
+
return edges_dict, prizes, costs, mapping
|
146
|
+
|
147
|
+
def get_subgraph_nodes_edges(
|
148
|
+
self, graph: Data, vertices: np.ndarray, edges_dict: dict, mapping: dict,
|
149
|
+
) -> dict:
|
150
|
+
"""
|
151
|
+
Get the selected nodes and edges of the subgraph based on the vertices and edges computed
|
152
|
+
by the PCST algorithm.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
156
|
+
vertices: The vertices of the subgraph computed by the PCST algorithm.
|
157
|
+
edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
|
158
|
+
and the number of prior edges (without virtual edges).
|
159
|
+
mapping: The mapping dictionary of the nodes and edges.
|
160
|
+
num_prior_edges: The number of edges before adding virtual edges.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
The selected nodes and edges of the extracted subgraph.
|
164
|
+
"""
|
165
|
+
# Get edges information
|
166
|
+
edges = edges_dict["edges"]
|
167
|
+
num_prior_edges = edges_dict["num_prior_edges"]
|
168
|
+
# Retrieve the selected nodes and edges based on the given vertices and edges
|
169
|
+
subgraph_nodes = vertices[vertices < graph.num_nodes]
|
170
|
+
subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
|
171
|
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
172
|
+
if len(virtual_vertices) > 0:
|
173
|
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
174
|
+
virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
|
175
|
+
subgraph_edges = np.array(subgraph_edges + virtual_edges)
|
176
|
+
edge_index = graph.edge_index[:, subgraph_edges]
|
177
|
+
subgraph_nodes = np.unique(
|
178
|
+
np.concatenate(
|
179
|
+
[subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()]
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
return {"nodes": subgraph_nodes, "edges": subgraph_edges}
|
184
|
+
|
185
|
+
def extract_subgraph(self, graph: Data, query_emb: torch.Tensor) -> dict:
|
186
|
+
"""
|
187
|
+
Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
191
|
+
query_emb: The query embedding.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
The selected nodes and edges of the subgraph.
|
195
|
+
"""
|
196
|
+
# Assert the topk and topk_e values for subgraph retrieval
|
197
|
+
assert self.topk > 0, "topk must be greater than or equal to 0"
|
198
|
+
assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
|
199
|
+
|
200
|
+
# Retrieve the top-k nodes and edges based on the query embedding
|
201
|
+
prizes = self.compute_prizes(graph, query_emb)
|
202
|
+
|
203
|
+
# Compute costs in constructing the subgraph
|
204
|
+
edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
|
205
|
+
graph, prizes
|
206
|
+
)
|
207
|
+
|
208
|
+
# Retrieve the subgraph using the PCST algorithm
|
209
|
+
result_vertices, result_edges = pcst_fast.pcst_fast(
|
210
|
+
edges_dict["edges"],
|
211
|
+
prizes,
|
212
|
+
costs,
|
213
|
+
self.root,
|
214
|
+
self.num_clusters,
|
215
|
+
self.pruning,
|
216
|
+
self.verbosity_level,
|
217
|
+
)
|
218
|
+
|
219
|
+
subgraph = self.get_subgraph_nodes_edges(
|
220
|
+
graph,
|
221
|
+
result_vertices,
|
222
|
+
{"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
|
223
|
+
mapping)
|
224
|
+
|
225
|
+
return subgraph
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: aiagents4pharma
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.19.1
|
4
4
|
Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D.
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -12,6 +12,7 @@ Requires-Dist: copasi_basico==0.78
|
|
12
12
|
Requires-Dist: coverage==7.6.4
|
13
13
|
Requires-Dist: einops==0.8.0
|
14
14
|
Requires-Dist: gdown==5.2.0
|
15
|
+
Requires-Dist: gravis==0.1.0
|
15
16
|
Requires-Dist: huggingface_hub==0.26.5
|
16
17
|
Requires-Dist: hydra-core==1.3.2
|
17
18
|
Requires-Dist: joblib==1.4.2
|
@@ -27,6 +28,7 @@ Requires-Dist: matplotlib==3.9.2
|
|
27
28
|
Requires-Dist: openai==1.59.4
|
28
29
|
Requires-Dist: ollama==0.4.6
|
29
30
|
Requires-Dist: pandas==2.2.3
|
31
|
+
Requires-Dist: pcst_fast==1.0.10
|
30
32
|
Requires-Dist: plotly==5.24.1
|
31
33
|
Requires-Dist: pydantic==2.9.2
|
32
34
|
Requires-Dist: pylint==3.3.1
|
@@ -55,31 +55,63 @@ aiagents4pharma/talk2cells/tools/__init__.py,sha256=38nK2a_lEFRjO3qD6Fo9a3983ZCY
|
|
55
55
|
aiagents4pharma/talk2cells/tools/scp_agent/__init__.py,sha256=s7g0lyH1lMD9pcWHLPtwRJRvzmTh2II7DrxyLulpjmQ,163
|
56
56
|
aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py,sha256=6q59gh_NQaiOU2rn55A3sIIFKlXi4SK3iKgySvUDrtQ,600
|
57
57
|
aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py,sha256=MLe-twtFnOu-P8P9diYq7jvHBHbWFRRCZLcfpUzqPMg,2806
|
58
|
-
aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=
|
58
|
+
aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=Z0Eo7LTiKk0STsr8VI7wkCLq7PHrK1vYlH4I1hSNLiA,165
|
59
|
+
aiagents4pharma/talk2knowledgegraphs/agents/__init__.py,sha256=iOAzuy_8A03tQDFtSBhC9dldUo62z5gfxcVtXAdLOJs,92
|
60
|
+
aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py,sha256=j6MA1LB28mqpb6ZEmNLGcvDZvOnlGbJB9r7VXyEGask,3079
|
61
|
+
aiagents4pharma/talk2knowledgegraphs/configs/__init__.py,sha256=Y49ucO22v9oe9EwFiXN6MU2wvyB3_ZBpmHwHbeh-ZVQ,106
|
62
|
+
aiagents4pharma/talk2knowledgegraphs/configs/config.yaml,sha256=rwUIZ2t5j5hlFyre7VnV8zMsP0qpPTwvAFExgvQD6q0,196
|
63
|
+
aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
64
|
+
aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml,sha256=ENCGROwYFpR6g4QD518h73sshdn3vPVpotBMk1QJcpU,4830
|
65
|
+
aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py,sha256=fKfc3FR7g5KjY9b6jzrU6cwKTVVpkoVZQS3dvUowu34,69
|
66
|
+
aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
67
|
+
aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml,sha256=4azC4cH-_-zt-bRVgNjkFM24mjNke6Rgn9pNl7XWrPQ,912
|
68
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py,sha256=C1yyRZW8hqWw46p_bh1vAJp2z9aVvn4HpKjKkjlWIqY,150
|
69
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
70
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml,sha256=Ua99yECXiwp4ZCUDgsDskYbKzcJrv7roQuLj31Zky4c,1037
|
71
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
72
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml,sha256=U8HvMsYbaOwDwQPATj7EFvLtTy7XZEplE5WMoNjgYYc,1469
|
73
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
74
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml,sha256=OOSlPpJVwJK4_lu4lhA2E48yhFFbEYpyHsoi9Orgm00,561
|
59
75
|
aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py,sha256=L3gPuHskSegmtXskVrLIYr7FXe_ibKgJ2GGr1_Wok6k,173
|
60
76
|
aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py,sha256=QlzDXmXREoa9MA6-GwzqRjdzndQeGBAF11Td6NFk_9Y,23426
|
61
77
|
aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py,sha256=-LaPLse8BkALqwFetNK7wch2dt9Dz6QKGKZKBKM6bIk,409
|
62
78
|
aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py,sha256=KBMhCJ7yjMWqQJJctFYdpjYAlwv48Jl6i1dddXP4f08,7599
|
63
79
|
aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py,sha256=Y-6-nORsnBJlU6rH0skyfr9S9J4PfTWK-af_p5UuknQ,7483
|
80
|
+
aiagents4pharma/talk2knowledgegraphs/states/__init__.py,sha256=XaqorSvx634dWRRlXUdzlisHtYMyqgJ2q7TanzsKlhw,108
|
81
|
+
aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py,sha256=6HqGo-awqoyNJG0igm5so5A4Tq8RPkCsjPg8Go38csE,1066
|
64
82
|
aiagents4pharma/talk2knowledgegraphs/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
83
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py,sha256=CCN6cyhEaiXSvIC-4y3ueDSzjDCBYDsmSmOor-DMeF4,3928
|
65
84
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py,sha256=crH0eFA3P8P6IYzi1UWNa4YvRVrtlBzoScf9NaE1lDk,9827
|
66
85
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py,sha256=NFUlsZvhfIrkF4YenWfahrLK93Xhm5UYEGG_uYN2LVM,566
|
67
86
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py,sha256=Pvu0r93CpnhjkfMxc-EiVLpAJ04FdW9iTamCnetu654,2272
|
68
87
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py,sha256=TuIsqcN1Mww3DTqGk6ebgJBWzUWdMWEq2yRQuYSFqvA,4416
|
88
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py,sha256=aOKHTber2Cg3mjNjfIa6RZU7XdFj5C2ps1YEUXw76CI,10650
|
89
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py,sha256=zRi2j9Dm3VFywhhrPjVoJ7z_zJpAEM74MJRXapnhwVE,6246
|
90
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py,sha256=oBqfspXXOxH04OQuPb8BCW0liIQTGKXtaPNSrPpQtFc,7597
|
69
91
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py,sha256=uYFoE_6zeU10_1mLLAHUr5c4S2XZMSc0Q_860o-KWEw,1517
|
70
|
-
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py,sha256=
|
92
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py,sha256=hzX84pheZdEsTtikF2KtBFiH44_xPjYXxLA6p4Ax1CY,1623
|
93
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py,sha256=jn-TrPwF0aR9kVoerwkbMZa3U6Hc6HjV6Zoau4qSH4g,1834
|
71
94
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py,sha256=Qxo6WeIDRy8aLh1tNKw0kSlzmUj3MtTak63oW2YwB24,1327
|
72
95
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py,sha256=N6HRr4lWHXY7bTHe2uXJe4D_EG9WqZPibZne6qLl9_k,1447
|
73
|
-
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py,sha256=
|
74
|
-
aiagents4pharma/talk2knowledgegraphs/
|
96
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py,sha256=JhY7axvVULLywDJ2ctA-gob5YPeaJYWsaMNjHT6L9CU,3021
|
97
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py,sha256=pal76wi7WgQWUNk56BrzfFV8jKpbDaHHdbwtgx_gXLI,2410
|
98
|
+
aiagents4pharma/talk2knowledgegraphs/tools/__init__.py,sha256=zpD4h7EYtyq0QNOqLd6bkxrPlPb2XN64ceI9ncgESrA,171
|
99
|
+
aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py,sha256=OEuOFncDRdb7TQEGq4rkT5On-jI-R7Nt8K5EBzaND8w,5338
|
100
|
+
aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py,sha256=zhmsRp-8vjB5rRekqTA07d3yb-42HWqng9dDMkvK6hM,623
|
101
|
+
aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py,sha256=te06QMFQfgJWrjaGrqpcOYeaV38jwm0KY_rXVSMHkeI,11468
|
102
|
+
aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py,sha256=mDSBOxopDfNhEJeU8fVI8b5lXTYrRzcc97aLbFgYSy4,4413
|
103
|
+
aiagents4pharma/talk2knowledgegraphs/utils/__init__.py,sha256=Q9mzcSmkmhdnOn13fxGh1fNECYoUR5Y5CCuEJTIxwAI,167
|
75
104
|
aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py,sha256=6vQnPkeOWae_8jePjhma3sJuMTngy0I0tqzdFt6OqKg,2507
|
76
|
-
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py,sha256=
|
105
|
+
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py,sha256=4TGK0XIVkkfGOyrSVwFQ-Lp-rzH9CCl-fWcqkFJKRLc,174
|
77
106
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py,sha256=1nGznrAj-xT0xuSMBGz2dOujJ7M_IwSR84njxtxsy9A,2523
|
78
107
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py,sha256=2vi_elf6EgzfagFAO5QnL3a_aXZyN7B1EBziu44MTfM,3806
|
108
|
+
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py,sha256=8w0sjt3Ex5YJ_XvpKl9UbhdTiiaoMIarbPUxLBU-1Uw,2378
|
79
109
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py,sha256=36iKlisOpMtGR5xfTAlSHXWvPqVC_Jbezod8kbBBMVg,2136
|
80
110
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py,sha256=tW426knki2DBIHcWyF_K04iMMdbpIn_e_TpPmTgz2dI,113
|
81
111
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py,sha256=Bx8x6zzk5614ApWB90N_iv4_Y_Uq0-KwUeBwYSdQMU4,924
|
82
112
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py,sha256=8eoxR-VHo0G7ReQIwje7xEhE-SJlHdef7_wJRpnvFIc,4116
|
113
|
+
aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py,sha256=7gwwtfzKhB8GuOBD47XRi0NprwEXkOzwNl5eeu-hDTI,86
|
114
|
+
aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py,sha256=m5p0yoJb7I19ua5yeQfXPf7c4r6S1XPwttsrM7Qoy94,9336
|
83
115
|
aiagents4pharma/talk2scholars/__init__.py,sha256=gphERyVKZHvOnMQsml7TIHlaIshHJ75R1J3FKExkfuY,120
|
84
116
|
aiagents4pharma/talk2scholars/agents/__init__.py,sha256=ykszlVGxz3egLHZAttlNoTPxIrnQJZYva_ssR8fwIFk,117
|
85
117
|
aiagents4pharma/talk2scholars/agents/main_agent.py,sha256=etPQUCjHtD-in-kD7Wg_UD6jRtCHj-mj41y03PYbAQM,4616
|
@@ -112,8 +144,8 @@ aiagents4pharma/talk2scholars/tools/s2/display_results.py,sha256=B8JJGohi1Eyx8C3
|
|
112
144
|
aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py,sha256=0Y3q8TkF_Phng9L7g1kk9Fhyit9UNitWurp03H0GZv8,4455
|
113
145
|
aiagents4pharma/talk2scholars/tools/s2/search.py,sha256=CcgFN7YuuQ9Vl1DJcldnnvPrswABKjNxeauK1rABps8,4176
|
114
146
|
aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py,sha256=irS-igdG8BZbVb0Z4VlIjzsyBlUfREd0v0_RlUM-0_U,4994
|
115
|
-
aiagents4pharma-1.
|
116
|
-
aiagents4pharma-1.
|
117
|
-
aiagents4pharma-1.
|
118
|
-
aiagents4pharma-1.
|
119
|
-
aiagents4pharma-1.
|
147
|
+
aiagents4pharma-1.19.1.dist-info/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
|
148
|
+
aiagents4pharma-1.19.1.dist-info/METADATA,sha256=lVPldiDbObNVA1dfpFA8wnLjh-4s3oX8O5eZSkApcL4,7053
|
149
|
+
aiagents4pharma-1.19.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
150
|
+
aiagents4pharma-1.19.1.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
|
151
|
+
aiagents4pharma-1.19.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|