aiagents4pharma 1.28.0__py3-none-any.whl → 1.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. aiagents4pharma/talk2scholars/agents/main_agent.py +35 -209
  2. aiagents4pharma/talk2scholars/agents/s2_agent.py +10 -6
  3. aiagents4pharma/talk2scholars/agents/zotero_agent.py +12 -6
  4. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +2 -48
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +5 -28
  6. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +5 -21
  7. aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
  8. aiagents4pharma/talk2scholars/configs/tools/__init__.py +1 -0
  9. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +1 -1
  10. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -1
  11. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -1
  12. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +42 -1
  13. aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
  14. aiagents4pharma/talk2scholars/tests/test_main_agent.py +186 -111
  15. aiagents4pharma/talk2scholars/tests/test_s2_display.py +74 -0
  16. aiagents4pharma/talk2scholars/tests/test_s2_multi.py +282 -0
  17. aiagents4pharma/talk2scholars/tests/test_s2_query.py +78 -0
  18. aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +65 -0
  19. aiagents4pharma/talk2scholars/tests/test_s2_search.py +266 -0
  20. aiagents4pharma/talk2scholars/tests/test_s2_single.py +274 -0
  21. aiagents4pharma/talk2scholars/tests/test_zotero_path.py +57 -0
  22. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +412 -0
  23. aiagents4pharma/talk2scholars/tests/test_zotero_write.py +626 -0
  24. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +50 -34
  25. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +8 -8
  26. aiagents4pharma/talk2scholars/tools/s2/search.py +36 -23
  27. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +44 -38
  28. aiagents4pharma/talk2scholars/tools/zotero/__init__.py +2 -0
  29. aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +5 -0
  30. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +63 -0
  31. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +64 -19
  32. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +247 -0
  33. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/METADATA +6 -5
  34. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/RECORD +37 -28
  35. aiagents4pharma/talk2scholars/tests/test_call_s2.py +0 -100
  36. aiagents4pharma/talk2scholars/tests/test_call_zotero.py +0 -94
  37. aiagents4pharma/talk2scholars/tests/test_s2_tools.py +0 -355
  38. aiagents4pharma/talk2scholars/tests/test_zotero_tool.py +0 -171
  39. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/LICENSE +0 -0
  40. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/WHEEL +0 -0
  41. {aiagents4pharma-1.28.0.dist-info → aiagents4pharma-1.29.0.dist-info}/top_level.txt +0 -0
@@ -3,119 +3,194 @@ Unit tests for main agent functionality.
3
3
  Tests the supervisor agent's routing logic and state management.
