@datagrok/eda 1.4.11 → 1.4.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc.json +0 -1
- package/CHANGELOG.md +15 -0
- package/CLAUDE.md +185 -0
- package/README.md +8 -0
- package/css/pmpo.css +35 -0
- package/dist/package-test.js +1 -1
- package/dist/package-test.js.map +1 -1
- package/dist/package.js +1 -1
- package/dist/package.js.map +1 -1
- package/eslintrc.json +45 -0
- package/files/drugs-props-test.csv +126 -0
- package/files/drugs-props-train-scores.csv +664 -0
- package/files/drugs-props-train.csv +664 -0
- package/package.json +9 -3
- package/src/anova/anova-tools.ts +1 -1
- package/src/anova/anova-ui.ts +1 -1
- package/src/package-api.ts +18 -0
- package/src/package-test.ts +4 -1
- package/src/package.g.ts +25 -0
- package/src/package.ts +55 -15
- package/src/pareto-optimization/pareto-computations.ts +6 -0
- package/src/pareto-optimization/utils.ts +6 -4
- package/src/probabilistic-scoring/data-generator.ts +157 -0
- package/src/probabilistic-scoring/nelder-mead.ts +204 -0
- package/src/probabilistic-scoring/pmpo-defs.ts +218 -0
- package/src/probabilistic-scoring/pmpo-utils.ts +603 -0
- package/src/probabilistic-scoring/prob-scoring.ts +991 -0
- package/src/probabilistic-scoring/stat-tools.ts +303 -0
- package/src/softmax-classifier.ts +1 -1
- package/src/tests/anova-tests.ts +1 -1
- package/src/tests/classifiers-tests.ts +1 -1
- package/src/tests/dim-reduction-tests.ts +1 -1
- package/src/tests/linear-methods-tests.ts +1 -1
- package/src/tests/mis-vals-imputation-tests.ts +1 -1
- package/src/tests/pareto-tests.ts +253 -0
- package/src/tests/pmpo-tests.ts +157 -0
- package/test-console-output-1.log +175 -209
- package/test-record-1.mp4 +0 -0
|
@@ -0,0 +1,991 @@
|
|
|
1
|
+
/* eslint-disable max-len */
|
|
2
|
+
// Probabilistic scoring (pMPO) features
|
|
3
|
+
// Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
|
|
4
|
+
|
|
5
|
+
import * as grok from 'datagrok-api/grok';
|
|
6
|
+
import * as ui from 'datagrok-api/ui';
|
|
7
|
+
import * as DG from 'datagrok-api/dg';
|
|
8
|
+
|
|
9
|
+
import {MpoProfileEditor} from '@datagrok-libraries/statistics/src/mpo/mpo-profile-editor';
|
|
10
|
+
|
|
11
|
+
import '../../css/pmpo.css';
|
|
12
|
+
|
|
13
|
+
import {getDesiredTables, getDescriptorStatistics, getBoolPredictionColumn, getPmpoEvaluation} from './stat-tools';
|
|
14
|
+
import {MIN_SAMPLES_COUNT, PMPO_NON_APPLICABLE, DescriptorStatistics, P_VAL_TRES_MIN, DESCR_TITLE,
|
|
15
|
+
R2_MIN, Q_CUTOFF_MIN, PmpoParams, SCORES_TITLE, DESCR_TABLE_TITLE, PMPO_COMPUTE_FAILED, SELECTED_TITLE,
|
|
16
|
+
P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE,
|
|
17
|
+
P_VAL_TRES_DEFAULT, R2_DEFAULT, Q_CUTOFF_DEFAULT, USE_SIGMOID_DEFAULT, ROC_TRESHOLDS,
|
|
18
|
+
FPR_TITLE, TPR_TITLE, COLORS, THRESHOLD, AUTO_TUNE_MAX_APPLICABLE_ROWS, DEFAULT_OPTIMIZATION_SETTINGS,
|
|
19
|
+
P_VAL_TRES_MAX, R2_MAX, Q_CUTOFF_MAX, OptimalPoint, LOW_PARAMS_BOUNDS, HIGH_PARAMS_BOUNDS, FORMAT} from './pmpo-defs';
|
|
20
|
+
import {addSelectedDescriptorsCol, getDescriptorStatisticsTable, getFilteredByPvalue, getFilteredByCorrelations,
|
|
21
|
+
getModelParams, getDescrTooltip, saveModel, getScoreTooltip, getDesirabilityProfileJson, getCorrelationTriples,
|
|
22
|
+
addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding, PmpoError} from './pmpo-utils';
|
|
23
|
+
import {getOutputPalette} from '../pareto-optimization/utils';
|
|
24
|
+
import {OPT_TYPE} from '../pareto-optimization/defs';
|
|
25
|
+
import {optimizeNM} from './nelder-mead';
|
|
26
|
+
|
|
27
|
+
export type PmpoTrainingResult = {
|
|
28
|
+
params: Map<string, PmpoParams>,
|
|
29
|
+
descrStatsTable: DG.DataFrame,
|
|
30
|
+
selectedByPvalue: string[],
|
|
31
|
+
selectedByCorr: string[],
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
/** Type for pMPO training controls */
|
|
35
|
+
export type Controls = {form: HTMLElement, saveBtn: HTMLButtonElement};
|
|
36
|
+
|
|
37
|
+
/** Type for pMPO elements */
|
|
38
|
+
export type PmpoAppItems = {
|
|
39
|
+
statsGrid: DG.Viewer;
|
|
40
|
+
rocCurve: DG.Viewer;
|
|
41
|
+
confusionMatrix: DG.Viewer;
|
|
42
|
+
controls: Controls;
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
/** Class implementing probabilistic MPO (pMPO) model training and prediction */
|
|
46
|
+
export class Pmpo {
|
|
47
|
+
/** Checks if pMPO model can be applied to the given descriptors and desirability column */
|
|
48
|
+
static isApplicable(descriptors: DG.ColumnList, desirability: DG.Column, pValThresh: number,
|
|
49
|
+
r2Tresh: number, qCutoff: number, toShowWarning: boolean = false): boolean {
|
|
50
|
+
const rows = desirability.length;
|
|
51
|
+
|
|
52
|
+
const showWarning = (msg: string) => {
|
|
53
|
+
if (toShowWarning)
|
|
54
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + msg);
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
// Check p-value threshold
|
|
58
|
+
if (pValThresh < P_VAL_TRES_MIN) {
|
|
59
|
+
showWarning(`: too small p-value threshold - ${pValThresh}, minimum - ${P_VAL_TRES_MIN}`);
|
|
60
|
+
return false;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
// Check R2 threshold
|
|
64
|
+
if (r2Tresh < R2_MIN) {
|
|
65
|
+
showWarning(`: too small R² threshold - ${r2Tresh}, minimum - ${R2_MIN}`);
|
|
66
|
+
return false;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// Check q-cutoff
|
|
70
|
+
if (qCutoff < Q_CUTOFF_MIN) {
|
|
71
|
+
showWarning(`: too small q-cutoff - ${qCutoff}, minimum - ${Q_CUTOFF_MIN}`);
|
|
72
|
+
return false;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// Check samples count
|
|
76
|
+
if (rows < MIN_SAMPLES_COUNT) {
|
|
77
|
+
showWarning(`: not enough of samples - ${rows}, minimum - ${MIN_SAMPLES_COUNT}`);
|
|
78
|
+
return false;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Check desirability
|
|
82
|
+
if (desirability.type !== DG.COLUMN_TYPE.BOOL) {
|
|
83
|
+
showWarning(`: "${desirability.name}" must be boolean column.`);
|
|
84
|
+
return false;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (desirability.stats.stdev === 0) { // TRUE & FALSE
|
|
88
|
+
showWarning(`: "${desirability.name}" has a single category.`);
|
|
89
|
+
return false;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// Check descriptors
|
|
93
|
+
let nonConstantCols = 0;
|
|
94
|
+
|
|
95
|
+
for (const col of descriptors) {
|
|
96
|
+
if (!col.isNumerical) {
|
|
97
|
+
showWarning(`: "${col.name}" is not numerical.`);
|
|
98
|
+
return false;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if (col.stats.missingValueCount > 0) {
|
|
102
|
+
showWarning(`: "${col.name}" contains missing values.`);
|
|
103
|
+
return false;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if (col.stats.stdev > 0)
|
|
107
|
+
++nonConstantCols;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
if (nonConstantCols < 1) {
|
|
111
|
+
showWarning(`: not enough of non-constant descriptors.`);
|
|
112
|
+
return false;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
return true;
|
|
116
|
+
} // isApplicable
|
|
117
|
+
|
|
118
|
+
/** Validates the input data frame for pMPO applicability */
|
|
119
|
+
static isTableValid(df: DG.DataFrame, toShowMsg: boolean = true): boolean {
|
|
120
|
+
// Check row count
|
|
121
|
+
if (df.rowCount < 2) {
|
|
122
|
+
if (toShowMsg)
|
|
123
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + `. Not enough of samples: ${df.rowCount}, minimum: 2.`);
|
|
124
|
+
return false;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
let boolColsCount = 0;
|
|
128
|
+
let validNumericColsCount = 0;
|
|
129
|
+
|
|
130
|
+
// Check numeric columns and boolean columns
|
|
131
|
+
for (const col of df.columns) {
|
|
132
|
+
if (col.isNumerical) {
|
|
133
|
+
if ((col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
|
|
134
|
+
++validNumericColsCount;
|
|
135
|
+
} else if (col.type == DG.COLUMN_TYPE.BOOL)
|
|
136
|
+
++boolColsCount;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// Check boolean columns count
|
|
140
|
+
if (boolColsCount < 1) {
|
|
141
|
+
if (toShowMsg)
|
|
142
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + ': no boolean columns.');
|
|
143
|
+
return false;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
// Check valid numeric columns count
|
|
147
|
+
if (validNumericColsCount < 1) {
|
|
148
|
+
if (toShowMsg)
|
|
149
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + ': no numeric columns without missing values and non-zero variance.');
|
|
150
|
+
return false;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
return true;
|
|
154
|
+
} // isTableValid
|
|
155
|
+
|
|
156
|
+
/** Fits the pMPO model to the given data and returns training results */
|
|
157
|
+
static fit(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
|
|
158
|
+
pValTresh: number, r2Tresh: number, qCutoff: number, toCheckApplicability: boolean = true): PmpoTrainingResult {
|
|
159
|
+
if (toCheckApplicability) {
|
|
160
|
+
if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
|
|
161
|
+
throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
const descriptorNames = descriptors.names();
|
|
165
|
+
const {desired, nonDesired} = getDesiredTables(df, desirability);
|
|
166
|
+
|
|
167
|
+
// Compute descriptors' statistics
|
|
168
|
+
const descrStats = new Map<string, DescriptorStatistics>();
|
|
169
|
+
descriptorNames.forEach((name) => {
|
|
170
|
+
descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
|
|
171
|
+
});
|
|
172
|
+
const descrStatsTable = getDescriptorStatisticsTable(descrStats);
|
|
173
|
+
|
|
174
|
+
// Set p-value column color coding
|
|
175
|
+
setPvalColumnColorCoding(descrStatsTable, pValTresh);
|
|
176
|
+
|
|
177
|
+
// Filter by p-value
|
|
178
|
+
const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
|
|
179
|
+
|
|
180
|
+
if (selectedByPvalue.length < 1)
|
|
181
|
+
throw new PmpoError('Cannot train pMPO model: all descriptors have high p-values (not significant).');
|
|
182
|
+
|
|
183
|
+
// Compute correlation triples
|
|
184
|
+
const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
|
|
185
|
+
|
|
186
|
+
// Filter by correlations
|
|
187
|
+
const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
|
|
188
|
+
|
|
189
|
+
// Add the Selected column
|
|
190
|
+
addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
|
|
191
|
+
|
|
192
|
+
// Add correlation columns
|
|
193
|
+
addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
|
|
194
|
+
|
|
195
|
+
// Set correlation columns color coding
|
|
196
|
+
setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
|
|
197
|
+
|
|
198
|
+
// Compute pMPO parameters - training
|
|
199
|
+
const params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
|
|
200
|
+
|
|
201
|
+
return {
|
|
202
|
+
params: params,
|
|
203
|
+
descrStatsTable: descrStatsTable,
|
|
204
|
+
selectedByPvalue: selectedByPvalue,
|
|
205
|
+
selectedByCorr: selectedByCorr,
|
|
206
|
+
};
|
|
207
|
+
} // fitModelParams
|
|
208
|
+
|
|
209
|
+
/** Predicts pMPO scores for the given data frame using provided pMPO parameters */
|
|
210
|
+
static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, useSigmoid: boolean, predictionName: string): DG.Column {
|
|
211
|
+
const count = df.rowCount;
|
|
212
|
+
const scores = new Float64Array(count).fill(0);
|
|
213
|
+
|
|
214
|
+
// Compute pMPO scores (see https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
|
|
215
|
+
params.forEach((param, name) => {
|
|
216
|
+
const col = df.col(name);
|
|
217
|
+
const b = param.b;
|
|
218
|
+
const c = param.c;
|
|
219
|
+
const x0 = param.cutoff;
|
|
220
|
+
let weight = param.weight;
|
|
221
|
+
const avg = param.desAvg;
|
|
222
|
+
const std = param.desStd;
|
|
223
|
+
const frac = 1.0 / (2 * std**2);
|
|
224
|
+
|
|
225
|
+
if (col == null)
|
|
226
|
+
throw new Error(`Failed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
|
|
227
|
+
|
|
228
|
+
const vals = col.getRawData();
|
|
229
|
+
|
|
230
|
+
if (useSigmoid) {
|
|
231
|
+
if (c > 0) {
|
|
232
|
+
for (let i = 0; i < count; ++i)
|
|
233
|
+
scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac) / (1.0 + b * (c ** (-(vals[i] - x0))));
|
|
234
|
+
} else {
|
|
235
|
+
weight = weight / (1.0 + b);
|
|
236
|
+
|
|
237
|
+
for (let i = 0; i < count; ++i)
|
|
238
|
+
scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
|
|
239
|
+
}
|
|
240
|
+
} else {
|
|
241
|
+
for (let i = 0; i < count; ++i)
|
|
242
|
+
scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
|
|
243
|
+
}
|
|
244
|
+
});
|
|
245
|
+
|
|
246
|
+
return DG.Column.fromFloat64Array(predictionName, scores);
|
|
247
|
+
} // predict
|
|
248
|
+
|
|
249
|
+
private params: Map<string, PmpoParams> | null = null;
|
|
250
|
+
|
|
251
|
+
private table: DG.DataFrame;
|
|
252
|
+
private view: DG.TableView;
|
|
253
|
+
private boolCols: DG.Column[];
|
|
254
|
+
private numericCols: DG.Column[];
|
|
255
|
+
|
|
256
|
+
private initTable = DG.DataFrame.create();
|
|
257
|
+
|
|
258
|
+
private statGrid = DG.Viewer.grid(this.initTable, {showTitle: true, title: DESCR_TABLE_TITLE});
|
|
259
|
+
|
|
260
|
+
private predictionName = SCORES_TITLE;
|
|
261
|
+
private boolPredictionName = '';
|
|
262
|
+
|
|
263
|
+
private desirabilityProfileRoots = new Map<string, HTMLElement>();
|
|
264
|
+
|
|
265
|
+
private rocCurve = DG.Viewer.scatterPlot(this.initTable, {
|
|
266
|
+
showTitle: true,
|
|
267
|
+
showSizeSelector: false,
|
|
268
|
+
showColorSelector: false,
|
|
269
|
+
});
|
|
270
|
+
|
|
271
|
+
private confusionMatrix = DG.Viewer.fromType('Confusion matrix', this.initTable, {
|
|
272
|
+
xColumnName: 'control',
|
|
273
|
+
yColumnName: 'control',
|
|
274
|
+
showTitle: true,
|
|
275
|
+
title: 'Confusion Matrix',
|
|
276
|
+
descriptionPosition: 'Bottom',
|
|
277
|
+
description: 'Confusion matrix for the predicted vs actual desirability labels.',
|
|
278
|
+
descriptionVisibilityMode: 'Always',
|
|
279
|
+
});
|
|
280
|
+
|
|
281
|
+
constructor(df: DG.DataFrame, view?: DG.TableView) {
|
|
282
|
+
this.table = df;
|
|
283
|
+
this.view = view ?? (grok.shell.tableView(df.name) ?? grok.shell.addTableView(df));
|
|
284
|
+
this.boolCols = this.getBoolCols();
|
|
285
|
+
this.numericCols = this.getValidNumericCols();
|
|
286
|
+
this.predictionName = df.columns.getUnusedName(SCORES_TITLE);
|
|
287
|
+
};
|
|
288
|
+
|
|
289
|
+
/** Sets the ribbon panels in the table view (removes the first panel) */
|
|
290
|
+
private setRibbons(): void {
|
|
291
|
+
const ribPanel = this.view.getRibbonPanels();
|
|
292
|
+
|
|
293
|
+
if (ribPanel.length < 1)
|
|
294
|
+
return;
|
|
295
|
+
|
|
296
|
+
this.view.setRibbonPanels(ribPanel.slice(1));
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
/** Updates the statistics grid viewer with the given statistics table and selected descriptors */
|
|
300
|
+
private updateStatisticsGrid(table: DG.DataFrame, descriptorNames: string[], selectedByPvalue: string[], selectedByCorr: string[]): void {
|
|
301
|
+
const grid = this.statGrid;
|
|
302
|
+
grid.dataFrame = table;
|
|
303
|
+
grid.setOptions({
|
|
304
|
+
showTitle: true,
|
|
305
|
+
title: table.name,
|
|
306
|
+
});
|
|
307
|
+
|
|
308
|
+
grid.sort([SELECTED_TITLE], [false]);
|
|
309
|
+
grid.col(P_VAL)!.format = 'scientific';
|
|
310
|
+
|
|
311
|
+
// set color coding
|
|
312
|
+
const descrCol = grid.col(DESCR_TITLE)!;
|
|
313
|
+
descrCol.isTextColorCoded = true;
|
|
314
|
+
|
|
315
|
+
const pValCol = grid.col(P_VAL)!;
|
|
316
|
+
pValCol.isTextColorCoded = true;
|
|
317
|
+
|
|
318
|
+
descriptorNames.forEach((name) => {
|
|
319
|
+
const col = grid.col(name);
|
|
320
|
+
if (col == null)
|
|
321
|
+
return;
|
|
322
|
+
|
|
323
|
+
col.isTextColorCoded = true;
|
|
324
|
+
col.format = '0.000';
|
|
325
|
+
});
|
|
326
|
+
|
|
327
|
+
// set tooltips
|
|
328
|
+
grid.onCellTooltip((cell, x, y) =>{
|
|
329
|
+
if (cell.isColHeader) {
|
|
330
|
+
const cellCol = cell.tableColumn;
|
|
331
|
+
|
|
332
|
+
if (cellCol == null)
|
|
333
|
+
return false;
|
|
334
|
+
|
|
335
|
+
const colName = cellCol.name;
|
|
336
|
+
|
|
337
|
+
switch (colName) {
|
|
338
|
+
case DESCR_TITLE:
|
|
339
|
+
ui.tooltip.show(getDescrTooltip(
|
|
340
|
+
DESCR_TITLE,
|
|
341
|
+
'Use of descriptors in model construction:',
|
|
342
|
+
'selected',
|
|
343
|
+
'excluded',
|
|
344
|
+
), x, y);
|
|
345
|
+
return true;
|
|
346
|
+
|
|
347
|
+
case DESIRABILITY_COL_NAME:
|
|
348
|
+
ui.tooltip.show(ui.divV([
|
|
349
|
+
ui.h2(DESIRABILITY_COL_NAME),
|
|
350
|
+
ui.divText('Desirability profile charts for each descriptor. Only profiles for selected descriptors are shown.'),
|
|
351
|
+
]), x, y);
|
|
352
|
+
return true;
|
|
353
|
+
|
|
354
|
+
case WEIGHT_TITLE:
|
|
355
|
+
ui.tooltip.show(ui.divV([
|
|
356
|
+
ui.h2(WEIGHT_TITLE),
|
|
357
|
+
ui.divText('Weights of selected descriptors.'),
|
|
358
|
+
]), x, y);
|
|
359
|
+
return true;
|
|
360
|
+
|
|
361
|
+
case P_VAL:
|
|
362
|
+
ui.tooltip.show(getDescrTooltip(
|
|
363
|
+
P_VAL,
|
|
364
|
+
'Filtering descriptors by p-value:',
|
|
365
|
+
'selected',
|
|
366
|
+
'excluded',
|
|
367
|
+
), x, y);
|
|
368
|
+
return true;
|
|
369
|
+
|
|
370
|
+
default:
|
|
371
|
+
if (descriptorNames.includes(colName)) {
|
|
372
|
+
ui.tooltip.show(
|
|
373
|
+
getDescrTooltip(
|
|
374
|
+
colName,
|
|
375
|
+
`Correlation of ${colName} with other descriptors, measured by R²:`,
|
|
376
|
+
'weakly correlated',
|
|
377
|
+
'highly correlated',
|
|
378
|
+
), x, y);
|
|
379
|
+
|
|
380
|
+
return true;
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
return false;
|
|
384
|
+
}
|
|
385
|
+
} else {
|
|
386
|
+
if (cell.isTableCell) {
|
|
387
|
+
const cellCol = cell.tableColumn;
|
|
388
|
+
|
|
389
|
+
if (cellCol == null)
|
|
390
|
+
return false;
|
|
391
|
+
|
|
392
|
+
const colName = cellCol.name;
|
|
393
|
+
const value = cell.value;
|
|
394
|
+
|
|
395
|
+
if (colName === DESCR_TITLE) {
|
|
396
|
+
if (selectedByCorr.includes(value))
|
|
397
|
+
ui.tooltip.show('Selected for model construction.', x, y);
|
|
398
|
+
else if (selectedByPvalue.includes(value))
|
|
399
|
+
ui.tooltip.show('Excluded due to a high correlation with other descriptors.', x, y);
|
|
400
|
+
else
|
|
401
|
+
ui.tooltip.show('Excluded due to a high p-value.', x, y);
|
|
402
|
+
|
|
403
|
+
return true;
|
|
404
|
+
} else {
|
|
405
|
+
const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
|
|
406
|
+
|
|
407
|
+
if (colName === WEIGHT_TITLE) {
|
|
408
|
+
if (!this.desirabilityProfileRoots.has(descriptor)) {
|
|
409
|
+
if (selectedByPvalue.includes(descriptor))
|
|
410
|
+
ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
|
|
411
|
+
else
|
|
412
|
+
ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
|
|
413
|
+
|
|
414
|
+
return true;
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
return false;
|
|
418
|
+
} else {
|
|
419
|
+
if (descriptorNames.includes(colName) && (!selectedByPvalue.includes(descriptor))) {
|
|
420
|
+
ui.tooltip.show(`<b>${descriptor}</b> is excluded due to a high p-value; so correlation with <b>${colName}</b> is not needed.`, x, y);
|
|
421
|
+
return true;
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
return false;
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
}); // grid.onCellTooltip
|
|
430
|
+
|
|
431
|
+
const desirabilityCol = grid.col(DESIRABILITY_COL_NAME);
|
|
432
|
+
grid.setOptions({'rowHeight': STAT_GRID_HEIGHT});
|
|
433
|
+
desirabilityCol!.width = DESIRABILITY_COLUMN_WIDTH;
|
|
434
|
+
desirabilityCol!.cellType = 'html';
|
|
435
|
+
|
|
436
|
+
// show desirability profile
|
|
437
|
+
grid.onCellPrepare((cell) => {
|
|
438
|
+
const cellCol = cell.tableColumn;
|
|
439
|
+
if (cellCol == null)
|
|
440
|
+
return;
|
|
441
|
+
|
|
442
|
+
if (cell.tableColumn == null)
|
|
443
|
+
return;
|
|
444
|
+
|
|
445
|
+
if (!cell.isTableCell)
|
|
446
|
+
return;
|
|
447
|
+
|
|
448
|
+
if (cell.tableColumn.name !== DESIRABILITY_COL_NAME)
|
|
449
|
+
return;
|
|
450
|
+
|
|
451
|
+
const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
|
|
452
|
+
const element = this.desirabilityProfileRoots.get(descriptor);
|
|
453
|
+
|
|
454
|
+
if (element != null)
|
|
455
|
+
cell.element = element;
|
|
456
|
+
else {
|
|
457
|
+
const selected = selectedByPvalue.includes(descriptor);
|
|
458
|
+
const text = selected ? 'highly correlated with other descriptors' : 'statistically insignificant';
|
|
459
|
+
const tooltipMsg = selected ?
|
|
460
|
+
`No chart shown: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.` :
|
|
461
|
+
`No chart shown: <b>${descriptor}</b> is excluded due to a high p-value.`;
|
|
462
|
+
|
|
463
|
+
const divWithDescription = ui.divText(text);
|
|
464
|
+
divWithDescription.style.color = COLORS.SKIPPED;
|
|
465
|
+
divWithDescription.classList.add('eda-pmpo-centered-text');
|
|
466
|
+
ui.tooltip.bind(divWithDescription, tooltipMsg);
|
|
467
|
+
cell.element = divWithDescription;
|
|
468
|
+
}
|
|
469
|
+
}); // grid.onCellPrepare
|
|
470
|
+
} // updateGrid
|
|
471
|
+
|
|
472
|
+
/** Updates the main grid viewer with the pMPO scores column */
|
|
473
|
+
private updateGrid(): void {
|
|
474
|
+
const grid = this.view.grid;
|
|
475
|
+
const name = this.predictionName;
|
|
476
|
+
|
|
477
|
+
grid.sort([this.predictionName], [false]);
|
|
478
|
+
|
|
479
|
+
grid.col(name)!.format = '0.0000';
|
|
480
|
+
|
|
481
|
+
// set tooltips
|
|
482
|
+
grid.onCellTooltip((cell, x, y) => {
|
|
483
|
+
if (cell.isColHeader) {
|
|
484
|
+
const cellCol = cell.tableColumn;
|
|
485
|
+
if (cellCol) {
|
|
486
|
+
if (cell.tableColumn.name === name) {
|
|
487
|
+
ui.tooltip.show(getScoreTooltip(), x, y);
|
|
488
|
+
|
|
489
|
+
return true;
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
return false;
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
});
|
|
496
|
+
} // updateGrid
|
|
497
|
+
|
|
498
|
+
/** Updates the desirability profile data */
|
|
499
|
+
private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame, useSigmoidalCorrection: boolean): void {
|
|
500
|
+
if (this.params == null)
|
|
501
|
+
return;
|
|
502
|
+
|
|
503
|
+
// Clear existing roots
|
|
504
|
+
this.desirabilityProfileRoots.forEach((root) => root.remove());
|
|
505
|
+
this.desirabilityProfileRoots.clear();
|
|
506
|
+
|
|
507
|
+
const desirabilityProfile = getDesirabilityProfileJson(this.params, useSigmoidalCorrection, '', '', true);
|
|
508
|
+
|
|
509
|
+
// Set weights
|
|
510
|
+
const descrNames = descrStatsTable.col(DESCR_TITLE)!.toList();
|
|
511
|
+
const weightsRaw = descrStatsTable.col(WEIGHT_TITLE)!.getRawData();
|
|
512
|
+
const props = desirabilityProfile.properties;
|
|
513
|
+
|
|
514
|
+
for (const name of Object.keys(props))
|
|
515
|
+
weightsRaw[descrNames.indexOf(name)] = props[name].weight;
|
|
516
|
+
|
|
517
|
+
// Set HTML elements
|
|
518
|
+
const mpoEditor = new MpoProfileEditor();
|
|
519
|
+
mpoEditor.setProfile(desirabilityProfile);
|
|
520
|
+
const container = mpoEditor.root;
|
|
521
|
+
const rootsCol = container.querySelector('div.d4-flex-col.ui-div');
|
|
522
|
+
|
|
523
|
+
if (rootsCol == null)
|
|
524
|
+
return;
|
|
525
|
+
|
|
526
|
+
const rows = rootsCol.querySelectorAll('div.d4-flex-row.ui-div');
|
|
527
|
+
|
|
528
|
+
rows.forEach((row) => {
|
|
529
|
+
const children = row.children;
|
|
530
|
+
if (children.length < 2) // expecting descriptor name, weight & profile
|
|
531
|
+
return;
|
|
532
|
+
|
|
533
|
+
const descrDivChildren = (children[0] as HTMLElement).children;
|
|
534
|
+
if (descrDivChildren.length < 1) // expecting 1 div with descriptor name
|
|
535
|
+
return;
|
|
536
|
+
|
|
537
|
+
const descrName = (descrDivChildren[0] as HTMLElement).innerText;
|
|
538
|
+
|
|
539
|
+
this.desirabilityProfileRoots.set(descrName, children[2] as HTMLElement);
|
|
540
|
+
});
|
|
541
|
+
} // updateDesirabilityProfileData
|
|
542
|
+
|
|
543
|
+
/** Updates the ROC curve viewer with the given desirability (labels) and prediction columns
|
|
544
|
+
* @return Best threshold according to Youden's J statistic
|
|
545
|
+
*/
|
|
546
|
+
private updateRocCurve(desirability: DG.Column, prediction: DG.Column): number {
|
|
547
|
+
const evaluation = getPmpoEvaluation(desirability, prediction);
|
|
548
|
+
|
|
549
|
+
const rocDf = DG.DataFrame.fromColumns([
|
|
550
|
+
DG.Column.fromFloat32Array(THRESHOLD, ROC_TRESHOLDS),
|
|
551
|
+
DG.Column.fromFloat32Array(FPR_TITLE, evaluation.fpr),
|
|
552
|
+
DG.Column.fromFloat32Array(TPR_TITLE, evaluation.tpr),
|
|
553
|
+
]);
|
|
554
|
+
|
|
555
|
+
// Add baseline
|
|
556
|
+
rocDf.meta.formulaLines.addLine({
|
|
557
|
+
title: 'Non-informative baseline',
|
|
558
|
+
formula: `\${${TPR_TITLE}} = \${${FPR_TITLE}}`,
|
|
559
|
+
width: 1,
|
|
560
|
+
style: 'dashed',
|
|
561
|
+
min: 0,
|
|
562
|
+
max: 1,
|
|
563
|
+
});
|
|
564
|
+
|
|
565
|
+
this.rocCurve.dataFrame = rocDf;
|
|
566
|
+
this.rocCurve.setOptions({
|
|
567
|
+
xColumnName: FPR_TITLE,
|
|
568
|
+
yColumnName: TPR_TITLE,
|
|
569
|
+
linesOrderColumnName: FPR_TITLE,
|
|
570
|
+
linesWidth: 5,
|
|
571
|
+
markerType: 'dot',
|
|
572
|
+
title: `ROC Curve (AUC = ${evaluation.auc.toFixed(3)})`,
|
|
573
|
+
});
|
|
574
|
+
|
|
575
|
+
return evaluation.threshold;
|
|
576
|
+
} // updateRocCurve
|
|
577
|
+
|
|
578
|
+
/** Updates the confusion matrix viewer with the given data frame, desirability column name, and best threshold */
|
|
579
|
+
private updateConfusionMatrix(df: DG.DataFrame, desColName: string, bestThreshold: number): void {
|
|
580
|
+
this.confusionMatrix.dataFrame = df;
|
|
581
|
+
this.confusionMatrix.setOptions({
|
|
582
|
+
xColumnName: desColName,
|
|
583
|
+
yColumnName: this.boolPredictionName,
|
|
584
|
+
description: `Threshold: ${bestThreshold.toFixed(3)} (optimized via Youden's J)`,
|
|
585
|
+
title: desColName + ' Confusion Matrix',
|
|
586
|
+
});
|
|
587
|
+
} // updateConfusionMatrix
|
|
588
|
+
|
|
589
|
+
/** Fits the pMPO model to the given data and updates the viewers accordingly */
|
|
590
|
+
private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
|
|
591
|
+
pValTresh: number, r2Tresh: number, qCutoff: number, useSigmoid: boolean): void {
|
|
592
|
+
const trainResult = Pmpo.fit(df, descriptors, desirability, pValTresh, r2Tresh, qCutoff);
|
|
593
|
+
this.params = trainResult.params;
|
|
594
|
+
const descrStatsTable = trainResult.descrStatsTable;
|
|
595
|
+
const selectedByPvalue = trainResult.selectedByPvalue;
|
|
596
|
+
const selectedByCorr = trainResult.selectedByCorr;
|
|
597
|
+
|
|
598
|
+
const descriptorNames = descriptors.names();
|
|
599
|
+
|
|
600
|
+
const prediction = Pmpo.predict(df, this.params, useSigmoid, this.predictionName);
|
|
601
|
+
|
|
602
|
+
// Mark predictions with a color
|
|
603
|
+
prediction.colors.setLinear(getOutputPalette(OPT_TYPE.MAX), {min: prediction.stats.min, max: prediction.stats.max});
|
|
604
|
+
|
|
605
|
+
// Remove existing prediction column and add the new one
|
|
606
|
+
df.columns.remove(this.predictionName);
|
|
607
|
+
df.columns.add(prediction);
|
|
608
|
+
|
|
609
|
+
// Update viewers
|
|
610
|
+
this.updateGrid();
|
|
611
|
+
|
|
612
|
+
// Update desirability profile roots map
|
|
613
|
+
this.updateDesirabilityProfileData(descrStatsTable, useSigmoid);
|
|
614
|
+
|
|
615
|
+
// Update statistics grid
|
|
616
|
+
this.updateStatisticsGrid(descrStatsTable, descriptorNames, selectedByPvalue, selectedByCorr);
|
|
617
|
+
|
|
618
|
+
// Update ROC curve
|
|
619
|
+
const bestThreshold = this.updateRocCurve(desirability, prediction);
|
|
620
|
+
|
|
621
|
+
// Update desirability prediction column
|
|
622
|
+
const desColName = desirability.name;
|
|
623
|
+
df.columns.remove(this.boolPredictionName);
|
|
624
|
+
this.boolPredictionName = df.columns.getUnusedName(desColName + '(predicted)');
|
|
625
|
+
const boolPrediction = getBoolPredictionColumn(prediction, bestThreshold, this.boolPredictionName);
|
|
626
|
+
df.columns.add(boolPrediction);
|
|
627
|
+
|
|
628
|
+
// Update confusion matrix
|
|
629
|
+
this.updateConfusionMatrix(df, desColName, bestThreshold);
|
|
630
|
+
|
|
631
|
+
this.view.dataFrame.selection.setAll(false, true);
|
|
632
|
+
} // fitAndUpdateViewers
|
|
633
|
+
|
|
634
|
+
/** Runs the pMPO model training application */
|
|
635
|
+
public runTrainingApp(): void {
|
|
636
|
+
const dockMng = this.view.dockManager;
|
|
637
|
+
|
|
638
|
+
// Inputs form
|
|
639
|
+
dockMng.dock(this.getInputForm(true).form, DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
|
|
640
|
+
|
|
641
|
+
// Dock viewers
|
|
642
|
+
const gridNode = dockMng.findNode(this.view.grid.root);
|
|
643
|
+
|
|
644
|
+
if (gridNode == null)
|
|
645
|
+
throw new Error('Failed to train pMPO: missing a grid in the table view.');
|
|
646
|
+
|
|
647
|
+
// Dock statistics grid
|
|
648
|
+
const statGridNode = dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
|
|
649
|
+
|
|
650
|
+
// Dock ROC curve
|
|
651
|
+
const rocNode = dockMng.dock(this.rocCurve, DG.DOCK_TYPE.RIGHT, statGridNode, undefined, 0.3);
|
|
652
|
+
|
|
653
|
+
// Dock confusion matrix
|
|
654
|
+
dockMng.dock(this.confusionMatrix, DG.DOCK_TYPE.RIGHT, rocNode, undefined, 0.2);
|
|
655
|
+
|
|
656
|
+
this.setRibbons();
|
|
657
|
+
} // runTrainingApp
|
|
658
|
+
|
|
659
|
+
/** Runs the pMPO model training application */
|
|
660
|
+
public getPmpoAppItems(): PmpoAppItems {
|
|
661
|
+
return {
|
|
662
|
+
statsGrid: this.statGrid,
|
|
663
|
+
rocCurve: this.rocCurve,
|
|
664
|
+
confusionMatrix: this.confusionMatrix,
|
|
665
|
+
controls: this.getInputForm(false),
|
|
666
|
+
};
|
|
667
|
+
} // getViewers
|
|
668
|
+
|
|
669
|
+
/** Creates and returns the input form for pMPO model training */
|
|
670
|
+
private getInputForm(addBtn: boolean): Controls {
|
|
671
|
+
const form = ui.form([]);
|
|
672
|
+
form.append(ui.h2('Training data'));
|
|
673
|
+
const numericColNames = this.numericCols.map((col) => col.name);
|
|
674
|
+
|
|
675
|
+
// Function to run computations on input changes
|
|
676
|
+
const runComputations = () => {
|
|
677
|
+
try {
|
|
678
|
+
//grok.shell.info('Running...');
|
|
679
|
+
|
|
680
|
+
this.fitAndUpdateViewers(
|
|
681
|
+
this.table,
|
|
682
|
+
DG.DataFrame.fromColumns(descrInput.value).columns,
|
|
683
|
+
this.table.col(desInput.value!)!,
|
|
684
|
+
pInput.value!,
|
|
685
|
+
rInput.value!,
|
|
686
|
+
qInput.value!,
|
|
687
|
+
useSigmoidInput.value,
|
|
688
|
+
);
|
|
689
|
+
} catch (err) {
|
|
690
|
+
err instanceof PmpoError ?
|
|
691
|
+
grok.shell.warning(err.message) :
|
|
692
|
+
grok.shell.error(err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.');
|
|
693
|
+
}
|
|
694
|
+
};
|
|
695
|
+
|
|
696
|
+
// Descriptor columns input
|
|
697
|
+
const descrInput = ui.input.columns('Descriptors', {
|
|
698
|
+
table: this.table,
|
|
699
|
+
nullable: false,
|
|
700
|
+
available: numericColNames,
|
|
701
|
+
checked: numericColNames,
|
|
702
|
+
tooltipText: 'Descriptor columns used for model construction.',
|
|
703
|
+
onValueChanged: (value) => {
|
|
704
|
+
if (value != null) {
|
|
705
|
+
areTunedSettingsUsed = false;
|
|
706
|
+
checkAutoTuneAndRun();
|
|
707
|
+
}
|
|
708
|
+
},
|
|
709
|
+
});
|
|
710
|
+
form.append(descrInput.root);
|
|
711
|
+
|
|
712
|
+
// Desirability column input
|
|
713
|
+
const desInput = ui.input.choice('Desirability', {
|
|
714
|
+
nullable: false,
|
|
715
|
+
value: this.boolCols[0].name,
|
|
716
|
+
items: this.boolCols.map((col) => col.name),
|
|
717
|
+
tooltipText: 'Desirability column.',
|
|
718
|
+
onValueChanged: (value) => {
|
|
719
|
+
if (value != null) {
|
|
720
|
+
areTunedSettingsUsed = false;
|
|
721
|
+
checkAutoTuneAndRun();
|
|
722
|
+
}
|
|
723
|
+
},
|
|
724
|
+
});
|
|
725
|
+
form.append(desInput.root);
|
|
726
|
+
|
|
727
|
+
const header = ui.h2('Settings');
|
|
728
|
+
form.append(header);
|
|
729
|
+
ui.tooltip.bind(header, 'Settings of the pMPO model.');
|
|
730
|
+
|
|
731
|
+
// use sigmoid correction
|
|
732
|
+
const useSigmoidInput = ui.input.bool('\u03C3 correction', {
|
|
733
|
+
value: USE_SIGMOID_DEFAULT,
|
|
734
|
+
tooltipText: 'Use the sigmoidal correction to the weighted Gaussian scores.',
|
|
735
|
+
onValueChanged: (_value) => {
|
|
736
|
+
areTunedSettingsUsed = false;
|
|
737
|
+
checkAutoTuneAndRun();
|
|
738
|
+
},
|
|
739
|
+
});
|
|
740
|
+
form.append(useSigmoidInput.root);
|
|
741
|
+
|
|
742
|
+
const toUseAutoTune = (this.table.rowCount <= AUTO_TUNE_MAX_APPLICABLE_ROWS);
|
|
743
|
+
|
|
744
|
+
// Flag indicating whether optimal parameters from auto-tuning are currently used
|
|
745
|
+
let areTunedSettingsUsed = false;
|
|
746
|
+
|
|
747
|
+
const setOptimalParametersAndRun = async () => {
|
|
748
|
+
if (!areTunedSettingsUsed) {
|
|
749
|
+
const optimalSettings = await this.getOptimalSettings(
|
|
750
|
+
DG.DataFrame.fromColumns(descrInput.value).columns,
|
|
751
|
+
this.table.col(desInput.value!)!,
|
|
752
|
+
useSigmoidInput.value,
|
|
753
|
+
);
|
|
754
|
+
|
|
755
|
+
if (optimalSettings.success) {
|
|
756
|
+
pInput.value = Math.max(optimalSettings.pValTresh, P_VAL_TRES_MIN);
|
|
757
|
+
rInput.value = Math.max(optimalSettings.r2Tresh, R2_MIN);
|
|
758
|
+
qInput.value = Math.max(optimalSettings.qCutoff, Q_CUTOFF_MIN);
|
|
759
|
+
areTunedSettingsUsed = true;
|
|
760
|
+
} else
|
|
761
|
+
autoTuneInput.value = false; // revert to manual mode if optimization failed
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
runComputations();
|
|
765
|
+
};
|
|
766
|
+
|
|
767
|
+
const checkAutoTuneAndRun = () => {
|
|
768
|
+
if (autoTuneInput.value)
|
|
769
|
+
setOptimalParametersAndRun();
|
|
770
|
+
else
|
|
771
|
+
runComputations();
|
|
772
|
+
};
|
|
773
|
+
|
|
774
|
+
// autotuning input
|
|
775
|
+
const autoTuneInput = ui.input.bool('Auto-tuning', {
|
|
776
|
+
value: false,
|
|
777
|
+
tooltipText: 'Automatically select optimal p-value, R², and q-cutoff by maximizing AUC.',
|
|
778
|
+
onValueChanged: async (value) => {
|
|
779
|
+
setEnability(!value);
|
|
780
|
+
|
|
781
|
+
if (areTunedSettingsUsed)
|
|
782
|
+
return;
|
|
783
|
+
|
|
784
|
+
// If auto-tuning is turned on, set optimal parameters and run computations
|
|
785
|
+
if (value)
|
|
786
|
+
await setOptimalParametersAndRun();
|
|
787
|
+
},
|
|
788
|
+
});
|
|
789
|
+
form.append(autoTuneInput.root);
|
|
790
|
+
|
|
791
|
+
// p-value threshold input
|
|
792
|
+
const pInput = ui.input.float('p-value', {
|
|
793
|
+
nullable: false,
|
|
794
|
+
min: P_VAL_TRES_MIN,
|
|
795
|
+
max: P_VAL_TRES_MAX,
|
|
796
|
+
step: 0.001,
|
|
797
|
+
value: P_VAL_TRES_DEFAULT,
|
|
798
|
+
// @ts-ignore
|
|
799
|
+
format: FORMAT,
|
|
800
|
+
tooltipText: 'P-value threshold. Descriptors with p-values above this threshold are excluded.',
|
|
801
|
+
onValueChanged: (value) => {
|
|
802
|
+
// Prevent running computations when auto-tuning is on, since parameters will be set automatically
|
|
803
|
+
if (autoTuneInput.value)
|
|
804
|
+
return;
|
|
805
|
+
|
|
806
|
+
areTunedSettingsUsed = false;
|
|
807
|
+
if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= P_VAL_TRES_MAX))
|
|
808
|
+
runComputations();
|
|
809
|
+
},
|
|
810
|
+
});
|
|
811
|
+
form.append(pInput.root);
|
|
812
|
+
|
|
813
|
+
// R² threshold input
|
|
814
|
+
const rInput = ui.input.float('R²', {
|
|
815
|
+
// @ts-ignore
|
|
816
|
+
format: FORMAT,
|
|
817
|
+
nullable: false,
|
|
818
|
+
min: R2_MIN,
|
|
819
|
+
value: R2_DEFAULT,
|
|
820
|
+
max: R2_MAX,
|
|
821
|
+
step: 0.01,
|
|
822
|
+
// eslint-disable-next-line max-len
|
|
823
|
+
tooltipText: 'Squared correlation threshold. Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
|
|
824
|
+
onValueChanged: (value) => {
|
|
825
|
+
// Prevent running computations when auto-tuning is on, since parameters will be set automatically
|
|
826
|
+
if (autoTuneInput.value)
|
|
827
|
+
return;
|
|
828
|
+
|
|
829
|
+
areTunedSettingsUsed = false;
|
|
830
|
+
|
|
831
|
+
if ((value != null) && (value >= R2_MIN) && (value <= R2_MAX))
|
|
832
|
+
runComputations();
|
|
833
|
+
},
|
|
834
|
+
});
|
|
835
|
+
form.append(rInput.root);
|
|
836
|
+
|
|
837
|
+
// q-cutoff input
|
|
838
|
+
const qInput = ui.input.float('q-cutoff', {
|
|
839
|
+
// @ts-ignore
|
|
840
|
+
format: FORMAT,
|
|
841
|
+
nullable: false,
|
|
842
|
+
min: Q_CUTOFF_MIN,
|
|
843
|
+
value: Q_CUTOFF_DEFAULT,
|
|
844
|
+
max: Q_CUTOFF_MAX,
|
|
845
|
+
step: 0.01,
|
|
846
|
+
tooltipText: 'Q-cutoff for the pMPO model computation.',
|
|
847
|
+
onValueChanged: (value) => {
|
|
848
|
+
// Prevent running computations when auto-tuning is on, since parameters will be set automatically
|
|
849
|
+
if (autoTuneInput.value)
|
|
850
|
+
return;
|
|
851
|
+
|
|
852
|
+
areTunedSettingsUsed = false;
|
|
853
|
+
|
|
854
|
+
if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= Q_CUTOFF_MAX))
|
|
855
|
+
runComputations();
|
|
856
|
+
},
|
|
857
|
+
});
|
|
858
|
+
form.append(qInput.root);
|
|
859
|
+
|
|
860
|
+
const setEnability = (toEnable: boolean) => {
|
|
861
|
+
pInput.enabled = toEnable;
|
|
862
|
+
rInput.enabled = toEnable;
|
|
863
|
+
qInput.enabled = toEnable;
|
|
864
|
+
};
|
|
865
|
+
|
|
866
|
+
setTimeout(() => {
|
|
867
|
+
runComputations();
|
|
868
|
+
|
|
869
|
+
if (toUseAutoTune)
|
|
870
|
+
autoTuneInput.value = true; // this will trigger setting optimal parameters and running computations
|
|
871
|
+
else
|
|
872
|
+
runComputations();
|
|
873
|
+
}, 10);
|
|
874
|
+
|
|
875
|
+
// Save model button
|
|
876
|
+
const saveBtn = ui.button('Save', async () => {
|
|
877
|
+
if (this.params == null) {
|
|
878
|
+
grok.shell.warning('Failed to save pMPO model: null parameters.');
|
|
879
|
+
return;
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
saveModel(this.params, this.table.name, useSigmoidInput.value);
|
|
883
|
+
}, 'Save model as platform file.');
|
|
884
|
+
|
|
885
|
+
if (addBtn)
|
|
886
|
+
form.append(saveBtn);
|
|
887
|
+
|
|
888
|
+
const div = ui.div([form]);
|
|
889
|
+
div.classList.add('eda-pmpo-input-form');
|
|
890
|
+
|
|
891
|
+
return {
|
|
892
|
+
form: div,
|
|
893
|
+
saveBtn: saveBtn,
|
|
894
|
+
};
|
|
895
|
+
} // getInputForm
|
|
896
|
+
|
|
897
|
+
/** Retrieves boolean columns from the data frame */
|
|
898
|
+
private getBoolCols(): DG.Column[] {
|
|
899
|
+
const res: DG.Column[] = [];
|
|
900
|
+
|
|
901
|
+
for (const col of this.table.columns) {
|
|
902
|
+
if ((col.type === DG.COLUMN_TYPE.BOOL) && (col.stats.stdev > 0))
|
|
903
|
+
res.push(col);
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
return res;
|
|
907
|
+
} // getBoolCols
|
|
908
|
+
|
|
909
|
+
/** Retrieves valid (numerical, no missing values, non-zero standard deviation) numeric columns from the data frame */
|
|
910
|
+
private getValidNumericCols(): DG.Column[] {
|
|
911
|
+
const res: DG.Column[] = [];
|
|
912
|
+
|
|
913
|
+
for (const col of this.table.columns) {
|
|
914
|
+
if ((col.isNumerical) && (col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
|
|
915
|
+
res.push(col);
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
return res;
|
|
919
|
+
} // getValidNumericCols
|
|
920
|
+
|
|
921
|
+
/** Fits the pMPO model to the given data and updates the viewers accordingly */
|
|
922
|
+
private async getOptimalSettings(descriptors: DG.ColumnList, desirability: DG.Column, useSigmoid: boolean): Promise<OptimalPoint> {
|
|
923
|
+
const failedResult: OptimalPoint = {
|
|
924
|
+
pValTresh: 0,
|
|
925
|
+
r2Tresh: 0,
|
|
926
|
+
qCutoff: 0,
|
|
927
|
+
success: false,
|
|
928
|
+
};
|
|
929
|
+
|
|
930
|
+
const descriptorNames = descriptors.names();
|
|
931
|
+
const {desired, nonDesired} = getDesiredTables(this.table, desirability);
|
|
932
|
+
|
|
933
|
+
// Compute descriptors' statistics
|
|
934
|
+
const descrStats = new Map<string, DescriptorStatistics>();
|
|
935
|
+
descriptorNames.forEach((name) => {
|
|
936
|
+
descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
|
|
937
|
+
});
|
|
938
|
+
const descrStatsTable = getDescriptorStatisticsTable(descrStats);
|
|
939
|
+
|
|
940
|
+
// Filter by p-value
|
|
941
|
+
const selectedByPvalue = getFilteredByPvalue(descrStatsTable, P_VAL_TRES_DEFAULT);
|
|
942
|
+
if (selectedByPvalue.length < 1)
|
|
943
|
+
return failedResult;
|
|
944
|
+
|
|
945
|
+
const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
|
|
946
|
+
|
|
947
|
+
const funcToBeMinimized = (point: Float32Array) => {
|
|
948
|
+
// Filter by correlations
|
|
949
|
+
const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, point[0], correlationTriples);
|
|
950
|
+
|
|
951
|
+
// Compute pMPO parameters - training
|
|
952
|
+
const params = getModelParams(desired, nonDesired, selectedByCorr, point[1]);
|
|
953
|
+
|
|
954
|
+
// Get predictions
|
|
955
|
+
const prediction = Pmpo.predict(this.table, params, useSigmoid, this.predictionName);
|
|
956
|
+
|
|
957
|
+
// Evaluate predictions and return 1 - AUC (since optimization minimizes the function, but we want to maximize AUC)
|
|
958
|
+
return 1 - getPmpoEvaluation(desirability, prediction).auc;
|
|
959
|
+
}; // funcToBeMinimized
|
|
960
|
+
|
|
961
|
+
const pi = DG.TaskBarProgressIndicator.create('Optimizing... ', {cancelable: true});
|
|
962
|
+
|
|
963
|
+
try {
|
|
964
|
+
const optimalResult = await optimizeNM(
|
|
965
|
+
pi,
|
|
966
|
+
funcToBeMinimized,
|
|
967
|
+
new Float32Array([R2_DEFAULT, Q_CUTOFF_DEFAULT]),
|
|
968
|
+
DEFAULT_OPTIMIZATION_SETTINGS,
|
|
969
|
+
LOW_PARAMS_BOUNDS,
|
|
970
|
+
HIGH_PARAMS_BOUNDS,
|
|
971
|
+
);
|
|
972
|
+
|
|
973
|
+
const success = !pi.canceled;
|
|
974
|
+
pi.close();
|
|
975
|
+
|
|
976
|
+
if (success) {
|
|
977
|
+
return {
|
|
978
|
+
pValTresh: P_VAL_TRES_DEFAULT,
|
|
979
|
+
r2Tresh: optimalResult.optimalPoint[0],
|
|
980
|
+
qCutoff: optimalResult.optimalPoint[1],
|
|
981
|
+
success: true,
|
|
982
|
+
};
|
|
983
|
+
} else
|
|
984
|
+
return failedResult;
|
|
985
|
+
} catch (err) {
|
|
986
|
+
pi.close();
|
|
987
|
+
|
|
988
|
+
return failedResult;
|
|
989
|
+
}
|
|
990
|
+
} // getOptimalSettings
|
|
991
|
+
}; // Pmpo
|