@moleculeagora/cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -0
- package/dist/index.js +30368 -0
- package/dist/index.js.map +1 -0
- package/dist/python-v1/agora_runtime.py +282 -0
- package/dist/python-v1/answer-set-metric.py +264 -0
- package/dist/python-v1/assertion-set-evaluation.py +879 -0
- package/dist/python-v1/exact-match.py +60 -0
- package/dist/python-v1/l4-composition.py +435 -0
- package/dist/python-v1/multi-output-tabular-metric.py +392 -0
- package/dist/python-v1/panel-ranking-metric.py +622 -0
- package/dist/python-v1/project-test.py +256 -0
- package/dist/python-v1/protein-binder-assay-metric.py +600 -0
- package/dist/python-v1/public-tool-metric.py +161 -0
- package/dist/python-v1/ranking-metric.py +426 -0
- package/dist/python-v1/reference-artifact-assertion.py +532 -0
- package/dist/python-v1/rubric-validation.py +246 -0
- package/dist/python-v1/solver-python-stdio-test.py +160 -0
- package/dist/python-v1/statistical-endpoint-test-v2.py +629 -0
- package/dist/python-v1/statistical-endpoint-test.py +442 -0
- package/dist/python-v1/table-metric.py +1291 -0
- package/dist/release-metadata.json +7 -0
- package/package.json +67 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import sys
|
|
3
|
+
import tempfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from agora_runtime import (
|
|
7
|
+
fail_runtime,
|
|
8
|
+
load_json_file,
|
|
9
|
+
load_runtime_context,
|
|
10
|
+
resolve_scoring_asset,
|
|
11
|
+
resolve_submission_artifact,
|
|
12
|
+
safe_extract_zip,
|
|
13
|
+
write_score,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def require_string(value, label):
|
|
18
|
+
if not isinstance(value, str) or not value.strip():
|
|
19
|
+
fail_runtime(f"{label} must be a non-empty string.")
|
|
20
|
+
return value.strip()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def require_int(value, label):
|
|
24
|
+
if not isinstance(value, int) or value <= 0:
|
|
25
|
+
fail_runtime(f"{label} must be a positive integer.")
|
|
26
|
+
return value
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def load_json_object(path, label):
|
|
30
|
+
try:
|
|
31
|
+
data = load_json_file(path, label=label)
|
|
32
|
+
except RuntimeError as error:
|
|
33
|
+
fail_runtime(str(error))
|
|
34
|
+
if not isinstance(data, dict):
|
|
35
|
+
fail_runtime(f"{label} must be a JSON object.")
|
|
36
|
+
return data
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def read_required_text(path, label):
|
|
40
|
+
try:
|
|
41
|
+
return path.read_text(encoding="utf-8")
|
|
42
|
+
except FileNotFoundError:
|
|
43
|
+
fail_runtime(f"Missing {label} at {path}.")
|
|
44
|
+
except OSError as error:
|
|
45
|
+
fail_runtime(f"Unable to read {label}: {error}.")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def normalize_output(text, strip_trailing_whitespace):
|
|
49
|
+
if strip_trailing_whitespace:
|
|
50
|
+
return text.rstrip()
|
|
51
|
+
return text
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def run_python_solution(solution_path, stdin_text, timeout_ms, working_dir):
|
|
55
|
+
try:
|
|
56
|
+
return subprocess.run(
|
|
57
|
+
[sys.executable, str(solution_path)],
|
|
58
|
+
input=stdin_text,
|
|
59
|
+
capture_output=True,
|
|
60
|
+
text=True,
|
|
61
|
+
timeout=timeout_ms / 1000.0,
|
|
62
|
+
cwd=str(working_dir),
|
|
63
|
+
check=False,
|
|
64
|
+
)
|
|
65
|
+
except subprocess.TimeoutExpired:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def main():
|
|
70
|
+
runtime_context = load_runtime_context()
|
|
71
|
+
config_path = resolve_scoring_asset(
|
|
72
|
+
runtime_context,
|
|
73
|
+
"compiled_config",
|
|
74
|
+
kind="config",
|
|
75
|
+
)
|
|
76
|
+
config = load_json_object(config_path, "compiled_config")
|
|
77
|
+
tool_role = require_string(config.get("tool_role"), "compiled_config.tool_role")
|
|
78
|
+
submission_role = require_string(
|
|
79
|
+
config.get("submission_role"), "compiled_config.submission_role"
|
|
80
|
+
)
|
|
81
|
+
final_score_key = require_string(
|
|
82
|
+
runtime_context.get("final_score_key"),
|
|
83
|
+
"runtime_context.final_score_key",
|
|
84
|
+
)
|
|
85
|
+
tool_bundle_path = resolve_scoring_asset(runtime_context, tool_role, kind="bundle")
|
|
86
|
+
solution_path = resolve_submission_artifact(runtime_context, submission_role)
|
|
87
|
+
with tempfile.TemporaryDirectory(prefix="agora-public-tool-") as working_root:
|
|
88
|
+
working_dir = Path(working_root)
|
|
89
|
+
safe_extract_zip(
|
|
90
|
+
tool_bundle_path,
|
|
91
|
+
working_dir,
|
|
92
|
+
label=f"scoring asset {tool_role}",
|
|
93
|
+
)
|
|
94
|
+
manifest_path = working_dir / "agora-harness.json"
|
|
95
|
+
harness_manifest = load_json_object(manifest_path, "agora-harness.json")
|
|
96
|
+
version = require_string(harness_manifest.get("version"), "agora-harness.version")
|
|
97
|
+
language = require_string(harness_manifest.get("language"), "agora-harness.language")
|
|
98
|
+
if version != "v1":
|
|
99
|
+
fail_runtime(f"Unsupported harness version {version}.")
|
|
100
|
+
if language != "python":
|
|
101
|
+
fail_runtime(f"Unsupported harness language {language}.")
|
|
102
|
+
timeout_ms = require_int(
|
|
103
|
+
harness_manifest.get("timeout_ms", 30000),
|
|
104
|
+
"agora-harness.timeout_ms",
|
|
105
|
+
)
|
|
106
|
+
strip_trailing_whitespace = bool(
|
|
107
|
+
harness_manifest.get("strip_trailing_whitespace", False)
|
|
108
|
+
)
|
|
109
|
+
tests = harness_manifest.get("tests")
|
|
110
|
+
if not isinstance(tests, list) or len(tests) == 0:
|
|
111
|
+
fail_runtime("agora-harness.tests must declare at least one test case.")
|
|
112
|
+
passed_tests = 0
|
|
113
|
+
for index, test_case in enumerate(tests):
|
|
114
|
+
if not isinstance(test_case, dict):
|
|
115
|
+
fail_runtime(f"agora-harness.tests[{index}] must be an object.")
|
|
116
|
+
stdin_path = working_dir / require_string(
|
|
117
|
+
test_case.get("stdin_path"),
|
|
118
|
+
f"agora-harness.tests[{index}].stdin_path",
|
|
119
|
+
)
|
|
120
|
+
expected_stdout_path = working_dir / require_string(
|
|
121
|
+
test_case.get("expected_stdout_path"),
|
|
122
|
+
f"agora-harness.tests[{index}].expected_stdout_path",
|
|
123
|
+
)
|
|
124
|
+
stdin_text = read_required_text(
|
|
125
|
+
stdin_path,
|
|
126
|
+
f"test stdin {stdin_path.name}",
|
|
127
|
+
)
|
|
128
|
+
expected_stdout = normalize_output(
|
|
129
|
+
read_required_text(
|
|
130
|
+
expected_stdout_path,
|
|
131
|
+
f"test expected stdout {expected_stdout_path.name}",
|
|
132
|
+
),
|
|
133
|
+
strip_trailing_whitespace,
|
|
134
|
+
)
|
|
135
|
+
run = run_python_solution(
|
|
136
|
+
solution_path,
|
|
137
|
+
stdin_text,
|
|
138
|
+
timeout_ms,
|
|
139
|
+
working_dir,
|
|
140
|
+
)
|
|
141
|
+
if run is None:
|
|
142
|
+
continue
|
|
143
|
+
if run.returncode != 0:
|
|
144
|
+
continue
|
|
145
|
+
actual_stdout = normalize_output(run.stdout, strip_trailing_whitespace)
|
|
146
|
+
if actual_stdout == expected_stdout:
|
|
147
|
+
passed_tests += 1
|
|
148
|
+
total_tests = len(tests)
|
|
149
|
+
score = passed_tests / total_tests
|
|
150
|
+
write_score(
|
|
151
|
+
score=score,
|
|
152
|
+
details={
|
|
153
|
+
final_score_key: score,
|
|
154
|
+
"tests_passed": passed_tests,
|
|
155
|
+
"total_tests": total_tests,
|
|
156
|
+
},
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
if __name__ == "__main__":
|
|
161
|
+
main()
|
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
from agora_runtime import (
|
|
5
|
+
fail_runtime,
|
|
6
|
+
load_json_file,
|
|
7
|
+
load_runtime_context,
|
|
8
|
+
reject_submission,
|
|
9
|
+
resolve_evaluation_artifact,
|
|
10
|
+
resolve_scoring_asset,
|
|
11
|
+
resolve_submission_artifact,
|
|
12
|
+
write_score,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
SUPPORTED_METRICS = ("ndcg", "mrr", "map", "top_k_recall")
|
|
16
|
+
SUPPORTED_METRIC_SET = set(SUPPORTED_METRICS)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def format_metric_list(metrics):
|
|
20
|
+
ordered = list(metrics)
|
|
21
|
+
if len(ordered) == 1:
|
|
22
|
+
return ordered[0]
|
|
23
|
+
if len(ordered) == 2:
|
|
24
|
+
return f"{ordered[0]} or {ordered[1]}"
|
|
25
|
+
return f"{', '.join(ordered[:-1])}, or {ordered[-1]}"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def require_string(value, label):
|
|
29
|
+
if not isinstance(value, str) or not value.strip():
|
|
30
|
+
fail_runtime(f"{label} must be a non-empty string.")
|
|
31
|
+
return value.strip()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def require_policy(policies, key, allowed):
|
|
35
|
+
value = require_string(policies.get(key), f"policies.{key}")
|
|
36
|
+
if value not in allowed:
|
|
37
|
+
fail_runtime(
|
|
38
|
+
f"policies.{key} must be one of {', '.join(sorted(allowed))}."
|
|
39
|
+
)
|
|
40
|
+
return value
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def require_metric(config):
|
|
44
|
+
metric = require_string(config.get("metric"), "compiled_config.metric").lower()
|
|
45
|
+
if metric not in SUPPORTED_METRIC_SET:
|
|
46
|
+
fail_runtime(
|
|
47
|
+
f"compiled_config.metric must be one of {format_metric_list(SUPPORTED_METRICS)}."
|
|
48
|
+
)
|
|
49
|
+
return metric
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def require_metric_params(config, metric):
|
|
53
|
+
metric_params = config.get("metric_params")
|
|
54
|
+
if not isinstance(metric_params, dict):
|
|
55
|
+
fail_runtime("compiled_config.metric_params must be an object.")
|
|
56
|
+
if metric_params.get("metric") != metric:
|
|
57
|
+
fail_runtime("compiled_config.metric_params.metric must match compiled_config.metric.")
|
|
58
|
+
return metric_params
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def require_top_k_fraction(metric_params):
|
|
62
|
+
value = metric_params.get("k")
|
|
63
|
+
if (
|
|
64
|
+
isinstance(value, bool)
|
|
65
|
+
or not isinstance(value, (int, float))
|
|
66
|
+
or not math.isfinite(float(value))
|
|
67
|
+
):
|
|
68
|
+
fail_runtime("compiled_config.metric_params.k must be a finite number.")
|
|
69
|
+
fraction = float(value)
|
|
70
|
+
if fraction <= 0.0 or fraction > 1.0:
|
|
71
|
+
fail_runtime("compiled_config.metric_params.k must be greater than 0 and at most 1.")
|
|
72
|
+
return fraction
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def find_slot(runtime_context, lane, role):
|
|
76
|
+
slot_key = f"{lane}_slots"
|
|
77
|
+
slots = runtime_context.get(slot_key)
|
|
78
|
+
if not isinstance(slots, list):
|
|
79
|
+
fail_runtime(f"Runtime context is missing {slot_key}.")
|
|
80
|
+
for slot in slots:
|
|
81
|
+
if isinstance(slot, dict) and slot.get("role") == role:
|
|
82
|
+
return slot
|
|
83
|
+
fail_runtime(f"Runtime context is missing {lane} slot for role {role}.")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def require_csv_slot_columns(runtime_context, lane, role):
|
|
87
|
+
slot = find_slot(runtime_context, lane, role)
|
|
88
|
+
validator = slot.get("validator")
|
|
89
|
+
if not isinstance(validator, dict) or validator.get("kind") != "csv_columns":
|
|
90
|
+
fail_runtime(
|
|
91
|
+
f"{lane} role {role} must use validator.kind=csv_columns for ranking_metric."
|
|
92
|
+
)
|
|
93
|
+
record_key = require_string(
|
|
94
|
+
validator.get("record_key"),
|
|
95
|
+
f"{lane}.{role}.validator.record_key",
|
|
96
|
+
)
|
|
97
|
+
value_field = require_string(
|
|
98
|
+
validator.get("value_field"),
|
|
99
|
+
f"{lane}.{role}.validator.value_field",
|
|
100
|
+
)
|
|
101
|
+
return record_key, value_field
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def read_csv_rows(path, label, *, invalid_handler):
|
|
105
|
+
try:
|
|
106
|
+
with path.open("r", encoding="utf-8", newline="") as handle:
|
|
107
|
+
reader = csv.DictReader(handle)
|
|
108
|
+
fieldnames = reader.fieldnames
|
|
109
|
+
if not fieldnames:
|
|
110
|
+
invalid_handler(f"{label} must include a CSV header row.")
|
|
111
|
+
normalized_fieldnames = []
|
|
112
|
+
for fieldname in fieldnames:
|
|
113
|
+
if not isinstance(fieldname, str) or not fieldname.strip():
|
|
114
|
+
invalid_handler(f"{label} contains an empty CSV column name.")
|
|
115
|
+
normalized_fieldnames.append(fieldname.strip())
|
|
116
|
+
rows = list(reader)
|
|
117
|
+
except FileNotFoundError:
|
|
118
|
+
invalid_handler(f"Missing {label} at {path}.")
|
|
119
|
+
except OSError as error:
|
|
120
|
+
invalid_handler(f"Unable to read {label}: {error}.")
|
|
121
|
+
return normalized_fieldnames, rows
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def parse_reference_relevance(raw_value, label):
|
|
125
|
+
text = raw_value.strip() if isinstance(raw_value, str) else ""
|
|
126
|
+
if not text:
|
|
127
|
+
fail_runtime(f"{label} is blank.")
|
|
128
|
+
try:
|
|
129
|
+
value = float(text)
|
|
130
|
+
except ValueError:
|
|
131
|
+
fail_runtime(f"{label} must be numeric, received {text!r}.")
|
|
132
|
+
if not math.isfinite(value):
|
|
133
|
+
fail_runtime(f"{label} must be finite.")
|
|
134
|
+
if value < 0:
|
|
135
|
+
fail_runtime(f"{label} must be non-negative for ranking metrics.")
|
|
136
|
+
return value
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def parse_submission_score(raw_value, label, invalid_value_policy):
|
|
140
|
+
text = raw_value.strip() if isinstance(raw_value, str) else ""
|
|
141
|
+
if not text:
|
|
142
|
+
if invalid_value_policy == "reject":
|
|
143
|
+
reject_submission(f"{label} is blank.")
|
|
144
|
+
return None
|
|
145
|
+
try:
|
|
146
|
+
value = float(text)
|
|
147
|
+
except ValueError:
|
|
148
|
+
if invalid_value_policy == "reject":
|
|
149
|
+
reject_submission(f"{label} must be numeric, received {text!r}.")
|
|
150
|
+
return None
|
|
151
|
+
if not math.isfinite(value):
|
|
152
|
+
if invalid_value_policy == "reject":
|
|
153
|
+
reject_submission(f"{label} must be finite.")
|
|
154
|
+
return None
|
|
155
|
+
return value
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def load_reference_relevance(path, role, record_key, value_field):
|
|
159
|
+
fieldnames, rows = read_csv_rows(
|
|
160
|
+
path,
|
|
161
|
+
f"evaluation artifact {role}",
|
|
162
|
+
invalid_handler=fail_runtime,
|
|
163
|
+
)
|
|
164
|
+
if record_key not in fieldnames:
|
|
165
|
+
fail_runtime(
|
|
166
|
+
f"evaluation artifact {role} is missing record key column {record_key}."
|
|
167
|
+
)
|
|
168
|
+
if value_field not in fieldnames:
|
|
169
|
+
fail_runtime(
|
|
170
|
+
f"evaluation artifact {role} is missing value column {value_field}."
|
|
171
|
+
)
|
|
172
|
+
values = {}
|
|
173
|
+
for row_index, row in enumerate(rows, start=2):
|
|
174
|
+
raw_key = row.get(record_key)
|
|
175
|
+
key = raw_key.strip() if isinstance(raw_key, str) else ""
|
|
176
|
+
if not key:
|
|
177
|
+
fail_runtime(
|
|
178
|
+
f"evaluation artifact {role} row {row_index} is missing {record_key}."
|
|
179
|
+
)
|
|
180
|
+
if key in values:
|
|
181
|
+
fail_runtime(
|
|
182
|
+
f"evaluation artifact {role} contains duplicate record id {key!r}."
|
|
183
|
+
)
|
|
184
|
+
values[key] = parse_reference_relevance(
|
|
185
|
+
row.get(value_field),
|
|
186
|
+
f"evaluation artifact {role} row {row_index} column {value_field}",
|
|
187
|
+
)
|
|
188
|
+
if not values:
|
|
189
|
+
fail_runtime(
|
|
190
|
+
f"evaluation artifact {role} must contain at least one ranked row."
|
|
191
|
+
)
|
|
192
|
+
return values
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def load_submission_scores(
|
|
196
|
+
path,
|
|
197
|
+
role,
|
|
198
|
+
record_key,
|
|
199
|
+
value_field,
|
|
200
|
+
duplicate_id_policy,
|
|
201
|
+
invalid_value_policy,
|
|
202
|
+
):
|
|
203
|
+
fieldnames, rows = read_csv_rows(
|
|
204
|
+
path,
|
|
205
|
+
f"submission artifact {role}",
|
|
206
|
+
invalid_handler=reject_submission,
|
|
207
|
+
)
|
|
208
|
+
if record_key not in fieldnames:
|
|
209
|
+
reject_submission(
|
|
210
|
+
f"submission artifact {role} is missing record key column {record_key}."
|
|
211
|
+
)
|
|
212
|
+
if value_field not in fieldnames:
|
|
213
|
+
reject_submission(
|
|
214
|
+
f"submission artifact {role} is missing value column {value_field}."
|
|
215
|
+
)
|
|
216
|
+
values = {}
|
|
217
|
+
for row_index, row in enumerate(rows, start=2):
|
|
218
|
+
raw_key = row.get(record_key)
|
|
219
|
+
key = raw_key.strip() if isinstance(raw_key, str) else ""
|
|
220
|
+
if not key:
|
|
221
|
+
if invalid_value_policy == "reject":
|
|
222
|
+
reject_submission(
|
|
223
|
+
f"submission artifact {role} row {row_index} is missing {record_key}."
|
|
224
|
+
)
|
|
225
|
+
continue
|
|
226
|
+
if key in values:
|
|
227
|
+
if duplicate_id_policy == "reject":
|
|
228
|
+
reject_submission(
|
|
229
|
+
f"submission artifact {role} contains duplicate record id {key!r}."
|
|
230
|
+
)
|
|
231
|
+
continue
|
|
232
|
+
parsed_value = parse_submission_score(
|
|
233
|
+
row.get(value_field),
|
|
234
|
+
f"submission artifact {role} row {row_index} column {value_field}",
|
|
235
|
+
invalid_value_policy,
|
|
236
|
+
)
|
|
237
|
+
if parsed_value is None:
|
|
238
|
+
continue
|
|
239
|
+
values[key] = parsed_value
|
|
240
|
+
return values
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def discount(rank_index):
|
|
244
|
+
return 1.0 / math.log2(rank_index + 2.0)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def dcg(labels):
|
|
248
|
+
return sum(((2.0 ** label) - 1.0) * discount(index) for index, label in enumerate(labels))
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def rank_submission_ids(submission_by_id):
|
|
252
|
+
return sorted(
|
|
253
|
+
submission_by_id.keys(),
|
|
254
|
+
key=lambda record_id: (-submission_by_id[record_id], record_id),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def compute_ndcg(reference_by_id, submission_by_id):
|
|
259
|
+
ranked_ids = rank_submission_ids(submission_by_id)
|
|
260
|
+
ranked_labels = [reference_by_id[record_id] for record_id in ranked_ids]
|
|
261
|
+
ideal_labels = sorted(ranked_labels, reverse=True)
|
|
262
|
+
ideal_dcg = dcg(ideal_labels)
|
|
263
|
+
if ideal_dcg == 0.0:
|
|
264
|
+
return 1.0
|
|
265
|
+
return dcg(ranked_labels) / ideal_dcg
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def compute_mrr(reference_by_id, submission_by_id):
|
|
269
|
+
ranked_ids = rank_submission_ids(submission_by_id)
|
|
270
|
+
for index, record_id in enumerate(ranked_ids, start=1):
|
|
271
|
+
if reference_by_id[record_id] > 0:
|
|
272
|
+
return 1.0 / index
|
|
273
|
+
return 0.0
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def compute_map(reference_by_id, submission_by_id):
|
|
277
|
+
ranked_ids = rank_submission_ids(submission_by_id)
|
|
278
|
+
hits = 0
|
|
279
|
+
precision_sum = 0.0
|
|
280
|
+
for index, record_id in enumerate(ranked_ids, start=1):
|
|
281
|
+
if reference_by_id[record_id] <= 0:
|
|
282
|
+
continue
|
|
283
|
+
hits += 1
|
|
284
|
+
precision_sum += hits / index
|
|
285
|
+
if hits == 0:
|
|
286
|
+
return 0.0
|
|
287
|
+
return precision_sum / hits
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def compute_top_k_recall(reference_by_id, submission_by_id, metric_params):
|
|
291
|
+
relevant_ids = [
|
|
292
|
+
record_id for record_id, relevance in reference_by_id.items() if relevance > 0
|
|
293
|
+
]
|
|
294
|
+
if not relevant_ids:
|
|
295
|
+
fail_runtime(
|
|
296
|
+
"ranking_metric metric top_k_recall requires at least one positive reference label."
|
|
297
|
+
)
|
|
298
|
+
ranked_ids = rank_submission_ids(submission_by_id)
|
|
299
|
+
cutoff = max(1, math.ceil(len(ranked_ids) * require_top_k_fraction(metric_params)))
|
|
300
|
+
selected_ids = set(ranked_ids[:cutoff])
|
|
301
|
+
recalled = sum(1 for record_id in relevant_ids if record_id in selected_ids)
|
|
302
|
+
return recalled / len(relevant_ids)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def compute_metric(metric, reference_by_id, submission_by_id, metric_params):
|
|
306
|
+
if metric == "ndcg":
|
|
307
|
+
return compute_ndcg(reference_by_id, submission_by_id)
|
|
308
|
+
if metric == "mrr":
|
|
309
|
+
return compute_mrr(reference_by_id, submission_by_id)
|
|
310
|
+
if metric == "map":
|
|
311
|
+
return compute_map(reference_by_id, submission_by_id)
|
|
312
|
+
return compute_top_k_recall(reference_by_id, submission_by_id, metric_params)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def main():
|
|
316
|
+
runtime_context = load_runtime_context()
|
|
317
|
+
config_path = resolve_scoring_asset(
|
|
318
|
+
runtime_context,
|
|
319
|
+
"compiled_config",
|
|
320
|
+
kind="config",
|
|
321
|
+
)
|
|
322
|
+
try:
|
|
323
|
+
config = load_json_file(config_path, label="compiled_config")
|
|
324
|
+
except RuntimeError as error:
|
|
325
|
+
fail_runtime(str(error))
|
|
326
|
+
metric = require_metric(config)
|
|
327
|
+
metric_params = require_metric_params(config, metric)
|
|
328
|
+
evaluation_role = require_string(
|
|
329
|
+
config.get("evaluation_role"),
|
|
330
|
+
"compiled_config.evaluation_role",
|
|
331
|
+
)
|
|
332
|
+
submission_role = require_string(
|
|
333
|
+
config.get("submission_role"),
|
|
334
|
+
"compiled_config.submission_role",
|
|
335
|
+
)
|
|
336
|
+
objective = require_string(
|
|
337
|
+
runtime_context.get("objective"),
|
|
338
|
+
"runtime_context.objective",
|
|
339
|
+
)
|
|
340
|
+
if objective != "maximize":
|
|
341
|
+
fail_runtime(f"ranking_metric metric {metric} requires objective=maximize.")
|
|
342
|
+
final_score_key = require_string(
|
|
343
|
+
runtime_context.get("final_score_key"),
|
|
344
|
+
"runtime_context.final_score_key",
|
|
345
|
+
)
|
|
346
|
+
policies = runtime_context.get("policies")
|
|
347
|
+
if not isinstance(policies, dict):
|
|
348
|
+
fail_runtime("Runtime context is missing execution policies.")
|
|
349
|
+
coverage_policy = require_policy(
|
|
350
|
+
policies,
|
|
351
|
+
"coverage_policy",
|
|
352
|
+
{"reject", "ignore", "penalize"},
|
|
353
|
+
)
|
|
354
|
+
duplicate_id_policy = require_policy(
|
|
355
|
+
policies,
|
|
356
|
+
"duplicate_id_policy",
|
|
357
|
+
{"reject", "ignore"},
|
|
358
|
+
)
|
|
359
|
+
invalid_value_policy = require_policy(
|
|
360
|
+
policies,
|
|
361
|
+
"invalid_value_policy",
|
|
362
|
+
{"reject", "ignore"},
|
|
363
|
+
)
|
|
364
|
+
evaluation_record_key, evaluation_value_field = require_csv_slot_columns(
|
|
365
|
+
runtime_context,
|
|
366
|
+
"evaluation",
|
|
367
|
+
evaluation_role,
|
|
368
|
+
)
|
|
369
|
+
submission_record_key, submission_value_field = require_csv_slot_columns(
|
|
370
|
+
runtime_context,
|
|
371
|
+
"submission",
|
|
372
|
+
submission_role,
|
|
373
|
+
)
|
|
374
|
+
evaluation_path = resolve_evaluation_artifact(runtime_context, evaluation_role)
|
|
375
|
+
submission_path = resolve_submission_artifact(runtime_context, submission_role)
|
|
376
|
+
reference_by_id = load_reference_relevance(
|
|
377
|
+
evaluation_path,
|
|
378
|
+
evaluation_role,
|
|
379
|
+
evaluation_record_key,
|
|
380
|
+
evaluation_value_field,
|
|
381
|
+
)
|
|
382
|
+
submission_by_id = load_submission_scores(
|
|
383
|
+
submission_path,
|
|
384
|
+
submission_role,
|
|
385
|
+
submission_record_key,
|
|
386
|
+
submission_value_field,
|
|
387
|
+
duplicate_id_policy,
|
|
388
|
+
invalid_value_policy,
|
|
389
|
+
)
|
|
390
|
+
missing_ids = [
|
|
391
|
+
record_id
|
|
392
|
+
for record_id in reference_by_id
|
|
393
|
+
if record_id not in submission_by_id
|
|
394
|
+
]
|
|
395
|
+
if missing_ids and coverage_policy == "reject":
|
|
396
|
+
reject_submission(
|
|
397
|
+
f"Submission is missing rankings for {len(missing_ids)} required rows; first missing id is {missing_ids[0]!r}."
|
|
398
|
+
)
|
|
399
|
+
scored_ids = [
|
|
400
|
+
record_id
|
|
401
|
+
for record_id in reference_by_id
|
|
402
|
+
if record_id in submission_by_id
|
|
403
|
+
]
|
|
404
|
+
if not scored_ids:
|
|
405
|
+
reject_submission(
|
|
406
|
+
"Submission produced no ranked rows after applying runtime policies."
|
|
407
|
+
)
|
|
408
|
+
scored_reference = {record_id: reference_by_id[record_id] for record_id in scored_ids}
|
|
409
|
+
scored_submission = {record_id: submission_by_id[record_id] for record_id in scored_ids}
|
|
410
|
+
raw_metric = compute_metric(metric, scored_reference, scored_submission, metric_params)
|
|
411
|
+
normalized_score = max(0.0, min(raw_metric, 1.0))
|
|
412
|
+
if coverage_policy == "penalize":
|
|
413
|
+
normalized_score *= len(scored_ids) / len(reference_by_id)
|
|
414
|
+
write_score(
|
|
415
|
+
score=normalized_score,
|
|
416
|
+
details={
|
|
417
|
+
final_score_key: normalized_score,
|
|
418
|
+
"selected_metric": metric,
|
|
419
|
+
"selected_metric_value": raw_metric,
|
|
420
|
+
"ranked_items": len(scored_ids),
|
|
421
|
+
},
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
if __name__ == "__main__":
|
|
426
|
+
main()
|