aiagents4pharma 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
- aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
- aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
- aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
- aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
- aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/METADATA +2 -2
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/RECORD +28 -24
- aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,20 @@ Unit tests for S2 tools functionality.
|
|
4
4
|
|
5
5
|
# pylint: disable=redefined-outer-name
|
6
6
|
from unittest.mock import patch
|
7
|
-
from
|
7
|
+
from unittest.mock import MagicMock
|
8
8
|
import pytest
|
9
|
-
from
|
9
|
+
from langgraph.types import Command
|
10
|
+
from ..tools.s2.display_results import (
|
11
|
+
display_results,
|
12
|
+
NoPapersFoundError as raised_error,
|
13
|
+
)
|
10
14
|
from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
|
11
15
|
from ..tools.s2.search import search_tool
|
12
16
|
from ..tools.s2.single_paper_rec import get_single_paper_recommendations
|
17
|
+
from ..tools.s2.query_results import query_results, NoPapersFoundError
|
18
|
+
from ..tools.s2.retrieve_semantic_scholar_paper_id import (
|
19
|
+
retrieve_semantic_scholar_paper_id,
|
20
|
+
)
|
13
21
|
|
14
22
|
|
15
23
|
@pytest.fixture
|
@@ -50,20 +58,29 @@ class TestS2Tools:
|
|
50
58
|
def test_display_results_empty_state(self, initial_state):
|
51
59
|
"""Verifies display_results tool behavior when state is empty and raises an exception"""
|
52
60
|
with pytest.raises(
|
53
|
-
|
54
|
-
match="No papers found. A search needs to be performed first.",
|
61
|
+
raised_error,
|
62
|
+
match="No papers found. A search/rec needs to be performed first.",
|
55
63
|
):
|
56
|
-
display_results.invoke({"state": initial_state})
|
64
|
+
display_results.invoke({"state": initial_state, "tool_call_id": "test123"})
|
57
65
|
|
58
66
|
def test_display_results_shows_papers(self, initial_state):
|
59
67
|
"""Verifies display_results tool correctly returns papers from state"""
|
60
68
|
state = initial_state.copy()
|
69
|
+
state["last_displayed_papers"] = "papers"
|
61
70
|
state["papers"] = MOCK_STATE_PAPER
|
62
|
-
|
63
|
-
result = display_results.invoke(
|
64
|
-
|
65
|
-
|
66
|
-
|
71
|
+
|
72
|
+
result = display_results.invoke(
|
73
|
+
input={"state": state, "tool_call_id": "test123"}
|
74
|
+
)
|
75
|
+
|
76
|
+
assert isinstance(result, Command) # Expect a Command object
|
77
|
+
assert isinstance(result.update, dict) # Ensure update is a dictionary
|
78
|
+
assert "messages" in result.update
|
79
|
+
assert len(result.update["messages"]) == 1
|
80
|
+
assert (
|
81
|
+
"1 papers found. Papers are attached as an artifact."
|
82
|
+
in result.update["messages"][0].content
|
83
|
+
)
|
67
84
|
|
68
85
|
@patch("requests.get")
|
69
86
|
def test_search_finds_papers(self, mock_get):
|
@@ -152,14 +169,12 @@ class TestS2Tools:
|
|
152
169
|
mock_get.return_value.status_code = 200
|
153
170
|
|
154
171
|
result = get_single_paper_recommendations.invoke(
|
155
|
-
input={
|
156
|
-
"paper_id": "123",
|
157
|
-
"limit": 1,
|
158
|
-
"tool_call_id": "test123",
|
159
|
-
}
|
172
|
+
input={"paper_id": "123", "limit": 1, "tool_call_id": "test123"}
|
160
173
|
)
|
174
|
+
|
175
|
+
assert isinstance(result, Command)
|
161
176
|
assert "papers" in result.update
|
162
|
-
assert
|
177
|
+
assert len(result.update["messages"]) == 1
|
163
178
|
|
164
179
|
@patch("requests.get")
|
165
180
|
def test_single_paper_rec_with_optional_params(self, mock_get):
|
@@ -189,14 +204,12 @@ class TestS2Tools:
|
|
189
204
|
mock_post.return_value.status_code = 200
|
190
205
|
|
191
206
|
result = get_multi_paper_recommendations.invoke(
|
192
|
-
input={
|
193
|
-
"paper_ids": ["123", "456"],
|
194
|
-
"limit": 1,
|
195
|
-
"tool_call_id": "test123",
|
196
|
-
}
|
207
|
+
input={"paper_ids": ["123", "456"], "limit": 1, "tool_call_id": "test123"}
|
197
208
|
)
|
209
|
+
|
210
|
+
assert isinstance(result, Command)
|
198
211
|
assert "multi_papers" in result.update
|
199
|
-
assert
|
212
|
+
assert len(result.update["messages"]) == 1
|
200
213
|
|
201
214
|
@patch("requests.post")
|
202
215
|
def test_multi_paper_rec_with_optional_params(self, mock_post):
|
@@ -217,3 +230,126 @@ class TestS2Tools:
|
|
217
230
|
)
|
218
231
|
assert "multi_papers" in result.update
|
219
232
|
assert len(result.update["messages"]) == 1
|
233
|
+
|
234
|
+
@patch("requests.get")
|
235
|
+
def test_search_tool_finds_papers(self, mock_get):
|
236
|
+
"""Verifies search tool finds and formats papers correctly"""
|
237
|
+
mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
|
238
|
+
mock_get.return_value.status_code = 200
|
239
|
+
|
240
|
+
result = search_tool.invoke(
|
241
|
+
input={"query": "machine learning", "limit": 1, "tool_call_id": "test123"}
|
242
|
+
)
|
243
|
+
|
244
|
+
assert isinstance(result, Command) # Expect a Command object
|
245
|
+
assert "papers" in result.update
|
246
|
+
assert len(result.update["papers"]) > 0
|
247
|
+
|
248
|
+
def test_query_results_empty_state(self, initial_state):
|
249
|
+
"""Tests query_results tool behavior when no papers are found."""
|
250
|
+
with pytest.raises(
|
251
|
+
NoPapersFoundError,
|
252
|
+
match="No papers found. A search needs to be performed first.",
|
253
|
+
):
|
254
|
+
query_results.invoke(
|
255
|
+
{"question": "List all papers", "state": initial_state}
|
256
|
+
)
|
257
|
+
|
258
|
+
@patch(
|
259
|
+
"aiagents4pharma.talk2scholars.tools.s2.query_results.create_pandas_dataframe_agent"
|
260
|
+
)
|
261
|
+
def test_query_results_with_papers(self, mock_create_agent, initial_state):
|
262
|
+
"""Tests querying papers when data is available."""
|
263
|
+
state = initial_state.copy()
|
264
|
+
state["last_displayed_papers"] = "papers"
|
265
|
+
state["papers"] = MOCK_STATE_PAPER
|
266
|
+
|
267
|
+
# Mock the dataframe agent instead of the LLM
|
268
|
+
mock_agent = MagicMock()
|
269
|
+
mock_agent.invoke.return_value = {"output": "Mocked response"}
|
270
|
+
|
271
|
+
mock_create_agent.return_value = (
|
272
|
+
mock_agent # Mock the function returning the agent
|
273
|
+
)
|
274
|
+
|
275
|
+
# Ensure that the output of query_results is correctly structured
|
276
|
+
result = query_results.invoke({"question": "List all papers", "state": state})
|
277
|
+
|
278
|
+
assert isinstance(result, str) # Ensure output is a string
|
279
|
+
assert result == "Mocked response" # Validate the expected response
|
280
|
+
|
281
|
+
@patch("requests.get")
|
282
|
+
def test_retrieve_semantic_scholar_paper_id(self, mock_get):
|
283
|
+
"""Tests retrieving a paper ID from Semantic Scholar."""
|
284
|
+
mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
|
285
|
+
mock_get.return_value.status_code = 200
|
286
|
+
|
287
|
+
result = retrieve_semantic_scholar_paper_id.invoke(
|
288
|
+
input={"paper_title": "Machine Learning Basics", "tool_call_id": "test123"}
|
289
|
+
)
|
290
|
+
|
291
|
+
assert isinstance(result, Command)
|
292
|
+
assert "messages" in result.update
|
293
|
+
assert (
|
294
|
+
"Paper ID for 'Machine Learning Basics' is: 123"
|
295
|
+
in result.update["messages"][0].content
|
296
|
+
)
|
297
|
+
|
298
|
+
def test_retrieve_semantic_scholar_paper_id_no_results(self):
|
299
|
+
"""Test retrieving a paper ID when no results are found."""
|
300
|
+
with pytest.raises(ValueError, match="No papers found for query: UnknownPaper"):
|
301
|
+
retrieve_semantic_scholar_paper_id.invoke(
|
302
|
+
input={"paper_title": "UnknownPaper", "tool_call_id": "test123"}
|
303
|
+
)
|
304
|
+
|
305
|
+
def test_single_paper_rec_invalid_id(self):
|
306
|
+
"""Test single paper recommendation with an invalid ID."""
|
307
|
+
with pytest.raises(ValueError, match="Invalid paper ID or API error."):
|
308
|
+
get_single_paper_recommendations.invoke(
|
309
|
+
input={"paper_id": "", "tool_call_id": "test123"} # Empty ID case
|
310
|
+
)
|
311
|
+
|
312
|
+
@patch("requests.post")
|
313
|
+
def test_multi_paper_rec_no_recommendations(self, mock_post):
|
314
|
+
"""Tests behavior when multi-paper recommendation API returns no results."""
|
315
|
+
mock_post.return_value.json.return_value = {"recommendedPapers": []}
|
316
|
+
mock_post.return_value.status_code = 200
|
317
|
+
|
318
|
+
result = get_multi_paper_recommendations.invoke(
|
319
|
+
input={"paper_ids": ["123", "456"], "limit": 1, "tool_call_id": "test123"}
|
320
|
+
)
|
321
|
+
|
322
|
+
assert isinstance(result, Command)
|
323
|
+
assert "messages" in result.update
|
324
|
+
assert (
|
325
|
+
"No recommendations found based on multiple papers."
|
326
|
+
in result.update["messages"][0].content
|
327
|
+
)
|
328
|
+
|
329
|
+
@patch("requests.get")
|
330
|
+
def test_search_no_results(self, mock_get):
|
331
|
+
"""Tests behavior when search API returns no results."""
|
332
|
+
mock_get.return_value.json.return_value = {"data": []}
|
333
|
+
mock_get.return_value.status_code = 200
|
334
|
+
|
335
|
+
result = search_tool.invoke(
|
336
|
+
input={"query": "nonexistent topic", "limit": 1, "tool_call_id": "test123"}
|
337
|
+
)
|
338
|
+
|
339
|
+
assert isinstance(result, Command)
|
340
|
+
assert "messages" in result.update
|
341
|
+
assert "No papers found." in result.update["messages"][0].content
|
342
|
+
|
343
|
+
@patch("requests.get")
|
344
|
+
def test_single_paper_rec_no_recommendations(self, mock_get):
|
345
|
+
"""Tests behavior when single paper recommendation API returns no results."""
|
346
|
+
mock_get.return_value.json.return_value = {"recommendedPapers": []}
|
347
|
+
mock_get.return_value.status_code = 200
|
348
|
+
|
349
|
+
result = get_single_paper_recommendations.invoke(
|
350
|
+
input={"paper_id": "123", "limit": 1, "tool_call_id": "test123"}
|
351
|
+
)
|
352
|
+
|
353
|
+
assert isinstance(result, Command)
|
354
|
+
assert "messages" in result.update
|
355
|
+
assert "No recommendations found for" in result.update["messages"][0].content
|
@@ -1,8 +1,10 @@
|
|
1
|
-
|
1
|
+
"""
|
2
2
|
This file is used to import all the modules in the package.
|
3
|
-
|
3
|
+
"""
|
4
4
|
|
5
5
|
from . import display_results
|
6
6
|
from . import multi_paper_rec
|
7
7
|
from . import search
|
8
8
|
from . import single_paper_rec
|
9
|
+
from . import query_results
|
10
|
+
from . import retrieve_semantic_scholar_paper_id
|
@@ -1,13 +1,24 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
|
3
|
+
|
3
4
|
"""
|
4
|
-
|
5
|
+
Tool for displaying search or recommendation results.
|
6
|
+
|
7
|
+
This module defines a tool that retrieves and displays a table of research papers
|
8
|
+
found during searches or recommendations. If no papers are found, an exception is raised
|
9
|
+
to signal the need for a new search.
|
5
10
|
"""
|
6
11
|
|
12
|
+
|
7
13
|
import logging
|
8
|
-
|
14
|
+
|
15
|
+
from typing import Annotated
|
16
|
+
from langchain_core.messages import ToolMessage
|
9
17
|
from langchain_core.tools import tool
|
18
|
+
from langchain_core.tools.base import InjectedToolCallId
|
10
19
|
from langgraph.prebuilt import InjectedState
|
20
|
+
from langgraph.types import Command
|
21
|
+
|
11
22
|
|
12
23
|
# Configure logging
|
13
24
|
logging.basicConfig(level=logging.INFO)
|
@@ -15,36 +26,64 @@ logger = logging.getLogger(__name__)
|
|
15
26
|
|
16
27
|
|
17
28
|
class NoPapersFoundError(Exception):
|
18
|
-
"""
|
29
|
+
"""
|
30
|
+
Exception raised when no research papers are found in the agent's state.
|
19
31
|
|
32
|
+
This exception helps the language model determine whether a new search
|
33
|
+
or recommendation should be initiated.
|
20
34
|
|
21
|
-
|
22
|
-
|
35
|
+
Example:
|
36
|
+
>>> if not papers:
|
37
|
+
>>> raise NoPapersFoundError("No papers found. A search is needed.")
|
23
38
|
"""
|
24
|
-
|
25
|
-
|
39
|
+
|
40
|
+
|
41
|
+
@tool("display_results", parse_docstring=True)
|
42
|
+
def display_results(
|
43
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
44
|
+
state: Annotated[dict, InjectedState],
|
45
|
+
) -> Command:
|
46
|
+
"""
|
47
|
+
Displays retrieved research papers after a search or recommendation.
|
48
|
+
|
49
|
+
This function retrieves the last displayed research papers from the state and
|
50
|
+
returns them as an artifact for further processing. If no papers are found,
|
51
|
+
it raises a `NoPapersFoundError` to indicate that a new search is needed.
|
26
52
|
|
27
53
|
Args:
|
28
|
-
|
54
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID for tracking.
|
55
|
+
state (dict): The agent's state containing retrieved papers.
|
29
56
|
|
30
57
|
Returns:
|
31
|
-
|
58
|
+
Command: A command containing a message with the number of displayed papers
|
59
|
+
and an attached artifact for further reference.
|
32
60
|
|
33
61
|
Raises:
|
34
|
-
NoPapersFoundError: If no papers are found in the state.
|
62
|
+
NoPapersFoundError: If no research papers are found in the agent's state.
|
35
63
|
|
36
|
-
|
37
|
-
|
64
|
+
Example:
|
65
|
+
>>> state = {"last_displayed_papers": {"paper1": "Title 1", "paper2": "Title 2"}}
|
66
|
+
>>> result = display_results(tool_call_id="123", state=state)
|
67
|
+
>>> print(result.update["messages"][0].content)
|
68
|
+
"2 papers found. Papers are attached as an artifact."
|
38
69
|
"""
|
39
|
-
logger.info("Displaying papers
|
40
|
-
|
41
|
-
|
70
|
+
logger.info("Displaying papers")
|
71
|
+
context_key = state.get("last_displayed_papers")
|
72
|
+
artifact = state.get(context_key)
|
73
|
+
if not artifact:
|
42
74
|
logger.info("No papers found in state, raising NoPapersFoundError")
|
43
75
|
raise NoPapersFoundError(
|
44
|
-
"No papers found. A search needs to be performed first."
|
76
|
+
"No papers found. A search/rec needs to be performed first."
|
45
77
|
)
|
46
|
-
|
47
|
-
return
|
48
|
-
|
49
|
-
|
50
|
-
|
78
|
+
content = f"{len(artifact)} papers found. Papers are attached as an artifact."
|
79
|
+
return Command(
|
80
|
+
update={
|
81
|
+
"messages": [
|
82
|
+
ToolMessage(
|
83
|
+
content=content,
|
84
|
+
tool_call_id=tool_call_id,
|
85
|
+
artifact=artifact,
|
86
|
+
)
|
87
|
+
],
|
88
|
+
}
|
89
|
+
)
|
@@ -7,7 +7,7 @@ multi_paper_rec: Tool for getting recommendations
|
|
7
7
|
|
8
8
|
import json
|
9
9
|
import logging
|
10
|
-
from typing import Annotated, Any,
|
10
|
+
from typing import Annotated, Any, List, Optional
|
11
11
|
import hydra
|
12
12
|
import requests
|
13
13
|
from langchain_core.messages import ToolMessage
|
@@ -52,15 +52,16 @@ with hydra.initialize(version_base=None, config_path="../../configs"):
|
|
52
52
|
cfg = cfg.tools.multi_paper_recommendation
|
53
53
|
|
54
54
|
|
55
|
-
@tool(args_schema=MultiPaperRecInput)
|
55
|
+
@tool(args_schema=MultiPaperRecInput, parse_docstring=True)
|
56
56
|
def get_multi_paper_recommendations(
|
57
57
|
paper_ids: List[str],
|
58
58
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
59
59
|
limit: int = 2,
|
60
60
|
year: Optional[str] = None,
|
61
|
-
) ->
|
61
|
+
) -> Command[Any]:
|
62
62
|
"""
|
63
|
-
Get
|
63
|
+
Get recommendations for a group of multiple papers using the Semantic Scholar IDs.
|
64
|
+
No other paper IDs are supported.
|
64
65
|
|
65
66
|
Args:
|
66
67
|
paper_ids (List[str]): The list of paper IDs to base recommendations on.
|
@@ -72,7 +73,9 @@ def get_multi_paper_recommendations(
|
|
72
73
|
Returns:
|
73
74
|
Dict[str, Any]: The recommendations and related information.
|
74
75
|
"""
|
75
|
-
logging.info(
|
76
|
+
logging.info(
|
77
|
+
"Starting multi-paper recommendations search with paper IDs: %s", paper_ids
|
78
|
+
)
|
76
79
|
|
77
80
|
endpoint = cfg.api_endpoint
|
78
81
|
headers = cfg.headers
|
@@ -101,26 +104,50 @@ def get_multi_paper_recommendations(
|
|
101
104
|
data = response.json()
|
102
105
|
recommendations = data.get("recommendedPapers", [])
|
103
106
|
|
107
|
+
if not recommendations:
|
108
|
+
return Command(
|
109
|
+
update={ # Place 'messages' inside 'update'
|
110
|
+
"messages": [
|
111
|
+
ToolMessage(
|
112
|
+
content="No recommendations found based on multiple papers.",
|
113
|
+
tool_call_id=tool_call_id,
|
114
|
+
)
|
115
|
+
]
|
116
|
+
}
|
117
|
+
)
|
118
|
+
|
104
119
|
# Create a dictionary to store the papers
|
105
120
|
filtered_papers = {
|
106
121
|
paper["paperId"]: {
|
122
|
+
# "semantic_scholar_id": paper["paperId"], # Store Semantic Scholar ID
|
107
123
|
"Title": paper.get("title", "N/A"),
|
108
124
|
"Abstract": paper.get("abstract", "N/A"),
|
109
125
|
"Year": paper.get("year", "N/A"),
|
110
126
|
"Citation Count": paper.get("citationCount", "N/A"),
|
111
127
|
"URL": paper.get("url", "N/A"),
|
128
|
+
# "arXiv_ID": paper.get("externalIds", {}).get(
|
129
|
+
# "ArXiv", "N/A"
|
130
|
+
# ), # Extract arXiv ID
|
112
131
|
}
|
113
132
|
for paper in recommendations
|
114
|
-
if paper.get("title") and paper.get("
|
133
|
+
if paper.get("title") and paper.get("authors")
|
115
134
|
}
|
116
135
|
|
136
|
+
content = "Recommendations based on multiple papers was successful."
|
137
|
+
content += " Here is a summary of the recommendations:"
|
138
|
+
content += f"Number of papers found: {len(filtered_papers)}\n"
|
139
|
+
content += f"Query Paper IDs: {', '.join(paper_ids)}\n"
|
140
|
+
content += f"Year: {year}\n" if year else ""
|
141
|
+
|
117
142
|
return Command(
|
118
143
|
update={
|
119
144
|
"multi_papers": filtered_papers, # Now sending the dictionary directly
|
145
|
+
"last_displayed_papers": "multi_papers",
|
120
146
|
"messages": [
|
121
147
|
ToolMessage(
|
122
|
-
content=
|
123
|
-
tool_call_id=tool_call_id
|
148
|
+
content=content,
|
149
|
+
tool_call_id=tool_call_id,
|
150
|
+
artifact=filtered_papers,
|
124
151
|
)
|
125
152
|
],
|
126
153
|
}
|
@@ -0,0 +1,61 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to display the table of studies.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated
|
9
|
+
import pandas as pd
|
10
|
+
from langchain_experimental.agents import create_pandas_dataframe_agent
|
11
|
+
from langchain_core.tools import tool
|
12
|
+
from langgraph.prebuilt import InjectedState
|
13
|
+
|
14
|
+
# Configure logging
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class NoPapersFoundError(Exception):
|
20
|
+
"""Exception raised when no papers are found in the state."""
|
21
|
+
|
22
|
+
|
23
|
+
@tool("query_results", parse_docstring=True)
|
24
|
+
def query_results(question: str, state: Annotated[dict, InjectedState]) -> str:
|
25
|
+
"""
|
26
|
+
Query the last displayed papers from the state. If no papers are found,
|
27
|
+
raises an exception.
|
28
|
+
|
29
|
+
Use this also to get the last displayed papers from the state,
|
30
|
+
and then use the papers to get recommendations for a single paper or
|
31
|
+
multiple papers.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
question (str): The question to ask the agent.
|
35
|
+
state (dict): The state of the agent containing the papers.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
str: A message with the last displayed papers.
|
39
|
+
"""
|
40
|
+
logger.info("Querying last displayed papers with question: %s", question)
|
41
|
+
llm_model = state.get("llm_model")
|
42
|
+
if not state.get("last_displayed_papers"):
|
43
|
+
logger.info("No papers displayed so far, raising NoPapersFoundError")
|
44
|
+
raise NoPapersFoundError(
|
45
|
+
"No papers found. A search needs to be performed first."
|
46
|
+
)
|
47
|
+
context_key = state.get("last_displayed_papers")
|
48
|
+
dic_papers = state.get(context_key)
|
49
|
+
df_papers = pd.DataFrame.from_dict(dic_papers, orient="index")
|
50
|
+
df_agent = create_pandas_dataframe_agent(
|
51
|
+
llm_model,
|
52
|
+
allow_dangerous_code=True,
|
53
|
+
agent_type="tool-calling",
|
54
|
+
df=df_papers,
|
55
|
+
max_iterations=5,
|
56
|
+
include_df_in_prompt=True,
|
57
|
+
number_of_head_rows=df_papers.shape[0],
|
58
|
+
verbose=True,
|
59
|
+
)
|
60
|
+
llm_result = df_agent.invoke(question, stream_mode=None)
|
61
|
+
return llm_result["output"]
|
@@ -0,0 +1,79 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to search for academic papers on Semantic Scholar.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated, Any
|
9
|
+
import hydra
|
10
|
+
import requests
|
11
|
+
from langchain_core.messages import ToolMessage
|
12
|
+
from langchain_core.tools import tool
|
13
|
+
from langchain_core.tools.base import InjectedToolCallId
|
14
|
+
from langgraph.types import Command
|
15
|
+
from pydantic import Field
|
16
|
+
|
17
|
+
|
18
|
+
# Configure logging
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
# Load hydra configuration
|
23
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
24
|
+
cfg = hydra.compose(
|
25
|
+
config_name="config",
|
26
|
+
overrides=["tools/retrieve_semantic_scholar_paper_id=default"],
|
27
|
+
)
|
28
|
+
cfg = cfg.tools.retrieve_semantic_scholar_paper_id
|
29
|
+
|
30
|
+
|
31
|
+
@tool("retrieve_semantic_scholar_paper_id", parse_docstring=True)
|
32
|
+
def retrieve_semantic_scholar_paper_id(
|
33
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
34
|
+
paper_title: str = Field(
|
35
|
+
description="The title of the paper to search for on Semantic Scholar."
|
36
|
+
),
|
37
|
+
) -> Command[Any]:
|
38
|
+
"""
|
39
|
+
This tool can be used to search for a paper on Semantic Scholar
|
40
|
+
and retrieve the paper Semantic Scholar ID.
|
41
|
+
|
42
|
+
This is useful for when an article is retrieved from users Zotero library
|
43
|
+
and the Semantic Scholar ID is needed to retrieve more information about the paper.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
47
|
+
paper_title (str): The title of the paper to search for on Semantic Scholar.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
ToolMessage: A message containing the paper ID.
|
51
|
+
"""
|
52
|
+
logger.info("Retrieving ID of paper with title: %s", paper_title)
|
53
|
+
endpoint = cfg.api_endpoint
|
54
|
+
params = {
|
55
|
+
"query": paper_title,
|
56
|
+
"limit": 1,
|
57
|
+
"fields": ",".join(cfg.api_fields),
|
58
|
+
}
|
59
|
+
|
60
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
61
|
+
data = response.json()
|
62
|
+
papers = data.get("data", [])
|
63
|
+
logger.info("Received %d papers", len(papers))
|
64
|
+
if not papers:
|
65
|
+
logger.error("No papers found for query: %s", paper_title)
|
66
|
+
raise ValueError(f"No papers found for query: {paper_title}. Try again.")
|
67
|
+
# Get the paper ID
|
68
|
+
paper_id = papers[0]["paperId"]
|
69
|
+
|
70
|
+
return Command(
|
71
|
+
update={
|
72
|
+
"messages": [
|
73
|
+
ToolMessage(
|
74
|
+
content=f"Paper ID for '{paper_title}' is: {paper_id}",
|
75
|
+
tool_call_id=tool_call_id,
|
76
|
+
)
|
77
|
+
],
|
78
|
+
}
|
79
|
+
)
|