@moleculeagora/cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1291 @@
1
+ import csv
2
+ import math
3
+
4
+ from agora_runtime import (
5
+ fail_runtime,
6
+ load_json_file,
7
+ load_runtime_context,
8
+ reject_submission,
9
+ resolve_evaluation_artifact,
10
+ resolve_scoring_asset,
11
+ resolve_submission_artifact,
12
+ write_score,
13
+ )
14
+
15
+ SUPPORTED_METRICS = (
16
+ "rmse",
17
+ "mae",
18
+ "r2",
19
+ "pearson",
20
+ "spearman",
21
+ "accuracy",
22
+ "f1",
23
+ "auroc",
24
+ "brier",
25
+ "log_loss",
26
+ "auprc",
27
+ "ece",
28
+ "z_rmse",
29
+ "within_noise_fraction",
30
+ )
31
+ SUPPORTED_METRIC_SET = set(SUPPORTED_METRICS)
32
+ LOG_LOSS_EPSILON = 1e-15
33
+ PROBABILITY_METRICS = {"brier", "log_loss", "auprc", "ece"}
34
+ CALIBRATED_METRICS = {"z_rmse", "within_noise_fraction"}
35
+ MINIMIZE_METRICS = {"rmse", "mae", "brier", "log_loss", "ece", "z_rmse"}
36
+ CORRELATION_METRICS = {"pearson", "spearman"}
37
+ BINARY_LABEL_METRICS = {"accuracy", "f1", "auroc"} | PROBABILITY_METRICS
38
+ UNIT_INTERVAL_METRICS = {
39
+ "accuracy",
40
+ "f1",
41
+ "auroc",
42
+ "auprc",
43
+ "within_noise_fraction",
44
+ }
45
+ MAXIMIZE_METRICS = {"r2"} | CORRELATION_METRICS | UNIT_INTERVAL_METRICS
46
+
47
+
48
+ def format_metric_list(metrics):
49
+ ordered = list(metrics)
50
+ if len(ordered) == 1:
51
+ return ordered[0]
52
+ if len(ordered) == 2:
53
+ return f"{ordered[0]} or {ordered[1]}"
54
+ return f"{', '.join(ordered[:-1])}, or {ordered[-1]}"
55
+
56
+
57
+ def require_string(value, label):
58
+ if not isinstance(value, str) or not value.strip():
59
+ fail_runtime(f"{label} must be a non-empty string.")
60
+ return value.strip()
61
+
62
+
63
+ def require_metric(config):
64
+ metric = require_string(config.get("metric"), "compiled_config.metric").lower()
65
+ if metric not in SUPPORTED_METRIC_SET:
66
+ fail_runtime(
67
+ f"compiled_config.metric must be one of {format_metric_list(SUPPORTED_METRICS)}."
68
+ )
69
+ return metric
70
+
71
+
72
+ def require_metric_params(config, metric):
73
+ metric_params = config.get("metric_params")
74
+ if not isinstance(metric_params, dict):
75
+ fail_runtime("compiled_config.metric_params must be an object.")
76
+
77
+ param_metric = require_string(
78
+ metric_params.get("metric"),
79
+ "compiled_config.metric_params.metric",
80
+ ).lower()
81
+ if param_metric != metric:
82
+ fail_runtime(
83
+ f"compiled_config.metric_params.metric must equal compiled_config.metric {metric!r}."
84
+ )
85
+
86
+ if metric == "ece":
87
+ bin_count = metric_params.get("bin_count")
88
+ if isinstance(bin_count, bool) or not isinstance(bin_count, int):
89
+ fail_runtime(
90
+ "compiled_config.metric_params.bin_count must be an integer from 2 to 100 for ece."
91
+ )
92
+ if bin_count < 2 or bin_count > 100:
93
+ fail_runtime(
94
+ "compiled_config.metric_params.bin_count must be an integer from 2 to 100 for ece."
95
+ )
96
+ binning = require_string(
97
+ metric_params.get("binning"),
98
+ "compiled_config.metric_params.binning",
99
+ )
100
+ if binning != "uniform":
101
+ fail_runtime(
102
+ "compiled_config.metric_params.binning must be uniform for ece."
103
+ )
104
+ if metric == "z_rmse":
105
+ require_string(
106
+ metric_params.get("baseline_mean_field"),
107
+ "compiled_config.metric_params.baseline_mean_field",
108
+ )
109
+ require_string(
110
+ metric_params.get("baseline_sd_field"),
111
+ "compiled_config.metric_params.baseline_sd_field",
112
+ )
113
+ if metric == "within_noise_fraction":
114
+ require_string(
115
+ metric_params.get("noise_floor_field"),
116
+ "compiled_config.metric_params.noise_floor_field",
117
+ )
118
+ return metric_params
119
+
120
+
121
+ def require_config_field_matches_metric_params(config, metric_params, key):
122
+ value = require_string(config.get(key), f"compiled_config.{key}")
123
+ metric_param_value = require_string(
124
+ metric_params.get(key),
125
+ f"compiled_config.metric_params.{key}",
126
+ )
127
+ if metric_param_value != value:
128
+ fail_runtime(
129
+ f"compiled_config.{key} must match compiled_config.metric_params.{key}."
130
+ )
131
+ return value
132
+
133
+
134
+ def require_grouping_config(config):
135
+ raw_group_by = config.get("group_by")
136
+ if raw_group_by is None:
137
+ if "aggregation" in config:
138
+ fail_runtime("compiled_config.aggregation requires compiled_config.group_by.")
139
+ return [], None
140
+ if not isinstance(raw_group_by, list) or not (1 <= len(raw_group_by) <= 3):
141
+ fail_runtime("compiled_config.group_by must list one to three CSV columns.")
142
+
143
+ group_by = []
144
+ seen = set()
145
+ for index, raw_column in enumerate(raw_group_by):
146
+ column = require_string(
147
+ raw_column,
148
+ f"compiled_config.group_by[{index}]",
149
+ )
150
+ if column in seen:
151
+ fail_runtime(
152
+ f"compiled_config.group_by repeats column {column!r}; grouped metrics require canonical unique columns."
153
+ )
154
+ seen.add(column)
155
+ group_by.append(column)
156
+
157
+ aggregation = require_string(
158
+ config.get("aggregation"),
159
+ "compiled_config.aggregation",
160
+ )
161
+ if aggregation not in {"macro_mean", "weighted_mean"}:
162
+ fail_runtime(
163
+ "compiled_config.aggregation must be macro_mean or weighted_mean for grouped table metrics."
164
+ )
165
+ return group_by, aggregation
166
+
167
+
168
+ def require_policy(policies, key, allowed):
169
+ value = require_string(policies.get(key), f"policies.{key}")
170
+ if value not in allowed:
171
+ fail_runtime(
172
+ f"policies.{key} must be one of {', '.join(sorted(allowed))}."
173
+ )
174
+ return value
175
+
176
+
177
+ def find_slot(runtime_context, lane, role):
178
+ slot_key = f"{lane}_slots"
179
+ slots = runtime_context.get(slot_key)
180
+ if not isinstance(slots, list):
181
+ fail_runtime(f"Runtime context is missing {slot_key}.")
182
+ for slot in slots:
183
+ if isinstance(slot, dict) and slot.get("role") == role:
184
+ return slot
185
+ fail_runtime(f"Runtime context is missing {lane} slot for role {role}.")
186
+
187
+
188
+ def require_csv_slot_columns(runtime_context, lane, role):
189
+ slot = find_slot(runtime_context, lane, role)
190
+ validator = slot.get("validator")
191
+ if not isinstance(validator, dict) or validator.get("kind") != "csv_columns":
192
+ fail_runtime(
193
+ f"{lane} role {role} must use validator.kind=csv_columns for table_metric."
194
+ )
195
+ record_key = require_string(
196
+ validator.get("record_key"),
197
+ f"{lane}.{role}.validator.record_key",
198
+ )
199
+ value_field = require_string(
200
+ validator.get("value_field"),
201
+ f"{lane}.{role}.validator.value_field",
202
+ )
203
+ return record_key, value_field
204
+
205
+
206
+ def read_csv_rows(path, label, *, invalid_handler):
207
+ try:
208
+ with path.open("r", encoding="utf-8", newline="") as handle:
209
+ reader = csv.DictReader(handle)
210
+ fieldnames = reader.fieldnames
211
+ if not fieldnames:
212
+ invalid_handler(f"{label} must include a CSV header row.")
213
+ normalized_fieldnames = []
214
+ for fieldname in fieldnames:
215
+ if not isinstance(fieldname, str) or not fieldname.strip():
216
+ invalid_handler(f"{label} contains an empty CSV column name.")
217
+ normalized_fieldnames.append(fieldname.strip())
218
+ rows = list(reader)
219
+ except FileNotFoundError:
220
+ invalid_handler(f"Missing {label} at {path}.")
221
+ except OSError as error:
222
+ invalid_handler(f"Unable to read {label}: {error}.")
223
+ return normalized_fieldnames, rows
224
+
225
+
226
+ def parse_reference_value(raw_value, label):
227
+ text = raw_value.strip() if isinstance(raw_value, str) else ""
228
+ if not text:
229
+ fail_runtime(f"{label} is blank.")
230
+ try:
231
+ value = float(text)
232
+ except ValueError:
233
+ fail_runtime(f"{label} must be numeric, received {text!r}.")
234
+ if not math.isfinite(value):
235
+ fail_runtime(f"{label} must be finite.")
236
+ return value
237
+
238
+
239
+ def parse_submission_value(raw_value, label, invalid_value_policy, metric):
240
+ text = raw_value.strip() if isinstance(raw_value, str) else ""
241
+ if not text:
242
+ if invalid_value_policy == "reject":
243
+ reject_submission(f"{label} is blank.")
244
+ return None
245
+ try:
246
+ value = float(text)
247
+ except ValueError:
248
+ if invalid_value_policy == "reject":
249
+ reject_submission(f"{label} must be numeric, received {text!r}.")
250
+ return None
251
+ if not math.isfinite(value):
252
+ if invalid_value_policy == "reject":
253
+ reject_submission(f"{label} must be finite.")
254
+ return None
255
+ if metric in PROBABILITY_METRICS and not (0.0 <= value <= 1.0):
256
+ if invalid_value_policy == "reject":
257
+ reject_submission(
258
+ f"{label} must be a probability in [0, 1] for table_metric metric {metric}; received {value}."
259
+ )
260
+ return None
261
+ return value
262
+
263
+
264
+ def load_reference_values(path, role, record_key, value_field):
265
+ fieldnames, rows = read_csv_rows(
266
+ path,
267
+ f"evaluation artifact {role}",
268
+ invalid_handler=fail_runtime,
269
+ )
270
+ if record_key not in fieldnames:
271
+ fail_runtime(
272
+ f"evaluation artifact {role} is missing record key column {record_key}."
273
+ )
274
+ if value_field not in fieldnames:
275
+ fail_runtime(
276
+ f"evaluation artifact {role} is missing value column {value_field}."
277
+ )
278
+ values = {}
279
+ for row_index, row in enumerate(rows, start=2):
280
+ raw_key = row.get(record_key)
281
+ key = raw_key.strip() if isinstance(raw_key, str) else ""
282
+ if not key:
283
+ fail_runtime(
284
+ f"evaluation artifact {role} row {row_index} is missing {record_key}."
285
+ )
286
+ if key in values:
287
+ fail_runtime(
288
+ f"evaluation artifact {role} contains duplicate record id {key!r}."
289
+ )
290
+ values[key] = parse_reference_value(
291
+ row.get(value_field),
292
+ f"evaluation artifact {role} row {row_index} column {value_field}",
293
+ )
294
+ if not values:
295
+ fail_runtime(
296
+ f"evaluation artifact {role} must contain at least one scored row."
297
+ )
298
+ return values
299
+
300
+
301
+ def load_reference_group_records(path, role, record_key, value_field, group_by):
302
+ fieldnames, rows = read_csv_rows(
303
+ path,
304
+ f"evaluation artifact {role}",
305
+ invalid_handler=fail_runtime,
306
+ )
307
+ if record_key not in fieldnames:
308
+ fail_runtime(
309
+ f"evaluation artifact {role} is missing record key column {record_key}."
310
+ )
311
+ if value_field not in fieldnames:
312
+ fail_runtime(
313
+ f"evaluation artifact {role} is missing value column {value_field}."
314
+ )
315
+ for group_column in group_by:
316
+ if group_column not in fieldnames:
317
+ fail_runtime(
318
+ f"evaluation artifact {role} is missing group_by column {group_column}."
319
+ )
320
+
321
+ records = {}
322
+ for row_index, row in enumerate(rows, start=2):
323
+ raw_key = row.get(record_key)
324
+ key = raw_key.strip() if isinstance(raw_key, str) else ""
325
+ if not key:
326
+ fail_runtime(
327
+ f"evaluation artifact {role} row {row_index} is missing {record_key}."
328
+ )
329
+ if key in records:
330
+ fail_runtime(
331
+ f"evaluation artifact {role} contains duplicate record id {key!r}."
332
+ )
333
+
334
+ group_values = []
335
+ for group_column in group_by:
336
+ raw_group = row.get(group_column)
337
+ group_value = raw_group.strip() if isinstance(raw_group, str) else ""
338
+ if not group_value:
339
+ fail_runtime(
340
+ f"evaluation artifact {role} row {row_index} is missing group_by column {group_column}."
341
+ )
342
+ group_values.append(group_value)
343
+
344
+ records[key] = {
345
+ "value": parse_reference_value(
346
+ row.get(value_field),
347
+ f"evaluation artifact {role} row {row_index} column {value_field}",
348
+ ),
349
+ "group_key": tuple(group_values),
350
+ }
351
+
352
+ if not records:
353
+ fail_runtime(
354
+ f"evaluation artifact {role} must contain at least one scored row."
355
+ )
356
+ return records
357
+
358
+
359
+ def parse_positive_reference_value(raw_value, label):
360
+ value = parse_reference_value(raw_value, label)
361
+ if value <= 0.0:
362
+ fail_runtime(f"{label} must be greater than 0.")
363
+ return value
364
+
365
+
366
+ def parse_nonnegative_reference_value(raw_value, label):
367
+ value = parse_reference_value(raw_value, label)
368
+ if value < 0.0:
369
+ fail_runtime(f"{label} must be greater than or equal to 0.")
370
+ return value
371
+
372
+
373
+ def load_calibrated_reference_records(
374
+ path,
375
+ role,
376
+ record_key,
377
+ value_field,
378
+ metric,
379
+ metric_params,
380
+ ):
381
+ fieldnames, rows = read_csv_rows(
382
+ path,
383
+ f"evaluation artifact {role}",
384
+ invalid_handler=fail_runtime,
385
+ )
386
+ if record_key not in fieldnames:
387
+ fail_runtime(
388
+ f"evaluation artifact {role} is missing record key column {record_key}."
389
+ )
390
+ if value_field not in fieldnames:
391
+ fail_runtime(
392
+ f"evaluation artifact {role} is missing value column {value_field}."
393
+ )
394
+
395
+ required_calibration_fields = []
396
+ if metric == "z_rmse":
397
+ required_calibration_fields = [
398
+ metric_params["baseline_mean_field"],
399
+ metric_params["baseline_sd_field"],
400
+ ]
401
+ elif metric == "within_noise_fraction":
402
+ required_calibration_fields = [metric_params["noise_floor_field"]]
403
+
404
+ for calibration_field in required_calibration_fields:
405
+ if calibration_field not in fieldnames:
406
+ fail_runtime(
407
+ f"evaluation artifact {role} is missing calibration column {calibration_field}."
408
+ )
409
+
410
+ records = {}
411
+ for row_index, row in enumerate(rows, start=2):
412
+ raw_key = row.get(record_key)
413
+ key = raw_key.strip() if isinstance(raw_key, str) else ""
414
+ if not key:
415
+ fail_runtime(
416
+ f"evaluation artifact {role} row {row_index} is missing {record_key}."
417
+ )
418
+ if key in records:
419
+ fail_runtime(
420
+ f"evaluation artifact {role} contains duplicate record id {key!r}."
421
+ )
422
+
423
+ record = {
424
+ "value": parse_reference_value(
425
+ row.get(value_field),
426
+ f"evaluation artifact {role} row {row_index} column {value_field}",
427
+ ),
428
+ }
429
+ if metric == "z_rmse":
430
+ baseline_mean_field = metric_params["baseline_mean_field"]
431
+ baseline_sd_field = metric_params["baseline_sd_field"]
432
+ record["baseline_mean"] = parse_reference_value(
433
+ row.get(baseline_mean_field),
434
+ f"evaluation artifact {role} row {row_index} column {baseline_mean_field}",
435
+ )
436
+ record["baseline_sd"] = parse_positive_reference_value(
437
+ row.get(baseline_sd_field),
438
+ f"evaluation artifact {role} row {row_index} column {baseline_sd_field}",
439
+ )
440
+ elif metric == "within_noise_fraction":
441
+ noise_floor_field = metric_params["noise_floor_field"]
442
+ record["noise_floor"] = parse_nonnegative_reference_value(
443
+ row.get(noise_floor_field),
444
+ f"evaluation artifact {role} row {row_index} column {noise_floor_field}",
445
+ )
446
+ records[key] = record
447
+
448
+ if not records:
449
+ fail_runtime(
450
+ f"evaluation artifact {role} must contain at least one scored row."
451
+ )
452
+ return records
453
+
454
+
455
+ def load_submission_values(
456
+ path,
457
+ role,
458
+ record_key,
459
+ value_field,
460
+ duplicate_id_policy,
461
+ invalid_value_policy,
462
+ metric,
463
+ ):
464
+ fieldnames, rows = read_csv_rows(
465
+ path,
466
+ f"submission artifact {role}",
467
+ invalid_handler=reject_submission,
468
+ )
469
+ if record_key not in fieldnames:
470
+ reject_submission(
471
+ f"submission artifact {role} is missing record key column {record_key}."
472
+ )
473
+ if value_field not in fieldnames:
474
+ reject_submission(
475
+ f"submission artifact {role} is missing value column {value_field}."
476
+ )
477
+ values = {}
478
+ for row_index, row in enumerate(rows, start=2):
479
+ raw_key = row.get(record_key)
480
+ key = raw_key.strip() if isinstance(raw_key, str) else ""
481
+ if not key:
482
+ if invalid_value_policy == "reject":
483
+ reject_submission(
484
+ f"submission artifact {role} row {row_index} is missing {record_key}."
485
+ )
486
+ continue
487
+ if key in values:
488
+ if duplicate_id_policy == "reject":
489
+ reject_submission(
490
+ f"submission artifact {role} contains duplicate record id {key!r}."
491
+ )
492
+ continue
493
+ parsed_value = parse_submission_value(
494
+ row.get(value_field),
495
+ f"submission artifact {role} row {row_index} column {value_field}",
496
+ invalid_value_policy,
497
+ metric,
498
+ )
499
+ if parsed_value is None:
500
+ continue
501
+ values[key] = parsed_value
502
+ return values
503
+
504
+
505
+ def compute_rmse(reference_values, submission_values):
506
+ squared_error = sum(
507
+ (submission - reference) ** 2
508
+ for reference, submission in zip(reference_values, submission_values)
509
+ )
510
+ return math.sqrt(squared_error / len(reference_values))
511
+
512
+
513
+ def compute_mae(reference_values, submission_values):
514
+ absolute_error = sum(
515
+ abs(submission - reference)
516
+ for reference, submission in zip(reference_values, submission_values)
517
+ )
518
+ return absolute_error / len(reference_values)
519
+
520
+
521
+ def compute_r2(reference_values, submission_values):
522
+ mean_reference = sum(reference_values) / len(reference_values)
523
+ ss_res = sum(
524
+ (reference - submission) ** 2
525
+ for reference, submission in zip(reference_values, submission_values)
526
+ )
527
+ ss_tot = sum((reference - mean_reference) ** 2 for reference in reference_values)
528
+ if ss_tot == 0:
529
+ return 1.0 if ss_res == 0 else 0.0
530
+ return 1.0 - (ss_res / ss_tot)
531
+
532
+
533
+ def compute_pearson(reference_values, submission_values):
534
+ if len(reference_values) < 2:
535
+ reject_submission("Pearson correlation requires at least two scored rows.")
536
+ mean_reference = sum(reference_values) / len(reference_values)
537
+ mean_submission = sum(submission_values) / len(submission_values)
538
+ numerator = sum(
539
+ (reference - mean_reference) * (submission - mean_submission)
540
+ for reference, submission in zip(reference_values, submission_values)
541
+ )
542
+ denominator_left = sum(
543
+ (reference - mean_reference) ** 2 for reference in reference_values
544
+ )
545
+ denominator_right = sum(
546
+ (submission - mean_submission) ** 2 for submission in submission_values
547
+ )
548
+ denominator = math.sqrt(denominator_left * denominator_right)
549
+ if denominator == 0:
550
+ return 1.0 if reference_values == submission_values else 0.0
551
+ return numerator / denominator
552
+
553
+
554
+ def average_ranks(values):
555
+ ordered = sorted(enumerate(values), key=lambda item: item[1])
556
+ ranks = [0.0] * len(values)
557
+ start = 0
558
+ while start < len(ordered):
559
+ end = start
560
+ while end + 1 < len(ordered) and ordered[end + 1][1] == ordered[start][1]:
561
+ end += 1
562
+ average_rank = (start + end + 2) / 2.0
563
+ for index in range(start, end + 1):
564
+ ranks[ordered[index][0]] = average_rank
565
+ start = end + 1
566
+ return ranks
567
+
568
+
569
+ def compute_spearman(reference_values, submission_values):
570
+ return compute_pearson(
571
+ average_ranks(reference_values),
572
+ average_ranks(submission_values),
573
+ )
574
+
575
+
576
+ def require_binary_reference_labels(metric, reference_values):
577
+ labels = []
578
+ for index, value in enumerate(reference_values, start=1):
579
+ if value == 0.0:
580
+ labels.append(0)
581
+ continue
582
+ if value == 1.0:
583
+ labels.append(1)
584
+ continue
585
+ fail_runtime(
586
+ f"table_metric metric {metric} requires evaluation labels encoded as 0 or 1; scored row {index} had {value}."
587
+ )
588
+ return labels
589
+
590
+
591
+ def require_positive_reference_label(metric, reference_labels, label):
592
+ if sum(reference_labels) == 0:
593
+ fail_runtime(
594
+ f"table_metric metric {metric} requires at least one positive label in {label}."
595
+ )
596
+
597
+
598
+ def threshold_submission_labels(submission_values):
599
+ return [1 if value >= 0.5 else 0 for value in submission_values]
600
+
601
+
602
+ def compute_accuracy(reference_values, submission_values):
603
+ reference_labels = require_binary_reference_labels("accuracy", reference_values)
604
+ predicted_labels = threshold_submission_labels(submission_values)
605
+ correct = sum(
606
+ 1
607
+ for reference, predicted in zip(reference_labels, predicted_labels)
608
+ if reference == predicted
609
+ )
610
+ return correct / len(reference_labels)
611
+
612
+
613
+ def compute_f1(reference_values, submission_values):
614
+ reference_labels = require_binary_reference_labels("f1", reference_values)
615
+ predicted_labels = threshold_submission_labels(submission_values)
616
+ true_positive = sum(
617
+ 1
618
+ for reference, predicted in zip(reference_labels, predicted_labels)
619
+ if reference == 1 and predicted == 1
620
+ )
621
+ false_positive = sum(
622
+ 1
623
+ for reference, predicted in zip(reference_labels, predicted_labels)
624
+ if reference == 0 and predicted == 1
625
+ )
626
+ false_negative = sum(
627
+ 1
628
+ for reference, predicted in zip(reference_labels, predicted_labels)
629
+ if reference == 1 and predicted == 0
630
+ )
631
+ if true_positive == 0 and false_positive == 0 and false_negative == 0:
632
+ return 1.0
633
+ if true_positive == 0:
634
+ return 0.0
635
+ precision = true_positive / (true_positive + false_positive)
636
+ recall = true_positive / (true_positive + false_negative)
637
+ if precision + recall == 0.0:
638
+ return 0.0
639
+ return (2.0 * precision * recall) / (precision + recall)
640
+
641
+
642
+ def compute_auroc(reference_values, submission_values):
643
+ reference_labels = require_binary_reference_labels("auroc", reference_values)
644
+ positive_count = sum(reference_labels)
645
+ negative_count = len(reference_labels) - positive_count
646
+ if positive_count == 0 or negative_count == 0:
647
+ fail_runtime(
648
+ "table_metric metric auroc requires both positive and negative labels in the evaluation panel."
649
+ )
650
+ submission_ranks = average_ranks(submission_values)
651
+ positive_rank_sum = sum(
652
+ rank
653
+ for label, rank in zip(reference_labels, submission_ranks)
654
+ if label == 1
655
+ )
656
+ return (
657
+ positive_rank_sum - (positive_count * (positive_count + 1) / 2.0)
658
+ ) / (positive_count * negative_count)
659
+
660
+
661
+ def compute_brier(reference_values, submission_values):
662
+ reference_labels = require_binary_reference_labels("brier", reference_values)
663
+ squared_error = sum(
664
+ (probability - label) ** 2
665
+ for label, probability in zip(reference_labels, submission_values)
666
+ )
667
+ return squared_error / len(reference_labels)
668
+
669
+
670
+ def compute_log_loss(reference_values, submission_values):
671
+ reference_labels = require_binary_reference_labels("log_loss", reference_values)
672
+ total = 0.0
673
+ for label, probability in zip(reference_labels, submission_values):
674
+ clipped_probability = min(
675
+ max(probability, LOG_LOSS_EPSILON),
676
+ 1.0 - LOG_LOSS_EPSILON,
677
+ )
678
+ total += -(
679
+ (label * math.log(clipped_probability))
680
+ + ((1 - label) * math.log(1.0 - clipped_probability))
681
+ )
682
+ return total / len(reference_labels)
683
+
684
+
685
+ def compute_auprc(reference_values, submission_values, record_ids):
686
+ reference_labels = require_binary_reference_labels("auprc", reference_values)
687
+ positive_count = sum(reference_labels)
688
+ if positive_count == 0:
689
+ reject_submission(
690
+ "table_metric metric auprc requires at least one positive label among scored rows."
691
+ )
692
+
693
+ ordered = sorted(
694
+ zip(reference_labels, submission_values, record_ids),
695
+ key=lambda item: (-item[1], item[2]),
696
+ )
697
+ true_positive_count = 0
698
+ precision_sum = 0.0
699
+ for rank, (label, _probability, _record_id) in enumerate(ordered, start=1):
700
+ if label == 1:
701
+ true_positive_count += 1
702
+ precision_sum += true_positive_count / rank
703
+ return precision_sum / positive_count
704
+
705
+
706
+ def compute_ece(reference_values, submission_values, bin_count):
707
+ reference_labels = require_binary_reference_labels("ece", reference_values)
708
+ bin_totals = [
709
+ {"count": 0, "probability_sum": 0.0, "positive_sum": 0}
710
+ for _index in range(bin_count)
711
+ ]
712
+ for label, probability in zip(reference_labels, submission_values):
713
+ bin_index = min(int(probability * bin_count), bin_count - 1)
714
+ bin_totals[bin_index]["count"] += 1
715
+ bin_totals[bin_index]["probability_sum"] += probability
716
+ bin_totals[bin_index]["positive_sum"] += label
717
+
718
+ scored_row_count = len(reference_labels)
719
+ total_error = 0.0
720
+ for bin_total in bin_totals:
721
+ bin_size = bin_total["count"]
722
+ if bin_size == 0:
723
+ continue
724
+ mean_probability = bin_total["probability_sum"] / bin_size
725
+ empirical_positive_rate = bin_total["positive_sum"] / bin_size
726
+ total_error += (
727
+ bin_size
728
+ / scored_row_count
729
+ * abs(mean_probability - empirical_positive_rate)
730
+ )
731
+ return total_error
732
+
733
+
734
+ def compute_metric(
735
+ metric,
736
+ reference_values,
737
+ submission_values,
738
+ record_ids,
739
+ metric_params,
740
+ ):
741
+ if metric == "rmse":
742
+ return compute_rmse(reference_values, submission_values)
743
+ if metric == "mae":
744
+ return compute_mae(reference_values, submission_values)
745
+ if metric == "r2":
746
+ return compute_r2(reference_values, submission_values)
747
+ if metric == "pearson":
748
+ return compute_pearson(reference_values, submission_values)
749
+ if metric == "spearman":
750
+ return compute_spearman(reference_values, submission_values)
751
+ if metric == "accuracy":
752
+ return compute_accuracy(reference_values, submission_values)
753
+ if metric == "f1":
754
+ return compute_f1(reference_values, submission_values)
755
+ if metric == "auroc":
756
+ return compute_auroc(reference_values, submission_values)
757
+ if metric == "brier":
758
+ return compute_brier(reference_values, submission_values)
759
+ if metric == "log_loss":
760
+ return compute_log_loss(reference_values, submission_values)
761
+ if metric == "auprc":
762
+ return compute_auprc(reference_values, submission_values, record_ids)
763
+ if metric == "ece":
764
+ return compute_ece(
765
+ reference_values,
766
+ submission_values,
767
+ metric_params["bin_count"],
768
+ )
769
+ fail_runtime(f"table_metric metric {metric} is not supported by this scorer path.")
770
+
771
+
772
+ def compute_calibrated_metric(metric, reference_records, submission_by_id, scored_ids):
773
+ if metric == "z_rmse":
774
+ z_errors = []
775
+ for record_id in scored_ids:
776
+ record = reference_records[record_id]
777
+ baseline_mean = record["baseline_mean"]
778
+ baseline_sd = record["baseline_sd"]
779
+ candidate_z = (submission_by_id[record_id] - baseline_mean) / baseline_sd
780
+ reference_z = (record["value"] - baseline_mean) / baseline_sd
781
+ z_errors.append(candidate_z - reference_z)
782
+ raw_metric = math.sqrt(sum(error**2 for error in z_errors) / len(z_errors))
783
+ return raw_metric, {
784
+ "mean_absolute_z_error": sum(abs(error) for error in z_errors)
785
+ / len(z_errors),
786
+ "max_absolute_z_error": max(abs(error) for error in z_errors),
787
+ }
788
+
789
+ passed_count = 0
790
+ for record_id in scored_ids:
791
+ record = reference_records[record_id]
792
+ if abs(submission_by_id[record_id] - record["value"]) <= record["noise_floor"]:
793
+ passed_count += 1
794
+ raw_metric = passed_count / len(scored_ids)
795
+ return raw_metric, {
796
+ "passed_count": passed_count,
797
+ "failed_count": len(scored_ids) - passed_count,
798
+ }
799
+
800
+
801
+ def clamp01(value):
802
+ if value < 0.0:
803
+ return 0.0
804
+ if value > 1.0:
805
+ return 1.0
806
+ return value
807
+
808
+
809
+ def normalize_score(metric, raw_metric):
810
+ if metric in MINIMIZE_METRICS:
811
+ return 1.0 / (1.0 + max(raw_metric, 0.0))
812
+ if metric == "r2":
813
+ return clamp01(raw_metric)
814
+ if metric in UNIT_INTERVAL_METRICS:
815
+ return clamp01(raw_metric)
816
+ return clamp01((raw_metric + 1.0) / 2.0)
817
+
818
+
819
+ def group_object(group_by, group_key):
820
+ return {column: value for column, value in zip(group_by, group_key)}
821
+
822
+
823
+ def format_group_label(group_by, group_key):
824
+ parts = [f"{column}={value!r}" for column, value in zip(group_by, group_key)]
825
+ return ", ".join(parts)
826
+
827
+
828
+ def validate_reference_group_preconditions(metric, group_label, reference_values):
829
+ if metric in CORRELATION_METRICS and len(reference_values) < 2:
830
+ fail_runtime(
831
+ f"grouped table_metric metric {metric} requires at least two reference rows in group {group_label}."
832
+ )
833
+
834
+ if metric not in BINARY_LABEL_METRICS:
835
+ return
836
+
837
+ reference_labels = require_binary_reference_labels(metric, reference_values)
838
+ if metric == "auprc":
839
+ require_positive_reference_label(
840
+ metric,
841
+ reference_labels,
842
+ f"reference group {group_label}",
843
+ )
844
+ return
845
+ if metric != "auroc":
846
+ return
847
+
848
+ positive_count = sum(reference_labels)
849
+ negative_count = len(reference_labels) - positive_count
850
+ if positive_count == 0 or negative_count == 0:
851
+ fail_runtime(
852
+ f"grouped table_metric metric auroc requires both positive and negative labels in reference group {group_label}."
853
+ )
854
+
855
+
856
+ def validate_scored_group_preconditions(metric, group_label, reference_values):
857
+ if metric in CORRELATION_METRICS and len(reference_values) < 2:
858
+ reject_submission(
859
+ f"grouped table_metric metric {metric} requires at least two scored rows in group {group_label}."
860
+ )
861
+
862
+ if metric == "auprc":
863
+ reference_labels = require_binary_reference_labels(metric, reference_values)
864
+ if sum(reference_labels) == 0:
865
+ reject_submission(
866
+ f"grouped table_metric metric auprc requires at least one positive label among scored rows in group {group_label}."
867
+ )
868
+ return
869
+
870
+ if metric != "auroc":
871
+ return
872
+
873
+ reference_labels = require_binary_reference_labels(metric, reference_values)
874
+ positive_count = sum(reference_labels)
875
+ negative_count = len(reference_labels) - positive_count
876
+ if positive_count == 0 or negative_count == 0:
877
+ reject_submission(
878
+ f"grouped table_metric metric auroc requires both positive and negative labels among scored rows in group {group_label}."
879
+ )
880
+
881
+
882
+ def validate_reference_panel_preconditions(metric, reference_values):
883
+ if metric not in BINARY_LABEL_METRICS:
884
+ return
885
+
886
+ reference_labels = require_binary_reference_labels(metric, reference_values)
887
+ if metric == "auprc":
888
+ require_positive_reference_label(metric, reference_labels, "the evaluation panel")
889
+ return
890
+ if metric != "auroc":
891
+ return
892
+
893
+ positive_count = sum(reference_labels)
894
+ negative_count = len(reference_labels) - positive_count
895
+ if positive_count == 0 or negative_count == 0:
896
+ fail_runtime(
897
+ "table_metric metric auroc requires both positive and negative labels in the evaluation panel."
898
+ )
899
+
900
+
901
+ def validate_objective(metric, objective):
902
+ if metric in MINIMIZE_METRICS and objective != "minimize":
903
+ fail_runtime(
904
+ f"table_metric metric {metric} requires objective=minimize, received {objective}."
905
+ )
906
+ if metric in MAXIMIZE_METRICS and objective != "maximize":
907
+ fail_runtime(
908
+ f"table_metric metric {metric} requires objective=maximize, received {objective}."
909
+ )
910
+
911
+
912
+ def score_grouped_table_metric(
913
+ *,
914
+ metric,
915
+ metric_params,
916
+ group_by,
917
+ aggregation,
918
+ evaluation_path,
919
+ evaluation_role,
920
+ evaluation_record_key,
921
+ evaluation_value_field,
922
+ submission_path,
923
+ submission_role,
924
+ submission_record_key,
925
+ submission_value_field,
926
+ coverage_policy,
927
+ duplicate_id_policy,
928
+ invalid_value_policy,
929
+ final_score_key,
930
+ ):
931
+ reference_by_id = load_reference_group_records(
932
+ evaluation_path,
933
+ evaluation_role,
934
+ evaluation_record_key,
935
+ evaluation_value_field,
936
+ group_by,
937
+ )
938
+ groups = {}
939
+ for record_id, record in reference_by_id.items():
940
+ groups.setdefault(record["group_key"], []).append(record_id)
941
+ for group_key in sorted(groups):
942
+ group_label = format_group_label(group_by, group_key)
943
+ group_reference_values = [
944
+ reference_by_id[record_id]["value"] for record_id in groups[group_key]
945
+ ]
946
+ validate_reference_group_preconditions(
947
+ metric,
948
+ group_label,
949
+ group_reference_values,
950
+ )
951
+
952
+ submission_by_id = load_submission_values(
953
+ submission_path,
954
+ submission_role,
955
+ submission_record_key,
956
+ submission_value_field,
957
+ duplicate_id_policy,
958
+ invalid_value_policy,
959
+ metric,
960
+ )
961
+ missing_ids = [
962
+ record_id
963
+ for record_id in reference_by_id
964
+ if record_id not in submission_by_id
965
+ ]
966
+ if missing_ids and coverage_policy == "reject":
967
+ reject_submission(
968
+ f"Submission is missing predictions for {len(missing_ids)} required rows; first missing id is {missing_ids[0]!r}."
969
+ )
970
+
971
+ group_details = []
972
+ group_scores = []
973
+ total_scored = 0
974
+
975
+ for group_key in sorted(groups):
976
+ group_record_ids = groups[group_key]
977
+ group_label = format_group_label(group_by, group_key)
978
+ group_missing_ids = [
979
+ record_id for record_id in group_record_ids if record_id not in submission_by_id
980
+ ]
981
+ group_scored_ids = [
982
+ record_id for record_id in group_record_ids if record_id in submission_by_id
983
+ ]
984
+ if not group_scored_ids:
985
+ reject_submission(
986
+ f"Submission produced no scoreable rows for group {group_label} after applying runtime policies."
987
+ )
988
+
989
+ scored_reference_values = [
990
+ reference_by_id[record_id]["value"] for record_id in group_scored_ids
991
+ ]
992
+ scored_submission_values = [
993
+ submission_by_id[record_id] for record_id in group_scored_ids
994
+ ]
995
+ validate_scored_group_preconditions(
996
+ metric,
997
+ group_label,
998
+ scored_reference_values,
999
+ )
1000
+ raw_metric = compute_metric(
1001
+ metric,
1002
+ scored_reference_values,
1003
+ scored_submission_values,
1004
+ group_scored_ids,
1005
+ metric_params,
1006
+ )
1007
+ normalized_score = normalize_score(metric, raw_metric)
1008
+ if coverage_policy == "penalize":
1009
+ normalized_score *= len(group_scored_ids) / len(group_record_ids)
1010
+
1011
+ aggregation_weight = len(group_scored_ids)
1012
+ total_scored += aggregation_weight
1013
+ group_scores.append(
1014
+ {
1015
+ "score": normalized_score,
1016
+ "weight": aggregation_weight,
1017
+ }
1018
+ )
1019
+ group_details.append(
1020
+ {
1021
+ "group": group_object(group_by, group_key),
1022
+ "reference_row_count": len(group_record_ids),
1023
+ "rows_scored": len(group_scored_ids),
1024
+ "missing_count": len(group_missing_ids),
1025
+ "selected_metric_value": raw_metric,
1026
+ "normalized_score": normalized_score,
1027
+ "aggregation_weight": aggregation_weight,
1028
+ }
1029
+ )
1030
+
1031
+ if not group_scores:
1032
+ reject_submission(
1033
+ "Submission produced no scoreable groups after applying runtime policies."
1034
+ )
1035
+
1036
+ if aggregation == "macro_mean":
1037
+ final_score = sum(group_score["score"] for group_score in group_scores) / len(
1038
+ group_scores
1039
+ )
1040
+ else:
1041
+ total_weight = sum(group_score["weight"] for group_score in group_scores)
1042
+ if total_weight <= 0:
1043
+ reject_submission(
1044
+ "Submission produced no weighted group rows after applying runtime policies."
1045
+ )
1046
+ final_score = (
1047
+ sum(
1048
+ group_score["score"] * group_score["weight"]
1049
+ for group_score in group_scores
1050
+ )
1051
+ / total_weight
1052
+ )
1053
+
1054
+ write_score(
1055
+ score=final_score,
1056
+ details={
1057
+ final_score_key: final_score,
1058
+ "selected_metric": metric,
1059
+ "aggregation": aggregation,
1060
+ "group_by": group_by,
1061
+ "group_count": len(group_details),
1062
+ "rows_scored": total_scored,
1063
+ "missing_count": len(missing_ids),
1064
+ "groups": group_details,
1065
+ },
1066
+ )
1067
+
1068
+
1069
+ def main():
1070
+ runtime_context = load_runtime_context()
1071
+ config_path = resolve_scoring_asset(
1072
+ runtime_context,
1073
+ "compiled_config",
1074
+ kind="config",
1075
+ )
1076
+ try:
1077
+ config = load_json_file(config_path, label="compiled_config")
1078
+ except RuntimeError as error:
1079
+ fail_runtime(str(error))
1080
+ metric = require_metric(config)
1081
+ metric_params = require_metric_params(config, metric)
1082
+ group_by, aggregation = require_grouping_config(config)
1083
+ evaluation_role = require_string(
1084
+ config.get("evaluation_role"),
1085
+ "compiled_config.evaluation_role",
1086
+ )
1087
+ submission_role = require_string(
1088
+ config.get("submission_role"),
1089
+ "compiled_config.submission_role",
1090
+ )
1091
+ final_score_key = require_string(
1092
+ runtime_context.get("final_score_key"),
1093
+ "runtime_context.final_score_key",
1094
+ )
1095
+ objective = require_string(
1096
+ runtime_context.get("objective"),
1097
+ "runtime_context.objective",
1098
+ )
1099
+ validate_objective(metric, objective)
1100
+ policies = runtime_context.get("policies")
1101
+ if not isinstance(policies, dict):
1102
+ fail_runtime("Runtime context is missing execution policies.")
1103
+ coverage_policy = require_policy(
1104
+ policies,
1105
+ "coverage_policy",
1106
+ {"reject", "ignore", "penalize"},
1107
+ )
1108
+ duplicate_id_policy = require_policy(
1109
+ policies,
1110
+ "duplicate_id_policy",
1111
+ {"reject", "ignore"},
1112
+ )
1113
+ invalid_value_policy = require_policy(
1114
+ policies,
1115
+ "invalid_value_policy",
1116
+ {"reject", "ignore"},
1117
+ )
1118
+ evaluation_record_key, evaluation_value_field = require_csv_slot_columns(
1119
+ runtime_context,
1120
+ "evaluation",
1121
+ evaluation_role,
1122
+ )
1123
+ submission_record_key, submission_value_field = require_csv_slot_columns(
1124
+ runtime_context,
1125
+ "submission",
1126
+ submission_role,
1127
+ )
1128
+ evaluation_path = resolve_evaluation_artifact(runtime_context, evaluation_role)
1129
+ submission_path = resolve_submission_artifact(runtime_context, submission_role)
1130
+ if group_by:
1131
+ if metric in CALIBRATED_METRICS:
1132
+ fail_runtime(
1133
+ f"table_metric metric {metric} is not supported with compiled_config.group_by; use calibrated_table_metric@1 without grouped params."
1134
+ )
1135
+ score_grouped_table_metric(
1136
+ metric=metric,
1137
+ metric_params=metric_params,
1138
+ group_by=group_by,
1139
+ aggregation=aggregation,
1140
+ evaluation_path=evaluation_path,
1141
+ evaluation_role=evaluation_role,
1142
+ evaluation_record_key=evaluation_record_key,
1143
+ evaluation_value_field=evaluation_value_field,
1144
+ submission_path=submission_path,
1145
+ submission_role=submission_role,
1146
+ submission_record_key=submission_record_key,
1147
+ submission_value_field=submission_value_field,
1148
+ coverage_policy=coverage_policy,
1149
+ duplicate_id_policy=duplicate_id_policy,
1150
+ invalid_value_policy=invalid_value_policy,
1151
+ final_score_key=final_score_key,
1152
+ )
1153
+ return
1154
+
1155
+ if metric in CALIBRATED_METRICS:
1156
+ if metric == "z_rmse":
1157
+ require_config_field_matches_metric_params(
1158
+ config,
1159
+ metric_params,
1160
+ "baseline_mean_field",
1161
+ )
1162
+ require_config_field_matches_metric_params(
1163
+ config,
1164
+ metric_params,
1165
+ "baseline_sd_field",
1166
+ )
1167
+ elif metric == "within_noise_fraction":
1168
+ require_config_field_matches_metric_params(
1169
+ config,
1170
+ metric_params,
1171
+ "noise_floor_field",
1172
+ )
1173
+
1174
+ reference_by_id = load_calibrated_reference_records(
1175
+ evaluation_path,
1176
+ evaluation_role,
1177
+ evaluation_record_key,
1178
+ evaluation_value_field,
1179
+ metric,
1180
+ metric_params,
1181
+ )
1182
+ submission_by_id = load_submission_values(
1183
+ submission_path,
1184
+ submission_role,
1185
+ submission_record_key,
1186
+ submission_value_field,
1187
+ duplicate_id_policy,
1188
+ invalid_value_policy,
1189
+ metric,
1190
+ )
1191
+ missing_ids = [
1192
+ record_id
1193
+ for record_id in reference_by_id
1194
+ if record_id not in submission_by_id
1195
+ ]
1196
+ if missing_ids and coverage_policy == "reject":
1197
+ reject_submission(
1198
+ f"Submission is missing predictions for {len(missing_ids)} required rows; first missing id is {missing_ids[0]!r}."
1199
+ )
1200
+ scored_ids = [
1201
+ record_id
1202
+ for record_id in reference_by_id
1203
+ if record_id in submission_by_id
1204
+ ]
1205
+ if not scored_ids:
1206
+ reject_submission(
1207
+ "Submission produced no scoreable rows after applying runtime policies."
1208
+ )
1209
+ raw_metric, metric_details = compute_calibrated_metric(
1210
+ metric,
1211
+ reference_by_id,
1212
+ submission_by_id,
1213
+ scored_ids,
1214
+ )
1215
+ normalized_score = normalize_score(metric, raw_metric)
1216
+ if coverage_policy == "penalize":
1217
+ normalized_score *= len(scored_ids) / len(reference_by_id)
1218
+ write_score(
1219
+ score=normalized_score,
1220
+ details={
1221
+ final_score_key: normalized_score,
1222
+ "selected_metric": metric,
1223
+ "selected_metric_value": raw_metric,
1224
+ "rows_scored": len(scored_ids),
1225
+ **metric_details,
1226
+ },
1227
+ )
1228
+ return
1229
+
1230
+ reference_by_id = load_reference_values(
1231
+ evaluation_path,
1232
+ evaluation_role,
1233
+ evaluation_record_key,
1234
+ evaluation_value_field,
1235
+ )
1236
+ validate_reference_panel_preconditions(
1237
+ metric,
1238
+ list(reference_by_id.values()),
1239
+ )
1240
+ submission_by_id = load_submission_values(
1241
+ submission_path,
1242
+ submission_role,
1243
+ submission_record_key,
1244
+ submission_value_field,
1245
+ duplicate_id_policy,
1246
+ invalid_value_policy,
1247
+ metric,
1248
+ )
1249
+ missing_ids = [
1250
+ record_id
1251
+ for record_id in reference_by_id
1252
+ if record_id not in submission_by_id
1253
+ ]
1254
+ if missing_ids and coverage_policy == "reject":
1255
+ reject_submission(
1256
+ f"Submission is missing predictions for {len(missing_ids)} required rows; first missing id is {missing_ids[0]!r}."
1257
+ )
1258
+ scored_ids = [
1259
+ record_id
1260
+ for record_id in reference_by_id
1261
+ if record_id in submission_by_id
1262
+ ]
1263
+ if not scored_ids:
1264
+ reject_submission(
1265
+ "Submission produced no scoreable rows after applying runtime policies."
1266
+ )
1267
+ reference_values = [reference_by_id[record_id] for record_id in scored_ids]
1268
+ submission_values = [submission_by_id[record_id] for record_id in scored_ids]
1269
+ raw_metric = compute_metric(
1270
+ metric,
1271
+ reference_values,
1272
+ submission_values,
1273
+ scored_ids,
1274
+ metric_params,
1275
+ )
1276
+ normalized_score = normalize_score(metric, raw_metric)
1277
+ if coverage_policy == "penalize":
1278
+ normalized_score *= len(scored_ids) / len(reference_by_id)
1279
+ write_score(
1280
+ score=normalized_score,
1281
+ details={
1282
+ final_score_key: normalized_score,
1283
+ "selected_metric": metric,
1284
+ "selected_metric_value": raw_metric,
1285
+ "rows_scored": len(scored_ids),
1286
+ },
1287
+ )
1288
+
1289
+
1290
+ if __name__ == "__main__":
1291
+ main()