@datagrok/eda 1.4.12 → 1.4.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc.json +0 -1
- package/CHANGELOG.md +8 -0
- package/CLAUDE.md +185 -0
- package/css/pmpo.css +9 -0
- package/dist/package-test.js +1 -1
- package/dist/package-test.js.map +1 -1
- package/dist/package.js +1 -1
- package/dist/package.js.map +1 -1
- package/eslintrc.json +0 -1
- package/files/drugs-props-train-scores.csv +664 -0
- package/package.json +7 -3
- package/src/package-api.ts +7 -3
- package/src/package-test.ts +4 -1
- package/src/package.g.ts +21 -9
- package/src/package.ts +32 -23
- package/src/pareto-optimization/pareto-computations.ts +6 -0
- package/src/probabilistic-scoring/data-generator.ts +157 -0
- package/src/probabilistic-scoring/nelder-mead.ts +204 -0
- package/src/probabilistic-scoring/pmpo-defs.ts +112 -2
- package/src/probabilistic-scoring/pmpo-utils.ts +100 -77
- package/src/probabilistic-scoring/prob-scoring.ts +442 -88
- package/src/probabilistic-scoring/stat-tools.ts +140 -5
- package/src/tests/anova-tests.ts +1 -1
- package/src/tests/classifiers-tests.ts +1 -1
- package/src/tests/dim-reduction-tests.ts +1 -1
- package/src/tests/linear-methods-tests.ts +1 -1
- package/src/tests/mis-vals-imputation-tests.ts +1 -1
- package/src/tests/pareto-tests.ts +253 -0
- package/src/tests/pmpo-tests.ts +157 -0
- package/test-console-output-1.log +158 -222
- package/test-record-1.mp4 +0 -0
- package/files/mpo-done.ipynb +0 -2123
|
@@ -8,9 +8,8 @@ import * as DG from 'datagrok-api/dg';
|
|
|
8
8
|
//@ts-ignore: no types
|
|
9
9
|
import * as jStat from 'jstat';
|
|
10
10
|
|
|
11
|
-
import {Cutoff, DescriptorStatistics,
|
|
12
|
-
|
|
13
|
-
const SQRT_2_PI = Math.sqrt(2 * Math.PI);
|
|
11
|
+
import {ConfusionMatrix, Cutoff, DescriptorStatistics, ModelEvaluationResult,
|
|
12
|
+
ROC_TRESHOLDS, ROC_TRESHOLDS_COUNT, SigmoidParams} from './pmpo-defs';
|
|
14
13
|
|
|
15
14
|
/** Splits the dataframe into desired and non-desired tables based on the desirability column */
|
|
16
15
|
export function getDesiredTables(df: DG.DataFrame, desirability: DG.Column) {
|
|
@@ -70,6 +69,8 @@ export function getDescriptorStatistics(des: DG.Column, nonDes: DG.Column): Desc
|
|
|
70
69
|
nonDesAvg: nonDesAvg,
|
|
71
70
|
nonDesStd: nonDesStd,
|
|
72
71
|
nonSesLen: nonDesLen,
|
|
72
|
+
min: Math.min(des.stats.min, nonDes.stats.min),
|
|
73
|
+
max: Math.max(des.stats.max, nonDes.stats.max),
|
|
73
74
|
tstat: t,
|
|
74
75
|
pValue: pValue,
|
|
75
76
|
};
|
|
@@ -163,6 +164,140 @@ export function sigmoidS(x: number, x0: number, b: number, c: number): number {
|
|
|
163
164
|
}
|
|
164
165
|
|
|
165
166
|
/** Normal probability density function */
|
|
166
|
-
export function
|
|
167
|
-
return Math.exp(-((x - mu)**2) / (2 * sigma**2))
|
|
167
|
+
export function gaussDesirabilityFunc(x: number, mu: number, sigma: number): number {
|
|
168
|
+
return Math.exp(-((x - mu)**2) / (2 * sigma**2));
|
|
168
169
|
}
|
|
170
|
+
|
|
171
|
+
/** Computes the confusion matrix given desirability (labels) and prediction columns
|
|
172
|
+
* @param desirability - desirability column (boolean)
|
|
173
|
+
* @param prediction - prediction column (numeric)
|
|
174
|
+
* @param threshold - threshold to convert prediction scores to binary labels
|
|
175
|
+
* @return ConfusionMatrix object with TP, TN, FP, FN counts
|
|
176
|
+
*/
|
|
177
|
+
export function getConfusionMatrix(desirability: DG.Column, prediction: DG.Column, threshold: number): ConfusionMatrix {
|
|
178
|
+
if (desirability.length !== prediction.length)
|
|
179
|
+
throw new Error('Failed to compute confusion matrix: columns have different lengths.');
|
|
180
|
+
|
|
181
|
+
if (desirability.type !== DG.COLUMN_TYPE.BOOL)
|
|
182
|
+
throw new Error('Failed to compute confusion matrix: desirability column must be boolean.');
|
|
183
|
+
|
|
184
|
+
if (!prediction.isNumerical)
|
|
185
|
+
throw new Error('Failed to compute confusion matrix: prediction column must be numerical.');
|
|
186
|
+
|
|
187
|
+
let TP = 0;
|
|
188
|
+
let TN = 0;
|
|
189
|
+
let FP = 0;
|
|
190
|
+
let FN = 0;
|
|
191
|
+
|
|
192
|
+
const desRaw = desirability.getRawData();
|
|
193
|
+
const predRaw = prediction.getRawData();
|
|
194
|
+
|
|
195
|
+
let desIdx = 0;
|
|
196
|
+
let curPos = 0;
|
|
197
|
+
let desElem = desRaw[0];
|
|
198
|
+
|
|
199
|
+
// Here, we extract bits from the desirability boolean column in chunks of 32 bits
|
|
200
|
+
for (let predIdx = 0; predIdx < prediction.length; ++predIdx) {
|
|
201
|
+
// console.log(predIdx + 1, ': ',
|
|
202
|
+
// desirability.get(predIdx), '<-->', (desElem >>> curPos) & 1, ' vs ', predRaw[predIdx] >= threshold);
|
|
203
|
+
|
|
204
|
+
if (((desElem >>> curPos) & 1) == 1) { // True actual
|
|
205
|
+
if (predRaw[predIdx] >= threshold) { // True predicted
|
|
206
|
+
++TP;
|
|
207
|
+
} else { // False predicted
|
|
208
|
+
++FN;
|
|
209
|
+
}
|
|
210
|
+
} else { // False actual
|
|
211
|
+
if (predRaw[predIdx] >= threshold) { // True predicted
|
|
212
|
+
++FP;
|
|
213
|
+
} else { // False predicted
|
|
214
|
+
++TN;
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
++curPos;
|
|
219
|
+
|
|
220
|
+
// Move to the next desirability element if we have processed 32 bits
|
|
221
|
+
if (curPos >= 32) {
|
|
222
|
+
curPos = 0;
|
|
223
|
+
++desIdx;
|
|
224
|
+
desElem = desRaw[desIdx];
|
|
225
|
+
}
|
|
226
|
+
} // for predIdx
|
|
227
|
+
|
|
228
|
+
return {TP: TP, TN: TN, FP: FP, FN: FN};
|
|
229
|
+
} // getConfusionMatrix
|
|
230
|
+
|
|
231
|
+
/** Computes Area Under Curve (AUC) given TPR and FPR arrays
|
|
232
|
+
* @param tpr - True Positive Rate array
|
|
233
|
+
* @param fpr - False Positive Rate array
|
|
234
|
+
* @return AUC value
|
|
235
|
+
*/
|
|
236
|
+
export function getAuc(tpr: Float32Array, fpr: Float32Array): number {
|
|
237
|
+
if (tpr.length !== fpr.length)
|
|
238
|
+
throw new Error('Failed to compute AUC: TPR and FPR arrays have different lengths.');
|
|
239
|
+
|
|
240
|
+
let auc = 0.0;
|
|
241
|
+
|
|
242
|
+
for (let i = 1; i < tpr.length; ++i) {
|
|
243
|
+
const xDiff = Math.abs(fpr[i] - fpr[i - 1]);
|
|
244
|
+
const yAvg = (tpr[i] + tpr[i - 1]) / 2.0;
|
|
245
|
+
auc += xDiff * yAvg;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
return auc;
|
|
249
|
+
} // getAuc
|
|
250
|
+
|
|
251
|
+
/** Converts numeric prediction column to boolean based on the given threshold
|
|
252
|
+
* @param numericPrediction - numeric prediction column
|
|
253
|
+
* @param threshold - threshold to convert prediction scores to binary labels
|
|
254
|
+
* @param name - name for the resulting boolean column
|
|
255
|
+
* @return Boolean prediction column
|
|
256
|
+
*/
|
|
257
|
+
export function getBoolPredictionColumn(numericPrediction: DG.Column, threshold: number, name: string): DG.Column {
|
|
258
|
+
if (!numericPrediction.isNumerical)
|
|
259
|
+
throw new Error('Failed to compute confusion matrix: prediction column must be numerical.');
|
|
260
|
+
|
|
261
|
+
const size = numericPrediction.length;
|
|
262
|
+
const boolPredData = new Array<boolean>(size);
|
|
263
|
+
const predRaw = numericPrediction.getRawData();
|
|
264
|
+
|
|
265
|
+
for (let i = 0; i < size; ++i)
|
|
266
|
+
boolPredData[i] = (predRaw[i] >= threshold);
|
|
267
|
+
|
|
268
|
+
return DG.Column.fromList(DG.COLUMN_TYPE.BOOL, name, boolPredData);
|
|
269
|
+
} // getBoolPredictionColumn
|
|
270
|
+
|
|
271
|
+
/** Computes pMPO model evaluation metrics: AUC, optimal threshold, TPR and FPR arrays
|
|
272
|
+
* @param desirability - desirability column (boolean)
|
|
273
|
+
* @param prediction - prediction column (numeric)
|
|
274
|
+
* @return ModelEvaluationResult object with AUC, optimal threshold, TPR and FPR arrays
|
|
275
|
+
*/
|
|
276
|
+
export function getPmpoEvaluation(desirability: DG.Column, prediction: DG.Column): ModelEvaluationResult {
|
|
277
|
+
const tpr = new Float32Array(ROC_TRESHOLDS_COUNT);
|
|
278
|
+
const fpr = new Float32Array(ROC_TRESHOLDS_COUNT);
|
|
279
|
+
|
|
280
|
+
let bestJ = -1;
|
|
281
|
+
let currentJ = -1;
|
|
282
|
+
let bestThreshold = ROC_TRESHOLDS[0];
|
|
283
|
+
|
|
284
|
+
// Compute TPR and FPR for each threshold
|
|
285
|
+
for (let i = 0; i < ROC_TRESHOLDS_COUNT; ++i) {
|
|
286
|
+
const confusion = getConfusionMatrix(desirability, prediction, ROC_TRESHOLDS[i]);
|
|
287
|
+
tpr[i] = (confusion.TP + confusion.FN) > 0 ? confusion.TP / (confusion.TP + confusion.FN) : 0;
|
|
288
|
+
fpr[i] = (confusion.FP + confusion.TN) > 0 ? confusion.FP / (confusion.FP + confusion.TN) : 0;
|
|
289
|
+
currentJ = tpr[i] - fpr[i];
|
|
290
|
+
|
|
291
|
+
if (currentJ > bestJ) {
|
|
292
|
+
bestJ = currentJ;
|
|
293
|
+
bestThreshold = ROC_TRESHOLDS[i];
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
return {
|
|
298
|
+
auc: getAuc(tpr, fpr),
|
|
299
|
+
threshold: bestThreshold,
|
|
300
|
+
tpr: tpr,
|
|
301
|
+
fpr: fpr,
|
|
302
|
+
};
|
|
303
|
+
} // getPmpoEvaluation
|
package/src/tests/anova-tests.ts
CHANGED
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
|
|
10
10
|
import {oneWayAnova, FactorizedData} from '../anova/anova-tools';
|
|
11
11
|
|
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
|
|
10
10
|
import {classificationDataset, accuracy} from './utils';
|
|
11
11
|
import {SoftmaxClassifier} from '../softmax-classifier';
|
|
@@ -5,7 +5,7 @@ import {_package} from '../package-test';
|
|
|
5
5
|
|
|
6
6
|
// tests for dimensionality reduction
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
import {DimReductionMethods} from '@datagrok-libraries/ml/src/multi-column-dimensionality-reduction/types';
|
|
10
10
|
import {KnownMetrics, NumberMetricsNames, StringMetricsNames} from '@datagrok-libraries/ml/src/typed-metrics';
|
|
11
11
|
import {multiColReduceDimensionality}
|
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
import {computePCA} from '../eda-tools';
|
|
10
10
|
import {getPlsAnalysis} from '../pls/pls-tools';
|
|
11
11
|
import {PlsModel} from '../pls/pls-ml';
|
|
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
|
|
|
5
5
|
import * as DG from 'datagrok-api/dg';
|
|
6
6
|
import {_package} from '../package-test';
|
|
7
7
|
|
|
8
|
-
import {category, expect, test} from '@datagrok-libraries/
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
9
9
|
|
|
10
10
|
import {MetricInfo, DISTANCE_TYPE, impute} from '../missing-values-imputation/knn-imputer';
|
|
11
11
|
import {getFeatureInputSettings} from '../missing-values-imputation/ui';
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
// Tests for Pareto Front Computations
|
|
2
|
+
// Performance tests for the Pareto optimality algorithm
|
|
3
|
+
|
|
4
|
+
import * as grok from 'datagrok-api/grok';
|
|
5
|
+
import * as ui from 'datagrok-api/ui';
|
|
6
|
+
import * as DG from 'datagrok-api/dg';
|
|
7
|
+
import {_package} from '../package-test';
|
|
8
|
+
|
|
9
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
10
|
+
|
|
11
|
+
import {getParetoMask} from '../pareto-optimization/pareto-computations';
|
|
12
|
+
import {OPT_TYPE, NumericArray} from '../pareto-optimization/defs';
|
|
13
|
+
|
|
14
|
+
const TIMEOUT = 5000;
|
|
15
|
+
|
|
16
|
+
// Test dataset sizes
|
|
17
|
+
const ROWS_COUNT = 1000000;
|
|
18
|
+
const M = 1000000;
|
|
19
|
+
const COLS_COUNT = 2;
|
|
20
|
+
const suffix = M < 1e6 ? 'K' : 'M';
|
|
21
|
+
const DATASET_SIZE_LABEL = `${ROWS_COUNT / M}${suffix} points, ${COLS_COUNT}D`;
|
|
22
|
+
|
|
23
|
+
/** Generates synthetic numeric data for Pareto front testing */
|
|
24
|
+
function generateSyntheticData(nPoints: number, nDims: number, seed: number = 42): NumericArray[] {
|
|
25
|
+
const data: NumericArray[] = [];
|
|
26
|
+
|
|
27
|
+
// Simple deterministic pseudo-random generator for reproducibility
|
|
28
|
+
let rng = seed;
|
|
29
|
+
const random = () => {
|
|
30
|
+
rng = (rng * 1664525 + 1013904223) % 4294967296;
|
|
31
|
+
return rng / 4294967296;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
for (let d = 0; d < nDims; d++) {
|
|
35
|
+
const column = new Float32Array(nPoints);
|
|
36
|
+
for (let i = 0; i < nPoints; i++) {
|
|
37
|
+
// Generate values with some correlation to create realistic Pareto fronts
|
|
38
|
+
column[i] = random() * 100 + (d * 10);
|
|
39
|
+
}
|
|
40
|
+
data.push(column);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return data;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/** Generates optimization sense array */
|
|
47
|
+
function generateSense(nDims: number, pattern: 'all-min' | 'all-max' | 'mixed'): OPT_TYPE[] {
|
|
48
|
+
const sense: OPT_TYPE[] = [];
|
|
49
|
+
|
|
50
|
+
for (let d = 0; d < nDims; d++) {
|
|
51
|
+
if (pattern === 'all-min') {
|
|
52
|
+
sense.push(OPT_TYPE.MIN);
|
|
53
|
+
} else if (pattern === 'all-max') {
|
|
54
|
+
sense.push(OPT_TYPE.MAX);
|
|
55
|
+
} else {
|
|
56
|
+
// Mixed: alternate between MIN and MAX
|
|
57
|
+
sense.push(d % 2 === 0 ? OPT_TYPE.MIN : OPT_TYPE.MAX);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return sense;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/** Generates null indices set for testing missing value handling */
|
|
65
|
+
function generateNullIndices(nPoints: number, nullRatio: number): Set<number> {
|
|
66
|
+
const nullCount = Math.floor(nPoints * nullRatio);
|
|
67
|
+
const nullIndices = new Set<number>();
|
|
68
|
+
|
|
69
|
+
// Distribute null indices evenly
|
|
70
|
+
const step = Math.floor(nPoints / nullCount);
|
|
71
|
+
for (let i = 0; i < nullCount; i++) {
|
|
72
|
+
nullIndices.add(i * step);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
return nullIndices;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
/** Validates Pareto mask result */
|
|
79
|
+
function validateParetoMask(mask: boolean[], nPoints: number): void {
|
|
80
|
+
if (mask.length !== nPoints) {
|
|
81
|
+
throw new Error(`Invalid mask length: expected ${nPoints}, got ${mask.length}`);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
const optimalCount = mask.filter(x => x).length;
|
|
85
|
+
if (optimalCount === 0) {
|
|
86
|
+
throw new Error('No optimal points found');
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if (optimalCount === nPoints) {
|
|
90
|
+
grok.shell.warning('All points are optimal - data may be degenerate');
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
category('Pareto optimization', () => {
|
|
95
|
+
|
|
96
|
+
test(`Performance: ${DATASET_SIZE_LABEL}`, async () => {
|
|
97
|
+
let mask: boolean[] | null = null;
|
|
98
|
+
let error: Error | null = null;
|
|
99
|
+
|
|
100
|
+
try {
|
|
101
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
102
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
103
|
+
mask = getParetoMask(data, sense, ROWS_COUNT);
|
|
104
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
105
|
+
} catch (e) {
|
|
106
|
+
error = e as Error;
|
|
107
|
+
grok.shell.error(error.message);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
111
|
+
expect(error === null, true, error?.message ?? '');
|
|
112
|
+
}, {timeout: TIMEOUT});
|
|
113
|
+
|
|
114
|
+
// Tests for different optimization patterns
|
|
115
|
+
test(`Performance: ${DATASET_SIZE_LABEL}, all minimize`, async () => {
|
|
116
|
+
let mask: boolean[] | null = null;
|
|
117
|
+
let error: Error | null = null;
|
|
118
|
+
|
|
119
|
+
try {
|
|
120
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
121
|
+
const sense = generateSense(COLS_COUNT, 'all-min');
|
|
122
|
+
mask = getParetoMask(data, sense, ROWS_COUNT);
|
|
123
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
124
|
+
} catch (e) {
|
|
125
|
+
error = e as Error;
|
|
126
|
+
grok.shell.error(error.message);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
130
|
+
expect(error === null, true, error?.message ?? '');
|
|
131
|
+
}, {timeout: TIMEOUT});
|
|
132
|
+
|
|
133
|
+
test(`Performance: ${DATASET_SIZE_LABEL}, all maximize`, async () => {
|
|
134
|
+
let mask: boolean[] | null = null;
|
|
135
|
+
let error: Error | null = null;
|
|
136
|
+
|
|
137
|
+
try {
|
|
138
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
139
|
+
const sense = generateSense(COLS_COUNT, 'all-max');
|
|
140
|
+
mask = getParetoMask(data, sense, ROWS_COUNT);
|
|
141
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
142
|
+
} catch (e) {
|
|
143
|
+
error = e as Error;
|
|
144
|
+
grok.shell.error(error.message);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
148
|
+
expect(error === null, true, error?.message ?? '');
|
|
149
|
+
}, {timeout: TIMEOUT});
|
|
150
|
+
|
|
151
|
+
// Tests with missing values
|
|
152
|
+
test(`Performance: ${DATASET_SIZE_LABEL} with 10% null indices`, async () => {
|
|
153
|
+
let mask: boolean[] | null = null;
|
|
154
|
+
let error: Error | null = null;
|
|
155
|
+
|
|
156
|
+
try {
|
|
157
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
158
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
159
|
+
const nullIndices = generateNullIndices(ROWS_COUNT, 0.1);
|
|
160
|
+
mask = getParetoMask(data, sense, ROWS_COUNT, nullIndices);
|
|
161
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
162
|
+
} catch (e) {
|
|
163
|
+
error = e as Error;
|
|
164
|
+
grok.shell.error(error.message);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
168
|
+
expect(error === null, true, error?.message ?? '');
|
|
169
|
+
}, {timeout: TIMEOUT});
|
|
170
|
+
|
|
171
|
+
test(`Performance: ${DATASET_SIZE_LABEL} with 25% null indices`, async () => {
|
|
172
|
+
let mask: boolean[] | null = null;
|
|
173
|
+
let error: Error | null = null;
|
|
174
|
+
|
|
175
|
+
try {
|
|
176
|
+
const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
|
|
177
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
178
|
+
const nullIndices = generateNullIndices(ROWS_COUNT, 0.25);
|
|
179
|
+
mask = getParetoMask(data, sense, ROWS_COUNT, nullIndices);
|
|
180
|
+
validateParetoMask(mask, ROWS_COUNT);
|
|
181
|
+
} catch (e) {
|
|
182
|
+
error = e as Error;
|
|
183
|
+
grok.shell.error(error.message);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
187
|
+
expect(error === null, true, error?.message ?? '');
|
|
188
|
+
}, {timeout: TIMEOUT});
|
|
189
|
+
|
|
190
|
+
// Edge cases
|
|
191
|
+
test('Edge case: Empty dataset', async () => {
|
|
192
|
+
let mask: boolean[] | null = null;
|
|
193
|
+
let error: Error | null = null;
|
|
194
|
+
|
|
195
|
+
try {
|
|
196
|
+
const data: NumericArray[] = [new Float32Array(0), new Float32Array(0)];
|
|
197
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
198
|
+
mask = getParetoMask(data, sense, 0);
|
|
199
|
+
} catch (e) {
|
|
200
|
+
error = e as Error;
|
|
201
|
+
grok.shell.error(error.message);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
205
|
+
expect(mask!.length, 0, 'Empty dataset should return empty mask');
|
|
206
|
+
expect(error === null, true, error?.message ?? '');
|
|
207
|
+
}, {timeout: TIMEOUT});
|
|
208
|
+
|
|
209
|
+
test('Edge case: Single point', async () => {
|
|
210
|
+
let mask: boolean[] | null = null;
|
|
211
|
+
let error: Error | null = null;
|
|
212
|
+
|
|
213
|
+
try {
|
|
214
|
+
const data: NumericArray[] = [new Float32Array([1.0]), new Float32Array([2.0])];
|
|
215
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
216
|
+
|
|
217
|
+
mask = getParetoMask(data, sense, 1);
|
|
218
|
+
} catch (e) {
|
|
219
|
+
error = e as Error;
|
|
220
|
+
grok.shell.error(error.message);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
224
|
+
expect(mask!.length, 1, 'Single point dataset should return mask with one element');
|
|
225
|
+
expect(mask![0], true, 'Single point should be optimal');
|
|
226
|
+
expect(error === null, true, error?.message ?? '');
|
|
227
|
+
}, {timeout: TIMEOUT});
|
|
228
|
+
|
|
229
|
+
test('Edge case: All identical points', async () => {
|
|
230
|
+
let mask: boolean[] | null = null;
|
|
231
|
+
let error: Error | null = null;
|
|
232
|
+
|
|
233
|
+
try {
|
|
234
|
+
const nPoints = 100;
|
|
235
|
+
const data: NumericArray[] = [
|
|
236
|
+
new Float32Array(nPoints).fill(5.0),
|
|
237
|
+
new Float32Array(nPoints).fill(10.0),
|
|
238
|
+
];
|
|
239
|
+
const sense = generateSense(COLS_COUNT, 'mixed');
|
|
240
|
+
|
|
241
|
+
mask = getParetoMask(data, sense, nPoints);
|
|
242
|
+
} catch (e) {
|
|
243
|
+
error = e as Error;
|
|
244
|
+
grok.shell.error(error.message);
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
expect(mask !== null, true, 'Failed to compute Pareto mask');
|
|
248
|
+
expect(mask!.length, 100, 'Should return mask with correct length');
|
|
249
|
+
const optimalCount = mask!.filter(x => x).length;
|
|
250
|
+
expect(optimalCount > 0, true, 'At least some identical points should be optimal');
|
|
251
|
+
expect(error === null, true, error?.message ?? '');
|
|
252
|
+
}, {timeout: TIMEOUT});
|
|
253
|
+
});
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
// Tests for Probabilistic MPO (pMPO)
|
|
2
|
+
// Reference scores are pre-computed and stored in the 'drugs-props-train-scores.csv' file.
|
|
3
|
+
// This scores are computed using the library: https://github.com/Merck/pmpo
|
|
4
|
+
|
|
5
|
+
import * as grok from 'datagrok-api/grok';
|
|
6
|
+
import * as ui from 'datagrok-api/ui';
|
|
7
|
+
import * as DG from 'datagrok-api/dg';
|
|
8
|
+
import {_package} from '../package-test';
|
|
9
|
+
|
|
10
|
+
import {category, expect, test} from '@datagrok-libraries/test/src/test';
|
|
11
|
+
|
|
12
|
+
import {Pmpo} from '../probabilistic-scoring/prob-scoring';
|
|
13
|
+
import {P_VAL_TRES_DEFAULT, Q_CUTOFF_DEFAULT, R2_DEFAULT, SCORES_PATH,
|
|
14
|
+
SOURCE_PATH} from '../probabilistic-scoring/pmpo-defs';
|
|
15
|
+
import {getSynteticPmpoData} from '../probabilistic-scoring/data-generator';
|
|
16
|
+
|
|
17
|
+
const TIMEOUT = 10000;
|
|
18
|
+
const MAD_THRESH = 1E-6;
|
|
19
|
+
|
|
20
|
+
const DESIRABILITY_COL_NAME = 'CNS';
|
|
21
|
+
const DESCRIPTOR_NAMES = ['TPSA', 'TPSA_S', 'HBA', 'HBD', 'MW', 'nAtoms',
|
|
22
|
+
'cLogD_ACD_v15', 'mapKa', 'cLogP_Biobyte', 'mbpKa', 'cLogP_ACD_v15', 'ALogP98'];
|
|
23
|
+
const SCORES_NAME = 'Score';
|
|
24
|
+
const DRUG = 'Drug';
|
|
25
|
+
|
|
26
|
+
const SIGMOIDAL = 'Sigmoidal';
|
|
27
|
+
const GAUSSIAN = 'Gaussian';
|
|
28
|
+
const PMPO_MODES = [SIGMOIDAL, GAUSSIAN];
|
|
29
|
+
|
|
30
|
+
const SAMPLES_K = 100;
|
|
31
|
+
const SAMPLES_COUNT = 1000 * SAMPLES_K;
|
|
32
|
+
|
|
33
|
+
/** Computes the maximum absolute deviation between pMPO scores in two data frames */
|
|
34
|
+
function getScoreMaxDeviation(sourceDrugCol: DG.Column, sourceScores: DG.Column,
|
|
35
|
+
referenceDrugCol: DG.Column, referenceScores: DG.Column): number {
|
|
36
|
+
let mad = 0;
|
|
37
|
+
|
|
38
|
+
const sourceDrugList = sourceDrugCol.toList();
|
|
39
|
+
const referenceDrugList = referenceDrugCol.toList();
|
|
40
|
+
|
|
41
|
+
const sourceScoresRaw = sourceScores.getRawData();
|
|
42
|
+
const referenceScoresRaw = referenceScores.getRawData();
|
|
43
|
+
|
|
44
|
+
sourceDrugList.forEach((name, idx) => {
|
|
45
|
+
const refIdx = referenceDrugList.indexOf(name);
|
|
46
|
+
|
|
47
|
+
if (refIdx < 0)
|
|
48
|
+
throw new Error(`Failed to compare pMPO scores: the "${name}" drug is missing in the reference data.`);
|
|
49
|
+
|
|
50
|
+
mad = Math.max(mad, Math.abs(sourceScoresRaw[idx] - referenceScoresRaw[refIdx]));
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
return mad;
|
|
54
|
+
} // getScoreMaxDeviation
|
|
55
|
+
|
|
56
|
+
category('Probabilistic MPO', () => {
|
|
57
|
+
// Correctness tests: compare pMPO scores with reference scores
|
|
58
|
+
PMPO_MODES.forEach((refScoreName) => {
|
|
59
|
+
const useSigmoid = (refScoreName == SIGMOIDAL);
|
|
60
|
+
|
|
61
|
+
test('Correctness: ' + refScoreName, async () => {
|
|
62
|
+
let sourceDf: DG.DataFrame | null = null;
|
|
63
|
+
let referenceDf: DG.DataFrame | null = null;
|
|
64
|
+
let desirability: DG.Column | null = null;
|
|
65
|
+
let descriptors: DG.Column[] = [];
|
|
66
|
+
let sourceDrugCol: DG.Column | null = null;
|
|
67
|
+
let referenceDrugCol: DG.Column | null = null;
|
|
68
|
+
let referencePrediction: DG.Column | null = null;
|
|
69
|
+
let mad: number | null = null;
|
|
70
|
+
|
|
71
|
+
try {
|
|
72
|
+
// Load data
|
|
73
|
+
sourceDf = await grok.dapi.files.readCsv(SOURCE_PATH);
|
|
74
|
+
referenceDf = await grok.dapi.files.readCsv(SCORES_PATH);
|
|
75
|
+
|
|
76
|
+
// Extract training items
|
|
77
|
+
desirability = sourceDf.col(DESIRABILITY_COL_NAME);
|
|
78
|
+
descriptors = sourceDf.columns.byNames(DESCRIPTOR_NAMES);
|
|
79
|
+
|
|
80
|
+
if (desirability == null)
|
|
81
|
+
throw new Error();
|
|
82
|
+
|
|
83
|
+
// Train pMPO model
|
|
84
|
+
const trainRes = Pmpo.fit(
|
|
85
|
+
sourceDf,
|
|
86
|
+
DG.DataFrame.fromColumns(descriptors).columns,
|
|
87
|
+
desirability,
|
|
88
|
+
P_VAL_TRES_DEFAULT,
|
|
89
|
+
R2_DEFAULT,
|
|
90
|
+
Q_CUTOFF_DEFAULT,
|
|
91
|
+
);
|
|
92
|
+
|
|
93
|
+
// Apply pMPO
|
|
94
|
+
const prediction = Pmpo.predict(sourceDf, trainRes.params, useSigmoid, SCORES_NAME);
|
|
95
|
+
|
|
96
|
+
// Compare with reference scores
|
|
97
|
+
sourceDrugCol = sourceDf.col(DRUG);
|
|
98
|
+
referenceDrugCol = referenceDf.col(DRUG);
|
|
99
|
+
referencePrediction = referenceDf.col(refScoreName);
|
|
100
|
+
|
|
101
|
+
mad = getScoreMaxDeviation(sourceDrugCol!, prediction, referenceDrugCol!, referencePrediction!);
|
|
102
|
+
|
|
103
|
+
//console.log(refScoreName, ': max absolute deviation of pMPO scores:', mad);
|
|
104
|
+
} catch (error) {
|
|
105
|
+
grok.shell.error((error as Error).message);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
expect(sourceDf !== null, true, 'Failed to load the source data: ' + SOURCE_PATH);
|
|
109
|
+
expect(referenceDf !== null, true, 'Failed to load the scores data: ' + SCORES_PATH);
|
|
110
|
+
expect(desirability !== null, true, 'Inconsistent source data: no column ' + DESIRABILITY_COL_NAME);
|
|
111
|
+
expect(descriptors.length, DESCRIPTOR_NAMES.length, 'Inconsistent source data: no enough of columns');
|
|
112
|
+
expect(sourceDrugCol !== null, true, 'Inconsistent source data: no column ' + DRUG);
|
|
113
|
+
expect(referenceDrugCol !== null, true, 'Inconsistent reference data: no column ' + DRUG);
|
|
114
|
+
expect(referencePrediction !== null, true, 'Inconsistent reference data: no column ' + SCORES_NAME);
|
|
115
|
+
expect(mad !== null, true, 'Failed to compare pMPO scores with the reference data');
|
|
116
|
+
expect(mad! < MAD_THRESH, true, `Max absolute deviation of pMPO scores exceeds the threshold (${MAD_THRESH})`);
|
|
117
|
+
}, {timeout: TIMEOUT});
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
// Performance tests: measure time of pMPO training
|
|
121
|
+
test('Performance: ' + SAMPLES_K + 'K drugs, ' + DESCRIPTOR_NAMES.length + ' descriptors', async () => {
|
|
122
|
+
let sourceDf: DG.DataFrame | null = null;
|
|
123
|
+
let desirability: DG.Column | null = null;
|
|
124
|
+
let descriptors: DG.Column[] = [];
|
|
125
|
+
|
|
126
|
+
try {
|
|
127
|
+
// Generate synthetic data
|
|
128
|
+
sourceDf = await getSynteticPmpoData(SAMPLES_COUNT);
|
|
129
|
+
|
|
130
|
+
// Extract training items
|
|
131
|
+
desirability = sourceDf.col(DESIRABILITY_COL_NAME);
|
|
132
|
+
descriptors = sourceDf.columns.byNames(DESCRIPTOR_NAMES);
|
|
133
|
+
|
|
134
|
+
if (desirability == null)
|
|
135
|
+
throw new Error();
|
|
136
|
+
|
|
137
|
+
// Train pMPO model
|
|
138
|
+
const trainRes = Pmpo.fit(
|
|
139
|
+
sourceDf,
|
|
140
|
+
DG.DataFrame.fromColumns(descriptors).columns,
|
|
141
|
+
desirability,
|
|
142
|
+
P_VAL_TRES_DEFAULT,
|
|
143
|
+
R2_DEFAULT,
|
|
144
|
+
Q_CUTOFF_DEFAULT,
|
|
145
|
+
);
|
|
146
|
+
|
|
147
|
+
// Apply pMPO
|
|
148
|
+
Pmpo.predict(sourceDf, trainRes.params, true, SCORES_NAME);
|
|
149
|
+
} catch (error) {
|
|
150
|
+
grok.shell.error((error as Error).message);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
expect(sourceDf !== null, true, 'Failed to load the source data: ' + SOURCE_PATH);
|
|
154
|
+
expect(desirability !== null, true, 'Inconsistent source data: no column ' + DESIRABILITY_COL_NAME);
|
|
155
|
+
expect(descriptors.length, DESCRIPTOR_NAMES.length, 'Inconsistent source data: no enough of columns');
|
|
156
|
+
}, {timeout: TIMEOUT});
|
|
157
|
+
});
|