evalsense 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +678 -0
- package/bin/evalsense.js +3 -0
- package/dist/chunk-5P7LNNO6.js +747 -0
- package/dist/chunk-5P7LNNO6.js.map +1 -0
- package/dist/chunk-BRPM6AB6.js +925 -0
- package/dist/chunk-BRPM6AB6.js.map +1 -0
- package/dist/chunk-HDJID3GC.cjs +779 -0
- package/dist/chunk-HDJID3GC.cjs.map +1 -0
- package/dist/chunk-Y23VHTD3.cjs +942 -0
- package/dist/chunk-Y23VHTD3.cjs.map +1 -0
- package/dist/cli.cjs +65 -0
- package/dist/cli.cjs.map +1 -0
- package/dist/cli.d.cts +1 -0
- package/dist/cli.d.ts +1 -0
- package/dist/cli.js +63 -0
- package/dist/cli.js.map +1 -0
- package/dist/index.cjs +1126 -0
- package/dist/index.cjs.map +1 -0
- package/dist/index.d.cts +604 -0
- package/dist/index.d.ts +604 -0
- package/dist/index.js +1043 -0
- package/dist/index.js.map +1 -0
- package/dist/metrics/index.cjs +275 -0
- package/dist/metrics/index.cjs.map +1 -0
- package/dist/metrics/index.d.cts +299 -0
- package/dist/metrics/index.d.ts +299 -0
- package/dist/metrics/index.js +191 -0
- package/dist/metrics/index.js.map +1 -0
- package/dist/metrics/opinionated/index.cjs +24 -0
- package/dist/metrics/opinionated/index.cjs.map +1 -0
- package/dist/metrics/opinionated/index.d.cts +163 -0
- package/dist/metrics/opinionated/index.d.ts +163 -0
- package/dist/metrics/opinionated/index.js +3 -0
- package/dist/metrics/opinionated/index.js.map +1 -0
- package/dist/types-C71p0wzM.d.cts +265 -0
- package/dist/types-C71p0wzM.d.ts +265 -0
- package/package.json +91 -0
package/dist/index.cjs
ADDED
|
@@ -0,0 +1,1126 @@
|
|
|
1
|
+
'use strict';
|
|
2
|
+
|
|
3
|
+
var chunkHDJID3GC_cjs = require('./chunk-HDJID3GC.cjs');
|
|
4
|
+
var fs = require('fs');
|
|
5
|
+
var path = require('path');
|
|
6
|
+
|
|
7
|
+
// src/core/describe.ts
|
|
8
|
+
function describe(name, fn) {
|
|
9
|
+
const parentSuite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
10
|
+
const suite = {
|
|
11
|
+
name,
|
|
12
|
+
tests: [],
|
|
13
|
+
beforeAll: [],
|
|
14
|
+
afterAll: [],
|
|
15
|
+
beforeEach: [],
|
|
16
|
+
afterEach: []
|
|
17
|
+
};
|
|
18
|
+
chunkHDJID3GC_cjs.setCurrentSuite(suite);
|
|
19
|
+
try {
|
|
20
|
+
fn();
|
|
21
|
+
} finally {
|
|
22
|
+
chunkHDJID3GC_cjs.setCurrentSuite(parentSuite);
|
|
23
|
+
}
|
|
24
|
+
chunkHDJID3GC_cjs.addSuite(suite);
|
|
25
|
+
}
|
|
26
|
+
function beforeAll(fn) {
|
|
27
|
+
const suite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
28
|
+
if (!suite) {
|
|
29
|
+
throw new Error("beforeAll() must be called inside a describe() block");
|
|
30
|
+
}
|
|
31
|
+
suite.beforeAll?.push(fn);
|
|
32
|
+
}
|
|
33
|
+
function afterAll(fn) {
|
|
34
|
+
const suite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
35
|
+
if (!suite) {
|
|
36
|
+
throw new Error("afterAll() must be called inside a describe() block");
|
|
37
|
+
}
|
|
38
|
+
suite.afterAll?.push(fn);
|
|
39
|
+
}
|
|
40
|
+
function beforeEach(fn) {
|
|
41
|
+
const suite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
42
|
+
if (!suite) {
|
|
43
|
+
throw new Error("beforeEach() must be called inside a describe() block");
|
|
44
|
+
}
|
|
45
|
+
suite.beforeEach?.push(fn);
|
|
46
|
+
}
|
|
47
|
+
function afterEach(fn) {
|
|
48
|
+
const suite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
49
|
+
if (!suite) {
|
|
50
|
+
throw new Error("afterEach() must be called inside a describe() block");
|
|
51
|
+
}
|
|
52
|
+
suite.afterEach?.push(fn);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
// src/core/eval-test.ts
|
|
56
|
+
function evalTest(name, fn) {
|
|
57
|
+
const currentSuite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
58
|
+
if (!currentSuite) {
|
|
59
|
+
throw new Error("evalTest() must be called inside a describe() block");
|
|
60
|
+
}
|
|
61
|
+
const test2 = {
|
|
62
|
+
name,
|
|
63
|
+
fn
|
|
64
|
+
};
|
|
65
|
+
chunkHDJID3GC_cjs.addTestToCurrentSuite(test2);
|
|
66
|
+
}
|
|
67
|
+
var test = evalTest;
|
|
68
|
+
var it = evalTest;
|
|
69
|
+
function evalTestSkip(name, _fn) {
|
|
70
|
+
const currentSuite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
71
|
+
if (!currentSuite) {
|
|
72
|
+
throw new Error("evalTest.skip() must be called inside a describe() block");
|
|
73
|
+
}
|
|
74
|
+
const test2 = {
|
|
75
|
+
name: `[SKIPPED] ${name}`,
|
|
76
|
+
fn: async () => {
|
|
77
|
+
}
|
|
78
|
+
};
|
|
79
|
+
chunkHDJID3GC_cjs.addTestToCurrentSuite(test2);
|
|
80
|
+
}
|
|
81
|
+
function evalTestOnly(name, fn) {
|
|
82
|
+
const currentSuite = chunkHDJID3GC_cjs.getCurrentSuite();
|
|
83
|
+
if (!currentSuite) {
|
|
84
|
+
throw new Error("evalTest.only() must be called inside a describe() block");
|
|
85
|
+
}
|
|
86
|
+
const test2 = {
|
|
87
|
+
name: `[ONLY] ${name}`,
|
|
88
|
+
fn
|
|
89
|
+
};
|
|
90
|
+
chunkHDJID3GC_cjs.addTestToCurrentSuite(test2);
|
|
91
|
+
}
|
|
92
|
+
evalTest.skip = evalTestSkip;
|
|
93
|
+
evalTest.only = evalTestOnly;
|
|
94
|
+
function loadDataset(path$1) {
|
|
95
|
+
const absolutePath = path.resolve(process.cwd(), path$1);
|
|
96
|
+
const ext = path.extname(absolutePath).toLowerCase();
|
|
97
|
+
let records;
|
|
98
|
+
try {
|
|
99
|
+
const content = fs.readFileSync(absolutePath, "utf-8");
|
|
100
|
+
if (ext === ".ndjson" || ext === ".jsonl") {
|
|
101
|
+
records = parseNDJSON(content);
|
|
102
|
+
} else if (ext === ".json") {
|
|
103
|
+
records = parseJSON(content);
|
|
104
|
+
} else {
|
|
105
|
+
throw new chunkHDJID3GC_cjs.DatasetError(
|
|
106
|
+
`Unsupported file format: ${ext}. Use .json, .ndjson, or .jsonl`,
|
|
107
|
+
path$1
|
|
108
|
+
);
|
|
109
|
+
}
|
|
110
|
+
} catch (error) {
|
|
111
|
+
if (error instanceof chunkHDJID3GC_cjs.DatasetError) {
|
|
112
|
+
throw error;
|
|
113
|
+
}
|
|
114
|
+
const message = error instanceof Error ? error.message : String(error);
|
|
115
|
+
throw new chunkHDJID3GC_cjs.DatasetError(`Failed to load dataset from ${path$1}: ${message}`, path$1);
|
|
116
|
+
}
|
|
117
|
+
return {
|
|
118
|
+
records,
|
|
119
|
+
metadata: {
|
|
120
|
+
source: path$1,
|
|
121
|
+
count: records.length,
|
|
122
|
+
loadedAt: /* @__PURE__ */ new Date()
|
|
123
|
+
}
|
|
124
|
+
};
|
|
125
|
+
}
|
|
126
|
+
function parseJSON(content) {
|
|
127
|
+
const parsed = JSON.parse(content);
|
|
128
|
+
if (!Array.isArray(parsed)) {
|
|
129
|
+
throw new chunkHDJID3GC_cjs.DatasetError("JSON dataset must be an array of records");
|
|
130
|
+
}
|
|
131
|
+
return parsed;
|
|
132
|
+
}
|
|
133
|
+
function parseNDJSON(content) {
|
|
134
|
+
const lines = content.split("\n").filter((line) => line.trim() !== "");
|
|
135
|
+
const records = [];
|
|
136
|
+
for (let i = 0; i < lines.length; i++) {
|
|
137
|
+
const line = lines[i];
|
|
138
|
+
if (line === void 0) continue;
|
|
139
|
+
try {
|
|
140
|
+
records.push(JSON.parse(line));
|
|
141
|
+
} catch {
|
|
142
|
+
throw new chunkHDJID3GC_cjs.DatasetError(`Invalid JSON at line ${i + 1} in NDJSON file`);
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
return records;
|
|
146
|
+
}
|
|
147
|
+
function createDataset(records, source = "inline") {
|
|
148
|
+
return {
|
|
149
|
+
records,
|
|
150
|
+
metadata: {
|
|
151
|
+
source,
|
|
152
|
+
count: records.length,
|
|
153
|
+
loadedAt: /* @__PURE__ */ new Date()
|
|
154
|
+
}
|
|
155
|
+
};
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
// src/dataset/run-model.ts
|
|
159
|
+
async function runModel(dataset, modelFn) {
|
|
160
|
+
const startTime = Date.now();
|
|
161
|
+
const predictions = [];
|
|
162
|
+
const aligned = [];
|
|
163
|
+
for (const record of dataset.records) {
|
|
164
|
+
const id = getRecordId(record);
|
|
165
|
+
const prediction = await modelFn(record);
|
|
166
|
+
if (prediction.id !== id) {
|
|
167
|
+
throw new chunkHDJID3GC_cjs.DatasetError(
|
|
168
|
+
`Prediction ID mismatch: expected "${id}", got "${prediction.id}". Model function must return the same ID as the input record.`
|
|
169
|
+
);
|
|
170
|
+
}
|
|
171
|
+
predictions.push(prediction);
|
|
172
|
+
aligned.push({
|
|
173
|
+
id,
|
|
174
|
+
actual: { ...prediction },
|
|
175
|
+
expected: { ...record }
|
|
176
|
+
});
|
|
177
|
+
}
|
|
178
|
+
return {
|
|
179
|
+
predictions,
|
|
180
|
+
aligned,
|
|
181
|
+
duration: Date.now() - startTime
|
|
182
|
+
};
|
|
183
|
+
}
|
|
184
|
+
function getRecordId(record) {
|
|
185
|
+
const id = record.id ?? record._id;
|
|
186
|
+
if (id === void 0 || id === null) {
|
|
187
|
+
throw new chunkHDJID3GC_cjs.DatasetError(
|
|
188
|
+
'Dataset records must have an "id" or "_id" field for alignment'
|
|
189
|
+
);
|
|
190
|
+
}
|
|
191
|
+
return String(id);
|
|
192
|
+
}
|
|
193
|
+
async function runModelParallel(dataset, modelFn, concurrency = 10) {
|
|
194
|
+
const startTime = Date.now();
|
|
195
|
+
const results = [];
|
|
196
|
+
for (let i = 0; i < dataset.records.length; i += concurrency) {
|
|
197
|
+
const batch = dataset.records.slice(i, i + concurrency);
|
|
198
|
+
const batchResults = await Promise.all(
|
|
199
|
+
batch.map(async (record) => {
|
|
200
|
+
const prediction = await modelFn(record);
|
|
201
|
+
return { prediction, record };
|
|
202
|
+
})
|
|
203
|
+
);
|
|
204
|
+
results.push(...batchResults);
|
|
205
|
+
}
|
|
206
|
+
const predictions = [];
|
|
207
|
+
const aligned = [];
|
|
208
|
+
for (const { prediction, record } of results) {
|
|
209
|
+
const id = getRecordId(record);
|
|
210
|
+
if (prediction.id !== id) {
|
|
211
|
+
throw new chunkHDJID3GC_cjs.DatasetError(
|
|
212
|
+
`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`
|
|
213
|
+
);
|
|
214
|
+
}
|
|
215
|
+
predictions.push(prediction);
|
|
216
|
+
aligned.push({
|
|
217
|
+
id,
|
|
218
|
+
actual: { ...prediction },
|
|
219
|
+
expected: { ...record }
|
|
220
|
+
});
|
|
221
|
+
}
|
|
222
|
+
return {
|
|
223
|
+
predictions,
|
|
224
|
+
aligned,
|
|
225
|
+
duration: Date.now() - startTime
|
|
226
|
+
};
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
// src/dataset/alignment.ts
|
|
230
|
+
function alignByKey(predictions, expected, options = {}) {
|
|
231
|
+
const { strict = false, idField = "id" } = options;
|
|
232
|
+
const expectedMap = /* @__PURE__ */ new Map();
|
|
233
|
+
for (const record of expected) {
|
|
234
|
+
const id = String(record[idField] ?? record._id);
|
|
235
|
+
expectedMap.set(id, record);
|
|
236
|
+
}
|
|
237
|
+
const aligned = [];
|
|
238
|
+
const missingIds = [];
|
|
239
|
+
for (const prediction of predictions) {
|
|
240
|
+
const id = prediction.id;
|
|
241
|
+
const expectedRecord = expectedMap.get(id);
|
|
242
|
+
if (!expectedRecord) {
|
|
243
|
+
missingIds.push(id);
|
|
244
|
+
if (strict) {
|
|
245
|
+
continue;
|
|
246
|
+
}
|
|
247
|
+
aligned.push({
|
|
248
|
+
id,
|
|
249
|
+
actual: { ...prediction },
|
|
250
|
+
expected: {}
|
|
251
|
+
});
|
|
252
|
+
} else {
|
|
253
|
+
aligned.push({
|
|
254
|
+
id,
|
|
255
|
+
actual: { ...prediction },
|
|
256
|
+
expected: { ...expectedRecord }
|
|
257
|
+
});
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
if (strict && missingIds.length > 0) {
|
|
261
|
+
throw new chunkHDJID3GC_cjs.IntegrityError(
|
|
262
|
+
`${missingIds.length} prediction(s) have no matching expected record`,
|
|
263
|
+
missingIds
|
|
264
|
+
);
|
|
265
|
+
}
|
|
266
|
+
return aligned;
|
|
267
|
+
}
|
|
268
|
+
function extractFieldValues(aligned, field) {
|
|
269
|
+
const actual = [];
|
|
270
|
+
const expected = [];
|
|
271
|
+
const ids = [];
|
|
272
|
+
for (const record of aligned) {
|
|
273
|
+
actual.push(record.actual[field]);
|
|
274
|
+
expected.push(record.expected[field]);
|
|
275
|
+
ids.push(record.id);
|
|
276
|
+
}
|
|
277
|
+
return { actual, expected, ids };
|
|
278
|
+
}
|
|
279
|
+
function filterComplete(aligned, field) {
|
|
280
|
+
return aligned.filter((record) => {
|
|
281
|
+
const actualValue = record.actual[field];
|
|
282
|
+
const expectedValue = record.expected[field];
|
|
283
|
+
return actualValue !== void 0 && expectedValue !== void 0;
|
|
284
|
+
});
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
// src/dataset/integrity.ts
|
|
288
|
+
function checkIntegrity(dataset, options = {}) {
|
|
289
|
+
const { requiredFields = [], throwOnFailure = false } = options;
|
|
290
|
+
const seenIds = /* @__PURE__ */ new Map();
|
|
291
|
+
const missingIds = [];
|
|
292
|
+
const duplicateIds = [];
|
|
293
|
+
const missingFields = [];
|
|
294
|
+
for (let i = 0; i < dataset.records.length; i++) {
|
|
295
|
+
const record = dataset.records[i];
|
|
296
|
+
if (!record) continue;
|
|
297
|
+
const id = record.id ?? record._id;
|
|
298
|
+
if (id === void 0 || id === null) {
|
|
299
|
+
missingIds.push(`record[${i}]`);
|
|
300
|
+
} else {
|
|
301
|
+
const idStr = String(id);
|
|
302
|
+
const previousIndex = seenIds.get(idStr);
|
|
303
|
+
if (previousIndex !== void 0) {
|
|
304
|
+
duplicateIds.push(idStr);
|
|
305
|
+
} else {
|
|
306
|
+
seenIds.set(idStr, i);
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
if (requiredFields.length > 0) {
|
|
310
|
+
const missing = requiredFields.filter(
|
|
311
|
+
(field) => record[field] === void 0
|
|
312
|
+
);
|
|
313
|
+
if (missing.length > 0) {
|
|
314
|
+
missingFields.push({
|
|
315
|
+
id: String(id ?? `record[${i}]`),
|
|
316
|
+
fields: missing
|
|
317
|
+
});
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
const valid = missingIds.length === 0 && duplicateIds.length === 0 && missingFields.length === 0;
|
|
322
|
+
const result = {
|
|
323
|
+
valid,
|
|
324
|
+
totalRecords: dataset.records.length,
|
|
325
|
+
missingIds,
|
|
326
|
+
duplicateIds,
|
|
327
|
+
missingFields
|
|
328
|
+
};
|
|
329
|
+
if (throwOnFailure && !valid) {
|
|
330
|
+
const issues = [];
|
|
331
|
+
if (missingIds.length > 0) {
|
|
332
|
+
issues.push(`${missingIds.length} record(s) missing ID`);
|
|
333
|
+
}
|
|
334
|
+
if (duplicateIds.length > 0) {
|
|
335
|
+
issues.push(`${duplicateIds.length} duplicate ID(s): ${duplicateIds.slice(0, 3).join(", ")}${duplicateIds.length > 3 ? "..." : ""}`);
|
|
336
|
+
}
|
|
337
|
+
if (missingFields.length > 0) {
|
|
338
|
+
issues.push(`${missingFields.length} record(s) missing required fields`);
|
|
339
|
+
}
|
|
340
|
+
throw new chunkHDJID3GC_cjs.IntegrityError(`Dataset integrity check failed: ${issues.join("; ")}`);
|
|
341
|
+
}
|
|
342
|
+
return result;
|
|
343
|
+
}
|
|
344
|
+
function validatePredictions(predictions, expectedIds) {
|
|
345
|
+
const predictionIds = new Set(predictions.map((p) => p.id));
|
|
346
|
+
const expectedIdSet = new Set(expectedIds);
|
|
347
|
+
const missing = expectedIds.filter((id) => !predictionIds.has(id));
|
|
348
|
+
const extra = predictions.map((p) => p.id).filter((id) => !expectedIdSet.has(id));
|
|
349
|
+
return {
|
|
350
|
+
valid: missing.length === 0 && extra.length === 0,
|
|
351
|
+
missing,
|
|
352
|
+
extra
|
|
353
|
+
};
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
// src/statistics/classification.ts
|
|
357
|
+
function computeClassificationMetrics(actual, expected) {
|
|
358
|
+
const confusionMatrix = chunkHDJID3GC_cjs.buildConfusionMatrix(actual, expected);
|
|
359
|
+
return computeMetricsFromMatrix(confusionMatrix);
|
|
360
|
+
}
|
|
361
|
+
function computeMetricsFromMatrix(cm) {
|
|
362
|
+
const perClass = {};
|
|
363
|
+
let totalSupport = 0;
|
|
364
|
+
let correctPredictions = 0;
|
|
365
|
+
for (const label of cm.labels) {
|
|
366
|
+
const tp = chunkHDJID3GC_cjs.getTruePositives(cm, label);
|
|
367
|
+
const fp = chunkHDJID3GC_cjs.getFalsePositives(cm, label);
|
|
368
|
+
const fn = chunkHDJID3GC_cjs.getFalseNegatives(cm, label);
|
|
369
|
+
const support = chunkHDJID3GC_cjs.getSupport(cm, label);
|
|
370
|
+
const precision = tp + fp > 0 ? tp / (tp + fp) : 0;
|
|
371
|
+
const recall = tp + fn > 0 ? tp / (tp + fn) : 0;
|
|
372
|
+
const f1 = precision + recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
|
|
373
|
+
perClass[label] = { precision, recall, f1, support };
|
|
374
|
+
totalSupport += support;
|
|
375
|
+
correctPredictions += tp;
|
|
376
|
+
}
|
|
377
|
+
const accuracy = totalSupport > 0 ? correctPredictions / totalSupport : 0;
|
|
378
|
+
const classCount = cm.labels.length;
|
|
379
|
+
const macroAvg = {
|
|
380
|
+
precision: classCount > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.precision, 0) / classCount : 0,
|
|
381
|
+
recall: classCount > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.recall, 0) / classCount : 0,
|
|
382
|
+
f1: classCount > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.f1, 0) / classCount : 0
|
|
383
|
+
};
|
|
384
|
+
const weightedAvg = {
|
|
385
|
+
precision: totalSupport > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.precision * m.support, 0) / totalSupport : 0,
|
|
386
|
+
recall: totalSupport > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.recall * m.support, 0) / totalSupport : 0,
|
|
387
|
+
f1: totalSupport > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.f1 * m.support, 0) / totalSupport : 0
|
|
388
|
+
};
|
|
389
|
+
return {
|
|
390
|
+
accuracy,
|
|
391
|
+
perClass,
|
|
392
|
+
macroAvg,
|
|
393
|
+
weightedAvg,
|
|
394
|
+
confusionMatrix: cm
|
|
395
|
+
};
|
|
396
|
+
}
|
|
397
|
+
function computePrecision(actual, expected, targetClass) {
|
|
398
|
+
const cm = chunkHDJID3GC_cjs.buildConfusionMatrix(actual, expected);
|
|
399
|
+
const tp = chunkHDJID3GC_cjs.getTruePositives(cm, targetClass);
|
|
400
|
+
const fp = chunkHDJID3GC_cjs.getFalsePositives(cm, targetClass);
|
|
401
|
+
return tp + fp > 0 ? tp / (tp + fp) : 0;
|
|
402
|
+
}
|
|
403
|
+
function computeRecall(actual, expected, targetClass) {
|
|
404
|
+
const cm = chunkHDJID3GC_cjs.buildConfusionMatrix(actual, expected);
|
|
405
|
+
const tp = chunkHDJID3GC_cjs.getTruePositives(cm, targetClass);
|
|
406
|
+
const fn = chunkHDJID3GC_cjs.getFalseNegatives(cm, targetClass);
|
|
407
|
+
return tp + fn > 0 ? tp / (tp + fn) : 0;
|
|
408
|
+
}
|
|
409
|
+
function computeF1(actual, expected, targetClass) {
|
|
410
|
+
const precision = computePrecision(actual, expected, targetClass);
|
|
411
|
+
const recall = computeRecall(actual, expected, targetClass);
|
|
412
|
+
return precision + recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
|
|
413
|
+
}
|
|
414
|
+
function computeAccuracy(actual, expected) {
|
|
415
|
+
if (actual.length !== expected.length || actual.length === 0) {
|
|
416
|
+
return 0;
|
|
417
|
+
}
|
|
418
|
+
let correct = 0;
|
|
419
|
+
let total = 0;
|
|
420
|
+
for (let i = 0; i < actual.length; i++) {
|
|
421
|
+
const a = actual[i];
|
|
422
|
+
const e = expected[i];
|
|
423
|
+
if (a !== void 0 && a !== null && e !== void 0 && e !== null) {
|
|
424
|
+
total++;
|
|
425
|
+
if (String(a) === String(e)) {
|
|
426
|
+
correct++;
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
return total > 0 ? correct / total : 0;
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
// src/statistics/distribution.ts
|
|
434
|
+
function filterNumericValues(values) {
|
|
435
|
+
return values.filter(
|
|
436
|
+
(v) => typeof v === "number" && !Number.isNaN(v) && v !== null && v !== void 0
|
|
437
|
+
);
|
|
438
|
+
}
|
|
439
|
+
function calculatePercentageBelow(values, threshold) {
|
|
440
|
+
if (values.length === 0) {
|
|
441
|
+
return 0;
|
|
442
|
+
}
|
|
443
|
+
const countBelow = values.filter((v) => v <= threshold).length;
|
|
444
|
+
return countBelow / values.length;
|
|
445
|
+
}
|
|
446
|
+
function calculatePercentageAbove(values, threshold) {
|
|
447
|
+
if (values.length === 0) {
|
|
448
|
+
return 0;
|
|
449
|
+
}
|
|
450
|
+
const countAbove = values.filter((v) => v > threshold).length;
|
|
451
|
+
return countAbove / values.length;
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
// src/assertions/binarize.ts
|
|
455
|
+
var BinarizeSelector = class {
|
|
456
|
+
fieldName;
|
|
457
|
+
threshold;
|
|
458
|
+
binaryActual;
|
|
459
|
+
binaryExpected;
|
|
460
|
+
assertions = [];
|
|
461
|
+
constructor(aligned, fieldName, threshold) {
|
|
462
|
+
this.fieldName = fieldName;
|
|
463
|
+
this.threshold = threshold;
|
|
464
|
+
this.binaryActual = [];
|
|
465
|
+
this.binaryExpected = [];
|
|
466
|
+
for (const record of aligned) {
|
|
467
|
+
const actualVal = record.actual[fieldName];
|
|
468
|
+
const expectedVal = record.expected[fieldName];
|
|
469
|
+
if (typeof actualVal === "number") {
|
|
470
|
+
this.binaryActual.push(actualVal >= threshold ? "true" : "false");
|
|
471
|
+
} else if (typeof actualVal === "boolean") {
|
|
472
|
+
this.binaryActual.push(String(actualVal));
|
|
473
|
+
} else {
|
|
474
|
+
this.binaryActual.push(String(actualVal));
|
|
475
|
+
}
|
|
476
|
+
if (typeof expectedVal === "number") {
|
|
477
|
+
this.binaryExpected.push(expectedVal >= threshold ? "true" : "false");
|
|
478
|
+
} else if (typeof expectedVal === "boolean") {
|
|
479
|
+
this.binaryExpected.push(String(expectedVal));
|
|
480
|
+
} else {
|
|
481
|
+
this.binaryExpected.push(String(expectedVal));
|
|
482
|
+
}
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
/**
|
|
486
|
+
* Asserts that accuracy is above a threshold
|
|
487
|
+
*/
|
|
488
|
+
toHaveAccuracyAbove(threshold) {
|
|
489
|
+
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
490
|
+
const passed = metrics.accuracy >= threshold;
|
|
491
|
+
const result = {
|
|
492
|
+
type: "accuracy",
|
|
493
|
+
passed,
|
|
494
|
+
message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})`,
|
|
495
|
+
expected: threshold,
|
|
496
|
+
actual: metrics.accuracy,
|
|
497
|
+
field: this.fieldName
|
|
498
|
+
};
|
|
499
|
+
this.assertions.push(result);
|
|
500
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
501
|
+
if (!passed) {
|
|
502
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
|
|
503
|
+
}
|
|
504
|
+
return this;
|
|
505
|
+
}
|
|
506
|
+
/**
|
|
507
|
+
* Asserts that precision is above a threshold
|
|
508
|
+
* @param classOrThreshold - Either the class (true/false) or threshold
|
|
509
|
+
* @param threshold - Threshold when class is specified
|
|
510
|
+
*/
|
|
511
|
+
toHavePrecisionAbove(classOrThreshold, threshold) {
|
|
512
|
+
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
513
|
+
let actualPrecision;
|
|
514
|
+
let targetClass;
|
|
515
|
+
let actualThreshold;
|
|
516
|
+
if (typeof classOrThreshold === "number") {
|
|
517
|
+
actualPrecision = metrics.macroAvg.precision;
|
|
518
|
+
actualThreshold = classOrThreshold;
|
|
519
|
+
} else {
|
|
520
|
+
targetClass = String(classOrThreshold);
|
|
521
|
+
actualThreshold = threshold;
|
|
522
|
+
const classMetrics = metrics.perClass[targetClass];
|
|
523
|
+
if (!classMetrics) {
|
|
524
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
525
|
+
`Class "${targetClass}" not found in binarized predictions`,
|
|
526
|
+
targetClass,
|
|
527
|
+
Object.keys(metrics.perClass),
|
|
528
|
+
this.fieldName
|
|
529
|
+
);
|
|
530
|
+
}
|
|
531
|
+
actualPrecision = classMetrics.precision;
|
|
532
|
+
}
|
|
533
|
+
const passed = actualPrecision >= actualThreshold;
|
|
534
|
+
const result = {
|
|
535
|
+
type: "precision",
|
|
536
|
+
passed,
|
|
537
|
+
message: passed ? `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
538
|
+
expected: actualThreshold,
|
|
539
|
+
actual: actualPrecision,
|
|
540
|
+
field: this.fieldName,
|
|
541
|
+
class: targetClass
|
|
542
|
+
};
|
|
543
|
+
this.assertions.push(result);
|
|
544
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
545
|
+
if (!passed) {
|
|
546
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
|
|
547
|
+
}
|
|
548
|
+
return this;
|
|
549
|
+
}
|
|
550
|
+
/**
|
|
551
|
+
* Asserts that recall is above a threshold
|
|
552
|
+
* @param classOrThreshold - Either the class (true/false) or threshold
|
|
553
|
+
* @param threshold - Threshold when class is specified
|
|
554
|
+
*/
|
|
555
|
+
toHaveRecallAbove(classOrThreshold, threshold) {
|
|
556
|
+
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
557
|
+
let actualRecall;
|
|
558
|
+
let targetClass;
|
|
559
|
+
let actualThreshold;
|
|
560
|
+
if (typeof classOrThreshold === "number") {
|
|
561
|
+
actualRecall = metrics.macroAvg.recall;
|
|
562
|
+
actualThreshold = classOrThreshold;
|
|
563
|
+
} else {
|
|
564
|
+
targetClass = String(classOrThreshold);
|
|
565
|
+
actualThreshold = threshold;
|
|
566
|
+
const classMetrics = metrics.perClass[targetClass];
|
|
567
|
+
if (!classMetrics) {
|
|
568
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
569
|
+
`Class "${targetClass}" not found in binarized predictions`,
|
|
570
|
+
targetClass,
|
|
571
|
+
Object.keys(metrics.perClass),
|
|
572
|
+
this.fieldName
|
|
573
|
+
);
|
|
574
|
+
}
|
|
575
|
+
actualRecall = classMetrics.recall;
|
|
576
|
+
}
|
|
577
|
+
const passed = actualRecall >= actualThreshold;
|
|
578
|
+
const result = {
|
|
579
|
+
type: "recall",
|
|
580
|
+
passed,
|
|
581
|
+
message: passed ? `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
582
|
+
expected: actualThreshold,
|
|
583
|
+
actual: actualRecall,
|
|
584
|
+
field: this.fieldName,
|
|
585
|
+
class: targetClass
|
|
586
|
+
};
|
|
587
|
+
this.assertions.push(result);
|
|
588
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
589
|
+
if (!passed) {
|
|
590
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
|
|
591
|
+
}
|
|
592
|
+
return this;
|
|
593
|
+
}
|
|
594
|
+
/**
|
|
595
|
+
* Asserts that F1 score is above a threshold
|
|
596
|
+
*/
|
|
597
|
+
toHaveF1Above(classOrThreshold, threshold) {
|
|
598
|
+
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
599
|
+
let actualF1;
|
|
600
|
+
let targetClass;
|
|
601
|
+
let actualThreshold;
|
|
602
|
+
if (typeof classOrThreshold === "number") {
|
|
603
|
+
actualF1 = metrics.macroAvg.f1;
|
|
604
|
+
actualThreshold = classOrThreshold;
|
|
605
|
+
} else {
|
|
606
|
+
targetClass = String(classOrThreshold);
|
|
607
|
+
actualThreshold = threshold;
|
|
608
|
+
const classMetrics = metrics.perClass[targetClass];
|
|
609
|
+
if (!classMetrics) {
|
|
610
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
611
|
+
`Class "${targetClass}" not found in binarized predictions`,
|
|
612
|
+
targetClass,
|
|
613
|
+
Object.keys(metrics.perClass),
|
|
614
|
+
this.fieldName
|
|
615
|
+
);
|
|
616
|
+
}
|
|
617
|
+
actualF1 = classMetrics.f1;
|
|
618
|
+
}
|
|
619
|
+
const passed = actualF1 >= actualThreshold;
|
|
620
|
+
const result = {
|
|
621
|
+
type: "f1",
|
|
622
|
+
passed,
|
|
623
|
+
message: passed ? `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
624
|
+
expected: actualThreshold,
|
|
625
|
+
actual: actualF1,
|
|
626
|
+
field: this.fieldName,
|
|
627
|
+
class: targetClass
|
|
628
|
+
};
|
|
629
|
+
this.assertions.push(result);
|
|
630
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
631
|
+
if (!passed) {
|
|
632
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
|
|
633
|
+
}
|
|
634
|
+
return this;
|
|
635
|
+
}
|
|
636
|
+
/**
|
|
637
|
+
* Includes the confusion matrix in the report
|
|
638
|
+
*/
|
|
639
|
+
toHaveConfusionMatrix() {
|
|
640
|
+
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
641
|
+
const fieldResult = {
|
|
642
|
+
field: this.fieldName,
|
|
643
|
+
metrics,
|
|
644
|
+
binarized: true,
|
|
645
|
+
binarizeThreshold: this.threshold
|
|
646
|
+
};
|
|
647
|
+
chunkHDJID3GC_cjs.recordFieldMetrics(fieldResult);
|
|
648
|
+
const result = {
|
|
649
|
+
type: "confusionMatrix",
|
|
650
|
+
passed: true,
|
|
651
|
+
message: `Confusion matrix recorded for binarized field "${this.fieldName}" (threshold: ${this.threshold})`,
|
|
652
|
+
field: this.fieldName
|
|
653
|
+
};
|
|
654
|
+
this.assertions.push(result);
|
|
655
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
656
|
+
return this;
|
|
657
|
+
}
|
|
658
|
+
/**
|
|
659
|
+
* Gets computed metrics
|
|
660
|
+
*/
|
|
661
|
+
getMetrics() {
|
|
662
|
+
return computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
663
|
+
}
|
|
664
|
+
/**
|
|
665
|
+
* Gets all assertions made
|
|
666
|
+
*/
|
|
667
|
+
getAssertions() {
|
|
668
|
+
return this.assertions;
|
|
669
|
+
}
|
|
670
|
+
};
|
|
671
|
+
|
|
672
|
+
// src/assertions/field-selector.ts
|
|
673
|
+
var FieldSelector = class {
|
|
674
|
+
aligned;
|
|
675
|
+
fieldName;
|
|
676
|
+
actualValues;
|
|
677
|
+
expectedValues;
|
|
678
|
+
assertions = [];
|
|
679
|
+
constructor(aligned, fieldName) {
|
|
680
|
+
this.aligned = aligned;
|
|
681
|
+
this.fieldName = fieldName;
|
|
682
|
+
const extracted = extractFieldValues(aligned, fieldName);
|
|
683
|
+
this.actualValues = extracted.actual;
|
|
684
|
+
this.expectedValues = extracted.expected;
|
|
685
|
+
}
|
|
686
|
+
/**
|
|
687
|
+
* Transforms continuous scores to binary classification using a threshold
|
|
688
|
+
*/
|
|
689
|
+
binarize(threshold) {
|
|
690
|
+
return new BinarizeSelector(this.aligned, this.fieldName, threshold);
|
|
691
|
+
}
|
|
692
|
+
/**
|
|
693
|
+
* Validates that ground truth exists for classification metrics.
|
|
694
|
+
* Throws a clear error if expected values are missing.
|
|
695
|
+
*/
|
|
696
|
+
validateGroundTruth() {
|
|
697
|
+
const hasExpected = this.expectedValues.some(
|
|
698
|
+
(v) => v !== void 0 && v !== null
|
|
699
|
+
);
|
|
700
|
+
if (!hasExpected) {
|
|
701
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
702
|
+
`Classification metric requires ground truth, but field "${this.fieldName}" has no expected values. Use expectStats(predictions, groundTruth) to provide expected values.`,
|
|
703
|
+
void 0,
|
|
704
|
+
void 0,
|
|
705
|
+
this.fieldName
|
|
706
|
+
);
|
|
707
|
+
}
|
|
708
|
+
}
|
|
709
|
+
/**
|
|
710
|
+
* Asserts that accuracy is above a threshold
|
|
711
|
+
*/
|
|
712
|
+
toHaveAccuracyAbove(threshold) {
|
|
713
|
+
this.validateGroundTruth();
|
|
714
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
715
|
+
const passed = metrics.accuracy >= threshold;
|
|
716
|
+
const result = {
|
|
717
|
+
type: "accuracy",
|
|
718
|
+
passed,
|
|
719
|
+
message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}%` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}%`,
|
|
720
|
+
expected: threshold,
|
|
721
|
+
actual: metrics.accuracy,
|
|
722
|
+
field: this.fieldName
|
|
723
|
+
};
|
|
724
|
+
this.assertions.push(result);
|
|
725
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
726
|
+
if (!passed) {
|
|
727
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
|
|
728
|
+
}
|
|
729
|
+
return this;
|
|
730
|
+
}
|
|
731
|
+
/**
|
|
732
|
+
* Asserts that precision is above a threshold
|
|
733
|
+
* @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
|
|
734
|
+
* @param threshold - Threshold when class is specified
|
|
735
|
+
*/
|
|
736
|
+
toHavePrecisionAbove(classOrThreshold, threshold) {
|
|
737
|
+
this.validateGroundTruth();
|
|
738
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
739
|
+
let actualPrecision;
|
|
740
|
+
let targetClass;
|
|
741
|
+
let actualThreshold;
|
|
742
|
+
if (typeof classOrThreshold === "number") {
|
|
743
|
+
actualPrecision = metrics.macroAvg.precision;
|
|
744
|
+
actualThreshold = classOrThreshold;
|
|
745
|
+
} else {
|
|
746
|
+
targetClass = classOrThreshold;
|
|
747
|
+
actualThreshold = threshold;
|
|
748
|
+
const classMetrics = metrics.perClass[targetClass];
|
|
749
|
+
if (!classMetrics) {
|
|
750
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
751
|
+
`Class "${targetClass}" not found in predictions`,
|
|
752
|
+
targetClass,
|
|
753
|
+
Object.keys(metrics.perClass),
|
|
754
|
+
this.fieldName
|
|
755
|
+
);
|
|
756
|
+
}
|
|
757
|
+
actualPrecision = classMetrics.precision;
|
|
758
|
+
}
|
|
759
|
+
const passed = actualPrecision >= actualThreshold;
|
|
760
|
+
const result = {
|
|
761
|
+
type: "precision",
|
|
762
|
+
passed,
|
|
763
|
+
message: passed ? `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
764
|
+
expected: actualThreshold,
|
|
765
|
+
actual: actualPrecision,
|
|
766
|
+
field: this.fieldName,
|
|
767
|
+
class: targetClass
|
|
768
|
+
};
|
|
769
|
+
this.assertions.push(result);
|
|
770
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
771
|
+
if (!passed) {
|
|
772
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
|
|
773
|
+
}
|
|
774
|
+
return this;
|
|
775
|
+
}
|
|
776
|
+
/**
|
|
777
|
+
* Asserts that recall is above a threshold
|
|
778
|
+
* @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
|
|
779
|
+
* @param threshold - Threshold when class is specified
|
|
780
|
+
*/
|
|
781
|
+
toHaveRecallAbove(classOrThreshold, threshold) {
|
|
782
|
+
this.validateGroundTruth();
|
|
783
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
784
|
+
let actualRecall;
|
|
785
|
+
let targetClass;
|
|
786
|
+
let actualThreshold;
|
|
787
|
+
if (typeof classOrThreshold === "number") {
|
|
788
|
+
actualRecall = metrics.macroAvg.recall;
|
|
789
|
+
actualThreshold = classOrThreshold;
|
|
790
|
+
} else {
|
|
791
|
+
targetClass = classOrThreshold;
|
|
792
|
+
actualThreshold = threshold;
|
|
793
|
+
const classMetrics = metrics.perClass[targetClass];
|
|
794
|
+
if (!classMetrics) {
|
|
795
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
796
|
+
`Class "${targetClass}" not found in predictions`,
|
|
797
|
+
targetClass,
|
|
798
|
+
Object.keys(metrics.perClass),
|
|
799
|
+
this.fieldName
|
|
800
|
+
);
|
|
801
|
+
}
|
|
802
|
+
actualRecall = classMetrics.recall;
|
|
803
|
+
}
|
|
804
|
+
const passed = actualRecall >= actualThreshold;
|
|
805
|
+
const result = {
|
|
806
|
+
type: "recall",
|
|
807
|
+
passed,
|
|
808
|
+
message: passed ? `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
809
|
+
expected: actualThreshold,
|
|
810
|
+
actual: actualRecall,
|
|
811
|
+
field: this.fieldName,
|
|
812
|
+
class: targetClass
|
|
813
|
+
};
|
|
814
|
+
this.assertions.push(result);
|
|
815
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
816
|
+
if (!passed) {
|
|
817
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
|
|
818
|
+
}
|
|
819
|
+
return this;
|
|
820
|
+
}
|
|
821
|
+
/**
|
|
822
|
+
* Asserts that F1 score is above a threshold
|
|
823
|
+
* @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
|
|
824
|
+
* @param threshold - Threshold when class is specified
|
|
825
|
+
*/
|
|
826
|
+
toHaveF1Above(classOrThreshold, threshold) {
|
|
827
|
+
this.validateGroundTruth();
|
|
828
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
829
|
+
let actualF1;
|
|
830
|
+
let targetClass;
|
|
831
|
+
let actualThreshold;
|
|
832
|
+
if (typeof classOrThreshold === "number") {
|
|
833
|
+
actualF1 = metrics.macroAvg.f1;
|
|
834
|
+
actualThreshold = classOrThreshold;
|
|
835
|
+
} else {
|
|
836
|
+
targetClass = classOrThreshold;
|
|
837
|
+
actualThreshold = threshold;
|
|
838
|
+
const classMetrics = metrics.perClass[targetClass];
|
|
839
|
+
if (!classMetrics) {
|
|
840
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
841
|
+
`Class "${targetClass}" not found in predictions`,
|
|
842
|
+
targetClass,
|
|
843
|
+
Object.keys(metrics.perClass),
|
|
844
|
+
this.fieldName
|
|
845
|
+
);
|
|
846
|
+
}
|
|
847
|
+
actualF1 = classMetrics.f1;
|
|
848
|
+
}
|
|
849
|
+
const passed = actualF1 >= actualThreshold;
|
|
850
|
+
const result = {
|
|
851
|
+
type: "f1",
|
|
852
|
+
passed,
|
|
853
|
+
message: passed ? `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
854
|
+
expected: actualThreshold,
|
|
855
|
+
actual: actualF1,
|
|
856
|
+
field: this.fieldName,
|
|
857
|
+
class: targetClass
|
|
858
|
+
};
|
|
859
|
+
this.assertions.push(result);
|
|
860
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
861
|
+
if (!passed) {
|
|
862
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
|
|
863
|
+
}
|
|
864
|
+
return this;
|
|
865
|
+
}
|
|
866
|
+
/**
|
|
867
|
+
* Includes the confusion matrix in the report
|
|
868
|
+
*/
|
|
869
|
+
toHaveConfusionMatrix() {
|
|
870
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
871
|
+
const fieldResult = {
|
|
872
|
+
field: this.fieldName,
|
|
873
|
+
metrics,
|
|
874
|
+
binarized: false
|
|
875
|
+
};
|
|
876
|
+
chunkHDJID3GC_cjs.recordFieldMetrics(fieldResult);
|
|
877
|
+
const result = {
|
|
878
|
+
type: "confusionMatrix",
|
|
879
|
+
passed: true,
|
|
880
|
+
message: `Confusion matrix recorded for field "${this.fieldName}"`,
|
|
881
|
+
field: this.fieldName
|
|
882
|
+
};
|
|
883
|
+
this.assertions.push(result);
|
|
884
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
885
|
+
return this;
|
|
886
|
+
}
|
|
887
|
+
/**
|
|
888
|
+
* Asserts that a percentage of values are below or equal to a threshold.
|
|
889
|
+
* This is a distributional assertion that only looks at actual values (no ground truth required).
|
|
890
|
+
*
|
|
891
|
+
* @param valueThreshold - The value threshold to compare against
|
|
892
|
+
* @param percentageThreshold - The minimum percentage (0-1) of values that should be <= valueThreshold
|
|
893
|
+
* @returns this for method chaining
|
|
894
|
+
*
|
|
895
|
+
* @example
|
|
896
|
+
* // Assert that 90% of confidence scores are below 0.5
|
|
897
|
+
* expectStats(predictions)
|
|
898
|
+
* .field("confidence")
|
|
899
|
+
* .toHavePercentageBelow(0.5, 0.9)
|
|
900
|
+
*/
|
|
901
|
+
toHavePercentageBelow(valueThreshold, percentageThreshold) {
|
|
902
|
+
const numericActual = filterNumericValues(this.actualValues);
|
|
903
|
+
if (numericActual.length === 0) {
|
|
904
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
905
|
+
`Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
|
|
906
|
+
percentageThreshold,
|
|
907
|
+
void 0,
|
|
908
|
+
this.fieldName
|
|
909
|
+
);
|
|
910
|
+
}
|
|
911
|
+
const actualPercentage = calculatePercentageBelow(numericActual, valueThreshold);
|
|
912
|
+
const passed = actualPercentage >= percentageThreshold;
|
|
913
|
+
const result = {
|
|
914
|
+
type: "percentageBelow",
|
|
915
|
+
passed,
|
|
916
|
+
message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
|
|
917
|
+
expected: percentageThreshold,
|
|
918
|
+
actual: actualPercentage,
|
|
919
|
+
field: this.fieldName
|
|
920
|
+
};
|
|
921
|
+
this.assertions.push(result);
|
|
922
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
923
|
+
if (!passed) {
|
|
924
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, percentageThreshold, actualPercentage, this.fieldName);
|
|
925
|
+
}
|
|
926
|
+
return this;
|
|
927
|
+
}
|
|
928
|
+
/**
|
|
929
|
+
* Asserts that a percentage of values are above a threshold.
|
|
930
|
+
* This is a distributional assertion that only looks at actual values (no ground truth required).
|
|
931
|
+
*
|
|
932
|
+
* @param valueThreshold - The value threshold to compare against
|
|
933
|
+
* @param percentageThreshold - The minimum percentage (0-1) of values that should be > valueThreshold
|
|
934
|
+
* @returns this for method chaining
|
|
935
|
+
*
|
|
936
|
+
* @example
|
|
937
|
+
* // Assert that 80% of quality scores are above 0.7
|
|
938
|
+
* expectStats(predictions)
|
|
939
|
+
* .field("quality")
|
|
940
|
+
* .toHavePercentageAbove(0.7, 0.8)
|
|
941
|
+
*/
|
|
942
|
+
toHavePercentageAbove(valueThreshold, percentageThreshold) {
|
|
943
|
+
const numericActual = filterNumericValues(this.actualValues);
|
|
944
|
+
if (numericActual.length === 0) {
|
|
945
|
+
throw new chunkHDJID3GC_cjs.AssertionError(
|
|
946
|
+
`Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
|
|
947
|
+
percentageThreshold,
|
|
948
|
+
void 0,
|
|
949
|
+
this.fieldName
|
|
950
|
+
);
|
|
951
|
+
}
|
|
952
|
+
const actualPercentage = calculatePercentageAbove(numericActual, valueThreshold);
|
|
953
|
+
const passed = actualPercentage >= percentageThreshold;
|
|
954
|
+
const result = {
|
|
955
|
+
type: "percentageAbove",
|
|
956
|
+
passed,
|
|
957
|
+
message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
|
|
958
|
+
expected: percentageThreshold,
|
|
959
|
+
actual: actualPercentage,
|
|
960
|
+
field: this.fieldName
|
|
961
|
+
};
|
|
962
|
+
this.assertions.push(result);
|
|
963
|
+
chunkHDJID3GC_cjs.recordAssertion(result);
|
|
964
|
+
if (!passed) {
|
|
965
|
+
throw new chunkHDJID3GC_cjs.AssertionError(result.message, percentageThreshold, actualPercentage, this.fieldName);
|
|
966
|
+
}
|
|
967
|
+
return this;
|
|
968
|
+
}
|
|
969
|
+
/**
|
|
970
|
+
* Gets the computed metrics for this field
|
|
971
|
+
*/
|
|
972
|
+
getMetrics() {
|
|
973
|
+
return computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
974
|
+
}
|
|
975
|
+
/**
|
|
976
|
+
* Gets all assertions made on this field
|
|
977
|
+
*/
|
|
978
|
+
getAssertions() {
|
|
979
|
+
return this.assertions;
|
|
980
|
+
}
|
|
981
|
+
};
|
|
982
|
+
|
|
983
|
+
// src/assertions/expect-stats.ts
|
|
984
|
+
function normalizeInput(input) {
|
|
985
|
+
if ("aligned" in input && Array.isArray(input.aligned)) {
|
|
986
|
+
return input.aligned;
|
|
987
|
+
}
|
|
988
|
+
if (Array.isArray(input)) {
|
|
989
|
+
if (input.length === 0) {
|
|
990
|
+
return [];
|
|
991
|
+
}
|
|
992
|
+
const first = input[0];
|
|
993
|
+
if (first && "actual" in first && "expected" in first) {
|
|
994
|
+
return input;
|
|
995
|
+
}
|
|
996
|
+
return input.map((p) => ({
|
|
997
|
+
id: p.id,
|
|
998
|
+
actual: { ...p },
|
|
999
|
+
expected: {}
|
|
1000
|
+
}));
|
|
1001
|
+
}
|
|
1002
|
+
throw new Error("Invalid input to expectStats(): expected ModelRunResult, Prediction[], or AlignedRecord[]");
|
|
1003
|
+
}
|
|
1004
|
+
function expectStats(inputOrActual, expected) {
|
|
1005
|
+
if (expected !== void 0) {
|
|
1006
|
+
if (!Array.isArray(inputOrActual)) {
|
|
1007
|
+
throw new Error(
|
|
1008
|
+
"When using two-argument expectStats(), first argument must be Prediction[]"
|
|
1009
|
+
);
|
|
1010
|
+
}
|
|
1011
|
+
const aligned2 = alignByKey(inputOrActual, expected);
|
|
1012
|
+
return new ExpectStats(aligned2);
|
|
1013
|
+
}
|
|
1014
|
+
const aligned = normalizeInput(inputOrActual);
|
|
1015
|
+
return new ExpectStats(aligned);
|
|
1016
|
+
}
|
|
1017
|
+
var ExpectStats = class {
|
|
1018
|
+
aligned;
|
|
1019
|
+
constructor(aligned) {
|
|
1020
|
+
this.aligned = aligned;
|
|
1021
|
+
}
|
|
1022
|
+
/**
|
|
1023
|
+
* Selects a field to evaluate
|
|
1024
|
+
*/
|
|
1025
|
+
field(fieldName) {
|
|
1026
|
+
return new FieldSelector(this.aligned, fieldName);
|
|
1027
|
+
}
|
|
1028
|
+
/**
|
|
1029
|
+
* Gets the raw aligned records (for advanced use)
|
|
1030
|
+
*/
|
|
1031
|
+
getAligned() {
|
|
1032
|
+
return this.aligned;
|
|
1033
|
+
}
|
|
1034
|
+
/**
|
|
1035
|
+
* Gets the count of records
|
|
1036
|
+
*/
|
|
1037
|
+
count() {
|
|
1038
|
+
return this.aligned.length;
|
|
1039
|
+
}
|
|
1040
|
+
};
|
|
1041
|
+
|
|
1042
|
+
Object.defineProperty(exports, "AssertionError", {
|
|
1043
|
+
enumerable: true,
|
|
1044
|
+
get: function () { return chunkHDJID3GC_cjs.AssertionError; }
|
|
1045
|
+
});
|
|
1046
|
+
Object.defineProperty(exports, "ConfigurationError", {
|
|
1047
|
+
enumerable: true,
|
|
1048
|
+
get: function () { return chunkHDJID3GC_cjs.ConfigurationError; }
|
|
1049
|
+
});
|
|
1050
|
+
Object.defineProperty(exports, "ConsoleReporter", {
|
|
1051
|
+
enumerable: true,
|
|
1052
|
+
get: function () { return chunkHDJID3GC_cjs.ConsoleReporter; }
|
|
1053
|
+
});
|
|
1054
|
+
Object.defineProperty(exports, "DatasetError", {
|
|
1055
|
+
enumerable: true,
|
|
1056
|
+
get: function () { return chunkHDJID3GC_cjs.DatasetError; }
|
|
1057
|
+
});
|
|
1058
|
+
Object.defineProperty(exports, "EvalSenseError", {
|
|
1059
|
+
enumerable: true,
|
|
1060
|
+
get: function () { return chunkHDJID3GC_cjs.EvalSenseError; }
|
|
1061
|
+
});
|
|
1062
|
+
Object.defineProperty(exports, "ExitCodes", {
|
|
1063
|
+
enumerable: true,
|
|
1064
|
+
get: function () { return chunkHDJID3GC_cjs.ExitCodes; }
|
|
1065
|
+
});
|
|
1066
|
+
Object.defineProperty(exports, "IntegrityError", {
|
|
1067
|
+
enumerable: true,
|
|
1068
|
+
get: function () { return chunkHDJID3GC_cjs.IntegrityError; }
|
|
1069
|
+
});
|
|
1070
|
+
Object.defineProperty(exports, "JsonReporter", {
|
|
1071
|
+
enumerable: true,
|
|
1072
|
+
get: function () { return chunkHDJID3GC_cjs.JsonReporter; }
|
|
1073
|
+
});
|
|
1074
|
+
Object.defineProperty(exports, "TestExecutionError", {
|
|
1075
|
+
enumerable: true,
|
|
1076
|
+
get: function () { return chunkHDJID3GC_cjs.TestExecutionError; }
|
|
1077
|
+
});
|
|
1078
|
+
Object.defineProperty(exports, "buildConfusionMatrix", {
|
|
1079
|
+
enumerable: true,
|
|
1080
|
+
get: function () { return chunkHDJID3GC_cjs.buildConfusionMatrix; }
|
|
1081
|
+
});
|
|
1082
|
+
Object.defineProperty(exports, "discoverEvalFiles", {
|
|
1083
|
+
enumerable: true,
|
|
1084
|
+
get: function () { return chunkHDJID3GC_cjs.discoverEvalFiles; }
|
|
1085
|
+
});
|
|
1086
|
+
Object.defineProperty(exports, "executeEvalFiles", {
|
|
1087
|
+
enumerable: true,
|
|
1088
|
+
get: function () { return chunkHDJID3GC_cjs.executeEvalFiles; }
|
|
1089
|
+
});
|
|
1090
|
+
Object.defineProperty(exports, "formatConfusionMatrix", {
|
|
1091
|
+
enumerable: true,
|
|
1092
|
+
get: function () { return chunkHDJID3GC_cjs.formatConfusionMatrix; }
|
|
1093
|
+
});
|
|
1094
|
+
Object.defineProperty(exports, "getExitCode", {
|
|
1095
|
+
enumerable: true,
|
|
1096
|
+
get: function () { return chunkHDJID3GC_cjs.getExitCode; }
|
|
1097
|
+
});
|
|
1098
|
+
Object.defineProperty(exports, "parseReport", {
|
|
1099
|
+
enumerable: true,
|
|
1100
|
+
get: function () { return chunkHDJID3GC_cjs.parseReport; }
|
|
1101
|
+
});
|
|
1102
|
+
exports.afterAll = afterAll;
|
|
1103
|
+
exports.afterEach = afterEach;
|
|
1104
|
+
exports.alignByKey = alignByKey;
|
|
1105
|
+
exports.beforeAll = beforeAll;
|
|
1106
|
+
exports.beforeEach = beforeEach;
|
|
1107
|
+
exports.checkIntegrity = checkIntegrity;
|
|
1108
|
+
exports.computeAccuracy = computeAccuracy;
|
|
1109
|
+
exports.computeClassificationMetrics = computeClassificationMetrics;
|
|
1110
|
+
exports.computeF1 = computeF1;
|
|
1111
|
+
exports.computePrecision = computePrecision;
|
|
1112
|
+
exports.computeRecall = computeRecall;
|
|
1113
|
+
exports.createDataset = createDataset;
|
|
1114
|
+
exports.describe = describe;
|
|
1115
|
+
exports.evalTest = evalTest;
|
|
1116
|
+
exports.expectStats = expectStats;
|
|
1117
|
+
exports.extractFieldValues = extractFieldValues;
|
|
1118
|
+
exports.filterComplete = filterComplete;
|
|
1119
|
+
exports.it = it;
|
|
1120
|
+
exports.loadDataset = loadDataset;
|
|
1121
|
+
exports.runModel = runModel;
|
|
1122
|
+
exports.runModelParallel = runModelParallel;
|
|
1123
|
+
exports.test = test;
|
|
1124
|
+
exports.validatePredictions = validatePredictions;
|
|
1125
|
+
//# sourceMappingURL=index.cjs.map
|
|
1126
|
+
//# sourceMappingURL=index.cjs.map
|