academia-mcp 1.10.9__tar.gz → 1.11.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/PKG-INFO +1 -1
  2. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/server.py +41 -21
  3. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/arxiv_download.py +30 -33
  4. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/arxiv_search.py +43 -34
  5. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/bitflip.py +63 -60
  6. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/s2.py +50 -40
  7. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/show_image.py +21 -6
  8. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/visit_webpage.py +25 -14
  9. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/web_search.py +42 -35
  10. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/utils.py +2 -0
  11. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp.egg-info/PKG-INFO +1 -1
  12. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp.egg-info/SOURCES.txt +1 -0
  13. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/pyproject.toml +2 -1
  14. academia_mcp-1.11.1/tests/test_arxiv_download.py +35 -0
  15. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_arxiv_search.py +12 -21
  16. academia_mcp-1.11.1/tests/test_bitflip.py +52 -0
  17. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_document_qa.py +1 -1
  18. academia_mcp-1.11.1/tests/test_s2.py +42 -0
  19. academia_mcp-1.11.1/tests/test_server.py +81 -0
  20. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_visit_webpage.py +16 -10
  21. academia_mcp-1.11.1/tests/test_web_search.py +59 -0
  22. academia_mcp-1.10.9/tests/test_arxiv_download.py +0 -25
  23. academia_mcp-1.10.9/tests/test_bitflip.py +0 -54
  24. academia_mcp-1.10.9/tests/test_s2.py +0 -44
  25. academia_mcp-1.10.9/tests/test_web_search.py +0 -55
  26. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/LICENSE +0 -0
  27. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/README.md +0 -0
  28. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/__init__.py +0 -0
  29. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/__main__.py +0 -0
  30. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/files.py +0 -0
  31. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/latex_templates/agents4science_2025/agents4science_2025.sty +0 -0
  32. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/latex_templates/agents4science_2025/agents4science_2025.tex +0 -0
  33. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/llm.py +0 -0
  34. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/pdf.py +0 -0
  35. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/py.typed +0 -0
  36. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/settings.py +0 -0
  37. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/__init__.py +0 -0
  38. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/anthology_search.py +0 -0
  39. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/document_qa.py +0 -0
  40. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/hf_datasets_search.py +0 -0
  41. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/latex.py +0 -0
  42. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/py.typed +0 -0
  43. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/review.py +0 -0
  44. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/speech_to_text.py +0 -0
  45. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp/tools/yt_transcript.py +0 -0
  46. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp.egg-info/dependency_links.txt +0 -0
  47. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp.egg-info/entry_points.txt +0 -0
  48. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp.egg-info/requires.txt +0 -0
  49. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/academia_mcp.egg-info/top_level.txt +0 -0
  50. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/setup.cfg +0 -0
  51. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_anthology_search.py +0 -0
  52. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_extract_json.py +0 -0
  53. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_hf_dataset_search.py +0 -0
  54. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_latex.py +0 -0
  55. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_review.py +0 -0
  56. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_show_image.py +0 -0
  57. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_speech_to_text.py +0 -0
  58. {academia_mcp-1.10.9 → academia_mcp-1.11.1}/tests/test_yt_transcript.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: academia-mcp
3
- Version: 1.10.9
3
+ Version: 1.11.1
4
4
  Summary: MCP server that provides different tools to search for scientific publications
5
5
  Author-email: Ilya Gusev <phoenixilya@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/IlyaGusev/academia_mcp
@@ -63,35 +63,34 @@ def find_free_port() -> int:
63
63
  raise RuntimeError("No free port in range 5000-6000 found")
64
64
 
65
65
 
66
- def run(
67
- host: str = "0.0.0.0",
68
- port: Optional[int] = None,
69
- mount_path: str = "/",
66
+ def create_server(
70
67
  streamable_http_path: str = "/mcp",
71
- transport: Literal["stdio", "sse", "streamable-http"] = "streamable-http",
68
+ mount_path: str = "/",
69
+ stateless_http: bool = True,
72
70
  disable_web_search_tools: bool = False,
73
71
  disable_llm_tools: bool = False,
74
- ) -> None:
75
- configure_uvicorn_style_logging()
72
+ port: Optional[int] = None,
73
+ host: str = "0.0.0.0",
74
+ ) -> FastMCP:
76
75
  server = FastMCP(
77
76
  "Academia MCP",
78
- stateless_http=True,
77
+ stateless_http=stateless_http,
79
78
  streamable_http_path=streamable_http_path,
80
79
  mount_path=mount_path,
81
80
  )
