aiagents4pharma 1.8.0__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. aiagents4pharma/__init__.py +9 -5
  2. aiagents4pharma/configs/__init__.py +5 -0
  3. aiagents4pharma/configs/config.yaml +4 -0
  4. aiagents4pharma/configs/talk2biomodels/__init__.py +6 -0
  5. aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
  6. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
  7. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +14 -0
  8. aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
  9. aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
  10. aiagents4pharma/talk2biomodels/__init__.py +3 -0
  11. aiagents4pharma/talk2biomodels/agents/__init__.py +5 -0
  12. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +96 -0
  13. aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
  14. aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
  15. aiagents4pharma/talk2biomodels/api/ols.py +72 -0
  16. aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
  17. aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
  18. aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
  19. aiagents4pharma/talk2biomodels/states/__init__.py +5 -0
  20. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +41 -0
  21. aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
  22. aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
  23. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
  24. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +54 -0
  25. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +171 -0
  26. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
  27. aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
  28. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
  29. aiagents4pharma/talk2biomodels/tests/test_query_article.py +76 -0
  30. aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
  31. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
  32. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
  33. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +63 -0
  34. aiagents4pharma/talk2biomodels/tools/__init__.py +5 -0
  35. aiagents4pharma/talk2biomodels/tools/ask_question.py +61 -18
  36. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
  37. aiagents4pharma/talk2biomodels/tools/get_annotation.py +304 -0
  38. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +11 -9
  39. aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
  40. aiagents4pharma/talk2biomodels/tools/load_biomodel.py +0 -1
  41. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +287 -0
  42. aiagents4pharma/talk2biomodels/tools/query_article.py +59 -0
  43. aiagents4pharma/talk2biomodels/tools/simulate_model.py +35 -90
  44. aiagents4pharma/talk2biomodels/tools/steady_state.py +167 -0
  45. aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
  46. aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
  47. aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +25 -0
  48. aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +79 -0
  49. aiagents4pharma/talk2competitors/__init__.py +5 -0
  50. aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
  51. aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
  52. aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
  53. aiagents4pharma/talk2competitors/config/__init__.py +5 -0
  54. aiagents4pharma/talk2competitors/config/config.py +110 -0
  55. aiagents4pharma/talk2competitors/state/__init__.py +5 -0
  56. aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
  57. aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
  58. aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
  59. aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
  60. aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
  61. aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
  62. aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
  63. aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
  64. aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
  65. aiagents4pharma/talk2knowledgegraphs/__init__.py +2 -1
  66. aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
  67. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
  68. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
  69. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
  70. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
  71. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
  72. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
  73. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
  74. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +39 -0
  75. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +117 -0
  76. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
  77. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +5 -0
  78. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +36 -0
  79. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +123 -0
  80. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/METADATA +44 -25
  81. aiagents4pharma-1.15.0.dist-info/RECORD +102 -0
  82. aiagents4pharma-1.8.0.dist-info/RECORD +0 -35
  83. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/LICENSE +0 -0
  84. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/WHEEL +0 -0
  85. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,3 @@