4
4
  """
5
5
 
6
- # pylint: disable=redefined-outer-name
7
6
  # pylint: disable=redefined-outer-name,too-few-public-methods
8
- import random
9
- from unittest.mock import Mock, patch, MagicMock
7
+
8
+ from types import SimpleNamespace
10
9
  import pytest
11
- from langchain_core.messages import HumanMessage, AIMessage
12
- from langgraph.graph import END
13
- from ..agents.main_agent import make_supervisor_node, get_hydra_config, get_app
14
- from ..state.state_talk2scholars import Talk2Scholars
10
+ import hydra
11
+ from langchain_core.language_models.chat_models import BaseChatModel
12
+ from langchain_openai import ChatOpenAI
13
+ from pydantic import Field
14
+ from aiagents4pharma.talk2scholars.agents.main_agent import get_app
15
+
16
+ # --- Dummy LLM Implementation ---
17
+
18
+
19
+ class DummyLLM(BaseChatModel):
20
+ """A dummy language model implementation for testing purposes."""
21
+
22
+ model_name: str = Field(...)
23
+
24
+ def _generate(self, prompt, stop=None):
25
+ """Generate a response given a prompt."""
26
+ DummyLLM.called_prompt = prompt
27
+ return "dummy output"
28
+
29
+ @property
30
+ def _llm_type(self):
31
+ """Return the type of the language model."""
32
+ return "dummy"
33
+
34
+
35
+ # --- Dummy Workflow and Sub-agent Functions ---
36
+
37
+
38
+ class DummyWorkflow:
39
+ """A dummy workflow class that records arguments for verification."""
40
+
41
+ def __init__(self, supervisor_args=None):
42
+ """Initialize the workflow with the given supervisor arguments."""
43
+ self.supervisor_args = supervisor_args or {}
44
+ self.checkpointer = None
45
+ self.name = None
46
+
47
+ def compile(self, checkpointer, name):
48
+ """Compile the workflow with the given checkpointer and name."""
49
+ self.checkpointer = checkpointer
50
+ self.name = name
51
+ return self
52
+
53
+
54
+ def dummy_get_app_s2(uniq_id, llm_model):
55
+ """Return a DummyWorkflow for the S2 agent."""
56
+ dummy_get_app_s2.called_uniq_id = uniq_id
57
+ dummy_get_app_s2.called_llm_model = llm_model
58
+ return DummyWorkflow(supervisor_args={"agent": "s2", "uniq_id": uniq_id})
59
+
60
+
61
+ def dummy_get_app_zotero(uniq_id, llm_model):
62
+ """Return a DummyWorkflow for the Zotero agent."""
63
+ dummy_get_app_zotero.called_uniq_id = uniq_id
64
+ dummy_get_app_zotero.called_llm_model = llm_model
65
+ return DummyWorkflow(supervisor_args={"agent": "zotero", "uniq_id": uniq_id})
66
+
67
+
68
+ def dummy_create_supervisor(apps, model, state_schema, **kwargs):
69
+ """Return a DummyWorkflow for the supervisor."""
70
+ dummy_create_supervisor.called_kwargs = kwargs
71
+ return DummyWorkflow(
72
+ supervisor_args={
73
+ "apps": apps,
74
+ "model": model,
75
+ "state_schema": state_schema,
76
+ **kwargs,
77
+ }
78
+ )
79
+
80
+
81
+ # --- Dummy Hydra Configuration Setup ---
82
+
83
+
84
+ class DummyHydraContext:
85
+ """A dummy context manager for mocking Hydra's initialize and compose functions."""
86
+
87
+ def __enter__(self):
88
+ """Return None when entering the context."""
89
+ return None
90
+
91
+ def __exit__(self, exc_type, exc_val, traceback):
92
+ """Exit function that does nothing."""
93
+ return None
94
+
95
+
96
+ def dict_to_namespace(d):
97
+ """Convert a dictionary to a SimpleNamespace object."""
98
+ return SimpleNamespace(
99
+ **{
100
+ key: dict_to_namespace(val) if isinstance(val, dict) else val
101
+ for key, val in d.items()
102
+ }
103
+ )
104
+
105
+
106
+ dummy_config = {
107
+ "agents": {
108
+ "talk2scholars": {"main_agent": {"system_prompt": "Dummy system prompt"}}
109
+ }
110
+ }
111
+
112
+
113
+ class DummyHydraCompose:
114
+ """A dummy class that returns a namespace from a dummy config dictionary."""
115
+
116
+ def __init__(self, config):
117
+ """Constructor that stores the dummy config."""
118
+ self.config = config
119
+
120
+ def __getattr__(self, item):
121
+ """Return a namespace from the dummy config."""
122
+ return dict_to_namespace(self.config.get(item, {}))
123
+
124
+
125
+ # --- Pytest Fixtures to Patch Dependencies ---
126
+
127
+
128
+ @pytest.fixture(autouse=True)
129
+ def patch_hydra(monkeypatch):
130
+ """Patch the hydra.initialize and hydra.compose functions to return dummy objects."""
131
+ monkeypatch.setattr(
132
+ hydra, "initialize", lambda version_base, config_path: DummyHydraContext()
133
+ )
134
+ monkeypatch.setattr(
135
+ hydra, "compose", lambda config_name, overrides: DummyHydraCompose(dummy_config)
136
+ )
15
137
 
16
138
 
17
139
  @pytest.fixture(autouse=True)
