pearmut 0.3.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pearmut/app.py CHANGED
@@ -1,20 +1,23 @@
1
- import collections
2
1
  import json
3
2
  import os
4
- import statistics
5
3
  from typing import Any
6
4
 
7
5
  from fastapi import FastAPI, Query
8
6
  from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import JSONResponse
7
+ from fastapi.responses import JSONResponse, Response
10
8
  from fastapi.staticfiles import StaticFiles
11
9
  from pydantic import BaseModel
12
10
 
13
11
  from .assignment import get_i_item, get_next_item, reset_task, update_progress
12
+ from .results_export import (
13
+ compute_model_scores,
14
+ generate_latex_table,
15
+ generate_pdf,
16
+ generate_typst_table,
17
+ )
14
18
  from .utils import (
15
19
  ROOT,
16
20
  check_validation_threshold,
17
- get_db_log,
18
21
  load_progress_data,
19
22
  save_db_payload,
20
23
  save_progress_data,
@@ -159,7 +162,7 @@ async def _dashboard_data(request: DashboardDataRequest):
159
162
 
160
163
  progress_new = {}
161
164
  assignment = tasks_data[campaign_id]["info"]["assignment"]
162
- if assignment not in ["task-based", "single-stream"]:
165
+ if assignment not in ["task-based", "single-stream", "dynamic"]:
163
166
  return JSONResponse(
164
167
  content="Unsupported campaign assignment type", status_code=400
165
168
  )
@@ -211,31 +214,47 @@ async def _dashboard_results(request: DashboardResultsRequest):
211
214
  if token != tasks_data[campaign_id]["token"]:
212
215
  return JSONResponse(content="Invalid token", status_code=400)
213
216
 
214
- # Compute model scores from annotations
215
- model_scores = collections.defaultdict(dict)
216
-
217
- # Iterate through all tasks to find items with 'models' field (basic template)
218
- log = get_db_log(campaign_id)
219
- for entry in log:
220
- if "item" not in entry or "annotation" not in entry:
221
- continue
222
- for item, annotation in zip(entry["item"], entry["annotation"]):
223
- for model, annotation in annotation.items():
224
- if "score" in annotation and annotation["score"] is not None:
225
- model_scores[model][json.dumps(item)] = annotation["score"]
226
-
227
- results = [
228
- {
229
- "model": model,
230
- "score": statistics.mean(scores.values()),
231
- "count": len(scores),
232
- }
233
- for model, scores in model_scores.items()
234
- ]
235
- results.sort(key=lambda x: x["score"], reverse=True)
217
+ results = compute_model_scores(campaign_id)
236
218
  return JSONResponse(content=results, status_code=200)
237
219
 
238
220
 
221
+ @app.get("/export-results")
222
+ async def _export_results(
223
+ campaign_id: str = Query(),
224
+ token: str = Query(),
225
+ format: str = Query(),
226
+ ):
227
+ if campaign_id not in progress_data:
228
+ return JSONResponse(content="Unknown campaign ID", status_code=400)
229
+
230
+ # Check if token is valid
231
+ if token != tasks_data[campaign_id]["token"]:
232
+ return JSONResponse(content="Invalid token", status_code=400)
233
+
234
+ results = compute_model_scores(campaign_id)
235
+
236
+ if format == "typst":
237
+ content = generate_typst_table(results)
238
+ return Response(
239
+ content=content,
240
+ media_type="text/plain",
241
+ )
242
+ elif format == "latex":
243
+ content = generate_latex_table(results)
244
+ return Response(
245
+ content=content,
246
+ media_type="text/plain",
247
+ )
248
+ elif format == "pdf":
249
+ pdf_bytes = generate_pdf(results, campaign_id)
250
+ return Response(
251
+ content=pdf_bytes,
252
+ media_type="application/pdf",
253
+ )
254
+ else:
255
+ return JSONResponse(content="Invalid export format", status_code=400)
256
+
257
+
239
258
  class ResetTaskRequest(BaseModel):
240
259
  campaign_id: str
241
260
  user_id: str
@@ -261,6 +280,79 @@ async def _reset_task(request: ResetTaskRequest):
261
280
  return response
262
281
 
263
282
 
283
+ class PurgeCampaignRequest(BaseModel):
284
+ campaign_id: str
285
+ token: str
286
+
287
+
288
+ @app.post("/purge-campaign")
289
+ async def _purge_campaign(request: PurgeCampaignRequest):
290
+ global progress_data, tasks_data
291
+
292
+ campaign_id = request.campaign_id
293
+ token = request.token
294
+
295
+ if campaign_id not in progress_data:
296
+ return JSONResponse(content="Unknown campaign ID", status_code=400)
297
+ if token != tasks_data[campaign_id]["token"]:
298
+ return JSONResponse(content="Invalid token", status_code=400)
299
+
300
+ # Unlink assets if they exist
301
+ destination = tasks_data[campaign_id].get("info", {}).get("assets", {}).get("destination")
302
+ if destination:
303
+ symlink_path = f"{ROOT}/data/{destination}".rstrip("/")
304
+ if os.path.islink(symlink_path):
305
+ os.remove(symlink_path)
306
+
307
+ # Remove task file
308
+ task_file = f"{ROOT}/data/tasks/{campaign_id}.json"
309
+ if os.path.exists(task_file):
310
+ os.remove(task_file)
311
+
312
+ # Remove output file
313
+ output_file = f"{ROOT}/data/outputs/{campaign_id}.jsonl"
314
+ if os.path.exists(output_file):
315
+ os.remove(output_file)
316
+
317
+ # Remove from in-memory data structures
318
+ del tasks_data[campaign_id]
319
+ del progress_data[campaign_id]
320
+
321
+ # Save updated progress data
322
+ save_progress_data(progress_data)
323
+
324
+ return JSONResponse(content="ok", status_code=200)
325
+
326
+
327
+ class AddCampaignRequest(BaseModel):
328
+ campaign_data: dict[str, Any]
329
+
330
+
331
+ @app.post("/add-campaign")
332
+ async def _add_campaign(request: AddCampaignRequest):
333
+ global progress_data, tasks_data
334
+
335
+ from .cli import _add_single_campaign
336
+
337
+ try:
338
+ server = f"{os.environ.get('PEARMUT_SERVER_URL', 'http://localhost:8001')}"
339
+ _add_single_campaign(request.campaign_data, overwrite=False, server=server)
340
+
341
+ campaign_id = request.campaign_data['campaign_id']
342
+ with open(f"{ROOT}/data/tasks/{campaign_id}.json", "r") as f:
343
+ tasks_data[campaign_id] = json.load(f)
344
+
345
+ progress_data = load_progress_data(warn=None)
346
+
347
+ return JSONResponse(content={
348
+ "status": "ok",
349
+ "campaign_id": campaign_id,
350
+ "token": tasks_data[campaign_id]["token"]
351
+ }, status_code=200)
352
+ except Exception as e:
353
+ return JSONResponse(content={"error": str(e)}, status_code=400)
354
+
355
+
264
356
  @app.get("/download-annotations")
265
357
  async def _download_annotations(
266
358
  campaign_id: list[str] = Query(),
pearmut/assignment.py CHANGED
@@ -1,16 +1,30 @@
1
+ import collections
2
+ import copy
1
3
  import random
4
+ import statistics
2
5
  from typing import Any
3
6
 
4
7
  from fastapi.responses import JSONResponse
5
8
 
9
+ from .constants import PROTOCOL_INSTRUCTIONS
6
10
  from .utils import (
7
11
  RESET_MARKER,
8
12
  check_validation_threshold,
13
+ get_db_log,
9
14
  get_db_log_item,
10
15
  save_db_payload,
11
16
  )
12
17
 
13
18
 
19
+ def _get_instructions(tasks_data: dict, campaign_id: str) -> str:
20
+ """Get instructions: custom if provided, else protocol default, else empty."""
21
+ campaign_info = tasks_data[campaign_id]["info"]
22
+ if "instructions" in campaign_info:
23
+ return campaign_info["instructions"]
24
+ return PROTOCOL_INSTRUCTIONS.get(campaign_info.get("protocol", ""), "")
25
+
26
+
27
+
14
28
  def _completed_response(
15
29
  tasks_data: dict,
16
30
  progress_data: dict,
@@ -20,14 +34,33 @@ def _completed_response(
20
34
  """Build a completed response with progress, time, and token."""
21
35
  user_progress = progress_data[campaign_id][user_id]
22
36
  is_ok = check_validation_threshold(tasks_data, progress_data, campaign_id, user_id)
37
+ token = user_progress["token_correct" if is_ok else "token_incorrect"]
38
+
39
+ # Get instructions_goodbye from campaign info, with default value
40
+ instructions_goodbye = tasks_data[campaign_id]["info"].get(
41
+ "instructions_goodbye",
42
+ "If someone asks you for a token of completion, show them: ${TOKEN}",
43
+ )
44
+
45
+ # Replace variables ${TOKEN} and ${USER_ID}
46
+ instructions_goodbye = instructions_goodbye.replace("${TOKEN}", token).replace(
47
+ "${USER_ID}", user_id
48
+ )
49
+
50
+ # Convert sets to lists for JSON serialization (for dynamic assignment)
51
+ progress = user_progress["progress"]
52
+ if progress and isinstance(progress[0], set):
53
+ progress = [list(s) for s in progress]
54
+
23
55
  return JSONResponse(
24
56
  content={
25
- "status": "completed",
26
- "progress": user_progress["progress"],
57
+ "status": "goodbye",
58
+ "progress": progress,
27
59
  "time": user_progress["time"],
28
- "token": user_progress["token_correct" if is_ok else "token_incorrect"],
60
+ "token": token,
61
+ "instructions_goodbye": instructions_goodbye,
29
62
  },
30
- status_code=200
63
+ status_code=200,
31
64
  )
32
65
 
33
66
 
@@ -44,7 +77,9 @@ def get_next_item(
44
77
  if assignment == "task-based":
45
78
  return get_next_item_taskbased(campaign_id, user_id, tasks_data, progress_data)
46
79
  elif assignment == "single-stream":
47
- return get_next_item_singlestream(campaign_id, user_id, tasks_data, progress_data)
80
+ return get_next_item_singlestream(
81
+ campaign_id, user_id, tasks_data, progress_data
82
+ )
48
83
  elif assignment == "dynamic":
49
84
  return get_next_item_dynamic(campaign_id, user_id, tasks_data, progress_data)
50
85
  else:
@@ -63,11 +98,17 @@ def get_i_item(
63
98
  """
64
99
  assignment = tasks_data[campaign_id]["info"]["assignment"]
65
100
  if assignment == "task-based":
66
- return get_i_item_taskbased(campaign_id, user_id, tasks_data, progress_data, item_i)
101
+ return get_i_item_taskbased(
102
+ campaign_id, user_id, tasks_data, progress_data, item_i
103
+ )
67
104
  elif assignment == "single-stream":
68
- return get_i_item_singlestream(campaign_id, user_id, tasks_data, progress_data, item_i)
105
+ return get_i_item_singlestream(
106
+ campaign_id, user_id, tasks_data, progress_data, item_i
107
+ )
69
108
  else:
70
- return JSONResponse(content="Get item not supported for this assignment type", status_code=400)
109
+ return JSONResponse(
110
+ content="Get item not supported for this assignment type", status_code=400
111
+ )
71
112
 
72
113
 
73
114
  def get_i_item_taskbased(
@@ -93,10 +134,7 @@ def get_i_item_taskbased(
93
134
  payload_existing["comment"] = latest_item["comment"]
94
135
 
95
136
  if item_i < 0 or item_i >= len(data_all[campaign_id]["data"][user_id]):
96
- return JSONResponse(
97
- content="Item index out of range",
98
- status_code=400
99
- )
137
+ return JSONResponse(content="Item index out of range", status_code=400)
100
138
 
101
139
  return JSONResponse(
102
140
  content={
@@ -105,14 +143,17 @@ def get_i_item_taskbased(
105
143
  "time": user_progress["time"],
106
144
  "info": {
107
145
  "item_i": item_i,
108
- } | {
146
+ "instructions": _get_instructions(data_all, campaign_id),
147
+ }
148
+ | {
109
149
  k: v
110
150
  for k, v in data_all[campaign_id]["info"].items()
111
- if k.startswith("protocol")
151
+ if k in {"protocol", "sliders"}
112
152
  },
113
- "payload": data_all[campaign_id]["data"][user_id][item_i]
114
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
115
- status_code=200
153
+ "payload": data_all[campaign_id]["data"][user_id][item_i],
154
+ }
155
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
156
+ status_code=200,
116
157
  )
117
158
 
118
159
 
@@ -140,10 +181,7 @@ def get_i_item_singlestream(
140
181
  payload_existing["comment"] = latest_item["comment"]
141
182
 
142
183
  if item_i < 0 or item_i >= len(data_all[campaign_id]["data"]):
143
- return JSONResponse(
144
- content="Item index out of range",
145
- status_code=400
146
- )
184
+ return JSONResponse(content="Item index out of range", status_code=400)
147
185
 
148
186
  return JSONResponse(
149
187
  content={
@@ -152,14 +190,17 @@ def get_i_item_singlestream(
152
190
  "time": user_progress["time"],
153
191
  "info": {
154
192
  "item_i": item_i,
155
- } | {
193
+ "instructions": _get_instructions(data_all, campaign_id),
194
+ }
195
+ | {
156
196
  k: v
157
197
  for k, v in data_all[campaign_id]["info"].items()
158
- if k.startswith("protocol")
198
+ if k in {"protocol", "sliders"}
159
199
  },
160
- "payload": data_all[campaign_id]["data"][item_i]
161
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
162
- status_code=200
200
+ "payload": data_all[campaign_id]["data"][item_i],
201
+ }
202
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
203
+ status_code=200,
163
204
  )
164
205
 
165
206
 
@@ -196,14 +237,17 @@ def get_next_item_taskbased(
196
237
  "time": user_progress["time"],
197
238
  "info": {
198
239
  "item_i": item_i,
199
- } | {
240
+ "instructions": _get_instructions(data_all, campaign_id),
241
+ }
242
+ | {
200
243
  k: v
201
244
  for k, v in data_all[campaign_id]["info"].items()
202
- if k.startswith("protocol")
245
+ if k in {"protocol", "sliders"}
203
246
  },
204
- "payload": data_all[campaign_id]["data"][user_id][item_i]
205
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
206
- status_code=200
247
+ "payload": data_all[campaign_id]["data"][user_id][item_i],
248
+ }
249
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
250
+ status_code=200,
207
251
  )
208
252
 
209
253
 
@@ -249,21 +293,178 @@ def get_next_item_singlestream(
249
293
  "progress": progress,
250
294
  "info": {
251
295
  "item_i": item_i,
252
- } | {
296
+ "instructions": _get_instructions(data_all, campaign_id),
297
+ }
298
+ | {
253
299
  k: v
254
300
  for k, v in data_all[campaign_id]["info"].items()
255
- if k.startswith("protocol")
301
+ if k in {"protocol", "sliders"}
256
302
  },
257
- "payload": data_all[campaign_id]["data"][item_i]
258
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
259
- status_code=200
303
+ "payload": data_all[campaign_id]["data"][item_i],
304
+ }
305
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
306
+ status_code=200,
307
+ )
308
+
309
+
310
+ def get_next_item_dynamic(
311
+ campaign_id: str,
312
+ user_id: str,
313
+ tasks_data: dict,
314
+ progress_data: dict,
315
+ ) -> JSONResponse:
316
+ """
317
+ Get the next item for dynamic assignment based on model performance.
318
+
319
+ NOTE: All items must contain all model outputs for this assignment type to work.
320
+
321
+ In this mode, items are selected based on the current performance of models:
322
+ 1. Contrastive comparison: `dynamic_contrastive_models` models are randomly selected and shown per item
323
+ 2. First phase: Each model gets `dynamic_first` annotations with fully random selection
324
+ 3. After first phase: Top `dynamic_top` models are identified, K randomly selected from them
325
+ 4. Items with least annotations for the selected models are prioritized
326
+ 5. With probability `dynamic_backoff`, uniformly random selection is used instead
327
+ """
328
+ import random
329
+
330
+ user_progress = progress_data[campaign_id][user_id]
331
+ campaign_data = tasks_data[campaign_id]
332
+
333
+ # Get all unique models in the campaign (all items must have all models)
334
+ all_models = list(set(campaign_data["data"][0][0]["tgt"].keys()))
335
+
336
+ # Check if completed (all models completed for all items)
337
+ # NOTE: this will rarely trigger but we don't have a good way to know when to end anyway for now
338
+ if all(len(v) == len(all_models) for v in user_progress["progress"]):
339
+ return _completed_response(tasks_data, progress_data, campaign_id, user_id)
340
+
341
+ # Get configuration parameters
342
+ dynamic_top = campaign_data["info"].get("dynamic_top", 2)
343
+ dynamic_first = campaign_data["info"].get("dynamic_first", 5)
344
+ dynamic_contrastive_models = campaign_data["info"].get(
345
+ "dynamic_contrastive_models", 1
346
+ )
347
+ dynamic_backoff = campaign_data["info"].get("dynamic_backoff", 0)
348
+
349
+ # Count annotations per (model, item) pair to track coverage
350
+ annotations = get_db_log(campaign_id)
351
+ model_item_counts = collections.defaultdict(int) # (model, item_i) -> count
352
+ model_total_counts = collections.defaultdict(int) # model -> total count
353
+
354
+ for annotation_line in annotations:
355
+ if (item_i := annotation_line.get("item_i")) is not None:
356
+ # Count which models were annotated in this annotation
357
+ for annotation_item in annotation_line.get("annotation", []):
358
+ for model in annotation_item:
359
+ model_item_counts[(model, item_i)] += 1
360
+ model_total_counts[model] += 1
361
+
362
+ # Check if we're still in the first phase (collecting initial data)
363
+ in_first_phase = any(
364
+ model_total_counts.get(model, 0) < dynamic_first for model in all_models
260
365
  )
261
366
 
367
+ # Select which models to show
368
+ if in_first_phase:
369
+ # First phase or backoff: select models that don't have enough annotations yet
370
+ selected_models = random.sample(
371
+ [
372
+ model
373
+ for model in all_models
374
+ if model_total_counts.get(model, 0) < dynamic_first
375
+ ],
376
+ k=min(dynamic_contrastive_models, len(all_models)),
377
+ )
378
+ elif random.random() < dynamic_backoff:
379
+ # Backoff: select K models randomly from all models
380
+ selected_models = random.sample(
381
+ all_models, k=min(dynamic_contrastive_models, len(all_models))
382
+ )
383
+ else:
384
+ # Calculate model scores from annotations
385
+ model_scores = collections.defaultdict(list)
386
+ for annotation_line in annotations:
387
+ for annotation_item in annotation_line.get("annotation", {}):
388
+ for model in annotation_item:
389
+ if "score" in annotation_item[model]:
390
+ model_scores[model].append(annotation_item[model]["score"])
391
+
392
+ # Calculate average scores
393
+ model_avg_scores = {
394
+ model: statistics.mean(scores) for model, scores in model_scores.items()
395
+ }
396
+
397
+ # Get top N models
398
+ sorted_models = sorted(
399
+ model_avg_scores.items(), key=lambda x: x[1], reverse=True
400
+ )
401
+ top_models = [model for model, score in sorted_models[:dynamic_top]]
262
402
 
403
+ # From top N, randomly select K models
404
+ selected_models = random.sample(
405
+ top_models, k=min(dynamic_contrastive_models, len(top_models))
406
+ )
263
407
 
264
- def get_next_item_dynamic(campaign_data: dict, user_id: str, progress_data: dict, data_all: dict):
265
- raise NotImplementedError("Dynamic protocol is not implemented yet.")
408
+ # Find incomplete items for the selected models (items where not all selected models are done)
409
+ item_annotation_counts = {
410
+ i: sum(model in completed_models for model in selected_models)
411
+ for i, completed_models in enumerate(user_progress["progress"])
412
+ }
413
+
414
+ # Select item with minimum annotations (with random tiebreaking)
415
+ min_annotations = min(item_annotation_counts.values())
416
+ items_with_min = [
417
+ item_i
418
+ for item_i, count in item_annotation_counts.items()
419
+ if count == min_annotations
420
+ ]
421
+ item_i = random.choice(items_with_min)
422
+
423
+ # Prune the payload to only include selected models
424
+ original_item = campaign_data["data"][item_i]
425
+ pruned_item = []
426
+ for doc_segment in original_item:
427
+ pruned_segment = doc_segment.copy()
428
+ # Filter tgt to only include selected models
429
+ pruned_segment["tgt"] = {
430
+ model: doc_segment["tgt"][model]
431
+ for model in selected_models
432
+ if model in doc_segment["tgt"]
433
+ }
434
+ # Also filter error_spans if present
435
+ if "error_spans" in doc_segment:
436
+ pruned_segment["error_spans"] = {
437
+ model: doc_segment["error_spans"][model]
438
+ for model in selected_models
439
+ if model in doc_segment.get("error_spans", {})
440
+ }
441
+ # Also filter validation if present
442
+ if "validation" in doc_segment:
443
+ pruned_segment["validation"] = {
444
+ model: doc_segment["validation"][model]
445
+ for model in selected_models
446
+ if model in doc_segment.get("validation", {})
447
+ }
448
+ pruned_item.append(pruned_segment)
266
449
 
450
+ return JSONResponse(
451
+ content={
452
+ "status": "ok",
453
+ "time": user_progress["time"],
454
+ "progress": user_progress["progress"],
455
+ "info": {
456
+ "item_i": item_i,
457
+ "instructions": _get_instructions(tasks_data, campaign_id),
458
+ }
459
+ | {
460
+ k: v
461
+ for k, v in campaign_data["info"].items()
462
+ if k in {"protocol", "sliders"}
463
+ },
464
+ "payload": pruned_item,
465
+ },
466
+ status_code=200,
467
+ )
267
468
 
268
469
 
269
470
  def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> None:
@@ -274,6 +475,26 @@ def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> Non
274
475
  progress_data[campaign_id][user_id]["validations"] = {}
275
476
 
276
477
 
478
+ def _get_user_annotated_items(campaign_id: str, user_id: str) -> set[int]:
479
+ """
480
+ Get the set of item indices that a specific user has annotated.
481
+
482
+ Args:
483
+ campaign_id: The campaign identifier
484
+ user_id: The user identifier
485
+
486
+ Returns:
487
+ Set of item indices (item_i) that the user has annotated
488
+ """
489
+ log = get_db_log(campaign_id)
490
+ user_items = set()
491
+ for entry in log:
492
+ if entry.get("user_id") == user_id and entry.get("annotation") != RESET_MARKER:
493
+ if (item_i := entry.get("item_i")) is not None:
494
+ user_items.add(item_i)
495
+ return user_items
496
+
497
+
277
498
  def reset_task(
278
499
  campaign_id: str,
279
500
  user_id: str,
@@ -289,30 +510,60 @@ def reset_task(
289
510
  # Save reset marker for this user to mask existing annotations
290
511
  num_items = len(tasks_data[campaign_id]["data"][user_id])
291
512
  for item_i in range(num_items):
292
- save_db_payload(campaign_id, {
293
- "user_id": user_id,
294
- "item_i": item_i,
295
- "annotation": RESET_MARKER
296
- })
513
+ save_db_payload(
514
+ campaign_id,
515
+ {"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
516
+ )
297
517
  progress_data[campaign_id][user_id]["progress"] = [False] * num_items
298
518
  _reset_user_time(progress_data, campaign_id, user_id)
299
519
  return JSONResponse(content="ok", status_code=200)
300
520
  elif assignment == "single-stream":
301
- # Save reset markers for all items (shared pool)
302
- num_items = len(tasks_data[campaign_id]["data"])
303
- for item_i in range(num_items):
304
- save_db_payload(campaign_id, {
305
- "user_id": None,
306
- "item_i": item_i,
307
- "annotation": RESET_MARKER
308
- })
309
- # for single-stream reset all progress
521
+ # Find all items that this user has annotated
522
+ user_items = _get_user_annotated_items(campaign_id, user_id)
523
+
524
+ # Save reset markers only for items this user has touched
525
+ for item_i in user_items:
526
+ save_db_payload(
527
+ campaign_id,
528
+ {"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
529
+ )
530
+
531
+ # Reset only the touched items in all users' progress (shared pool)
310
532
  for uid in progress_data[campaign_id]:
311
- progress_data[campaign_id][uid]["progress"] = [False] * num_items
533
+ for item_i in user_items:
534
+ progress_data[campaign_id][uid]["progress"][item_i] = False
535
+
536
+ # Reset only the specified user's time
537
+ _reset_user_time(progress_data, campaign_id, user_id)
538
+ return JSONResponse(content="ok", status_code=200)
539
+ elif assignment == "dynamic":
540
+ # Find all items that this user has annotated
541
+ user_items = _get_user_annotated_items(campaign_id, user_id)
542
+
543
+ # Save reset markers only for items this user has touched
544
+ for item_i in user_items:
545
+ save_db_payload(
546
+ campaign_id,
547
+ {"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
548
+ )
549
+
550
+ progress_data_user = copy.deepcopy(progress_data[campaign_id][user_id]["progress"])
551
+
552
+ # Reset only the touched items in all users' progress (shared pool, use lists to track models)
553
+ for uid in progress_data[campaign_id]:
554
+ for item_i in user_items:
555
+ progress_data[campaign_id][uid]["progress"][item_i] = [
556
+ x for x in progress_data[campaign_id][uid]["progress"][item_i]
557
+ if x not in progress_data_user[item_i]
558
+ ]
559
+
560
+ # Reset only the specified user's time
312
561
  _reset_user_time(progress_data, campaign_id, user_id)
313
562
  return JSONResponse(content="ok", status_code=200)
314
563
  else:
315
- return JSONResponse(content="Reset not supported for this assignment type", status_code=400)
564
+ return JSONResponse(
565
+ content="Reset not supported for this assignment type", status_code=400
566
+ )
316
567
 
317
568
 
318
569
  def update_progress(
@@ -337,6 +588,18 @@ def update_progress(
337
588
  progress_data[campaign_id][uid]["progress"][item_i] = True
338
589
  return JSONResponse(content="ok", status_code=200)
339
590
  elif assignment == "dynamic":
340
- return JSONResponse(content="Dynamic protocol logging not implemented yet.", status_code=400)
591
+ # For dynamic, track which models were annotated
592
+ # Extract models from the payload annotation
593
+ annotated_models = []
594
+ if "annotation" in payload:
595
+ for annotation_item in payload.get("annotation", []):
596
+ if isinstance(annotation_item, dict):
597
+ annotated_models.extend(annotation_item.keys())
598
+
599
+ # Update progress for all users (shared pool)
600
+ for uid in progress_data[campaign_id]:
601
+ # Add the newly annotated models
602
+ progress_data[campaign_id][uid]["progress"][item_i].extend(annotated_models)
603
+ return JSONResponse(content="ok", status_code=200)
341
604
  else:
342
605
  return JSONResponse(content="Unknown campaign assignment type", status_code=400)