aiagents4pharma 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. aiagents4pharma/talk2biomodels/configs/config.yaml +5 -0
  2. aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
  3. aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
  4. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
  6. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
  7. aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
  8. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
  9. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
  10. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
  11. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
  12. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
  13. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
  14. aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
  15. aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
  16. aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
  17. aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
  18. aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
  19. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
  20. aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
  21. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
  22. aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
  23. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
  24. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/METADATA +2 -2
  25. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/RECORD +28 -24
  26. aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
  27. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/LICENSE +0 -0
  28. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/WHEEL +0 -0
  29. {aiagents4pharma-1.20.0.dist-info → aiagents4pharma-1.21.0.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,20 @@ Unit tests for S2 tools functionality.
4
4
 
5
5
  # pylint: disable=redefined-outer-name
6
6
  from unittest.mock import patch
7
- from langchain_core.messages import ToolMessage
7
+ from unittest.mock import MagicMock
8
8
  import pytest
9
- from ..tools.s2.display_results import display_results, NoPapersFoundError
9
+ from langgraph.types import Command
10
+ from ..tools.s2.display_results import (
11
+ display_results,
12
+ NoPapersFoundError as raised_error,
13
+ )
10
14
  from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
11
15
  from ..tools.s2.search import search_tool
12
16
  from ..tools.s2.single_paper_rec import get_single_paper_recommendations
17
+ from ..tools.s2.query_results import query_results, NoPapersFoundError
18
+ from ..tools.s2.retrieve_semantic_scholar_paper_id import (
19
+ retrieve_semantic_scholar_paper_id,
20
+ )
13
21
 
14
22
 
15
23
  @pytest.fixture
@@ -50,20 +58,29 @@ class TestS2Tools:
50
58
  def test_display_results_empty_state(self, initial_state):
51
59
  """Verifies display_results tool behavior when state is empty and raises an exception"""
52
60
  with pytest.raises(
53
- NoPapersFoundError,
54
- match="No papers found. A search needs to be performed first.",
61
+ raised_error,
62
+ match="No papers found. A search/rec needs to be performed first.",
55
63
  ):
56
- display_results.invoke({"state": initial_state})
64
+ display_results.invoke({"state": initial_state, "tool_call_id": "test123"})
57
65
 
58
66
  def test_display_results_shows_papers(self, initial_state):
59
67
  """Verifies display_results tool correctly returns papers from state"""
60
68
  state = initial_state.copy()
69
+ state["last_displayed_papers"] = "papers"
61
70
  state["papers"] = MOCK_STATE_PAPER
62
- state["multi_papers"] = {}
63
- result = display_results.invoke(input={"state": state})
64
- assert isinstance(result, dict)
65
- assert result["papers"] == MOCK_STATE_PAPER
66
- assert result["multi_papers"] == {}
71
+
72
+ result = display_results.invoke(
73
+ input={"state": state, "tool_call_id": "test123"}
74
+ )
75
+
76
+ assert isinstance(result, Command) # Expect a Command object
77
+ assert isinstance(result.update, dict) # Ensure update is a dictionary
78
+ assert "messages" in result.update
79
+ assert len(result.update["messages"]) == 1
80
+ assert (
81
+ "1 papers found. Papers are attached as an artifact."
82
+ in result.update["messages"][0].content
83
+ )
67
84
 
68
85
  @patch("requests.get")
69
86
  def test_search_finds_papers(self, mock_get):
@@ -152,14 +169,12 @@ class TestS2Tools:
152
169
  mock_get.return_value.status_code = 200
153
170
 
154
171
  result = get_single_paper_recommendations.invoke(
155
- input={
156
- "paper_id": "123",
157
- "limit": 1,
158
- "tool_call_id": "test123",
159
- }
172
+ input={"paper_id": "123", "limit": 1, "tool_call_id": "test123"}
160
173
  )
174
+
175
+ assert isinstance(result, Command)
161
176
  assert "papers" in result.update
162
- assert isinstance(result.update["messages"][0], ToolMessage)
177
+ assert len(result.update["messages"]) == 1
163
178
 
164
179
  @patch("requests.get")
165
180
  def test_single_paper_rec_with_optional_params(self, mock_get):
@@ -189,14 +204,12 @@ class TestS2Tools:
189
204
  mock_post.return_value.status_code = 200
190
205
 
191
206
  result = get_multi_paper_recommendations.invoke(
192
- input={
193
- "paper_ids": ["123", "456"],
194
- "limit": 1,
195
- "tool_call_id": "test123",
196
- }
207
+ input={"paper_ids": ["123", "456"], "limit": 1, "tool_call_id": "test123"}
197
208
  )
209
+
210
+ assert isinstance(result, Command)
198
211
  assert "multi_papers" in result.update
199
- assert isinstance(result.update["messages"][0], ToolMessage)
212
+ assert len(result.update["messages"]) == 1
200
213
 
201
214
  @patch("requests.post")
202
215
  def test_multi_paper_rec_with_optional_params(self, mock_post):
@@ -217,3 +230,126 @@ class TestS2Tools:
217
230
  )