18
- def mock_hydra():
19
- """Mock Hydra configuration."""
20
- with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
21
- cfg_mock = MagicMock()
22
- cfg_mock.agents.talk2scholars.main_agent.temperature = 0
23
- cfg_mock.agents.talk2scholars.main_agent.system_prompt = "System prompt"
24
- cfg_mock.agents.talk2scholars.main_agent.router_prompt = "Router prompt"
25
- mock_compose.return_value = cfg_mock
26
- yield mock_compose
27
-
28
-
29
- def test_get_app():
30
- """Test the full initialization of the LangGraph application."""
31
- thread_id = "test_thread"
32
- mock_llm = Mock()
33
- app = get_app(thread_id, mock_llm)
34
- assert app is not None
35
- assert "supervisor" in app.nodes
36
- assert "s2_agent" in app.nodes # Ensure nodes exist
37
- assert "zotero_agent" in app.nodes
38
-
39
-
40
- def test_get_app_with_default_llm():
41
- """Test app initialization with default LLM parameters."""
42
- thread_id = "test_thread"
43
- llm_mock = Mock()
44
-
45
- # We need to explicitly pass the mock instead of patching, since the function uses
46
- # ChatOpenAI as a default argument value which is evaluated at function definition time
47
- app = get_app(thread_id, llm_mock)
48
- assert app is not None
49
- # We can only verify the app was created successfully
50
-
51
-
52
- def test_get_hydra_config():
53
- """Test that Hydra configuration loads correctly."""
54
- with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
55
- cfg_mock = MagicMock()
56
- cfg_mock.agents.talk2scholars.main_agent.temperature = 0
57
- mock_compose.return_value = cfg_mock
58
- cfg = get_hydra_config()
59
- assert cfg is not None
60
- assert cfg.temperature == 0
61
-
62
-
63
- def test_hydra_failure():
64
- """Test exception handling when Hydra fails to load config."""
65
- thread_id = "test_thread"
66
- llm_mock = Mock()
67
- with patch("hydra.initialize", side_effect=Exception("Hydra error")):
68
- with pytest.raises(Exception) as exc_info:
69
- get_app(thread_id, llm_model=llm_mock)
70
- assert "Hydra error" in str(exc_info.value)
71
-
72
-
73
- def test_supervisor_node_execution():
74
- """Test that the supervisor node routes correctly."""
75
- mock_llm = Mock()
76
- thread_id = "test_thread"
77
-
78
- class MockRouter:
79
- """Mock router class."""
80
-
81
- next = random.choice(["s2_agent", "zotero_agent"])
82
-
83
- with (
84
- patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
85
- patch.object(mock_llm, "invoke", return_value=MockRouter()),
86
- ):
87
- supervisor_node = make_supervisor_node(mock_llm, thread_id)
88
- mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
89
- result = supervisor_node(mock_state)
90
-
91
- # Accept either "s2_agent" or "zotero_agent"
92
- assert result.goto in ["s2_agent", "zotero_agent"]
93
-
94
-
95
- def test_supervisor_node_finish():
96
- """Test that supervisor node correctly handles FINISH case."""
97
- mock_llm = Mock()
98
- thread_id = "test_thread"
99
-
100
- class MockRouter:
101
- """Mock router class."""
102
-
103
- next = "FINISH"
104
-
105
- class MockAIResponse:
106
- """Mock AI response class."""
107
-
108
- def __init__(self):
109
- self.content = "Final AI Response"
110
-
111
- with (
112
- patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
113
- patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
114
- ):
115
- supervisor_node = make_supervisor_node(mock_llm, thread_id)
116
- mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
117
- result = supervisor_node(mock_state)
118
- assert result.goto == END
119
- assert "messages" in result.update
120
- assert isinstance(result.update["messages"], AIMessage)
121
- assert result.update["messages"].content == "Final AI Response"
140
+ def patch_sub_agents_and_supervisor(monkeypatch):
141
+ """Patch the sub-agents and supervisor creation functions."""
142
+ monkeypatch.setattr(
143
+ "aiagents4pharma.talk2scholars.agents.main_agent.get_app_s2", dummy_get_app_s2
144
+ )
145
+ monkeypatch.setattr(
146
+ "aiagents4pharma.talk2scholars.agents.main_agent.get_app_zotero",
147
+ dummy_get_app_zotero,
148
+ )
149
+ monkeypatch.setattr(
150
+ "aiagents4pharma.talk2scholars.agents.main_agent.create_supervisor",
151
+ dummy_create_supervisor,
152
+ )
153
+
154
+
155
+ # --- Test Cases ---
156
+
157
+
158
+ def test_dummy_llm_generate():
159
+ """Test the dummy LLM's generate function."""
160
+ dummy = DummyLLM(model_name="test-model")
161
+ output = getattr(dummy, "_generate")("any prompt")
162
+ assert output == "dummy output"
163
+
164
+
165
+ def test_dummy_llm_llm_type():
166
+ """Test the dummy LLM's _llm_type property."""
167
+ dummy = DummyLLM(model_name="test-model")
168
+ assert getattr(dummy, "_llm_type") == "dummy"
169
+
170
+
171
+ def test_get_app_with_gpt4o_mini():
172
+ """
173
+ Test that get_app replaces a 'gpt-4o-mini' LLM with a new ChatOpenAI instance.
174
+ """
175
+ uniq_id = "test_thread"
176
+ dummy_llm = DummyLLM(model_name="gpt-4o-mini")
177
+ app = get_app(uniq_id, dummy_llm)
178
+
179
+ supervisor_args = getattr(app, "supervisor_args", {})
180
+ assert isinstance(supervisor_args.get("model"), ChatOpenAI)
181
+ assert supervisor_args.get("prompt") == "Dummy system prompt"
182
+ assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
183
+
184
+
185
+ def test_get_app_with_other_model():
186
+ """
187
+ Test that get_app does not replace the LLM if its model_name is not 'gpt-4o-mini'.
188
+ """
189
+ uniq_id = "test_thread_2"
190
+ dummy_llm = DummyLLM(model_name="other-model")
191
+ app = get_app(uniq_id, dummy_llm)
192
+
193
+ supervisor_args = getattr(app, "supervisor_args", {})
194
+ assert supervisor_args.get("model") is dummy_llm
195
+ assert supervisor_args.get("prompt") == "Dummy system prompt"
196
+ assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
@@ -0,0 +1,74 @@
1
+ """
2
+ Unit tests for S2 tools functionality.
3
+ """
4
+
5
+ # pylint: disable=redefined-outer-name
6
+ import pytest
7
+ from langgraph.types import Command
8
+ from ..tools.s2.display_results import (
9
+ display_results,
10
+ NoPapersFoundError as raised_error,
11
+ )
12
+
13
+
14
+ @pytest.fixture
15
+ def initial_state():
16
+ """Provides an empty initial state for tests."""
17
+ return {"papers": {}, "multi_papers": {}}
18
+
19
+
20
+ # Fixed test data for deterministic results
21
+ MOCK_SEARCH_RESPONSE = {
22
+ "data": [
23
+ {
24
+ "paperId": "123",
25
+ "title": "Machine Learning Basics",
26
+ "abstract": "An introduction to ML",
27
+ "year": 2023,
28
+ "citationCount": 100,
29
+ "url": "https://example.com/paper1",
30
+ "authors": [{"name": "Test Author"}],
31
+ }
32
+ ]
33
+ }
34
+
35
+ MOCK_STATE_PAPER = {
36
+ "123": {
37
+ "Title": "Machine Learning Basics",
38
+ "Abstract": "An introduction to ML",
39
+ "Year": 2023,
40
+ "Citation Count": 100,
41
+ "URL": "https://example.com/paper1",
42
+ }
43
+ }
44
+
45
+
46
+ class TestS2Tools:
47
+ """Unit tests for individual S2 tools"""
48
+
49
+ def test_display_results_empty_state(self, initial_state):
50
+ """Verifies display_results tool behavior when state is empty and raises an exception"""
51
+ with pytest.raises(
52
+ raised_error,
53
+ match="No papers found. A search/rec needs to be performed first.",
54
+ ):
55
+ display_results.invoke({"state": initial_state, "tool_call_id": "test123"})
56
+
57
+ def test_display_results_shows_papers(self, initial_state):
58
+ """Verifies display_results tool correctly returns papers from state"""
59
+ state = initial_state.copy()
60
+ state["last_displayed_papers"] = "papers"
61
+ state["papers"] = MOCK_STATE_PAPER
62
+
63
+ result = display_results.invoke(
64
+ input={"state": state, "tool_call_id": "test123"}
65
+ )
66
+
67
+ assert isinstance(result, Command) # Expect a Command object
68
+ assert isinstance(result.update, dict) # Ensure update is a dictionary
69
+ assert "messages" in result.update
70
+ assert len(result.update["messages"]) == 1
71
+ assert (
72
+ "1 papers found. Papers are attached as an artifact."
73
+ in result.update["messages"][0].content
74
+ )
@@ -0,0 +1,282 @@
1
+ """
2
+ Unit tests for S2 tools functionality.
3
+ """
4
+
5
+ import json
6
+ from types import SimpleNamespace
7
+ import pytest
8
+ import requests
9
+ from langgraph.types import Command
10
+ from langchain_core.messages import ToolMessage
11
+ import hydra
12
+ from aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec import (
13
+ get_multi_paper_recommendations,
14
+ )
15
+
16
+ # --- Dummy Hydra Config Setup ---
17
+
18
+
19
+ class DummyHydraContext:
20
+ """dummy context manager for mocking Hydra's initialize and compose functions."""
21
+
22
+ def __enter__(self):
23
+ """enter function that returns None."""
24
+ return None
25
+
26
+ def __exit__(self, exc_type, exc_val, traceback):
27
+ """exit function that does nothing."""
28
+ return None
29
+
30
+
31
+ # Create a dummy configuration that mimics the expected hydra config.
32
+ dummy_config = SimpleNamespace(
33
+ tools=SimpleNamespace(
34
+ multi_paper_recommendation=SimpleNamespace(
35
+ api_endpoint="http://dummy.endpoint/multi",
36
+ headers={"Content-Type": "application/json"},
37
+ api_fields=["paperId", "title", "authors"],
38
+ request_timeout=10,
39
+ )
40
+ )
41
+ )
42
+
43
+ # --- Dummy Response Classes and Functions for requests.post ---
44
+
45
+
46
+ class DummyResponse:
47
+ """A dummy response class for mocking HTTP responses."""
48
+
49
+ def __init__(self, json_data, status_code=200):
50
+ """Initialize a DummyResponse with the given JSON data and status code."""
51
+ self._json_data = json_data
52
+ self.status_code = status_code
53
+
54
+ def json(self):
55
+ """Return the JSON data from the response."""
56
+ return self._json_data
57
+
58
+ def raise_for_status(self):
59
+ """raise an HTTP error for status codes >= 400."""
60
+ if self.status_code >= 400:
61
+ raise requests.HTTPError("HTTP Error")
62
+
63
+
64
+ def test_dummy_response_no_error():
65
+ """Test that raise_for_status does not raise an exception for a successful response."""
66
+ # Create a DummyResponse with a successful status code.
67
+ response = DummyResponse({"data": "success"}, status_code=200)
68
+ # Calling raise_for_status should not raise an exception and should return None.
69
+ assert response.raise_for_status() is None
70
+
71
+
72
+ def test_dummy_response_raise_error():
73
+ """Test that raise_for_status raises an exception for a failing response."""
74
+ # Create a DummyResponse with a failing status code.
75
+ response = DummyResponse({"error": "fail"}, status_code=400)
76
+ # Calling raise_for_status should raise an HTTPError.
77
+ with pytest.raises(requests.HTTPError):
78
+ response.raise_for_status()
79
+
80
+
81
+ def dummy_requests_post_success(url, headers, params, data, timeout):
82
+ """dummy_requests_post_success"""
83
+ # Record call parameters for assertions.
84
+ dummy_requests_post_success.called_url = url
85
+ dummy_requests_post_success.called_headers = headers
86
+ dummy_requests_post_success.called_params = params
87
+ dummy_requests_post_success.called_data = data
88
+ dummy_requests_post_success.called_timeout = timeout
89
+
90
+ # Simulate a valid API response with three recommended papers;
91
+ # one paper missing authors should be filtered out.
92
+ dummy_data = {
93
+ "recommendedPapers": [
94
+ {
95
+ "paperId": "paperA",
96
+ "title": "Multi Rec Paper A",
97
+ "authors": ["Author X"],
98
+ "year": 2019,
99
+ "citationCount": 12,
100
+ "url": "http://paperA",
101
+ "externalIds": {"ArXiv": "arxivA"},
102
+ },
103
+ {
104
+ "paperId": "paperB",
105
+ "title": "Multi Rec Paper B",
106
+ "authors": ["Author Y"],
107
+ "year": 2020,
108
+ "citationCount": 18,
109
+ "url": "http://paperB",
110
+ "externalIds": {},
111
+ },
112
+ {
113
+ "paperId": "paperC",
114
+ "title": "Multi Rec Paper C",
115
+ "authors": None, # This one should be filtered out.
116
+ "year": 2021,
117
+ "citationCount": 25,
118
+ "url": "http://paperC",
119
+ "externalIds": {"ArXiv": "arxivC"},
120
+ },
121
+ ]
122
+ }
123
+ return DummyResponse(dummy_data)
124
+
125
+
126
+ def dummy_requests_post_unexpected(url, headers, params, data, timeout):
127
+ """dummy_requests_post_unexpected"""
128
+ dummy_requests_post_unexpected.called_url = url
129
+ dummy_requests_post_unexpected.called_headers = headers
130
+ dummy_requests_post_unexpected.called_params = params
131
+ dummy_requests_post_unexpected.called_data = data
132
+ dummy_requests_post_unexpected.called_timeout = timeout
133
+ # Simulate a response missing the 'recommendedPapers' key.
134
+ return DummyResponse({"error": "Invalid format"})
135
+
136
+
137
+ def dummy_requests_post_no_recs(url, headers, params, data, timeout):
138
+ """dummy_requests_post_no_recs"""
139
+ dummy_requests_post_no_recs.called_url = url
140
+ dummy_requests_post_no_recs.called_headers = headers
141
+ dummy_requests_post_no_recs.called_params = params
142
+ dummy_requests_post_no_recs.called_data = data
143
+ dummy_requests_post_no_recs.called_timeout = timeout
144
+ # Simulate a response with an empty recommendations list.
145
+ return DummyResponse({"recommendedPapers": []})
146
+
147
+
148
+ def dummy_requests_post_exception(url, headers, params, data, timeout):
149
+ """dummy_requests_post_exception"""
150
+ dummy_requests_post_exception.called_url = url
151
+ dummy_requests_post_exception.called_headers = headers
152
+ dummy_requests_post_exception.called_params = params
153
+ dummy_requests_post_exception.called_data = data
154
+ dummy_requests_post_exception.called_timeout = timeout
155
+ # Simulate a network exception.
156
+ raise requests.exceptions.RequestException("Connection error")
157
+
158
+
159
+ # --- Pytest Fixture to Patch Hydra ---
160
+ @pytest.fixture(autouse=True)
161
+ def patch_hydra(monkeypatch):
162
+ """Patch Hydra's initialize and compose functions to return dummy objects."""
163
+ # Patch hydra.initialize to return our dummy context manager.
164
+ monkeypatch.setattr(
165
+ hydra, "initialize", lambda version_base, config_path: DummyHydraContext()
166
+ )
167
+ # Patch hydra.compose to return our dummy config.
168
+ monkeypatch.setattr(hydra, "compose", lambda config_name, overrides: dummy_config)
169
+
170
+
171
+ # --- Test Cases ---
172
+
173
+
174
+ def test_multi_paper_rec_success(monkeypatch):
175
+ """
176
+ Test that get_multi_paper_recommendations returns a valid Command object
177
+ when the API response is successful. Also, ensure that recommendations missing
178
+ required fields (like authors) are filtered out.
179
+ """
180
+ monkeypatch.setattr(requests, "post", dummy_requests_post_success)
181
+
182
+ tool_call_id = "test_tool_call_id"
183
+ input_data = {
184
+ "paper_ids": ["p1", "p2"],
185
+ "tool_call_id": tool_call_id,
186
+ "limit": 2,
187
+ "year": "2020",
188
+ }
189
+ # Call the tool using .run() with a dictionary input.
190
+ result = get_multi_paper_recommendations.run(input_data)
191
+
192
+ # Validate that the result is a Command with the expected update structure.
193
+ assert isinstance(result, Command)
194
+ update = result.update
195
+ assert "multi_papers" in update
196
+
197
+ papers = update["multi_papers"]
198
+ # Papers with valid 'title' and 'authors' should be included.
199
+ assert "paperA" in papers
200
+ assert "paperB" in papers
201
+ # Paper "paperC" is missing authors and should be filtered out.
202
+ assert "paperC" not in papers
203
+
204
+ # Check that a ToolMessage is included in the messages.
205
+ messages = update.get("messages", [])
206
+ assert len(messages) == 1
207
+ msg = messages[0]
208
+ assert isinstance(msg, ToolMessage)
209
+ assert "Recommendations based on multiple papers were successful" in msg.content
210
+
211
+ # Verify that the correct parameters were sent to requests.post.
212
+ called_params = dummy_requests_post_success.called_params
213
+ assert called_params["limit"] == 2 # Should be min(limit, 500)
214
+ assert called_params["fields"] == "paperId,title,authors"
215
+ # The year parameter should be present.
216
+ assert called_params["year"] == "2020"
217
+
218
+ # Also check the payload sent in the data.
219
+ sent_payload = json.loads(dummy_requests_post_success.called_data)
220
+ assert sent_payload["positivePaperIds"] == ["p1", "p2"]
221
+ assert sent_payload["negativePaperIds"] == []
222
+
223
+
224
+ def test_multi_paper_rec_unexpected_format(monkeypatch):
225
+ """
226
+ Test that get_multi_paper_recommendations raises a RuntimeError when the API
227
+ response does not include the expected 'recommendedPapers' key.
228
+ """
229
+ monkeypatch.setattr(requests, "post", dummy_requests_post_unexpected)
230
+ tool_call_id = "test_tool_call_id"
231
+ input_data = {
232
+ "paper_ids": ["p1", "p2"],
233
+ "tool_call_id": tool_call_id,
234
+ }
235
+ with pytest.raises(
236
+ RuntimeError,
237
+ match=(
238
+ "Unexpected response from Semantic Scholar API. The results could not be "
239
+ "retrieved due to an unexpected format. "
240
+ "Please modify your search query and try again."
241
+ ),
242
+ ):
243
+ get_multi_paper_recommendations.run(input_data)
244
+
245
+
246
+ def test_multi_paper_rec_no_recommendations(monkeypatch):
247
+ """
248
+ Test that get_multi_paper_recommendations raises a RuntimeError when the API
249
+ returns no recommendations.
250
+ """
251
+ monkeypatch.setattr(requests, "post", dummy_requests_post_no_recs)
252
+ tool_call_id = "test_tool_call_id"
253
+ input_data = {
254
+ "paper_ids": ["p1", "p2"],
255
+ "tool_call_id": tool_call_id,
256
+ }
257
+ with pytest.raises(
258
+ RuntimeError,
259
+ match=(
260
+ "No recommendations were found for your query. Consider refining your search "
261
+ "by using more specific keywords or different terms."
262
+ ),
263
+ ):
264
+ get_multi_paper_recommendations.run(input_data)
265
+
266
+
267
+ def test_multi_paper_rec_requests_exception(monkeypatch):
268
+ """
269
+ Test that get_multi_paper_recommendations raises a RuntimeError when requests.post
270
+ throws an exception.
271
+ """
272
+ monkeypatch.setattr(requests, "post", dummy_requests_post_exception)
273
+ tool_call_id = "test_tool_call_id"
274
+ input_data = {
275
+ "paper_ids": ["p1", "p2"],
276
+ "tool_call_id": tool_call_id,
277
+ }
278
+ with pytest.raises(
279
+ RuntimeError,
280
+ match="Failed to connect to Semantic Scholar API. Please retry the same query.",
281
+ ):
282
+ get_multi_paper_recommendations.run(input_data)