aiagents4pharma 1.9.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +9 -6
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +7 -10
- aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
- aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +3 -3
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +7 -8
- aiagents4pharma/talk2biomodels/tests/test_langgraph.py +64 -2
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +13 -7
- aiagents4pharma/talk2biomodels/tools/__init__.py +1 -0
- aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +5 -3
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +292 -0
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +9 -11
- aiagents4pharma/talk2competitors/__init__.py +5 -0
- aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
- aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
- aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
- aiagents4pharma/talk2competitors/config/__init__.py +5 -0
- aiagents4pharma/talk2competitors/config/config.py +110 -0
- aiagents4pharma/talk2competitors/state/__init__.py +5 -0
- aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
- aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
- aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
- aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
- aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
- aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
- aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
- aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
- aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.11.0.dist-info}/METADATA +39 -23
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.11.0.dist-info}/RECORD +33 -17
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.11.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.11.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
|
|
1
|
+
"""
|
2
|
+
Unit and integration tests for Talk2Competitors system.
|
3
|
+
Each test focuses on a single, specific functionality.
|
4
|
+
Tests are deterministic and independent of each other.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from unittest.mock import Mock, patch
|
8
|
+
|
9
|
+
import pytest
|
10
|
+
from langchain_core.messages import AIMessage, HumanMessage
|
11
|
+
|
12
|
+
from ..agents.main_agent import get_app, make_supervisor_node
|
13
|
+
from ..state.state_talk2competitors import replace_dict
|
14
|
+
from ..tools.s2.display_results import display_results
|
15
|
+
from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
|
16
|
+
from ..tools.s2.search import search_tool
|
17
|
+
from ..tools.s2.single_paper_rec import get_single_paper_recommendations
|
18
|
+
|
19
|
+
# pylint: disable=redefined-outer-name
|
20
|
+
|
21
|
+
# Fixed test data for deterministic results
|
22
|
+
MOCK_SEARCH_RESPONSE = {
|
23
|
+
"data": [
|
24
|
+
{
|
25
|
+
"paperId": "123",
|
26
|
+
"title": "Machine Learning Basics",
|
27
|
+
"abstract": "An introduction to ML",
|
28
|
+
"year": 2023,
|
29
|
+
"citationCount": 100,
|
30
|
+
"url": "https://example.com/paper1",
|
31
|
+
"authors": [{"name": "Test Author"}],
|
32
|
+
}
|
33
|
+
]
|
34
|
+
}
|
35
|
+
|
36
|
+
MOCK_STATE_PAPER = {
|
37
|
+
"123": {
|
38
|
+
"Title": "Machine Learning Basics",
|
39
|
+
"Abstract": "An introduction to ML",
|
40
|
+
"Year": 2023,
|
41
|
+
"Citation Count": 100,
|
42
|
+
"URL": "https://example.com/paper1",
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
@pytest.fixture
|
48
|
+
def initial_state():
|
49
|
+
"""Create a base state for tests"""
|
50
|
+
return {
|
51
|
+
"messages": [],
|
52
|
+
"papers": {},
|
53
|
+
"is_last_step": False,
|
54
|
+
"current_agent": None,
|
55
|
+
"llm_model": "gpt-4o-mini",
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
class TestMainAgent:
|
60
|
+
"""Unit tests for main agent functionality"""
|
61
|
+
|
62
|
+
def test_supervisor_routes_search_to_s2(self, initial_state):
|
63
|
+
"""Verifies that search-related queries are routed to S2 agent"""
|
64
|
+
llm_mock = Mock()
|
65
|
+
llm_mock.invoke.return_value = AIMessage(content="Search initiated")
|
66
|
+
|
67
|
+
supervisor = make_supervisor_node(llm_mock)
|
68
|
+
state = initial_state.copy()
|
69
|
+
state["messages"] = [HumanMessage(content="search for papers")]
|
70
|
+
|
71
|
+
result = supervisor(state)
|
72
|
+
assert result.goto == "s2_agent"
|
73
|
+
assert not result.update["is_last_step"]
|
74
|
+
assert result.update["current_agent"] == "s2_agent"
|
75
|
+
|
76
|
+
def test_supervisor_routes_general_to_end(self, initial_state):
|
77
|
+
"""Verifies that non-search queries end the conversation"""
|
78
|
+
llm_mock = Mock()
|
79
|
+
llm_mock.invoke.return_value = AIMessage(content="General response")
|
80
|
+
|
81
|
+
supervisor = make_supervisor_node(llm_mock)
|
82
|
+
state = initial_state.copy()
|
83
|
+
state["messages"] = [HumanMessage(content="What is ML?")]
|
84
|
+
|
85
|
+
result = supervisor(state)
|
86
|
+
assert result.goto == "__end__"
|
87
|
+
assert result.update["is_last_step"]
|
88
|
+
|
89
|
+
|
90
|
+
class TestS2Tools:
|
91
|
+
"""Unit tests for individual S2 tools"""
|
92
|
+
|
93
|
+
def test_display_results_shows_papers(self, initial_state):
|
94
|
+
"""Verifies display_results tool correctly returns papers from state"""
|
95
|
+
state = initial_state.copy()
|
96
|
+
state["papers"] = MOCK_STATE_PAPER
|
97
|
+
result = display_results.invoke(input={"state": state})
|
98
|
+
assert result == MOCK_STATE_PAPER
|
99
|
+
assert isinstance(result, dict)
|
100
|
+
|
101
|
+
@patch("requests.get")
|
102
|
+
def test_search_finds_papers(self, mock_get):
|
103
|
+
"""Verifies search tool finds and formats papers correctly"""
|
104
|
+
mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
|
105
|
+
mock_get.return_value.status_code = 200
|
106
|
+
|
107
|
+
result = search_tool.invoke(
|
108
|
+
input={
|
109
|
+
"query": "machine learning",
|
110
|
+
"limit": 1,
|
111
|
+
"tool_call_id": "test123",
|
112
|
+
"id": "test123",
|
113
|
+
}
|
114
|
+
)
|
115
|
+
|
116
|
+
assert "papers" in result.update
|
117
|
+
assert "messages" in result.update
|
118
|
+
papers = result.update["papers"]
|
119
|
+
assert isinstance(papers, dict)
|
120
|
+
assert len(papers) > 0
|
121
|
+
paper = next(iter(papers.values()))
|
122
|
+
assert paper["Title"] == "Machine Learning Basics"
|
123
|
+
assert paper["Year"] == 2023
|
124
|
+
|
125
|
+
@patch("requests.get")
|
126
|
+
def test_single_paper_rec_basic(self, mock_get):
|
127
|
+
"""Tests basic single paper recommendation functionality"""
|
128
|
+
mock_get.return_value.json.return_value = {
|
129
|
+
"recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
|
130
|
+
}
|
131
|
+
mock_get.return_value.status_code = 200
|
132
|
+
|
133
|
+
result = get_single_paper_recommendations.invoke(
|
134
|
+
input={
|
135
|
+
"paper_id": "123",
|
136
|
+
"limit": 1,
|
137
|
+
"tool_call_id": "test123",
|
138
|
+
"id": "test123",
|
139
|
+
}
|
140
|
+
)
|
141
|
+
assert "papers" in result.update
|
142
|
+
assert len(result.update["messages"]) == 1
|
143
|
+
|
144
|
+
@patch("requests.get")
|
145
|
+
def test_single_paper_rec_with_optional_params(self, mock_get):
|
146
|
+
"""Tests single paper recommendations with year parameter"""
|
147
|
+
mock_get.return_value.json.return_value = {
|
148
|
+
"recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
|
149
|
+
}
|
150
|
+
mock_get.return_value.status_code = 200
|
151
|
+
|
152
|
+
result = get_single_paper_recommendations.invoke(
|
153
|
+
input={
|
154
|
+
"paper_id": "123",
|
155
|
+
"limit": 1,
|
156
|
+
"year": "2023-",
|
157
|
+
"tool_call_id": "test123",
|
158
|
+
"id": "test123",
|
159
|
+
}
|
160
|
+
)
|
161
|
+
assert "papers" in result.update
|
162
|
+
|
163
|
+
@patch("requests.post")
|
164
|
+
def test_multi_paper_rec_basic(self, mock_post):
|
165
|
+
"""Tests basic multi-paper recommendation functionality"""
|
166
|
+
mock_post.return_value.json.return_value = {
|
167
|
+
"recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
|
168
|
+
}
|
169
|
+
mock_post.return_value.status_code = 200
|
170
|
+
|
171
|
+
result = get_multi_paper_recommendations.invoke(
|
172
|
+
input={
|
173
|
+
"paper_ids": ["123", "456"],
|
174
|
+
"limit": 1,
|
175
|
+
"tool_call_id": "test123",
|
176
|
+
"id": "test123",
|
177
|
+
}
|
178
|
+
)
|
179
|
+
assert "papers" in result.update
|
180
|
+
assert len(result.update["messages"]) == 1
|
181
|
+
|
182
|
+
@patch("requests.post")
|
183
|
+
def test_multi_paper_rec_with_optional_params(self, mock_post):
|
184
|
+
"""Tests multi-paper recommendations with all optional parameters"""
|
185
|
+
mock_post.return_value.json.return_value = {
|
186
|
+
"recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
|
187
|
+
}
|
188
|
+
mock_post.return_value.status_code = 200
|
189
|
+
|
190
|
+
result = get_multi_paper_recommendations.invoke(
|
191
|
+
input={
|
192
|
+
"paper_ids": ["123", "456"],
|
193
|
+
"limit": 1,
|
194
|
+
"year": "2023-",
|
195
|
+
"tool_call_id": "test123",
|
196
|
+
"id": "test123",
|
197
|
+
}
|
198
|
+
)
|
199
|
+
assert "papers" in result.update
|
200
|
+
assert len(result.update["messages"]) == 1
|
201
|
+
|
202
|
+
@patch("requests.get")
|
203
|
+
def test_single_paper_rec_empty_response(self, mock_get):
|
204
|
+
"""Tests single paper recommendations with empty response"""
|
205
|
+
mock_get.return_value.json.return_value = {"recommendedPapers": []}
|
206
|
+
mock_get.return_value.status_code = 200
|
207
|
+
|
208
|
+
result = get_single_paper_recommendations.invoke(
|
209
|
+
input={
|
210
|
+
"paper_id": "123",
|
211
|
+
"limit": 1,
|
212
|
+
"tool_call_id": "test123",
|
213
|
+
"id": "test123",
|
214
|
+
}
|
215
|
+
)
|
216
|
+
assert "papers" in result.update
|
217
|
+
assert len(result.update["papers"]) == 0
|
218
|
+
|
219
|
+
@patch("requests.post")
|
220
|
+
def test_multi_paper_rec_empty_response(self, mock_post):
|
221
|
+
"""Tests multi-paper recommendations with empty response"""
|
222
|
+
mock_post.return_value.json.return_value = {"recommendedPapers": []}
|
223
|
+
mock_post.return_value.status_code = 200
|
224
|
+
|
225
|
+
result = get_multi_paper_recommendations.invoke(
|
226
|
+
input={
|
227
|
+
"paper_ids": ["123", "456"],
|
228
|
+
"limit": 1,
|
229
|
+
"tool_call_id": "test123",
|
230
|
+
"id": "test123",
|
231
|
+
}
|
232
|
+
)
|
233
|
+
assert "papers" in result.update
|
234
|
+
assert len(result.update["papers"]) == 0
|
235
|
+
|
236
|
+
|
237
|
+
def test_state_replace_dict():
|
238
|
+
"""Verifies state dictionary replacement works correctly"""
|
239
|
+
existing = {"key1": "value1", "key2": "value2"}
|
240
|
+
new = {"key3": "value3"}
|
241
|
+
result = replace_dict(existing, new)
|
242
|
+
assert result == new
|
243
|
+
assert isinstance(result, dict)
|
244
|
+
|
245
|
+
|
246
|
+
@pytest.mark.integration
|
247
|
+
def test_end_to_end_search_workflow(initial_state):
|
248
|
+
"""Integration test: Complete search workflow"""
|
249
|
+
with (
|
250
|
+
patch("requests.get") as mock_get,
|
251
|
+
patch("langchain_openai.ChatOpenAI") as mock_llm,
|
252
|
+
):
|
253
|
+
mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
|
254
|
+
mock_get.return_value.status_code = 200
|
255
|
+
|
256
|
+
llm_instance = Mock()
|
257
|
+
llm_instance.invoke.return_value = AIMessage(content="Search completed")
|
258
|
+
mock_llm.return_value = llm_instance
|
259
|
+
|
260
|
+
app = get_app("test_integration")
|
261
|
+
test_state = initial_state.copy()
|
262
|
+
test_state["messages"] = [HumanMessage(content="search for ML papers")]
|
263
|
+
|
264
|
+
config = {
|
265
|
+
"configurable": {
|
266
|
+
"thread_id": "test_integration",
|
267
|
+
"checkpoint_ns": "test",
|
268
|
+
"checkpoint_id": "test123",
|
269
|
+
}
|
270
|
+
}
|
271
|
+
|
272
|
+
response = app.invoke(test_state, config)
|
273
|
+
assert "papers" in response
|
274
|
+
assert len(response["messages"]) > 0
|
@@ -0,0 +1,25 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
'''
|
4
|
+
This tool is used to display the table of studies.
|
5
|
+
'''
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated
|
9
|
+
from langchain_core.tools import tool
|
10
|
+
from langgraph.prebuilt import InjectedState
|
11
|
+
|
12
|
+
# Configure logging
|
13
|
+
logging.basicConfig(level=logging.INFO)
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
@tool('display_results')
|
17
|
+
def display_results(state: Annotated[dict, InjectedState]):
|
18
|
+
"""
|
19
|
+
Display the papers in the state.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
state (dict): The state of the agent.
|
23
|
+
"""
|
24
|
+
logger.info("Displaying papers from the state")
|
25
|
+
return state["papers"]
|
@@ -0,0 +1,132 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
multi_paper_rec: Tool for getting recommendations
|
5
|
+
based on multiple papers
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
from typing import Annotated, Any, Dict, List, Optional
|
11
|
+
|
12
|
+
import pandas as pd
|
13
|
+
import requests
|
14
|
+
from langchain_core.messages import ToolMessage
|
15
|
+
from langchain_core.tools import tool
|
16
|
+
from langchain_core.tools.base import InjectedToolCallId
|
17
|
+
from langgraph.types import Command
|
18
|
+
from pydantic import BaseModel, Field
|
19
|
+
|
20
|
+
|
21
|
+
class MultiPaperRecInput(BaseModel):
|
22
|
+
"""Input schema for multiple paper recommendations tool."""
|
23
|
+
|
24
|
+
paper_ids: List[str] = Field(
|
25
|
+
description=("List of Semantic Scholar Paper IDs to get recommendations for")
|
26
|
+
)
|
27
|
+
limit: int = Field(
|
28
|
+
default=2,
|
29
|
+
description="Maximum total number of recommendations to return",
|
30
|
+
ge=1,
|
31
|
+
le=500,
|
32
|
+
)
|
33
|
+
year: Optional[str] = Field(
|
34
|
+
default=None,
|
35
|
+
description="Year range in format: YYYY for specific year, "
|
36
|
+
"YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
|
37
|
+
)
|
38
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
39
|
+
|
40
|
+
model_config = {"arbitrary_types_allowed": True}
|
41
|
+
|
42
|
+
|
43
|
+
@tool(args_schema=MultiPaperRecInput)
|
44
|
+
def get_multi_paper_recommendations(
|
45
|
+
paper_ids: List[str],
|
46
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
47
|
+
limit: int = 2,
|
48
|
+
year: Optional[str] = None,
|
49
|
+
) -> Dict[str, Any]:
|
50
|
+
"""
|
51
|
+
Get paper recommendations based on multiple papers.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
paper_ids (List[str]): The list of paper IDs to base recommendations on.
|
55
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
56
|
+
limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
|
57
|
+
year (str, optional): Year range for papers.
|
58
|
+
Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Dict[str, Any]: The recommendations and related information.
|
62
|
+
"""
|
63
|
+
logging.info("Starting multi-paper recommendations search.")
|
64
|
+
|
65
|
+
endpoint = "https://api.semanticscholar.org/recommendations/v1/papers"
|
66
|
+
headers = {"Content-Type": "application/json"}
|
67
|
+
payload = {"positivePaperIds": paper_ids, "negativePaperIds": []}
|
68
|
+
params = {
|
69
|
+
"limit": min(limit, 500),
|
70
|
+
"fields": "paperId,title,abstract,year,authors,citationCount,url",
|
71
|
+
}
|
72
|
+
|
73
|
+
# Add year parameter if provided
|
74
|
+
if year:
|
75
|
+
params["year"] = year
|
76
|
+
|
77
|
+
# Getting recommendations
|
78
|
+
response = requests.post(
|
79
|
+
endpoint,
|
80
|
+
headers=headers,
|
81
|
+
params=params,
|
82
|
+
data=json.dumps(payload),
|
83
|
+
timeout=10,
|
84
|
+
)
|
85
|
+
logging.info(
|
86
|
+
"API Response Status for multi-paper recommendations: %s", response.status_code
|
87
|
+
)
|
88
|
+
|
89
|
+
data = response.json()
|
90
|
+
recommendations = data.get("recommendedPapers", [])
|
91
|
+
|
92
|
+
# Create a dictionary to store the papers
|
93
|
+
filtered_papers = {
|
94
|
+
paper["paperId"]: {
|
95
|
+
"Title": paper.get("title", "N/A"),
|
96
|
+
"Abstract": paper.get("abstract", "N/A"),
|
97
|
+
"Year": paper.get("year", "N/A"),
|
98
|
+
"Citation Count": paper.get("citationCount", "N/A"),
|
99
|
+
"URL": paper.get("url", "N/A"),
|
100
|
+
}
|
101
|
+
for paper in recommendations
|
102
|
+
if paper.get("title") and paper.get("paperId")
|
103
|
+
}
|
104
|
+
|
105
|
+
# Create a DataFrame from the dictionary
|
106
|
+
df = pd.DataFrame.from_dict(filtered_papers, orient="index")
|
107
|
+
# print("Created DataFrame with results:")
|
108
|
+
logging.info("Created DataFrame with results: %s", df)
|
109
|
+
|
110
|
+
# Format papers for state update
|
111
|
+
papers = [
|
112
|
+
f"Paper ID: {paper_id}\n"
|
113
|
+
f"Title: {paper_data['Title']}\n"
|
114
|
+
f"Abstract: {paper_data['Abstract']}\n"
|
115
|
+
f"Year: {paper_data['Year']}\n"
|
116
|
+
f"Citations: {paper_data['Citation Count']}\n"
|
117
|
+
f"URL: {paper_data['URL']}\n"
|
118
|
+
for paper_id, paper_data in filtered_papers.items()
|
119
|
+
]
|
120
|
+
|
121
|
+
# Convert DataFrame to markdown table
|
122
|
+
markdown_table = df.to_markdown(tablefmt="grid")
|
123
|
+
logging.info("Search results: %s", papers)
|
124
|
+
|
125
|
+
return Command(
|
126
|
+
update={
|
127
|
+
"papers": filtered_papers, # Now sending the dictionary directly
|
128
|
+
"messages": [
|
129
|
+
ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
|
130
|
+
],
|
131
|
+
}
|
132
|
+
)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to search for academic papers on Semantic Scholar.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated, Any, Dict, Optional
|
9
|
+
|
10
|
+
import pandas as pd
|
11
|
+
import requests
|
12
|
+
from langchain_core.messages import ToolMessage
|
13
|
+
from langchain_core.tools import tool
|
14
|
+
from langchain_core.tools.base import InjectedToolCallId
|
15
|
+
from langgraph.types import Command
|
16
|
+
from pydantic import BaseModel, Field
|
17
|
+
|
18
|
+
|
19
|
+
class SearchInput(BaseModel):
|
20
|
+
"""Input schema for the search papers tool."""
|
21
|
+
|
22
|
+
query: str = Field(
|
23
|
+
description="Search query string to find academic papers."
|
24
|
+
"Be specific and include relevant academic terms."
|
25
|
+
)
|
26
|
+
limit: int = Field(
|
27
|
+
default=2, description="Maximum number of results to return", ge=1, le=100
|
28
|
+
)
|
29
|
+
year: Optional[str] = Field(
|
30
|
+
default=None,
|
31
|
+
description="Year range in format: YYYY for specific year, "
|
32
|
+
"YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
|
33
|
+
)
|
34
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
35
|
+
|
36
|
+
|
37
|
+
@tool(args_schema=SearchInput)
|
38
|
+
def search_tool(
|
39
|
+
query: str,
|
40
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
41
|
+
limit: int = 2,
|
42
|
+
year: Optional[str] = None,
|
43
|
+
) -> Dict[str, Any]:
|
44
|
+
"""
|
45
|
+
Search for academic papers on Semantic Scholar.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
query (str): The search query string to find academic papers.
|
49
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
50
|
+
limit (int, optional): The maximum number of results to return. Defaults to 2.
|
51
|
+
year (str, optional): Year range for papers.
|
52
|
+
Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Dict[str, Any]: The search results and related information.
|
56
|
+
"""
|
57
|
+
print("Starting paper search...")
|
58
|
+
endpoint = "https://api.semanticscholar.org/graph/v1/paper/search"
|
59
|
+
params = {
|
60
|
+
"query": query,
|
61
|
+
"limit": min(limit, 100),
|
62
|
+
# "fields": "paperId,title,abstract,year,authors,
|
63
|
+
# citationCount,url,publicationTypes,openAccessPdf",
|
64
|
+
"fields": "paperId,title,abstract,year,authors,citationCount,url",
|
65
|
+
}
|
66
|
+
|
67
|
+
# Add year parameter if provided
|
68
|
+
if year:
|
69
|
+
params["year"] = year
|
70
|
+
|
71
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
72
|
+
data = response.json()
|
73
|
+
papers = data.get("data", [])
|
74
|
+
|
75
|
+
# Create a dictionary to store the papers
|
76
|
+
filtered_papers = {
|
77
|
+
paper["paperId"]: {
|
78
|
+
"Title": paper.get("title", "N/A"),
|
79
|
+
"Abstract": paper.get("abstract", "N/A"),
|
80
|
+
"Year": paper.get("year", "N/A"),
|
81
|
+
"Citation Count": paper.get("citationCount", "N/A"),
|
82
|
+
"URL": paper.get("url", "N/A"),
|
83
|
+
# "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
|
84
|
+
# if paper.get("publicationTypes")
|
85
|
+
# else "N/A",
|
86
|
+
# "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
|
87
|
+
# if paper.get("openAccessPdf") is not None
|
88
|
+
# else "N/A",
|
89
|
+
}
|
90
|
+
for paper in papers
|
91
|
+
if paper.get("title") and paper.get("authors")
|
92
|
+
}
|
93
|
+
|
94
|
+
df = pd.DataFrame(filtered_papers)
|
95
|
+
|
96
|
+
# Format papers for state update
|
97
|
+
papers = [
|
98
|
+
f"Paper ID: {paper_id}\n"
|
99
|
+
f"Title: {paper_data['Title']}\n"
|
100
|
+
f"Abstract: {paper_data['Abstract']}\n"
|
101
|
+
f"Year: {paper_data['Year']}\n"
|
102
|
+
f"Citations: {paper_data['Citation Count']}\n"
|
103
|
+
f"URL: {paper_data['URL']}\n"
|
104
|
+
# f"Publication Type: {paper_data['Publication Type']}\n"
|
105
|
+
# f"Open Access PDF: {paper_data['Open Access PDF']}"
|
106
|
+
for paper_id, paper_data in filtered_papers.items()
|
107
|
+
]
|
108
|
+
|
109
|
+
markdown_table = df.to_markdown(tablefmt="grid")
|
110
|
+
logging.info("Search results: %s", papers)
|
111
|
+
|
112
|
+
return Command(
|
113
|
+
update={
|
114
|
+
"papers": filtered_papers, # Now sending the dictionary directly
|
115
|
+
"messages": [
|
116
|
+
ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
|
117
|
+
],
|
118
|
+
}
|
119
|
+
)
|