218
231
  assert "multi_papers" in result.update
219
232
  assert len(result.update["messages"]) == 1
233
+
234
+ @patch("requests.get")
235
+ def test_search_tool_finds_papers(self, mock_get):
236
+ """Verifies search tool finds and formats papers correctly"""
237
+ mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
238
+ mock_get.return_value.status_code = 200
239
+
240
+ result = search_tool.invoke(
241
+ input={"query": "machine learning", "limit": 1, "tool_call_id": "test123"}
242
+ )
243
+
244
+ assert isinstance(result, Command) # Expect a Command object
245
+ assert "papers" in result.update
246
+ assert len(result.update["papers"]) > 0
247
+
248
+ def test_query_results_empty_state(self, initial_state):
249
+ """Tests query_results tool behavior when no papers are found."""
250
+ with pytest.raises(
251
+ NoPapersFoundError,
252
+ match="No papers found. A search needs to be performed first.",
253
+ ):
254
+ query_results.invoke(
255
+ {"question": "List all papers", "state": initial_state}
256
+ )
257
+
258
+ @patch(
259
+ "aiagents4pharma.talk2scholars.tools.s2.query_results.create_pandas_dataframe_agent"
260
+ )
261
+ def test_query_results_with_papers(self, mock_create_agent, initial_state):
262
+ """Tests querying papers when data is available."""
263
+ state = initial_state.copy()
264
+ state["last_displayed_papers"] = "papers"
265
+ state["papers"] = MOCK_STATE_PAPER
266
+
267
+ # Mock the dataframe agent instead of the LLM
268
+ mock_agent = MagicMock()
269
+ mock_agent.invoke.return_value = {"output": "Mocked response"}
270
+
271
+ mock_create_agent.return_value = (
272
+ mock_agent # Mock the function returning the agent
273
+ )
274
+
275
+ # Ensure that the output of query_results is correctly structured
276
+ result = query_results.invoke({"question": "List all papers", "state": state})
277
+
278
+ assert isinstance(result, str) # Ensure output is a string
279
+ assert result == "Mocked response" # Validate the expected response
280
+
281
+ @patch("requests.get")
282
+ def test_retrieve_semantic_scholar_paper_id(self, mock_get):
283
+ """Tests retrieving a paper ID from Semantic Scholar."""
284
+ mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
285
+ mock_get.return_value.status_code = 200
286
+
287
+ result = retrieve_semantic_scholar_paper_id.invoke(
288
+ input={"paper_title": "Machine Learning Basics", "tool_call_id": "test123"}
289
+ )
290
+
291
+ assert isinstance(result, Command)
292
+ assert "messages" in result.update
293
+ assert (
294
+ "Paper ID for 'Machine Learning Basics' is: 123"
295
+ in result.update["messages"][0].content
296
+ )
297
+
298
+ def test_retrieve_semantic_scholar_paper_id_no_results(self):
299
+ """Test retrieving a paper ID when no results are found."""
300
+ with pytest.raises(ValueError, match="No papers found for query: UnknownPaper"):
301
+ retrieve_semantic_scholar_paper_id.invoke(
302
+ input={"paper_title": "UnknownPaper", "tool_call_id": "test123"}
303
+ )
304
+
305
+ def test_single_paper_rec_invalid_id(self):
306
+ """Test single paper recommendation with an invalid ID."""
307
+ with pytest.raises(ValueError, match="Invalid paper ID or API error."):
308
+ get_single_paper_recommendations.invoke(
309
+ input={"paper_id": "", "tool_call_id": "test123"} # Empty ID case
310
+ )
311
+
312
+ @patch("requests.post")
313
+ def test_multi_paper_rec_no_recommendations(self, mock_post):
314
+ """Tests behavior when multi-paper recommendation API returns no results."""
315
+ mock_post.return_value.json.return_value = {"recommendedPapers": []}
316
+ mock_post.return_value.status_code = 200
317
+
318
+ result = get_multi_paper_recommendations.invoke(
319
+ input={"paper_ids": ["123", "456"], "limit": 1, "tool_call_id": "test123"}
320
+ )
321
+
322
+ assert isinstance(result, Command)
323
+ assert "messages" in result.update
324
+ assert (
325
+ "No recommendations found based on multiple papers."
326
+ in result.update["messages"][0].content
327
+ )
328
+
329
+ @patch("requests.get")
330
+ def test_search_no_results(self, mock_get):
331
+ """Tests behavior when search API returns no results."""
332
+ mock_get.return_value.json.return_value = {"data": []}
333
+ mock_get.return_value.status_code = 200
334
+
335
+ result = search_tool.invoke(
336
+ input={"query": "nonexistent topic", "limit": 1, "tool_call_id": "test123"}
337
+ )
338
+
339
+ assert isinstance(result, Command)
340
+ assert "messages" in result.update
341
+ assert "No papers found." in result.update["messages"][0].content
342
+
343
+ @patch("requests.get")
344
+ def test_single_paper_rec_no_recommendations(self, mock_get):
345
+ """Tests behavior when single paper recommendation API returns no results."""
346
+ mock_get.return_value.json.return_value = {"recommendedPapers": []}
347
+ mock_get.return_value.status_code = 200
348
+
349
+ result = get_single_paper_recommendations.invoke(
350
+ input={"paper_id": "123", "limit": 1, "tool_call_id": "test123"}
351
+ )
352
+
353
+ assert isinstance(result, Command)
354
+ assert "messages" in result.update
355
+ assert "No recommendations found for" in result.update["messages"][0].content
@@ -1,8 +1,10 @@
1
- '''
1
+ """
2
2
  This file is used to import all the modules in the package.
3
- '''
3
+ """
4
4
 
5
5
  from . import display_results
6
6
  from . import multi_paper_rec
7
7
  from . import search
8
8
  from . import single_paper_rec
9
+ from . import query_results
10
+ from . import retrieve_semantic_scholar_paper_id
@@ -1,13 +1,24 @@
1
1
  #!/usr/bin/env python3
2
2
 
3
+
3
4
  """
