academic-refchecker 2.0.19__py3-none-any.whl → 2.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
backend/main.py CHANGED
@@ -4,6 +4,7 @@ FastAPI application for RefChecker Web UI
4
4
  import asyncio
5
5
  import uuid
6
6
  import os
7
+ import sys
7
8
  import tempfile
8
9
  from pathlib import Path
9
10
  from typing import Optional
@@ -15,6 +16,16 @@ from pydantic import BaseModel
15
16
  import logging
16
17
  from refchecker.__version__ import __version__
17
18
 
19
+ # Fix Windows encoding issues with Unicode characters (e.g., Greek letters in paper titles).
20
+ # Skip this when running under pytest so we don't replace pytest's capture streams, which can
21
+ # lead to closed-file errors during teardown.
22
+ if sys.platform == 'win32' and not os.environ.get("PYTEST_CURRENT_TEST"):
23
+ import io
24
+ if hasattr(sys.stdout, "buffer"):
25
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
26
+ if hasattr(sys.stderr, "buffer"):
27
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
28
+
18
29
  import aiosqlite
19
30
  from .database import db
20
31
  from .websocket_manager import manager
@@ -69,6 +80,20 @@ class CheckLabelUpdate(BaseModel):
69
80
  custom_label: str
70
81
 
71
82
 
83
+ class BatchLabelUpdate(BaseModel):
84
+ batch_label: str
85
+
86
+
87
+ class BatchUrlsRequest(BaseModel):
88
+ """Request model for batch URL submission"""
89
+ urls: list[str]
90
+ batch_label: Optional[str] = None
91
+ llm_config_id: Optional[int] = None
92
+ llm_provider: str = "anthropic"
93
+ llm_model: Optional[str] = None
94
+ use_llm: bool = True
95
+
96
+
72
97
  # Create FastAPI app
73
98
  app = FastAPI(title="RefChecker Web UI API", version="1.0.0")
74
99
 
