@moleculeagora/cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,532 @@
1
+ import csv
2
+ import json
3
+ import math
4
+ import re
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ from agora_runtime import (
9
+ fail_runtime,
10
+ load_json_file,
11
+ load_runtime_context,
12
+ reject_submission,
13
+ resolve_evaluation_artifact,
14
+ resolve_scoring_asset,
15
+ resolve_submission_artifact,
16
+ safe_extract_zip,
17
+ write_score,
18
+ )
19
+
20
+ FIELD_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
21
+
22
+
23
+ def require_string(value, label):
24
+ if not isinstance(value, str) or not value.strip():
25
+ fail_runtime(f"{label} must be a non-empty string.")
26
+ return value.strip()
27
+
28
+
29
+ def require_object(value, label):
30
+ if not isinstance(value, dict):
31
+ fail_runtime(f"{label} must be a JSON object.")
32
+ return value
33
+
34
+
35
+ def require_list(value, label):
36
+ if not isinstance(value, list):
37
+ fail_runtime(f"{label} must be an array.")
38
+ return value
39
+
40
+
41
+ def require_number(value, label):
42
+ if isinstance(value, bool) or not isinstance(value, (int, float)):
43
+ fail_runtime(f"{label} must be a finite number.")
44
+ number = float(value)
45
+ if not math.isfinite(number):
46
+ fail_runtime(f"{label} must be a finite number.")
47
+ return number
48
+
49
+
50
+ def require_positive_number(value, label):
51
+ number = require_number(value, label)
52
+ if number <= 0:
53
+ fail_runtime(f"{label} must be positive.")
54
+ return number
55
+
56
+
57
+ def require_nonnegative_number(value, label):
58
+ number = require_number(value, label)
59
+ if number < 0:
60
+ fail_runtime(f"{label} must be nonnegative.")
61
+ return number
62
+
63
+
64
+ def require_positive_int(value, label):
65
+ if isinstance(value, bool) or not isinstance(value, int) or value <= 0:
66
+ fail_runtime(f"{label} must be a positive integer.")
67
+ return value
68
+
69
+
70
+ def require_json_scalar(value, label):
71
+ if (
72
+ value is None
73
+ or isinstance(value, str)
74
+ or isinstance(value, bool)
75
+ or (
76
+ isinstance(value, (int, float))
77
+ and not isinstance(value, bool)
78
+ and math.isfinite(float(value))
79
+ )
80
+ ):
81
+ return value
82
+ fail_runtime(f"{label} must be a JSON scalar.")
83
+
84
+
85
+ def normalize_member_path(value, label):
86
+ path_value = require_string(value, label)
87
+ path_parts = path_value.split("/")
88
+ member_path = Path(path_value)
89
+ if member_path.is_absolute() or any(
90
+ part in {"", ".", ".."} for part in path_parts
91
+ ):
92
+ fail_runtime(f"{label} must be a safe relative archive path.")
93
+ return path_value
94
+
95
+
96
+ def resolve_member(root, path_value, label):
97
+ relative_path = normalize_member_path(path_value, label)
98
+ root_path = root.resolve()
99
+ target = (root_path / relative_path).resolve()
100
+ try:
101
+ target.relative_to(root_path)
102
+ except ValueError:
103
+ fail_runtime(f"{label} must stay inside the extracted archive root.")
104
+ return target
105
+
106
+
107
+ def load_compiled_config(path):
108
+ try:
109
+ return load_json_file(path, label="compiled_config")
110
+ except RuntimeError as error:
111
+ fail_runtime(str(error))
112
+
113
+
114
+ def read_text(path, label):
115
+ try:
116
+ return True, path.read_text(encoding="utf-8"), None
117
+ except FileNotFoundError:
118
+ return False, None, f"{label} is missing."
119
+ except UnicodeDecodeError:
120
+ return False, None, f"{label} is not valid UTF-8 text."
121
+ except OSError as error:
122
+ return False, None, f"{label} could not be read: {error}."
123
+
124
+
125
+ def load_json_member(path, label):
126
+ try:
127
+ return True, json.loads(path.read_text(encoding="utf-8")), None
128
+ except FileNotFoundError:
129
+ return False, None, f"{label} is missing."
130
+ except UnicodeDecodeError:
131
+ return False, None, f"{label} is not valid UTF-8 text."
132
+ except json.JSONDecodeError as error:
133
+ return False, None, f"{label} is not valid JSON: {error.msg}."
134
+ except OSError as error:
135
+ return False, None, f"{label} could not be read: {error}."
136
+
137
+
138
+ def parse_json_path(path_value, label):
139
+ path = require_string(path_value, label)
140
+ if path == "$":
141
+ return []
142
+ if not path.startswith("$"):
143
+ fail_runtime(f"{label} must start with $.")
144
+ tokens = []
145
+ index = 1
146
+ while index < len(path):
147
+ if path[index] == ".":
148
+ index += 1
149
+ match = FIELD_RE.match(path, index)
150
+ if not match:
151
+ fail_runtime(f"{label} contains an invalid field segment.")
152
+ tokens.append(("field", match.group(0)))
153
+ index = match.end()
154
+ continue
155
+ if path[index] == "[":
156
+ end = path.find("]", index)
157
+ if end == -1:
158
+ fail_runtime(f"{label} contains an unterminated array index.")
159
+ raw_index = path[index + 1 : end]
160
+ if not raw_index.isdigit() or (
161
+ len(raw_index) > 1 and raw_index.startswith("0")
162
+ ):
163
+ fail_runtime(f"{label} contains an invalid array index.")
164
+ tokens.append(("index", int(raw_index)))
165
+ index = end + 1
166
+ continue
167
+ fail_runtime(f"{label} contains an unsupported path segment.")
168
+ return tokens
169
+
170
+
171
+ def extract_json_path(data, tokens):
172
+ current = data
173
+ for kind, token in tokens:
174
+ if kind == "field":
175
+ if not isinstance(current, dict) or token not in current:
176
+ return False, None, f"JSON field {token} is missing."
177
+ current = current[token]
178
+ continue
179
+ if not isinstance(current, list) or token >= len(current):
180
+ return False, None, f"JSON array index {token} is missing."
181
+ current = current[token]
182
+ return True, current, None
183
+
184
+
185
+ def json_scalar_equal(candidate_value, reference_value):
186
+ if reference_value is None or isinstance(reference_value, bool):
187
+ return candidate_value is reference_value
188
+ if isinstance(reference_value, str):
189
+ return isinstance(candidate_value, str) and candidate_value == reference_value
190
+ if isinstance(reference_value, (int, float)) and not isinstance(
191
+ reference_value,
192
+ bool,
193
+ ):
194
+ return (
195
+ isinstance(candidate_value, (int, float))
196
+ and not isinstance(candidate_value, bool)
197
+ and float(candidate_value) == float(reference_value)
198
+ )
199
+ return False
200
+
201
+
202
+ def read_csv_cell(path, row_number, column, label):
203
+ try:
204
+ with path.open("r", encoding="utf-8", newline="") as handle:
205
+ reader = csv.DictReader(handle)
206
+ if reader.fieldnames is None:
207
+ return False, None, f"{label} is missing a header row."
208
+ if column not in reader.fieldnames:
209
+ return False, None, f"{label} is missing CSV column {column}."
210
+ for current_row, row in enumerate(reader, start=1):
211
+ if current_row == row_number:
212
+ return True, row.get(column, ""), None
213
+ except FileNotFoundError:
214
+ return False, None, f"{label} is missing."
215
+ except UnicodeDecodeError:
216
+ return False, None, f"{label} is not valid UTF-8 text."
217
+ except csv.Error as error:
218
+ return False, None, f"{label} is not valid CSV: {error}."
219
+ except OSError as error:
220
+ return False, None, f"{label} could not be read: {error}."
221
+ return False, None, f"{label} is missing CSV row {row_number}."
222
+
223
+
224
+ def to_finite_number(value, label):
225
+ if isinstance(value, bool) or not isinstance(value, (int, float, str)):
226
+ return False, None, f"{label} is not numeric."
227
+ try:
228
+ number = float(value)
229
+ except ValueError:
230
+ return False, None, f"{label} is not numeric."
231
+ if not math.isfinite(number):
232
+ return False, None, f"{label} is not finite."
233
+ return True, number, None
234
+
235
+
236
+ def extract_regex_capture(root, extractor, label):
237
+ path = resolve_member(root, extractor.get("path"), f"{label}.path")
238
+ ok, text, reason = read_text(path, f"{label}.path")
239
+ if not ok:
240
+ return False, None, reason
241
+ pattern = require_string(extractor.get("pattern"), f"{label}.pattern")
242
+ group = require_positive_int(extractor.get("group"), f"{label}.group")
243
+ try:
244
+ match = re.search(pattern, text)
245
+ except re.error as error:
246
+ fail_runtime(f"{label}.pattern is not a valid regex: {error}.")
247
+ if not match:
248
+ return False, None, f"{label}.pattern did not match."
249
+ try:
250
+ return True, match.group(group), None
251
+ except IndexError:
252
+ fail_runtime(f"{label}.group references a missing capture group.")
253
+
254
+
255
+ def extract_scalar(root, extractor, label):
256
+ source = require_object(extractor, label)
257
+ extractor_kind = require_string(source.get("kind"), f"{label}.kind")
258
+ if extractor_kind == "json_path":
259
+ path = resolve_member(root, source.get("path"), f"{label}.path")
260
+ ok, data, reason = load_json_member(path, f"{label}.path")
261
+ if not ok:
262
+ return False, None, reason
263
+ tokens = parse_json_path(source.get("json_path"), f"{label}.json_path")
264
+ ok, value, reason = extract_json_path(data, tokens)
265
+ if not ok:
266
+ return False, None, reason
267
+ return True, require_json_scalar(value, label), None
268
+ if extractor_kind == "csv_cell":
269
+ path = resolve_member(root, source.get("path"), f"{label}.path")
270
+ row = require_positive_int(source.get("row"), f"{label}.row")
271
+ column = require_string(source.get("column"), f"{label}.column")
272
+ return read_csv_cell(path, row, column, f"{label}.path")
273
+ if extractor_kind == "text_regex_capture":
274
+ return extract_regex_capture(root, source, label)
275
+ fail_runtime(f"{label}.kind={extractor_kind} is unsupported.")
276
+
277
+
278
+ def extract_numeric(root, extractor, label):
279
+ source = require_object(extractor, label)
280
+ extractor_kind = require_string(source.get("kind"), f"{label}.kind")
281
+ if extractor_kind == "json_path_number":
282
+ path = resolve_member(root, source.get("path"), f"{label}.path")
283
+ ok, data, reason = load_json_member(path, f"{label}.path")
284
+ if not ok:
285
+ return False, None, reason
286
+ tokens = parse_json_path(source.get("json_path"), f"{label}.json_path")
287
+ ok, value, reason = extract_json_path(data, tokens)
288
+ if not ok:
289
+ return False, None, reason
290
+ return to_finite_number(value, label)
291
+ if extractor_kind == "csv_cell_number":
292
+ path = resolve_member(root, source.get("path"), f"{label}.path")
293
+ row = require_positive_int(source.get("row"), f"{label}.row")
294
+ column = require_string(source.get("column"), f"{label}.column")
295
+ ok, value, reason = read_csv_cell(path, row, column, f"{label}.path")
296
+ if not ok:
297
+ return False, None, reason
298
+ return to_finite_number(value, label)
299
+ if extractor_kind == "text_regex_number":
300
+ ok, value, reason = extract_regex_capture(root, source, label)
301
+ if not ok:
302
+ return False, None, reason
303
+ return to_finite_number(value, label)
304
+ fail_runtime(f"{label}.kind={extractor_kind} is unsupported.")
305
+
306
+
307
+ def require_reference_scalar(root, extractor, label):
308
+ ok, value, reason = extract_scalar(root, extractor, label)
309
+ if not ok:
310
+ fail_runtime(f"{label} could not be extracted from reference archive: {reason}")
311
+ return value
312
+
313
+
314
+ def require_reference_number(root, extractor, label):
315
+ ok, value, reason = extract_numeric(root, extractor, label)
316
+ if not ok:
317
+ fail_runtime(f"{label} could not be extracted from reference archive: {reason}")
318
+ return value
319
+
320
+
321
+ def resolve_tolerance(reference_root, tolerance, label):
322
+ policy = require_object(tolerance, label)
323
+ tolerance_kind = require_string(policy.get("kind"), f"{label}.kind")
324
+ if tolerance_kind == "absolute":
325
+ return require_nonnegative_number(policy.get("value"), f"{label}.value")
326
+ if tolerance_kind == "absolute_from_reference":
327
+ value = require_reference_number(
328
+ reference_root,
329
+ policy.get("source"),
330
+ f"{label}.source",
331
+ )
332
+ if value < 0:
333
+ fail_runtime(f"{label}.source must extract a nonnegative tolerance.")
334
+ return value
335
+ fail_runtime(f"{label}.kind={tolerance_kind} is unsupported.")
336
+
337
+
338
+ def normalize_assertion(value, index):
339
+ assertion = require_object(
340
+ value,
341
+ f"reference_artifact_assertion.assertions[{index}]",
342
+ )
343
+ assertion_id = require_string(
344
+ assertion.get("id"),
345
+ f"reference_artifact_assertion.assertions[{index}].id",
346
+ )
347
+ assertion_kind = require_string(
348
+ assertion.get("kind"),
349
+ f"reference_artifact_assertion.assertions[{index}].kind",
350
+ )
351
+ weight = require_positive_number(
352
+ assertion.get("weight"),
353
+ f"reference_artifact_assertion.assertions[{index}].weight",
354
+ )
355
+ return assertion_id, assertion_kind, weight, assertion
356
+
357
+
358
+ def evaluate_assertion(candidate_root, reference_root, assertion, index):
359
+ assertion_id, assertion_kind, weight, raw = normalize_assertion(assertion, index)
360
+ label = f"reference_artifact_assertion.assertions[{index}]"
361
+
362
+ if assertion_kind == "scalar_equals_reference":
363
+ ok, candidate_value, reason = extract_scalar(
364
+ candidate_root,
365
+ raw.get("candidate"),
366
+ f"{label}.candidate",
367
+ )
368
+ reference_value = require_reference_scalar(
369
+ reference_root,
370
+ raw.get("reference"),
371
+ f"{label}.reference",
372
+ )
373
+ passed = ok and json_scalar_equal(candidate_value, reference_value)
374
+ if ok and not passed:
375
+ reason = "candidate scalar did not equal reference scalar."
376
+ elif assertion_kind == "number_within_reference_tolerance":
377
+ ok, candidate_value, reason = extract_numeric(
378
+ candidate_root,
379
+ raw.get("candidate"),
380
+ f"{label}.candidate",
381
+ )
382
+ reference_value = require_reference_number(
383
+ reference_root,
384
+ raw.get("reference"),
385
+ f"{label}.reference",
386
+ )
387
+ tolerance = resolve_tolerance(
388
+ reference_root,
389
+ raw.get("tolerance"),
390
+ f"{label}.tolerance",
391
+ )
392
+ if ok:
393
+ difference = abs(candidate_value - reference_value)
394
+ passed = difference <= tolerance
395
+ if not passed:
396
+ reason = (
397
+ f"numeric difference {difference} exceeded absolute "
398
+ f"tolerance {tolerance}."
399
+ )
400
+ else:
401
+ passed = False
402
+ elif assertion_kind == "number_in_reference_interval":
403
+ ok, candidate_value, reason = extract_numeric(
404
+ candidate_root,
405
+ raw.get("candidate"),
406
+ f"{label}.candidate",
407
+ )
408
+ reference = require_object(raw.get("reference"), f"{label}.reference")
409
+ minimum = require_reference_number(
410
+ reference_root,
411
+ reference.get("min"),
412
+ f"{label}.reference.min",
413
+ )
414
+ maximum = require_reference_number(
415
+ reference_root,
416
+ reference.get("max"),
417
+ f"{label}.reference.max",
418
+ )
419
+ if minimum > maximum:
420
+ fail_runtime(f"{label}.reference.min must be less than or equal to max.")
421
+ passed = ok and minimum <= candidate_value <= maximum
422
+ if ok and not passed:
423
+ reason = "candidate number was outside the reference interval."
424
+ else:
425
+ fail_runtime(f"{label}.kind={assertion_kind} is unsupported.")
426
+
427
+ return {
428
+ "id": assertion_id,
429
+ "kind": assertion_kind,
430
+ "weight": weight,
431
+ "passed": bool(passed),
432
+ "score": 1.0 if passed else 0.0,
433
+ **({} if reason is None else {"reason": reason}),
434
+ }
435
+
436
+
437
+ def normalize_assertions(params):
438
+ assertions = require_list(
439
+ params.get("assertions"),
440
+ "reference_artifact_assertion.assertions",
441
+ )
442
+ if not assertions:
443
+ fail_runtime("reference_artifact_assertion.assertions must be non-empty.")
444
+ seen = set()
445
+ for index, assertion in enumerate(assertions):
446
+ assertion_id, _, _, _ = normalize_assertion(assertion, index)
447
+ if assertion_id in seen:
448
+ fail_runtime(
449
+ f"reference_artifact_assertion.assertions id {assertion_id} is duplicated."
450
+ )
451
+ seen.add(assertion_id)
452
+ return assertions
453
+
454
+
455
+ def main():
456
+ runtime_context = load_runtime_context()
457
+ config_path = resolve_scoring_asset(
458
+ runtime_context,
459
+ "compiled_config",
460
+ kind="config",
461
+ )
462
+ config = require_object(load_compiled_config(config_path), "compiled_config")
463
+ candidate_role = require_string(
464
+ config.get("candidate_role"),
465
+ "compiled_config.candidate_role",
466
+ )
467
+ reference_role = require_string(
468
+ config.get("reference_role"),
469
+ "compiled_config.reference_role",
470
+ )
471
+ config_score_key = require_string(
472
+ config.get("final_score_key"),
473
+ "compiled_config.final_score_key",
474
+ )
475
+ final_score_key = require_string(
476
+ runtime_context.get("final_score_key"),
477
+ "runtime_context.final_score_key",
478
+ )
479
+ if config_score_key != final_score_key:
480
+ fail_runtime("compiled_config.final_score_key must match runtime context.")
481
+ params = require_object(
482
+ config.get("reference_artifact_assertion"),
483
+ "compiled_config.reference_artifact_assertion",
484
+ )
485
+ assertions = normalize_assertions(params)
486
+ candidate_path = resolve_submission_artifact(runtime_context, candidate_role)
487
+ reference_path = resolve_evaluation_artifact(runtime_context, reference_role)
488
+
489
+ with tempfile.TemporaryDirectory(prefix="agora-reference-artifact-") as temp_dir:
490
+ candidate_root = Path(temp_dir) / "candidate"
491
+ reference_root = Path(temp_dir) / "reference"
492
+ safe_extract_zip(
493
+ candidate_path,
494
+ candidate_root,
495
+ label=f"submission artifact {candidate_role}",
496
+ invalid_handler=reject_submission,
497
+ )
498
+ safe_extract_zip(
499
+ reference_path,
500
+ reference_root,
501
+ label=f"evaluation artifact {reference_role}",
502
+ )
503
+ results = [
504
+ evaluate_assertion(candidate_root, reference_root, assertion, index)
505
+ for index, assertion in enumerate(assertions)
506
+ ]
507
+
508
+ total_weight = sum(result["weight"] for result in results)
509
+ if not math.isfinite(total_weight) or total_weight <= 0:
510
+ fail_runtime(
511
+ "reference_artifact_assertion.assertions must declare finite positive total weight."
512
+ )
513
+ earned_weight = sum(result["weight"] for result in results if result["passed"])
514
+ score = earned_weight / total_weight
515
+ if not math.isfinite(score):
516
+ fail_runtime("reference_artifact_assertion score must be finite.")
517
+ passed_count = sum(1 for result in results if result["passed"])
518
+ write_score(
519
+ score=score,
520
+ details={
521
+ final_score_key: score,
522
+ "passed_assertions": passed_count,
523
+ "total_assertions": len(results),
524
+ "earned_weight": earned_weight,
525
+ "total_weight": total_weight,
526
+ "assertion_results": results,
527
+ },
528
+ )
529
+
530
+
531
+ if __name__ == "__main__":
532
+ main()