@datagrok/eda 1.4.12 → 1.4.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc.json +0 -1
- package/CHANGELOG.md +8 -0
- package/CLAUDE.md +185 -0
- package/css/pmpo.css +9 -0
- package/dist/package-test.js +1 -1
- package/dist/package-test.js.map +1 -1
- package/dist/package.js +1 -1
- package/dist/package.js.map +1 -1
- package/eslintrc.json +0 -1
- package/files/drugs-props-train-scores.csv +664 -0
- package/package.json +7 -3
- package/src/package-api.ts +7 -3
- package/src/package-test.ts +4 -1
- package/src/package.g.ts +21 -9
- package/src/package.ts +32 -23
- package/src/pareto-optimization/pareto-computations.ts +6 -0
- package/src/probabilistic-scoring/data-generator.ts +157 -0
- package/src/probabilistic-scoring/nelder-mead.ts +204 -0
- package/src/probabilistic-scoring/pmpo-defs.ts +112 -2
- package/src/probabilistic-scoring/pmpo-utils.ts +100 -77
- package/src/probabilistic-scoring/prob-scoring.ts +442 -88
- package/src/probabilistic-scoring/stat-tools.ts +140 -5
- package/src/tests/anova-tests.ts +1 -1
- package/src/tests/classifiers-tests.ts +1 -1
- package/src/tests/dim-reduction-tests.ts +1 -1
- package/src/tests/linear-methods-tests.ts +1 -1
- package/src/tests/mis-vals-imputation-tests.ts +1 -1
- package/src/tests/pareto-tests.ts +253 -0
- package/src/tests/pmpo-tests.ts +157 -0
- package/test-console-output-1.log +158 -222
- package/test-record-1.mp4 +0 -0
- package/files/mpo-done.ipynb +0 -2123
|
@@ -10,15 +10,37 @@ import {MpoProfileEditor} from '@datagrok-libraries/statistics/src/mpo/mpo-profi
|
|
|
10
10
|
|
|
11
11
|
import '../../css/pmpo.css';
|
|
12
12
|
|
|
13
|
-
import {getDesiredTables, getDescriptorStatistics,
|
|
13
|
+
import {getDesiredTables, getDescriptorStatistics, getBoolPredictionColumn, getPmpoEvaluation} from './stat-tools';
|
|
14
14
|
import {MIN_SAMPLES_COUNT, PMPO_NON_APPLICABLE, DescriptorStatistics, P_VAL_TRES_MIN, DESCR_TITLE,
|
|
15
15
|
R2_MIN, Q_CUTOFF_MIN, PmpoParams, SCORES_TITLE, DESCR_TABLE_TITLE, PMPO_COMPUTE_FAILED, SELECTED_TITLE,
|
|
16
|
-
P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE
|
|
16
|
+
P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE,
|
|
17
|
+
P_VAL_TRES_DEFAULT, R2_DEFAULT, Q_CUTOFF_DEFAULT, USE_SIGMOID_DEFAULT, ROC_TRESHOLDS,
|
|
18
|
+
FPR_TITLE, TPR_TITLE, COLORS, THRESHOLD, AUTO_TUNE_MAX_APPLICABLE_ROWS, DEFAULT_OPTIMIZATION_SETTINGS,
|
|
19
|
+
P_VAL_TRES_MAX, R2_MAX, Q_CUTOFF_MAX, OptimalPoint, LOW_PARAMS_BOUNDS, HIGH_PARAMS_BOUNDS, FORMAT} from './pmpo-defs';
|
|
17
20
|
import {addSelectedDescriptorsCol, getDescriptorStatisticsTable, getFilteredByPvalue, getFilteredByCorrelations,
|
|
18
21
|
getModelParams, getDescrTooltip, saveModel, getScoreTooltip, getDesirabilityProfileJson, getCorrelationTriples,
|
|
19
|
-
addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding} from './pmpo-utils';
|
|
22
|
+
addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding, PmpoError} from './pmpo-utils';
|
|
20
23
|
import {getOutputPalette} from '../pareto-optimization/utils';
|
|
21
24
|
import {OPT_TYPE} from '../pareto-optimization/defs';
|
|
25
|
+
import {optimizeNM} from './nelder-mead';
|
|
26
|
+
|
|
27
|
+
export type PmpoTrainingResult = {
|
|
28
|
+
params: Map<string, PmpoParams>,
|
|
29
|
+
descrStatsTable: DG.DataFrame,
|
|
30
|
+
selectedByPvalue: string[],
|
|
31
|
+
selectedByCorr: string[],
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
/** Type for pMPO training controls */
|
|
35
|
+
export type Controls = {form: HTMLElement, saveBtn: HTMLButtonElement};
|
|
36
|
+
|
|
37
|
+
/** Type for pMPO elements */
|
|
38
|
+
export type PmpoAppItems = {
|
|
39
|
+
statsGrid: DG.Viewer;
|
|
40
|
+
rocCurve: DG.Viewer;
|
|
41
|
+
confusionMatrix: DG.Viewer;
|
|
42
|
+
controls: Controls;
|
|
43
|
+
};
|
|
22
44
|
|
|
23
45
|
/** Class implementing probabilistic MPO (pMPO) model training and prediction */
|
|
24
46
|
export class Pmpo {
|
|
@@ -131,23 +153,93 @@ export class Pmpo {
|
|
|
131
153
|
return true;
|
|
132
154
|
} // isTableValid
|
|
133
155
|
|
|
156
|
+
/** Fits the pMPO model to the given data and returns training results */
|
|
157
|
+
static fit(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
|
|
158
|
+
pValTresh: number, r2Tresh: number, qCutoff: number, toCheckApplicability: boolean = true): PmpoTrainingResult {
|
|
159
|
+
if (toCheckApplicability) {
|
|
160
|
+
if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
|
|
161
|
+
throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
const descriptorNames = descriptors.names();
|
|
165
|
+
const {desired, nonDesired} = getDesiredTables(df, desirability);
|
|
166
|
+
|
|
167
|
+
// Compute descriptors' statistics
|
|
168
|
+
const descrStats = new Map<string, DescriptorStatistics>();
|
|
169
|
+
descriptorNames.forEach((name) => {
|
|
170
|
+
descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
|
|
171
|
+
});
|
|
172
|
+
const descrStatsTable = getDescriptorStatisticsTable(descrStats);
|
|
173
|
+
|
|
174
|
+
// Set p-value column color coding
|
|
175
|
+
setPvalColumnColorCoding(descrStatsTable, pValTresh);
|
|
176
|
+
|
|
177
|
+
// Filter by p-value
|
|
178
|
+
const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
|
|
179
|
+
|
|
180
|
+
if (selectedByPvalue.length < 1)
|
|
181
|
+
throw new PmpoError('Cannot train pMPO model: all descriptors have high p-values (not significant).');
|
|
182
|
+
|
|
183
|
+
// Compute correlation triples
|
|
184
|
+
const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
|
|
185
|
+
|
|
186
|
+
// Filter by correlations
|
|
187
|
+
const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
|
|
188
|
+
|
|
189
|
+
// Add the Selected column
|
|
190
|
+
addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
|
|
191
|
+
|
|
192
|
+
// Add correlation columns
|
|
193
|
+
addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
|
|
194
|
+
|
|
195
|
+
// Set correlation columns color coding
|
|
196
|
+
setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
|
|
197
|
+
|
|
198
|
+
// Compute pMPO parameters - training
|
|
199
|
+
const params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
|
|
200
|
+
|
|
201
|
+
return {
|
|
202
|
+
params: params,
|
|
203
|
+
descrStatsTable: descrStatsTable,
|
|
204
|
+
selectedByPvalue: selectedByPvalue,
|
|
205
|
+
selectedByCorr: selectedByCorr,
|
|
206
|
+
};
|
|
207
|
+
} // fitModelParams
|
|
208
|
+
|
|
134
209
|
/** Predicts pMPO scores for the given data frame using provided pMPO parameters */
|
|
135
|
-
static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, predictionName: string): DG.Column {
|
|
210
|
+
static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, useSigmoid: boolean, predictionName: string): DG.Column {
|
|
136
211
|
const count = df.rowCount;
|
|
137
212
|
const scores = new Float64Array(count).fill(0);
|
|
138
|
-
let x = 0;
|
|
139
213
|
|
|
140
214
|
// Compute pMPO scores (see https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
|
|
141
215
|
params.forEach((param, name) => {
|
|
142
216
|
const col = df.col(name);
|
|
217
|
+
const b = param.b;
|
|
218
|
+
const c = param.c;
|
|
219
|
+
const x0 = param.cutoff;
|
|
220
|
+
let weight = param.weight;
|
|
221
|
+
const avg = param.desAvg;
|
|
222
|
+
const std = param.desStd;
|
|
223
|
+
const frac = 1.0 / (2 * std**2);
|
|
143
224
|
|
|
144
225
|
if (col == null)
|
|
145
|
-
throw new Error(`
|
|
226
|
+
throw new Error(`Failed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
|
|
146
227
|
|
|
147
228
|
const vals = col.getRawData();
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
229
|
+
|
|
230
|
+
if (useSigmoid) {
|
|
231
|
+
if (c > 0) {
|
|
232
|
+
for (let i = 0; i < count; ++i)
|
|
233
|
+
scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac) / (1.0 + b * (c ** (-(vals[i] - x0))));
|
|
234
|
+
} else {
|
|
235
|
+
weight = weight / (1.0 + b);
|
|
236
|
+
|
|
237
|
+
for (let i = 0; i < count; ++i)
|
|
238
|
+
scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
|
|
239
|
+
}
|
|
240
|
+
} else {
|
|
241
|
+
for (let i = 0; i < count; ++i)
|
|
242
|
+
scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
|
|
151
243
|
}
|
|
152
244
|
});
|
|
153
245
|
|
|
@@ -161,16 +253,34 @@ export class Pmpo {
|
|
|
161
253
|
private boolCols: DG.Column[];
|
|
162
254
|
private numericCols: DG.Column[];
|
|
163
255
|
|
|
164
|
-
private initTable =
|
|
256
|
+
private initTable = DG.DataFrame.create();
|
|
165
257
|
|
|
166
258
|
private statGrid = DG.Viewer.grid(this.initTable, {showTitle: true, title: DESCR_TABLE_TITLE});
|
|
167
259
|
|
|
168
260
|
private predictionName = SCORES_TITLE;
|
|
261
|
+
private boolPredictionName = '';
|
|
169
262
|
|
|
170
263
|
private desirabilityProfileRoots = new Map<string, HTMLElement>();
|
|
171
|
-
|
|
264
|
+
|
|
265
|
+
private rocCurve = DG.Viewer.scatterPlot(this.initTable, {
|
|
266
|
+
showTitle: true,
|
|
267
|
+
showSizeSelector: false,
|
|
268
|
+
showColorSelector: false,
|
|
269
|
+
});
|
|
270
|
+
|
|
271
|
+
private confusionMatrix = DG.Viewer.fromType('Confusion matrix', this.initTable, {
|
|
272
|
+
xColumnName: 'control',
|
|
273
|
+
yColumnName: 'control',
|
|
274
|
+
showTitle: true,
|
|
275
|
+
title: 'Confusion Matrix',
|
|
276
|
+
descriptionPosition: 'Bottom',
|
|
277
|
+
description: 'Confusion matrix for the predicted vs actual desirability labels.',
|
|
278
|
+
descriptionVisibilityMode: 'Always',
|
|
279
|
+
});
|
|
280
|
+
|
|
281
|
+
constructor(df: DG.DataFrame, view?: DG.TableView) {
|
|
172
282
|
this.table = df;
|
|
173
|
-
this.view = grok.shell.tableView(df.name) ?? grok.shell.addTableView(df);
|
|
283
|
+
this.view = view ?? (grok.shell.tableView(df.name) ?? grok.shell.addTableView(df));
|
|
174
284
|
this.boolCols = this.getBoolCols();
|
|
175
285
|
this.numericCols = this.getValidNumericCols();
|
|
176
286
|
this.predictionName = df.columns.getUnusedName(SCORES_TITLE);
|
|
@@ -294,14 +404,12 @@ export class Pmpo {
|
|
|
294
404
|
} else {
|
|
295
405
|
const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
|
|
296
406
|
|
|
297
|
-
if (
|
|
298
|
-
const startText = (colName === WEIGHT_TITLE) ? 'No weight' : 'No chart shown';
|
|
299
|
-
|
|
407
|
+
if (colName === WEIGHT_TITLE) {
|
|
300
408
|
if (!this.desirabilityProfileRoots.has(descriptor)) {
|
|
301
409
|
if (selectedByPvalue.includes(descriptor))
|
|
302
|
-
ui.tooltip.show(
|
|
410
|
+
ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
|
|
303
411
|
else
|
|
304
|
-
ui.tooltip.show(
|
|
412
|
+
ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
|
|
305
413
|
|
|
306
414
|
return true;
|
|
307
415
|
}
|
|
@@ -341,7 +449,23 @@ export class Pmpo {
|
|
|
341
449
|
return;
|
|
342
450
|
|
|
343
451
|
const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
|
|
344
|
-
|
|
452
|
+
const element = this.desirabilityProfileRoots.get(descriptor);
|
|
453
|
+
|
|
454
|
+
if (element != null)
|
|
455
|
+
cell.element = element;
|
|
456
|
+
else {
|
|
457
|
+
const selected = selectedByPvalue.includes(descriptor);
|
|
458
|
+
const text = selected ? 'highly correlated with other descriptors' : 'statistically insignificant';
|
|
459
|
+
const tooltipMsg = selected ?
|
|
460
|
+
`No chart shown: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.` :
|
|
461
|
+
`No chart shown: <b>${descriptor}</b> is excluded due to a high p-value.`;
|
|
462
|
+
|
|
463
|
+
const divWithDescription = ui.divText(text);
|
|
464
|
+
divWithDescription.style.color = COLORS.SKIPPED;
|
|
465
|
+
divWithDescription.classList.add('eda-pmpo-centered-text');
|
|
466
|
+
ui.tooltip.bind(divWithDescription, tooltipMsg);
|
|
467
|
+
cell.element = divWithDescription;
|
|
468
|
+
}
|
|
345
469
|
}); // grid.onCellPrepare
|
|
346
470
|
} // updateGrid
|
|
347
471
|
|
|
@@ -372,7 +496,7 @@ export class Pmpo {
|
|
|
372
496
|
} // updateGrid
|
|
373
497
|
|
|
374
498
|
/** Updates the desirability profile data */
|
|
375
|
-
private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame): void {
|
|
499
|
+
private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame, useSigmoidalCorrection: boolean): void {
|
|
376
500
|
if (this.params == null)
|
|
377
501
|
return;
|
|
378
502
|
|
|
@@ -380,7 +504,7 @@ export class Pmpo {
|
|
|
380
504
|
this.desirabilityProfileRoots.forEach((root) => root.remove());
|
|
381
505
|
this.desirabilityProfileRoots.clear();
|
|
382
506
|
|
|
383
|
-
const desirabilityProfile = getDesirabilityProfileJson(this.params, '', '');
|
|
507
|
+
const desirabilityProfile = getDesirabilityProfileJson(this.params, useSigmoidalCorrection, '', '', true);
|
|
384
508
|
|
|
385
509
|
// Set weights
|
|
386
510
|
const descrNames = descrStatsTable.col(DESCR_TITLE)!.toList();
|
|
@@ -416,52 +540,69 @@ export class Pmpo {
|
|
|
416
540
|
});
|
|
417
541
|
} // updateDesirabilityProfileData
|
|
418
542
|
|
|
419
|
-
/**
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
const
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
543
|
+
/** Updates the ROC curve viewer with the given desirability (labels) and prediction columns
|
|
544
|
+
* @return Best threshold according to Youden's J statistic
|
|
545
|
+
*/
|
|
546
|
+
private updateRocCurve(desirability: DG.Column, prediction: DG.Column): number {
|
|
547
|
+
const evaluation = getPmpoEvaluation(desirability, prediction);
|
|
548
|
+
|
|
549
|
+
const rocDf = DG.DataFrame.fromColumns([
|
|
550
|
+
DG.Column.fromFloat32Array(THRESHOLD, ROC_TRESHOLDS),
|
|
551
|
+
DG.Column.fromFloat32Array(FPR_TITLE, evaluation.fpr),
|
|
552
|
+
DG.Column.fromFloat32Array(TPR_TITLE, evaluation.tpr),
|
|
553
|
+
]);
|
|
554
|
+
|
|
555
|
+
// Add baseline
|
|
556
|
+
rocDf.meta.formulaLines.addLine({
|
|
557
|
+
title: 'Non-informative baseline',
|
|
558
|
+
formula: `\${${TPR_TITLE}} = \${${FPR_TITLE}}`,
|
|
559
|
+
width: 1,
|
|
560
|
+
style: 'dashed',
|
|
561
|
+
min: 0,
|
|
562
|
+
max: 1,
|
|
432
563
|
});
|
|
433
|
-
const descrStatsTable = getDescriptorStatisticsTable(descrStats);
|
|
434
|
-
|
|
435
|
-
// Set p-value column color coding
|
|
436
|
-
setPvalColumnColorCoding(descrStatsTable, pValTresh);
|
|
437
|
-
|
|
438
|
-
// Filter by p-value
|
|
439
|
-
const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
|
|
440
|
-
|
|
441
|
-
// Compute correlation triples
|
|
442
|
-
const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
|
|
443
|
-
|
|
444
|
-
// Filter by correlations
|
|
445
|
-
const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
|
|
446
564
|
|
|
447
|
-
|
|
448
|
-
|
|
565
|
+
this.rocCurve.dataFrame = rocDf;
|
|
566
|
+
this.rocCurve.setOptions({
|
|
567
|
+
xColumnName: FPR_TITLE,
|
|
568
|
+
yColumnName: TPR_TITLE,
|
|
569
|
+
linesOrderColumnName: FPR_TITLE,
|
|
570
|
+
linesWidth: 5,
|
|
571
|
+
markerType: 'dot',
|
|
572
|
+
title: `ROC Curve (AUC = ${evaluation.auc.toFixed(3)})`,
|
|
573
|
+
});
|
|
449
574
|
|
|
450
|
-
|
|
451
|
-
|
|
575
|
+
return evaluation.threshold;
|
|
576
|
+
} // updateRocCurve
|
|
577
|
+
|
|
578
|
+
/** Updates the confusion matrix viewer with the given data frame, desirability column name, and best threshold */
|
|
579
|
+
private updateConfusionMatrix(df: DG.DataFrame, desColName: string, bestThreshold: number): void {
|
|
580
|
+
this.confusionMatrix.dataFrame = df;
|
|
581
|
+
this.confusionMatrix.setOptions({
|
|
582
|
+
xColumnName: desColName,
|
|
583
|
+
yColumnName: this.boolPredictionName,
|
|
584
|
+
description: `Threshold: ${bestThreshold.toFixed(3)} (optimized via Youden's J)`,
|
|
585
|
+
title: desColName + ' Confusion Matrix',
|
|
586
|
+
});
|
|
587
|
+
} // updateConfusionMatrix
|
|
452
588
|
|
|
453
|
-
|
|
454
|
-
|
|
589
|
+
/** Fits the pMPO model to the given data and updates the viewers accordingly */
|
|
590
|
+
private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
|
|
591
|
+
pValTresh: number, r2Tresh: number, qCutoff: number, useSigmoid: boolean): void {
|
|
592
|
+
const trainResult = Pmpo.fit(df, descriptors, desirability, pValTresh, r2Tresh, qCutoff);
|
|
593
|
+
this.params = trainResult.params;
|
|
594
|
+
const descrStatsTable = trainResult.descrStatsTable;
|
|
595
|
+
const selectedByPvalue = trainResult.selectedByPvalue;
|
|
596
|
+
const selectedByCorr = trainResult.selectedByCorr;
|
|
455
597
|
|
|
456
|
-
|
|
457
|
-
this.params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
|
|
598
|
+
const descriptorNames = descriptors.names();
|
|
458
599
|
|
|
459
|
-
|
|
460
|
-
const prediction = Pmpo.predict(df, this.params, this.predictionName);
|
|
600
|
+
const prediction = Pmpo.predict(df, this.params, useSigmoid, this.predictionName);
|
|
461
601
|
|
|
462
602
|
// Mark predictions with a color
|
|
463
603
|
prediction.colors.setLinear(getOutputPalette(OPT_TYPE.MAX), {min: prediction.stats.min, max: prediction.stats.max});
|
|
464
604
|
|
|
605
|
+
// Remove existing prediction column and add the new one
|
|
465
606
|
df.columns.remove(this.predictionName);
|
|
466
607
|
df.columns.add(prediction);
|
|
467
608
|
|
|
@@ -469,10 +610,25 @@ export class Pmpo {
|
|
|
469
610
|
this.updateGrid();
|
|
470
611
|
|
|
471
612
|
// Update desirability profile roots map
|
|
472
|
-
this.updateDesirabilityProfileData(descrStatsTable);
|
|
613
|
+
this.updateDesirabilityProfileData(descrStatsTable, useSigmoid);
|
|
473
614
|
|
|
474
615
|
// Update statistics grid
|
|
475
616
|
this.updateStatisticsGrid(descrStatsTable, descriptorNames, selectedByPvalue, selectedByCorr);
|
|
617
|
+
|
|
618
|
+
// Update ROC curve
|
|
619
|
+
const bestThreshold = this.updateRocCurve(desirability, prediction);
|
|
620
|
+
|
|
621
|
+
// Update desirability prediction column
|
|
622
|
+
const desColName = desirability.name;
|
|
623
|
+
df.columns.remove(this.boolPredictionName);
|
|
624
|
+
this.boolPredictionName = df.columns.getUnusedName(desColName + '(predicted)');
|
|
625
|
+
const boolPrediction = getBoolPredictionColumn(prediction, bestThreshold, this.boolPredictionName);
|
|
626
|
+
df.columns.add(boolPrediction);
|
|
627
|
+
|
|
628
|
+
// Update confusion matrix
|
|
629
|
+
this.updateConfusionMatrix(df, desColName, bestThreshold);
|
|
630
|
+
|
|
631
|
+
this.view.dataFrame.selection.setAll(false, true);
|
|
476
632
|
} // fitAndUpdateViewers
|
|
477
633
|
|
|
478
634
|
/** Runs the pMPO model training application */
|
|
@@ -480,7 +636,7 @@ export class Pmpo {
|
|
|
480
636
|
const dockMng = this.view.dockManager;
|
|
481
637
|
|
|
482
638
|
// Inputs form
|
|
483
|
-
dockMng.dock(this.getInputForm(), DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
|
|
639
|
+
dockMng.dock(this.getInputForm(true).form, DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
|
|
484
640
|
|
|
485
641
|
// Dock viewers
|
|
486
642
|
const gridNode = dockMng.findNode(this.view.grid.root);
|
|
@@ -488,13 +644,30 @@ export class Pmpo {
|
|
|
488
644
|
if (gridNode == null)
|
|
489
645
|
throw new Error('Failed to train pMPO: missing a grid in the table view.');
|
|
490
646
|
|
|
491
|
-
|
|
647
|
+
// Dock statistics grid
|
|
648
|
+
const statGridNode = dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
|
|
649
|
+
|
|
650
|
+
// Dock ROC curve
|
|
651
|
+
const rocNode = dockMng.dock(this.rocCurve, DG.DOCK_TYPE.RIGHT, statGridNode, undefined, 0.3);
|
|
652
|
+
|
|
653
|
+
// Dock confusion matrix
|
|
654
|
+
dockMng.dock(this.confusionMatrix, DG.DOCK_TYPE.RIGHT, rocNode, undefined, 0.2);
|
|
492
655
|
|
|
493
656
|
this.setRibbons();
|
|
494
657
|
} // runTrainingApp
|
|
495
658
|
|
|
659
|
+
/** Runs the pMPO model training application */
|
|
660
|
+
public getPmpoAppItems(): PmpoAppItems {
|
|
661
|
+
return {
|
|
662
|
+
statsGrid: this.statGrid,
|
|
663
|
+
rocCurve: this.rocCurve,
|
|
664
|
+
confusionMatrix: this.confusionMatrix,
|
|
665
|
+
controls: this.getInputForm(false),
|
|
666
|
+
};
|
|
667
|
+
} // getViewers
|
|
668
|
+
|
|
496
669
|
/** Creates and returns the input form for pMPO model training */
|
|
497
|
-
private getInputForm():
|
|
670
|
+
private getInputForm(addBtn: boolean): Controls {
|
|
498
671
|
const form = ui.form([]);
|
|
499
672
|
form.append(ui.h2('Training data'));
|
|
500
673
|
const numericColNames = this.numericCols.map((col) => col.name);
|
|
@@ -502,16 +675,21 @@ export class Pmpo {
|
|
|
502
675
|
// Function to run computations on input changes
|
|
503
676
|
const runComputations = () => {
|
|
504
677
|
try {
|
|
678
|
+
//grok.shell.info('Running...');
|
|
679
|
+
|
|
505
680
|
this.fitAndUpdateViewers(
|
|
506
681
|
this.table,
|
|
507
682
|
DG.DataFrame.fromColumns(descrInput.value).columns,
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
683
|
+
this.table.col(desInput.value!)!,
|
|
684
|
+
pInput.value!,
|
|
685
|
+
rInput.value!,
|
|
686
|
+
qInput.value!,
|
|
687
|
+
useSigmoidInput.value,
|
|
512
688
|
);
|
|
513
689
|
} catch (err) {
|
|
514
|
-
|
|
690
|
+
err instanceof PmpoError ?
|
|
691
|
+
grok.shell.warning(err.message) :
|
|
692
|
+
grok.shell.error(err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.');
|
|
515
693
|
}
|
|
516
694
|
};
|
|
517
695
|
|
|
@@ -523,8 +701,10 @@ export class Pmpo {
|
|
|
523
701
|
checked: numericColNames,
|
|
524
702
|
tooltipText: 'Descriptor columns used for model construction.',
|
|
525
703
|
onValueChanged: (value) => {
|
|
526
|
-
if (value != null)
|
|
527
|
-
|
|
704
|
+
if (value != null) {
|
|
705
|
+
areTunedSettingsUsed = false;
|
|
706
|
+
checkAutoTuneAndRun();
|
|
707
|
+
}
|
|
528
708
|
},
|
|
529
709
|
});
|
|
530
710
|
form.append(descrInput.root);
|
|
@@ -536,26 +716,95 @@ export class Pmpo {
|
|
|
536
716
|
items: this.boolCols.map((col) => col.name),
|
|
537
717
|
tooltipText: 'Desirability column.',
|
|
538
718
|
onValueChanged: (value) => {
|
|
539
|
-
if (value != null)
|
|
540
|
-
|
|
719
|
+
if (value != null) {
|
|
720
|
+
areTunedSettingsUsed = false;
|
|
721
|
+
checkAutoTuneAndRun();
|
|
722
|
+
}
|
|
541
723
|
},
|
|
542
724
|
});
|
|
543
725
|
form.append(desInput.root);
|
|
544
726
|
|
|
545
|
-
const header = ui.h2('
|
|
546
|
-
ui.tooltip.bind(header, 'Settings of the pMPO model training.');
|
|
727
|
+
const header = ui.h2('Settings');
|
|
547
728
|
form.append(header);
|
|
729
|
+
ui.tooltip.bind(header, 'Settings of the pMPO model.');
|
|
730
|
+
|
|
731
|
+
// use sigmoid correction
|
|
732
|
+
const useSigmoidInput = ui.input.bool('\u03C3 correction', {
|
|
733
|
+
value: USE_SIGMOID_DEFAULT,
|
|
734
|
+
tooltipText: 'Use the sigmoidal correction to the weighted Gaussian scores.',
|
|
735
|
+
onValueChanged: (_value) => {
|
|
736
|
+
areTunedSettingsUsed = false;
|
|
737
|
+
checkAutoTuneAndRun();
|
|
738
|
+
},
|
|
739
|
+
});
|
|
740
|
+
form.append(useSigmoidInput.root);
|
|
741
|
+
|
|
742
|
+
const toUseAutoTune = (this.table.rowCount <= AUTO_TUNE_MAX_APPLICABLE_ROWS);
|
|
743
|
+
|
|
744
|
+
// Flag indicating whether optimal parameters from auto-tuning are currently used
|
|
745
|
+
let areTunedSettingsUsed = false;
|
|
746
|
+
|
|
747
|
+
const setOptimalParametersAndRun = async () => {
|
|
748
|
+
if (!areTunedSettingsUsed) {
|
|
749
|
+
const optimalSettings = await this.getOptimalSettings(
|
|
750
|
+
DG.DataFrame.fromColumns(descrInput.value).columns,
|
|
751
|
+
this.table.col(desInput.value!)!,
|
|
752
|
+
useSigmoidInput.value,
|
|
753
|
+
);
|
|
754
|
+
|
|
755
|
+
if (optimalSettings.success) {
|
|
756
|
+
pInput.value = Math.max(optimalSettings.pValTresh, P_VAL_TRES_MIN);
|
|
757
|
+
rInput.value = Math.max(optimalSettings.r2Tresh, R2_MIN);
|
|
758
|
+
qInput.value = Math.max(optimalSettings.qCutoff, Q_CUTOFF_MIN);
|
|
759
|
+
areTunedSettingsUsed = true;
|
|
760
|
+
} else
|
|
761
|
+
autoTuneInput.value = false; // revert to manual mode if optimization failed
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
runComputations();
|
|
765
|
+
};
|
|
766
|
+
|
|
767
|
+
const checkAutoTuneAndRun = () => {
|
|
768
|
+
if (autoTuneInput.value)
|
|
769
|
+
setOptimalParametersAndRun();
|
|
770
|
+
else
|
|
771
|
+
runComputations();
|
|
772
|
+
};
|
|
773
|
+
|
|
774
|
+
// autotuning input
|
|
775
|
+
const autoTuneInput = ui.input.bool('Auto-tuning', {
|
|
776
|
+
value: false,
|
|
777
|
+
tooltipText: 'Automatically select optimal p-value, R², and q-cutoff by maximizing AUC.',
|
|
778
|
+
onValueChanged: async (value) => {
|
|
779
|
+
setEnability(!value);
|
|
780
|
+
|
|
781
|
+
if (areTunedSettingsUsed)
|
|
782
|
+
return;
|
|
783
|
+
|
|
784
|
+
// If auto-tuning is turned on, set optimal parameters and run computations
|
|
785
|
+
if (value)
|
|
786
|
+
await setOptimalParametersAndRun();
|
|
787
|
+
},
|
|
788
|
+
});
|
|
789
|
+
form.append(autoTuneInput.root);
|
|
548
790
|
|
|
549
791
|
// p-value threshold input
|
|
550
792
|
const pInput = ui.input.float('p-value', {
|
|
551
793
|
nullable: false,
|
|
552
794
|
min: P_VAL_TRES_MIN,
|
|
553
|
-
max:
|
|
554
|
-
step: 0.
|
|
555
|
-
value:
|
|
556
|
-
|
|
795
|
+
max: P_VAL_TRES_MAX,
|
|
796
|
+
step: 0.001,
|
|
797
|
+
value: P_VAL_TRES_DEFAULT,
|
|
798
|
+
// @ts-ignore
|
|
799
|
+
format: FORMAT,
|
|
800
|
+
tooltipText: 'P-value threshold. Descriptors with p-values above this threshold are excluded.',
|
|
557
801
|
onValueChanged: (value) => {
|
|
558
|
-
|
|
802
|
+
// Prevent running computations when auto-tuning is on, since parameters will be set automatically
|
|
803
|
+
if (autoTuneInput.value)
|
|
804
|
+
return;
|
|
805
|
+
|
|
806
|
+
areTunedSettingsUsed = false;
|
|
807
|
+
if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= P_VAL_TRES_MAX))
|
|
559
808
|
runComputations();
|
|
560
809
|
},
|
|
561
810
|
});
|
|
@@ -563,15 +812,23 @@ export class Pmpo {
|
|
|
563
812
|
|
|
564
813
|
// R² threshold input
|
|
565
814
|
const rInput = ui.input.float('R²', {
|
|
815
|
+
// @ts-ignore
|
|
816
|
+
format: FORMAT,
|
|
566
817
|
nullable: false,
|
|
567
818
|
min: R2_MIN,
|
|
568
|
-
value:
|
|
569
|
-
max:
|
|
819
|
+
value: R2_DEFAULT,
|
|
820
|
+
max: R2_MAX,
|
|
570
821
|
step: 0.01,
|
|
571
822
|
// eslint-disable-next-line max-len
|
|
572
|
-
tooltipText: 'Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
|
|
823
|
+
tooltipText: 'Squared correlation threshold. Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
|
|
573
824
|
onValueChanged: (value) => {
|
|
574
|
-
|
|
825
|
+
// Prevent running computations when auto-tuning is on, since parameters will be set automatically
|
|
826
|
+
if (autoTuneInput.value)
|
|
827
|
+
return;
|
|
828
|
+
|
|
829
|
+
areTunedSettingsUsed = false;
|
|
830
|
+
|
|
831
|
+
if ((value != null) && (value >= R2_MIN) && (value <= R2_MAX))
|
|
575
832
|
runComputations();
|
|
576
833
|
},
|
|
577
834
|
});
|
|
@@ -579,36 +836,62 @@ export class Pmpo {
|
|
|
579
836
|
|
|
580
837
|
// q-cutoff input
|
|
581
838
|
const qInput = ui.input.float('q-cutoff', {
|
|
839
|
+
// @ts-ignore
|
|
840
|
+
format: FORMAT,
|
|
582
841
|
nullable: false,
|
|
583
842
|
min: Q_CUTOFF_MIN,
|
|
584
|
-
value:
|
|
585
|
-
max:
|
|
843
|
+
value: Q_CUTOFF_DEFAULT,
|
|
844
|
+
max: Q_CUTOFF_MAX,
|
|
586
845
|
step: 0.01,
|
|
587
846
|
tooltipText: 'Q-cutoff for the pMPO model computation.',
|
|
588
847
|
onValueChanged: (value) => {
|
|
589
|
-
|
|
848
|
+
// Prevent running computations when auto-tuning is on, since parameters will be set automatically
|
|
849
|
+
if (autoTuneInput.value)
|
|
850
|
+
return;
|
|
851
|
+
|
|
852
|
+
areTunedSettingsUsed = false;
|
|
853
|
+
|
|
854
|
+
if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= Q_CUTOFF_MAX))
|
|
590
855
|
runComputations();
|
|
591
856
|
},
|
|
592
857
|
});
|
|
593
858
|
form.append(qInput.root);
|
|
594
859
|
|
|
595
|
-
|
|
860
|
+
const setEnability = (toEnable: boolean) => {
|
|
861
|
+
pInput.enabled = toEnable;
|
|
862
|
+
rInput.enabled = toEnable;
|
|
863
|
+
qInput.enabled = toEnable;
|
|
864
|
+
};
|
|
865
|
+
|
|
866
|
+
setTimeout(() => {
|
|
867
|
+
runComputations();
|
|
868
|
+
|
|
869
|
+
if (toUseAutoTune)
|
|
870
|
+
autoTuneInput.value = true; // this will trigger setting optimal parameters and running computations
|
|
871
|
+
else
|
|
872
|
+
runComputations();
|
|
873
|
+
}, 10);
|
|
596
874
|
|
|
597
875
|
// Save model button
|
|
598
|
-
const saveBtn = ui.button('Save
|
|
876
|
+
const saveBtn = ui.button('Save', async () => {
|
|
599
877
|
if (this.params == null) {
|
|
600
878
|
grok.shell.warning('Failed to save pMPO model: null parameters.');
|
|
601
879
|
return;
|
|
602
880
|
}
|
|
603
881
|
|
|
604
|
-
saveModel(this.params, this.table.name);
|
|
882
|
+
saveModel(this.params, this.table.name, useSigmoidInput.value);
|
|
605
883
|
}, 'Save model as platform file.');
|
|
606
|
-
|
|
884
|
+
|
|
885
|
+
if (addBtn)
|
|
886
|
+
form.append(saveBtn);
|
|
607
887
|
|
|
608
888
|
const div = ui.div([form]);
|
|
609
889
|
div.classList.add('eda-pmpo-input-form');
|
|
610
890
|
|
|
611
|
-
return
|
|
891
|
+
return {
|
|
892
|
+
form: div,
|
|
893
|
+
saveBtn: saveBtn,
|
|
894
|
+
};
|
|
612
895
|
} // getInputForm
|
|
613
896
|
|
|
614
897
|
/** Retrieves boolean columns from the data frame */
|
|
@@ -634,4 +917,75 @@ export class Pmpo {
|
|
|
634
917
|
|
|
635
918
|
return res;
|
|
636
919
|
} // getValidNumericCols
|
|
920
|
+
|
|
921
|
+
/** Fits the pMPO model to the given data and updates the viewers accordingly */
|
|
922
|
+
private async getOptimalSettings(descriptors: DG.ColumnList, desirability: DG.Column, useSigmoid: boolean): Promise<OptimalPoint> {
|
|
923
|
+
const failedResult: OptimalPoint = {
|
|
924
|
+
pValTresh: 0,
|
|
925
|
+
r2Tresh: 0,
|
|
926
|
+
qCutoff: 0,
|
|
927
|
+
success: false,
|
|
928
|
+
};
|
|
929
|
+
|
|
930
|
+
const descriptorNames = descriptors.names();
|
|
931
|
+
const {desired, nonDesired} = getDesiredTables(this.table, desirability);
|
|
932
|
+
|
|
933
|
+
// Compute descriptors' statistics
|
|
934
|
+
const descrStats = new Map<string, DescriptorStatistics>();
|
|
935
|
+
descriptorNames.forEach((name) => {
|
|
936
|
+
descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
|
|
937
|
+
});
|
|
938
|
+
const descrStatsTable = getDescriptorStatisticsTable(descrStats);
|
|
939
|
+
|
|
940
|
+
// Filter by p-value
|
|
941
|
+
const selectedByPvalue = getFilteredByPvalue(descrStatsTable, P_VAL_TRES_DEFAULT);
|
|
942
|
+
if (selectedByPvalue.length < 1)
|
|
943
|
+
return failedResult;
|
|
944
|
+
|
|
945
|
+
const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
|
|
946
|
+
|
|
947
|
+
const funcToBeMinimized = (point: Float32Array) => {
|
|
948
|
+
// Filter by correlations
|
|
949
|
+
const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, point[0], correlationTriples);
|
|
950
|
+
|
|
951
|
+
// Compute pMPO parameters - training
|
|
952
|
+
const params = getModelParams(desired, nonDesired, selectedByCorr, point[1]);
|
|
953
|
+
|
|
954
|
+
// Get predictions
|
|
955
|
+
const prediction = Pmpo.predict(this.table, params, useSigmoid, this.predictionName);
|
|
956
|
+
|
|
957
|
+
// Evaluate predictions and return 1 - AUC (since optimization minimizes the function, but we want to maximize AUC)
|
|
958
|
+
return 1 - getPmpoEvaluation(desirability, prediction).auc;
|
|
959
|
+
}; // funcToBeMinimized
|
|
960
|
+
|
|
961
|
+
const pi = DG.TaskBarProgressIndicator.create('Optimizing... ', {cancelable: true});
|
|
962
|
+
|
|
963
|
+
try {
|
|
964
|
+
const optimalResult = await optimizeNM(
|
|
965
|
+
pi,
|
|
966
|
+
funcToBeMinimized,
|
|
967
|
+
new Float32Array([R2_DEFAULT, Q_CUTOFF_DEFAULT]),
|
|
968
|
+
DEFAULT_OPTIMIZATION_SETTINGS,
|
|
969
|
+
LOW_PARAMS_BOUNDS,
|
|
970
|
+
HIGH_PARAMS_BOUNDS,
|
|
971
|
+
);
|
|
972
|
+
|
|
973
|
+
const success = !pi.canceled;
|
|
974
|
+
pi.close();
|
|
975
|
+
|
|
976
|
+
if (success) {
|
|
977
|
+
return {
|
|
978
|
+
pValTresh: P_VAL_TRES_DEFAULT,
|
|
979
|
+
r2Tresh: optimalResult.optimalPoint[0],
|
|
980
|
+
qCutoff: optimalResult.optimalPoint[1],
|
|
981
|
+
success: true,
|
|
982
|
+
};
|
|
983
|
+
} else
|
|
984
|
+
return failedResult;
|
|
985
|
+
} catch (err) {
|
|
986
|
+
pi.close();
|
|
987
|
+
|
|
988
|
+
return failedResult;
|
|
989
|
+
}
|
|
990
|
+
} // getOptimalSettings
|
|
637
991
|
}; // Pmpo
|