pearmut 0.3.3__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pearmut/app.py CHANGED
@@ -1,20 +1,23 @@
1
- import collections
2
1
  import json
3
2
  import os
4
- import statistics
5
3
  from typing import Any
6
4
 
7
5
  from fastapi import FastAPI, Query
8
6
  from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import JSONResponse
7
+ from fastapi.responses import JSONResponse, Response
10
8
  from fastapi.staticfiles import StaticFiles
11
9
  from pydantic import BaseModel
12
10
 
13
11
  from .assignment import get_i_item, get_next_item, reset_task, update_progress
12
+ from .results_export import (
13
+ compute_model_scores,
14
+ generate_latex_table,
15
+ generate_pdf,
16
+ generate_typst_table,
17
+ )
14
18
  from .utils import (
15
19
  ROOT,
16
20
  check_validation_threshold,
17
- get_db_log,
18
21
  load_progress_data,
19
22
  save_db_payload,
20
23
  save_progress_data,
@@ -159,7 +162,7 @@ async def _dashboard_data(request: DashboardDataRequest):
159
162
 
160
163
  progress_new = {}
161
164
  assignment = tasks_data[campaign_id]["info"]["assignment"]
162
- if assignment not in ["task-based", "single-stream"]:
165
+ if assignment not in ["task-based", "single-stream", "dynamic"]:
163
166
  return JSONResponse(
164
167
  content="Unsupported campaign assignment type", status_code=400
165
168
  )
@@ -211,31 +214,47 @@ async def _dashboard_results(request: DashboardResultsRequest):
211
214
  if token != tasks_data[campaign_id]["token"]:
212
215
  return JSONResponse(content="Invalid token", status_code=400)
213
216
 
214
- # Compute model scores from annotations
215
- model_scores = collections.defaultdict(dict)
216
-
217
- # Iterate through all tasks to find items with 'models' field (basic template)
218
- log = get_db_log(campaign_id)
219
- for entry in log:
220
- if "item" not in entry or "annotation" not in entry:
221
- continue
222
- for item, annotation in zip(entry["item"], entry["annotation"]):
223
- for model, annotation in annotation.items():
224
- if "score" in annotation and annotation["score"] is not None:
225
- model_scores[model][json.dumps(item)] = annotation["score"]
226
-
227
- results = [
228
- {
229
- "model": model,
230
- "score": statistics.mean(scores.values()),
231
- "count": len(scores),
232
- }
233
- for model, scores in model_scores.items()
234
- ]
235
- results.sort(key=lambda x: x["score"], reverse=True)
217
+ results = compute_model_scores(campaign_id)
236
218
  return JSONResponse(content=results, status_code=200)
237
219
 
238
220
 
221
+ @app.get("/export-results")
222
+ async def _export_results(
223
+ campaign_id: str = Query(),
224
+ token: str = Query(),
225
+ format: str = Query(),
226
+ ):
227
+ if campaign_id not in progress_data:
228
+ return JSONResponse(content="Unknown campaign ID", status_code=400)
229
+
230
+ # Check if token is valid
231
+ if token != tasks_data[campaign_id]["token"]:
232
+ return JSONResponse(content="Invalid token", status_code=400)
233
+
234
+ results = compute_model_scores(campaign_id)
235
+
236
+ if format == "typst":
237
+ content = generate_typst_table(results)
238
+ return Response(
239
+ content=content,
240
+ media_type="text/plain",
241
+ )
242
+ elif format == "latex":
243
+ content = generate_latex_table(results)
244
+ return Response(
245
+ content=content,
246
+ media_type="text/plain",
247
+ )
248
+ elif format == "pdf":
249
+ pdf_bytes = generate_pdf(results, campaign_id)
250
+ return Response(
251
+ content=pdf_bytes,
252
+ media_type="application/pdf",
253
+ )
254
+ else:
255
+ return JSONResponse(content="Invalid export format", status_code=400)
256
+
257
+
239
258
  class ResetTaskRequest(BaseModel):
240
259
  campaign_id: str
241
260
  user_id: str
pearmut/assignment.py CHANGED
@@ -1,4 +1,6 @@
1
+ import collections
1
2
  import random
3
+ import statistics
2
4
  from typing import Any
3
5
 
4
6
  from fastapi.responses import JSONResponse
@@ -6,6 +8,7 @@ from fastapi.responses import JSONResponse
6
8
  from .utils import (
7
9
  RESET_MARKER,
8
10
  check_validation_threshold,
11
+ get_db_log,
9
12
  get_db_log_item,
10
13
  save_db_payload,
11
14
  )
@@ -20,14 +23,33 @@ def _completed_response(
20
23
  """Build a completed response with progress, time, and token."""
21
24
  user_progress = progress_data[campaign_id][user_id]
22
25
  is_ok = check_validation_threshold(tasks_data, progress_data, campaign_id, user_id)
26
+ token = user_progress["token_correct" if is_ok else "token_incorrect"]
27
+
28
+ # Get instructions_goodbye from campaign info, with default value
29
+ instructions_goodbye = tasks_data[campaign_id]["info"].get(
30
+ "instructions_goodbye",
31
+ "If someone asks you for a token of completion, show them: ${TOKEN}",
32
+ )
33
+
34
+ # Replace variables ${TOKEN} and ${USER_ID}
35
+ instructions_goodbye = instructions_goodbye.replace("${TOKEN}", token).replace(
36
+ "${USER_ID}", user_id
37
+ )
38
+
39
+ # Convert sets to lists for JSON serialization (for dynamic assignment)
40
+ progress = user_progress["progress"]
41
+ if progress and isinstance(progress[0], set):
42
+ progress = [list(s) for s in progress]
43
+
23
44
  return JSONResponse(
24
45
  content={
25
- "status": "completed",
26
- "progress": user_progress["progress"],
46
+ "status": "goodbye",
47
+ "progress": progress,
27
48
  "time": user_progress["time"],
28
- "token": user_progress["token_correct" if is_ok else "token_incorrect"],
49
+ "token": token,
50
+ "instructions_goodbye": instructions_goodbye,
29
51
  },
30
- status_code=200
52
+ status_code=200,
31
53
  )
32
54
 
33
55
 
@@ -44,7 +66,9 @@ def get_next_item(
44
66
  if assignment == "task-based":
45
67
  return get_next_item_taskbased(campaign_id, user_id, tasks_data, progress_data)
46
68
  elif assignment == "single-stream":
47
- return get_next_item_singlestream(campaign_id, user_id, tasks_data, progress_data)
69
+ return get_next_item_singlestream(
70
+ campaign_id, user_id, tasks_data, progress_data
71
+ )
48
72
  elif assignment == "dynamic":
49
73
  return get_next_item_dynamic(campaign_id, user_id, tasks_data, progress_data)
50
74
  else:
@@ -63,11 +87,17 @@ def get_i_item(
63
87
  """
64
88
  assignment = tasks_data[campaign_id]["info"]["assignment"]
65
89
  if assignment == "task-based":
66
- return get_i_item_taskbased(campaign_id, user_id, tasks_data, progress_data, item_i)
90
+ return get_i_item_taskbased(
91
+ campaign_id, user_id, tasks_data, progress_data, item_i
92
+ )
67
93
  elif assignment == "single-stream":
68
- return get_i_item_singlestream(campaign_id, user_id, tasks_data, progress_data, item_i)
94
+ return get_i_item_singlestream(
95
+ campaign_id, user_id, tasks_data, progress_data, item_i
96
+ )
69
97
  else:
70
- return JSONResponse(content="Get item not supported for this assignment type", status_code=400)
98
+ return JSONResponse(
99
+ content="Get item not supported for this assignment type", status_code=400
100
+ )
71
101
 
72
102
 
73
103
  def get_i_item_taskbased(
@@ -93,10 +123,7 @@ def get_i_item_taskbased(
93
123
  payload_existing["comment"] = latest_item["comment"]
94
124
 
95
125
  if item_i < 0 or item_i >= len(data_all[campaign_id]["data"][user_id]):
96
- return JSONResponse(
97
- content="Item index out of range",
98
- status_code=400
99
- )
126
+ return JSONResponse(content="Item index out of range", status_code=400)
100
127
 
101
128
  return JSONResponse(
102
129
  content={
@@ -105,14 +132,16 @@ def get_i_item_taskbased(
105
132
  "time": user_progress["time"],
106
133
  "info": {
107
134
  "item_i": item_i,
108
- } | {
135
+ }
136
+ | {
109
137
  k: v
110
138
  for k, v in data_all[campaign_id]["info"].items()
111
139
  if k.startswith("protocol")
112
140
  },
113
- "payload": data_all[campaign_id]["data"][user_id][item_i]
114
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
115
- status_code=200
141
+ "payload": data_all[campaign_id]["data"][user_id][item_i],
142
+ }
143
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
144
+ status_code=200,
116
145
  )
117
146
 
118
147
 
@@ -140,10 +169,7 @@ def get_i_item_singlestream(
140
169
  payload_existing["comment"] = latest_item["comment"]
141
170
 
142
171
  if item_i < 0 or item_i >= len(data_all[campaign_id]["data"]):
143
- return JSONResponse(
144
- content="Item index out of range",
145
- status_code=400
146
- )
172
+ return JSONResponse(content="Item index out of range", status_code=400)
147
173
 
148
174
  return JSONResponse(
149
175
  content={
@@ -152,14 +178,16 @@ def get_i_item_singlestream(
152
178
  "time": user_progress["time"],
153
179
  "info": {
154
180
  "item_i": item_i,
155
- } | {
181
+ }
182
+ | {
156
183
  k: v
157
184
  for k, v in data_all[campaign_id]["info"].items()
158
185
  if k.startswith("protocol")
159
186
  },
160
- "payload": data_all[campaign_id]["data"][item_i]
161
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
162
- status_code=200
187
+ "payload": data_all[campaign_id]["data"][item_i],
188
+ }
189
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
190
+ status_code=200,
163
191
  )
164
192
 
165
193
 
@@ -196,14 +224,16 @@ def get_next_item_taskbased(
196
224
  "time": user_progress["time"],
197
225
  "info": {
198
226
  "item_i": item_i,
199
- } | {
227
+ }
228
+ | {
200
229
  k: v
201
230
  for k, v in data_all[campaign_id]["info"].items()
202
231
  if k.startswith("protocol")
203
232
  },
204
- "payload": data_all[campaign_id]["data"][user_id][item_i]
205
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
206
- status_code=200
233
+ "payload": data_all[campaign_id]["data"][user_id][item_i],
234
+ }
235
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
236
+ status_code=200,
207
237
  )
208
238
 
209
239
 
@@ -249,21 +279,176 @@ def get_next_item_singlestream(
249
279
  "progress": progress,
250
280
  "info": {
251
281
  "item_i": item_i,
252
- } | {
282
+ }
283
+ | {
253
284
  k: v
254
285
  for k, v in data_all[campaign_id]["info"].items()
255
286
  if k.startswith("protocol")
256
287
  },
257
- "payload": data_all[campaign_id]["data"][item_i]
258
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
259
- status_code=200
288
+ "payload": data_all[campaign_id]["data"][item_i],
289
+ }
290
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
291
+ status_code=200,
260
292
  )
261
293
 
262
294
 
295
+ def get_next_item_dynamic(
296
+ campaign_id: str,
297
+ user_id: str,
298
+ tasks_data: dict,
299
+ progress_data: dict,
300
+ ) -> JSONResponse:
301
+ """
302
+ Get the next item for dynamic assignment based on model performance.
303
+
304
+ NOTE: All items must contain all model outputs for this assignment type to work.
305
+
306
+ In this mode, items are selected based on the current performance of models:
307
+ 1. Contrastive comparison: `dynamic_contrastive_models` models are randomly selected and shown per item
308
+ 2. First phase: Each model gets `dynamic_first` annotations with fully random selection
309
+ 3. After first phase: Top `dynamic_top` models are identified, K randomly selected from them
310
+ 4. Items with least annotations for the selected models are prioritized
311
+ 5. With probability `dynamic_backoff`, uniformly random selection is used instead
312
+ """
313
+ import random
314
+
315
+ user_progress = progress_data[campaign_id][user_id]
316
+ campaign_data = tasks_data[campaign_id]
263
317
 
264
- def get_next_item_dynamic(campaign_data: dict, user_id: str, progress_data: dict, data_all: dict):
265
- raise NotImplementedError("Dynamic protocol is not implemented yet.")
318
+ # Get all unique models in the campaign (all items must have all models)
319
+ all_models = list(set(campaign_data["data"][0][0]["tgt"].keys()))
266
320
 
321
+ # Check if completed (all models completed for all items)
322
+ # NOTE: this will rarely trigger but we don't have a good way to know when to end anyway for now
323
+ if all(len(v) == len(all_models) for v in user_progress["progress"]):
324
+ return _completed_response(tasks_data, progress_data, campaign_id, user_id)
325
+
326
+ # Get configuration parameters
327
+ dynamic_top = campaign_data["info"].get("dynamic_top", 2)
328
+ dynamic_first = campaign_data["info"].get("dynamic_first", 5)
329
+ dynamic_contrastive_models = campaign_data["info"].get(
330
+ "dynamic_contrastive_models", 1
331
+ )
332
+ dynamic_backoff = campaign_data["info"].get("dynamic_backoff", 0)
333
+
334
+ # Count annotations per (model, item) pair to track coverage
335
+ annotations = get_db_log(campaign_id)
336
+ model_item_counts = collections.defaultdict(int) # (model, item_i) -> count
337
+ model_total_counts = collections.defaultdict(int) # model -> total count
338
+
339
+ for annotation_line in annotations:
340
+ if (item_i := annotation_line.get("item_i")) is not None:
341
+ # Count which models were annotated in this annotation
342
+ for annotation_item in annotation_line.get("annotation", []):
343
+ for model in annotation_item:
344
+ model_item_counts[(model, item_i)] += 1
345
+ model_total_counts[model] += 1
346
+
347
+ # Check if we're still in the first phase (collecting initial data)
348
+ in_first_phase = any(
349
+ model_total_counts.get(model, 0) < dynamic_first for model in all_models
350
+ )
351
+
352
+ # Select which models to show
353
+ if in_first_phase:
354
+ # First phase or backoff: select models that don't have enough annotations yet
355
+ selected_models = random.sample(
356
+ [
357
+ model
358
+ for model in all_models
359
+ if model_total_counts.get(model, 0) < dynamic_first
360
+ ],
361
+ k=min(dynamic_contrastive_models, len(all_models)),
362
+ )
363
+ elif random.random() < dynamic_backoff:
364
+ # Backoff: select K models randomly from all models
365
+ selected_models = random.sample(
366
+ all_models, k=min(dynamic_contrastive_models, len(all_models))
367
+ )
368
+ else:
369
+ # Calculate model scores from annotations
370
+ model_scores = collections.defaultdict(list)
371
+ for annotation_line in annotations:
372
+ for annotation_item in annotation_line.get("annotation", {}):
373
+ for model in annotation_item:
374
+ if "score" in annotation_item[model]:
375
+ model_scores[model].append(annotation_item[model]["score"])
376
+
377
+ # Calculate average scores
378
+ model_avg_scores = {
379
+ model: statistics.mean(scores) for model, scores in model_scores.items()
380
+ }
381
+
382
+ # Get top N models
383
+ sorted_models = sorted(
384
+ model_avg_scores.items(), key=lambda x: x[1], reverse=True
385
+ )
386
+ top_models = [model for model, score in sorted_models[:dynamic_top]]
387
+
388
+ # From top N, randomly select K models
389
+ selected_models = random.sample(
390
+ top_models, k=min(dynamic_contrastive_models, len(top_models))
391
+ )
392
+
393
+ # Find incomplete items for the selected models (items where not all selected models are done)
394
+ item_annotation_counts = {
395
+ i: sum(model in completed_models for model in selected_models)
396
+ for i, completed_models in enumerate(user_progress["progress"])
397
+ }
398
+
399
+ # Select item with minimum annotations (with random tiebreaking)
400
+ min_annotations = min(item_annotation_counts.values())
401
+ items_with_min = [
402
+ item_i
403
+ for item_i, count in item_annotation_counts.items()
404
+ if count == min_annotations
405
+ ]
406
+ item_i = random.choice(items_with_min)
407
+
408
+ # Prune the payload to only include selected models
409
+ original_item = campaign_data["data"][item_i]
410
+ pruned_item = []
411
+ for doc_segment in original_item:
412
+ pruned_segment = doc_segment.copy()
413
+ # Filter tgt to only include selected models
414
+ pruned_segment["tgt"] = {
415
+ model: doc_segment["tgt"][model]
416
+ for model in selected_models
417
+ if model in doc_segment["tgt"]
418
+ }
419
+ # Also filter error_spans if present
420
+ if "error_spans" in doc_segment:
421
+ pruned_segment["error_spans"] = {
422
+ model: doc_segment["error_spans"][model]
423
+ for model in selected_models
424
+ if model in doc_segment.get("error_spans", {})
425
+ }
426
+ # Also filter validation if present
427
+ if "validation" in doc_segment:
428
+ pruned_segment["validation"] = {
429
+ model: doc_segment["validation"][model]
430
+ for model in selected_models
431
+ if model in doc_segment.get("validation", {})
432
+ }
433
+ pruned_item.append(pruned_segment)
434
+
435
+ return JSONResponse(
436
+ content={
437
+ "status": "ok",
438
+ "time": user_progress["time"],
439
+ "progress": user_progress["progress"],
440
+ "info": {
441
+ "item_i": item_i,
442
+ }
443
+ | {
444
+ k: v
445
+ for k, v in campaign_data["info"].items()
446
+ if k.startswith("protocol")
447
+ },
448
+ "payload": pruned_item,
449
+ },
450
+ status_code=200,
451
+ )
267
452
 
268
453
 
269
454
  def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> None:
@@ -289,11 +474,10 @@ def reset_task(
289
474
  # Save reset marker for this user to mask existing annotations
290
475
  num_items = len(tasks_data[campaign_id]["data"][user_id])
291
476
  for item_i in range(num_items):
292
- save_db_payload(campaign_id, {
293
- "user_id": user_id,
294
- "item_i": item_i,
295
- "annotation": RESET_MARKER
296
- })
477
+ save_db_payload(
478
+ campaign_id,
479
+ {"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
480
+ )
297
481
  progress_data[campaign_id][user_id]["progress"] = [False] * num_items
298
482
  _reset_user_time(progress_data, campaign_id, user_id)
299
483
  return JSONResponse(content="ok", status_code=200)
@@ -301,18 +485,32 @@ def reset_task(
301
485
  # Save reset markers for all items (shared pool)
302
486
  num_items = len(tasks_data[campaign_id]["data"])
303
487
  for item_i in range(num_items):
304
- save_db_payload(campaign_id, {
305
- "user_id": None,
306
- "item_i": item_i,
307
- "annotation": RESET_MARKER
308
- })
488
+ save_db_payload(
489
+ campaign_id,
490
+ {"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
491
+ )
309
492
  # for single-stream reset all progress
310
493
  for uid in progress_data[campaign_id]:
311
494
  progress_data[campaign_id][uid]["progress"] = [False] * num_items
312
495
  _reset_user_time(progress_data, campaign_id, user_id)
313
496
  return JSONResponse(content="ok", status_code=200)
497
+ elif assignment == "dynamic":
498
+ # Save reset markers for all items (shared pool like single-stream)
499
+ num_items = len(tasks_data[campaign_id]["data"])
500
+ for item_i in range(num_items):
501
+ save_db_payload(
502
+ campaign_id,
503
+ {"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
504
+ )
505
+ # for dynamic reset all progress (use sets to track models)
506
+ for uid in progress_data[campaign_id]:
507
+ progress_data[campaign_id][uid]["progress"] = [[] for _ in range(num_items)]
508
+ _reset_user_time(progress_data, campaign_id, user_id)
509
+ return JSONResponse(content="ok", status_code=200)
314
510
  else:
315
- return JSONResponse(content="Reset not supported for this assignment type", status_code=400)
511
+ return JSONResponse(
512
+ content="Reset not supported for this assignment type", status_code=400
513
+ )
316
514
 
317
515
 
318
516
  def update_progress(
@@ -337,6 +535,18 @@ def update_progress(
337
535
  progress_data[campaign_id][uid]["progress"][item_i] = True
338
536
  return JSONResponse(content="ok", status_code=200)
339
537
  elif assignment == "dynamic":
340
- return JSONResponse(content="Dynamic protocol logging not implemented yet.", status_code=400)
538
+ # For dynamic, track which models were annotated
539
+ # Extract models from the payload annotation
540
+ annotated_models = []
541
+ if "annotation" in payload:
542
+ for annotation_item in payload.get("annotation", []):
543
+ if isinstance(annotation_item, dict):
544
+ annotated_models.extend(annotation_item.keys())
545
+
546
+ # Update progress for all users (shared pool)
547
+ for uid in progress_data[campaign_id]:
548
+ # Add the newly annotated models
549
+ progress_data[campaign_id][uid]["progress"][item_i].extend(annotated_models)
550
+ return JSONResponse(content="ok", status_code=200)
341
551
  else:
342
552
  return JSONResponse(content="Unknown campaign assignment type", status_code=400)
pearmut/cli.py CHANGED
@@ -179,7 +179,7 @@ def _shuffle_campaign_data(campaign_data, rng):
179
179
  for user_id, task in campaign_data["data"].items():
180
180
  for doc in task:
181
181
  shuffle_document(doc)
182
- elif assignment == "single-stream":
182
+ elif assignment in ["single-stream", "dynamic"]:
183
183
  # Shuffle each document in the shared pool
184
184
  for doc in campaign_data["data"]:
185
185
  shuffle_document(doc)
@@ -259,8 +259,46 @@ def _add_single_campaign(data_file, overwrite, server):
259
259
  else:
260
260
  raise ValueError("'users' must be an integer or a list.")
261
261
  elif assignment == "dynamic":
262
- raise NotImplementedError(
263
- "Dynamic campaign assignment is not yet implemented.")
262
+ tasks = campaign_data["data"]
263
+ if users_spec is None:
264
+ raise ValueError(
265
+ "Dynamic campaigns must specify 'users' in info.")
266
+ if not isinstance(campaign_data["data"], list):
267
+ raise ValueError(
268
+ "Dynamic campaign 'data' must be a list of items.")
269
+ # Validate item structure for dynamic
270
+ for doc_i, doc in enumerate(tasks):
271
+ try:
272
+ _validate_item_structure(doc)
273
+ except ValueError as e:
274
+ raise ValueError(f"Document {doc_i}: {e}")
275
+ if isinstance(users_spec, int):
276
+ num_users = users_spec
277
+ elif isinstance(users_spec, list):
278
+ num_users = len(users_spec)
279
+ else:
280
+ raise ValueError("'users' must be an integer or a list.")
281
+ # Validate dynamic-specific parameters
282
+ if "dynamic_top" not in campaign_data["info"]:
283
+ campaign_data["info"]["dynamic_top"] = 2
284
+ if "dynamic_first" not in campaign_data["info"]:
285
+ campaign_data["info"]["dynamic_first"] = 5
286
+ if "dynamic_contrastive_models" not in campaign_data["info"]:
287
+ campaign_data["info"]["dynamic_contrastive_models"] = 1
288
+ # Validate that dynamic_first is at least 1
289
+ assert campaign_data["info"]["dynamic_first"] >= 1, "dynamic_first must be at least 1"
290
+ # Validate that dynamic_contrastive_models is at most dynamic_top
291
+ assert campaign_data["info"]["dynamic_contrastive_models"] <= campaign_data["info"]["dynamic_top"], \
292
+ "dynamic_contrastive_models must be at most dynamic_top"
293
+ # Validate that all items have the same models
294
+ all_models = set()
295
+ for item in campaign_data["data"]:
296
+ if item and len(item) > 0:
297
+ all_models.update(item[0]["tgt"].keys())
298
+ for item in campaign_data["data"]:
299
+ if item and len(item) > 0:
300
+ item_models = set(item[0]["tgt"].keys())
301
+ assert item_models == all_models, "All items must have the same model outputs"
264
302
  else:
265
303
  raise ValueError(f"Unknown campaign assignment type: {assignment}")
266
304
 
@@ -310,13 +348,13 @@ def _add_single_campaign(data_file, overwrite, server):
310
348
  os.remove(output_file)
311
349
 
312
350
  # For task-based, data is a dict mapping user_id -> tasks
313
- # For single-stream, data is a flat list (shared among all users)
351
+ # For single-stream and dynamic, data is a flat list (shared among all users)
314
352
  if assignment == "task-based":
315
353
  campaign_data["data"] = {
316
354
  user_id: task
317
355
  for user_id, task in zip(user_ids, tasks)
318
356
  }
319
- elif assignment == "single-stream":
357
+ elif assignment in ["single-stream", "dynamic"]:
320
358
  campaign_data["data"] = tasks
321
359
 
322
360
  # generate a token for dashboard access if not present
@@ -338,6 +376,7 @@ def _add_single_campaign(data_file, overwrite, server):
338
376
  "progress": (
339
377
  [False]*len(campaign_data["data"][user_id]) if assignment == "task-based"
340
378
  else [False]*len(campaign_data["data"]) if assignment == "single-stream"
379
+ else [list() for _ in range(len(campaign_data["data"]))] if assignment == "dynamic"
341
380
  else []
342
381
  ),
343
382
  "time_start": None,
@@ -421,9 +460,7 @@ def _add_single_campaign(data_file, overwrite, server):
421
460
  json.dump(campaign_data, f, indent=2, ensure_ascii=False)
422
461
 
423
462
  progress_data[campaign_data['campaign_id']] = user_progress
424
-
425
- with open(f"{ROOT}/data/progress.json", "w") as f:
426
- json.dump(progress_data, f, indent=2, ensure_ascii=False)
463
+ save_progress_data(progress_data)
427
464
 
428
465
 
429
466
  print(