pearmut 0.3.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pearmut/app.py +119 -27
- pearmut/assignment.py +318 -55
- pearmut/cli.py +245 -135
- pearmut/constants.py +93 -0
- pearmut/results_export.py +210 -0
- pearmut/static/basic.bundle.js +1 -1
- pearmut/static/basic.html +39 -3
- pearmut/static/dashboard.bundle.js +1 -1
- pearmut/static/dashboard.html +27 -12
- pearmut/static/index.bundle.js +1 -1
- pearmut/static/index.html +1 -1
- pearmut/utils.py +3 -1
- {pearmut-0.3.3.dist-info → pearmut-1.0.1.dist-info}/METADATA +152 -34
- pearmut-1.0.1.dist-info/RECORD +20 -0
- pearmut-0.3.3.dist-info/RECORD +0 -18
- {pearmut-0.3.3.dist-info → pearmut-1.0.1.dist-info}/WHEEL +0 -0
- {pearmut-0.3.3.dist-info → pearmut-1.0.1.dist-info}/entry_points.txt +0 -0
- {pearmut-0.3.3.dist-info → pearmut-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {pearmut-0.3.3.dist-info → pearmut-1.0.1.dist-info}/top_level.txt +0 -0
pearmut/app.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
|
1
|
-
import collections
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
|
-
import statistics
|
|
5
3
|
from typing import Any
|
|
6
4
|
|
|
7
5
|
from fastapi import FastAPI, Query
|
|
8
6
|
from fastapi.middleware.cors import CORSMiddleware
|
|
9
|
-
from fastapi.responses import JSONResponse
|
|
7
|
+
from fastapi.responses import JSONResponse, Response
|
|
10
8
|
from fastapi.staticfiles import StaticFiles
|
|
11
9
|
from pydantic import BaseModel
|
|
12
10
|
|
|
13
11
|
from .assignment import get_i_item, get_next_item, reset_task, update_progress
|
|
12
|
+
from .results_export import (
|
|
13
|
+
compute_model_scores,
|
|
14
|
+
generate_latex_table,
|
|
15
|
+
generate_pdf,
|
|
16
|
+
generate_typst_table,
|
|
17
|
+
)
|
|
14
18
|
from .utils import (
|
|
15
19
|
ROOT,
|
|
16
20
|
check_validation_threshold,
|
|
17
|
-
get_db_log,
|
|
18
21
|
load_progress_data,
|
|
19
22
|
save_db_payload,
|
|
20
23
|
save_progress_data,
|
|
@@ -159,7 +162,7 @@ async def _dashboard_data(request: DashboardDataRequest):
|
|
|
159
162
|
|
|
160
163
|
progress_new = {}
|
|
161
164
|
assignment = tasks_data[campaign_id]["info"]["assignment"]
|
|
162
|
-
if assignment not in ["task-based", "single-stream"]:
|
|
165
|
+
if assignment not in ["task-based", "single-stream", "dynamic"]:
|
|
163
166
|
return JSONResponse(
|
|
164
167
|
content="Unsupported campaign assignment type", status_code=400
|
|
165
168
|
)
|
|
@@ -211,31 +214,47 @@ async def _dashboard_results(request: DashboardResultsRequest):
|
|
|
211
214
|
if token != tasks_data[campaign_id]["token"]:
|
|
212
215
|
return JSONResponse(content="Invalid token", status_code=400)
|
|
213
216
|
|
|
214
|
-
|
|
215
|
-
model_scores = collections.defaultdict(dict)
|
|
216
|
-
|
|
217
|
-
# Iterate through all tasks to find items with 'models' field (basic template)
|
|
218
|
-
log = get_db_log(campaign_id)
|
|
219
|
-
for entry in log:
|
|
220
|
-
if "item" not in entry or "annotation" not in entry:
|
|
221
|
-
continue
|
|
222
|
-
for item, annotation in zip(entry["item"], entry["annotation"]):
|
|
223
|
-
for model, annotation in annotation.items():
|
|
224
|
-
if "score" in annotation and annotation["score"] is not None:
|
|
225
|
-
model_scores[model][json.dumps(item)] = annotation["score"]
|
|
226
|
-
|
|
227
|
-
results = [
|
|
228
|
-
{
|
|
229
|
-
"model": model,
|
|
230
|
-
"score": statistics.mean(scores.values()),
|
|
231
|
-
"count": len(scores),
|
|
232
|
-
}
|
|
233
|
-
for model, scores in model_scores.items()
|
|
234
|
-
]
|
|
235
|
-
results.sort(key=lambda x: x["score"], reverse=True)
|
|
217
|
+
results = compute_model_scores(campaign_id)
|
|
236
218
|
return JSONResponse(content=results, status_code=200)
|
|
237
219
|
|
|
238
220
|
|
|
221
|
+
@app.get("/export-results")
|
|
222
|
+
async def _export_results(
|
|
223
|
+
campaign_id: str = Query(),
|
|
224
|
+
token: str = Query(),
|
|
225
|
+
format: str = Query(),
|
|
226
|
+
):
|
|
227
|
+
if campaign_id not in progress_data:
|
|
228
|
+
return JSONResponse(content="Unknown campaign ID", status_code=400)
|
|
229
|
+
|
|
230
|
+
# Check if token is valid
|
|
231
|
+
if token != tasks_data[campaign_id]["token"]:
|
|
232
|
+
return JSONResponse(content="Invalid token", status_code=400)
|
|
233
|
+
|
|
234
|
+
results = compute_model_scores(campaign_id)
|
|
235
|
+
|
|
236
|
+
if format == "typst":
|
|
237
|
+
content = generate_typst_table(results)
|
|
238
|
+
return Response(
|
|
239
|
+
content=content,
|
|
240
|
+
media_type="text/plain",
|
|
241
|
+
)
|
|
242
|
+
elif format == "latex":
|
|
243
|
+
content = generate_latex_table(results)
|
|
244
|
+
return Response(
|
|
245
|
+
content=content,
|
|
246
|
+
media_type="text/plain",
|
|
247
|
+
)
|
|
248
|
+
elif format == "pdf":
|
|
249
|
+
pdf_bytes = generate_pdf(results, campaign_id)
|
|
250
|
+
return Response(
|
|
251
|
+
content=pdf_bytes,
|
|
252
|
+
media_type="application/pdf",
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
return JSONResponse(content="Invalid export format", status_code=400)
|
|
256
|
+
|
|
257
|
+
|
|
239
258
|
class ResetTaskRequest(BaseModel):
|
|
240
259
|
campaign_id: str
|
|
241
260
|
user_id: str
|
|
@@ -261,6 +280,79 @@ async def _reset_task(request: ResetTaskRequest):
|
|
|
261
280
|
return response
|
|
262
281
|
|
|
263
282
|
|
|
283
|
+
class PurgeCampaignRequest(BaseModel):
|
|
284
|
+
campaign_id: str
|
|
285
|
+
token: str
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@app.post("/purge-campaign")
|
|
289
|
+
async def _purge_campaign(request: PurgeCampaignRequest):
|
|
290
|
+
global progress_data, tasks_data
|
|
291
|
+
|
|
292
|
+
campaign_id = request.campaign_id
|
|
293
|
+
token = request.token
|
|
294
|
+
|
|
295
|
+
if campaign_id not in progress_data:
|
|
296
|
+
return JSONResponse(content="Unknown campaign ID", status_code=400)
|
|
297
|
+
if token != tasks_data[campaign_id]["token"]:
|
|
298
|
+
return JSONResponse(content="Invalid token", status_code=400)
|
|
299
|
+
|
|
300
|
+
# Unlink assets if they exist
|
|
301
|
+
destination = tasks_data[campaign_id].get("info", {}).get("assets", {}).get("destination")
|
|
302
|
+
if destination:
|
|
303
|
+
symlink_path = f"{ROOT}/data/{destination}".rstrip("/")
|
|
304
|
+
if os.path.islink(symlink_path):
|
|
305
|
+
os.remove(symlink_path)
|
|
306
|
+
|
|
307
|
+
# Remove task file
|
|
308
|
+
task_file = f"{ROOT}/data/tasks/{campaign_id}.json"
|
|
309
|
+
if os.path.exists(task_file):
|
|
310
|
+
os.remove(task_file)
|
|
311
|
+
|
|
312
|
+
# Remove output file
|
|
313
|
+
output_file = f"{ROOT}/data/outputs/{campaign_id}.jsonl"
|
|
314
|
+
if os.path.exists(output_file):
|
|
315
|
+
os.remove(output_file)
|
|
316
|
+
|
|
317
|
+
# Remove from in-memory data structures
|
|
318
|
+
del tasks_data[campaign_id]
|
|
319
|
+
del progress_data[campaign_id]
|
|
320
|
+
|
|
321
|
+
# Save updated progress data
|
|
322
|
+
save_progress_data(progress_data)
|
|
323
|
+
|
|
324
|
+
return JSONResponse(content="ok", status_code=200)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class AddCampaignRequest(BaseModel):
|
|
328
|
+
campaign_data: dict[str, Any]
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@app.post("/add-campaign")
|
|
332
|
+
async def _add_campaign(request: AddCampaignRequest):
|
|
333
|
+
global progress_data, tasks_data
|
|
334
|
+
|
|
335
|
+
from .cli import _add_single_campaign
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
server = f"{os.environ.get('PEARMUT_SERVER_URL', 'http://localhost:8001')}"
|
|
339
|
+
_add_single_campaign(request.campaign_data, overwrite=False, server=server)
|
|
340
|
+
|
|
341
|
+
campaign_id = request.campaign_data['campaign_id']
|
|
342
|
+
with open(f"{ROOT}/data/tasks/{campaign_id}.json", "r") as f:
|
|
343
|
+
tasks_data[campaign_id] = json.load(f)
|
|
344
|
+
|
|
345
|
+
progress_data = load_progress_data(warn=None)
|
|
346
|
+
|
|
347
|
+
return JSONResponse(content={
|
|
348
|
+
"status": "ok",
|
|
349
|
+
"campaign_id": campaign_id,
|
|
350
|
+
"token": tasks_data[campaign_id]["token"]
|
|
351
|
+
}, status_code=200)
|
|
352
|
+
except Exception as e:
|
|
353
|
+
return JSONResponse(content={"error": str(e)}, status_code=400)
|
|
354
|
+
|
|
355
|
+
|
|
264
356
|
@app.get("/download-annotations")
|
|
265
357
|
async def _download_annotations(
|
|
266
358
|
campaign_id: list[str] = Query(),
|
pearmut/assignment.py
CHANGED
|
@@ -1,16 +1,30 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import copy
|
|
1
3
|
import random
|
|
4
|
+
import statistics
|
|
2
5
|
from typing import Any
|
|
3
6
|
|
|
4
7
|
from fastapi.responses import JSONResponse
|
|
5
8
|
|
|
9
|
+
from .constants import PROTOCOL_INSTRUCTIONS
|
|
6
10
|
from .utils import (
|
|
7
11
|
RESET_MARKER,
|
|
8
12
|
check_validation_threshold,
|
|
13
|
+
get_db_log,
|
|
9
14
|
get_db_log_item,
|
|
10
15
|
save_db_payload,
|
|
11
16
|
)
|
|
12
17
|
|
|
13
18
|
|
|
19
|
+
def _get_instructions(tasks_data: dict, campaign_id: str) -> str:
|
|
20
|
+
"""Get instructions: custom if provided, else protocol default, else empty."""
|
|
21
|
+
campaign_info = tasks_data[campaign_id]["info"]
|
|
22
|
+
if "instructions" in campaign_info:
|
|
23
|
+
return campaign_info["instructions"]
|
|
24
|
+
return PROTOCOL_INSTRUCTIONS.get(campaign_info.get("protocol", ""), "")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
14
28
|
def _completed_response(
|
|
15
29
|
tasks_data: dict,
|
|
16
30
|
progress_data: dict,
|
|
@@ -20,14 +34,33 @@ def _completed_response(
|
|
|
20
34
|
"""Build a completed response with progress, time, and token."""
|
|
21
35
|
user_progress = progress_data[campaign_id][user_id]
|
|
22
36
|
is_ok = check_validation_threshold(tasks_data, progress_data, campaign_id, user_id)
|
|
37
|
+
token = user_progress["token_correct" if is_ok else "token_incorrect"]
|
|
38
|
+
|
|
39
|
+
# Get instructions_goodbye from campaign info, with default value
|
|
40
|
+
instructions_goodbye = tasks_data[campaign_id]["info"].get(
|
|
41
|
+
"instructions_goodbye",
|
|
42
|
+
"If someone asks you for a token of completion, show them: ${TOKEN}",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Replace variables ${TOKEN} and ${USER_ID}
|
|
46
|
+
instructions_goodbye = instructions_goodbye.replace("${TOKEN}", token).replace(
|
|
47
|
+
"${USER_ID}", user_id
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Convert sets to lists for JSON serialization (for dynamic assignment)
|
|
51
|
+
progress = user_progress["progress"]
|
|
52
|
+
if progress and isinstance(progress[0], set):
|
|
53
|
+
progress = [list(s) for s in progress]
|
|
54
|
+
|
|
23
55
|
return JSONResponse(
|
|
24
56
|
content={
|
|
25
|
-
"status": "
|
|
26
|
-
"progress":
|
|
57
|
+
"status": "goodbye",
|
|
58
|
+
"progress": progress,
|
|
27
59
|
"time": user_progress["time"],
|
|
28
|
-
"token":
|
|
60
|
+
"token": token,
|
|
61
|
+
"instructions_goodbye": instructions_goodbye,
|
|
29
62
|
},
|
|
30
|
-
status_code=200
|
|
63
|
+
status_code=200,
|
|
31
64
|
)
|
|
32
65
|
|
|
33
66
|
|
|
@@ -44,7 +77,9 @@ def get_next_item(
|
|
|
44
77
|
if assignment == "task-based":
|
|
45
78
|
return get_next_item_taskbased(campaign_id, user_id, tasks_data, progress_data)
|
|
46
79
|
elif assignment == "single-stream":
|
|
47
|
-
return get_next_item_singlestream(
|
|
80
|
+
return get_next_item_singlestream(
|
|
81
|
+
campaign_id, user_id, tasks_data, progress_data
|
|
82
|
+
)
|
|
48
83
|
elif assignment == "dynamic":
|
|
49
84
|
return get_next_item_dynamic(campaign_id, user_id, tasks_data, progress_data)
|
|
50
85
|
else:
|
|
@@ -63,11 +98,17 @@ def get_i_item(
|
|
|
63
98
|
"""
|
|
64
99
|
assignment = tasks_data[campaign_id]["info"]["assignment"]
|
|
65
100
|
if assignment == "task-based":
|
|
66
|
-
return get_i_item_taskbased(
|
|
101
|
+
return get_i_item_taskbased(
|
|
102
|
+
campaign_id, user_id, tasks_data, progress_data, item_i
|
|
103
|
+
)
|
|
67
104
|
elif assignment == "single-stream":
|
|
68
|
-
return get_i_item_singlestream(
|
|
105
|
+
return get_i_item_singlestream(
|
|
106
|
+
campaign_id, user_id, tasks_data, progress_data, item_i
|
|
107
|
+
)
|
|
69
108
|
else:
|
|
70
|
-
return JSONResponse(
|
|
109
|
+
return JSONResponse(
|
|
110
|
+
content="Get item not supported for this assignment type", status_code=400
|
|
111
|
+
)
|
|
71
112
|
|
|
72
113
|
|
|
73
114
|
def get_i_item_taskbased(
|
|
@@ -93,10 +134,7 @@ def get_i_item_taskbased(
|
|
|
93
134
|
payload_existing["comment"] = latest_item["comment"]
|
|
94
135
|
|
|
95
136
|
if item_i < 0 or item_i >= len(data_all[campaign_id]["data"][user_id]):
|
|
96
|
-
return JSONResponse(
|
|
97
|
-
content="Item index out of range",
|
|
98
|
-
status_code=400
|
|
99
|
-
)
|
|
137
|
+
return JSONResponse(content="Item index out of range", status_code=400)
|
|
100
138
|
|
|
101
139
|
return JSONResponse(
|
|
102
140
|
content={
|
|
@@ -105,14 +143,17 @@ def get_i_item_taskbased(
|
|
|
105
143
|
"time": user_progress["time"],
|
|
106
144
|
"info": {
|
|
107
145
|
"item_i": item_i,
|
|
108
|
-
|
|
146
|
+
"instructions": _get_instructions(data_all, campaign_id),
|
|
147
|
+
}
|
|
148
|
+
| {
|
|
109
149
|
k: v
|
|
110
150
|
for k, v in data_all[campaign_id]["info"].items()
|
|
111
|
-
if k
|
|
151
|
+
if k in {"protocol", "sliders"}
|
|
112
152
|
},
|
|
113
|
-
"payload": data_all[campaign_id]["data"][user_id][item_i]
|
|
114
|
-
}
|
|
115
|
-
|
|
153
|
+
"payload": data_all[campaign_id]["data"][user_id][item_i],
|
|
154
|
+
}
|
|
155
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
156
|
+
status_code=200,
|
|
116
157
|
)
|
|
117
158
|
|
|
118
159
|
|
|
@@ -140,10 +181,7 @@ def get_i_item_singlestream(
|
|
|
140
181
|
payload_existing["comment"] = latest_item["comment"]
|
|
141
182
|
|
|
142
183
|
if item_i < 0 or item_i >= len(data_all[campaign_id]["data"]):
|
|
143
|
-
return JSONResponse(
|
|
144
|
-
content="Item index out of range",
|
|
145
|
-
status_code=400
|
|
146
|
-
)
|
|
184
|
+
return JSONResponse(content="Item index out of range", status_code=400)
|
|
147
185
|
|
|
148
186
|
return JSONResponse(
|
|
149
187
|
content={
|
|
@@ -152,14 +190,17 @@ def get_i_item_singlestream(
|
|
|
152
190
|
"time": user_progress["time"],
|
|
153
191
|
"info": {
|
|
154
192
|
"item_i": item_i,
|
|
155
|
-
|
|
193
|
+
"instructions": _get_instructions(data_all, campaign_id),
|
|
194
|
+
}
|
|
195
|
+
| {
|
|
156
196
|
k: v
|
|
157
197
|
for k, v in data_all[campaign_id]["info"].items()
|
|
158
|
-
if k
|
|
198
|
+
if k in {"protocol", "sliders"}
|
|
159
199
|
},
|
|
160
|
-
"payload": data_all[campaign_id]["data"][item_i]
|
|
161
|
-
}
|
|
162
|
-
|
|
200
|
+
"payload": data_all[campaign_id]["data"][item_i],
|
|
201
|
+
}
|
|
202
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
203
|
+
status_code=200,
|
|
163
204
|
)
|
|
164
205
|
|
|
165
206
|
|
|
@@ -196,14 +237,17 @@ def get_next_item_taskbased(
|
|
|
196
237
|
"time": user_progress["time"],
|
|
197
238
|
"info": {
|
|
198
239
|
"item_i": item_i,
|
|
199
|
-
|
|
240
|
+
"instructions": _get_instructions(data_all, campaign_id),
|
|
241
|
+
}
|
|
242
|
+
| {
|
|
200
243
|
k: v
|
|
201
244
|
for k, v in data_all[campaign_id]["info"].items()
|
|
202
|
-
if k
|
|
245
|
+
if k in {"protocol", "sliders"}
|
|
203
246
|
},
|
|
204
|
-
"payload": data_all[campaign_id]["data"][user_id][item_i]
|
|
205
|
-
}
|
|
206
|
-
|
|
247
|
+
"payload": data_all[campaign_id]["data"][user_id][item_i],
|
|
248
|
+
}
|
|
249
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
250
|
+
status_code=200,
|
|
207
251
|
)
|
|
208
252
|
|
|
209
253
|
|
|
@@ -249,21 +293,178 @@ def get_next_item_singlestream(
|
|
|
249
293
|
"progress": progress,
|
|
250
294
|
"info": {
|
|
251
295
|
"item_i": item_i,
|
|
252
|
-
|
|
296
|
+
"instructions": _get_instructions(data_all, campaign_id),
|
|
297
|
+
}
|
|
298
|
+
| {
|
|
253
299
|
k: v
|
|
254
300
|
for k, v in data_all[campaign_id]["info"].items()
|
|
255
|
-
if k
|
|
301
|
+
if k in {"protocol", "sliders"}
|
|
256
302
|
},
|
|
257
|
-
"payload": data_all[campaign_id]["data"][item_i]
|
|
258
|
-
}
|
|
259
|
-
|
|
303
|
+
"payload": data_all[campaign_id]["data"][item_i],
|
|
304
|
+
}
|
|
305
|
+
| ({"payload_existing": payload_existing} if payload_existing else {}),
|
|
306
|
+
status_code=200,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def get_next_item_dynamic(
|
|
311
|
+
campaign_id: str,
|
|
312
|
+
user_id: str,
|
|
313
|
+
tasks_data: dict,
|
|
314
|
+
progress_data: dict,
|
|
315
|
+
) -> JSONResponse:
|
|
316
|
+
"""
|
|
317
|
+
Get the next item for dynamic assignment based on model performance.
|
|
318
|
+
|
|
319
|
+
NOTE: All items must contain all model outputs for this assignment type to work.
|
|
320
|
+
|
|
321
|
+
In this mode, items are selected based on the current performance of models:
|
|
322
|
+
1. Contrastive comparison: `dynamic_contrastive_models` models are randomly selected and shown per item
|
|
323
|
+
2. First phase: Each model gets `dynamic_first` annotations with fully random selection
|
|
324
|
+
3. After first phase: Top `dynamic_top` models are identified, K randomly selected from them
|
|
325
|
+
4. Items with least annotations for the selected models are prioritized
|
|
326
|
+
5. With probability `dynamic_backoff`, uniformly random selection is used instead
|
|
327
|
+
"""
|
|
328
|
+
import random
|
|
329
|
+
|
|
330
|
+
user_progress = progress_data[campaign_id][user_id]
|
|
331
|
+
campaign_data = tasks_data[campaign_id]
|
|
332
|
+
|
|
333
|
+
# Get all unique models in the campaign (all items must have all models)
|
|
334
|
+
all_models = list(set(campaign_data["data"][0][0]["tgt"].keys()))
|
|
335
|
+
|
|
336
|
+
# Check if completed (all models completed for all items)
|
|
337
|
+
# NOTE: this will rarely trigger but we don't have a good way to know when to end anyway for now
|
|
338
|
+
if all(len(v) == len(all_models) for v in user_progress["progress"]):
|
|
339
|
+
return _completed_response(tasks_data, progress_data, campaign_id, user_id)
|
|
340
|
+
|
|
341
|
+
# Get configuration parameters
|
|
342
|
+
dynamic_top = campaign_data["info"].get("dynamic_top", 2)
|
|
343
|
+
dynamic_first = campaign_data["info"].get("dynamic_first", 5)
|
|
344
|
+
dynamic_contrastive_models = campaign_data["info"].get(
|
|
345
|
+
"dynamic_contrastive_models", 1
|
|
346
|
+
)
|
|
347
|
+
dynamic_backoff = campaign_data["info"].get("dynamic_backoff", 0)
|
|
348
|
+
|
|
349
|
+
# Count annotations per (model, item) pair to track coverage
|
|
350
|
+
annotations = get_db_log(campaign_id)
|
|
351
|
+
model_item_counts = collections.defaultdict(int) # (model, item_i) -> count
|
|
352
|
+
model_total_counts = collections.defaultdict(int) # model -> total count
|
|
353
|
+
|
|
354
|
+
for annotation_line in annotations:
|
|
355
|
+
if (item_i := annotation_line.get("item_i")) is not None:
|
|
356
|
+
# Count which models were annotated in this annotation
|
|
357
|
+
for annotation_item in annotation_line.get("annotation", []):
|
|
358
|
+
for model in annotation_item:
|
|
359
|
+
model_item_counts[(model, item_i)] += 1
|
|
360
|
+
model_total_counts[model] += 1
|
|
361
|
+
|
|
362
|
+
# Check if we're still in the first phase (collecting initial data)
|
|
363
|
+
in_first_phase = any(
|
|
364
|
+
model_total_counts.get(model, 0) < dynamic_first for model in all_models
|
|
260
365
|
)
|
|
261
366
|
|
|
367
|
+
# Select which models to show
|
|
368
|
+
if in_first_phase:
|
|
369
|
+
# First phase or backoff: select models that don't have enough annotations yet
|
|
370
|
+
selected_models = random.sample(
|
|
371
|
+
[
|
|
372
|
+
model
|
|
373
|
+
for model in all_models
|
|
374
|
+
if model_total_counts.get(model, 0) < dynamic_first
|
|
375
|
+
],
|
|
376
|
+
k=min(dynamic_contrastive_models, len(all_models)),
|
|
377
|
+
)
|
|
378
|
+
elif random.random() < dynamic_backoff:
|
|
379
|
+
# Backoff: select K models randomly from all models
|
|
380
|
+
selected_models = random.sample(
|
|
381
|
+
all_models, k=min(dynamic_contrastive_models, len(all_models))
|
|
382
|
+
)
|
|
383
|
+
else:
|
|
384
|
+
# Calculate model scores from annotations
|
|
385
|
+
model_scores = collections.defaultdict(list)
|
|
386
|
+
for annotation_line in annotations:
|
|
387
|
+
for annotation_item in annotation_line.get("annotation", {}):
|
|
388
|
+
for model in annotation_item:
|
|
389
|
+
if "score" in annotation_item[model]:
|
|
390
|
+
model_scores[model].append(annotation_item[model]["score"])
|
|
391
|
+
|
|
392
|
+
# Calculate average scores
|
|
393
|
+
model_avg_scores = {
|
|
394
|
+
model: statistics.mean(scores) for model, scores in model_scores.items()
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
# Get top N models
|
|
398
|
+
sorted_models = sorted(
|
|
399
|
+
model_avg_scores.items(), key=lambda x: x[1], reverse=True
|
|
400
|
+
)
|
|
401
|
+
top_models = [model for model, score in sorted_models[:dynamic_top]]
|
|
262
402
|
|
|
403
|
+
# From top N, randomly select K models
|
|
404
|
+
selected_models = random.sample(
|
|
405
|
+
top_models, k=min(dynamic_contrastive_models, len(top_models))
|
|
406
|
+
)
|
|
263
407
|
|
|
264
|
-
|
|
265
|
-
|
|
408
|
+
# Find incomplete items for the selected models (items where not all selected models are done)
|
|
409
|
+
item_annotation_counts = {
|
|
410
|
+
i: sum(model in completed_models for model in selected_models)
|
|
411
|
+
for i, completed_models in enumerate(user_progress["progress"])
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
# Select item with minimum annotations (with random tiebreaking)
|
|
415
|
+
min_annotations = min(item_annotation_counts.values())
|
|
416
|
+
items_with_min = [
|
|
417
|
+
item_i
|
|
418
|
+
for item_i, count in item_annotation_counts.items()
|
|
419
|
+
if count == min_annotations
|
|
420
|
+
]
|
|
421
|
+
item_i = random.choice(items_with_min)
|
|
422
|
+
|
|
423
|
+
# Prune the payload to only include selected models
|
|
424
|
+
original_item = campaign_data["data"][item_i]
|
|
425
|
+
pruned_item = []
|
|
426
|
+
for doc_segment in original_item:
|
|
427
|
+
pruned_segment = doc_segment.copy()
|
|
428
|
+
# Filter tgt to only include selected models
|
|
429
|
+
pruned_segment["tgt"] = {
|
|
430
|
+
model: doc_segment["tgt"][model]
|
|
431
|
+
for model in selected_models
|
|
432
|
+
if model in doc_segment["tgt"]
|
|
433
|
+
}
|
|
434
|
+
# Also filter error_spans if present
|
|
435
|
+
if "error_spans" in doc_segment:
|
|
436
|
+
pruned_segment["error_spans"] = {
|
|
437
|
+
model: doc_segment["error_spans"][model]
|
|
438
|
+
for model in selected_models
|
|
439
|
+
if model in doc_segment.get("error_spans", {})
|
|
440
|
+
}
|
|
441
|
+
# Also filter validation if present
|
|
442
|
+
if "validation" in doc_segment:
|
|
443
|
+
pruned_segment["validation"] = {
|
|
444
|
+
model: doc_segment["validation"][model]
|
|
445
|
+
for model in selected_models
|
|
446
|
+
if model in doc_segment.get("validation", {})
|
|
447
|
+
}
|
|
448
|
+
pruned_item.append(pruned_segment)
|
|
266
449
|
|
|
450
|
+
return JSONResponse(
|
|
451
|
+
content={
|
|
452
|
+
"status": "ok",
|
|
453
|
+
"time": user_progress["time"],
|
|
454
|
+
"progress": user_progress["progress"],
|
|
455
|
+
"info": {
|
|
456
|
+
"item_i": item_i,
|
|
457
|
+
"instructions": _get_instructions(tasks_data, campaign_id),
|
|
458
|
+
}
|
|
459
|
+
| {
|
|
460
|
+
k: v
|
|
461
|
+
for k, v in campaign_data["info"].items()
|
|
462
|
+
if k in {"protocol", "sliders"}
|
|
463
|
+
},
|
|
464
|
+
"payload": pruned_item,
|
|
465
|
+
},
|
|
466
|
+
status_code=200,
|
|
467
|
+
)
|
|
267
468
|
|
|
268
469
|
|
|
269
470
|
def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> None:
|
|
@@ -274,6 +475,26 @@ def _reset_user_time(progress_data: dict, campaign_id: str, user_id: str) -> Non
|
|
|
274
475
|
progress_data[campaign_id][user_id]["validations"] = {}
|
|
275
476
|
|
|
276
477
|
|
|
478
|
+
def _get_user_annotated_items(campaign_id: str, user_id: str) -> set[int]:
|
|
479
|
+
"""
|
|
480
|
+
Get the set of item indices that a specific user has annotated.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
campaign_id: The campaign identifier
|
|
484
|
+
user_id: The user identifier
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Set of item indices (item_i) that the user has annotated
|
|
488
|
+
"""
|
|
489
|
+
log = get_db_log(campaign_id)
|
|
490
|
+
user_items = set()
|
|
491
|
+
for entry in log:
|
|
492
|
+
if entry.get("user_id") == user_id and entry.get("annotation") != RESET_MARKER:
|
|
493
|
+
if (item_i := entry.get("item_i")) is not None:
|
|
494
|
+
user_items.add(item_i)
|
|
495
|
+
return user_items
|
|
496
|
+
|
|
497
|
+
|
|
277
498
|
def reset_task(
|
|
278
499
|
campaign_id: str,
|
|
279
500
|
user_id: str,
|
|
@@ -289,30 +510,60 @@ def reset_task(
|
|
|
289
510
|
# Save reset marker for this user to mask existing annotations
|
|
290
511
|
num_items = len(tasks_data[campaign_id]["data"][user_id])
|
|
291
512
|
for item_i in range(num_items):
|
|
292
|
-
save_db_payload(
|
|
293
|
-
|
|
294
|
-
"item_i": item_i,
|
|
295
|
-
|
|
296
|
-
})
|
|
513
|
+
save_db_payload(
|
|
514
|
+
campaign_id,
|
|
515
|
+
{"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
|
|
516
|
+
)
|
|
297
517
|
progress_data[campaign_id][user_id]["progress"] = [False] * num_items
|
|
298
518
|
_reset_user_time(progress_data, campaign_id, user_id)
|
|
299
519
|
return JSONResponse(content="ok", status_code=200)
|
|
300
520
|
elif assignment == "single-stream":
|
|
301
|
-
#
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
521
|
+
# Find all items that this user has annotated
|
|
522
|
+
user_items = _get_user_annotated_items(campaign_id, user_id)
|
|
523
|
+
|
|
524
|
+
# Save reset markers only for items this user has touched
|
|
525
|
+
for item_i in user_items:
|
|
526
|
+
save_db_payload(
|
|
527
|
+
campaign_id,
|
|
528
|
+
{"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# Reset only the touched items in all users' progress (shared pool)
|
|
310
532
|
for uid in progress_data[campaign_id]:
|
|
311
|
-
|
|
533
|
+
for item_i in user_items:
|
|
534
|
+
progress_data[campaign_id][uid]["progress"][item_i] = False
|
|
535
|
+
|
|
536
|
+
# Reset only the specified user's time
|
|
537
|
+
_reset_user_time(progress_data, campaign_id, user_id)
|
|
538
|
+
return JSONResponse(content="ok", status_code=200)
|
|
539
|
+
elif assignment == "dynamic":
|
|
540
|
+
# Find all items that this user has annotated
|
|
541
|
+
user_items = _get_user_annotated_items(campaign_id, user_id)
|
|
542
|
+
|
|
543
|
+
# Save reset markers only for items this user has touched
|
|
544
|
+
for item_i in user_items:
|
|
545
|
+
save_db_payload(
|
|
546
|
+
campaign_id,
|
|
547
|
+
{"user_id": user_id, "item_i": item_i, "annotation": RESET_MARKER},
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
progress_data_user = copy.deepcopy(progress_data[campaign_id][user_id]["progress"])
|
|
551
|
+
|
|
552
|
+
# Reset only the touched items in all users' progress (shared pool, use lists to track models)
|
|
553
|
+
for uid in progress_data[campaign_id]:
|
|
554
|
+
for item_i in user_items:
|
|
555
|
+
progress_data[campaign_id][uid]["progress"][item_i] = [
|
|
556
|
+
x for x in progress_data[campaign_id][uid]["progress"][item_i]
|
|
557
|
+
if x not in progress_data_user[item_i]
|
|
558
|
+
]
|
|
559
|
+
|
|
560
|
+
# Reset only the specified user's time
|
|
312
561
|
_reset_user_time(progress_data, campaign_id, user_id)
|
|
313
562
|
return JSONResponse(content="ok", status_code=200)
|
|
314
563
|
else:
|
|
315
|
-
return JSONResponse(
|
|
564
|
+
return JSONResponse(
|
|
565
|
+
content="Reset not supported for this assignment type", status_code=400
|
|
566
|
+
)
|
|
316
567
|
|
|
317
568
|
|
|
318
569
|
def update_progress(
|
|
@@ -337,6 +588,18 @@ def update_progress(
|
|
|
337
588
|
progress_data[campaign_id][uid]["progress"][item_i] = True
|
|
338
589
|
return JSONResponse(content="ok", status_code=200)
|
|
339
590
|
elif assignment == "dynamic":
|
|
340
|
-
|
|
591
|
+
# For dynamic, track which models were annotated
|
|
592
|
+
# Extract models from the payload annotation
|
|
593
|
+
annotated_models = []
|
|
594
|
+
if "annotation" in payload:
|
|
595
|
+
for annotation_item in payload.get("annotation", []):
|
|
596
|
+
if isinstance(annotation_item, dict):
|
|
597
|
+
annotated_models.extend(annotation_item.keys())
|
|
598
|
+
|
|
599
|
+
# Update progress for all users (shared pool)
|
|
600
|
+
for uid in progress_data[campaign_id]:
|
|
601
|
+
# Add the newly annotated models
|
|
602
|
+
progress_data[campaign_id][uid]["progress"][item_i].extend(annotated_models)
|
|
603
|
+
return JSONResponse(content="ok", status_code=200)
|
|
341
604
|
else:
|
|
342
605
|
return JSONResponse(content="Unknown campaign assignment type", status_code=400)
|