aiagents4pharma 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
  2. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
  3. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
  4. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
  5. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  6. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  7. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
  8. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  9. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
  10. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
  11. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
  12. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  13. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  14. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  15. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  16. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  17. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  18. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
  19. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
  20. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
  21. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
  22. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
  23. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
  24. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
  25. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
  26. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
  27. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
  28. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
  29. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
  30. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  31. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
  32. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
  33. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
  34. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  35. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
  36. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
  37. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
  38. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +3 -1
  39. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +42 -10
  40. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
  41. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
  42. {aiagents4pharma-1.18.0.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,21 @@ import pytest
6
6
  import ollama
7
7
  from ..utils.enrichments.ollama import EnrichmentWithOllama
8
8
 
9
+
9
10
  @pytest.fixture(name="ollama_config")
10
11
  def fixture_ollama_config():
11
12
  """Return a dictionary with Ollama configuration."""
12
13
  return {
13
- "model_name": "smollm2:360m",
14
+ "model_name": "llama3.2:1b",
14
15
  "prompt_enrichment": """
15
- Given the input as a list of strings, please return the list of addditional information of
16
- each input terms using your prior knowledge.
16
+ Given the input as a list of strings, please return the list of addditional information
17
+ of each input terms using your prior knowledge.
17
18
 
18
19
  Example:
19
20
  Input: ['acetaminophen', 'aspirin']
20
- Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
+ Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
22
  'aspirin is a medication used to treat pain, fever, and inflammation']
22
-
23
+
23
24
  Do not include any pretext as the output, only the list of strings enriched.
24
25
 
25
26
  Input: {input}
@@ -28,10 +29,11 @@ def fixture_ollama_config():
28
29
  "streaming": False,
29
30
  }
30
31
 
32
+
31
33
  def test_no_model_ollama(ollama_config):
32
34
  """Test the case when the Ollama model is not available."""
33
35
  cfg = ollama_config
34
- cfg_model = "smollm2:135m" # Choose a small model
36
+ cfg_model = "smollm2:135m" # Choose a small model
35
37
 
36
38
  # Delete the Ollama model
37
39
  try:
@@ -41,7 +43,8 @@ def test_no_model_ollama(ollama_config):
41
43
 
42
44
  # Check if the model is available
43
45
  with pytest.raises(
44
- ValueError, match=f"Error: Pulled {cfg_model} model and restarted Ollama server."
46
+ ValueError,
47
+ match=f"Error: Pulled {cfg_model} model and restarted Ollama server.",
45
48
  ):
46
49
  EnrichmentWithOllama(
47
50
  model_name=cfg_model,
@@ -51,7 +54,8 @@ def test_no_model_ollama(ollama_config):
51
54
  )
52
55
  ollama.delete(cfg_model)
53
56
 
54
- def test_enrich_nodes_ollama(ollama_config):
57
+
58
+ def test_enrich_ollama(ollama_config):
55
59
  """Test the Ollama textual enrichment class for node enrichment."""
56
60
  # Prepare enrichment model
57
61
  cfg = ollama_config
@@ -63,37 +67,11 @@ def test_enrich_nodes_ollama(ollama_config):
63
67
  )
64
68
 
65
69
  # Perform enrichment for nodes
66
- nodes = ["Adalimumab", "Infliximab"]
70
+ nodes = ["acetaminophen"]
67
71
  enriched_nodes = enr_model.enrich_documents(nodes)
68
72
  # Check the enriched nodes
69
- assert len(enriched_nodes) == 2
70
- assert all(
71
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
72
- )
73
-
74
-
75
- def test_enrich_relations_ollama(ollama_config):
76
- """Test the Ollama textual enrichment class for relation enrichment."""
77
- # Prepare enrichment model
78
- cfg = ollama_config
79
- enr_model = EnrichmentWithOllama(
80
- model_name=cfg["model_name"],
81
- prompt_enrichment=cfg["prompt_enrichment"],
82
- temperature=cfg["temperature"],
83
- streaming=cfg["streaming"],
84
- )
85
- # Perform enrichment for relations
86
- relations = [
87
- "IL23R-gene causation disease-inflammatory bowel diseases",
88
- "NOD2-gene causation disease-inflammatory bowel diseases",
89
- ]
90
- enriched_relations = enr_model.enrich_documents(relations)
91
- # Check the enriched relations
92
- assert len(enriched_relations) == 2
93
- assert all(
94
- enriched_relations[i] != relations[i]
95
- for i in range(len(relations))
96
- )
73
+ assert len(enriched_nodes) == 1
74
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
97
75
 
98
76
 
99
77
  def test_enrich_ollama_rag(ollama_config):
@@ -107,11 +85,9 @@ def test_enrich_ollama_rag(ollama_config):
107
85
  streaming=cfg["streaming"],
108
86
  )
109
87
  # Perform enrichment for nodes
110
- nodes = ["Adalimumab", "Infliximab"]
88
+ nodes = ["acetaminophen"]
111
89
  docs = [r"\path\to\doc1", r"\path\to\doc2"]
112
90
  enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
113
91
  # Check the enriched nodes
114
- assert len(enriched_nodes) == 2
115
- assert all(
116
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
117
- )
92
+ assert len(enriched_nodes) == 1
93
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
@@ -0,0 +1,79 @@
1
+ """
2
+ Test cases for utils/kg_utils.py
3
+ """
4
+
5
+ import pytest
6
+ import networkx as nx
7
+ import pandas as pd
8
+ from ..utils import kg_utils
9
+
10
+
11
+ @pytest.fixture(name="sample_graph")
12
+ def make_sample_graph():
13
+ """Return a sample graph"""
14
+ sg = nx.Graph()
15
+ sg.add_node(1, node_id=1, feature_id="A", feature_value="ValueA")
16
+ sg.add_node(2, node_id=2, feature_id="B", feature_value="ValueB")
17
+ sg.add_edge(1, 2, edge_id=1, feature_id="E", feature_value="EdgeValue")
18
+ return sg
19
+
20
+
21
+ def test_kg_to_df_pandas(sample_graph):
22
+ """Test the kg_to_df_pandas function"""
23
+ df_nodes, df_edges = kg_utils.kg_to_df_pandas(sample_graph)
24
+ print(df_nodes)
25
+ expected_nodes_data = {
26
+ "node_id": [1, 2],
27
+ "feature_id": ["A", "B"],
28
+ "feature_value": ["ValueA", "ValueB"],
29
+ }
30
+ expected_nodes_df = pd.DataFrame(expected_nodes_data, index=[1, 2])
31
+ print(expected_nodes_df)
32
+ expected_edges_data = {
33
+ "node_source": [1],
34
+ "node_target": [2],
35
+ "edge_id": [1],
36
+ "feature_id": ["E"],
37
+ "feature_value": ["EdgeValue"],
38
+ }
39
+ expected_edges_df = pd.DataFrame(expected_edges_data)
40
+
41
+ # Assert that the dataframes are equal but the order of columns may be different
42
+ # Ignore the index of the dataframes
43
+ pd.testing.assert_frame_equal(df_nodes, expected_nodes_df, check_like=True)
44
+ pd.testing.assert_frame_equal(df_edges, expected_edges_df, check_like=True)
45
+
46
+
47
+ def test_df_pandas_to_kg():
48
+ """Test the df_pandas_to_kg function"""
49
+ nodes_data = {
50
+ "node_id": [1, 2],
51
+ "feature_id": ["A", "B"],
52
+ "feature_value": ["ValueA", "ValueB"],
53
+ }
54
+ df_nodes_attrs = pd.DataFrame(nodes_data).set_index("node_id")
55
+
56
+ edges_data = {
57
+ "node_source": [1],
58
+ "node_target": [2],
59
+ "edge_id": [1],
60
+ "feature_id": ["E"],
61
+ "feature_value": ["EdgeValue"],
62
+ }
63
+ df_edges = pd.DataFrame(edges_data)
64
+
65
+ kg = kg_utils.df_pandas_to_kg(
66
+ df_edges, df_nodes_attrs, "node_source", "node_target"
67
+ )
68
+
69
+ assert len(kg.nodes) == 2
70
+ assert len(kg.edges) == 1
71
+
72
+ assert kg.nodes[1]["feature_id"] == "A"
73
+ assert kg.nodes[1]["feature_value"] == "ValueA"
74
+ assert kg.nodes[2]["feature_id"] == "B"
75
+ assert kg.nodes[2]["feature_value"] == "ValueB"
76
+
77
+ assert kg.edges[1, 2]["feature_id"] == "E"
78
+ assert kg.edges[1, 2]["feature_value"] == "EdgeValue"
79
+ assert kg.edges[1, 2]["edge_id"] == 1
@@ -0,0 +1,6 @@
1
+ '''
2
+ This file is used to import all the models in the package.
3
+ '''
4
+ from . import subgraph_extraction
5
+ from . import subgraph_summarization
6
+ from . import graphrag_reasoning
@@ -0,0 +1,143 @@
1
+ """
2
+ Tool for performing Graph RAG reasoning.
3
+ """
4
+
5
+ import logging
6
+ from typing import Type, Annotated
7
+ from pydantic import BaseModel, Field
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.messages import ToolMessage
10
+ from langchain_core.tools.base import InjectedToolCallId
11
+ from langchain_core.tools import BaseTool
12
+ from langchain_core.vectorstores import InMemoryVectorStore
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.document_loaders import PyPDFLoader
17
+ from langgraph.types import Command
18
+ from langgraph.prebuilt import InjectedState
19
+ import hydra
20
+
21
+ # Initialize logger
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class GraphRAGReasoningInput(BaseModel):
27
+ """
28
+ GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.
29
+
30
+ Args:
31
+ state: Injected state.
32
+ prompt: Prompt to interact with the backend.
33
+ extraction_name: Name assigned to the subgraph extraction process
34
+ """
35
+
36
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
37
+ description="Tool call ID."
38
+ )
39
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
40
+ prompt: str = Field(description="Prompt to interact with the backend.")
41
+ extraction_name: str = Field(
42
+ description="""Name assigned to the subgraph extraction process
43
+ when the subgraph_extraction tool is invoked."""
44
+ )
45
+
46
+
47
+ class GraphRAGReasoningTool(BaseTool):
48
+ """
49
+ This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
50
+ over user's request by considering textualized subgraph context and document context.
51
+ """
52
+
53
+ name: str = "graphrag_reasoning"
54
+ description: str = """A tool to perform reasoning using a Graph RAG approach
55
+ by considering textualized subgraph context and document context."""
56
+ args_schema: Type[BaseModel] = GraphRAGReasoningInput
57
+
58
+ def _run(
59
+ self,
60
+ tool_call_id: Annotated[str, InjectedToolCallId],
61
+ state: Annotated[dict, InjectedState],
62
+ prompt: str,
63
+ extraction_name: str,
64
+ ):
65
+ """
66
+ Run the Graph RAG reasoning tool.
67
+
68
+ Args:
69
+ tool_call_id: The tool call ID.
70
+ state: The injected state.
71
+ prompt: The prompt to interact with the backend.
72
+ extraction_name: The name assigned to the subgraph extraction process.
73
+ """
74
+ logger.log(
75
+ logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
76
+ )
77
+
78
+ # Load Hydra configuration
79
+ with hydra.initialize(version_base=None, config_path="../configs"):
80
+ cfg = hydra.compose(
81
+ config_name="config", overrides=["tools/graphrag_reasoning=default"]
82
+ )
83
+ cfg = cfg.tools.graphrag_reasoning
84
+
85
+ # Prepare documents
86
+ all_docs = []
87
+ if len(state["uploaded_files"]) != 0:
88
+ for uploaded_file in state["uploaded_files"]:
89
+ if uploaded_file["file_type"] == "drug_data":
90
+ # Load documents
91
+ raw_documents = PyPDFLoader(
92
+ file_path=uploaded_file["file_path"]
93
+ ).load()
94
+
95
+ # Split documents
96
+ # May need to find an optimal chunk size and overlap configuration
97
+ documents = RecursiveCharacterTextSplitter(
98
+ chunk_size=cfg.splitter_chunk_size,
99
+ chunk_overlap=cfg.splitter_chunk_overlap,
100
+ ).split_documents(raw_documents)
101
+
102
+ # Add documents to the list
103
+ all_docs.extend(documents)
104
+
105
+ # Load the extracted graph
106
+ extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
107
+ # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)
108
+
109
+ # Set another prompt template
110
+ prompt_template = ChatPromptTemplate.from_messages(
111
+ [("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
112
+ )
113
+
114
+ # Prepare chain with retrieved documents
115
+ qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
116
+ rag_chain = create_retrieval_chain(
117
+ InMemoryVectorStore.from_documents(
118
+ documents=all_docs, embedding=state["embedding_model"]
119
+ ).as_retriever(
120
+ search_type=cfg.retriever_search_type,
121
+ search_kwargs={
122
+ "k": cfg.retriever_k,
123
+ "fetch_k": cfg.retriever_fetch_k,
124
+ "lambda_mult": cfg.retriever_lambda_mult,
125
+ },
126
+ ),
127
+ qa_chain,
128
+ )
129
+
130
+ # Invoke the chain
131
+ response = rag_chain.invoke(
132
+ {
133
+ "input": prompt,
134
+ "subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
135
+ }
136
+ )
137
+
138
+ return Command(
139
+ update={
140
+ # update the message history
141
+ "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
142
+ }
143
+ )
@@ -0,0 +1,22 @@
1
+ """
2
+ A utility module for defining the dataclasses
3
+ for the arguments to set up initial settings
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Annotated
8
+
9
+
10
+ @dataclass
11
+ class ArgumentData:
12
+ """
13
+ Dataclass for storing the argument data.
14
+ """
15
+
16
+ extraction_name: Annotated[
17
+ str,
18
+ """An AI assigned _ separated name of the subgraph extraction
19
+ based on human query and the context of the graph reasoning
20
+ experiment.
21
+ This must be set before the subgraph extraction is invoked.""",
22
+ ]
@@ -0,0 +1,305 @@
1
+ """
2
+ Tool for performing subgraph extraction.
3
+ """
4
+
5
+ from typing import Type, Annotated
6
+ import logging
7
+ import pickle
8
+ import numpy as np
9
+ import pandas as pd
10
+ import hydra
11
+ import networkx as nx
12
+ from pydantic import BaseModel, Field
13
+ from langchain.chains.retrieval import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_core.vectorstores import InMemoryVectorStore
17
+ from langchain_core.tools import BaseTool
18
+ from langchain_core.messages import ToolMessage
19
+ from langchain_core.tools.base import InjectedToolCallId
20
+ from langchain_community.document_loaders import PyPDFLoader
21
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
22
+ from langgraph.types import Command
23
+ from langgraph.prebuilt import InjectedState
24
+ import torch
25
+ from torch_geometric.data import Data
26
+ from ..utils.extractions.pcst import PCSTPruning
27
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
28
+ from .load_arguments import ArgumentData
29
+
30
+ # Initialize logger
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SubgraphExtractionInput(BaseModel):
36
+ """
37
+ SubgraphExtractionInput is a Pydantic model representing an input for extracting a subgraph.
38
+
39
+ Args:
40
+ prompt: Prompt to interact with the backend.
41
+ tool_call_id: Tool call ID.
42
+ state: Injected state.
43
+ arg_data: Argument for analytical process over graph data.
44
+ """
45
+
46
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
47
+ description="Tool call ID."
48
+ )
49
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
50
+ prompt: str = Field(description="Prompt to interact with the backend.")
51
+ arg_data: ArgumentData = Field(
52
+ description="Experiment over graph data.", default=None
53
+ )
54
+
55
+
56
+ class SubgraphExtractionTool(BaseTool):
57
+ """
58
+ This tool performs subgraph extraction based on user's prompt by taking into account
59
+ the top-k nodes and edges.
60
+ """
61
+
62
+ name: str = "subgraph_extraction"
63
+ description: str = "A tool for subgraph extraction based on user's prompt."
64
+ args_schema: Type[BaseModel] = SubgraphExtractionInput
65
+
66
+ def perform_endotype_filtering(
67
+ self,
68
+ prompt: str,
69
+ state: Annotated[dict, InjectedState],
70
+ cfg: hydra.core.config_store.ConfigStore,
71
+ ) -> str:
72
+ """
73
+ Perform endotype filtering based on the uploaded files and prepare the prompt.
74
+
75
+ Args:
76
+ prompt: The prompt to interact with the backend.
77
+ state: Injected state for the tool.
78
+ cfg: Hydra configuration object.
79
+ """
80
+ # Loop through the uploaded files
81
+ all_genes = []
82
+ for uploaded_file in state["uploaded_files"]:
83
+ if uploaded_file["file_type"] == "endotype":
84
+ # Load the PDF file
85
+ docs = PyPDFLoader(file_path=uploaded_file["file_path"]).load()
86
+
87
+ # Split the text into chunks
88
+ splits = RecursiveCharacterTextSplitter(
89
+ chunk_size=cfg.splitter_chunk_size,
90
+ chunk_overlap=cfg.splitter_chunk_overlap,
91
+ ).split_documents(docs)
92
+
93
+ # Create a chat prompt template
94
+ prompt_template = ChatPromptTemplate.from_messages(
95
+ [
96
+ ("system", cfg.prompt_endotype_filtering),
97
+ ("human", "{input}"),
98
+ ]
99
+ )
100
+
101
+ qa_chain = create_stuff_documents_chain(
102
+ state["llm_model"], prompt_template
103
+ )
104
+ rag_chain = create_retrieval_chain(
105
+ InMemoryVectorStore.from_documents(
106
+ documents=splits, embedding=state["embedding_model"]
107
+ ).as_retriever(
108
+ search_type=cfg.retriever_search_type,
109
+ search_kwargs={
110
+ "k": cfg.retriever_k,
111
+ "fetch_k": cfg.retriever_fetch_k,
112
+ "lambda_mult": cfg.retriever_lambda_mult,
113
+ },
114
+ ),
115
+ qa_chain,
116
+ )
117
+ results = rag_chain.invoke({"input": prompt})
118
+ all_genes.append(results["answer"])
119
+
120
+ # Prepare the prompt
121
+ if len(all_genes) > 0:
122
+ prompt = " ".join(
123
+ [prompt, cfg.prompt_endotype_addition, ", ".join(all_genes)]
124
+ )
125
+
126
+ return prompt
127
+
128
+ def prepare_final_subgraph(self,
129
+ subgraph: dict,
130
+ pyg_graph: Data,
131
+ textualized_graph: pd.DataFrame) -> dict:
132
+ """
133
+ Prepare the subgraph based on the extracted subgraph.
134
+
135
+ Args:
136
+ subgraph: The extracted subgraph.
137
+ pyg_graph: The PyTorch Geometric graph.
138
+ textualized_graph: The textualized graph.
139
+
140
+ Returns:
141
+ A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
142
+ """
143
+ # print(subgraph)
144
+ # Prepare the PyTorch Geometric graph
145
+ mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
146
+ pyg_graph = Data(
147
+ # Node features
148
+ x=pyg_graph.x[subgraph["nodes"]],
149
+ node_id=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
150
+ node_name=np.array(pyg_graph.node_id)[subgraph["nodes"]].tolist(),
151
+ enriched_node=np.array(pyg_graph.enriched_node)[subgraph["nodes"]].tolist(),
152
+ num_nodes=len(subgraph["nodes"]),
153
+ # Edge features
154
+ edge_index=torch.LongTensor(
155
+ [
156
+ [
157
+ mapping[i]
158
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][0].tolist()
159
+ ],
160
+ [
161
+ mapping[i]
162
+ for i in pyg_graph.edge_index[:, subgraph["edges"]][1].tolist()
163
+ ],
164
+ ]
165
+ ),
166
+ edge_attr=pyg_graph.edge_attr[subgraph["edges"]],
167
+ edge_type=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
168
+ relation=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
169
+ label=np.array(pyg_graph.edge_type)[subgraph["edges"]].tolist(),
170
+ enriched_edge=np.array(pyg_graph.enriched_edge)[subgraph["edges"]].tolist(),
171
+ )
172
+
173
+ # Networkx DiGraph construction to be visualized in the frontend
174
+ nx_graph = nx.DiGraph()
175
+ for n in pyg_graph.node_name:
176
+ nx_graph.add_node(n)
177
+ for i, e in enumerate(
178
+ [
179
+ [pyg_graph.node_name[i], pyg_graph.node_name[j]]
180
+ for (i, j) in pyg_graph.edge_index.transpose(1, 0)
181
+ ]
182
+ ):
183
+ nx_graph.add_edge(
184
+ e[0],
185
+ e[1],
186
+ relation=pyg_graph.edge_type[i],
187
+ label=pyg_graph.edge_type[i],
188
+ )
189
+
190
+ # Prepare the textualized subgraph
191
+ textualized_graph = (
192
+ textualized_graph["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
193
+ + "\n"
194
+ + textualized_graph["edges"].iloc[subgraph["edges"]].to_csv(index=False)
195
+ )
196
+
197
+ return {
198
+ "graph_pyg": pyg_graph,
199
+ "graph_nx": nx_graph,
200
+ "graph_text": textualized_graph,
201
+ }
202
+
203
+ def _run(
204
+ self,
205
+ tool_call_id: Annotated[str, InjectedToolCallId],
206
+ state: Annotated[dict, InjectedState],
207
+ prompt: str,
208
+ arg_data: ArgumentData = None,
209
+ ) -> Command:
210
+ """
211
+ Run the subgraph extraction tool.
212
+
213
+ Args:
214
+ tool_call_id: The tool call ID for the tool.
215
+ state: Injected state for the tool.
216
+ prompt: The prompt to interact with the backend.
217
+ arg_data (ArgumentData): The argument data.
218
+
219
+ Returns:
220
+ Command: The command to be executed.
221
+ """
222
+ logger.log(logging.INFO, "Invoking subgraph_extraction tool")
223
+
224
+ # Load hydra configuration
225
+ with hydra.initialize(version_base=None, config_path="../configs"):
226
+ cfg = hydra.compose(
227
+ config_name="config", overrides=["tools/subgraph_extraction=default"]
228
+ )
229
+ cfg = cfg.tools.subgraph_extraction
230
+
231
+ # Retrieve source graph from the state
232
+ initial_graph = {}
233
+ initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
234
+ # logger.log(logging.INFO, "Source graph: %s", source_graph)
235
+
236
+ # Load the knowledge graph
237
+ with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
238
+ initial_graph["pyg"] = pickle.load(f)
239
+ with open(initial_graph["source"]["kg_text_path"], "rb") as f:
240
+ initial_graph["text"] = pickle.load(f)
241
+
242
+ # Prepare prompt construction along with a list of endotypes
243
+ if len(state["uploaded_files"]) != 0 and "endotype" in [
244
+ f["file_type"] for f in state["uploaded_files"]
245
+ ]:
246
+ prompt = self.perform_endotype_filtering(prompt, state, cfg)
247
+
248
+ # Prepare embedding model and embed the user prompt as query
249
+ query_emb = torch.tensor(
250
+ EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)
251
+ ).float()
252
+
253
+ # Prepare the PCSTPruning object and extract the subgraph
254
+ # Parameters were set in the configuration file obtained from Hydra
255
+ subgraph = PCSTPruning(
256
+ state["topk_nodes"],
257
+ state["topk_edges"],
258
+ cfg.cost_e,
259
+ cfg.c_const,
260
+ cfg.root,
261
+ cfg.num_clusters,
262
+ cfg.pruning,
263
+ cfg.verbosity_level,
264
+ ).extract_subgraph(initial_graph["pyg"], query_emb)
265
+
266
+ # Prepare subgraph as a NetworkX graph and textualized graph
267
+ final_subgraph = self.prepare_final_subgraph(
268
+ subgraph, initial_graph["pyg"], initial_graph["text"]
269
+ )
270
+
271
+ # Prepare the dictionary of extracted graph
272
+ dic_extracted_graph = {
273
+ "name": arg_data.extraction_name,
274
+ "tool_call_id": tool_call_id,
275
+ "graph_source": initial_graph["source"]["name"],
276
+ "topk_nodes": state["topk_nodes"],
277
+ "topk_edges": state["topk_edges"],
278
+ "graph_dict": {
279
+ "nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
280
+ "edges": list(final_subgraph["graph_nx"].edges(data=True)),
281
+ },
282
+ "graph_text": final_subgraph["graph_text"],
283
+ "graph_summary": None,
284
+ }
285
+
286
+ # Prepare the dictionary of updated state
287
+ dic_updated_state_for_model = {}
288
+ for key, value in {
289
+ "dic_extracted_graph": [dic_extracted_graph],
290
+ }.items():
291
+ if value:
292
+ dic_updated_state_for_model[key] = value
293
+
294
+ # Return the updated state of the tool
295
+ return Command(
296
+ update=dic_updated_state_for_model | {
297
+ # update the message history
298
+ "messages": [
299
+ ToolMessage(
300
+ content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
301
+ tool_call_id=tool_call_id,
302
+ )
303
+ ],
304
+ }
305
+ )