aiagents4pharma 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
- aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
- aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
- aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
- aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
- aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/METADATA +2 -2
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/RECORD +28 -24
- aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,20 @@ Unit tests for S2 tools functionality.
|
|
4
4
|
|
5
5
|
# pylint: disable=redefined-outer-name
|
6
6
|
from unittest.mock import patch
|
7
|
-
from
|
7
|
+
from unittest.mock import MagicMock
|
8
8
|
import pytest
|
9
|
-
from
|
9
|
+
from langgraph.types import Command
|
10
|
+
from ..tools.s2.display_results import (
|
11
|
+
display_results,
|
12
|
+
NoPapersFoundError as raised_error,
|
13
|
+
)
|
10
14
|
from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
|
11
15
|
from ..tools.s2.search import search_tool
|
12
16
|
from ..tools.s2.single_paper_rec import get_single_paper_recommendations
|
17
|
+
from ..tools.s2.query_results import query_results, NoPapersFoundError
|
18
|
+
from ..tools.s2.retrieve_semantic_scholar_paper_id import (
|
19
|
+
retrieve_semantic_scholar_paper_id,
|
20
|
+
)
|
13
21
|
|
14
22
|
|
15
23
|
@pytest.fixture
|
@@ -50,20 +58,29 @@ class TestS2Tools:
|
|
50
58
|
def test_display_results_empty_state(self, initial_state):
|
51
59
|
"""Verifies display_results tool behavior when state is empty and raises an exception"""
|
52
60
|
with pytest.raises(
|
53
|
-
|
54
|
-
match="No papers found. A search needs to be performed first.",
|
61
|
+
raised_error,
|
62
|
+
match="No papers found. A search/rec needs to be performed first.",
|
55
63
|
):
|
56
|
-
display_results.invoke({"state": initial_state})
|
64
|
+
display_results.invoke({"state": initial_state, "tool_call_id": "test123"})
|
57
65
|
|
58
66
|
def test_display_results_shows_papers(self, initial_state):
|
59
67
|
"""Verifies display_results tool correctly returns papers from state"""
|
60
68
|
state = initial_state.copy()
|
69
|
+
state["last_displayed_papers"] = "papers"
|
61
70
|
state["papers"] = MOCK_STATE_PAPER
|
62
|
-
|
63
|
-
result = display_results.invoke(
|
64
|
-
|
65
|
-
|
66
|
-
|
71
|
+
|
72
|
+
result = display_results.invoke(
|
73
|
+
input={"state": state, "tool_call_id": "test123"}
|
74
|
+
)
|
75
|
+
|
76
|
+
assert isinstance(result, Command) # Expect a Command object
|
77
|
+
assert isinstance(result.update, dict) # Ensure update is a dictionary
|
78
|
+
assert "messages" in result.update
|
79
|
+
assert len(result.update["messages"]) == 1
|
80
|
+
assert (
|
81
|
+
"1 papers found. Papers are attached as an artifact."
|
82
|
+
in result.update["messages"][0].content
|
83
|
+
)
|
67
84
|
|
68
85
|
@patch("requests.get")
|
69
86
|
def test_search_finds_papers(self, mock_get):
|
@@ -152,14 +169,12 @@ class TestS2Tools:
|
|
152
169
|
mock_get.return_value.status_code = 200
|
153
170
|
|
154
171
|
result = get_single_paper_recommendations.invoke(
|
155
|
-
input={
|
156
|
-
"paper_id": "123",
|
157
|
-
"limit": 1,
|
158
|
-
"tool_call_id": "test123",
|
159
|
-
}
|
172
|
+
input={"paper_id": "123", "limit": 1, "tool_call_id": "test123"}
|
160
173
|
)
|
174
|
+
|
175
|
+
assert isinstance(result, Command)
|
161
176
|
assert "papers" in result.update
|
162
|
-
assert
|
177
|
+
assert len(result.update["messages"]) == 1
|
163
178
|
|
164
179
|
@patch("requests.get")
|
165
180
|
def test_single_paper_rec_with_optional_params(self, mock_get):
|
@@ -189,14 +204,12 @@ class TestS2Tools:
|
|
189
204
|
mock_post.return_value.status_code = 200
|
190
205
|
|
191
206
|
result = get_multi_paper_recommendations.invoke(
|
192
|
-
input={
|
193
|
-
"paper_ids": ["123", "456"],
|
194
|
-
"limit": 1,
|
195
|
-
"tool_call_id": "test123",
|
196
|
-
}
|
207
|
+
input={"paper_ids": ["123", "456"], "limit": 1, "tool_call_id": "test123"}
|
197
208
|
)
|
209
|
+
|
210
|
+
assert isinstance(result, Command)
|
198
211
|
assert "multi_papers" in result.update
|
199
|
-
assert
|
212
|
+
assert len(result.update["messages"]) == 1
|
200
213
|
|
201
214
|
@patch("requests.post")
|
202
215
|
def test_multi_paper_rec_with_optional_params(self, mock_post):
|
@@ -217,3 +230,126 @@ class TestS2Tools:
|
|
217
230
|
)
|
218
231
|
assert "multi_papers" in result.update
|
219
232
|
assert len(result.update["messages"]) == 1
|
233
|
+
|
234
|
+
@patch("requests.get")
|
235
|
+
def test_search_tool_finds_papers(self, mock_get):
|
236
|
+
"""Verifies search tool finds and formats papers correctly"""
|
237
|
+
mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
|
238
|
+
mock_get.return_value.status_code = 200
|
239
|
+
|
240
|
+
result = search_tool.invoke(
|
241
|
+
input={"query": "machine learning", "limit": 1, "tool_call_id": "test123"}
|
242
|
+
)
|
243
|
+
|
244
|
+
assert isinstance(result, Command) # Expect a Command object
|
245
|
+
assert "papers" in result.update
|
246
|
+
assert len(result.update["papers"]) > 0
|
247
|
+
|
248
|
+
def test_query_results_empty_state(self, initial_state):
|
249
|
+
"""Tests query_results tool behavior when no papers are found."""
|
250
|
+
with pytest.raises(
|
251
|
+
NoPapersFoundError,
|
252
|
+
match="No papers found. A search needs to be performed first.",
|
253
|
+
):
|
254
|
+
query_results.invoke(
|
255
|
+
{"question": "List all papers", "state": initial_state}
|
256
|
+
)
|
257
|
+
|
258
|
+
@patch(
|
259
|
+
"aiagents4pharma.talk2scholars.tools.s2.query_results.create_pandas_dataframe_agent"
|
260
|
+
)
|
261
|
+
def test_query_results_with_papers(self, mock_create_agent, initial_state):
|
262
|
+
"""Tests querying papers when data is available."""
|
263
|
+
state = initial_state.copy()
|
264
|
+
state["last_displayed_papers"] = "papers"
|
265
|
+
state["papers"] = MOCK_STATE_PAPER
|
266
|
+
|
267
|
+
# Mock the dataframe agent instead of the LLM
|
268
|
+
mock_agent = MagicMock()
|
269
|
+
mock_agent.invoke.return_value = {"output": "Mocked response"}
|
270
|
+
|
271
|
+
mock_create_agent.return_value = (
|
272
|
+
mock_agent # Mock the function returning the agent
|
273
|
+
)
|
274
|
+
|
275
|
+
# Ensure that the output of query_results is correctly structured
|
276
|
+
result = query_results.invoke({"question": "List all papers", "state": state})
|
277
|
+
|
278
|
+
assert isinstance(result, str) # Ensure output is a string
|
279
|
+
assert result == "Mocked response" # Validate the expected response
|
280
|
+
|
281
|
+
@patch("requests.get")
|
282
|
+
def test_retrieve_semantic_scholar_paper_id(self, mock_get):
|
283
|
+
"""Tests retrieving a paper ID from Semantic Scholar."""
|
284
|
+
mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
|
285
|
+
mock_get.return_value.status_code = 200
|
286
|
+
|
287
|
+
result = retrieve_semantic_scholar_paper_id.invoke(
|
288
|
+
input={"paper_title": "Machine Learning Basics", "tool_call_id": "test123"}
|
289
|
+
)
|
290
|
+
|
291
|
+
assert isinstance(result, Command)
|
292
|
+
assert "messages" in result.update
|
293
|
+
assert (
|
294
|
+
"Paper ID for 'Machine Learning Basics' is: 123"
|
295
|
+
in result.update["messages"][0].content
|
296
|
+
)
|
297
|
+
|
298
|
+
def test_retrieve_semantic_scholar_paper_id_no_results(self):
|
299
|
+
"""Test retrieving a paper ID when no results are found."""
|
300
|
+
with pytest.raises(ValueError, match="No papers found for query: UnknownPaper"):
|
301
|
+
retrieve_semantic_scholar_paper_id.invoke(
|
302
|
+
input={"paper_title": "UnknownPaper", "tool_call_id": "test123"}
|
303
|
+
)
|
304
|
+
|
305
|
+
def test_single_paper_rec_invalid_id(self):
|
306
|
+
"""Test single paper recommendation with an invalid ID."""
|
307
|
+
with pytest.raises(ValueError, match="Invalid paper ID or API error."):
|
308
|
+
get_single_paper_recommendations.invoke(
|
309
|
+
input={"paper_id": "", "tool_call_id": "test123"} # Empty ID case
|
310
|
+
)
|
311
|
+
|
312
|
+
@patch("requests.post")
|
313
|
+
def test_multi_paper_rec_no_recommendations(self, mock_post):
|
314
|
+
"""Tests behavior when multi-paper recommendation API returns no results."""
|
315
|
+
mock_post.return_value.json.return_value = {"recommendedPapers": []}
|
316
|
+
mock_post.return_value.status_code = 200
|
317
|
+
|
318
|
+
result = get_multi_paper_recommendations.invoke(
|
319
|
+
input={"paper_ids": ["123", "456"], "limit": 1, "tool_call_id": "test123"}
|
320
|
+
)
|
321
|
+
|
322
|
+
assert isinstance(result, Command)
|
323
|
+
assert "messages" in result.update
|
324
|
+
assert (
|
325
|
+
"No recommendations found based on multiple papers."
|
326
|
+
in result.update["messages"][0].content
|
327
|
+
)
|
328
|
+
|
329
|
+
@patch("requests.get")
|
330
|
+
def test_search_no_results(self, mock_get):
|
331
|
+
"""Tests behavior when search API returns no results."""
|
332
|
+
mock_get.return_value.json.return_value = {"data": []}
|
333
|
+
mock_get.return_value.status_code = 200
|
334
|
+
|
335
|
+
result = search_tool.invoke(
|
336
|
+
input={"query": "nonexistent topic", "limit": 1, "tool_call_id": "test123"}
|
337
|
+
)
|
338
|
+
|
339
|
+
assert isinstance(result, Command)
|
340
|
+
assert "messages" in result.update
|
341
|
+
assert "No papers found." in result.update["messages"][0].content
|
342
|
+
|
343
|
+
@patch("requests.get")
|
344
|
+
def test_single_paper_rec_no_recommendations(self, mock_get):
|
345
|
+
"""Tests behavior when single paper recommendation API returns no results."""
|
346
|
+
mock_get.return_value.json.return_value = {"recommendedPapers": []}
|
347
|
+
mock_get.return_value.status_code = 200
|
348
|
+
|
349
|
+
result = get_single_paper_recommendations.invoke(
|
350
|
+
input={"paper_id": "123", "limit": 1, "tool_call_id": "test123"}
|
351
|
+
)
|
352
|
+
|
353
|
+
assert isinstance(result, Command)
|
354
|
+
assert "messages" in result.update
|
355
|
+
assert "No recommendations found for" in result.update["messages"][0].content
|
@@ -1,8 +1,10 @@
|
|
1
|
-
|
1
|
+
"""
|
2
2
|
This file is used to import all the modules in the package.
|
3
|
-
|
3
|
+
"""
|
4
4
|
|
5
5
|
from . import display_results
|
6
6
|
from . import multi_paper_rec
|
7
7
|
from . import search
|
8
8
|
from . import single_paper_rec
|
9
|
+
from . import query_results
|
10
|
+
from . import retrieve_semantic_scholar_paper_id
|
@@ -1,13 +1,24 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
|
3
|
+
|
3
4
|
"""
|
4
|
-
|
5
|
+
Tool for displaying search or recommendation results.
|
6
|
+
|
7
|
+
This module defines a tool that retrieves and displays a table of research papers
|
8
|
+
found during searches or recommendations. If no papers are found, an exception is raised
|
9
|
+
to signal the need for a new search.
|
5
10
|
"""
|
6
11
|
|
12
|
+
|
7
13
|
import logging
|
8
|
-
|
14
|
+
|
15
|
+
from typing import Annotated
|
16
|
+
from langchain_core.messages import ToolMessage
|
9
17
|
from langchain_core.tools import tool
|
18
|
+
from langchain_core.tools.base import InjectedToolCallId
|
10
19
|
from langgraph.prebuilt import InjectedState
|
20
|
+
from langgraph.types import Command
|
21
|
+
|
11
22
|
|
12
23
|
# Configure logging
|
13
24
|
logging.basicConfig(level=logging.INFO)
|
@@ -15,36 +26,64 @@ logger = logging.getLogger(__name__)
|
|
15
26
|
|
16
27
|
|
17
28
|
class NoPapersFoundError(Exception):
|
18
|
-
"""
|
29
|
+
"""
|
30
|
+
Exception raised when no research papers are found in the agent's state.
|
19
31
|
|
32
|
+
This exception helps the language model determine whether a new search
|
33
|
+
or recommendation should be initiated.
|
20
34
|
|
21
|
-
|
22
|
-
|
35
|
+
Example:
|
36
|
+
>>> if not papers:
|
37
|
+
>>> raise NoPapersFoundError("No papers found. A search is needed.")
|
23
38
|
"""
|
24
|
-
|
25
|
-
|
39
|
+
|
40
|
+
|
41
|
+
@tool("display_results", parse_docstring=True)
|
42
|
+
def display_results(
|
43
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
44
|
+
state: Annotated[dict, InjectedState],
|
45
|
+
) -> Command:
|
46
|
+
"""
|
47
|
+
Displays retrieved research papers after a search or recommendation.
|
48
|
+
|
49
|
+
This function retrieves the last displayed research papers from the state and
|
50
|
+
returns them as an artifact for further processing. If no papers are found,
|
51
|
+
it raises a `NoPapersFoundError` to indicate that a new search is needed.
|
26
52
|
|
27
53
|
Args:
|
28
|
-
|
54
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID for tracking.
|
55
|
+
state (dict): The agent's state containing retrieved papers.
|
29
56
|
|
30
57
|
Returns:
|
31
|
-
|
58
|
+
Command: A command containing a message with the number of displayed papers
|
59
|
+
and an attached artifact for further reference.
|
32
60
|
|
33
61
|
Raises:
|
34
|
-
NoPapersFoundError: If no papers are found in the state.
|
62
|
+
NoPapersFoundError: If no research papers are found in the agent's state.
|
35
63
|
|
36
|
-
|
37
|
-
|
64
|
+
Example:
|
65
|
+
>>> state = {"last_displayed_papers": {"paper1": "Title 1", "paper2": "Title 2"}}
|
66
|
+
>>> result = display_results(tool_call_id="123", state=state)
|
67
|
+
>>> print(result.update["messages"][0].content)
|
68
|
+
"2 papers found. Papers are attached as an artifact."
|
38
69
|
"""
|
39
|
-
logger.info("Displaying papers
|
40
|
-
|
41
|
-
|
70
|
+
logger.info("Displaying papers")
|
71
|
+
context_key = state.get("last_displayed_papers")
|
72
|
+
artifact = state.get(context_key)
|
73
|
+
if not artifact:
|
42
74
|
logger.info("No papers found in state, raising NoPapersFoundError")
|
43
75
|
raise NoPapersFoundError(
|
44
|
-
"No papers found. A search needs to be performed first."
|
76
|
+
"No papers found. A search/rec needs to be performed first."
|
45
77
|
)
|
46
|
-
|
47
|
-
return
|
48
|
-
|
49
|
-
|
50
|
-
|
78
|
+
content = f"{len(artifact)} papers found. Papers are attached as an artifact."
|
79
|
+
return Command(
|
80
|
+
update={
|
81
|
+
"messages": [
|
82
|
+
ToolMessage(
|
83
|
+
content=content,
|
84
|
+
tool_call_id=tool_call_id,
|
85
|
+
artifact=artifact,
|
86
|
+
)
|
87
|
+
],
|
88
|
+
}
|
89
|
+
)
|
@@ -7,7 +7,7 @@ multi_paper_rec: Tool for getting recommendations
|
|
7
7
|
|
8
8
|
import json
|
9
9
|
import logging
|
10
|
-
from typing import Annotated, Any,
|
10
|
+
from typing import Annotated, Any, List, Optional
|
11
11
|
import hydra
|
12
12
|
import requests
|
13
13
|
from langchain_core.messages import ToolMessage
|
@@ -52,15 +52,16 @@ with hydra.initialize(version_base=None, config_path="../../configs"):
|
|
52
52
|
cfg = cfg.tools.multi_paper_recommendation
|
53
53
|
|
54
54
|
|
55
|
-
@tool(args_schema=MultiPaperRecInput)
|
55
|
+
@tool(args_schema=MultiPaperRecInput, parse_docstring=True)
|
56
56
|
def get_multi_paper_recommendations(
|
57
57
|
paper_ids: List[str],
|
58
58
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
59
59
|
limit: int = 2,
|
60
60
|
year: Optional[str] = None,
|
61
|
-
) ->
|
61
|
+
) -> Command[Any]:
|
62
62
|
"""
|
63
|
-
Get
|
63
|
+
Get recommendations for a group of multiple papers using the Semantic Scholar IDs.
|
64
|
+
No other paper IDs are supported.
|
64
65
|
|
65
66
|
Args:
|
66
67
|
paper_ids (List[str]): The list of paper IDs to base recommendations on.
|
@@ -72,7 +73,9 @@ def get_multi_paper_recommendations(
|
|
72
73
|
Returns:
|
73
74
|
Dict[str, Any]: The recommendations and related information.
|
74
75
|
"""
|
75
|
-
logging.info(
|
76
|
+
logging.info(
|
77
|
+
"Starting multi-paper recommendations search with paper IDs: %s", paper_ids
|
78
|
+
)
|
76
79
|
|
77
80
|
endpoint = cfg.api_endpoint
|
78
81
|
headers = cfg.headers
|
@@ -101,26 +104,50 @@ def get_multi_paper_recommendations(
|
|
101
104
|
data = response.json()
|
102
105
|
recommendations = data.get("recommendedPapers", [])
|
103
106
|
|
107
|
+
if not recommendations:
|
108
|
+
return Command(
|
109
|
+
update={ # Place 'messages' inside 'update'
|
110
|
+
"messages": [
|
111
|
+
ToolMessage(
|
112
|
+
content="No recommendations found based on multiple papers.",
|
113
|
+
tool_call_id=tool_call_id,
|
114
|
+
)
|
115
|
+
]
|
116
|
+
}
|
117
|
+
)
|
118
|
+
|
104
119
|
# Create a dictionary to store the papers
|
105
120
|
filtered_papers = {
|
106
121
|
paper["paperId"]: {
|
122
|
+
# "semantic_scholar_id": paper["paperId"], # Store Semantic Scholar ID
|
107
123
|
"Title": paper.get("title", "N/A"),
|
108
124
|
"Abstract": paper.get("abstract", "N/A"),
|
109
125
|
"Year": paper.get("year", "N/A"),
|
110
126
|
"Citation Count": paper.get("citationCount", "N/A"),
|
111
127
|
"URL": paper.get("url", "N/A"),
|
128
|
+
# "arXiv_ID": paper.get("externalIds", {}).get(
|
129
|
+
# "ArXiv", "N/A"
|
130
|
+
# ), # Extract arXiv ID
|
112
131
|
}
|
113
132
|
for paper in recommendations
|
114
|
-
if paper.get("title") and paper.get("
|
133
|
+
if paper.get("title") and paper.get("authors")
|
115
134
|
}
|
116
135
|
|
136
|
+
content = "Recommendations based on multiple papers was successful."
|
137
|
+
content += " Here is a summary of the recommendations:"
|
138
|
+
content += f"Number of papers found: {len(filtered_papers)}\n"
|
139
|
+
content += f"Query Paper IDs: {', '.join(paper_ids)}\n"
|
140
|
+
content += f"Year: {year}\n" if year else ""
|
141
|
+
|
117
142
|
return Command(
|
118
143
|
update={
|
119
144
|
"multi_papers": filtered_papers, # Now sending the dictionary directly
|
145
|
+
"last_displayed_papers": "multi_papers",
|
120
146
|
"messages": [
|
121
147
|
ToolMessage(
|
122
|
-
content=
|
123
|
-
tool_call_id=tool_call_id
|
148
|
+
content=content,
|
149
|
+
tool_call_id=tool_call_id,
|
150
|
+
artifact=filtered_papers,
|
124
151
|
)
|
125
152
|
],
|
126
153
|
}
|
@@ -0,0 +1,61 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to display the table of studies.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated
|
9
|
+
import pandas as pd
|
10
|
+
from langchain_experimental.agents import create_pandas_dataframe_agent
|
11
|
+
from langchain_core.tools import tool
|
12
|
+
from langgraph.prebuilt import InjectedState
|
13
|
+
|
14
|
+
# Configure logging
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class NoPapersFoundError(Exception):
|
20
|
+
"""Exception raised when no papers are found in the state."""
|
21
|
+
|
22
|
+
|
23
|
+
@tool("query_results", parse_docstring=True)
|
24
|
+
def query_results(question: str, state: Annotated[dict, InjectedState]) -> str:
|
25
|
+
"""
|
26
|
+
Query the last displayed papers from the state. If no papers are found,
|
27
|
+
raises an exception.
|
28
|
+
|
29
|
+
Use this also to get the last displayed papers from the state,
|
30
|
+
and then use the papers to get recommendations for a single paper or
|
31
|
+
multiple papers.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
question (str): The question to ask the agent.
|
35
|
+
state (dict): The state of the agent containing the papers.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
str: A message with the last displayed papers.
|
39
|
+
"""
|
40
|
+
logger.info("Querying last displayed papers with question: %s", question)
|
41
|
+
llm_model = state.get("llm_model")
|
42
|
+
if not state.get("last_displayed_papers"):
|
43
|
+
logger.info("No papers displayed so far, raising NoPapersFoundError")
|
44
|
+
raise NoPapersFoundError(
|
45
|
+
"No papers found. A search needs to be performed first."
|
46
|
+
)
|
47
|
+
context_key = state.get("last_displayed_papers")
|
48
|
+
dic_papers = state.get(context_key)
|
49
|
+
df_papers = pd.DataFrame.from_dict(dic_papers, orient="index")
|
50
|
+
df_agent = create_pandas_dataframe_agent(
|
51
|
+
llm_model,
|
52
|
+
allow_dangerous_code=True,
|
53
|
+
agent_type="tool-calling",
|
54
|
+
df=df_papers,
|
55
|
+
max_iterations=5,
|
56
|
+
include_df_in_prompt=True,
|
57
|
+
number_of_head_rows=df_papers.shape[0],
|
58
|
+
verbose=True,
|
59
|
+
)
|
60
|
+
llm_result = df_agent.invoke(question, stream_mode=None)
|
61
|
+
return llm_result["output"]
|
@@ -0,0 +1,79 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to search for academic papers on Semantic Scholar.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated, Any
|
9
|
+
import hydra
|
10
|
+
import requests
|
11
|
+
from langchain_core.messages import ToolMessage
|
12
|
+
from langchain_core.tools import tool
|
13
|
+
from langchain_core.tools.base import InjectedToolCallId
|
14
|
+
from langgraph.types import Command
|
15
|
+
from pydantic import Field
|
16
|
+
|
17
|
+
|
18
|
+
# Configure logging
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
# Load hydra configuration
|
23
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
24
|
+
cfg = hydra.compose(
|
25
|
+
config_name="config",
|
26
|
+
overrides=["tools/retrieve_semantic_scholar_paper_id=default"],
|
27
|
+
)
|
28
|
+
cfg = cfg.tools.retrieve_semantic_scholar_paper_id
|
29
|
+
|
30
|
+
|
31
|
+
@tool("retrieve_semantic_scholar_paper_id", parse_docstring=True)
|
32
|
+
def retrieve_semantic_scholar_paper_id(
|
33
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
34
|
+
paper_title: str = Field(
|
35
|
+
description="The title of the paper to search for on Semantic Scholar."
|
36
|
+
),
|
37
|
+
) -> Command[Any]:
|
38
|
+
"""
|
39
|
+
This tool can be used to search for a paper on Semantic Scholar
|
40
|
+
and retrieve the paper Semantic Scholar ID.
|
41
|
+
|
42
|
+
This is useful for when an article is retrieved from users Zotero library
|
43
|
+
and the Semantic Scholar ID is needed to retrieve more information about the paper.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
47
|
+
paper_title (str): The title of the paper to search for on Semantic Scholar.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
ToolMessage: A message containing the paper ID.
|
51
|
+
"""
|
52
|
+
logger.info("Retrieving ID of paper with title: %s", paper_title)
|
53
|
+
endpoint = cfg.api_endpoint
|
54
|
+
params = {
|
55
|
+
"query": paper_title,
|
56
|
+
"limit": 1,
|
57
|
+
"fields": ",".join(cfg.api_fields),
|
58
|
+
}
|
59
|
+
|
60
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
61
|
+
data = response.json()
|
62
|
+
papers = data.get("data", [])
|
63
|
+
logger.info("Received %d papers", len(papers))
|
64
|
+
if not papers:
|
65
|
+
logger.error("No papers found for query: %s", paper_title)
|
66
|
+
raise ValueError(f"No papers found for query: {paper_title}. Try again.")
|
67
|
+
# Get the paper ID
|
68
|
+
paper_id = papers[0]["paperId"]
|
69
|
+
|
70
|
+
return Command(
|
71
|
+
update={
|
72
|
+
"messages": [
|
73
|
+
ToolMessage(
|
74
|
+
content=f"Paper ID for '{paper_title}' is: {paper_id}",
|
75
|
+
tool_call_id=tool_call_id,
|
76
|
+
)
|
77
|
+
],
|
78
|
+
}
|
79
|
+
)
|