pearmut 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pearmut/app.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any
4
4
 
5
5
  from fastapi import FastAPI, Query
6
6
  from fastapi.middleware.cors import CORSMiddleware
7
- from fastapi.responses import JSONResponse, Response
7
+ from fastapi.responses import FileResponse, JSONResponse, Response
8
8
  from fastapi.staticfiles import StaticFiles
9
9
  from pydantic import BaseModel
10
10
 
@@ -17,6 +17,7 @@ from .results_export import (
17
17
  )
18
18
  from .utils import (
19
19
  ROOT,
20
+ TOKEN_MAIN,
20
21
  check_validation_threshold,
21
22
  load_progress_data,
22
23
  save_db_payload,
@@ -48,7 +49,7 @@ for campaign_id in progress_data.keys():
48
49
  class LogResponseRequest(BaseModel):
49
50
  campaign_id: str
50
51
  user_id: str
51
- item_i: int
52
+ item_i: int | str
52
53
  payload: dict[str, Any]
53
54
 
54
55
 
@@ -123,7 +124,7 @@ async def _get_next_item(request: NextItemRequest):
123
124
  class GetItemRequest(BaseModel):
124
125
  campaign_id: str
125
126
  user_id: str
126
- item_i: int
127
+ item_i: int | str
127
128
 
128
129
 
129
130
  @app.post("/get-i-item")
@@ -178,7 +179,11 @@ async def _dashboard_data(request: DashboardDataRequest):
178
179
  ]
179
180
 
180
181
  # Add threshold pass/fail status (only when user is complete)
181
- if all(entry["progress"]):
182
+ if (
183
+ tasks_data[campaign_id]["info"]["assignment"] != "dynamic" and all(v in {"completed", "completed_foreign"} for v in entry["progress"])
184
+ ) or (
185
+ tasks_data[campaign_id]["info"]["assignment"] == "dynamic" and all(v in {"completed", "completed_foreign"} for mv in entry["progress"] for v in mv.values())
186
+ ):
182
187
  entry["threshold_passed"] = check_validation_threshold(
183
188
  tasks_data, progress_data, campaign_id, user_id
184
189
  )
@@ -192,7 +197,11 @@ async def _dashboard_data(request: DashboardDataRequest):
192
197
  progress_new[user_id] = entry
193
198
 
194
199
  return JSONResponse(
195
- content={"data": progress_new, "validation_threshold": validation_threshold},
200
+ content={
201
+ "data": progress_new,
202
+ "validation_threshold": validation_threshold,
203
+ "assignment": assignment,
204
+ },
196
205
  status_code=200,
197
206
  )
198
207
 
@@ -288,7 +297,7 @@ class PurgeCampaignRequest(BaseModel):
288
297
  @app.post("/purge-campaign")
289
298
  async def _purge_campaign(request: PurgeCampaignRequest):
290
299
  global progress_data, tasks_data
291
-
300
+
292
301
  campaign_id = request.campaign_id
293
302
  token = request.token
294
303
 
@@ -298,57 +307,69 @@ async def _purge_campaign(request: PurgeCampaignRequest):
298
307
  return JSONResponse(content="Invalid token", status_code=400)
299
308
 
300
309
  # Unlink assets if they exist
301
- destination = tasks_data[campaign_id].get("info", {}).get("assets", {}).get("destination")
310
+ destination = (
311
+ tasks_data[campaign_id].get("info", {}).get("assets", {}).get("destination")
312
+ )
302
313
  if destination:
303
314
  symlink_path = f"{ROOT}/data/{destination}".rstrip("/")
304
315
  if os.path.islink(symlink_path):
305
316
  os.remove(symlink_path)
306
-
317
+
307
318
  # Remove task file
308
319
  task_file = f"{ROOT}/data/tasks/{campaign_id}.json"
309
320
  if os.path.exists(task_file):
310
321
  os.remove(task_file)