1
+ """
2
+ This module contains the test cases.
3
+ """
@@ -0,0 +1,274 @@
1
+ """
2
+ Unit and integration tests for Talk2Competitors system.
3
+ Each test focuses on a single, specific functionality.
4
+ Tests are deterministic and independent of each other.
5
+ """
6
+
7
+ from unittest.mock import Mock, patch
8
+
9
+ import pytest
10
+ from langchain_core.messages import AIMessage, HumanMessage
11
+
12
+ from ..agents.main_agent import get_app, make_supervisor_node
13
+ from ..state.state_talk2competitors import replace_dict
14
+ from ..tools.s2.display_results import display_results
15
+ from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
16
+ from ..tools.s2.search import search_tool
17
+ from ..tools.s2.single_paper_rec import get_single_paper_recommendations
18
+
19
+ # pylint: disable=redefined-outer-name
20
+
21
+ # Fixed test data for deterministic results
22
+ MOCK_SEARCH_RESPONSE = {
23
+ "data": [
24
+ {
25
+ "paperId": "123",
26
+ "title": "Machine Learning Basics",
27
+ "abstract": "An introduction to ML",
28
+ "year": 2023,
29
+ "citationCount": 100,
30
+ "url": "https://example.com/paper1",
31
+ "authors": [{"name": "Test Author"}],
32
+ }
33
+ ]
34
+ }
35
+
36
+ MOCK_STATE_PAPER = {
37
+ "123": {
38
+ "Title": "Machine Learning Basics",
39
+ "Abstract": "An introduction to ML",
40
+ "Year": 2023,
41
+ "Citation Count": 100,
42
+ "URL": "https://example.com/paper1",
43
+ }
44
+ }
45
+
46
+
47
+ @pytest.fixture
48
+ def initial_state():
49
+ """Create a base state for tests"""
50
+ return {
51
+ "messages": [],
52
+ "papers": {},
53
+ "is_last_step": False,
54
+ "current_agent": None,
55
+ "llm_model": "gpt-4o-mini",
56
+ }
57
+
58
+
59
+ class TestMainAgent:
60
+ """Unit tests for main agent functionality"""
61
+
62
+ def test_supervisor_routes_search_to_s2(self, initial_state):
63
+ """Verifies that search-related queries are routed to S2 agent"""
64
+ llm_mock = Mock()
65
+ llm_mock.invoke.return_value = AIMessage(content="Search initiated")
66
+
67
+ supervisor = make_supervisor_node(llm_mock)
68
+ state = initial_state.copy()
69
+ state["messages"] = [HumanMessage(content="search for papers")]
70
+
71
+ result = supervisor(state)
72
+ assert result.goto == "s2_agent"
73
+ assert not result.update["is_last_step"]
74
+ assert result.update["current_agent"] == "s2_agent"
75
+
76
+ def test_supervisor_routes_general_to_end(self, initial_state):
77
+ """Verifies that non-search queries end the conversation"""
78
+ llm_mock = Mock()
79
+ llm_mock.invoke.return_value = AIMessage(content="General response")
80
+
81
+ supervisor = make_supervisor_node(llm_mock)
82
+ state = initial_state.copy()
83
+ state["messages"] = [HumanMessage(content="What is ML?")]
84
+
85
+ result = supervisor(state)
86
+ assert result.goto == "__end__"
87
+ assert result.update["is_last_step"]
88
+
89
+
90
+ class TestS2Tools:
91
+ """Unit tests for individual S2 tools"""
92
+
93
+ def test_display_results_shows_papers(self, initial_state):
94
+ """Verifies display_results tool correctly returns papers from state"""
95
+ state = initial_state.copy()
96
+ state["papers"] = MOCK_STATE_PAPER
97
+ result = display_results.invoke(input={"state": state})
98
+ assert result == MOCK_STATE_PAPER
99
+ assert isinstance(result, dict)
100
+
101
+ @patch("requests.get")
102
+ def test_search_finds_papers(self, mock_get):
103
+ """Verifies search tool finds and formats papers correctly"""
104
+ mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
105
+ mock_get.return_value.status_code = 200
106
+
107
+ result = search_tool.invoke(
108
+ input={
109
+ "query": "machine learning",
110
+ "limit": 1,
111
+ "tool_call_id": "test123",
112
+ "id": "test123",
113
+ }
114
+ )
115
+
116
+ assert "papers" in result.update
117
+ assert "messages" in result.update
118
+ papers = result.update["papers"]
119
+ assert isinstance(papers, dict)
120
+ assert len(papers) > 0
121
+ paper = next(iter(papers.values()))
122
+ assert paper["Title"] == "Machine Learning Basics"
123
+ assert paper["Year"] == 2023
124
+
125
+ @patch("requests.get")
126
+ def test_single_paper_rec_basic(self, mock_get):
127
+ """Tests basic single paper recommendation functionality"""
128
+ mock_get.return_value.json.return_value = {
129
+ "recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
130
+ }
131
+ mock_get.return_value.status_code = 200
132
+
133
+ result = get_single_paper_recommendations.invoke(
134
+ input={
135
+ "paper_id": "123",
136
+ "limit": 1,
137
+ "tool_call_id": "test123",
138
+ "id": "test123",
139
+ }
140
+ )
141
+ assert "papers" in result.update
142
+ assert len(result.update["messages"]) == 1
143
+
144
+ @patch("requests.get")
145
+ def test_single_paper_rec_with_optional_params(self, mock_get):
146
+ """Tests single paper recommendations with year parameter"""
147
+ mock_get.return_value.json.return_value = {
148
+ "recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
149
+ }
150
+ mock_get.return_value.status_code = 200
151
+
152
+ result = get_single_paper_recommendations.invoke(
153
+ input={
154
+ "paper_id": "123",
155
+ "limit": 1,
156
+ "year": "2023-",
157
+ "tool_call_id": "test123",
158
+ "id": "test123",
159
+ }
160
+ )
161
+ assert "papers" in result.update
162
+
163
+ @patch("requests.post")
164
+ def test_multi_paper_rec_basic(self, mock_post):
165
+ """Tests basic multi-paper recommendation functionality"""
166
+ mock_post.return_value.json.return_value = {
167
+ "recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
168
+ }
169
+ mock_post.return_value.status_code = 200
170
+
171
+ result = get_multi_paper_recommendations.invoke(
172
+ input={
173
+ "paper_ids": ["123", "456"],
174
+ "limit": 1,
175
+ "tool_call_id": "test123",
176
+ "id": "test123",
177
+ }
178
+ )
179
+ assert "papers" in result.update
180
+ assert len(result.update["messages"]) == 1
181
+
182
+ @patch("requests.post")
183
+ def test_multi_paper_rec_with_optional_params(self, mock_post):
184
+ """Tests multi-paper recommendations with all optional parameters"""
185
+ mock_post.return_value.json.return_value = {
186
+ "recommendedPapers": [MOCK_SEARCH_RESPONSE["data"][0]]
187
+ }
188
+ mock_post.return_value.status_code = 200
189
+
190
+ result = get_multi_paper_recommendations.invoke(
191
+ input={
192
+ "paper_ids": ["123", "456"],
193
+ "limit": 1,
194
+ "year": "2023-",
195
+ "tool_call_id": "test123",
196
+ "id": "test123",
197
+ }
198
+ )
199
+ assert "papers" in result.update
200
+ assert len(result.update["messages"]) == 1
201
+
202
+ @patch("requests.get")
203
+ def test_single_paper_rec_empty_response(self, mock_get):
204
+ """Tests single paper recommendations with empty response"""
205
+ mock_get.return_value.json.return_value = {"recommendedPapers": []}
206
+ mock_get.return_value.status_code = 200
207
+
208
+ result = get_single_paper_recommendations.invoke(
209
+ input={
210
+ "paper_id": "123",
211
+ "limit": 1,
212
+ "tool_call_id": "test123",
213
+ "id": "test123",
214
+ }
215
+ )
216
+ assert "papers" in result.update
217
+ assert len(result.update["papers"]) == 0
218
+
219
+ @patch("requests.post")
220
+ def test_multi_paper_rec_empty_response(self, mock_post):
221
+ """Tests multi-paper recommendations with empty response"""
222
+ mock_post.return_value.json.return_value = {"recommendedPapers": []}
223
+ mock_post.return_value.status_code = 200
224
+
225
+ result = get_multi_paper_recommendations.invoke(
226
+ input={
227
+ "paper_ids": ["123", "456"],
228
+ "limit": 1,
229
+ "tool_call_id": "test123",
230
+ "id": "test123",
231
+ }
232
+ )
233
+ assert "papers" in result.update
234
+ assert len(result.update["papers"]) == 0
235
+
236
+
237
+ def test_state_replace_dict():
238
+ """Verifies state dictionary replacement works correctly"""
239
+ existing = {"key1": "value1", "key2": "value2"}
240
+ new = {"key3": "value3"}
241
+ result = replace_dict(existing, new)
242
+ assert result == new
243
+ assert isinstance(result, dict)
244
+
245
+
246
+ @pytest.mark.integration
247
+ def test_end_to_end_search_workflow(initial_state):
248
+ """Integration test: Complete search workflow"""
249
+ with (
250
+ patch("requests.get") as mock_get,
251
+ patch("langchain_openai.ChatOpenAI") as mock_llm,
252
+ ):
253
+ mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
254
+ mock_get.return_value.status_code = 200
255
+
256
+ llm_instance = Mock()
257
+ llm_instance.invoke.return_value = AIMessage(content="Search completed")
258
+ mock_llm.return_value = llm_instance
259
+
260
+ app = get_app("test_integration")
261
+ test_state = initial_state.copy()
262
+ test_state["messages"] = [HumanMessage(content="search for ML papers")]
263
+
264
+ config = {
265
+ "configurable": {
266
+ "thread_id": "test_integration",
267
+ "checkpoint_ns": "test",
268
+ "checkpoint_id": "test123",
269
+ }
270
+ }
271
+
272
+ response = app.invoke(test_state, config)
273
+ assert "papers" in response
274
+ assert len(response["messages"]) > 0
@@ -0,0 +1,7 @@
1
+ #!/usr/bin/env python3
2
+
3
+ '''
4
+ Import statements
5
+ '''
6
+
7
+ from . import s2
@@ -0,0 +1,8 @@
1
+ '''
2
+ This file is used to import all the modules in the package.
3
+ '''
4
+
5
+ from . import display_results
6
+ from . import multi_paper_rec
7
+ from . import search
8
+ from . import single_paper_rec
@@ -0,0 +1,25 @@
1
+ #!/usr/bin/env python3
2
+
3
+ '''
4
+ This tool is used to display the table of studies.
5
+ '''
6
+
7
+ import logging
8
+ from typing import Annotated
9
+ from langchain_core.tools import tool
10
+ from langgraph.prebuilt import InjectedState
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ @tool('display_results')
17
+ def display_results(state: Annotated[dict, InjectedState]):
18
+ """
19
+ Display the papers in the state.
20
+
21
+ Args:
22
+ state (dict): The state of the agent.
23
+ """
24
+ logger.info("Displaying papers from the state")
25
+ return state["papers"]
@@ -0,0 +1,132 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ multi_paper_rec: Tool for getting recommendations
5
+ based on multiple papers
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from typing import Annotated, Any, Dict, List, Optional
11
+
12
+ import pandas as pd
13
+ import requests
14
+ from langchain_core.messages import ToolMessage
15
+ from langchain_core.tools import tool
16
+ from langchain_core.tools.base import InjectedToolCallId
17
+ from langgraph.types import Command
18
+ from pydantic import BaseModel, Field
19
+
20
+
21
+ class MultiPaperRecInput(BaseModel):
22
+ """Input schema for multiple paper recommendations tool."""
23
+
24
+ paper_ids: List[str] = Field(
25
+ description=("List of Semantic Scholar Paper IDs to get recommendations for")
26
+ )
27
+ limit: int = Field(
28
+ default=2,
29
+ description="Maximum total number of recommendations to return",
30
+ ge=1,
31
+ le=500,
32
+ )
33
+ year: Optional[str] = Field(
34
+ default=None,
35
+ description="Year range in format: YYYY for specific year, "
36
+ "YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
37
+ )
38
+ tool_call_id: Annotated[str, InjectedToolCallId]
39
+
40
+ model_config = {"arbitrary_types_allowed": True}
41
+
42
+
43
+ @tool(args_schema=MultiPaperRecInput)
44
+ def get_multi_paper_recommendations(
45
+ paper_ids: List[str],
46
+ tool_call_id: Annotated[str, InjectedToolCallId],
47
+ limit: int = 2,
48
+ year: Optional[str] = None,
49
+ ) -> Dict[str, Any]:
50
+ """
51
+ Get paper recommendations based on multiple papers.
52
+
53
+ Args:
54
+ paper_ids (List[str]): The list of paper IDs to base recommendations on.
55
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
56
+ limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
57
+ year (str, optional): Year range for papers.
58
+ Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
59
+
60
+ Returns:
61
+ Dict[str, Any]: The recommendations and related information.
62
+ """
63
+ logging.info("Starting multi-paper recommendations search.")
64
+
65
+ endpoint = "https://api.semanticscholar.org/recommendations/v1/papers"
66
+ headers = {"Content-Type": "application/json"}
67
+ payload = {"positivePaperIds": paper_ids, "negativePaperIds": []}
68
+ params = {
69
+ "limit": min(limit, 500),
70
+ "fields": "paperId,title,abstract,year,authors,citationCount,url",
71
+ }
72
+
73
+ # Add year parameter if provided
74
+ if year:
75
+ params["year"] = year
76
+
77
+ # Getting recommendations
78
+ response = requests.post(
79
+ endpoint,
80
+ headers=headers,
81
+ params=params,
82
+ data=json.dumps(payload),
83
+ timeout=10,
84
+ )
85
+ logging.info(
86
+ "API Response Status for multi-paper recommendations: %s", response.status_code
87
+ )
88
+
89
+ data = response.json()
90
+ recommendations = data.get("recommendedPapers", [])
91
+
92
+ # Create a dictionary to store the papers
93
+ filtered_papers = {
94
+ paper["paperId"]: {
95
+ "Title": paper.get("title", "N/A"),
96
+ "Abstract": paper.get("abstract", "N/A"),
97
+ "Year": paper.get("year", "N/A"),
98
+ "Citation Count": paper.get("citationCount", "N/A"),
99
+ "URL": paper.get("url", "N/A"),
100
+ }
101
+ for paper in recommendations
102
+ if paper.get("title") and paper.get("paperId")
103
+ }
104
+
105
+ # Create a DataFrame from the dictionary
106
+ df = pd.DataFrame.from_dict(filtered_papers, orient="index")
107
+ # print("Created DataFrame with results:")
108
+ logging.info("Created DataFrame with results: %s", df)
109
+
110
+ # Format papers for state update
111
+ papers = [
112
+ f"Paper ID: {paper_id}\n"
113
+ f"Title: {paper_data['Title']}\n"
114
+ f"Abstract: {paper_data['Abstract']}\n"
115
+ f"Year: {paper_data['Year']}\n"
116
+ f"Citations: {paper_data['Citation Count']}\n"
117
+ f"URL: {paper_data['URL']}\n"
118
+ for paper_id, paper_data in filtered_papers.items()
119
+ ]
120
+
121
+ # Convert DataFrame to markdown table
122
+ markdown_table = df.to_markdown(tablefmt="grid")
123
+ logging.info("Search results: %s", papers)
124
+
125
+ return Command(
126
+ update={
127
+ "papers": filtered_papers, # Now sending the dictionary directly
128
+ "messages": [
129
+ ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
130
+ ],
131
+ }
132
+ )
@@ -0,0 +1,119 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This tool is used to search for academic papers on Semantic Scholar.
5
+ """
6
+
7
+ import logging
8
+ from typing import Annotated, Any, Dict, Optional
9
+
10
+ import pandas as pd
11
+ import requests
12
+ from langchain_core.messages import ToolMessage
13
+ from langchain_core.tools import tool
14
+ from langchain_core.tools.base import InjectedToolCallId
15
+ from langgraph.types import Command
16
+ from pydantic import BaseModel, Field
17
+
18
+
19
+ class SearchInput(BaseModel):
20
+ """Input schema for the search papers tool."""
21
+
22
+ query: str = Field(
23
+ description="Search query string to find academic papers."
24
+ "Be specific and include relevant academic terms."
25
+ )
26
+ limit: int = Field(
27
+ default=2, description="Maximum number of results to return", ge=1, le=100
28
+ )
29
+ year: Optional[str] = Field(
30
+ default=None,
31
+ description="Year range in format: YYYY for specific year, "
32
+ "YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
33
+ )
34
+ tool_call_id: Annotated[str, InjectedToolCallId]
35
+
36
+
37
+ @tool(args_schema=SearchInput)
38
+ def search_tool(
39
+ query: str,
40
+ tool_call_id: Annotated[str, InjectedToolCallId],
41
+ limit: int = 2,
42
+ year: Optional[str] = None,
43
+ ) -> Dict[str, Any]:
44
+ """
45
+ Search for academic papers on Semantic Scholar.
46
+
47
+ Args:
48
+ query (str): The search query string to find academic papers.
49
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
50
+ limit (int, optional): The maximum number of results to return. Defaults to 2.
51
+ year (str, optional): Year range for papers.
52
+ Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
53
+
54
+ Returns:
55
+ Dict[str, Any]: The search results and related information.
56
+ """
57
+ print("Starting paper search...")
58
+ endpoint = "https://api.semanticscholar.org/graph/v1/paper/search"
59
+ params = {
60
+ "query": query,
61
+ "limit": min(limit, 100),
62
+ # "fields": "paperId,title,abstract,year,authors,
63
+ # citationCount,url,publicationTypes,openAccessPdf",
64
+ "fields": "paperId,title,abstract,year,authors,citationCount,url",
65
+ }
66
+
67
+ # Add year parameter if provided
68
+ if year:
69
+ params["year"] = year
70
+
71
+ response = requests.get(endpoint, params=params, timeout=10)
72
+ data = response.json()
73
+ papers = data.get("data", [])
74
+
75
+ # Create a dictionary to store the papers
76
+ filtered_papers = {
77
+ paper["paperId"]: {
78
+ "Title": paper.get("title", "N/A"),
79
+ "Abstract": paper.get("abstract", "N/A"),
80
+ "Year": paper.get("year", "N/A"),
81
+ "Citation Count": paper.get("citationCount", "N/A"),
82
+ "URL": paper.get("url", "N/A"),
83
+ # "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
84
+ # if paper.get("publicationTypes")
85
+ # else "N/A",
86
+ # "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
87
+ # if paper.get("openAccessPdf") is not None
88
+ # else "N/A",
89
+ }
90
+ for paper in papers
91
+ if paper.get("title") and paper.get("authors")
92
+ }
93
+
94
+ df = pd.DataFrame(filtered_papers)
95
+
96
+ # Format papers for state update
97
+ papers = [
98
+ f"Paper ID: {paper_id}\n"
99
+ f"Title: {paper_data['Title']}\n"
100
+ f"Abstract: {paper_data['Abstract']}\n"
101
+ f"Year: {paper_data['Year']}\n"
102
+ f"Citations: {paper_data['Citation Count']}\n"
103
+ f"URL: {paper_data['URL']}\n"
104
+ # f"Publication Type: {paper_data['Publication Type']}\n"
105
+ # f"Open Access PDF: {paper_data['Open Access PDF']}"
106
+ for paper_id, paper_data in filtered_papers.items()
107
+ ]
108
+
109
+ markdown_table = df.to_markdown(tablefmt="grid")
110
+ logging.info("Search results: %s", papers)
111
+
112
+ return Command(
113
+ update={
114
+ "papers": filtered_papers, # Now sending the dictionary directly
115
+ "messages": [
116
+ ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
117
+ ],
118
+ }
119
+ )