aiagents4pharma 1.30.0__py3-none-any.whl → 1.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/agents/main_agent.py +18 -10
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +5 -6
- aiagents4pharma/talk2scholars/agents/pdf_agent.py +4 -10
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +18 -9
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +2 -2
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +6 -1
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +7 -1
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +6 -1
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -1
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +4 -1
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +84 -53
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +24 -0
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +79 -15
- aiagents4pharma/talk2scholars/tests/test_routing_logic.py +12 -8
- aiagents4pharma/talk2scholars/tests/test_s2_multi.py +27 -4
- aiagents4pharma/talk2scholars/tests/test_s2_search.py +19 -3
- aiagents4pharma/talk2scholars/tests/test_s2_single.py +27 -3
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +17 -10
- aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +2 -0
- aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +11 -4
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +5 -1
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -26
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +46 -22
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +1 -1
- aiagents4pharma/talk2scholars/tools/s2/search.py +40 -12
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +42 -16
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +33 -16
- aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +39 -7
- {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/METADATA +2 -2
- {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/RECORD +34 -34
- {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/WHEEL +1 -1
- {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,7 @@ def mock_router():
|
|
23
23
|
|
24
24
|
def mock_supervisor_node(state):
|
25
25
|
query = state["messages"][-1].content.lower()
|
26
|
-
|
27
|
-
# Expanded keyword matching for S2 Agent
|
26
|
+
# Define keywords for each sub-agent.
|
28
27
|
s2_keywords = [
|
29
28
|
"paper",
|
30
29
|
"research",
|
@@ -34,13 +33,19 @@ def mock_router():
|
|
34
33
|
"references",
|
35
34
|
]
|
36
35
|
zotero_keywords = ["zotero", "library", "saved papers", "academic library"]
|
36
|
+
pdf_keywords = ["pdf", "document", "read pdf"]
|
37
|
+
paper_download_keywords = ["download", "arxiv", "fetch paper", "paper download"]
|
37
38
|
|
39
|
+
# Priority ordering: Zotero, then paper download, then PDF, then S2.
|
38
40
|
if any(keyword in query for keyword in zotero_keywords):
|
39
41
|
return Command(goto="zotero_agent")
|
42
|
+
if any(keyword in query for keyword in paper_download_keywords):
|
43
|
+
return Command(goto="paper_download_agent")
|
44
|
+
if any(keyword in query for keyword in pdf_keywords):
|
45
|
+
return Command(goto="pdf_agent")
|
40
46
|
if any(keyword in query for keyword in s2_keywords):
|
41
47
|
return Command(goto="s2_agent")
|
42
|
-
|
43
|
-
# If no match, default to ending the conversation
|
48
|
+
# Default to end if no keyword matches.
|
44
49
|
return Command(goto=END)
|
45
50
|
|
46
51
|
return mock_supervisor_node
|
@@ -55,10 +60,9 @@ def mock_router():
|
|
55
60
|
("Fetch my academic library.", "zotero_agent"),
|
56
61
|
("Retrieve citations.", "s2_agent"),
|
57
62
|
("Can you get journal articles?", "s2_agent"),
|
58
|
-
(
|
59
|
-
|
60
|
-
|
61
|
-
), # NEW: Should trigger the `END` case
|
63
|
+
("I want to read the PDF document.", "pdf_agent"),
|
64
|
+
("Download the paper from arxiv.", "paper_download_agent"),
|
65
|
+
("Completely unrelated query.", "__end__"),
|
62
66
|
],
|
63
67
|
)
|
64
68
|
def test_routing_logic(mock_state, mock_router, user_query, expected_agent):
|
@@ -94,7 +94,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
|
|
94
94
|
{
|
95
95
|
"paperId": "paperA",
|
96
96
|
"title": "Multi Rec Paper A",
|
97
|
-
"authors": ["Author X"],
|
97
|
+
"authors": [{"name": "Author X", "authorId": "AX"}],
|
98
98
|
"year": 2019,
|
99
99
|
"citationCount": 12,
|
100
100
|
"url": "http://paperA",
|
@@ -103,7 +103,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
|
|
103
103
|
{
|
104
104
|
"paperId": "paperB",
|
105
105
|
"title": "Multi Rec Paper B",
|
106
|
-
"authors": ["Author Y"],
|
106
|
+
"authors": [{"name": "Author Y", "authorId": "AY"}],
|
107
107
|
"year": 2020,
|
108
108
|
"citationCount": 18,
|
109
109
|
"url": "http://paperB",
|
@@ -112,7 +112,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
|
|
112
112
|
{
|
113
113
|
"paperId": "paperC",
|
114
114
|
"title": "Multi Rec Paper C",
|
115
|
-
"authors": None, # This
|
115
|
+
"authors": None, # This paper should be filtered out.
|
116
116
|
"year": 2021,
|
117
117
|
"citationCount": 25,
|
118
118
|
"url": "http://paperC",
|
@@ -277,6 +277,29 @@ def test_multi_paper_rec_requests_exception(monkeypatch):
|
|
277
277
|
}
|
278
278
|
with pytest.raises(
|
279
279
|
RuntimeError,
|
280
|
-
match="Failed to connect to Semantic Scholar API
|
280
|
+
match="Failed to connect to Semantic Scholar API after 10 attempts."
|
281
|
+
"Please retry the same query.",
|
282
|
+
):
|
283
|
+
get_multi_paper_recommendations.run(input_data)
|
284
|
+
|
285
|
+
|
286
|
+
def test_multi_paper_rec_no_response(monkeypatch):
|
287
|
+
"""
|
288
|
+
Test that get_multi_paper_recommendations raises a RuntimeError when no response is obtained.
|
289
|
+
This is simulated by patching 'range' in the underlying function's globals to
|
290
|
+
return an empty iterator,
|
291
|
+
so that the for loop does not iterate and response remains None.
|
292
|
+
"""
|
293
|
+
# Patch 'range' in the underlying function's globals (accessed via .func.__globals__)
|
294
|
+
monkeypatch.setitem(
|
295
|
+
get_multi_paper_recommendations.func.__globals__, "range", lambda x: iter([])
|
296
|
+
)
|
297
|
+
tool_call_id = "test_tool_call_id"
|
298
|
+
input_data = {
|
299
|
+
"paper_ids": ["p1", "p2"],
|
300
|
+
"tool_call_id": tool_call_id,
|
301
|
+
}
|
302
|
+
with pytest.raises(
|
303
|
+
RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
|
281
304
|
):
|
282
305
|
get_multi_paper_recommendations.run(input_data)
|
@@ -85,7 +85,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
85
85
|
{
|
86
86
|
"paperId": "1",
|
87
87
|
"title": "Paper 1",
|
88
|
-
"authors": ["Author A"],
|
88
|
+
"authors": [{"name": "Author A", "authorId": "A1"}],
|
89
89
|
"year": 2020,
|
90
90
|
"citationCount": 10,
|
91
91
|
"url": "http://paper1",
|
@@ -94,7 +94,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
94
94
|
{
|
95
95
|
"paperId": "2",
|
96
96
|
"title": "Paper 2",
|
97
|
-
"authors": ["Author B"],
|
97
|
+
"authors": [{"name": "Author B", "authorId": "B1"}],
|
98
98
|
"year": 2021,
|
99
99
|
"citationCount": 20,
|
100
100
|
"url": "http://paper2",
|
@@ -256,7 +256,8 @@ def test_search_tool_requests_exception(monkeypatch):
|
|
256
256
|
tool_call_id = "test_tool_call_id"
|
257
257
|
with pytest.raises(
|
258
258
|
RuntimeError,
|
259
|
-
match="Failed to connect to Semantic Scholar API
|
259
|
+
match="Failed to connect to Semantic Scholar API after 10 attempts."
|
260
|
+
"Please retry the same query.",
|
260
261
|
):
|
261
262
|
search_tool.run(
|
262
263
|
{
|
@@ -264,3 +265,18 @@ def test_search_tool_requests_exception(monkeypatch):
|
|
264
265
|
"tool_call_id": tool_call_id,
|
265
266
|
}
|
266
267
|
)
|
268
|
+
|
269
|
+
|
270
|
+
def test_search_tool_no_response(monkeypatch):
|
271
|
+
"""
|
272
|
+
Test that search_tool raises a RuntimeError when no response is obtained.
|
273
|
+
This is simulated by patching 'range' in the original function's globals (a dict)
|
274
|
+
so that it returns an empty iterator, leaving response as None.
|
275
|
+
"""
|
276
|
+
# Patch 'range' in the original function's globals using setitem.
|
277
|
+
monkeypatch.setitem(search_tool.func.__globals__, "range", lambda x: iter([]))
|
278
|
+
tool_call_id = "test_tool_call_id"
|
279
|
+
with pytest.raises(
|
280
|
+
RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
|
281
|
+
):
|
282
|
+
search_tool.run({"query": "test", "tool_call_id": tool_call_id})
|
@@ -92,7 +92,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
92
92
|
{
|
93
93
|
"paperId": "paper1",
|
94
94
|
"title": "Recommended Paper 1",
|
95
|
-
"authors": ["Author A"],
|
95
|
+
"authors": [{"name": "Author A", "authorId": "A1"}],
|
96
96
|
"year": 2020,
|
97
97
|
"citationCount": 15,
|
98
98
|
"url": "http://paper1",
|
@@ -101,7 +101,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
101
101
|
{
|
102
102
|
"paperId": "paper2",
|
103
103
|
"title": "Recommended Paper 2",
|
104
|
-
"authors": ["Author B"],
|
104
|
+
"authors": [{"name": "Author B", "authorId": "B1"}],
|
105
105
|
"year": 2021,
|
106
106
|
"citationCount": 25,
|
107
107
|
"url": "http://paper2",
|
@@ -269,6 +269,30 @@ def test_single_paper_rec_requests_exception(monkeypatch):
|
|
269
269
|
}
|
270
270
|
with pytest.raises(
|
271
271
|
RuntimeError,
|
272
|
-
match="Failed to connect to Semantic Scholar API
|
272
|
+
match="Failed to connect to Semantic Scholar API after 10 attempts."
|
273
|
+
"Please retry the same query.",
|
274
|
+
):
|
275
|
+
get_single_paper_recommendations.run(input_data)
|
276
|
+
|
277
|
+
|
278
|
+
def test_single_paper_rec_no_response(monkeypatch):
|
279
|
+
"""
|
280
|
+
Test that get_single_paper_recommendations raises a RuntimeError
|
281
|
+
when no response is obtained from the API.
|
282
|
+
|
283
|
+
This is simulated by patching 'range' in the underlying function's globals
|
284
|
+
to return an empty iterator, so the for-loop never iterates and response remains None.
|
285
|
+
"""
|
286
|
+
# Patch 'range' in the underlying function's globals (accessed via .func.__globals__)
|
287
|
+
monkeypatch.setitem(
|
288
|
+
get_single_paper_recommendations.func.__globals__, "range", lambda x: iter([])
|
289
|
+
)
|
290
|
+
tool_call_id = "test_tool_call_id"
|
291
|
+
input_data = {
|
292
|
+
"paper_id": "12345",
|
293
|
+
"tool_call_id": tool_call_id,
|
294
|
+
}
|
295
|
+
with pytest.raises(
|
296
|
+
RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
|
273
297
|
):
|
274
298
|
get_single_paper_recommendations.run(input_data)
|
@@ -26,7 +26,7 @@ dummy_cfg = SimpleNamespace(tools=SimpleNamespace(zotero_read=dummy_zotero_read_
|
|
26
26
|
|
27
27
|
|
28
28
|
class TestZoteroSearchTool(unittest.TestCase):
|
29
|
-
"""
|
29
|
+
"""Tests for Zotero search tool."""
|
30
30
|
|
31
31
|
@patch(
|
32
32
|
"aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
|
@@ -41,7 +41,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
41
41
|
mock_zotero_class,
|
42
42
|
mock_get_item_collections,
|
43
43
|
):
|
44
|
-
"""
|
44
|
+
"""Test valid query returns correct Command output."""
|
45
45
|
# Setup Hydra mocks
|
46
46
|
mock_hydra_compose.return_value = dummy_cfg
|
47
47
|
mock_hydra_init.return_value.__enter__.return_value = None
|
@@ -116,7 +116,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
116
116
|
mock_zotero_class,
|
117
117
|
mock_get_item_collections,
|
118
118
|
):
|
119
|
-
"""
|
119
|
+
"""Test empty query fetches all items."""
|
120
120
|
mock_hydra_compose.return_value = dummy_cfg
|
121
121
|
mock_hydra_init.return_value.__enter__.return_value = None
|
122
122
|
|
@@ -166,7 +166,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
166
166
|
mock_zotero_class,
|
167
167
|
mock_get_item_collections,
|
168
168
|
):
|
169
|
-
"""
|
169
|
+
"""Test no items returned from Zotero."""
|
170
170
|
mock_hydra_compose.return_value = dummy_cfg
|
171
171
|
mock_hydra_init.return_value.__enter__.return_value = None
|
172
172
|
|
@@ -199,7 +199,10 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
199
199
|
mock_zotero_class,
|
200
200
|
mock_get_item_collections,
|
201
201
|
):
|
202
|
-
"""
|
202
|
+
"""
|
203
|
+
Test that when non-research items (e.g. attachments, notes) are returned,
|
204
|
+
they are still included since filtering is disabled.
|
205
|
+
"""
|
203
206
|
mock_hydra_compose.return_value = dummy_cfg
|
204
207
|
mock_hydra_init.return_value.__enter__.return_value = None
|
205
208
|
|
@@ -240,9 +243,13 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
240
243
|
"tool_call_id": tool_call_id,
|
241
244
|
"limit": 2,
|
242
245
|
}
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
+
# Instead of expecting a RuntimeError, we now expect both items to be returned.
|
247
|
+
result = zotero_search_tool.run(tool_input)
|
248
|
+
update = result.update
|
249
|
+
filtered_papers = update["zotero_read"]
|
250
|
+
self.assertIn("paper1", filtered_papers)
|
251
|
+
self.assertIn("paper2", filtered_papers)
|
252
|
+
self.assertEqual(len(filtered_papers), 2)
|
246
253
|
|
247
254
|
@patch(
|
248
255
|
"aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
|
@@ -257,7 +264,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
257
264
|
mock_zotero_class,
|
258
265
|
mock_get_item_collections,
|
259
266
|
):
|
260
|
-
"""
|
267
|
+
"""Test items API exception is properly raised."""
|
261
268
|
mock_hydra_compose.return_value = dummy_cfg
|
262
269
|
mock_hydra_init.return_value.__enter__.return_value = None
|
263
270
|
mock_get_item_collections.return_value = {}
|
@@ -306,7 +313,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
306
313
|
"url": "http://example.com",
|
307
314
|
"itemType": "journalArticle",
|
308
315
|
}
|
309
|
-
}, #
|
316
|
+
}, # Missing 'key' field
|
310
317
|
{
|
311
318
|
"data": {
|
312
319
|
"key": "paper_valid",
|
@@ -8,6 +8,7 @@ downloads the corresponding PDF.
|
|
8
8
|
By using an abstract base class, this implementation is extendable to other
|
9
9
|
APIs like PubMed, IEEE Xplore, etc.
|
10
10
|
"""
|
11
|
+
|
11
12
|
import xml.etree.ElementTree as ET
|
12
13
|
from typing import Any, Dict
|
13
14
|
import logging
|
@@ -19,6 +20,7 @@ from .abstract_downloader import AbstractPaperDownloader
|
|
19
20
|
logging.basicConfig(level=logging.INFO)
|
20
21
|
logger = logging.getLogger(__name__)
|
21
22
|
|
23
|
+
|
22
24
|
class ArxivPaperDownloader(AbstractPaperDownloader):
|
23
25
|
"""
|
24
26
|
Downloader class for arXiv.
|
@@ -35,13 +37,13 @@ class ArxivPaperDownloader(AbstractPaperDownloader):
|
|
35
37
|
"""
|
36
38
|
with hydra.initialize(version_base=None, config_path="../../configs"):
|
37
39
|
cfg = hydra.compose(
|
38
|
-
config_name="config",
|
39
|
-
overrides=["tools/download_arxiv_paper=default"]
|
40
|
+
config_name="config", overrides=["tools/download_arxiv_paper=default"]
|
40
41
|
)
|
41
42
|
self.api_url = cfg.tools.download_arxiv_paper.api_url
|
42
43
|
self.request_timeout = cfg.tools.download_arxiv_paper.request_timeout
|
43
44
|
self.chunk_size = cfg.tools.download_arxiv_paper.chunk_size
|
44
45
|
self.pdf_base_url = cfg.tools.download_arxiv_paper.pdf_base_url
|
46
|
+
|
45
47
|
def fetch_metadata(self, paper_id: str) -> Dict[str, Any]:
|
46
48
|
"""
|
47
49
|
Fetch metadata from arXiv for a given paper ID.
|
@@ -95,11 +97,16 @@ class ArxivPaperDownloader(AbstractPaperDownloader):
|
|
95
97
|
logger.info("Downloading PDF from: %s", pdf_url)
|
96
98
|
pdf_response = requests.get(pdf_url, stream=True, timeout=self.request_timeout)
|
97
99
|
pdf_response.raise_for_status()
|
100
|
+
# print (pdf_response)
|
98
101
|
|
99
102
|
# Combine the PDF data from chunks.
|
100
103
|
pdf_object = b"".join(
|
101
|
-
chunk
|
102
|
-
)
|
104
|
+
chunk
|
105
|
+
for chunk in pdf_response.iter_content(chunk_size=self.chunk_size)
|
106
|
+
if chunk
|
107
|
+
)
|
108
|
+
# print (pdf_object)
|
109
|
+
print("PDF_URL", pdf_url)
|
103
110
|
|
104
111
|
return {
|
105
112
|
"pdf_object": pdf_object,
|
@@ -14,16 +14,19 @@ from langgraph.types import Command
|
|
14
14
|
# Local import from the same package:
|
15
15
|
from .arxiv_downloader import ArxivPaperDownloader
|
16
16
|
|
17
|
+
|
17
18
|
class DownloadArxivPaperInput(BaseModel):
|
18
19
|
"""
|
19
20
|
Input schema for the arXiv paper download tool.
|
20
21
|
(Optional: if you decide to keep Pydantic validation in the future)
|
21
22
|
"""
|
23
|
+
|
22
24
|
arxiv_id: str = Field(
|
23
25
|
description="The arXiv paper ID used to retrieve the paper details and PDF."
|
24
|
-
|
26
|
+
)
|
25
27
|
tool_call_id: Annotated[str, InjectedToolCallId]
|
26
28
|
|
29
|
+
|
27
30
|
@tool(args_schema=DownloadArxivPaperInput, parse_docstring=True)
|
28
31
|
def download_arxiv_paper(
|
29
32
|
arxiv_id: str,
|
@@ -49,6 +52,7 @@ def download_arxiv_paper(
|
|
49
52
|
|
50
53
|
# If the downloader fails or the arxiv_id is invalid, this might raise an error
|
51
54
|
pdf_data = downloader.download_pdf(arxiv_id)
|
55
|
+
# print (pdf_data)
|
52
56
|
|
53
57
|
content = f"Successfully downloaded PDF for arXiv ID {arxiv_id}"
|
54
58
|
|
@@ -2,8 +2,8 @@
|
|
2
2
|
"""
|
3
3
|
question_and_answer: Tool for performing Q&A on PDF documents using retrieval augmented generation.
|
4
4
|
|
5
|
-
This module provides functionality to extract text from PDF binary data, split it into
|
6
|
-
chunks, retrieve relevant segments via a vector store, and generate an answer to a
|
5
|
+
This module provides functionality to extract text from PDF binary data, split it into
|
6
|
+
chunks, retrieve relevant segments via a vector store, and generate an answer to a
|
7
7
|
user-provided question using a language model chain.
|
8
8
|
"""
|
9
9
|
|
@@ -18,13 +18,15 @@ import hydra
|
|
18
18
|
from langchain.chains.question_answering import load_qa_chain
|
19
19
|
from langchain.docstore.document import Document
|
20
20
|
from langchain.text_splitter import CharacterTextSplitter
|
21
|
-
from langchain_community.vectorstores import Annoy
|
22
|
-
from langchain_openai import OpenAIEmbeddings
|
23
21
|
from langchain_core.language_models.chat_models import BaseChatModel
|
24
|
-
|
22
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
25
23
|
from langchain_core.messages import ToolMessage
|
26
24
|
from langchain_core.tools import tool
|
27
25
|
from langchain_core.tools.base import InjectedToolCallId
|
26
|
+
from langchain_core.embeddings import Embeddings
|
27
|
+
from langchain_community.vectorstores import Annoy
|
28
|
+
from langchain_community.document_loaders import PyPDFLoader
|
29
|
+
from langchain_openai import OpenAIEmbeddings
|
28
30
|
from langgraph.types import Command
|
29
31
|
from langgraph.prebuilt import InjectedState
|
30
32
|
|
@@ -35,10 +37,13 @@ logger.setLevel(logging.INFO)
|
|
35
37
|
|
36
38
|
# Load configuration using Hydra.
|
37
39
|
with hydra.initialize(version_base=None, config_path="../../configs"):
|
38
|
-
cfg = hydra.compose(
|
40
|
+
cfg = hydra.compose(
|
41
|
+
config_name="config", overrides=["tools/question_and_answer=default"]
|
42
|
+
)
|
39
43
|
cfg = cfg.tools.question_and_answer
|
40
44
|
logger.info("Loaded Question and Answer tool configuration.")
|
41
45
|
|
46
|
+
|
42
47
|
class QuestionAndAnswerInput(BaseModel):
|
43
48
|
"""
|
44
49
|
Input schema for the PDF Question and Answer tool.
|
@@ -47,12 +52,12 @@ class QuestionAndAnswerInput(BaseModel):
|
|
47
52
|
question (str): The question to ask regarding the PDF content.
|
48
53
|
tool_call_id (str): Unique identifier for the tool call, injected automatically.
|
49
54
|
"""
|
50
|
-
|
51
|
-
|
52
|
-
)
|
55
|
+
|
56
|
+
question: str = Field(description="The question to ask regarding the PDF content.")
|
53
57
|
tool_call_id: Annotated[str, InjectedToolCallId]
|
54
58
|
state: Annotated[dict, InjectedState]
|
55
59
|
|
60
|
+
|
56
61
|
def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
|
57
62
|
"""
|
58
63
|
Extract text content from PDF binary data.
|
@@ -73,7 +78,10 @@ def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
|
|
73
78
|
text += page_text
|
74
79
|
return text
|
75
80
|
|
76
|
-
|
81
|
+
|
82
|
+
def generate_answer(
|
83
|
+
question: str, pdf_bytes: bytes, llm_model: BaseChatModel
|
84
|
+
) -> Dict[str, Any]:
|
77
85
|
"""
|
78
86
|
Generate an answer for a question using retrieval augmented generation on PDF content.
|
79
87
|
|
@@ -92,9 +100,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
|
|
92
100
|
text = extract_text_from_pdf_data(pdf_bytes)
|
93
101
|
logger.info("Extracted text from PDF.")
|
94
102
|
text_splitter = CharacterTextSplitter(
|
95
|
-
separator="\n",
|
96
|
-
chunk_size=cfg.chunk_size,
|
97
|
-
chunk_overlap=cfg.chunk_overlap
|
103
|
+
separator="\n", chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap
|
98
104
|
)
|
99
105
|
chunks = text_splitter.split_text(text)
|
100
106
|
documents: List[Document] = [Document(page_content=chunk) for chunk in chunks]
|
@@ -102,10 +108,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
|
|
102
108
|
|
103
109
|
embeddings = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
|
104
110
|
vector_store = Annoy.from_documents(documents, embeddings)
|
105
|
-
search_results = vector_store.similarity_search(
|
106
|
-
question,
|
107
|
-
k=cfg.num_retrievals
|
108
|
-
)
|
111
|
+
search_results = vector_store.similarity_search(question, k=cfg.num_retrievals)
|
109
112
|
logger.info("Retrieved %d relevant document chunks.", len(search_results))
|
110
113
|
# Use the provided llm_model to build the QA chain.
|
111
114
|
qa_chain = load_qa_chain(llm_model, chain_type=cfg.qa_chain_type)
|
@@ -114,6 +117,49 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
|
|
114
117
|
)
|
115
118
|
return answer
|
116
119
|
|
120
|
+
|
121
|
+
def generate_answer2(
|
122
|
+
question: str, pdf_url: str, text_embedding_model: Embeddings
|
123
|
+
) -> Dict[str, Any]:
|
124
|
+
"""
|
125
|
+
Generate an answer for a question using retrieval augmented generation on PDF content.
|
126
|
+
|
127
|
+
This function extracts text from the PDF data, splits the text into manageable chunks,
|
128
|
+
performs a similarity search to retrieve the most relevant segments, and then uses a
|
129
|
+
question-answering chain (built using the provided llm_model) to generate an answer.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
question (str): The question to be answered.
|
133
|
+
pdf_bytes (bytes): The binary content of the PDF document.
|
134
|
+
llm_model (BaseChatModel): The language model instance to use for answering.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Dict[str, Any]: A dictionary containing the answer generated by the language model.
|
138
|
+
"""
|
139
|
+
# text = extract_text_from_pdf_data(pdf_bytes)
|
140
|
+
# logger.info("Extracted text from PDF.")
|
141
|
+
logger.log(logging.INFO, "searching the article with the question: %s", question)
|
142
|
+
# Load the article
|
143
|
+
# loader = PyPDFLoader(state['pdf_file_name'])
|
144
|
+
# loader = PyPDFLoader("https://arxiv.org/pdf/2310.08365")
|
145
|
+
loader = PyPDFLoader(pdf_url)
|
146
|
+
# Load the pages of the article
|
147
|
+
pages = []
|
148
|
+
for page in loader.lazy_load():
|
149
|
+
pages.append(page)
|
150
|
+
# Set up text embedding model
|
151
|
+
# text_embedding_model = state['text_embedding_model']
|
152
|
+
# text_embedding_model = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
|
153
|
+
logging.info("Loaded text embedding model %s", text_embedding_model)
|
154
|
+
# Create a vector store from the pages
|
155
|
+
vector_store = InMemoryVectorStore.from_documents(pages, text_embedding_model)
|
156
|
+
# Search the article with the question
|
157
|
+
docs = vector_store.similarity_search(question)
|
158
|
+
# Return the content of the pages
|
159
|
+
return "\n".join([doc.page_content for doc in docs])
|
160
|
+
# return answer
|
161
|
+
|
162
|
+
|
117
163
|
@tool(args_schema=QuestionAndAnswerInput)
|
118
164
|
def question_and_answer_tool(
|
119
165
|
question: str,
|
@@ -124,7 +170,7 @@ def question_and_answer_tool(
|
|
124
170
|
Answer a question using PDF content stored in the state via retrieval augmented generation.
|
125
171
|
|
126
172
|
This tool retrieves the PDF binary data from the state (under the key "pdf_data"), extracts its
|
127
|
-
textual content, and generates an answer to the specified question. It also extracts the
|
173
|
+
textual content, and generates an answer to the specified question. It also extracts the
|
128
174
|
llm_model (of type BaseChatModel) from the state to use for answering.
|
129
175
|
|
130
176
|
Args:
|
@@ -138,15 +184,15 @@ def question_and_answer_tool(
|
|
138
184
|
Dict[str, Any]: A dictionary containing the generated answer or an error message.
|
139
185
|
"""
|
140
186
|
logger.info("Starting PDF Question and Answer tool using PDF data from state.")
|
187
|
+
# print (state['text_embedding_model'])
|
188
|
+
text_embedding_model = state["text_embedding_model"]
|
141
189
|
pdf_state = state.get("pdf_data")
|
142
190
|
if not pdf_state:
|
143
191
|
error_msg = "No pdf_data found in state."
|
144
192
|
logger.error(error_msg)
|
145
193
|
return Command(
|
146
194
|
update={
|
147
|
-
"messages": [
|
148
|
-
ToolMessage(content=error_msg, tool_call_id=tool_call_id)
|
149
|
-
]
|
195
|
+
"messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
|
150
196
|
}
|
151
197
|
)
|
152
198
|
pdf_bytes = pdf_state.get("pdf_object")
|
@@ -155,16 +201,17 @@ def question_and_answer_tool(
|
|
155
201
|
logger.error(error_msg)
|
156
202
|
return Command(
|
157
203
|
update={
|
158
|
-
"messages": [
|
159
|
-
ToolMessage(content=error_msg, tool_call_id=tool_call_id)
|
160
|
-
]
|
204
|
+
"messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
|
161
205
|
}
|
162
206
|
)
|
207
|
+
pdf_url = pdf_state.get("pdf_url")
|
163
208
|
# Retrieve llm_model from state; use a default if not provided.
|
164
209
|
llm_model = state.get("llm_model")
|
165
210
|
if not llm_model:
|
166
211
|
logger.error("Missing LLM model instance in state.")
|
167
212
|
return {"error": "No LLM model found in state."}
|
168
|
-
answer = generate_answer(question, pdf_bytes, llm_model)
|
169
|
-
|
213
|
+
# answer = generate_answer(question, pdf_bytes, llm_model)
|
214
|
+
print(pdf_url)
|
215
|
+
answer = generate_answer2(question, pdf_url, text_embedding_model)
|
216
|
+
# logger.info("Generated answer: %s", answer)
|
170
217
|
return answer
|