@moleculeagora/cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,282 @@
1
+ import json
2
+ import math
3
+ import os
4
+ import tempfile
5
+ import zipfile
6
+ from pathlib import Path
7
+ from typing import Any, Callable
8
+
9
+ from runtime_manifest import (
10
+ load_runtime_manifest,
11
+ resolve_artifact_by_role,
12
+ resolve_program_scoring_asset,
13
+ resolve_scoring_asset_by_role,
14
+ )
15
+
16
+ INPUT_ROOT = Path(os.environ.get("AGORA_RUNTIME_INPUT_ROOT", "/input"))
17
+ OUTPUT_ROOT = Path(os.environ.get("AGORA_RUNTIME_OUTPUT_ROOT", "/output"))
18
+ OUTPUT_PATH = OUTPUT_ROOT / "score.json"
19
+ MAX_ZIP_MEMBER_COUNT = 1024
20
+ MAX_ZIP_MEMBER_BYTES = 50 * 1024 * 1024
21
+ MAX_ZIP_TOTAL_BYTES = 50 * 1024 * 1024
22
+ _ZIP_COPY_CHUNK_BYTES = 1024 * 1024
23
+
24
+
25
+ def _serialize_payload(payload: dict[str, Any]) -> str:
26
+ return json.dumps(payload, sort_keys=True, separators=(",", ":"))
27
+
28
+
29
+ def write_score(
30
+ *,
31
+ score: float,
32
+ details: dict[str, Any] | None = None,
33
+ ok: bool = True,
34
+ error: str | None = None,
35
+ ) -> None:
36
+ OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
37
+ payload = {
38
+ "ok": bool(ok),
39
+ "score": float(score),
40
+ "details": details or {},
41
+ }
42
+ if error is not None:
43
+ payload["error"] = error
44
+ OUTPUT_PATH.write_text(_serialize_payload(payload), encoding="utf-8")
45
+
46
+
47
+ def fail_runtime(message: str) -> None:
48
+ write_score(score=0.0, details={}, ok=False, error=message)
49
+ raise SystemExit(1)
50
+
51
+
52
+ def reject_submission(
53
+ message: str,
54
+ *,
55
+ details: dict[str, Any] | None = None,
56
+ ) -> None:
57
+ write_score(score=0.0, details=details or {}, ok=False, error=message)
58
+ raise SystemExit(0)
59
+
60
+
61
+ def load_runtime_context(
62
+ *,
63
+ input_dir: Path | None = None,
64
+ fail_runtime_handler: Callable[[str], None] | None = None,
65
+ ) -> dict[str, Any]:
66
+ fail_handler = fail_runtime_handler or fail_runtime
67
+ runtime_manifest = load_runtime_manifest(
68
+ input_dir=input_dir or INPUT_ROOT,
69
+ fail_runtime=fail_handler,
70
+ )
71
+ runtime_manifest["program_asset"] = resolve_program_scoring_asset(
72
+ runtime_manifest,
73
+ fail_runtime=fail_handler,
74
+ supported_abi_versions={"python-v1"},
75
+ )
76
+ return runtime_manifest
77
+
78
+
79
+ def resolve_evaluation_artifact(
80
+ runtime_context: dict[str, Any],
81
+ role: str,
82
+ *,
83
+ fail_runtime_handler: Callable[[str], None] | None = None,
84
+ ) -> Path:
85
+ fail_handler = fail_runtime_handler or fail_runtime
86
+ artifact = resolve_artifact_by_role(
87
+ runtime_context,
88
+ lane="evaluation",
89
+ role=role,
90
+ fail_runtime=fail_handler,
91
+ )
92
+ if artifact["path"] is None:
93
+ fail_handler(f"Missing required evaluation artifact for role {role}.")
94
+ return artifact["path"]
95
+
96
+
97
+ def resolve_submission_artifact(
98
+ runtime_context: dict[str, Any],
99
+ role: str,
100
+ *,
101
+ fail_runtime_handler: Callable[[str], None] | None = None,
102
+ ) -> Path:
103
+ fail_handler = fail_runtime_handler or fail_runtime
104
+ artifact = resolve_artifact_by_role(
105
+ runtime_context,
106
+ lane="submission",
107
+ role=role,
108
+ fail_runtime=fail_handler,
109
+ )
110
+ if artifact["path"] is None:
111
+ reject_submission(f"Missing required submission artifact for role {role}.")
112
+ return artifact["path"]
113
+
114
+
115
+ def resolve_scoring_asset(
116
+ runtime_context: dict[str, Any],
117
+ role: str,
118
+ *,
119
+ kind: str | None = None,
120
+ fail_runtime_handler: Callable[[str], None] | None = None,
121
+ ) -> Path:
122
+ fail_handler = fail_runtime_handler or fail_runtime
123
+ asset = resolve_scoring_asset_by_role(
124
+ runtime_context,
125
+ role=role,
126
+ kind=kind,
127
+ fail_runtime=fail_handler,
128
+ )
129
+ return asset["path"]
130
+
131
+
132
+ def load_json_file(path: Path, *, label: str | None = None) -> Any:
133
+ try:
134
+ return json.loads(path.read_text(encoding="utf-8"))
135
+ except FileNotFoundError as error:
136
+ human_label = label or str(path)
137
+ raise RuntimeError(f"{human_label} is missing") from error
138
+ except json.JSONDecodeError as error:
139
+ human_label = label or str(path)
140
+ raise RuntimeError(f"{human_label} is not valid JSON: {error.msg}") from error
141
+
142
+
143
+ def _require_finite_score(value: Any, *, label: str) -> float:
144
+ if isinstance(value, bool) or not isinstance(value, (int, float)):
145
+ raise RuntimeError(f"{label}.score must be a finite number.")
146
+ score = float(value)
147
+ if not math.isfinite(score):
148
+ raise RuntimeError(f"{label}.score must be a finite number.")
149
+ return score
150
+
151
+
152
+ def load_score_output(path: Path, *, label: str) -> dict[str, Any]:
153
+ payload = load_json_file(path, label=label)
154
+ if not isinstance(payload, dict):
155
+ raise RuntimeError(f"{label} must be a JSON object.")
156
+
157
+ ok = payload.get("ok")
158
+ if not isinstance(ok, bool):
159
+ raise RuntimeError(f"{label}.ok must be a boolean.")
160
+
161
+ details = payload.get("details", {})
162
+ if not isinstance(details, dict):
163
+ raise RuntimeError(f"{label}.details must be an object when present.")
164
+
165
+ if "score" not in payload:
166
+ raise RuntimeError(f"{label}.score is missing.")
167
+ score = _require_finite_score(payload.get("score"), label=label)
168
+
169
+ error = payload.get("error")
170
+ if error is not None and not isinstance(error, str):
171
+ raise RuntimeError(f"{label}.error must be a string when present.")
172
+ if ok is False and (error is None or not error.strip()):
173
+ raise RuntimeError(f"{label}.error must explain ok=false outputs.")
174
+
175
+ normalized = dict(payload)
176
+ normalized["ok"] = ok
177
+ normalized["score"] = score
178
+ normalized["details"] = details
179
+ return normalized
180
+
181
+
182
+ def load_text_file(path: Path) -> str:
183
+ return path.read_text(encoding="utf-8")
184
+
185
+
186
+ def safe_extract_zip(
187
+ archive_path: Path,
188
+ destination: Path,
189
+ *,
190
+ label: str,
191
+ invalid_handler: Callable[[str], None] | None = None,
192
+ ) -> None:
193
+ handler = invalid_handler or fail_runtime
194
+ root = destination.resolve()
195
+ root.mkdir(parents=True, exist_ok=True)
196
+ try:
197
+ archive = zipfile.ZipFile(archive_path, "r")
198
+ except zipfile.BadZipFile:
199
+ handler(f"{label} is not a valid ZIP archive.")
200
+ return
201
+
202
+ with archive:
203
+ members = archive.infolist()
204
+ if len(members) > MAX_ZIP_MEMBER_COUNT:
205
+ handler(
206
+ f"{label} contains {len(members)} archive entries, exceeding the "
207
+ f"maximum ZIP member count of {MAX_ZIP_MEMBER_COUNT}. Next step: "
208
+ "remove unnecessary files and upload a smaller ZIP archive."
209
+ )
210
+ return
211
+
212
+ extraction_plan = []
213
+ for member in members:
214
+ name = member.filename
215
+ member_path = Path(name)
216
+ if (
217
+ not name
218
+ or member_path.is_absolute()
219
+ or any(part in {"", ".", ".."} for part in member_path.parts)
220
+ ):
221
+ handler(f"{label} contains unsafe archive path {name!r}.")
222
+ return
223
+ target = (root / member_path).resolve()
224
+ try:
225
+ target.relative_to(root)
226
+ except ValueError:
227
+ handler(f"{label} contains archive path outside extraction root: {name!r}.")
228
+ return
229
+ extraction_plan.append((member, target))
230
+
231
+ total_bytes = 0
232
+ for member, target in extraction_plan:
233
+ if member.is_dir():
234
+ target.mkdir(parents=True, exist_ok=True)
235
+ continue
236
+
237
+ name = member.filename
238
+ target.parent.mkdir(parents=True, exist_ok=True)
239
+ member_bytes = 0
240
+ temp_path = None
241
+ try:
242
+ with archive.open(member, "r") as source:
243
+ with tempfile.NamedTemporaryFile(
244
+ "wb",
245
+ dir=target.parent,
246
+ prefix=".agora-zip-",
247
+ delete=False,
248
+ ) as output:
249
+ temp_path = Path(output.name)
250
+ while True:
251
+ chunk = source.read(_ZIP_COPY_CHUNK_BYTES)
252
+ if not chunk:
253
+ break
254
+ member_bytes += len(chunk)
255
+ total_bytes += len(chunk)
256
+ if member_bytes > MAX_ZIP_MEMBER_BYTES:
257
+ output.close()
258
+ temp_path.unlink(missing_ok=True)
259
+ handler(
260
+ f"{label} member {name!r} exceeds the maximum ZIP "
261
+ f"member size of {MAX_ZIP_MEMBER_BYTES} bytes. "
262
+ "Next step: reduce that file and upload a ZIP "
263
+ "within the Python-v1 artifact limits."
264
+ )
265
+ return
266
+ if total_bytes > MAX_ZIP_TOTAL_BYTES:
267
+ output.close()
268
+ temp_path.unlink(missing_ok=True)
269
+ handler(
270
+ f"{label} exceeds the maximum ZIP extracted size "
271
+ f"of {MAX_ZIP_TOTAL_BYTES} bytes. Next step: "
272
+ "remove unnecessary files or upload a smaller "
273
+ "ZIP archive."
274
+ )
275
+ return
276
+ output.write(chunk)
277
+ temp_path.replace(target)
278
+ except zipfile.BadZipFile:
279
+ if temp_path is not None:
280
+ temp_path.unlink(missing_ok=True)
281
+ handler(f"{label} is not a valid ZIP archive.")
282
+ return
@@ -0,0 +1,264 @@
1
+ import math
2
+ import re
3
+
4
+ from agora_runtime import (
5
+ fail_runtime,
6
+ load_json_file,
7
+ load_runtime_context,
8
+ reject_submission,
9
+ resolve_evaluation_artifact,
10
+ resolve_scoring_asset,
11
+ resolve_submission_artifact,
12
+ write_score,
13
+ )
14
+
15
+ SUPPORTED_METRICS = ("precision", "recall", "f1", "iou")
16
+ SUPPORTED_ID_SPACES = ("opaque", "doi", "pmid", "arxiv")
17
+ DOI_PATTERN = re.compile(r"^10\.[0-9]{4,9}/[!-~]+$")
18
+ PMID_PATTERN = re.compile(r"^[0-9]+$")
19
+ ARXIV_PATTERN = re.compile(
20
+ r"^(?:[a-z-]+(?:\.[a-z]{2})?/[0-9]{7}|[0-9]{4}\.[0-9]{4,5})(?:v[0-9]+)?$"
21
+ )
22
+
23
+
24
+ def format_value_list(values):
25
+ ordered = list(values)
26
+ if len(ordered) == 1:
27
+ return ordered[0]
28
+ if len(ordered) == 2:
29
+ return f"{ordered[0]} or {ordered[1]}"
30
+ return f"{', '.join(ordered[:-1])}, or {ordered[-1]}"
31
+
32
+
33
+ def require_string(value, label):
34
+ if not isinstance(value, str) or not value.strip():
35
+ fail_runtime(f"{label} must be a non-empty string.")
36
+ return value.strip()
37
+
38
+
39
+ def require_config_value(config, key, supported_values):
40
+ value = require_string(config.get(key), f"compiled_config.{key}").lower()
41
+ if value not in supported_values:
42
+ fail_runtime(
43
+ f"compiled_config.{key} must be one of {format_value_list(supported_values)}."
44
+ )
45
+ return value
46
+
47
+
48
+ def require_json_slot(runtime_context, lane, role):
49
+ artifact_contract = runtime_context.get("artifact_contract")
50
+ if not isinstance(artifact_contract, dict):
51
+ fail_runtime("Runtime context is missing artifact_contract.")
52
+ slots = artifact_contract.get(lane)
53
+ if not isinstance(slots, list):
54
+ fail_runtime(f"Runtime context is missing artifact_contract.{lane}.")
55
+ for slot in slots:
56
+ if not isinstance(slot, dict) or slot.get("role") != role:
57
+ continue
58
+ validator = slot.get("validator")
59
+ if not isinstance(validator, dict) or validator.get("kind") not in {
60
+ "json_document",
61
+ "json_schema",
62
+ }:
63
+ fail_runtime(
64
+ f"{lane} role {role} must use validator.kind=json_document or json_schema for answer_set_metric."
65
+ )
66
+ return
67
+ fail_runtime(f"Runtime context is missing {lane} slot for role {role}.")
68
+
69
+
70
+ def load_json_payload(path, label, invalid_handler):
71
+ try:
72
+ return load_json_file(path, label=label)
73
+ except RuntimeError as error:
74
+ invalid_handler(str(error))
75
+
76
+
77
+ def require_ascii(value, label, invalid_handler):
78
+ try:
79
+ value.encode("ascii")
80
+ except UnicodeEncodeError:
81
+ invalid_handler(f"{label} must contain ASCII characters only.")
82
+ return value
83
+
84
+
85
+ def normalize_answer_id(raw_value, id_space, label, invalid_handler):
86
+ if not isinstance(raw_value, str):
87
+ invalid_handler(f"{label} must be a string.")
88
+ value = raw_value.strip()
89
+ if not value:
90
+ invalid_handler(f"{label} must be a non-empty string after trimming.")
91
+
92
+ if id_space == "opaque":
93
+ return value
94
+
95
+ ascii_value = require_ascii(value, label, invalid_handler)
96
+ if id_space == "pmid":
97
+ if not PMID_PATTERN.fullmatch(ascii_value):
98
+ invalid_handler(f"{label} must be a bare PMID containing ASCII digits only.")
99
+ return ascii_value
100
+
101
+ normalized = ascii_value.lower()
102
+ if id_space == "doi":
103
+ if not DOI_PATTERN.fullmatch(normalized):
104
+ invalid_handler(
105
+ f"{label} must be a bare DOI such as 10.1234/example, not a URL or alias."
106
+ )
107
+ return normalized
108
+
109
+ if not ARXIV_PATTERN.fullmatch(normalized):
110
+ invalid_handler(
111
+ f"{label} must be a canonical arXiv id string, not a URL or prefixed alias."
112
+ )
113
+ return normalized
114
+
115
+
116
+ def load_answer_set(path, role, id_space, *, lane, invalid_handler):
117
+ payload = load_json_payload(path, f"{lane} artifact {role}", invalid_handler)
118
+ if not isinstance(payload, dict):
119
+ invalid_handler(f"{lane} artifact {role} must be a JSON object.")
120
+ items = payload.get("items")
121
+ if not isinstance(items, list):
122
+ invalid_handler(f"{lane} artifact {role}.items must be an array.")
123
+
124
+ answer_ids = set()
125
+ for index, item in enumerate(items):
126
+ normalized = normalize_answer_id(
127
+ item,
128
+ id_space,
129
+ f"{lane} artifact {role}.items[{index}]",
130
+ invalid_handler,
131
+ )
132
+ if normalized in answer_ids:
133
+ invalid_handler(
134
+ f"{lane} artifact {role} contains duplicate normalized id {normalized!r}."
135
+ )
136
+ answer_ids.add(normalized)
137
+
138
+ if lane == "evaluation" and not answer_ids:
139
+ invalid_handler(f"evaluation artifact {role}.items must contain at least one id.")
140
+ return answer_ids
141
+
142
+
143
+ def compute_counts(reference_ids, candidate_ids):
144
+ true_positive_count = len(reference_ids & candidate_ids)
145
+ false_positive_count = len(candidate_ids - reference_ids)
146
+ false_negative_count = len(reference_ids - candidate_ids)
147
+ return true_positive_count, false_positive_count, false_negative_count
148
+
149
+
150
+ def compute_metric(metric, true_positive_count, false_positive_count, false_negative_count):
151
+ candidate_count = true_positive_count + false_positive_count
152
+ reference_count = true_positive_count + false_negative_count
153
+ if metric == "precision":
154
+ if candidate_count == 0:
155
+ return 0.0
156
+ return true_positive_count / candidate_count
157
+ if metric == "recall":
158
+ return true_positive_count / reference_count
159
+ if metric == "f1":
160
+ precision = 0.0 if candidate_count == 0 else true_positive_count / candidate_count
161
+ recall = true_positive_count / reference_count
162
+ if precision + recall == 0.0:
163
+ return 0.0
164
+ return (2.0 * precision * recall) / (precision + recall)
165
+
166
+ union_count = true_positive_count + false_positive_count + false_negative_count
167
+ return true_positive_count / union_count
168
+
169
+
170
+ def require_finite_unit_interval(value, label):
171
+ if not math.isfinite(value):
172
+ fail_runtime(f"{label} must be finite.")
173
+ if value < 0.0:
174
+ return 0.0
175
+ if value > 1.0:
176
+ return 1.0
177
+ return value
178
+
179
+
180
+ def main():
181
+ runtime_context = load_runtime_context()
182
+ config_path = resolve_scoring_asset(
183
+ runtime_context,
184
+ "compiled_config",
185
+ kind="config",
186
+ )
187
+ config = load_json_payload(
188
+ config_path,
189
+ "compiled_config",
190
+ fail_runtime,
191
+ )
192
+ if not isinstance(config, dict):
193
+ fail_runtime("compiled_config must be a JSON object.")
194
+
195
+ metric = require_config_value(config, "metric", SUPPORTED_METRICS)
196
+ id_space = require_config_value(config, "id_space", SUPPORTED_ID_SPACES)
197
+ reference_role = require_string(
198
+ config.get("reference_role"),
199
+ "compiled_config.reference_role",
200
+ )
201
+ candidate_role = require_string(
202
+ config.get("candidate_role"),
203
+ "compiled_config.candidate_role",
204
+ )
205
+ final_score_key = require_string(
206
+ config.get("final_score_key"),
207
+ "compiled_config.final_score_key",
208
+ )
209
+
210
+ objective = require_string(
211
+ runtime_context.get("objective"),
212
+ "runtime_context.objective",
213
+ )
214
+ if objective != "maximize":
215
+ fail_runtime(f"answer_set_metric metric {metric} requires objective=maximize.")
216
+
217
+ require_json_slot(runtime_context, "evaluation", reference_role)
218
+ require_json_slot(runtime_context, "submission", candidate_role)
219
+ reference_path = resolve_evaluation_artifact(runtime_context, reference_role)
220
+ candidate_path = resolve_submission_artifact(runtime_context, candidate_role)
221
+ reference_ids = load_answer_set(
222
+ reference_path,
223
+ reference_role,
224
+ id_space,
225
+ lane="evaluation",
226
+ invalid_handler=fail_runtime,
227
+ )
228
+ candidate_ids = load_answer_set(
229
+ candidate_path,
230
+ candidate_role,
231
+ id_space,
232
+ lane="submission",
233
+ invalid_handler=reject_submission,
234
+ )
235
+
236
+ true_positive_count, false_positive_count, false_negative_count = compute_counts(
237
+ reference_ids,
238
+ candidate_ids,
239
+ )
240
+ raw_metric = compute_metric(
241
+ metric,
242
+ true_positive_count,
243
+ false_positive_count,
244
+ false_negative_count,
245
+ )
246
+ score = require_finite_unit_interval(raw_metric, "answer_set_metric score")
247
+ write_score(
248
+ score=score,
249
+ details={
250
+ final_score_key: score,
251
+ "selected_metric": metric,
252
+ "selected_id_space": id_space,
253
+ "reference_count": len(reference_ids),
254
+ "candidate_count": len(candidate_ids),
255
+ "true_positive_count": true_positive_count,
256
+ "false_positive_count": false_positive_count,
257
+ "false_negative_count": false_negative_count,
258
+ "selected_metric_value": raw_metric,
259
+ },
260
+ )
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()