evalsense 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +190 -0
- package/README.md +99 -82
- package/dist/{chunk-HDJID3GC.cjs → chunk-DFC6FRTG.cjs} +8 -26
- package/dist/chunk-DFC6FRTG.cjs.map +1 -0
- package/dist/chunk-DGUM43GV.js +10 -0
- package/dist/chunk-DGUM43GV.js.map +1 -0
- package/dist/chunk-JEQ2X3Z6.cjs +12 -0
- package/dist/chunk-JEQ2X3Z6.cjs.map +1 -0
- package/dist/{chunk-5P7LNNO6.js → chunk-JPVZL45G.js} +8 -26
- package/dist/chunk-JPVZL45G.js.map +1 -0
- package/dist/{chunk-Y23VHTD3.cjs → chunk-RZFLCWTW.cjs} +2 -2
- package/dist/chunk-RZFLCWTW.cjs.map +1 -0
- package/dist/{chunk-BRPM6AB6.js → chunk-Z3U6AUWX.js} +2 -2
- package/dist/chunk-Z3U6AUWX.js.map +1 -0
- package/dist/cli.cjs +39 -36
- package/dist/cli.cjs.map +1 -1
- package/dist/cli.js +37 -34
- package/dist/cli.js.map +1 -1
- package/dist/index.cjs +300 -101
- package/dist/index.cjs.map +1 -1
- package/dist/index.d.cts +76 -6
- package/dist/index.d.ts +76 -6
- package/dist/index.js +222 -23
- package/dist/index.js.map +1 -1
- package/dist/metrics/index.cjs +257 -17
- package/dist/metrics/index.cjs.map +1 -1
- package/dist/metrics/index.d.cts +252 -1
- package/dist/metrics/index.d.ts +252 -1
- package/dist/metrics/index.js +240 -2
- package/dist/metrics/index.js.map +1 -1
- package/dist/metrics/opinionated/index.cjs +6 -5
- package/dist/metrics/opinionated/index.js +2 -1
- package/package.json +8 -6
- package/dist/chunk-5P7LNNO6.js.map +0 -1
- package/dist/chunk-BRPM6AB6.js.map +0 -1
- package/dist/chunk-HDJID3GC.cjs.map +0 -1
- package/dist/chunk-Y23VHTD3.cjs.map +0 -1
package/dist/index.d.ts
CHANGED
|
@@ -321,6 +321,52 @@ declare class FieldSelector {
|
|
|
321
321
|
* .toHavePercentageAbove(0.7, 0.8)
|
|
322
322
|
*/
|
|
323
323
|
toHavePercentageAbove(valueThreshold: number, percentageThreshold: number): this;
|
|
324
|
+
/**
|
|
325
|
+
* Validates that ground truth exists and both arrays contain numeric values.
|
|
326
|
+
* Returns the filtered numeric arrays for regression metrics.
|
|
327
|
+
*/
|
|
328
|
+
private validateRegressionInputs;
|
|
329
|
+
/**
|
|
330
|
+
* Asserts that Mean Absolute Error is below a threshold.
|
|
331
|
+
* Requires numeric values in both actual and expected.
|
|
332
|
+
*
|
|
333
|
+
* @param threshold - Maximum allowed MAE
|
|
334
|
+
* @returns this for method chaining
|
|
335
|
+
*
|
|
336
|
+
* @example
|
|
337
|
+
* expectStats(predictions, groundTruth)
|
|
338
|
+
* .field("score")
|
|
339
|
+
* .toHaveMAEBelow(0.1)
|
|
340
|
+
*/
|
|
341
|
+
toHaveMAEBelow(threshold: number): this;
|
|
342
|
+
/**
|
|
343
|
+
* Asserts that Root Mean Squared Error is below a threshold.
|
|
344
|
+
* Requires numeric values in both actual and expected.
|
|
345
|
+
*
|
|
346
|
+
* @param threshold - Maximum allowed RMSE
|
|
347
|
+
* @returns this for method chaining
|
|
348
|
+
*
|
|
349
|
+
* @example
|
|
350
|
+
* expectStats(predictions, groundTruth)
|
|
351
|
+
* .field("score")
|
|
352
|
+
* .toHaveRMSEBelow(0.15)
|
|
353
|
+
*/
|
|
354
|
+
toHaveRMSEBelow(threshold: number): this;
|
|
355
|
+
/**
|
|
356
|
+
* Asserts that R-squared (coefficient of determination) is above a threshold.
|
|
357
|
+
* R² measures how well the predictions explain the variance in expected values.
|
|
358
|
+
* R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
|
|
359
|
+
* Requires numeric values in both actual and expected.
|
|
360
|
+
*
|
|
361
|
+
* @param threshold - Minimum required R² value (0-1)
|
|
362
|
+
* @returns this for method chaining
|
|
363
|
+
*
|
|
364
|
+
* @example
|
|
365
|
+
* expectStats(predictions, groundTruth)
|
|
366
|
+
* .field("score")
|
|
367
|
+
* .toHaveR2Above(0.8)
|
|
368
|
+
*/
|
|
369
|
+
toHaveR2Above(threshold: number): this;
|
|
324
370
|
/**
|
|
325
371
|
* Gets the computed metrics for this field
|
|
326
372
|
*/
|
|
@@ -339,15 +385,32 @@ declare class FieldSelector {
|
|
|
339
385
|
* Input types that expectStats() accepts
|
|
340
386
|
*/
|
|
341
387
|
type StatsInput = ModelRunResult | Prediction[] | AlignedRecord[];
|
|
388
|
+
/**
|
|
389
|
+
* Options for expectStats when using two-argument form
|
|
390
|
+
*/
|
|
391
|
+
interface ExpectStatsOptions {
|
|
392
|
+
/**
|
|
393
|
+
* Field to use as ID for alignment (default: "id")
|
|
394
|
+
* Also checks "_id" as fallback for expected records.
|
|
395
|
+
*/
|
|
396
|
+
idField?: string;
|
|
397
|
+
/**
|
|
398
|
+
* Whether to throw on missing IDs (default: false)
|
|
399
|
+
* When true, throws if any prediction has no matching expected record.
|
|
400
|
+
*/
|
|
401
|
+
strict?: boolean;
|
|
402
|
+
}
|
|
342
403
|
/**
|
|
343
404
|
* Entry point for statistical assertions.
|
|
344
405
|
*
|
|
345
|
-
* Supports
|
|
406
|
+
* Supports multiple usage patterns:
|
|
346
407
|
* 1. Single argument: predictions without ground truth (for distribution assertions)
|
|
347
|
-
* 2. Two arguments: predictions with ground truth (for classification metrics)
|
|
408
|
+
* 2. Two arguments: predictions with ground truth (for classification/regression metrics)
|
|
409
|
+
* 3. Three arguments: predictions with ground truth and options (for custom ID field)
|
|
348
410
|
*
|
|
349
|
-
* @param inputOrActual - Either StatsInput (one-arg) or Prediction[] (two-arg)
|
|
350
|
-
* @param expected - Ground truth data (optional, only for two-arg usage)
|
|
411
|
+
* @param inputOrActual - Either StatsInput (one-arg) or Prediction[] (two/three-arg)
|
|
412
|
+
* @param expected - Ground truth data (optional, only for two/three-arg usage)
|
|
413
|
+
* @param options - Alignment options (optional, only for three-arg usage)
|
|
351
414
|
* @returns ExpectStats instance for chaining assertions
|
|
352
415
|
*
|
|
353
416
|
* @example
|
|
@@ -357,14 +420,21 @@ type StatsInput = ModelRunResult | Prediction[] | AlignedRecord[];
|
|
|
357
420
|
* .toHavePercentageBelow(0.5, 0.9);
|
|
358
421
|
*
|
|
359
422
|
* @example
|
|
360
|
-
* // Pattern
|
|
423
|
+
* // Pattern 2: Classification with ground truth
|
|
361
424
|
* expectStats(judgeOutputs, humanLabels)
|
|
362
425
|
* .field("hallucinated")
|
|
363
426
|
* .toHaveRecallAbove(true, 0.85)
|
|
364
427
|
* .toHavePrecisionAbove(true, 0.8);
|
|
428
|
+
*
|
|
429
|
+
* @example
|
|
430
|
+
* // Pattern 3: Custom ID field
|
|
431
|
+
* expectStats(predictions, groundTruth, { idField: 'uuid' })
|
|
432
|
+
* .field("score")
|
|
433
|
+
* .toHaveAccuracyAbove(0.8);
|
|
365
434
|
*/
|
|
366
435
|
declare function expectStats(input: StatsInput): ExpectStats;
|
|
367
436
|
declare function expectStats(actual: Prediction[], expected: Array<Record<string, unknown>>): ExpectStats;
|
|
437
|
+
declare function expectStats(actual: Prediction[], expected: Array<Record<string, unknown>>, options: ExpectStatsOptions): ExpectStats;
|
|
368
438
|
/**
|
|
369
439
|
* Main stats expectation class
|
|
370
440
|
*/
|
|
@@ -601,4 +671,4 @@ declare class TestExecutionError extends EvalSenseError {
|
|
|
601
671
|
constructor(message: string, testName: string, originalError?: Error);
|
|
602
672
|
}
|
|
603
673
|
|
|
604
|
-
export { AlignedRecord, AssertionError, AssertionResult, ClassificationMetrics, ConfigurationError, ConfusionMatrix, ConsoleReporter, Dataset, DatasetError, EvalReport, EvalSenseError, FieldMetricResult, IntegrityError, IntegrityResult, JsonReporter, Prediction, TestExecutionError, TestFn, afterAll, afterEach, alignByKey, beforeAll, beforeEach, buildConfusionMatrix, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, createDataset, describe, discoverEvalFiles, evalTest, executeEvalFiles, expectStats, extractFieldValues, filterComplete, formatConfusionMatrix, getExitCode, it, loadDataset, parseReport, runModel, runModelParallel, test, validatePredictions };
|
|
674
|
+
export { AlignedRecord, AssertionError, AssertionResult, ClassificationMetrics, ConfigurationError, ConfusionMatrix, ConsoleReporter, Dataset, DatasetError, EvalReport, EvalSenseError, type ExpectStatsOptions, FieldMetricResult, IntegrityError, IntegrityResult, JsonReporter, Prediction, TestExecutionError, TestFn, afterAll, afterEach, alignByKey, beforeAll, beforeEach, buildConfusionMatrix, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, createDataset, describe, discoverEvalFiles, evalTest, executeEvalFiles, expectStats, extractFieldValues, filterComplete, formatConfusionMatrix, getExitCode, it, loadDataset, parseReport, runModel, runModelParallel, test, validatePredictions };
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, DatasetError, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordAssertion, recordFieldMetrics } from './chunk-
|
|
2
|
-
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-
|
|
1
|
+
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, DatasetError, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordAssertion, recordFieldMetrics } from './chunk-JPVZL45G.js';
|
|
2
|
+
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-JPVZL45G.js';
|
|
3
|
+
import './chunk-DGUM43GV.js';
|
|
3
4
|
import { readFileSync } from 'fs';
|
|
4
5
|
import { resolve, extname } from 'path';
|
|
5
6
|
|
|
@@ -183,9 +184,7 @@ async function runModel(dataset, modelFn) {
|
|
|
183
184
|
function getRecordId(record) {
|
|
184
185
|
const id = record.id ?? record._id;
|
|
185
186
|
if (id === void 0 || id === null) {
|
|
186
|
-
throw new DatasetError(
|
|
187
|
-
'Dataset records must have an "id" or "_id" field for alignment'
|
|
188
|
-
);
|
|
187
|
+
throw new DatasetError('Dataset records must have an "id" or "_id" field for alignment');
|
|
189
188
|
}
|
|
190
189
|
return String(id);
|
|
191
190
|
}
|
|
@@ -207,9 +206,7 @@ async function runModelParallel(dataset, modelFn, concurrency = 10) {
|
|
|
207
206
|
for (const { prediction, record } of results) {
|
|
208
207
|
const id = getRecordId(record);
|
|
209
208
|
if (prediction.id !== id) {
|
|
210
|
-
throw new DatasetError(
|
|
211
|
-
`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`
|
|
212
|
-
);
|
|
209
|
+
throw new DatasetError(`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`);
|
|
213
210
|
}
|
|
214
211
|
predictions.push(prediction);
|
|
215
212
|
aligned.push({
|
|
@@ -306,9 +303,7 @@ function checkIntegrity(dataset, options = {}) {
|
|
|
306
303
|
}
|
|
307
304
|
}
|
|
308
305
|
if (requiredFields.length > 0) {
|
|
309
|
-
const missing = requiredFields.filter(
|
|
310
|
-
(field) => record[field] === void 0
|
|
311
|
-
);
|
|
306
|
+
const missing = requiredFields.filter((field) => record[field] === void 0);
|
|
312
307
|
if (missing.length > 0) {
|
|
313
308
|
missingFields.push({
|
|
314
309
|
id: String(id ?? `record[${i}]`),
|
|
@@ -331,7 +326,9 @@ function checkIntegrity(dataset, options = {}) {
|
|
|
331
326
|
issues.push(`${missingIds.length} record(s) missing ID`);
|
|
332
327
|
}
|
|
333
328
|
if (duplicateIds.length > 0) {
|
|
334
|
-
issues.push(
|
|
329
|
+
issues.push(
|
|
330
|
+
`${duplicateIds.length} duplicate ID(s): ${duplicateIds.slice(0, 3).join(", ")}${duplicateIds.length > 3 ? "..." : ""}`
|
|
331
|
+
);
|
|
335
332
|
}
|
|
336
333
|
if (missingFields.length > 0) {
|
|
337
334
|
issues.push(`${missingFields.length} record(s) missing required fields`);
|
|
@@ -429,6 +426,67 @@ function computeAccuracy(actual, expected) {
|
|
|
429
426
|
return total > 0 ? correct / total : 0;
|
|
430
427
|
}
|
|
431
428
|
|
|
429
|
+
// src/statistics/regression.ts
|
|
430
|
+
function computeRegressionMetrics(actual, expected) {
|
|
431
|
+
if (actual.length !== expected.length) {
|
|
432
|
+
throw new Error(
|
|
433
|
+
`Array length mismatch: actual has ${actual.length} elements, expected has ${expected.length}`
|
|
434
|
+
);
|
|
435
|
+
}
|
|
436
|
+
const n = actual.length;
|
|
437
|
+
if (n === 0) {
|
|
438
|
+
return { mae: 0, mse: 0, rmse: 0, r2: 0 };
|
|
439
|
+
}
|
|
440
|
+
const mae = computeMAE(actual, expected);
|
|
441
|
+
const mse = computeMSE(actual, expected);
|
|
442
|
+
const rmse = Math.sqrt(mse);
|
|
443
|
+
const r2 = computeR2(actual, expected);
|
|
444
|
+
return { mae, mse, rmse, r2 };
|
|
445
|
+
}
|
|
446
|
+
function computeMAE(actual, expected) {
|
|
447
|
+
if (actual.length !== expected.length || actual.length === 0) {
|
|
448
|
+
return 0;
|
|
449
|
+
}
|
|
450
|
+
let sum = 0;
|
|
451
|
+
for (let i = 0; i < actual.length; i++) {
|
|
452
|
+
sum += Math.abs((actual[i] ?? 0) - (expected[i] ?? 0));
|
|
453
|
+
}
|
|
454
|
+
return sum / actual.length;
|
|
455
|
+
}
|
|
456
|
+
function computeMSE(actual, expected) {
|
|
457
|
+
if (actual.length !== expected.length || actual.length === 0) {
|
|
458
|
+
return 0;
|
|
459
|
+
}
|
|
460
|
+
let sum = 0;
|
|
461
|
+
for (let i = 0; i < actual.length; i++) {
|
|
462
|
+
const diff = (actual[i] ?? 0) - (expected[i] ?? 0);
|
|
463
|
+
sum += diff * diff;
|
|
464
|
+
}
|
|
465
|
+
return sum / actual.length;
|
|
466
|
+
}
|
|
467
|
+
function computeR2(actual, expected) {
|
|
468
|
+
if (actual.length !== expected.length || actual.length === 0) {
|
|
469
|
+
return 0;
|
|
470
|
+
}
|
|
471
|
+
let meanExpected = 0;
|
|
472
|
+
for (const val of expected) {
|
|
473
|
+
meanExpected += val ?? 0;
|
|
474
|
+
}
|
|
475
|
+
meanExpected /= expected.length;
|
|
476
|
+
let ssTotal = 0;
|
|
477
|
+
let ssResidual = 0;
|
|
478
|
+
for (let i = 0; i < actual.length; i++) {
|
|
479
|
+
const exp = expected[i] ?? 0;
|
|
480
|
+
const act = actual[i] ?? 0;
|
|
481
|
+
ssTotal += (exp - meanExpected) ** 2;
|
|
482
|
+
ssResidual += (exp - act) ** 2;
|
|
483
|
+
}
|
|
484
|
+
if (ssTotal === 0) {
|
|
485
|
+
return ssResidual === 0 ? 1 : 0;
|
|
486
|
+
}
|
|
487
|
+
return 1 - ssResidual / ssTotal;
|
|
488
|
+
}
|
|
489
|
+
|
|
432
490
|
// src/statistics/distribution.ts
|
|
433
491
|
function filterNumericValues(values) {
|
|
434
492
|
return values.filter(
|
|
@@ -693,9 +751,7 @@ var FieldSelector = class {
|
|
|
693
751
|
* Throws a clear error if expected values are missing.
|
|
694
752
|
*/
|
|
695
753
|
validateGroundTruth() {
|
|
696
|
-
const hasExpected = this.expectedValues.some(
|
|
697
|
-
(v) => v !== void 0 && v !== null
|
|
698
|
-
);
|
|
754
|
+
const hasExpected = this.expectedValues.some((v) => v !== void 0 && v !== null);
|
|
699
755
|
if (!hasExpected) {
|
|
700
756
|
throw new AssertionError(
|
|
701
757
|
`Classification metric requires ground truth, but field "${this.fieldName}" has no expected values. Use expectStats(predictions, groundTruth) to provide expected values.`,
|
|
@@ -920,7 +976,12 @@ var FieldSelector = class {
|
|
|
920
976
|
this.assertions.push(result);
|
|
921
977
|
recordAssertion(result);
|
|
922
978
|
if (!passed) {
|
|
923
|
-
throw new AssertionError(
|
|
979
|
+
throw new AssertionError(
|
|
980
|
+
result.message,
|
|
981
|
+
percentageThreshold,
|
|
982
|
+
actualPercentage,
|
|
983
|
+
this.fieldName
|
|
984
|
+
);
|
|
924
985
|
}
|
|
925
986
|
return this;
|
|
926
987
|
}
|
|
@@ -961,7 +1022,144 @@ var FieldSelector = class {
|
|
|
961
1022
|
this.assertions.push(result);
|
|
962
1023
|
recordAssertion(result);
|
|
963
1024
|
if (!passed) {
|
|
964
|
-
throw new AssertionError(
|
|
1025
|
+
throw new AssertionError(
|
|
1026
|
+
result.message,
|
|
1027
|
+
percentageThreshold,
|
|
1028
|
+
actualPercentage,
|
|
1029
|
+
this.fieldName
|
|
1030
|
+
);
|
|
1031
|
+
}
|
|
1032
|
+
return this;
|
|
1033
|
+
}
|
|
1034
|
+
// ============================================================================
|
|
1035
|
+
// Regression Assertions
|
|
1036
|
+
// ============================================================================
|
|
1037
|
+
/**
|
|
1038
|
+
* Validates that ground truth exists and both arrays contain numeric values.
|
|
1039
|
+
* Returns the filtered numeric arrays for regression metrics.
|
|
1040
|
+
*/
|
|
1041
|
+
validateRegressionInputs() {
|
|
1042
|
+
this.validateGroundTruth();
|
|
1043
|
+
const numericActual = filterNumericValues(this.actualValues);
|
|
1044
|
+
const numericExpected = filterNumericValues(this.expectedValues);
|
|
1045
|
+
if (numericActual.length === 0) {
|
|
1046
|
+
throw new AssertionError(
|
|
1047
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
|
|
1048
|
+
void 0,
|
|
1049
|
+
void 0,
|
|
1050
|
+
this.fieldName
|
|
1051
|
+
);
|
|
1052
|
+
}
|
|
1053
|
+
if (numericExpected.length === 0) {
|
|
1054
|
+
throw new AssertionError(
|
|
1055
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
|
|
1056
|
+
void 0,
|
|
1057
|
+
void 0,
|
|
1058
|
+
this.fieldName
|
|
1059
|
+
);
|
|
1060
|
+
}
|
|
1061
|
+
if (numericActual.length !== numericExpected.length) {
|
|
1062
|
+
throw new AssertionError(
|
|
1063
|
+
`Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
|
|
1064
|
+
numericExpected.length,
|
|
1065
|
+
numericActual.length,
|
|
1066
|
+
this.fieldName
|
|
1067
|
+
);
|
|
1068
|
+
}
|
|
1069
|
+
return { actual: numericActual, expected: numericExpected };
|
|
1070
|
+
}
|
|
1071
|
+
/**
|
|
1072
|
+
* Asserts that Mean Absolute Error is below a threshold.
|
|
1073
|
+
* Requires numeric values in both actual and expected.
|
|
1074
|
+
*
|
|
1075
|
+
* @param threshold - Maximum allowed MAE
|
|
1076
|
+
* @returns this for method chaining
|
|
1077
|
+
*
|
|
1078
|
+
* @example
|
|
1079
|
+
* expectStats(predictions, groundTruth)
|
|
1080
|
+
* .field("score")
|
|
1081
|
+
* .toHaveMAEBelow(0.1)
|
|
1082
|
+
*/
|
|
1083
|
+
toHaveMAEBelow(threshold) {
|
|
1084
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
1085
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
1086
|
+
const passed = metrics.mae <= threshold;
|
|
1087
|
+
const result = {
|
|
1088
|
+
type: "mae",
|
|
1089
|
+
passed,
|
|
1090
|
+
message: passed ? `MAE ${metrics.mae.toFixed(4)} is below ${threshold}` : `MAE ${metrics.mae.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1091
|
+
expected: threshold,
|
|
1092
|
+
actual: metrics.mae,
|
|
1093
|
+
field: this.fieldName
|
|
1094
|
+
};
|
|
1095
|
+
this.assertions.push(result);
|
|
1096
|
+
recordAssertion(result);
|
|
1097
|
+
if (!passed) {
|
|
1098
|
+
throw new AssertionError(result.message, threshold, metrics.mae, this.fieldName);
|
|
1099
|
+
}
|
|
1100
|
+
return this;
|
|
1101
|
+
}
|
|
1102
|
+
/**
|
|
1103
|
+
* Asserts that Root Mean Squared Error is below a threshold.
|
|
1104
|
+
* Requires numeric values in both actual and expected.
|
|
1105
|
+
*
|
|
1106
|
+
* @param threshold - Maximum allowed RMSE
|
|
1107
|
+
* @returns this for method chaining
|
|
1108
|
+
*
|
|
1109
|
+
* @example
|
|
1110
|
+
* expectStats(predictions, groundTruth)
|
|
1111
|
+
* .field("score")
|
|
1112
|
+
* .toHaveRMSEBelow(0.15)
|
|
1113
|
+
*/
|
|
1114
|
+
toHaveRMSEBelow(threshold) {
|
|
1115
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
1116
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
1117
|
+
const passed = metrics.rmse <= threshold;
|
|
1118
|
+
const result = {
|
|
1119
|
+
type: "rmse",
|
|
1120
|
+
passed,
|
|
1121
|
+
message: passed ? `RMSE ${metrics.rmse.toFixed(4)} is below ${threshold}` : `RMSE ${metrics.rmse.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1122
|
+
expected: threshold,
|
|
1123
|
+
actual: metrics.rmse,
|
|
1124
|
+
field: this.fieldName
|
|
1125
|
+
};
|
|
1126
|
+
this.assertions.push(result);
|
|
1127
|
+
recordAssertion(result);
|
|
1128
|
+
if (!passed) {
|
|
1129
|
+
throw new AssertionError(result.message, threshold, metrics.rmse, this.fieldName);
|
|
1130
|
+
}
|
|
1131
|
+
return this;
|
|
1132
|
+
}
|
|
1133
|
+
/**
|
|
1134
|
+
* Asserts that R-squared (coefficient of determination) is above a threshold.
|
|
1135
|
+
* R² measures how well the predictions explain the variance in expected values.
|
|
1136
|
+
* R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
|
|
1137
|
+
* Requires numeric values in both actual and expected.
|
|
1138
|
+
*
|
|
1139
|
+
* @param threshold - Minimum required R² value (0-1)
|
|
1140
|
+
* @returns this for method chaining
|
|
1141
|
+
*
|
|
1142
|
+
* @example
|
|
1143
|
+
* expectStats(predictions, groundTruth)
|
|
1144
|
+
* .field("score")
|
|
1145
|
+
* .toHaveR2Above(0.8)
|
|
1146
|
+
*/
|
|
1147
|
+
toHaveR2Above(threshold) {
|
|
1148
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
1149
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
1150
|
+
const passed = metrics.r2 >= threshold;
|
|
1151
|
+
const result = {
|
|
1152
|
+
type: "r2",
|
|
1153
|
+
passed,
|
|
1154
|
+
message: passed ? `R\xB2 ${metrics.r2.toFixed(4)} is above ${threshold}` : `R\xB2 ${metrics.r2.toFixed(4)} is below threshold ${threshold}`,
|
|
1155
|
+
expected: threshold,
|
|
1156
|
+
actual: metrics.r2,
|
|
1157
|
+
field: this.fieldName
|
|
1158
|
+
};
|
|
1159
|
+
this.assertions.push(result);
|
|
1160
|
+
recordAssertion(result);
|
|
1161
|
+
if (!passed) {
|
|
1162
|
+
throw new AssertionError(result.message, threshold, metrics.r2, this.fieldName);
|
|
965
1163
|
}
|
|
966
1164
|
return this;
|
|
967
1165
|
}
|
|
@@ -998,16 +1196,17 @@ function normalizeInput(input) {
|
|
|
998
1196
|
expected: {}
|
|
999
1197
|
}));
|
|
1000
1198
|
}
|
|
1001
|
-
throw new Error(
|
|
1199
|
+
throw new Error(
|
|
1200
|
+
"Invalid input to expectStats(): expected ModelRunResult, Prediction[], or AlignedRecord[]"
|
|
1201
|
+
);
|
|
1002
1202
|
}
|
|
1003
|
-
function expectStats(inputOrActual, expected) {
|
|
1203
|
+
function expectStats(inputOrActual, expected, options) {
|
|
1004
1204
|
if (expected !== void 0) {
|
|
1005
1205
|
if (!Array.isArray(inputOrActual)) {
|
|
1006
|
-
throw new Error(
|
|
1007
|
-
"When using two-argument expectStats(), first argument must be Prediction[]"
|
|
1008
|
-
);
|
|
1206
|
+
throw new Error("When using two-argument expectStats(), first argument must be Prediction[]");
|
|
1009
1207
|
}
|
|
1010
|
-
const
|
|
1208
|
+
const alignOptions = options ? { idField: options.idField, strict: options.strict } : void 0;
|
|
1209
|
+
const aligned2 = alignByKey(inputOrActual, expected, alignOptions);
|
|
1011
1210
|
return new ExpectStats(aligned2);
|
|
1012
1211
|
}
|
|
1013
1212
|
const aligned = normalizeInput(inputOrActual);
|