4
- This tool is used to display the table of studies.
5
+ Tool for displaying search or recommendation results.
6
+
7
+ This module defines a tool that retrieves and displays a table of research papers
8
+ found during searches or recommendations. If no papers are found, an exception is raised
9
+ to signal the need for a new search.
5
10
  """
6
11
 
12
+
7
13
  import logging
8
- from typing import Annotated, Dict, Any
14
+
15
+ from typing import Annotated
16
+ from langchain_core.messages import ToolMessage
9
17
  from langchain_core.tools import tool
18
+ from langchain_core.tools.base import InjectedToolCallId
10
19
  from langgraph.prebuilt import InjectedState
20
+ from langgraph.types import Command
21
+
11
22
 
12
23
  # Configure logging
13
24
  logging.basicConfig(level=logging.INFO)
@@ -15,36 +26,64 @@ logger = logging.getLogger(__name__)
15
26
 
16
27
 
17
28
  class NoPapersFoundError(Exception):
18
- """Exception raised when no papers are found in the state."""
29
+ """
30
+ Exception raised when no research papers are found in the agent's state.
19
31
 
32
+ This exception helps the language model determine whether a new search
33
+ or recommendation should be initiated.
20
34
 
21
- @tool("display_results")
22
- def display_results(state: Annotated[dict, InjectedState]) -> Dict[str, Any]:
35
+ Example:
36
+ >>> if not papers:
37
+ >>> raise NoPapersFoundError("No papers found. A search is needed.")
23
38
  """
24
- Display the papers in the state. If no papers are found, raises an exception
25
- indicating that a search is needed.
39
+
40
+
41
+ @tool("display_results", parse_docstring=True)
42
+ def display_results(
43
+ tool_call_id: Annotated[str, InjectedToolCallId],
44
+ state: Annotated[dict, InjectedState],
45
+ ) -> Command:
46
+ """
47
+ Displays retrieved research papers after a search or recommendation.
48
+
49
+ This function retrieves the last displayed research papers from the state and
50
+ returns them as an artifact for further processing. If no papers are found,
51
+ it raises a `NoPapersFoundError` to indicate that a new search is needed.
26
52
 
