aiagents4pharma 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
  2. aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
  3. aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
  4. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
  6. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
  7. aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
  8. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
  9. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
  10. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
  11. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
  12. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
  13. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
  14. aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
  15. aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
  16. aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
  17. aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
  18. aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
  19. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
  20. aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
  21. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
  22. aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
  23. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
  24. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/METADATA +2 -2
  25. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/RECORD +28 -24
  26. aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
  27. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/LICENSE +0 -0
  28. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/WHEEL +0 -0
  29. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,58 @@
1
+ """
2
+ Integration tests for talk2scholars system with OpenAI.
3
+ """
4
+
5
+ import os
6
+ import pytest
7
+ import hydra
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+ from ..agents.main_agent import get_app
11
+ from ..state.state_talk2scholars import Talk2Scholars
12
+
13
+ # pylint: disable=redefined-outer-name
14
+
15
+
16
+ @pytest.mark.skipif(
17
+ not os.getenv("OPENAI_API_KEY"), reason="Requires OpenAI API key to run"
18
+ )
19
+ def test_main_agent_real_llm():
20
+ """
21
+ Test that the main agent invokes S2 agent correctly
22
+ and updates the state with real LLM execution.
23
+ """
24
+
25
+ # Load Hydra Configuration EXACTLY like in main_agent.py
26
+ with hydra.initialize(version_base=None, config_path="../configs"):
27
+ cfg = hydra.compose(
28
+ config_name="config", overrides=["agents/talk2scholars/main_agent=default"]
29
+ )
30
+ hydra_cfg = cfg.agents.talk2scholars.main_agent
31
+
32
+ assert hydra_cfg is not None, "Hydra config failed to load"
33
+
34
+ # Use the real OpenAI API (ensure env variable is set)
35
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=hydra_cfg.temperature)
36
+
37
+ # Initialize main agent workflow (WITH real Hydra config)
38
+ thread_id = "test_thread"
39
+ app = get_app(thread_id, llm)
40
+
41
+ # Provide an actual user query
42
+ initial_state = Talk2Scholars(
43
+ messages=[HumanMessage(content="Find AI papers on transformers")]
44
+ )
45
+
46
+ # Invoke the agent (triggers supervisor → s2_agent)
47
+ result = app.invoke(
48
+ initial_state,
49
+ {"configurable": {"config_id": thread_id, "thread_id": thread_id}},
50
+ )
51
+
52
+ # Assert that the supervisor routed correctly
53
+ assert "messages" in result, "Expected messages in response"
54
+
55
+ # Fix: Accept AIMessage as a valid response type
56
+ assert isinstance(
57
+ result["messages"][-1], (HumanMessage, AIMessage, str)
58
+ ), "Last message should be a valid response"
@@ -4,10 +4,12 @@ Tests the supervisor agent's routing logic and state management.
4
4
  """
5
5
 
6
6
  # pylint: disable=redefined-outer-name
7
+ # pylint: disable=redefined-outer-name,too-few-public-methods
7
8
  from unittest.mock import Mock, patch, MagicMock
8
9
  import pytest
9
10
  from langchain_core.messages import HumanMessage, AIMessage
10
- from ..agents.main_agent import make_supervisor_node, get_app
11
+ from langgraph.graph import END
12
+ from ..agents.main_agent import make_supervisor_node, get_hydra_config, get_app
11
13
  from ..state.state_talk2scholars import Talk2Scholars
12
14
 
13
15
 
@@ -17,7 +19,8 @@ def mock_hydra():
17
19
  with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
18
20
  cfg_mock = MagicMock()
19
21
  cfg_mock.agents.talk2scholars.main_agent.temperature = 0
20
- cfg_mock.agents.talk2scholars.main_agent.main_agent = "Test prompt"
22
+ cfg_mock.agents.talk2scholars.main_agent.system_prompt = "System prompt"
23
+ cfg_mock.agents.talk2scholars.main_agent.router_prompt = "Router prompt"
21
24
  mock_compose.return_value = cfg_mock
22
25
  yield mock_compose
23
26
 
@@ -25,156 +28,129 @@ def mock_hydra():
25
28
  def test_get_app():
26
29
  """Test the full initialization of the LangGraph application."""
27
30
  thread_id = "test_thread"
28
- llm_model = "gpt-4o-mini"
29
-
30
- # Mock the LLM
31
31
  mock_llm = Mock()
32
+ app = get_app(thread_id, mock_llm)
33
+ assert app is not None
34
+ assert "supervisor" in app.nodes
35
+ assert "s2_agent" in app.nodes # Ensure nodes exist
32
36
 
33
- with patch(
34
- "aiagents4pharma.talk2scholars.agents.main_agent.ChatOpenAI",
35
- return_value=mock_llm,
36
- ):
37
37
 
38
- app = get_app(thread_id, llm_model)
39
- assert app is not None
40
- assert "supervisor" in app.nodes
41
- assert "s2_agent" in app.nodes # Ensure nodes exist
38
+ def test_get_app_with_default_llm():
39
+ """Test app initialization with default LLM parameters."""
40
+ thread_id = "test_thread"
41
+ llm_mock = Mock()
42
+
43
+ # We need to explicitly pass the mock instead of patching, since the function uses
44
+ # ChatOpenAI as a default argument value which is evaluated at function definition time
45
+ app = get_app(thread_id, llm_mock)
46
+ assert app is not None
47
+ # We can only verify the app was created successfully
48
+
49
+
50
+ def test_get_hydra_config():
51
+ """Test that Hydra configuration loads correctly."""
52
+ with patch("hydra.initialize"), patch("hydra.compose") as mock_compose:
53
+ cfg_mock = MagicMock()
54
+ cfg_mock.agents.talk2scholars.main_agent.temperature = 0
55
+ mock_compose.return_value = cfg_mock
56
+ cfg = get_hydra_config()
57
+ assert cfg is not None
58
+ assert cfg.temperature == 0
59
+
60
+
61
+ def test_hydra_failure():
62
+ """Test exception handling when Hydra fails to load config."""
63
+ thread_id = "test_thread"
64
+ with patch("hydra.initialize", side_effect=Exception("Hydra error")):
65
+ with pytest.raises(Exception) as exc_info:
66
+ get_app(thread_id)
67
+ assert "Hydra error" in str(exc_info.value)
42
68
 
43
69
 
44
70
  def test_supervisor_node_execution():
45
- """Test that the supervisor node processes messages and makes a decision."""
71
+ """Test that the supervisor node routes correctly."""
46
72
  mock_llm = Mock()
47
73
  thread_id = "test_thread"
48
74
 
49
- # Mock the supervisor agent's response
50
- mock_supervisor = Mock()
51
- mock_supervisor.invoke.return_value = {"messages": [AIMessage(content="s2_agent")]}
75
+ class MockRouter:
76
+ """Mock router class."""
52
77
 
53
- with patch(
54
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent",
55
- return_value=mock_supervisor,
78
+ next = "s2_agent"
79
+
80
+ with (
81
+ patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
82
+ patch.object(mock_llm, "invoke", return_value=MockRouter()),
56
83
  ):
57
84
  supervisor_node = make_supervisor_node(mock_llm, thread_id)
58
-
59
- # Create a mock state
60
85
  mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
61
-
62
- # Execute
63
86
  result = supervisor_node(mock_state)
64
-
65
- # Validate
66
87
  assert result.goto == "s2_agent"
67
- mock_supervisor.invoke.assert_called_once_with(
68
- mock_state, {"configurable": {"thread_id": thread_id}}
69
- ) # Ensure invoke was called correctly
70
88
 
71
89
 
72
- def test_call_s2_agent():
73
- """Test the call to S2 agent and its integration with the state."""
90
+ def test_supervisor_node_finish():
91
+ """Test that supervisor node correctly handles FINISH case."""
92
+ mock_llm = Mock()
74
93
  thread_id = "test_thread"
75
- mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
76
94
 
77
- with patch("aiagents4pharma.talk2scholars.agents.s2_agent.get_app") as mock_s2_app:
78
- mock_s2_app.return_value.invoke.return_value = {
79
- "messages": [AIMessage(content="Here are the papers")],
80
- "papers": {"id123": "Sample Paper"},
81
- }
95
+ class MockRouter:
96
+ """Mock router class."""
82
97
 
83
- app = get_app(thread_id)
84
- result = app.invoke(
85
- mock_state,
86
- config={
87
- "configurable": {
88
- "thread_id": thread_id,
89
- "checkpoint_ns": "test_ns",
90
- "checkpoint_id": "test_checkpoint",
91
- }
92
- },
93
- )
98
+ next = "FINISH"
94
99
 
95
- assert "messages" in result
96
- assert "papers" in result
97
- assert result["papers"]["id123"] == "Sample Paper"
100
+ class MockAIResponse:
101
+ """Mock AI response class."""
98
102
 
99
- mock_s2_app.return_value.invoke.assert_called_once()
103
+ def __init__(self):
104
+ self.content = "Final AI Response"
100
105
 
101
-
102
- def test_hydra_failure():
103
- """Test exception handling when Hydra fails to load config."""
104
- thread_id = "test_thread"
105
- with patch("hydra.initialize", side_effect=Exception("Hydra error")):
106
- with pytest.raises(Exception) as exc_info:
107
- get_app(thread_id)
108
-
109
- assert "Hydra error" in str(exc_info.value)
106
+ with (
107
+ patch.object(mock_llm, "with_structured_output", return_value=mock_llm),
108
+ patch.object(mock_llm, "invoke", side_effect=[MockRouter(), MockAIResponse()]),
109
+ ):
110
+ supervisor_node = make_supervisor_node(mock_llm, thread_id)
111
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="End conversation")])
112
+ result = supervisor_node(mock_state)
113
+ assert result.goto == END
114
+ assert "messages" in result.update
115
+ assert isinstance(result.update["messages"], AIMessage)
116
+ assert result.update["messages"].content == "Final AI Response"
110
117
 
111
118
 
112
- class TestMainAgent:
113
- """Basic tests for the main agent initialization and configuration"""
119
+ def test_call_s2_agent_failure_in_get_app():
120
+ """Test handling failure when calling s2_agent.get_app()."""
121
+ thread_id = "test_thread"
122
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
114
123
 
115
- def test_supervisor_node_creation(self, mock_hydra):
116
- """Test that supervisor node can be created with correct config"""
117
- mock_llm = Mock()
118
- thread_id = "test_thread"
124
+ with patch(
125
+ "aiagents4pharma.talk2scholars.agents.s2_agent.get_app",
126
+ side_effect=Exception("S2 Agent Failure"),
127
+ ):
128
+ with pytest.raises(Exception) as exc_info:
129
+ app = get_app(thread_id) # Get the compiled workflow
130
+ app.invoke(
131
+ mock_state,
132
+ {"configurable": {"config_id": thread_id, "thread_id": thread_id}},
133
+ )
119
134
 
120
- with patch(
121
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
122
- ) as mock_create:
123
- mock_create.return_value = Mock()
124
- supervisor = make_supervisor_node(mock_llm, thread_id)
135
+ assert "S2 Agent Failure" in str(exc_info.value)
125
136
 
126
- assert supervisor is not None
127
- assert mock_create.called
128
- # Verify Hydra was called with correct parameters
129
- assert mock_hydra.call_count == 1 # Updated assertion
130
137
 
131
- def test_supervisor_config_loading(self, mock_hydra):
132
- """Test that supervisor loads configuration correctly"""
133
- mock_llm = Mock()
134
- thread_id = "test_thread"
138
+ def test_call_s2_agent_failure_in_invoke():
139
+ """Test handling failure when invoking s2_agent app."""
140
+ thread_id = "test_thread"
141
+ mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
135
142
 
136
- with patch(
137
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
138
- ):
139
- make_supervisor_node(mock_llm, thread_id)
143
+ mock_app = Mock()
144
+ mock_app.invoke.side_effect = Exception("S2 Agent Invoke Failure")
140
145
 
141
- # Verify Hydra initialization
142
- assert mock_hydra.call_count == 1
143
- assert "agents/talk2scholars/main_agent=default" in str(
144
- mock_hydra.call_args_list[0]
146
+ with patch(
147
+ "aiagents4pharma.talk2scholars.agents.s2_agent.get_app", return_value=mock_app
148
+ ):
149
+ with pytest.raises(Exception) as exc_info:
150
+ app = get_app(thread_id) # Get the compiled workflow
151
+ app.invoke(
152
+ mock_state,
153
+ {"configurable": {"config_id": thread_id, "thread_id": thread_id}},
145
154
  )
146
155
 
147
- def test_react_agent_params(self):
148
- """Test that react agent is created with correct parameters"""
149
- mock_llm = Mock()
150
- thread_id = "test_thread"
151
-
152
- with patch(
153
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
154
- ) as mock_create:
155
- mock_create.return_value = Mock()
156
- make_supervisor_node(mock_llm, thread_id)
157
-
158
- # Verify create_react_agent was called
159
- assert mock_create.called
160
-
161
- # Verify the parameters
162
- args, kwargs = mock_create.call_args
163
- assert args[0] == mock_llm # First argument should be the LLM
164
- assert "state_schema" in kwargs # Should have state_schema
165
- assert hasattr(
166
- mock_create.return_value, "invoke"
167
- ) # Should have invoke method
168
-
169
- def test_supervisor_custom_config(self, mock_hydra):
170
- """Test supervisor with custom configuration"""
171
- mock_llm = Mock()
172
- thread_id = "test_thread"
173
-
174
- with patch(
175
- "aiagents4pharma.talk2scholars.agents.main_agent.create_react_agent"
176
- ):
177
- make_supervisor_node(mock_llm, thread_id)
178
-
179
- # Verify Hydra was called
180
- mock_hydra.assert_called_once()
156
+ assert "S2 Agent Invoke Failure" in str(exc_info.value)
@@ -1,7 +1,8 @@
1
1
  """
