@moleculeagora/cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,392 @@
1
+ import csv
2
+ import math
3
+
4
+ from agora_runtime import (
5
+ fail_runtime,
6
+ load_json_file,
7
+ load_runtime_context,
8
+ reject_submission,
9
+ resolve_evaluation_artifact,
10
+ resolve_scoring_asset,
11
+ resolve_submission_artifact,
12
+ write_score,
13
+ )
14
+
15
+ SUPPORTED_METRIC = "mcrmse"
16
+
17
+
18
+ def require_string(value, label):
19
+ if not isinstance(value, str) or not value.strip():
20
+ fail_runtime(f"{label} must be a non-empty string.")
21
+ return value.strip()
22
+
23
+
24
+ def require_metric(config):
25
+ metric = require_string(config.get("metric"), "compiled_config.metric").lower()
26
+ if metric != SUPPORTED_METRIC:
27
+ fail_runtime("compiled_config.metric must be mcrmse.")
28
+ return metric
29
+
30
+
31
+ def require_policy(policies, key, allowed):
32
+ value = require_string(policies.get(key), f"policies.{key}")
33
+ if value not in allowed:
34
+ fail_runtime(
35
+ f"policies.{key} must be one of {', '.join(sorted(allowed))}."
36
+ )
37
+ return value
38
+
39
+
40
+ def find_slot(runtime_context, lane, role):
41
+ slot_key = f"{lane}_slots"
42
+ slots = runtime_context.get(slot_key)
43
+ if not isinstance(slots, list):
44
+ fail_runtime(f"Runtime context is missing {slot_key}.")
45
+ for slot in slots:
46
+ if isinstance(slot, dict) and slot.get("role") == role:
47
+ return slot
48
+ fail_runtime(f"Runtime context is missing {lane} slot for role {role}.")
49
+
50
+
51
+ def require_csv_slot_columns(runtime_context, lane, role):
52
+ slot = find_slot(runtime_context, lane, role)
53
+ validator = slot.get("validator")
54
+ if not isinstance(validator, dict) or validator.get("kind") != "csv_columns":
55
+ fail_runtime(
56
+ f"{lane} role {role} must use validator.kind=csv_columns for multi_output_tabular_metric."
57
+ )
58
+ record_key = require_string(
59
+ validator.get("record_key"),
60
+ f"{lane}.{role}.validator.record_key",
61
+ )
62
+ value_fields = validator.get("value_fields")
63
+ if not isinstance(value_fields, list) or not value_fields:
64
+ fail_runtime(
65
+ f"{lane}.{role}.validator.value_fields must be a non-empty array."
66
+ )
67
+ normalized_value_fields = [
68
+ require_string(field, f"{lane}.{role}.validator.value_fields[{index}]")
69
+ for index, field in enumerate(value_fields)
70
+ ]
71
+ if len(set(normalized_value_fields)) != len(normalized_value_fields):
72
+ fail_runtime(f"{lane}.{role}.validator.value_fields must be unique.")
73
+ return record_key, normalized_value_fields
74
+
75
+
76
+ def read_csv_rows(path, label, *, invalid_handler):
77
+ try:
78
+ with path.open("r", encoding="utf-8", newline="") as handle:
79
+ reader = csv.DictReader(handle)
80
+ fieldnames = reader.fieldnames
81
+ if not fieldnames:
82
+ invalid_handler(f"{label} must include a CSV header row.")
83
+ normalized_fieldnames = []
84
+ for fieldname in fieldnames:
85
+ if not isinstance(fieldname, str) or not fieldname.strip():
86
+ invalid_handler(f"{label} contains an empty CSV column name.")
87
+ normalized_fieldnames.append(fieldname.strip())
88
+ rows = list(reader)
89
+ except FileNotFoundError:
90
+ invalid_handler(f"Missing {label} at {path}.")
91
+ except OSError as error:
92
+ invalid_handler(f"Unable to read {label}: {error}.")
93
+ return normalized_fieldnames, rows
94
+
95
+
96
+ def parse_reference_value(raw_value, label):
97
+ text = raw_value.strip() if isinstance(raw_value, str) else ""
98
+ if not text:
99
+ fail_runtime(f"{label} is blank.")
100
+ try:
101
+ value = float(text)
102
+ except ValueError:
103
+ fail_runtime(f"{label} must be numeric, received {text!r}.")
104
+ if not math.isfinite(value):
105
+ fail_runtime(f"{label} must be finite.")
106
+ return value
107
+
108
+
109
+ def parse_submission_value(raw_value, label, invalid_value_policy):
110
+ text = raw_value.strip() if isinstance(raw_value, str) else ""
111
+ if not text:
112
+ if invalid_value_policy == "reject":
113
+ reject_submission(f"{label} is blank.")
114
+ return None
115
+ try:
116
+ value = float(text)
117
+ except ValueError:
118
+ if invalid_value_policy == "reject":
119
+ reject_submission(f"{label} must be numeric, received {text!r}.")
120
+ return None
121
+ if not math.isfinite(value):
122
+ if invalid_value_policy == "reject":
123
+ reject_submission(f"{label} must be finite.")
124
+ return None
125
+ return value
126
+
127
+
128
+ def require_columns(fieldnames, role, lane, record_key, value_fields, invalid_handler):
129
+ if record_key not in fieldnames:
130
+ invalid_handler(f"{lane} artifact {role} is missing record key column {record_key}.")
131
+ for value_field in value_fields:
132
+ if value_field not in fieldnames:
133
+ invalid_handler(
134
+ f"{lane} artifact {role} is missing target column {value_field}."
135
+ )
136
+
137
+
138
+ def load_reference_values(path, role, record_key, value_fields):
139
+ fieldnames, rows = read_csv_rows(
140
+ path,
141
+ f"evaluation artifact {role}",
142
+ invalid_handler=fail_runtime,
143
+ )
144
+ require_columns(
145
+ fieldnames,
146
+ role,
147
+ "evaluation",
148
+ record_key,
149
+ value_fields,
150
+ fail_runtime,
151
+ )
152
+ values = {}
153
+ for row_index, row in enumerate(rows, start=2):
154
+ raw_key = row.get(record_key)
155
+ key = raw_key.strip() if isinstance(raw_key, str) else ""
156
+ if not key:
157
+ fail_runtime(
158
+ f"evaluation artifact {role} row {row_index} is missing {record_key}."
159
+ )
160
+ if key in values:
161
+ fail_runtime(
162
+ f"evaluation artifact {role} contains duplicate record id {key!r}."
163
+ )
164
+ values[key] = {
165
+ value_field: parse_reference_value(
166
+ row.get(value_field),
167
+ f"evaluation artifact {role} row {row_index} column {value_field}",
168
+ )
169
+ for value_field in value_fields
170
+ }
171
+ if not values:
172
+ fail_runtime(
173
+ f"evaluation artifact {role} must contain at least one scored row."
174
+ )
175
+ return values
176
+
177
+
178
+ def parse_submission_row(row, role, row_index, value_fields, invalid_value_policy):
179
+ parsed = {}
180
+ for value_field in value_fields:
181
+ parsed_value = parse_submission_value(
182
+ row.get(value_field),
183
+ f"submission artifact {role} row {row_index} column {value_field}",
184
+ invalid_value_policy,
185
+ )
186
+ if parsed_value is None:
187
+ return None
188
+ parsed[value_field] = parsed_value
189
+ return parsed
190
+
191
+
192
+ def load_submission_values(
193
+ path,
194
+ role,
195
+ record_key,
196
+ value_fields,
197
+ duplicate_id_policy,
198
+ invalid_value_policy,
199
+ ):
200
+ fieldnames, rows = read_csv_rows(
201
+ path,
202
+ f"submission artifact {role}",
203
+ invalid_handler=reject_submission,
204
+ )
205
+ require_columns(
206
+ fieldnames,
207
+ role,
208
+ "submission",
209
+ record_key,
210
+ value_fields,
211
+ reject_submission,
212
+ )
213
+ values = {}
214
+ for row_index, row in enumerate(rows, start=2):
215
+ raw_key = row.get(record_key)
216
+ key = raw_key.strip() if isinstance(raw_key, str) else ""
217
+ if not key:
218
+ if invalid_value_policy == "reject":
219
+ reject_submission(
220
+ f"submission artifact {role} row {row_index} is missing {record_key}."
221
+ )
222
+ continue
223
+ if key in values:
224
+ if duplicate_id_policy == "reject":
225
+ reject_submission(
226
+ f"submission artifact {role} contains duplicate record id {key!r}."
227
+ )
228
+ continue
229
+ parsed_row = parse_submission_row(
230
+ row,
231
+ role,
232
+ row_index,
233
+ value_fields,
234
+ invalid_value_policy,
235
+ )
236
+ if parsed_row is None:
237
+ continue
238
+ values[key] = parsed_row
239
+ return values
240
+
241
+
242
+ def compute_rmse(reference_values, submission_values):
243
+ squared_error = sum(
244
+ (submission - reference) ** 2
245
+ for reference, submission in zip(reference_values, submission_values)
246
+ )
247
+ return math.sqrt(squared_error / len(reference_values))
248
+
249
+
250
+ def compute_mcrmse(reference_by_id, submission_by_id, scored_ids, value_fields):
251
+ target_rmse = {}
252
+ for value_field in value_fields:
253
+ reference_values = [
254
+ reference_by_id[record_id][value_field] for record_id in scored_ids
255
+ ]
256
+ submission_values = [
257
+ submission_by_id[record_id][value_field] for record_id in scored_ids
258
+ ]
259
+ target_rmse[value_field] = compute_rmse(reference_values, submission_values)
260
+ return sum(target_rmse.values()) / len(target_rmse), target_rmse
261
+
262
+
263
+ def normalize_score(raw_metric):
264
+ return 1.0 / (1.0 + max(raw_metric, 0.0))
265
+
266
+
267
+ def validate_objective(objective):
268
+ if objective != "minimize":
269
+ fail_runtime(
270
+ f"multi_output_tabular_metric metric mcrmse requires objective=minimize, received {objective}."
271
+ )
272
+
273
+
274
+ def main():
275
+ runtime_context = load_runtime_context()
276
+ config_path = resolve_scoring_asset(
277
+ runtime_context,
278
+ "compiled_config",
279
+ kind="config",
280
+ )
281
+ try:
282
+ config = load_json_file(config_path, label="compiled_config")
283
+ except RuntimeError as error:
284
+ fail_runtime(str(error))
285
+ metric = require_metric(config)
286
+ evaluation_role = require_string(
287
+ config.get("evaluation_role"),
288
+ "compiled_config.evaluation_role",
289
+ )
290
+ submission_role = require_string(
291
+ config.get("submission_role"),
292
+ "compiled_config.submission_role",
293
+ )
294
+ final_score_key = require_string(
295
+ runtime_context.get("final_score_key"),
296
+ "runtime_context.final_score_key",
297
+ )
298
+ objective = require_string(
299
+ runtime_context.get("objective"),
300
+ "runtime_context.objective",
301
+ )
302
+ validate_objective(objective)
303
+ policies = runtime_context.get("policies")
304
+ if not isinstance(policies, dict):
305
+ fail_runtime("Runtime context is missing execution policies.")
306
+ coverage_policy = require_policy(
307
+ policies,
308
+ "coverage_policy",
309
+ {"reject", "ignore", "penalize"},
310
+ )
311
+ duplicate_id_policy = require_policy(
312
+ policies,
313
+ "duplicate_id_policy",
314
+ {"reject", "ignore"},
315
+ )
316
+ invalid_value_policy = require_policy(
317
+ policies,
318
+ "invalid_value_policy",
319
+ {"reject", "ignore"},
320
+ )
321
+ evaluation_record_key, evaluation_value_fields = require_csv_slot_columns(
322
+ runtime_context,
323
+ "evaluation",
324
+ evaluation_role,
325
+ )
326
+ submission_record_key, submission_value_fields = require_csv_slot_columns(
327
+ runtime_context,
328
+ "submission",
329
+ submission_role,
330
+ )
331
+ if evaluation_value_fields != submission_value_fields:
332
+ fail_runtime(
333
+ "evaluation and submission csv_columns.validator.value_fields must match for multi_output_tabular_metric."
334
+ )
335
+ evaluation_path = resolve_evaluation_artifact(runtime_context, evaluation_role)
336
+ submission_path = resolve_submission_artifact(runtime_context, submission_role)
337
+ reference_by_id = load_reference_values(
338
+ evaluation_path,
339
+ evaluation_role,
340
+ evaluation_record_key,
341
+ evaluation_value_fields,
342
+ )
343
+ submission_by_id = load_submission_values(
344
+ submission_path,
345
+ submission_role,
346
+ submission_record_key,
347
+ submission_value_fields,
348
+ duplicate_id_policy,
349
+ invalid_value_policy,
350
+ )
351
+ missing_ids = [
352
+ record_id
353
+ for record_id in reference_by_id
354
+ if record_id not in submission_by_id
355
+ ]
356
+ if missing_ids and coverage_policy == "reject":
357
+ reject_submission(
358
+ f"Submission is missing predictions for {len(missing_ids)} required rows; first missing id is {missing_ids[0]!r}."
359
+ )
360
+ scored_ids = [
361
+ record_id
362
+ for record_id in reference_by_id
363
+ if record_id in submission_by_id
364
+ ]
365
+ if not scored_ids:
366
+ reject_submission(
367
+ "Submission produced no scoreable rows after applying runtime policies."
368
+ )
369
+ raw_metric, target_rmse = compute_mcrmse(
370
+ reference_by_id,
371
+ submission_by_id,
372
+ scored_ids,
373
+ evaluation_value_fields,
374
+ )
375
+ normalized_score = normalize_score(raw_metric)
376
+ if coverage_policy == "penalize":
377
+ normalized_score *= len(scored_ids) / len(reference_by_id)
378
+ write_score(
379
+ score=normalized_score,
380
+ details={
381
+ final_score_key: normalized_score,
382
+ "selected_metric": metric,
383
+ "selected_metric_value": raw_metric,
384
+ "rows_scored": len(scored_ids),
385
+ "target_count": len(evaluation_value_fields),
386
+ "target_rmse": target_rmse,
387
+ },
388
+ )
389
+
390
+
391
+ if __name__ == "__main__":
392
+ main()