82
81
  logger = logging.getLogger(__name__)
83
82
 
84
- server.add_tool(arxiv_search)
85
- server.add_tool(arxiv_download)
86
- server.add_tool(s2_get_citations)
87
- server.add_tool(s2_get_references)
83
+ server.add_tool(arxiv_search, structured_output=True)
84
+ server.add_tool(arxiv_download, structured_output=True)
85
+ server.add_tool(visit_webpage, structured_output=True)
86
+ server.add_tool(s2_get_citations, structured_output=True)
87
+ server.add_tool(s2_get_references, structured_output=True)
88
+ server.add_tool(s2_get_info, structured_output=True)
88
89
  server.add_tool(s2_corpus_id_from_arxiv_id)
89
- server.add_tool(s2_get_info)
90
90
  server.add_tool(hf_datasets_search)
91
91
  server.add_tool(anthology_search)
92
92
  server.add_tool(get_latex_template)
93
93
  server.add_tool(get_latex_templates_list)
94
- server.add_tool(visit_webpage)
95
94
  server.add_tool(show_image)
96
95
  server.add_tool(yt_transcript)
97
96
 
@@ -106,20 +105,20 @@ def run(
106
105
 
107
106
  if not disable_web_search_tools:
108
107
  if settings.TAVILY_API_KEY:
109
- server.add_tool(tavily_web_search)
108
+ server.add_tool(tavily_web_search, structured_output=True)
110
109
  if settings.EXA_API_KEY:
111
- server.add_tool(exa_web_search)
110
+ server.add_tool(exa_web_search, structured_output=True)
112
111
  if settings.BRAVE_API_KEY:
113
- server.add_tool(brave_web_search)
112
+ server.add_tool(brave_web_search, structured_output=True)
114
113
  if settings.EXA_API_KEY or settings.BRAVE_API_KEY or settings.TAVILY_API_KEY:
115
- server.add_tool(web_search)
114
+ server.add_tool(web_search, structured_output=True)
116
115
  else:
117
116
  logger.warning("No web search tools keys are set, web_search will not be available!")
118
117
 
119
118
  if not disable_llm_tools and settings.OPENROUTER_API_KEY:
120
- server.add_tool(extract_bitflip_info)
121
- server.add_tool(generate_research_proposals)
122
- server.add_tool(score_research_proposals)
119
+ server.add_tool(extract_bitflip_info, structured_output=True)
120
+ server.add_tool(generate_research_proposals, structured_output=True)
121
+ server.add_tool(score_research_proposals, structured_output=True)
123
122
  server.add_tool(document_qa)
124
123
  server.add_tool(describe_image)
125
124
  if settings.WORKSPACE_DIR:
@@ -140,6 +139,27 @@ def run(
140
139
 
141
140
  server.settings.port = port
142
141
  server.settings.host = host
142
+ return server
143
+
144
+
145
+ def run(
146
+ host: str = "0.0.0.0",
147
+ port: Optional[int] = None,
148
+ mount_path: str = "/",
149
+ streamable_http_path: str = "/mcp",
150
+ transport: Literal["stdio", "sse", "streamable-http"] = "streamable-http",
151
+ disable_web_search_tools: bool = False,
152
+ disable_llm_tools: bool = False,
153
+ ) -> None:
154
+ configure_uvicorn_style_logging()
155
+ server = create_server(
156
+ streamable_http_path=streamable_http_path,
157
+ mount_path=mount_path,
158
+ disable_web_search_tools=disable_web_search_tools,
159
+ disable_llm_tools=disable_llm_tools,
160
+ port=port,
161
+ host=host,
162
+ )
143
163
 
144
164
  if transport == "streamable-http":
145
165
  # Enable CORS for browser-based clients
@@ -3,19 +3,17 @@
3
3
  # https://github.com/bytedance/pasa/blob/main/utils.py
4
4
 
5
5
  import re
6
- import json
7
6
  import tempfile
8
7
  from pathlib import Path
9
- from typing import Any, List, Optional, Dict
10
- from dataclasses import dataclass, field
8
+ from typing import Any, Dict, List, Optional
11
9
 
12
- import requests
13
10
  import bs4
11
+ import requests
14
12
  from markdownify import MarkdownConverter # type: ignore
13
+ from pydantic import BaseModel, Field
15
14
 
15
+ from academia_mcp.pdf import download_pdf, parse_pdf_file
16
16
  from academia_mcp.utils import get_with_retries
17
- from academia_mcp.pdf import parse_pdf_file, download_pdf
18
-
19
17
 
20
18
  HTML_URL = "https://arxiv.org/html/{paper_id}"
21
19
  ABS_URL = "https://arxiv.org/abs/{paper_id}"
@@ -28,12 +26,24 @@ SECTION_STOP_WORDS = (
28
26
  )
29
27
 
30
28
 
31
- @dataclass
32
- class TOCEntry:
29
+ class DownloadResponse(BaseModel): # type: ignore
30
+ title: str = Field(description="Title of the paper")
31
+ abstract: str = Field(description="Abstract of the paper")
32
+ toc: str = Field(description="Table of Contents", default="")
33
+ sections: Optional[List[str]] = Field(description="Sections of the paper", default=None)
34
+ references: Optional[List[Dict[str, Any]]] = Field(
35
+ description="Parsed references from the paper", default=None
36
+ )
37
+ original_format: str = Field(
38
+ description="Original format of the paper (pdf or html)", default="html"
39
+ )
40
+
41
+
42
+ class TOCEntry(BaseModel): # type: ignore
33
43
  level: int
34
44
  title: str
35
45
  html_id: Optional[str] = None
36
- subsections: List["TOCEntry"] = field(default_factory=list)
46
+ subsections: List["TOCEntry"] = Field(default_factory=list)
37
47
 
38
48
  def linearize(self) -> List["TOCEntry"]:
39
49
  entries = [self]
@@ -196,7 +206,7 @@ def _parse_citation_metadata(metas: List[str]) -> Dict[str, Any]:
196
206
  return result
197
207
 
198
208
 
199
- def _extract_citations(soup_biblist: bs4.element.Tag) -> List[Dict[str, Any]]:
209
+ def _extract_references(soup_biblist: bs4.element.Tag) -> List[Dict[str, Any]]:
200
210
  extracted = []
201
211
  for li in soup_biblist.find_all("li", recursive=False):
202
212
  metas = [x.text.strip() for x in li.find_all("span", class_="ltx_bibblock")]
@@ -214,17 +224,17 @@ def _parse_html(paper_id: str) -> Dict[str, Any]:
214
224
  article = soup.article
215
225
  assert article and isinstance(article, bs4.element.Tag)
216
226
 
217
- citations = []
227
+ references = []
218
228
  biblist_tag = article.find(class_="ltx_biblist")
219
229
  if biblist_tag and isinstance(biblist_tag, bs4.element.Tag):
220
- citations = _extract_citations(biblist_tag)
230
+ references = _extract_references(biblist_tag)
221
231
 
222
232
  toc = _generate_toc(article)
223
233
  sections = _build_by_toc(toc, article, url)
224
234
  return {
225
235
  "toc": toc.to_str(),
226
236
  "sections": sections,
227
- "citations": citations,
237
+ "references": references,
228
238
  "original_format": "html",
229
239
  }
230
240
 
@@ -255,36 +265,24 @@ def _parse_pdf(paper_id: str) -> Dict[str, Any]:
255
265
  return {
256
266
  "toc": "\n".join([f"Page {page_number}" for page_number in range(1, len(pages) + 1)]),
257
267
  "sections": pages,
258
- "citations": [],
268
+ "references": [],
259
269
  "original_format": "pdf",
260
270
  }
261
271
 
262
272
 
263
273
  def arxiv_download(
264
274
  paper_id: str,
265
- include_citations: Optional[bool] = False,
275
+ include_references: Optional[bool] = False,
266
276
  mode: Optional[str] = "html",
267
- ) -> str:
277
+ ) -> DownloadResponse:
268
278
  """
269
279
  Downloads a paper from Arxiv and converts it to text.
270
280
  Use mode = "html" by default.
271
281
  Fall back to mode = "pdf" if there are any problems with the HTML version.
272
282
 
273
- Returns a JSON with a following structure:
274
- {
275
- "title": "...",
276
- "abstract": "...",
277
- "toc": "...",
278
- "sections": ["...", ...],
279
- "citations": [...]
280
- }
281
- Use `json.loads` to deserialize the result if you want to get specific fields.
282
- For example, `abstract = json.loads(arxiv_download("2409.06820v1"))`
283
- The "toc" key contains Table of Contents, that sometimes has indexing for sections.
284
-
285
283
  Args:
286
284
  paper_id: ID of the paper on Arxiv. For instance: 2409.06820v1
287
- include_citations: include "citations" in the result or not. False by default.
285
+ include_references: include "references" in the result or not. False by default.
288
286
  mode: Which version of paper to use. Options: ["html", "pdf"]. "html" by default.
289
287
  """
290
288
 
@@ -297,7 +295,6 @@ def arxiv_download(
297
295
  else:
298
296
  content = _parse_pdf(paper_id)
299
297
 
300
- if not include_citations and "citations" in content:
301
- content.pop("citations")
302
-
303
- return json.dumps({**abs_meta, **content}, ensure_ascii=False)
298
+ if not include_references and "references" in content:
299
+ content.pop("references")
300
+ return DownloadResponse(**{**abs_meta, **content})
@@ -2,12 +2,12 @@
2
2
  # https://github.com/jonatasgrosman/findpapers/blob/master/findpapers/searchers/arxiv_searcher.py
3
3
  # https://info.arxiv.org/help/api/user-manual.html
4
4
 
5
- import json
6
5
  import re
7
- from typing import Optional, List, Dict, Any, Union
8
- from datetime import datetime, date
6
+ from datetime import date, datetime
7
+ from typing import Any, Dict, List, Optional, Union
9
8
 
10
9
  import xmltodict
10
+ from pydantic import BaseModel, Field
11
11
 
12
12
  from academia_mcp.utils import get_with_retries
13
13
 
@@ -17,6 +17,25 @@ SORT_BY_OPTIONS = ("relevance", "lastUpdatedDate", "submittedDate")
17
17
  SORT_ORDER_OPTIONS = ("ascending", "descending")
18
18
 
19
19
 
20
+ class ArxivSearchEntry(BaseModel): # type: ignore
21
+ id: str = Field(description="Paper ID")
22
+ title: str = Field(description="Paper title")
23
+ authors: str = Field(description="Authors of the paper")
24
+ published: str = Field(description="Published date of the paper")
25
+ updated: str = Field(description="Updated date of the paper")
26
+ categories: str = Field(description="Categories of the paper")
27
+ comment: str = Field(description="Comment of the paper")
28
+ index: int = Field(description="Index of the paper", default=0)
29
+ abstract: Optional[str] = Field(description="Abstract of the paper", default=None)
30
+
31
+
32
+ class ArxivSearchResponse(BaseModel): # type: ignore
33
+ total_count: int = Field(description="Total number of results")
34
+ returned_count: int = Field(description="Number of results returned")
35
+ offset: int = Field(description="Offset for pagination")
36
+ results: List[ArxivSearchEntry] = Field(description="Search entries")
37
+
38
+
20
39
  def _format_text_field(text: str) -> str:
21
40
  return " ".join([line.strip() for line in text.split() if line.strip()])
22
41
 
@@ -48,17 +67,17 @@ def _format_date(date: str) -> str:
48
67
  return dt.strftime("%B %d, %Y")
49
68
 
50
69
 
51
- def _clean_entry(entry: Dict[str, Any]) -> Dict[str, Any]:
52
- return {
53
- "id": entry["id"].split("/")[-1],
54
- "title": _format_text_field(entry["title"]),
55
- "authors": _format_authors(entry["author"]),
56
- "abstract": _format_text_field(entry["summary"]),
57
- "published": _format_date(entry["published"]),
58
- "updated": _format_date(entry["updated"]),
59
- "categories": _format_categories(entry.get("category", {})),
60
- "comment": _format_text_field(entry.get("arxiv:comment", {}).get("#text", "")),
61
- }
70
+ def _clean_entry(entry: Dict[str, Any]) -> ArxivSearchEntry:
71
+ return ArxivSearchEntry(
72
+ id=entry["id"].split("/")[-1],
73
+ title=_format_text_field(entry["title"]),
74
+ authors=_format_authors(entry["author"]),
75
+ abstract=_format_text_field(entry["summary"]),
76
+ published=_format_date(entry["published"]),
77
+ updated=_format_date(entry["updated"]),
78
+ categories=_format_categories(entry.get("category", {})),
79
+ comment=_format_text_field(entry.get("arxiv:comment", {}).get("#text", "")),
80
+ )
62
81
 
63
82
 
64
83
  def _convert_to_yyyymmddtttt(date_str: str) -> str:
@@ -105,22 +124,19 @@ def _format_entries(
105
124
  start_index: int,
106
125
  include_abstracts: bool,
107
126
  total_results: int,
108
- ) -> str:
127
+ ) -> ArxivSearchResponse:
109
128
  clean_entries: List[Dict[str, Any]] = []
110
129
  for entry_num, entry in enumerate(entries):
111
130
  clean_entry = _clean_entry(entry)
112
131
  if not include_abstracts:
113
- clean_entry.pop("abstract")
114
- clean_entry["index"] = start_index + entry_num
132
+ clean_entry.abstract = None
133
+ clean_entry.index = start_index + entry_num
115
134
  clean_entries.append(clean_entry)
116
- return json.dumps(
117
- {
118
- "total_count": total_results,
119
- "returned_count": len(entries),
120
- "offset": start_index,
121
- "results": clean_entries,
122
- },
123
- ensure_ascii=False,
135
+ return ArxivSearchResponse(
136
+ total_count=total_results,
137
+ returned_count=len(entries),
138
+ offset=start_index,
139
+ results=clean_entries,
124
140
  )
125
141
 
126
142
 
@@ -133,7 +149,7 @@ def arxiv_search(
133
149
  sort_by: Optional[str] = "relevance",
134
150
  sort_order: Optional[str] = "descending",
135
151
  include_abstracts: Optional[bool] = False,
136
- ) -> str:
152
+ ) -> ArxivSearchResponse:
137
153
  """
138
154
  Search arXiv papers with field-specific queries.
139
155
 
@@ -158,12 +174,6 @@ def arxiv_search(
158
174
  all:role OR all:playing OR all:"language model"
159
175
  (au:vaswani OR au:"del maestro") ANDNOT ti:attention
160
176
 
161
- Returns a JSON object serialized to a string. The structure is:
162
- {"total_count": ..., "returned_count": ..., "offset": ..., "results": [...]}
163
- Every item in the "results" has the following fields:
164
- ("index", "id", "title", "authors", "abstract", "published", "updated", "categories", "comment")
165
- Use `json.loads` to deserialize the result if you want to get specific fields.
166
-
167
177
  Args:
168
178
  query: The search query, required.
169
179
  offset: The offset to scroll search results. 10 items will be skipped if offset=10. 0 by default.
@@ -211,10 +221,9 @@ def arxiv_search(
211
221
  entries = feed.get("entry", [])
212
222
  if isinstance(entries, dict):
213
223
  entries = [entries]
214
- formatted_entries: str = _format_entries(
224
+ return _format_entries(
215
225
  entries,
216
226
  start_index=start_index,
217
227
  total_results=total_results,
218
228
  include_abstracts=include_abstracts,
219
229
  )
220
- return formatted_entries
@@ -1,17 +1,18 @@
1
+ # Based on
1
2
  # https://arxiv.org/abs/2504.12976
2
3
  # https://web.stanford.edu/class/cs197c/slides/02-literature-search.pdf
3
4
 
4
5
  import json
5
6
  import random
6
- from typing import List, Optional, Any, Dict
7
+ from typing import Any, Dict, List, Optional
7
8
 
8
- from pydantic import BaseModel
9
9
  from datasets import load_dataset # type: ignore
10
+ from pydantic import BaseModel, Field
10
11
 
11
- from academia_mcp.tools.arxiv_download import arxiv_download
12
- from academia_mcp.utils import extract_json, encode_prompt
13
- from academia_mcp.llm import llm_acall, ChatMessage
12
+ from academia_mcp.llm import ChatMessage, llm_acall
14
13
  from academia_mcp.settings import settings
14
+ from academia_mcp.tools.arxiv_download import arxiv_download
15
+ from academia_mcp.utils import encode_prompt, extract_json
15
16
 
16
17
 
17
18
  class ProposalDataset:
@@ -128,7 +129,7 @@ Return only the JSON list of proposals in this exact format:
128
129
  "spark": "4-6 word summary",
129
130
  "abstract": "An abstract that summarizes the proposal in conference format (approximately 250 words).",
130
131
  "experiments": ["...", "..."],
131
- "risks_and_limitations": "A list of potential risks and limitations of the proposal."
132
+ "risks_and_limitations": ["...", "..."]
132
133
  },
133
134
  ...
134
135
  ]
@@ -177,12 +178,12 @@ Return only scores for all proposals in this exact format (no extra text):
177
178
 
178
179
 
179
180
  class BitFlipInfo(BaseModel): # type: ignore
180
- bit: str
181
- flip: str
182
- spark: str
181
+ bit: str = Field(description="Technical limitation or conventional approach")
182
+ flip: str = Field(description="Innovative approach or solution")
183
+ spark: str = Field(description="4-6 word summary")
183
184
 
184
185
 
185
- async def extract_bitflip_info(arxiv_id: str) -> str:
186
+ async def extract_bitflip_info(arxiv_id: str) -> BitFlipInfo:
186
187
  """
187
188
  Extracts the Bit-Flip information from the arXiv paper.
188
189
 
@@ -190,20 +191,12 @@ async def extract_bitflip_info(arxiv_id: str) -> str:
190
191
  questioning existing constraints or reapplying techniques to new domains/scales.
191
192
  The "Bit" is the prevailing belief, and the "Flip" is the counterargument.
192
193
 
193
- Returns a JSON object in this format:
194
- {
195
- "bit": "Technical limitation or conventional approach, in at least two sentences",
196
- "flip": "Innovative approach or solution, in at least two sentences",
197
- "spark": "4-6 word summary of the core idea"
198
- }
199
- Use `json.loads` to deserialize the result if you want to get specific fields.
200
-
201
194
  Args:
202
195
  arxiv_id: The arXiv ID of the paper to extract the Bit-Flip information from.
203
196
  """
204
197
  model_name = settings.BITFLIP_MODEL_NAME
205
198
  paper = arxiv_download(arxiv_id)
206
- abstract = json.loads(paper)["abstract"]
199
+ abstract = paper.abstract
207
200
  prompt = encode_prompt(EXTRACT_PROMPT, abstract=abstract)
208
201
  content = await llm_acall(
209
202
  model_name=model_name,
@@ -212,12 +205,31 @@ async def extract_bitflip_info(arxiv_id: str) -> str:
212
205
  )
213
206
  result = extract_json(content)
214
207
  bitflip_info: BitFlipInfo = BitFlipInfo.model_validate(result)
215
- return str(bitflip_info.model_dump_json())
208
+ return bitflip_info
209
+
210
+
211
+ class ResearchProposal(BaseModel): # type: ignore
212
+ proposal_id: int = Field(default=0, description="ID of the proposal")
213
+ flip: str = Field(description="Innovative approach or solution, in at least two sentences")
214
+ spark: str = Field(description="4-6 word summary")
215
+ abstract: str = Field(
216
+ description="An abstract that summarizes the proposal in conference format."
217
+ )
218
+ experiments: List[str] = Field(
219
+ description="A list of experiments that would be conducted to validate the proposal."
220
+ )
221
+ risks_and_limitations: List[str] = Field(
222
+ description="A list of potential risks and limitations of the proposal."
223
+ )
224
+
225
+
226
+ class GenerateResearchProposalResponse(BaseModel): # type: ignore
227
+ proposals: List[ResearchProposal] = Field(description="A list of research proposals")
216
228
 
217
229
 
218
230
  async def generate_research_proposals(
219
231
  bit: str, num_proposals: int = 3, additional_context: str = ""
220
- ) -> str:
232
+ ) -> GenerateResearchProposalResponse:
221
233
  """
222
234
  Proposes improvement ideas for the Bit.
223
235
 
@@ -225,20 +237,6 @@ async def generate_research_proposals(
225
237
  bit: The Bit to propose improvement ideas for. The bit is a technical limitation or conventional approach of some paper.
226
238
  num_proposals: The number of proposals to generate.
227
239
  additional_context: Additional context to use when proposing the improvement idea.
228
-
229
- Returns a JSON string with a research proposal in this format:
230
- [
231
- {
232
- "proposal_id": ...,
233
- "flip": "Innovative approach or solution, in at least two sentences",
234
- "spark": "4-6 word summary",
235
- "abstract": "An abstract that summarizes the proposal in conference format (approximately 250 words).",
236
- "experiments": ["...", "..."],
237
- "risks_and_limitations": "A list of potential risks and limitations of the proposal."
238
- },
239
- ...
240
- ]
241
- Use `json.loads` to deserialize the result if you want to get specific items.
242
240
  """
243
241
  model_name = settings.BITFLIP_MODEL_NAME
244
242
  max_completion_tokens = int(settings.BITFLIP_MAX_COMPLETION_TOKENS)
@@ -262,46 +260,51 @@ async def generate_research_proposals(
262
260
  temperature=1.0,
263
261
  )
264
262
  result = extract_json(content)
265
- for proposal in result:
266
- proposal["proposal_id"] = random.randint(0, 1000000)
267
- return json.dumps(result, ensure_ascii=False)
263
+ return GenerateResearchProposalResponse(
264
+ proposals=[ResearchProposal.model_validate(proposal) for proposal in result]
265
+ )
266
+
267
+
268
+ class ScoredProposal(BaseModel): # type: ignore
269
+ proposal_id: int = Field(default=0, description="ID of the proposal")
270
+ spark: str = Field(description="4-6 word summary")
271
+ strengths: List[str] = Field(description="A list of strengths of the proposal")
272
+ weaknesses: List[str] = Field(description="A list of weaknesses of the proposal")
273
+ novelty: int = Field(description="Novelty rating from 1 to 4")
274
+ clarity: int = Field(description="Clarity rating from 1 to 4")
275
+ significance: int = Field(description="Significance rating from 1 to 4")
276
+ feasibility: int = Field(description="Feasibility rating from 1 to 4")
277
+ soundness: int = Field(description="Soundness rating from 1 to 4")
278
+ overall: int = Field(description="Overall rating from 1 to 10")
268
279
 
269
280
 
270
- async def score_research_proposals(proposals: str | List[str | Dict[str, Any] | Any]) -> str:
281
+ class ScoreResearchProposalsResponse(BaseModel): # type: ignore
282
+ proposals: List[ScoredProposal] = Field(description="List of scored proposals")
283
+
284
+
285
+ async def score_research_proposals(
286
+ proposals: str | List[str | Dict[str, Any] | Any],
287
+ ) -> ScoreResearchProposalsResponse:
271
288
  """
272
289
  Scores a list of research proposals.
273
290
  Use proposals obtained with the `generate_research_proposal` tool.
274
291
 
275
- Returns a JSON string with a list of scores in this format:
276
- [
277
- {
278
- "proposal_id": 0,
279
- "spark": "...",
280
- "strengths": ["...", "..."],
281
- "weaknesses": ["...", "..."],
282
- "novelty": 2,
283
- "clarity": 2,
284
- "significance": 2,
285
- "feasibility": 2,
286
- "soundness": 2,
287
- "overall": 5
288
- },
289
- ...
290
- ]
291
- Use `json.loads` to deserialize the result if you want to get specific fields.
292
-
293
292
  Args:
294
293
  proposals: A list of JSON strings with research proposals.
295
294
  """
296
295
  model_name = settings.BITFLIP_MODEL_NAME
297
296
  if isinstance(proposals, str):
298
297
  proposals = json.loads(proposals)
299
- assert isinstance(proposals, list), "Proposals should be a list of JSON strings"
300
- prompt = encode_prompt(SCORE_PROMPT, proposals=[str(p) for p in proposals])
298
+ assert isinstance(proposals, list), "Proposals should be a list"
299
+ if isinstance(proposals, list):
300
+ proposals = [str(p) for p in proposals]
301
+ prompt = encode_prompt(SCORE_PROMPT, proposals=proposals)
301
302
  content = await llm_acall(
302
303
  model_name=model_name,
303
304
  messages=[ChatMessage(role="user", content=prompt)],
304
305
  temperature=0.0,
305
306
  )
306
307
  scores = extract_json(content)
307
- return json.dumps(scores, ensure_ascii=False)
308
+ return ScoreResearchProposalsResponse(
309
+ proposals=[ScoredProposal.model_validate(score) for score in scores]
310
+ )