27
53
  Args:
28
- state (dict): The state of the agent containing the papers.
54
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID for tracking.
55
+ state (dict): The agent's state containing retrieved papers.
29
56
 
30
57
  Returns:
31
- dict: A dictionary containing the papers and multi_papers from the state.
58
+ Command: A command containing a message with the number of displayed papers
59
+ and an attached artifact for further reference.
32
60
 
33
61
  Raises:
34
- NoPapersFoundError: If no papers are found in the state.
62
+ NoPapersFoundError: If no research papers are found in the agent's state.
35
63
 
36
- Note:
37
- The exception allows the LLM to make a more informed decision about initiating a search.
64
+ Example:
65
+ >>> state = {"last_displayed_papers": {"paper1": "Title 1", "paper2": "Title 2"}}
66
+ >>> result = display_results(tool_call_id="123", state=state)
67
+ >>> print(result.update["messages"][0].content)
68
+ "2 papers found. Papers are attached as an artifact."
38
69
  """
39
- logger.info("Displaying papers from the state")
40
-
41
- if not state.get("papers") and not state.get("multi_papers"):
70
+ logger.info("Displaying papers")
71
+ context_key = state.get("last_displayed_papers")
72
+ artifact = state.get(context_key)
73
+ if not artifact:
42
74
  logger.info("No papers found in state, raising NoPapersFoundError")
43
75
  raise NoPapersFoundError(
44
- "No papers found. A search needs to be performed first."
76
+ "No papers found. A search/rec needs to be performed first."
45
77
  )
46
-
47
- return {
48
- "papers": state.get("papers"),
49
- "multi_papers": state.get("multi_papers"),
50
- }
78
+ content = f"{len(artifact)} papers found. Papers are attached as an artifact."
79
+ return Command(
80
+ update={
81
+ "messages": [
82
+ ToolMessage(
83
+ content=content,
84
+ tool_call_id=tool_call_id,
85
+ artifact=artifact,
86
+ )
87
+ ],
88
+ }
89
+ )
@@ -7,7 +7,7 @@ multi_paper_rec: Tool for getting recommendations
7
7
 
8
8
  import json
9
9
  import logging
10
- from typing import Annotated, Any, Dict, List, Optional
10
+ from typing import Annotated, Any, List, Optional
11
11
  import hydra
12
12
  import requests
13
13
  from langchain_core.messages import ToolMessage
@@ -52,15 +52,16 @@ with hydra.initialize(version_base=None, config_path="../../configs"):
52
52
  cfg = cfg.tools.multi_paper_recommendation
53
53
 
54
54
 
55
- @tool(args_schema=MultiPaperRecInput)
55
+ @tool(args_schema=MultiPaperRecInput, parse_docstring=True)
56
56
  def get_multi_paper_recommendations(
57
57
  paper_ids: List[str],
58
58
  tool_call_id: Annotated[str, InjectedToolCallId],
59
59
  limit: int = 2,
60
60
  year: Optional[str] = None,
61
- ) -> Dict[str, Any]:
61
+ ) -> Command[Any]:
62
62
  """
