evalsense 0.3.1 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +309 -149
- package/dist/{chunk-BE7CB3AM.cjs → chunk-4BKZPVY4.cjs} +50 -13
- package/dist/chunk-4BKZPVY4.cjs.map +1 -0
- package/dist/{chunk-K6QPJ2NO.js → chunk-IUVDDMJ3.js} +50 -13
- package/dist/chunk-IUVDDMJ3.js.map +1 -0
- package/dist/chunk-NCCQRZ2Y.cjs +1141 -0
- package/dist/chunk-NCCQRZ2Y.cjs.map +1 -0
- package/dist/chunk-TDGWDK2L.js +1108 -0
- package/dist/chunk-TDGWDK2L.js.map +1 -0
- package/dist/cli.cjs +11 -11
- package/dist/cli.js +1 -1
- package/dist/index-CATqAHNK.d.cts +416 -0
- package/dist/index-CoMpaW-K.d.ts +416 -0
- package/dist/index.cjs +507 -629
- package/dist/index.cjs.map +1 -1
- package/dist/index.d.cts +210 -161
- package/dist/index.d.ts +210 -161
- package/dist/index.js +455 -573
- package/dist/index.js.map +1 -1
- package/dist/metrics/index.cjs +103 -342
- package/dist/metrics/index.cjs.map +1 -1
- package/dist/metrics/index.d.cts +260 -31
- package/dist/metrics/index.d.ts +260 -31
- package/dist/metrics/index.js +24 -312
- package/dist/metrics/index.js.map +1 -1
- package/dist/metrics/opinionated/index.cjs +5 -5
- package/dist/metrics/opinionated/index.d.cts +2 -163
- package/dist/metrics/opinionated/index.d.ts +2 -163
- package/dist/metrics/opinionated/index.js +1 -1
- package/dist/{types-C71p0wzM.d.cts → types-D0hzfyKm.d.cts} +1 -13
- package/dist/{types-C71p0wzM.d.ts → types-D0hzfyKm.d.ts} +1 -13
- package/package.json +1 -1
- package/dist/chunk-BE7CB3AM.cjs.map +0 -1
- package/dist/chunk-K6QPJ2NO.js.map +0 -1
- package/dist/chunk-RZFLCWTW.cjs +0 -942
- package/dist/chunk-RZFLCWTW.cjs.map +0 -1
- package/dist/chunk-Z3U6AUWX.js +0 -925
- package/dist/chunk-Z3U6AUWX.js.map +0 -1
package/dist/index.js
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
|
-
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite,
|
|
2
|
-
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-
|
|
1
|
+
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordFieldMetrics, recordAssertion } from './chunk-IUVDDMJ3.js';
|
|
2
|
+
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-IUVDDMJ3.js';
|
|
3
3
|
import './chunk-DGUM43GV.js';
|
|
4
|
-
import { readFileSync } from 'fs';
|
|
5
|
-
import { resolve, extname } from 'path';
|
|
6
4
|
|
|
7
5
|
// src/core/describe.ts
|
|
8
6
|
function describe(name, fn) {
|
|
@@ -91,136 +89,6 @@ function evalTestOnly(name, fn) {
|
|
|
91
89
|
}
|
|
92
90
|
evalTest.skip = evalTestSkip;
|
|
93
91
|
evalTest.only = evalTestOnly;
|
|
94
|
-
function loadDataset(path) {
|
|
95
|
-
const absolutePath = resolve(process.cwd(), path);
|
|
96
|
-
const ext = extname(absolutePath).toLowerCase();
|
|
97
|
-
let records;
|
|
98
|
-
try {
|
|
99
|
-
const content = readFileSync(absolutePath, "utf-8");
|
|
100
|
-
if (ext === ".ndjson" || ext === ".jsonl") {
|
|
101
|
-
records = parseNDJSON(content);
|
|
102
|
-
} else if (ext === ".json") {
|
|
103
|
-
records = parseJSON(content);
|
|
104
|
-
} else {
|
|
105
|
-
throw new DatasetError(
|
|
106
|
-
`Unsupported file format: ${ext}. Use .json, .ndjson, or .jsonl`,
|
|
107
|
-
path
|
|
108
|
-
);
|
|
109
|
-
}
|
|
110
|
-
} catch (error) {
|
|
111
|
-
if (error instanceof DatasetError) {
|
|
112
|
-
throw error;
|
|
113
|
-
}
|
|
114
|
-
const message = error instanceof Error ? error.message : String(error);
|
|
115
|
-
throw new DatasetError(`Failed to load dataset from ${path}: ${message}`, path);
|
|
116
|
-
}
|
|
117
|
-
return {
|
|
118
|
-
records,
|
|
119
|
-
metadata: {
|
|
120
|
-
source: path,
|
|
121
|
-
count: records.length,
|
|
122
|
-
loadedAt: /* @__PURE__ */ new Date()
|
|
123
|
-
}
|
|
124
|
-
};
|
|
125
|
-
}
|
|
126
|
-
function parseJSON(content) {
|
|
127
|
-
const parsed = JSON.parse(content);
|
|
128
|
-
if (!Array.isArray(parsed)) {
|
|
129
|
-
throw new DatasetError("JSON dataset must be an array of records");
|
|
130
|
-
}
|
|
131
|
-
return parsed;
|
|
132
|
-
}
|
|
133
|
-
function parseNDJSON(content) {
|
|
134
|
-
const lines = content.split("\n").filter((line) => line.trim() !== "");
|
|
135
|
-
const records = [];
|
|
136
|
-
for (let i = 0; i < lines.length; i++) {
|
|
137
|
-
const line = lines[i];
|
|
138
|
-
if (line === void 0) continue;
|
|
139
|
-
try {
|
|
140
|
-
records.push(JSON.parse(line));
|
|
141
|
-
} catch {
|
|
142
|
-
throw new DatasetError(`Invalid JSON at line ${i + 1} in NDJSON file`);
|
|
143
|
-
}
|
|
144
|
-
}
|
|
145
|
-
return records;
|
|
146
|
-
}
|
|
147
|
-
function createDataset(records, source = "inline") {
|
|
148
|
-
return {
|
|
149
|
-
records,
|
|
150
|
-
metadata: {
|
|
151
|
-
source,
|
|
152
|
-
count: records.length,
|
|
153
|
-
loadedAt: /* @__PURE__ */ new Date()
|
|
154
|
-
}
|
|
155
|
-
};
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
// src/dataset/run-model.ts
|
|
159
|
-
async function runModel(dataset, modelFn) {
|
|
160
|
-
const startTime = Date.now();
|
|
161
|
-
const predictions = [];
|
|
162
|
-
const aligned = [];
|
|
163
|
-
for (const record of dataset.records) {
|
|
164
|
-
const id = getRecordId(record);
|
|
165
|
-
const prediction = await modelFn(record);
|
|
166
|
-
if (prediction.id !== id) {
|
|
167
|
-
throw new DatasetError(
|
|
168
|
-
`Prediction ID mismatch: expected "${id}", got "${prediction.id}". Model function must return the same ID as the input record.`
|
|
169
|
-
);
|
|
170
|
-
}
|
|
171
|
-
predictions.push(prediction);
|
|
172
|
-
aligned.push({
|
|
173
|
-
id,
|
|
174
|
-
actual: { ...prediction },
|
|
175
|
-
expected: { ...record }
|
|
176
|
-
});
|
|
177
|
-
}
|
|
178
|
-
return {
|
|
179
|
-
predictions,
|
|
180
|
-
aligned,
|
|
181
|
-
duration: Date.now() - startTime
|
|
182
|
-
};
|
|
183
|
-
}
|
|
184
|
-
function getRecordId(record) {
|
|
185
|
-
const id = record.id ?? record._id;
|
|
186
|
-
if (id === void 0 || id === null) {
|
|
187
|
-
throw new DatasetError('Dataset records must have an "id" or "_id" field for alignment');
|
|
188
|
-
}
|
|
189
|
-
return String(id);
|
|
190
|
-
}
|
|
191
|
-
async function runModelParallel(dataset, modelFn, concurrency = 10) {
|
|
192
|
-
const startTime = Date.now();
|
|
193
|
-
const results = [];
|
|
194
|
-
for (let i = 0; i < dataset.records.length; i += concurrency) {
|
|
195
|
-
const batch = dataset.records.slice(i, i + concurrency);
|
|
196
|
-
const batchResults = await Promise.all(
|
|
197
|
-
batch.map(async (record) => {
|
|
198
|
-
const prediction = await modelFn(record);
|
|
199
|
-
return { prediction, record };
|
|
200
|
-
})
|
|
201
|
-
);
|
|
202
|
-
results.push(...batchResults);
|
|
203
|
-
}
|
|
204
|
-
const predictions = [];
|
|
205
|
-
const aligned = [];
|
|
206
|
-
for (const { prediction, record } of results) {
|
|
207
|
-
const id = getRecordId(record);
|
|
208
|
-
if (prediction.id !== id) {
|
|
209
|
-
throw new DatasetError(`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`);
|
|
210
|
-
}
|
|
211
|
-
predictions.push(prediction);
|
|
212
|
-
aligned.push({
|
|
213
|
-
id,
|
|
214
|
-
actual: { ...prediction },
|
|
215
|
-
expected: { ...record }
|
|
216
|
-
});
|
|
217
|
-
}
|
|
218
|
-
return {
|
|
219
|
-
predictions,
|
|
220
|
-
aligned,
|
|
221
|
-
duration: Date.now() - startTime
|
|
222
|
-
};
|
|
223
|
-
}
|
|
224
92
|
|
|
225
93
|
// src/dataset/alignment.ts
|
|
226
94
|
function alignByKey(predictions, expected, options = {}) {
|
|
@@ -293,14 +161,14 @@ function filterComplete(aligned, field) {
|
|
|
293
161
|
}
|
|
294
162
|
|
|
295
163
|
// src/dataset/integrity.ts
|
|
296
|
-
function checkIntegrity(
|
|
164
|
+
function checkIntegrity(records, options = {}) {
|
|
297
165
|
const { requiredFields = [], throwOnFailure = false } = options;
|
|
298
166
|
const seenIds = /* @__PURE__ */ new Map();
|
|
299
167
|
const missingIds = [];
|
|
300
168
|
const duplicateIds = [];
|
|
301
169
|
const missingFields = [];
|
|
302
|
-
for (let i = 0; i <
|
|
303
|
-
const record =
|
|
170
|
+
for (let i = 0; i < records.length; i++) {
|
|
171
|
+
const record = records[i];
|
|
304
172
|
if (!record) continue;
|
|
305
173
|
const id = record.id ?? record._id;
|
|
306
174
|
if (id === void 0 || id === null) {
|
|
@@ -327,7 +195,7 @@ function checkIntegrity(dataset, options = {}) {
|
|
|
327
195
|
const valid = missingIds.length === 0 && duplicateIds.length === 0 && missingFields.length === 0;
|
|
328
196
|
const result = {
|
|
329
197
|
valid,
|
|
330
|
-
totalRecords:
|
|
198
|
+
totalRecords: records.length,
|
|
331
199
|
missingIds,
|
|
332
200
|
duplicateIds,
|
|
333
201
|
missingFields
|
|
@@ -520,6 +388,91 @@ function calculatePercentageAbove(values, threshold) {
|
|
|
520
388
|
return countAbove / values.length;
|
|
521
389
|
}
|
|
522
390
|
|
|
391
|
+
// src/assertions/metric-matcher.ts
|
|
392
|
+
var MetricMatcher = class {
|
|
393
|
+
context;
|
|
394
|
+
constructor(context) {
|
|
395
|
+
this.context = context;
|
|
396
|
+
}
|
|
397
|
+
formatMetricValue(value) {
|
|
398
|
+
if (this.context.formatValue) {
|
|
399
|
+
return this.context.formatValue(value);
|
|
400
|
+
}
|
|
401
|
+
if (value >= 0 && value <= 1) {
|
|
402
|
+
return `${(value * 100).toFixed(1)}%`;
|
|
403
|
+
}
|
|
404
|
+
return value.toFixed(4);
|
|
405
|
+
}
|
|
406
|
+
createAssertion(operator, threshold, passed) {
|
|
407
|
+
const { metricName, metricValue, fieldName, targetClass } = this.context;
|
|
408
|
+
const formattedActual = this.formatMetricValue(metricValue);
|
|
409
|
+
const formattedThreshold = this.formatMetricValue(threshold);
|
|
410
|
+
const classInfo = targetClass ? ` for "${targetClass}"` : "";
|
|
411
|
+
const operatorText = {
|
|
412
|
+
">=": "at least",
|
|
413
|
+
">": "above",
|
|
414
|
+
"<=": "at most",
|
|
415
|
+
"<": "below",
|
|
416
|
+
"===": "equal to"
|
|
417
|
+
}[operator];
|
|
418
|
+
const message = passed ? `${metricName}${classInfo} ${formattedActual} is ${operatorText} ${formattedThreshold}` : `${metricName}${classInfo} ${formattedActual} is not ${operatorText} ${formattedThreshold}`;
|
|
419
|
+
return {
|
|
420
|
+
type: metricName.toLowerCase().replace(/\s+/g, "").replace(/²/g, "2"),
|
|
421
|
+
passed,
|
|
422
|
+
message,
|
|
423
|
+
expected: threshold,
|
|
424
|
+
actual: metricValue,
|
|
425
|
+
field: fieldName,
|
|
426
|
+
class: targetClass
|
|
427
|
+
};
|
|
428
|
+
}
|
|
429
|
+
recordAndReturn(result) {
|
|
430
|
+
this.context.assertions.push(result);
|
|
431
|
+
recordAssertion(result);
|
|
432
|
+
return this.context.parent;
|
|
433
|
+
}
|
|
434
|
+
/**
|
|
435
|
+
* Assert that the metric is greater than or equal to the threshold (>=)
|
|
436
|
+
*/
|
|
437
|
+
toBeAtLeast(threshold) {
|
|
438
|
+
const passed = this.context.metricValue >= threshold;
|
|
439
|
+
const result = this.createAssertion(">=", threshold, passed);
|
|
440
|
+
return this.recordAndReturn(result);
|
|
441
|
+
}
|
|
442
|
+
/**
|
|
443
|
+
* Assert that the metric is strictly greater than the threshold (>)
|
|
444
|
+
*/
|
|
445
|
+
toBeAbove(threshold) {
|
|
446
|
+
const passed = this.context.metricValue > threshold;
|
|
447
|
+
const result = this.createAssertion(">", threshold, passed);
|
|
448
|
+
return this.recordAndReturn(result);
|
|
449
|
+
}
|
|
450
|
+
/**
|
|
451
|
+
* Assert that the metric is less than or equal to the threshold (<=)
|
|
452
|
+
*/
|
|
453
|
+
toBeAtMost(threshold) {
|
|
454
|
+
const passed = this.context.metricValue <= threshold;
|
|
455
|
+
const result = this.createAssertion("<=", threshold, passed);
|
|
456
|
+
return this.recordAndReturn(result);
|
|
457
|
+
}
|
|
458
|
+
/**
|
|
459
|
+
* Assert that the metric is strictly less than the threshold (<)
|
|
460
|
+
*/
|
|
461
|
+
toBeBelow(threshold) {
|
|
462
|
+
const passed = this.context.metricValue < threshold;
|
|
463
|
+
const result = this.createAssertion("<", threshold, passed);
|
|
464
|
+
return this.recordAndReturn(result);
|
|
465
|
+
}
|
|
466
|
+
/**
|
|
467
|
+
* Assert that the metric equals the expected value (with optional tolerance for floats)
|
|
468
|
+
*/
|
|
469
|
+
toEqual(expected, tolerance = 1e-9) {
|
|
470
|
+
const passed = Math.abs(this.context.metricValue - expected) <= tolerance;
|
|
471
|
+
const result = this.createAssertion("===", expected, passed);
|
|
472
|
+
return this.recordAndReturn(result);
|
|
473
|
+
}
|
|
474
|
+
};
|
|
475
|
+
|
|
523
476
|
// src/assertions/binarize.ts
|
|
524
477
|
var BinarizeSelector = class {
|
|
525
478
|
fieldName;
|
|
@@ -551,161 +504,127 @@ var BinarizeSelector = class {
|
|
|
551
504
|
}
|
|
552
505
|
}
|
|
553
506
|
}
|
|
507
|
+
// ============================================================================
|
|
508
|
+
// Classification Metric Getters
|
|
509
|
+
// ============================================================================
|
|
554
510
|
/**
|
|
555
|
-
*
|
|
511
|
+
* Access accuracy metric for assertions
|
|
512
|
+
* @example
|
|
513
|
+
* expectStats(predictions, groundTruth)
|
|
514
|
+
* .field("score")
|
|
515
|
+
* .binarize(0.5)
|
|
516
|
+
* .accuracy.toBeAtLeast(0.8)
|
|
556
517
|
*/
|
|
557
|
-
|
|
518
|
+
get accuracy() {
|
|
558
519
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
field: this.fieldName
|
|
567
|
-
};
|
|
568
|
-
this.assertions.push(result);
|
|
569
|
-
recordAssertion(result);
|
|
570
|
-
if (!passed) {
|
|
571
|
-
throw new AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
|
|
572
|
-
}
|
|
573
|
-
return this;
|
|
520
|
+
return new MetricMatcher({
|
|
521
|
+
parent: this,
|
|
522
|
+
metricName: "Accuracy",
|
|
523
|
+
metricValue: metrics.accuracy,
|
|
524
|
+
fieldName: this.fieldName,
|
|
525
|
+
assertions: this.assertions
|
|
526
|
+
});
|
|
574
527
|
}
|
|
575
528
|
/**
|
|
576
|
-
*
|
|
577
|
-
* @
|
|
578
|
-
*
|
|
529
|
+
* Access F1 score metric for assertions (macro average)
|
|
530
|
+
* @example
|
|
531
|
+
* expectStats(predictions, groundTruth)
|
|
532
|
+
* .field("score")
|
|
533
|
+
* .binarize(0.5)
|
|
534
|
+
* .f1.toBeAtLeast(0.75)
|
|
579
535
|
*/
|
|
580
|
-
|
|
536
|
+
get f1() {
|
|
581
537
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
}
|
|
589
|
-
targetClass = String(classOrThreshold);
|
|
590
|
-
actualThreshold = threshold;
|
|
591
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
592
|
-
if (!classMetrics) {
|
|
593
|
-
throw new AssertionError(
|
|
594
|
-
`Class "${targetClass}" not found in binarized predictions`,
|
|
595
|
-
targetClass,
|
|
596
|
-
Object.keys(metrics.perClass),
|
|
597
|
-
this.fieldName
|
|
598
|
-
);
|
|
599
|
-
}
|
|
600
|
-
actualPrecision = classMetrics.precision;
|
|
601
|
-
}
|
|
602
|
-
const passed = actualPrecision >= actualThreshold;
|
|
603
|
-
const result = {
|
|
604
|
-
type: "precision",
|
|
605
|
-
passed,
|
|
606
|
-
message: passed ? `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
607
|
-
expected: actualThreshold,
|
|
608
|
-
actual: actualPrecision,
|
|
609
|
-
field: this.fieldName,
|
|
610
|
-
class: targetClass
|
|
611
|
-
};
|
|
612
|
-
this.assertions.push(result);
|
|
613
|
-
recordAssertion(result);
|
|
614
|
-
if (!passed) {
|
|
615
|
-
throw new AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
|
|
616
|
-
}
|
|
617
|
-
return this;
|
|
538
|
+
return new MetricMatcher({
|
|
539
|
+
parent: this,
|
|
540
|
+
metricName: "F1",
|
|
541
|
+
metricValue: metrics.macroAvg.f1,
|
|
542
|
+
fieldName: this.fieldName,
|
|
543
|
+
assertions: this.assertions
|
|
544
|
+
});
|
|
618
545
|
}
|
|
619
546
|
/**
|
|
620
|
-
*
|
|
621
|
-
* @param
|
|
622
|
-
* @
|
|
547
|
+
* Access precision metric for assertions
|
|
548
|
+
* @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
|
|
549
|
+
* @example
|
|
550
|
+
* expectStats(predictions, groundTruth)
|
|
551
|
+
* .field("score")
|
|
552
|
+
* .binarize(0.5)
|
|
553
|
+
* .precision(true).toBeAtLeast(0.7)
|
|
623
554
|
*/
|
|
624
|
-
|
|
555
|
+
precision(targetClass) {
|
|
625
556
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
626
|
-
let
|
|
627
|
-
let
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
actualRecall = metrics.macroAvg.recall;
|
|
631
|
-
actualThreshold = classOrThreshold;
|
|
557
|
+
let metricValue;
|
|
558
|
+
let classKey;
|
|
559
|
+
if (targetClass === void 0) {
|
|
560
|
+
metricValue = metrics.macroAvg.precision;
|
|
632
561
|
} else {
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
562
|
+
classKey = String(targetClass);
|
|
563
|
+
const classMetrics = metrics.perClass[classKey];
|
|
636
564
|
if (!classMetrics) {
|
|
637
565
|
throw new AssertionError(
|
|
638
|
-
`Class "${
|
|
639
|
-
|
|
566
|
+
`Class "${classKey}" not found in binarized predictions`,
|
|
567
|
+
classKey,
|
|
640
568
|
Object.keys(metrics.perClass),
|
|
641
569
|
this.fieldName
|
|
642
570
|
);
|
|
643
571
|
}
|
|
644
|
-
|
|
645
|
-
}
|
|
646
|
-
const passed = actualRecall >= actualThreshold;
|
|
647
|
-
const result = {
|
|
648
|
-
type: "recall",
|
|
649
|
-
passed,
|
|
650
|
-
message: passed ? `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for ${targetClass}` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
651
|
-
expected: actualThreshold,
|
|
652
|
-
actual: actualRecall,
|
|
653
|
-
field: this.fieldName,
|
|
654
|
-
class: targetClass
|
|
655
|
-
};
|
|
656
|
-
this.assertions.push(result);
|
|
657
|
-
recordAssertion(result);
|
|
658
|
-
if (!passed) {
|
|
659
|
-
throw new AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
|
|
572
|
+
metricValue = classMetrics.precision;
|
|
660
573
|
}
|
|
661
|
-
return
|
|
574
|
+
return new MetricMatcher({
|
|
575
|
+
parent: this,
|
|
576
|
+
metricName: "Precision",
|
|
577
|
+
metricValue,
|
|
578
|
+
fieldName: this.fieldName,
|
|
579
|
+
targetClass: classKey,
|
|
580
|
+
assertions: this.assertions
|
|
581
|
+
});
|
|
662
582
|
}
|
|
663
583
|
/**
|
|
664
|
-
*
|
|
584
|
+
* Access recall metric for assertions
|
|
585
|
+
* @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
|
|
586
|
+
* @example
|
|
587
|
+
* expectStats(predictions, groundTruth)
|
|
588
|
+
* .field("score")
|
|
589
|
+
* .binarize(0.5)
|
|
590
|
+
* .recall(true).toBeAtLeast(0.7)
|
|
665
591
|
*/
|
|
666
|
-
|
|
592
|
+
recall(targetClass) {
|
|
667
593
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
668
|
-
let
|
|
669
|
-
let
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
actualF1 = metrics.macroAvg.f1;
|
|
673
|
-
actualThreshold = classOrThreshold;
|
|
594
|
+
let metricValue;
|
|
595
|
+
let classKey;
|
|
596
|
+
if (targetClass === void 0) {
|
|
597
|
+
metricValue = metrics.macroAvg.recall;
|
|
674
598
|
} else {
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
599
|
+
classKey = String(targetClass);
|
|
600
|
+
const classMetrics = metrics.perClass[classKey];
|
|
678
601
|
if (!classMetrics) {
|
|
679
602
|
throw new AssertionError(
|
|
680
|
-
`Class "${
|
|
681
|
-
|
|
603
|
+
`Class "${classKey}" not found in binarized predictions`,
|
|
604
|
+
classKey,
|
|
682
605
|
Object.keys(metrics.perClass),
|
|
683
606
|
this.fieldName
|
|
684
607
|
);
|
|
685
608
|
}
|
|
686
|
-
|
|
687
|
-
}
|
|
688
|
-
const passed = actualF1 >= actualThreshold;
|
|
689
|
-
const result = {
|
|
690
|
-
type: "f1",
|
|
691
|
-
passed,
|
|
692
|
-
message: passed ? `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for ${targetClass}` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
693
|
-
expected: actualThreshold,
|
|
694
|
-
actual: actualF1,
|
|
695
|
-
field: this.fieldName,
|
|
696
|
-
class: targetClass
|
|
697
|
-
};
|
|
698
|
-
this.assertions.push(result);
|
|
699
|
-
recordAssertion(result);
|
|
700
|
-
if (!passed) {
|
|
701
|
-
throw new AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
|
|
609
|
+
metricValue = classMetrics.recall;
|
|
702
610
|
}
|
|
703
|
-
return
|
|
611
|
+
return new MetricMatcher({
|
|
612
|
+
parent: this,
|
|
613
|
+
metricName: "Recall",
|
|
614
|
+
metricValue,
|
|
615
|
+
fieldName: this.fieldName,
|
|
616
|
+
targetClass: classKey,
|
|
617
|
+
assertions: this.assertions
|
|
618
|
+
});
|
|
704
619
|
}
|
|
620
|
+
// ============================================================================
|
|
621
|
+
// Display Methods
|
|
622
|
+
// ============================================================================
|
|
705
623
|
/**
|
|
706
|
-
*
|
|
624
|
+
* Displays the confusion matrix in the report
|
|
625
|
+
* This is not an assertion - it always passes and just records the matrix for display
|
|
707
626
|
*/
|
|
708
|
-
|
|
627
|
+
displayConfusionMatrix() {
|
|
709
628
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
710
629
|
const fieldResult = {
|
|
711
630
|
field: this.fieldName,
|
|
@@ -724,6 +643,9 @@ var BinarizeSelector = class {
|
|
|
724
643
|
recordAssertion(result);
|
|
725
644
|
return this;
|
|
726
645
|
}
|
|
646
|
+
// ============================================================================
|
|
647
|
+
// Utility Methods
|
|
648
|
+
// ============================================================================
|
|
727
649
|
/**
|
|
728
650
|
* Gets computed metrics
|
|
729
651
|
*/
|
|
@@ -738,6 +660,73 @@ var BinarizeSelector = class {
|
|
|
738
660
|
}
|
|
739
661
|
};
|
|
740
662
|
|
|
663
|
+
// src/assertions/percentage-matcher.ts
|
|
664
|
+
var PercentageMatcher = class {
|
|
665
|
+
context;
|
|
666
|
+
constructor(context) {
|
|
667
|
+
this.context = context;
|
|
668
|
+
}
|
|
669
|
+
formatPercentage(value) {
|
|
670
|
+
return `${(value * 100).toFixed(1)}%`;
|
|
671
|
+
}
|
|
672
|
+
createAssertion(operator, percentageThreshold, passed) {
|
|
673
|
+
const { fieldName, valueThreshold, direction, actualPercentage } = this.context;
|
|
674
|
+
const operatorText = {
|
|
675
|
+
">=": "at least",
|
|
676
|
+
">": "above",
|
|
677
|
+
"<=": "at most",
|
|
678
|
+
"<": "below"
|
|
679
|
+
}[operator];
|
|
680
|
+
const directionText = direction === "above" ? "above" : "below or equal to";
|
|
681
|
+
const message = passed ? `${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})` : `Only ${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})`;
|
|
682
|
+
return {
|
|
683
|
+
type: direction === "above" ? "percentageAbove" : "percentageBelow",
|
|
684
|
+
passed,
|
|
685
|
+
message,
|
|
686
|
+
expected: percentageThreshold,
|
|
687
|
+
actual: actualPercentage,
|
|
688
|
+
field: fieldName
|
|
689
|
+
};
|
|
690
|
+
}
|
|
691
|
+
recordAndReturn(result) {
|
|
692
|
+
this.context.assertions.push(result);
|
|
693
|
+
recordAssertion(result);
|
|
694
|
+
return this.context.parent;
|
|
695
|
+
}
|
|
696
|
+
/**
|
|
697
|
+
* Assert that the percentage is greater than or equal to the threshold (>=)
|
|
698
|
+
*/
|
|
699
|
+
toBeAtLeast(percentageThreshold) {
|
|
700
|
+
const passed = this.context.actualPercentage >= percentageThreshold;
|
|
701
|
+
const result = this.createAssertion(">=", percentageThreshold, passed);
|
|
702
|
+
return this.recordAndReturn(result);
|
|
703
|
+
}
|
|
704
|
+
/**
|
|
705
|
+
* Assert that the percentage is strictly greater than the threshold (>)
|
|
706
|
+
*/
|
|
707
|
+
toBeAbove(percentageThreshold) {
|
|
708
|
+
const passed = this.context.actualPercentage > percentageThreshold;
|
|
709
|
+
const result = this.createAssertion(">", percentageThreshold, passed);
|
|
710
|
+
return this.recordAndReturn(result);
|
|
711
|
+
}
|
|
712
|
+
/**
|
|
713
|
+
* Assert that the percentage is less than or equal to the threshold (<=)
|
|
714
|
+
*/
|
|
715
|
+
toBeAtMost(percentageThreshold) {
|
|
716
|
+
const passed = this.context.actualPercentage <= percentageThreshold;
|
|
717
|
+
const result = this.createAssertion("<=", percentageThreshold, passed);
|
|
718
|
+
return this.recordAndReturn(result);
|
|
719
|
+
}
|
|
720
|
+
/**
|
|
721
|
+
* Assert that the percentage is strictly less than the threshold (<)
|
|
722
|
+
*/
|
|
723
|
+
toBeBelow(percentageThreshold) {
|
|
724
|
+
const passed = this.context.actualPercentage < percentageThreshold;
|
|
725
|
+
const result = this.createAssertion("<", percentageThreshold, passed);
|
|
726
|
+
return this.recordAndReturn(result);
|
|
727
|
+
}
|
|
728
|
+
};
|
|
729
|
+
|
|
741
730
|
// src/assertions/field-selector.ts
|
|
742
731
|
var FieldSelector = class {
|
|
743
732
|
aligned;
|
|
@@ -774,89 +763,93 @@ var FieldSelector = class {
|
|
|
774
763
|
}
|
|
775
764
|
}
|
|
776
765
|
/**
|
|
777
|
-
*
|
|
766
|
+
* Validates that ground truth exists and both arrays contain numeric values.
|
|
767
|
+
* Returns the filtered numeric arrays for regression metrics.
|
|
778
768
|
*/
|
|
779
|
-
|
|
769
|
+
validateRegressionInputs() {
|
|
780
770
|
this.validateGroundTruth();
|
|
781
|
-
const
|
|
782
|
-
const
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
};
|
|
791
|
-
this.assertions.push(result);
|
|
792
|
-
recordAssertion(result);
|
|
793
|
-
if (!passed) {
|
|
794
|
-
throw new AssertionError(result.message, threshold, metrics.accuracy, this.fieldName);
|
|
771
|
+
const numericActual = filterNumericValues(this.actualValues);
|
|
772
|
+
const numericExpected = filterNumericValues(this.expectedValues);
|
|
773
|
+
if (numericActual.length === 0) {
|
|
774
|
+
throw new AssertionError(
|
|
775
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
|
|
776
|
+
void 0,
|
|
777
|
+
void 0,
|
|
778
|
+
this.fieldName
|
|
779
|
+
);
|
|
795
780
|
}
|
|
796
|
-
|
|
781
|
+
if (numericExpected.length === 0) {
|
|
782
|
+
throw new AssertionError(
|
|
783
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
|
|
784
|
+
void 0,
|
|
785
|
+
void 0,
|
|
786
|
+
this.fieldName
|
|
787
|
+
);
|
|
788
|
+
}
|
|
789
|
+
if (numericActual.length !== numericExpected.length) {
|
|
790
|
+
throw new AssertionError(
|
|
791
|
+
`Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
|
|
792
|
+
numericExpected.length,
|
|
793
|
+
numericActual.length,
|
|
794
|
+
this.fieldName
|
|
795
|
+
);
|
|
796
|
+
}
|
|
797
|
+
return { actual: numericActual, expected: numericExpected };
|
|
797
798
|
}
|
|
799
|
+
// ============================================================================
|
|
800
|
+
// Classification Metric Getters
|
|
801
|
+
// ============================================================================
|
|
798
802
|
/**
|
|
799
|
-
*
|
|
800
|
-
* @
|
|
801
|
-
*
|
|
803
|
+
* Access accuracy metric for assertions
|
|
804
|
+
* @example
|
|
805
|
+
* expectStats(predictions, groundTruth)
|
|
806
|
+
* .field("sentiment")
|
|
807
|
+
* .accuracy.toBeAtLeast(0.8)
|
|
802
808
|
*/
|
|
803
|
-
|
|
809
|
+
get accuracy() {
|
|
804
810
|
this.validateGroundTruth();
|
|
805
811
|
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
}
|
|
813
|
-
targetClass = classOrThreshold;
|
|
814
|
-
actualThreshold = threshold;
|
|
815
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
816
|
-
if (!classMetrics) {
|
|
817
|
-
throw new AssertionError(
|
|
818
|
-
`Class "${targetClass}" not found in predictions`,
|
|
819
|
-
targetClass,
|
|
820
|
-
Object.keys(metrics.perClass),
|
|
821
|
-
this.fieldName
|
|
822
|
-
);
|
|
823
|
-
}
|
|
824
|
-
actualPrecision = classMetrics.precision;
|
|
825
|
-
}
|
|
826
|
-
const passed = actualPrecision >= actualThreshold;
|
|
827
|
-
const result = {
|
|
828
|
-
type: "precision",
|
|
829
|
-
passed,
|
|
830
|
-
message: passed ? `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
831
|
-
expected: actualThreshold,
|
|
832
|
-
actual: actualPrecision,
|
|
833
|
-
field: this.fieldName,
|
|
834
|
-
class: targetClass
|
|
835
|
-
};
|
|
836
|
-
this.assertions.push(result);
|
|
837
|
-
recordAssertion(result);
|
|
838
|
-
if (!passed) {
|
|
839
|
-
throw new AssertionError(result.message, actualThreshold, actualPrecision, this.fieldName);
|
|
840
|
-
}
|
|
841
|
-
return this;
|
|
812
|
+
return new MetricMatcher({
|
|
813
|
+
parent: this,
|
|
814
|
+
metricName: "Accuracy",
|
|
815
|
+
metricValue: metrics.accuracy,
|
|
816
|
+
fieldName: this.fieldName,
|
|
817
|
+
assertions: this.assertions
|
|
818
|
+
});
|
|
842
819
|
}
|
|
843
820
|
/**
|
|
844
|
-
*
|
|
845
|
-
* @
|
|
846
|
-
*
|
|
821
|
+
* Access F1 score metric for assertions (macro average)
|
|
822
|
+
* @example
|
|
823
|
+
* expectStats(predictions, groundTruth)
|
|
824
|
+
* .field("sentiment")
|
|
825
|
+
* .f1.toBeAtLeast(0.75)
|
|
847
826
|
*/
|
|
848
|
-
|
|
827
|
+
get f1() {
|
|
849
828
|
this.validateGroundTruth();
|
|
850
829
|
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
830
|
+
return new MetricMatcher({
|
|
831
|
+
parent: this,
|
|
832
|
+
metricName: "F1",
|
|
833
|
+
metricValue: metrics.macroAvg.f1,
|
|
834
|
+
fieldName: this.fieldName,
|
|
835
|
+
assertions: this.assertions
|
|
836
|
+
});
|
|
837
|
+
}
|
|
838
|
+
/**
|
|
839
|
+
* Access precision metric for assertions
|
|
840
|
+
* @param targetClass - Optional class name. If omitted, uses macro average
|
|
841
|
+
* @example
|
|
842
|
+
* expectStats(predictions, groundTruth)
|
|
843
|
+
* .field("sentiment")
|
|
844
|
+
* .precision("positive").toBeAtLeast(0.7)
|
|
845
|
+
*/
|
|
846
|
+
precision(targetClass) {
|
|
847
|
+
this.validateGroundTruth();
|
|
848
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
849
|
+
let metricValue;
|
|
850
|
+
if (targetClass === void 0) {
|
|
851
|
+
metricValue = metrics.macroAvg.precision;
|
|
857
852
|
} else {
|
|
858
|
-
targetClass = classOrThreshold;
|
|
859
|
-
actualThreshold = threshold;
|
|
860
853
|
const classMetrics = metrics.perClass[targetClass];
|
|
861
854
|
if (!classMetrics) {
|
|
862
855
|
throw new AssertionError(
|
|
@@ -866,42 +859,32 @@ var FieldSelector = class {
|
|
|
866
859
|
this.fieldName
|
|
867
860
|
);
|
|
868
861
|
}
|
|
869
|
-
|
|
870
|
-
}
|
|
871
|
-
const passed = actualRecall >= actualThreshold;
|
|
872
|
-
const result = {
|
|
873
|
-
type: "recall",
|
|
874
|
-
passed,
|
|
875
|
-
message: passed ? `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Recall${targetClass ? ` for "${targetClass}"` : ""} ${(actualRecall * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
876
|
-
expected: actualThreshold,
|
|
877
|
-
actual: actualRecall,
|
|
878
|
-
field: this.fieldName,
|
|
879
|
-
class: targetClass
|
|
880
|
-
};
|
|
881
|
-
this.assertions.push(result);
|
|
882
|
-
recordAssertion(result);
|
|
883
|
-
if (!passed) {
|
|
884
|
-
throw new AssertionError(result.message, actualThreshold, actualRecall, this.fieldName);
|
|
862
|
+
metricValue = classMetrics.precision;
|
|
885
863
|
}
|
|
886
|
-
return
|
|
864
|
+
return new MetricMatcher({
|
|
865
|
+
parent: this,
|
|
866
|
+
metricName: "Precision",
|
|
867
|
+
metricValue,
|
|
868
|
+
fieldName: this.fieldName,
|
|
869
|
+
targetClass,
|
|
870
|
+
assertions: this.assertions
|
|
871
|
+
});
|
|
887
872
|
}
|
|
888
873
|
/**
|
|
889
|
-
*
|
|
890
|
-
* @param
|
|
891
|
-
* @
|
|
874
|
+
* Access recall metric for assertions
|
|
875
|
+
* @param targetClass - Optional class name. If omitted, uses macro average
|
|
876
|
+
* @example
|
|
877
|
+
* expectStats(predictions, groundTruth)
|
|
878
|
+
* .field("sentiment")
|
|
879
|
+
* .recall("positive").toBeAtLeast(0.7)
|
|
892
880
|
*/
|
|
893
|
-
|
|
881
|
+
recall(targetClass) {
|
|
894
882
|
this.validateGroundTruth();
|
|
895
883
|
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
896
|
-
let
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
if (typeof classOrThreshold === "number") {
|
|
900
|
-
actualF1 = metrics.macroAvg.f1;
|
|
901
|
-
actualThreshold = classOrThreshold;
|
|
884
|
+
let metricValue;
|
|
885
|
+
if (targetClass === void 0) {
|
|
886
|
+
metricValue = metrics.macroAvg.recall;
|
|
902
887
|
} else {
|
|
903
|
-
targetClass = classOrThreshold;
|
|
904
|
-
actualThreshold = threshold;
|
|
905
888
|
const classMetrics = metrics.perClass[targetClass];
|
|
906
889
|
if (!classMetrics) {
|
|
907
890
|
throw new AssertionError(
|
|
@@ -911,272 +894,171 @@ var FieldSelector = class {
|
|
|
911
894
|
this.fieldName
|
|
912
895
|
);
|
|
913
896
|
}
|
|
914
|
-
|
|
915
|
-
}
|
|
916
|
-
const passed = actualF1 >= actualThreshold;
|
|
917
|
-
const result = {
|
|
918
|
-
type: "f1",
|
|
919
|
-
passed,
|
|
920
|
-
message: passed ? `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `F1${targetClass ? ` for "${targetClass}"` : ""} ${(actualF1 * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
921
|
-
expected: actualThreshold,
|
|
922
|
-
actual: actualF1,
|
|
923
|
-
field: this.fieldName,
|
|
924
|
-
class: targetClass
|
|
925
|
-
};
|
|
926
|
-
this.assertions.push(result);
|
|
927
|
-
recordAssertion(result);
|
|
928
|
-
if (!passed) {
|
|
929
|
-
throw new AssertionError(result.message, actualThreshold, actualF1, this.fieldName);
|
|
897
|
+
metricValue = classMetrics.recall;
|
|
930
898
|
}
|
|
931
|
-
return
|
|
899
|
+
return new MetricMatcher({
|
|
900
|
+
parent: this,
|
|
901
|
+
metricName: "Recall",
|
|
902
|
+
metricValue,
|
|
903
|
+
fieldName: this.fieldName,
|
|
904
|
+
targetClass,
|
|
905
|
+
assertions: this.assertions
|
|
906
|
+
});
|
|
932
907
|
}
|
|
908
|
+
// ============================================================================
|
|
909
|
+
// Regression Metric Getters
|
|
910
|
+
// ============================================================================
|
|
933
911
|
/**
|
|
934
|
-
*
|
|
912
|
+
* Access Mean Absolute Error metric for assertions
|
|
913
|
+
* @example
|
|
914
|
+
* expectStats(predictions, groundTruth)
|
|
915
|
+
* .field("score")
|
|
916
|
+
* .mae.toBeAtMost(0.1)
|
|
935
917
|
*/
|
|
936
|
-
|
|
937
|
-
const
|
|
938
|
-
const
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
message: `Confusion matrix recorded for field "${this.fieldName}"`,
|
|
948
|
-
field: this.fieldName
|
|
949
|
-
};
|
|
950
|
-
this.assertions.push(result);
|
|
951
|
-
recordAssertion(result);
|
|
952
|
-
return this;
|
|
918
|
+
get mae() {
|
|
919
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
920
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
921
|
+
return new MetricMatcher({
|
|
922
|
+
parent: this,
|
|
923
|
+
metricName: "MAE",
|
|
924
|
+
metricValue: metrics.mae,
|
|
925
|
+
fieldName: this.fieldName,
|
|
926
|
+
assertions: this.assertions,
|
|
927
|
+
formatValue: (v) => v.toFixed(4)
|
|
928
|
+
});
|
|
953
929
|
}
|
|
954
930
|
/**
|
|
955
|
-
*
|
|
956
|
-
*
|
|
957
|
-
*
|
|
931
|
+
* Access Root Mean Squared Error metric for assertions
|
|
932
|
+
* @example
|
|
933
|
+
* expectStats(predictions, groundTruth)
|
|
934
|
+
* .field("score")
|
|
935
|
+
* .rmse.toBeAtMost(0.15)
|
|
936
|
+
*/
|
|
937
|
+
get rmse() {
|
|
938
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
939
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
940
|
+
return new MetricMatcher({
|
|
941
|
+
parent: this,
|
|
942
|
+
metricName: "RMSE",
|
|
943
|
+
metricValue: metrics.rmse,
|
|
944
|
+
fieldName: this.fieldName,
|
|
945
|
+
assertions: this.assertions,
|
|
946
|
+
formatValue: (v) => v.toFixed(4)
|
|
947
|
+
});
|
|
948
|
+
}
|
|
949
|
+
/**
|
|
950
|
+
* Access R-squared (coefficient of determination) metric for assertions
|
|
951
|
+
* @example
|
|
952
|
+
* expectStats(predictions, groundTruth)
|
|
953
|
+
* .field("score")
|
|
954
|
+
* .r2.toBeAtLeast(0.8)
|
|
955
|
+
*/
|
|
956
|
+
get r2() {
|
|
957
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
958
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
959
|
+
return new MetricMatcher({
|
|
960
|
+
parent: this,
|
|
961
|
+
metricName: "R\xB2",
|
|
962
|
+
metricValue: metrics.r2,
|
|
963
|
+
fieldName: this.fieldName,
|
|
964
|
+
assertions: this.assertions,
|
|
965
|
+
formatValue: (v) => v.toFixed(4)
|
|
966
|
+
});
|
|
967
|
+
}
|
|
968
|
+
// ============================================================================
|
|
969
|
+
// Distribution Assertions
|
|
970
|
+
// ============================================================================
|
|
971
|
+
/**
|
|
972
|
+
* Assert on the percentage of values below or equal to a threshold
|
|
958
973
|
* @param valueThreshold - The value threshold to compare against
|
|
959
|
-
* @param percentageThreshold - The minimum percentage (0-1) of values that should be <= valueThreshold
|
|
960
|
-
* @returns this for method chaining
|
|
961
|
-
*
|
|
962
974
|
* @example
|
|
963
|
-
* // Assert that 90% of confidence scores are below 0.5
|
|
964
975
|
* expectStats(predictions)
|
|
965
976
|
* .field("confidence")
|
|
966
|
-
* .
|
|
977
|
+
* .percentageBelow(0.5).toBeAtLeast(0.9)
|
|
967
978
|
*/
|
|
968
|
-
|
|
979
|
+
percentageBelow(valueThreshold) {
|
|
969
980
|
const numericActual = filterNumericValues(this.actualValues);
|
|
970
981
|
if (numericActual.length === 0) {
|
|
971
982
|
throw new AssertionError(
|
|
972
983
|
`Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
|
|
973
|
-
|
|
984
|
+
void 0,
|
|
974
985
|
void 0,
|
|
975
986
|
this.fieldName
|
|
976
987
|
);
|
|
977
988
|
}
|
|
978
989
|
const actualPercentage = calculatePercentageBelow(numericActual, valueThreshold);
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
};
|
|
988
|
-
this.assertions.push(result);
|
|
989
|
-
recordAssertion(result);
|
|
990
|
-
if (!passed) {
|
|
991
|
-
throw new AssertionError(
|
|
992
|
-
result.message,
|
|
993
|
-
percentageThreshold,
|
|
994
|
-
actualPercentage,
|
|
995
|
-
this.fieldName
|
|
996
|
-
);
|
|
997
|
-
}
|
|
998
|
-
return this;
|
|
990
|
+
return new PercentageMatcher({
|
|
991
|
+
parent: this,
|
|
992
|
+
fieldName: this.fieldName,
|
|
993
|
+
valueThreshold,
|
|
994
|
+
direction: "below",
|
|
995
|
+
actualPercentage,
|
|
996
|
+
assertions: this.assertions
|
|
997
|
+
});
|
|
999
998
|
}
|
|
1000
999
|
/**
|
|
1001
|
-
*
|
|
1002
|
-
* This is a distributional assertion that only looks at actual values (no ground truth required).
|
|
1003
|
-
*
|
|
1000
|
+
* Assert on the percentage of values above a threshold
|
|
1004
1001
|
* @param valueThreshold - The value threshold to compare against
|
|
1005
|
-
* @param percentageThreshold - The minimum percentage (0-1) of values that should be > valueThreshold
|
|
1006
|
-
* @returns this for method chaining
|
|
1007
|
-
*
|
|
1008
1002
|
* @example
|
|
1009
|
-
* // Assert that 80% of quality scores are above 0.7
|
|
1010
1003
|
* expectStats(predictions)
|
|
1011
1004
|
* .field("quality")
|
|
1012
|
-
* .
|
|
1005
|
+
* .percentageAbove(0.7).toBeAtLeast(0.8)
|
|
1013
1006
|
*/
|
|
1014
|
-
|
|
1007
|
+
percentageAbove(valueThreshold) {
|
|
1015
1008
|
const numericActual = filterNumericValues(this.actualValues);
|
|
1016
1009
|
if (numericActual.length === 0) {
|
|
1017
1010
|
throw new AssertionError(
|
|
1018
1011
|
`Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
|
|
1019
|
-
|
|
1012
|
+
void 0,
|
|
1020
1013
|
void 0,
|
|
1021
1014
|
this.fieldName
|
|
1022
1015
|
);
|
|
1023
1016
|
}
|
|
1024
1017
|
const actualPercentage = calculatePercentageAbove(numericActual, valueThreshold);
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
};
|
|
1034
|
-
this.assertions.push(result);
|
|
1035
|
-
recordAssertion(result);
|
|
1036
|
-
if (!passed) {
|
|
1037
|
-
throw new AssertionError(
|
|
1038
|
-
result.message,
|
|
1039
|
-
percentageThreshold,
|
|
1040
|
-
actualPercentage,
|
|
1041
|
-
this.fieldName
|
|
1042
|
-
);
|
|
1043
|
-
}
|
|
1044
|
-
return this;
|
|
1018
|
+
return new PercentageMatcher({
|
|
1019
|
+
parent: this,
|
|
1020
|
+
fieldName: this.fieldName,
|
|
1021
|
+
valueThreshold,
|
|
1022
|
+
direction: "above",
|
|
1023
|
+
actualPercentage,
|
|
1024
|
+
assertions: this.assertions
|
|
1025
|
+
});
|
|
1045
1026
|
}
|
|
1046
1027
|
// ============================================================================
|
|
1047
|
-
//
|
|
1028
|
+
// Display Methods
|
|
1048
1029
|
// ============================================================================
|
|
1049
1030
|
/**
|
|
1050
|
-
*
|
|
1051
|
-
*
|
|
1052
|
-
*/
|
|
1053
|
-
validateRegressionInputs() {
|
|
1054
|
-
this.validateGroundTruth();
|
|
1055
|
-
const numericActual = filterNumericValues(this.actualValues);
|
|
1056
|
-
const numericExpected = filterNumericValues(this.expectedValues);
|
|
1057
|
-
if (numericActual.length === 0) {
|
|
1058
|
-
throw new AssertionError(
|
|
1059
|
-
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
|
|
1060
|
-
void 0,
|
|
1061
|
-
void 0,
|
|
1062
|
-
this.fieldName
|
|
1063
|
-
);
|
|
1064
|
-
}
|
|
1065
|
-
if (numericExpected.length === 0) {
|
|
1066
|
-
throw new AssertionError(
|
|
1067
|
-
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
|
|
1068
|
-
void 0,
|
|
1069
|
-
void 0,
|
|
1070
|
-
this.fieldName
|
|
1071
|
-
);
|
|
1072
|
-
}
|
|
1073
|
-
if (numericActual.length !== numericExpected.length) {
|
|
1074
|
-
throw new AssertionError(
|
|
1075
|
-
`Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
|
|
1076
|
-
numericExpected.length,
|
|
1077
|
-
numericActual.length,
|
|
1078
|
-
this.fieldName
|
|
1079
|
-
);
|
|
1080
|
-
}
|
|
1081
|
-
return { actual: numericActual, expected: numericExpected };
|
|
1082
|
-
}
|
|
1083
|
-
/**
|
|
1084
|
-
* Asserts that Mean Absolute Error is below a threshold.
|
|
1085
|
-
* Requires numeric values in both actual and expected.
|
|
1086
|
-
*
|
|
1087
|
-
* @param threshold - Maximum allowed MAE
|
|
1088
|
-
* @returns this for method chaining
|
|
1089
|
-
*
|
|
1090
|
-
* @example
|
|
1091
|
-
* expectStats(predictions, groundTruth)
|
|
1092
|
-
* .field("score")
|
|
1093
|
-
* .toHaveMAEBelow(0.1)
|
|
1094
|
-
*/
|
|
1095
|
-
toHaveMAEBelow(threshold) {
|
|
1096
|
-
const { actual, expected } = this.validateRegressionInputs();
|
|
1097
|
-
const metrics = computeRegressionMetrics(actual, expected);
|
|
1098
|
-
const passed = metrics.mae <= threshold;
|
|
1099
|
-
const result = {
|
|
1100
|
-
type: "mae",
|
|
1101
|
-
passed,
|
|
1102
|
-
message: passed ? `MAE ${metrics.mae.toFixed(4)} is below ${threshold}` : `MAE ${metrics.mae.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1103
|
-
expected: threshold,
|
|
1104
|
-
actual: metrics.mae,
|
|
1105
|
-
field: this.fieldName
|
|
1106
|
-
};
|
|
1107
|
-
this.assertions.push(result);
|
|
1108
|
-
recordAssertion(result);
|
|
1109
|
-
if (!passed) {
|
|
1110
|
-
throw new AssertionError(result.message, threshold, metrics.mae, this.fieldName);
|
|
1111
|
-
}
|
|
1112
|
-
return this;
|
|
1113
|
-
}
|
|
1114
|
-
/**
|
|
1115
|
-
* Asserts that Root Mean Squared Error is below a threshold.
|
|
1116
|
-
* Requires numeric values in both actual and expected.
|
|
1117
|
-
*
|
|
1118
|
-
* @param threshold - Maximum allowed RMSE
|
|
1119
|
-
* @returns this for method chaining
|
|
1120
|
-
*
|
|
1031
|
+
* Displays the confusion matrix in the report
|
|
1032
|
+
* This is not an assertion - it always passes and just records the matrix for display
|
|
1121
1033
|
* @example
|
|
1122
1034
|
* expectStats(predictions, groundTruth)
|
|
1123
|
-
* .field("
|
|
1124
|
-
* .
|
|
1035
|
+
* .field("sentiment")
|
|
1036
|
+
* .accuracy.toBeAtLeast(0.8)
|
|
1037
|
+
* .displayConfusionMatrix()
|
|
1125
1038
|
*/
|
|
1126
|
-
|
|
1127
|
-
const
|
|
1128
|
-
const
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
passed,
|
|
1133
|
-
message: passed ? `RMSE ${metrics.rmse.toFixed(4)} is below ${threshold}` : `RMSE ${metrics.rmse.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1134
|
-
expected: threshold,
|
|
1135
|
-
actual: metrics.rmse,
|
|
1136
|
-
field: this.fieldName
|
|
1039
|
+
displayConfusionMatrix() {
|
|
1040
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
1041
|
+
const fieldResult = {
|
|
1042
|
+
field: this.fieldName,
|
|
1043
|
+
metrics,
|
|
1044
|
+
binarized: false
|
|
1137
1045
|
};
|
|
1138
|
-
|
|
1139
|
-
recordAssertion(result);
|
|
1140
|
-
if (!passed) {
|
|
1141
|
-
throw new AssertionError(result.message, threshold, metrics.rmse, this.fieldName);
|
|
1142
|
-
}
|
|
1143
|
-
return this;
|
|
1144
|
-
}
|
|
1145
|
-
/**
|
|
1146
|
-
* Asserts that R-squared (coefficient of determination) is above a threshold.
|
|
1147
|
-
* R² measures how well the predictions explain the variance in expected values.
|
|
1148
|
-
* R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
|
|
1149
|
-
* Requires numeric values in both actual and expected.
|
|
1150
|
-
*
|
|
1151
|
-
* @param threshold - Minimum required R² value (0-1)
|
|
1152
|
-
* @returns this for method chaining
|
|
1153
|
-
*
|
|
1154
|
-
* @example
|
|
1155
|
-
* expectStats(predictions, groundTruth)
|
|
1156
|
-
* .field("score")
|
|
1157
|
-
* .toHaveR2Above(0.8)
|
|
1158
|
-
*/
|
|
1159
|
-
toHaveR2Above(threshold) {
|
|
1160
|
-
const { actual, expected } = this.validateRegressionInputs();
|
|
1161
|
-
const metrics = computeRegressionMetrics(actual, expected);
|
|
1162
|
-
const passed = metrics.r2 >= threshold;
|
|
1046
|
+
recordFieldMetrics(fieldResult);
|
|
1163
1047
|
const result = {
|
|
1164
|
-
type: "
|
|
1165
|
-
passed,
|
|
1166
|
-
message:
|
|
1167
|
-
expected: threshold,
|
|
1168
|
-
actual: metrics.r2,
|
|
1048
|
+
type: "confusionMatrix",
|
|
1049
|
+
passed: true,
|
|
1050
|
+
message: `Confusion matrix recorded for field "${this.fieldName}"`,
|
|
1169
1051
|
field: this.fieldName
|
|
1170
1052
|
};
|
|
1171
1053
|
this.assertions.push(result);
|
|
1172
1054
|
recordAssertion(result);
|
|
1173
|
-
if (!passed) {
|
|
1174
|
-
throw new AssertionError(result.message, threshold, metrics.r2, this.fieldName);
|
|
1175
|
-
}
|
|
1176
1055
|
return this;
|
|
1177
1056
|
}
|
|
1057
|
+
// ============================================================================
|
|
1058
|
+
// Utility Methods
|
|
1059
|
+
// ============================================================================
|
|
1178
1060
|
/**
|
|
1179
|
-
* Gets the computed metrics for this field
|
|
1061
|
+
* Gets the computed classification metrics for this field
|
|
1180
1062
|
*/
|
|
1181
1063
|
getMetrics() {
|
|
1182
1064
|
return computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
@@ -1209,7 +1091,7 @@ function normalizeInput(input) {
|
|
|
1209
1091
|
}));
|
|
1210
1092
|
}
|
|
1211
1093
|
throw new Error(
|
|
1212
|
-
"Invalid input to expectStats(): expected
|
|
1094
|
+
"Invalid input to expectStats(): expected { aligned: AlignedRecord[] }, Prediction[], or AlignedRecord[]"
|
|
1213
1095
|
);
|
|
1214
1096
|
}
|
|
1215
1097
|
function expectStats(inputOrActual, expected, options) {
|
|
@@ -1254,6 +1136,6 @@ var ExpectStats = class {
|
|
|
1254
1136
|
}
|
|
1255
1137
|
};
|
|
1256
1138
|
|
|
1257
|
-
export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall,
|
|
1139
|
+
export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, describe, evalTest, expectStats, extractFieldValues, filterComplete, it, test, validatePredictions };
|
|
1258
1140
|
//# sourceMappingURL=index.js.map
|
|
1259
1141
|
//# sourceMappingURL=index.js.map
|