311
-
322
+
312
323
  # Remove output file
313
324
  output_file = f"{ROOT}/data/outputs/{campaign_id}.jsonl"
314
325
  if os.path.exists(output_file):
315
326
  os.remove(output_file)
316
-
327
+
317
328
  # Remove from in-memory data structures
318
329
  del tasks_data[campaign_id]
319
330
  del progress_data[campaign_id]
320
-
331
+
321
332
  # Save updated progress data
322
333
  save_progress_data(progress_data)
323
-
334
+
324
335
  return JSONResponse(content="ok", status_code=200)
325
336
 
326
337
 
327
338
  class AddCampaignRequest(BaseModel):
328
339
  campaign_data: dict[str, Any]
340
+ token_main: str
329
341
 
330
342
 
331
343
  @app.post("/add-campaign")
332
344
  async def _add_campaign(request: AddCampaignRequest):
333
345
  global progress_data, tasks_data
334
-
346
+
335
347
  from .cli import _add_single_campaign
336
-
348
+
349
+ if request.token_main != TOKEN_MAIN:
350
+ return JSONResponse(
351
+ content={"error": "Invalid main token. Use the latest one."},
352
+ status_code=400,
353
+ )
354
+
337
355
  try:
338
356
  server = f"{os.environ.get('PEARMUT_SERVER_URL', 'http://localhost:8001')}"
339
357
  _add_single_campaign(request.campaign_data, overwrite=False, server=server)
340
-
341
- campaign_id = request.campaign_data['campaign_id']
358
+
359
+ campaign_id = request.campaign_data["campaign_id"]
342
360
  with open(f"{ROOT}/data/tasks/{campaign_id}.json", "r") as f:
343
361
  tasks_data[campaign_id] = json.load(f)
344
-
362
+
345
363
  progress_data = load_progress_data(warn=None)
346
-
347
- return JSONResponse(content={
348
- "status": "ok",
349
- "campaign_id": campaign_id,
350
- "token": tasks_data[campaign_id]["token"]
351
- }, status_code=200)
364
+
365
+ return JSONResponse(
366
+ content={
367
+ "status": "ok",
368
+ "campaign_id": campaign_id,
369
+ "token": tasks_data[campaign_id]["token"],
370
+ },
371
+ status_code=200,
372
+ )
352
373
  except Exception as e:
353
374
  return JSONResponse(content={"error": str(e)}, status_code=400)
354
375
 
@@ -359,7 +380,6 @@ async def _download_annotations(
359
380
  # NOTE: currently not checking tokens for progress download as it is non-destructive
360
381
  # token: list[str] = Query()
361
382
  ):
362
-
363
383
  output = {}
364
384
  for campaign_id in campaign_id:
365
385
  output_path = f"{ROOT}/data/outputs/{campaign_id}.jsonl"
@@ -386,7 +406,6 @@ async def _download_annotations(
386
406
  async def _download_progress(
387
407
  campaign_id: list[str] = Query(), token: list[str] = Query()
388
408
  ):
389
-
390
409
  if len(campaign_id) != len(token):
391
410
  return JSONResponse(
392
411
  content="Mismatched campaign_id and token count", status_code=400
@@ -418,6 +437,18 @@ if not os.path.exists(static_dir + "index.html"):
418
437
  "Static directory not found. Please build the frontend first."
419
438
  )
420
439
 
440
+
441
+ # Serve HTML files directly without redirect
442
+ @app.get("/annotate")
443
+ async def serve_annotate():
444
+ return FileResponse(static_dir + "annotate.html")
445
+
446
+
447
+ @app.get("/dashboard")
448
+ async def serve_dashboard():
449
+ return FileResponse(static_dir + "dashboard.html")
450
+
451
+
421
452
  # Mount user assets from data/assets/
422
453
  assets_dir = f"{ROOT}/data/assets"
423
454
  os.makedirs(assets_dir, exist_ok=True)