63
- Get paper recommendations based on multiple papers.
63
+ Get recommendations for a group of multiple papers using the Semantic Scholar IDs.
64
+ No other paper IDs are supported.
64
65
 
65
66
  Args:
66
67
  paper_ids (List[str]): The list of paper IDs to base recommendations on.
@@ -72,7 +73,9 @@ def get_multi_paper_recommendations(
72
73
  Returns:
73
74
  Dict[str, Any]: The recommendations and related information.
74
75
  """
75
- logging.info("Starting multi-paper recommendations search.")
76
+ logging.info(
77
+ "Starting multi-paper recommendations search with paper IDs: %s", paper_ids
78
+ )
76
79
 
77
80
  endpoint = cfg.api_endpoint
78
81
  headers = cfg.headers
@@ -101,26 +104,50 @@ def get_multi_paper_recommendations(
101
104
  data = response.json()
102
105
  recommendations = data.get("recommendedPapers", [])
103
106
 
107
+ if not recommendations:
108
+ return Command(
109
+ update={ # Place 'messages' inside 'update'
110
+ "messages": [
111
+ ToolMessage(
112
+ content="No recommendations found based on multiple papers.",
113
+ tool_call_id=tool_call_id,
114
+ )
115
+ ]
116
+ }
117
+ )
118
+
104
119
  # Create a dictionary to store the papers
105
120
  filtered_papers = {
106
121
  paper["paperId"]: {
122
+ # "semantic_scholar_id": paper["paperId"], # Store Semantic Scholar ID
107
123
  "Title": paper.get("title", "N/A"),
108
124
  "Abstract": paper.get("abstract", "N/A"),
109
125
  "Year": paper.get("year", "N/A"),
110
126
  "Citation Count": paper.get("citationCount", "N/A"),
111
127
  "URL": paper.get("url", "N/A"),
128
+ # "arXiv_ID": paper.get("externalIds", {}).get(
129
+ # "ArXiv", "N/A"
130
+ # ), # Extract arXiv ID
112
131
  }
113
132
  for paper in recommendations
114
- if paper.get("title") and paper.get("paperId")
133
+ if paper.get("title") and paper.get("authors")
115
134
  }
116
135
 
136
+ content = "Recommendations based on multiple papers was successful."
137
+ content += " Here is a summary of the recommendations:"
138
+ content += f"Number of papers found: {len(filtered_papers)}\n"
139
+ content += f"Query Paper IDs: {', '.join(paper_ids)}\n"
140
+ content += f"Year: {year}\n" if year else ""
141
+
117
142
  return Command(
118
143
  update={
119
144
  "multi_papers": filtered_papers, # Now sending the dictionary directly
145
+ "last_displayed_papers": "multi_papers",
120
146
  "messages": [
121
147
  ToolMessage(
122
- content=f"Search Successful: {filtered_papers}",
123
- tool_call_id=tool_call_id
148
+ content=content,
149
+ tool_call_id=tool_call_id,
150
+ artifact=filtered_papers,
124
151
  )
125
152
  ],
126
153
  }
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This tool is used to display the table of studies.
5
+ """
6
+
7
+ import logging
8
+ from typing import Annotated
9
+ import pandas as pd
10
+ from langchain_experimental.agents import create_pandas_dataframe_agent
11
+ from langchain_core.tools import tool
12
+ from langgraph.prebuilt import InjectedState
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class NoPapersFoundError(Exception):
20
+ """Exception raised when no papers are found in the state."""
21
+
22
+
23
+ @tool("query_results", parse_docstring=True)
24
+ def query_results(question: str, state: Annotated[dict, InjectedState]) -> str:
25
+ """
26
+ Query the last displayed papers from the state. If no papers are found,
27
+ raises an exception.
28
+
29
+ Use this also to get the last displayed papers from the state,
30
+ and then use the papers to get recommendations for a single paper or
31
+ multiple papers.
32
+
33
+ Args:
34
+ question (str): The question to ask the agent.
35
+ state (dict): The state of the agent containing the papers.
36
+
37
+ Returns:
38
+ str: A message with the last displayed papers.
39
+ """
40
+ logger.info("Querying last displayed papers with question: %s", question)
41
+ llm_model = state.get("llm_model")
42
+ if not state.get("last_displayed_papers"):
43
+ logger.info("No papers displayed so far, raising NoPapersFoundError")
44
+ raise NoPapersFoundError(
45
+ "No papers found. A search needs to be performed first."
46
+ )
47
+ context_key = state.get("last_displayed_papers")
48
+ dic_papers = state.get(context_key)
49
+ df_papers = pd.DataFrame.from_dict(dic_papers, orient="index")
50
+ df_agent = create_pandas_dataframe_agent(
51
+ llm_model,
52
+ allow_dangerous_code=True,
53
+ agent_type="tool-calling",
54
+ df=df_papers,
55
+ max_iterations=5,
56
+ include_df_in_prompt=True,
57
+ number_of_head_rows=df_papers.shape[0],
58
+ verbose=True,
59
+ )
60
+ llm_result = df_agent.invoke(question, stream_mode=None)
61
+ return llm_result["output"]
@@ -0,0 +1,79 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This tool is used to search for academic papers on Semantic Scholar.
5
+ """
6
+
7
+ import logging
8
+ from typing import Annotated, Any
9
+ import hydra
10
+ import requests
11
+ from langchain_core.messages import ToolMessage
12
+ from langchain_core.tools import tool
13
+ from langchain_core.tools.base import InjectedToolCallId
14
+ from langgraph.types import Command
15
+ from pydantic import Field
16
+
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Load hydra configuration
23
+ with hydra.initialize(version_base=None, config_path="../../configs"):
24
+ cfg = hydra.compose(
25
+ config_name="config",
26
+ overrides=["tools/retrieve_semantic_scholar_paper_id=default"],
27
+ )
28
+ cfg = cfg.tools.retrieve_semantic_scholar_paper_id
29
+
30
+
31
+ @tool("retrieve_semantic_scholar_paper_id", parse_docstring=True)
32
+ def retrieve_semantic_scholar_paper_id(
33
+ tool_call_id: Annotated[str, InjectedToolCallId],
34
+ paper_title: str = Field(
35
+ description="The title of the paper to search for on Semantic Scholar."
36
+ ),
37
+ ) -> Command[Any]:
38
+ """
39
+ This tool can be used to search for a paper on Semantic Scholar
40
+ and retrieve the paper Semantic Scholar ID.
41
+
42
+ This is useful for when an article is retrieved from users Zotero library
43
+ and the Semantic Scholar ID is needed to retrieve more information about the paper.
44
+
45
+ Args:
46
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
47
+ paper_title (str): The title of the paper to search for on Semantic Scholar.
48
+
49
+ Returns:
50
+ ToolMessage: A message containing the paper ID.
51
+ """
52
+ logger.info("Retrieving ID of paper with title: %s", paper_title)
53
+ endpoint = cfg.api_endpoint
54
+ params = {
55
+ "query": paper_title,
56
+ "limit": 1,
57
+ "fields": ",".join(cfg.api_fields),
58
+ }
59
+
60
+ response = requests.get(endpoint, params=params, timeout=10)
61
+ data = response.json()
62
+ papers = data.get("data", [])
63
+ logger.info("Received %d papers", len(papers))
64
+ if not papers:
65
+ logger.error("No papers found for query: %s", paper_title)
66
+ raise ValueError(f"No papers found for query: {paper_title}. Try again.")
67
+ # Get the paper ID
68
+ paper_id = papers[0]["paperId"]
69
+
70
+ return Command(
71
+ update={
72
+ "messages": [
73
+ ToolMessage(
74
+ content=f"Paper ID for '{paper_title}' is: {paper_id}",
75
+ tool_call_id=tool_call_id,
76
+ )
77
+ ],
78
+ }
79
+ )