evalsense 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js ADDED
@@ -0,0 +1,1043 @@
1
+ import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, DatasetError, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordAssertion, recordFieldMetrics } from './chunk-5P7LNNO6.js';
2
+ export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-5P7LNNO6.js';
3
+ import { readFileSync } from 'fs';
4
+ import { resolve, extname } from 'path';
5
+
6
+ // src/core/describe.ts
7
+ function describe(name, fn) {
8
+ const parentSuite = getCurrentSuite();
9
+ const suite = {
10
+ name,
11
+ tests: [],
12
+ beforeAll: [],
13
+ afterAll: [],
14
+ beforeEach: [],
15
+ afterEach: []
16
+ };
17
+ setCurrentSuite(suite);
18
+ try {
19
+ fn();
20
+ } finally {
21
+ setCurrentSuite(parentSuite);
22
+ }
23
+ addSuite(suite);
24
+ }
25
+ function beforeAll(fn) {
26
+ const suite = getCurrentSuite();
27
+ if (!suite) {
28
+ throw new Error("beforeAll() must be called inside a describe() block");
29
+ }
30
+ suite.beforeAll?.push(fn);
31
+ }
32
+ function afterAll(fn) {
33
+ const suite = getCurrentSuite();
34
+ if (!suite) {
35
+ throw new Error("afterAll() must be called inside a describe() block");
36
+ }
37
+ suite.afterAll?.push(fn);
38
+ }
39
+ function beforeEach(fn) {
40
+ const suite = getCurrentSuite();
41
+ if (!suite) {
42
+ throw new Error("beforeEach() must be called inside a describe() block");
43
+ }
44
+ suite.beforeEach?.push(fn);
45
+ }
46
+ function afterEach(fn) {
47
+ const suite = getCurrentSuite();
48
+ if (!suite) {
49
+ throw new Error("afterEach() must be called inside a describe() block");
50
+ }
51
+ suite.afterEach?.push(fn);
52
+ }
53
+
54
+ // src/core/eval-test.ts
55
+ function evalTest(name, fn) {
56
+ const currentSuite = getCurrentSuite();
57
+ if (!currentSuite) {
58
+ throw new Error("evalTest() must be called inside a describe() block");
59
+ }
60
+ const test2 = {
61
+ name,
62
+ fn
63
+ };
64
+ addTestToCurrentSuite(test2);
65
+ }
66
+ var test = evalTest;
67
+ var it = evalTest;
68
+ function evalTestSkip(name, _fn) {
69
+ const currentSuite = getCurrentSuite();
70
+ if (!currentSuite) {
71
+ throw new Error("evalTest.skip() must be called inside a describe() block");
72
+ }
73
+ const test2 = {
74
+ name: `[SKIPPED] ${name}`,
75
+ fn: async () => {
76
+ }
77
+ };
78
+ addTestToCurrentSuite(test2);
79
+ }
80
+ function evalTestOnly(name, fn) {
81
+ const currentSuite = getCurrentSuite();
82
+ if (!currentSuite) {
83
+ throw new Error("evalTest.only() must be called inside a describe() block");
84
+ }
85
+ const test2 = {
86
+ name: `[ONLY] ${name}`,
87
+ fn
88
+ };
89
+ addTestToCurrentSuite(test2);
90
+ }
91
+ evalTest.skip = evalTestSkip;
92
+ evalTest.only = evalTestOnly;
93
+ function loadDataset(path) {
94
+ const absolutePath = resolve(process.cwd(), path);
95
+ const ext = extname(absolutePath).toLowerCase();
96
+ let records;
97
+ try {
98
+ const content = readFileSync(absolutePath, "utf-8");
99
+ if (ext === ".ndjson" || ext === ".jsonl") {
100
+ records = parseNDJSON(content);
101
+ } else if (ext === ".json") {
102
+ records = parseJSON(content);
103
+ } else {
104
+ throw new DatasetError(
105
+ `Unsupported file format: ${ext}. Use .json, .ndjson, or .jsonl`,
106
+ path
107
+ );
108
+ }
109
+ } catch (error) {
110
+ if (error instanceof DatasetError) {
111
+ throw error;
112
+ }
113
+ const message = error instanceof Error ? error.message : String(error);
114
+ throw new DatasetError(`Failed to load dataset from ${path}: ${message}`, path);
115
+ }
116
+ return {
117
+ records,
118
+ metadata: {
119
+ source: path,
120
+ count: records.length,
121
+ loadedAt: /* @__PURE__ */ new Date()
122
+ }
123
+ };
124
+ }
125
+ function parseJSON(content) {
126
+ const parsed = JSON.parse(content);
127
+ if (!Array.isArray(parsed)) {
128
+ throw new DatasetError("JSON dataset must be an array of records");
129
+ }
130
+ return parsed;
131
+ }
132
+ function parseNDJSON(content) {
133
+ const lines = content.split("\n").filter((line) => line.trim() !== "");
134
+ const records = [];
135
+ for (let i = 0; i < lines.length; i++) {
136
+ const line = lines[i];
137
+ if (line === void 0) continue;
138
+ try {
139
+ records.push(JSON.parse(line));
140
+ } catch {
141
+ throw new DatasetError(`Invalid JSON at line ${i + 1} in NDJSON file`);
142
+ }
143
+ }
144
+ return records;
145
+ }
146
+ function createDataset(records, source = "inline") {
147
+ return {
148
+ records,
149
+ metadata: {
150
+ source,
151
+ count: records.length,
152
+ loadedAt: /* @__PURE__ */ new Date()
153
+ }
154
+ };
155
+ }
156
+
157
+ // src/dataset/run-model.ts
158
+ async function runModel(dataset, modelFn) {
159
+ const startTime = Date.now();
160
+ const predictions = [];
161
+ const aligned = [];
162
+ for (const record of dataset.records) {
163
+ const id = getRecordId(record);
164
+ const prediction = await modelFn(record);
165
+ if (prediction.id !== id) {
166
+ throw new DatasetError(
167
+ `Prediction ID mismatch: expected "${id}", got "${prediction.id}". Model function must return the same ID as the input record.`
168
+ );
169
+ }
170
+ predictions.push(prediction);
171
+ aligned.push({
172
+ id,
173
+ actual: { ...prediction },
174
+ expected: { ...record }
175
+ });
176
+ }
177
+ return {
178
+ predictions,
179
+ aligned,
180
+ duration: Date.now() - startTime
181
+ };
182
+ }
183
+ function getRecordId(record) {
184
+ const id = record.id ?? record._id;
185
+ if (id === void 0 || id === null) {
186
+ throw new DatasetError(
187
+ 'Dataset records must have an "id" or "_id" field for alignment'
188
+ );
189
+ }
190
+ return String(id);
191
+ }
192
+ async function runModelParallel(dataset, modelFn, concurrency = 10) {
193
+ const startTime = Date.now();
194
+ const results = [];
195
+ for (let i = 0; i < dataset.records.length; i += concurrency) {
196
+ const batch = dataset.records.slice(i, i + concurrency);
197
+ const batchResults = await Promise.all(
198
+ batch.map(async (record) => {
199
+ const prediction = await modelFn(record);
200
+ return { prediction, record };
201
+ })
202
+ );
203
+ results.push(...batchResults);
204
+ }
205
+ const predictions = [];
206
+ const aligned = [];
207
+ for (const { prediction, record } of results) {
208
+ const id = getRecordId(record);
209
+ if (prediction.id !== id) {
210
+ throw new DatasetError(
211
+ `Prediction ID mismatch: expected "${id}", got "${prediction.id}".`
212
+ );
213
+ }
214
+ predictions.push(prediction);
215
+ aligned.push({
216
+ id,
217
+ actual: { ...prediction },
218
+ expected: { ...record }
219
+ });
220
+ }
221
+ return {
222
+ predictions,
223
+ aligned,
224
+ duration: Date.now() - startTime
225
+ };
226
+ }
227
+
228
+ // src/dataset/alignment.ts
229
+ function alignByKey(predictions, expected, options = {}) {
230
+ const { strict = false, idField = "id" } = options;
231
+ const expectedMap = /* @__PURE__ */ new Map();
232
+ for (const record of expected) {
233
+ const id = String(record[idField] ?? record._id);
234
+ expectedMap.set(id, record);
235
+ }
236
+ const aligned = [];
237
+ const missingIds = [];
238
+ for (const prediction of predictions) {
239
+ const id = prediction.id;
240
+ const expectedRecord = expectedMap.get(id);
241
+ if (!expectedRecord) {
242
+ missingIds.push(id);
243
+ if (strict) {
244
+ continue;
245
+ }
246
+ aligned.push({
247
+ id,
248
+ actual: { ...prediction },
249
+ expected: {}
250
+ });
251
+ } else {
252
+ aligned.push({
253
+ id,
254
+ actual: { ...prediction },
255
+ expected: { ...expectedRecord }
256
+ });
257
+ }
258
+ }
259
+ if (strict && missingIds.length > 0) {
260
+ throw new IntegrityError(
261
+ `${missingIds.length} prediction(s) have no matching expected record`,
262
+ missingIds
263
+ );
264
+ }
265
+ return aligned;
266
+ }
267
+ function extractFieldValues(aligned, field) {
268
+ const actual = [];
269
+ const expected = [];
270
+ const ids = [];
271
+ for (const record of aligned) {
272
+ actual.push(record.actual[field]);
273
+ expected.push(record.expected[field]);
274
+ ids.push(record.id);
275
+ }
276
+ return { actual, expected, ids };
277
+ }
278
+ function filterComplete(aligned, field) {
279
+ return aligned.filter((record) => {
280
+ const actualValue = record.actual[field];
281
+ const expectedValue = record.expected[field];
282
+ return actualValue !== void 0 && expectedValue !== void 0;
283
+ });
284
+ }
285
+
286
+ // src/dataset/integrity.ts
287
+ function checkIntegrity(dataset, options = {}) {
288
+ const { requiredFields = [], throwOnFailure = false } = options;
289
+ const seenIds = /* @__PURE__ */ new Map();
290
+ const missingIds = [];
291
+ const duplicateIds = [];
292
+ const missingFields = [];
293
+ for (let i = 0; i < dataset.records.length; i++) {
294
+ const record = dataset.records[i];
295
+ if (!record) continue;
296
+ const id = record.id ?? record._id;
297
+ if (id === void 0 || id === null) {
298
+ missingIds.push(`record[${i}]`);
299
+ } else {
300
+ const idStr = String(id);
301
+ const previousIndex = seenIds.get(idStr);
302
+ if (previousIndex !== void 0) {
303
+ duplicateIds.push(idStr);
304
+ } else {
305
+ seenIds.set(idStr, i);
306
+ }
307
+ }
308
+ if (requiredFields.length > 0) {
309
+ const missing = requiredFields.filter(
310
+ (field) => record[field] === void 0
311
+ );
312
+ if (missing.length > 0) {
313
+ missingFields.push({
314
+ id: String(id ?? `record[${i}]`),
315
+ fields: missing
316
+ });
317
+ }
318
+ }
319
+ }
320
+ const valid = missingIds.length === 0 && duplicateIds.length === 0 && missingFields.length === 0;
321
+ const result = {
322
+ valid,
323
+ totalRecords: dataset.records.length,
324
+ missingIds,
325
+ duplicateIds,
326
+ missingFields
327
+ };
328
+ if (throwOnFailure && !valid) {
329
+ const issues = [];
330
+ if (missingIds.length > 0) {
331
+ issues.push(`${missingIds.length} record(s) missing ID`);
332
+ }
333
+ if (duplicateIds.length > 0) {
334
+ issues.push(`${duplicateIds.length} duplicate ID(s): ${duplicateIds.slice(0, 3).join(", ")}${duplicateIds.length > 3 ? "..." : ""}`);
335
+ }
336
+ if (missingFields.length > 0) {
337
+ issues.push(`${missingFields.length} record(s) missing required fields`);
338
+ }
339
+ throw new IntegrityError(`Dataset integrity check failed: ${issues.join("; ")}`);
340
+ }
341
+ return result;
342
+ }
343
+ function validatePredictions(predictions, expectedIds) {
344
+ const predictionIds = new Set(predictions.map((p) => p.id));
345
+ const expectedIdSet = new Set(expectedIds);
346
+ const missing = expectedIds.filter((id) => !predictionIds.has(id));
347
+ const extra = predictions.map((p) => p.id).filter((id) => !expectedIdSet.has(id));
348
+ return {
349
+ valid: missing.length === 0 && extra.length === 0,
350
+ missing,
351
+ extra
352
+ };
353
+ }
354
+
355
+ // src/statistics/classification.ts
356
+ function computeClassificationMetrics(actual, expected) {
357
+ const confusionMatrix = buildConfusionMatrix(actual, expected);
358
+ return computeMetricsFromMatrix(confusionMatrix);
359
+ }
360
+ function computeMetricsFromMatrix(cm) {
361
+ const perClass = {};
362
+ let totalSupport = 0;
363
+ let correctPredictions = 0;
364
+ for (const label of cm.labels) {
365
+ const tp = getTruePositives(cm, label);
366
+ const fp = getFalsePositives(cm, label);
367
+ const fn = getFalseNegatives(cm, label);
368
+ const support = getSupport(cm, label);
369
+ const precision = tp + fp > 0 ? tp / (tp + fp) : 0;
370
+ const recall = tp + fn > 0 ? tp / (tp + fn) : 0;
371
+ const f1 = precision + recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
372
+ perClass[label] = { precision, recall, f1, support };
373
+ totalSupport += support;
374
+ correctPredictions += tp;
375
+ }
376
+ const accuracy = totalSupport > 0 ? correctPredictions / totalSupport : 0;
377
+ const classCount = cm.labels.length;
378
+ const macroAvg = {
379
+ precision: classCount > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.precision, 0) / classCount : 0,
380
+ recall: classCount > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.recall, 0) / classCount : 0,
381
+ f1: classCount > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.f1, 0) / classCount : 0
382
+ };
383
+ const weightedAvg = {
384
+ precision: totalSupport > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.precision * m.support, 0) / totalSupport : 0,
385
+ recall: totalSupport > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.recall * m.support, 0) / totalSupport : 0,
386
+ f1: totalSupport > 0 ? Object.values(perClass).reduce((sum, m) => sum + m.f1 * m.support, 0) / totalSupport : 0
387
+ };
388
+ return {
389
+ accuracy,
390
+ perClass,
391
+ macroAvg,
392
+ weightedAvg,
393
+ confusionMatrix: cm
394
+ };
395
+ }
396
+ function computePrecision(actual, expected, targetClass) {
397
+ const cm = buildConfusionMatrix(actual, expected);
398
+ const tp = getTruePositives(cm, targetClass);
399
+ const fp = getFalsePositives(cm, targetClass);
400
+ return tp + fp > 0 ? tp / (tp + fp) : 0;
401
+ }
402
+ function computeRecall(actual, expected, targetClass) {
403
+ const cm = buildConfusionMatrix(actual, expected);
404
+ const tp = getTruePositives(cm, targetClass);
405
+ const fn = getFalseNegatives(cm, targetClass);
406
+ return tp + fn > 0 ? tp / (tp + fn) : 0;
407
+ }
408
+ function computeF1(actual, expected, targetClass) {
409
+ const precision = computePrecision(actual, expected, targetClass);
410
+ const recall = computeRecall(actual, expected, targetClass);
411
+ return precision + recall > 0 ? 2 * precision * recall / (precision + recall) : 0;
412
+ }
413
+ function computeAccuracy(actual, expected) {
414
+ if (actual.length !== expected.length || actual.length === 0) {
415
+ return 0;
416
+ }
417
+ let correct = 0;
418
+ let total = 0;
419
+ for (let i = 0; i < actual.length; i++) {
420
+ const a = actual[i];
421
+ const e = expected[i];
422
+ if (a !== void 0 && a !== null && e !== void 0 && e !== null) {
423
+ total++;
424
+ if (String(a) === String(e)) {
425
+ correct++;
426
+ }
427
+ }
428
+ }
429
+ return total > 0 ? correct / total : 0;
430
+ }
431
+
432
+ // src/statistics/distribution.ts
433
+ function filterNumericValues(values) {
434
+ return values.filter(
435
+ (v) => typeof v === "number" && !Number.isNaN(v) && v !== null && v !== void 0
436
+ );
437
+ }
438
+ function calculatePercentageBelow(values, threshold) {
439
+ if (values.length === 0) {
440
+ return 0;
441
+ }
442
+ const countBelow = values.filter((v) => v <= threshold).length;
443
+ return countBelow / values.length;
444
+ }
445
+ function calculatePercentageAbove(values, threshold) {
446
+ if (values.length === 0) {
447
+ return 0;
448
+ }
449
+ const countAbove = values.filter((v) => v > threshold).length;
450
+ return countAbove / values.length;
451
+ }
452
+
453
+ // src/assertions/binarize.ts
454
+ var BinarizeSelector = class {
455
+ fieldName;
456
+ threshold;
457
+ binaryActual;
458
+ binaryExpected;
459
+ assertions = [];
460
+ constructor(aligned, fieldName, threshold) {
461
+ this.fieldName = fieldName;
462
+ this.threshold = threshold;
463
+ this.binaryActual = [];
464
+ this.binaryExpected = [];
465
+ for (const record of aligned) {
466
+ const actualVal = record.actual[fieldName];
467
+ const expectedVal = record.expected[fieldName];
468
+ if (typeof actualVal === "number") {
469
+ this.binaryActual.push(actualVal >= threshold ? "true" : "false");
470
+ } else if (typeof actualVal === "boolean") {
471
+ this.binaryActual.push(String(actualVal));
472
+ } else {
473
+ this.binaryActual.push(String(actualVal));
474
+ }
475
+ if (typeof expectedVal === "number") {
476
+ this.binaryExpected.push(expectedVal >= threshold ? "true" : "false");
477
+ } else if (typeof expectedVal === "boolean") {
478
+ this.binaryExpected.push(String(expectedVal));
479
+ } else {
480
+ this.binaryExpected.push(String(expectedVal));
481
+ }
482
+ }
483
+ }
484
+ /**
485
+ * Asserts that accuracy is above a threshold
486
+ */
487
+ toHaveAccuracyAbove(threshold) {
488
+ const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
489
+ const passed = metrics.accuracy >= threshold;
490
+ const result = {
491
+ type: "accuracy",
492
+ passed,
493
+ message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})`,
494
+ expected: threshold,
495
+ actual: metrics.accuracy,
496
+ field: this.fieldName
497
+ };
498
+ this.assertions.push(result);
499
+ recordAssertion(result);
500
+ if (!passed) {
501
+ throw new AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
502
+ }
503
+ return this;
504
+ }
505
+ /**
506
+ * Asserts that precision is above a threshold
507
+ * @param classOrThreshold - Either the class (true/false) or threshold
508
+ * @param threshold - Threshold when class is specified
509
+ */
510
+ toHavePrecisionAbove(classOrThreshold, threshold) {
511
+ const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
512
+ let actualPrecision;
513
+ let targetClass;
514
+ let actualThreshold;
515
+ if (typeof classOrThreshold === "number") {
516
+ actualPrecision = metrics.macroAvg.precision;
517
+ actualThreshold = classOrThreshold;
518
+ } else {
519
+ targetClass = String(classOrThreshold);
520
+ actualThreshold = threshold;
521
+ const classMetrics = metrics.perClass[targetClass];
522
+ if (!classMetrics) {
523
+ throw new AssertionError(
524
+ `Class "${targetClass}" not found in binarized predictions`,
525
+ targetClass,
526
+ Object.keys(metrics.perClass),
527
+ this.fieldName
528
+ );
529
+ }
530
+ actualPrecision = classMetrics.precision;
531
+ }
532
+ const passed = actualPrecision >= actualThreshold;
533
+ const result = {
534
+ type: "precision",
535
+ passed,
536
+ message: passed ? `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
537
+ expected: actualThreshold,
538
+ actual: actualPrecision,
539
+ field: this.fieldName,
540
+ class: targetClass
541
+ };
542
+ this.assertions.push(result);
543
+ recordAssertion(result);
544
+ if (!passed) {
545
+ throw new AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
546
+ }
547
+ return this;
548
+ }
549
+ /**
550
+ * Asserts that recall is above a threshold
551
+ * @param classOrThreshold - Either the class (true/false) or threshold
552
+ * @param threshold - Threshold when class is specified
553
+ */
554
+ toHaveRecallAbove(classOrThreshold, threshold) {
555
+ const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
556
+ let actualRecall;
557
+ let targetClass;
558
+ let actualThreshold;
559
+ if (typeof classOrThreshold === "number") {
560
+ actualRecall = metrics.macroAvg.recall;
561
+ actualThreshold = classOrThreshold;
562
+ } else {
563
+ targetClass = String(classOrThreshold);
564
+ actualThreshold = threshold;
565
+ const classMetrics = metrics.perClass[targetClass];
566
+ if (!classMetrics) {
567
+ throw new AssertionError(
568
+ `Class "${targetClass}" not found in binarized predictions`,
569
+ targetClass,
570
+ Object.keys(metrics.perClass),
571
+ this.fieldName
572
+ );
573
+ }
574
+ actualRecall = classMetrics.recall;
575
+ }
576
+ const passed = actualRecall >= actualThreshold;
577
+ const result = {
578
+ type: "recall",
579
+ passed,
580
+ message: passed ? `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
581
+ expected: actualThreshold,
582
+ actual: actualRecall,
583
+ field: this.fieldName,
584
+ class: targetClass
585
+ };
586
+ this.assertions.push(result);
587
+ recordAssertion(result);
588
+ if (!passed) {
589
+ throw new AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
590
+ }
591
+ return this;
592
+ }
593
+ /**
594
+ * Asserts that F1 score is above a threshold
595
+ */
596
+ toHaveF1Above(classOrThreshold, threshold) {
597
+ const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
598
+ let actualF1;
599
+ let targetClass;
600
+ let actualThreshold;
601
+ if (typeof classOrThreshold === "number") {
602
+ actualF1 = metrics.macroAvg.f1;
603
+ actualThreshold = classOrThreshold;
604
+ } else {
605
+ targetClass = String(classOrThreshold);
606
+ actualThreshold = threshold;
607
+ const classMetrics = metrics.perClass[targetClass];
608
+ if (!classMetrics) {
609
+ throw new AssertionError(
610
+ `Class "${targetClass}" not found in binarized predictions`,
611
+ targetClass,
612
+ Object.keys(metrics.perClass),
613
+ this.fieldName
614
+ );
615
+ }
616
+ actualF1 = classMetrics.f1;
617
+ }
618
+ const passed = actualF1 >= actualThreshold;
619
+ const result = {
620
+ type: "f1",
621
+ passed,
622
+ message: passed ? `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
623
+ expected: actualThreshold,
624
+ actual: actualF1,
625
+ field: this.fieldName,
626
+ class: targetClass
627
+ };
628
+ this.assertions.push(result);
629
+ recordAssertion(result);
630
+ if (!passed) {
631
+ throw new AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
632
+ }
633
+ return this;
634
+ }
635
+ /**
636
+ * Includes the confusion matrix in the report
637
+ */
638
+ toHaveConfusionMatrix() {
639
+ const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
640
+ const fieldResult = {
641
+ field: this.fieldName,
642
+ metrics,
643
+ binarized: true,
644
+ binarizeThreshold: this.threshold
645
+ };
646
+ recordFieldMetrics(fieldResult);
647
+ const result = {
648
+ type: "confusionMatrix",
649
+ passed: true,
650
+ message: `Confusion matrix recorded for binarized field "${this.fieldName}" (threshold: ${this.threshold})`,
651
+ field: this.fieldName
652
+ };
653
+ this.assertions.push(result);
654
+ recordAssertion(result);
655
+ return this;
656
+ }
657
+ /**
658
+ * Gets computed metrics
659
+ */
660
+ getMetrics() {
661
+ return computeClassificationMetrics(this.binaryActual, this.binaryExpected);
662
+ }
663
+ /**
664
+ * Gets all assertions made
665
+ */
666
+ getAssertions() {
667
+ return this.assertions;
668
+ }
669
+ };
670
+
671
+ // src/assertions/field-selector.ts
672
+ var FieldSelector = class {
673
+ aligned;
674
+ fieldName;
675
+ actualValues;
676
+ expectedValues;
677
+ assertions = [];
678
+ constructor(aligned, fieldName) {
679
+ this.aligned = aligned;
680
+ this.fieldName = fieldName;
681
+ const extracted = extractFieldValues(aligned, fieldName);
682
+ this.actualValues = extracted.actual;
683
+ this.expectedValues = extracted.expected;
684
+ }
685
+ /**
686
+ * Transforms continuous scores to binary classification using a threshold
687
+ */
688
+ binarize(threshold) {
689
+ return new BinarizeSelector(this.aligned, this.fieldName, threshold);
690
+ }
691
+ /**
692
+ * Validates that ground truth exists for classification metrics.
693
+ * Throws a clear error if expected values are missing.
694
+ */
695
+ validateGroundTruth() {
696
+ const hasExpected = this.expectedValues.some(
697
+ (v) => v !== void 0 && v !== null
698
+ );
699
+ if (!hasExpected) {
700
+ throw new AssertionError(
701
+ `Classification metric requires ground truth, but field "${this.fieldName}" has no expected values. Use expectStats(predictions, groundTruth) to provide expected values.`,
702
+ void 0,
703
+ void 0,
704
+ this.fieldName
705
+ );
706
+ }
707
+ }
708
+ /**
709
+ * Asserts that accuracy is above a threshold
710
+ */
711
+ toHaveAccuracyAbove(threshold) {
712
+ this.validateGroundTruth();
713
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
714
+ const passed = metrics.accuracy >= threshold;
715
+ const result = {
716
+ type: "accuracy",
717
+ passed,
718
+ message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}%` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}%`,
719
+ expected: threshold,
720
+ actual: metrics.accuracy,
721
+ field: this.fieldName
722
+ };
723
+ this.assertions.push(result);
724
+ recordAssertion(result);
725
+ if (!passed) {
726
+ throw new AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
727
+ }
728
+ return this;
729
+ }
730
+ /**
731
+ * Asserts that precision is above a threshold
732
+ * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
733
+ * @param threshold - Threshold when class is specified
734
+ */
735
+ toHavePrecisionAbove(classOrThreshold, threshold) {
736
+ this.validateGroundTruth();
737
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
738
+ let actualPrecision;
739
+ let targetClass;
740
+ let actualThreshold;
741
+ if (typeof classOrThreshold === "number") {
742
+ actualPrecision = metrics.macroAvg.precision;
743
+ actualThreshold = classOrThreshold;
744
+ } else {
745
+ targetClass = classOrThreshold;
746
+ actualThreshold = threshold;
747
+ const classMetrics = metrics.perClass[targetClass];
748
+ if (!classMetrics) {
749
+ throw new AssertionError(
750
+ `Class "${targetClass}" not found in predictions`,
751
+ targetClass,
752
+ Object.keys(metrics.perClass),
753
+ this.fieldName
754
+ );
755
+ }
756
+ actualPrecision = classMetrics.precision;
757
+ }
758
+ const passed = actualPrecision >= actualThreshold;
759
+ const result = {
760
+ type: "precision",
761
+ passed,
762
+ message: passed ? `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
763
+ expected: actualThreshold,
764
+ actual: actualPrecision,
765
+ field: this.fieldName,
766
+ class: targetClass
767
+ };
768
+ this.assertions.push(result);
769
+ recordAssertion(result);
770
+ if (!passed) {
771
+ throw new AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
772
+ }
773
+ return this;
774
+ }
775
+ /**
776
+ * Asserts that recall is above a threshold
777
+ * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
778
+ * @param threshold - Threshold when class is specified
779
+ */
780
+ toHaveRecallAbove(classOrThreshold, threshold) {
781
+ this.validateGroundTruth();
782
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
783
+ let actualRecall;
784
+ let targetClass;
785
+ let actualThreshold;
786
+ if (typeof classOrThreshold === "number") {
787
+ actualRecall = metrics.macroAvg.recall;
788
+ actualThreshold = classOrThreshold;
789
+ } else {
790
+ targetClass = classOrThreshold;
791
+ actualThreshold = threshold;
792
+ const classMetrics = metrics.perClass[targetClass];
793
+ if (!classMetrics) {
794
+ throw new AssertionError(
795
+ `Class "${targetClass}" not found in predictions`,
796
+ targetClass,
797
+ Object.keys(metrics.perClass),
798
+ this.fieldName
799
+ );
800
+ }
801
+ actualRecall = classMetrics.recall;
802
+ }
803
+ const passed = actualRecall >= actualThreshold;
804
+ const result = {
805
+ type: "recall",
806
+ passed,
807
+ message: passed ? `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
808
+ expected: actualThreshold,
809
+ actual: actualRecall,
810
+ field: this.fieldName,
811
+ class: targetClass
812
+ };
813
+ this.assertions.push(result);
814
+ recordAssertion(result);
815
+ if (!passed) {
816
+ throw new AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
817
+ }
818
+ return this;
819
+ }
820
+ /**
821
+ * Asserts that F1 score is above a threshold
822
+ * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
823
+ * @param threshold - Threshold when class is specified
824
+ */
825
+ toHaveF1Above(classOrThreshold, threshold) {
826
+ this.validateGroundTruth();
827
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
828
+ let actualF1;
829
+ let targetClass;
830
+ let actualThreshold;
831
+ if (typeof classOrThreshold === "number") {
832
+ actualF1 = metrics.macroAvg.f1;
833
+ actualThreshold = classOrThreshold;
834
+ } else {
835
+ targetClass = classOrThreshold;
836
+ actualThreshold = threshold;
837
+ const classMetrics = metrics.perClass[targetClass];
838
+ if (!classMetrics) {
839
+ throw new AssertionError(
840
+ `Class "${targetClass}" not found in predictions`,
841
+ targetClass,
842
+ Object.keys(metrics.perClass),
843
+ this.fieldName
844
+ );
845
+ }
846
+ actualF1 = classMetrics.f1;
847
+ }
848
+ const passed = actualF1 >= actualThreshold;
849
+ const result = {
850
+ type: "f1",
851
+ passed,
852
+ message: passed ? `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
853
+ expected: actualThreshold,
854
+ actual: actualF1,
855
+ field: this.fieldName,
856
+ class: targetClass
857
+ };
858
+ this.assertions.push(result);
859
+ recordAssertion(result);
860
+ if (!passed) {
861
+ throw new AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
862
+ }
863
+ return this;
864
+ }
865
+ /**
866
+ * Includes the confusion matrix in the report
867
+ */
868
+ toHaveConfusionMatrix() {
869
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
870
+ const fieldResult = {
871
+ field: this.fieldName,
872
+ metrics,
873
+ binarized: false
874
+ };
875
+ recordFieldMetrics(fieldResult);
876
+ const result = {
877
+ type: "confusionMatrix",
878
+ passed: true,
879
+ message: `Confusion matrix recorded for field "${this.fieldName}"`,
880
+ field: this.fieldName
881
+ };
882
+ this.assertions.push(result);
883
+ recordAssertion(result);
884
+ return this;
885
+ }
886
+ /**
887
+ * Asserts that a percentage of values are below or equal to a threshold.
888
+ * This is a distributional assertion that only looks at actual values (no ground truth required).
889
+ *
890
+ * @param valueThreshold - The value threshold to compare against
891
+ * @param percentageThreshold - The minimum percentage (0-1) of values that should be <= valueThreshold
892
+ * @returns this for method chaining
893
+ *
894
+ * @example
895
+ * // Assert that 90% of confidence scores are below 0.5
896
+ * expectStats(predictions)
897
+ * .field("confidence")
898
+ * .toHavePercentageBelow(0.5, 0.9)
899
+ */
900
+ toHavePercentageBelow(valueThreshold, percentageThreshold) {
901
+ const numericActual = filterNumericValues(this.actualValues);
902
+ if (numericActual.length === 0) {
903
+ throw new AssertionError(
904
+ `Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
905
+ percentageThreshold,
906
+ void 0,
907
+ this.fieldName
908
+ );
909
+ }
910
+ const actualPercentage = calculatePercentageBelow(numericActual, valueThreshold);
911
+ const passed = actualPercentage >= percentageThreshold;
912
+ const result = {
913
+ type: "percentageBelow",
914
+ passed,
915
+ message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
916
+ expected: percentageThreshold,
917
+ actual: actualPercentage,
918
+ field: this.fieldName
919
+ };
920
+ this.assertions.push(result);
921
+ recordAssertion(result);
922
+ if (!passed) {
923
+ throw new AssertionError(result.message, percentageThreshold, actualPercentage, this.fieldName);
924
+ }
925
+ return this;
926
+ }
927
+ /**
928
+ * Asserts that a percentage of values are above a threshold.
929
+ * This is a distributional assertion that only looks at actual values (no ground truth required).
930
+ *
931
+ * @param valueThreshold - The value threshold to compare against
932
+ * @param percentageThreshold - The minimum percentage (0-1) of values that should be > valueThreshold
933
+ * @returns this for method chaining
934
+ *
935
+ * @example
936
+ * // Assert that 80% of quality scores are above 0.7
937
+ * expectStats(predictions)
938
+ * .field("quality")
939
+ * .toHavePercentageAbove(0.7, 0.8)
940
+ */
941
+ toHavePercentageAbove(valueThreshold, percentageThreshold) {
942
+ const numericActual = filterNumericValues(this.actualValues);
943
+ if (numericActual.length === 0) {
944
+ throw new AssertionError(
945
+ `Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
946
+ percentageThreshold,
947
+ void 0,
948
+ this.fieldName
949
+ );
950
+ }
951
+ const actualPercentage = calculatePercentageAbove(numericActual, valueThreshold);
952
+ const passed = actualPercentage >= percentageThreshold;
953
+ const result = {
954
+ type: "percentageAbove",
955
+ passed,
956
+ message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
957
+ expected: percentageThreshold,
958
+ actual: actualPercentage,
959
+ field: this.fieldName
960
+ };
961
+ this.assertions.push(result);
962
+ recordAssertion(result);
963
+ if (!passed) {
964
+ throw new AssertionError(result.message, percentageThreshold, actualPercentage, this.fieldName);
965
+ }
966
+ return this;
967
+ }
968
+ /**
969
+ * Gets the computed metrics for this field
970
+ */
971
+ getMetrics() {
972
+ return computeClassificationMetrics(this.actualValues, this.expectedValues);
973
+ }
974
+ /**
975
+ * Gets all assertions made on this field
976
+ */
977
+ getAssertions() {
978
+ return this.assertions;
979
+ }
980
+ };
981
+
982
+ // src/assertions/expect-stats.ts
983
+ function normalizeInput(input) {
984
+ if ("aligned" in input && Array.isArray(input.aligned)) {
985
+ return input.aligned;
986
+ }
987
+ if (Array.isArray(input)) {
988
+ if (input.length === 0) {
989
+ return [];
990
+ }
991
+ const first = input[0];
992
+ if (first && "actual" in first && "expected" in first) {
993
+ return input;
994
+ }
995
+ return input.map((p) => ({
996
+ id: p.id,
997
+ actual: { ...p },
998
+ expected: {}
999
+ }));
1000
+ }
1001
+ throw new Error("Invalid input to expectStats(): expected ModelRunResult, Prediction[], or AlignedRecord[]");
1002
+ }
1003
+ function expectStats(inputOrActual, expected) {
1004
+ if (expected !== void 0) {
1005
+ if (!Array.isArray(inputOrActual)) {
1006
+ throw new Error(
1007
+ "When using two-argument expectStats(), first argument must be Prediction[]"
1008
+ );
1009
+ }
1010
+ const aligned2 = alignByKey(inputOrActual, expected);
1011
+ return new ExpectStats(aligned2);
1012
+ }
1013
+ const aligned = normalizeInput(inputOrActual);
1014
+ return new ExpectStats(aligned);
1015
+ }
1016
+ var ExpectStats = class {
1017
+ aligned;
1018
+ constructor(aligned) {
1019
+ this.aligned = aligned;
1020
+ }
1021
+ /**
1022
+ * Selects a field to evaluate
1023
+ */
1024
+ field(fieldName) {
1025
+ return new FieldSelector(this.aligned, fieldName);
1026
+ }
1027
+ /**
1028
+ * Gets the raw aligned records (for advanced use)
1029
+ */
1030
+ getAligned() {
1031
+ return this.aligned;
1032
+ }
1033
+ /**
1034
+ * Gets the count of records
1035
+ */
1036
+ count() {
1037
+ return this.aligned.length;
1038
+ }
1039
+ };
1040
+
1041
+ export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, createDataset, describe, evalTest, expectStats, extractFieldValues, filterComplete, it, loadDataset, runModel, runModelParallel, test, validatePredictions };
1042
+ //# sourceMappingURL=index.js.map
1043
+ //# sourceMappingURL=index.js.map