aiagents4pharma 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
  2. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
  3. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
  4. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
  5. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  6. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  7. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
  8. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  9. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
  10. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
  11. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
  12. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  13. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  14. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  15. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  16. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  17. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  18. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
  19. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
  20. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
  21. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
  22. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
  23. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
  24. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
  25. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
  26. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
  27. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
  28. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
  29. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
  30. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  31. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
  32. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
  33. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
  34. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  35. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
  36. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
  37. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
  38. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +3 -1
  39. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +42 -10
  40. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
  41. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
  42. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,21 @@ import pytest
6
6
  import ollama
7
7
  from ..utils.enrichments.ollama import EnrichmentWithOllama
8
8
 
9
+
9
10
  @pytest.fixture(name="ollama_config")
10
11
  def fixture_ollama_config():
11
12
  """Return a dictionary with Ollama configuration."""
12
13
  return {
13
- "model_name": "smollm2:360m",
14
+ "model_name": "llama3.2:1b",
14
15
  "prompt_enrichment": """
15
- Given the input as a list of strings, please return the list of addditional information of
16
- each input terms using your prior knowledge.
16
+ Given the input as a list of strings, please return the list of addditional information
17
+ of each input terms using your prior knowledge.
17
18
 
18
19
  Example:
19
20
  Input: ['acetaminophen', 'aspirin']
20
- Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
+ Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
22
  'aspirin is a medication used to treat pain, fever, and inflammation']
22
-
23
+
23
24
  Do not include any pretext as the output, only the list of strings enriched.
24
25
 
25
26
  Input: {input}
@@ -28,10 +29,11 @@ def fixture_ollama_config():
28
29
  "streaming": False,
29
30
  }
30
31
 
32
+
31
33
  def test_no_model_ollama(ollama_config):
32
34
  """Test the case when the Ollama model is not available."""
33
35
  cfg = ollama_config
34
- cfg_model = "smollm2:135m" # Choose a small model
36
+ cfg_model = "smollm2:135m" # Choose a small model
35
37
 
36
38
  # Delete the Ollama model
37
39
  try:
@@ -41,7 +43,8 @@ def test_no_model_ollama(ollama_config):
41
43
 
42
44
  # Check if the model is available
43
45
  with pytest.raises(
44
- ValueError, match=f"Error: Pulled {cfg_model} model and restarted Ollama server."
46
+ ValueError,
47
+ match=f"Error: Pulled {cfg_model} model and restarted Ollama server.",
45
48
  ):
46
49
  EnrichmentWithOllama(
47
50
  model_name=cfg_model,
@@ -51,7 +54,8 @@ def test_no_model_ollama(ollama_config):
51
54
  )
52
55
  ollama.delete(cfg_model)
53
56
 
54
- def test_enrich_nodes_ollama(ollama_config):
57
+
58
+ def test_enrich_ollama(ollama_config):
55
59
  """Test the Ollama textual enrichment class for node enrichment."""
56
60
  # Prepare enrichment model
57
61
  cfg = ollama_config
@@ -63,37 +67,11 @@ def test_enrich_nodes_ollama(ollama_config):
63
67
  )
64
68
 
65
69
  # Perform enrichment for nodes
66
- nodes = ["Adalimumab", "Infliximab"]
70
+ nodes = ["acetaminophen"]
67
71
  enriched_nodes = enr_model.enrich_documents(nodes)
68
72
  # Check the enriched nodes
69
- assert len(enriched_nodes) == 2
70
- assert all(
71
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
72
- )
73
-
74
-
75
- def test_enrich_relations_ollama(ollama_config):
76
- """Test the Ollama textual enrichment class for relation enrichment."""
77
- # Prepare enrichment model
78
- cfg = ollama_config
79
- enr_model = EnrichmentWithOllama(
80
- model_name=cfg["model_name"],
81
- prompt_enrichment=cfg["prompt_enrichment"],
82
- temperature=cfg["temperature"],
83
- streaming=cfg["streaming"],
84
- )
85
- # Perform enrichment for relations
86
- relations = [
87
- "IL23R-gene causation disease-inflammatory bowel diseases",
88
- "NOD2-gene causation disease-inflammatory bowel diseases",
89
- ]
90
- enriched_relations = enr_model.enrich_documents(relations)
91
- # Check the enriched relations
92
- assert len(enriched_relations) == 2
93
- assert all(
94
- enriched_relations[i] != relations[i]
95
- for i in range(len(relations))
96
- )
73
+ assert len(enriched_nodes) == 1
74
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
97
75
 
98
76
 
99
77
  def test_enrich_ollama_rag(ollama_config):
@@ -107,11 +85,9 @@ def test_enrich_ollama_rag(ollama_config):
107
85
  streaming=cfg["streaming"],
108
86
  )
109
87
  # Perform enrichment for nodes
110
- nodes = ["Adalimumab", "Infliximab"]
88
+ nodes = ["acetaminophen"]
111
89
  docs = [r"\path\to\doc1", r"\path\to\doc2"]
112
90
  enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
113
91
  # Check the enriched nodes
114
- assert len(enriched_nodes) == 2
115
- assert all(
116
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
117
- )
92
+ assert len(enriched_nodes) == 1
93
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
@@ -0,0 +1,79 @@
1
+ """
2
+ Test cases for utils/kg_utils.py
3
+ """
4
+
5
+ import pytest
6
+ import networkx as nx
7
+ import pandas as pd
8
+ from ..utils import kg_utils
9
+
10
+
11
+ @pytest.fixture(name="sample_graph")
12
+ def make_sample_graph():
13
+ """Return a sample graph"""
14
+ sg = nx.Graph()
15
+ sg.add_node(1, node_id=1, feature_id="A", feature_value="ValueA")
16
+ sg.add_node(2, node_id=2, feature_id="B", feature_value="ValueB")
17
+ sg.add_edge(1, 2, edge_id=1, feature_id="E", feature_value="EdgeValue")
18
+ return sg
19
+
20
+
21
+ def test_kg_to_df_pandas(sample_graph):
22
+ """Test the kg_to_df_pandas function"""
23
+ df_nodes, df_edges = kg_utils.kg_to_df_pandas(sample_graph)
24
+ print(df_nodes)
25
+ expected_nodes_data = {
26
+ "node_id": [1, 2],
27
+ "feature_id": ["A", "B"],
28
+ "feature_value": ["ValueA", "ValueB"],
29
+ }
30
+ expected_nodes_df = pd.DataFrame(expected_nodes_data, index=[1, 2])
31
+ print(expected_nodes_df)
32
+ expected_edges_data = {
33
+ "node_source": [1],
34
+ "node_target": [2],
35
+ "edge_id": [1],
36
+ "feature_id": ["E"],
37
+ "feature_value": ["EdgeValue"],
38
+ }
39
+ expected_edges_df = pd.DataFrame(expected_edges_data)
40
+
41
+ # Assert that the dataframes are equal but the order of columns may be different
42
+ # Ignore the index of the dataframes
43
+ pd.testing.assert_frame_equal(df_nodes, expected_nodes_df, check_like=True)
44
+ pd.testing.assert_frame_equal(df_edges, expected_edges_df, check_like=True)
45
+
46
+
47
+ def test_df_pandas_to_kg():
48
+ """Test the df_pandas_to_kg function"""
49
+ nodes_data = {
50
+ "node_id": [1, 2],
51
+ "feature_id": ["A", "B"],
52
+ "feature_value": ["ValueA", "ValueB"],
53
+ }
54
+ df_nodes_attrs = pd.DataFrame(nodes_data).set_index("node_id")
55
+
56
+ edges_data = {
57
+ "node_source": [1],
58
+ "node_target": [2],
59
+ "edge_id": [1],
60
+ "feature_id": ["E"],
61
+ "feature_value": ["EdgeValue"],
62
+ }
63
+ df_edges = pd.DataFrame(edges_data)
64
+
65
+ kg = kg_utils.df_pandas_to_kg(
66
+ df_edges, df_nodes_attrs, "node_source", "node_target"
67
+ )
68
+
69
+ assert len(kg.nodes) == 2
70
+ assert len(kg.edges) == 1
71
+
72
+ assert kg.nodes[1]["feature_id"] == "A"
73
+ assert kg.nodes[1]["feature_value"] == "ValueA"
74
+ assert kg.nodes[2]["feature_id"] == "B"
75
+ assert kg.nodes[2]["feature_value"] == "ValueB"
76
+
77
+ assert kg.edges[1, 2]["feature_id"] == "E"
78
+ assert kg.edges[1, 2]["feature_value"] == "EdgeValue"
79
+ assert kg.edges[1, 2]["edge_id"] == 1
@@ -0,0 +1,6 @@
1
+ '''
2
+ This file is used to import all the models in the package.
3
+ '''
4
+ from . import subgraph_extraction
5
+ from . import subgraph_summarization
6
+ from . import graphrag_reasoning
@@ -0,0 +1,143 @@
1
+ """
2
+ Tool for performing Graph RAG reasoning.
3
+ """
4
+
5
+ import logging
6
+ from typing import Type, Annotated
7
+ from pydantic import BaseModel, Field
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.messages import ToolMessage
10
+ from langchain_core.tools.base import InjectedToolCallId
11
+ from langchain_core.tools import BaseTool
12
+ from langchain_core.vectorstores import InMemoryVectorStore
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.document_loaders import PyPDFLoader
17
+ from langgraph.types import Command
18
+ from langgraph.prebuilt import InjectedState
19
+ import hydra
20
+
21
+ # Initialize logger
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class GraphRAGReasoningInput(BaseModel):
27
+ """
28
+ GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
29
+
30
+ Args:
31
+ state: Injected state.
32
+ prompt: Prompt to interact with the backend.
33
+ extraction_name: Name assigned to the subgraph extraction process
34
+ """
35
+
36
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
37
+ description="Tool call ID."
38
+ )
39
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
40
+ prompt: str = Field(description="Prompt to interact with the backend.")
41
+ extraction_name: str = Field(
42
+ description="""Name assigned to the subgraph extraction process
43
+ when the subgraph_extraction tool is invoked."""
44
+ )
45
+
46
+
47
+ class GraphRAGReasoningTool(BaseTool):
48
+ """
49
+ This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
50
+ over user's request by considering textualized subgraph context and document context.
51
+ """
52
+
53
+ name: str = "graphrag_reasoning"
54
+ description: str = """A tool to perform reasoning using a Graph RAG approach
55
+ by considering textualized subgraph context and document context."""
56
+ args_schema: Type[BaseModel] = GraphRAGReasoningInput
57
+
58
+ def _run(
59
+ self,
60
+ tool_call_id: Annotated[str, InjectedToolCallId],
61
+ state: Annotated[dict, InjectedState],
62
+ prompt: str,
63
+ extraction_name: str,
64
+ ):
65
+ """
66
+ Run the Graph RAG reasoning tool.
67
+
68
+ Args:
69
+ tool_call_id: The tool call ID.
70
+ state: The injected state.
71
+ prompt: The prompt to interact with the backend.
72
+ extraction_name: The name assigned to the subgraph extraction process.
73
+ """
74
+ logger.log(
75
+ logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
76
+ )
77
+
78
+ # Load Hydra configuration
79
+ with hydra.initialize(version_base=None, config_path="../configs"):
80
+ cfg = hydra.compose(
81
+ config_name="config", overrides=["tools/graphrag_reasoning=default"]
82
+ )
83
+ cfg = cfg.tools.graphrag_reasoning
84
+
85
+ # Prepare documents
86
+ all_docs = []
87
+ if len(state["uploaded_files"]) != 0:
88
+ for uploaded_file in state["uploaded_files"]:
89
+ if uploaded_file["file_type"] == "drug_data":
90
+ # Load documents
91
+ raw_documents = PyPDFLoader(
92
+ file_path=uploaded_file["file_path"]
93
+ ).load()
94
+
95
+ # Split documents
96
+ # May need to find an optimal chunk size and overlap configuration
97
+ documents = RecursiveCharacterTextSplitter(
98
+ chunk_size=cfg.splitter_chunk_size,
99
+ chunk_overlap=cfg.splitter_chunk_overlap,
100
+ ).split_documents(raw_documents)
101
+
102
+ # Add documents to the list
103
+ all_docs.extend(documents)
104
+
105
+ # Load the extracted graph
106
+ extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
107
+ # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
108
+
109
+ # Set another prompt template
110
+ prompt_template = ChatPromptTemplate.from_messages(
111
+ [("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
112
+ )
113
+
114
+ # Prepare chain with retrieved documents
115
+ qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
116
+ rag_chain = create_retrieval_chain(
117
+ InMemoryVectorStore.from_documents(
118
+ documents=all_docs, embedding=state["embedding_model"]
119
+ ).as_retriever(
120
+ search_type=cfg.retriever_search_type,
121
+ search_kwargs={
122
+ "k": cfg.retriever_k,
123
+ "fetch_k": cfg.retriever_fetch_k,
124
+ "lambda_mult": cfg.retriever_lambda_mult,
125
+ },
126
+ ),
127
+ qa_chain,
128
+ )
129
+
130
+ # Invoke the chain
131
+ response = rag_chain.invoke(
132
+ {
133
+ "input": prompt,
134
+ "subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
135
+ }
136
+ )
137
+
138
+ return Command(
139
+ update={
140
+ # update the message history
141
+ "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
142
+ }
143
+ )
@@ -0,0 +1,22 @@
1
+ """
2
+ A utility module for defining the dataclasses
3
+ for the arguments to set up initial settings
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Annotated
8
+
9
+
10
+ @dataclass
11
+ class ArgumentData:
12
+ """
13
+ Dataclass for storing the argument data.
14
+ """
15
+
16
+ extraction_name: Annotated[
17
+ str,
18
+ """An AI assigned _ separated name of the subgraph extraction
19
+ based on human query and the context of the graph reasoning
20
+ experiment.
21
+ This must be set before the subgraph extraction is invoked.""",
22
+ ]
@@ -0,0 +1,305 @@
1
+ """
2
+ Tool for performing subgraph extraction.
3
+ """
4
+
5
+ from typing import Type, Annotated
6
+ import logging
7
+ import pickle
8
+ import numpy as np
9
+ import pandas as pd
10
+ import hydra
11
+ import networkx as nx
12
+ from pydantic import BaseModel, Field
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_core.vectorstores import InMemoryVectorStore
17
+ from langchain_core.tools import BaseTool
18
+ from langchain_core.messages import ToolMessage
19
+ from langchain_core.tools.base import InjectedToolCallId
20
+ from langchain_community.document_loaders import PyPDFLoader
21
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
22
+ from langgraph.types import Command
23
+ from langgraph.prebuilt import InjectedState
24
+ import torch
25
+ from torch_geometric.data import Data
26
+ from ..utils.extractions.pcst import PCSTPruning
27
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
28
+ from .load_arguments import ArgumentData
29
+
30
+ # Initialize logger
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SubgraphExtractionInput(BaseModel):
36
+ """
37
+ SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
38
+
39
+ Args:
40
+ prompt: Prompt to interact with the backend.
41
+ tool_call_id: Tool call ID.
42
+ state: Injected state.
43
+ arg_data: Argument for analytical process over graph data.
44
+ """
45
+
46
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
47
+ description="Tool call ID."
48
+ )
49
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
50
+ prompt: str = Field(description="Prompt to interact with the backend.")
51
+ arg_data: ArgumentData = Field(
52
+ description="Experiment over graph data.", default=None
53
+ )
54
+
55
+
56
+ class SubgraphExtractionTool(BaseTool):
57
+ """
58
+ This tool performs subgraph extraction based on user's prompt by taking into account
59
+ the top-k nodes and edges.
60
+ """
61
+
62
+ name: str = "subgraph_extraction"
63
+ description: str = "A tool for subgraph extraction based on user's prompt."
64
+ args_schema: Type[BaseModel] = SubgraphExtractionInput
65
+
66
+ def perform_endotype_filtering(
67
+ self,
68
+ prompt: str,
69
+ state: Annotated[dict, InjectedState],
70
+ cfg: hydra.core.config_store.ConfigStore,
71
+ ) -> str:
72
+ """
73
+ Perform endotype filtering based on the uploaded files and prepare the prompt.
74
+
75
+ Args:
76
+ prompt: The prompt to interact with the backend.
77
+ state: Injected state for the tool.
78
+ cfg: Hydra configuration object.
79
+ """
80
+ # Loop through the uploaded files
81
+ all_genes = []
82
+ for uploaded_file in state["uploaded_files"]:
83
+ if uploaded_file["file_type"] == "endotype":
84
+ # Load the PDF file
85
+ docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
86
+
87
+ # Split the text into chunks
88
+ splits = RecursiveCharacterTextSplitter(
89
+ chunk_size=cfg.splitter_chunk_size,
90
+ chunk_overlap=cfg.splitter_chunk_overlap,
91
+ ).split_documents(docs)
92
+
93
+ # Create a chat prompt template
94
+ prompt_template = ChatPromptTemplate.from_messages(
95
+ [
96
+ ("system", cfg.prompt_endotype_filtering),
97
+ ("human", "{input}"),
98
+ ]
99
+ )
100
+
101
+ qa_chain = create_stuff_documents_chain(
102
+ state["llm_model"], prompt_template
103
+ )
104
+ rag_chain = create_retrieval_chain(
105
+ InMemoryVectorStore.from_documents(
106
+ documents=splits, embedding=state["embedding_model"]
107
+ ).as_retriever(
108
+ search_type=cfg.retriever_search_type,
109
+ search_kwargs={
110
+ "k": cfg.retriever_k,
111
+ "fetch_k": cfg.retriever_fetch_k,
112
+ "lambda_mult": cfg.retriever_lambda_mult,
113
+ },
114
+ ),
115
+ qa_chain,
116
+ )
117
+ results = rag_chain.invoke({"input": prompt})
118
+ all_genes.append(results["answer"])
119
+
120
+ # Prepare the prompt
121
+ if len(all_genes) > 0:
122
+ prompt = " ".join(
123
+ [prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
124
+ )
125
+
126
+ return prompt
127
+
128
+ def prepare_final_subgraph(self,
129
+ subgraph: dict,
130
+ pyg_graph: Data,
131
+ textualized_graph: pd.DataFrame) -> dict:
132
+ """
133
+ Prepare the subgraph based on the extracted subgraph.
134
+
135
+ Args:
136
+ subgraph: The extracted subgraph.
137
+ pyg_graph: The PyTorch Geometric graph.
138
+ textualized_graph: The textualized graph.
139
+
140
+ Returns:
141
+ A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
142
+ """
143
+ # print(subgraph)
144
+ # Prepare the PyTorch Geometric graph
145
+ mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
146
+ pyg_graph = Data(
147
+ # Node features
148
+ x=pyg_graph.x[subgraph["nodes"]],
149
+ node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
150
+ node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
151
+ enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
152
+ num_nodes=len(subgraph["nodes"]),
153
+ # Edge features
154
+ edge_index=torch.LongTensor(
155
+ [
156
+ [
157
+ mapping[i]
158
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
159
+ ],
160
+ [
161
+ mapping[i]
162
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
163
+ ],
164
+ ]
165
+ ),
166
+ edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
167
+ edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
168
+ relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
169
+ label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
170
+ enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
171
+ )
172
+
173
+ # Networkx DiGraph construction to be visualized in the frontend
174
+ nx_graph = nx.DiGraph()
175
+ for n in pyg_graph.node_name:
176
+ nx_graph.add_node(n)
177
+ for i, e in enumerate(
178
+ [
179
+ [pyg_graph.node_name[i], pyg_graph.node_name[j]]
180
+ for (i, j) in pyg_graph.edge_index.transpose(1, 0)
181
+ ]
182
+ ):
183
+ nx_graph.add_edge(
184
+ e[0],
185
+ e[1],
186
+ relation=pyg_graph.edge_type[i],
187
+ label=pyg_graph.edge_type[i],
188
+ )
189
+
190
+ # Prepare the textualized subgraph
191
+ textualized_graph = (
192
+ textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
193
+ + "\n"
194
+ + textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
195
+ )
196
+
197
+ return {
198
+ "graph_pyg": pyg_graph,
199
+ "graph_nx": nx_graph,
200
+ "graph_text": textualized_graph,
201
+ }
202
+
203
+ def _run(
204
+ self,
205
+ tool_call_id: Annotated[str, InjectedToolCallId],
206
+ state: Annotated[dict, InjectedState],
207
+ prompt: str,
208
+ arg_data: ArgumentData = None,
209
+ ) -> Command:
210
+ """
211
+ Run the subgraph extraction tool.
212
+
213
+ Args:
214
+ tool_call_id: The tool call ID for the tool.
215
+ state: Injected state for the tool.
216
+ prompt: The prompt to interact with the backend.
217
+ arg_data (ArgumentData): The argument data.
218
+
219
+ Returns:
220
+ Command: The command to be executed.
221
+ """
222
+ logger.log(logging.INFO, "Invoking subgraph_extraction tool")
223
+
224
+ # Load hydra configuration
225
+ with hydra.initialize(version_base=None, config_path="../configs"):
226
+ cfg = hydra.compose(
227
+ config_name="config", overrides=["tools/subgraph_extraction=default"]
228
+ )
229
+ cfg = cfg.tools.subgraph_extraction
230
+
231
+ # Retrieve source graph from the state
232
+ initial_graph = {}
233
+ initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
234
+ # logger.log(logging.INFO, "Source graph: %s", source_graph)
235
+
236
+ # Load the knowledge graph
237
+ with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
238
+ initial_graph["pyg"] = pickle.load(f)
239
+ with open(initial_graph["source"]["kg_text_path"], "rb") as f:
240
+ initial_graph["text"] = pickle.load(f)
241
+
242
+ # Prepare prompt construction along with a list of endotypes
243
+ if len(state["uploaded_files"]) != 0 and "endotype" in [
244
+ f["file_type"] for f in state["uploaded_files"]
245
+ ]:
246
+ prompt = self.perform_endotype_filtering(prompt, state, cfg)
247
+
248
+ # Prepare embedding model and embed the user prompt as query
249
+ query_emb = torch.tensor(
250
+ EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
251
+ ).float()
252
+
253
+ # Prepare the PCSTPruning object and extract the subgraph
254
+ # Parameters were set in the configuration file obtained from Hydra
255
+ subgraph = PCSTPruning(
256
+ state["topk_nodes"],
257
+ state["topk_edges"],
258
+ cfg.cost_e,
259
+ cfg.c_const,
260
+ cfg.root,
261
+ cfg.num_clusters,
262
+ cfg.pruning,
263
+ cfg.verbosity_level,
264
+ ).extract_subgraph(initial_graph["pyg"], query_emb)
265
+
266
+ # Prepare subgraph as a NetworkX graph and textualized graph
267
+ final_subgraph = self.prepare_final_subgraph(
268
+ subgraph, initial_graph["pyg"], initial_graph["text"]
269
+ )
270
+
271
+ # Prepare the dictionary of extracted graph
272
+ dic_extracted_graph = {
273
+ "name": arg_data.extraction_name,
274
+ "tool_call_id": tool_call_id,
275
+ "graph_source": initial_graph["source"]["name"],
276
+ "topk_nodes": state["topk_nodes"],
277
+ "topk_edges": state["topk_edges"],
278
+ "graph_dict": {
279
+ "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
280
+ "edges": list(final_subgraph["graph_nx"].edges(data=True)),
281
+ },
282
+ "graph_text": final_subgraph["graph_text"],
283
+ "graph_summary": None,
284
+ }
285
+
286
+ # Prepare the dictionary of updated state
287
+ dic_updated_state_for_model = {}
288
+ for key, value in {
289
+ "dic_extracted_graph": [dic_extracted_graph],
290
+ }.items():
291
+ if value:
292
+ dic_updated_state_for_model[key] = value
293
+
294
+ # Return the updated state of the tool
295
+ return Command(
296
+ update=dic_updated_state_for_model | {
297
+ # update the message history
298
+ "messages": [
299
+ ToolMessage(
300
+ content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
301
+ tool_call_id=tool_call_id,
302
+ )
303
+ ],
304
+ }
305
+ )