pearmut 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pearmut/app.py CHANGED
@@ -280,6 +280,79 @@ async def _reset_task(request: ResetTaskRequest):
280
280
  return response
281
281
 
282
282
 
283
+ class PurgeCampaignRequest(BaseModel):
284
+ campaign_id: str
285
+ token: str
286
+
287
+
288
+ @app.post("/purge-campaign")
289
+ async def _purge_campaign(request: PurgeCampaignRequest):
290
+ global progress_data, tasks_data
291
+
292
+ campaign_id = request.campaign_id
293
+ token = request.token
294
+
295
+ if campaign_id not in progress_data:
296
+ return JSONResponse(content="Unknown campaign ID", status_code=400)
297
+ if token != tasks_data[campaign_id]["token"]:
298
+ return JSONResponse(content="Invalid token", status_code=400)
299
+
300
+ # Unlink assets if they exist
301
+ destination = tasks_data[campaign_id].get("info", {}).get("assets", {}).get("destination")
302
+ if destination:
303
+ symlink_path = f"{ROOT}/data/{destination}".rstrip("/")
304
+ if os.path.islink(symlink_path):
305
+ os.remove(symlink_path)
306
+
307
+ # Remove task file
308
+ task_file = f"{ROOT}/data/tasks/{campaign_id}.json"
309
+ if os.path.exists(task_file):
310
+ os.remove(task_file)
311
+
312
+ # Remove output file
313
+ output_file = f"{ROOT}/data/outputs/{campaign_id}.jsonl"
314
+ if os.path.exists(output_file):
315
+ os.remove(output_file)
316
+
317
+ # Remove from in-memory data structures
318
+ del tasks_data[campaign_id]
319
+ del progress_data[campaign_id]
320
+
321
+ # Save updated progress data
322
+ save_progress_data(progress_data)
323
+
324
+ return JSONResponse(content="ok", status_code=200)
325
+
326
+
327
+ class AddCampaignRequest(BaseModel):
328
+ campaign_data: dict[str, Any]
329
+
330
+
331
+ @app.post("/add-campaign")
332
+ async def _add_campaign(request: AddCampaignRequest):
333
+ global progress_data, tasks_data
334
+
335
+ from .cli import _add_single_campaign
336
+
337
+ try:
338
+ server = f"{os.environ.get('PEARMUT_SERVER_URL', 'http://localhost:8001')}"
339
+ _add_single_campaign(request.campaign_data, overwrite=False, server=server)
340
+
341
+ campaign_id = request.campaign_data['campaign_id']
342
+ with open(f"{ROOT}/data/tasks/{campaign_id}.json", "r") as f:
343
+ tasks_data[campaign_id] = json.load(f)
344
+
345
+ progress_data = load_progress_data(warn=None)
346
+
347
+ return JSONResponse(content={
348
+ "status": "ok",
349
+ "campaign_id": campaign_id,
350
+ "token": tasks_data[campaign_id]["token"]
351
+ }, status_code=200)
352
+ except Exception as e:
353
+ return JSONResponse(content={"error": str(e)}, status_code=400)
354
+
355
+
283
356
  @app.get("/download-annotations")
