aiagents4pharma 1.28.0__py3-none-any.whl → 1.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. aiagents4pharma/talk2scholars/agents/__init__.py +1 -0
  2. aiagents4pharma/talk2scholars/agents/main_agent.py +35 -209
  3. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +86 -0
  4. aiagents4pharma/talk2scholars/agents/s2_agent.py +10 -6
  5. aiagents4pharma/talk2scholars/agents/zotero_agent.py +12 -6
  6. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +2 -48
  7. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
  8. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +5 -28
  9. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +5 -21
  10. aiagents4pharma/talk2scholars/configs/config.yaml +3 -0
  11. aiagents4pharma/talk2scholars/configs/tools/__init__.py +1 -0
  12. aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/__init__.py +3 -0
  13. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +1 -1
  14. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -1
  15. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -1
  16. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +42 -1
  17. aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
  18. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +1 -0
  19. aiagents4pharma/talk2scholars/tests/test_main_agent.py +186 -111
  20. aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +142 -0
  21. aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +154 -0
  22. aiagents4pharma/talk2scholars/tests/test_s2_display.py +74 -0
  23. aiagents4pharma/talk2scholars/tests/test_s2_multi.py +282 -0
  24. aiagents4pharma/talk2scholars/tests/test_s2_query.py +78 -0
  25. aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +65 -0
  26. aiagents4pharma/talk2scholars/tests/test_s2_search.py +266 -0
  27. aiagents4pharma/talk2scholars/tests/test_s2_single.py +274 -0
  28. aiagents4pharma/talk2scholars/tests/test_zotero_path.py +57 -0
  29. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +412 -0
  30. aiagents4pharma/talk2scholars/tests/test_zotero_write.py +626 -0
  31. aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +17 -0
  32. aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +43 -0
  33. aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +108 -0
  34. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +60 -0
  35. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +50 -34
  36. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +8 -8
  37. aiagents4pharma/talk2scholars/tools/s2/search.py +36 -23
  38. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +44 -38
  39. aiagents4pharma/talk2scholars/tools/zotero/__init__.py +2 -0
  40. aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +5 -0
  41. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +63 -0
  42. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +64 -19
  43. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +247 -0
  44. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/METADATA +6 -5
  45. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/RECORD +48 -30
  46. aiagents4pharma/talk2scholars/tests/test_call_s2.py +0 -100
  47. aiagents4pharma/talk2scholars/tests/test_call_zotero.py +0 -94
  48. aiagents4pharma/talk2scholars/tests/test_s2_tools.py +0 -355
  49. aiagents4pharma/talk2scholars/tests/test_zotero_tool.py +0 -171
  50. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/LICENSE +0 -0
  51. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/WHEEL +0 -0
  52. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.30.0.dist-info}/top_level.txt +0 -0
@@ -12,4 +12,45 @@ search_params:
12
12
  # Item Types and Limit
13
13
  zotero:
14
14
  max_limit: 100
15
- filter_item_types: ["journalArticle", "conferencePaper", "preprint"]
15
+ filter_item_types:
16
+ [
17
+ "Artwork",
18
+ "Audio Recording",
19
+ "Bill",
20
+ "Blog Post",
21
+ "Book",
22
+ "Book Section",
23
+ "Case",
24
+ "Conference Paper",
25
+ "Dataset",
26
+ "Dictionary Entry",
27
+ "Document",
28
+ "E-mail",
29
+ "Encyclopedia Article",
30
+ "Film",
31
+ "Forum Post",
32
+ "Hearing",
33
+ "Instant Message",
34
+ "Interview",
35
+ "Journal Article",
36
+ "Letter",
37
+ "Magazine Article",
38
+ "Manuscript",
39
+ "Map",
40
+ "Newspaper Article",
41
+ "Patent",
42
+ "Podcast",
43
+ "Preprint",
44
+ "Presentation",
45
+ "Radio Broadcast",
46
+ "Report",
47
+ "Software",
48
+ "Standard",
49
+ "Statute",
50
+ "Thesis",
51
+ "TV Broadcast",
52
+ "Video Recording",
53
+ "Web Page",
54
+ ]
55
+
56
+ filter_excluded_types: ["attachment", "note", "annotation"]
@@ -0,0 +1,3 @@
1
+ """
2
+ Import all the modules in the package
3
+ """
@@ -63,3 +63,4 @@ class Talk2Scholars(AgentState):
63
63
  pdf_data: Annotated[Dict[str, Any], replace_dict]
