@datagrok/eda 1.4.11 → 1.4.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc.json +0 -1
- package/CHANGELOG.md +15 -0
- package/CLAUDE.md +185 -0
- package/README.md +8 -0
- package/css/pmpo.css +35 -0
- package/dist/package-test.js +1 -1
- package/dist/package-test.js.map +1 -1
- package/dist/package.js +1 -1
- package/dist/package.js.map +1 -1
- package/eslintrc.json +45 -0
- package/files/drugs-props-test.csv +126 -0
- package/files/drugs-props-train-scores.csv +664 -0
- package/files/drugs-props-train.csv +664 -0
- package/package.json +9 -3
- package/src/anova/anova-tools.ts +1 -1
- package/src/anova/anova-ui.ts +1 -1
- package/src/package-api.ts +18 -0
- package/src/package-test.ts +4 -1
- package/src/package.g.ts +25 -0
- package/src/package.ts +55 -15
- package/src/pareto-optimization/pareto-computations.ts +6 -0
- package/src/pareto-optimization/utils.ts +6 -4
- package/src/probabilistic-scoring/data-generator.ts +157 -0
- package/src/probabilistic-scoring/nelder-mead.ts +204 -0
- package/src/probabilistic-scoring/pmpo-defs.ts +218 -0
- package/src/probabilistic-scoring/pmpo-utils.ts +603 -0
- package/src/probabilistic-scoring/prob-scoring.ts +991 -0
- package/src/probabilistic-scoring/stat-tools.ts +303 -0
- package/src/softmax-classifier.ts +1 -1
- package/src/tests/anova-tests.ts +1 -1
- package/src/tests/classifiers-tests.ts +1 -1
- package/src/tests/dim-reduction-tests.ts +1 -1
- package/src/tests/linear-methods-tests.ts +1 -1
- package/src/tests/mis-vals-imputation-tests.ts +1 -1
- package/src/tests/pareto-tests.ts +253 -0
- package/src/tests/pmpo-tests.ts +157 -0
- package/test-console-output-1.log +175 -209
- package/test-record-1.mp4 +0 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
// Probabilistic scoring (pMPO) statistical tools
|
|
2
|
+
// Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
|
|
3
|
+
|
|
4
|
+
import * as grok from 'datagrok-api/grok';
|
|
5
|
+
import * as ui from 'datagrok-api/ui';
|
|
6
|
+
import * as DG from 'datagrok-api/dg';
|
|
7
|
+
|
|
8
|
+
//@ts-ignore: no types
|
|
9
|
+
import * as jStat from 'jstat';
|
|
10
|
+
|
|
11
|
+
import {ConfusionMatrix, Cutoff, DescriptorStatistics, ModelEvaluationResult,
|
|
12
|
+
ROC_TRESHOLDS, ROC_TRESHOLDS_COUNT, SigmoidParams} from './pmpo-defs';
|
|
13
|
+
|
|
14
|
+
/** Splits the dataframe into desired and non-desired tables based on the desirability column */
|
|
15
|
+
export function getDesiredTables(df: DG.DataFrame, desirability: DG.Column) {
|
|
16
|
+
const groups = df.groupBy([desirability.name]).getGroups() as any;
|
|
17
|
+
let desired: DG.DataFrame;
|
|
18
|
+
let nonDesired: DG.DataFrame;
|
|
19
|
+
|
|
20
|
+
for (const name in groups) {
|
|
21
|
+
if (name.toLowerCase().includes('true'))
|
|
22
|
+
desired = groups[name];
|
|
23
|
+
else
|
|
24
|
+
nonDesired = groups[name];
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
//@ts-ignore
|
|
28
|
+
return {desired, nonDesired};
|
|
29
|
+
} // getDesiredTables
|
|
30
|
+
|
|
31
|
+
/* Welch two-sample t-test (two-sided) */
|
|
32
|
+
export function getDescriptorStatistics(des: DG.Column, nonDes: DG.Column): DescriptorStatistics {
|
|
33
|
+
const desLen = des.length;
|
|
34
|
+
const nonDesLen = nonDes.length;
|
|
35
|
+
if (desLen < 2 || nonDesLen < 2) {
|
|
36
|
+
throw new Error(`Failed to compute the "${des.name}" descriptor statistics:
|
|
37
|
+
both samples must have at least two observations.`);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
const desAvg = des.stats.avg;
|
|
41
|
+
const nonDesAvg = nonDes.stats.avg;
|
|
42
|
+
const desVar = des.stats.variance;
|
|
43
|
+
const nonDesVar = nonDes.stats.variance;
|
|
44
|
+
const desStd = des.stats.stdev;
|
|
45
|
+
const nonDesStd = nonDes.stats.stdev;
|
|
46
|
+
|
|
47
|
+
const se = Math.sqrt(desVar / desLen + nonDesVar / nonDesLen);
|
|
48
|
+
if (se === 0) {
|
|
49
|
+
throw new Error(`Failed to compute the "${des.name}" descriptor statistics:
|
|
50
|
+
zero variance.`);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const t = (desAvg - nonDesAvg) / se;
|
|
54
|
+
|
|
55
|
+
// Welch–Satterthwaite degrees of freedom
|
|
56
|
+
const numerator = (desVar / desLen + nonDesVar / nonDesLen) ** 2;
|
|
57
|
+
const denom = (desVar * desVar) / (desLen * desLen * (desLen - 1)) +
|
|
58
|
+
(nonDesVar * nonDesVar) / (nonDesLen * nonDesLen * (nonDesLen - 1));
|
|
59
|
+
const df = numerator / denom;
|
|
60
|
+
|
|
61
|
+
// two-sided p-value
|
|
62
|
+
const cdf = jStat.studentt.cdf(Math.abs(t), df);
|
|
63
|
+
const pValue = 2 * (1 - cdf);
|
|
64
|
+
|
|
65
|
+
return {
|
|
66
|
+
desAvg: desAvg,
|
|
67
|
+
desStd: desStd,
|
|
68
|
+
desLen: desLen,
|
|
69
|
+
nonDesAvg: nonDesAvg,
|
|
70
|
+
nonDesStd: nonDesStd,
|
|
71
|
+
nonSesLen: nonDesLen,
|
|
72
|
+
min: Math.min(des.stats.min, nonDes.stats.min),
|
|
73
|
+
max: Math.max(des.stats.max, nonDes.stats.max),
|
|
74
|
+
tstat: t,
|
|
75
|
+
pValue: pValue,
|
|
76
|
+
};
|
|
77
|
+
} // getDescriptorStatistics
|
|
78
|
+
|
|
79
|
+
/** Compute cutoffs for the pMPO method */
|
|
80
|
+
export function getCutoffs(muDesired: number, stdDesired: number, muNotDesired: number,
|
|
81
|
+
stdNotDesired: number): Cutoff {
|
|
82
|
+
if (muDesired < muNotDesired) {
|
|
83
|
+
return {
|
|
84
|
+
cutoff: ((muNotDesired - muDesired) / (stdDesired + stdNotDesired)) * stdDesired + muDesired,
|
|
85
|
+
cutoffDesired: Math.max(muDesired, muNotDesired - stdNotDesired),
|
|
86
|
+
cutoffNotDesired: Math.max(muDesired + stdDesired, muNotDesired),
|
|
87
|
+
};
|
|
88
|
+
} else {
|
|
89
|
+
return {
|
|
90
|
+
cutoff: ((muDesired - muNotDesired) / (stdDesired + stdNotDesired)) * stdNotDesired + muNotDesired,
|
|
91
|
+
cutoffDesired: Math.min(muNotDesired + stdNotDesired, muDesired),
|
|
92
|
+
cutoffNotDesired: Math.max(muNotDesired, muDesired - stdDesired),
|
|
93
|
+
};
|
|
94
|
+
}
|
|
95
|
+
} // getCutoffs
|
|
96
|
+
|
|
97
|
+
/** Solve normal intersection for the pMPO method */
|
|
98
|
+
export function solveNormalIntersection(mu1: number, s1: number, mu2: number, s2: number): number[] {
|
|
99
|
+
const a = 1 / (2 * s1 ** 2) - 1 / (2 * s2 ** 2);
|
|
100
|
+
const b = mu2 / (s2 ** 2) - mu1 / (s1 ** 2);
|
|
101
|
+
const c = (mu1 ** 2) / (2 * s1 ** 2) - (mu2 ** 2) / (2 * s2 ** 2) - Math.log(s2 / s1);
|
|
102
|
+
|
|
103
|
+
// If a is nearly zero, solve linear equation
|
|
104
|
+
if (Math.abs(a) < 1e-12) {
|
|
105
|
+
if (Math.abs(b) < 1e-12) return [];
|
|
106
|
+
return [-c / b];
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
const disc = b * b - 4 * a * c;
|
|
110
|
+
if (disc < 0) return [];
|
|
111
|
+
|
|
112
|
+
const sqrtDisc = Math.sqrt(disc);
|
|
113
|
+
const x1 = (-b + sqrtDisc) / (2 * a);
|
|
114
|
+
const x2 = (-b - sqrtDisc) / (2 * a);
|
|
115
|
+
|
|
116
|
+
return [x1, x2];
|
|
117
|
+
} // solveNormalIntersection
|
|
118
|
+
|
|
119
|
+
/** Compute sigmoid parameters for the pMPO method */
|
|
120
|
+
export function computeSigmoidParamsFromX0(muDes: number, sigmaDes: number, x0: number, xBound: number,
|
|
121
|
+
qCutoff: number = 0.05): SigmoidParams {
|
|
122
|
+
let pX0: number;
|
|
123
|
+
|
|
124
|
+
if (sigmaDes <= 0)
|
|
125
|
+
pX0 = x0 === muDes ? 1.0 : 0.0;
|
|
126
|
+
else {
|
|
127
|
+
// normal pdf
|
|
128
|
+
const coef = 1 / (sigmaDes * Math.sqrt(2 * Math.PI));
|
|
129
|
+
const exponent = -0.5 * ((x0 - muDes) / sigmaDes) ** 2;
|
|
130
|
+
pX0 = coef * Math.exp(exponent);
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
const eps = 1e-12;
|
|
134
|
+
pX0 = Math.max(pX0, eps);
|
|
135
|
+
|
|
136
|
+
const b = Math.max(1.0 / pX0 - 1.0, eps);
|
|
137
|
+
const n = 1.0 / qCutoff - 1.0;
|
|
138
|
+
const dx = xBound - x0;
|
|
139
|
+
|
|
140
|
+
let c: number;
|
|
141
|
+
if (Math.abs(dx) < 1e-12)
|
|
142
|
+
c = 1.0;
|
|
143
|
+
else {
|
|
144
|
+
const ratio = n / b;
|
|
145
|
+
if (ratio <= 0)
|
|
146
|
+
c = 1.0;
|
|
147
|
+
else {
|
|
148
|
+
try {
|
|
149
|
+
c = Math.exp(-Math.log(ratio) / dx);
|
|
150
|
+
} catch {
|
|
151
|
+
c = 1.0;
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
return {pX0: pX0, b: b, c: c};
|
|
157
|
+
} // computeSigmoidParamsFromX0
|
|
158
|
+
|
|
159
|
+
/** Generalized sigmoid function */
|
|
160
|
+
export function sigmoidS(x: number, x0: number, b: number, c: number): number {
|
|
161
|
+
if (c > 0)
|
|
162
|
+
return 1.0 / (1.0 + b * (c ** (-(x - x0))));
|
|
163
|
+
return 1.0/(1.0 + b);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
/** Normal probability density function */
|
|
167
|
+
export function gaussDesirabilityFunc(x: number, mu: number, sigma: number): number {
|
|
168
|
+
return Math.exp(-((x - mu)**2) / (2 * sigma**2));
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
/** Computes the confusion matrix given desirability (labels) and prediction columns
|
|
172
|
+
* @param desirability - desirability column (boolean)
|
|
173
|
+
* @param prediction - prediction column (numeric)
|
|
174
|
+
* @param threshold - threshold to convert prediction scores to binary labels
|
|
175
|
+
* @return ConfusionMatrix object with TP, TN, FP, FN counts
|
|
176
|
+
*/
|
|
177
|
+
export function getConfusionMatrix(desirability: DG.Column, prediction: DG.Column, threshold: number): ConfusionMatrix {
|
|
178
|
+
if (desirability.length !== prediction.length)
|
|
179
|
+
throw new Error('Failed to compute confusion matrix: columns have different lengths.');
|
|
180
|
+
|
|
181
|
+
if (desirability.type !== DG.COLUMN_TYPE.BOOL)
|
|
182
|
+
throw new Error('Failed to compute confusion matrix: desirability column must be boolean.');
|
|
183
|
+
|
|
184
|
+
if (!prediction.isNumerical)
|
|
185
|
+
throw new Error('Failed to compute confusion matrix: prediction column must be numerical.');
|
|
186
|
+
|
|
187
|
+
let TP = 0;
|
|
188
|
+
let TN = 0;
|
|
189
|
+
let FP = 0;
|
|
190
|
+
let FN = 0;
|
|
191
|
+
|
|
192
|
+
const desRaw = desirability.getRawData();
|
|
193
|
+
const predRaw = prediction.getRawData();
|
|
194
|
+
|
|
195
|
+
let desIdx = 0;
|
|
196
|
+
let curPos = 0;
|
|
197
|
+
let desElem = desRaw[0];
|
|
198
|
+
|
|
199
|
+
// Here, we extract bits from the desirability boolean column in chunks of 32 bits
|
|
200
|
+
for (let predIdx = 0; predIdx < prediction.length; ++predIdx) {
|
|
201
|
+
// console.log(predIdx + 1, ': ',
|
|
202
|
+
// desirability.get(predIdx), '<-->', (desElem >>> curPos) & 1, ' vs ', predRaw[predIdx] >= threshold);
|
|
203
|
+
|
|
204
|
+
if (((desElem >>> curPos) & 1) == 1) { // True actual
|
|
205
|
+
if (predRaw[predIdx] >= threshold) { // True predicted
|
|
206
|
+
++TP;
|
|
207
|
+
} else { // False predicted
|
|
208
|
+
++FN;
|
|
209
|
+
}
|
|
210
|
+
} else { // False actual
|
|
211
|
+
if (predRaw[predIdx] >= threshold) { // True predicted
|
|
212
|
+
++FP;
|
|
213
|
+
} else { // False predicted
|
|
214
|
+
++TN;
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
++curPos;
|
|
219
|
+
|
|
220
|
+
// Move to the next desirability element if we have processed 32 bits
|
|
221
|
+
if (curPos >= 32) {
|
|
222
|
+
curPos = 0;
|
|
223
|
+
++desIdx;
|
|
224
|
+
desElem = desRaw[desIdx];
|
|
225
|
+
}
|
|
226
|
+
} // for predIdx
|
|
227
|
+
|
|
228
|
+
return {TP: TP, TN: TN, FP: FP, FN: FN};
|
|
229
|
+
} // getConfusionMatrix
|
|
230
|
+
|
|
231
|
+
/** Computes Area Under Curve (AUC) given TPR and FPR arrays
|
|
232
|
+
* @param tpr - True Positive Rate array
|
|
233
|
+
* @param fpr - False Positive Rate array
|
|
234
|
+
* @return AUC value
|
|
235
|
+
*/
|
|
236
|
+
export function getAuc(tpr: Float32Array, fpr: Float32Array): number {
|
|
237
|
+
if (tpr.length !== fpr.length)
|
|
238
|
+
throw new Error('Failed to compute AUC: TPR and FPR arrays have different lengths.');
|
|
239
|
+
|
|
240
|
+
let auc = 0.0;
|
|
241
|
+
|
|
242
|
+
for (let i = 1; i < tpr.length; ++i) {
|
|
243
|
+
const xDiff = Math.abs(fpr[i] - fpr[i - 1]);
|
|
244
|
+
const yAvg = (tpr[i] + tpr[i - 1]) / 2.0;
|
|
245
|
+
auc += xDiff * yAvg;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
return auc;
|
|
249
|
+
} // getAuc
|
|
250
|
+
|
|
251
|
+
/** Converts numeric prediction column to boolean based on the given threshold
|
|
252
|
+
* @param numericPrediction - numeric prediction column
|
|
253
|
+
* @param threshold - threshold to convert prediction scores to binary labels
|
|
254
|
+
* @param name - name for the resulting boolean column
|
|
255
|
+
* @return Boolean prediction column
|
|
256
|
+
*/
|
|
257
|
+
export function getBoolPredictionColumn(numericPrediction: DG.Column, threshold: number, name: string): DG.Column {
|
|
258
|
+
if (!numericPrediction.isNumerical)
|
|
259
|
+
throw new Error('Failed to compute confusion matrix: prediction column must be numerical.');
|
|
260
|
+
|
|
261
|
+
const size = numericPrediction.length;
|
|
262
|
+
const boolPredData = new Array<boolean>(size);
|
|
263
|
+
const predRaw = numericPrediction.getRawData();
|
|
264
|
+
|
|
265
|
+
for (let i = 0; i < size; ++i)
|
|
266
|
+
boolPredData[i] = (predRaw[i] >= threshold);
|
|
267
|
+
|
|
268
|
+
return DG.Column.fromList(DG.COLUMN_TYPE.BOOL, name, boolPredData);
|
|
269
|
+
} // getBoolPredictionColumn
|
|
270
|
+
|
|
271
|
+
/** Computes pMPO model evaluation metrics: AUC, optimal threshold, TPR and FPR arrays
|
|
272
|
+
* @param desirability - desirability column (boolean)
|
|
273
|
+
* @param prediction - prediction column (numeric)
|
|
274
|
+
* @return ModelEvaluationResult object with AUC, optimal threshold, TPR and FPR arrays
|
|
275
|
+
*/
|
|
276
|
+
export function getPmpoEvaluation(desirability: DG.Column, prediction: DG.Column): ModelEvaluationResult {
|
|
277
|
+
const tpr = new Float32Array(ROC_TRESHOLDS_COUNT);
|
|
278
|
+
const fpr = new Float32Array(ROC_TRESHOLDS_COUNT);
|
|
279
|
+
|
|
280
|
+
let bestJ = -1;
|
|
281
|
+
let currentJ = -1;
|
|
282
|
+
let bestThreshold = ROC_TRESHOLDS[0];
|
|
283
|
+
|
|
284
|
+
// Compute TPR and FPR for each threshold
|
|
285
|
+
for (let i = 0; i < ROC_TRESHOLDS_COUNT; ++i) {
|
|
286
|
+
const confusion = getConfusionMatrix(desirability, prediction, ROC_TRESHOLDS[i]);
|
|
287
|
+
tpr[i] = (confusion.TP + confusion.FN) > 0 ? confusion.TP / (confusion.TP + confusion.FN) : 0;
|
|
288
|
+
fpr[i] = (confusion.FP + confusion.TN) > 0 ? confusion.FP / (confusion.FP + confusion.TN) : 0;
|
|
289
|
+
currentJ = tpr[i] - fpr[i];
|
|
290
|
+
|
|
291
|
+
if (currentJ > bestJ) {
|
|
292
|
+
bestJ = currentJ;
|
|
293
|
+
bestThreshold = ROC_TRESHOLDS[i];
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
return {
|
|
298
|
+
auc: getAuc(tpr, fpr),
|
|
299
|
+
threshold: bestThreshold,
|
|
300
|
+
tpr: tpr,
|
|
301
|
+
fpr: fpr,
|
|
302
|
+
};
|
|
303
|
+
} // getPmpoEvaluation
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
/* Softmax classifier (multinomial logistic regression): https://en.wikipedia.org/wiki/Multinomial_logistic_regression */
|
|
2
2
|
|
|
3
3
|
import * as grok from 'datagrok-api/grok';
|
|
4
4
|
import * as ui from 'datagrok-api/ui';
|
package/src/tests/anova-tests.ts
CHANGED
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
|
|
10
10
|
import {oneWayAnova, FactorizedData} from '../anova/anova-tools';
|
|
11
11
|
|
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
|
|
10
10
|
import {classificationDataset, accuracy} from './utils';
|
|
11
11
|
import {SoftmaxClassifier} from '../softmax-classifier';
|
|
@@ -5,7 +5,7 @@ import {_package} from '../package-test';
|
|
|
5
5
|
|
|
6
6
|
// tests for dimensionality reduction
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
import {DimReductionMethods} from '@datagrok-libraries/ml/src/multi-column-dimensionality-reduction/types';
|
|
10
10
|
import {KnownMetrics, NumberMetricsNames, StringMetricsNames} from '@datagrok-libraries/ml/src/typed-metrics';
|
|
11
11
|
import {multiColReduceDimensionality}
|
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
import {computePCA} from '../eda-tools';
|
|
10
10
|
import {getPlsAnalysis} from '../pls/pls-tools';
|
|
11
11
|
import {PlsModel} from '../pls/pls-ml';
|
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
|
|
10
10
|
import {MetricInfo, DISTANCE_TYPE, impute} from '../missing-values-imputation/knn-imputer';
|
|
11
11
|
import {getFeatureInputSettings} from '../missing-values-imputation/ui';
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
// Tests for Pareto Front Computations
|
|
2
|
+
// Performance tests for the Pareto optimality algorithm
|
|
3
|
+
|
|
4
|
+
import * as grok from 'datagrok-api/grok';
|
|
5
|
+
import * as ui from 'datagrok-api/ui';
|
|
6
|
+
import * as DG from 'datagrok-api/dg';
|
|
7
|
+
import {_package} from '../package-test';
|
|
8
|
+
|
|
9
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
10
|
+
|
|
11
|
+
import {getParetoMask} from '../pareto-optimization/pareto-computations';
|
|
12
|
+
import {OPT_TYPE, NumericArray} from '../pareto-optimization/defs';
|
|
13
|
+
|
|
14
|
+
const TIMEOUT = 5000;
|
|
15
|
+
|
|
16
|
+
// Test dataset sizes
|
|
17
|
+
const ROWS_COUNT = 1000000;
|
|
18
|
+
const M = 1000000;
|
|
19
|
+
const COLS_COUNT = 2;
|
|
20
|
+
const suffix = M < 1e6 ? 'K' : 'M';
|
|
21
|
+
const DATASET_SIZE_LABEL = `${ROWS_COUNT / M}${suffix} points, ${COLS_COUNT}D`;
|
|
22
|
+
|
|
23
|
+
/** Generates synthetic numeric data for Pareto front testing */
|
|
24
|
+
function generateSyntheticData(nPoints: number, nDims: number, seed: number = 42): NumericArray[] {
|
|
25
|
+
const data: NumericArray[] = [];
|
|
26
|
+
|
|
27
|
+
// Simple deterministic pseudo-random generator for reproducibility
|
|
28
|
+
let rng = seed;
|
|
29
|
+
const random = () => {
|
|
30
|
+
rng = (rng * 1664525 + 1013904223) % 4294967296;
|
|
31
|
+
return rng / 4294967296;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
for (let d = 0; d < nDims; d++) {
|
|
35
|
+
const column = new Float32Array(nPoints);
|
|
36
|
+
for (let i = 0; i < nPoints; i++) {
|
|
37
|
+
// Generate values with some correlation to create realistic Pareto fronts
|
|
38
|
+
column[i] = random() * 100 + (d * 10);
|
|
39
|
+
}
|
|
40
|
+
data.push(column);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return data;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/** Generates optimization sense array */
|
|
47
|
+
function generateSense(nDims: number, pattern: 'all-min' | 'all-max' | 'mixed'): OPT_TYPE[] {
|
|
48
|
+
const sense: OPT_TYPE[] = [];
|
|
49
|
+
|
|
50
|
+
for (let d = 0; d < nDims; d++) {
|
|
51
|
+
if (pattern === 'all-min') {
|
|
52
|
+
sense.push(OPT_TYPE.MIN);
|
|
53
|
+
} else if (pattern === 'all-max') {
|
|
54
|
+
sense.push(OPT_TYPE.MAX);
|
|
55
|
+
} else {
|
|
56
|
+
// Mixed: alternate between MIN and MAX
|
|
57
|
+
sense.push(d % 2 === 0 ? OPT_TYPE.MIN : OPT_TYPE.MAX);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return sense;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/** Generates null indices set for testing missing value handling */
|
|
65
|
+
function generateNullIndices(nPoints: number, nullRatio: number): Set<number> {
|
|
66
|
+
const nullCount = Math.floor(nPoints * nullRatio);
|
|
67
|
+
const nullIndices = new Set<number>();
|
|
68
|
+
|
|
69
|
+
// Distribute null indices evenly
|
|
70
|
+
const step = Math.floor(nPoints / nullCount);
|
|
71
|
+
for (let i = 0; i < nullCount; i++) {
|
|
72
|
+
nullIndices.add(i * step);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
return nullIndices;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
/** Validates Pareto mask result */
|
|
79
|
+
function validateParetoMask(mask: boolean[], nPoints: number): void {
|
|
80
|
+
if (mask.length !== nPoints) {
|
|
81
|
+
throw new Error(`Invalid mask length: expected ${nPoints}, got ${mask.length}`);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
const optimalCount = mask.filter(x => x).length;
|
|
85
|
+
if (optimalCount === 0) {
|
|
86
|
+
throw new Error('No optimal points found');
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if (optimalCount === nPoints) {
|
|
90
|
+
grok.shell.warning('All points are optimal - data may be degenerate');
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
category('Pareto optimization', () => {
|
|
95
|
+
|
|
96
|
+
test(`Performance: ${DATASET_SIZE_LABEL}`, async () => {
|
|
97
|
+
let mask: boolean[] | null = null;
|
|
98
|
+
let error: Error | null = null;
|
|
99
|
+
|
|
100
|
+
try {
|
|
101
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
102
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
103
|
+
mask = getParetoMask(data, sense, ROWS_COUNT);
|
|
104
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
105
|
+
} catch (e) {
|
|
106
|
+
error = e as Error;
|
|
107
|
+
grok.shell.error(error.message);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
111
|
+
expect(error === null, true, error?.message ?? '');
|
|
112
|
+
}, {timeout: TIMEOUT});
|
|
113
|
+
|
|
114
|
+
// Tests for different optimization patterns
|
|
115
|
+
test(`Performance: ${DATASET_SIZE_LABEL}, all minimize`, async () => {
|
|
116
|
+
let mask: boolean[] | null = null;
|
|
117
|
+
let error: Error | null = null;
|
|
118
|
+
|
|
119
|
+
try {
|
|
120
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
121
|
+
const sense = generateSense(COLS_COUNT, 'all-min');
|
|
122
|
+
mask = getParetoMask(data, sense, ROWS_COUNT);
|
|
123
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
124
|
+
} catch (e) {
|
|
125
|
+
error = e as Error;
|
|
126
|
+
grok.shell.error(error.message);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
130
|
+
expect(error === null, true, error?.message ?? '');
|
|
131
|
+
}, {timeout: TIMEOUT});
|
|
132
|
+
|
|
133
|
+
test(`Performance: ${DATASET_SIZE_LABEL}, all maximize`, async () => {
|
|
134
|
+
let mask: boolean[] | null = null;
|
|
135
|
+
let error: Error | null = null;
|
|
136
|
+
|
|
137
|
+
try {
|
|
138
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
139
|
+
const sense = generateSense(COLS_COUNT, 'all-max');
|
|
140
|
+
mask = getParetoMask(data, sense, ROWS_COUNT);
|
|
141
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
142
|
+
} catch (e) {
|
|
143
|
+
error = e as Error;
|
|
144
|
+
grok.shell.error(error.message);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
148
|
+
expect(error === null, true, error?.message ?? '');
|
|
149
|
+
}, {timeout: TIMEOUT});
|
|
150
|
+
|
|
151
|
+
// Tests with missing values
|
|
152
|
+
test(`Performance: ${DATASET_SIZE_LABEL} with 10% null indices`, async () => {
|
|
153
|
+
let mask: boolean[] | null = null;
|
|
154
|
+
let error: Error | null = null;
|
|
155
|
+
|
|
156
|
+
try {
|
|
157
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
158
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
159
|
+
const nullIndices = generateNullIndices(ROWS_COUNT, 0.1);
|
|
160
|
+
mask = getParetoMask(data, sense, ROWS_COUNT, nullIndices);
|
|
161
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
162
|
+
} catch (e) {
|
|
163
|
+
error = e as Error;
|
|
164
|
+
grok.shell.error(error.message);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
168
|
+
expect(error === null, true, error?.message ?? '');
|
|
169
|
+
}, {timeout: TIMEOUT});
|
|
170
|
+
|
|
171
|
+
test(`Performance: ${DATASET_SIZE_LABEL} with 25% null indices`, async () => {
|
|
172
|
+
let mask: boolean[] | null = null;
|
|
173
|
+
let error: Error | null = null;
|
|
174
|
+
|
|
175
|
+
try {
|
|
176
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
177
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
178
|
+
const nullIndices = generateNullIndices(ROWS_COUNT, 0.25);
|
|
179
|
+
mask = getParetoMask(data, sense, ROWS_COUNT, nullIndices);
|
|
180
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
181
|
+
} catch (e) {
|
|
182
|
+
error = e as Error;
|
|
183
|
+
grok.shell.error(error.message);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
187
|
+
expect(error === null, true, error?.message ?? '');
|
|
188
|
+
}, {timeout: TIMEOUT});
|
|
189
|
+
|
|
190
|
+
// Edge cases
|
|
191
|
+
test('Edge case: Empty dataset', async () => {
|
|
192
|
+
let mask: boolean[] | null = null;
|
|
193
|
+
let error: Error | null = null;
|
|
194
|
+
|
|
195
|
+
try {
|
|
196
|
+
const data: NumericArray[] = [new Float32Array(0), new Float32Array(0)];
|
|
197
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
198
|
+
mask = getParetoMask(data, sense, 0);
|
|
199
|
+
} catch (e) {
|
|
200
|
+
error = e as Error;
|
|
201
|
+
grok.shell.error(error.message);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
205
|
+
expect(mask!.length, 0, 'Empty dataset should return empty mask');
|
|
206
|
+
expect(error === null, true, error?.message ?? '');
|
|
207
|
+
}, {timeout: TIMEOUT});
|
|
208
|
+
|
|
209
|
+
test('Edge case: Single point', async () => {
|
|
210
|
+
let mask: boolean[] | null = null;
|
|
211
|
+
let error: Error | null = null;
|
|
212
|
+
|
|
213
|
+
try {
|
|
214
|
+
const data: NumericArray[] = [new Float32Array([1.0]), new Float32Array([2.0])];
|
|
215
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
216
|
+
|
|
217
|
+
mask = getParetoMask(data, sense, 1);
|
|
218
|
+
} catch (e) {
|
|
219
|
+
error = e as Error;
|
|
220
|
+
grok.shell.error(error.message);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
224
|
+
expect(mask!.length, 1, 'Single point dataset should return mask with one element');
|
|
225
|
+
expect(mask![0], true, 'Single point should be optimal');
|
|
226
|
+
expect(error === null, true, error?.message ?? '');
|
|
227
|
+
}, {timeout: TIMEOUT});
|
|
228
|
+
|
|
229
|
+
test('Edge case: All identical points', async () => {
|
|
230
|
+
let mask: boolean[] | null = null;
|
|
231
|
+
let error: Error | null = null;
|
|
232
|
+
|
|
233
|
+
try {
|
|
234
|
+
const nPoints = 100;
|
|
235
|
+
const data: NumericArray[] = [
|
|
236
|
+
new Float32Array(nPoints).fill(5.0),
|
|
237
|
+
new Float32Array(nPoints).fill(10.0),
|
|
238
|
+
];
|
|
239
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
240
|
+
|
|
241
|
+
mask = getParetoMask(data, sense, nPoints);
|
|
242
|
+
} catch (e) {
|
|
243
|
+
error = e as Error;
|
|
244
|
+
grok.shell.error(error.message);
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
248
|
+
expect(mask!.length, 100, 'Should return mask with correct length');
|
|
249
|
+
const optimalCount = mask!.filter(x => x).length;
|
|
250
|
+
expect(optimalCount > 0, true, 'At least some identical points should be optimal');
|
|
251
|
+
expect(error === null, true, error?.message ?? '');
|
|
252
|
+
}, {timeout: TIMEOUT});
|
|
253
|
+
});
|