284
357
  async def _download_annotations(
285
358
  campaign_id: list[str] = Query(),
pearmut/assignment.py CHANGED
@@ -1,10 +1,12 @@
1
1
  import collections
2
+ import copy
2
3
  import random
3
4
  import statistics
4
5
  from typing import Any
5
6
 
6
7
  from fastapi.responses import JSONResponse
7
8
 
9
+ from .constants import PROTOCOL_INSTRUCTIONS
8
10
  from .utils import (
9
11
  RESET_MARKER,
10
12
  check_validation_threshold,
@@ -14,6 +16,15 @@ from .utils import (
14
16
  )
15
17
 
16
18
 
19
+ def _get_instructions(tasks_data: dict, campaign_id: str) -> str:
20
+ """Get instructions: custom if provided, else protocol default, else empty."""
21
+ campaign_info = tasks_data[campaign_id]["info"]
22
+ if "instructions" in campaign_info:
23
+ return campaign_info["instructions"]
24
+ return PROTOCOL_INSTRUCTIONS.get(campaign_info.get("protocol", ""), "")
25
+
26
+
27
+
17
28
  def _completed_response(
18
29
  tasks_data: dict,
19
30
  progress_data: dict,
@@ -132,11 +143,12 @@ def get_i_item_taskbased(
132
143
  "time": user_progress["time"],
133
144
  "info": {
134
145
  "item_i": item_i,
146
+ "instructions": _get_instructions(data_all, campaign_id),
135
147
  }
136
148
  | {
137
149
  k: v
138
150
  for k, v in data_all[campaign_id]["info"].items()
139
- if k.startswith("protocol")
151
+ if k in {"protocol", "sliders"}
140
152
  },
141
153
  "payload": data_all[campaign_id]["data"][user_id][item_i],
142
154
  }
@@ -178,11 +190,12 @@ def get_i_item_singlestream(
178
190
  "time": user_progress["time"],
179
191
  "info": {
180
192
  "item_i": item_i,
193
+ "instructions": _get_instructions(data_all, campaign_id),
181
194
  }
182
195
  | {
183
196
  k: v
184
197
  for k, v in data_all[campaign_id]["info"].items()
185
- if k.startswith("protocol")
198
+ if k in {"protocol", "sliders"}
186
199
  },
187
200
  "payload": data_all[campaign_id]["data"][item_i],
188
201
  }
@@ -224,11 +237,12 @@ def get_next_item_taskbased(
224
237
  "time": user_progress["time"],
225
238
  "info": {
226
239
  "item_i": item_i,
240
+ "instructions": _get_instructions(data_all, campaign_id),
227
241
  }
228
242
  | {
229
243
  k: v
230
244
  for k, v in data_all[campaign_id]["info"].items()
231
- if k.startswith("protocol")
245
+ if k in {"protocol", "sliders"}
232
246
  },
233
247
  "payload": data_all[campaign_id]["data"][user_id][item_i],
234
248
  }
@@ -279,11 +293,12 @@ def get_next_item_singlestream(
279
293
  "progress": progress,
280
294
  "info": {
281
295
  "item_i": item_i,
296
+ "instructions": _get_instructions(data_all, campaign_id),
282
297
  }
283
298
  | {
284
299
  k: v
285
300
  for k, v in data_all[campaign_id]["info"].items()
286
- if k.startswith("protocol")
301
+ if k in {"protocol", "sliders"}
287
302
  },
288
303
  "payload": data_all[campaign_id]["data"][item_i],
289
304
  }
@@ -439,11 +454,12 @@ def get_next_item_dynamic(
439
454
  "progress": user_progress["progress"],
440
455
  "info": {
441
456
  "item_i": item_i,
457
+ "instructions": _get_instructions(tasks_data, campaign_id),
442
458
  }
443
459
  | {
444
460
  k: v
445
461
  for k, v in campaign_data["info"].items()
446
- if k.startswith("protocol")
462
+ if k in {"protocol", "sliders"}
447
463
  },
448
464
  "payload": pruned_item,
449
465
  },
@@ -459,6 +475,26 @@ def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> Non
459
475
  progress_data[campaign_id][user_id]["validations"] = {}
460
476
 
461
477
 
478
+ def _get_user_annotated_items(campaign_id: str, user_id: str) -> set[int]:
479
+ """
480
+ Get the set of item indices that a specific user has annotated.
481
+
482
+ Args:
483
+ campaign_id: The campaign identifier
484
+ user_id: The user identifier
485
+
486
+ Returns:
487
+ Set of item indices (item_i) that the user has annotated
488
+ """
489
+ log = get_db_log(campaign_id)
490
+ user_items = set()
491
+ for entry in log:
492
+ if entry.get("user_id") == user_id and entry.get("annotation") != RESET_MARKER:
493
+ if (item_i := entry.get("item_i")) is not None:
494
+ user_items.add(item_i)
495
+ return user_items
496
+
497
+
462
498
  def reset_task(
463
499
  campaign_id: str,
464
500
  user_id: str,
@@ -482,29 +518,46 @@ def reset_task(
482
518
  _reset_user_time(progress_data, campaign_id, user_id)
483
519
  return JSONResponse(content="ok", status_code=200)
484
520
  elif assignment == "single-stream":
485
- # Save reset markers for all items (shared pool)
486
- num_items = len(tasks_data[campaign_id]["data"])
487
- for item_i in range(num_items):
521
+ # Find all items that this user has annotated
522
+ user_items = _get_user_annotated_items(campaign_id, user_id)
523
+
524
+ # Save reset markers only for items this user has touched
525
+ for item_i in user_items:
488
526
  save_db_payload(
489
527
  campaign_id,
490
- {"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
528
+ {"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
491
529
  )
492
- # for single-stream reset all progress
530
+
531
+ # Reset only the touched items in all users' progress (shared pool)
493
532
  for uid in progress_data[campaign_id]:
494
- progress_data[campaign_id][uid]["progress"] = [False] * num_items
533
+ for item_i in user_items:
534
+ progress_data[campaign_id][uid]["progress"][item_i] = False
535
+
536
+ # Reset only the specified user's time
495
537
  _reset_user_time(progress_data, campaign_id, user_id)
496
538
  return JSONResponse(content="ok", status_code=200)
497
539
  elif assignment == "dynamic":
498
- # Save reset markers for all items (shared pool like single-stream)
499
- num_items = len(tasks_data[campaign_id]["data"])
500
- for item_i in range(num_items):
540
+ # Find all items that this user has annotated
541
+ user_items = _get_user_annotated_items(campaign_id, user_id)
542
+
543
+ # Save reset markers only for items this user has touched
544
+ for item_i in user_items:
501
545
  save_db_payload(
502
546
  campaign_id,
503
- {"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
547
+ {"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
504
548
  )
505
- # for dynamic reset all progress (use sets to track models)
549
+
550
+ progress_data_user = copy.deepcopy(progress_data[campaign_id][user_id]["progress"])
551
+
552
+ # Reset only the touched items in all users' progress (shared pool, use lists to track models)
506
553
  for uid in progress_data[campaign_id]:
507
- progress_data[campaign_id][uid]["progress"] = [[] for _ in range(num_items)]
554
+ for item_i in user_items:
555
+ progress_data[campaign_id][uid]["progress"][item_i] = [
556
+ x for x in progress_data[campaign_id][uid]["progress"][item_i]
557
+ if x not in progress_data_user[item_i]
558
+ ]
559
+
560
+ # Reset only the specified user's time
508
561
  _reset_user_time(progress_data, campaign_id, user_id)
509
562
  return JSONResponse(content="ok", status_code=200)
510
563
  else: