aiagents4pharma 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (29) hide show
  1. aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
  2. aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
  3. aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
  4. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
  6. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
  7. aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
  8. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
  9. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
  10. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
  11. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
  12. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
  13. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
  14. aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
  15. aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
  16. aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
  17. aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
  18. aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
  19. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
  20. aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
  21. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
  22. aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
  23. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
  24. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/METADATA +2 -2
  25. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/RECORD +28 -24
  26. aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
  27. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/LICENSE +0 -0
  28. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/WHEEL +0 -0
  29. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,58 @@
1
+ """
2
+ Integration tests for talk2scholars system with OpenAI.
3
+ """
4
+
5
+ import os
6
+ import pytest
7
+ import hydra
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+ from ..agents.main_agent import get_app
11
+ from ..state.state_talk2scholars import Talk2Scholars
12
+
13
+ # pylint: disable=redefined-outer-name
14
+
15
+
16
+ @pytest.mark.skipif(
17
+ not os.getenv("OPENAI_API_KEY"), reason="Requires OpenAI API key to run"
18
+ )
19
+ def test_main_agent_real_llm():
20
+ """
21
+ Test that the main agent invokes S2 agent correctly
22
+ and updates the state with real LLM execution.
23
+ """
24
+
25
+ # Load Hydra Configuration EXACTLY like in main_agent.py
26
+ with hydra.initialize(version_base=None, config_path="../configs"):
27
+ cfg = hydra.compose(
28
+ config_name="config", overrides=["agents/talk2scholars/main_agent=default"]
29
+ )
30
+ hydra_cfg = cfg.agents.talk2scholars.main_agent
31
+
32
+ assert hydra_cfg is not None, "Hydra config failed to load"
33
+
34
+ # Use the real OpenAI API (ensure env variable is set)
35
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=hydra_cfg.temperature)
36
+
37
+ # Initialize main agent workflow (WITH real Hydra config)
38
+ thread_id = "test_thread"
39
+ app = get_app(thread_id, llm)
40
+
41
+ # Provide an actual user query
42
+ initial_state = Talk2Scholars(
43
+ messages=[HumanMessage(content="Find AI papers on transformers")]
44
+ )
45
+
46
+ # Invoke the agent (triggers supervisor → s2_agent)
47
+ result = app.invoke(
48
+ initial_state,
49
+ {"configurable": {"config_id": thread_id, "thread_id": thread_id}},
50
+ )
51
+
52
+ # Assert that the supervisor routed correctly
53
+ assert "messages" in result, "Expected messages in response"
54
+
55
+ # Fix: Accept AIMessage as a valid response type
56
+ assert isinstance(
57
+ result["messages"][-1], (HumanMessage, AIMessage, str)
58
+ ), "Last message should be a valid response"
@@ -4,10 +4,12 @@ Tests the supervisor agent's routing logic and state management.
4
4
  """
5
5
 
6
6
  # pylint: disable=redefined-outer-name
7
+ # pylint: disable=redefined-outer-name,too-few-public-methods
7
8
  from unittest.mock import Mock, patch, MagicMock
8
9
  import pytest
9
10
  from langchain_core.messages import HumanMessage, AIMessage
10
- from ..agents.main_agent import make_supervisor_node, get_app
11
+ from langgraph.graph import END
12
+ from ..agents.main_agent import make_supervisor_node, get_hydra_config, get_app
11
13
  from ..state.state_talk2scholars import Talk2Scholars
12
14
 
13
15
 
@@ -17,7 +19,8 @@ def mock_hydra():
17
19
  with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
18
20
  cfg_mock = MagicMock()
19
21
  cfg_mock.agents.talk2scholars.main_agent.temperature = 0
20
- cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test prompt"
22
+ cfg_mock.agents.talk2scholars.main_agent.system_prompt = "System prompt"
23
+ cfg_mock.agents.talk2scholars.main_agent.router_prompt = "Router prompt"
21
24
  mock_compose.return_value = cfg_mock
22
25
  yield mock_compose
23
26
 
@@ -25,156 +28,129 @@ def mock_hydra():
25
28
  def test_get_app():
26
29
  """Test the full initialization of the LangGraph application."""
27
30
  thread_id = "test_thread"
28
- llm_model = "gpt-4o-mini"
29
-
30
- # Mock the LLM
31
31
  mock_llm = Mock()
32
+ app = get_app(thread_id, mock_llm)
33
+ assert app is not None
34
+ assert "supervisor" in app.nodes
35
+ assert "s2_agent" in app.nodes # Ensure nodes exist
32
36
 
33
- with patch(
34
- "aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
35
- return_value=mock_llm,
36
- ):
37
37
 
38
- app = get_app(thread_id, llm_model)
39
- assert app is not None
40
- assert "supervisor" in app.nodes
41
- assert "s2_agent" in app.nodes # Ensure nodes exist
38
+ def test_get_app_with_default_llm():
39
+ """Test app initialization with default LLM parameters."""
40
+ thread_id = "test_thread"
41
+ llm_mock = Mock()
42
+
43
+ # We need to explicitly pass the mock instead of patching, since the function uses
44
+ # ChatOpenAI as a default argument value which is evaluated at function definition time
45
+ app = get_app(thread_id, llm_mock)
46
+ assert app is not None
47
+ # We can only verify the app was created successfully
48
+
49
+
50
+ def test_get_hydra_config():
51
+ """Test that Hydra configuration loads correctly."""
52
+ with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
53
+ cfg_mock = MagicMock()
54
+ cfg_mock.agents.talk2scholars.main_agent.temperature = 0
55
+ mock_compose.return_value = cfg_mock
56
+ cfg = get_hydra_config()
57
+ assert cfg is not None
58
+ assert cfg.temperature == 0
59
+
60
+
61
+ def test_hydra_failure():
62
+ """Test exception handling when Hydra fails to load config."""
63
+ thread_id = "test_thread"
64
+ with patch("hydra.initialize", side_effect=Exception("Hydra error")):
65
+ with pytest.raises(Exception) as exc_info:
66
+ get_app(thread_id)
67
+ assert "Hydra error" in str(exc_info.value)
42
68
 
43
69
 
44
70
  def test_supervisor_node_execution():
45
- """Test that the supervisor node processes messages and makes a decision."""
71
+ """Test that the supervisor node routes correctly."""
46
72
  mock_llm = Mock()
47
73
  thread_id = "test_thread"
48
74
 
49
- # Mock the supervisor agent's response
50
- mock_supervisor = Mock()
51
- mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
75
+ class MockRouter:
76
+ """Mock router class."""
52
77
 
53
- with patch(
54
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent",
55
- return_value=mock_supervisor,
78
+ next = "s2_agent"
79
+
80
+ with (
81
+ patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
82
+ patch.object(mock_llm, "invoke", return_value=MockRouter()),
56
83
  ):
57
84
  supervisor_node = make_supervisor_node(mock_llm, thread_id)
58
-
59
- # Create a mock state
60
85
  mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
61
-
62
- # Execute
63
86
  result = supervisor_node(mock_state)
64
-
65
- # Validate
66
87
  assert result.goto == "s2_agent"
67
- mock_supervisor.invoke.assert_called_once_with(
68
- mock_state, {"configurable": {"thread_id": thread_id}}
69
- ) # Ensure invoke was called correctly
70
88
 
71
89
 
72
- def test_call_s2_agent():
73
- """Test the call to S2 agent and its integration with the state."""
90
+ def test_supervisor_node_finish():
91
+ """Test that supervisor node correctly handles FINISH case."""
92
+ mock_llm = Mock()
74
93
  thread_id = "test_thread"
75
- mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
76
94
 
77
- with patch("aiagents4pharma.talk2scholars.agents.s2_agent.get_app") as mock_s2_app:
78
- mock_s2_app.return_value.invoke.return_value = {
79
- "messages": [AIMessage(content="Here are the papers")],
80
- "papers": {"id123": "Sample Paper"},
81
- }
95
+ class MockRouter:
96
+ """Mock router class."""
82
97
 
83
- app = get_app(thread_id)
84
- result = app.invoke(
85
- mock_state,
86
- config={
87
- "configurable": {
88
- "thread_id": thread_id,
89
- "checkpoint_ns": "test_ns",
90
- "checkpoint_id": "test_checkpoint",
91
- }
92
- },
93
- )
98
+ next = "FINISH"
94
99
 
95
- assert "messages" in result
96
- assert "papers" in result
97
- assert result["papers"]["id123"] == "Sample Paper"
100
+ class MockAIResponse:
101
+ """Mock AI response class."""
98
102
 
99
- mock_s2_app.return_value.invoke.assert_called_once()
103
+ def __init__(self):
104
+ self.content = "Final AI Response"
100
105
 
101
-
102
- def test_hydra_failure():
103
- """Test exception handling when Hydra fails to load config."""
104
- thread_id = "test_thread"
105
- with patch("hydra.initialize", side_effect=Exception("Hydra error")):
106
- with pytest.raises(Exception) as exc_info:
107
- get_app(thread_id)
108
-
109
- assert "Hydra error" in str(exc_info.value)
106
+ with (
107
+ patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
108
+ patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
109
+ ):
110
+ supervisor_node = make_supervisor_node(mock_llm, thread_id)
111
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
112
+ result = supervisor_node(mock_state)
113
+ assert result.goto == END
114
+ assert "messages" in result.update
115
+ assert isinstance(result.update["messages"], AIMessage)
116
+ assert result.update["messages"].content == "Final AI Response"
110
117
 
111
118
 
112
- class TestMainAgent:
113
- """Basic tests for the main agent initialization and configuration"""
119
+ def test_call_s2_agent_failure_in_get_app():
120
+ """Test handling failure when calling s2_agent.get_app()."""
121
+ thread_id = "test_thread"
122
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
114
123
 
115
- def test_supervisor_node_creation(self, mock_hydra):
116
- """Test that supervisor node can be created with correct config"""
117
- mock_llm = Mock()
118
- thread_id = "test_thread"
124
+ with patch(
125
+ "aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
126
+ side_effect=Exception("S2 Agent Failure"),
127
+ ):
128
+ with pytest.raises(Exception) as exc_info:
129
+ app = get_app(thread_id) # Get the compiled workflow
130
+ app.invoke(
131
+ mock_state,
132
+ {"configurable": {"config_id": thread_id, "thread_id": thread_id}},
133
+ )
119
134
 
120
- with patch(
121
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
122
- ) as mock_create:
123
- mock_create.return_value = Mock()
124
- supervisor = make_supervisor_node(mock_llm, thread_id)
135
+ assert "S2 Agent Failure" in str(exc_info.value)
125
136
 
126
- assert supervisor is not None
127
- assert mock_create.called
128
- # Verify Hydra was called with correct parameters
129
- assert mock_hydra.call_count == 1 # Updated assertion
130
137
 
131
- def test_supervisor_config_loading(self, mock_hydra):
132
- """Test that supervisor loads configuration correctly"""
133
- mock_llm = Mock()
134
- thread_id = "test_thread"
138
+ def test_call_s2_agent_failure_in_invoke():
139
+ """Test handling failure when invoking s2_agent app."""
140
+ thread_id = "test_thread"
141
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
135
142
 
136
- with patch(
137
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
138
- ):
139
- make_supervisor_node(mock_llm, thread_id)
143
+ mock_app = Mock()
144
+ mock_app.invoke.side_effect = Exception("S2 Agent Invoke Failure")
140
145
 
141
- # Verify Hydra initialization
142
- assert mock_hydra.call_count == 1
143
- assert "agents/talk2scholars/main_agent=default" in str(
144
- mock_hydra.call_args_list[0]
146
+ with patch(
147
+ "aiagents4pharma.talk2scholars.agents.s2_agent.get_app", return_value=mock_app
148
+ ):
149
+ with pytest.raises(Exception) as exc_info:
150
+ app = get_app(thread_id) # Get the compiled workflow
151
+ app.invoke(
152
+ mock_state,
153
+ {"configurable": {"config_id": thread_id, "thread_id": thread_id}},
145
154
  )
146
155
 
147
- def test_react_agent_params(self):
148
- """Test that react agent is created with correct parameters"""
149
- mock_llm = Mock()
150
- thread_id = "test_thread"
151
-
152
- with patch(
153
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
154
- ) as mock_create:
155
- mock_create.return_value = Mock()
156
- make_supervisor_node(mock_llm, thread_id)
157
-
158
- # Verify create_react_agent was called
159
- assert mock_create.called
160
-
161
- # Verify the parameters
162
- args, kwargs = mock_create.call_args
163
- assert args[0] == mock_llm # First argument should be the LLM
164
- assert "state_schema" in kwargs # Should have state_schema
165
- assert hasattr(
166
- mock_create.return_value, "invoke"
167
- ) # Should have invoke method
168
-
169
- def test_supervisor_custom_config(self, mock_hydra):
170
- """Test supervisor with custom configuration"""
171
- mock_llm = Mock()
172
- thread_id = "test_thread"
173
-
174
- with patch(
175
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
176
- ):
177
- make_supervisor_node(mock_llm, thread_id)
178
-
179
- # Verify Hydra was called
180
- mock_hydra.assert_called_once()
156
+ assert "S2 Agent Invoke Failure" in str(exc_info.value)
@@ -1,7 +1,8 @@
1
1
  """
