evalsense 0.3.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/README.md +235 -98
  2. package/dist/{chunk-BFGA2NUB.cjs → chunk-4BKZPVY4.cjs} +13 -6
  3. package/dist/chunk-4BKZPVY4.cjs.map +1 -0
  4. package/dist/{chunk-IYLSY7NX.js → chunk-IUVDDMJ3.js} +13 -6
  5. package/dist/chunk-IUVDDMJ3.js.map +1 -0
  6. package/dist/chunk-NCCQRZ2Y.cjs +1141 -0
  7. package/dist/chunk-NCCQRZ2Y.cjs.map +1 -0
  8. package/dist/chunk-TDGWDK2L.js +1108 -0
  9. package/dist/chunk-TDGWDK2L.js.map +1 -0
  10. package/dist/cli.cjs +11 -11
  11. package/dist/cli.js +1 -1
  12. package/dist/index-CATqAHNK.d.cts +416 -0
  13. package/dist/index-CoMpaW-K.d.ts +416 -0
  14. package/dist/index.cjs +507 -580
  15. package/dist/index.cjs.map +1 -1
  16. package/dist/index.d.cts +210 -161
  17. package/dist/index.d.ts +210 -161
  18. package/dist/index.js +455 -524
  19. package/dist/index.js.map +1 -1
  20. package/dist/metrics/index.cjs +103 -342
  21. package/dist/metrics/index.cjs.map +1 -1
  22. package/dist/metrics/index.d.cts +260 -31
  23. package/dist/metrics/index.d.ts +260 -31
  24. package/dist/metrics/index.js +24 -312
  25. package/dist/metrics/index.js.map +1 -1
  26. package/dist/metrics/opinionated/index.cjs +5 -5
  27. package/dist/metrics/opinionated/index.d.cts +2 -163
  28. package/dist/metrics/opinionated/index.d.ts +2 -163
  29. package/dist/metrics/opinionated/index.js +1 -1
  30. package/dist/{types-C71p0wzM.d.cts → types-D0hzfyKm.d.cts} +1 -13
  31. package/dist/{types-C71p0wzM.d.ts → types-D0hzfyKm.d.ts} +1 -13
  32. package/package.json +1 -1
  33. package/dist/chunk-BFGA2NUB.cjs.map +0 -1
  34. package/dist/chunk-IYLSY7NX.js.map +0 -1
  35. package/dist/chunk-RZFLCWTW.cjs +0 -942
  36. package/dist/chunk-RZFLCWTW.cjs.map +0 -1
  37. package/dist/chunk-Z3U6AUWX.js +0 -925
  38. package/dist/chunk-Z3U6AUWX.js.map +0 -1
package/dist/index.js CHANGED
@@ -1,8 +1,6 @@
1
- import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, DatasetError, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordAssertion, recordFieldMetrics } from './chunk-IYLSY7NX.js';
2
- export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-IYLSY7NX.js';
1
+ import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordFieldMetrics, recordAssertion } from './chunk-IUVDDMJ3.js';
2
+ export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-IUVDDMJ3.js';
3
3
  import './chunk-DGUM43GV.js';
4
- import { readFileSync } from 'fs';
5
- import { resolve, extname } from 'path';
6
4
 
7
5
  // src/core/describe.ts
