pearmut 0.3.3__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pearmut/app.py +46 -27
- pearmut/assignment.py +256 -46
- pearmut/cli.py +45 -8
- pearmut/results_export.py +210 -0
- pearmut/static/basic.bundle.js +1 -1
- pearmut/static/basic.html +1 -1
- pearmut/static/dashboard.bundle.js +1 -1
- pearmut/static/dashboard.html +27 -12
- pearmut/static/index.bundle.js +1 -1
- pearmut/static/index.html +1 -1
- pearmut/utils.py +16 -2
- {pearmut-0.3.3.dist-info → pearmut-1.0.0.dist-info}/METADATA +54 -26
- pearmut-1.0.0.dist-info/RECORD +19 -0
- pearmut-0.3.3.dist-info/RECORD +0 -18
- {pearmut-0.3.3.dist-info → pearmut-1.0.0.dist-info}/WHEEL +0 -0
- {pearmut-0.3.3.dist-info → pearmut-1.0.0.dist-info}/entry_points.txt +0 -0
- {pearmut-0.3.3.dist-info → pearmut-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {pearmut-0.3.3.dist-info → pearmut-1.0.0.dist-info}/top_level.txt +0 -0
pearmut/app.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
|
1
|
-
import collections
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
|
-
import statistics
|
|
5
3
|
from typing import Any
|
|
6
4
|
|
|
7
5
|
from fastapi import FastAPI, Query
|
|
8
6
|
from fastapi.middleware.cors import CORSMiddleware
|
|
9
|
-
from fastapi.responses import JSONResponse
|
|
7
|
+
from fastapi.responses import JSONResponse, Response
|
|
10
8
|
from fastapi.staticfiles import StaticFiles
|
|
11
9
|
from pydantic import BaseModel
|
|
12
10
|
|
|
13
11
|
from .assignment import get_i_item, get_next_item, reset_task, update_progress
|
|
12
|
+
from .results_export import (
|
|
13
|
+
compute_model_scores,
|
|
14
|
+
generate_latex_table,
|
|
15
|
+
generate_pdf,
|
|
16
|
+
generate_typst_table,
|
|
17
|
+
)
|
|
14
18
|
from .utils import (
|
|
15
19
|
ROOT,
|
|
16
20
|
check_validation_threshold,
|
|
17
|
-
get_db_log,
|
|
18
21
|
load_progress_data,
|
|
19
22
|
save_db_payload,
|
|
20
23
|
save_progress_data,
|
|
@@ -159,7 +162,7 @@ async def _dashboard_data(request: DashboardDataRequest):
|
|
|
159
162
|
|
|
160
163
|
progress_new = {}
|
|
161
164
|
assignment = tasks_data[campaign_id]["info"]["assignment"]
|
|
162
|
-
if assignment not in ["task-based", "single-stream"]:
|
|
165
|
+
if assignment not in ["task-based", "single-stream", "dynamic"]:
|
|
163
166
|
return JSONResponse(
|
|
164
167
|
content="Unsupported campaign assignment type", status_code=400
|
|
165
168
|
)
|
|
@@ -211,31 +214,47 @@ async def _dashboard_results(request: DashboardResultsRequest):
|
|
|
211
214
|
if token != tasks_data[campaign_id]["token"]:
|
|
212
215
|
return JSONResponse(content="Invalid token", status_code=400)
|
|
213
216
|
|
|
214
|
-
|
|
215
|
-
model_scores = collections.defaultdict(dict)
|
|
216
|
-
|
|
217
|
-
# Iterate through all tasks to find items with 'models' field (basic template)
|
|
218
|
-
log = get_db_log(campaign_id)
|
|
219
|
-
for entry in log:
|
|
220
|
-
if "item" not in entry or "annotation" not in entry:
|
|
221
|
-
continue
|
|
222
|
-
for item, annotation in zip(entry["item"], entry["annotation"]):
|
|
223
|
-
for model, annotation in annotation.items():
|
|
224
|
-
if "score" in annotation and annotation["score"] is not None:
|
|
225
|
-
model_scores[model][json.dumps(item)] = annotation["score"]
|
|
226
|
-
|
|
227
|
-
results = [
|
|
228
|
-
{
|
|
229
|
-
"model": model,
|
|
230
|
-
"score": statistics.mean(scores.values()),
|
|
231
|
-
"count": len(scores),
|
|
232
|
-
}
|
|
233
|
-
for model, scores in model_scores.items()
|
|
234
|
-
]
|
|
235
|
-
results.sort(key=lambda x: x["score"], reverse=True)
|
|
217
|
+
results = compute_model_scores(campaign_id)
|
|
236
218
|
return JSONResponse(content=results, status_code=200)
|
|
237
219
|
|
|
238
220
|
|
|
221
|
+
@app.get("/export-results")
|
|
222
|
+
async def _export_results(
|
|
223
|
+
campaign_id: str = Query(),
|
|
224
|
+
token: str = Query(),
|
|
225
|
+
format: str = Query(),
|
|
226
|
+
):
|
|
227
|
+
if campaign_id not in progress_data:
|
|
228
|
+
return JSONResponse(content="Unknown campaign ID", status_code=400)
|
|
229
|
+
|
|
230
|
+
# Check if token is valid
|
|
231
|
+
if token != tasks_data[campaign_id]["token"]:
|
|
232
|
+
return JSONResponse(content="Invalid token", status_code=400)
|
|
233
|
+
|
|
234
|
+
results = compute_model_scores(campaign_id)
|
|
235
|
+
|
|
236
|
+
if format == "typst":
|
|
237
|
+
content = generate_typst_table(results)
|
|
238
|
+
return Response(
|
|
239
|
+
content=content,
|
|
240
|
+
media_type="text/plain",
|
|
241
|
+
)
|
|
242
|
+
elif format == "latex":
|
|
243
|
+
content = generate_latex_table(results)
|
|
244
|
+
return Response(
|
|
245
|
+
content=content,
|
|
246
|
+
media_type="text/plain",
|
|
247
|
+
)
|
|
248
|
+
elif format == "pdf":
|
|
249
|
+
pdf_bytes = generate_pdf(results, campaign_id)
|
|
250
|
+
return Response(
|
|
251
|
+
content=pdf_bytes,
|
|
252
|
+
media_type="application/pdf",
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
return JSONResponse(content="Invalid export format", status_code=400)
|
|
256
|
+
|
|
257
|
+
|
|
239
258
|
class ResetTaskRequest(BaseModel):
|
|
240
259
|
campaign_id: str
|
|
241
260
|
user_id: str
|
pearmut/assignment.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import collections
|
|
1
2
|
import random
|
|
3
|
+
import statistics
|
|
2
4
|
from typing import Any
|
|
3
5
|
|
|
4
6
|
from fastapi.responses import JSONResponse
|
|
@@ -6,6 +8,7 @@ from fastapi.responses import JSONResponse
|
|
|
6
8
|
from .utils import (
|
|
7
9
|
RESET_MARKER,
|
|
8
10
|
check_validation_threshold,
|
|
11
|
+
get_db_log,
|
|
9
12
|
get_db_log_item,
|
|
10
13
|
save_db_payload,
|
|
11
14
|
)
|
|
@@ -20,14 +23,33 @@ def _completed_response(
|
|
|
20
23
|
"""Build a completed response with progress, time, and token."""
|
|
21
24
|
user_progress = progress_data[campaign_id][user_id]
|
|
22
25
|
is_ok = check_validation_threshold(tasks_data, progress_data, campaign_id, user_id)
|
|
26
|
+
token = user_progress["token_correct" if is_ok else "token_incorrect"]
|
|
27
|
+
|
|
28
|
+
# Get instructions_goodbye from campaign info, with default value
|
|
29
|
+
instructions_goodbye = tasks_data[campaign_id]["info"].get(
|
|
30
|
+
"instructions_goodbye",
|
|
31
|
+
"If someone asks you for a token of completion, show them: ${TOKEN}",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Replace variables ${TOKEN} and ${USER_ID}
|
|
35
|
+
instructions_goodbye = instructions_goodbye.replace("${TOKEN}", token).replace(
|
|
36
|
+
"${USER_ID}", user_id
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Convert sets to lists for JSON serialization (for dynamic assignment)
|
|
40
|
+
progress = user_progress["progress"]
|
|
41
|
+
if progress and isinstance(progress[0], set):
|
|
42
|
+
progress = [list(s) for s in progress]
|
|
43
|
+
|
|
23
44
|
return JSONResponse(
|
|
24
45
|
content={
|
|
25
|
-
"status": "
|
|
26
|
-
"progress":
|
|
46
|
+
"status": "goodbye",
|
|
47
|
+
"progress": progress,
|
|
27
48
|
"time": user_progress["time"],
|
|
28
|
-
"token":
|
|
49
|
+
"token": token,
|
|
50
|
+
"instructions_goodbye": instructions_goodbye,
|
|
29
51
|
},
|
|
30
|
-
status_code=200
|
|
52
|
+
status_code=200,
|
|
31
53
|
)
|
|
32
54
|
|
|
33
55
|
|
|
@@ -44,7 +66,9 @@ def get_next_item(
|
|
|
44
66
|
if assignment == "task-based":
|
|
45
67
|
return get_next_item_taskbased(campaign_id, user_id, tasks_data, progress_data)
|
|
46
68
|
elif assignment == "single-stream":
|
|
47
|
-
return get_next_item_singlestream(
|
|
69
|
+
return get_next_item_singlestream(
|
|
70
|
+
campaign_id, user_id, tasks_data, progress_data
|
|
71
|
+
)
|
|
48
72
|
elif assignment == "dynamic":
|
|
49
73
|
return get_next_item_dynamic(campaign_id, user_id, tasks_data, progress_data)
|
|
50
74
|
else:
|
|
@@ -63,11 +87,17 @@ def get_i_item(
|
|
|
63
87
|
"""
|
|
64
88
|
assignment = tasks_data[campaign_id]["info"]["assignment"]
|
|
65
89
|
if assignment == "task-based":
|
|
66
|
-
return get_i_item_taskbased(
|
|
90
|
+
return get_i_item_taskbased(
|
|
91
|
+
campaign_id, user_id, tasks_data, progress_data, item_i
|
|
92
|
+
)
|
|
67
93
|
elif assignment == "single-stream":
|
|
68
|
-
return get_i_item_singlestream(
|
|
94
|
+
return get_i_item_singlestream(
|
|
95
|
+
campaign_id, user_id, tasks_data, progress_data, item_i
|
|
96
|
+
)
|
|
69
97
|
else:
|
|
70
|
-
return JSONResponse(
|
|
98
|
+
return JSONResponse(
|
|
99
|
+
content="Get item not supported for this assignment type", status_code=400
|
|
100
|
+
)
|
|
71
101
|
|
|
72
102
|
|
|
73
103
|
def get_i_item_taskbased(
|
|
@@ -93,10 +123,7 @@ def get_i_item_taskbased(
|
|
|
93
123
|
payload_existing["comment"] = latest_item["comment"]
|
|
94
124
|
|
|
95
125
|
if item_i < 0 or item_i >= len(data_all[campaign_id]["data"][user_id]):
|
|
96
|
-
return JSONResponse(
|
|
97
|
-
content="Item index out of range",
|
|
98
|
-
status_code=400
|
|
99
|
-
)
|
|
126
|
+
return JSONResponse(content="Item index out of range", status_code=400)
|
|
100
127
|
|
|
101
128
|
return JSONResponse(
|
|
102
129
|
content={
|
|
@@ -105,14 +132,16 @@ def get_i_item_taskbased(
|
|
|
105
132
|
"time": user_progress["time"],
|
|
106
133
|
"info": {
|
|
107
134
|
"item_i": item_i,
|
|
108
|
-
}
|
|
135
|
+
}
|
|
136
|
+
| {
|
|
109
137
|
k: v
|
|
110
138
|
for k, v in data_all[campaign_id]["info"].items()
|
|
111
139
|
if k.startswith("protocol")
|
|
112
140
|
},
|
|
113
|
-
"payload": data_all[campaign_id]["data"][user_id][item_i]
|
|
114
|
-
}
|
|
115
|
-
|
|
141
|
+
"payload": data_all[campaign_id]["data"][user_id][item_i],
|
|
142
|
+
}
|
|
143
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
144
|
+
status_code=200,
|
|
116
145
|
)
|
|
117
146
|
|
|
118
147
|
|
|
@@ -140,10 +169,7 @@ def get_i_item_singlestream(
|
|
|
140
169
|
payload_existing["comment"] = latest_item["comment"]
|
|
141
170
|
|
|
142
171
|
if item_i < 0 or item_i >= len(data_all[campaign_id]["data"]):
|
|
143
|
-
return JSONResponse(
|
|
144
|
-
content="Item index out of range",
|
|
145
|
-
status_code=400
|
|
146
|
-
)
|
|
172
|
+
return JSONResponse(content="Item index out of range", status_code=400)
|
|
147
173
|
|
|
148
174
|
return JSONResponse(
|
|
149
175
|
content={
|
|
@@ -152,14 +178,16 @@ def get_i_item_singlestream(
|
|
|
152
178
|
"time": user_progress["time"],
|
|
153
179
|
"info": {
|
|
154
180
|
"item_i": item_i,
|
|
155
|
-
}
|
|
181
|
+
}
|
|
182
|
+
| {
|
|
156
183
|
k: v
|
|
157
184
|
for k, v in data_all[campaign_id]["info"].items()
|
|
158
185
|
if k.startswith("protocol")
|
|
159
186
|
},
|
|
160
|
-
"payload": data_all[campaign_id]["data"][item_i]
|
|
161
|
-
}
|
|
162
|
-
|
|
187
|
+
"payload": data_all[campaign_id]["data"][item_i],
|
|
188
|
+
}
|
|
189
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
190
|
+
status_code=200,
|
|
163
191
|
)
|
|
164
192
|
|
|
165
193
|
|
|
@@ -196,14 +224,16 @@ def get_next_item_taskbased(
|
|
|
196
224
|
"time": user_progress["time"],
|
|
197
225
|
"info": {
|
|
198
226
|
"item_i": item_i,
|
|
199
|
-
}
|
|
227
|
+
}
|
|
228
|
+
| {
|
|
200
229
|
k: v
|
|
201
230
|
for k, v in data_all[campaign_id]["info"].items()
|
|
202
231
|
if k.startswith("protocol")
|
|
203
232
|
},
|
|
204
|
-
"payload": data_all[campaign_id]["data"][user_id][item_i]
|
|
205
|
-
}
|
|
206
|
-
|
|
233
|
+
"payload": data_all[campaign_id]["data"][user_id][item_i],
|
|
234
|
+
}
|
|
235
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
236
|
+
status_code=200,
|
|
207
237
|
)
|
|
208
238
|
|
|
209
239
|
|
|
@@ -249,21 +279,176 @@ def get_next_item_singlestream(
|
|
|
249
279
|
"progress": progress,
|
|
250
280
|
"info": {
|
|
251
281
|
"item_i": item_i,
|
|
252
|
-
}
|
|
282
|
+
}
|
|
283
|
+
| {
|
|
253
284
|
k: v
|
|
254
285
|
for k, v in data_all[campaign_id]["info"].items()
|
|
255
286
|
if k.startswith("protocol")
|
|
256
287
|
},
|
|
257
|
-
"payload": data_all[campaign_id]["data"][item_i]
|
|
258
|
-
}
|
|
259
|
-
|
|
288
|
+
"payload": data_all[campaign_id]["data"][item_i],
|
|
289
|
+
}
|
|
290
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
291
|
+
status_code=200,
|
|
260
292
|
)
|
|
261
293
|
|
|
262
294
|
|
|
295
|
+
def get_next_item_dynamic(
|
|
296
|
+
campaign_id: str,
|
|
297
|
+
user_id: str,
|
|
298
|
+
tasks_data: dict,
|
|
299
|
+
progress_data: dict,
|
|
300
|
+
) -> JSONResponse:
|
|
301
|
+
"""
|
|
302
|
+
Get the next item for dynamic assignment based on model performance.
|
|
303
|
+
|
|
304
|
+
NOTE: All items must contain all model outputs for this assignment type to work.
|
|
305
|
+
|
|
306
|
+
In this mode, items are selected based on the current performance of models:
|
|
307
|
+
1. Contrastive comparison: `dynamic_contrastive_models` models are randomly selected and shown per item
|
|
308
|
+
2. First phase: Each model gets `dynamic_first` annotations with fully random selection
|
|
309
|
+
3. After first phase: Top `dynamic_top` models are identified, K randomly selected from them
|
|
310
|
+
4. Items with least annotations for the selected models are prioritized
|
|
311
|
+
5. With probability `dynamic_backoff`, uniformly random selection is used instead
|
|
312
|
+
"""
|
|
313
|
+
import random
|
|
314
|
+
|
|
315
|
+
user_progress = progress_data[campaign_id][user_id]
|
|
316
|
+
campaign_data = tasks_data[campaign_id]
|
|
263
317
|
|
|
264
|
-
|
|
265
|
-
|
|
318
|
+
# Get all unique models in the campaign (all items must have all models)
|
|
319
|
+
all_models = list(set(campaign_data["data"][0][0]["tgt"].keys()))
|
|
266
320
|
|
|
321
|
+
# Check if completed (all models completed for all items)
|
|
322
|
+
# NOTE: this will rarely trigger but we don't have a good way to know when to end anyway for now
|
|
323
|
+
if all(len(v) == len(all_models) for v in user_progress["progress"]):
|
|
324
|
+
return _completed_response(tasks_data, progress_data, campaign_id, user_id)
|
|
325
|
+
|
|
326
|
+
# Get configuration parameters
|
|
327
|
+
dynamic_top = campaign_data["info"].get("dynamic_top", 2)
|
|
328
|
+
dynamic_first = campaign_data["info"].get("dynamic_first", 5)
|
|
329
|
+
dynamic_contrastive_models = campaign_data["info"].get(
|
|
330
|
+
"dynamic_contrastive_models", 1
|
|
331
|
+
)
|
|
332
|
+
dynamic_backoff = campaign_data["info"].get("dynamic_backoff", 0)
|
|
333
|
+
|
|
334
|
+
# Count annotations per (model, item) pair to track coverage
|
|
335
|
+
annotations = get_db_log(campaign_id)
|
|
336
|
+
model_item_counts = collections.defaultdict(int) # (model, item_i) -> count
|
|
337
|
+
model_total_counts = collections.defaultdict(int) # model -> total count
|
|
338
|
+
|
|
339
|
+
for annotation_line in annotations:
|
|
340
|
+
if (item_i := annotation_line.get("item_i")) is not None:
|
|
341
|
+
# Count which models were annotated in this annotation
|
|
342
|
+
for annotation_item in annotation_line.get("annotation", []):
|
|
343
|
+
for model in annotation_item:
|
|
344
|
+
model_item_counts[(model, item_i)] += 1
|
|
345
|
+
model_total_counts[model] += 1
|
|
346
|
+
|
|
347
|
+
# Check if we're still in the first phase (collecting initial data)
|
|
348
|
+
in_first_phase = any(
|
|
349
|
+
model_total_counts.get(model, 0) < dynamic_first for model in all_models
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Select which models to show
|
|
353
|
+
if in_first_phase:
|
|
354
|
+
# First phase or backoff: select models that don't have enough annotations yet
|
|
355
|
+
selected_models = random.sample(
|
|
356
|
+
[
|
|
357
|
+
model
|
|
358
|
+
for model in all_models
|
|
359
|
+
if model_total_counts.get(model, 0) < dynamic_first
|
|
360
|
+
],
|
|
361
|
+
k=min(dynamic_contrastive_models, len(all_models)),
|
|
362
|
+
)
|
|
363
|
+
elif random.random() < dynamic_backoff:
|
|
364
|
+
# Backoff: select K models randomly from all models
|
|
365
|
+
selected_models = random.sample(
|
|
366
|
+
all_models, k=min(dynamic_contrastive_models, len(all_models))
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
# Calculate model scores from annotations
|
|
370
|
+
model_scores = collections.defaultdict(list)
|
|
371
|
+
for annotation_line in annotations:
|
|
372
|
+
for annotation_item in annotation_line.get("annotation", {}):
|
|
373
|
+
for model in annotation_item:
|
|
374
|
+
if "score" in annotation_item[model]:
|
|
375
|
+
model_scores[model].append(annotation_item[model]["score"])
|
|
376
|
+
|
|
377
|
+
# Calculate average scores
|
|
378
|
+
model_avg_scores = {
|
|
379
|
+
model: statistics.mean(scores) for model, scores in model_scores.items()
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
# Get top N models
|
|
383
|
+
sorted_models = sorted(
|
|
384
|
+
model_avg_scores.items(), key=lambda x: x[1], reverse=True
|
|
385
|
+
)
|
|
386
|
+
top_models = [model for model, score in sorted_models[:dynamic_top]]
|
|
387
|
+
|
|
388
|
+
# From top N, randomly select K models
|
|
389
|
+
selected_models = random.sample(
|
|
390
|
+
top_models, k=min(dynamic_contrastive_models, len(top_models))
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Find incomplete items for the selected models (items where not all selected models are done)
|
|
394
|
+
item_annotation_counts = {
|
|
395
|
+
i: sum(model in completed_models for model in selected_models)
|
|
396
|
+
for i, completed_models in enumerate(user_progress["progress"])
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
# Select item with minimum annotations (with random tiebreaking)
|
|
400
|
+
min_annotations = min(item_annotation_counts.values())
|
|
401
|
+
items_with_min = [
|
|
402
|
+
item_i
|
|
403
|
+
for item_i, count in item_annotation_counts.items()
|
|
404
|
+
if count == min_annotations
|
|
405
|
+
]
|
|
406
|
+
item_i = random.choice(items_with_min)
|
|
407
|
+
|
|
408
|
+
# Prune the payload to only include selected models
|
|
409
|
+
original_item = campaign_data["data"][item_i]
|
|
410
|
+
pruned_item = []
|
|
411
|
+
for doc_segment in original_item:
|
|
412
|
+
pruned_segment = doc_segment.copy()
|
|
413
|
+
# Filter tgt to only include selected models
|
|
414
|
+
pruned_segment["tgt"] = {
|
|
415
|
+
model: doc_segment["tgt"][model]
|
|
416
|
+
for model in selected_models
|
|
417
|
+
if model in doc_segment["tgt"]
|
|
418
|
+
}
|
|
419
|
+
# Also filter error_spans if present
|
|
420
|
+
if "error_spans" in doc_segment:
|
|
421
|
+
pruned_segment["error_spans"] = {
|
|
422
|
+
model: doc_segment["error_spans"][model]
|
|
423
|
+
for model in selected_models
|
|
424
|
+
if model in doc_segment.get("error_spans", {})
|
|
425
|
+
}
|
|
426
|
+
# Also filter validation if present
|
|
427
|
+
if "validation" in doc_segment:
|
|
428
|
+
pruned_segment["validation"] = {
|
|
429
|
+
model: doc_segment["validation"][model]
|
|
430
|
+
for model in selected_models
|
|
431
|
+
if model in doc_segment.get("validation", {})
|
|
432
|
+
}
|
|
433
|
+
pruned_item.append(pruned_segment)
|
|
434
|
+
|
|
435
|
+
return JSONResponse(
|
|
436
|
+
content={
|
|
437
|
+
"status": "ok",
|
|
438
|
+
"time": user_progress["time"],
|
|
439
|
+
"progress": user_progress["progress"],
|
|
440
|
+
"info": {
|
|
441
|
+
"item_i": item_i,
|
|
442
|
+
}
|
|
443
|
+
| {
|
|
444
|
+
k: v
|
|
445
|
+
for k, v in campaign_data["info"].items()
|
|
446
|
+
if k.startswith("protocol")
|
|
447
|
+
},
|
|
448
|
+
"payload": pruned_item,
|
|
449
|
+
},
|
|
450
|
+
status_code=200,
|
|
451
|
+
)
|
|
267
452
|
|
|
268
453
|
|
|
269
454
|
def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> None:
|
|
@@ -289,11 +474,10 @@ def reset_task(
|
|
|
289
474
|
# Save reset marker for this user to mask existing annotations
|
|
290
475
|
num_items = len(tasks_data[campaign_id]["data"][user_id])
|
|
291
476
|
for item_i in range(num_items):
|
|
292
|
-
save_db_payload(
|
|
293
|
-
|
|
294
|
-
"item_i": item_i,
|
|
295
|
-
|
|
296
|
-
})
|
|
477
|
+
save_db_payload(
|
|
478
|
+
campaign_id,
|
|
479
|
+
{"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
|
|
480
|
+
)
|
|
297
481
|
progress_data[campaign_id][user_id]["progress"] = [False] * num_items
|
|
298
482
|
_reset_user_time(progress_data, campaign_id, user_id)
|
|
299
483
|
return JSONResponse(content="ok", status_code=200)
|
|
@@ -301,18 +485,32 @@ def reset_task(
|
|
|
301
485
|
# Save reset markers for all items (shared pool)
|
|
302
486
|
num_items = len(tasks_data[campaign_id]["data"])
|
|
303
487
|
for item_i in range(num_items):
|
|
304
|
-
save_db_payload(
|
|
305
|
-
|
|
306
|
-
"item_i": item_i,
|
|
307
|
-
|
|
308
|
-
})
|
|
488
|
+
save_db_payload(
|
|
489
|
+
campaign_id,
|
|
490
|
+
{"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
|
|
491
|
+
)
|
|
309
492
|
# for single-stream reset all progress
|
|
310
493
|
for uid in progress_data[campaign_id]:
|
|
311
494
|
progress_data[campaign_id][uid]["progress"] = [False] * num_items
|
|
312
495
|
_reset_user_time(progress_data, campaign_id, user_id)
|
|
313
496
|
return JSONResponse(content="ok", status_code=200)
|
|
497
|
+
elif assignment == "dynamic":
|
|
498
|
+
# Save reset markers for all items (shared pool like single-stream)
|
|
499
|
+
num_items = len(tasks_data[campaign_id]["data"])
|
|
500
|
+
for item_i in range(num_items):
|
|
501
|
+
save_db_payload(
|
|
502
|
+
campaign_id,
|
|
503
|
+
{"user_id": None, "item_i": item_i, "annotation": RESET_MARKER},
|
|
504
|
+
)
|
|
505
|
+
# for dynamic reset all progress (use sets to track models)
|
|
506
|
+
for uid in progress_data[campaign_id]:
|
|
507
|
+
progress_data[campaign_id][uid]["progress"] = [[] for _ in range(num_items)]
|
|
508
|
+
_reset_user_time(progress_data, campaign_id, user_id)
|
|
509
|
+
return JSONResponse(content="ok", status_code=200)
|
|
314
510
|
else:
|
|
315
|
-
return JSONResponse(
|
|
511
|
+
return JSONResponse(
|
|
512
|
+
content="Reset not supported for this assignment type", status_code=400
|
|
513
|
+
)
|
|
316
514
|
|
|
317
515
|
|
|
318
516
|
def update_progress(
|
|
@@ -337,6 +535,18 @@ def update_progress(
|
|
|
337
535
|
progress_data[campaign_id][uid]["progress"][item_i] = True
|
|
338
536
|
return JSONResponse(content="ok", status_code=200)
|
|
339
537
|
elif assignment == "dynamic":
|
|
340
|
-
|
|
538
|
+
# For dynamic, track which models were annotated
|
|
539
|
+
# Extract models from the payload annotation
|
|
540
|
+
annotated_models = []
|
|
541
|
+
if "annotation" in payload:
|
|
542
|
+
for annotation_item in payload.get("annotation", []):
|
|
543
|
+
if isinstance(annotation_item, dict):
|
|
544
|
+
annotated_models.extend(annotation_item.keys())
|
|
545
|
+
|
|
546
|
+
# Update progress for all users (shared pool)
|
|
547
|
+
for uid in progress_data[campaign_id]:
|
|
548
|
+
# Add the newly annotated models
|
|
549
|
+
progress_data[campaign_id][uid]["progress"][item_i].extend(annotated_models)
|
|
550
|
+
return JSONResponse(content="ok", status_code=200)
|
|
341
551
|
else:
|
|
342
552
|
return JSONResponse(content="Unknown campaign assignment type", status_code=400)
|
pearmut/cli.py
CHANGED
|
@@ -179,7 +179,7 @@ def _shuffle_campaign_data(campaign_data, rng):
|
|
|
179
179
|
for user_id, task in campaign_data["data"].items():
|
|
180
180
|
for doc in task:
|
|
181
181
|
shuffle_document(doc)
|
|
182
|
-
elif assignment
|
|
182
|
+
elif assignment in ["single-stream", "dynamic"]:
|
|
183
183
|
# Shuffle each document in the shared pool
|
|
184
184
|
for doc in campaign_data["data"]:
|
|
185
185
|
shuffle_document(doc)
|
|
@@ -259,8 +259,46 @@ def _add_single_campaign(data_file, overwrite, server):
|
|
|
259
259
|
else:
|
|
260
260
|
raise ValueError("'users' must be an integer or a list.")
|
|
261
261
|
elif assignment == "dynamic":
|
|
262
|
-
|
|
263
|
-
|
|
262
|
+
tasks = campaign_data["data"]
|
|
263
|
+
if users_spec is None:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"Dynamic campaigns must specify 'users' in info.")
|
|
266
|
+
if not isinstance(campaign_data["data"], list):
|
|
267
|
+
raise ValueError(
|
|
268
|
+
"Dynamic campaign 'data' must be a list of items.")
|
|
269
|
+
# Validate item structure for dynamic
|
|
270
|
+
for doc_i, doc in enumerate(tasks):
|
|
271
|
+
try:
|
|
272
|
+
_validate_item_structure(doc)
|
|
273
|
+
except ValueError as e:
|
|
274
|
+
raise ValueError(f"Document {doc_i}: {e}")
|
|
275
|
+
if isinstance(users_spec, int):
|
|
276
|
+
num_users = users_spec
|
|
277
|
+
elif isinstance(users_spec, list):
|
|
278
|
+
num_users = len(users_spec)
|
|
279
|
+
else:
|
|
280
|
+
raise ValueError("'users' must be an integer or a list.")
|
|
281
|
+
# Validate dynamic-specific parameters
|
|
282
|
+
if "dynamic_top" not in campaign_data["info"]:
|
|
283
|
+
campaign_data["info"]["dynamic_top"] = 2
|
|
284
|
+
if "dynamic_first" not in campaign_data["info"]:
|
|
285
|
+
campaign_data["info"]["dynamic_first"] = 5
|
|
286
|
+
if "dynamic_contrastive_models" not in campaign_data["info"]:
|
|
287
|
+
campaign_data["info"]["dynamic_contrastive_models"] = 1
|
|
288
|
+
# Validate that dynamic_first is at least 1
|
|
289
|
+
assert campaign_data["info"]["dynamic_first"] >= 1, "dynamic_first must be at least 1"
|
|
290
|
+
# Validate that dynamic_contrastive_models is at most dynamic_top
|
|
291
|
+
assert campaign_data["info"]["dynamic_contrastive_models"] <= campaign_data["info"]["dynamic_top"], \
|
|
292
|
+
"dynamic_contrastive_models must be at most dynamic_top"
|
|
293
|
+
# Validate that all items have the same models
|
|
294
|
+
all_models = set()
|
|
295
|
+
for item in campaign_data["data"]:
|
|
296
|
+
if item and len(item) > 0:
|
|
297
|
+
all_models.update(item[0]["tgt"].keys())
|
|
298
|
+
for item in campaign_data["data"]:
|
|
299
|
+
if item and len(item) > 0:
|
|
300
|
+
item_models = set(item[0]["tgt"].keys())
|
|
301
|
+
assert item_models == all_models, "All items must have the same model outputs"
|
|
264
302
|
else:
|
|
265
303
|
raise ValueError(f"Unknown campaign assignment type: {assignment}")
|
|
266
304
|
|
|
@@ -310,13 +348,13 @@ def _add_single_campaign(data_file, overwrite, server):
|
|
|
310
348
|
os.remove(output_file)
|
|
311
349
|
|
|
312
350
|
# For task-based, data is a dict mapping user_id -> tasks
|
|
313
|
-
# For single-stream, data is a flat list (shared among all users)
|
|
351
|
+
# For single-stream and dynamic, data is a flat list (shared among all users)
|
|
314
352
|
if assignment == "task-based":
|
|
315
353
|
campaign_data["data"] = {
|
|
316
354
|
user_id: task
|
|
317
355
|
for user_id, task in zip(user_ids, tasks)
|
|
318
356
|
}
|
|
319
|
-
elif assignment
|
|
357
|
+
elif assignment in ["single-stream", "dynamic"]:
|
|
320
358
|
campaign_data["data"] = tasks
|
|
321
359
|
|
|
322
360
|
# generate a token for dashboard access if not present
|
|
@@ -338,6 +376,7 @@ def _add_single_campaign(data_file, overwrite, server):
|
|
|
338
376
|
"progress": (
|
|
339
377
|
[False]*len(campaign_data["data"][user_id]) if assignment == "task-based"
|
|
340
378
|
else [False]*len(campaign_data["data"]) if assignment == "single-stream"
|
|
379
|
+
else [list() for _ in range(len(campaign_data["data"]))] if assignment == "dynamic"
|
|
341
380
|
else []
|
|
342
381
|
),
|
|
343
382
|
"time_start": None,
|
|
@@ -421,9 +460,7 @@ def _add_single_campaign(data_file, overwrite, server):
|
|
|
421
460
|
json.dump(campaign_data, f, indent=2, ensure_ascii=False)
|
|
422
461
|
|
|
423
462
|
progress_data[campaign_data['campaign_id']] = user_progress
|
|
424
|
-
|
|
425
|
-
with open(f"{ROOT}/data/progress.json", "w") as f:
|
|
426
|
-
json.dump(progress_data, f, indent=2, ensure_ascii=False)
|
|
463
|
+
save_progress_data(progress_data)
|
|
427
464
|
|
|
428
465
|
|
|
429
466
|
print(
|