aiagents4pharma 1.36.0__py3-none-any.whl → 1.38.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +12 -4
  2. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +2 -2
  3. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +7 -6
  4. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
  5. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
  6. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +1 -0
  7. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +12 -11
  8. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +152 -0
  9. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +36 -65
  10. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
  11. aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
  12. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
  13. aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +292 -0
  14. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -0
  15. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +33 -7
  16. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +59 -3
  17. aiagents4pharma/talk2scholars/tests/test_read_helper_utils.py +110 -0
  18. aiagents4pharma/talk2scholars/tests/test_s2_display.py +20 -1
  19. aiagents4pharma/talk2scholars/tests/test_s2_query.py +17 -0
  20. aiagents4pharma/talk2scholars/tests/test_state.py +25 -1
  21. aiagents4pharma/talk2scholars/tests/test_zotero_pdf_downloader_utils.py +46 -0
  22. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +35 -40
  23. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +62 -40
  24. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +6 -2
  25. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +2 -1
  26. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +7 -3
  27. aiagents4pharma/talk2scholars/tools/s2/search.py +2 -1
  28. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +2 -1
  29. aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +79 -136
  30. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +147 -0
  31. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +42 -9
  32. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/METADATA +2 -1
  33. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/RECORD +36 -29
  34. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/WHEEL +1 -1
  35. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/licenses/LICENSE +0 -0
  36. {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/top_level.txt +0 -0
@@ -20,12 +20,20 @@ def input_dict_fixture():
20
20
  input_dict = {
21
21
  "topk_nodes": 3,
22
22
  "topk_edges": 3,
23
+ "selections": {
24
+ "gene/protein": [],
25
+ "molecular_function": [],
26
+ "cellular_component": [],
27
+ "biological_process": [],
28
+ "drug": [],
29
+ "disease": []
30
+ },
23
31
  "uploaded_files": [],
24
32
  "dic_source_graph": [
25
33
  {
26
- "name": "PrimeKG",
27
- "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
28
- "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
34
+ "name": "BioBridge",
35
+ "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
36
+ "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
29
37
  }
30
38
  ],
31
39
  "dic_extracted_graph": []
@@ -70,7 +78,7 @@ def test_main_agent_invokes_t2kg(input_dict):
70
78
  current_state = app.get_state(config)
71
79
  dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
72
80
  assert isinstance(dic_extracted_graph, dict)
73
- assert dic_extracted_graph["graph_source"] == "PrimeKG"
81
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
74
82
  assert dic_extracted_graph["topk_nodes"] == 3
75
83
  assert dic_extracted_graph["topk_edges"] == 3
76
84
  assert isinstance(dic_extracted_graph["graph_dict"], dict)
@@ -9,7 +9,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
9
9
  from langgraph.checkpoint.memory import MemorySaver
10
10
  from langgraph.graph import START, StateGraph
11
11
  from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
12
- from ..tools.subgraph_extraction import SubgraphExtractionTool
12
+ from ..tools.multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
13
13
  from ..tools.subgraph_summarization import SubgraphSummarizationTool
14
14
  from ..tools.graphrag_reasoning import GraphRAGReasoningTool
15
15
  from ..states.state_talk2knowledgegraphs import Talk2KnowledgeGraphs
@@ -39,7 +39,7 @@ def get_app(uniq_id, llm_model: BaseChatModel):
39
39
  cfg = cfg.agents.t2kg_agent
40
40
 
41
41
  # Define the tools
42
- subgraph_extraction = SubgraphExtractionTool()
42
+ subgraph_extraction = MultimodalSubgraphExtractionTool()
43
43
  subgraph_summarization = SubgraphSummarizationTool()
44
44
  graphrag_reasoning = GraphRAGReasoningTool()
45
45
  tools = ToolNode([
@@ -2,12 +2,13 @@ _target_: app.frontend.streamlit_app_talk2knowledgegraphs
2
2
  default_user: "talk2kg_user"
3
3
  data_package_allowed_file_types:
4
4
  - "pdf"
5
- endotype_allowed_file_types:
6
- - "pdf"
5
+ multimodal_allowed_file_types:
6
+ - "xls"
7
+ - "xlsx"
7
8
  upload_data_dir: "../files"
8
9
  kg_name: "PrimeKG"
9
- kg_pyg_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/primekg_ibd_pyg_graph.pkl"
10
- kg_text_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/primekg_ibd_text_graph.pkl"
10
+ kg_pyg_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_pyg_graph.pkl"
11
+ kg_text_path: "aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_text_graph.pkl"
11
12
  openai_api_key: ${oc.env:OPENAI_API_KEY}
12
13
  openai_llms:
13
14
  - "gpt-4o-mini"
@@ -23,9 +24,9 @@ ollama_embeddings:
23
24
  - "nomic-embed-text"
24
25
  temperature: 0.1
25
26
  streaming: False
26
- reasoning_subgraph_topk_nodes: 10
27
+ reasoning_subgraph_topk_nodes: 5
27
28
  reasoning_subgraph_topk_nodes_min: 1
28
29
  reasoning_subgraph_topk_nodes_max: 100
29
- reasoning_subgraph_topk_edges: 10
30
+ reasoning_subgraph_topk_edges: 5
30
31
  reasoning_subgraph_topk_edges_min: 1
31
32
  reasoning_subgraph_topk_edges_max: 100
@@ -2,6 +2,7 @@ defaults:
2
2
  - _self_
3
3
  - agents/t2kg_agent: default
4
4
  - tools/subgraph_extraction: default
5
+ - tools/multimodal_subgraph_extraction: default
5
6
  - tools/subgraph_summarization: default
6
7
  - tools/graphrag_reasoning: default
7
8
  - utils/pubchem_utils: default
@@ -31,6 +31,7 @@ class Talk2KnowledgeGraphs(AgentState):
31
31
 
32
32
  llm_model: BaseChatModel
33
33
  embedding_model: Embeddings
34
+ selections: dict
34
35
  uploaded_files: list
35
36
  topk_nodes: int
36
37
  topk_edges: int
@@ -19,6 +19,14 @@ def input_dict_fixture():
19
19
  input_dict = {
20
20
  "llm_model": None, # TBA for each test case
21
21
  "embedding_model": None, # TBA for each test case
22
+ "selections": {
23
+ "gene/protein": [],
24
+ "molecular_function": [],
25
+ "cellular_component": [],
26
+ "biological_process": [],
27
+ "drug": [],
28
+ "disease": []
29
+ },
22
30
  "uploaded_files": [
23
31
  {
24
32
  "file_name": "adalimumab.pdf",
@@ -27,21 +35,14 @@ def input_dict_fixture():
27
35
  "uploaded_by": "VPEUser",
28
36
  "uploaded_timestamp": "2024-11-05 00:00:00",
29
37
  },
30
- {
31
- "file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
32
- "file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
33
- "file_type": "endotype",
34
- "uploaded_by": "VPEUser",
35
- "uploaded_timestamp": "2024-11-05 00:00:00",
36
- },
37
38
  ],
38
39
  "topk_nodes": 3,
39
40
  "topk_edges": 3,
40
41
  "dic_source_graph": [
41
42
  {
42
- "name": "PrimeKG",
43
- "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
44
- "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
43
+ "name": "BioBridge",
44
+ "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
45
+ "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
45
46
  }
46
47
  ],
47
48
  "dic_extracted_graph": []
@@ -96,7 +97,7 @@ def test_t2kg_agent_openai(input_dict):
96
97
  dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
97
98
  assert isinstance(dic_extracted_graph, dict)
98
99
  assert dic_extracted_graph["name"] == "subkg_12345"
99
- assert dic_extracted_graph["graph_source"] == "PrimeKG"
100
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
100
101
  assert dic_extracted_graph["topk_nodes"] == 3
101
102
  assert dic_extracted_graph["topk_edges"] == 3
102
103
  assert isinstance(dic_extracted_graph["graph_dict"], dict)
@@ -0,0 +1,152 @@
1
+ """
2
+ Test cases for tools/subgraph_extraction.py
3
+ """
4
+
5
+ import pytest
6
+ # from langchain_openai import ChatOpenAI, OpenAIEmbeddings
7
+ from ..tools.multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
8
+
9
+ # Define the data path
10
+ DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
11
+
12
+
13
+ @pytest.fixture(name="agent_state")
14
+ def agent_state_fixture():
15
+ """
16
+ Agent state fixture.
17
+ """
18
+ agent_state = {
19
+ # "llm_model": ChatOpenAI(model="gpt-4o-mini", temperature=0.0),
20
+ # "embedding_model": OpenAIEmbeddings(model="text-embedding-3-small"),
21
+ "selections": {
22
+ "gene/protein": [],
23
+ "molecular_function": [],
24
+ "cellular_component": [],
25
+ "biological_process": [],
26
+ "drug": [],
27
+ "disease": []
28
+ },
29
+ "uploaded_files": [],
30
+ "topk_nodes": 3,
31
+ "topk_edges": 3,
32
+ "dic_source_graph": [
33
+ {
34
+ "name": "BioBridge",
35
+ "kg_pyg_path": f"{DATA_PATH}/biobridge_multimodal_pyg_graph.pkl",
36
+ "kg_text_path": f"{DATA_PATH}/biobridge_multimodal_text_graph.pkl",
37
+ }
38
+ ],
39
+ }
40
+
41
+ return agent_state
42
+
43
+
44
+ def test_extract_multimodal_subgraph_wo_doc(agent_state):
45
+ """
46
+ Test the multimodal subgraph extraction tool for only text as modality.
47
+
48
+ Args:
49
+ agent_state: Agent state in the form of a dictionary.
50
+ """
51
+ prompt = """
52
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
53
+ (IBD) that existed in the knowledge graph.
54
+ Please set the extraction name for this process as `subkg_12345`.
55
+ """
56
+
57
+ # Instantiate the MultimodalSubgraphExtractionTool
58
+ subgraph_extraction_tool = MultimodalSubgraphExtractionTool()
59
+
60
+ # Invoking the subgraph_extraction_tool
61
+ response = subgraph_extraction_tool.invoke(
62
+ input={"prompt": prompt,
63
+ "tool_call_id": "subgraph_extraction_tool",
64
+ "state": agent_state,
65
+ "arg_data": {"extraction_name": "subkg_12345"}})
66
+
67
+ # Check tool message
68
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
69
+
70
+ # Check extracted subgraph dictionary
71
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
72
+ assert isinstance(dic_extracted_graph, dict)
73
+ assert dic_extracted_graph["name"] == "subkg_12345"
74
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
75
+ assert dic_extracted_graph["topk_nodes"] == 3
76
+ assert dic_extracted_graph["topk_edges"] == 3
77
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
78
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
79
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
80
+ assert isinstance(dic_extracted_graph["graph_text"], str)
81
+ # Check if the nodes are in the graph_text
82
+ assert all(
83
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
84
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
85
+ )
86
+ # Check if the edges are in the graph_text
87
+ assert all(
88
+ ",".join([e[0], str(tuple(e[2]["relation"])), e[1]])
89
+ in dic_extracted_graph["graph_text"].replace('"', '')
90
+ for e in dic_extracted_graph["graph_dict"]["edges"]
91
+ )
92
+
93
+
94
+ def test_extract_multimodal_subgraph_w_doc(agent_state):
95
+ """
96
+ Test the multimodal subgraph extraction tool for text as modality, plus genes.
97
+
98
+ Args:
99
+ agent_state: Agent state in the form of a dictionary.
100
+ """
101
+ # Update state
102
+ agent_state["uploaded_files"] = [
103
+ {
104
+ "file_name": "multimodal-analysis.xlsx",
105
+ "file_path": f"{DATA_PATH}/multimodal-analysis.xlsx",
106
+ "file_type": "multimodal",
107
+ "uploaded_by": "VPEUser",
108
+ "uploaded_timestamp": "2025-05-12 00:00:00",
109
+ }
110
+ ]
111
+
112
+ prompt = """
113
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
114
+ (IBD) that existed in the knowledge graph.
115
+ Please set the extraction name for this process as `subkg_12345`.
116
+ """
117
+
118
+ # Instantiate the SubgraphExtractionTool
119
+ subgraph_extraction_tool = MultimodalSubgraphExtractionTool()
120
+
121
+ # Invoking the subgraph_extraction_tool
122
+ response = subgraph_extraction_tool.invoke(
123
+ input={"prompt": prompt,
124
+ "tool_call_id": "subgraph_extraction_tool",
125
+ "state": agent_state,
126
+ "arg_data": {"extraction_name": "subkg_12345"}})
127
+
128
+ # Check tool message
129
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
130
+
131
+ # Check extracted subgraph dictionary
132
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
133
+ assert isinstance(dic_extracted_graph, dict)
134
+ assert dic_extracted_graph["name"] == "subkg_12345"
135
+ assert dic_extracted_graph["graph_source"] == "BioBridge"
136
+ assert dic_extracted_graph["topk_nodes"] == 3
137
+ assert dic_extracted_graph["topk_edges"] == 3
138
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
139
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
140
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
141
+ assert isinstance(dic_extracted_graph["graph_text"], str)
142
+ # Check if the nodes are in the graph_text
143
+ assert all(
144
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
145
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
146
+ )
147
+ # Check if the edges are in the graph_text
148
+ assert all(
149
+ ",".join([e[0], str(tuple(e[2]["relation"])), e[1]])
150
+ in dic_extracted_graph["graph_text"].replace('"', '')
151
+ for e in dic_extracted_graph["graph_dict"]["edges"]
152
+ )
@@ -3,22 +3,21 @@ Test cases for tools/subgraph_extraction.py
3
3
  """
4
4
 
5
5
  import pytest
6
- from langchain_core.messages import HumanMessage
7
6
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
- from ..agents.t2kg_agent import get_app
7
+ from ..tools.subgraph_extraction import SubgraphExtractionTool
9
8
 
10
9
  # Define the data path
11
10
  DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
12
11
 
13
12
 
14
- @pytest.fixture(name="input_dict")
15
- def input_dict_fixture():
13
+ @pytest.fixture(name="agent_state")
14
+ def agent_state_fixture():
16
15
  """
17
- Input dictionary fixture.
16
+ Agent state fixture.
18
17
  """
19
- input_dict = {
20
- "llm_model": None, # TBA for each test case
21
- "embedding_model": None, # TBA for each test case
18
+ agent_state = {
19
+ "llm_model": ChatOpenAI(model="gpt-4o-mini", temperature=0.0),
20
+ "embedding_model": OpenAIEmbeddings(model="text-embedding-3-small"),
22
21
  "uploaded_files": [],
23
22
  "topk_nodes": 3,
24
23
  "topk_edges": 3,
@@ -31,52 +30,37 @@ def input_dict_fixture():
31
30
  ],
32
31
  }
33
32
 
34
- return input_dict
33
+ return agent_state
35
34
 
36
35
 
37
- def test_extract_subgraph_wo_docs(input_dict):
36
+ def test_extract_subgraph_wo_docs(agent_state):
38
37
  """
39
38
  Test the subgraph extraction tool without any documents using OpenAI model.
40
39
 
41
40
  Args:
42
- input_dict: Input dictionary.
41
+ agent_state: Agent state in the form of a dictionary.
43
42
  """
44
- # Prepare LLM and embedding model
45
- input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
46
- input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
47
-
48
- # Setup the app
49
- unique_id = 12345
50
- app = get_app(unique_id, llm_model=input_dict["llm_model"])
51
- config = {"configurable": {"thread_id": unique_id}}
52
- # Update state
53
- app.update_state(
54
- config,
55
- input_dict,
56
- )
57
43
  prompt = """
58
- Please directly invoke `subgraph_extraction` tool without calling any other tools
59
- to respond to the following prompt:
60
-
61
44
  Extract all relevant information related to nodes of genes related to inflammatory bowel disease
62
45
  (IBD) that existed in the knowledge graph.
63
46
  Please set the extraction name for this process as `subkg_12345`.
64
47
  """
65
48
 
66
- # Test the tool subgraph_extraction
67
- response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
49
+ # Instantiate the SubgraphExtractionTool
50
+ subgraph_extraction_tool = SubgraphExtractionTool()
68
51
 
69
- # Check assistant message
70
- assistant_msg = response["messages"][-1].content
71
- assert isinstance(assistant_msg, str)
52
+ # Invoking the subgraph_extraction_tool
53
+ response = subgraph_extraction_tool.invoke(
54
+ input={"prompt": prompt,
55
+ "tool_call_id": "subgraph_extraction_tool",
56
+ "state": agent_state,
57
+ "arg_data": {"extraction_name": "subkg_12345"}})
72
58
 
73
59
  # Check tool message
74
- tool_msg = response["messages"][-2]
75
- assert tool_msg.name == "subgraph_extraction"
60
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
76
61
 
77
62
  # Check extracted subgraph dictionary
78
- current_state = app.get_state(config)
79
- dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
63
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
80
64
  assert isinstance(dic_extracted_graph, dict)
81
65
  assert dic_extracted_graph["name"] == "subkg_12345"
82
66
  assert dic_extracted_graph["graph_source"] == "PrimeKG"
@@ -99,24 +83,16 @@ def test_extract_subgraph_wo_docs(input_dict):
99
83
  )
100
84
 
101
85
 
102
- def test_extract_subgraph_w_docs(input_dict):
86
+ def test_extract_subgraph_w_docs(agent_state):
103
87
  """
104
- Test the subgraph extraction tool with a document as reference (i.e., endotype document)
105
- using OpenAI model.
88
+ As a knowledge graph agent, I would like you to call a tool called `subgraph_extraction`.
89
+ After calling the tool, restrain yourself to call any other tool.
106
90
 
107
91
  Args:
108
- input_dict: Input dictionary.
92
+ agent_state: Agent state in the form of a dictionary.
109
93
  """
110
- # Prepare LLM and embedding model
111
- input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
112
- input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
113
-
114
- # Setup the app
115
- unique_id = 12345
116
- app = get_app(unique_id, llm_model=input_dict["llm_model"])
117
- config = {"configurable": {"thread_id": unique_id}}
118
94
  # Update state
119
- input_dict["uploaded_files"] = [
95
+ agent_state["uploaded_files"] = [
120
96
  {
121
97
  "file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
122
98
  "file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
@@ -125,33 +101,28 @@ def test_extract_subgraph_w_docs(input_dict):
125
101
  "uploaded_timestamp": "2024-11-05 00:00:00",
126
102
  }
127
103
  ]
128
- app.update_state(
129
- config,
130
- input_dict,
131
- )
132
- prompt = """
133
- Please ONLY invoke `subgraph_extraction` tool without calling any other tools
134
- to respond to the following prompt:
135
104
 
105
+ prompt = """
136
106
  Extract all relevant information related to nodes of genes related to inflammatory bowel disease
137
107
  (IBD) that existed in the knowledge graph.
138
108
  Please set the extraction name for this process as `subkg_12345`.
139
109
  """
140
110
 
141
- # Test the tool subgraph_extraction
142
- response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
111
+ # Instantiate the SubgraphExtractionTool
112
+ subgraph_extraction_tool = SubgraphExtractionTool()
143
113
 
144
- # Check assistant message
145
- assistant_msg = response["messages"][-1].content
146
- assert isinstance(assistant_msg, str)
114
+ # Invoking the subgraph_extraction_tool
115
+ response = subgraph_extraction_tool.invoke(
116
+ input={"prompt": prompt,
117
+ "tool_call_id": "subgraph_extraction_tool",
118
+ "state": agent_state,
119
+ "arg_data": {"extraction_name": "subkg_12345"}})
147
120
 
148
121
  # Check tool message
149
- tool_msg = response["messages"][-2]
150
- assert tool_msg.name == "subgraph_extraction"
122
+ assert response.update["messages"][-1].tool_call_id == "subgraph_extraction_tool"
151
123
 
152
124
  # Check extracted subgraph dictionary
153
- current_state = app.get_state(config)
154
- dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
125
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
155
126
  assert isinstance(dic_extracted_graph, dict)
156
127
  assert dic_extracted_graph["name"] == "subkg_12345"
157
128
  assert dic_extracted_graph["graph_source"] == "PrimeKG"
@@ -2,5 +2,6 @@
2
2
  This file is used to import all the models in the package.
3
3
  '''
4
4
  from . import subgraph_extraction
5
+ from . import multimodal_subgraph_extraction
5
6
  from . import subgraph_summarization
6
7
  from . import graphrag_reasoning