@@ -206,6 +231,7 @@ async def start_check(
206
231
  # Handle file upload or pasted text
207
232
  paper_source = source_value
208
233
  paper_title = "Processing..." # Placeholder title until we parse the paper
234
+ original_filename = None # Only set for file uploads
209
235
  if source_type == "file" and file:
210
236
  # Save uploaded file to permanent uploads directory
211
237
  uploads_dir = Path(__file__).parent / "uploads"
@@ -218,6 +244,7 @@ async def start_check(
218
244
  f.write(content)
219
245
  paper_source = str(file_path)
220
246
  paper_title = file.filename
247
+ original_filename = file.filename # Store original filename
221
248
  elif source_type == "text":
222
249
  if not source_text:
223
250
  raise HTTPException(status_code=400, detail="No text provided")
@@ -244,7 +271,8 @@ async def start_check(
244
271
  paper_source=paper_source,
245
272
  source_type=source_type,
246
273
  llm_provider=llm_provider if use_llm else None,
247
- llm_model=llm_model if use_llm else None
274
+ llm_model=llm_model if use_llm else None,
275
+ original_filename=original_filename
248
276
  )
249
277
  logger.info(f"Created pending check with ID {check_id}")
250
278
 
@@ -358,6 +386,9 @@ async def run_check(
358
386
  except Exception as e:
359
387
  logger.warning(f"Failed to save bibliography source: {e}")
360
388
 
389
+ # Get Semantic Scholar API key from database settings
390
+ semantic_scholar_api_key = await db.get_setting("semantic_scholar_api_key")
391
+
361
392
  # Create checker with progress callback
362
393
  checker = ProgressRefChecker(
363
394
  llm_provider=llm_provider,
@@ -369,7 +400,8 @@ async def run_check(
369
400
  cancel_event=cancel_event,
370
401
  check_id=check_id,
371
402
  title_update_callback=title_update_callback,
372
- bibliography_source_callback=bibliography_source_callback
403
+ bibliography_source_callback=bibliography_source_callback,
404
+ semantic_scholar_api_key=semantic_scholar_api_key
373
405
  )
374
406
 
375
407
  # Run the check
@@ -854,7 +886,8 @@ async def recheck(check_id: int):
854
886
  paper_source=source,
855
887
  source_type=source_type,
856
888
  llm_provider=llm_provider,
857
- llm_model=llm_model
889
+ llm_model=llm_model,
890
+ original_filename=original.get("original_filename")
858
891
  )
859
892
 
860
893
  # Start check in background
@@ -900,6 +933,420 @@ async def cancel_check(session_id: str):
900
933
  return {"message": "Cancellation requested"}
901
934
 
902
935
 
936
+ # ============ Batch Operations ============
937
+
938
+ @app.post("/api/check/batch")
939
+ async def start_batch_check(request: BatchUrlsRequest):
940
+ """
941
+ Start a batch of reference checks from a list of URLs/ArXiv IDs.
942
+
943
+ Returns batch_id and list of individual check sessions.
944
+ """
945
+ try:
946
+ if not request.urls or len(request.urls) == 0:
947
+ raise HTTPException(status_code=400, detail="No URLs provided")
948
+
949
+ # Limit batch size to prevent abuse
950
+ MAX_BATCH_SIZE = 50
951
+ if len(request.urls) > MAX_BATCH_SIZE:
952
+ raise HTTPException(
953
+ status_code=400,
954
+ detail=f"Batch size exceeds maximum of {MAX_BATCH_SIZE} papers"
955
+ )
956
+
957
+ # Generate unique batch ID
958
+ batch_id = str(uuid.uuid4())
959
+ batch_label = request.batch_label or f"Batch of {len(request.urls)} papers"
960
+
961
+ # Get API key from config if provided
962
+ api_key = None
963
+ endpoint = None
964
+ llm_provider = request.llm_provider
965
+ llm_model = request.llm_model
966
+
967
+ if request.llm_config_id and request.use_llm:
968
+ config = await db.get_llm_config_by_id(request.llm_config_id)
969
+ if config:
970
+ api_key = config.get('api_key')
971
+ endpoint = config.get('endpoint')
972
+ llm_provider = config.get('provider', llm_provider)
973
+ llm_model = config.get('model') or llm_model
974
+
975
+ checks = []
976
+
977
+ for url in request.urls:
978
+ url = url.strip()
979
+ if not url:
980
+ continue
981
+
982
+ session_id = str(uuid.uuid4())
983
+
984
+ # Create pending check entry with batch info
985
+ check_id = await db.create_pending_check(
986
+ paper_title=url, # Will be updated during processing
987
+ paper_source=url,
988
+ source_type='url',
989
+ llm_provider=llm_provider if request.use_llm else None,
990
+ llm_model=llm_model if request.use_llm else None,
991
+ batch_id=batch_id,
992
+ batch_label=batch_label
993
+ )
994
+
995
+ # Start check in background
996
+ cancel_event = asyncio.Event()
997
+ task = asyncio.create_task(
998
+ run_check(
999
+ session_id, check_id, url, 'url',
1000
+ llm_provider, llm_model, api_key, endpoint,
1001
+ request.use_llm, cancel_event
1002
+ )
1003
+ )
1004
+ active_checks[session_id] = {
1005
+ "task": task,
1006
+ "cancel_event": cancel_event,
1007
+ "check_id": check_id,
1008
+ "batch_id": batch_id
1009
+ }
1010
+
1011
+ checks.append({
1012
+ "session_id": session_id,
1013
+ "check_id": check_id,
1014
+ "source": url
1015
+ })
1016
+
1017
+ logger.info(f"Started batch {batch_id} with {len(checks)} papers")
1018
+
1019
+ return {
1020
+ "batch_id": batch_id,
1021
+ "batch_label": batch_label,
1022
+ "total_papers": len(checks),
1023
+ "checks": checks
1024
+ }
1025
+
1026
+ except HTTPException:
1027
+ raise
1028
+ except Exception as e:
1029
+ logger.error(f"Error starting batch check: {e}", exc_info=True)
1030
+ raise HTTPException(status_code=500, detail=str(e))
1031
+
1032
+
1033
+ @app.post("/api/check/batch/files")
1034
+ async def start_batch_check_files(
1035
+ files: list[UploadFile] = File(...),
1036
+ batch_label: Optional[str] = Form(None),
1037
+ llm_config_id: Optional[int] = Form(None),
1038
+ llm_provider: str = Form("anthropic"),
1039
+ llm_model: Optional[str] = Form(None),
1040
+ use_llm: bool = Form(True)
1041
+ ):
1042
+ """
1043
+ Start a batch of reference checks from uploaded files.
1044
+
1045
+ Accepts multiple files or a single ZIP file containing documents.
1046
+ """
1047
+ try:
1048
+ if not files or len(files) == 0:
1049
+ raise HTTPException(status_code=400, detail="No files provided")
1050
+
1051
+ MAX_BATCH_SIZE = 50
1052
+
1053
+ # Generate unique batch ID
1054
+ batch_id = str(uuid.uuid4())
1055
+ uploads_dir = Path(__file__).parent / "uploads"
1056
+ uploads_dir.mkdir(parents=True, exist_ok=True)
1057
+
1058
+ # Get API key from config if provided
1059
+ api_key = None
1060
+ endpoint = None
1061
+
1062
+ if llm_config_id and use_llm:
1063
+ config = await db.get_llm_config_by_id(llm_config_id)
1064
+ if config:
1065
+ api_key = config.get('api_key')
1066
+ endpoint = config.get('endpoint')
1067
+ llm_provider = config.get('provider', llm_provider)
1068
+ llm_model = config.get('model') or llm_model
1069
+
1070
+ files_to_process = []
1071
+
1072
+ # Check if single ZIP file
1073
+ if len(files) == 1 and files[0].filename.lower().endswith('.zip'):
1074
+ import zipfile
1075
+ import io
1076
+
1077
+ zip_content = await files[0].read()
1078
+ with zipfile.ZipFile(io.BytesIO(zip_content), 'r') as zf:
1079
+ for name in zf.namelist():
1080
+ # Skip directories and hidden files
1081
+ if name.endswith('/') or name.startswith('__') or '/.' in name:
1082
+ continue
1083
+
1084
+ # Only process supported file types
1085
+ lower_name = name.lower()
1086
+ if not any(lower_name.endswith(ext) for ext in ['.pdf', '.txt', '.tex', '.bib', '.bbl']):
1087
+ continue
1088
+
1089
+ if len(files_to_process) >= MAX_BATCH_SIZE:
1090
+ break
1091
+
1092
+ # Extract file
1093
+ content = zf.read(name)
1094
+ filename = os.path.basename(name)
1095
+ file_path = uploads_dir / f"{batch_id}_{filename}"
1096
+ with open(file_path, 'wb') as f:
1097
+ f.write(content)
1098
+
1099
+ files_to_process.append({
1100
+ 'path': str(file_path),
1101
+ 'filename': filename
1102
+ })
1103
+ else:
1104
+ # Process individual files
1105
+ for file in files[:MAX_BATCH_SIZE]:
1106
+ safe_filename = file.filename.replace("/", "_").replace("\\", "_")
1107
+ file_path = uploads_dir / f"{batch_id}_{safe_filename}"
1108
+ content = await file.read()
1109
+ with open(file_path, "wb") as f:
1110
+ f.write(content)
1111
+
1112
+ files_to_process.append({
1113
+ 'path': str(file_path),
1114
+ 'filename': file.filename
1115
+ })
1116
+
1117
+ if not files_to_process:
1118
+ raise HTTPException(status_code=400, detail="No valid files found")
1119
+
1120
+ label = batch_label or f"Batch of {len(files_to_process)} files"
1121
+
1122
+ checks = []
1123
+ for file_info in files_to_process:
1124
+ session_id = str(uuid.uuid4())
1125
+
1126
+ check_id = await db.create_pending_check(
1127
+ paper_title=file_info['filename'],
1128
+ paper_source=file_info['path'],
1129
+ source_type='file',
1130
+ llm_provider=llm_provider if use_llm else None,
1131
+ llm_model=llm_model if use_llm else None,
1132
+ batch_id=batch_id,
1133
+ batch_label=label,
1134
+ original_filename=file_info['filename']
1135
+ )
1136
+
1137
+ cancel_event = asyncio.Event()
1138
+ task = asyncio.create_task(
1139
+ run_check(
1140
+ session_id, check_id, file_info['path'], 'file',
1141
+ llm_provider, llm_model, api_key, endpoint,
1142
+ use_llm, cancel_event
1143
+ )
1144
+ )
1145
+ active_checks[session_id] = {
1146
+ "task": task,
1147
+ "cancel_event": cancel_event,
1148
+ "check_id": check_id,
1149
+ "batch_id": batch_id
1150
+ }
1151
+
1152
+ checks.append({
1153
+ "session_id": session_id,
1154
+ "check_id": check_id,
1155
+ "source": file_info['filename']
1156
+ })
1157
+
1158
+ logger.info(f"Started file batch {batch_id} with {len(checks)} files")
1159
+
1160
+ return {
1161
+ "batch_id": batch_id,
1162
+ "batch_label": label,
1163
+ "total_papers": len(checks),
1164
+ "checks": checks
1165
+ }
1166
+
1167
+ except HTTPException:
1168
+ raise
1169
+ except Exception as e:
1170
+ logger.error(f"Error starting batch file check: {e}", exc_info=True)
1171
+ raise HTTPException(status_code=500, detail=str(e))
1172
+
1173
+
1174
+ @app.get("/api/batch/{batch_id}")
1175
+ async def get_batch(batch_id: str):
1176
+ """Get batch summary and all checks in the batch"""
1177
+ try:
1178
+ summary = await db.get_batch_summary(batch_id)
1179
+ if not summary:
1180
+ raise HTTPException(status_code=404, detail="Batch not found")
1181
+
1182
+ checks = await db.get_batch_checks(batch_id)
1183
+
1184
+ # Add session_id for in-progress checks
1185
+ for check in checks:
1186
+ if check.get("status") == "in_progress":
1187
+ session_id = _session_id_for_check(check["id"])
1188
+ if session_id:
1189
+ check["session_id"] = session_id
1190
+
1191
+ return {
1192
+ **summary,
1193
+ "checks": checks
1194
+ }
1195
+ except HTTPException:
1196
+ raise
1197
+ except Exception as e:
1198
+ logger.error(f"Error getting batch: {e}", exc_info=True)
1199
+ raise HTTPException(status_code=500, detail=str(e))
1200
+
1201
+
1202
+ @app.post("/api/cancel/batch/{batch_id}")
1203
+ async def cancel_batch(batch_id: str):
1204
+ """Cancel all active checks in a batch"""
1205
+ try:
1206
+ # Cancel active tasks
1207
+ cancelled_sessions = 0
1208
+ for session_id, meta in list(active_checks.items()):
1209
+ if meta.get("batch_id") == batch_id:
1210
+ meta["cancel_event"].set()
1211
+ meta["task"].cancel()
1212
+ cancelled_sessions += 1
1213
+
1214
+ # Update database status for any remaining in-progress
1215
+ db_cancelled = await db.cancel_batch(batch_id)
1216
+
1217
+ logger.info(f"Cancelled batch {batch_id}: {cancelled_sessions} active, {db_cancelled} in DB")
1218
+
1219
+ return {
1220
+ "message": "Batch cancellation requested",
1221
+ "batch_id": batch_id,
1222
+ "cancelled_active": cancelled_sessions,
1223
+ "cancelled_pending": db_cancelled
1224
+ }
1225
+ except Exception as e:
1226
+ logger.error(f"Error cancelling batch: {e}", exc_info=True)
1227
+ raise HTTPException(status_code=500, detail=str(e))
1228
+
1229
+
1230
+ @app.delete("/api/batch/{batch_id}")
1231
+ async def delete_batch(batch_id: str):
1232
+ """Delete all checks in a batch"""
1233
+ try:
1234
+ # First cancel any active checks
1235
+ for session_id, meta in list(active_checks.items()):
1236
+ if meta.get("batch_id") == batch_id:
1237
+ meta["cancel_event"].set()
1238
+ meta["task"].cancel()
1239
+ active_checks.pop(session_id, None)
1240
+
1241
+ # Delete from database
1242
+ deleted_count = await db.delete_batch(batch_id)
1243
+
1244
+ if deleted_count == 0:
1245
+ raise HTTPException(status_code=404, detail="Batch not found")
1246
+
1247
+ logger.info(f"Deleted batch {batch_id}: {deleted_count} checks")
1248
+
1249
+ return {
1250
+ "message": "Batch deleted successfully",
1251
+ "batch_id": batch_id,
1252
+ "deleted_count": deleted_count
1253
+ }
1254
+ except HTTPException:
1255
+ raise
1256
+ except Exception as e:
1257
+ logger.error(f"Error deleting batch: {e}", exc_info=True)
1258
+ raise HTTPException(status_code=500, detail=str(e))
1259
+
1260
+
1261
+ @app.patch("/api/batch/{batch_id}")
1262
+ async def update_batch_label(batch_id: str, update: BatchLabelUpdate):
1263
+ """Update the label for a batch"""
1264
+ try:
1265
+ success = await db.update_batch_label(batch_id, update.batch_label)
1266
+ if success:
1267
+ return {"message": "Batch label updated successfully"}
1268
+ else:
1269
+ raise HTTPException(status_code=404, detail="Batch not found")
1270
+ except HTTPException:
1271
+ raise
1272
+ except Exception as e:
1273
+ logger.error(f"Error updating batch label: {e}", exc_info=True)
1274
+ raise HTTPException(status_code=500, detail=str(e))
1275
+
1276
+
1277
+ @app.post("/api/recheck/batch/{batch_id}")
1278
+ async def recheck_batch(batch_id: str):
1279
+ """Re-run all checks in a batch"""
1280
+ try:
1281
+ # Get original batch checks
1282
+ original_checks = await db.get_batch_checks(batch_id)
1283
+ if not original_checks:
1284
+ raise HTTPException(status_code=404, detail="Batch not found")
1285
+
1286
+ # Create new batch
1287
+ new_batch_id = str(uuid.uuid4())
1288
+ original_label = original_checks[0].get("batch_label", "Re-checked batch")
1289
+ new_label = f"Re-check: {original_label}"
1290
+
1291
+ checks = []
1292
+ for original in original_checks:
1293
+ session_id = str(uuid.uuid4())
1294
+ source = original["paper_source"]
1295
+ source_type = original.get("source_type", "url")
1296
+ llm_provider = original.get("llm_provider", "anthropic")
1297
+ llm_model = original.get("llm_model")
1298
+
1299
+ check_id = await db.create_pending_check(
1300
+ paper_title=original.get("paper_title", "Re-checking..."),
1301
+ paper_source=source,
1302
+ source_type=source_type,
1303
+ llm_provider=llm_provider,
1304
+ llm_model=llm_model,
1305
+ batch_id=new_batch_id,
1306
+ batch_label=new_label,
1307
+ original_filename=original.get("original_filename")
1308
+ )
1309
+
1310
+ cancel_event = asyncio.Event()
1311
+ task = asyncio.create_task(
1312
+ run_check(
1313
+ session_id, check_id, source, source_type,
1314
+ llm_provider, llm_model, None, None, True, cancel_event
1315
+ )
1316
+ )
1317
+ active_checks[session_id] = {
1318
+ "task": task,
1319
+ "cancel_event": cancel_event,
1320
+ "check_id": check_id,
1321
+ "batch_id": new_batch_id
1322
+ }
1323
+
1324
+ checks.append({
1325
+ "session_id": session_id,
1326
+ "check_id": check_id,
1327
+ "original_id": original["id"],
1328
+ "source": source
1329
+ })
1330
+
1331
+ logger.info(f"Re-started batch {batch_id} as {new_batch_id} with {len(checks)} papers")
1332
+
1333
+ return {
1334
+ "batch_id": new_batch_id,
1335
+ "batch_label": new_label,
1336
+ "original_batch_id": batch_id,
1337
+ "total_papers": len(checks),
1338
+ "checks": checks
1339
+ }
1340
+ except HTTPException:
1341
+ raise
1342
+ except Exception as e:
1343
+ logger.error(f"Error rechecking batch: {e}", exc_info=True)
1344
+ raise HTTPException(status_code=500, detail=str(e))
1345
+
1346
+
1347
+ # ============ End Batch Operations ============
1348
+
1349
+
903
1350
  @app.delete("/api/history/{check_id}")
904
1351
  async def delete_check(check_id: int):
905
1352
  """Delete a check from history"""