@datagrok/eda 1.4.11 → 1.4.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +7 -0
- package/README.md +8 -0
- package/css/pmpo.css +26 -0
- package/dist/package-test.js.map +1 -1
- package/dist/package.js +1 -1
- package/dist/package.js.map +1 -1
- package/eslintrc.json +46 -0
- package/files/drugs-props-test.csv +126 -0
- package/files/drugs-props-train.csv +664 -0
- package/files/mpo-done.ipynb +2123 -0
- package/package.json +3 -1
- package/src/anova/anova-tools.ts +1 -1
- package/src/anova/anova-ui.ts +1 -1
- package/src/package-api.ts +14 -0
- package/src/package.g.ts +18 -5
- package/src/package.ts +45 -14
- package/src/pareto-optimization/utils.ts +6 -4
- package/src/probabilistic-scoring/pmpo-defs.ts +108 -0
- package/src/probabilistic-scoring/pmpo-utils.ts +580 -0
- package/src/probabilistic-scoring/prob-scoring.ts +637 -0
- package/src/probabilistic-scoring/stat-tools.ts +168 -0
- package/src/softmax-classifier.ts +1 -1
- package/test-console-output-1.log +77 -47
- package/test-record-1.mp4 +0 -0
|
@@ -0,0 +1,637 @@
|
|
|
1
|
+
/* eslint-disable max-len */
|
|
2
|
+
// Probabilistic scoring (pMPO) features
|
|
3
|
+
// Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
|
|
4
|
+
|
|
5
|
+
import * as grok from 'datagrok-api/grok';
|
|
6
|
+
import * as ui from 'datagrok-api/ui';
|
|
7
|
+
import * as DG from 'datagrok-api/dg';
|
|
8
|
+
|
|
9
|
+
import {MpoProfileEditor} from '@datagrok-libraries/statistics/src/mpo/mpo-profile-editor';
|
|
10
|
+
|
|
11
|
+
import '../../css/pmpo.css';
|
|
12
|
+
|
|
13
|
+
import {getDesiredTables, getDescriptorStatistics, normalPdf, sigmoidS} from './stat-tools';
|
|
14
|
+
import {MIN_SAMPLES_COUNT, PMPO_NON_APPLICABLE, DescriptorStatistics, P_VAL_TRES_MIN, DESCR_TITLE,
|
|
15
|
+
R2_MIN, Q_CUTOFF_MIN, PmpoParams, SCORES_TITLE, DESCR_TABLE_TITLE, PMPO_COMPUTE_FAILED, SELECTED_TITLE,
|
|
16
|
+
P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE} from './pmpo-defs';
|
|
17
|
+
import {addSelectedDescriptorsCol, getDescriptorStatisticsTable, getFilteredByPvalue, getFilteredByCorrelations,
|
|
18
|
+
getModelParams, getDescrTooltip, saveModel, getScoreTooltip, getDesirabilityProfileJson, getCorrelationTriples,
|
|
19
|
+
addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding} from './pmpo-utils';
|
|
20
|
+
import {getOutputPalette} from '../pareto-optimization/utils';
|
|
21
|
+
import {OPT_TYPE} from '../pareto-optimization/defs';
|
|
22
|
+
|
|
23
|
+
/** Class implementing probabilistic MPO (pMPO) model training and prediction */
|
|
24
|
+
export class Pmpo {
|
|
25
|
+
/** Checks if pMPO model can be applied to the given descriptors and desirability column */
|
|
26
|
+
static isApplicable(descriptors: DG.ColumnList, desirability: DG.Column, pValThresh: number,
|
|
27
|
+
r2Tresh: number, qCutoff: number, toShowWarning: boolean = false): boolean {
|
|
28
|
+
const rows = desirability.length;
|
|
29
|
+
|
|
30
|
+
const showWarning = (msg: string) => {
|
|
31
|
+
if (toShowWarning)
|
|
32
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + msg);
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
// Check p-value threshold
|
|
36
|
+
if (pValThresh < P_VAL_TRES_MIN) {
|
|
37
|
+
showWarning(`: too small p-value threshold - ${pValThresh}, minimum - ${P_VAL_TRES_MIN}`);
|
|
38
|
+
return false;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
// Check R2 threshold
|
|
42
|
+
if (r2Tresh < R2_MIN) {
|
|
43
|
+
showWarning(`: too small R² threshold - ${r2Tresh}, minimum - ${R2_MIN}`);
|
|
44
|
+
return false;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// Check q-cutoff
|
|
48
|
+
if (qCutoff < Q_CUTOFF_MIN) {
|
|
49
|
+
showWarning(`: too small q-cutoff - ${qCutoff}, minimum - ${Q_CUTOFF_MIN}`);
|
|
50
|
+
return false;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Check samples count
|
|
54
|
+
if (rows < MIN_SAMPLES_COUNT) {
|
|
55
|
+
showWarning(`: not enough of samples - ${rows}, minimum - ${MIN_SAMPLES_COUNT}`);
|
|
56
|
+
return false;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
// Check desirability
|
|
60
|
+
if (desirability.type !== DG.COLUMN_TYPE.BOOL) {
|
|
61
|
+
showWarning(`: "${desirability.name}" must be boolean column.`);
|
|
62
|
+
return false;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if (desirability.stats.stdev === 0) { // TRUE & FALSE
|
|
66
|
+
showWarning(`: "${desirability.name}" has a single category.`);
|
|
67
|
+
return false;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// Check descriptors
|
|
71
|
+
let nonConstantCols = 0;
|
|
72
|
+
|
|
73
|
+
for (const col of descriptors) {
|
|
74
|
+
if (!col.isNumerical) {
|
|
75
|
+
showWarning(`: "${col.name}" is not numerical.`);
|
|
76
|
+
return false;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if (col.stats.missingValueCount > 0) {
|
|
80
|
+
showWarning(`: "${col.name}" contains missing values.`);
|
|
81
|
+
return false;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
if (col.stats.stdev > 0)
|
|
85
|
+
++nonConstantCols;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
if (nonConstantCols < 1) {
|
|
89
|
+
showWarning(`: not enough of non-constant descriptors.`);
|
|
90
|
+
return false;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
return true;
|
|
94
|
+
} // isApplicable
|
|
95
|
+
|
|
96
|
+
/** Validates the input data frame for pMPO applicability */
|
|
97
|
+
static isTableValid(df: DG.DataFrame, toShowMsg: boolean = true): boolean {
|
|
98
|
+
// Check row count
|
|
99
|
+
if (df.rowCount < 2) {
|
|
100
|
+
if (toShowMsg)
|
|
101
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + `. Not enough of samples: ${df.rowCount}, minimum: 2.`);
|
|
102
|
+
return false;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
let boolColsCount = 0;
|
|
106
|
+
let validNumericColsCount = 0;
|
|
107
|
+
|
|
108
|
+
// Check numeric columns and boolean columns
|
|
109
|
+
for (const col of df.columns) {
|
|
110
|
+
if (col.isNumerical) {
|
|
111
|
+
if ((col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
|
|
112
|
+
++validNumericColsCount;
|
|
113
|
+
} else if (col.type == DG.COLUMN_TYPE.BOOL)
|
|
114
|
+
++boolColsCount;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// Check boolean columns count
|
|
118
|
+
if (boolColsCount < 1) {
|
|
119
|
+
if (toShowMsg)
|
|
120
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + ': no boolean columns.');
|
|
121
|
+
return false;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// Check valid numeric columns count
|
|
125
|
+
if (validNumericColsCount < 1) {
|
|
126
|
+
if (toShowMsg)
|
|
127
|
+
grok.shell.warning(PMPO_NON_APPLICABLE + ': no numeric columns without missing values and non-zero variance.');
|
|
128
|
+
return false;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return true;
|
|
132
|
+
} // isTableValid
|
|
133
|
+
|
|
134
|
+
/** Predicts pMPO scores for the given data frame using provided pMPO parameters */
|
|
135
|
+
static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, predictionName: string): DG.Column {
|
|
136
|
+
const count = df.rowCount;
|
|
137
|
+
const scores = new Float64Array(count).fill(0);
|
|
138
|
+
let x = 0;
|
|
139
|
+
|
|
140
|
+
// Compute pMPO scores (see https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
|
|
141
|
+
params.forEach((param, name) => {
|
|
142
|
+
const col = df.col(name);
|
|
143
|
+
|
|
144
|
+
if (col == null)
|
|
145
|
+
throw new Error(`Filed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
|
|
146
|
+
|
|
147
|
+
const vals = col.getRawData();
|
|
148
|
+
for (let i = 0; i < count; ++i) {
|
|
149
|
+
x = vals[i];
|
|
150
|
+
scores[i] += param.weight * normalPdf(x, param.desAvg, param.desStd) * sigmoidS(x, param.x0, param.b, param.c);
|
|
151
|
+
}
|
|
152
|
+
});
|
|
153
|
+
|
|
154
|
+
return DG.Column.fromFloat64Array(predictionName, scores);
|
|
155
|
+
} // predict
|
|
156
|
+
|
|
157
|
+
private params: Map<string, PmpoParams> | null = null;
|
|
158
|
+
|
|
159
|
+
private table: DG.DataFrame;
|
|
160
|
+
private view: DG.TableView;
|
|
161
|
+
private boolCols: DG.Column[];
|
|
162
|
+
private numericCols: DG.Column[];
|
|
163
|
+
|
|
164
|
+
private initTable = grok.data.demo.demog(10);
|
|
165
|
+
|
|
166
|
+
private statGrid = DG.Viewer.grid(this.initTable, {showTitle: true, title: DESCR_TABLE_TITLE});
|
|
167
|
+
|
|
168
|
+
private predictionName = SCORES_TITLE;
|
|
169
|
+
|
|
170
|
+
private desirabilityProfileRoots = new Map<string, HTMLElement>();
|
|
171
|
+
constructor(df: DG.DataFrame) {
|
|
172
|
+
this.table = df;
|
|
173
|
+
this.view = grok.shell.tableView(df.name) ?? grok.shell.addTableView(df);
|
|
174
|
+
this.boolCols = this.getBoolCols();
|
|
175
|
+
this.numericCols = this.getValidNumericCols();
|
|
176
|
+
this.predictionName = df.columns.getUnusedName(SCORES_TITLE);
|
|
177
|
+
};
|
|
178
|
+
|
|
179
|
+
/** Sets the ribbon panels in the table view (removes the first panel) */
|
|
180
|
+
private setRibbons(): void {
|
|
181
|
+
const ribPanel = this.view.getRibbonPanels();
|
|
182
|
+
|
|
183
|
+
if (ribPanel.length < 1)
|
|
184
|
+
return;
|
|
185
|
+
|
|
186
|
+
this.view.setRibbonPanels(ribPanel.slice(1));
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
/** Updates the statistics grid viewer with the given statistics table and selected descriptors */
|
|
190
|
+
private updateStatisticsGrid(table: DG.DataFrame, descriptorNames: string[], selectedByPvalue: string[], selectedByCorr: string[]): void {
|
|
191
|
+
const grid = this.statGrid;
|
|
192
|
+
grid.dataFrame = table;
|
|
193
|
+
grid.setOptions({
|
|
194
|
+
showTitle: true,
|
|
195
|
+
title: table.name,
|
|
196
|
+
});
|
|
197
|
+
|
|
198
|
+
grid.sort([SELECTED_TITLE], [false]);
|
|
199
|
+
grid.col(P_VAL)!.format = 'scientific';
|
|
200
|
+
|
|
201
|
+
// set color coding
|
|
202
|
+
const descrCol = grid.col(DESCR_TITLE)!;
|
|
203
|
+
descrCol.isTextColorCoded = true;
|
|
204
|
+
|
|
205
|
+
const pValCol = grid.col(P_VAL)!;
|
|
206
|
+
pValCol.isTextColorCoded = true;
|
|
207
|
+
|
|
208
|
+
descriptorNames.forEach((name) => {
|
|
209
|
+
const col = grid.col(name);
|
|
210
|
+
if (col == null)
|
|
211
|
+
return;
|
|
212
|
+
|
|
213
|
+
col.isTextColorCoded = true;
|
|
214
|
+
col.format = '0.000';
|
|
215
|
+
});
|
|
216
|
+
|
|
217
|
+
// set tooltips
|
|
218
|
+
grid.onCellTooltip((cell, x, y) =>{
|
|
219
|
+
if (cell.isColHeader) {
|
|
220
|
+
const cellCol = cell.tableColumn;
|
|
221
|
+
|
|
222
|
+
if (cellCol == null)
|
|
223
|
+
return false;
|
|
224
|
+
|
|
225
|
+
const colName = cellCol.name;
|
|
226
|
+
|
|
227
|
+
switch (colName) {
|
|
228
|
+
case DESCR_TITLE:
|
|
229
|
+
ui.tooltip.show(getDescrTooltip(
|
|
230
|
+
DESCR_TITLE,
|
|
231
|
+
'Use of descriptors in model construction:',
|
|
232
|
+
'selected',
|
|
233
|
+
'excluded',
|
|
234
|
+
), x, y);
|
|
235
|
+
return true;
|
|
236
|
+
|
|
237
|
+
case DESIRABILITY_COL_NAME:
|
|
238
|
+
ui.tooltip.show(ui.divV([
|
|
239
|
+
ui.h2(DESIRABILITY_COL_NAME),
|
|
240
|
+
ui.divText('Desirability profile charts for each descriptor. Only profiles for selected descriptors are shown.'),
|
|
241
|
+
]), x, y);
|
|
242
|
+
return true;
|
|
243
|
+
|
|
244
|
+
case WEIGHT_TITLE:
|
|
245
|
+
ui.tooltip.show(ui.divV([
|
|
246
|
+
ui.h2(WEIGHT_TITLE),
|
|
247
|
+
ui.divText('Weights of selected descriptors.'),
|
|
248
|
+
]), x, y);
|
|
249
|
+
return true;
|
|
250
|
+
|
|
251
|
+
case P_VAL:
|
|
252
|
+
ui.tooltip.show(getDescrTooltip(
|
|
253
|
+
P_VAL,
|
|
254
|
+
'Filtering descriptors by p-value:',
|
|
255
|
+
'selected',
|
|
256
|
+
'excluded',
|
|
257
|
+
), x, y);
|
|
258
|
+
return true;
|
|
259
|
+
|
|
260
|
+
default:
|
|
261
|
+
if (descriptorNames.includes(colName)) {
|
|
262
|
+
ui.tooltip.show(
|
|
263
|
+
getDescrTooltip(
|
|
264
|
+
colName,
|
|
265
|
+
`Correlation of ${colName} with other descriptors, measured by R²:`,
|
|
266
|
+
'weakly correlated',
|
|
267
|
+
'highly correlated',
|
|
268
|
+
), x, y);
|
|
269
|
+
|
|
270
|
+
return true;
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
return false;
|
|
274
|
+
}
|
|
275
|
+
} else {
|
|
276
|
+
if (cell.isTableCell) {
|
|
277
|
+
const cellCol = cell.tableColumn;
|
|
278
|
+
|
|
279
|
+
if (cellCol == null)
|
|
280
|
+
return false;
|
|
281
|
+
|
|
282
|
+
const colName = cellCol.name;
|
|
283
|
+
const value = cell.value;
|
|
284
|
+
|
|
285
|
+
if (colName === DESCR_TITLE) {
|
|
286
|
+
if (selectedByCorr.includes(value))
|
|
287
|
+
ui.tooltip.show('Selected for model construction.', x, y);
|
|
288
|
+
else if (selectedByPvalue.includes(value))
|
|
289
|
+
ui.tooltip.show('Excluded due to a high correlation with other descriptors.', x, y);
|
|
290
|
+
else
|
|
291
|
+
ui.tooltip.show('Excluded due to a high p-value.', x, y);
|
|
292
|
+
|
|
293
|
+
return true;
|
|
294
|
+
} else {
|
|
295
|
+
const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
|
|
296
|
+
|
|
297
|
+
if ((colName === DESIRABILITY_COL_NAME) || (colName === WEIGHT_TITLE)) {
|
|
298
|
+
const startText = (colName === WEIGHT_TITLE) ? 'No weight' : 'No chart shown';
|
|
299
|
+
|
|
300
|
+
if (!this.desirabilityProfileRoots.has(descriptor)) {
|
|
301
|
+
if (selectedByPvalue.includes(descriptor))
|
|
302
|
+
ui.tooltip.show(`${startText}: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
|
|
303
|
+
else
|
|
304
|
+
ui.tooltip.show(`${startText}: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
|
|
305
|
+
|
|
306
|
+
return true;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
return false;
|
|
310
|
+
} else {
|
|
311
|
+
if (descriptorNames.includes(colName) && (!selectedByPvalue.includes(descriptor))) {
|
|
312
|
+
ui.tooltip.show(`<b>${descriptor}</b> is excluded due to a high p-value; so correlation with <b>${colName}</b> is not needed.`, x, y);
|
|
313
|
+
return true;
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
return false;
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
}); // grid.onCellTooltip
|
|
322
|
+
|
|
323
|
+
const desirabilityCol = grid.col(DESIRABILITY_COL_NAME);
|
|
324
|
+
grid.setOptions({'rowHeight': STAT_GRID_HEIGHT});
|
|
325
|
+
desirabilityCol!.width = DESIRABILITY_COLUMN_WIDTH;
|
|
326
|
+
desirabilityCol!.cellType = 'html';
|
|
327
|
+
|
|
328
|
+
// show desirability profile
|
|
329
|
+
grid.onCellPrepare((cell) => {
|
|
330
|
+
const cellCol = cell.tableColumn;
|
|
331
|
+
if (cellCol == null)
|
|
332
|
+
return;
|
|
333
|
+
|
|
334
|
+
if (cell.tableColumn == null)
|
|
335
|
+
return;
|
|
336
|
+
|
|
337
|
+
if (!cell.isTableCell)
|
|
338
|
+
return;
|
|
339
|
+
|
|
340
|
+
if (cell.tableColumn.name !== DESIRABILITY_COL_NAME)
|
|
341
|
+
return;
|
|
342
|
+
|
|
343
|
+
const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
|
|
344
|
+
cell.element = this.desirabilityProfileRoots.get(descriptor) ?? ui.div();
|
|
345
|
+
}); // grid.onCellPrepare
|
|
346
|
+
} // updateGrid
|
|
347
|
+
|
|
348
|
+
/** Updates the main grid viewer with the pMPO scores column */
|
|
349
|
+
private updateGrid(): void {
|
|
350
|
+
const grid = this.view.grid;
|
|
351
|
+
const name = this.predictionName;
|
|
352
|
+
|
|
353
|
+
grid.sort([this.predictionName], [false]);
|
|
354
|
+
|
|
355
|
+
grid.col(name)!.format = '0.0000';
|
|
356
|
+
|
|
357
|
+
// set tooltips
|
|
358
|
+
grid.onCellTooltip((cell, x, y) => {
|
|
359
|
+
if (cell.isColHeader) {
|
|
360
|
+
const cellCol = cell.tableColumn;
|
|
361
|
+
if (cellCol) {
|
|
362
|
+
if (cell.tableColumn.name === name) {
|
|
363
|
+
ui.tooltip.show(getScoreTooltip(), x, y);
|
|
364
|
+
|
|
365
|
+
return true;
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
return false;
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
});
|
|
372
|
+
} // updateGrid
|
|
373
|
+
|
|
374
|
+
/** Updates the desirability profile data */
|
|
375
|
+
private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame): void {
|
|
376
|
+
if (this.params == null)
|
|
377
|
+
return;
|
|
378
|
+
|
|
379
|
+
// Clear existing roots
|
|
380
|
+
this.desirabilityProfileRoots.forEach((root) => root.remove());
|
|
381
|
+
this.desirabilityProfileRoots.clear();
|
|
382
|
+
|
|
383
|
+
const desirabilityProfile = getDesirabilityProfileJson(this.params, '', '');
|
|
384
|
+
|
|
385
|
+
// Set weights
|
|
386
|
+
const descrNames = descrStatsTable.col(DESCR_TITLE)!.toList();
|
|
387
|
+
const weightsRaw = descrStatsTable.col(WEIGHT_TITLE)!.getRawData();
|
|
388
|
+
const props = desirabilityProfile.properties;
|
|
389
|
+
|
|
390
|
+
for (const name of Object.keys(props))
|
|
391
|
+
weightsRaw[descrNames.indexOf(name)] = props[name].weight;
|
|
392
|
+
|
|
393
|
+
// Set HTML elements
|
|
394
|
+
const mpoEditor = new MpoProfileEditor();
|
|
395
|
+
mpoEditor.setProfile(desirabilityProfile);
|
|
396
|
+
const container = mpoEditor.root;
|
|
397
|
+
const rootsCol = container.querySelector('div.d4-flex-col.ui-div');
|
|
398
|
+
|
|
399
|
+
if (rootsCol == null)
|
|
400
|
+
return;
|
|
401
|
+
|
|
402
|
+
const rows = rootsCol.querySelectorAll('div.d4-flex-row.ui-div');
|
|
403
|
+
|
|
404
|
+
rows.forEach((row) => {
|
|
405
|
+
const children = row.children;
|
|
406
|
+
if (children.length < 2) // expecting descriptor name, weight & profile
|
|
407
|
+
return;
|
|
408
|
+
|
|
409
|
+
const descrDivChildren = (children[0] as HTMLElement).children;
|
|
410
|
+
if (descrDivChildren.length < 1) // expecting 1 div with descriptor name
|
|
411
|
+
return;
|
|
412
|
+
|
|
413
|
+
const descrName = (descrDivChildren[0] as HTMLElement).innerText;
|
|
414
|
+
|
|
415
|
+
this.desirabilityProfileRoots.set(descrName, children[2] as HTMLElement);
|
|
416
|
+
});
|
|
417
|
+
} // updateDesirabilityProfileData
|
|
418
|
+
|
|
419
|
+
/** Fits the pMPO model to the given data and updates the viewers accordingly */
|
|
420
|
+
private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
|
|
421
|
+
pValTresh: number, r2Tresh: number, qCutoff: number): void {
|
|
422
|
+
if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
|
|
423
|
+
throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
|
|
424
|
+
|
|
425
|
+
const descriptorNames = descriptors.names();
|
|
426
|
+
const {desired, nonDesired} = getDesiredTables(df, desirability);
|
|
427
|
+
|
|
428
|
+
// Compute descriptors' statistics
|
|
429
|
+
const descrStats = new Map<string, DescriptorStatistics>();
|
|
430
|
+
descriptorNames.forEach((name) => {
|
|
431
|
+
descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
|
|
432
|
+
});
|
|
433
|
+
const descrStatsTable = getDescriptorStatisticsTable(descrStats);
|
|
434
|
+
|
|
435
|
+
// Set p-value column color coding
|
|
436
|
+
setPvalColumnColorCoding(descrStatsTable, pValTresh);
|
|
437
|
+
|
|
438
|
+
// Filter by p-value
|
|
439
|
+
const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
|
|
440
|
+
|
|
441
|
+
// Compute correlation triples
|
|
442
|
+
const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
|
|
443
|
+
|
|
444
|
+
// Filter by correlations
|
|
445
|
+
const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
|
|
446
|
+
|
|
447
|
+
// Add the Selected column
|
|
448
|
+
addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
|
|
449
|
+
|
|
450
|
+
// Add correlation columns
|
|
451
|
+
addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
|
|
452
|
+
|
|
453
|
+
// Set correlation columns color coding
|
|
454
|
+
setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
|
|
455
|
+
|
|
456
|
+
// Compute pMPO parameters - training
|
|
457
|
+
this.params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
|
|
458
|
+
|
|
459
|
+
//const weightsTable = getWeightsTable(this.params);
|
|
460
|
+
const prediction = Pmpo.predict(df, this.params, this.predictionName);
|
|
461
|
+
|
|
462
|
+
// Mark predictions with a color
|
|
463
|
+
prediction.colors.setLinear(getOutputPalette(OPT_TYPE.MAX), {min: prediction.stats.min, max: prediction.stats.max});
|
|
464
|
+
|
|
465
|
+
df.columns.remove(this.predictionName);
|
|
466
|
+
df.columns.add(prediction);
|
|
467
|
+
|
|
468
|
+
// Update viewers
|
|
469
|
+
this.updateGrid();
|
|
470
|
+
|
|
471
|
+
// Update desirability profile roots map
|
|
472
|
+
this.updateDesirabilityProfileData(descrStatsTable);
|
|
473
|
+
|
|
474
|
+
// Update statistics grid
|
|
475
|
+
this.updateStatisticsGrid(descrStatsTable, descriptorNames, selectedByPvalue, selectedByCorr);
|
|
476
|
+
} // fitAndUpdateViewers
|
|
477
|
+
|
|
478
|
+
/** Runs the pMPO model training application */
|
|
479
|
+
public runTrainingApp(): void {
|
|
480
|
+
const dockMng = this.view.dockManager;
|
|
481
|
+
|
|
482
|
+
// Inputs form
|
|
483
|
+
dockMng.dock(this.getInputForm(), DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
|
|
484
|
+
|
|
485
|
+
// Dock viewers
|
|
486
|
+
const gridNode = dockMng.findNode(this.view.grid.root);
|
|
487
|
+
|
|
488
|
+
if (gridNode == null)
|
|
489
|
+
throw new Error('Failed to train pMPO: missing a grid in the table view.');
|
|
490
|
+
|
|
491
|
+
dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
|
|
492
|
+
|
|
493
|
+
this.setRibbons();
|
|
494
|
+
} // runTrainingApp
|
|
495
|
+
|
|
496
|
+
/** Creates and returns the input form for pMPO model training */
|
|
497
|
+
private getInputForm(): HTMLElement {
|
|
498
|
+
const form = ui.form([]);
|
|
499
|
+
form.append(ui.h2('Training data'));
|
|
500
|
+
const numericColNames = this.numericCols.map((col) => col.name);
|
|
501
|
+
|
|
502
|
+
// Function to run computations on input changes
|
|
503
|
+
const runComputations = () => {
|
|
504
|
+
try {
|
|
505
|
+
this.fitAndUpdateViewers(
|
|
506
|
+
this.table,
|
|
507
|
+
DG.DataFrame.fromColumns(descrInput.value).columns,
|
|
508
|
+
this.table.col(desInput.value!)!,
|
|
509
|
+
pInput.value!,
|
|
510
|
+
rInput.value!,
|
|
511
|
+
qInput.value!,
|
|
512
|
+
);
|
|
513
|
+
} catch (err) {
|
|
514
|
+
grok.shell.error(err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.');
|
|
515
|
+
}
|
|
516
|
+
};
|
|
517
|
+
|
|
518
|
+
// Descriptor columns input
|
|
519
|
+
const descrInput = ui.input.columns('Descriptors', {
|
|
520
|
+
table: this.table,
|
|
521
|
+
nullable: false,
|
|
522
|
+
available: numericColNames,
|
|
523
|
+
checked: numericColNames,
|
|
524
|
+
tooltipText: 'Descriptor columns used for model construction.',
|
|
525
|
+
onValueChanged: (value) => {
|
|
526
|
+
if (value != null)
|
|
527
|
+
runComputations();
|
|
528
|
+
},
|
|
529
|
+
});
|
|
530
|
+
form.append(descrInput.root);
|
|
531
|
+
|
|
532
|
+
// Desirability column input
|
|
533
|
+
const desInput = ui.input.choice('Desirability', {
|
|
534
|
+
nullable: false,
|
|
535
|
+
value: this.boolCols[0].name,
|
|
536
|
+
items: this.boolCols.map((col) => col.name),
|
|
537
|
+
tooltipText: 'Desirability column.',
|
|
538
|
+
onValueChanged: (value) => {
|
|
539
|
+
if (value != null)
|
|
540
|
+
runComputations();
|
|
541
|
+
},
|
|
542
|
+
});
|
|
543
|
+
form.append(desInput.root);
|
|
544
|
+
|
|
545
|
+
const header = ui.h2('Thresholds');
|
|
546
|
+
ui.tooltip.bind(header, 'Settings of the pMPO model training.');
|
|
547
|
+
form.append(header);
|
|
548
|
+
|
|
549
|
+
// p-value threshold input
|
|
550
|
+
const pInput = ui.input.float('p-value', {
|
|
551
|
+
nullable: false,
|
|
552
|
+
min: P_VAL_TRES_MIN,
|
|
553
|
+
max: 1,
|
|
554
|
+
step: 0.01,
|
|
555
|
+
value: 0.05,
|
|
556
|
+
tooltipText: 'Descriptors with p-values above this threshold are excluded.',
|
|
557
|
+
onValueChanged: (value) => {
|
|
558
|
+
if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= 1))
|
|
559
|
+
runComputations();
|
|
560
|
+
},
|
|
561
|
+
});
|
|
562
|
+
form.append(pInput.root);
|
|
563
|
+
|
|
564
|
+
// R² threshold input
|
|
565
|
+
const rInput = ui.input.float('R²', {
|
|
566
|
+
nullable: false,
|
|
567
|
+
min: R2_MIN,
|
|
568
|
+
value: 0.5,
|
|
569
|
+
max: 1,
|
|
570
|
+
step: 0.01,
|
|
571
|
+
// eslint-disable-next-line max-len
|
|
572
|
+
tooltipText: 'Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
|
|
573
|
+
onValueChanged: (value) => {
|
|
574
|
+
if ((value != null) && (value >= R2_MIN) && (value <= 1))
|
|
575
|
+
runComputations();
|
|
576
|
+
},
|
|
577
|
+
});
|
|
578
|
+
form.append(rInput.root);
|
|
579
|
+
|
|
580
|
+
// q-cutoff input
|
|
581
|
+
const qInput = ui.input.float('q-cutoff', {
|
|
582
|
+
nullable: false,
|
|
583
|
+
min: Q_CUTOFF_MIN,
|
|
584
|
+
value: 0.05,
|
|
585
|
+
max: 1,
|
|
586
|
+
step: 0.01,
|
|
587
|
+
tooltipText: 'Q-cutoff for the pMPO model computation.',
|
|
588
|
+
onValueChanged: (value) => {
|
|
589
|
+
if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= 1))
|
|
590
|
+
runComputations();
|
|
591
|
+
},
|
|
592
|
+
});
|
|
593
|
+
form.append(qInput.root);
|
|
594
|
+
|
|
595
|
+
setTimeout(() => runComputations(), 10);
|
|
596
|
+
|
|
597
|
+
// Save model button
|
|
598
|
+
const saveBtn = ui.button('Save model', async () => {
|
|
599
|
+
if (this.params == null) {
|
|
600
|
+
grok.shell.warning('Failed to save pMPO model: null parameters.');
|
|
601
|
+
return;
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
saveModel(this.params, this.table.name);
|
|
605
|
+
}, 'Save model as platform file.');
|
|
606
|
+
form.append(saveBtn);
|
|
607
|
+
|
|
608
|
+
const div = ui.div([form]);
|
|
609
|
+
div.classList.add('eda-pmpo-input-form');
|
|
610
|
+
|
|
611
|
+
return div;
|
|
612
|
+
} // getInputForm
|
|
613
|
+
|
|
614
|
+
/** Retrieves boolean columns from the data frame */
|
|
615
|
+
private getBoolCols(): DG.Column[] {
|
|
616
|
+
const res: DG.Column[] = [];
|
|
617
|
+
|
|
618
|
+
for (const col of this.table.columns) {
|
|
619
|
+
if ((col.type === DG.COLUMN_TYPE.BOOL) && (col.stats.stdev > 0))
|
|
620
|
+
res.push(col);
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
return res;
|
|
624
|
+
} // getBoolCols
|
|
625
|
+
|
|
626
|
+
/** Retrieves valid (numerical, no missing values, non-zero standard deviation) numeric columns from the data frame */
|
|
627
|
+
private getValidNumericCols(): DG.Column[] {
|
|
628
|
+
const res: DG.Column[] = [];
|
|
629
|
+
|
|
630
|
+
for (const col of this.table.columns) {
|
|
631
|
+
if ((col.isNumerical) && (col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
|
|
632
|
+
res.push(col);
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
return res;
|
|
636
|
+
} // getValidNumericCols
|
|
637
|
+
}; // Pmpo
|