aiagents4pharma 1.29.0__py3-none-any.whl → 1.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. aiagents4pharma/talk2scholars/agents/__init__.py +1 -0
  2. aiagents4pharma/talk2scholars/agents/main_agent.py +18 -10
  3. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +85 -0
  4. aiagents4pharma/talk2scholars/agents/pdf_agent.py +4 -10
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +18 -9
  6. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
  7. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +2 -2
  8. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +1 -0
  9. aiagents4pharma/talk2scholars/configs/config.yaml +2 -0
  10. aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/__init__.py +3 -0
  11. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +6 -1
  12. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +7 -1
  13. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +6 -1
  14. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -1
  15. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +4 -0
  16. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +84 -53
  17. aiagents4pharma/talk2scholars/tests/test_main_agent.py +24 -0
  18. aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +142 -0
  19. aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +154 -0
  20. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +79 -15
  21. aiagents4pharma/talk2scholars/tests/test_routing_logic.py +12 -8
  22. aiagents4pharma/talk2scholars/tests/test_s2_multi.py +27 -4
  23. aiagents4pharma/talk2scholars/tests/test_s2_search.py +19 -3
  24. aiagents4pharma/talk2scholars/tests/test_s2_single.py +27 -3
  25. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +17 -10
  26. aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +17 -0
  27. aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +45 -0
  28. aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +115 -0
  29. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +64 -0
  30. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -26
  31. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +46 -22
  32. aiagents4pharma/talk2scholars/tools/s2/query_results.py +1 -1
  33. aiagents4pharma/talk2scholars/tools/s2/search.py +40 -12
  34. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +42 -16
  35. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +33 -16
  36. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +39 -7
  37. {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/METADATA +2 -2
  38. {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/RECORD +41 -32
  39. {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/WHEEL +1 -1
  40. {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/LICENSE +0 -0
  41. {aiagents4pharma-1.29.0.dist-info → aiagents4pharma-1.30.1.dist-info}/top_level.txt +0 -0
@@ -92,7 +92,7 @@ def dummy_requests_get_success(url, params, timeout):
92
92
  {
93
93
  "paperId": "paper1",
94
94
  "title": "Recommended Paper 1",
95
- "authors": ["Author A"],
95
+ "authors": [{"name": "Author A", "authorId": "A1"}],
96
96
  "year": 2020,
97
97
  "citationCount": 15,
98
98
  "url": "http://paper1",
@@ -101,7 +101,7 @@ def dummy_requests_get_success(url, params, timeout):
101
101
  {
102
102
  "paperId": "paper2",
103
103
  "title": "Recommended Paper 2",
104
- "authors": ["Author B"],
104
+ "authors": [{"name": "Author B", "authorId": "B1"}],
105
105
  "year": 2021,
106
106
  "citationCount": 25,
107
107
  "url": "http://paper2",
@@ -269,6 +269,30 @@ def test_single_paper_rec_requests_exception(monkeypatch):
269
269
  }
270
270
  with pytest.raises(
271
271
  RuntimeError,
272
- match="Failed to connect to Semantic Scholar API. Please retry the same query.",
272
+ match="Failed to connect to Semantic Scholar API after 10 attempts."
273
+ "Please retry the same query.",
274
+ ):
275
+ get_single_paper_recommendations.run(input_data)
276
+
277
+
278
+ def test_single_paper_rec_no_response(monkeypatch):
279
+ """
280
+ Test that get_single_paper_recommendations raises a RuntimeError
281
+ when no response is obtained from the API.
282
+
283
+ This is simulated by patching 'range' in the underlying function's globals
284
+ to return an empty iterator, so the for-loop never iterates and response remains None.
285
+ """
286
+ # Patch 'range' in the underlying function's globals (accessed via .func.__globals__)
287
+ monkeypatch.setitem(
288
+ get_single_paper_recommendations.func.__globals__, "range", lambda x: iter([])
289
+ )
290
+ tool_call_id = "test_tool_call_id"
291
+ input_data = {
292
+ "paper_id": "12345",
293
+ "tool_call_id": tool_call_id,
294
+ }
295
+ with pytest.raises(
296
+ RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
273
297
  ):
274
298
  get_single_paper_recommendations.run(input_data)
@@ -26,7 +26,7 @@ dummy_cfg = SimpleNamespace(tools=SimpleNamespace(zotero_read=dummy_zotero_read_
26
26
 
27
27
 
28
28
  class TestZoteroSearchTool(unittest.TestCase):
29
- """test for Zotero search tool"""
29
+ """Tests for Zotero search tool."""
30
30
 
31
31
  @patch(
32
32
  "aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
@@ -41,7 +41,7 @@ class TestZoteroSearchTool(unittest.TestCase):
41
41
  mock_zotero_class,
42
42
  mock_get_item_collections,
43
43
  ):
44
- """test valid query"""
44
+ """Test valid query returns correct Command output."""
45
45
  # Setup Hydra mocks
46
46
  mock_hydra_compose.return_value = dummy_cfg
47
47
  mock_hydra_init.return_value.__enter__.return_value = None
@@ -116,7 +116,7 @@ class TestZoteroSearchTool(unittest.TestCase):
116
116
  mock_zotero_class,
117
117
  mock_get_item_collections,
118
118
  ):
119
- """test empty query fetches all items"""
119
+ """Test empty query fetches all items."""
120
120
  mock_hydra_compose.return_value = dummy_cfg
121
121
  mock_hydra_init.return_value.__enter__.return_value = None
122
122
 
@@ -166,7 +166,7 @@ class TestZoteroSearchTool(unittest.TestCase):
166
166
  mock_zotero_class,
167
167
  mock_get_item_collections,
168
168
  ):
169
- """test no items returned from Zotero"""
169
+ """Test no items returned from Zotero."""
170
170
  mock_hydra_compose.return_value = dummy_cfg
171
171
  mock_hydra_init.return_value.__enter__.return_value = None
172
172
 
@@ -199,7 +199,10 @@ class TestZoteroSearchTool(unittest.TestCase):
199
199
  mock_zotero_class,
200
200
  mock_get_item_collections,
201
201
  ):
202
- """test no matching papers returned from Zotero"""
202
+ """
203
+ Test that when non-research items (e.g. attachments, notes) are returned,
204
+ they are still included since filtering is disabled.
205
+ """
203
206
  mock_hydra_compose.return_value = dummy_cfg
204
207
  mock_hydra_init.return_value.__enter__.return_value = None
205
208
 
@@ -240,9 +243,13 @@ class TestZoteroSearchTool(unittest.TestCase):
240
243
  "tool_call_id": tool_call_id,
241
244
  "limit": 2,
242
245
  }
243
- with self.assertRaises(RuntimeError) as context:
244
- zotero_search_tool.run(tool_input)
245
- self.assertIn("No matching papers returned from Zotero", str(context.exception))
246
+ # Instead of expecting a RuntimeError, we now expect both items to be returned.
247
+ result = zotero_search_tool.run(tool_input)
248
+ update = result.update
249
+ filtered_papers = update["zotero_read"]
250
+ self.assertIn("paper1", filtered_papers)
251
+ self.assertIn("paper2", filtered_papers)
252
+ self.assertEqual(len(filtered_papers), 2)
246
253
 
247
254
  @patch(
248
255
  "aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
@@ -257,7 +264,7 @@ class TestZoteroSearchTool(unittest.TestCase):
257
264
  mock_zotero_class,
258
265
  mock_get_item_collections,
259
266
  ):
260
- """test items API exception"""
267
+ """Test items API exception is properly raised."""
261
268
  mock_hydra_compose.return_value = dummy_cfg
262
269
  mock_hydra_init.return_value.__enter__.return_value = None
263
270
  mock_get_item_collections.return_value = {}
@@ -306,7 +313,7 @@ class TestZoteroSearchTool(unittest.TestCase):
306
313
  "url": "http://example.com",
307
314
  "itemType": "journalArticle",
308
315
  }
309
- }, # missing key triggers line 136
316
+ }, # Missing 'key' field
310
317
  {
311
318
  "data": {
312
319
  "key": "paper_valid",
@@ -0,0 +1,17 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ This package provides modules for fetching and downloading academic papers from arXiv.
4
+ """
5
+
6
+ # Import modules
7
+ from . import abstract_downloader
8
+ from . import arxiv_downloader
9
+ from . import download_arxiv_input
10
+ from .download_arxiv_input import download_arxiv_paper
11
+
12
+ __all__ = [
13
+ "abstract_downloader",
14
+ "arxiv_downloader",
15
+ "download_arxiv_input",
16
+ "download_arxiv_paper",
17
+ ]
@@ -0,0 +1,45 @@
1
+ """
2
+ Abstract Base Class for Paper Downloaders.
3
+
4
+ This module defines the `AbstractPaperDownloader` class, which serves as a
5
+ base class for downloading scholarly papers from different sources
6
+ (e.g., arXiv, PubMed, IEEE Xplore). Any specific downloader should
7
+ inherit from this class and implement its methods.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from typing import Any, Dict
12
+
13
+
14
+ class AbstractPaperDownloader(ABC):
15
+ """
16
+ Abstract base class for scholarly paper downloaders.
17
+
18
+ This is designed to be extended for different paper sources
19
+ like arXiv, PubMed, IEEE Xplore, etc. Each implementation
20
+ must define methods for fetching metadata and downloading PDFs.
21
+ """
22
+
23
+ @abstractmethod
24
+ def fetch_metadata(self, paper_id: str) -> Dict[str, Any]:
25
+ """
26
+ Fetch metadata for a given paper ID.
27
+
28
+ Args:
29
+ paper_id (str): The unique identifier for the paper.
30
+
31
+ Returns:
32
+ Dict[str, Any]: The metadata dictionary (format depends on the data source).
33
+ """
34
+
35
+ @abstractmethod
36
+ def download_pdf(self, paper_id: str) -> bytes:
37
+ """
38
+ Download the PDF for a given paper ID.
39
+
40
+ Args:
41
+ paper_id (str): The unique identifier for the paper.
42
+
43
+ Returns:
44
+ bytes: The binary content of the downloaded PDF.
45
+ """
@@ -0,0 +1,115 @@
1
+ """
2
+ Arxiv Paper Downloader
3
+
4
+ This module provides an implementation of `AbstractPaperDownloader` for arXiv.
5
+ It connects to the arXiv API, retrieves metadata for a research paper, and
6
+ downloads the corresponding PDF.
7
+
8
+ By using an abstract base class, this implementation is extendable to other
9
+ APIs like PubMed, IEEE Xplore, etc.
10
+ """
11
+
12
+ import xml.etree.ElementTree as ET
13
+ from typing import Any, Dict
14
+ import logging
15
+ import hydra
16
+ import requests
17
+ from .abstract_downloader import AbstractPaperDownloader
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ArxivPaperDownloader(AbstractPaperDownloader):
25
+ """
26
+ Downloader class for arXiv.
27
+
28
+ This class interfaces with the arXiv API to fetch metadata
29
+ and retrieve PDFs of academic papers based on their arXiv IDs.
30
+ """
31
+
32
+ def __init__(self):
33
+ """
34
+ Initializes the arXiv paper downloader.
35
+
36
+ Uses Hydra for configuration management to retrieve API details.
37
+ """
38
+ with hydra.initialize(version_base=None, config_path="../../configs"):
39
+ cfg = hydra.compose(
40
+ config_name="config", overrides=["tools/download_arxiv_paper=default"]
41
+ )
42
+ self.api_url = cfg.tools.download_arxiv_paper.api_url
43
+ self.request_timeout = cfg.tools.download_arxiv_paper.request_timeout
44
+ self.chunk_size = cfg.tools.download_arxiv_paper.chunk_size
45
+ self.pdf_base_url = cfg.tools.download_arxiv_paper.pdf_base_url
46
+
47
+ def fetch_metadata(self, paper_id: str) -> Dict[str, Any]:
48
+ """
49
+ Fetch metadata from arXiv for a given paper ID.
50
+
51
+ Args:
52
+ paper_id (str): The arXiv ID of the paper.
53
+
54
+ Returns:
55
+ Dict[str, Any]: A dictionary containing metadata, including the XML response.
56
+ """
57
+ logger.info("Fetching metadata from arXiv for paper ID: %s", paper_id)
58
+ api_url = f"{self.api_url}?search_query=id:{paper_id}&start=0&max_results=1"
59
+ response = requests.get(api_url, timeout=self.request_timeout)
60
+ response.raise_for_status()
61
+ return {"xml": response.text}
62
+
63
+ def download_pdf(self, paper_id: str) -> Dict[str, Any]:
64
+ """
65
+ Download the PDF of a paper from arXiv.
66
+
67
+ This function first retrieves the paper's metadata to locate the PDF link
68
+ before downloading the file.
69
+
70
+ Args:
71
+ paper_id (str): The arXiv ID of the paper.
72
+
73
+ Returns:
74
+ Dict[str, Any]: A dictionary containing:
75
+ - `pdf_object`: The binary content of the downloaded PDF.
76
+ - `pdf_url`: The URL from which the PDF was fetched.
77
+ - `arxiv_id`: The arXiv ID of the downloaded paper.
78
+ """
79
+ metadata = self.fetch_metadata(paper_id)
80
+
81
+ # Parse the XML response to locate the PDF link.
82
+ root = ET.fromstring(metadata["xml"])
83
+ ns = {"atom": "http://www.w3.org/2005/Atom"}
84
+ pdf_url = next(
85
+ (
86
+ link.attrib.get("href")
87
+ for entry in root.findall("atom:entry", ns)
88
+ for link in entry.findall("atom:link", ns)
89
+ if link.attrib.get("title") == "pdf"
90
+ ),
91
+ None,
92
+ )
93
+
94
+ if not pdf_url:
95
+ raise RuntimeError(f"Failed to download PDF for arXiv ID {paper_id}.")
96
+
97
+ logger.info("Downloading PDF from: %s", pdf_url)
98
+ pdf_response = requests.get(pdf_url, stream=True, timeout=self.request_timeout)
99
+ pdf_response.raise_for_status()
100
+ # print (pdf_response)
101
+
102
+ # Combine the PDF data from chunks.
103
+ pdf_object = b"".join(
104
+ chunk
105
+ for chunk in pdf_response.iter_content(chunk_size=self.chunk_size)
106
+ if chunk
107
+ )
108
+ # print (pdf_object)
109
+ print("PDF_URL", pdf_url)
110
+
111
+ return {
112
+ "pdf_object": pdf_object,
113
+ "pdf_url": pdf_url,
114
+ "arxiv_id": paper_id,
115
+ }
@@ -0,0 +1,64 @@
1
+ # File: aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py
2
+ """
3
+ This module defines the `download_arxiv_paper` tool, which leverages the
4
+ `ArxivPaperDownloader` class to fetch and download academic papers from arXiv
5
+ based on their unique arXiv ID.
6
+ """
7
+ from typing import Annotated, Any
8
+ from pydantic import BaseModel, Field
9
+ from langchain_core.tools import tool
10
+ from langchain_core.messages import ToolMessage
11
+ from langchain_core.tools.base import InjectedToolCallId
12
+ from langgraph.types import Command
13
+
14
+ # Local import from the same package:
15
+ from .arxiv_downloader import ArxivPaperDownloader
16
+
17
+
18
+ class DownloadArxivPaperInput(BaseModel):
19
+ """
20
+ Input schema for the arXiv paper download tool.
21
+ (Optional: if you decide to keep Pydantic validation in the future)
22
+ """
23
+
24
+ arxiv_id: str = Field(
25
+ description="The arXiv paper ID used to retrieve the paper details and PDF."
26
+ )
27
+ tool_call_id: Annotated[str, InjectedToolCallId]
28
+
29
+
30
+ @tool(args_schema=DownloadArxivPaperInput, parse_docstring=True)
31
+ def download_arxiv_paper(
32
+ arxiv_id: str,
33
+ tool_call_id: Annotated[str, InjectedToolCallId],
34
+ ) -> Command[Any]:
35
+ """
36
+ Download an arXiv paper's PDF using its unique arXiv ID.
37
+
38
+ This function:
39
+ 1. Creates an `ArxivPaperDownloader` instance.
40
+ 2. Fetches metadata from arXiv using the provided `arxiv_id`.
41
+ 3. Downloads the PDF from the returned link.
42
+ 4. Returns a `Command` object containing the PDF data and a success message.
43
+
44
+ Args:
45
+ arxiv_id (str): The unique arXiv paper ID.
46
+ tool_call_id (InjectedToolCallId): A unique identifier for tracking this tool call.
47
+
48
+ Returns:
49
+ Command[Any]: Contains metadata and messages about the success of the operation.
50
+ """
51
+ downloader = ArxivPaperDownloader()
52
+
53
+ # If the downloader fails or the arxiv_id is invalid, this might raise an error
54
+ pdf_data = downloader.download_pdf(arxiv_id)
55
+ # print (pdf_data)
56
+
57
+ content = f"Successfully downloaded PDF for arXiv ID {arxiv_id}"
58
+
59
+ return Command(
60
+ update={
61
+ "pdf_data": pdf_data,
62
+ "messages": [ToolMessage(content=content, tool_call_id=tool_call_id)],
63
+ }
64
+ )
@@ -2,8 +2,8 @@
2
2
  """
3
3
  question_and_answer: Tool for performing Q&A on PDF documents using retrieval augmented generation.
4
4
 
5
- This module provides functionality to extract text from PDF binary data, split it into
6
- chunks, retrieve relevant segments via a vector store, and generate an answer to a
5
+ This module provides functionality to extract text from PDF binary data, split it into
6
+ chunks, retrieve relevant segments via a vector store, and generate an answer to a
7
7
  user-provided question using a language model chain.
8
8
  """
9
9
 
@@ -18,13 +18,15 @@ import hydra
18
18
  from langchain.chains.question_answering import load_qa_chain
19
19
  from langchain.docstore.document import Document
20
20
  from langchain.text_splitter import CharacterTextSplitter
21
- from langchain_community.vectorstores import Annoy
22
- from langchain_openai import OpenAIEmbeddings
23
21
  from langchain_core.language_models.chat_models import BaseChatModel
24
-
22
+ from langchain_core.vectorstores import InMemoryVectorStore
25
23
  from langchain_core.messages import ToolMessage
26
24
  from langchain_core.tools import tool
27
25
  from langchain_core.tools.base import InjectedToolCallId
26
+ from langchain_core.embeddings import Embeddings
27
+ from langchain_community.vectorstores import Annoy
28
+ from langchain_community.document_loaders import PyPDFLoader
29
+ from langchain_openai import OpenAIEmbeddings
28
30
  from langgraph.types import Command
29
31
  from langgraph.prebuilt import InjectedState
30
32
 
@@ -35,10 +37,13 @@ logger.setLevel(logging.INFO)
35
37
 
36
38
  # Load configuration using Hydra.
37
39
  with hydra.initialize(version_base=None, config_path="../../configs"):
38
- cfg = hydra.compose(config_name="config", overrides=["tools/question_and_answer=default"])
40
+ cfg = hydra.compose(
41
+ config_name="config", overrides=["tools/question_and_answer=default"]
42
+ )
39
43
  cfg = cfg.tools.question_and_answer
40
44
  logger.info("Loaded Question and Answer tool configuration.")
41
45
 
46
+
42
47
  class QuestionAndAnswerInput(BaseModel):
43
48
  """
44
49
  Input schema for the PDF Question and Answer tool.
@@ -47,12 +52,12 @@ class QuestionAndAnswerInput(BaseModel):
47
52
  question (str): The question to ask regarding the PDF content.
48
53
  tool_call_id (str): Unique identifier for the tool call, injected automatically.
49
54
  """
50
- question: str = Field(
51
- description="The question to ask regarding the PDF content."
52
- )
55
+
56
+ question: str = Field(description="The question to ask regarding the PDF content.")
53
57
  tool_call_id: Annotated[str, InjectedToolCallId]
54
58
  state: Annotated[dict, InjectedState]
55
59
 
60
+
56
61
  def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
57
62
  """
58
63
  Extract text content from PDF binary data.
@@ -73,7 +78,10 @@ def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
73
78
  text += page_text
74
79
  return text
75
80
 
76
- def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -> Dict[str, Any]:
81
+
82
+ def generate_answer(
83
+ question: str, pdf_bytes: bytes, llm_model: BaseChatModel
84
+ ) -> Dict[str, Any]:
77
85
  """
78
86
  Generate an answer for a question using retrieval augmented generation on PDF content.
79
87
 
@@ -92,9 +100,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
92
100
  text = extract_text_from_pdf_data(pdf_bytes)
93
101
  logger.info("Extracted text from PDF.")
94
102
  text_splitter = CharacterTextSplitter(
95
- separator="\n",
96
- chunk_size=cfg.chunk_size,
97
- chunk_overlap=cfg.chunk_overlap
103
+ separator="\n", chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap
98
104
  )
99
105
  chunks = text_splitter.split_text(text)
100
106
  documents: List[Document] = [Document(page_content=chunk) for chunk in chunks]
@@ -102,10 +108,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
102
108
 
103
109
  embeddings = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
104
110
  vector_store = Annoy.from_documents(documents, embeddings)
105
- search_results = vector_store.similarity_search(
106
- question,
107
- k=cfg.num_retrievals
108
- )
111
+ search_results = vector_store.similarity_search(question, k=cfg.num_retrievals)
109
112
  logger.info("Retrieved %d relevant document chunks.", len(search_results))
110
113
  # Use the provided llm_model to build the QA chain.
111
114
  qa_chain = load_qa_chain(llm_model, chain_type=cfg.qa_chain_type)
@@ -114,6 +117,49 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
114
117
  )
115
118
  return answer
116
119
 
120
+
121
+ def generate_answer2(
122
+ question: str, pdf_url: str, text_embedding_model: Embeddings
123
+ ) -> Dict[str, Any]:
124
+ """
125
+ Generate an answer for a question using retrieval augmented generation on PDF content.
126
+
127
+ This function extracts text from the PDF data, splits the text into manageable chunks,
128
+ performs a similarity search to retrieve the most relevant segments, and then uses a
129
+ question-answering chain (built using the provided llm_model) to generate an answer.
130
+
131
+ Args:
132
+ question (str): The question to be answered.
133
+ pdf_bytes (bytes): The binary content of the PDF document.
134
+ llm_model (BaseChatModel): The language model instance to use for answering.
135
+
136
+ Returns:
137
+ Dict[str, Any]: A dictionary containing the answer generated by the language model.
138
+ """
139
+ # text = extract_text_from_pdf_data(pdf_bytes)
140
+ # logger.info("Extracted text from PDF.")
141
+ logger.log(logging.INFO, "searching the article with the question: %s", question)
142
+ # Load the article
143
+ # loader = PyPDFLoader(state['pdf_file_name'])
144
+ # loader = PyPDFLoader("https://arxiv.org/pdf/2310.08365")
145
+ loader = PyPDFLoader(pdf_url)
146
+ # Load the pages of the article
147
+ pages = []
148
+ for page in loader.lazy_load():
149
+ pages.append(page)
150
+ # Set up text embedding model
151
+ # text_embedding_model = state['text_embedding_model']
152
+ # text_embedding_model = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
153
+ logging.info("Loaded text embedding model %s", text_embedding_model)
154
+ # Create a vector store from the pages
155
+ vector_store = InMemoryVectorStore.from_documents(pages, text_embedding_model)
156
+ # Search the article with the question
157
+ docs = vector_store.similarity_search(question)
158
+ # Return the content of the pages
159
+ return "\n".join([doc.page_content for doc in docs])
160
+ # return answer
161
+
162
+
117
163
  @tool(args_schema=QuestionAndAnswerInput)
118
164
  def question_and_answer_tool(
119
165
  question: str,
@@ -124,7 +170,7 @@ def question_and_answer_tool(
124
170
  Answer a question using PDF content stored in the state via retrieval augmented generation.
125
171
 
126
172
  This tool retrieves the PDF binary data from the state (under the key "pdf_data"), extracts its
127
- textual content, and generates an answer to the specified question. It also extracts the
173
+ textual content, and generates an answer to the specified question. It also extracts the
128
174
  llm_model (of type BaseChatModel) from the state to use for answering.
129
175
 
130
176
  Args:
@@ -138,15 +184,15 @@ def question_and_answer_tool(
138
184
  Dict[str, Any]: A dictionary containing the generated answer or an error message.
139
185
  """
140
186
  logger.info("Starting PDF Question and Answer tool using PDF data from state.")
187
+ # print (state['text_embedding_model'])
188
+ text_embedding_model = state["text_embedding_model"]
141
189
  pdf_state = state.get("pdf_data")
142
190
  if not pdf_state:
143
191
  error_msg = "No pdf_data found in state."
144
192
  logger.error(error_msg)
145
193
  return Command(
146
194
  update={
147
- "messages": [
148
- ToolMessage(content=error_msg, tool_call_id=tool_call_id)
149
- ]
195
+ "messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
150
196
  }
151
197
  )
152
198
  pdf_bytes = pdf_state.get("pdf_object")
@@ -155,16 +201,17 @@ def question_and_answer_tool(
155
201
  logger.error(error_msg)
156
202
  return Command(
157
203
  update={
158
- "messages": [
159
- ToolMessage(content=error_msg, tool_call_id=tool_call_id)
160
- ]
204
+ "messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
161
205
  }
162
206
  )
207
+ pdf_url = pdf_state.get("pdf_url")
163
208
  # Retrieve llm_model from state; use a default if not provided.
164
209
  llm_model = state.get("llm_model")
165
210
  if not llm_model:
166
211
  logger.error("Missing LLM model instance in state.")
167
212
  return {"error": "No LLM model found in state."}
168
- answer = generate_answer(question, pdf_bytes, llm_model)
169
- logger.info("Generated answer: %s", answer)
213
+ # answer = generate_answer(question, pdf_bytes, llm_model)
214
+ print(pdf_url)
215
+ answer = generate_answer2(question, pdf_url, text_embedding_model)
216
+ # logger.info("Generated answer: %s", answer)
170
217
  return answer
@@ -30,7 +30,7 @@ class MultiPaperRecInput(BaseModel):
30
30
  description="List of Semantic Scholar Paper IDs to get recommendations for"
31
31
  )
32
32
  limit: int = Field(
33
- default=2,
33
+ default=10,
34
34
  description="Maximum total number of recommendations to return",
35
35
  ge=1,
36
36
  le=500,
@@ -90,23 +90,33 @@ def get_multi_paper_recommendations(
90
90
  params["year"] = year
91
91
 
92
92
  # Wrap API call in try/except to catch connectivity issues and validate response format
93
- try:
94
- response = requests.post(
95
- endpoint,
96
- headers=headers,
97
- params=params,
98
- data=json.dumps(payload),
99
- timeout=cfg.request_timeout,
100
- )
101
- response.raise_for_status() # Raises HTTPError for bad responses
102
- except requests.exceptions.RequestException as e:
103
- logger.error(
104
- "Failed to connect to Semantic Scholar API for multi-paper recommendations: %s",
105
- e,
106
- )
107
- raise RuntimeError(
108
- "Failed to connect to Semantic Scholar API. Please retry the same query."
109
- ) from e
93
+ response = None
94
+ for attempt in range(10):
95
+ try:
96
+ response = requests.post(
97
+ endpoint,
98
+ headers=headers,
99
+ params=params,
100
+ data=json.dumps(payload),
101
+ timeout=cfg.request_timeout,
102
+ )
103
+ response.raise_for_status() # Raises HTTPError for bad responses
104
+ break # Exit loop if request is successful
105
+ except requests.exceptions.RequestException as e:
106
+ logger.error(
107
+ "Attempt %d: Failed to connect to Semantic Scholar API for "
108
+ "multi-paper recommendations: %s",
109
+ attempt + 1,
110
+ e,
111
+ )
112
+ if attempt == 9: # Last attempt
113
+ raise RuntimeError(
114
+ "Failed to connect to Semantic Scholar API after 10 attempts."
115
+ "Please retry the same query."
116
+ ) from e
117
+
118
+ if response is None:
119
+ raise RuntimeError("Failed to obtain a response from the Semantic Scholar API.")
110
120
 
111
121
  logger.info(
112
122
  "API Response Status for multi-paper recommendations: %s", response.status_code
@@ -137,11 +147,22 @@ def get_multi_paper_recommendations(
137
147
  # Create a dictionary to store the papers
138
148
  filtered_papers = {
139
149
  paper["paperId"]: {
140
- "paper_id": paper["paperId"],
150
+ "semantic_scholar_paper_id": paper["paperId"],
141
151
  "Title": paper.get("title", "N/A"),
142
152
  "Abstract": paper.get("abstract", "N/A"),
143
153
  "Year": paper.get("year", "N/A"),
154
+ "Publication Date": paper.get("publicationDate", "N/A"),
155
+ "Venue": paper.get("venue", "N/A"),
156
+ # "Publication Venue": (paper.get("publicationVenue") or {}).get("name", "N/A"),
157
+ # "Venue Type": (paper.get("publicationVenue") or {}).get("name", "N/A"),
158
+ "Journal Name": (paper.get("journal") or {}).get("name", "N/A"),
159
+ # "Journal Volume": paper.get("journal", {}).get("volume", "N/A"),
160
+ # "Journal Pages": paper.get("journal", {}).get("pages", "N/A"),
144
161
  "Citation Count": paper.get("citationCount", "N/A"),
162
+ "Authors": [
163
+ f"{author.get('name', 'N/A')} (ID: {author.get('authorId', 'N/A')})"
164
+ for author in paper.get("authors", [])
165
+ ],
145
166
  "URL": paper.get("url", "N/A"),
146
167
  "arxiv_id": paper.get("externalIds", {}).get("ArXiv", "N/A"),
147
168
  }
@@ -153,7 +174,10 @@ def get_multi_paper_recommendations(
153
174
  top_papers = list(filtered_papers.values())[:3]
154
175
  top_papers_info = "\n".join(
155
176
  [
156
- f"{i+1}. {paper['Title']} ({paper['Year']})"
177
+ # f"{i+1}. {paper['Title']} ({paper['Year']})"
178
+ f"{i+1}. {paper['Title']} ({paper['Year']}; "
179
+ f"semantic_scholar_paper_id: {paper['semantic_scholar_paper_id']}; "
180
+ f"arXiv ID: {paper['arxiv_id']})"
157
181
  for i, paper in enumerate(top_papers)
158
182
  ]
159
183
  )
@@ -165,10 +189,10 @@ def get_multi_paper_recommendations(
165
189
  "Papers are attached as an artifact."
166
190
  )
167
191
  content += " Here is a summary of the recommendations:\n"
168
- content += f"Number of papers found: {len(filtered_papers)}\n"
192
+ content += f"Number of recommended papers found: {len(filtered_papers)}\n"
169
193
  content += f"Query Paper IDs: {', '.join(paper_ids)}\n"
170
194
  content += f"Year: {year}\n" if year else ""
171
- content += "Top papers:\n" + top_papers_info
195
+ content += "Here are a few of these papers:\n" + top_papers_info
172
196
 
173
197
  return Command(
174
198
  update={