2
- Unit tests for the S2 agent (Semantic Scholar sub-agent).
2
+ Updated Unit Tests for the S2 agent (Semantic Scholar sub-agent).
3
3
  """
4
4
 
5
+ # pylint: disable=redefined-outer-name
5
6
  from unittest import mock
6
7
  import pytest
7
8
  from langchain_core.messages import HumanMessage, AIMessage
@@ -35,31 +36,42 @@ def mock_tools_fixture():
35
36
  "get_single_paper_recommendations"
36
37
  ) as mock_s2_single_rec,
37
38
  mock.patch(
38
- "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec."
39
- "get_multi_paper_recommendations"
39
+ "aiagents4pharma.talk2scholars.tools.s2.multi_paper_rec.get_multi_paper_recommendations"
40
40
  ) as mock_s2_multi_rec,
41
+ mock.patch(
42
+ "aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
43
+ ) as mock_s2_query_results,
44
+ mock.patch(
45
+ "aiagents4pharma.talk2scholars.tools.s2.retrieve_semantic_scholar_paper_id."
46
+ "retrieve_semantic_scholar_paper_id"
47
+ ) as mock_s2_retrieve_id,
41
48
  ):
42
-
43
- mock_s2_search.return_value = mock.Mock()
44
- mock_s2_display.return_value = mock.Mock()
45
- mock_s2_single_rec.return_value = mock.Mock()
46
- mock_s2_multi_rec.return_value = mock.Mock()
47
-
48
- yield [mock_s2_search, mock_s2_display, mock_s2_single_rec, mock_s2_multi_rec]
49
+ mock_s2_search.return_value = {"result": "Mock Search Result"}
50
+ mock_s2_display.return_value = {"result": "Mock Display Result"}
51
+ mock_s2_single_rec.return_value = {"result": "Mock Single Rec"}
52
+ mock_s2_multi_rec.return_value = {"result": "Mock Multi Rec"}
53
+ mock_s2_query_results.return_value = {"result": "Mock Query Result"}
54
+ mock_s2_retrieve_id.return_value = {"paper_id": "MockPaper123"}
55
+
56
+ yield [
57
+ mock_s2_search,
58
+ mock_s2_display,
59
+ mock_s2_single_rec,
60
+ mock_s2_multi_rec,
61
+ mock_s2_query_results,
62
+ mock_s2_retrieve_id,
63
+ ]
49
64
 
50
65
 
51
66
  @pytest.mark.usefixtures("mock_hydra_fixture")
52
67
  def test_s2_agent_initialization():
53
68
  """Test that S2 agent initializes correctly with mock configuration."""
54
69
  thread_id = "test_thread"
55
-
56
70
  with mock.patch(
57
71
  "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
58
72
  ) as mock_create:
59
73
  mock_create.return_value = mock.Mock()
60
-
61
74
  app = get_app(thread_id)
62
-
63
75
  assert app is not None
64
76
  assert mock_create.called
65
77
 
@@ -68,7 +80,6 @@ def test_s2_agent_invocation():
68
80
  """Test that the S2 agent processes user input and returns a valid response."""
69
81
  thread_id = "test_thread"
70
82
  mock_state = Talk2Scholars(messages=[HumanMessage(content="Find AI papers")])
71
-
72
83
  with mock.patch(
73
84
  "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
74
85
  ) as mock_create:
@@ -78,7 +89,6 @@ def test_s2_agent_invocation():
78
89
  "messages": [AIMessage(content="Here are some AI papers")],
79
90
  "papers": {"id123": "AI Research Paper"},
80
91
  }
81
-
82
92
  app = get_app(thread_id)
83
93
  result = app.invoke(
84
94
  mock_state,
@@ -90,7 +100,6 @@ def test_s2_agent_invocation():
90
100
  }
91
101
  },
92
102
  )
93
-
94
103
  assert "messages" in result
95
104
  assert "papers" in result
96
105
  assert result["papers"]["id123"] == "AI Research Paper"
@@ -99,10 +108,7 @@ def test_s2_agent_invocation():
99
108
  def test_s2_agent_tools_assignment(request):
100
109
  """Ensure that the correct tools are assigned to the agent."""
101
110
  thread_id = "test_thread"
102
-
103
- # Dynamically retrieve the fixture to avoid redefining it in the function signature
104
111
  mock_tools = request.getfixturevalue("mock_tools_fixture")
105
-
106
112
  with (
107
113
  mock.patch(
108
114
  "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
@@ -111,28 +117,88 @@ def test_s2_agent_tools_assignment(request):
111
117
  "aiagents4pharma.talk2scholars.agents.s2_agent.ToolNode"
112
118
  ) as mock_toolnode,
113
119
  ):
114
-
115
120
  mock_agent = mock.Mock()
116
121
  mock_create.return_value = mock_agent
117
-
118
- # Mock ToolNode behavior
119
122
  mock_tool_instance = mock.Mock()
120
- mock_tool_instance.tools = mock_tools # Use the dynamically retrieved fixture
123
+ mock_tool_instance.tools = mock_tools
121
124
  mock_toolnode.return_value = mock_tool_instance
122
-
123
125
  get_app(thread_id)
124
-
125
- # Ensure the agent was created with the mocked ToolNode
126
126
  assert mock_toolnode.called
127
- assert len(mock_tool_instance.tools) == 4
127
+ assert len(mock_tool_instance.tools) == 6
128
+
129
+
130
+ def test_s2_query_results_tool():
131
+ """Test if the query_results tool is correctly utilized by the agent."""
132
+ thread_id = "test_thread"
133
+ mock_state = Talk2Scholars(
134
+ messages=[HumanMessage(content="Query results for AI papers")]
135
+ )
136
+ with mock.patch(
137
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
138
+ ) as mock_create:
139
+ mock_agent = mock.Mock()
140
+ mock_create.return_value = mock_agent
141
+ mock_agent.invoke.return_value = {
142
+ "messages": [HumanMessage(content="Query results for AI papers")],
143
+ "last_displayed_papers": {},
144
+ "papers": {
145
+ "query_results": "Mock Query Result"
146
+ }, # Ensure the expected key is inside 'papers'
147
+ "multi_papers": {},
148
+ }
149
+ app = get_app(thread_id)
150
+ result = app.invoke(
151
+ mock_state,
152
+ config={
153
+ "configurable": {
154
+ "thread_id": thread_id,
155
+ "checkpoint_ns": "test_ns",
156
+ "checkpoint_id": "test_checkpoint",
157
+ }
158
+ },
159
+ )
160
+ assert "query_results" in result["papers"]
161
+ assert mock_agent.invoke.called
162
+
163
+
164
+ def test_s2_retrieve_id_tool():
165
+ """Test if the retrieve_semantic_scholar_paper_id tool is correctly utilized by the agent."""
166
+ thread_id = "test_thread"
167
+ mock_state = Talk2Scholars(
168
+ messages=[HumanMessage(content="Retrieve paper ID for AI research")]
169
+ )
170
+ with mock.patch(
171
+ "aiagents4pharma.talk2scholars.agents.s2_agent.create_react_agent"
172
+ ) as mock_create:
173
+ mock_agent = mock.Mock()
174
+ mock_create.return_value = mock_agent
175
+ mock_agent.invoke.return_value = {
176
+ "messages": [HumanMessage(content="Retrieve paper ID for AI research")],
177
+ "last_displayed_papers": {},
178
+ "papers": {
179
+ "paper_id": "MockPaper123"
180
+ }, # Ensure 'paper_id' is inside 'papers'
181
+ "multi_papers": {},
182
+ }
183
+ app = get_app(thread_id)
184
+ result = app.invoke(
185
+ mock_state,
186
+ config={
187
+ "configurable": {
188
+ "thread_id": thread_id,
189
+ "checkpoint_ns": "test_ns",
190
+ "checkpoint_id": "test_checkpoint",
191
+ }
192
+ },
193
+ )
194
+ assert "paper_id" in result["papers"]
195
+ assert mock_agent.invoke.called
128
196
 
129
197
 
130
198
  def test_s2_agent_hydra_failure():
131
199
  """Test exception handling when Hydra fails to load config."""
132
200
  thread_id = "test_thread"
133
-
134
201
  with mock.patch("hydra.initialize", side_effect=Exception("Hydra error")):
135
202
  with pytest.raises(Exception) as exc_info:
136
203
  get_app(thread_id)
137
-
138
204
  assert "Hydra error" in str(exc_info.value)