evalsense 0.3.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/README.md +309 -149
  2. package/dist/{chunk-BE7CB3AM.cjs → chunk-4BKZPVY4.cjs} +50 -13
  3. package/dist/chunk-4BKZPVY4.cjs.map +1 -0
  4. package/dist/{chunk-K6QPJ2NO.js → chunk-IUVDDMJ3.js} +50 -13
  5. package/dist/chunk-IUVDDMJ3.js.map +1 -0
  6. package/dist/chunk-NCCQRZ2Y.cjs +1141 -0
  7. package/dist/chunk-NCCQRZ2Y.cjs.map +1 -0
  8. package/dist/chunk-TDGWDK2L.js +1108 -0
  9. package/dist/chunk-TDGWDK2L.js.map +1 -0
  10. package/dist/cli.cjs +11 -11
  11. package/dist/cli.js +1 -1
  12. package/dist/index-CATqAHNK.d.cts +416 -0
  13. package/dist/index-CoMpaW-K.d.ts +416 -0
  14. package/dist/index.cjs +507 -629
  15. package/dist/index.cjs.map +1 -1
  16. package/dist/index.d.cts +210 -161
  17. package/dist/index.d.ts +210 -161
  18. package/dist/index.js +455 -573
  19. package/dist/index.js.map +1 -1
  20. package/dist/metrics/index.cjs +103 -342
  21. package/dist/metrics/index.cjs.map +1 -1
  22. package/dist/metrics/index.d.cts +260 -31
  23. package/dist/metrics/index.d.ts +260 -31
  24. package/dist/metrics/index.js +24 -312
  25. package/dist/metrics/index.js.map +1 -1
  26. package/dist/metrics/opinionated/index.cjs +5 -5
  27. package/dist/metrics/opinionated/index.d.cts +2 -163
  28. package/dist/metrics/opinionated/index.d.ts +2 -163
  29. package/dist/metrics/opinionated/index.js +1 -1
  30. package/dist/{types-C71p0wzM.d.cts → types-D0hzfyKm.d.cts} +1 -13
  31. package/dist/{types-C71p0wzM.d.ts → types-D0hzfyKm.d.ts} +1 -13
  32. package/package.json +1 -1
  33. package/dist/chunk-BE7CB3AM.cjs.map +0 -1
  34. package/dist/chunk-K6QPJ2NO.js.map +0 -1
  35. package/dist/chunk-RZFLCWTW.cjs +0 -942
  36. package/dist/chunk-RZFLCWTW.cjs.map +0 -1
  37. package/dist/chunk-Z3U6AUWX.js +0 -925
  38. package/dist/chunk-Z3U6AUWX.js.map +0 -1
package/dist/index.js CHANGED
@@ -1,8 +1,6 @@
1
- import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, DatasetError, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordAssertion, recordFieldMetrics } from './chunk-K6QPJ2NO.js';
2
- export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-K6QPJ2NO.js';
1
+ import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordFieldMetrics, recordAssertion } from './chunk-IUVDDMJ3.js';
2
+ export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-IUVDDMJ3.js';
3
3
  import './chunk-DGUM43GV.js';
4
- import { readFileSync } from 'fs';
5
- import { resolve, extname } from 'path';
6
4
 
7
5
  // src/core/describe.ts
