aiagents4pharma 1.29.0__py3-none-any.whl → 1.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/agents/__init__.py +1 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +18 -10
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +85 -0
- aiagents4pharma/talk2scholars/agents/pdf_agent.py +4 -10
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +18 -9
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +2 -2
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/config.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +6 -1
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +7 -1
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +6 -1
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -1
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +4 -0
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +84 -53
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +24 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +142 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +154 -0
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +79 -15
- aiagents4pharma/talk2scholars/tests/test_routing_logic.py +12 -8
- aiagents4pharma/talk2scholars/tests/test_s2_multi.py +27 -4
- aiagents4pharma/talk2scholars/tests/test_s2_search.py +19 -3
- aiagents4pharma/talk2scholars/tests/test_s2_single.py +27 -3
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +17 -10
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +17 -0
- aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +45 -0
- aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +115 -0
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +64 -0
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -26
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +46 -22
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +1 -1
- aiagents4pharma/talk2scholars/tools/s2/search.py +40 -12
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +42 -16
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +33 -16
- aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +39 -7
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/METADATA +2 -2
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/RECORD +41 -32
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/WHEEL +1 -1
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/top_level.txt +0 -0
@@ -92,7 +92,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
92
92
|
{
|
93
93
|
"paperId": "paper1",
|
94
94
|
"title": "Recommended Paper 1",
|
95
|
-
"authors": ["Author A"],
|
95
|
+
"authors": [{"name": "Author A", "authorId": "A1"}],
|
96
96
|
"year": 2020,
|
97
97
|
"citationCount": 15,
|
98
98
|
"url": "http://paper1",
|
@@ -101,7 +101,7 @@ def dummy_requests_get_success(url, params, timeout):
|
|
101
101
|
{
|
102
102
|
"paperId": "paper2",
|
103
103
|
"title": "Recommended Paper 2",
|
104
|
-
"authors": ["Author B"],
|
104
|
+
"authors": [{"name": "Author B", "authorId": "B1"}],
|
105
105
|
"year": 2021,
|
106
106
|
"citationCount": 25,
|
107
107
|
"url": "http://paper2",
|
@@ -269,6 +269,30 @@ def test_single_paper_rec_requests_exception(monkeypatch):
|
|
269
269
|
}
|
270
270
|
with pytest.raises(
|
271
271
|
RuntimeError,
|
272
|
-
match="Failed to connect to Semantic Scholar API
|
272
|
+
match="Failed to connect to Semantic Scholar API after 10 attempts."
|
273
|
+
"Please retry the same query.",
|
274
|
+
):
|
275
|
+
get_single_paper_recommendations.run(input_data)
|
276
|
+
|
277
|
+
|
278
|
+
def test_single_paper_rec_no_response(monkeypatch):
|
279
|
+
"""
|
280
|
+
Test that get_single_paper_recommendations raises a RuntimeError
|
281
|
+
when no response is obtained from the API.
|
282
|
+
|
283
|
+
This is simulated by patching 'range' in the underlying function's globals
|
284
|
+
to return an empty iterator, so the for-loop never iterates and response remains None.
|
285
|
+
"""
|
286
|
+
# Patch 'range' in the underlying function's globals (accessed via .func.__globals__)
|
287
|
+
monkeypatch.setitem(
|
288
|
+
get_single_paper_recommendations.func.__globals__, "range", lambda x: iter([])
|
289
|
+
)
|
290
|
+
tool_call_id = "test_tool_call_id"
|
291
|
+
input_data = {
|
292
|
+
"paper_id": "12345",
|
293
|
+
"tool_call_id": tool_call_id,
|
294
|
+
}
|
295
|
+
with pytest.raises(
|
296
|
+
RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
|
273
297
|
):
|
274
298
|
get_single_paper_recommendations.run(input_data)
|
@@ -26,7 +26,7 @@ dummy_cfg = SimpleNamespace(tools=SimpleNamespace(zotero_read=dummy_zotero_read_
|
|
26
26
|
|
27
27
|
|
28
28
|
class TestZoteroSearchTool(unittest.TestCase):
|
29
|
-
"""
|
29
|
+
"""Tests for Zotero search tool."""
|
30
30
|
|
31
31
|
@patch(
|
32
32
|
"aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
|
@@ -41,7 +41,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
41
41
|
mock_zotero_class,
|
42
42
|
mock_get_item_collections,
|
43
43
|
):
|
44
|
-
"""
|
44
|
+
"""Test valid query returns correct Command output."""
|
45
45
|
# Setup Hydra mocks
|
46
46
|
mock_hydra_compose.return_value = dummy_cfg
|
47
47
|
mock_hydra_init.return_value.__enter__.return_value = None
|
@@ -116,7 +116,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
116
116
|
mock_zotero_class,
|
117
117
|
mock_get_item_collections,
|
118
118
|
):
|
119
|
-
"""
|
119
|
+
"""Test empty query fetches all items."""
|
120
120
|
mock_hydra_compose.return_value = dummy_cfg
|
121
121
|
mock_hydra_init.return_value.__enter__.return_value = None
|
122
122
|
|
@@ -166,7 +166,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
166
166
|
mock_zotero_class,
|
167
167
|
mock_get_item_collections,
|
168
168
|
):
|
169
|
-
"""
|
169
|
+
"""Test no items returned from Zotero."""
|
170
170
|
mock_hydra_compose.return_value = dummy_cfg
|
171
171
|
mock_hydra_init.return_value.__enter__.return_value = None
|
172
172
|
|
@@ -199,7 +199,10 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
199
199
|
mock_zotero_class,
|
200
200
|
mock_get_item_collections,
|
201
201
|
):
|
202
|
-
"""
|
202
|
+
"""
|
203
|
+
Test that when non-research items (e.g. attachments, notes) are returned,
|
204
|
+
they are still included since filtering is disabled.
|
205
|
+
"""
|
203
206
|
mock_hydra_compose.return_value = dummy_cfg
|
204
207
|
mock_hydra_init.return_value.__enter__.return_value = None
|
205
208
|
|
@@ -240,9 +243,13 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
240
243
|
"tool_call_id": tool_call_id,
|
241
244
|
"limit": 2,
|
242
245
|
}
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
+
# Instead of expecting a RuntimeError, we now expect both items to be returned.
|
247
|
+
result = zotero_search_tool.run(tool_input)
|
248
|
+
update = result.update
|
249
|
+
filtered_papers = update["zotero_read"]
|
250
|
+
self.assertIn("paper1", filtered_papers)
|
251
|
+
self.assertIn("paper2", filtered_papers)
|
252
|
+
self.assertEqual(len(filtered_papers), 2)
|
246
253
|
|
247
254
|
@patch(
|
248
255
|
"aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
|
@@ -257,7 +264,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
257
264
|
mock_zotero_class,
|
258
265
|
mock_get_item_collections,
|
259
266
|
):
|
260
|
-
"""
|
267
|
+
"""Test items API exception is properly raised."""
|
261
268
|
mock_hydra_compose.return_value = dummy_cfg
|
262
269
|
mock_hydra_init.return_value.__enter__.return_value = None
|
263
270
|
mock_get_item_collections.return_value = {}
|
@@ -306,7 +313,7 @@ class TestZoteroSearchTool(unittest.TestCase):
|
|
306
313
|
"url": "http://example.com",
|
307
314
|
"itemType": "journalArticle",
|
308
315
|
}
|
309
|
-
}, #
|
316
|
+
}, # Missing 'key' field
|
310
317
|
{
|
311
318
|
"data": {
|
312
319
|
"key": "paper_valid",
|
@@ -0,0 +1,17 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
This package provides modules for fetching and downloading academic papers from arXiv.
|
4
|
+
"""
|
5
|
+
|
6
|
+
# Import modules
|
7
|
+
from . import abstract_downloader
|
8
|
+
from . import arxiv_downloader
|
9
|
+
from . import download_arxiv_input
|
10
|
+
from .download_arxiv_input import download_arxiv_paper
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"abstract_downloader",
|
14
|
+
"arxiv_downloader",
|
15
|
+
"download_arxiv_input",
|
16
|
+
"download_arxiv_paper",
|
17
|
+
]
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""
|
2
|
+
Abstract Base Class for Paper Downloaders.
|
3
|
+
|
4
|
+
This module defines the `AbstractPaperDownloader` class, which serves as a
|
5
|
+
base class for downloading scholarly papers from different sources
|
6
|
+
(e.g., arXiv, PubMed, IEEE Xplore). Any specific downloader should
|
7
|
+
inherit from this class and implement its methods.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
from typing import Any, Dict
|
12
|
+
|
13
|
+
|
14
|
+
class AbstractPaperDownloader(ABC):
|
15
|
+
"""
|
16
|
+
Abstract base class for scholarly paper downloaders.
|
17
|
+
|
18
|
+
This is designed to be extended for different paper sources
|
19
|
+
like arXiv, PubMed, IEEE Xplore, etc. Each implementation
|
20
|
+
must define methods for fetching metadata and downloading PDFs.
|
21
|
+
"""
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def fetch_metadata(self, paper_id: str) -> Dict[str, Any]:
|
25
|
+
"""
|
26
|
+
Fetch metadata for a given paper ID.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
paper_id (str): The unique identifier for the paper.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
Dict[str, Any]: The metadata dictionary (format depends on the data source).
|
33
|
+
"""
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def download_pdf(self, paper_id: str) -> bytes:
|
37
|
+
"""
|
38
|
+
Download the PDF for a given paper ID.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
paper_id (str): The unique identifier for the paper.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
bytes: The binary content of the downloaded PDF.
|
45
|
+
"""
|
@@ -0,0 +1,115 @@
|
|
1
|
+
"""
|
2
|
+
Arxiv Paper Downloader
|
3
|
+
|
4
|
+
This module provides an implementation of `AbstractPaperDownloader` for arXiv.
|
5
|
+
It connects to the arXiv API, retrieves metadata for a research paper, and
|
6
|
+
downloads the corresponding PDF.
|
7
|
+
|
8
|
+
By using an abstract base class, this implementation is extendable to other
|
9
|
+
APIs like PubMed, IEEE Xplore, etc.
|
10
|
+
"""
|
11
|
+
|
12
|
+
import xml.etree.ElementTree as ET
|
13
|
+
from typing import Any, Dict
|
14
|
+
import logging
|
15
|
+
import hydra
|
16
|
+
import requests
|
17
|
+
from .abstract_downloader import AbstractPaperDownloader
|
18
|
+
|
19
|
+
# Configure logging
|
20
|
+
logging.basicConfig(level=logging.INFO)
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class ArxivPaperDownloader(AbstractPaperDownloader):
|
25
|
+
"""
|
26
|
+
Downloader class for arXiv.
|
27
|
+
|
28
|
+
This class interfaces with the arXiv API to fetch metadata
|
29
|
+
and retrieve PDFs of academic papers based on their arXiv IDs.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self):
|
33
|
+
"""
|
34
|
+
Initializes the arXiv paper downloader.
|
35
|
+
|
36
|
+
Uses Hydra for configuration management to retrieve API details.
|
37
|
+
"""
|
38
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
39
|
+
cfg = hydra.compose(
|
40
|
+
config_name="config", overrides=["tools/download_arxiv_paper=default"]
|
41
|
+
)
|
42
|
+
self.api_url = cfg.tools.download_arxiv_paper.api_url
|
43
|
+
self.request_timeout = cfg.tools.download_arxiv_paper.request_timeout
|
44
|
+
self.chunk_size = cfg.tools.download_arxiv_paper.chunk_size
|
45
|
+
self.pdf_base_url = cfg.tools.download_arxiv_paper.pdf_base_url
|
46
|
+
|
47
|
+
def fetch_metadata(self, paper_id: str) -> Dict[str, Any]:
|
48
|
+
"""
|
49
|
+
Fetch metadata from arXiv for a given paper ID.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
paper_id (str): The arXiv ID of the paper.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Dict[str, Any]: A dictionary containing metadata, including the XML response.
|
56
|
+
"""
|
57
|
+
logger.info("Fetching metadata from arXiv for paper ID: %s", paper_id)
|
58
|
+
api_url = f"{self.api_url}?search_query=id:{paper_id}&start=0&max_results=1"
|
59
|
+
response = requests.get(api_url, timeout=self.request_timeout)
|
60
|
+
response.raise_for_status()
|
61
|
+
return {"xml": response.text}
|
62
|
+
|
63
|
+
def download_pdf(self, paper_id: str) -> Dict[str, Any]:
|
64
|
+
"""
|
65
|
+
Download the PDF of a paper from arXiv.
|
66
|
+
|
67
|
+
This function first retrieves the paper's metadata to locate the PDF link
|
68
|
+
before downloading the file.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
paper_id (str): The arXiv ID of the paper.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
Dict[str, Any]: A dictionary containing:
|
75
|
+
- `pdf_object`: The binary content of the downloaded PDF.
|
76
|
+
- `pdf_url`: The URL from which the PDF was fetched.
|
77
|
+
- `arxiv_id`: The arXiv ID of the downloaded paper.
|
78
|
+
"""
|
79
|
+
metadata = self.fetch_metadata(paper_id)
|
80
|
+
|
81
|
+
# Parse the XML response to locate the PDF link.
|
82
|
+
root = ET.fromstring(metadata["xml"])
|
83
|
+
ns = {"atom": "http://www.w3.org/2005/Atom"}
|
84
|
+
pdf_url = next(
|
85
|
+
(
|
86
|
+
link.attrib.get("href")
|
87
|
+
for entry in root.findall("atom:entry", ns)
|
88
|
+
for link in entry.findall("atom:link", ns)
|
89
|
+
if link.attrib.get("title") == "pdf"
|
90
|
+
),
|
91
|
+
None,
|
92
|
+
)
|
93
|
+
|
94
|
+
if not pdf_url:
|
95
|
+
raise RuntimeError(f"Failed to download PDF for arXiv ID {paper_id}.")
|
96
|
+
|
97
|
+
logger.info("Downloading PDF from: %s", pdf_url)
|
98
|
+
pdf_response = requests.get(pdf_url, stream=True, timeout=self.request_timeout)
|
99
|
+
pdf_response.raise_for_status()
|
100
|
+
# print (pdf_response)
|
101
|
+
|
102
|
+
# Combine the PDF data from chunks.
|
103
|
+
pdf_object = b"".join(
|
104
|
+
chunk
|
105
|
+
for chunk in pdf_response.iter_content(chunk_size=self.chunk_size)
|
106
|
+
if chunk
|
107
|
+
)
|
108
|
+
# print (pdf_object)
|
109
|
+
print("PDF_URL", pdf_url)
|
110
|
+
|
111
|
+
return {
|
112
|
+
"pdf_object": pdf_object,
|
113
|
+
"pdf_url": pdf_url,
|
114
|
+
"arxiv_id": paper_id,
|
115
|
+
}
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# File: aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py
|
2
|
+
"""
|
3
|
+
This module defines the `download_arxiv_paper` tool, which leverages the
|
4
|
+
`ArxivPaperDownloader` class to fetch and download academic papers from arXiv
|
5
|
+
based on their unique arXiv ID.
|
6
|
+
"""
|
7
|
+
from typing import Annotated, Any
|
8
|
+
from pydantic import BaseModel, Field
|
9
|
+
from langchain_core.tools import tool
|
10
|
+
from langchain_core.messages import ToolMessage
|
11
|
+
from langchain_core.tools.base import InjectedToolCallId
|
12
|
+
from langgraph.types import Command
|
13
|
+
|
14
|
+
# Local import from the same package:
|
15
|
+
from .arxiv_downloader import ArxivPaperDownloader
|
16
|
+
|
17
|
+
|
18
|
+
class DownloadArxivPaperInput(BaseModel):
|
19
|
+
"""
|
20
|
+
Input schema for the arXiv paper download tool.
|
21
|
+
(Optional: if you decide to keep Pydantic validation in the future)
|
22
|
+
"""
|
23
|
+
|
24
|
+
arxiv_id: str = Field(
|
25
|
+
description="The arXiv paper ID used to retrieve the paper details and PDF."
|
26
|
+
)
|
27
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
28
|
+
|
29
|
+
|
30
|
+
@tool(args_schema=DownloadArxivPaperInput, parse_docstring=True)
|
31
|
+
def download_arxiv_paper(
|
32
|
+
arxiv_id: str,
|
33
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
34
|
+
) -> Command[Any]:
|
35
|
+
"""
|
36
|
+
Download an arXiv paper's PDF using its unique arXiv ID.
|
37
|
+
|
38
|
+
This function:
|
39
|
+
1. Creates an `ArxivPaperDownloader` instance.
|
40
|
+
2. Fetches metadata from arXiv using the provided `arxiv_id`.
|
41
|
+
3. Downloads the PDF from the returned link.
|
42
|
+
4. Returns a `Command` object containing the PDF data and a success message.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
arxiv_id (str): The unique arXiv paper ID.
|
46
|
+
tool_call_id (InjectedToolCallId): A unique identifier for tracking this tool call.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
Command[Any]: Contains metadata and messages about the success of the operation.
|
50
|
+
"""
|
51
|
+
downloader = ArxivPaperDownloader()
|
52
|
+
|
53
|
+
# If the downloader fails or the arxiv_id is invalid, this might raise an error
|
54
|
+
pdf_data = downloader.download_pdf(arxiv_id)
|
55
|
+
# print (pdf_data)
|
56
|
+
|
57
|
+
content = f"Successfully downloaded PDF for arXiv ID {arxiv_id}"
|
58
|
+
|
59
|
+
return Command(
|
60
|
+
update={
|
61
|
+
"pdf_data": pdf_data,
|
62
|
+
"messages": [ToolMessage(content=content, tool_call_id=tool_call_id)],
|
63
|
+
}
|
64
|
+
)
|
@@ -2,8 +2,8 @@
|
|
2
2
|
"""
|
3
3
|
question_and_answer: Tool for performing Q&A on PDF documents using retrieval augmented generation.
|
4
4
|
|
5
|
-
This module provides functionality to extract text from PDF binary data, split it into
|
6
|
-
chunks, retrieve relevant segments via a vector store, and generate an answer to a
|
5
|
+
This module provides functionality to extract text from PDF binary data, split it into
|
6
|
+
chunks, retrieve relevant segments via a vector store, and generate an answer to a
|
7
7
|
user-provided question using a language model chain.
|
8
8
|
"""
|
9
9
|
|
@@ -18,13 +18,15 @@ import hydra
|
|
18
18
|
from langchain.chains.question_answering import load_qa_chain
|
19
19
|
from langchain.docstore.document import Document
|
20
20
|
from langchain.text_splitter import CharacterTextSplitter
|
21
|
-
from langchain_community.vectorstores import Annoy
|
22
|
-
from langchain_openai import OpenAIEmbeddings
|
23
21
|
from langchain_core.language_models.chat_models import BaseChatModel
|
24
|
-
|
22
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
25
23
|
from langchain_core.messages import ToolMessage
|
26
24
|
from langchain_core.tools import tool
|
27
25
|
from langchain_core.tools.base import InjectedToolCallId
|
26
|
+
from langchain_core.embeddings import Embeddings
|
27
|
+
from langchain_community.vectorstores import Annoy
|
28
|
+
from langchain_community.document_loaders import PyPDFLoader
|
29
|
+
from langchain_openai import OpenAIEmbeddings
|
28
30
|
from langgraph.types import Command
|
29
31
|
from langgraph.prebuilt import InjectedState
|
30
32
|
|
@@ -35,10 +37,13 @@ logger.setLevel(logging.INFO)
|
|
35
37
|
|
36
38
|
# Load configuration using Hydra.
|
37
39
|
with hydra.initialize(version_base=None, config_path="../../configs"):
|
38
|
-
cfg = hydra.compose(
|
40
|
+
cfg = hydra.compose(
|
41
|
+
config_name="config", overrides=["tools/question_and_answer=default"]
|
42
|
+
)
|
39
43
|
cfg = cfg.tools.question_and_answer
|
40
44
|
logger.info("Loaded Question and Answer tool configuration.")
|
41
45
|
|
46
|
+
|
42
47
|
class QuestionAndAnswerInput(BaseModel):
|
43
48
|
"""
|
44
49
|
Input schema for the PDF Question and Answer tool.
|
@@ -47,12 +52,12 @@ class QuestionAndAnswerInput(BaseModel):
|
|
47
52
|
question (str): The question to ask regarding the PDF content.
|
48
53
|
tool_call_id (str): Unique identifier for the tool call, injected automatically.
|
49
54
|
"""
|
50
|
-
|
51
|
-
|
52
|
-
)
|
55
|
+
|
56
|
+
question: str = Field(description="The question to ask regarding the PDF content.")
|
53
57
|
tool_call_id: Annotated[str, InjectedToolCallId]
|
54
58
|
state: Annotated[dict, InjectedState]
|
55
59
|
|
60
|
+
|
56
61
|
def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
|
57
62
|
"""
|
58
63
|
Extract text content from PDF binary data.
|
@@ -73,7 +78,10 @@ def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
|
|
73
78
|
text += page_text
|
74
79
|
return text
|
75
80
|
|
76
|
-
|
81
|
+
|
82
|
+
def generate_answer(
|
83
|
+
question: str, pdf_bytes: bytes, llm_model: BaseChatModel
|
84
|
+
) -> Dict[str, Any]:
|
77
85
|
"""
|
78
86
|
Generate an answer for a question using retrieval augmented generation on PDF content.
|
79
87
|
|
@@ -92,9 +100,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
|
|
92
100
|
text = extract_text_from_pdf_data(pdf_bytes)
|
93
101
|
logger.info("Extracted text from PDF.")
|
94
102
|
text_splitter = CharacterTextSplitter(
|
95
|
-
separator="\n",
|
96
|
-
chunk_size=cfg.chunk_size,
|
97
|
-
chunk_overlap=cfg.chunk_overlap
|
103
|
+
separator="\n", chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap
|
98
104
|
)
|
99
105
|
chunks = text_splitter.split_text(text)
|
100
106
|
documents: List[Document] = [Document(page_content=chunk) for chunk in chunks]
|
@@ -102,10 +108,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
|
|
102
108
|
|
103
109
|
embeddings = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
|
104
110
|
vector_store = Annoy.from_documents(documents, embeddings)
|
105
|
-
search_results = vector_store.similarity_search(
|
106
|
-
question,
|
107
|
-
k=cfg.num_retrievals
|
108
|
-
)
|
111
|
+
search_results = vector_store.similarity_search(question, k=cfg.num_retrievals)
|
109
112
|
logger.info("Retrieved %d relevant document chunks.", len(search_results))
|
110
113
|
# Use the provided llm_model to build the QA chain.
|
111
114
|
qa_chain = load_qa_chain(llm_model, chain_type=cfg.qa_chain_type)
|
@@ -114,6 +117,49 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
|
|
114
117
|
)
|
115
118
|
return answer
|
116
119
|
|
120
|
+
|
121
|
+
def generate_answer2(
|
122
|
+
question: str, pdf_url: str, text_embedding_model: Embeddings
|
123
|
+
) -> Dict[str, Any]:
|
124
|
+
"""
|
125
|
+
Generate an answer for a question using retrieval augmented generation on PDF content.
|
126
|
+
|
127
|
+
This function extracts text from the PDF data, splits the text into manageable chunks,
|
128
|
+
performs a similarity search to retrieve the most relevant segments, and then uses a
|
129
|
+
question-answering chain (built using the provided llm_model) to generate an answer.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
question (str): The question to be answered.
|
133
|
+
pdf_bytes (bytes): The binary content of the PDF document.
|
134
|
+
llm_model (BaseChatModel): The language model instance to use for answering.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Dict[str, Any]: A dictionary containing the answer generated by the language model.
|
138
|
+
"""
|
139
|
+
# text = extract_text_from_pdf_data(pdf_bytes)
|
140
|
+
# logger.info("Extracted text from PDF.")
|
141
|
+
logger.log(logging.INFO, "searching the article with the question: %s", question)
|
142
|
+
# Load the article
|
143
|
+
# loader = PyPDFLoader(state['pdf_file_name'])
|
144
|
+
# loader = PyPDFLoader("https://arxiv.org/pdf/2310.08365")
|
145
|
+
loader = PyPDFLoader(pdf_url)
|
146
|
+
# Load the pages of the article
|
147
|
+
pages = []
|
148
|
+
for page in loader.lazy_load():
|
149
|
+
pages.append(page)
|
150
|
+
# Set up text embedding model
|
151
|
+
# text_embedding_model = state['text_embedding_model']
|
152
|
+
# text_embedding_model = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
|
153
|
+
logging.info("Loaded text embedding model %s", text_embedding_model)
|
154
|
+
# Create a vector store from the pages
|
155
|
+
vector_store = InMemoryVectorStore.from_documents(pages, text_embedding_model)
|
156
|
+
# Search the article with the question
|
157
|
+
docs = vector_store.similarity_search(question)
|
158
|
+
# Return the content of the pages
|
159
|
+
return "\n".join([doc.page_content for doc in docs])
|
160
|
+
# return answer
|
161
|
+
|
162
|
+
|
117
163
|
@tool(args_schema=QuestionAndAnswerInput)
|
118
164
|
def question_and_answer_tool(
|
119
165
|
question: str,
|
@@ -124,7 +170,7 @@ def question_and_answer_tool(
|
|
124
170
|
Answer a question using PDF content stored in the state via retrieval augmented generation.
|
125
171
|
|
126
172
|
This tool retrieves the PDF binary data from the state (under the key "pdf_data"), extracts its
|
127
|
-
textual content, and generates an answer to the specified question. It also extracts the
|
173
|
+
textual content, and generates an answer to the specified question. It also extracts the
|
128
174
|
llm_model (of type BaseChatModel) from the state to use for answering.
|
129
175
|
|
130
176
|
Args:
|
@@ -138,15 +184,15 @@ def question_and_answer_tool(
|
|
138
184
|
Dict[str, Any]: A dictionary containing the generated answer or an error message.
|
139
185
|
"""
|
140
186
|
logger.info("Starting PDF Question and Answer tool using PDF data from state.")
|
187
|
+
# print (state['text_embedding_model'])
|
188
|
+
text_embedding_model = state["text_embedding_model"]
|
141
189
|
pdf_state = state.get("pdf_data")
|
142
190
|
if not pdf_state:
|
143
191
|
error_msg = "No pdf_data found in state."
|
144
192
|
logger.error(error_msg)
|
145
193
|
return Command(
|
146
194
|
update={
|
147
|
-
"messages": [
|
148
|
-
ToolMessage(content=error_msg, tool_call_id=tool_call_id)
|
149
|
-
]
|
195
|
+
"messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
|
150
196
|
}
|
151
197
|
)
|
152
198
|
pdf_bytes = pdf_state.get("pdf_object")
|
@@ -155,16 +201,17 @@ def question_and_answer_tool(
|
|
155
201
|
logger.error(error_msg)
|
156
202
|
return Command(
|
157
203
|
update={
|
158
|
-
"messages": [
|
159
|
-
ToolMessage(content=error_msg, tool_call_id=tool_call_id)
|
160
|
-
]
|
204
|
+
"messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
|
161
205
|
}
|
162
206
|
)
|
207
|
+
pdf_url = pdf_state.get("pdf_url")
|
163
208
|
# Retrieve llm_model from state; use a default if not provided.
|
164
209
|
llm_model = state.get("llm_model")
|
165
210
|
if not llm_model:
|
166
211
|
logger.error("Missing LLM model instance in state.")
|
167
212
|
return {"error": "No LLM model found in state."}
|
168
|
-
answer = generate_answer(question, pdf_bytes, llm_model)
|
169
|
-
|
213
|
+
# answer = generate_answer(question, pdf_bytes, llm_model)
|
214
|
+
print(pdf_url)
|
215
|
+
answer = generate_answer2(question, pdf_url, text_embedding_model)
|
216
|
+
# logger.info("Generated answer: %s", answer)
|
170
217
|
return answer
|
@@ -30,7 +30,7 @@ class MultiPaperRecInput(BaseModel):
|
|
30
30
|
description="List of Semantic Scholar Paper IDs to get recommendations for"
|
31
31
|
)
|
32
32
|
limit: int = Field(
|
33
|
-
default=
|
33
|
+
default=10,
|
34
34
|
description="Maximum total number of recommendations to return",
|
35
35
|
ge=1,
|
36
36
|
le=500,
|
@@ -90,23 +90,33 @@ def get_multi_paper_recommendations(
|
|
90
90
|
params["year"] = year
|
91
91
|
|
92
92
|
# Wrap API call in try/except to catch connectivity issues and validate response format
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
93
|
+
response = None
|
94
|
+
for attempt in range(10):
|
95
|
+
try:
|
96
|
+
response = requests.post(
|
97
|
+
endpoint,
|
98
|
+
headers=headers,
|
99
|
+
params=params,
|
100
|
+
data=json.dumps(payload),
|
101
|
+
timeout=cfg.request_timeout,
|
102
|
+
)
|
103
|
+
response.raise_for_status() # Raises HTTPError for bad responses
|
104
|
+
break # Exit loop if request is successful
|
105
|
+
except requests.exceptions.RequestException as e:
|
106
|
+
logger.error(
|
107
|
+
"Attempt %d: Failed to connect to Semantic Scholar API for "
|
108
|
+
"multi-paper recommendations: %s",
|
109
|
+
attempt + 1,
|
110
|
+
e,
|
111
|
+
)
|
112
|
+
if attempt == 9: # Last attempt
|
113
|
+
raise RuntimeError(
|
114
|
+
"Failed to connect to Semantic Scholar API after 10 attempts."
|
115
|
+
"Please retry the same query."
|
116
|
+
) from e
|
117
|
+
|
118
|
+
if response is None:
|
119
|
+
raise RuntimeError("Failed to obtain a response from the Semantic Scholar API.")
|
110
120
|
|
111
121
|
logger.info(
|
112
122
|
"API Response Status for multi-paper recommendations: %s", response.status_code
|
@@ -137,11 +147,22 @@ def get_multi_paper_recommendations(
|
|
137
147
|
# Create a dictionary to store the papers
|
138
148
|
filtered_papers = {
|
139
149
|
paper["paperId"]: {
|
140
|
-
"
|
150
|
+
"semantic_scholar_paper_id": paper["paperId"],
|
141
151
|
"Title": paper.get("title", "N/A"),
|
142
152
|
"Abstract": paper.get("abstract", "N/A"),
|
143
153
|
"Year": paper.get("year", "N/A"),
|
154
|
+
"Publication Date": paper.get("publicationDate", "N/A"),
|
155
|
+
"Venue": paper.get("venue", "N/A"),
|
156
|
+
# "Publication Venue": (paper.get("publicationVenue") or {}).get("name", "N/A"),
|
157
|
+
# "Venue Type": (paper.get("publicationVenue") or {}).get("name", "N/A"),
|
158
|
+
"Journal Name": (paper.get("journal") or {}).get("name", "N/A"),
|
159
|
+
# "Journal Volume": paper.get("journal", {}).get("volume", "N/A"),
|
160
|
+
# "Journal Pages": paper.get("journal", {}).get("pages", "N/A"),
|
144
161
|
"Citation Count": paper.get("citationCount", "N/A"),
|
162
|
+
"Authors": [
|
163
|
+
f"{author.get('name', 'N/A')} (ID: {author.get('authorId', 'N/A')})"
|
164
|
+
for author in paper.get("authors", [])
|
165
|
+
],
|
145
166
|
"URL": paper.get("url", "N/A"),
|
146
167
|
"arxiv_id": paper.get("externalIds", {}).get("ArXiv", "N/A"),
|
147
168
|
}
|
@@ -153,7 +174,10 @@ def get_multi_paper_recommendations(
|
|
153
174
|
top_papers = list(filtered_papers.values())[:3]
|
154
175
|
top_papers_info = "\n".join(
|
155
176
|
[
|
156
|
-
f"{i+1}. {paper['Title']} ({paper['Year']})"
|
177
|
+
# f"{i+1}. {paper['Title']} ({paper['Year']})"
|
178
|
+
f"{i+1}. {paper['Title']} ({paper['Year']}; "
|
179
|
+
f"semantic_scholar_paper_id: {paper['semantic_scholar_paper_id']}; "
|
180
|
+
f"arXiv ID: {paper['arxiv_id']})"
|
157
181
|
for i, paper in enumerate(top_papers)
|
158
182
|
]
|
159
183
|
)
|
@@ -165,10 +189,10 @@ def get_multi_paper_recommendations(
|
|
165
189
|
"Papers are attached as an artifact."
|
166
190
|
)
|
167
191
|
content += " Here is a summary of the recommendations:\n"
|
168
|
-
content += f"Number of papers found: {len(filtered_papers)}\n"
|
192
|
+
content += f"Number of recommended papers found: {len(filtered_papers)}\n"
|
169
193
|
content += f"Query Paper IDs: {', '.join(paper_ids)}\n"
|
170
194
|
content += f"Year: {year}\n" if year else ""
|
171
|
-
content += "
|
195
|
+
content += "Here are a few of these papers:\n" + top_papers_info
|
172
196
|
|
173
197
|
return Command(
|
174
198
|
update={
|