@moleculeagora/cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -0
- package/dist/index.js +30368 -0
- package/dist/index.js.map +1 -0
- package/dist/python-v1/agora_runtime.py +282 -0
- package/dist/python-v1/answer-set-metric.py +264 -0
- package/dist/python-v1/assertion-set-evaluation.py +879 -0
- package/dist/python-v1/exact-match.py +60 -0
- package/dist/python-v1/l4-composition.py +435 -0
- package/dist/python-v1/multi-output-tabular-metric.py +392 -0
- package/dist/python-v1/panel-ranking-metric.py +622 -0
- package/dist/python-v1/project-test.py +256 -0
- package/dist/python-v1/protein-binder-assay-metric.py +600 -0
- package/dist/python-v1/public-tool-metric.py +161 -0
- package/dist/python-v1/ranking-metric.py +426 -0
- package/dist/python-v1/reference-artifact-assertion.py +532 -0
- package/dist/python-v1/rubric-validation.py +246 -0
- package/dist/python-v1/solver-python-stdio-test.py +160 -0
- package/dist/python-v1/statistical-endpoint-test-v2.py +629 -0
- package/dist/python-v1/statistical-endpoint-test.py +442 -0
- package/dist/python-v1/table-metric.py +1291 -0
- package/dist/release-metadata.json +7 -0
- package/package.json +67 -0
|
@@ -0,0 +1,629 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import math
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
from agora_runtime import (
|
|
6
|
+
fail_runtime,
|
|
7
|
+
load_json_file,
|
|
8
|
+
load_runtime_context,
|
|
9
|
+
reject_submission,
|
|
10
|
+
resolve_scoring_asset,
|
|
11
|
+
resolve_submission_artifact,
|
|
12
|
+
write_score,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
SUPPORTED_DIRECTIONS = ("treatment_lt_control", "treatment_gt_control")
|
|
16
|
+
SUPPORTED_P_ADJUSTMENTS = ("bonferroni", "holm")
|
|
17
|
+
CONTRAST_ID_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
|
|
18
|
+
BETA_CONTINUED_FRACTION_MAX_ITERATIONS = 200
|
|
19
|
+
BETA_CONTINUED_FRACTION_EPSILON = 3e-14
|
|
20
|
+
BETA_CONTINUED_FRACTION_MIN_FLOAT = 1e-300
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def require_string(value, label, invalid_handler=fail_runtime):
|
|
24
|
+
if not isinstance(value, str) or not value.strip():
|
|
25
|
+
invalid_handler(f"{label} must be a non-empty string.")
|
|
26
|
+
return value.strip()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def require_config_dict(config, key):
|
|
30
|
+
value = config.get(key)
|
|
31
|
+
if not isinstance(value, dict):
|
|
32
|
+
fail_runtime(f"compiled_config.{key} must be an object.")
|
|
33
|
+
return value
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def require_alpha(value):
|
|
37
|
+
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
|
38
|
+
fail_runtime("endpoint_test.alpha must be a finite number.")
|
|
39
|
+
alpha = float(value)
|
|
40
|
+
if not math.isfinite(alpha) or alpha <= 0.0 or alpha >= 1.0:
|
|
41
|
+
fail_runtime("endpoint_test.alpha must be finite, greater than 0, and less than 1.")
|
|
42
|
+
return alpha
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def require_minimum_n(value):
|
|
46
|
+
if isinstance(value, bool) or not isinstance(value, int):
|
|
47
|
+
fail_runtime("endpoint_test.minimum_n_per_group must be an integer.")
|
|
48
|
+
if value < 2:
|
|
49
|
+
fail_runtime("endpoint_test.minimum_n_per_group must be at least 2.")
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def require_string_list(value, label, min_length, max_length):
|
|
54
|
+
if not isinstance(value, list):
|
|
55
|
+
fail_runtime(f"{label} must be an array.")
|
|
56
|
+
if len(value) < min_length or len(value) > max_length:
|
|
57
|
+
fail_runtime(f"{label} must contain between {min_length} and {max_length} entries.")
|
|
58
|
+
result = []
|
|
59
|
+
seen = set()
|
|
60
|
+
for index, item in enumerate(value):
|
|
61
|
+
text = require_string(item, f"{label}[{index}]")
|
|
62
|
+
if text in seen:
|
|
63
|
+
fail_runtime(f"{label} repeats value {text!r}.")
|
|
64
|
+
seen.add(text)
|
|
65
|
+
result.append(text)
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def require_p_adjustment(value):
|
|
70
|
+
adjustment = require_string(value, "endpoint_test.p_adjustment")
|
|
71
|
+
if adjustment not in SUPPORTED_P_ADJUSTMENTS:
|
|
72
|
+
fail_runtime(
|
|
73
|
+
f"endpoint_test.p_adjustment must be one of {', '.join(SUPPORTED_P_ADJUSTMENTS)}."
|
|
74
|
+
)
|
|
75
|
+
return adjustment
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def load_contrasts(value, included_groups):
|
|
79
|
+
if not isinstance(value, list):
|
|
80
|
+
fail_runtime("endpoint_test.contrasts must be an array.")
|
|
81
|
+
if len(value) < 1 or len(value) > 20:
|
|
82
|
+
fail_runtime("endpoint_test.contrasts must contain between 1 and 20 entries.")
|
|
83
|
+
group_set = set(included_groups)
|
|
84
|
+
contrast_ids = set()
|
|
85
|
+
contrasts = []
|
|
86
|
+
for index, item in enumerate(value):
|
|
87
|
+
if not isinstance(item, dict):
|
|
88
|
+
fail_runtime(f"endpoint_test.contrasts[{index}] must be an object.")
|
|
89
|
+
contrast_id = require_string(item.get("id"), f"endpoint_test.contrasts[{index}].id")
|
|
90
|
+
if not CONTRAST_ID_PATTERN.match(contrast_id):
|
|
91
|
+
fail_runtime(
|
|
92
|
+
f"endpoint_test.contrasts[{index}].id must start with a lowercase letter and use only lowercase letters, digits, and underscores."
|
|
93
|
+
)
|
|
94
|
+
if contrast_id in contrast_ids:
|
|
95
|
+
fail_runtime(f"endpoint_test.contrasts repeats id {contrast_id}.")
|
|
96
|
+
contrast_ids.add(contrast_id)
|
|
97
|
+
treatment_group = require_string(
|
|
98
|
+
item.get("treatment_group"),
|
|
99
|
+
f"endpoint_test.contrasts[{index}].treatment_group",
|
|
100
|
+
)
|
|
101
|
+
control_group = require_string(
|
|
102
|
+
item.get("control_group"),
|
|
103
|
+
f"endpoint_test.contrasts[{index}].control_group",
|
|
104
|
+
)
|
|
105
|
+
if treatment_group == control_group:
|
|
106
|
+
fail_runtime(
|
|
107
|
+
f"endpoint_test.contrasts[{index}] treatment_group and control_group must be different."
|
|
108
|
+
)
|
|
109
|
+
if treatment_group not in group_set:
|
|
110
|
+
fail_runtime(
|
|
111
|
+
f"endpoint_test.contrasts[{index}].treatment_group must reference included_groups."
|
|
112
|
+
)
|
|
113
|
+
if control_group not in group_set:
|
|
114
|
+
fail_runtime(
|
|
115
|
+
f"endpoint_test.contrasts[{index}].control_group must reference included_groups."
|
|
116
|
+
)
|
|
117
|
+
expected_direction = require_string(
|
|
118
|
+
item.get("expected_direction"),
|
|
119
|
+
f"endpoint_test.contrasts[{index}].expected_direction",
|
|
120
|
+
)
|
|
121
|
+
if expected_direction not in SUPPORTED_DIRECTIONS:
|
|
122
|
+
fail_runtime(
|
|
123
|
+
f"endpoint_test.contrasts[{index}].expected_direction must be one of {', '.join(SUPPORTED_DIRECTIONS)}."
|
|
124
|
+
)
|
|
125
|
+
contrasts.append(
|
|
126
|
+
{
|
|
127
|
+
"id": contrast_id,
|
|
128
|
+
"treatment_group": treatment_group,
|
|
129
|
+
"control_group": control_group,
|
|
130
|
+
"expected_direction": expected_direction,
|
|
131
|
+
}
|
|
132
|
+
)
|
|
133
|
+
return contrasts
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def load_endpoint_test_config(config):
|
|
137
|
+
endpoint_test = require_config_dict(config, "endpoint_test")
|
|
138
|
+
group_column = require_string(
|
|
139
|
+
endpoint_test.get("group_column"),
|
|
140
|
+
"endpoint_test.group_column",
|
|
141
|
+
)
|
|
142
|
+
value_column = require_string(
|
|
143
|
+
endpoint_test.get("value_column"),
|
|
144
|
+
"endpoint_test.value_column",
|
|
145
|
+
)
|
|
146
|
+
if group_column == value_column:
|
|
147
|
+
fail_runtime("endpoint_test.group_column and endpoint_test.value_column must be different.")
|
|
148
|
+
included_groups = require_string_list(
|
|
149
|
+
endpoint_test.get("included_groups"),
|
|
150
|
+
"endpoint_test.included_groups",
|
|
151
|
+
3,
|
|
152
|
+
20,
|
|
153
|
+
)
|
|
154
|
+
return {
|
|
155
|
+
"group_column": group_column,
|
|
156
|
+
"value_column": value_column,
|
|
157
|
+
"included_groups": included_groups,
|
|
158
|
+
"alpha": require_alpha(endpoint_test.get("alpha")),
|
|
159
|
+
"minimum_n_per_group": require_minimum_n(
|
|
160
|
+
endpoint_test.get("minimum_n_per_group")
|
|
161
|
+
),
|
|
162
|
+
"p_adjustment": require_p_adjustment(endpoint_test.get("p_adjustment")),
|
|
163
|
+
"contrasts": load_contrasts(endpoint_test.get("contrasts"), included_groups),
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def require_observations_slot(runtime_context, role, params):
|
|
168
|
+
artifact_contract = runtime_context.get("artifact_contract")
|
|
169
|
+
if not isinstance(artifact_contract, dict):
|
|
170
|
+
fail_runtime("Runtime context is missing artifact_contract.")
|
|
171
|
+
slots = artifact_contract.get("submission")
|
|
172
|
+
if not isinstance(slots, list):
|
|
173
|
+
fail_runtime("Runtime context is missing artifact_contract.submission.")
|
|
174
|
+
for slot in slots:
|
|
175
|
+
if not isinstance(slot, dict) or slot.get("role") != role:
|
|
176
|
+
continue
|
|
177
|
+
validator = slot.get("validator")
|
|
178
|
+
if not isinstance(validator, dict) or validator.get("kind") != "csv_columns":
|
|
179
|
+
fail_runtime(
|
|
180
|
+
f"submission role {role} must use validator.kind=csv_columns for multi_group_endpoint_test@1."
|
|
181
|
+
)
|
|
182
|
+
required = validator.get("required")
|
|
183
|
+
if not isinstance(required, list):
|
|
184
|
+
fail_runtime(
|
|
185
|
+
f"submission role {role} validator.required must be an array for multi_group_endpoint_test@1."
|
|
186
|
+
)
|
|
187
|
+
required_columns = {
|
|
188
|
+
str(column).strip()
|
|
189
|
+
for column in required
|
|
190
|
+
if isinstance(column, str) and str(column).strip()
|
|
191
|
+
}
|
|
192
|
+
expected_columns = {
|
|
193
|
+
"observation_id",
|
|
194
|
+
params["group_column"],
|
|
195
|
+
params["value_column"],
|
|
196
|
+
}
|
|
197
|
+
missing = sorted(expected_columns - required_columns)
|
|
198
|
+
if missing:
|
|
199
|
+
fail_runtime(
|
|
200
|
+
f"submission role {role} validator.required must include {', '.join(missing)} for multi_group_endpoint_test@1."
|
|
201
|
+
)
|
|
202
|
+
return
|
|
203
|
+
fail_runtime(f"Runtime context is missing submission slot for role {role}.")
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def read_csv_rows(path, label):
|
|
207
|
+
try:
|
|
208
|
+
with path.open("r", encoding="utf-8", newline="") as handle:
|
|
209
|
+
reader = csv.DictReader(handle)
|
|
210
|
+
fieldnames = reader.fieldnames
|
|
211
|
+
if not fieldnames:
|
|
212
|
+
reject_submission(f"{label} must include a CSV header row.")
|
|
213
|
+
normalized = []
|
|
214
|
+
seen = set()
|
|
215
|
+
for fieldname in fieldnames:
|
|
216
|
+
if not isinstance(fieldname, str) or not fieldname.strip():
|
|
217
|
+
reject_submission(f"{label} contains an empty CSV column name.")
|
|
218
|
+
column = fieldname.strip()
|
|
219
|
+
if column in seen:
|
|
220
|
+
reject_submission(f"{label} contains duplicate CSV column {column!r}.")
|
|
221
|
+
seen.add(column)
|
|
222
|
+
normalized.append((fieldname, column))
|
|
223
|
+
rows = []
|
|
224
|
+
for row_index, row in enumerate(reader, start=2):
|
|
225
|
+
if None in row:
|
|
226
|
+
reject_submission(f"{label} row {row_index} has too many columns.")
|
|
227
|
+
rows.append(
|
|
228
|
+
{
|
|
229
|
+
normalized_name: row.get(raw_name, "")
|
|
230
|
+
for raw_name, normalized_name in normalized
|
|
231
|
+
}
|
|
232
|
+
)
|
|
233
|
+
except FileNotFoundError:
|
|
234
|
+
reject_submission(f"Missing {label} at {path}.")
|
|
235
|
+
except OSError as error:
|
|
236
|
+
reject_submission(f"Unable to read {label}: {error}.")
|
|
237
|
+
return [column for _, column in normalized], rows
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def require_columns(fieldnames, required_columns, label):
|
|
241
|
+
missing = [column for column in required_columns if column not in fieldnames]
|
|
242
|
+
if missing:
|
|
243
|
+
reject_submission(f"{label} is missing required columns: {', '.join(missing)}.")
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def parse_finite_number(raw_value, label):
|
|
247
|
+
text = str(raw_value).strip() if raw_value is not None else ""
|
|
248
|
+
if not text:
|
|
249
|
+
reject_submission(f"{label} must be present.")
|
|
250
|
+
try:
|
|
251
|
+
value = float(text)
|
|
252
|
+
except ValueError:
|
|
253
|
+
reject_submission(f"{label} must be numeric, received {text!r}.")
|
|
254
|
+
if not math.isfinite(value):
|
|
255
|
+
reject_submission(f"{label} must be finite.")
|
|
256
|
+
return value
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def load_observations(path, role, params):
|
|
260
|
+
fieldnames, rows = read_csv_rows(path, f"submission artifact {role}")
|
|
261
|
+
require_columns(
|
|
262
|
+
fieldnames,
|
|
263
|
+
["observation_id", params["group_column"], params["value_column"]],
|
|
264
|
+
f"submission artifact {role}",
|
|
265
|
+
)
|
|
266
|
+
included_groups = set(params["included_groups"])
|
|
267
|
+
seen_observation_ids = set()
|
|
268
|
+
values_by_group = {group: [] for group in params["included_groups"]}
|
|
269
|
+
for row_index, row in enumerate(rows, start=2):
|
|
270
|
+
observation_id = require_string(
|
|
271
|
+
row.get("observation_id"),
|
|
272
|
+
f"submission artifact {role} row {row_index} observation_id",
|
|
273
|
+
reject_submission,
|
|
274
|
+
)
|
|
275
|
+
if observation_id in seen_observation_ids:
|
|
276
|
+
reject_submission(
|
|
277
|
+
f"submission artifact {role} contains duplicate observation_id {observation_id!r}."
|
|
278
|
+
)
|
|
279
|
+
seen_observation_ids.add(observation_id)
|
|
280
|
+
group = require_string(
|
|
281
|
+
row.get(params["group_column"]),
|
|
282
|
+
f"submission artifact {role} row {row_index} {params['group_column']}",
|
|
283
|
+
reject_submission,
|
|
284
|
+
)
|
|
285
|
+
value = parse_finite_number(
|
|
286
|
+
row.get(params["value_column"]),
|
|
287
|
+
f"submission artifact {role} row {row_index} {params['value_column']}",
|
|
288
|
+
)
|
|
289
|
+
if group in included_groups:
|
|
290
|
+
values_by_group[group].append(value)
|
|
291
|
+
return values_by_group
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def require_group_sizes(values_by_group, minimum_n):
|
|
295
|
+
for group, values in values_by_group.items():
|
|
296
|
+
if len(values) < minimum_n:
|
|
297
|
+
reject_submission(
|
|
298
|
+
f"multi_group_endpoint_test@1 requires at least {minimum_n} observations for group {group!r}; received {len(values)}.",
|
|
299
|
+
details={"group": group, "n": len(values)},
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def sample_mean(values):
|
|
304
|
+
return sum(values) / len(values)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def sample_variance(values, mean):
|
|
308
|
+
return sum((value - mean) ** 2 for value in values) / (len(values) - 1)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def beta_continued_fraction(a, b, x):
|
|
312
|
+
qab = a + b
|
|
313
|
+
qap = a + 1.0
|
|
314
|
+
qam = a - 1.0
|
|
315
|
+
c = 1.0
|
|
316
|
+
d = 1.0 - (qab * x / qap)
|
|
317
|
+
if abs(d) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
|
|
318
|
+
d = BETA_CONTINUED_FRACTION_MIN_FLOAT
|
|
319
|
+
d = 1.0 / d
|
|
320
|
+
h = d
|
|
321
|
+
for iteration in range(1, BETA_CONTINUED_FRACTION_MAX_ITERATIONS + 1):
|
|
322
|
+
m2 = 2 * iteration
|
|
323
|
+
aa = iteration * (b - iteration) * x / ((qam + m2) * (a + m2))
|
|
324
|
+
d = 1.0 + aa * d
|
|
325
|
+
if abs(d) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
|
|
326
|
+
d = BETA_CONTINUED_FRACTION_MIN_FLOAT
|
|
327
|
+
c = 1.0 + aa / c
|
|
328
|
+
if abs(c) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
|
|
329
|
+
c = BETA_CONTINUED_FRACTION_MIN_FLOAT
|
|
330
|
+
d = 1.0 / d
|
|
331
|
+
h *= d * c
|
|
332
|
+
|
|
333
|
+
aa = -((a + iteration) * (qab + iteration) * x) / (
|
|
334
|
+
(a + m2) * (qap + m2)
|
|
335
|
+
)
|
|
336
|
+
d = 1.0 + aa * d
|
|
337
|
+
if abs(d) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
|
|
338
|
+
d = BETA_CONTINUED_FRACTION_MIN_FLOAT
|
|
339
|
+
c = 1.0 + aa / c
|
|
340
|
+
if abs(c) < BETA_CONTINUED_FRACTION_MIN_FLOAT:
|
|
341
|
+
c = BETA_CONTINUED_FRACTION_MIN_FLOAT
|
|
342
|
+
d = 1.0 / d
|
|
343
|
+
delta = d * c
|
|
344
|
+
h *= delta
|
|
345
|
+
if abs(delta - 1.0) < BETA_CONTINUED_FRACTION_EPSILON:
|
|
346
|
+
return h
|
|
347
|
+
fail_runtime("regularized incomplete beta calculation did not converge.")
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def regularized_incomplete_beta(x, a, b):
|
|
351
|
+
if a <= 0.0 or b <= 0.0:
|
|
352
|
+
fail_runtime("regularized incomplete beta parameters must be positive.")
|
|
353
|
+
if x < 0.0 or x > 1.0:
|
|
354
|
+
fail_runtime("regularized incomplete beta x must be in [0, 1].")
|
|
355
|
+
if x == 0.0:
|
|
356
|
+
return 0.0
|
|
357
|
+
if x == 1.0:
|
|
358
|
+
return 1.0
|
|
359
|
+
|
|
360
|
+
log_front = (
|
|
361
|
+
math.lgamma(a + b)
|
|
362
|
+
- math.lgamma(a)
|
|
363
|
+
- math.lgamma(b)
|
|
364
|
+
+ (a * math.log(x))
|
|
365
|
+
+ (b * math.log1p(-x))
|
|
366
|
+
)
|
|
367
|
+
front = math.exp(log_front)
|
|
368
|
+
if x < (a + 1.0) / (a + b + 2.0):
|
|
369
|
+
return front * beta_continued_fraction(a, b, x) / a
|
|
370
|
+
return 1.0 - (front * beta_continued_fraction(b, a, 1.0 - x) / b)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def clamp_probability(value, label):
|
|
374
|
+
if not math.isfinite(value):
|
|
375
|
+
fail_runtime(f"{label} produced a non-finite result.")
|
|
376
|
+
if value < 0.0 and value > -1e-15:
|
|
377
|
+
return 0.0
|
|
378
|
+
if value > 1.0 and value < 1.0 + 1e-15:
|
|
379
|
+
return 1.0
|
|
380
|
+
if value < 0.0 or value > 1.0:
|
|
381
|
+
fail_runtime(f"{label} produced a value outside [0, 1].")
|
|
382
|
+
return value
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def two_sided_student_t_p_value(test_statistic, degrees_of_freedom):
|
|
386
|
+
if degrees_of_freedom <= 0.0 or not math.isfinite(degrees_of_freedom):
|
|
387
|
+
fail_runtime("Welch-Satterthwaite degrees of freedom must be positive and finite.")
|
|
388
|
+
t_abs = abs(test_statistic)
|
|
389
|
+
if t_abs == 0.0:
|
|
390
|
+
return 1.0
|
|
391
|
+
x = degrees_of_freedom / (degrees_of_freedom + (t_abs * t_abs))
|
|
392
|
+
return clamp_probability(
|
|
393
|
+
regularized_incomplete_beta(x, degrees_of_freedom / 2.0, 0.5),
|
|
394
|
+
"Student t p-value calculation",
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def f_distribution_survival(f_statistic, df_between, df_within):
|
|
399
|
+
if f_statistic < 0.0 or not math.isfinite(f_statistic):
|
|
400
|
+
fail_runtime("ANOVA F statistic must be finite and nonnegative.")
|
|
401
|
+
if df_between <= 0.0 or df_within <= 0.0:
|
|
402
|
+
fail_runtime("ANOVA degrees of freedom must be positive.")
|
|
403
|
+
if f_statistic == 0.0:
|
|
404
|
+
return 1.0
|
|
405
|
+
x = df_within / (df_within + (df_between * f_statistic))
|
|
406
|
+
return clamp_probability(
|
|
407
|
+
regularized_incomplete_beta(x, df_within / 2.0, df_between / 2.0),
|
|
408
|
+
"F-distribution survival calculation",
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def compute_group_stats(values_by_group):
|
|
413
|
+
stats = {}
|
|
414
|
+
for group, values in values_by_group.items():
|
|
415
|
+
mean = sample_mean(values)
|
|
416
|
+
stats[group] = {
|
|
417
|
+
"n": len(values),
|
|
418
|
+
"mean": mean,
|
|
419
|
+
"variance": sample_variance(values, mean),
|
|
420
|
+
"values": values,
|
|
421
|
+
}
|
|
422
|
+
return stats
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def compute_anova(values_by_group, group_stats):
|
|
426
|
+
all_values = [
|
|
427
|
+
value for group in values_by_group.values() for value in group
|
|
428
|
+
]
|
|
429
|
+
total_n = len(all_values)
|
|
430
|
+
group_count = len(values_by_group)
|
|
431
|
+
grand_mean = sample_mean(all_values)
|
|
432
|
+
ss_between = 0.0
|
|
433
|
+
ss_within = 0.0
|
|
434
|
+
for group, stats in group_stats.items():
|
|
435
|
+
ss_between += stats["n"] * ((stats["mean"] - grand_mean) ** 2)
|
|
436
|
+
ss_within += sum(
|
|
437
|
+
(value - stats["mean"]) ** 2 for value in values_by_group[group]
|
|
438
|
+
)
|
|
439
|
+
df_between = group_count - 1
|
|
440
|
+
df_within = total_n - group_count
|
|
441
|
+
if df_within <= 0:
|
|
442
|
+
reject_submission("ANOVA requires positive within-group degrees of freedom.")
|
|
443
|
+
if not math.isfinite(ss_within) or ss_within <= 0.0:
|
|
444
|
+
reject_submission(
|
|
445
|
+
"One-way ANOVA requires positive within-group residual variance."
|
|
446
|
+
)
|
|
447
|
+
ms_between = ss_between / df_between
|
|
448
|
+
ms_within = ss_within / df_within
|
|
449
|
+
f_statistic = ms_between / ms_within
|
|
450
|
+
p_value = f_distribution_survival(f_statistic, df_between, df_within)
|
|
451
|
+
if not math.isfinite(p_value):
|
|
452
|
+
reject_submission("One-way ANOVA produced a non-finite p-value.")
|
|
453
|
+
return {
|
|
454
|
+
"grand_mean": grand_mean,
|
|
455
|
+
"ss_between": ss_between,
|
|
456
|
+
"ss_within": ss_within,
|
|
457
|
+
"df_between": df_between,
|
|
458
|
+
"df_within": df_within,
|
|
459
|
+
"ms_between": ms_between,
|
|
460
|
+
"ms_within": ms_within,
|
|
461
|
+
"f_statistic": f_statistic,
|
|
462
|
+
"p_value": p_value,
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def compute_welch_contrast(contrast, group_stats):
|
|
467
|
+
treatment_stats = group_stats[contrast["treatment_group"]]
|
|
468
|
+
control_stats = group_stats[contrast["control_group"]]
|
|
469
|
+
treatment_term = treatment_stats["variance"] / treatment_stats["n"]
|
|
470
|
+
control_term = control_stats["variance"] / control_stats["n"]
|
|
471
|
+
standard_error_squared = treatment_term + control_term
|
|
472
|
+
if not math.isfinite(standard_error_squared) or standard_error_squared <= 0.0:
|
|
473
|
+
reject_submission(
|
|
474
|
+
f"planned contrast {contrast['id']} requires positive within-group variance; observed zero pooled standard error."
|
|
475
|
+
)
|
|
476
|
+
effect = treatment_stats["mean"] - control_stats["mean"]
|
|
477
|
+
test_statistic = effect / math.sqrt(standard_error_squared)
|
|
478
|
+
denominator = 0.0
|
|
479
|
+
if treatment_term > 0.0:
|
|
480
|
+
denominator += (treatment_term * treatment_term) / (treatment_stats["n"] - 1)
|
|
481
|
+
if control_term > 0.0:
|
|
482
|
+
denominator += (control_term * control_term) / (control_stats["n"] - 1)
|
|
483
|
+
if denominator <= 0.0 or not math.isfinite(denominator):
|
|
484
|
+
reject_submission(
|
|
485
|
+
f"planned contrast {contrast['id']} requires positive variance in at least one group."
|
|
486
|
+
)
|
|
487
|
+
degrees_of_freedom = (standard_error_squared * standard_error_squared) / denominator
|
|
488
|
+
raw_p_value = two_sided_student_t_p_value(test_statistic, degrees_of_freedom)
|
|
489
|
+
return {
|
|
490
|
+
"id": contrast["id"],
|
|
491
|
+
"treatment_group": contrast["treatment_group"],
|
|
492
|
+
"control_group": contrast["control_group"],
|
|
493
|
+
"expected_direction": contrast["expected_direction"],
|
|
494
|
+
"treatment_mean": treatment_stats["mean"],
|
|
495
|
+
"control_mean": control_stats["mean"],
|
|
496
|
+
"n_treatment": treatment_stats["n"],
|
|
497
|
+
"n_control": control_stats["n"],
|
|
498
|
+
"effect": effect,
|
|
499
|
+
"test_statistic": test_statistic,
|
|
500
|
+
"degrees_of_freedom": degrees_of_freedom,
|
|
501
|
+
"raw_p_value": raw_p_value,
|
|
502
|
+
"direction_matched": direction_matches(effect, contrast["expected_direction"]),
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def direction_matches(effect, expected_direction):
|
|
507
|
+
if expected_direction == "treatment_lt_control":
|
|
508
|
+
return effect < 0.0
|
|
509
|
+
return effect > 0.0
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def adjust_p_values(raw_p_values, method):
|
|
513
|
+
count = len(raw_p_values)
|
|
514
|
+
if method == "bonferroni":
|
|
515
|
+
return [min(p_value * count, 1.0) for p_value in raw_p_values]
|
|
516
|
+
if method == "holm":
|
|
517
|
+
indexed = sorted(enumerate(raw_p_values), key=lambda item: item[1])
|
|
518
|
+
adjusted = [0.0] * count
|
|
519
|
+
running_max = 0.0
|
|
520
|
+
for rank, (original_index, p_value) in enumerate(indexed):
|
|
521
|
+
candidate = min((count - rank) * p_value, 1.0)
|
|
522
|
+
running_max = max(running_max, candidate)
|
|
523
|
+
adjusted[original_index] = running_max
|
|
524
|
+
return adjusted
|
|
525
|
+
fail_runtime(f"Unsupported p_adjustment {method}.")
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def compute_contrasts(params, group_stats):
|
|
529
|
+
contrast_results = [
|
|
530
|
+
compute_welch_contrast(contrast, group_stats)
|
|
531
|
+
for contrast in params["contrasts"]
|
|
532
|
+
]
|
|
533
|
+
adjusted_p_values = adjust_p_values(
|
|
534
|
+
[result["raw_p_value"] for result in contrast_results],
|
|
535
|
+
params["p_adjustment"],
|
|
536
|
+
)
|
|
537
|
+
for result, adjusted_p_value in zip(contrast_results, adjusted_p_values):
|
|
538
|
+
result["adjusted_p_value"] = adjusted_p_value
|
|
539
|
+
result["passed"] = (
|
|
540
|
+
adjusted_p_value < params["alpha"] and result["direction_matched"]
|
|
541
|
+
)
|
|
542
|
+
return contrast_results
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def main():
|
|
546
|
+
runtime_context = load_runtime_context()
|
|
547
|
+
config_path = resolve_scoring_asset(
|
|
548
|
+
runtime_context,
|
|
549
|
+
"compiled_config",
|
|
550
|
+
kind="config",
|
|
551
|
+
)
|
|
552
|
+
try:
|
|
553
|
+
config = load_json_file(config_path, label="compiled_config")
|
|
554
|
+
except RuntimeError as error:
|
|
555
|
+
fail_runtime(str(error))
|
|
556
|
+
if not isinstance(config, dict):
|
|
557
|
+
fail_runtime("compiled_config must be a JSON object.")
|
|
558
|
+
|
|
559
|
+
submission_role = require_string(
|
|
560
|
+
config.get("submission_role"),
|
|
561
|
+
"compiled_config.submission_role",
|
|
562
|
+
)
|
|
563
|
+
final_score_key = require_string(
|
|
564
|
+
config.get("final_score_key"),
|
|
565
|
+
"compiled_config.final_score_key",
|
|
566
|
+
)
|
|
567
|
+
objective = require_string(
|
|
568
|
+
runtime_context.get("objective"),
|
|
569
|
+
"runtime_context.objective",
|
|
570
|
+
)
|
|
571
|
+
if objective != "maximize":
|
|
572
|
+
fail_runtime("multi_group_endpoint_test@1 requires objective=maximize.")
|
|
573
|
+
|
|
574
|
+
params = load_endpoint_test_config(config)
|
|
575
|
+
require_observations_slot(runtime_context, submission_role, params)
|
|
576
|
+
observations_path = resolve_submission_artifact(runtime_context, submission_role)
|
|
577
|
+
values_by_group = load_observations(
|
|
578
|
+
observations_path,
|
|
579
|
+
submission_role,
|
|
580
|
+
params,
|
|
581
|
+
)
|
|
582
|
+
require_group_sizes(values_by_group, params["minimum_n_per_group"])
|
|
583
|
+
|
|
584
|
+
group_stats = compute_group_stats(values_by_group)
|
|
585
|
+
anova = compute_anova(values_by_group, group_stats)
|
|
586
|
+
contrast_results = compute_contrasts(params, group_stats)
|
|
587
|
+
passed_contrast_count = sum(1 for result in contrast_results if result["passed"])
|
|
588
|
+
total_contrast_count = len(contrast_results)
|
|
589
|
+
minimum_group_n = min(stats["n"] for stats in group_stats.values())
|
|
590
|
+
maximum_adjusted_p_value = max(
|
|
591
|
+
result["adjusted_p_value"] for result in contrast_results
|
|
592
|
+
)
|
|
593
|
+
omnibus_passed = anova["p_value"] < params["alpha"]
|
|
594
|
+
all_contrasts_passed = passed_contrast_count == total_contrast_count
|
|
595
|
+
score = 1.0 if omnibus_passed and all_contrasts_passed else 0.0
|
|
596
|
+
details = {
|
|
597
|
+
final_score_key: score,
|
|
598
|
+
"score": score,
|
|
599
|
+
"omnibus_p_value": anova["p_value"],
|
|
600
|
+
"omnibus_f_statistic": anova["f_statistic"],
|
|
601
|
+
"df_between": anova["df_between"],
|
|
602
|
+
"df_within": anova["df_within"],
|
|
603
|
+
"passed_contrast_count": passed_contrast_count,
|
|
604
|
+
"total_contrast_count": total_contrast_count,
|
|
605
|
+
"minimum_group_n": minimum_group_n,
|
|
606
|
+
"maximum_adjusted_p_value": maximum_adjusted_p_value,
|
|
607
|
+
"omnibus_passed": omnibus_passed,
|
|
608
|
+
"alpha": params["alpha"],
|
|
609
|
+
"minimum_n_per_group": params["minimum_n_per_group"],
|
|
610
|
+
"p_adjustment": params["p_adjustment"],
|
|
611
|
+
"group_column": params["group_column"],
|
|
612
|
+
"value_column": params["value_column"],
|
|
613
|
+
"included_groups": params["included_groups"],
|
|
614
|
+
"group_stats": {
|
|
615
|
+
group: {
|
|
616
|
+
"n": stats["n"],
|
|
617
|
+
"mean": stats["mean"],
|
|
618
|
+
"variance": stats["variance"],
|
|
619
|
+
}
|
|
620
|
+
for group, stats in group_stats.items()
|
|
621
|
+
},
|
|
622
|
+
"anova": anova,
|
|
623
|
+
"contrasts": contrast_results,
|
|
624
|
+
}
|
|
625
|
+
write_score(score=score, details=details)
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
if __name__ == "__main__":
|
|
629
|
+
main()
|