8
6
  function describe(name, fn) {
@@ -91,136 +89,6 @@ function evalTestOnly(name, fn) {
91
89
  }
92
90
  evalTest.skip = evalTestSkip;
93
91
  evalTest.only = evalTestOnly;
94
- function loadDataset(path) {
95
- const absolutePath = resolve(process.cwd(), path);
96
- const ext = extname(absolutePath).toLowerCase();
97
- let records;
98
- try {
99
- const content = readFileSync(absolutePath, "utf-8");
100
- if (ext === ".ndjson" || ext === ".jsonl") {
101
- records = parseNDJSON(content);
102
- } else if (ext === ".json") {
103
- records = parseJSON(content);
104
- } else {
105
- throw new DatasetError(
106
- `Unsupported file format: ${ext}. Use .json, .ndjson, or .jsonl`,
107
- path
108
- );
109
- }
110
- } catch (error) {
111
- if (error instanceof DatasetError) {
112
- throw error;
113
- }
114
- const message = error instanceof Error ? error.message : String(error);
115
- throw new DatasetError(`Failed to load dataset from ${path}: ${message}`, path);
116
- }
117
- return {
118
- records,
119
- metadata: {
120
- source: path,
121
- count: records.length,
122
- loadedAt: /* @__PURE__ */ new Date()
123
- }
124
- };
125
- }
126
- function parseJSON(content) {
127
- const parsed = JSON.parse(content);
128
- if (!Array.isArray(parsed)) {
129
- throw new DatasetError("JSON dataset must be an array of records");
130
- }
131
- return parsed;
132
- }
133
- function parseNDJSON(content) {
134
- const lines = content.split("\n").filter((line) => line.trim() !== "");
135
- const records = [];
136
- for (let i = 0; i < lines.length; i++) {
137
- const line = lines[i];
138
- if (line === void 0) continue;
139
- try {
140
- records.push(JSON.parse(line));
141
- } catch {
142
- throw new DatasetError(`Invalid JSON at line ${i + 1} in NDJSON file`);
143
- }
144
- }
145
- return records;
146
- }
147
- function createDataset(records, source = "inline") {
148
- return {
149
- records,
150
- metadata: {
151
- source,
152
- count: records.length,
153
- loadedAt: /* @__PURE__ */ new Date()
154
- }
155
- };
156
- }
157
-
158
- // src/dataset/run-model.ts
159
- async function runModel(dataset, modelFn) {
160
- const startTime = Date.now();
161
- const predictions = [];
162
- const aligned = [];
163
- for (const record of dataset.records) {
164
- const id = getRecordId(record);
165
- const prediction = await modelFn(record);
166
- if (prediction.id !== id) {
167
- throw new DatasetError(
168
- `Prediction ID mismatch: expected "${id}", got "${prediction.id}". Model function must return the same ID as the input record.`
169
- );
170
- }
171
- predictions.push(prediction);
172
- aligned.push({
173
- id,
174
- actual: { ...prediction },
175
- expected: { ...record }
176
- });
177
- }
178
- return {
179
- predictions,
180
- aligned,
181
- duration: Date.now() - startTime
182
- };
183
- }
184
- function getRecordId(record) {
185
- const id = record.id ?? record._id;
186
- if (id === void 0 || id === null) {
187
- throw new DatasetError('Dataset records must have an "id" or "_id" field for alignment');
188
- }
189
- return String(id);
190
- }
191
- async function runModelParallel(dataset, modelFn, concurrency = 10) {
192
- const startTime = Date.now();
193
- const results = [];
194
- for (let i = 0; i < dataset.records.length; i += concurrency) {
195
- const batch = dataset.records.slice(i, i + concurrency);
196
- const batchResults = await Promise.all(
197
- batch.map(async (record) => {
198
- const prediction = await modelFn(record);
199
- return { prediction, record };
200
- })
201
- );
202
- results.push(...batchResults);
203
- }
204
- const predictions = [];
205
- const aligned = [];
206
- for (const { prediction, record } of results) {
207
- const id = getRecordId(record);
208
- if (prediction.id !== id) {
209
- throw new DatasetError(`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`);
210
- }
211
- predictions.push(prediction);
212
- aligned.push({
213
- id,
214
- actual: { ...prediction },
215
- expected: { ...record }
216
- });
217
- }
218
- return {
219
- predictions,
220
- aligned,
221
- duration: Date.now() - startTime
222
- };
223
- }
224
92
 
225
93
  // src/dataset/alignment.ts
226
94
  function alignByKey(predictions, expected, options = {}) {
@@ -293,14 +161,14 @@ function filterComplete(aligned, field) {
293
161
  }
294
162
 
295
163
  // src/dataset/integrity.ts
296
- function checkIntegrity(dataset, options = {}) {
164
+ function checkIntegrity(records, options = {}) {
297
165
  const { requiredFields = [], throwOnFailure = false } = options;
298
166
  const seenIds = /* @__PURE__ */ new Map();
299
167
  const missingIds = [];
300
168
  const duplicateIds = [];
301
169
  const missingFields = [];
302
- for (let i = 0; i < dataset.records.length; i++) {
303
- const record = dataset.records[i];
170
+ for (let i = 0; i < records.length; i++) {
171
+ const record = records[i];
304
172
  if (!record) continue;
305
173
  const id = record.id ?? record._id;
306
174
  if (id === void 0 || id === null) {
@@ -327,7 +195,7 @@ function checkIntegrity(dataset, options = {}) {
327
195
  const valid = missingIds.length === 0 && duplicateIds.length === 0 && missingFields.length === 0;
328
196
  const result = {
329
197
  valid,
330
- totalRecords: dataset.records.length,
198
+ totalRecords: records.length,
331
199
  missingIds,
332
200
  duplicateIds,
333
201
  missingFields
@@ -520,6 +388,91 @@ function calculatePercentageAbove(values, threshold) {
520
388
  return countAbove / values.length;
521
389
  }
522
390
 
391
+ // src/assertions/metric-matcher.ts
392
+ var MetricMatcher = class {
393
+ context;
394
+ constructor(context) {
395
+ this.context = context;
396
+ }
397
+ formatMetricValue(value) {
398
+ if (this.context.formatValue) {
399
+ return this.context.formatValue(value);
400
+ }
401
+ if (value >= 0 && value <= 1) {
402
+ return `${(value * 100).toFixed(1)}%`;
403
+ }
404
+ return value.toFixed(4);
405
+ }
406
+ createAssertion(operator, threshold, passed) {
407
+ const { metricName, metricValue, fieldName, targetClass } = this.context;
408
+ const formattedActual = this.formatMetricValue(metricValue);
409
+ const formattedThreshold = this.formatMetricValue(threshold);
410
+ const classInfo = targetClass ? ` for "${targetClass}"` : "";
411
+ const operatorText = {
412
+ ">=": "at least",
413
+ ">": "above",
414
+ "<=": "at most",
415
+ "<": "below",
416
+ "===": "equal to"
417
+ }[operator];
418
+ const message = passed ? `${metricName}${classInfo} ${formattedActual} is ${operatorText} ${formattedThreshold}` : `${metricName}${classInfo} ${formattedActual} is not ${operatorText} ${formattedThreshold}`;
419
+ return {
420
+ type: metricName.toLowerCase().replace(/\s+/g, "").replace(/²/g, "2"),
421
+ passed,
422
+ message,
423
+ expected: threshold,
424
+ actual: metricValue,
425
+ field: fieldName,
426
+ class: targetClass
427
+ };
428
+ }
429
+ recordAndReturn(result) {
430
+ this.context.assertions.push(result);
431
+ recordAssertion(result);
432
+ return this.context.parent;
433
+ }
434
+ /**
435
+ * Assert that the metric is greater than or equal to the threshold (>=)
436
+ */
437
+ toBeAtLeast(threshold) {
438
+ const passed = this.context.metricValue >= threshold;
439
+ const result = this.createAssertion(">=", threshold, passed);
440
+ return this.recordAndReturn(result);
441
+ }
442
+ /**
443
+ * Assert that the metric is strictly greater than the threshold (>)
444
+ */
445
+ toBeAbove(threshold) {
446
+ const passed = this.context.metricValue > threshold;
447
+ const result = this.createAssertion(">", threshold, passed);
448
+ return this.recordAndReturn(result);
449
+ }
450
+ /**
451
+ * Assert that the metric is less than or equal to the threshold (<=)
452
+ */
453
+ toBeAtMost(threshold) {
454
+ const passed = this.context.metricValue <= threshold;
455
+ const result = this.createAssertion("<=", threshold, passed);
456
+ return this.recordAndReturn(result);
457
+ }
458
+ /**
459
+ * Assert that the metric is strictly less than the threshold (<)
460
+ */
461
+ toBeBelow(threshold) {
462
+ const passed = this.context.metricValue < threshold;
463
+ const result = this.createAssertion("<", threshold, passed);
464
+ return this.recordAndReturn(result);
465
+ }
466
+ /**
467
+ * Assert that the metric equals the expected value (with optional tolerance for floats)
468
+ */
469
+ toEqual(expected, tolerance = 1e-9) {
470
+ const passed = Math.abs(this.context.metricValue - expected) <= tolerance;
471
+ const result = this.createAssertion("===", expected, passed);
472
+ return this.recordAndReturn(result);
473
+ }
474
+ };
475
+
523
476
  // src/assertions/binarize.ts
524
477
  var BinarizeSelector = class {
525
478
  fieldName;
@@ -551,161 +504,127 @@ var BinarizeSelector = class {
551
504
  }
552
505
  }
553
506
  }
507
+ // ============================================================================
508
+ // Classification Metric Getters
509
+ // ============================================================================
554
510
  /**
555
- * Asserts that accuracy is above a threshold
511
+ * Access accuracy metric for assertions
512
+ * @example
513
+ * expectStats(predictions, groundTruth)
514
+ * .field("score")
515
+ * .binarize(0.5)
516
+ * .accuracy.toBeAtLeast(0.8)
556
517
  */
557
- toHaveAccuracyAbove(threshold) {
518
+ get accuracy() {
558
519
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
559
- const passed = metrics.accuracy >= threshold;
560
- const result = {
561
- type: "accuracy",
562
- passed,
563
- message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})`,
564
- expected: threshold,
565
- actual: metrics.accuracy,
566
- field: this.fieldName
567
- };
568
- this.assertions.push(result);
569
- recordAssertion(result);
570
- if (!passed) {
571
- throw new AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
572
- }
573
- return this;
520
+ return new MetricMatcher({
521
+ parent: this,
522
+ metricName: "Accuracy",
523
+ metricValue: metrics.accuracy,
524
+ fieldName: this.fieldName,
525
+ assertions: this.assertions
526
+ });
574
527
  }
575
528
  /**
576
- * Asserts that precision is above a threshold
577
- * @param classOrThreshold - Either the class (true/false) or threshold
578
- * @param threshold - Threshold when class is specified
529
+ * Access F1 score metric for assertions (macro average)
530
+ * @example
531
+ * expectStats(predictions, groundTruth)
532
+ * .field("score")
533
+ * .binarize(0.5)
534
+ * .f1.toBeAtLeast(0.75)
579
535
  */
580
- toHavePrecisionAbove(classOrThreshold, threshold) {
536
+ get f1() {
581
537
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
582
- let actualPrecision;
583
- let targetClass;
584
- let actualThreshold;
585
- if (typeof classOrThreshold === "number") {
586
- actualPrecision = metrics.macroAvg.precision;
587
- actualThreshold = classOrThreshold;
588
- } else {
589
- targetClass = String(classOrThreshold);
590
- actualThreshold = threshold;
591
- const classMetrics = metrics.perClass[targetClass];
592
- if (!classMetrics) {
593
- throw new AssertionError(
594
- `Class "${targetClass}" not found in binarized predictions`,
595
- targetClass,
596
- Object.keys(metrics.perClass),
597
- this.fieldName
598
- );
599
- }
600
- actualPrecision = classMetrics.precision;
601
- }
602
- const passed = actualPrecision >= actualThreshold;
603
- const result = {
604
- type: "precision",
605
- passed,
606
- message: passed ? `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
607
- expected: actualThreshold,
608
- actual: actualPrecision,
609
- field: this.fieldName,
610
- class: targetClass
611
- };
612
- this.assertions.push(result);
613
- recordAssertion(result);
614
- if (!passed) {
615
- throw new AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
616
- }
617
- return this;
538
+ return new MetricMatcher({
539
+ parent: this,
540
+ metricName: "F1",
541
+ metricValue: metrics.macroAvg.f1,
542
+ fieldName: this.fieldName,
543
+ assertions: this.assertions
544
+ });
618
545
  }
619
546
  /**
620
- * Asserts that recall is above a threshold
621
- * @param classOrThreshold - Either the class (true/false) or threshold
622
- * @param threshold - Threshold when class is specified
547
+ * Access precision metric for assertions
548
+ * @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
549
+ * @example
550
+ * expectStats(predictions, groundTruth)
551
+ * .field("score")
552
+ * .binarize(0.5)
553
+ * .precision(true).toBeAtLeast(0.7)
623
554
  */
624
- toHaveRecallAbove(classOrThreshold, threshold) {
555
+ precision(targetClass) {
625
556
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
626
- let actualRecall;
627
- let targetClass;
628
- let actualThreshold;
629
- if (typeof classOrThreshold === "number") {
630
- actualRecall = metrics.macroAvg.recall;
631
- actualThreshold = classOrThreshold;
557
+ let metricValue;
558
+ let classKey;
559
+ if (targetClass === void 0) {
560
+ metricValue = metrics.macroAvg.precision;
632
561
  } else {
633
- targetClass = String(classOrThreshold);
634
- actualThreshold = threshold;
635
- const classMetrics = metrics.perClass[targetClass];
562
+ classKey = String(targetClass);
563
+ const classMetrics = metrics.perClass[classKey];
636
564
  if (!classMetrics) {
637
565
  throw new AssertionError(
638
- `Class "${targetClass}" not found in binarized predictions`,
639
- targetClass,
566
+ `Class "${classKey}" not found in binarized predictions`,
567
+ classKey,
640
568
  Object.keys(metrics.perClass),
641
569
  this.fieldName
642
570
  );
643
571
  }
644
- actualRecall = classMetrics.recall;
645
- }
646
- const passed = actualRecall >= actualThreshold;
647
- const result = {
648
- type: "recall",
649
- passed,
650
- message: passed ? `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
651
- expected: actualThreshold,
652
- actual: actualRecall,
653
- field: this.fieldName,
654
- class: targetClass
655
- };
656
- this.assertions.push(result);
657
- recordAssertion(result);
658
- if (!passed) {
659
- throw new AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
572
+ metricValue = classMetrics.precision;
660
573
  }
661
- return this;
574
+ return new MetricMatcher({
575
+ parent: this,
576
+ metricName: "Precision",
577
+ metricValue,
578
+ fieldName: this.fieldName,
579
+ targetClass: classKey,
580
+ assertions: this.assertions
581
+ });
662
582
  }
663
583
  /**
664
- * Asserts that F1 score is above a threshold
584
+ * Access recall metric for assertions
585
+ * @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
586
+ * @example
587
+ * expectStats(predictions, groundTruth)
588
+ * .field("score")
589
+ * .binarize(0.5)
590
+ * .recall(true).toBeAtLeast(0.7)
665
591
  */
666
- toHaveF1Above(classOrThreshold, threshold) {
592
+ recall(targetClass) {
667
593
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
668
- let actualF1;
669
- let targetClass;
670
- let actualThreshold;
671
- if (typeof classOrThreshold === "number") {
672
- actualF1 = metrics.macroAvg.f1;
673
- actualThreshold = classOrThreshold;
594
+ let metricValue;
595
+ let classKey;
596
+ if (targetClass === void 0) {
597
+ metricValue = metrics.macroAvg.recall;
674
598
  } else {
675
- targetClass = String(classOrThreshold);
676
- actualThreshold = threshold;
677
- const classMetrics = metrics.perClass[targetClass];
599
+ classKey = String(targetClass);
600
+ const classMetrics = metrics.perClass[classKey];
678
601
  if (!classMetrics) {
679
602
  throw new AssertionError(
680
- `Class "${targetClass}" not found in binarized predictions`,
681
- targetClass,
603
+ `Class "${classKey}" not found in binarized predictions`,
604
+ classKey,
682
605
  Object.keys(metrics.perClass),
683
606
  this.fieldName
684
607
  );
685
608
  }
686
- actualF1 = classMetrics.f1;
687
- }
688
- const passed = actualF1 >= actualThreshold;
689
- const result = {
690
- type: "f1",
691
- passed,
692
- message: passed ? `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
693
- expected: actualThreshold,
694
- actual: actualF1,
695
- field: this.fieldName,
696
- class: targetClass
697
- };
698
- this.assertions.push(result);
699
- recordAssertion(result);
700
- if (!passed) {
701
- throw new AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
609
+ metricValue = classMetrics.recall;
702
610
  }
703
- return this;
611
+ return new MetricMatcher({
612
+ parent: this,
613
+ metricName: "Recall",
614
+ metricValue,
615
+ fieldName: this.fieldName,
616
+ targetClass: classKey,
617
+ assertions: this.assertions
618
+ });
704
619
  }
620
+ // ============================================================================
621
+ // Display Methods
622
+ // ============================================================================
705
623
  /**
706
- * Includes the confusion matrix in the report
624
+ * Displays the confusion matrix in the report
625
+ * This is not an assertion - it always passes and just records the matrix for display
707
626
  */
708
- toHaveConfusionMatrix() {
627
+ displayConfusionMatrix() {
709
628
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
710
629
  const fieldResult = {
711
630
  field: this.fieldName,
@@ -724,6 +643,9 @@ var BinarizeSelector = class {
724
643
  recordAssertion(result);
725
644
  return this;
726
645
  }
646
+ // ============================================================================
647
+ // Utility Methods
648
+ // ============================================================================
727
649
  /**
728
650
  * Gets computed metrics
729
651
  */
@@ -738,6 +660,73 @@ var BinarizeSelector = class {
738
660
  }
739
661
  };
740
662
 
663
+ // src/assertions/percentage-matcher.ts
664
+ var PercentageMatcher = class {
665
+ context;
666
+ constructor(context) {
667
+ this.context = context;
668
+ }
669
+ formatPercentage(value) {
670
+ return `${(value * 100).toFixed(1)}%`;
671
+ }
672
+ createAssertion(operator, percentageThreshold, passed) {
673
+ const { fieldName, valueThreshold, direction, actualPercentage } = this.context;
674
+ const operatorText = {
675
+ ">=": "at least",
676
+ ">": "above",
677
+ "<=": "at most",
678
+ "<": "below"
679
+ }[operator];
680
+ const directionText = direction === "above" ? "above" : "below or equal to";
681
+ const message = passed ? `${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})` : `Only ${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})`;
682
+ return {
683
+ type: direction === "above" ? "percentageAbove" : "percentageBelow",
684
+ passed,
685
+ message,
686
+ expected: percentageThreshold,
687
+ actual: actualPercentage,
688
+ field: fieldName
689
+ };
690
+ }
691
+ recordAndReturn(result) {
692
+ this.context.assertions.push(result);
693
+ recordAssertion(result);
694
+ return this.context.parent;
695
+ }
696
+ /**
697
+ * Assert that the percentage is greater than or equal to the threshold (>=)
698
+ */
699
+ toBeAtLeast(percentageThreshold) {
700
+ const passed = this.context.actualPercentage >= percentageThreshold;
701
+ const result = this.createAssertion(">=", percentageThreshold, passed);
702
+ return this.recordAndReturn(result);
703
+ }
704
+ /**
705
+ * Assert that the percentage is strictly greater than the threshold (>)
706
+ */
707
+ toBeAbove(percentageThreshold) {
708
+ const passed = this.context.actualPercentage > percentageThreshold;
709
+ const result = this.createAssertion(">", percentageThreshold, passed);
710
+ return this.recordAndReturn(result);
711
+ }
712
+ /**
713
+ * Assert that the percentage is less than or equal to the threshold (<=)
714
+ */
715
+ toBeAtMost(percentageThreshold) {
716
+ const passed = this.context.actualPercentage <= percentageThreshold;
717
+ const result = this.createAssertion("<=", percentageThreshold, passed);
718
+ return this.recordAndReturn(result);
719
+ }
720
+ /**
721
+ * Assert that the percentage is strictly less than the threshold (<)
722
+ */
723
+ toBeBelow(percentageThreshold) {
724
+ const passed = this.context.actualPercentage < percentageThreshold;
725
+ const result = this.createAssertion("<", percentageThreshold, passed);
726
+ return this.recordAndReturn(result);
727
+ }
728
+ };
729
+
741
730
  // src/assertions/field-selector.ts
742
731
  var FieldSelector = class {
743
732
  aligned;
@@ -774,89 +763,93 @@ var FieldSelector = class {
774
763
  }
775
764
  }
776
765
  /**
777
- * Asserts that accuracy is above a threshold
766
+ * Validates that ground truth exists and both arrays contain numeric values.
767
+ * Returns the filtered numeric arrays for regression metrics.
778
768
  */
779
- toHaveAccuracyAbove(threshold) {
769
+ validateRegressionInputs() {
780
770
  this.validateGroundTruth();
781
- const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
782
- const passed = metrics.accuracy >= threshold;
783
- const result = {
784
- type: "accuracy",
785
- passed,
786
- message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}%` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}%`,
787
- expected: threshold,
788
- actual: metrics.accuracy,
789
- field: this.fieldName
790
- };
791
- this.assertions.push(result);
792
- recordAssertion(result);
793
- if (!passed) {
794
- throw new AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
771
+ const numericActual = filterNumericValues(this.actualValues);
772
+ const numericExpected = filterNumericValues(this.expectedValues);
773
+ if (numericActual.length === 0) {
774
+ throw new AssertionError(
775
+ `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
776
+ void 0,
777
+ void 0,
778
+ this.fieldName
779
+ );
795
780
  }
796
- return this;
781
+ if (numericExpected.length === 0) {
782
+ throw new AssertionError(
783
+ `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
784
+ void 0,
785
+ void 0,
786
+ this.fieldName
787
+ );
788
+ }
789
+ if (numericActual.length !== numericExpected.length) {
790
+ throw new AssertionError(
791
+ `Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
792
+ numericExpected.length,
793
+ numericActual.length,
794
+ this.fieldName
795
+ );
796
+ }
797
+ return { actual: numericActual, expected: numericExpected };
797
798
  }
799
+ // ============================================================================
800
+ // Classification Metric Getters
801
+ // ============================================================================
798
802
  /**
799
- * Asserts that precision is above a threshold
800
- * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
801
- * @param threshold - Threshold when class is specified
803
+ * Access accuracy metric for assertions
804
+ * @example
805
+ * expectStats(predictions, groundTruth)
806
+ * .field("sentiment")
807
+ * .accuracy.toBeAtLeast(0.8)
802
808
  */
803
- toHavePrecisionAbove(classOrThreshold, threshold) {
809
+ get accuracy() {
804
810
  this.validateGroundTruth();
805
811
  const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
806
- let actualPrecision;
807
- let targetClass;
808
- let actualThreshold;
809
- if (typeof classOrThreshold === "number") {
810
- actualPrecision = metrics.macroAvg.precision;
811
- actualThreshold = classOrThreshold;
812
- } else {
813
- targetClass = classOrThreshold;
814
- actualThreshold = threshold;
815
- const classMetrics = metrics.perClass[targetClass];
816
- if (!classMetrics) {
817
- throw new AssertionError(
818
- `Class "${targetClass}" not found in predictions`,
819
- targetClass,
820
- Object.keys(metrics.perClass),
821
- this.fieldName
822
- );
823
- }
824
- actualPrecision = classMetrics.precision;
825
- }
826
- const passed = actualPrecision >= actualThreshold;
827
- const result = {
828
- type: "precision",
829
- passed,
830
- message: passed ? `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
831
- expected: actualThreshold,
832
- actual: actualPrecision,
833
- field: this.fieldName,
834
- class: targetClass
835
- };
836
- this.assertions.push(result);
837
- recordAssertion(result);
838
- if (!passed) {
839
- throw new AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
840
- }
841
- return this;
812
+ return new MetricMatcher({
813
+ parent: this,
814
+ metricName: "Accuracy",
815
+ metricValue: metrics.accuracy,
816
+ fieldName: this.fieldName,
817
+ assertions: this.assertions
818
+ });
842
819
  }
843
820
  /**
844
- * Asserts that recall is above a threshold
845
- * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
846
- * @param threshold - Threshold when class is specified
821
+ * Access F1 score metric for assertions (macro average)
822
+ * @example
823
+ * expectStats(predictions, groundTruth)
824
+ * .field("sentiment")
825
+ * .f1.toBeAtLeast(0.75)
847
826
  */
848
- toHaveRecallAbove(classOrThreshold, threshold) {
827
+ get f1() {
849
828
  this.validateGroundTruth();
850
829
  const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
851
- let actualRecall;
852
- let targetClass;
853
- let actualThreshold;
854
- if (typeof classOrThreshold === "number") {
855
- actualRecall = metrics.macroAvg.recall;
856
- actualThreshold = classOrThreshold;
830
+ return new MetricMatcher({
831
+ parent: this,
832
+ metricName: "F1",
833
+ metricValue: metrics.macroAvg.f1,
834
+ fieldName: this.fieldName,
835
+ assertions: this.assertions
836
+ });
837
+ }
838
+ /**
839
+ * Access precision metric for assertions
840
+ * @param targetClass - Optional class name. If omitted, uses macro average
841
+ * @example
842
+ * expectStats(predictions, groundTruth)
843
+ * .field("sentiment")
844
+ * .precision("positive").toBeAtLeast(0.7)
845
+ */
846
+ precision(targetClass) {
847
+ this.validateGroundTruth();
848
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
849
+ let metricValue;
850
+ if (targetClass === void 0) {
851
+ metricValue = metrics.macroAvg.precision;
857
852
  } else {
858
- targetClass = classOrThreshold;
859
- actualThreshold = threshold;
860
853
  const classMetrics = metrics.perClass[targetClass];
861
854
  if (!classMetrics) {
862
855
  throw new AssertionError(
@@ -866,42 +859,32 @@ var FieldSelector = class {
866
859
  this.fieldName
867
860
  );
868
861
  }
869
- actualRecall = classMetrics.recall;
870
- }
871
- const passed = actualRecall >= actualThreshold;
872
- const result = {
873
- type: "recall",
874
- passed,
875
- message: passed ? `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
876
- expected: actualThreshold,
877
- actual: actualRecall,
878
- field: this.fieldName,
879
- class: targetClass
880
- };
881
- this.assertions.push(result);
882
- recordAssertion(result);
883
- if (!passed) {
884
- throw new AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
862
+ metricValue = classMetrics.precision;
885
863
  }
886
- return this;
864
+ return new MetricMatcher({
865
+ parent: this,
866
+ metricName: "Precision",
867
+ metricValue,
868
+ fieldName: this.fieldName,
869
+ targetClass,
870
+ assertions: this.assertions
871
+ });
887
872
  }
888
873
  /**
889
- * Asserts that F1 score is above a threshold
890
- * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
891
- * @param threshold - Threshold when class is specified
874
+ * Access recall metric for assertions
875
+ * @param targetClass - Optional class name. If omitted, uses macro average
876
+ * @example
877
+ * expectStats(predictions, groundTruth)
878
+ * .field("sentiment")
879
+ * .recall("positive").toBeAtLeast(0.7)
892
880
  */
893
- toHaveF1Above(classOrThreshold, threshold) {
881
+ recall(targetClass) {
894
882
  this.validateGroundTruth();
895
883
  const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
896
- let actualF1;
897
- let targetClass;
898
- let actualThreshold;
899
- if (typeof classOrThreshold === "number") {
900
- actualF1 = metrics.macroAvg.f1;
901
- actualThreshold = classOrThreshold;
884
+ let metricValue;
885
+ if (targetClass === void 0) {
886
+ metricValue = metrics.macroAvg.recall;
902
887
  } else {
903
- targetClass = classOrThreshold;
904
- actualThreshold = threshold;
905
888
  const classMetrics = metrics.perClass[targetClass];
906
889
  if (!classMetrics) {
907
890
  throw new AssertionError(
@@ -911,272 +894,171 @@ var FieldSelector = class {
911
894
  this.fieldName
912
895
  );
913
896
  }
914
- actualF1 = classMetrics.f1;
915
- }
916
- const passed = actualF1 >= actualThreshold;
917
- const result = {
918
- type: "f1",
919
- passed,
920
- message: passed ? `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
921
- expected: actualThreshold,
922
- actual: actualF1,
923
- field: this.fieldName,
924
- class: targetClass
925
- };
926
- this.assertions.push(result);
927
- recordAssertion(result);
928
- if (!passed) {
929
- throw new AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
897
+ metricValue = classMetrics.recall;
930
898
  }
931
- return this;
899
+ return new MetricMatcher({
900
+ parent: this,
901
+ metricName: "Recall",
902
+ metricValue,
903
+ fieldName: this.fieldName,
904
+ targetClass,
905
+ assertions: this.assertions
906
+ });
932
907
  }
908
+ // ============================================================================
909
+ // Regression Metric Getters
910
+ // ============================================================================
933
911
  /**
934
- * Includes the confusion matrix in the report
912
+ * Access Mean Absolute Error metric for assertions
913
+ * @example
914
+ * expectStats(predictions, groundTruth)
915
+ * .field("score")
916
+ * .mae.toBeAtMost(0.1)
935
917
  */
936
- toHaveConfusionMatrix() {
937
- const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
938
- const fieldResult = {
939
- field: this.fieldName,
940
- metrics,
941
- binarized: false
942
- };
943
- recordFieldMetrics(fieldResult);
944
- const result = {
945
- type: "confusionMatrix",
946
- passed: true,
947
- message: `Confusion matrix recorded for field "${this.fieldName}"`,
948
- field: this.fieldName
949
- };
950
- this.assertions.push(result);
951
- recordAssertion(result);
952
- return this;
918
+ get mae() {
919
+ const { actual, expected } = this.validateRegressionInputs();
920
+ const metrics = computeRegressionMetrics(actual, expected);
921
+ return new MetricMatcher({
922
+ parent: this,
923
+ metricName: "MAE",
924
+ metricValue: metrics.mae,
925
+ fieldName: this.fieldName,
926
+ assertions: this.assertions,
927
+ formatValue: (v) => v.toFixed(4)
928
+ });
953
929
  }
954
930
  /**
955
- * Asserts that a percentage of values are below or equal to a threshold.
956
- * This is a distributional assertion that only looks at actual values (no ground truth required).
957
- *
931
+ * Access Root Mean Squared Error metric for assertions
932
+ * @example
933
+ * expectStats(predictions, groundTruth)
934
+ * .field("score")
935
+ * .rmse.toBeAtMost(0.15)
936
+ */
937
+ get rmse() {
938
+ const { actual, expected } = this.validateRegressionInputs();
939
+ const metrics = computeRegressionMetrics(actual, expected);
940
+ return new MetricMatcher({
941
+ parent: this,
942
+ metricName: "RMSE",
943
+ metricValue: metrics.rmse,
944
+ fieldName: this.fieldName,
945
+ assertions: this.assertions,
946
+ formatValue: (v) => v.toFixed(4)
947
+ });
948
+ }
949
+ /**
950
+ * Access R-squared (coefficient of determination) metric for assertions
951
+ * @example
952
+ * expectStats(predictions, groundTruth)
953
+ * .field("score")
954
+ * .r2.toBeAtLeast(0.8)
955
+ */
956
+ get r2() {
957
+ const { actual, expected } = this.validateRegressionInputs();
958
+ const metrics = computeRegressionMetrics(actual, expected);
959
+ return new MetricMatcher({
960
+ parent: this,
961
+ metricName: "R\xB2",
962
+ metricValue: metrics.r2,
963
+ fieldName: this.fieldName,
964
+ assertions: this.assertions,
965
+ formatValue: (v) => v.toFixed(4)
966
+ });
967
+ }
968
+ // ============================================================================
969
+ // Distribution Assertions
970
+ // ============================================================================
971
+ /**
972
+ * Assert on the percentage of values below or equal to a threshold
958
973
  * @param valueThreshold - The value threshold to compare against
959
- * @param percentageThreshold - The minimum percentage (0-1) of values that should be <= valueThreshold
960
- * @returns this for method chaining
961
- *
962
974
  * @example
963
- * // Assert that 90% of confidence scores are below 0.5
964
975
  * expectStats(predictions)
965
976
  * .field("confidence")
966
- * .toHavePercentageBelow(0.5, 0.9)
977
+ * .percentageBelow(0.5).toBeAtLeast(0.9)
967
978
  */
968
- toHavePercentageBelow(valueThreshold, percentageThreshold) {
979
+ percentageBelow(valueThreshold) {
969
980
  const numericActual = filterNumericValues(this.actualValues);
970
981
  if (numericActual.length === 0) {
971
982
  throw new AssertionError(
972
983
  `Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
973
- percentageThreshold,
984
+ void 0,
974
985
  void 0,
975
986
  this.fieldName
976
987
  );
977
988
  }
978
989
  const actualPercentage = calculatePercentageBelow(numericActual, valueThreshold);
979
- const passed = actualPercentage >= percentageThreshold;
980
- const result = {
981
- type: "percentageBelow",
982
- passed,
983
- message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
984
- expected: percentageThreshold,
985
- actual: actualPercentage,
986
- field: this.fieldName
987
- };
988
- this.assertions.push(result);
989
- recordAssertion(result);
990
- if (!passed) {
991
- throw new AssertionError(
992
- result.message,
993
- percentageThreshold,
994
- actualPercentage,
995
- this.fieldName
996
- );
997
- }
998
- return this;
990
+ return new PercentageMatcher({
991
+ parent: this,
992
+ fieldName: this.fieldName,
993
+ valueThreshold,
994
+ direction: "below",
995
+ actualPercentage,
996
+ assertions: this.assertions
997
+ });
999
998
  }
1000
999
  /**
1001
- * Asserts that a percentage of values are above a threshold.
1002
- * This is a distributional assertion that only looks at actual values (no ground truth required).
1003
- *
1000
+ * Assert on the percentage of values above a threshold
1004
1001
  * @param valueThreshold - The value threshold to compare against
1005
- * @param percentageThreshold - The minimum percentage (0-1) of values that should be > valueThreshold
1006
- * @returns this for method chaining
1007
- *
1008
1002
  * @example
1009
- * // Assert that 80% of quality scores are above 0.7
1010
1003
  * expectStats(predictions)
1011
1004
  * .field("quality")
1012
- * .toHavePercentageAbove(0.7, 0.8)
1005
+ * .percentageAbove(0.7).toBeAtLeast(0.8)
1013
1006
  */
1014
- toHavePercentageAbove(valueThreshold, percentageThreshold) {
1007
+ percentageAbove(valueThreshold) {
1015
1008
  const numericActual = filterNumericValues(this.actualValues);
1016
1009
  if (numericActual.length === 0) {
1017
1010
  throw new AssertionError(
1018
1011
  `Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
1019
- percentageThreshold,
1012
+ void 0,
1020
1013
  void 0,
1021
1014
  this.fieldName
1022
1015
  );
1023
1016
  }
1024
1017
  const actualPercentage = calculatePercentageAbove(numericActual, valueThreshold);
1025
- const passed = actualPercentage >= percentageThreshold;
1026
- const result = {
1027
- type: "percentageAbove",
1028
- passed,
1029
- message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
1030
- expected: percentageThreshold,
1031
- actual: actualPercentage,
1032
- field: this.fieldName
1033
- };
1034
- this.assertions.push(result);
1035
- recordAssertion(result);
1036
- if (!passed) {
1037
- throw new AssertionError(
1038
- result.message,
1039
- percentageThreshold,
1040
- actualPercentage,
1041
- this.fieldName
1042
- );
1043
- }
1044
- return this;
1018
+ return new PercentageMatcher({
1019
+ parent: this,
1020
+ fieldName: this.fieldName,
1021
+ valueThreshold,
1022
+ direction: "above",
1023
+ actualPercentage,
1024
+ assertions: this.assertions
1025
+ });
1045
1026
  }
1046
1027
  // ============================================================================
1047
- // Regression Assertions
1028
+ // Display Methods
1048
1029
  // ============================================================================
1049
1030
  /**
1050
- * Validates that ground truth exists and both arrays contain numeric values.
1051
- * Returns the filtered numeric arrays for regression metrics.
1052
- */
1053
- validateRegressionInputs() {
1054
- this.validateGroundTruth();
1055
- const numericActual = filterNumericValues(this.actualValues);
1056
- const numericExpected = filterNumericValues(this.expectedValues);
1057
- if (numericActual.length === 0) {
1058
- throw new AssertionError(
1059
- `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
1060
- void 0,
1061
- void 0,
1062
- this.fieldName
1063
- );
1064
- }
1065
- if (numericExpected.length === 0) {
1066
- throw new AssertionError(
1067
- `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
1068
- void 0,
1069
- void 0,
1070
- this.fieldName
1071
- );
1072
- }
1073
- if (numericActual.length !== numericExpected.length) {
1074
- throw new AssertionError(
1075
- `Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
1076
- numericExpected.length,
1077
- numericActual.length,
1078
- this.fieldName
1079
- );
1080
- }
1081
- return { actual: numericActual, expected: numericExpected };
1082
- }
1083
- /**
1084
- * Asserts that Mean Absolute Error is below a threshold.
1085
- * Requires numeric values in both actual and expected.
1086
- *
1087
- * @param threshold - Maximum allowed MAE
1088
- * @returns this for method chaining
1089
- *
1090
- * @example
1091
- * expectStats(predictions, groundTruth)
1092
- * .field("score")
1093
- * .toHaveMAEBelow(0.1)
1094
- */
1095
- toHaveMAEBelow(threshold) {
1096
- const { actual, expected } = this.validateRegressionInputs();
1097
- const metrics = computeRegressionMetrics(actual, expected);
1098
- const passed = metrics.mae <= threshold;
1099
- const result = {
1100
- type: "mae",
1101
- passed,
1102
- message: passed ? `MAE ${metrics.mae.toFixed(4)} is below ${threshold}` : `MAE ${metrics.mae.toFixed(4)} exceeds threshold ${threshold}`,
1103
- expected: threshold,
1104
- actual: metrics.mae,
1105
- field: this.fieldName
1106
- };
1107
- this.assertions.push(result);
1108
- recordAssertion(result);
1109
- if (!passed) {
1110
- throw new AssertionError(result.message, threshold, metrics.mae, this.fieldName);
1111
- }
1112
- return this;
1113
- }
1114
- /**
1115
- * Asserts that Root Mean Squared Error is below a threshold.
1116
- * Requires numeric values in both actual and expected.
1117
- *
1118
- * @param threshold - Maximum allowed RMSE
1119
- * @returns this for method chaining
1120
- *
1031
+ * Displays the confusion matrix in the report
1032
+ * This is not an assertion - it always passes and just records the matrix for display
1121
1033
  * @example
1122
1034
  * expectStats(predictions, groundTruth)
1123
- * .field("score")
1124
- * .toHaveRMSEBelow(0.15)
1035
+ * .field("sentiment")
1036
+ * .accuracy.toBeAtLeast(0.8)
1037
+ * .displayConfusionMatrix()
1125
1038
  */
1126
- toHaveRMSEBelow(threshold) {
1127
- const { actual, expected } = this.validateRegressionInputs();
1128
- const metrics = computeRegressionMetrics(actual, expected);
1129
- const passed = metrics.rmse <= threshold;
1130
- const result = {
1131
- type: "rmse",
1132
- passed,
1133
- message: passed ? `RMSE ${metrics.rmse.toFixed(4)} is below ${threshold}` : `RMSE ${metrics.rmse.toFixed(4)} exceeds threshold ${threshold}`,
1134
- expected: threshold,
1135
- actual: metrics.rmse,
1136
- field: this.fieldName
1039
+ displayConfusionMatrix() {
1040
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
1041
+ const fieldResult = {
1042
+ field: this.fieldName,
1043
+ metrics,
1044
+ binarized: false
1137
1045
  };
1138
- this.assertions.push(result);
1139
- recordAssertion(result);
1140
- if (!passed) {
1141
- throw new AssertionError(result.message, threshold, metrics.rmse, this.fieldName);
1142
- }
1143
- return this;
1144
- }
1145
- /**
1146
- * Asserts that R-squared (coefficient of determination) is above a threshold.
1147
- * R² measures how well the predictions explain the variance in expected values.
1148
- * R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
1149
- * Requires numeric values in both actual and expected.
1150
- *
1151
- * @param threshold - Minimum required R² value (0-1)
1152
- * @returns this for method chaining
1153
- *
1154
- * @example
1155
- * expectStats(predictions, groundTruth)
1156
- * .field("score")
1157
- * .toHaveR2Above(0.8)
1158
- */
1159
- toHaveR2Above(threshold) {
1160
- const { actual, expected } = this.validateRegressionInputs();
1161
- const metrics = computeRegressionMetrics(actual, expected);
1162
- const passed = metrics.r2 >= threshold;
1046
+ recordFieldMetrics(fieldResult);
1163
1047
  const result = {
1164
- type: "r2",
1165
- passed,
1166
- message: passed ? `R\xB2 ${metrics.r2.toFixed(4)} is above ${threshold}` : `R\xB2 ${metrics.r2.toFixed(4)} is below threshold ${threshold}`,
1167
- expected: threshold,
1168
- actual: metrics.r2,
1048
+ type: "confusionMatrix",
1049
+ passed: true,
1050
+ message: `Confusion matrix recorded for field "${this.fieldName}"`,
1169
1051
  field: this.fieldName
1170
1052
  };
1171
1053
  this.assertions.push(result);
1172
1054
  recordAssertion(result);
1173
- if (!passed) {
1174
- throw new AssertionError(result.message, threshold, metrics.r2, this.fieldName);
1175
- }
1176
1055
  return this;
1177
1056
  }
1057
+ // ============================================================================
1058
+ // Utility Methods
1059
+ // ============================================================================
1178
1060
  /**
1179
- * Gets the computed metrics for this field
1061
+ * Gets the computed classification metrics for this field
1180
1062
  */
1181
1063
  getMetrics() {
1182
1064
  return computeClassificationMetrics(this.actualValues, this.expectedValues);
@@ -1209,7 +1091,7 @@ function normalizeInput(input) {
1209
1091
  }));
1210
1092
  }
1211
1093
  throw new Error(
1212
- "Invalid input to expectStats(): expected ModelRunResult, Prediction[], or AlignedRecord[]"
1094
+ "Invalid input to expectStats(): expected { aligned: AlignedRecord[] }, Prediction[], or AlignedRecord[]"
1213
1095
  );
1214
1096
  }
1215
1097
  function expectStats(inputOrActual, expected, options) {
@@ -1254,6 +1136,6 @@ var ExpectStats = class {
1254
1136
  }
1255
1137
  };
1256
1138
 
1257
- export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, createDataset, describe, evalTest, expectStats, extractFieldValues, filterComplete, it, loadDataset, runModel, runModelParallel, test, validatePredictions };
1139
+ export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, describe, evalTest, expectStats, extractFieldValues, filterComplete, it, test, validatePredictions };
1258
1140
  //# sourceMappingURL=index.js.map
1259
1141
  //# sourceMappingURL=index.js.map