8
6
  function describe(name, fn) {
@@ -91,136 +89,6 @@ function evalTestOnly(name, fn) {
91
89
  }
92
90
  evalTest.skip = evalTestSkip;
93
91
  evalTest.only = evalTestOnly;
94
- function loadDataset(path) {
95
- const absolutePath = resolve(process.cwd(), path);
96
- const ext = extname(absolutePath).toLowerCase();
97
- let records;
98
- try {
99
- const content = readFileSync(absolutePath, "utf-8");
100
- if (ext === ".ndjson" || ext === ".jsonl") {
101
- records = parseNDJSON(content);
102
- } else if (ext === ".json") {
103
- records = parseJSON(content);
104
- } else {
105
- throw new DatasetError(
106
- `Unsupported file format: ${ext}. Use .json, .ndjson, or .jsonl`,
107
- path
108
- );
109
- }
110
- } catch (error) {
111
- if (error instanceof DatasetError) {
112
- throw error;
113
- }
114
- const message = error instanceof Error ? error.message : String(error);
115
- throw new DatasetError(`Failed to load dataset from ${path}: ${message}`, path);
116
- }
117
- return {
118
- records,
119
- metadata: {
120
- source: path,
121
- count: records.length,
122
- loadedAt: /* @__PURE__ */ new Date()
123
- }
124
- };
125
- }
126
- function parseJSON(content) {
127
- const parsed = JSON.parse(content);
128
- if (!Array.isArray(parsed)) {
129
- throw new DatasetError("JSON dataset must be an array of records");
130
- }
131
- return parsed;
132
- }
133
- function parseNDJSON(content) {
134
- const lines = content.split("\n").filter((line) => line.trim() !== "");
135
- const records = [];
136
- for (let i = 0; i < lines.length; i++) {
137
- const line = lines[i];
138
- if (line === void 0) continue;
139
- try {
140
- records.push(JSON.parse(line));
141
- } catch {
142
- throw new DatasetError(`Invalid JSON at line ${i + 1} in NDJSON file`);
143
- }
144
- }
145
- return records;
146
- }
147
- function createDataset(records, source = "inline") {
148
- return {
149
- records,
150
- metadata: {
151
- source,
152
- count: records.length,
153
- loadedAt: /* @__PURE__ */ new Date()
154
- }
155
- };
156
- }
157
-
158
- // src/dataset/run-model.ts
159
- async function runModel(dataset, modelFn) {
160
- const startTime = Date.now();
161
- const predictions = [];
162
- const aligned = [];
163
- for (const record of dataset.records) {
164
- const id = getRecordId(record);
165
- const prediction = await modelFn(record);
166
- if (prediction.id !== id) {
167
- throw new DatasetError(
168
- `Prediction ID mismatch: expected "${id}", got "${prediction.id}". Model function must return the same ID as the input record.`
169
- );
170
- }
171
- predictions.push(prediction);
172
- aligned.push({
173
- id,
174
- actual: { ...prediction },
175
- expected: { ...record }
176
- });
177
- }
178
- return {
179
- predictions,
180
- aligned,
181
- duration: Date.now() - startTime
182
- };
183
- }
184
- function getRecordId(record) {
185
- const id = record.id ?? record._id;
186
- if (id === void 0 || id === null) {
187
- throw new DatasetError('Dataset records must have an "id" or "_id" field for alignment');
188
- }
189
- return String(id);
190
- }
191
- async function runModelParallel(dataset, modelFn, concurrency = 10) {
192
- const startTime = Date.now();
193
- const results = [];
194
- for (let i = 0; i < dataset.records.length; i += concurrency) {
195
- const batch = dataset.records.slice(i, i + concurrency);
196
- const batchResults = await Promise.all(
197
- batch.map(async (record) => {
198
- const prediction = await modelFn(record);
199
- return { prediction, record };
200
- })
201
- );
202
- results.push(...batchResults);
203
- }
204
- const predictions = [];
205
- const aligned = [];
206
- for (const { prediction, record } of results) {
207
- const id = getRecordId(record);
208
- if (prediction.id !== id) {
209
- throw new DatasetError(`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`);
210
- }
211
- predictions.push(prediction);
212
- aligned.push({
213
- id,
214
- actual: { ...prediction },
215
- expected: { ...record }
216
- });
217
- }
218
- return {
219
- predictions,
220
- aligned,
221
- duration: Date.now() - startTime
222
- };
223
- }
224
92
 
225
93
  // src/dataset/alignment.ts
226
94
  function alignByKey(predictions, expected, options = {}) {
@@ -293,14 +161,14 @@ function filterComplete(aligned, field) {
293
161
  }
294
162
 
295
163
  // src/dataset/integrity.ts
296
- function checkIntegrity(dataset, options = {}) {
164
+ function checkIntegrity(records, options = {}) {
297
165
  const { requiredFields = [], throwOnFailure = false } = options;
298
166
  const seenIds = /* @__PURE__ */ new Map();
299
167
  const missingIds = [];
300
168
  const duplicateIds = [];
301
169
  const missingFields = [];
302
- for (let i = 0; i < dataset.records.length; i++) {
303
- const record = dataset.records[i];
170
+ for (let i = 0; i < records.length; i++) {
171
+ const record = records[i];
304
172
  if (!record) continue;
305
173
  const id = record.id ?? record._id;
306
174
  if (id === void 0 || id === null) {
@@ -327,7 +195,7 @@ function checkIntegrity(dataset, options = {}) {
327
195
  const valid = missingIds.length === 0 && duplicateIds.length === 0 && missingFields.length === 0;
328
196
  const result = {
329
197
  valid,
330
- totalRecords: dataset.records.length,
198
+ totalRecords: records.length,
331
199
  missingIds,
332
200
  duplicateIds,
333
201
  missingFields
@@ -520,6 +388,91 @@ function calculatePercentageAbove(values, threshold) {
520
388
  return countAbove / values.length;
521
389
  }
522
390
 
391
+ // src/assertions/metric-matcher.ts
392
+ var MetricMatcher = class {
393
+ context;
394
+ constructor(context) {
395
+ this.context = context;
396
+ }
397
+ formatMetricValue(value) {
398
+ if (this.context.formatValue) {
399
+ return this.context.formatValue(value);
400
+ }
401
+ if (value >= 0 && value <= 1) {
402
+ return `${(value * 100).toFixed(1)}%`;
403
+ }
404
+ return value.toFixed(4);
405
+ }
406
+ createAssertion(operator, threshold, passed) {
407
+ const { metricName, metricValue, fieldName, targetClass } = this.context;
408
+ const formattedActual = this.formatMetricValue(metricValue);
409
+ const formattedThreshold = this.formatMetricValue(threshold);
410
+ const classInfo = targetClass ? ` for "${targetClass}"` : "";
411
+ const operatorText = {
412
+ ">=": "at least",
413
+ ">": "above",
414
+ "<=": "at most",
415
+ "<": "below",
416
+ "===": "equal to"
417
+ }[operator];
418
+ const message = passed ? `${metricName}${classInfo} ${formattedActual} is ${operatorText} ${formattedThreshold}` : `${metricName}${classInfo} ${formattedActual} is not ${operatorText} ${formattedThreshold}`;
419
+ return {
420
+ type: metricName.toLowerCase().replace(/\s+/g, "").replace(/²/g, "2"),
421
+ passed,
422
+ message,
423
+ expected: threshold,
424
+ actual: metricValue,
425
+ field: fieldName,
426
+ class: targetClass
427
+ };
428
+ }
429
+ recordAndReturn(result) {
430
+ this.context.assertions.push(result);
431
+ recordAssertion(result);
432
+ return this.context.parent;
433
+ }
434
+ /**
435
+ * Assert that the metric is greater than or equal to the threshold (>=)
436
+ */
437
+ toBeAtLeast(threshold) {
438
+ const passed = this.context.metricValue >= threshold;
439
+ const result = this.createAssertion(">=", threshold, passed);
440
+ return this.recordAndReturn(result);
441
+ }
442
+ /**
443
+ * Assert that the metric is strictly greater than the threshold (>)
444
+ */
445
+ toBeAbove(threshold) {
446
+ const passed = this.context.metricValue > threshold;
447
+ const result = this.createAssertion(">", threshold, passed);
448
+ return this.recordAndReturn(result);
449
+ }
450
+ /**
451
+ * Assert that the metric is less than or equal to the threshold (<=)
452
+ */
453
+ toBeAtMost(threshold) {
454
+ const passed = this.context.metricValue <= threshold;
455
+ const result = this.createAssertion("<=", threshold, passed);
456
+ return this.recordAndReturn(result);
457
+ }
458
+ /**
459
+ * Assert that the metric is strictly less than the threshold (<)
460
+ */
461
+ toBeBelow(threshold) {
462
+ const passed = this.context.metricValue < threshold;
463
+ const result = this.createAssertion("<", threshold, passed);
464
+ return this.recordAndReturn(result);
465
+ }
466
+ /**
467
+ * Assert that the metric equals the expected value (with optional tolerance for floats)
468
+ */
469
+ toEqual(expected, tolerance = 1e-9) {
470
+ const passed = Math.abs(this.context.metricValue - expected) <= tolerance;
471
+ const result = this.createAssertion("===", expected, passed);
472
+ return this.recordAndReturn(result);
473
+ }
474
+ };
475
+
523
476
  // src/assertions/binarize.ts
524
477
  var BinarizeSelector = class {
525
478
  fieldName;
@@ -551,149 +504,127 @@ var BinarizeSelector = class {
551
504
  }
552
505
  }
553
506
  }
507
+ // ============================================================================
508
+ // Classification Metric Getters
509
+ // ============================================================================
554
510
  /**
555
- * Asserts that accuracy is above a threshold
511
+ * Access accuracy metric for assertions
512
+ * @example
513
+ * expectStats(predictions, groundTruth)
514
+ * .field("score")
515
+ * .binarize(0.5)
516
+ * .accuracy.toBeAtLeast(0.8)
556
517
  */
557
- toHaveAccuracyAbove(threshold) {
518
+ get accuracy() {
558
519
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
559
- const passed = metrics.accuracy >= threshold;
560
- const result = {
561
- type: "accuracy",
562
- passed,
563
- message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}% (binarized at ${this.threshold})`,
564
- expected: threshold,
565
- actual: metrics.accuracy,
566
- field: this.fieldName
567
- };
568
- this.assertions.push(result);
569
- recordAssertion(result);
570
- return this;
520
+ return new MetricMatcher({
521
+ parent: this,
522
+ metricName: "Accuracy",
523
+ metricValue: metrics.accuracy,
524
+ fieldName: this.fieldName,
525
+ assertions: this.assertions
526
+ });
571
527
  }
572
528
  /**
573
- * Asserts that precision is above a threshold
574
- * @param classOrThreshold - Either the class (true/false) or threshold
575
- * @param threshold - Threshold when class is specified
529
+ * Access F1 score metric for assertions (macro average)
530
+ * @example
531
+ * expectStats(predictions, groundTruth)
532
+ * .field("score")
533
+ * .binarize(0.5)
534
+ * .f1.toBeAtLeast(0.75)
576
535
  */
577
- toHavePrecisionAbove(classOrThreshold, threshold) {
536
+ get f1() {
578
537
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
579
- let actualPrecision;
580
- let targetClass;
581
- let actualThreshold;
582
- if (typeof classOrThreshold === "number") {
583
- actualPrecision = metrics.macroAvg.precision;
584
- actualThreshold = classOrThreshold;
585
- } else {
586
- targetClass = String(classOrThreshold);
587
- actualThreshold = threshold;
588
- const classMetrics = metrics.perClass[targetClass];
589
- if (!classMetrics) {
590
- throw new AssertionError(
591
- `Class "${targetClass}" not found in binarized predictions`,
592
- targetClass,
593
- Object.keys(metrics.perClass),
594
- this.fieldName
595
- );
596
- }
597
- actualPrecision = classMetrics.precision;
598
- }
599
- const passed = actualPrecision >= actualThreshold;
600
- const result = {
601
- type: "precision",
602
- passed,
603
- message: passed ? `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
604
- expected: actualThreshold,
605
- actual: actualPrecision,
606
- field: this.fieldName,
607
- class: targetClass
608
- };
609
- this.assertions.push(result);
610
- recordAssertion(result);
611
- return this;
538
+ return new MetricMatcher({
539
+ parent: this,
540
+ metricName: "F1",
541
+ metricValue: metrics.macroAvg.f1,
542
+ fieldName: this.fieldName,
543
+ assertions: this.assertions
544
+ });
612
545
  }
613
546
  /**
614
- * Asserts that recall is above a threshold
615
- * @param classOrThreshold - Either the class (true/false) or threshold
616
- * @param threshold - Threshold when class is specified
547
+ * Access precision metric for assertions
548
+ * @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
549
+ * @example
550
+ * expectStats(predictions, groundTruth)
551
+ * .field("score")
552
+ * .binarize(0.5)
553
+ * .precision(true).toBeAtLeast(0.7)
617
554
  */
618
- toHaveRecallAbove(classOrThreshold, threshold) {
555
+ precision(targetClass) {
619
556
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
620
- let actualRecall;
621
- let targetClass;
622
- let actualThreshold;
623
- if (typeof classOrThreshold === "number") {
624
- actualRecall = metrics.macroAvg.recall;
625
- actualThreshold = classOrThreshold;
557
+ let metricValue;
558
+ let classKey;
559
+ if (targetClass === void 0) {
560
+ metricValue = metrics.macroAvg.precision;
626
561
  } else {
627
- targetClass = String(classOrThreshold);
628
- actualThreshold = threshold;
629
- const classMetrics = metrics.perClass[targetClass];
562
+ classKey = String(targetClass);
563
+ const classMetrics = metrics.perClass[classKey];
630
564
  if (!classMetrics) {
631
565
  throw new AssertionError(
632
- `Class "${targetClass}" not found in binarized predictions`,
633
- targetClass,
566
+ `Class "${classKey}" not found in binarized predictions`,
567
+ classKey,
634
568
  Object.keys(metrics.perClass),
635
569
  this.fieldName
636
570
  );
637
571
  }
638
- actualRecall = classMetrics.recall;
572
+ metricValue = classMetrics.precision;
639
573
  }
640
- const passed = actualRecall >= actualThreshold;
641
- const result = {
642
- type: "recall",
643
- passed,
644
- message: passed ? `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
645
- expected: actualThreshold,
646
- actual: actualRecall,
647
- field: this.fieldName,
648
- class: targetClass
649
- };
650
- this.assertions.push(result);
651
- recordAssertion(result);
652
- return this;
574
+ return new MetricMatcher({
575
+ parent: this,
576
+ metricName: "Precision",
577
+ metricValue,
578
+ fieldName: this.fieldName,
579
+ targetClass: classKey,
580
+ assertions: this.assertions
581
+ });
653
582
  }
654
583
  /**
655
- * Asserts that F1 score is above a threshold
584
+ * Access recall metric for assertions
585
+ * @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
586
+ * @example
587
+ * expectStats(predictions, groundTruth)
588
+ * .field("score")
589
+ * .binarize(0.5)
590
+ * .recall(true).toBeAtLeast(0.7)
656
591
  */
657
- toHaveF1Above(classOrThreshold, threshold) {
592
+ recall(targetClass) {
658
593
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
659
- let actualF1;
660
- let targetClass;
661
- let actualThreshold;
662
- if (typeof classOrThreshold === "number") {
663
- actualF1 = metrics.macroAvg.f1;
664
- actualThreshold = classOrThreshold;
594
+ let metricValue;
595
+ let classKey;
596
+ if (targetClass === void 0) {
597
+ metricValue = metrics.macroAvg.recall;
665
598
  } else {
666
- targetClass = String(classOrThreshold);
667
- actualThreshold = threshold;
668
- const classMetrics = metrics.perClass[targetClass];
599
+ classKey = String(targetClass);
600
+ const classMetrics = metrics.perClass[classKey];
669
601
  if (!classMetrics) {
670
602
  throw new AssertionError(
671
- `Class "${targetClass}" not found in binarized predictions`,
672
- targetClass,
603
+ `Class "${classKey}" not found in binarized predictions`,
604
+ classKey,
673
605
  Object.keys(metrics.perClass),
674
606
  this.fieldName
675
607
  );
676
608
  }
677
- actualF1 = classMetrics.f1;
609
+ metricValue = classMetrics.recall;
678
610
  }
679
- const passed = actualF1 >= actualThreshold;
680
- const result = {
681
- type: "f1",
682
- passed,
683
- message: passed ? `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
684
- expected: actualThreshold,
685
- actual: actualF1,
686
- field: this.fieldName,
687
- class: targetClass
688
- };
689
- this.assertions.push(result);
690
- recordAssertion(result);
691
- return this;
611
+ return new MetricMatcher({
612
+ parent: this,
613
+ metricName: "Recall",
614
+ metricValue,
615
+ fieldName: this.fieldName,
616
+ targetClass: classKey,
617
+ assertions: this.assertions
618
+ });
692
619
  }
620
+ // ============================================================================
621
+ // Display Methods
622
+ // ============================================================================
693
623
  /**
694
- * Includes the confusion matrix in the report
624
+ * Displays the confusion matrix in the report
625
+ * This is not an assertion - it always passes and just records the matrix for display
695
626
  */
696
- toHaveConfusionMatrix() {
627
+ displayConfusionMatrix() {
697
628
  const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
698
629
  const fieldResult = {
699
630
  field: this.fieldName,
@@ -712,6 +643,9 @@ var BinarizeSelector = class {
712
643
  recordAssertion(result);
713
644
  return this;
714
645
  }
646
+ // ============================================================================
647
+ // Utility Methods
648
+ // ============================================================================
715
649
  /**
716
650
  * Gets computed metrics
717
651
  */
@@ -726,6 +660,73 @@ var BinarizeSelector = class {
726
660
  }
727
661
  };
728
662
 
663
+ // src/assertions/percentage-matcher.ts
664
+ var PercentageMatcher = class {
665
+ context;
666
+ constructor(context) {
667
+ this.context = context;
668
+ }
669
+ formatPercentage(value) {
670
+ return `${(value * 100).toFixed(1)}%`;
671
+ }
672
+ createAssertion(operator, percentageThreshold, passed) {
673
+ const { fieldName, valueThreshold, direction, actualPercentage } = this.context;
674
+ const operatorText = {
675
+ ">=": "at least",
676
+ ">": "above",
677
+ "<=": "at most",
678
+ "<": "below"
679
+ }[operator];
680
+ const directionText = direction === "above" ? "above" : "below or equal to";
681
+ const message = passed ? `${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})` : `Only ${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})`;
682
+ return {
683
+ type: direction === "above" ? "percentageAbove" : "percentageBelow",
684
+ passed,
685
+ message,
686
+ expected: percentageThreshold,
687
+ actual: actualPercentage,
688
+ field: fieldName
689
+ };
690
+ }
691
+ recordAndReturn(result) {
692
+ this.context.assertions.push(result);
693
+ recordAssertion(result);
694
+ return this.context.parent;
695
+ }
696
+ /**
697
+ * Assert that the percentage is greater than or equal to the threshold (>=)
698
+ */
699
+ toBeAtLeast(percentageThreshold) {
700
+ const passed = this.context.actualPercentage >= percentageThreshold;
701
+ const result = this.createAssertion(">=", percentageThreshold, passed);
702
+ return this.recordAndReturn(result);
703
+ }
704
+ /**
705
+ * Assert that the percentage is strictly greater than the threshold (>)
706
+ */
707
+ toBeAbove(percentageThreshold) {
708
+ const passed = this.context.actualPercentage > percentageThreshold;
709
+ const result = this.createAssertion(">", percentageThreshold, passed);
710
+ return this.recordAndReturn(result);
711
+ }
712
+ /**
713
+ * Assert that the percentage is less than or equal to the threshold (<=)
714
+ */
715
+ toBeAtMost(percentageThreshold) {
716
+ const passed = this.context.actualPercentage <= percentageThreshold;
717
+ const result = this.createAssertion("<=", percentageThreshold, passed);
718
+ return this.recordAndReturn(result);
719
+ }
720
+ /**
721
+ * Assert that the percentage is strictly less than the threshold (<)
722
+ */
723
+ toBeBelow(percentageThreshold) {
724
+ const passed = this.context.actualPercentage < percentageThreshold;
725
+ const result = this.createAssertion("<", percentageThreshold, passed);
726
+ return this.recordAndReturn(result);
727
+ }
728
+ };
729
+
729
730
  // src/assertions/field-selector.ts
730
731
  var FieldSelector = class {
731
732
  aligned;
@@ -762,83 +763,93 @@ var FieldSelector = class {
762
763
  }
763
764
  }
764
765
  /**
765
- * Asserts that accuracy is above a threshold
766
+ * Validates that ground truth exists and both arrays contain numeric values.
767
+ * Returns the filtered numeric arrays for regression metrics.
766
768
  */
767
- toHaveAccuracyAbove(threshold) {
769
+ validateRegressionInputs() {
770
+ this.validateGroundTruth();
771
+ const numericActual = filterNumericValues(this.actualValues);
772
+ const numericExpected = filterNumericValues(this.expectedValues);
773
+ if (numericActual.length === 0) {
774
+ throw new AssertionError(
775
+ `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
776
+ void 0,
777
+ void 0,
778
+ this.fieldName
779
+ );
780
+ }
781
+ if (numericExpected.length === 0) {
782
+ throw new AssertionError(
783
+ `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
784
+ void 0,
785
+ void 0,
786
+ this.fieldName
787
+ );
788
+ }
789
+ if (numericActual.length !== numericExpected.length) {
790
+ throw new AssertionError(
791
+ `Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
792
+ numericExpected.length,
793
+ numericActual.length,
794
+ this.fieldName
795
+ );
796
+ }
797
+ return { actual: numericActual, expected: numericExpected };
798
+ }
799
+ // ============================================================================
800
+ // Classification Metric Getters
801
+ // ============================================================================
802
+ /**
803
+ * Access accuracy metric for assertions
804
+ * @example
805
+ * expectStats(predictions, groundTruth)
806
+ * .field("sentiment")
807
+ * .accuracy.toBeAtLeast(0.8)
808
+ */
809
+ get accuracy() {
768
810
  this.validateGroundTruth();
769
811
  const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
770
- const passed = metrics.accuracy >= threshold;
771
- const result = {
772
- type: "accuracy",
773
- passed,
774
- message: passed ? `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is above ${(threshold * 100).toFixed(1)}%` : `Accuracy ${(metrics.accuracy * 100).toFixed(1)}% is below threshold ${(threshold * 100).toFixed(1)}%`,
775
- expected: threshold,
776
- actual: metrics.accuracy,
777
- field: this.fieldName
778
- };
779
- this.assertions.push(result);
780
- recordAssertion(result);
781
- return this;
812
+ return new MetricMatcher({
813
+ parent: this,
814
+ metricName: "Accuracy",
815
+ metricValue: metrics.accuracy,
816
+ fieldName: this.fieldName,
817
+ assertions: this.assertions
818
+ });
782
819
  }
783
820
  /**
784
- * Asserts that precision is above a threshold
785
- * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
786
- * @param threshold - Threshold when class is specified
821
+ * Access F1 score metric for assertions (macro average)
822
+ * @example
823
+ * expectStats(predictions, groundTruth)
824
+ * .field("sentiment")
825
+ * .f1.toBeAtLeast(0.75)
787
826
  */
788
- toHavePrecisionAbove(classOrThreshold, threshold) {
827
+ get f1() {
789
828
  this.validateGroundTruth();
790
829
  const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
791
- let actualPrecision;
792
- let targetClass;
793
- let actualThreshold;
794
- if (typeof classOrThreshold === "number") {
795
- actualPrecision = metrics.macroAvg.precision;
796
- actualThreshold = classOrThreshold;
797
- } else {
798
- targetClass = classOrThreshold;
799
- actualThreshold = threshold;
800
- const classMetrics = metrics.perClass[targetClass];
801
- if (!classMetrics) {
802
- throw new AssertionError(
803
- `Class "${targetClass}" not found in predictions`,
804
- targetClass,
805
- Object.keys(metrics.perClass),
806
- this.fieldName
807
- );
808
- }
809
- actualPrecision = classMetrics.precision;
810
- }
811
- const passed = actualPrecision >= actualThreshold;
812
- const result = {
813
- type: "precision",
814
- passed,
815
- message: passed ? `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
816
- expected: actualThreshold,
817
- actual: actualPrecision,
818
- field: this.fieldName,
819
- class: targetClass
820
- };
821
- this.assertions.push(result);
822
- recordAssertion(result);
823
- return this;
830
+ return new MetricMatcher({
831
+ parent: this,
832
+ metricName: "F1",
833
+ metricValue: metrics.macroAvg.f1,
834
+ fieldName: this.fieldName,
835
+ assertions: this.assertions
836
+ });
824
837
  }
825
838
  /**
826
- * Asserts that recall is above a threshold
827
- * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
828
- * @param threshold - Threshold when class is specified
839
+ * Access precision metric for assertions
840
+ * @param targetClass - Optional class name. If omitted, uses macro average
841
+ * @example
842
+ * expectStats(predictions, groundTruth)
843
+ * .field("sentiment")
844
+ * .precision("positive").toBeAtLeast(0.7)
829
845
  */
830
- toHaveRecallAbove(classOrThreshold, threshold) {
846
+ precision(targetClass) {
831
847
  this.validateGroundTruth();
832
848
  const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
833
- let actualRecall;
834
- let targetClass;
835
- let actualThreshold;
836
- if (typeof classOrThreshold === "number") {
837
- actualRecall = metrics.macroAvg.recall;
838
- actualThreshold = classOrThreshold;
849
+ let metricValue;
850
+ if (targetClass === void 0) {
851
+ metricValue = metrics.macroAvg.precision;
839
852
  } else {
840
- targetClass = classOrThreshold;
841
- actualThreshold = threshold;
842
853
  const classMetrics = metrics.perClass[targetClass];
843
854
  if (!classMetrics) {
844
855
  throw new AssertionError(
@@ -848,39 +859,32 @@ var FieldSelector = class {
848
859
  this.fieldName
849
860
  );
850
861
  }
851
- actualRecall = classMetrics.recall;
862
+ metricValue = classMetrics.precision;
852
863
  }
853
- const passed = actualRecall >= actualThreshold;
854
- const result = {
855
- type: "recall",
856
- passed,
857
- message: passed ? `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
858
- expected: actualThreshold,
859
- actual: actualRecall,
860
- field: this.fieldName,
861
- class: targetClass
862
- };
863
- this.assertions.push(result);
864
- recordAssertion(result);
865
- return this;
864
+ return new MetricMatcher({
865
+ parent: this,
866
+ metricName: "Precision",
867
+ metricValue,
868
+ fieldName: this.fieldName,
869
+ targetClass,
870
+ assertions: this.assertions
871
+ });
866
872
  }
867
873
  /**
868
- * Asserts that F1 score is above a threshold
869
- * @param classOrThreshold - Either the class name or threshold (if class is omitted, uses macro average)
870
- * @param threshold - Threshold when class is specified
874
+ * Access recall metric for assertions
875
+ * @param targetClass - Optional class name. If omitted, uses macro average
876
+ * @example
877
+ * expectStats(predictions, groundTruth)
878
+ * .field("sentiment")
879
+ * .recall("positive").toBeAtLeast(0.7)
871
880
  */
872
- toHaveF1Above(classOrThreshold, threshold) {
881
+ recall(targetClass) {
873
882
  this.validateGroundTruth();
874
883
  const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
875
- let actualF1;
876
- let targetClass;
877
- let actualThreshold;
878
- if (typeof classOrThreshold === "number") {
879
- actualF1 = metrics.macroAvg.f1;
880
- actualThreshold = classOrThreshold;
884
+ let metricValue;
885
+ if (targetClass === void 0) {
886
+ metricValue = metrics.macroAvg.recall;
881
887
  } else {
882
- targetClass = classOrThreshold;
883
- actualThreshold = threshold;
884
888
  const classMetrics = metrics.perClass[targetClass];
885
889
  if (!classMetrics) {
886
890
  throw new AssertionError(
@@ -890,244 +894,171 @@ var FieldSelector = class {
890
894
  this.fieldName
891
895
  );
892
896
  }
893
- actualF1 = classMetrics.f1;
897
+ metricValue = classMetrics.recall;
894
898
  }
895
- const passed = actualF1 >= actualThreshold;
896
- const result = {
897
- type: "f1",
898
- passed,
899
- message: passed ? `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
900
- expected: actualThreshold,
901
- actual: actualF1,
902
- field: this.fieldName,
903
- class: targetClass
904
- };
905
- this.assertions.push(result);
906
- recordAssertion(result);
907
- return this;
899
+ return new MetricMatcher({
900
+ parent: this,
901
+ metricName: "Recall",
902
+ metricValue,
903
+ fieldName: this.fieldName,
904
+ targetClass,
905
+ assertions: this.assertions
906
+ });
908
907
  }
908
+ // ============================================================================
909
+ // Regression Metric Getters
910
+ // ============================================================================
909
911
  /**
910
- * Includes the confusion matrix in the report
912
+ * Access Mean Absolute Error metric for assertions
913
+ * @example
914
+ * expectStats(predictions, groundTruth)
915
+ * .field("score")
916
+ * .mae.toBeAtMost(0.1)
911
917
  */
912
- toHaveConfusionMatrix() {
913
- const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
914
- const fieldResult = {
915
- field: this.fieldName,
916
- metrics,
917
- binarized: false
918
- };
919
- recordFieldMetrics(fieldResult);
920
- const result = {
921
- type: "confusionMatrix",
922
- passed: true,
923
- message: `Confusion matrix recorded for field "${this.fieldName}"`,
924
- field: this.fieldName
925
- };
926
- this.assertions.push(result);
927
- recordAssertion(result);
928
- return this;
918
+ get mae() {
919
+ const { actual, expected } = this.validateRegressionInputs();
920
+ const metrics = computeRegressionMetrics(actual, expected);
921
+ return new MetricMatcher({
922
+ parent: this,
923
+ metricName: "MAE",
924
+ metricValue: metrics.mae,
925
+ fieldName: this.fieldName,
926
+ assertions: this.assertions,
927
+ formatValue: (v) => v.toFixed(4)
928
+ });
929
929
  }
930
930
  /**
931
- * Asserts that a percentage of values are below or equal to a threshold.
932
- * This is a distributional assertion that only looks at actual values (no ground truth required).
933
- *
931
+ * Access Root Mean Squared Error metric for assertions
932
+ * @example
933
+ * expectStats(predictions, groundTruth)
934
+ * .field("score")
935
+ * .rmse.toBeAtMost(0.15)
936
+ */
937
+ get rmse() {
938
+ const { actual, expected } = this.validateRegressionInputs();
939
+ const metrics = computeRegressionMetrics(actual, expected);
940
+ return new MetricMatcher({
941
+ parent: this,
942
+ metricName: "RMSE",
943
+ metricValue: metrics.rmse,
944
+ fieldName: this.fieldName,
945
+ assertions: this.assertions,
946
+ formatValue: (v) => v.toFixed(4)
947
+ });
948
+ }
949
+ /**
950
+ * Access R-squared (coefficient of determination) metric for assertions
951
+ * @example
952
+ * expectStats(predictions, groundTruth)
953
+ * .field("score")
954
+ * .r2.toBeAtLeast(0.8)
955
+ */
956
+ get r2() {
957
+ const { actual, expected } = this.validateRegressionInputs();
958
+ const metrics = computeRegressionMetrics(actual, expected);
959
+ return new MetricMatcher({
960
+ parent: this,
961
+ metricName: "R\xB2",
962
+ metricValue: metrics.r2,
963
+ fieldName: this.fieldName,
964
+ assertions: this.assertions,
965
+ formatValue: (v) => v.toFixed(4)
966
+ });
967
+ }
968
+ // ============================================================================
969
+ // Distribution Assertions
970
+ // ============================================================================
971
+ /**
972
+ * Assert on the percentage of values below or equal to a threshold
934
973
  * @param valueThreshold - The value threshold to compare against
935
- * @param percentageThreshold - The minimum percentage (0-1) of values that should be <= valueThreshold
936
- * @returns this for method chaining
937
- *
938
974
  * @example
939
- * // Assert that 90% of confidence scores are below 0.5
940
975
  * expectStats(predictions)
941
976
  * .field("confidence")
942
- * .toHavePercentageBelow(0.5, 0.9)
977
+ * .percentageBelow(0.5).toBeAtLeast(0.9)
943
978
  */
944
- toHavePercentageBelow(valueThreshold, percentageThreshold) {
979
+ percentageBelow(valueThreshold) {
945
980
  const numericActual = filterNumericValues(this.actualValues);
946
981
  if (numericActual.length === 0) {
947
982
  throw new AssertionError(
948
983
  `Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
949
- percentageThreshold,
984
+ void 0,
950
985
  void 0,
951
986
  this.fieldName
952
987
  );
953
988
  }
954
989
  const actualPercentage = calculatePercentageBelow(numericActual, valueThreshold);
955
- const passed = actualPercentage >= percentageThreshold;
956
- const result = {
957
- type: "percentageBelow",
958
- passed,
959
- message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are below or equal to ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
960
- expected: percentageThreshold,
961
- actual: actualPercentage,
962
- field: this.fieldName
963
- };
964
- this.assertions.push(result);
965
- recordAssertion(result);
966
- return this;
990
+ return new PercentageMatcher({
991
+ parent: this,
992
+ fieldName: this.fieldName,
993
+ valueThreshold,
994
+ direction: "below",
995
+ actualPercentage,
996
+ assertions: this.assertions
997
+ });
967
998
  }
968
999
  /**
969
- * Asserts that a percentage of values are above a threshold.
970
- * This is a distributional assertion that only looks at actual values (no ground truth required).
971
- *
1000
+ * Assert on the percentage of values above a threshold
972
1001
  * @param valueThreshold - The value threshold to compare against
973
- * @param percentageThreshold - The minimum percentage (0-1) of values that should be > valueThreshold
974
- * @returns this for method chaining
975
- *
976
1002
  * @example
977
- * // Assert that 80% of quality scores are above 0.7
978
1003
  * expectStats(predictions)
979
1004
  * .field("quality")
980
- * .toHavePercentageAbove(0.7, 0.8)
1005
+ * .percentageAbove(0.7).toBeAtLeast(0.8)
981
1006
  */
982
- toHavePercentageAbove(valueThreshold, percentageThreshold) {
1007
+ percentageAbove(valueThreshold) {
983
1008
  const numericActual = filterNumericValues(this.actualValues);
984
1009
  if (numericActual.length === 0) {
985
1010
  throw new AssertionError(
986
1011
  `Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
987
- percentageThreshold,
1012
+ void 0,
988
1013
  void 0,
989
1014
  this.fieldName
990
1015
  );
991
1016
  }
992
1017
  const actualPercentage = calculatePercentageAbove(numericActual, valueThreshold);
993
- const passed = actualPercentage >= percentageThreshold;
994
- const result = {
995
- type: "percentageAbove",
996
- passed,
997
- message: passed ? `${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)` : `Only ${(actualPercentage * 100).toFixed(1)}% of '${this.fieldName}' values are above ${valueThreshold} (expected >= ${(percentageThreshold * 100).toFixed(1)}%)`,
998
- expected: percentageThreshold,
999
- actual: actualPercentage,
1000
- field: this.fieldName
1001
- };
1002
- this.assertions.push(result);
1003
- recordAssertion(result);
1004
- return this;
1018
+ return new PercentageMatcher({
1019
+ parent: this,
1020
+ fieldName: this.fieldName,
1021
+ valueThreshold,
1022
+ direction: "above",
1023
+ actualPercentage,
1024
+ assertions: this.assertions
1025
+ });
1005
1026
  }
1006
1027
  // ============================================================================
1007
- // Regression Assertions
1028
+ // Display Methods
1008
1029
  // ============================================================================
1009
1030
  /**
1010
- * Validates that ground truth exists and both arrays contain numeric values.
1011
- * Returns the filtered numeric arrays for regression metrics.
1012
- */
1013
- validateRegressionInputs() {
1014
- this.validateGroundTruth();
1015
- const numericActual = filterNumericValues(this.actualValues);
1016
- const numericExpected = filterNumericValues(this.expectedValues);
1017
- if (numericActual.length === 0) {
1018
- throw new AssertionError(
1019
- `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
1020
- void 0,
1021
- void 0,
1022
- this.fieldName
1023
- );
1024
- }
1025
- if (numericExpected.length === 0) {
1026
- throw new AssertionError(
1027
- `Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
1028
- void 0,
1029
- void 0,
1030
- this.fieldName
1031
- );
1032
- }
1033
- if (numericActual.length !== numericExpected.length) {
1034
- throw new AssertionError(
1035
- `Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
1036
- numericExpected.length,
1037
- numericActual.length,
1038
- this.fieldName
1039
- );
1040
- }
1041
- return { actual: numericActual, expected: numericExpected };
1042
- }
1043
- /**
1044
- * Asserts that Mean Absolute Error is below a threshold.
1045
- * Requires numeric values in both actual and expected.
1046
- *
1047
- * @param threshold - Maximum allowed MAE
1048
- * @returns this for method chaining
1049
- *
1031
+ * Displays the confusion matrix in the report
1032
+ * This is not an assertion - it always passes and just records the matrix for display
1050
1033
  * @example
1051
1034
  * expectStats(predictions, groundTruth)
1052
- * .field("score")
1053
- * .toHaveMAEBelow(0.1)
1035
+ * .field("sentiment")
1036
+ * .accuracy.toBeAtLeast(0.8)
1037
+ * .displayConfusionMatrix()
1054
1038
  */
1055
- toHaveMAEBelow(threshold) {
1056
- const { actual, expected } = this.validateRegressionInputs();
1057
- const metrics = computeRegressionMetrics(actual, expected);
1058
- const passed = metrics.mae <= threshold;
1059
- const result = {
1060
- type: "mae",
1061
- passed,
1062
- message: passed ? `MAE ${metrics.mae.toFixed(4)} is below ${threshold}` : `MAE ${metrics.mae.toFixed(4)} exceeds threshold ${threshold}`,
1063
- expected: threshold,
1064
- actual: metrics.mae,
1065
- field: this.fieldName
1066
- };
1067
- this.assertions.push(result);
1068
- recordAssertion(result);
1069
- return this;
1070
- }
1071
- /**
1072
- * Asserts that Root Mean Squared Error is below a threshold.
1073
- * Requires numeric values in both actual and expected.
1074
- *
1075
- * @param threshold - Maximum allowed RMSE
1076
- * @returns this for method chaining
1077
- *
1078
- * @example
1079
- * expectStats(predictions, groundTruth)
1080
- * .field("score")
1081
- * .toHaveRMSEBelow(0.15)
1082
- */
1083
- toHaveRMSEBelow(threshold) {
1084
- const { actual, expected } = this.validateRegressionInputs();
1085
- const metrics = computeRegressionMetrics(actual, expected);
1086
- const passed = metrics.rmse <= threshold;
1087
- const result = {
1088
- type: "rmse",
1089
- passed,
1090
- message: passed ? `RMSE ${metrics.rmse.toFixed(4)} is below ${threshold}` : `RMSE ${metrics.rmse.toFixed(4)} exceeds threshold ${threshold}`,
1091
- expected: threshold,
1092
- actual: metrics.rmse,
1093
- field: this.fieldName
1039
+ displayConfusionMatrix() {
1040
+ const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
1041
+ const fieldResult = {
1042
+ field: this.fieldName,
1043
+ metrics,
1044
+ binarized: false
1094
1045
  };
1095
- this.assertions.push(result);
1096
- recordAssertion(result);
1097
- return this;
1098
- }
1099
- /**
1100
- * Asserts that R-squared (coefficient of determination) is above a threshold.
1101
- * R² measures how well the predictions explain the variance in expected values.
1102
- * R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
1103
- * Requires numeric values in both actual and expected.
1104
- *
1105
- * @param threshold - Minimum required R² value (0-1)
1106
- * @returns this for method chaining
1107
- *
1108
- * @example
1109
- * expectStats(predictions, groundTruth)
1110
- * .field("score")
1111
- * .toHaveR2Above(0.8)
1112
- */
1113
- toHaveR2Above(threshold) {
1114
- const { actual, expected } = this.validateRegressionInputs();
1115
- const metrics = computeRegressionMetrics(actual, expected);
1116
- const passed = metrics.r2 >= threshold;
1046
+ recordFieldMetrics(fieldResult);
1117
1047
  const result = {
1118
- type: "r2",
1119
- passed,
1120
- message: passed ? `R\xB2 ${metrics.r2.toFixed(4)} is above ${threshold}` : `R\xB2 ${metrics.r2.toFixed(4)} is below threshold ${threshold}`,
1121
- expected: threshold,
1122
- actual: metrics.r2,
1048
+ type: "confusionMatrix",
1049
+ passed: true,
1050
+ message: `Confusion matrix recorded for field "${this.fieldName}"`,
1123
1051
  field: this.fieldName
1124
1052
  };
1125
1053
  this.assertions.push(result);
1126
1054
  recordAssertion(result);
1127
1055
  return this;
1128
1056
  }
1057
+ // ============================================================================
1058
+ // Utility Methods
1059
+ // ============================================================================
1129
1060
  /**
1130
- * Gets the computed metrics for this field
1061
+ * Gets the computed classification metrics for this field
1131
1062
  */
1132
1063
  getMetrics() {
1133
1064
  return computeClassificationMetrics(this.actualValues, this.expectedValues);
@@ -1160,7 +1091,7 @@ function normalizeInput(input) {
1160
1091
  }));
1161
1092
  }
1162
1093
  throw new Error(
1163
- "Invalid input to expectStats(): expected ModelRunResult, Prediction[], or AlignedRecord[]"
1094
+ "Invalid input to expectStats(): expected { aligned: AlignedRecord[] }, Prediction[], or AlignedRecord[]"
1164
1095
  );
1165
1096
  }
1166
1097
  function expectStats(inputOrActual, expected, options) {
@@ -1205,6 +1136,6 @@ var ExpectStats = class {
1205
1136
  }
1206
1137
  };
1207
1138
 
1208
- export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, createDataset, describe, evalTest, expectStats, extractFieldValues, filterComplete, it, loadDataset, runModel, runModelParallel, test, validatePredictions };
1139
+ export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, describe, evalTest, expectStats, extractFieldValues, filterComplete, it, test, validatePredictions };
1209
1140
  //# sourceMappingURL=index.js.map
1210
1141
  //# sourceMappingURL=index.js.map