aiagents4pharma 1.30.0__py3-none-any.whl → 1.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. aiagents4pharma/talk2scholars/agents/main_agent.py +18 -10
  2. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +5 -6
  3. aiagents4pharma/talk2scholars/agents/pdf_agent.py +4 -10
  4. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +18 -9
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +2 -2
  6. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +1 -0
  7. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +6 -1
  8. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +7 -1
  9. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +6 -1
  10. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -1
  11. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +4 -1
  12. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +84 -53
  13. aiagents4pharma/talk2scholars/tests/test_main_agent.py +24 -0
  14. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +79 -15
  15. aiagents4pharma/talk2scholars/tests/test_routing_logic.py +12 -8
  16. aiagents4pharma/talk2scholars/tests/test_s2_multi.py +27 -4
  17. aiagents4pharma/talk2scholars/tests/test_s2_search.py +19 -3
  18. aiagents4pharma/talk2scholars/tests/test_s2_single.py +27 -3
  19. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +17 -10
  20. aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +2 -0
  21. aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +11 -4
  22. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +5 -1
  23. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -26
  24. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +46 -22
  25. aiagents4pharma/talk2scholars/tools/s2/query_results.py +1 -1
  26. aiagents4pharma/talk2scholars/tools/s2/search.py +40 -12
  27. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +42 -16
  28. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +33 -16
  29. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +39 -7
  30. {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/METADATA +2 -2
  31. {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/RECORD +34 -34
  32. {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/WHEEL +1 -1
  33. {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/LICENSE +0 -0
  34. {aiagents4pharma-1.30.0.dist-info → aiagents4pharma-1.30.1.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,7 @@ def mock_router():
23
23
 
24
24
  def mock_supervisor_node(state):
25
25
  query = state["messages"][-1].content.lower()
26
-
27
- # Expanded keyword matching for S2 Agent
26
+ # Define keywords for each sub-agent.
28
27
  s2_keywords = [
29
28
  "paper",
30
29
  "research",
@@ -34,13 +33,19 @@ def mock_router():
34
33
  "references",
35
34
  ]
36
35
  zotero_keywords = ["zotero", "library", "saved papers", "academic library"]
36
+ pdf_keywords = ["pdf", "document", "read pdf"]
37
+ paper_download_keywords = ["download", "arxiv", "fetch paper", "paper download"]
37
38
 
39
+ # Priority ordering: Zotero, then paper download, then PDF, then S2.
38
40
  if any(keyword in query for keyword in zotero_keywords):
39
41
  return Command(goto="zotero_agent")
42
+ if any(keyword in query for keyword in paper_download_keywords):
43
+ return Command(goto="paper_download_agent")
44
+ if any(keyword in query for keyword in pdf_keywords):
45
+ return Command(goto="pdf_agent")
40
46
  if any(keyword in query for keyword in s2_keywords):
41
47
  return Command(goto="s2_agent")
42
-
43
- # If no match, default to ending the conversation
48
+ # Default to end if no keyword matches.
44
49
  return Command(goto=END)
45
50
 
46
51
  return mock_supervisor_node
@@ -55,10 +60,9 @@ def mock_router():
55
60
  ("Fetch my academic library.", "zotero_agent"),
56
61
  ("Retrieve citations.", "s2_agent"),
57
62
  ("Can you get journal articles?", "s2_agent"),
58
- (
59
- "Completely unrelated query.",
60
- "__end__",
61
- ), # NEW: Should trigger the `END` case
63
+ ("I want to read the PDF document.", "pdf_agent"),
64
+ ("Download the paper from arxiv.", "paper_download_agent"),
65
+ ("Completely unrelated query.", "__end__"),
62
66
  ],
63
67
  )
64
68
  def test_routing_logic(mock_state, mock_router, user_query, expected_agent):
@@ -94,7 +94,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
94
94
  {
95
95
  "paperId": "paperA",
96
96
  "title": "Multi Rec Paper A",
97
- "authors": ["Author X"],
97
+ "authors": [{"name": "Author X", "authorId": "AX"}],
98
98
  "year": 2019,
99
99
  "citationCount": 12,
100
100
  "url": "http://paperA",
@@ -103,7 +103,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
103
103
  {
104
104
  "paperId": "paperB",
105
105
  "title": "Multi Rec Paper B",
106
- "authors": ["Author Y"],
106
+ "authors": [{"name": "Author Y", "authorId": "AY"}],
107
107
  "year": 2020,
108
108
  "citationCount": 18,
109
109
  "url": "http://paperB",
@@ -112,7 +112,7 @@ def dummy_requests_post_success(url, headers, params, data, timeout):
112
112
  {
113
113
  "paperId": "paperC",
114
114
  "title": "Multi Rec Paper C",
115
- "authors": None, # This one should be filtered out.
115
+ "authors": None, # This paper should be filtered out.
116
116
  "year": 2021,
117
117
  "citationCount": 25,
118
118
  "url": "http://paperC",
@@ -277,6 +277,29 @@ def test_multi_paper_rec_requests_exception(monkeypatch):
277
277
  }
278
278
  with pytest.raises(
279
279
  RuntimeError,
280
- match="Failed to connect to Semantic Scholar API. Please retry the same query.",
280
+ match="Failed to connect to Semantic Scholar API after 10 attempts."
281
+ "Please retry the same query.",
282
+ ):
283
+ get_multi_paper_recommendations.run(input_data)
284
+
285
+
286
+ def test_multi_paper_rec_no_response(monkeypatch):
287
+ """
288
+ Test that get_multi_paper_recommendations raises a RuntimeError when no response is obtained.
289
+ This is simulated by patching 'range' in the underlying function's globals to
290
+ return an empty iterator,
291
+ so that the for loop does not iterate and response remains None.
292
+ """
293
+ # Patch 'range' in the underlying function's globals (accessed via .func.__globals__)
294
+ monkeypatch.setitem(
295
+ get_multi_paper_recommendations.func.__globals__, "range", lambda x: iter([])
296
+ )
297
+ tool_call_id = "test_tool_call_id"
298
+ input_data = {
299
+ "paper_ids": ["p1", "p2"],
300
+ "tool_call_id": tool_call_id,
301
+ }
302
+ with pytest.raises(
303
+ RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
281
304
  ):
282
305
  get_multi_paper_recommendations.run(input_data)
@@ -85,7 +85,7 @@ def dummy_requests_get_success(url, params, timeout):
85
85
  {
86
86
  "paperId": "1",
87
87
  "title": "Paper 1",
88
- "authors": ["Author A"],
88
+ "authors": [{"name": "Author A", "authorId": "A1"}],
89
89
  "year": 2020,
90
90
  "citationCount": 10,
91
91
  "url": "http://paper1",
@@ -94,7 +94,7 @@ def dummy_requests_get_success(url, params, timeout):
94
94
  {
95
95
  "paperId": "2",
96
96
  "title": "Paper 2",
97
- "authors": ["Author B"],
97
+ "authors": [{"name": "Author B", "authorId": "B1"}],
98
98
  "year": 2021,
99
99
  "citationCount": 20,
100
100
  "url": "http://paper2",
@@ -256,7 +256,8 @@ def test_search_tool_requests_exception(monkeypatch):
256
256
  tool_call_id = "test_tool_call_id"
257
257
  with pytest.raises(
258
258
  RuntimeError,
259
- match="Failed to connect to Semantic Scholar API. Please retry the same query.",
259
+ match="Failed to connect to Semantic Scholar API after 10 attempts."
260
+ "Please retry the same query.",
260
261
  ):
261
262
  search_tool.run(
262
263
  {
@@ -264,3 +265,18 @@ def test_search_tool_requests_exception(monkeypatch):
264
265
  "tool_call_id": tool_call_id,
265
266
  }
266
267
  )
268
+
269
+
270
+ def test_search_tool_no_response(monkeypatch):
271
+ """
272
+ Test that search_tool raises a RuntimeError when no response is obtained.
273
+ This is simulated by patching 'range' in the original function's globals (a dict)
274
+ so that it returns an empty iterator, leaving response as None.
275
+ """
276
+ # Patch 'range' in the original function's globals using setitem.
277
+ monkeypatch.setitem(search_tool.func.__globals__, "range", lambda x: iter([]))
278
+ tool_call_id = "test_tool_call_id"
279
+ with pytest.raises(
280
+ RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
281
+ ):
282
+ search_tool.run({"query": "test", "tool_call_id": tool_call_id})
@@ -92,7 +92,7 @@ def dummy_requests_get_success(url, params, timeout):
92
92
  {
93
93
  "paperId": "paper1",
94
94
  "title": "Recommended Paper 1",
95
- "authors": ["Author A"],
95
+ "authors": [{"name": "Author A", "authorId": "A1"}],
96
96
  "year": 2020,
97
97
  "citationCount": 15,
98
98
  "url": "http://paper1",
@@ -101,7 +101,7 @@ def dummy_requests_get_success(url, params, timeout):
101
101
  {
102
102
  "paperId": "paper2",
103
103
  "title": "Recommended Paper 2",
104
- "authors": ["Author B"],
104
+ "authors": [{"name": "Author B", "authorId": "B1"}],
105
105
  "year": 2021,
106
106
  "citationCount": 25,
107
107
  "url": "http://paper2",
@@ -269,6 +269,30 @@ def test_single_paper_rec_requests_exception(monkeypatch):
269
269
  }
270
270
  with pytest.raises(
271
271
  RuntimeError,
272
- match="Failed to connect to Semantic Scholar API. Please retry the same query.",
272
+ match="Failed to connect to Semantic Scholar API after 10 attempts."
273
+ "Please retry the same query.",
274
+ ):
275
+ get_single_paper_recommendations.run(input_data)
276
+
277
+
278
+ def test_single_paper_rec_no_response(monkeypatch):
279
+ """
280
+ Test that get_single_paper_recommendations raises a RuntimeError
281
+ when no response is obtained from the API.
282
+
283
+ This is simulated by patching 'range' in the underlying function's globals
284
+ to return an empty iterator, so the for-loop never iterates and response remains None.
285
+ """
286
+ # Patch 'range' in the underlying function's globals (accessed via .func.__globals__)
287
+ monkeypatch.setitem(
288
+ get_single_paper_recommendations.func.__globals__, "range", lambda x: iter([])
289
+ )
290
+ tool_call_id = "test_tool_call_id"
291
+ input_data = {
292
+ "paper_id": "12345",
293
+ "tool_call_id": tool_call_id,
294
+ }
295
+ with pytest.raises(
296
+ RuntimeError, match="Failed to obtain a response from the Semantic Scholar API."
273
297
  ):
274
298
  get_single_paper_recommendations.run(input_data)
@@ -26,7 +26,7 @@ dummy_cfg = SimpleNamespace(tools=SimpleNamespace(zotero_read=dummy_zotero_read_
26
26
 
27
27
 
28
28
  class TestZoteroSearchTool(unittest.TestCase):
29
- """test for Zotero search tool"""
29
+ """Tests for Zotero search tool."""
30
30
 
31
31
  @patch(
32
32
  "aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
@@ -41,7 +41,7 @@ class TestZoteroSearchTool(unittest.TestCase):
41
41
  mock_zotero_class,
42
42
  mock_get_item_collections,
43
43
  ):
44
- """test valid query"""
44
+ """Test valid query returns correct Command output."""
45
45
  # Setup Hydra mocks
46
46
  mock_hydra_compose.return_value = dummy_cfg
47
47
  mock_hydra_init.return_value.__enter__.return_value = None
@@ -116,7 +116,7 @@ class TestZoteroSearchTool(unittest.TestCase):
116
116
  mock_zotero_class,
117
117
  mock_get_item_collections,
118
118
  ):
119
- """test empty query fetches all items"""
119
+ """Test empty query fetches all items."""
120
120
  mock_hydra_compose.return_value = dummy_cfg
121
121
  mock_hydra_init.return_value.__enter__.return_value = None
122
122
 
@@ -166,7 +166,7 @@ class TestZoteroSearchTool(unittest.TestCase):
166
166
  mock_zotero_class,
167
167
  mock_get_item_collections,
168
168
  ):
169
- """test no items returned from Zotero"""
169
+ """Test no items returned from Zotero."""
170
170
  mock_hydra_compose.return_value = dummy_cfg
171
171
  mock_hydra_init.return_value.__enter__.return_value = None
172
172
 
@@ -199,7 +199,10 @@ class TestZoteroSearchTool(unittest.TestCase):
199
199
  mock_zotero_class,
200
200
  mock_get_item_collections,
201
201
  ):
202
- """test no matching papers returned from Zotero"""
202
+ """
203
+ Test that when non-research items (e.g. attachments, notes) are returned,
204
+ they are still included since filtering is disabled.
205
+ """
203
206
  mock_hydra_compose.return_value = dummy_cfg
204
207
  mock_hydra_init.return_value.__enter__.return_value = None
205
208
 
@@ -240,9 +243,13 @@ class TestZoteroSearchTool(unittest.TestCase):
240
243
  "tool_call_id": tool_call_id,
241
244
  "limit": 2,
242
245
  }
243
- with self.assertRaises(RuntimeError) as context:
244
- zotero_search_tool.run(tool_input)
245
- self.assertIn("No matching papers returned from Zotero", str(context.exception))
246
+ # Instead of expecting a RuntimeError, we now expect both items to be returned.
247
+ result = zotero_search_tool.run(tool_input)
248
+ update = result.update
249
+ filtered_papers = update["zotero_read"]
250
+ self.assertIn("paper1", filtered_papers)
251
+ self.assertIn("paper2", filtered_papers)
252
+ self.assertEqual(len(filtered_papers), 2)
246
253
 
247
254
  @patch(
248
255
  "aiagents4pharma.talk2scholars.tools.zotero.zotero_read.get_item_collections"
@@ -257,7 +264,7 @@ class TestZoteroSearchTool(unittest.TestCase):
257
264
  mock_zotero_class,
258
265
  mock_get_item_collections,
259
266
  ):
260
- """test items API exception"""
267
+ """Test items API exception is properly raised."""
261
268
  mock_hydra_compose.return_value = dummy_cfg
262
269
  mock_hydra_init.return_value.__enter__.return_value = None
263
270
  mock_get_item_collections.return_value = {}
@@ -306,7 +313,7 @@ class TestZoteroSearchTool(unittest.TestCase):
306
313
  "url": "http://example.com",
307
314
  "itemType": "journalArticle",
308
315
  }
309
- }, # missing key triggers line 136
316
+ }, # Missing 'key' field
310
317
  {
311
318
  "data": {
312
319
  "key": "paper_valid",
@@ -9,6 +9,8 @@ inherit from this class and implement its methods.
9
9
 
10
10
  from abc import ABC, abstractmethod
11
11
  from typing import Any, Dict
12
+
13
+
12
14
  class AbstractPaperDownloader(ABC):
13
15
  """
14
16
  Abstract base class for scholarly paper downloaders.
@@ -8,6 +8,7 @@ downloads the corresponding PDF.
8
8
  By using an abstract base class, this implementation is extendable to other
9
9
  APIs like PubMed, IEEE Xplore, etc.
10
10
  """
11
+
11
12
  import xml.etree.ElementTree as ET
12
13
  from typing import Any, Dict
13
14
  import logging
@@ -19,6 +20,7 @@ from .abstract_downloader import AbstractPaperDownloader
19
20
  logging.basicConfig(level=logging.INFO)
20
21
  logger = logging.getLogger(__name__)
21
22
 
23
+
22
24
  class ArxivPaperDownloader(AbstractPaperDownloader):
23
25
  """
24
26
  Downloader class for arXiv.
@@ -35,13 +37,13 @@ class ArxivPaperDownloader(AbstractPaperDownloader):
35
37
  """
36
38
  with hydra.initialize(version_base=None, config_path="../../configs"):
37
39
  cfg = hydra.compose(
38
- config_name="config",
39
- overrides=["tools/download_arxiv_paper=default"]
40
+ config_name="config", overrides=["tools/download_arxiv_paper=default"]
40
41
  )
41
42
  self.api_url = cfg.tools.download_arxiv_paper.api_url
42
43
  self.request_timeout = cfg.tools.download_arxiv_paper.request_timeout
43
44
  self.chunk_size = cfg.tools.download_arxiv_paper.chunk_size
44
45
  self.pdf_base_url = cfg.tools.download_arxiv_paper.pdf_base_url
46
+
45
47
  def fetch_metadata(self, paper_id: str) -> Dict[str, Any]:
46
48
  """
47
49
  Fetch metadata from arXiv for a given paper ID.
@@ -95,11 +97,16 @@ class ArxivPaperDownloader(AbstractPaperDownloader):
95
97
  logger.info("Downloading PDF from: %s", pdf_url)
96
98
  pdf_response = requests.get(pdf_url, stream=True, timeout=self.request_timeout)
97
99
  pdf_response.raise_for_status()
100
+ # print (pdf_response)
98
101
 
99
102
  # Combine the PDF data from chunks.
100
103
  pdf_object = b"".join(
101
- chunk for chunk in pdf_response.iter_content(chunk_size=self.chunk_size) if chunk
102
- )
104
+ chunk
105
+ for chunk in pdf_response.iter_content(chunk_size=self.chunk_size)
106
+ if chunk
107
+ )
108
+ # print (pdf_object)
109
+ print("PDF_URL", pdf_url)
103
110
 
104
111
  return {
105
112
  "pdf_object": pdf_object,
@@ -14,16 +14,19 @@ from langgraph.types import Command
14
14
  # Local import from the same package:
15
15
  from .arxiv_downloader import ArxivPaperDownloader
16
16
 
17
+
17
18
  class DownloadArxivPaperInput(BaseModel):
18
19
  """
19
20
  Input schema for the arXiv paper download tool.
20
21
  (Optional: if you decide to keep Pydantic validation in the future)
21
22
  """
23
+
22
24
  arxiv_id: str = Field(
23
25
  description="The arXiv paper ID used to retrieve the paper details and PDF."
24
- )
26
+ )
25
27
  tool_call_id: Annotated[str, InjectedToolCallId]
26
28
 
29
+
27
30
  @tool(args_schema=DownloadArxivPaperInput, parse_docstring=True)
28
31
  def download_arxiv_paper(
29
32
  arxiv_id: str,
@@ -49,6 +52,7 @@ def download_arxiv_paper(
49
52
 
50
53
  # If the downloader fails or the arxiv_id is invalid, this might raise an error
51
54
  pdf_data = downloader.download_pdf(arxiv_id)
55
+ # print (pdf_data)
52
56
 
53
57
  content = f"Successfully downloaded PDF for arXiv ID {arxiv_id}"
54
58
 
@@ -2,8 +2,8 @@
2
2
  """
3
3
  question_and_answer: Tool for performing Q&A on PDF documents using retrieval augmented generation.
4
4
 
5
- This module provides functionality to extract text from PDF binary data, split it into
6
- chunks, retrieve relevant segments via a vector store, and generate an answer to a
5
+ This module provides functionality to extract text from PDF binary data, split it into
6
+ chunks, retrieve relevant segments via a vector store, and generate an answer to a
7
7
  user-provided question using a language model chain.
8
8
  """
9
9
 
@@ -18,13 +18,15 @@ import hydra
18
18
  from langchain.chains.question_answering import load_qa_chain
19
19
  from langchain.docstore.document import Document
20
20
  from langchain.text_splitter import CharacterTextSplitter
21
- from langchain_community.vectorstores import Annoy
22
- from langchain_openai import OpenAIEmbeddings
23
21
  from langchain_core.language_models.chat_models import BaseChatModel
24
-
22
+ from langchain_core.vectorstores import InMemoryVectorStore
25
23
  from langchain_core.messages import ToolMessage
26
24
  from langchain_core.tools import tool
27
25
  from langchain_core.tools.base import InjectedToolCallId
26
+ from langchain_core.embeddings import Embeddings
27
+ from langchain_community.vectorstores import Annoy
28
+ from langchain_community.document_loaders import PyPDFLoader
29
+ from langchain_openai import OpenAIEmbeddings
28
30
  from langgraph.types import Command
29
31
  from langgraph.prebuilt import InjectedState
30
32
 
@@ -35,10 +37,13 @@ logger.setLevel(logging.INFO)
35
37
 
36
38
  # Load configuration using Hydra.
37
39
  with hydra.initialize(version_base=None, config_path="../../configs"):
38
- cfg = hydra.compose(config_name="config", overrides=["tools/question_and_answer=default"])
40
+ cfg = hydra.compose(
41
+ config_name="config", overrides=["tools/question_and_answer=default"]
42
+ )
39
43
  cfg = cfg.tools.question_and_answer
40
44
  logger.info("Loaded Question and Answer tool configuration.")
41
45
 
46
+
42
47
  class QuestionAndAnswerInput(BaseModel):
43
48
  """
44
49
  Input schema for the PDF Question and Answer tool.
@@ -47,12 +52,12 @@ class QuestionAndAnswerInput(BaseModel):
47
52
  question (str): The question to ask regarding the PDF content.
48
53
  tool_call_id (str): Unique identifier for the tool call, injected automatically.
49
54
  """
50
- question: str = Field(
51
- description="The question to ask regarding the PDF content."
52
- )
55
+
56
+ question: str = Field(description="The question to ask regarding the PDF content.")
53
57
  tool_call_id: Annotated[str, InjectedToolCallId]
54
58
  state: Annotated[dict, InjectedState]
55
59
 
60
+
56
61
  def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
57
62
  """
58
63
  Extract text content from PDF binary data.
@@ -73,7 +78,10 @@ def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
73
78
  text += page_text
74
79
  return text
75
80
 
76
- def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -> Dict[str, Any]:
81
+
82
+ def generate_answer(
83
+ question: str, pdf_bytes: bytes, llm_model: BaseChatModel
84
+ ) -> Dict[str, Any]:
77
85
  """
78
86
  Generate an answer for a question using retrieval augmented generation on PDF content.
79
87
 
@@ -92,9 +100,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
92
100
  text = extract_text_from_pdf_data(pdf_bytes)
93
101
  logger.info("Extracted text from PDF.")
94
102
  text_splitter = CharacterTextSplitter(
95
- separator="\n",
96
- chunk_size=cfg.chunk_size,
97
- chunk_overlap=cfg.chunk_overlap
103
+ separator="\n", chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap
98
104
  )
99
105
  chunks = text_splitter.split_text(text)
100
106
  documents: List[Document] = [Document(page_content=chunk) for chunk in chunks]
@@ -102,10 +108,7 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
102
108
 
103
109
  embeddings = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
104
110
  vector_store = Annoy.from_documents(documents, embeddings)
105
- search_results = vector_store.similarity_search(
106
- question,
107
- k=cfg.num_retrievals
108
- )
111
+ search_results = vector_store.similarity_search(question, k=cfg.num_retrievals)
109
112
  logger.info("Retrieved %d relevant document chunks.", len(search_results))
110
113
  # Use the provided llm_model to build the QA chain.
111
114
  qa_chain = load_qa_chain(llm_model, chain_type=cfg.qa_chain_type)
@@ -114,6 +117,49 @@ def generate_answer(question: str, pdf_bytes: bytes, llm_model: BaseChatModel) -
114
117
  )
115
118
  return answer
116
119
 
120
+
121
+ def generate_answer2(
122
+ question: str, pdf_url: str, text_embedding_model: Embeddings
123
+ ) -> Dict[str, Any]:
124
+ """
125
+ Generate an answer for a question using retrieval augmented generation on PDF content.
126
+
127
+ This function extracts text from the PDF data, splits the text into manageable chunks,
128
+ performs a similarity search to retrieve the most relevant segments, and then uses a
129
+ question-answering chain (built using the provided llm_model) to generate an answer.
130
+
131
+ Args:
132
+ question (str): The question to be answered.
133
+ pdf_bytes (bytes): The binary content of the PDF document.
134
+ llm_model (BaseChatModel): The language model instance to use for answering.
135
+
136
+ Returns:
137
+ Dict[str, Any]: A dictionary containing the answer generated by the language model.
138
+ """
139
+ # text = extract_text_from_pdf_data(pdf_bytes)
140
+ # logger.info("Extracted text from PDF.")
141
+ logger.log(logging.INFO, "searching the article with the question: %s", question)
142
+ # Load the article
143
+ # loader = PyPDFLoader(state['pdf_file_name'])
144
+ # loader = PyPDFLoader("https://arxiv.org/pdf/2310.08365")
145
+ loader = PyPDFLoader(pdf_url)
146
+ # Load the pages of the article
147
+ pages = []
148
+ for page in loader.lazy_load():
149
+ pages.append(page)
150
+ # Set up text embedding model
151
+ # text_embedding_model = state['text_embedding_model']
152
+ # text_embedding_model = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
153
+ logging.info("Loaded text embedding model %s", text_embedding_model)
154
+ # Create a vector store from the pages
155
+ vector_store = InMemoryVectorStore.from_documents(pages, text_embedding_model)
156
+ # Search the article with the question
157
+ docs = vector_store.similarity_search(question)
158
+ # Return the content of the pages
159
+ return "\n".join([doc.page_content for doc in docs])
160
+ # return answer
161
+
162
+
117
163
  @tool(args_schema=QuestionAndAnswerInput)
118
164
  def question_and_answer_tool(
119
165
  question: str,
@@ -124,7 +170,7 @@ def question_and_answer_tool(
124
170
  Answer a question using PDF content stored in the state via retrieval augmented generation.
125
171
 
126
172
  This tool retrieves the PDF binary data from the state (under the key "pdf_data"), extracts its
127
- textual content, and generates an answer to the specified question. It also extracts the
173
+ textual content, and generates an answer to the specified question. It also extracts the
128
174
  llm_model (of type BaseChatModel) from the state to use for answering.
129
175
 
130
176
  Args:
@@ -138,15 +184,15 @@ def question_and_answer_tool(
138
184
  Dict[str, Any]: A dictionary containing the generated answer or an error message.
139
185
  """
140
186
  logger.info("Starting PDF Question and Answer tool using PDF data from state.")
187
+ # print (state['text_embedding_model'])
188
+ text_embedding_model = state["text_embedding_model"]
141
189
  pdf_state = state.get("pdf_data")
142
190
  if not pdf_state:
143
191
  error_msg = "No pdf_data found in state."
144
192
  logger.error(error_msg)
145
193
  return Command(
146
194
  update={
147
- "messages": [
148
- ToolMessage(content=error_msg, tool_call_id=tool_call_id)
149
- ]
195
+ "messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
150
196
  }
151
197
  )
152
198
  pdf_bytes = pdf_state.get("pdf_object")
@@ -155,16 +201,17 @@ def question_and_answer_tool(
155
201
  logger.error(error_msg)
156
202
  return Command(
157
203
  update={
158
- "messages": [
159
- ToolMessage(content=error_msg, tool_call_id=tool_call_id)
160
- ]
204
+ "messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
161
205
  }
162
206
  )
207
+ pdf_url = pdf_state.get("pdf_url")
163
208
  # Retrieve llm_model from state; use a default if not provided.
164
209
  llm_model = state.get("llm_model")
165
210
  if not llm_model:
166
211
  logger.error("Missing LLM model instance in state.")
167
212
  return {"error": "No LLM model found in state."}
168
- answer = generate_answer(question, pdf_bytes, llm_model)
169
- logger.info("Generated answer: %s", answer)
213
+ # answer = generate_answer(question, pdf_bytes, llm_model)
214
+ print(pdf_url)
215
+ answer = generate_answer2(question, pdf_url, text_embedding_model)
216
+ # logger.info("Generated answer: %s", answer)
170
217
  return answer