@moleculeagora/cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,442 @@
1
+ import csv
2
+ import math
3
+
4
+ from agora_runtime import (
5
+ fail_runtime,
6
+ load_json_file,
7
+ load_runtime_context,
8
+ reject_submission,
9
+ resolve_scoring_asset,
10
+ resolve_submission_artifact,
11
+ write_score,
12
+ )
13
+
14
+ SUPPORTED_DIRECTIONS = ("treatment_lt_control", "treatment_gt_control")
15
+ BETA_CONTINUED_FRACTION_MAX_ITERATIONS = 200
16
+ BETA_CONTINUED_FRACTION_EPSILON = 3e-14
17
+ BETA_CONTINUED_FRACTION_MIN_FLOAT = 1e-300
18
+
19
+
20
+ def require_string(value, label, invalid_handler=fail_runtime):
21
+ if not isinstance(value, str) or not value.strip():
22
+ invalid_handler(f"{label} must be a non-empty string.")
23
+ return value.strip()
24
+
25
+
26
+ def require_config_dict(config, key):
27
+ value = config.get(key)
28
+ if not isinstance(value, dict):
29
+ fail_runtime(f"compiled_config.{key} must be an object.")
30
+ return value
31
+
32
+
33
+ def require_alpha(value):
34
+ if isinstance(value, bool) or not isinstance(value, (int, float)):
35
+ fail_runtime("endpoint_test.alpha must be a finite number.")
36
+ alpha = float(value)
37
+ if not math.isfinite(alpha) or alpha <= 0.0 or alpha >= 1.0:
38
+ fail_runtime("endpoint_test.alpha must be finite, greater than 0, and less than 1.")
39
+ return alpha
40
+
41
+
42
+ def require_minimum_n(value):
43
+ if isinstance(value, bool) or not isinstance(value, int):
44
+ fail_runtime("endpoint_test.minimum_n_per_group must be an integer.")
45
+ if value < 2:
46
+ fail_runtime("endpoint_test.minimum_n_per_group must be at least 2.")
47
+ return value
48
+
49
+
50
+ def load_endpoint_test_config(config):
51
+ endpoint_test = require_config_dict(config, "endpoint_test")
52
+ group_column = require_string(
53
+ endpoint_test.get("group_column"),
54
+ "endpoint_test.group_column",
55
+ )
56
+ value_column = require_string(
57
+ endpoint_test.get("value_column"),
58
+ "endpoint_test.value_column",
59
+ )
60
+ if group_column == value_column:
61
+ fail_runtime("endpoint_test.group_column and endpoint_test.value_column must be different.")
62
+ treatment_group = require_string(
63
+ endpoint_test.get("treatment_group"),
64
+ "endpoint_test.treatment_group",
65
+ )
66
+ control_group = require_string(
67
+ endpoint_test.get("control_group"),
68
+ "endpoint_test.control_group",
69
+ )
70
+ if treatment_group == control_group:
71
+ fail_runtime("endpoint_test.treatment_group and endpoint_test.control_group must be different.")
72
+ expected_direction = require_string(
73
+ endpoint_test.get("expected_direction"),
74
+ "endpoint_test.expected_direction",
75
+ )
76
+ if expected_direction not in SUPPORTED_DIRECTIONS:
77
+ fail_runtime(
78
+ f"endpoint_test.expected_direction must be one of {', '.join(SUPPORTED_DIRECTIONS)}."
79
+ )
80
+ return {
81
+ "group_column": group_column,
82
+ "value_column": value_column,
83
+ "treatment_group": treatment_group,
84
+ "control_group": control_group,
85
+ "expected_direction": expected_direction,
86
+ "alpha": require_alpha(endpoint_test.get("alpha")),
87
+ "minimum_n_per_group": require_minimum_n(
88
+ endpoint_test.get("minimum_n_per_group")
89
+ ),
90
+ }
91
+
92
+
93
+ def require_observations_slot(runtime_context, role, params):
94
+ artifact_contract = runtime_context.get("artifact_contract")
95
+ if not isinstance(artifact_contract, dict):
96
+ fail_runtime("Runtime context is missing artifact_contract.")
97
+ slots = artifact_contract.get("submission")
98
+ if not isinstance(slots, list):
99
+ fail_runtime("Runtime context is missing artifact_contract.submission.")
100
+ for slot in slots:
101
+ if not isinstance(slot, dict) or slot.get("role") != role:
102
+ continue
103
+ validator = slot.get("validator")
104
+ if not isinstance(validator, dict) or validator.get("kind") != "csv_columns":
105
+ fail_runtime(
106
+ f"submission role {role} must use validator.kind=csv_columns for two_group_endpoint_test@1."
107
+ )
108
+ required = validator.get("required")
109
+ if not isinstance(required, list):
110
+ fail_runtime(
111
+ f"submission role {role} validator.required must be an array for two_group_endpoint_test@1."
112
+ )
113
+ required_columns = {
114
+ str(column).strip()
115
+ for column in required
116
+ if isinstance(column, str) and str(column).strip()
117
+ }
118
+ expected_columns = {
119
+ "observation_id",
120
+ params["group_column"],
121
+ params["value_column"],
122
+ }
123
+ missing = sorted(expected_columns - required_columns)
124
+ if missing:
125
+ fail_runtime(
126
+ f"submission role {role} validator.required must include {', '.join(missing)} for two_group_endpoint_test@1."
127
+ )
128
+ return
129
+ fail_runtime(f"Runtime context is missing submission slot for role {role}.")
130
+
131
+
132
+ def read_csv_rows(path, label):
133
+ try:
134
+ with path.open("r", encoding="utf-8", newline="") as handle:
135
+ reader = csv.DictReader(handle)
136
+ fieldnames = reader.fieldnames
137
+ if not fieldnames:
138
+ reject_submission(f"{label} must include a CSV header row.")
139
+ normalized = []
140
+ seen = set()
141
+ for fieldname in fieldnames:
142
+ if not isinstance(fieldname, str) or not fieldname.strip():
143
+ reject_submission(f"{label} contains an empty CSV column name.")
144
+ column = fieldname.strip()
145
+ if column in seen:
146
+ reject_submission(f"{label} contains duplicate CSV column {column!r}.")
147
+ seen.add(column)
148
+ normalized.append((fieldname, column))
149
+ rows = []
150
+ for row_index, row in enumerate(reader, start=2):
151
+ if None in row:
152
+ reject_submission(f"{label} row {row_index} has too many columns.")
153
+ rows.append(
154
+ {
155
+ normalized_name: row.get(raw_name, "")
156
+ for raw_name, normalized_name in normalized
157
+ }
158
+ )
159
+ except FileNotFoundError:
160
+ reject_submission(f"Missing {label} at {path}.")
161
+ except OSError as error:
162
+ reject_submission(f"Unable to read {label}: {error}.")
163
+ return [column for _, column in normalized], rows
164
+
165
+
166
+ def require_columns(fieldnames, required_columns, label):
167
+ missing = [column for column in required_columns if column not in fieldnames]
168
+ if missing:
169
+ reject_submission(f"{label} is missing required columns: {', '.join(missing)}.")
170
+
171
+
172
+ def parse_finite_number(raw_value, label):
173
+ text = str(raw_value).strip() if raw_value is not None else ""
174
+ if not text:
175
+ reject_submission(f"{label} must be present.")
176
+ try:
177
+ value = float(text)
178
+ except ValueError:
179
+ reject_submission(f"{label} must be numeric, received {text!r}.")
180
+ if not math.isfinite(value):
181
+ reject_submission(f"{label} must be finite.")
182
+ return value
183
+
184
+
185
+ def load_observations(path, role, params):
186
+ fieldnames, rows = read_csv_rows(path, f"submission artifact {role}")
187
+ require_columns(
188
+ fieldnames,
189
+ ["observation_id", params["group_column"], params["value_column"]],
190
+ f"submission artifact {role}",
191
+ )
192
+ seen_observation_ids = set()
193
+ treatment_values = []
194
+ control_values = []
195
+ for row_index, row in enumerate(rows, start=2):
196
+ observation_id = require_string(
197
+ row.get("observation_id"),
198
+ f"submission artifact {role} row {row_index} observation_id",
199
+ reject_submission,
200
+ )
201
+ if observation_id in seen_observation_ids:
202
+ reject_submission(
203
+ f"submission artifact {role} contains duplicate observation_id {observation_id!r}."
204
+ )
205
+ seen_observation_ids.add(observation_id)
206
+ group = require_string(
207
+ row.get(params["group_column"]),
208
+ f"submission artifact {role} row {row_index} {params['group_column']}",
209
+ reject_submission,
210
+ )
211
+ value = parse_finite_number(
212
+ row.get(params["value_column"]),
213
+ f"submission artifact {role} row {row_index} {params['value_column']}",
214
+ )
215
+ if group == params["treatment_group"]:
216
+ treatment_values.append(value)
217
+ elif group == params["control_group"]:
218
+ control_values.append(value)
219
+ return treatment_values, control_values
220
+
221
+
222
+ def require_group_size(values, minimum_n, group_label, group_kind):
223
+ if len(values) < minimum_n:
224
+ reject_submission(
225
+ f"two_group_endpoint_test@1 requires at least {minimum_n} {group_kind} observations for group {group_label!r}; received {len(values)}.",
226
+ details={f"n_{group_kind}": len(values)},
227
+ )
228
+
229
+
230
+ def sample_mean(values):
231
+ return sum(values) / len(values)
232
+
233
+
234
+ def sample_variance(values, mean):
235
+ return sum((value - mean) ** 2 for value in values) / (len(values) - 1)
236
+
237
+
238
+ def beta_continued_fraction(a, b, x):
239
+ qab = a + b
240
+ qap = a + 1.0
241
+ qam = a - 1.0
242
+ c = 1.0
243
+ d = 1.0 - (qab * x / qap)
244
+ if abs(d) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
245
+ d = BETA_CONTINUED_FRACTION_MIN_FLOAT
246
+ d = 1.0 / d
247
+ h = d
248
+ for iteration in range(1, BETA_CONTINUED_FRACTION_MAX_ITERATIONS + 1):
249
+ m2 = 2 * iteration
250
+ aa = iteration * (b - iteration) * x / ((qam + m2) * (a + m2))
251
+ d = 1.0 + aa * d
252
+ if abs(d) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
253
+ d = BETA_CONTINUED_FRACTION_MIN_FLOAT
254
+ c = 1.0 + aa / c
255
+ if abs(c) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
256
+ c = BETA_CONTINUED_FRACTION_MIN_FLOAT
257
+ d = 1.0 / d
258
+ h *= d * c
259
+
260
+ aa = -((a + iteration) * (qab + iteration) * x) / (
261
+ (a + m2) * (qap + m2)
262
+ )
263
+ d = 1.0 + aa * d
264
+ if abs(d) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
265
+ d = BETA_CONTINUED_FRACTION_MIN_FLOAT
266
+ c = 1.0 + aa / c
267
+ if abs(c) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
268
+ c = BETA_CONTINUED_FRACTION_MIN_FLOAT
269
+ d = 1.0 / d
270
+ delta = d * c
271
+ h *= delta
272
+ if abs(delta - 1.0) < BETA_CONTINUED_FRACTION_EPSILON:
273
+ return h
274
+ fail_runtime("regularized incomplete beta calculation did not converge.")
275
+
276
+
277
+ def regularized_incomplete_beta(x, a, b):
278
+ if a <= 0.0 or b <= 0.0:
279
+ fail_runtime("regularized incomplete beta parameters must be positive.")
280
+ if x < 0.0 or x > 1.0:
281
+ fail_runtime("regularized incomplete beta x must be in [0, 1].")
282
+ if x == 0.0:
283
+ return 0.0
284
+ if x == 1.0:
285
+ return 1.0
286
+
287
+ log_front = (
288
+ math.lgamma(a + b)
289
+ - math.lgamma(a)
290
+ - math.lgamma(b)
291
+ + (a * math.log(x))
292
+ + (b * math.log1p(-x))
293
+ )
294
+ front = math.exp(log_front)
295
+ if x < (a + 1.0) / (a + b + 2.0):
296
+ return front * beta_continued_fraction(a, b, x) / a
297
+ return 1.0 - (front * beta_continued_fraction(b, a, 1.0 - x) / b)
298
+
299
+
300
+ def clamp_probability(value):
301
+ if not math.isfinite(value):
302
+ fail_runtime("Student t p-value calculation produced a non-finite result.")
303
+ if value < 0.0 and value > -1e-15:
304
+ return 0.0
305
+ if value > 1.0 and value < 1.0 + 1e-15:
306
+ return 1.0
307
+ if value < 0.0 or value > 1.0:
308
+ fail_runtime("Student t p-value calculation produced a value outside [0, 1].")
309
+ return value
310
+
311
+
312
+ def two_sided_student_t_p_value(test_statistic, degrees_of_freedom):
313
+ if degrees_of_freedom <= 0.0 or not math.isfinite(degrees_of_freedom):
314
+ fail_runtime("Welch-Satterthwaite degrees of freedom must be positive and finite.")
315
+ t_abs = abs(test_statistic)
316
+ if t_abs == 0.0:
317
+ return 1.0
318
+ x = degrees_of_freedom / (degrees_of_freedom + (t_abs * t_abs))
319
+ return clamp_probability(
320
+ regularized_incomplete_beta(x, degrees_of_freedom / 2.0, 0.5)
321
+ )
322
+
323
+
324
+ def compute_welch_test(treatment_values, control_values):
325
+ treatment_mean = sample_mean(treatment_values)
326
+ control_mean = sample_mean(control_values)
327
+ treatment_variance = sample_variance(treatment_values, treatment_mean)
328
+ control_variance = sample_variance(control_values, control_mean)
329
+ treatment_term = treatment_variance / len(treatment_values)
330
+ control_term = control_variance / len(control_values)
331
+ standard_error_squared = treatment_term + control_term
332
+ if not math.isfinite(standard_error_squared) or standard_error_squared <= 0.0:
333
+ reject_submission(
334
+ "Welch t-test requires positive within-group variance; observed zero pooled standard error."
335
+ )
336
+ effect = treatment_mean - control_mean
337
+ test_statistic = effect / math.sqrt(standard_error_squared)
338
+ denominator = 0.0
339
+ if treatment_term > 0.0:
340
+ denominator += (treatment_term * treatment_term) / (len(treatment_values) - 1)
341
+ if control_term > 0.0:
342
+ denominator += (control_term * control_term) / (len(control_values) - 1)
343
+ if denominator <= 0.0 or not math.isfinite(denominator):
344
+ reject_submission(
345
+ "Welch t-test requires positive variance in at least one group."
346
+ )
347
+ degrees_of_freedom = (standard_error_squared * standard_error_squared) / denominator
348
+ p_value = two_sided_student_t_p_value(test_statistic, degrees_of_freedom)
349
+ return {
350
+ "treatment_mean": treatment_mean,
351
+ "control_mean": control_mean,
352
+ "effect": effect,
353
+ "test_statistic": test_statistic,
354
+ "degrees_of_freedom": degrees_of_freedom,
355
+ "p_value": p_value,
356
+ }
357
+
358
+
359
+ def direction_matches(effect, expected_direction):
360
+ if expected_direction == "treatment_lt_control":
361
+ return effect < 0.0
362
+ return effect > 0.0
363
+
364
+
365
+ def main():
366
+ runtime_context = load_runtime_context()
367
+ config_path = resolve_scoring_asset(
368
+ runtime_context,
369
+ "compiled_config",
370
+ kind="config",
371
+ )
372
+ try:
373
+ config = load_json_file(config_path, label="compiled_config")
374
+ except RuntimeError as error:
375
+ fail_runtime(str(error))
376
+ if not isinstance(config, dict):
377
+ fail_runtime("compiled_config must be a JSON object.")
378
+
379
+ submission_role = require_string(
380
+ config.get("submission_role"),
381
+ "compiled_config.submission_role",
382
+ )
383
+ final_score_key = require_string(
384
+ config.get("final_score_key"),
385
+ "compiled_config.final_score_key",
386
+ )
387
+ objective = require_string(
388
+ runtime_context.get("objective"),
389
+ "runtime_context.objective",
390
+ )
391
+ if objective != "maximize":
392
+ fail_runtime("two_group_endpoint_test@1 requires objective=maximize.")
393
+
394
+ params = load_endpoint_test_config(config)
395
+ require_observations_slot(runtime_context, submission_role, params)
396
+ observations_path = resolve_submission_artifact(runtime_context, submission_role)
397
+ treatment_values, control_values = load_observations(
398
+ observations_path,
399
+ submission_role,
400
+ params,
401
+ )
402
+ require_group_size(
403
+ treatment_values,
404
+ params["minimum_n_per_group"],
405
+ params["treatment_group"],
406
+ "treatment",
407
+ )
408
+ require_group_size(
409
+ control_values,
410
+ params["minimum_n_per_group"],
411
+ params["control_group"],
412
+ "control",
413
+ )
414
+
415
+ result = compute_welch_test(treatment_values, control_values)
416
+ matched_direction = direction_matches(result["effect"], params["expected_direction"])
417
+ score = 1.0 if result["p_value"] < params["alpha"] and matched_direction else 0.0
418
+ details = {
419
+ final_score_key: score,
420
+ "score": score,
421
+ "p_value": result["p_value"],
422
+ "effect": result["effect"],
423
+ "test_statistic": result["test_statistic"],
424
+ "degrees_of_freedom": result["degrees_of_freedom"],
425
+ "treatment_mean": result["treatment_mean"],
426
+ "control_mean": result["control_mean"],
427
+ "n_treatment": len(treatment_values),
428
+ "n_control": len(control_values),
429
+ "direction_matched": matched_direction,
430
+ "expected_direction": params["expected_direction"],
431
+ "alpha": params["alpha"],
432
+ "minimum_n_per_group": params["minimum_n_per_group"],
433
+ "treatment_group": params["treatment_group"],
434
+ "control_group": params["control_group"],
435
+ "group_column": params["group_column"],
436
+ "value_column": params["value_column"],
437
+ }
438
+ write_score(score=score, details=details)
439
+
440
+
441
+ if __name__ == "__main__":
442
+ main()