2
- Unit tests for the S2 agent (Semantic Scholar sub-agent).
2
+ Updated Unit Tests for the S2 agent (Semantic Scholar sub-agent).
3
3
  """
4
4
 
5
+ # pylint: disable=redefined-outer-name
5
6
  from unittest import mock
6
7
  import pytest
7
8
  from langchain_core.messages import HumanMessage, AIMessage
@@ -35,31 +36,42 @@ def mock_tools_fixture():
35
36
  "get_single_paper_recommendations"
36
37
  ) as mock_s2_single_rec,
37
38
  mock.patch(
38
- "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
39
- "get_multi_paper_recommendations"
39
+ "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec.get_multi_paper_recommendations"
40
40
  ) as mock_s2_multi_rec,
41
+ mock.patch(
42
+ "aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
43
+ ) as mock_s2_query_results,
44
+ mock.patch(
45
+ "aiagents4pharma.talk2scholars.tools.s2.retrieve_semantic_scholar_paper_id."
46
+ "retrieve_semantic_scholar_paper_id"
47
+ ) as mock_s2_retrieve_id,
41
48
  ):
42
-
43
- mock_s2_search.return_value = mock.Mock()
44
- mock_s2_display.return_value = mock.Mock()
45
- mock_s2_single_rec.return_value = mock.Mock()
46
- mock_s2_multi_rec.return_value = mock.Mock()
47
-
48
- yield [mock_s2_search, mock_s2_display, mock_s2_single_rec, mock_s2_multi_rec]
49
+ mock_s2_search.return_value = {"result": "Mock Search Result"}
50
+ mock_s2_display.return_value = {"result": "Mock Display Result"}
51
+ mock_s2_single_rec.return_value = {"result": "Mock Single Rec"}
52
+ mock_s2_multi_rec.return_value = {"result": "Mock Multi Rec"}
53
+ mock_s2_query_results.return_value = {"result": "Mock Query Result"}
54
+ mock_s2_retrieve_id.return_value = {"paper_id": "MockPaper123"}
55
+
56
+ yield [
57
+ mock_s2_search,
58
+ mock_s2_display,
59
+ mock_s2_single_rec,
60
+ mock_s2_multi_rec,
61
+ mock_s2_query_results,
62
+ mock_s2_retrieve_id,
63
+ ]
49
64
 
50
65
 
51
66
  @pytest.mark.usefixtures("mock_hydra_fixture")
52
67
  def test_s2_agent_initialization():
53
68
  """Test that S2 agent initializes correctly with mock configuration."""
54
69
  thread_id = "test_thread"
55
-
56
70
  with mock.patch(
57
71
  "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
58
72
  ) as mock_create:
59
73
  mock_create.return_value = mock.Mock()
60
-
61
74
  app = get_app(thread_id)
62
-
63
75
  assert app is not None
64
76
  assert mock_create.called
65
77
 
@@ -68,7 +80,6 @@ def test_s2_agent_invocation():
68
80
  """Test that the S2 agent processes user input and returns a valid response."""
69
81
  thread_id = "test_thread"
70
82
  mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
71
-
72
83
  with mock.patch(
73
84
  "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
74
85
  ) as mock_create:
@@ -78,7 +89,6 @@ def test_s2_agent_invocation():
78
89
  "messages": [AIMessage(content="Here are some AI papers")],
79
90
  "papers": {"id123": "AI Research Paper"},
80
91
  }
81
-
82
92
  app = get_app(thread_id)
83
93
  result = app.invoke(
84
94
  mock_state,
@@ -90,7 +100,6 @@ def test_s2_agent_invocation():
90
100
  }
91
101
  },
92
102
  )
93
-
94
103
  assert "messages" in result
95
104
  assert "papers" in result
96
105
  assert result["papers"]["id123"] == "AI Research Paper"
@@ -99,10 +108,7 @@ def test_s2_agent_invocation():
99
108
  def test_s2_agent_tools_assignment(request):
100
109
  """Ensure that the correct tools are assigned to the agent."""
101
110
  thread_id = "test_thread"
102
-
103
- # Dynamically retrieve the fixture to avoid redefining it in the function signature
104
111
  mock_tools = request.getfixturevalue("mock_tools_fixture")
105
-
106
112
  with (
107
113
  mock.patch(
108
114
  "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
@@ -111,28 +117,88 @@ def test_s2_agent_tools_assignment(request):
111
117
  "aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
112
118
  ) as mock_toolnode,
113
119
  ):
114
-
115
120
  mock_agent = mock.Mock()
116
121
  mock_create.return_value = mock_agent
117
-
118
- # Mock ToolNode behavior
119
122
  mock_tool_instance = mock.Mock()
120
- mock_tool_instance.tools = mock_tools # Use the dynamically retrieved fixture
123
+ mock_tool_instance.tools = mock_tools
121
124
  mock_toolnode.return_value = mock_tool_instance
122
-
123
125
  get_app(thread_id)
124
-
125
- # Ensure the agent was created with the mocked ToolNode
126
126
  assert mock_toolnode.called
127
- assert len(mock_tool_instance.tools) == 4
127
+ assert len(mock_tool_instance.tools) == 6
128
+
129
+
130
+ def test_s2_query_results_tool():
131
+ """Test if the query_results tool is correctly utilized by the agent."""
132
+ thread_id = "test_thread"
133
+ mock_state = Talk2Scholars(
134
+ messages=[HumanMessage(content="Query results for AI papers")]
135
+ )
136
+ with mock.patch(
137
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
138
+ ) as mock_create:
139
+ mock_agent = mock.Mock()
140
+ mock_create.return_value = mock_agent
141
+ mock_agent.invoke.return_value = {
142
+ "messages": [HumanMessage(content="Query results for AI papers")],
143
+ "last_displayed_papers": {},
144
+ "papers": {
145
+ "query_results": "Mock Query Result"
146
+ }, # Ensure the expected key is inside 'papers'
147
+ "multi_papers": {},
148
+ }
149
+ app = get_app(thread_id)
150
+ result = app.invoke(
151
+ mock_state,
152
+ config={
153
+ "configurable": {
154
+ "thread_id": thread_id,
155
+ "checkpoint_ns": "test_ns",
156
+ "checkpoint_id": "test_checkpoint",
157
+ }
158
+ },
159
+ )
160
+ assert "query_results" in result["papers"]
161
+ assert mock_agent.invoke.called
162
+
163
+
164
+ def test_s2_retrieve_id_tool():
165
+ """Test if the retrieve_semantic_scholar_paper_id tool is correctly utilized by the agent."""
166
+ thread_id = "test_thread"
167
+ mock_state = Talk2Scholars(
168
+ messages=[HumanMessage(content="Retrieve paper ID for AI research")]
169
+ )
170
+ with mock.patch(
171
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
172
+ ) as mock_create:
173
+ mock_agent = mock.Mock()
174
+ mock_create.return_value = mock_agent
175
+ mock_agent.invoke.return_value = {
176
+ "messages": [HumanMessage(content="Retrieve paper ID for AI research")],
177
+ "last_displayed_papers": {},
178
+ "papers": {
179
+ "paper_id": "MockPaper123"
180
+ }, # Ensure 'paper_id' is inside 'papers'
181
+ "multi_papers": {},
182
+ }
183
+ app = get_app(thread_id)
184
+ result = app.invoke(
185
+ mock_state,
186
+ config={
187
+ "configurable": {
188
+ "thread_id": thread_id,
189
+ "checkpoint_ns": "test_ns",
190
+ "checkpoint_id": "test_checkpoint",
191
+ }
192
+ },
193
+ )
194
+ assert "paper_id" in result["papers"]
195
+ assert mock_agent.invoke.called
128
196
 
129
197
 
130
198
  def test_s2_agent_hydra_failure():
131
199
  """Test exception handling when Hydra fails to load config."""
132
200
  thread_id = "test_thread"
133
-
134
201
  with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
135
202
  with pytest.raises(Exception) as exc_info:
136
203
  get_app(thread_id)
137
-
138
204
  assert "Hydra error" in str(exc_info.value)