@moleculeagora/cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -0
- package/dist/index.js +30368 -0
- package/dist/index.js.map +1 -0
- package/dist/python-v1/agora_runtime.py +282 -0
- package/dist/python-v1/answer-set-metric.py +264 -0
- package/dist/python-v1/assertion-set-evaluation.py +879 -0
- package/dist/python-v1/exact-match.py +60 -0
- package/dist/python-v1/l4-composition.py +435 -0
- package/dist/python-v1/multi-output-tabular-metric.py +392 -0
- package/dist/python-v1/panel-ranking-metric.py +622 -0
- package/dist/python-v1/project-test.py +256 -0
- package/dist/python-v1/protein-binder-assay-metric.py +600 -0
- package/dist/python-v1/public-tool-metric.py +161 -0
- package/dist/python-v1/ranking-metric.py +426 -0
- package/dist/python-v1/reference-artifact-assertion.py +532 -0
- package/dist/python-v1/rubric-validation.py +246 -0
- package/dist/python-v1/solver-python-stdio-test.py +160 -0
- package/dist/python-v1/statistical-endpoint-test-v2.py +629 -0
- package/dist/python-v1/statistical-endpoint-test.py +442 -0
- package/dist/python-v1/table-metric.py +1291 -0
- package/dist/release-metadata.json +7 -0
- package/package.json +67 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
import zipfile
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from runtime_manifest import (
|
|
10
|
+
load_runtime_manifest,
|
|
11
|
+
resolve_artifact_by_role,
|
|
12
|
+
resolve_program_scoring_asset,
|
|
13
|
+
resolve_scoring_asset_by_role,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
INPUT_ROOT = Path(os.environ.get("AGORA_RUNTIME_INPUT_ROOT", "/input"))
|
|
17
|
+
OUTPUT_ROOT = Path(os.environ.get("AGORA_RUNTIME_OUTPUT_ROOT", "/output"))
|
|
18
|
+
OUTPUT_PATH = OUTPUT_ROOT / "score.json"
|
|
19
|
+
MAX_ZIP_MEMBER_COUNT = 1024
|
|
20
|
+
MAX_ZIP_MEMBER_BYTES = 50 * 1024 * 1024
|
|
21
|
+
MAX_ZIP_TOTAL_BYTES = 50 * 1024 * 1024
|
|
22
|
+
_ZIP_COPY_CHUNK_BYTES = 1024 * 1024
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _serialize_payload(payload: dict[str, Any]) -> str:
|
|
26
|
+
return json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def write_score(
|
|
30
|
+
*,
|
|
31
|
+
score: float,
|
|
32
|
+
details: dict[str, Any] | None = None,
|
|
33
|
+
ok: bool = True,
|
|
34
|
+
error: str | None = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
|
|
37
|
+
payload = {
|
|
38
|
+
"ok": bool(ok),
|
|
39
|
+
"score": float(score),
|
|
40
|
+
"details": details or {},
|
|
41
|
+
}
|
|
42
|
+
if error is not None:
|
|
43
|
+
payload["error"] = error
|
|
44
|
+
OUTPUT_PATH.write_text(_serialize_payload(payload), encoding="utf-8")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def fail_runtime(message: str) -> None:
|
|
48
|
+
write_score(score=0.0, details={}, ok=False, error=message)
|
|
49
|
+
raise SystemExit(1)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def reject_submission(
|
|
53
|
+
message: str,
|
|
54
|
+
*,
|
|
55
|
+
details: dict[str, Any] | None = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
write_score(score=0.0, details=details or {}, ok=False, error=message)
|
|
58
|
+
raise SystemExit(0)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def load_runtime_context(
|
|
62
|
+
*,
|
|
63
|
+
input_dir: Path | None = None,
|
|
64
|
+
fail_runtime_handler: Callable[[str], None] | None = None,
|
|
65
|
+
) -> dict[str, Any]:
|
|
66
|
+
fail_handler = fail_runtime_handler or fail_runtime
|
|
67
|
+
runtime_manifest = load_runtime_manifest(
|
|
68
|
+
input_dir=input_dir or INPUT_ROOT,
|
|
69
|
+
fail_runtime=fail_handler,
|
|
70
|
+
)
|
|
71
|
+
runtime_manifest["program_asset"] = resolve_program_scoring_asset(
|
|
72
|
+
runtime_manifest,
|
|
73
|
+
fail_runtime=fail_handler,
|
|
74
|
+
supported_abi_versions={"python-v1"},
|
|
75
|
+
)
|
|
76
|
+
return runtime_manifest
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def resolve_evaluation_artifact(
|
|
80
|
+
runtime_context: dict[str, Any],
|
|
81
|
+
role: str,
|
|
82
|
+
*,
|
|
83
|
+
fail_runtime_handler: Callable[[str], None] | None = None,
|
|
84
|
+
) -> Path:
|
|
85
|
+
fail_handler = fail_runtime_handler or fail_runtime
|
|
86
|
+
artifact = resolve_artifact_by_role(
|
|
87
|
+
runtime_context,
|
|
88
|
+
lane="evaluation",
|
|
89
|
+
role=role,
|
|
90
|
+
fail_runtime=fail_handler,
|
|
91
|
+
)
|
|
92
|
+
if artifact["path"] is None:
|
|
93
|
+
fail_handler(f"Missing required evaluation artifact for role {role}.")
|
|
94
|
+
return artifact["path"]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def resolve_submission_artifact(
|
|
98
|
+
runtime_context: dict[str, Any],
|
|
99
|
+
role: str,
|
|
100
|
+
*,
|
|
101
|
+
fail_runtime_handler: Callable[[str], None] | None = None,
|
|
102
|
+
) -> Path:
|
|
103
|
+
fail_handler = fail_runtime_handler or fail_runtime
|
|
104
|
+
artifact = resolve_artifact_by_role(
|
|
105
|
+
runtime_context,
|
|
106
|
+
lane="submission",
|
|
107
|
+
role=role,
|
|
108
|
+
fail_runtime=fail_handler,
|
|
109
|
+
)
|
|
110
|
+
if artifact["path"] is None:
|
|
111
|
+
reject_submission(f"Missing required submission artifact for role {role}.")
|
|
112
|
+
return artifact["path"]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def resolve_scoring_asset(
|
|
116
|
+
runtime_context: dict[str, Any],
|
|
117
|
+
role: str,
|
|
118
|
+
*,
|
|
119
|
+
kind: str | None = None,
|
|
120
|
+
fail_runtime_handler: Callable[[str], None] | None = None,
|
|
121
|
+
) -> Path:
|
|
122
|
+
fail_handler = fail_runtime_handler or fail_runtime
|
|
123
|
+
asset = resolve_scoring_asset_by_role(
|
|
124
|
+
runtime_context,
|
|
125
|
+
role=role,
|
|
126
|
+
kind=kind,
|
|
127
|
+
fail_runtime=fail_handler,
|
|
128
|
+
)
|
|
129
|
+
return asset["path"]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def load_json_file(path: Path, *, label: str | None = None) -> Any:
|
|
133
|
+
try:
|
|
134
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
135
|
+
except FileNotFoundError as error:
|
|
136
|
+
human_label = label or str(path)
|
|
137
|
+
raise RuntimeError(f"{human_label} is missing") from error
|
|
138
|
+
except json.JSONDecodeError as error:
|
|
139
|
+
human_label = label or str(path)
|
|
140
|
+
raise RuntimeError(f"{human_label} is not valid JSON: {error.msg}") from error
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _require_finite_score(value: Any, *, label: str) -> float:
|
|
144
|
+
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
|
145
|
+
raise RuntimeError(f"{label}.score must be a finite number.")
|
|
146
|
+
score = float(value)
|
|
147
|
+
if not math.isfinite(score):
|
|
148
|
+
raise RuntimeError(f"{label}.score must be a finite number.")
|
|
149
|
+
return score
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def load_score_output(path: Path, *, label: str) -> dict[str, Any]:
|
|
153
|
+
payload = load_json_file(path, label=label)
|
|
154
|
+
if not isinstance(payload, dict):
|
|
155
|
+
raise RuntimeError(f"{label} must be a JSON object.")
|
|
156
|
+
|
|
157
|
+
ok = payload.get("ok")
|
|
158
|
+
if not isinstance(ok, bool):
|
|
159
|
+
raise RuntimeError(f"{label}.ok must be a boolean.")
|
|
160
|
+
|
|
161
|
+
details = payload.get("details", {})
|
|
162
|
+
if not isinstance(details, dict):
|
|
163
|
+
raise RuntimeError(f"{label}.details must be an object when present.")
|
|
164
|
+
|
|
165
|
+
if "score" not in payload:
|
|
166
|
+
raise RuntimeError(f"{label}.score is missing.")
|
|
167
|
+
score = _require_finite_score(payload.get("score"), label=label)
|
|
168
|
+
|
|
169
|
+
error = payload.get("error")
|
|
170
|
+
if error is not None and not isinstance(error, str):
|
|
171
|
+
raise RuntimeError(f"{label}.error must be a string when present.")
|
|
172
|
+
if ok is False and (error is None or not error.strip()):
|
|
173
|
+
raise RuntimeError(f"{label}.error must explain ok=false outputs.")
|
|
174
|
+
|
|
175
|
+
normalized = dict(payload)
|
|
176
|
+
normalized["ok"] = ok
|
|
177
|
+
normalized["score"] = score
|
|
178
|
+
normalized["details"] = details
|
|
179
|
+
return normalized
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def load_text_file(path: Path) -> str:
|
|
183
|
+
return path.read_text(encoding="utf-8")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def safe_extract_zip(
|
|
187
|
+
archive_path: Path,
|
|
188
|
+
destination: Path,
|
|
189
|
+
*,
|
|
190
|
+
label: str,
|
|
191
|
+
invalid_handler: Callable[[str], None] | None = None,
|
|
192
|
+
) -> None:
|
|
193
|
+
handler = invalid_handler or fail_runtime
|
|
194
|
+
root = destination.resolve()
|
|
195
|
+
root.mkdir(parents=True, exist_ok=True)
|
|
196
|
+
try:
|
|
197
|
+
archive = zipfile.ZipFile(archive_path, "r")
|
|
198
|
+
except zipfile.BadZipFile:
|
|
199
|
+
handler(f"{label} is not a valid ZIP archive.")
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
with archive:
|
|
203
|
+
members = archive.infolist()
|
|
204
|
+
if len(members) > MAX_ZIP_MEMBER_COUNT:
|
|
205
|
+
handler(
|
|
206
|
+
f"{label} contains {len(members)} archive entries, exceeding the "
|
|
207
|
+
f"maximum ZIP member count of {MAX_ZIP_MEMBER_COUNT}. Next step: "
|
|
208
|
+
"remove unnecessary files and upload a smaller ZIP archive."
|
|
209
|
+
)
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
extraction_plan = []
|
|
213
|
+
for member in members:
|
|
214
|
+
name = member.filename
|
|
215
|
+
member_path = Path(name)
|
|
216
|
+
if (
|
|
217
|
+
not name
|
|
218
|
+
or member_path.is_absolute()
|
|
219
|
+
or any(part in {"", ".", ".."} for part in member_path.parts)
|
|
220
|
+
):
|
|
221
|
+
handler(f"{label} contains unsafe archive path {name!r}.")
|
|
222
|
+
return
|
|
223
|
+
target = (root / member_path).resolve()
|
|
224
|
+
try:
|
|
225
|
+
target.relative_to(root)
|
|
226
|
+
except ValueError:
|
|
227
|
+
handler(f"{label} contains archive path outside extraction root: {name!r}.")
|
|
228
|
+
return
|
|
229
|
+
extraction_plan.append((member, target))
|
|
230
|
+
|
|
231
|
+
total_bytes = 0
|
|
232
|
+
for member, target in extraction_plan:
|
|
233
|
+
if member.is_dir():
|
|
234
|
+
target.mkdir(parents=True, exist_ok=True)
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
name = member.filename
|
|
238
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
member_bytes = 0
|
|
240
|
+
temp_path = None
|
|
241
|
+
try:
|
|
242
|
+
with archive.open(member, "r") as source:
|
|
243
|
+
with tempfile.NamedTemporaryFile(
|
|
244
|
+
"wb",
|
|
245
|
+
dir=target.parent,
|
|
246
|
+
prefix=".agora-zip-",
|
|
247
|
+
delete=False,
|
|
248
|
+
) as output:
|
|
249
|
+
temp_path = Path(output.name)
|
|
250
|
+
while True:
|
|
251
|
+
chunk = source.read(_ZIP_COPY_CHUNK_BYTES)
|
|
252
|
+
if not chunk:
|
|
253
|
+
break
|
|
254
|
+
member_bytes += len(chunk)
|
|
255
|
+
total_bytes += len(chunk)
|
|
256
|
+
if member_bytes > MAX_ZIP_MEMBER_BYTES:
|
|
257
|
+
output.close()
|
|
258
|
+
temp_path.unlink(missing_ok=True)
|
|
259
|
+
handler(
|
|
260
|
+
f"{label} member {name!r} exceeds the maximum ZIP "
|
|
261
|
+
f"member size of {MAX_ZIP_MEMBER_BYTES} bytes. "
|
|
262
|
+
"Next step: reduce that file and upload a ZIP "
|
|
263
|
+
"within the Python-v1 artifact limits."
|
|
264
|
+
)
|
|
265
|
+
return
|
|
266
|
+
if total_bytes > MAX_ZIP_TOTAL_BYTES:
|
|
267
|
+
output.close()
|
|
268
|
+
temp_path.unlink(missing_ok=True)
|
|
269
|
+
handler(
|
|
270
|
+
f"{label} exceeds the maximum ZIP extracted size "
|
|
271
|
+
f"of {MAX_ZIP_TOTAL_BYTES} bytes. Next step: "
|
|
272
|
+
"remove unnecessary files or upload a smaller "
|
|
273
|
+
"ZIP archive."
|
|
274
|
+
)
|
|
275
|
+
return
|
|
276
|
+
output.write(chunk)
|
|
277
|
+
temp_path.replace(target)
|
|
278
|
+
except zipfile.BadZipFile:
|
|
279
|
+
if temp_path is not None:
|
|
280
|
+
temp_path.unlink(missing_ok=True)
|
|
281
|
+
handler(f"{label} is not a valid ZIP archive.")
|
|
282
|
+
return
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from agora_runtime import (
|
|
5
|
+
fail_runtime,
|
|
6
|
+
load_json_file,
|
|
7
|
+
load_runtime_context,
|
|
8
|
+
reject_submission,
|
|
9
|
+
resolve_evaluation_artifact,
|
|
10
|
+
resolve_scoring_asset,
|
|
11
|
+
resolve_submission_artifact,
|
|
12
|
+
write_score,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
SUPPORTED_METRICS = ("precision", "recall", "f1", "iou")
|
|
16
|
+
SUPPORTED_ID_SPACES = ("opaque", "doi", "pmid", "arxiv")
|
|
17
|
+
DOI_PATTERN = re.compile(r"^10\.[0-9]{4,9}/[!-~]+$")
|
|
18
|
+
PMID_PATTERN = re.compile(r"^[0-9]+$")
|
|
19
|
+
ARXIV_PATTERN = re.compile(
|
|
20
|
+
r"^(?:[a-z-]+(?:\.[a-z]{2})?/[0-9]{7}|[0-9]{4}\.[0-9]{4,5})(?:v[0-9]+)?$"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def format_value_list(values):
|
|
25
|
+
ordered = list(values)
|
|
26
|
+
if len(ordered) == 1:
|
|
27
|
+
return ordered[0]
|
|
28
|
+
if len(ordered) == 2:
|
|
29
|
+
return f"{ordered[0]} or {ordered[1]}"
|
|
30
|
+
return f"{', '.join(ordered[:-1])}, or {ordered[-1]}"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def require_string(value, label):
|
|
34
|
+
if not isinstance(value, str) or not value.strip():
|
|
35
|
+
fail_runtime(f"{label} must be a non-empty string.")
|
|
36
|
+
return value.strip()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def require_config_value(config, key, supported_values):
|
|
40
|
+
value = require_string(config.get(key), f"compiled_config.{key}").lower()
|
|
41
|
+
if value not in supported_values:
|
|
42
|
+
fail_runtime(
|
|
43
|
+
f"compiled_config.{key} must be one of {format_value_list(supported_values)}."
|
|
44
|
+
)
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def require_json_slot(runtime_context, lane, role):
|
|
49
|
+
artifact_contract = runtime_context.get("artifact_contract")
|
|
50
|
+
if not isinstance(artifact_contract, dict):
|
|
51
|
+
fail_runtime("Runtime context is missing artifact_contract.")
|
|
52
|
+
slots = artifact_contract.get(lane)
|
|
53
|
+
if not isinstance(slots, list):
|
|
54
|
+
fail_runtime(f"Runtime context is missing artifact_contract.{lane}.")
|
|
55
|
+
for slot in slots:
|
|
56
|
+
if not isinstance(slot, dict) or slot.get("role") != role:
|
|
57
|
+
continue
|
|
58
|
+
validator = slot.get("validator")
|
|
59
|
+
if not isinstance(validator, dict) or validator.get("kind") not in {
|
|
60
|
+
"json_document",
|
|
61
|
+
"json_schema",
|
|
62
|
+
}:
|
|
63
|
+
fail_runtime(
|
|
64
|
+
f"{lane} role {role} must use validator.kind=json_document or json_schema for answer_set_metric."
|
|
65
|
+
)
|
|
66
|
+
return
|
|
67
|
+
fail_runtime(f"Runtime context is missing {lane} slot for role {role}.")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def load_json_payload(path, label, invalid_handler):
|
|
71
|
+
try:
|
|
72
|
+
return load_json_file(path, label=label)
|
|
73
|
+
except RuntimeError as error:
|
|
74
|
+
invalid_handler(str(error))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def require_ascii(value, label, invalid_handler):
|
|
78
|
+
try:
|
|
79
|
+
value.encode("ascii")
|
|
80
|
+
except UnicodeEncodeError:
|
|
81
|
+
invalid_handler(f"{label} must contain ASCII characters only.")
|
|
82
|
+
return value
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def normalize_answer_id(raw_value, id_space, label, invalid_handler):
|
|
86
|
+
if not isinstance(raw_value, str):
|
|
87
|
+
invalid_handler(f"{label} must be a string.")
|
|
88
|
+
value = raw_value.strip()
|
|
89
|
+
if not value:
|
|
90
|
+
invalid_handler(f"{label} must be a non-empty string after trimming.")
|
|
91
|
+
|
|
92
|
+
if id_space == "opaque":
|
|
93
|
+
return value
|
|
94
|
+
|
|
95
|
+
ascii_value = require_ascii(value, label, invalid_handler)
|
|
96
|
+
if id_space == "pmid":
|
|
97
|
+
if not PMID_PATTERN.fullmatch(ascii_value):
|
|
98
|
+
invalid_handler(f"{label} must be a bare PMID containing ASCII digits only.")
|
|
99
|
+
return ascii_value
|
|
100
|
+
|
|
101
|
+
normalized = ascii_value.lower()
|
|
102
|
+
if id_space == "doi":
|
|
103
|
+
if not DOI_PATTERN.fullmatch(normalized):
|
|
104
|
+
invalid_handler(
|
|
105
|
+
f"{label} must be a bare DOI such as 10.1234/example, not a URL or alias."
|
|
106
|
+
)
|
|
107
|
+
return normalized
|
|
108
|
+
|
|
109
|
+
if not ARXIV_PATTERN.fullmatch(normalized):
|
|
110
|
+
invalid_handler(
|
|
111
|
+
f"{label} must be a canonical arXiv id string, not a URL or prefixed alias."
|
|
112
|
+
)
|
|
113
|
+
return normalized
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def load_answer_set(path, role, id_space, *, lane, invalid_handler):
|
|
117
|
+
payload = load_json_payload(path, f"{lane} artifact {role}", invalid_handler)
|
|
118
|
+
if not isinstance(payload, dict):
|
|
119
|
+
invalid_handler(f"{lane} artifact {role} must be a JSON object.")
|
|
120
|
+
items = payload.get("items")
|
|
121
|
+
if not isinstance(items, list):
|
|
122
|
+
invalid_handler(f"{lane} artifact {role}.items must be an array.")
|
|
123
|
+
|
|
124
|
+
answer_ids = set()
|
|
125
|
+
for index, item in enumerate(items):
|
|
126
|
+
normalized = normalize_answer_id(
|
|
127
|
+
item,
|
|
128
|
+
id_space,
|
|
129
|
+
f"{lane} artifact {role}.items[{index}]",
|
|
130
|
+
invalid_handler,
|
|
131
|
+
)
|
|
132
|
+
if normalized in answer_ids:
|
|
133
|
+
invalid_handler(
|
|
134
|
+
f"{lane} artifact {role} contains duplicate normalized id {normalized!r}."
|
|
135
|
+
)
|
|
136
|
+
answer_ids.add(normalized)
|
|
137
|
+
|
|
138
|
+
if lane == "evaluation" and not answer_ids:
|
|
139
|
+
invalid_handler(f"evaluation artifact {role}.items must contain at least one id.")
|
|
140
|
+
return answer_ids
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def compute_counts(reference_ids, candidate_ids):
|
|
144
|
+
true_positive_count = len(reference_ids & candidate_ids)
|
|
145
|
+
false_positive_count = len(candidate_ids - reference_ids)
|
|
146
|
+
false_negative_count = len(reference_ids - candidate_ids)
|
|
147
|
+
return true_positive_count, false_positive_count, false_negative_count
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def compute_metric(metric, true_positive_count, false_positive_count, false_negative_count):
|
|
151
|
+
candidate_count = true_positive_count + false_positive_count
|
|
152
|
+
reference_count = true_positive_count + false_negative_count
|
|
153
|
+
if metric == "precision":
|
|
154
|
+
if candidate_count == 0:
|
|
155
|
+
return 0.0
|
|
156
|
+
return true_positive_count / candidate_count
|
|
157
|
+
if metric == "recall":
|
|
158
|
+
return true_positive_count / reference_count
|
|
159
|
+
if metric == "f1":
|
|
160
|
+
precision = 0.0 if candidate_count == 0 else true_positive_count / candidate_count
|
|
161
|
+
recall = true_positive_count / reference_count
|
|
162
|
+
if precision + recall == 0.0:
|
|
163
|
+
return 0.0
|
|
164
|
+
return (2.0 * precision * recall) / (precision + recall)
|
|
165
|
+
|
|
166
|
+
union_count = true_positive_count + false_positive_count + false_negative_count
|
|
167
|
+
return true_positive_count / union_count
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def require_finite_unit_interval(value, label):
|
|
171
|
+
if not math.isfinite(value):
|
|
172
|
+
fail_runtime(f"{label} must be finite.")
|
|
173
|
+
if value < 0.0:
|
|
174
|
+
return 0.0
|
|
175
|
+
if value > 1.0:
|
|
176
|
+
return 1.0
|
|
177
|
+
return value
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def main():
|
|
181
|
+
runtime_context = load_runtime_context()
|
|
182
|
+
config_path = resolve_scoring_asset(
|
|
183
|
+
runtime_context,
|
|
184
|
+
"compiled_config",
|
|
185
|
+
kind="config",
|
|
186
|
+
)
|
|
187
|
+
config = load_json_payload(
|
|
188
|
+
config_path,
|
|
189
|
+
"compiled_config",
|
|
190
|
+
fail_runtime,
|
|
191
|
+
)
|
|
192
|
+
if not isinstance(config, dict):
|
|
193
|
+
fail_runtime("compiled_config must be a JSON object.")
|
|
194
|
+
|
|
195
|
+
metric = require_config_value(config, "metric", SUPPORTED_METRICS)
|
|
196
|
+
id_space = require_config_value(config, "id_space", SUPPORTED_ID_SPACES)
|
|
197
|
+
reference_role = require_string(
|
|
198
|
+
config.get("reference_role"),
|
|
199
|
+
"compiled_config.reference_role",
|
|
200
|
+
)
|
|
201
|
+
candidate_role = require_string(
|
|
202
|
+
config.get("candidate_role"),
|
|
203
|
+
"compiled_config.candidate_role",
|
|
204
|
+
)
|
|
205
|
+
final_score_key = require_string(
|
|
206
|
+
config.get("final_score_key"),
|
|
207
|
+
"compiled_config.final_score_key",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
objective = require_string(
|
|
211
|
+
runtime_context.get("objective"),
|
|
212
|
+
"runtime_context.objective",
|
|
213
|
+
)
|
|
214
|
+
if objective != "maximize":
|
|
215
|
+
fail_runtime(f"answer_set_metric metric {metric} requires objective=maximize.")
|
|
216
|
+
|
|
217
|
+
require_json_slot(runtime_context, "evaluation", reference_role)
|
|
218
|
+
require_json_slot(runtime_context, "submission", candidate_role)
|
|
219
|
+
reference_path = resolve_evaluation_artifact(runtime_context, reference_role)
|
|
220
|
+
candidate_path = resolve_submission_artifact(runtime_context, candidate_role)
|
|
221
|
+
reference_ids = load_answer_set(
|
|
222
|
+
reference_path,
|
|
223
|
+
reference_role,
|
|
224
|
+
id_space,
|
|
225
|
+
lane="evaluation",
|
|
226
|
+
invalid_handler=fail_runtime,
|
|
227
|
+
)
|
|
228
|
+
candidate_ids = load_answer_set(
|
|
229
|
+
candidate_path,
|
|
230
|
+
candidate_role,
|
|
231
|
+
id_space,
|
|
232
|
+
lane="submission",
|
|
233
|
+
invalid_handler=reject_submission,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
true_positive_count, false_positive_count, false_negative_count = compute_counts(
|
|
237
|
+
reference_ids,
|
|
238
|
+
candidate_ids,
|
|
239
|
+
)
|
|
240
|
+
raw_metric = compute_metric(
|
|
241
|
+
metric,
|
|
242
|
+
true_positive_count,
|
|
243
|
+
false_positive_count,
|
|
244
|
+
false_negative_count,
|
|
245
|
+
)
|
|
246
|
+
score = require_finite_unit_interval(raw_metric, "answer_set_metric score")
|
|
247
|
+
write_score(
|
|
248
|
+
score=score,
|
|
249
|
+
details={
|
|
250
|
+
final_score_key: score,
|
|
251
|
+
"selected_metric": metric,
|
|
252
|
+
"selected_id_space": id_space,
|
|
253
|
+
"reference_count": len(reference_ids),
|
|
254
|
+
"candidate_count": len(candidate_ids),
|
|
255
|
+
"true_positive_count": true_positive_count,
|
|
256
|
+
"false_positive_count": false_positive_count,
|
|
257
|
+
"false_negative_count": false_negative_count,
|
|
258
|
+
"selected_metric_value": raw_metric,
|
|
259
|
+
},
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
if __name__ == "__main__":
|
|
264
|
+
main()
|