64
64
  zotero_read: Annotated[Dict[str, Any], replace_dict]
65
65
  llm_model: BaseChatModel
66
+ pdf_data: Annotated[Dict[str, Any], replace_dict]
@@ -3,119 +3,194 @@ Unit tests for main agent functionality.
3
3
  Tests the supervisor agent's routing logic and state management.
4
4
  """
5
5
 
6
- # pylint: disable=redefined-outer-name
7
6
  # pylint: disable=redefined-outer-name,too-few-public-methods
8
- import random
9
- from unittest.mock import Mock, patch, MagicMock
7
+
8
+ from types import SimpleNamespace
10
9
  import pytest
11
- from langchain_core.messages import HumanMessage, AIMessage
12
- from langgraph.graph import END
13
- from ..agents.main_agent import make_supervisor_node, get_hydra_config, get_app
14
- from ..state.state_talk2scholars import Talk2Scholars
10
+ import hydra
11
+ from langchain_core.language_models.chat_models import BaseChatModel
12
+ from langchain_openai import ChatOpenAI
13
+ from pydantic import Field
14
+ from aiagents4pharma.talk2scholars.agents.main_agent import get_app
15
+
16
+ # --- Dummy LLM Implementation ---
17
+
18
+
19
+ class DummyLLM(BaseChatModel):
20
+ """A dummy language model implementation for testing purposes."""
21
+
22
+ model_name: str = Field(...)
23
+
24
+ def _generate(self, prompt, stop=None):
25
+ """Generate a response given a prompt."""
26
+ DummyLLM.called_prompt = prompt
27
+ return "dummy output"
28
+
29
+ @property
30
+ def _llm_type(self):
31
+ """Return the type of the language model."""
32
+ return "dummy"
33
+
34
+
35
+ # --- Dummy Workflow and Sub-agent Functions ---
36
+
37
+
38
+ class DummyWorkflow:
39
+ """A dummy workflow class that records arguments for verification."""
40
+
41
+ def __init__(self, supervisor_args=None):
42
+ """Initialize the workflow with the given supervisor arguments."""
43
+ self.supervisor_args = supervisor_args or {}
44
+ self.checkpointer = None
45
+ self.name = None
46
+
47
+ def compile(self, checkpointer, name):
48
+ """Compile the workflow with the given checkpointer and name."""
49
+ self.checkpointer = checkpointer
50
+ self.name = name
51
+ return self
52
+
53
+
54
+ def dummy_get_app_s2(uniq_id, llm_model):
55
+ """Return a DummyWorkflow for the S2 agent."""
56
+ dummy_get_app_s2.called_uniq_id = uniq_id
57
+ dummy_get_app_s2.called_llm_model = llm_model
58
+ return DummyWorkflow(supervisor_args={"agent": "s2", "uniq_id": uniq_id})
59
+
60
+
61
+ def dummy_get_app_zotero(uniq_id, llm_model):
62
+ """Return a DummyWorkflow for the Zotero agent."""
63
+ dummy_get_app_zotero.called_uniq_id = uniq_id
64
+ dummy_get_app_zotero.called_llm_model = llm_model
65
+ return DummyWorkflow(supervisor_args={"agent": "zotero", "uniq_id": uniq_id})
66
+
67
+
68
+ def dummy_create_supervisor(apps, model, state_schema, **kwargs):
69
+ """Return a DummyWorkflow for the supervisor."""
70
+ dummy_create_supervisor.called_kwargs = kwargs
71
+ return DummyWorkflow(
72
+ supervisor_args={
73
+ "apps": apps,
74
+ "model": model,
75
+ "state_schema": state_schema,
76
+ **kwargs,
77
+ }
78
+ )
79
+
80
+
81
+ # --- Dummy Hydra Configuration Setup ---
82
+
83
+
84
+ class DummyHydraContext:
85
+ """A dummy context manager for mocking Hydra's initialize and compose functions."""
86
+
87
+ def __enter__(self):
88
+ """Return None when entering the context."""
89
+ return None
90
+
91
+ def __exit__(self, exc_type, exc_val, traceback):
92
+ """Exit function that does nothing."""
93
+ return None
94
+
95
+
96
+ def dict_to_namespace(d):
97
+ """Convert a dictionary to a SimpleNamespace object."""
98
+ return SimpleNamespace(
99
+ **{
100
+ key: dict_to_namespace(val) if isinstance(val, dict) else val
101
+ for key, val in d.items()
102
+ }
103
+ )
104
+
105
+
106
+ dummy_config = {
107
+ "agents": {
108
+ "talk2scholars": {"main_agent": {"system_prompt": "Dummy system prompt"}}
109
+ }
110
+ }
111
+
112
+
113
+ class DummyHydraCompose:
114
+ """A dummy class that returns a namespace from a dummy config dictionary."""
115
+
116
+ def __init__(self, config):
117
+ """Constructor that stores the dummy config."""
118
+ self.config = config
119
+
120
+ def __getattr__(self, item):
121
+ """Return a namespace from the dummy config."""
122
+ return dict_to_namespace(self.config.get(item, {}))
123
+
124
+
125
+ # --- Pytest Fixtures to Patch Dependencies ---
126
+
127
+
128
+ @pytest.fixture(autouse=True)
129
+ def patch_hydra(monkeypatch):
130
+ """Patch the hydra.initialize and hydra.compose functions to return dummy objects."""
131
+ monkeypatch.setattr(
132
+ hydra, "initialize", lambda version_base, config_path: DummyHydraContext()
133
+ )
134
+ monkeypatch.setattr(
135
+ hydra, "compose", lambda config_name, overrides: DummyHydraCompose(dummy_config)
136
+ )
15
137
 
