pearmut 0.3.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pearmut/app.py CHANGED
@@ -1,20 +1,23 @@
1
- import collections
2
1
  import json
3
2
  import os
4
- import statistics
5
3
  from typing import Any
6
4
 
7
5
  from fastapi import FastAPI, Query
8
6
  from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import JSONResponse
7
+ from fastapi.responses import JSONResponse, Response
10
8
  from fastapi.staticfiles import StaticFiles
11
9
  from pydantic import BaseModel
12
10
 
13
11
  from .assignment import get_i_item, get_next_item, reset_task, update_progress
12
+ from .results_export import (
13
+ compute_model_scores,
14
+ generate_latex_table,
15
+ generate_pdf,
16
+ generate_typst_table,
17
+ )
14
18
  from .utils import (
15
19
  ROOT,
16
20
  check_validation_threshold,
17
- get_db_log,
18
21
  load_progress_data,
19
22
  save_db_payload,
20
23
  save_progress_data,
@@ -159,7 +162,7 @@ async def _dashboard_data(request: DashboardDataRequest):
159
162
 
160
163
  progress_new = {}
161
164
  assignment = tasks_data[campaign_id]["info"]["assignment"]
162
- if assignment not in ["task-based", "single-stream"]:
165
+ if assignment not in ["task-based", "single-stream", "dynamic"]:
163
166
  return JSONResponse(
164
167
  content="Unsupported campaign assignment type", status_code=400
165
168
  )
@@ -211,31 +214,47 @@ async def _dashboard_results(request: DashboardResultsRequest):
211
214
  if token != tasks_data[campaign_id]["token"]:
212
215
  return JSONResponse(content="Invalid token", status_code=400)
213
216
 
214
- # Compute model scores from annotations
215
- model_scores = collections.defaultdict(dict)
216
-
217
- # Iterate through all tasks to find items with 'models' field (basic template)
218
- log = get_db_log(campaign_id)
219
- for entry in log:
220
- if "item" not in entry or "annotation" not in entry:
221
- continue
222
- for item, annotation in zip(entry["item"], entry["annotation"]):
223
- for model, annotation in annotation.items():
224
- if "score" in annotation:
225
- model_scores[model][json.dumps(item)] = annotation["score"]
226
-
227
- results = [
228
- {
229
- "model": model,
230
- "score": statistics.mean(scores.values()),
231
- "count": len(scores),
232
- }
233
- for model, scores in model_scores.items()
234
- ]
235
- results.sort(key=lambda x: x["score"], reverse=True)
217
+ results = compute_model_scores(campaign_id)
236
218
  return JSONResponse(content=results, status_code=200)
237
219
 
238
220
 
221
+ @app.get("/export-results")
222
+ async def _export_results(
223
+ campaign_id: str = Query(),
224
+ token: str = Query(),
225
+ format: str = Query(),
226
+ ):
227
+ if campaign_id not in progress_data:
228
+ return JSONResponse(content="Unknown campaign ID", status_code=400)
229
+
230
+ # Check if token is valid
231
+ if token != tasks_data[campaign_id]["token"]:
232
+ return JSONResponse(content="Invalid token", status_code=400)
233
+
234
+ results = compute_model_scores(campaign_id)
235
+
236
+ if format == "typst":
237
+ content = generate_typst_table(results)
238
+ return Response(
239
+ content=content,
240
+ media_type="text/plain",
241
+ )
242
+ elif format == "latex":
243
+ content = generate_latex_table(results)
244
+ return Response(
245
+ content=content,
246
+ media_type="text/plain",
247
+ )
248
+ elif format == "pdf":
249
+ pdf_bytes = generate_pdf(results, campaign_id)
250
+ return Response(
251
+ content=pdf_bytes,
252
+ media_type="application/pdf",
253
+ )
254
+ else:
255
+ return JSONResponse(content="Invalid export format", status_code=400)
256
+
257
+
239
258
  class ResetTaskRequest(BaseModel):
240
259
  campaign_id: str
241
260
  user_id: str
@@ -284,7 +303,9 @@ async def _download_annotations(
284
303
  return JSONResponse(
285
304
  content=output,
286
305
  status_code=200,
287
- headers={"Content-Disposition": 'inline; filename="annotations.json"'},
306
+ headers={
307
+ "Content-Disposition": 'attachment; filename="annotations.json"',
308
+ },
288
309
  )
289
310
 
290
311
 
@@ -312,7 +333,9 @@ async def _download_progress(
312
333
  return JSONResponse(
313
334
  content=output,
314
335
  status_code=200,
315
- headers={"Content-Disposition": 'inline; filename="progress.json"'},
336
+ headers={
337
+ "Content-Disposition": 'attachment; filename="progress.json"',
338
+ },
316
339
  )
317
340
 
318
341
 
pearmut/assignment.py CHANGED
@@ -1,4 +1,6 @@
1
+ import collections
1
2
  import random
3
+ import statistics
2
4
  from typing import Any
3
5
 
4
6
  from fastapi.responses import JSONResponse
@@ -6,6 +8,7 @@ from fastapi.responses import JSONResponse
6
8
  from .utils import (
7
9
  RESET_MARKER,
8
10
  check_validation_threshold,
11
+ get_db_log,
9
12
  get_db_log_item,
10
13
  save_db_payload,
11
14
  )
@@ -20,14 +23,33 @@ def _completed_response(
20
23
  """Build a completed response with progress, time, and token."""
21
24
  user_progress = progress_data[campaign_id][user_id]
22
25
  is_ok = check_validation_threshold(tasks_data, progress_data, campaign_id, user_id)
26
+ token = user_progress["token_correct" if is_ok else "token_incorrect"]
27
+
28
+ # Get instructions_goodbye from campaign info, with default value
29
+ instructions_goodbye = tasks_data[campaign_id]["info"].get(
30
+ "instructions_goodbye",
31
+ "If someone asks you for a token of completion, show them: ${TOKEN}",
32
+ )
33
+
34
+ # Replace variables ${TOKEN} and ${USER_ID}
35
+ instructions_goodbye = instructions_goodbye.replace("${TOKEN}", token).replace(
36
+ "${USER_ID}", user_id
37
+ )
38
+
39
+ # Convert sets to lists for JSON serialization (for dynamic assignment)
40
+ progress = user_progress["progress"]
41
+ if progress and isinstance(progress[0], set):
42
+ progress = [list(s) for s in progress]
43
+
23
44
  return JSONResponse(
24
45
  content={
25
- "status": "completed",
26
- "progress": user_progress["progress"],
46
+ "status": "goodbye",
47
+ "progress": progress,
27
48
  "time": user_progress["time"],
28
- "token": user_progress["token_correct" if is_ok else "token_incorrect"],
49
+ "token": token,
50
+ "instructions_goodbye": instructions_goodbye,
29
51
  },
30
- status_code=200
52
+ status_code=200,
31
53
  )
32
54
 
33
55
 
@@ -44,7 +66,9 @@ def get_next_item(
44
66
  if assignment == "task-based":
45
67
  return get_next_item_taskbased(campaign_id, user_id, tasks_data, progress_data)
46
68
  elif assignment == "single-stream":
47
- return get_next_item_singlestream(campaign_id, user_id, tasks_data, progress_data)
69
+ return get_next_item_singlestream(
70
+ campaign_id, user_id, tasks_data, progress_data
71
+ )
48
72
  elif assignment == "dynamic":
49
73
  return get_next_item_dynamic(campaign_id, user_id, tasks_data, progress_data)
50
74
  else:
@@ -63,11 +87,17 @@ def get_i_item(
63
87
  """
64
88
  assignment = tasks_data[campaign_id]["info"]["assignment"]
65
89
  if assignment == "task-based":
66
- return get_i_item_taskbased(campaign_id, user_id, tasks_data, progress_data, item_i)
90
+ return get_i_item_taskbased(
91
+ campaign_id, user_id, tasks_data, progress_data, item_i
92
+ )
67
93
  elif assignment == "single-stream":
68
- return get_i_item_singlestream(campaign_id, user_id, tasks_data, progress_data, item_i)
94
+ return get_i_item_singlestream(
95
+ campaign_id, user_id, tasks_data, progress_data, item_i
96
+ )
69
97
  else:
70
- return JSONResponse(content="Get item not supported for this assignment type", status_code=400)
98
+ return JSONResponse(
99
+ content="Get item not supported for this assignment type", status_code=400
100
+ )
71
101
 
72
102
 
73
103
  def get_i_item_taskbased(
@@ -93,10 +123,7 @@ def get_i_item_taskbased(
93
123
  payload_existing["comment"] = latest_item["comment"]
94
124
 
95
125
  if item_i < 0 or item_i >= len(data_all[campaign_id]["data"][user_id]):
96
- return JSONResponse(
97
- content="Item index out of range",
98
- status_code=400
99
- )
126
+ return JSONResponse(content="Item index out of range", status_code=400)
100
127
 
101
128
  return JSONResponse(
102
129
  content={
@@ -105,14 +132,16 @@ def get_i_item_taskbased(
105
132
  "time": user_progress["time"],
106
133
  "info": {
107
134
  "item_i": item_i,
108
- } | {
135
+ }
136
+ | {
109
137
  k: v
110
138
  for k, v in data_all[campaign_id]["info"].items()
111
139
  if k.startswith("protocol")
112
140
  },
113
- "payload": data_all[campaign_id]["data"][user_id][item_i]
114
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
115
- status_code=200
141
+ "payload": data_all[campaign_id]["data"][user_id][item_i],
142
+ }
143
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
144
+ status_code=200,
116
145
  )
117
146
 
118
147
 
@@ -140,10 +169,7 @@ def get_i_item_singlestream(
140
169
  payload_existing["comment"] = latest_item["comment"]
141
170
 
142
171
  if item_i < 0 or item_i >= len(data_all[campaign_id]["data"]):
143
- return JSONResponse(
144
- content="Item index out of range",
145
- status_code=400
146
- )
172
+ return JSONResponse(content="Item index out of range", status_code=400)
147
173
 
148
174
  return JSONResponse(
149
175
  content={
@@ -152,14 +178,16 @@ def get_i_item_singlestream(
152
178
  "time": user_progress["time"],
153
179
  "info": {
154
180
  "item_i": item_i,
155
- } | {
181
+ }
182
+ | {
156
183
  k: v
157
184
  for k, v in data_all[campaign_id]["info"].items()
158
185
  if k.startswith("protocol")
159
186
  },
160
- "payload": data_all[campaign_id]["data"][item_i]
161
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
162
- status_code=200
187
+ "payload": data_all[campaign_id]["data"][item_i],
188
+ }
189
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
190
+ status_code=200,
163
191
  )
164
192
 
165
193
 
@@ -196,14 +224,16 @@ def get_next_item_taskbased(
196
224
  "time": user_progress["time"],
197
225
  "info": {
198
226
  "item_i": item_i,
199
- } | {
227
+ }
228
+ | {
200
229
  k: v
201
230
  for k, v in data_all[campaign_id]["info"].items()
202
231
  if k.startswith("protocol")
203
232
  },
204
- "payload": data_all[campaign_id]["data"][user_id][item_i]
205
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
206
- status_code=200
233
+ "payload": data_all[campaign_id]["data"][user_id][item_i],
234
+ }
235
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
236
+ status_code=200,
207
237
  )
208
238
 
209
239
 
@@ -249,21 +279,176 @@ def get_next_item_singlestream(
249
279
  "progress": progress,
250
280
  "info": {
251
281
  "item_i": item_i,
252
- } | {
282
+ }
283
+ | {
253
284
  k: v
254
285
  for k, v in data_all[campaign_id]["info"].items()
255
286
  if k.startswith("protocol")
256
287
  },
257
- "payload": data_all[campaign_id]["data"][item_i]
258
- } | ({"payload_existing": payload_existing} if payload_existing else {}),
259
- status_code=200
288
+ "payload": data_all[campaign_id]["data"][item_i],
289
+ }
290
+ | ({"payload_existing": payload_existing} if payload_existing else {}),
291
+ status_code=200,
260
292
  )
261
293
 
262
294
 
295
+ def get_next_item_dynamic(
296
+ campaign_id: str,
297
+ user_id: str,
298
+ tasks_data: dict,
299
+ progress_data: dict,
300
+ ) -> JSONResponse:
301
+ """
302
+ Get the next item for dynamic assignment based on model performance.
303
+
304
+ NOTE: All items must contain all model outputs for this assignment type to work.
305
+
306
+ In this mode, items are selected based on the current performance of models:
307
+ 1. Contrastive comparison: `dynamic_contrastive_models` models are randomly selected and shown per item
308
+ 2. First phase: Each model gets `dynamic_first` annotations with fully random selection
309
+ 3. After first phase: Top `dynamic_top` models are identified, K randomly selected from them
310
+ 4. Items with least annotations for the selected models are prioritized
311
+ 5. With probability `dynamic_backoff`, uniformly random selection is used instead
312
+ """
313
+ import random
314
+
315
+ user_progress = progress_data[campaign_id][user_id]
316
+ campaign_data = tasks_data[campaign_id]
263
317
 
264
- def get_next_item_dynamic(campaign_data: dict, user_id: str, progress_data: dict, data_all: dict):
265
- raise NotImplementedError("Dynamic protocol is not implemented yet.")
318
+ # Get all unique models in the campaign (all items must have all models)
319
+ all_models = list(set(campaign_data["data"][0][0]["tgt"].keys()))
266
320
 
321
+ # Check if completed (all models completed for all items)
322
+ # NOTE: this will rarely trigger but we don't have a good way to know when to end anyway for now
323
+ if all(len(v) == len(all_models) for v in user_progress["progress"]):
324
+ return _completed_response(tasks_data, progress_data, campaign_id, user_id)
325
+
326
+ # Get configuration parameters
327
+ dynamic_top = campaign_data["info"].get("dynamic_top", 2)
328
+ dynamic_first = campaign_data["info"].get("dynamic_first", 5)
329
+ dynamic_contrastive_models = campaign_data["info"].get(
330
+ "dynamic_contrastive_models", 1
331
+ )
332
+ dynamic_backoff = campaign_data["info"].get("dynamic_backoff", 0)
333
+
334
+ # Count annotations per (model, item) pair to track coverage
335
+ annotations = get_db_log(campaign_id)
336
+ model_item_counts = collections.defaultdict(int) # (model, item_i) -> count
337
+ model_total_counts = collections.defaultdict(int) # model -> total count
338
+
339
+ for annotation_line in annotations:
340
+ if (item_i := annotation_line.get("item_i")) is not None:
341
+ # Count which models were annotated in this annotation
342
+ for annotation_item in annotation_line.get("annotation", []):
343
+ for model in annotation_item:
344
+ model_item_counts[(model, item_i)] += 1
345
+ model_total_counts[model] += 1
346
+
347
+ # Check if we're still in the first phase (collecting initial data)
348
+ in_first_phase = any(
349
+ model_total_counts.get(model, 0) < dynamic_first for model in all_models
350
+ )
351
+
352
+ # Select which models to show
353
+ if in_first_phase:
354
+ # First phase or backoff: select models that don't have enough annotations yet
355
+ selected_models = random.sample(
356
+ [
357
+ model
358
+ for model in all_models
359
+ if model_total_counts.get(model, 0) < dynamic_first
360
+ ],
361
+ k=min(dynamic_contrastive_models, len(all_models)),
362
+ )
363
+ elif random.random() < dynamic_backoff:
364
+ # Backoff: select K models randomly from all models
365
+ selected_models = random.sample(
366
+ all_models, k=min(dynamic_contrastive_models, len(all_models))
367
+ )
368
+ else:
369
+ # Calculate model scores from annotations
370
+ model_scores = collections.defaultdict(list)
371
+ for annotation_line in annotations:
372
+ for annotation_item in annotation_line.get("annotation", {}):
373
+ for model in annotation_item:
374
+ if "score" in annotation_item[model]:
375
+ model_scores[model].append(annotation_item[model]["score"])
376
+
377
+ # Calculate average scores
378
+ model_avg_scores = {
379
+ model: statistics.mean(scores) for model, scores in model_scores.items()
380
+ }
381
+
382
+ # Get top N models
383
+ sorted_models = sorted(
384
+ model_avg_scores.items(), key=lambda x: x[1], reverse=True
385
+ )
386
+ top_models = [model for model, score in sorted_models[:dynamic_top]]
387
+
388
+ # From top N, randomly select K models
389
+ selected_models = random.sample(
390
+ top_models, k=min(dynamic_contrastive_models, len(top_models))
391
+ )
392
+
393
+ # Find incomplete items for the selected models (items where not all selected models are done)
394
+ item_annotation_counts = {
395
+ i: sum(model in completed_models for model in selected_models)
396
+ for i, completed_models in enumerate(user_progress["progress"])
397
+ }
398
+
399
+ # Select item with minimum annotations (with random tiebreaking)
400
+ min_annotations = min(item_annotation_counts.values())
401
+ items_with_min = [
402
+ item_i
403
+ for item_i, count in item_annotation_counts.items()
404
+ if count == min_annotations
405
+ ]
406
+ item_i = random.choice(items_with_min)
407
+
408
+ # Prune the payload to only include selected models
409
+ original_item = campaign_data["data"][item_i]
410
+ pruned_item = []
411
+ for doc_segment in original_item:
412
+ pruned_segment = doc_segment.copy()
413
+ # Filter tgt to only include selected models
414
+ pruned_segment["tgt"] = {
415
+ model: doc_segment["tgt"][model]
416
+ for model in selected_models
417
+ if model in doc_segment["tgt"]
418
+ }
419
+ # Also filter error_spans if present
420
+ if "error_spans" in doc_segment:
421
+ pruned_segment["error_spans"] = {
422
+ model: doc_segment["error_spans"][model]
423
+ for model in selected_models
424
+ if model in doc_segment.get("error_spans", {})
425
+ }
426
+ # Also filter validation if present
427
+ if "validation" in doc_segment:
428
+ pruned_segment["validation"] = {
429
+ model: doc_segment["validation"][model]
430
+ for model in selected_models
431
+ if model in doc_segment.get("validation", {})
432
+ }
433
+ pruned_item.append(pruned_segment)
434
+
435
+ return JSONResponse(
436
+ content={
437
+ "status": "ok",
438
+ "time": user_progress["time"],
439
+ "progress": user_progress["progress"],
440
+ "info": {
441
+ "item_i": item_i,
442
+ }
443
+ | {
444
+ k: v
445
+ for k, v in campaign_data["info"].items()
446
+ if k.startswith("protocol")
447
+ },
448
+ "payload": pruned_item,
449
+ },
450
+ status_code=200,
451
+ )
267
452
 
268
453
 
269
454
  def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> None:
@@ -289,11 +474,10 @@ def reset_task(
289
474
  # Save reset marker for this user to mask existing annotations
290
475
  num_items = len(tasks_data[campaign_id]["data"][user_id])
291
476
  for item_i in range(num_items):
292
- save_db_payload(campaign_id, {
293
- "user_id": user_id,
294
- "item_i": item_i,
295
- "annotation": RESET_MARKER
296
- })
477
+ save_db_payload(
478
+ campaign_id,
479
+ {"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
480
+ )
297
481
  progress_data[campaign_id][user_id]["progress"] = [False] * num_items
298
482
  _reset_user_time(progress_data, campaign_id, user_id)
299
483
  return JSONResponse(content="ok", status_code=200)
@@ -301,18 +485,32 @@ def reset_task(
301
485
  # Save reset markers for all items (shared pool)
302
486
  num_items = len(tasks_data[campaign_id]["data"])
303
487
  for item_i in range(num_items):
304
- save_db_payload(campaign_id, {
305
- "user_id": None,
306
- "item_i": item_i,
307
- "annotation": RESET_MARKER
308
- })
488
+ save_db_payload(
489
+ campaign_id,
490
+ {"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
491
+ )
309
492
  # for single-stream reset all progress
310
493
  for uid in progress_data[campaign_id]:
311
494
  progress_data[campaign_id][uid]["progress"] = [False] * num_items
312
495
  _reset_user_time(progress_data, campaign_id, user_id)
313
496
  return JSONResponse(content="ok", status_code=200)
497
+ elif assignment == "dynamic":
498
+ # Save reset markers for all items (shared pool like single-stream)
499
+ num_items = len(tasks_data[campaign_id]["data"])
500
+ for item_i in range(num_items):
501
+ save_db_payload(
502
+ campaign_id,
503
+ {"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
504
+ )
505
+ # for dynamic reset all progress (use sets to track models)
506
+ for uid in progress_data[campaign_id]:
507
+ progress_data[campaign_id][uid]["progress"] = [[] for _ in range(num_items)]
508
+ _reset_user_time(progress_data, campaign_id, user_id)
509
+ return JSONResponse(content="ok", status_code=200)
314
510
  else:
315
- return JSONResponse(content="Reset not supported for this assignment type", status_code=400)
511
+ return JSONResponse(
512
+ content="Reset not supported for this assignment type", status_code=400
513
+ )
316
514
 
317
515
 
318
516
  def update_progress(
@@ -337,6 +535,18 @@ def update_progress(
337
535
  progress_data[campaign_id][uid]["progress"][item_i] = True
338
536
  return JSONResponse(content="ok", status_code=200)
339
537
  elif assignment == "dynamic":
340
- return JSONResponse(content="Dynamic protocol logging not implemented yet.", status_code=400)
538
+ # For dynamic, track which models were annotated
539
+ # Extract models from the payload annotation
540
+ annotated_models = []
541
+ if "annotation" in payload:
542
+ for annotation_item in payload.get("annotation", []):
543
+ if isinstance(annotation_item, dict):
544
+ annotated_models.extend(annotation_item.keys())
545
+
546
+ # Update progress for all users (shared pool)
547
+ for uid in progress_data[campaign_id]:
548
+ # Add the newly annotated models
549
+ progress_data[campaign_id][uid]["progress"][item_i].extend(annotated_models)
550
+ return JSONResponse(content="ok", status_code=200)
341
551
  else:
342
552
  return JSONResponse(content="Unknown campaign assignment type", status_code=400)