evalsense 0.3.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +235 -98
- package/dist/{chunk-BFGA2NUB.cjs → chunk-4BKZPVY4.cjs} +13 -6
- package/dist/chunk-4BKZPVY4.cjs.map +1 -0
- package/dist/{chunk-IYLSY7NX.js → chunk-IUVDDMJ3.js} +13 -6
- package/dist/chunk-IUVDDMJ3.js.map +1 -0
- package/dist/chunk-NCCQRZ2Y.cjs +1141 -0
- package/dist/chunk-NCCQRZ2Y.cjs.map +1 -0
- package/dist/chunk-TDGWDK2L.js +1108 -0
- package/dist/chunk-TDGWDK2L.js.map +1 -0
- package/dist/cli.cjs +11 -11
- package/dist/cli.js +1 -1
- package/dist/index-CATqAHNK.d.cts +416 -0
- package/dist/index-CoMpaW-K.d.ts +416 -0
- package/dist/index.cjs +507 -580
- package/dist/index.cjs.map +1 -1
- package/dist/index.d.cts +210 -161
- package/dist/index.d.ts +210 -161
- package/dist/index.js +455 -524
- package/dist/index.js.map +1 -1
- package/dist/metrics/index.cjs +103 -342
- package/dist/metrics/index.cjs.map +1 -1
- package/dist/metrics/index.d.cts +260 -31
- package/dist/metrics/index.d.ts +260 -31
- package/dist/metrics/index.js +24 -312
- package/dist/metrics/index.js.map +1 -1
- package/dist/metrics/opinionated/index.cjs +5 -5
- package/dist/metrics/opinionated/index.d.cts +2 -163
- package/dist/metrics/opinionated/index.d.ts +2 -163
- package/dist/metrics/opinionated/index.js +1 -1
- package/dist/{types-C71p0wzM.d.cts → types-D0hzfyKm.d.cts} +1 -13
- package/dist/{types-C71p0wzM.d.ts → types-D0hzfyKm.d.ts} +1 -13
- package/package.json +1 -1
- package/dist/chunk-BFGA2NUB.cjs.map +0 -1
- package/dist/chunk-IYLSY7NX.js.map +0 -1
- package/dist/chunk-RZFLCWTW.cjs +0 -942
- package/dist/chunk-RZFLCWTW.cjs.map +0 -1
- package/dist/chunk-Z3U6AUWX.js +0 -925
- package/dist/chunk-Z3U6AUWX.js.map +0 -1
package/dist/index.js
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
|
-
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite,
|
|
2
|
-
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-
|
|
1
|
+
import { getCurrentSuite, setCurrentSuite, addSuite, addTestToCurrentSuite, IntegrityError, buildConfusionMatrix, getTruePositives, getFalsePositives, getFalseNegatives, getSupport, AssertionError, recordFieldMetrics, recordAssertion } from './chunk-IUVDDMJ3.js';
|
|
2
|
+
export { AssertionError, ConfigurationError, ConsoleReporter, DatasetError, EvalSenseError, ExitCodes, IntegrityError, JsonReporter, TestExecutionError, buildConfusionMatrix, discoverEvalFiles, executeEvalFiles, formatConfusionMatrix, getExitCode, parseReport } from './chunk-IUVDDMJ3.js';
|
|
3
3
|
import './chunk-DGUM43GV.js';
|
|
4
|
-
import { readFileSync } from 'fs';
|
|
5
|
-
import { resolve, extname } from 'path';
|
|
6
4
|
|
|
7
5
|
// src/core/describe.ts
|
|
8
6
|
function describe(name, fn) {
|
|
@@ -91,136 +89,6 @@ function evalTestOnly(name, fn) {
|
|
|
91
89
|
}
|
|
92
90
|
evalTest.skip = evalTestSkip;
|
|
93
91
|
evalTest.only = evalTestOnly;
|
|
94
|
-
function loadDataset(path) {
|
|
95
|
-
const absolutePath = resolve(process.cwd(), path);
|
|
96
|
-
const ext = extname(absolutePath).toLowerCase();
|
|
97
|
-
let records;
|
|
98
|
-
try {
|
|
99
|
-
const content = readFileSync(absolutePath, "utf-8");
|
|
100
|
-
if (ext === ".ndjson" || ext === ".jsonl") {
|
|
101
|
-
records = parseNDJSON(content);
|
|
102
|
-
} else if (ext === ".json") {
|
|
103
|
-
records = parseJSON(content);
|
|
104
|
-
} else {
|
|
105
|
-
throw new DatasetError(
|
|
106
|
-
`Unsupported file format: ${ext}. Use .json, .ndjson, or .jsonl`,
|
|
107
|
-
path
|
|
108
|
-
);
|
|
109
|
-
}
|
|
110
|
-
} catch (error) {
|
|
111
|
-
if (error instanceof DatasetError) {
|
|
112
|
-
throw error;
|
|
113
|
-
}
|
|
114
|
-
const message = error instanceof Error ? error.message : String(error);
|
|
115
|
-
throw new DatasetError(`Failed to load dataset from ${path}: ${message}`, path);
|
|
116
|
-
}
|
|
117
|
-
return {
|
|
118
|
-
records,
|
|
119
|
-
metadata: {
|
|
120
|
-
source: path,
|
|
121
|
-
count: records.length,
|
|
122
|
-
loadedAt: /* @__PURE__ */ new Date()
|
|
123
|
-
}
|
|
124
|
-
};
|
|
125
|
-
}
|
|
126
|
-
function parseJSON(content) {
|
|
127
|
-
const parsed = JSON.parse(content);
|
|
128
|
-
if (!Array.isArray(parsed)) {
|
|
129
|
-
throw new DatasetError("JSON dataset must be an array of records");
|
|
130
|
-
}
|
|
131
|
-
return parsed;
|
|
132
|
-
}
|
|
133
|
-
function parseNDJSON(content) {
|
|
134
|
-
const lines = content.split("\n").filter((line) => line.trim() !== "");
|
|
135
|
-
const records = [];
|
|
136
|
-
for (let i = 0; i < lines.length; i++) {
|
|
137
|
-
const line = lines[i];
|
|
138
|
-
if (line === void 0) continue;
|
|
139
|
-
try {
|
|
140
|
-
records.push(JSON.parse(line));
|
|
141
|
-
} catch {
|
|
142
|
-
throw new DatasetError(`Invalid JSON at line ${i + 1} in NDJSON file`);
|
|
143
|
-
}
|
|
144
|
-
}
|
|
145
|
-
return records;
|
|
146
|
-
}
|
|
147
|
-
function createDataset(records, source = "inline") {
|
|
148
|
-
return {
|
|
149
|
-
records,
|
|
150
|
-
metadata: {
|
|
151
|
-
source,
|
|
152
|
-
count: records.length,
|
|
153
|
-
loadedAt: /* @__PURE__ */ new Date()
|
|
154
|
-
}
|
|
155
|
-
};
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
// src/dataset/run-model.ts
|
|
159
|
-
async function runModel(dataset, modelFn) {
|
|
160
|
-
const startTime = Date.now();
|
|
161
|
-
const predictions = [];
|
|
162
|
-
const aligned = [];
|
|
163
|
-
for (const record of dataset.records) {
|
|
164
|
-
const id = getRecordId(record);
|
|
165
|
-
const prediction = await modelFn(record);
|
|
166
|
-
if (prediction.id !== id) {
|
|
167
|
-
throw new DatasetError(
|
|
168
|
-
`Prediction ID mismatch: expected "${id}", got "${prediction.id}". Model function must return the same ID as the input record.`
|
|
169
|
-
);
|
|
170
|
-
}
|
|
171
|
-
predictions.push(prediction);
|
|
172
|
-
aligned.push({
|
|
173
|
-
id,
|
|
174
|
-
actual: { ...prediction },
|
|
175
|
-
expected: { ...record }
|
|
176
|
-
});
|
|
177
|
-
}
|
|
178
|
-
return {
|
|
179
|
-
predictions,
|
|
180
|
-
aligned,
|
|
181
|
-
duration: Date.now() - startTime
|
|
182
|
-
};
|
|
183
|
-
}
|
|
184
|
-
function getRecordId(record) {
|
|
185
|
-
const id = record.id ?? record._id;
|
|
186
|
-
if (id === void 0 || id === null) {
|
|
187
|
-
throw new DatasetError('Dataset records must have an "id" or "_id" field for alignment');
|
|
188
|
-
}
|
|
189
|
-
return String(id);
|
|
190
|
-
}
|
|
191
|
-
async function runModelParallel(dataset, modelFn, concurrency = 10) {
|
|
192
|
-
const startTime = Date.now();
|
|
193
|
-
const results = [];
|
|
194
|
-
for (let i = 0; i < dataset.records.length; i += concurrency) {
|
|
195
|
-
const batch = dataset.records.slice(i, i + concurrency);
|
|
196
|
-
const batchResults = await Promise.all(
|
|
197
|
-
batch.map(async (record) => {
|
|
198
|
-
const prediction = await modelFn(record);
|
|
199
|
-
return { prediction, record };
|
|
200
|
-
})
|
|
201
|
-
);
|
|
202
|
-
results.push(...batchResults);
|
|
203
|
-
}
|
|
204
|
-
const predictions = [];
|
|
205
|
-
const aligned = [];
|
|
206
|
-
for (const { prediction, record } of results) {
|
|
207
|
-
const id = getRecordId(record);
|
|
208
|
-
if (prediction.id !== id) {
|
|
209
|
-
throw new DatasetError(`Prediction ID mismatch: expected "${id}", got "${prediction.id}".`);
|
|
210
|
-
}
|
|
211
|
-
predictions.push(prediction);
|
|
212
|
-
aligned.push({
|
|
213
|
-
id,
|
|
214
|
-
actual: { ...prediction },
|
|
215
|
-
expected: { ...record }
|
|
216
|
-
});
|
|
217
|
-
}
|
|
218
|
-
return {
|
|
219
|
-
predictions,
|
|
220
|
-
aligned,
|
|
221
|
-
duration: Date.now() - startTime
|
|
222
|
-
};
|
|
223
|
-
}
|
|
224
92
|
|
|
225
93
|
// src/dataset/alignment.ts
|
|
226
94
|
function alignByKey(predictions, expected, options = {}) {
|
|
@@ -293,14 +161,14 @@ function filterComplete(aligned, field) {
|
|
|
293
161
|
}
|
|
294
162
|
|
|
295
163
|
// src/dataset/integrity.ts
|
|
296
|
-
function checkIntegrity(
|
|
164
|
+
function checkIntegrity(records, options = {}) {
|
|
297
165
|
const { requiredFields = [], throwOnFailure = false } = options;
|
|
298
166
|
const seenIds = /* @__PURE__ */ new Map();
|
|
299
167
|
const missingIds = [];
|
|
300
168
|
const duplicateIds = [];
|
|
301
169
|
const missingFields = [];
|
|
302
|
-
for (let i = 0; i <
|
|
303
|
-
const record =
|
|
170
|
+
for (let i = 0; i < records.length; i++) {
|
|
171
|
+
const record = records[i];
|
|
304
172
|
if (!record) continue;
|
|
305
173
|
const id = record.id ?? record._id;
|
|
306
174
|
if (id === void 0 || id === null) {
|
|
@@ -327,7 +195,7 @@ function checkIntegrity(dataset, options = {}) {
|
|
|
327
195
|
const valid = missingIds.length === 0 && duplicateIds.length === 0 && missingFields.length === 0;
|
|
328
196
|
const result = {
|
|
329
197
|
valid,
|
|
330
|
-
totalRecords:
|
|
198
|
+
totalRecords: records.length,
|
|
331
199
|
missingIds,
|
|
332
200
|
duplicateIds,
|
|
333
201
|
missingFields
|
|
@@ -520,6 +388,91 @@ function calculatePercentageAbove(values, threshold) {
|
|
|
520
388
|
return countAbove / values.length;
|
|
521
389
|
}
|
|
522
390
|
|
|
391
|
+
// src/assertions/metric-matcher.ts
|
|
392
|
+
var MetricMatcher = class {
|
|
393
|
+
context;
|
|
394
|
+
constructor(context) {
|
|
395
|
+
this.context = context;
|
|
396
|
+
}
|
|
397
|
+
formatMetricValue(value) {
|
|
398
|
+
if (this.context.formatValue) {
|
|
399
|
+
return this.context.formatValue(value);
|
|
400
|
+
}
|
|
401
|
+
if (value >= 0 && value <= 1) {
|
|
402
|
+
return `${(value * 100).toFixed(1)}%`;
|
|
403
|
+
}
|
|
404
|
+
return value.toFixed(4);
|
|
405
|
+
}
|
|
406
|
+
createAssertion(operator, threshold, passed) {
|
|
407
|
+
const { metricName, metricValue, fieldName, targetClass } = this.context;
|
|
408
|
+
const formattedActual = this.formatMetricValue(metricValue);
|
|
409
|
+
const formattedThreshold = this.formatMetricValue(threshold);
|
|
410
|
+
const classInfo = targetClass ? ` for "${targetClass}"` : "";
|
|
411
|
+
const operatorText = {
|
|
412
|
+
">=": "at least",
|
|
413
|
+
">": "above",
|
|
414
|
+
"<=": "at most",
|
|
415
|
+
"<": "below",
|
|
416
|
+
"===": "equal to"
|
|
417
|
+
}[operator];
|
|
418
|
+
const message = passed ? `${metricName}${classInfo} ${formattedActual} is ${operatorText} ${formattedThreshold}` : `${metricName}${classInfo} ${formattedActual} is not ${operatorText} ${formattedThreshold}`;
|
|
419
|
+
return {
|
|
420
|
+
type: metricName.toLowerCase().replace(/\s+/g, "").replace(/²/g, "2"),
|
|
421
|
+
passed,
|
|
422
|
+
message,
|
|
423
|
+
expected: threshold,
|
|
424
|
+
actual: metricValue,
|
|
425
|
+
field: fieldName,
|
|
426
|
+
class: targetClass
|
|
427
|
+
};
|
|
428
|
+
}
|
|
429
|
+
recordAndReturn(result) {
|
|
430
|
+
this.context.assertions.push(result);
|
|
431
|
+
recordAssertion(result);
|
|
432
|
+
return this.context.parent;
|
|
433
|
+
}
|
|
434
|
+
/**
|
|
435
|
+
* Assert that the metric is greater than or equal to the threshold (>=)
|
|
436
|
+
*/
|
|
437
|
+
toBeAtLeast(threshold) {
|
|
438
|
+
const passed = this.context.metricValue >= threshold;
|
|
439
|
+
const result = this.createAssertion(">=", threshold, passed);
|
|
440
|
+
return this.recordAndReturn(result);
|
|
441
|
+
}
|
|
442
|
+
/**
|
|
443
|
+
* Assert that the metric is strictly greater than the threshold (>)
|
|
444
|
+
*/
|
|
445
|
+
toBeAbove(threshold) {
|
|
446
|
+
const passed = this.context.metricValue > threshold;
|
|
447
|
+
const result = this.createAssertion(">", threshold, passed);
|
|
448
|
+
return this.recordAndReturn(result);
|
|
449
|
+
}
|
|
450
|
+
/**
|
|
451
|
+
* Assert that the metric is less than or equal to the threshold (<=)
|
|
452
|
+
*/
|
|
453
|
+
toBeAtMost(threshold) {
|
|
454
|
+
const passed = this.context.metricValue <= threshold;
|
|
455
|
+
const result = this.createAssertion("<=", threshold, passed);
|
|
456
|
+
return this.recordAndReturn(result);
|
|
457
|
+
}
|
|
458
|
+
/**
|
|
459
|
+
* Assert that the metric is strictly less than the threshold (<)
|
|
460
|
+
*/
|
|
461
|
+
toBeBelow(threshold) {
|
|
462
|
+
const passed = this.context.metricValue < threshold;
|
|
463
|
+
const result = this.createAssertion("<", threshold, passed);
|
|
464
|
+
return this.recordAndReturn(result);
|
|
465
|
+
}
|
|
466
|
+
/**
|
|
467
|
+
* Assert that the metric equals the expected value (with optional tolerance for floats)
|
|
468
|
+
*/
|
|
469
|
+
toEqual(expected, tolerance = 1e-9) {
|
|
470
|
+
const passed = Math.abs(this.context.metricValue - expected) <= tolerance;
|
|
471
|
+
const result = this.createAssertion("===", expected, passed);
|
|
472
|
+
return this.recordAndReturn(result);
|
|
473
|
+
}
|
|
474
|
+
};
|
|
475
|
+
|
|
523
476
|
// src/assertions/binarize.ts
|
|
524
477
|
var BinarizeSelector = class {
|
|
525
478
|
fieldName;
|
|
@@ -551,149 +504,127 @@ var BinarizeSelector = class {
|
|
|
551
504
|
}
|
|
552
505
|
}
|
|
553
506
|
}
|
|
507
|
+
// ============================================================================
|
|
508
|
+
// Classification Metric Getters
|
|
509
|
+
// ============================================================================
|
|
554
510
|
/**
|
|
555
|
-
*
|
|
511
|
+
* Access accuracy metric for assertions
|
|
512
|
+
* @example
|
|
513
|
+
* expectStats(predictions, groundTruth)
|
|
514
|
+
* .field("score")
|
|
515
|
+
* .binarize(0.5)
|
|
516
|
+
* .accuracy.toBeAtLeast(0.8)
|
|
556
517
|
*/
|
|
557
|
-
|
|
518
|
+
get accuracy() {
|
|
558
519
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
field: this.fieldName
|
|
567
|
-
};
|
|
568
|
-
this.assertions.push(result);
|
|
569
|
-
recordAssertion(result);
|
|
570
|
-
return this;
|
|
520
|
+
return new MetricMatcher({
|
|
521
|
+
parent: this,
|
|
522
|
+
metricName: "Accuracy",
|
|
523
|
+
metricValue: metrics.accuracy,
|
|
524
|
+
fieldName: this.fieldName,
|
|
525
|
+
assertions: this.assertions
|
|
526
|
+
});
|
|
571
527
|
}
|
|
572
528
|
/**
|
|
573
|
-
*
|
|
574
|
-
* @
|
|
575
|
-
*
|
|
529
|
+
* Access F1 score metric for assertions (macro average)
|
|
530
|
+
* @example
|
|
531
|
+
* expectStats(predictions, groundTruth)
|
|
532
|
+
* .field("score")
|
|
533
|
+
* .binarize(0.5)
|
|
534
|
+
* .f1.toBeAtLeast(0.75)
|
|
576
535
|
*/
|
|
577
|
-
|
|
536
|
+
get f1() {
|
|
578
537
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
}
|
|
586
|
-
targetClass = String(classOrThreshold);
|
|
587
|
-
actualThreshold = threshold;
|
|
588
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
589
|
-
if (!classMetrics) {
|
|
590
|
-
throw new AssertionError(
|
|
591
|
-
`Class "${targetClass}" not found in binarized predictions`,
|
|
592
|
-
targetClass,
|
|
593
|
-
Object.keys(metrics.perClass),
|
|
594
|
-
this.fieldName
|
|
595
|
-
);
|
|
596
|
-
}
|
|
597
|
-
actualPrecision = classMetrics.precision;
|
|
598
|
-
}
|
|
599
|
-
const passed = actualPrecision >= actualThreshold;
|
|
600
|
-
const result = {
|
|
601
|
-
type: "precision",
|
|
602
|
-
passed,
|
|
603
|
-
message: passed ? `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for ${targetClass}` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
604
|
-
expected: actualThreshold,
|
|
605
|
-
actual: actualPrecision,
|
|
606
|
-
field: this.fieldName,
|
|
607
|
-
class: targetClass
|
|
608
|
-
};
|
|
609
|
-
this.assertions.push(result);
|
|
610
|
-
recordAssertion(result);
|
|
611
|
-
return this;
|
|
538
|
+
return new MetricMatcher({
|
|
539
|
+
parent: this,
|
|
540
|
+
metricName: "F1",
|
|
541
|
+
metricValue: metrics.macroAvg.f1,
|
|
542
|
+
fieldName: this.fieldName,
|
|
543
|
+
assertions: this.assertions
|
|
544
|
+
});
|
|
612
545
|
}
|
|
613
546
|
/**
|
|
614
|
-
*
|
|
615
|
-
* @param
|
|
616
|
-
* @
|
|
547
|
+
* Access precision metric for assertions
|
|
548
|
+
* @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
|
|
549
|
+
* @example
|
|
550
|
+
* expectStats(predictions, groundTruth)
|
|
551
|
+
* .field("score")
|
|
552
|
+
* .binarize(0.5)
|
|
553
|
+
* .precision(true).toBeAtLeast(0.7)
|
|
617
554
|
*/
|
|
618
|
-
|
|
555
|
+
precision(targetClass) {
|
|
619
556
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
620
|
-
let
|
|
621
|
-
let
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
actualRecall = metrics.macroAvg.recall;
|
|
625
|
-
actualThreshold = classOrThreshold;
|
|
557
|
+
let metricValue;
|
|
558
|
+
let classKey;
|
|
559
|
+
if (targetClass === void 0) {
|
|
560
|
+
metricValue = metrics.macroAvg.precision;
|
|
626
561
|
} else {
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
562
|
+
classKey = String(targetClass);
|
|
563
|
+
const classMetrics = metrics.perClass[classKey];
|
|
630
564
|
if (!classMetrics) {
|
|
631
565
|
throw new AssertionError(
|
|
632
|
-
`Class "${
|
|
633
|
-
|
|
566
|
+
`Class "${classKey}" not found in binarized predictions`,
|
|
567
|
+
classKey,
|
|
634
568
|
Object.keys(metrics.perClass),
|
|
635
569
|
this.fieldName
|
|
636
570
|
);
|
|
637
571
|
}
|
|
638
|
-
|
|
572
|
+
metricValue = classMetrics.precision;
|
|
639
573
|
}
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
class: targetClass
|
|
649
|
-
};
|
|
650
|
-
this.assertions.push(result);
|
|
651
|
-
recordAssertion(result);
|
|
652
|
-
return this;
|
|
574
|
+
return new MetricMatcher({
|
|
575
|
+
parent: this,
|
|
576
|
+
metricName: "Precision",
|
|
577
|
+
metricValue,
|
|
578
|
+
fieldName: this.fieldName,
|
|
579
|
+
targetClass: classKey,
|
|
580
|
+
assertions: this.assertions
|
|
581
|
+
});
|
|
653
582
|
}
|
|
654
583
|
/**
|
|
655
|
-
*
|
|
584
|
+
* Access recall metric for assertions
|
|
585
|
+
* @param targetClass - Optional boolean class (true/false). If omitted, uses macro average
|
|
586
|
+
* @example
|
|
587
|
+
* expectStats(predictions, groundTruth)
|
|
588
|
+
* .field("score")
|
|
589
|
+
* .binarize(0.5)
|
|
590
|
+
* .recall(true).toBeAtLeast(0.7)
|
|
656
591
|
*/
|
|
657
|
-
|
|
592
|
+
recall(targetClass) {
|
|
658
593
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
659
|
-
let
|
|
660
|
-
let
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
actualF1 = metrics.macroAvg.f1;
|
|
664
|
-
actualThreshold = classOrThreshold;
|
|
594
|
+
let metricValue;
|
|
595
|
+
let classKey;
|
|
596
|
+
if (targetClass === void 0) {
|
|
597
|
+
metricValue = metrics.macroAvg.recall;
|
|
665
598
|
} else {
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
599
|
+
classKey = String(targetClass);
|
|
600
|
+
const classMetrics = metrics.perClass[classKey];
|
|
669
601
|
if (!classMetrics) {
|
|
670
602
|
throw new AssertionError(
|
|
671
|
-
`Class "${
|
|
672
|
-
|
|
603
|
+
`Class "${classKey}" not found in binarized predictions`,
|
|
604
|
+
classKey,
|
|
673
605
|
Object.keys(metrics.perClass),
|
|
674
606
|
this.fieldName
|
|
675
607
|
);
|
|
676
608
|
}
|
|
677
|
-
|
|
609
|
+
metricValue = classMetrics.recall;
|
|
678
610
|
}
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
class: targetClass
|
|
688
|
-
};
|
|
689
|
-
this.assertions.push(result);
|
|
690
|
-
recordAssertion(result);
|
|
691
|
-
return this;
|
|
611
|
+
return new MetricMatcher({
|
|
612
|
+
parent: this,
|
|
613
|
+
metricName: "Recall",
|
|
614
|
+
metricValue,
|
|
615
|
+
fieldName: this.fieldName,
|
|
616
|
+
targetClass: classKey,
|
|
617
|
+
assertions: this.assertions
|
|
618
|
+
});
|
|
692
619
|
}
|
|
620
|
+
// ============================================================================
|
|
621
|
+
// Display Methods
|
|
622
|
+
// ============================================================================
|
|
693
623
|
/**
|
|
694
|
-
*
|
|
624
|
+
* Displays the confusion matrix in the report
|
|
625
|
+
* This is not an assertion - it always passes and just records the matrix for display
|
|
695
626
|
*/
|
|
696
|
-
|
|
627
|
+
displayConfusionMatrix() {
|
|
697
628
|
const metrics = computeClassificationMetrics(this.binaryActual, this.binaryExpected);
|
|
698
629
|
const fieldResult = {
|
|
699
630
|
field: this.fieldName,
|
|
@@ -712,6 +643,9 @@ var BinarizeSelector = class {
|
|
|
712
643
|
recordAssertion(result);
|
|
713
644
|
return this;
|
|
714
645
|
}
|
|
646
|
+
// ============================================================================
|
|
647
|
+
// Utility Methods
|
|
648
|
+
// ============================================================================
|
|
715
649
|
/**
|
|
716
650
|
* Gets computed metrics
|
|
717
651
|
*/
|
|
@@ -726,6 +660,73 @@ var BinarizeSelector = class {
|
|
|
726
660
|
}
|
|
727
661
|
};
|
|
728
662
|
|
|
663
|
+
// src/assertions/percentage-matcher.ts
|
|
664
|
+
var PercentageMatcher = class {
|
|
665
|
+
context;
|
|
666
|
+
constructor(context) {
|
|
667
|
+
this.context = context;
|
|
668
|
+
}
|
|
669
|
+
formatPercentage(value) {
|
|
670
|
+
return `${(value * 100).toFixed(1)}%`;
|
|
671
|
+
}
|
|
672
|
+
createAssertion(operator, percentageThreshold, passed) {
|
|
673
|
+
const { fieldName, valueThreshold, direction, actualPercentage } = this.context;
|
|
674
|
+
const operatorText = {
|
|
675
|
+
">=": "at least",
|
|
676
|
+
">": "above",
|
|
677
|
+
"<=": "at most",
|
|
678
|
+
"<": "below"
|
|
679
|
+
}[operator];
|
|
680
|
+
const directionText = direction === "above" ? "above" : "below or equal to";
|
|
681
|
+
const message = passed ? `${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})` : `Only ${this.formatPercentage(actualPercentage)} of '${fieldName}' values are ${directionText} ${valueThreshold} (expected ${operatorText} ${this.formatPercentage(percentageThreshold)})`;
|
|
682
|
+
return {
|
|
683
|
+
type: direction === "above" ? "percentageAbove" : "percentageBelow",
|
|
684
|
+
passed,
|
|
685
|
+
message,
|
|
686
|
+
expected: percentageThreshold,
|
|
687
|
+
actual: actualPercentage,
|
|
688
|
+
field: fieldName
|
|
689
|
+
};
|
|
690
|
+
}
|
|
691
|
+
recordAndReturn(result) {
|
|
692
|
+
this.context.assertions.push(result);
|
|
693
|
+
recordAssertion(result);
|
|
694
|
+
return this.context.parent;
|
|
695
|
+
}
|
|
696
|
+
/**
|
|
697
|
+
* Assert that the percentage is greater than or equal to the threshold (>=)
|
|
698
|
+
*/
|
|
699
|
+
toBeAtLeast(percentageThreshold) {
|
|
700
|
+
const passed = this.context.actualPercentage >= percentageThreshold;
|
|
701
|
+
const result = this.createAssertion(">=", percentageThreshold, passed);
|
|
702
|
+
return this.recordAndReturn(result);
|
|
703
|
+
}
|
|
704
|
+
/**
|
|
705
|
+
* Assert that the percentage is strictly greater than the threshold (>)
|
|
706
|
+
*/
|
|
707
|
+
toBeAbove(percentageThreshold) {
|
|
708
|
+
const passed = this.context.actualPercentage > percentageThreshold;
|
|
709
|
+
const result = this.createAssertion(">", percentageThreshold, passed);
|
|
710
|
+
return this.recordAndReturn(result);
|
|
711
|
+
}
|
|
712
|
+
/**
|
|
713
|
+
* Assert that the percentage is less than or equal to the threshold (<=)
|
|
714
|
+
*/
|
|
715
|
+
toBeAtMost(percentageThreshold) {
|
|
716
|
+
const passed = this.context.actualPercentage <= percentageThreshold;
|
|
717
|
+
const result = this.createAssertion("<=", percentageThreshold, passed);
|
|
718
|
+
return this.recordAndReturn(result);
|
|
719
|
+
}
|
|
720
|
+
/**
|
|
721
|
+
* Assert that the percentage is strictly less than the threshold (<)
|
|
722
|
+
*/
|
|
723
|
+
toBeBelow(percentageThreshold) {
|
|
724
|
+
const passed = this.context.actualPercentage < percentageThreshold;
|
|
725
|
+
const result = this.createAssertion("<", percentageThreshold, passed);
|
|
726
|
+
return this.recordAndReturn(result);
|
|
727
|
+
}
|
|
728
|
+
};
|
|
729
|
+
|
|
729
730
|
// src/assertions/field-selector.ts
|
|
730
731
|
var FieldSelector = class {
|
|
731
732
|
aligned;
|
|
@@ -762,83 +763,93 @@ var FieldSelector = class {
|
|
|
762
763
|
}
|
|
763
764
|
}
|
|
764
765
|
/**
|
|
765
|
-
*
|
|
766
|
+
* Validates that ground truth exists and both arrays contain numeric values.
|
|
767
|
+
* Returns the filtered numeric arrays for regression metrics.
|
|
766
768
|
*/
|
|
767
|
-
|
|
769
|
+
validateRegressionInputs() {
|
|
770
|
+
this.validateGroundTruth();
|
|
771
|
+
const numericActual = filterNumericValues(this.actualValues);
|
|
772
|
+
const numericExpected = filterNumericValues(this.expectedValues);
|
|
773
|
+
if (numericActual.length === 0) {
|
|
774
|
+
throw new AssertionError(
|
|
775
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
|
|
776
|
+
void 0,
|
|
777
|
+
void 0,
|
|
778
|
+
this.fieldName
|
|
779
|
+
);
|
|
780
|
+
}
|
|
781
|
+
if (numericExpected.length === 0) {
|
|
782
|
+
throw new AssertionError(
|
|
783
|
+
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
|
|
784
|
+
void 0,
|
|
785
|
+
void 0,
|
|
786
|
+
this.fieldName
|
|
787
|
+
);
|
|
788
|
+
}
|
|
789
|
+
if (numericActual.length !== numericExpected.length) {
|
|
790
|
+
throw new AssertionError(
|
|
791
|
+
`Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
|
|
792
|
+
numericExpected.length,
|
|
793
|
+
numericActual.length,
|
|
794
|
+
this.fieldName
|
|
795
|
+
);
|
|
796
|
+
}
|
|
797
|
+
return { actual: numericActual, expected: numericExpected };
|
|
798
|
+
}
|
|
799
|
+
// ============================================================================
|
|
800
|
+
// Classification Metric Getters
|
|
801
|
+
// ============================================================================
|
|
802
|
+
/**
|
|
803
|
+
* Access accuracy metric for assertions
|
|
804
|
+
* @example
|
|
805
|
+
* expectStats(predictions, groundTruth)
|
|
806
|
+
* .field("sentiment")
|
|
807
|
+
* .accuracy.toBeAtLeast(0.8)
|
|
808
|
+
*/
|
|
809
|
+
get accuracy() {
|
|
768
810
|
this.validateGroundTruth();
|
|
769
811
|
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
field: this.fieldName
|
|
778
|
-
};
|
|
779
|
-
this.assertions.push(result);
|
|
780
|
-
recordAssertion(result);
|
|
781
|
-
return this;
|
|
812
|
+
return new MetricMatcher({
|
|
813
|
+
parent: this,
|
|
814
|
+
metricName: "Accuracy",
|
|
815
|
+
metricValue: metrics.accuracy,
|
|
816
|
+
fieldName: this.fieldName,
|
|
817
|
+
assertions: this.assertions
|
|
818
|
+
});
|
|
782
819
|
}
|
|
783
820
|
/**
|
|
784
|
-
*
|
|
785
|
-
* @
|
|
786
|
-
*
|
|
821
|
+
* Access F1 score metric for assertions (macro average)
|
|
822
|
+
* @example
|
|
823
|
+
* expectStats(predictions, groundTruth)
|
|
824
|
+
* .field("sentiment")
|
|
825
|
+
* .f1.toBeAtLeast(0.75)
|
|
787
826
|
*/
|
|
788
|
-
|
|
827
|
+
get f1() {
|
|
789
828
|
this.validateGroundTruth();
|
|
790
829
|
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
}
|
|
798
|
-
targetClass = classOrThreshold;
|
|
799
|
-
actualThreshold = threshold;
|
|
800
|
-
const classMetrics = metrics.perClass[targetClass];
|
|
801
|
-
if (!classMetrics) {
|
|
802
|
-
throw new AssertionError(
|
|
803
|
-
`Class "${targetClass}" not found in predictions`,
|
|
804
|
-
targetClass,
|
|
805
|
-
Object.keys(metrics.perClass),
|
|
806
|
-
this.fieldName
|
|
807
|
-
);
|
|
808
|
-
}
|
|
809
|
-
actualPrecision = classMetrics.precision;
|
|
810
|
-
}
|
|
811
|
-
const passed = actualPrecision >= actualThreshold;
|
|
812
|
-
const result = {
|
|
813
|
-
type: "precision",
|
|
814
|
-
passed,
|
|
815
|
-
message: passed ? `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is above ${(actualThreshold * 100).toFixed(1)}%` : `Precision${targetClass ? ` for "${targetClass}"` : ""} ${(actualPrecision * 100).toFixed(1)}% is below threshold ${(actualThreshold * 100).toFixed(1)}%`,
|
|
816
|
-
expected: actualThreshold,
|
|
817
|
-
actual: actualPrecision,
|
|
818
|
-
field: this.fieldName,
|
|
819
|
-
class: targetClass
|
|
820
|
-
};
|
|
821
|
-
this.assertions.push(result);
|
|
822
|
-
recordAssertion(result);
|
|
823
|
-
return this;
|
|
830
|
+
return new MetricMatcher({
|
|
831
|
+
parent: this,
|
|
832
|
+
metricName: "F1",
|
|
833
|
+
metricValue: metrics.macroAvg.f1,
|
|
834
|
+
fieldName: this.fieldName,
|
|
835
|
+
assertions: this.assertions
|
|
836
|
+
});
|
|
824
837
|
}
|
|
825
838
|
/**
|
|
826
|
-
*
|
|
827
|
-
* @param
|
|
828
|
-
* @
|
|
839
|
+
* Access precision metric for assertions
|
|
840
|
+
* @param targetClass - Optional class name. If omitted, uses macro average
|
|
841
|
+
* @example
|
|
842
|
+
* expectStats(predictions, groundTruth)
|
|
843
|
+
* .field("sentiment")
|
|
844
|
+
* .precision("positive").toBeAtLeast(0.7)
|
|
829
845
|
*/
|
|
830
|
-
|
|
846
|
+
precision(targetClass) {
|
|
831
847
|
this.validateGroundTruth();
|
|
832
848
|
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
833
|
-
let
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
if (typeof classOrThreshold === "number") {
|
|
837
|
-
actualRecall = metrics.macroAvg.recall;
|
|
838
|
-
actualThreshold = classOrThreshold;
|
|
849
|
+
let metricValue;
|
|
850
|
+
if (targetClass === void 0) {
|
|
851
|
+
metricValue = metrics.macroAvg.precision;
|
|
839
852
|
} else {
|
|
840
|
-
targetClass = classOrThreshold;
|
|
841
|
-
actualThreshold = threshold;
|
|
842
853
|
const classMetrics = metrics.perClass[targetClass];
|
|
843
854
|
if (!classMetrics) {
|
|
844
855
|
throw new AssertionError(
|
|
@@ -848,39 +859,32 @@ var FieldSelector = class {
|
|
|
848
859
|
this.fieldName
|
|
849
860
|
);
|
|
850
861
|
}
|
|
851
|
-
|
|
862
|
+
metricValue = classMetrics.precision;
|
|
852
863
|
}
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
class: targetClass
|
|
862
|
-
};
|
|
863
|
-
this.assertions.push(result);
|
|
864
|
-
recordAssertion(result);
|
|
865
|
-
return this;
|
|
864
|
+
return new MetricMatcher({
|
|
865
|
+
parent: this,
|
|
866
|
+
metricName: "Precision",
|
|
867
|
+
metricValue,
|
|
868
|
+
fieldName: this.fieldName,
|
|
869
|
+
targetClass,
|
|
870
|
+
assertions: this.assertions
|
|
871
|
+
});
|
|
866
872
|
}
|
|
867
873
|
/**
|
|
868
|
-
*
|
|
869
|
-
* @param
|
|
870
|
-
* @
|
|
874
|
+
* Access recall metric for assertions
|
|
875
|
+
* @param targetClass - Optional class name. If omitted, uses macro average
|
|
876
|
+
* @example
|
|
877
|
+
* expectStats(predictions, groundTruth)
|
|
878
|
+
* .field("sentiment")
|
|
879
|
+
* .recall("positive").toBeAtLeast(0.7)
|
|
871
880
|
*/
|
|
872
|
-
|
|
881
|
+
recall(targetClass) {
|
|
873
882
|
this.validateGroundTruth();
|
|
874
883
|
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
875
|
-
let
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
if (typeof classOrThreshold === "number") {
|
|
879
|
-
actualF1 = metrics.macroAvg.f1;
|
|
880
|
-
actualThreshold = classOrThreshold;
|
|
884
|
+
let metricValue;
|
|
885
|
+
if (targetClass === void 0) {
|
|
886
|
+
metricValue = metrics.macroAvg.recall;
|
|
881
887
|
} else {
|
|
882
|
-
targetClass = classOrThreshold;
|
|
883
|
-
actualThreshold = threshold;
|
|
884
888
|
const classMetrics = metrics.perClass[targetClass];
|
|
885
889
|
if (!classMetrics) {
|
|
886
890
|
throw new AssertionError(
|
|
@@ -890,244 +894,171 @@ var FieldSelector = class {
|
|
|
890
894
|
this.fieldName
|
|
891
895
|
);
|
|
892
896
|
}
|
|
893
|
-
|
|
897
|
+
metricValue = classMetrics.recall;
|
|
894
898
|
}
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
class: targetClass
|
|
904
|
-
};
|
|
905
|
-
this.assertions.push(result);
|
|
906
|
-
recordAssertion(result);
|
|
907
|
-
return this;
|
|
899
|
+
return new MetricMatcher({
|
|
900
|
+
parent: this,
|
|
901
|
+
metricName: "Recall",
|
|
902
|
+
metricValue,
|
|
903
|
+
fieldName: this.fieldName,
|
|
904
|
+
targetClass,
|
|
905
|
+
assertions: this.assertions
|
|
906
|
+
});
|
|
908
907
|
}
|
|
908
|
+
// ============================================================================
|
|
909
|
+
// Regression Metric Getters
|
|
910
|
+
// ============================================================================
|
|
909
911
|
/**
|
|
910
|
-
*
|
|
912
|
+
* Access Mean Absolute Error metric for assertions
|
|
913
|
+
* @example
|
|
914
|
+
* expectStats(predictions, groundTruth)
|
|
915
|
+
* .field("score")
|
|
916
|
+
* .mae.toBeAtMost(0.1)
|
|
911
917
|
*/
|
|
912
|
-
|
|
913
|
-
const
|
|
914
|
-
const
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
message: `Confusion matrix recorded for field "${this.fieldName}"`,
|
|
924
|
-
field: this.fieldName
|
|
925
|
-
};
|
|
926
|
-
this.assertions.push(result);
|
|
927
|
-
recordAssertion(result);
|
|
928
|
-
return this;
|
|
918
|
+
get mae() {
|
|
919
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
920
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
921
|
+
return new MetricMatcher({
|
|
922
|
+
parent: this,
|
|
923
|
+
metricName: "MAE",
|
|
924
|
+
metricValue: metrics.mae,
|
|
925
|
+
fieldName: this.fieldName,
|
|
926
|
+
assertions: this.assertions,
|
|
927
|
+
formatValue: (v) => v.toFixed(4)
|
|
928
|
+
});
|
|
929
929
|
}
|
|
930
930
|
/**
|
|
931
|
-
*
|
|
932
|
-
*
|
|
933
|
-
*
|
|
931
|
+
* Access Root Mean Squared Error metric for assertions
|
|
932
|
+
* @example
|
|
933
|
+
* expectStats(predictions, groundTruth)
|
|
934
|
+
* .field("score")
|
|
935
|
+
* .rmse.toBeAtMost(0.15)
|
|
936
|
+
*/
|
|
937
|
+
get rmse() {
|
|
938
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
939
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
940
|
+
return new MetricMatcher({
|
|
941
|
+
parent: this,
|
|
942
|
+
metricName: "RMSE",
|
|
943
|
+
metricValue: metrics.rmse,
|
|
944
|
+
fieldName: this.fieldName,
|
|
945
|
+
assertions: this.assertions,
|
|
946
|
+
formatValue: (v) => v.toFixed(4)
|
|
947
|
+
});
|
|
948
|
+
}
|
|
949
|
+
/**
|
|
950
|
+
* Access R-squared (coefficient of determination) metric for assertions
|
|
951
|
+
* @example
|
|
952
|
+
* expectStats(predictions, groundTruth)
|
|
953
|
+
* .field("score")
|
|
954
|
+
* .r2.toBeAtLeast(0.8)
|
|
955
|
+
*/
|
|
956
|
+
get r2() {
|
|
957
|
+
const { actual, expected } = this.validateRegressionInputs();
|
|
958
|
+
const metrics = computeRegressionMetrics(actual, expected);
|
|
959
|
+
return new MetricMatcher({
|
|
960
|
+
parent: this,
|
|
961
|
+
metricName: "R\xB2",
|
|
962
|
+
metricValue: metrics.r2,
|
|
963
|
+
fieldName: this.fieldName,
|
|
964
|
+
assertions: this.assertions,
|
|
965
|
+
formatValue: (v) => v.toFixed(4)
|
|
966
|
+
});
|
|
967
|
+
}
|
|
968
|
+
// ============================================================================
|
|
969
|
+
// Distribution Assertions
|
|
970
|
+
// ============================================================================
|
|
971
|
+
/**
|
|
972
|
+
* Assert on the percentage of values below or equal to a threshold
|
|
934
973
|
* @param valueThreshold - The value threshold to compare against
|
|
935
|
-
* @param percentageThreshold - The minimum percentage (0-1) of values that should be <= valueThreshold
|
|
936
|
-
* @returns this for method chaining
|
|
937
|
-
*
|
|
938
974
|
* @example
|
|
939
|
-
* // Assert that 90% of confidence scores are below 0.5
|
|
940
975
|
* expectStats(predictions)
|
|
941
976
|
* .field("confidence")
|
|
942
|
-
* .
|
|
977
|
+
* .percentageBelow(0.5).toBeAtLeast(0.9)
|
|
943
978
|
*/
|
|
944
|
-
|
|
979
|
+
percentageBelow(valueThreshold) {
|
|
945
980
|
const numericActual = filterNumericValues(this.actualValues);
|
|
946
981
|
if (numericActual.length === 0) {
|
|
947
982
|
throw new AssertionError(
|
|
948
983
|
`Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
|
|
949
|
-
|
|
984
|
+
void 0,
|
|
950
985
|
void 0,
|
|
951
986
|
this.fieldName
|
|
952
987
|
);
|
|
953
988
|
}
|
|
954
989
|
const actualPercentage = calculatePercentageBelow(numericActual, valueThreshold);
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
};
|
|
964
|
-
this.assertions.push(result);
|
|
965
|
-
recordAssertion(result);
|
|
966
|
-
return this;
|
|
990
|
+
return new PercentageMatcher({
|
|
991
|
+
parent: this,
|
|
992
|
+
fieldName: this.fieldName,
|
|
993
|
+
valueThreshold,
|
|
994
|
+
direction: "below",
|
|
995
|
+
actualPercentage,
|
|
996
|
+
assertions: this.assertions
|
|
997
|
+
});
|
|
967
998
|
}
|
|
968
999
|
/**
|
|
969
|
-
*
|
|
970
|
-
* This is a distributional assertion that only looks at actual values (no ground truth required).
|
|
971
|
-
*
|
|
1000
|
+
* Assert on the percentage of values above a threshold
|
|
972
1001
|
* @param valueThreshold - The value threshold to compare against
|
|
973
|
-
* @param percentageThreshold - The minimum percentage (0-1) of values that should be > valueThreshold
|
|
974
|
-
* @returns this for method chaining
|
|
975
|
-
*
|
|
976
1002
|
* @example
|
|
977
|
-
* // Assert that 80% of quality scores are above 0.7
|
|
978
1003
|
* expectStats(predictions)
|
|
979
1004
|
* .field("quality")
|
|
980
|
-
* .
|
|
1005
|
+
* .percentageAbove(0.7).toBeAtLeast(0.8)
|
|
981
1006
|
*/
|
|
982
|
-
|
|
1007
|
+
percentageAbove(valueThreshold) {
|
|
983
1008
|
const numericActual = filterNumericValues(this.actualValues);
|
|
984
1009
|
if (numericActual.length === 0) {
|
|
985
1010
|
throw new AssertionError(
|
|
986
1011
|
`Field '${this.fieldName}' contains no numeric values (found 0 numeric out of ${this.actualValues.length} total values)`,
|
|
987
|
-
|
|
1012
|
+
void 0,
|
|
988
1013
|
void 0,
|
|
989
1014
|
this.fieldName
|
|
990
1015
|
);
|
|
991
1016
|
}
|
|
992
1017
|
const actualPercentage = calculatePercentageAbove(numericActual, valueThreshold);
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
};
|
|
1002
|
-
this.assertions.push(result);
|
|
1003
|
-
recordAssertion(result);
|
|
1004
|
-
return this;
|
|
1018
|
+
return new PercentageMatcher({
|
|
1019
|
+
parent: this,
|
|
1020
|
+
fieldName: this.fieldName,
|
|
1021
|
+
valueThreshold,
|
|
1022
|
+
direction: "above",
|
|
1023
|
+
actualPercentage,
|
|
1024
|
+
assertions: this.assertions
|
|
1025
|
+
});
|
|
1005
1026
|
}
|
|
1006
1027
|
// ============================================================================
|
|
1007
|
-
//
|
|
1028
|
+
// Display Methods
|
|
1008
1029
|
// ============================================================================
|
|
1009
1030
|
/**
|
|
1010
|
-
*
|
|
1011
|
-
*
|
|
1012
|
-
*/
|
|
1013
|
-
validateRegressionInputs() {
|
|
1014
|
-
this.validateGroundTruth();
|
|
1015
|
-
const numericActual = filterNumericValues(this.actualValues);
|
|
1016
|
-
const numericExpected = filterNumericValues(this.expectedValues);
|
|
1017
|
-
if (numericActual.length === 0) {
|
|
1018
|
-
throw new AssertionError(
|
|
1019
|
-
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric actual values.`,
|
|
1020
|
-
void 0,
|
|
1021
|
-
void 0,
|
|
1022
|
-
this.fieldName
|
|
1023
|
-
);
|
|
1024
|
-
}
|
|
1025
|
-
if (numericExpected.length === 0) {
|
|
1026
|
-
throw new AssertionError(
|
|
1027
|
-
`Regression metric requires numeric values, but field "${this.fieldName}" has no numeric expected values.`,
|
|
1028
|
-
void 0,
|
|
1029
|
-
void 0,
|
|
1030
|
-
this.fieldName
|
|
1031
|
-
);
|
|
1032
|
-
}
|
|
1033
|
-
if (numericActual.length !== numericExpected.length) {
|
|
1034
|
-
throw new AssertionError(
|
|
1035
|
-
`Regression metric requires equal-length arrays, but got ${numericActual.length} actual and ${numericExpected.length} expected values.`,
|
|
1036
|
-
numericExpected.length,
|
|
1037
|
-
numericActual.length,
|
|
1038
|
-
this.fieldName
|
|
1039
|
-
);
|
|
1040
|
-
}
|
|
1041
|
-
return { actual: numericActual, expected: numericExpected };
|
|
1042
|
-
}
|
|
1043
|
-
/**
|
|
1044
|
-
* Asserts that Mean Absolute Error is below a threshold.
|
|
1045
|
-
* Requires numeric values in both actual and expected.
|
|
1046
|
-
*
|
|
1047
|
-
* @param threshold - Maximum allowed MAE
|
|
1048
|
-
* @returns this for method chaining
|
|
1049
|
-
*
|
|
1031
|
+
* Displays the confusion matrix in the report
|
|
1032
|
+
* This is not an assertion - it always passes and just records the matrix for display
|
|
1050
1033
|
* @example
|
|
1051
1034
|
* expectStats(predictions, groundTruth)
|
|
1052
|
-
* .field("
|
|
1053
|
-
* .
|
|
1035
|
+
* .field("sentiment")
|
|
1036
|
+
* .accuracy.toBeAtLeast(0.8)
|
|
1037
|
+
* .displayConfusionMatrix()
|
|
1054
1038
|
*/
|
|
1055
|
-
|
|
1056
|
-
const
|
|
1057
|
-
const
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
passed,
|
|
1062
|
-
message: passed ? `MAE ${metrics.mae.toFixed(4)} is below ${threshold}` : `MAE ${metrics.mae.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1063
|
-
expected: threshold,
|
|
1064
|
-
actual: metrics.mae,
|
|
1065
|
-
field: this.fieldName
|
|
1066
|
-
};
|
|
1067
|
-
this.assertions.push(result);
|
|
1068
|
-
recordAssertion(result);
|
|
1069
|
-
return this;
|
|
1070
|
-
}
|
|
1071
|
-
/**
|
|
1072
|
-
* Asserts that Root Mean Squared Error is below a threshold.
|
|
1073
|
-
* Requires numeric values in both actual and expected.
|
|
1074
|
-
*
|
|
1075
|
-
* @param threshold - Maximum allowed RMSE
|
|
1076
|
-
* @returns this for method chaining
|
|
1077
|
-
*
|
|
1078
|
-
* @example
|
|
1079
|
-
* expectStats(predictions, groundTruth)
|
|
1080
|
-
* .field("score")
|
|
1081
|
-
* .toHaveRMSEBelow(0.15)
|
|
1082
|
-
*/
|
|
1083
|
-
toHaveRMSEBelow(threshold) {
|
|
1084
|
-
const { actual, expected } = this.validateRegressionInputs();
|
|
1085
|
-
const metrics = computeRegressionMetrics(actual, expected);
|
|
1086
|
-
const passed = metrics.rmse <= threshold;
|
|
1087
|
-
const result = {
|
|
1088
|
-
type: "rmse",
|
|
1089
|
-
passed,
|
|
1090
|
-
message: passed ? `RMSE ${metrics.rmse.toFixed(4)} is below ${threshold}` : `RMSE ${metrics.rmse.toFixed(4)} exceeds threshold ${threshold}`,
|
|
1091
|
-
expected: threshold,
|
|
1092
|
-
actual: metrics.rmse,
|
|
1093
|
-
field: this.fieldName
|
|
1039
|
+
displayConfusionMatrix() {
|
|
1040
|
+
const metrics = computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
1041
|
+
const fieldResult = {
|
|
1042
|
+
field: this.fieldName,
|
|
1043
|
+
metrics,
|
|
1044
|
+
binarized: false
|
|
1094
1045
|
};
|
|
1095
|
-
|
|
1096
|
-
recordAssertion(result);
|
|
1097
|
-
return this;
|
|
1098
|
-
}
|
|
1099
|
-
/**
|
|
1100
|
-
* Asserts that R-squared (coefficient of determination) is above a threshold.
|
|
1101
|
-
* R² measures how well the predictions explain the variance in expected values.
|
|
1102
|
-
* R² = 1.0 means perfect prediction, R² = 0 means prediction is no better than mean.
|
|
1103
|
-
* Requires numeric values in both actual and expected.
|
|
1104
|
-
*
|
|
1105
|
-
* @param threshold - Minimum required R² value (0-1)
|
|
1106
|
-
* @returns this for method chaining
|
|
1107
|
-
*
|
|
1108
|
-
* @example
|
|
1109
|
-
* expectStats(predictions, groundTruth)
|
|
1110
|
-
* .field("score")
|
|
1111
|
-
* .toHaveR2Above(0.8)
|
|
1112
|
-
*/
|
|
1113
|
-
toHaveR2Above(threshold) {
|
|
1114
|
-
const { actual, expected } = this.validateRegressionInputs();
|
|
1115
|
-
const metrics = computeRegressionMetrics(actual, expected);
|
|
1116
|
-
const passed = metrics.r2 >= threshold;
|
|
1046
|
+
recordFieldMetrics(fieldResult);
|
|
1117
1047
|
const result = {
|
|
1118
|
-
type: "
|
|
1119
|
-
passed,
|
|
1120
|
-
message:
|
|
1121
|
-
expected: threshold,
|
|
1122
|
-
actual: metrics.r2,
|
|
1048
|
+
type: "confusionMatrix",
|
|
1049
|
+
passed: true,
|
|
1050
|
+
message: `Confusion matrix recorded for field "${this.fieldName}"`,
|
|
1123
1051
|
field: this.fieldName
|
|
1124
1052
|
};
|
|
1125
1053
|
this.assertions.push(result);
|
|
1126
1054
|
recordAssertion(result);
|
|
1127
1055
|
return this;
|
|
1128
1056
|
}
|
|
1057
|
+
// ============================================================================
|
|
1058
|
+
// Utility Methods
|
|
1059
|
+
// ============================================================================
|
|
1129
1060
|
/**
|
|
1130
|
-
* Gets the computed metrics for this field
|
|
1061
|
+
* Gets the computed classification metrics for this field
|
|
1131
1062
|
*/
|
|
1132
1063
|
getMetrics() {
|
|
1133
1064
|
return computeClassificationMetrics(this.actualValues, this.expectedValues);
|
|
@@ -1160,7 +1091,7 @@ function normalizeInput(input) {
|
|
|
1160
1091
|
}));
|
|
1161
1092
|
}
|
|
1162
1093
|
throw new Error(
|
|
1163
|
-
"Invalid input to expectStats(): expected
|
|
1094
|
+
"Invalid input to expectStats(): expected { aligned: AlignedRecord[] }, Prediction[], or AlignedRecord[]"
|
|
1164
1095
|
);
|
|
1165
1096
|
}
|
|
1166
1097
|
function expectStats(inputOrActual, expected, options) {
|
|
@@ -1205,6 +1136,6 @@ var ExpectStats = class {
|
|
|
1205
1136
|
}
|
|
1206
1137
|
};
|
|
1207
1138
|
|
|
1208
|
-
export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall,
|
|
1139
|
+
export { afterAll, afterEach, alignByKey, beforeAll, beforeEach, checkIntegrity, computeAccuracy, computeClassificationMetrics, computeF1, computePrecision, computeRecall, describe, evalTest, expectStats, extractFieldValues, filterComplete, it, test, validatePredictions };
|
|
1209
1140
|
//# sourceMappingURL=index.js.map
|
|
1210
1141
|
//# sourceMappingURL=index.js.map
|