16
138
 
17
139
  @pytest.fixture(autouse=True)
18
- def mock_hydra():
19
- """Mock Hydra configuration."""
20
- with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
21
- cfg_mock = MagicMock()
22
- cfg_mock.agents.talk2scholars.main_agent.temperature = 0
23
- cfg_mock.agents.talk2scholars.main_agent.system_prompt = "System prompt"
24
- cfg_mock.agents.talk2scholars.main_agent.router_prompt = "Router prompt"
25
- mock_compose.return_value = cfg_mock
26
- yield mock_compose
27
-
28
-
29
- def test_get_app():
30
- """Test the full initialization of the LangGraph application."""
31
- thread_id = "test_thread"
32
- mock_llm = Mock()
33
- app = get_app(thread_id, mock_llm)
34
- assert app is not None
35
- assert "supervisor" in app.nodes
36
- assert "s2_agent" in app.nodes # Ensure nodes exist
37
- assert "zotero_agent" in app.nodes
38
-
39
-
40
- def test_get_app_with_default_llm():
41
- """Test app initialization with default LLM parameters."""
42
- thread_id = "test_thread"
43
- llm_mock = Mock()
44
-
45
- # We need to explicitly pass the mock instead of patching, since the function uses
46
- # ChatOpenAI as a default argument value which is evaluated at function definition time
47
- app = get_app(thread_id, llm_mock)
48
- assert app is not None
49
- # We can only verify the app was created successfully
50
-
51
-
52
- def test_get_hydra_config():
53
- """Test that Hydra configuration loads correctly."""
54
- with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
55
- cfg_mock = MagicMock()
56
- cfg_mock.agents.talk2scholars.main_agent.temperature = 0
57
- mock_compose.return_value = cfg_mock
58
- cfg = get_hydra_config()
59
- assert cfg is not None
60
- assert cfg.temperature == 0
61
-
62
-
63
- def test_hydra_failure():
64
- """Test exception handling when Hydra fails to load config."""
65
- thread_id = "test_thread"
66
- llm_mock = Mock()
67
- with patch("hydra.initialize", side_effect=Exception("Hydra error")):
68
- with pytest.raises(Exception) as exc_info:
69
- get_app(thread_id, llm_model=llm_mock)
70
- assert "Hydra error" in str(exc_info.value)
71
-
72
-
73
- def test_supervisor_node_execution():
74
- """Test that the supervisor node routes correctly."""
75
- mock_llm = Mock()
76
- thread_id = "test_thread"
77
-
78
- class MockRouter:
79
- """Mock router class."""
80
-
81
- next = random.choice(["s2_agent", "zotero_agent"])
82
-
83
- with (
84
- patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
85
- patch.object(mock_llm, "invoke", return_value=MockRouter()),
86
- ):
87
- supervisor_node = make_supervisor_node(mock_llm, thread_id)
88
- mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
89
- result = supervisor_node(mock_state)
90
-
91
- # Accept either "s2_agent" or "zotero_agent"
92
- assert result.goto in ["s2_agent", "zotero_agent"]
93
-
94
-
95
- def test_supervisor_node_finish():
96
- """Test that supervisor node correctly handles FINISH case."""
97
- mock_llm = Mock()
98
- thread_id = "test_thread"
99
-
100
- class MockRouter:
101
- """Mock router class."""
102
-
103
- next = "FINISH"
104
-
105
- class MockAIResponse:
106
- """Mock AI response class."""
107
-
108
- def __init__(self):
109
- self.content = "Final AI Response"
110
-
111
- with (
112
- patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
113
- patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
114
- ):
115
- supervisor_node = make_supervisor_node(mock_llm, thread_id)
116
- mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
117
- result = supervisor_node(mock_state)
118
- assert result.goto == END
119
- assert "messages" in result.update
120
- assert isinstance(result.update["messages"], AIMessage)
121
- assert result.update["messages"].content == "Final AI Response"
140
+ def patch_sub_agents_and_supervisor(monkeypatch):
141
+ """Patch the sub-agents and supervisor creation functions."""
142
+ monkeypatch.setattr(
143
+ "aiagents4pharma.talk2scholars.agents.main_agent.get_app_s2", dummy_get_app_s2
144
+ )
145
+ monkeypatch.setattr(
146
+ "aiagents4pharma.talk2scholars.agents.main_agent.get_app_zotero",
147
+ dummy_get_app_zotero,
148
+ )
149
+ monkeypatch.setattr(
150
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_supervisor",
151
+ dummy_create_supervisor,
152
+ )
153
+
154
+
155
+ # --- Test Cases ---
156
+
157
+
158
+ def test_dummy_llm_generate():
159
+ """Test the dummy LLM's generate function."""
160
+ dummy = DummyLLM(model_name="test-model")
161
+ output = getattr(dummy, "_generate")("any prompt")
162
+ assert output == "dummy output"
163
+
164
+
165
+ def test_dummy_llm_llm_type():
166
+ """Test the dummy LLM's _llm_type property."""
167
+ dummy = DummyLLM(model_name="test-model")
168
+ assert getattr(dummy, "_llm_type") == "dummy"
169
+
170
+
171
+ def test_get_app_with_gpt4o_mini():
172
+ """
173
+ Test that get_app replaces a 'gpt-4o-mini' LLM with a new ChatOpenAI instance.
174
+ """
175
+ uniq_id = "test_thread"
176
+ dummy_llm = DummyLLM(model_name="gpt-4o-mini")
177
+ app = get_app(uniq_id, dummy_llm)
178
+
179
+ supervisor_args = getattr(app, "supervisor_args", {})
180
+ assert isinstance(supervisor_args.get("model"), ChatOpenAI)
181
+ assert supervisor_args.get("prompt") == "Dummy system prompt"
182
+ assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
183
+
184
+
185
+ def test_get_app_with_other_model():
186
+ """
187
+ Test that get_app does not replace the LLM if its model_name is not 'gpt-4o-mini'.
188
+ """
189
+ uniq_id = "test_thread_2"
190
+ dummy_llm = DummyLLM(model_name="other-model")
191
+ app = get_app(uniq_id, dummy_llm)
192
+
193
+ supervisor_args = getattr(app, "supervisor_args", {})
194
+ assert supervisor_args.get("model") is dummy_llm
195
+ assert supervisor_args.get("prompt") == "Dummy system prompt"
196
+ assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
@@ -0,0 +1,142 @@
1
+ """Unit tests for the paper download agent in Talk2Scholars."""
2
+
3
+ from unittest import mock
4
+ import pytest
5
+ from langchain_core.messages import HumanMessage, AIMessage
6
+ from langchain_core.language_models.chat_models import BaseChatModel
7
+ from ..agents.paper_download_agent import get_app
8
+ from ..state.state_talk2scholars import Talk2Scholars
9
+
10
+
11
+ @pytest.fixture(autouse=True)
12
+ def mock_hydra_fixture():
13
+ """Mocks Hydra configuration for tests."""
14
+ with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
15
+ cfg_mock = mock.MagicMock()
16
+ cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
17
+ cfg_mock.agents.talk2scholars.paper_download_agent.prompt = "Test prompt"
18
+ mock_compose.return_value = cfg_mock
19
+ yield mock_compose
20
+
21
+
22
+ @pytest.fixture
23
+ def mock_tools_fixture():
24
+ """Mocks paper download tools to prevent real HTTP calls."""
25
+ with (
26
+ mock.patch(
27
+ "aiagents4pharma.talk2scholars.tools.paper_download."
28
+ "download_arxiv_input.download_arxiv_paper"
29
+ ) as mock_download_arxiv_paper,
30
+ mock.patch(
31
+ "aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
32
+ ) as mock_query_results,
33
+ ):
34
+ mock_download_arxiv_paper.return_value = {
35
+ "pdf_data": {"dummy_key": "dummy_value"}
36
+ }
37
+ mock_query_results.return_value = {
38
+ "result": "Mocked Query Result"
39
+ }
40
+ yield [mock_download_arxiv_paper, mock_query_results]
41
+
42
+ @pytest.mark.usefixtures("mock_hydra_fixture")
43
+ def test_paper_download_agent_initialization():
44
+ """Ensures the paper download agent initializes properly with a prompt."""
45
+ thread_id = "test_thread_paper_dl"
46
+ llm_mock = mock.Mock(spec=BaseChatModel) # Mock LLM
47
+
48
+ with mock.patch(
49
+ "aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
50
+ ) as mock_create_agent:
51
+ mock_create_agent.return_value = mock.Mock()
52
+
53
+ app = get_app(thread_id, llm_mock)
54
+ assert app is not None, "The agent app should be successfully created."
55
+ assert mock_create_agent.called
56
+
57
+ def test_paper_download_agent_invocation():
58
+ """Verifies agent processes queries and updates state correctly."""
59
+ _ = mock_tools_fixture # Prevents unused-argument warning
60
+ thread_id = "test_thread_paper_dl"
61
+ mock_state = Talk2Scholars(
62
+ messages=[HumanMessage(content="Download paper 1234.5678")]
63
+ )
64
+ llm_mock = mock.Mock(spec=BaseChatModel)
65
+
66
+ with mock.patch(
67
+ "aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
68
+ ) as mock_create_agent:
69
+ mock_agent = mock.Mock()
70
+ mock_create_agent.return_value = mock_agent
71
+ mock_agent.invoke.return_value = {
72
+ "messages": [AIMessage(content="Here is the paper")],
73
+ "pdf_data": {"file_bytes": b"FAKE_PDF_CONTENTS"},
74
+ }
75
+
76
+
77
+ app = get_app(thread_id, llm_mock)
78
+ result = app.invoke(
79
+ mock_state,
80
+ config={
81
+ "configurable": {
82
+ "thread_id": thread_id,
83
+ "checkpoint_ns": "test_ns",
84
+ "checkpoint_id": "test_checkpoint",
85
+ }
86
+ },
87
+ )
88
+
89
+ assert "messages" in result
90
+ assert "pdf_data" in result
91
+
92
+
93
+ def test_paper_download_agent_tools_assignment(request): # Keep fixture name
94
+ """Checks correct tool assignment (download_arxiv_paper, query_results)."""
95
+ thread_id = "test_thread_paper_dl"
96
+ mock_tools = request.getfixturevalue("mock_tools_fixture")
97
+ llm_mock = mock.Mock(spec=BaseChatModel)
98
+
99
+ with (
100
+ mock.patch(
101
+ "aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
102
+ ) as mock_create_agent,
103
+ mock.patch(
104
+ "aiagents4pharma.talk2scholars.agents.paper_download_agent.ToolNode"
105
+ ) as mock_toolnode,
106
+ ):
107
+ mock_agent = mock.Mock()
108
+ mock_create_agent.return_value = mock_agent
109
+ mock_tool_instance = mock.Mock()
110
+ mock_tool_instance.tools = mock_tools
111
+ mock_toolnode.return_value= mock_tool_instance
112
+
113
+ get_app(thread_id, llm_mock)
114
+ assert mock_toolnode.called
115
+ assert len(mock_tool_instance.tools) == 2
116
+
117
+
118
+ def test_paper_download_agent_hydra_failure():
119
+ """Confirms the agent gracefully handles exceptions if Hydra fails."""
120
+ thread_id = "test_thread_paper_dl"
121
+ llm_mock = mock.Mock(spec=BaseChatModel)
122
+
123
+ with mock.patch("hydra.initialize", side_effect=Exception("Mock Hydra failure")):
124
+ with pytest.raises(Exception) as exc_info:
125
+ get_app(thread_id, llm_mock)
126
+ assert "Mock Hydra failure" in str(exc_info.value)
127
+
128
+
129
+ def test_paper_download_agent_model_failure():
130
+ """Ensures agent handles model-related failures gracefully."""
131
+ thread_id = "test_thread_paper_dl"
132
+ llm_mock = mock.Mock(spec=BaseChatModel)
133
+
134
+ with mock.patch(
135
+ "aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent",
136
+ side_effect=Exception("Mock model failure"),
137
+ ):
138
+ with pytest.raises(Exception) as exc_info:
139
+ get_app(thread_id, llm_mock)
140
+ assert "Mock model failure" in str(exc_info.value), (
141
+ "Model initialization failure should raise an exception."
142
+ )
@@ -0,0 +1,154 @@
1
+ """
2
+ Unit tests for arXiv paper downloading functionality, including:
3
+ - AbstractPaperDownloader (base class)
4
+ - ArxivPaperDownloader (arXiv-specific implementation)
5
+ - download_arxiv_paper tool function.
6
+ """
7
+
8
+ from unittest.mock import patch, MagicMock
9
+ import pytest
10
+ import requests
11
+ from requests.exceptions import HTTPError
12
+ from langgraph.types import Command
13
+ from langchain_core.messages import ToolMessage
14
+
15
+ # Import the classes and function under test
16
+ from aiagents4pharma.talk2scholars.tools.paper_download.abstract_downloader import (
17
+ AbstractPaperDownloader,
18
+ )
19
+ from aiagents4pharma.talk2scholars.tools.paper_download.arxiv_downloader import (
20
+ ArxivPaperDownloader,
21
+ )
22
+ from aiagents4pharma.talk2scholars.tools.paper_download.download_arxiv_input import (
23
+ download_arxiv_paper,
24
+ )
25
+
26
+ @pytest.mark.parametrize("class_obj", [AbstractPaperDownloader])
27
+
28
+ def test_abstract_downloader_cannot_be_instantiated(class_obj):
29
+ """
30
+ Validates that AbstractPaperDownloader is indeed abstract and raises TypeError
31
+ if anyone attempts to instantiate it directly.
32
+ """
33
+ with pytest.raises(TypeError):
34
+ class_obj()
35
+
36
+
37
+ @pytest.fixture(name="arxiv_downloader_fixture")
38
+ @pytest.mark.usefixtures("mock_hydra_config_setup")
39
+ def fixture_arxiv_downloader():
40
+ """
41
+ Provides an ArxivPaperDownloader instance with a mocked Hydra config.
42
+ """
43
+ return ArxivPaperDownloader()
44
+
45
+
46
+ def test_fetch_metadata_success(arxiv_downloader_fixture,):
47
+ """
48
+ Ensures fetch_metadata retrieves XML data correctly, given a successful HTTP response.
49
+ """
50
+ mock_response = MagicMock()
51
+ mock_response.text = "<xml>Mock ArXiv Metadata</xml>"
52
+ mock_response.raise_for_status = MagicMock()
53
+
54
+ with patch.object(requests, "get", return_value=mock_response) as mock_get:
55
+ paper_id = "1234.5678"
56
+ result = arxiv_downloader_fixture.fetch_metadata(paper_id)
57
+ mock_get.assert_called_once_with(
58
+ "http://export.arxiv.org/api/query?search_query=id:1234.5678&start=0&max_results=1",
59
+ timeout=10,
60
+ )
61
+ assert result["xml"] == "<xml>Mock ArXiv Metadata</xml>"
62
+
63
+
64
+ def test_fetch_metadata_http_error(arxiv_downloader_fixture):
65
+ """
66
+ Validates that fetch_metadata raises HTTPError when the response indicates a failure.
67
+ """
68
+ mock_response = MagicMock()
69
+ mock_response.raise_for_status.side_effect = HTTPError("Mocked HTTP failure")
70
+
71
+ with patch.object(requests, "get", return_value=mock_response):
72
+ with pytest.raises(HTTPError):
73
+ arxiv_downloader_fixture.fetch_metadata("invalid_id")
74
+
75
+
76
+ def test_download_pdf_success(arxiv_downloader_fixture):
77
+ """
78
+ Tests that download_pdf fetches the PDF link from metadata and successfully
79
+ retrieves the binary content.
80
+ """
81
+ mock_metadata = {
82
+ "xml": """
83
+ <feed xmlns="http://www.w3.org/2005/Atom">
84
+ <entry>
85
+ <link title="pdf" href="http://test.arxiv.org/pdf/1234.5678v1.pdf"/>
86
+ </entry>
87
+ </feed>
88
+ """
89
+ }
90
+
91
+ mock_pdf_response = MagicMock()
92
+ mock_pdf_response.raise_for_status = MagicMock()
93
+ mock_pdf_response.iter_content = lambda chunk_size: [b"FAKE_PDF_CONTENT"]
94
+
95
+ with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
96
+ with patch.object(requests, "get", return_value=mock_pdf_response) as mock_get:
97
+ result = arxiv_downloader_fixture.download_pdf("1234.5678")
98
+ assert result["pdf_object"] == b"FAKE_PDF_CONTENT"
99
+ assert result["pdf_url"] == "http://test.arxiv.org/pdf/1234.5678v1.pdf"
100
+ assert result["arxiv_id"] == "1234.5678"
101
+ mock_get.assert_called_once_with(
102
+ "http://test.arxiv.org/pdf/1234.5678v1.pdf",
103
+ stream=True,
104
+ timeout=10,
105
+ )
106
+
107
+
108
+ def test_download_pdf_no_pdf_link(arxiv_downloader_fixture):
109
+ """
110
+ Ensures a RuntimeError is raised if no <link> with title="pdf" is found in the XML.
111
+ """
112
+ mock_metadata = {"xml": "<feed></feed>"}
113
+
114
+ with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
115
+ with pytest.raises(RuntimeError, match="Failed to download PDF"):
116
+ arxiv_downloader_fixture.download_pdf("1234.5678")
117
+
118
+
119
+ def test_download_arxiv_paper_tool_success(arxiv_downloader_fixture):
120
+ """
121
+ Validates download_arxiv_paper orchestrates the ArxivPaperDownloader correctly,
122
+ returning a Command with PDF data and success messages.
123
+ """
124
+ mock_metadata = {"xml": "<mockxml></mockxml>"}
125
+ mock_pdf_response = {
126
+ "pdf_object": b"FAKE_PDF_CONTENT",
127
+ "pdf_url": "http://test.arxiv.org/mock.pdf",
128
+ "arxiv_id": "9999.8888",
129
+ }
130
+
131
+ with patch(
132
+ "aiagents4pharma.talk2scholars.tools.paper_download.download_arxiv_input."
133
+ "ArxivPaperDownloader",
134
+ return_value=arxiv_downloader_fixture,
135
+ ):
136
+ with patch.object(arxiv_downloader_fixture, "fetch_metadata", return_value=mock_metadata):
137
+ with patch.object(
138
+ arxiv_downloader_fixture,
139
+ "download_pdf",
140
+ return_value=mock_pdf_response,
141
+ ):
142
+ command_result = download_arxiv_paper.invoke(
143
+ {"arxiv_id": "9999.8888", "tool_call_id": "test_tool_call"}
144
+ )
145
+
146
+ assert isinstance(command_result, Command)
147
+ assert "pdf_data" in command_result.update
148
+ assert command_result.update["pdf_data"] == mock_pdf_response
149
+
150
+ messages = command_result.update.get("messages", [])
151
+ assert len(messages) == 1
152
+ assert isinstance(messages[0], ToolMessage)
153
+ assert "Successfully downloaded PDF" in messages[0].content
154
+ assert "9999.8888" in messages[0].content