evalsense 0.2.1 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +99 -82
- package/dist/{chunk-HDJID3GC.cjs → chunk-BE7CB3AM.cjs} +39 -28
- package/dist/chunk-BE7CB3AM.cjs.map +1 -0
- package/dist/chunk-DGUM43GV.js +10 -0
- package/dist/chunk-DGUM43GV.js.map +1 -0
- package/dist/chunk-JEQ2X3Z6.cjs +12 -0
- package/dist/chunk-JEQ2X3Z6.cjs.map +1 -0
- package/dist/{chunk-5P7LNNO6.js → chunk-K6QPJ2NO.js} +39 -28
- package/dist/chunk-K6QPJ2NO.js.map +1 -0
- package/dist/{chunk-Y23VHTD3.cjs → chunk-RZFLCWTW.cjs} +2 -2
- package/dist/chunk-RZFLCWTW.cjs.map +1 -0
- package/dist/{chunk-BRPM6AB6.js → chunk-Z3U6AUWX.js} +2 -2
- package/dist/chunk-Z3U6AUWX.js.map +1 -0
- package/dist/cli.cjs +39 -36
- package/dist/cli.cjs.map +1 -1
- package/dist/cli.js +37 -34
- package/dist/cli.js.map +1 -1
- package/dist/index.cjs +320 -104
- package/dist/index.cjs.map +1 -1
- package/dist/index.d.cts +93 -7
- package/dist/index.d.ts +93 -7
- package/dist/index.js +242 -26
- package/dist/index.js.map +1 -1
- package/dist/metrics/index.cjs +257 -17
- package/dist/metrics/index.cjs.map +1 -1
- package/dist/metrics/index.d.cts +252 -1
- package/dist/metrics/index.d.ts +252 -1
- package/dist/metrics/index.js +240 -2
- package/dist/metrics/index.js.map +1 -1
- package/dist/metrics/opinionated/index.cjs +6 -5
- package/dist/metrics/opinionated/index.js +2 -1
- package/package.json +4 -3
- package/dist/chunk-5P7LNNO6.js.map +0 -1
- package/dist/chunk-BRPM6AB6.js.map +0 -1
- package/dist/chunk-HDJID3GC.cjs.map +0 -1
- package/dist/chunk-Y23VHTD3.cjs.map +0 -1
package/dist/index.d.ts
CHANGED
|
@@ -143,8 +143,12 @@ declare function runModelParallel<T extends Record<string, unknown>>(dataset: Da
|
|
|
143
143
|
interface AlignOptions {
|
|
144
144
|
/** Whether to throw on missing IDs (default: false) */
|
|
145
145
|
strict?: boolean;
|
|
146
|
-
/** Field to use as ID (default: "id") */
|
|
146
|
+
/** Field to use as ID in both arrays (default: "id") - legacy option */
|
|
147
147
|
idField?: string;
|
|
148
|
+
/** Field to use as ID in predictions array (default: "id") */
|
|
149
|
+
predictionIdField?: string;
|
|
150
|
+
/** Field to use as ID in expected/ground truth array (default: "id") */
|
|
151
|
+
expectedIdField?: string;
|
|
148
152
|
}
|
|
149
153
|
/**
|
|
150
154
|
* Aligns predictions with expected values by ID
|
|
@@ -321,6 +325,52 @@ declare class FieldSelector {
|
|
|
321
325
|
* .toHavePercentageAbove(0.7, 0.8)
|
|
322
326
|
*/
|
|
323
327
|
toHavePercentageAbove(valueThreshold: number, percentageThreshold: number): this;
|
|
328
|
+
/**
|
|
329
|
+
* Validates that ground truth exists and both arrays contain numeric values.
|
|
330
|
+
* Returns the filtered numeric arrays for regression metrics.
|
|
331
|
+
*/
|
|
332
|
+
private validateRegressionInputs;
|
|
333
|
+
/**
|
|
334
|
+
* Asserts that Mean Absolute Error is below a threshold.
|
|
335
|
+
* Requires numeric values in both actual and expected.
|
|
336
|
+
*
|
|
337
|
+
* @param threshold - Maximum allowed MAE
|
|
338
|
+
* @returns this for method chaining
|
|
339
|
+
*
|
|
340
|
+
* @example
|
|
341
|
+
* expectStats(predictions, groundTruth)
|
|
342
|
+
* .field("score")
|
|
343
|
+
* .toHaveMAEBelow(0.1)
|
|
344
|
+
*/
|
|
345
|
+
toHaveMAEBelow(threshold: number): this;
|
|
346
|
+
/**
|
|
347
|
+
* Asserts that Root Mean Squared Error is below a threshold.
|
|
348
|
+
* Requires numeric values in both actual and expected.
|
|
349
|
+
*
|
|
350
|
+
* @param threshold - Maximum allowed RMSE
|
|
351
|
+
* @returns this for method chaining
|
|
352
|
+
*
|
|
353
|
+
* @example
|
|
354
|
+
* expectStats(predictions, groundTruth)
|
|
355
|
+
* .field("score")
|
|
356
|
+
* .toHaveRMSEBelow(0.15)
|
|
357
|
+
*/
|
|
358
|
+
toHaveRMSEBelow(threshold: number): this;
|
|
359
|
+
/**
|
|
360
|
+
* Asserts that R-squared (coefficient of determination) is above a threshold.
|
|
361
|
+
* R² measures how well the predictions explain the variance in expected values.
|
|
362
|
+
* R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
|
|
363
|
+
* Requires numeric values in both actual and expected.
|
|
364
|
+
*
|
|
365
|
+
* @param threshold - Minimum required R² value (0-1)
|
|
366
|
+
* @returns this for method chaining
|
|
367
|
+
*
|
|
368
|
+
* @example
|
|
369
|
+
* expectStats(predictions, groundTruth)
|
|
370
|
+
* .field("score")
|
|
371
|
+
* .toHaveR2Above(0.8)
|
|
372
|
+
*/
|
|
373
|
+
toHaveR2Above(threshold: number): this;
|
|
324
374
|
/**
|
|
325
375
|
* Gets the computed metrics for this field
|
|
326
376
|
*/
|
|
@@ -339,15 +389,40 @@ declare class FieldSelector {
|
|
|
339
389
|
* Input types that expectStats() accepts
|
|
340
390
|
*/
|
|
341
391
|
type StatsInput = ModelRunResult | Prediction[] | AlignedRecord[];
|
|
392
|
+
/**
|
|
393
|
+
* Options for expectStats when using two-argument form
|
|
394
|
+
*/
|
|
395
|
+
interface ExpectStatsOptions {
|
|
396
|
+
/**
|
|
397
|
+
* Field to use as ID in both arrays (default: "id") - legacy option
|
|
398
|
+
* Also checks "_id" as fallback for expected records.
|
|
399
|
+
*/
|
|
400
|
+
idField?: string;
|
|
401
|
+
/**
|
|
402
|
+
* Field to use as ID in predictions array (default: "id")
|
|
403
|
+
*/
|
|
404
|
+
predictionIdField?: string;
|
|
405
|
+
/**
|
|
406
|
+
* Field to use as ID in expected/ground truth array (default: "id")
|
|
407
|
+
*/
|
|
408
|
+
expectedIdField?: string;
|
|
409
|
+
/**
|
|
410
|
+
* Whether to throw on missing IDs (default: false)
|
|
411
|
+
* When true, throws if any prediction has no matching expected record.
|
|
412
|
+
*/
|
|
413
|
+
strict?: boolean;
|
|
414
|
+
}
|
|
342
415
|
/**
|
|
343
416
|
* Entry point for statistical assertions.
|
|
344
417
|
*
|
|
345
|
-
* Supports
|
|
418
|
+
* Supports multiple usage patterns:
|
|
346
419
|
* 1. Single argument: predictions without ground truth (for distribution assertions)
|
|
347
|
-
* 2. Two arguments: predictions with ground truth (for classification metrics)
|
|
420
|
+
* 2. Two arguments: predictions with ground truth (for classification/regression metrics)
|
|
421
|
+
* 3. Three arguments: predictions with ground truth and options (for custom ID field)
|
|
348
422
|
*
|
|
349
|
-
* @param inputOrActual - Either StatsInput (one-arg) or Prediction[] (two-arg)
|
|
350
|
-
* @param expected - Ground truth data (optional, only for two-arg usage)
|
|
423
|
+
* @param inputOrActual - Either StatsInput (one-arg) or Prediction[] (two/three-arg)
|
|
424
|
+
* @param expected - Ground truth data (optional, only for two/three-arg usage)
|
|
425
|
+
* @param options - Alignment options (optional, only for three-arg usage)
|
|
351
426
|
* @returns ExpectStats instance for chaining assertions
|
|
352
427
|
*
|
|
353
428
|
* @example
|
|
@@ -357,14 +432,21 @@ type StatsInput = ModelRunResult | Prediction[] | AlignedRecord[];
|
|
|
357
432
|
* .toHavePercentageBelow(0.5, 0.9);
|
|
358
433
|
*
|
|
359
434
|
* @example
|
|
360
|
-
* // Pattern
|
|
435
|
+
* // Pattern 2: Classification with ground truth
|
|
361
436
|
* expectStats(judgeOutputs, humanLabels)
|
|
362
437
|
* .field("hallucinated")
|
|
363
438
|
* .toHaveRecallAbove(true, 0.85)
|
|
364
439
|
* .toHavePrecisionAbove(true, 0.8);
|
|
440
|
+
*
|
|
441
|
+
* @example
|
|
442
|
+
* // Pattern 3: Custom ID field
|
|
443
|
+
* expectStats(predictions, groundTruth, { idField: 'uuid' })
|
|
444
|
+
* .field("score")
|
|
445
|
+
* .toHaveAccuracyAbove(0.8);
|
|
365
446
|
*/
|
|
366
447
|
declare function expectStats(input: StatsInput): ExpectStats;
|
|
367
448
|
declare function expectStats(actual: Prediction[], expected: Array<Record<string, unknown>>): ExpectStats;
|
|
449
|
+
declare function expectStats(actual: Prediction[], expected: Array<Record<string, unknown>>, options: ExpectStatsOptions): ExpectStats;
|
|
368
450
|
/**
|
|
369
451
|
* Main stats expectation class
|
|
370
452
|
*/
|
|
@@ -520,6 +602,10 @@ declare class ConsoleReporter {
|
|
|
520
602
|
* Applies color if enabled
|
|
521
603
|
*/
|
|
522
604
|
private color;
|
|
605
|
+
/**
|
|
606
|
+
* Formats a value for display
|
|
607
|
+
*/
|
|
608
|
+
private formatValue;
|
|
523
609
|
/**
|
|
524
610
|
* Logs a line
|
|
525
611
|
*/
|
|
@@ -601,4 +687,4 @@ declare class TestExecutionError extends EvalSenseError {
|
|
|
601
687
|
constructor(message: string, testName: string, originalError?: Error);
|
|
602
688
|
}
|
|
603
689
|
|
|
604
|
-
export { AlignedRecord, AssertionError, AssertionResult, ClassificationMetrics, ConfigurationError, ConfusionMatrix, ConsoleReporter, Dataset, DatasetError, EvalReport, EvalSenseError, FieldMetricResult, IntegrityError, IntegrityResult, JsonReporter, Prediction, TestExecutionError, TestFn, afterAll, afterEach, alignByKey, beforeAll, beforeEach, buildConfusionMatrix, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, createDataset, describe, discoverEvalFiles, evalTest, executeEvalFiles, expectStats, extractFieldValues, filterComplete, formatConfusionMatrix, getExitCode, it, loadDataset, parseReport, runModel, runModelParallel, test, validatePredictions };
|
|
690
|
+
export { AlignedRecord, AssertionError, AssertionResult, ClassificationMetrics, ConfigurationError, ConfusionMatrix, ConsoleReporter, Dataset, DatasetError, EvalReport, EvalSenseError, type ExpectStatsOptions, FieldMetricResult, IntegrityError, IntegrityResult, JsonReporter, Prediction, TestExecutionError, TestFn, afterAll, afterEach, alignByKey, beforeAll, beforeEach, buildConfusionMatrix, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, createDataset, describe, discoverEvalFiles, evalTest, executeEvalFiles, expectStats, extractFieldValues, filterComplete, formatConfusionMatrix, getExitCode, it, loadDataset, parseReport, runModel, runModelParallel, test, validatePredictions };
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, DatasetError, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordAssertion, recordFieldMetrics } from './chunk-
|
|
2
|
-
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-
|
|
1
|
+
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, DatasetError, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordAssertion, recordFieldMetrics } from './chunk-K6QPJ2NO.js';
|
|
2
|
+
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-K6QPJ2NO.js';
|
|
3
|
+
import './chunk-DGUM43GV.js';
|
|
3
4
|
import { readFileSync } from 'fs';
|
|
4
5
|
import { resolve, extname } from 'path';
|
|
5
6
|
|
|
@@ -183,9 +184,7 @@ async function runModel(dataset, modelFn) {
|
|
|
183
184
|
function getRecordId(record) {
|
|
184
185
|
const id = record.id ?? record._id;
|
|
185
186
|
if (id === void 0 || id === null) {
|
|
186
|
-
throw new DatasetError(
|
|
187
|
-
'Dataset records must have an "id" or "_id" field for alignment'
|
|
188
|
-
);
|
|
187
|
+
throw new DatasetError('Dataset records must have an "id" or "_id" field for alignment');
|
|
189
188
|
}
|
|
190
189
|
return String(id);
|
|
191
190
|
}
|
|
@@ -207,9 +206,7 @@ async function runModelParallel(dataset, modelFn, concurrency = 10) {
|
|
|
207
206
|
for (const { prediction, record } of results) {
|
|
208
207
|
const id = getRecordId(record);
|
|
209
208
|
if (prediction.id !== id) {
|
|
210
|
-
throw new DatasetError(
|
|
211
|
-
`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`
|
|
212
|
-
);
|
|
209
|
+
throw new DatasetError(`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`);
|
|
213
210
|
}
|
|
214
211
|
predictions.push(prediction);
|
|
215
212
|
aligned.push({
|
|
@@ -227,16 +224,28 @@ async function runModelParallel(dataset, modelFn, concurrency = 10) {
|
|
|
227
224
|
|
|
228
225
|
// src/dataset/alignment.ts
|
|
229
226
|
function alignByKey(predictions, expected, options = {}) {
|
|
230
|
-
const { strict = false, idField
|
|
227
|
+
const { strict = false, idField, predictionIdField, expectedIdField } = options;
|
|
228
|
+
const predIdField = predictionIdField ?? idField ?? "id";
|
|
229
|
+
const expIdField = expectedIdField ?? idField ?? "id";
|
|
231
230
|
const expectedMap = /* @__PURE__ */ new Map();
|
|
232
231
|
for (const record of expected) {
|
|
233
|
-
const id = String(record[
|
|
232
|
+
const id = String(record[expIdField] ?? record._id);
|
|
233
|
+
if (!id || id === "undefined") {
|
|
234
|
+
throw new IntegrityError(
|
|
235
|
+
`Expected record missing ${expIdField} field: ${JSON.stringify(record)}`
|
|
236
|
+
);
|
|
237
|
+
}
|
|
234
238
|
expectedMap.set(id, record);
|
|
235
239
|
}
|
|
236
240
|
const aligned = [];
|
|
237
241
|
const missingIds = [];
|
|
238
242
|
for (const prediction of predictions) {
|
|
239
|
-
const id = prediction
|
|
243
|
+
const id = String(prediction[predIdField]);
|
|
244
|
+
if (!id || id === "undefined") {
|
|
245
|
+
throw new IntegrityError(
|
|
246
|
+
`Prediction missing ${predIdField} field: ${JSON.stringify(prediction)}`
|
|
247
|
+
);
|
|
248
|
+
}
|
|
240
249
|
const expectedRecord = expectedMap.get(id);
|
|
241
250
|
if (!expectedRecord) {
|
|
242
251
|
missingIds.push(id);
|
|
@@ -306,9 +315,7 @@ function checkIntegrity(dataset, options = {}) {
|
|
|
306
315
|
}
|
|
307
316
|
}
|
|
308
317
|
if (requiredFields.length > 0) {
|
|
309
|
-
const missing = requiredFields.filter(
|
|
310
|
-
(field) => record[field] === void 0
|
|
311
|
-
);
|
|
318
|
+
const missing = requiredFields.filter((field) => record[field] === void 0);
|
|
312
319
|
if (missing.length > 0) {
|
|
313
320
|
missingFields.push({
|
|
314
321
|
id: String(id ?? `record[${i}]`),
|
|
@@ -331,7 +338,9 @@ function checkIntegrity(dataset, options = {}) {
|
|
|
331
338
|
issues.push(`${missingIds.length} record(s) missing ID`);
|
|
332
339
|
}
|
|
333
340
|
if (duplicateIds.length > 0) {
|
|
334
|
-
issues.push(
|
|
341
|
+
issues.push(
|
|
342
|
+
`${duplicateIds.length} duplicate ID(s): ${duplicateIds.slice(0, 3).join(", ")}${duplicateIds.length > 3 ? "..." : ""}`
|
|
343
|
+
);
|
|
335
344
|
}
|
|
336
345
|
if (missingFields.length > 0) {
|
|
337
346
|
issues.push(`${missingFields.length} record(s) missing required fields`);
|
|
@@ -429,6 +438,67 @@ function computeAccuracy(actual, expected) {
|
|
|
429
438
|
return total > 0 ? correct / total : 0;
|
|
430
439
|
}
|
|
431
440
|
|
|
441
|
+
// src/statistics/regression.ts
|
|
442
|
+
function computeRegressionMetrics(actual, expected) {
|
|
443
|
+
if (actual.length !== expected.length) {
|
|
444
|
+
throw new Error(
|
|
445
|
+
`Array length mismatch: actual has ${actual.length} elements, expected has ${expected.length}`
|
|
446
|
+
);
|
|
447
|
+
}
|
|
448
|
+
const n = actual.length;
|
|
449
|
+
if (n === 0) {
|
|
450
|
+
return { mae: 0, mse: 0, rmse: 0, r2: 0 };
|
|
451
|
+
}
|
|
452
|
+
const mae = computeMAE(actual, expected);
|
|
453
|
+
const mse = computeMSE(actual, expected);
|
|
454
|
+
const rmse = Math.sqrt(mse);
|
|
455
|
+
const r2 = computeR2(actual, expected);
|
|
456
|
+
return { mae, mse, rmse, r2 };
|
|
457
|
+
}
|
|
458
|
+
function computeMAE(actual, expected) {
|
|
459
|
+
if (actual.length !== expected.length || actual.length === 0) {
|
|
460
|
+
return 0;
|
|
461
|
+
}
|
|
462
|
+
let sum = 0;
|
|
463
|
+
for (let i = 0; i < actual.length; i++) {
|
|
464
|
+
sum += Math.abs((actual[i] ?? 0) - (expected[i] ?? 0));
|
|
465
|
+
}
|
|
466
|
+
return sum / actual.length;
|
|
467
|
+
}
|
|
468
|
+
function computeMSE(actual, expected) {
|
|
469
|
+
if (actual.length !== expected.length || actual.length === 0) {
|
|
470
|
+
return 0;
|
|
471
|
+
}
|
|
472
|
+
let sum = 0;
|
|
473
|
+
for (let i = 0; i < actual.length; i++) {
|
|
474
|
+
const diff = (actual[i] ?? 0) - (expected[i] ?? 0);
|
|
475
|
+
sum += diff * diff;
|
|
476
|
+
}
|
|
477
|
+
return sum / actual.length;
|
|
478
|
+
}
|
|
479
|
+
function computeR2(actual, expected) {
|
|
480
|
+
if (actual.length !== expected.length || actual.length === 0) {
|
|
481
|
+
return 0;
|
|
482
|
+
}
|
|
483
|
+
let meanExpected = 0;
|
|
484
|
+
for (const val of expected) {
|
|
485
|
+
meanExpected += val ?? 0;
|
|
486
|
+
}
|
|
487
|
+
meanExpected /= expected.length;
|
|
488
|
+
let ssTotal = 0;
|
|
489
|
+
let ssResidual = 0;
|
|
490
|
+
for (let i = 0; i < actual.length; i++) {
|
|
491
|
+
const exp = expected[i] ?? 0;
|
|
492
|
+
const act = actual[i] ?? 0;
|
|
493
|
+
ssTotal += (exp - meanExpected) ** 2;
|
|
494
|
+
ssResidual += (exp - act) ** 2;
|
|
495
|
+
}
|
|
496
|
+
if (ssTotal === 0) {
|
|
497
|
+
return ssResidual === 0 ? 1 : 0;
|
|
498
|
+
}
|
|
499
|
+
return 1 - ssResidual / ssTotal;
|
|
500
|
+
}
|
|
501
|
+
|
|
432
502
|
// src/statistics/distribution.ts
|
|
433
503
|
function filterNumericValues(values) {
|
|
434
504
|
return values.filter(
|
|
@@ -693,9 +763,7 @@ var FieldSelector = class {
|
|
|
693
763
|
* Throws a clear error if expected values are missing.
|
|
694
764
|
*/
|
|
695
765
|
validateGroundTruth() {
|
|
696
|
-
const hasExpected = this.expectedValues.some(
|
|
697
|
-
(v) => v !== void 0 && v !== null
|
|
698
|
-
);
|
|
766
|
+
const hasExpected = this.expectedValues.some((v) => v !== void 0 && v !== null);
|
|
699
767
|
if (!hasExpected) {
|
|
700
768
|
throw new AssertionError(
|
|
701
769
|
`Classification metric requires ground truth, but field "${this.fieldName}" has no expected values. Use expectStats(predictions, groundTruth) to provide expected values.`,
|
|
@@ -920,7 +988,12 @@ var FieldSelector = class {
|
|
|
920
988
|
this.assertions.push(result);
|
|
921
989
|
recordAssertion(result);
|
|
922
990
|
if (!passed) {
|
|
923
|
-
throw new AssertionError(
|
|
991
|
+
throw new AssertionError(
|
|
992
|
+
result.message,
|
|
993
|
+
percentageThreshold,
|
|
994
|
+
actualPercentage,
|
|
995
|
+
this.fieldName
|
|
996
|
+
);
|
|
924
997
|
}
|
|
925
998
|
return this;
|
|
926
999
|
}
|
|
@@ -961,7 +1034,144 @@ var FieldSelector = class {
|
|
|
961
1034
|
this.assertions.push(result);
|
|
962
1035
|
recordAssertion(result);
|
|
963
1036
|
if (!passed) {
|
|
964
|
-
throw new AssertionError(
|
|
1037
|
+
throw new AssertionError(
|
|
1038
|
+
result.message,
|
|
1039
|
+
percentageThreshold,
|
|
1040
|
+
actualPercentage,
|
|
1041
|
+
this.fieldName
|
|
1042
|
+
);
|
|
1043
|
+
}
|
|
1044
|
+
return this;
|
|
1045
|
+
}
|
|
1046
|
+
// ============================================================================
|
|
1047
|
+
// Regression Assertions
|
|
1048
|
+
// ============================================================================
|
|
1049
|
+
/**
|
|
1050
|
+
* Validates that ground truth exists and both arrays contain numeric values.
|
|
1051
|
+
* Returns the filtered numeric arrays for regression metrics.
|
|
1052
|
+
*/
|
|
1053
|
+
validateRegressionInputs() {
|
|
1054
|
+
this.validateGroundTruth();
|
|
1055
|
+
const numericActual = filterNumericValues(this.actualValues);
|
|
1056
|
+
const numericExpected = filterNumericValues(this.expectedValues);
|
|
1057
|
+
if (numericActual.length === 0) {
|
|
1058
|
+
throw new AssertionError(
|
|
1059
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
|
|
1060
|
+
void 0,
|
|
1061
|
+
void 0,
|
|
1062
|
+
this.fieldName
|
|
1063
|
+
);
|
|
1064
|
+
}
|
|
1065
|
+
if (numericExpected.length === 0) {
|
|
1066
|
+
throw new AssertionError(
|
|
1067
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
|
|
1068
|
+
void 0,
|
|
1069
|
+
void 0,
|
|
1070
|
+
this.fieldName
|
|
1071
|
+
);
|
|
1072
|
+
}
|
|
1073
|
+
if (numericActual.length !== numericExpected.length) {
|
|
1074
|
+
throw new AssertionError(
|
|
1075
|
+
`Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
|
|
1076
|
+
numericExpected.length,
|
|
1077
|
+
numericActual.length,
|
|
1078
|
+
this.fieldName
|
|
1079
|
+
);
|
|
1080
|
+
}
|
|
1081
|
+
return { actual: numericActual, expected: numericExpected };
|
|
1082
|
+
}
|
|
1083
|
+
/**
|
|
1084
|
+
* Asserts that Mean Absolute Error is below a threshold.
|
|
1085
|
+
* Requires numeric values in both actual and expected.
|
|
1086
|
+
*
|
|
1087
|
+
* @param threshold - Maximum allowed MAE
|
|
1088
|
+
* @returns this for method chaining
|
|
1089
|
+
*
|
|
1090
|
+
* @example
|
|
1091
|
+
* expectStats(predictions, groundTruth)
|
|
1092
|
+
* .field("score")
|
|
1093
|
+
* .toHaveMAEBelow(0.1)
|
|
1094
|
+
*/
|
|
1095
|
+
toHaveMAEBelow(threshold) {
|
|
1096
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
1097
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
1098
|
+
const passed = metrics.mae <= threshold;
|
|
1099
|
+
const result = {
|
|
1100
|
+
type: "mae",
|
|
1101
|
+
passed,
|
|
1102
|
+
message: passed ? `MAE ${metrics.mae.toFixed(4)} is below ${threshold}` : `MAE ${metrics.mae.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1103
|
+
expected: threshold,
|
|
1104
|
+
actual: metrics.mae,
|
|
1105
|
+
field: this.fieldName
|
|
1106
|
+
};
|
|
1107
|
+
this.assertions.push(result);
|
|
1108
|
+
recordAssertion(result);
|
|
1109
|
+
if (!passed) {
|
|
1110
|
+
throw new AssertionError(result.message, threshold, metrics.mae, this.fieldName);
|
|
1111
|
+
}
|
|
1112
|
+
return this;
|
|
1113
|
+
}
|
|
1114
|
+
/**
|
|
1115
|
+
* Asserts that Root Mean Squared Error is below a threshold.
|
|
1116
|
+
* Requires numeric values in both actual and expected.
|
|
1117
|
+
*
|
|
1118
|
+
* @param threshold - Maximum allowed RMSE
|
|
1119
|
+
* @returns this for method chaining
|
|
1120
|
+
*
|
|
1121
|
+
* @example
|
|
1122
|
+
* expectStats(predictions, groundTruth)
|
|
1123
|
+
* .field("score")
|
|
1124
|
+
* .toHaveRMSEBelow(0.15)
|
|
1125
|
+
*/
|
|
1126
|
+
toHaveRMSEBelow(threshold) {
|
|
1127
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
1128
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
1129
|
+
const passed = metrics.rmse <= threshold;
|
|
1130
|
+
const result = {
|
|
1131
|
+
type: "rmse",
|
|
1132
|
+
passed,
|
|
1133
|
+
message: passed ? `RMSE ${metrics.rmse.toFixed(4)} is below ${threshold}` : `RMSE ${metrics.rmse.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1134
|
+
expected: threshold,
|
|
1135
|
+
actual: metrics.rmse,
|
|
1136
|
+
field: this.fieldName
|
|
1137
|
+
};
|
|
1138
|
+
this.assertions.push(result);
|
|
1139
|
+
recordAssertion(result);
|
|
1140
|
+
if (!passed) {
|
|
1141
|
+
throw new AssertionError(result.message, threshold, metrics.rmse, this.fieldName);
|
|
1142
|
+
}
|
|
1143
|
+
return this;
|
|
1144
|
+
}
|
|
1145
|
+
/**
|
|
1146
|
+
* Asserts that R-squared (coefficient of determination) is above a threshold.
|
|
1147
|
+
* R² measures how well the predictions explain the variance in expected values.
|
|
1148
|
+
* R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
|
|
1149
|
+
* Requires numeric values in both actual and expected.
|
|
1150
|
+
*
|
|
1151
|
+
* @param threshold - Minimum required R² value (0-1)
|
|
1152
|
+
* @returns this for method chaining
|
|
1153
|
+
*
|
|
1154
|
+
* @example
|
|
1155
|
+
* expectStats(predictions, groundTruth)
|
|
1156
|
+
* .field("score")
|
|
1157
|
+
* .toHaveR2Above(0.8)
|
|
1158
|
+
*/
|
|
1159
|
+
toHaveR2Above(threshold) {
|
|
1160
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
1161
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
1162
|
+
const passed = metrics.r2 >= threshold;
|
|
1163
|
+
const result = {
|
|
1164
|
+
type: "r2",
|
|
1165
|
+
passed,
|
|
1166
|
+
message: passed ? `R\xB2 ${metrics.r2.toFixed(4)} is above ${threshold}` : `R\xB2 ${metrics.r2.toFixed(4)} is below threshold ${threshold}`,
|
|
1167
|
+
expected: threshold,
|
|
1168
|
+
actual: metrics.r2,
|
|
1169
|
+
field: this.fieldName
|
|
1170
|
+
};
|
|
1171
|
+
this.assertions.push(result);
|
|
1172
|
+
recordAssertion(result);
|
|
1173
|
+
if (!passed) {
|
|
1174
|
+
throw new AssertionError(result.message, threshold, metrics.r2, this.fieldName);
|
|
965
1175
|
}
|
|
966
1176
|
return this;
|
|
967
1177
|
}
|
|
@@ -998,16 +1208,22 @@ function normalizeInput(input) {
|
|
|
998
1208
|
expected: {}
|
|
999
1209
|
}));
|
|
1000
1210
|
}
|
|
1001
|
-
throw new Error(
|
|
1211
|
+
throw new Error(
|
|
1212
|
+
"Invalid input to expectStats(): expected ModelRunResult, Prediction[], or AlignedRecord[]"
|
|
1213
|
+
);
|
|
1002
1214
|
}
|
|
1003
|
-
function expectStats(inputOrActual, expected) {
|
|
1215
|
+
function expectStats(inputOrActual, expected, options) {
|
|
1004
1216
|
if (expected !== void 0) {
|
|
1005
1217
|
if (!Array.isArray(inputOrActual)) {
|
|
1006
|
-
throw new Error(
|
|
1007
|
-
"When using two-argument expectStats(), first argument must be Prediction[]"
|
|
1008
|
-
);
|
|
1218
|
+
throw new Error("When using two-argument expectStats(), first argument must be Prediction[]");
|
|
1009
1219
|
}
|
|
1010
|
-
const
|
|
1220
|
+
const alignOptions = options ? {
|
|
1221
|
+
idField: options.idField,
|
|
1222
|
+
predictionIdField: options.predictionIdField,
|
|
1223
|
+
expectedIdField: options.expectedIdField,
|
|
1224
|
+
strict: options.strict
|
|
1225
|
+
} : void 0;
|
|
1226
|
+
const aligned2 = alignByKey(inputOrActual, expected, alignOptions);
|
|
1011
1227
|
return new ExpectStats(aligned2);
|
|
1012
1228
|
}
|
|
1013
1229
|
const aligned = normalizeInput(inputOrActual);
|