@datagrok/eda 1.4.11 → 1.4.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/.eslintrc.json +0 -1
  2. package/CHANGELOG.md +15 -0
  3. package/CLAUDE.md +185 -0
  4. package/README.md +8 -0
  5. package/css/pmpo.css +35 -0
  6. package/dist/package-test.js +1 -1
  7. package/dist/package-test.js.map +1 -1
  8. package/dist/package.js +1 -1
  9. package/dist/package.js.map +1 -1
  10. package/eslintrc.json +45 -0
  11. package/files/drugs-props-test.csv +126 -0
  12. package/files/drugs-props-train-scores.csv +664 -0
  13. package/files/drugs-props-train.csv +664 -0
  14. package/package.json +9 -3
  15. package/src/anova/anova-tools.ts +1 -1
  16. package/src/anova/anova-ui.ts +1 -1
  17. package/src/package-api.ts +18 -0
  18. package/src/package-test.ts +4 -1
  19. package/src/package.g.ts +25 -0
  20. package/src/package.ts +55 -15
  21. package/src/pareto-optimization/pareto-computations.ts +6 -0
  22. package/src/pareto-optimization/utils.ts +6 -4
  23. package/src/probabilistic-scoring/data-generator.ts +157 -0
  24. package/src/probabilistic-scoring/nelder-mead.ts +204 -0
  25. package/src/probabilistic-scoring/pmpo-defs.ts +218 -0
  26. package/src/probabilistic-scoring/pmpo-utils.ts +603 -0
  27. package/src/probabilistic-scoring/prob-scoring.ts +991 -0
  28. package/src/probabilistic-scoring/stat-tools.ts +303 -0
  29. package/src/softmax-classifier.ts +1 -1
  30. package/src/tests/anova-tests.ts +1 -1
  31. package/src/tests/classifiers-tests.ts +1 -1
  32. package/src/tests/dim-reduction-tests.ts +1 -1
  33. package/src/tests/linear-methods-tests.ts +1 -1
  34. package/src/tests/mis-vals-imputation-tests.ts +1 -1
  35. package/src/tests/pareto-tests.ts +253 -0
  36. package/src/tests/pmpo-tests.ts +157 -0
  37. package/test-console-output-1.log +175 -209
  38. package/test-record-1.mp4 +0 -0
@@ -0,0 +1,991 @@
1
+ /* eslint-disable max-len */
2
+ // Probabilistic scoring (pMPO) features
3
+ // Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
4
+
5
+ import * as grok from 'datagrok-api/grok';
6
+ import * as ui from 'datagrok-api/ui';
7
+ import * as DG from 'datagrok-api/dg';
8
+
9
+ import {MpoProfileEditor} from '@datagrok-libraries/statistics/src/mpo/mpo-profile-editor';
10
+
11
+ import '../../css/pmpo.css';
12
+
13
+ import {getDesiredTables, getDescriptorStatistics, getBoolPredictionColumn, getPmpoEvaluation} from './stat-tools';
14
+ import {MIN_SAMPLES_COUNT, PMPO_NON_APPLICABLE, DescriptorStatistics, P_VAL_TRES_MIN, DESCR_TITLE,
15
+ R2_MIN, Q_CUTOFF_MIN, PmpoParams, SCORES_TITLE, DESCR_TABLE_TITLE, PMPO_COMPUTE_FAILED, SELECTED_TITLE,
16
+ P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE,
17
+ P_VAL_TRES_DEFAULT, R2_DEFAULT, Q_CUTOFF_DEFAULT, USE_SIGMOID_DEFAULT, ROC_TRESHOLDS,
18
+ FPR_TITLE, TPR_TITLE, COLORS, THRESHOLD, AUTO_TUNE_MAX_APPLICABLE_ROWS, DEFAULT_OPTIMIZATION_SETTINGS,
19
+ P_VAL_TRES_MAX, R2_MAX, Q_CUTOFF_MAX, OptimalPoint, LOW_PARAMS_BOUNDS, HIGH_PARAMS_BOUNDS, FORMAT} from './pmpo-defs';
20
+ import {addSelectedDescriptorsCol, getDescriptorStatisticsTable, getFilteredByPvalue, getFilteredByCorrelations,
21
+ getModelParams, getDescrTooltip, saveModel, getScoreTooltip, getDesirabilityProfileJson, getCorrelationTriples,
22
+ addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding, PmpoError} from './pmpo-utils';
23
+ import {getOutputPalette} from '../pareto-optimization/utils';
24
+ import {OPT_TYPE} from '../pareto-optimization/defs';
25
+ import {optimizeNM} from './nelder-mead';
26
+
27
+ export type PmpoTrainingResult = {
28
+ params: Map<string, PmpoParams>,
29
+ descrStatsTable: DG.DataFrame,
30
+ selectedByPvalue: string[],
31
+ selectedByCorr: string[],
32
+ };
33
+
34
+ /** Type for pMPO training controls */
35
+ export type Controls = {form: HTMLElement, saveBtn: HTMLButtonElement};
36
+
37
+ /** Type for pMPO elements */
38
+ export type PmpoAppItems = {
39
+ statsGrid: DG.Viewer;
40
+ rocCurve: DG.Viewer;
41
+ confusionMatrix: DG.Viewer;
42
+ controls: Controls;
43
+ };
44
+
45
+ /** Class implementing probabilistic MPO (pMPO) model training and prediction */
46
+ export class Pmpo {
47
+ /** Checks if pMPO model can be applied to the given descriptors and desirability column */
48
+ static isApplicable(descriptors: DG.ColumnList, desirability: DG.Column, pValThresh: number,
49
+ r2Tresh: number, qCutoff: number, toShowWarning: boolean = false): boolean {
50
+ const rows = desirability.length;
51
+
52
+ const showWarning = (msg: string) => {
53
+ if (toShowWarning)
54
+ grok.shell.warning(PMPO_NON_APPLICABLE + msg);
55
+ };
56
+
57
+ // Check p-value threshold
58
+ if (pValThresh < P_VAL_TRES_MIN) {
59
+ showWarning(`: too small p-value threshold - ${pValThresh}, minimum - ${P_VAL_TRES_MIN}`);
60
+ return false;
61
+ }
62
+
63
+ // Check R2 threshold
64
+ if (r2Tresh < R2_MIN) {
65
+ showWarning(`: too small R² threshold - ${r2Tresh}, minimum - ${R2_MIN}`);
66
+ return false;
67
+ }
68
+
69
+ // Check q-cutoff
70
+ if (qCutoff < Q_CUTOFF_MIN) {
71
+ showWarning(`: too small q-cutoff - ${qCutoff}, minimum - ${Q_CUTOFF_MIN}`);
72
+ return false;
73
+ }
74
+
75
+ // Check samples count
76
+ if (rows < MIN_SAMPLES_COUNT) {
77
+ showWarning(`: not enough of samples - ${rows}, minimum - ${MIN_SAMPLES_COUNT}`);
78
+ return false;
79
+ }
80
+
81
+ // Check desirability
82
+ if (desirability.type !== DG.COLUMN_TYPE.BOOL) {
83
+ showWarning(`: "${desirability.name}" must be boolean column.`);
84
+ return false;
85
+ }
86
+
87
+ if (desirability.stats.stdev === 0) { // TRUE & FALSE
88
+ showWarning(`: "${desirability.name}" has a single category.`);
89
+ return false;
90
+ }
91
+
92
+ // Check descriptors
93
+ let nonConstantCols = 0;
94
+
95
+ for (const col of descriptors) {
96
+ if (!col.isNumerical) {
97
+ showWarning(`: "${col.name}" is not numerical.`);
98
+ return false;
99
+ }
100
+
101
+ if (col.stats.missingValueCount > 0) {
102
+ showWarning(`: "${col.name}" contains missing values.`);
103
+ return false;
104
+ }
105
+
106
+ if (col.stats.stdev > 0)
107
+ ++nonConstantCols;
108
+ }
109
+
110
+ if (nonConstantCols < 1) {
111
+ showWarning(`: not enough of non-constant descriptors.`);
112
+ return false;
113
+ }
114
+
115
+ return true;
116
+ } // isApplicable
117
+
118
+ /** Validates the input data frame for pMPO applicability */
119
+ static isTableValid(df: DG.DataFrame, toShowMsg: boolean = true): boolean {
120
+ // Check row count
121
+ if (df.rowCount < 2) {
122
+ if (toShowMsg)
123
+ grok.shell.warning(PMPO_NON_APPLICABLE + `. Not enough of samples: ${df.rowCount}, minimum: 2.`);
124
+ return false;
125
+ }
126
+
127
+ let boolColsCount = 0;
128
+ let validNumericColsCount = 0;
129
+
130
+ // Check numeric columns and boolean columns
131
+ for (const col of df.columns) {
132
+ if (col.isNumerical) {
133
+ if ((col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
134
+ ++validNumericColsCount;
135
+ } else if (col.type == DG.COLUMN_TYPE.BOOL)
136
+ ++boolColsCount;
137
+ }
138
+
139
+ // Check boolean columns count
140
+ if (boolColsCount < 1) {
141
+ if (toShowMsg)
142
+ grok.shell.warning(PMPO_NON_APPLICABLE + ': no boolean columns.');
143
+ return false;
144
+ }
145
+
146
+ // Check valid numeric columns count
147
+ if (validNumericColsCount < 1) {
148
+ if (toShowMsg)
149
+ grok.shell.warning(PMPO_NON_APPLICABLE + ': no numeric columns without missing values and non-zero variance.');
150
+ return false;
151
+ }
152
+
153
+ return true;
154
+ } // isTableValid
155
+
156
+ /** Fits the pMPO model to the given data and returns training results */
157
+ static fit(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
158
+ pValTresh: number, r2Tresh: number, qCutoff: number, toCheckApplicability: boolean = true): PmpoTrainingResult {
159
+ if (toCheckApplicability) {
160
+ if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
161
+ throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
162
+ }
163
+
164
+ const descriptorNames = descriptors.names();
165
+ const {desired, nonDesired} = getDesiredTables(df, desirability);
166
+
167
+ // Compute descriptors' statistics
168
+ const descrStats = new Map<string, DescriptorStatistics>();
169
+ descriptorNames.forEach((name) => {
170
+ descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
171
+ });
172
+ const descrStatsTable = getDescriptorStatisticsTable(descrStats);
173
+
174
+ // Set p-value column color coding
175
+ setPvalColumnColorCoding(descrStatsTable, pValTresh);
176
+
177
+ // Filter by p-value
178
+ const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
179
+
180
+ if (selectedByPvalue.length < 1)
181
+ throw new PmpoError('Cannot train pMPO model: all descriptors have high p-values (not significant).');
182
+
183
+ // Compute correlation triples
184
+ const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
185
+
186
+ // Filter by correlations
187
+ const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
188
+
189
+ // Add the Selected column
190
+ addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
191
+
192
+ // Add correlation columns
193
+ addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
194
+
195
+ // Set correlation columns color coding
196
+ setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
197
+
198
+ // Compute pMPO parameters - training
199
+ const params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
200
+
201
+ return {
202
+ params: params,
203
+ descrStatsTable: descrStatsTable,
204
+ selectedByPvalue: selectedByPvalue,
205
+ selectedByCorr: selectedByCorr,
206
+ };
207
+ } // fitModelParams
208
+
209
+ /** Predicts pMPO scores for the given data frame using provided pMPO parameters */
210
+ static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, useSigmoid: boolean, predictionName: string): DG.Column {
211
+ const count = df.rowCount;
212
+ const scores = new Float64Array(count).fill(0);
213
+
214
+ // Compute pMPO scores (see https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
215
+ params.forEach((param, name) => {
216
+ const col = df.col(name);
217
+ const b = param.b;
218
+ const c = param.c;
219
+ const x0 = param.cutoff;
220
+ let weight = param.weight;
221
+ const avg = param.desAvg;
222
+ const std = param.desStd;
223
+ const frac = 1.0 / (2 * std**2);
224
+
225
+ if (col == null)
226
+ throw new Error(`Failed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
227
+
228
+ const vals = col.getRawData();
229
+
230
+ if (useSigmoid) {
231
+ if (c > 0) {
232
+ for (let i = 0; i < count; ++i)
233
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac) / (1.0 + b * (c ** (-(vals[i] - x0))));
234
+ } else {
235
+ weight = weight / (1.0 + b);
236
+
237
+ for (let i = 0; i < count; ++i)
238
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
239
+ }
240
+ } else {
241
+ for (let i = 0; i < count; ++i)
242
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
243
+ }
244
+ });
245
+
246
+ return DG.Column.fromFloat64Array(predictionName, scores);
247
+ } // predict
248
+
249
+ private params: Map<string, PmpoParams> | null = null;
250
+
251
+ private table: DG.DataFrame;
252
+ private view: DG.TableView;
253
+ private boolCols: DG.Column[];
254
+ private numericCols: DG.Column[];
255
+
256
+ private initTable = DG.DataFrame.create();
257
+
258
+ private statGrid = DG.Viewer.grid(this.initTable, {showTitle: true, title: DESCR_TABLE_TITLE});
259
+
260
+ private predictionName = SCORES_TITLE;
261
+ private boolPredictionName = '';
262
+
263
+ private desirabilityProfileRoots = new Map<string, HTMLElement>();
264
+
265
+ private rocCurve = DG.Viewer.scatterPlot(this.initTable, {
266
+ showTitle: true,
267
+ showSizeSelector: false,
268
+ showColorSelector: false,
269
+ });
270
+
271
+ private confusionMatrix = DG.Viewer.fromType('Confusion matrix', this.initTable, {
272
+ xColumnName: 'control',
273
+ yColumnName: 'control',
274
+ showTitle: true,
275
+ title: 'Confusion Matrix',
276
+ descriptionPosition: 'Bottom',
277
+ description: 'Confusion matrix for the predicted vs actual desirability labels.',
278
+ descriptionVisibilityMode: 'Always',
279
+ });
280
+
281
+ constructor(df: DG.DataFrame, view?: DG.TableView) {
282
+ this.table = df;
283
+ this.view = view ?? (grok.shell.tableView(df.name) ?? grok.shell.addTableView(df));
284
+ this.boolCols = this.getBoolCols();
285
+ this.numericCols = this.getValidNumericCols();
286
+ this.predictionName = df.columns.getUnusedName(SCORES_TITLE);
287
+ };
288
+
289
+ /** Sets the ribbon panels in the table view (removes the first panel) */
290
+ private setRibbons(): void {
291
+ const ribPanel = this.view.getRibbonPanels();
292
+
293
+ if (ribPanel.length < 1)
294
+ return;
295
+
296
+ this.view.setRibbonPanels(ribPanel.slice(1));
297
+ }
298
+
299
+ /** Updates the statistics grid viewer with the given statistics table and selected descriptors */
300
+ private updateStatisticsGrid(table: DG.DataFrame, descriptorNames: string[], selectedByPvalue: string[], selectedByCorr: string[]): void {
301
+ const grid = this.statGrid;
302
+ grid.dataFrame = table;
303
+ grid.setOptions({
304
+ showTitle: true,
305
+ title: table.name,
306
+ });
307
+
308
+ grid.sort([SELECTED_TITLE], [false]);
309
+ grid.col(P_VAL)!.format = 'scientific';
310
+
311
+ // set color coding
312
+ const descrCol = grid.col(DESCR_TITLE)!;
313
+ descrCol.isTextColorCoded = true;
314
+
315
+ const pValCol = grid.col(P_VAL)!;
316
+ pValCol.isTextColorCoded = true;
317
+
318
+ descriptorNames.forEach((name) => {
319
+ const col = grid.col(name);
320
+ if (col == null)
321
+ return;
322
+
323
+ col.isTextColorCoded = true;
324
+ col.format = '0.000';
325
+ });
326
+
327
+ // set tooltips
328
+ grid.onCellTooltip((cell, x, y) =>{
329
+ if (cell.isColHeader) {
330
+ const cellCol = cell.tableColumn;
331
+
332
+ if (cellCol == null)
333
+ return false;
334
+
335
+ const colName = cellCol.name;
336
+
337
+ switch (colName) {
338
+ case DESCR_TITLE:
339
+ ui.tooltip.show(getDescrTooltip(
340
+ DESCR_TITLE,
341
+ 'Use of descriptors in model construction:',
342
+ 'selected',
343
+ 'excluded',
344
+ ), x, y);
345
+ return true;
346
+
347
+ case DESIRABILITY_COL_NAME:
348
+ ui.tooltip.show(ui.divV([
349
+ ui.h2(DESIRABILITY_COL_NAME),
350
+ ui.divText('Desirability profile charts for each descriptor. Only profiles for selected descriptors are shown.'),
351
+ ]), x, y);
352
+ return true;
353
+
354
+ case WEIGHT_TITLE:
355
+ ui.tooltip.show(ui.divV([
356
+ ui.h2(WEIGHT_TITLE),
357
+ ui.divText('Weights of selected descriptors.'),
358
+ ]), x, y);
359
+ return true;
360
+
361
+ case P_VAL:
362
+ ui.tooltip.show(getDescrTooltip(
363
+ P_VAL,
364
+ 'Filtering descriptors by p-value:',
365
+ 'selected',
366
+ 'excluded',
367
+ ), x, y);
368
+ return true;
369
+
370
+ default:
371
+ if (descriptorNames.includes(colName)) {
372
+ ui.tooltip.show(
373
+ getDescrTooltip(
374
+ colName,
375
+ `Correlation of ${colName} with other descriptors, measured by R²:`,
376
+ 'weakly correlated',
377
+ 'highly correlated',
378
+ ), x, y);
379
+
380
+ return true;
381
+ }
382
+
383
+ return false;
384
+ }
385
+ } else {
386
+ if (cell.isTableCell) {
387
+ const cellCol = cell.tableColumn;
388
+
389
+ if (cellCol == null)
390
+ return false;
391
+
392
+ const colName = cellCol.name;
393
+ const value = cell.value;
394
+
395
+ if (colName === DESCR_TITLE) {
396
+ if (selectedByCorr.includes(value))
397
+ ui.tooltip.show('Selected for model construction.', x, y);
398
+ else if (selectedByPvalue.includes(value))
399
+ ui.tooltip.show('Excluded due to a high correlation with other descriptors.', x, y);
400
+ else
401
+ ui.tooltip.show('Excluded due to a high p-value.', x, y);
402
+
403
+ return true;
404
+ } else {
405
+ const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
406
+
407
+ if (colName === WEIGHT_TITLE) {
408
+ if (!this.desirabilityProfileRoots.has(descriptor)) {
409
+ if (selectedByPvalue.includes(descriptor))
410
+ ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
411
+ else
412
+ ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
413
+
414
+ return true;
415
+ }
416
+
417
+ return false;
418
+ } else {
419
+ if (descriptorNames.includes(colName) && (!selectedByPvalue.includes(descriptor))) {
420
+ ui.tooltip.show(`<b>${descriptor}</b> is excluded due to a high p-value; so correlation with <b>${colName}</b> is not needed.`, x, y);
421
+ return true;
422
+ }
423
+ }
424
+ }
425
+
426
+ return false;
427
+ }
428
+ }
429
+ }); // grid.onCellTooltip
430
+
431
+ const desirabilityCol = grid.col(DESIRABILITY_COL_NAME);
432
+ grid.setOptions({'rowHeight': STAT_GRID_HEIGHT});
433
+ desirabilityCol!.width = DESIRABILITY_COLUMN_WIDTH;
434
+ desirabilityCol!.cellType = 'html';
435
+
436
+ // show desirability profile
437
+ grid.onCellPrepare((cell) => {
438
+ const cellCol = cell.tableColumn;
439
+ if (cellCol == null)
440
+ return;
441
+
442
+ if (cell.tableColumn == null)
443
+ return;
444
+
445
+ if (!cell.isTableCell)
446
+ return;
447
+
448
+ if (cell.tableColumn.name !== DESIRABILITY_COL_NAME)
449
+ return;
450
+
451
+ const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
452
+ const element = this.desirabilityProfileRoots.get(descriptor);
453
+
454
+ if (element != null)
455
+ cell.element = element;
456
+ else {
457
+ const selected = selectedByPvalue.includes(descriptor);
458
+ const text = selected ? 'highly correlated with other descriptors' : 'statistically insignificant';
459
+ const tooltipMsg = selected ?
460
+ `No chart shown: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.` :
461
+ `No chart shown: <b>${descriptor}</b> is excluded due to a high p-value.`;
462
+
463
+ const divWithDescription = ui.divText(text);
464
+ divWithDescription.style.color = COLORS.SKIPPED;
465
+ divWithDescription.classList.add('eda-pmpo-centered-text');
466
+ ui.tooltip.bind(divWithDescription, tooltipMsg);
467
+ cell.element = divWithDescription;
468
+ }
469
+ }); // grid.onCellPrepare
470
+ } // updateGrid
471
+
472
+ /** Updates the main grid viewer with the pMPO scores column */
473
+ private updateGrid(): void {
474
+ const grid = this.view.grid;
475
+ const name = this.predictionName;
476
+
477
+ grid.sort([this.predictionName], [false]);
478
+
479
+ grid.col(name)!.format = '0.0000';
480
+
481
+ // set tooltips
482
+ grid.onCellTooltip((cell, x, y) => {
483
+ if (cell.isColHeader) {
484
+ const cellCol = cell.tableColumn;
485
+ if (cellCol) {
486
+ if (cell.tableColumn.name === name) {
487
+ ui.tooltip.show(getScoreTooltip(), x, y);
488
+
489
+ return true;
490
+ }
491
+
492
+ return false;
493
+ }
494
+ }
495
+ });
496
+ } // updateGrid
497
+
498
+ /** Updates the desirability profile data */
499
+ private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame, useSigmoidalCorrection: boolean): void {
500
+ if (this.params == null)
501
+ return;
502
+
503
+ // Clear existing roots
504
+ this.desirabilityProfileRoots.forEach((root) => root.remove());
505
+ this.desirabilityProfileRoots.clear();
506
+
507
+ const desirabilityProfile = getDesirabilityProfileJson(this.params, useSigmoidalCorrection, '', '', true);
508
+
509
+ // Set weights
510
+ const descrNames = descrStatsTable.col(DESCR_TITLE)!.toList();
511
+ const weightsRaw = descrStatsTable.col(WEIGHT_TITLE)!.getRawData();
512
+ const props = desirabilityProfile.properties;
513
+
514
+ for (const name of Object.keys(props))
515
+ weightsRaw[descrNames.indexOf(name)] = props[name].weight;
516
+
517
+ // Set HTML elements
518
+ const mpoEditor = new MpoProfileEditor();
519
+ mpoEditor.setProfile(desirabilityProfile);
520
+ const container = mpoEditor.root;
521
+ const rootsCol = container.querySelector('div.d4-flex-col.ui-div');
522
+
523
+ if (rootsCol == null)
524
+ return;
525
+
526
+ const rows = rootsCol.querySelectorAll('div.d4-flex-row.ui-div');
527
+
528
+ rows.forEach((row) => {
529
+ const children = row.children;
530
+ if (children.length < 2) // expecting descriptor name, weight & profile
531
+ return;
532
+
533
+ const descrDivChildren = (children[0] as HTMLElement).children;
534
+ if (descrDivChildren.length < 1) // expecting 1 div with descriptor name
535
+ return;
536
+
537
+ const descrName = (descrDivChildren[0] as HTMLElement).innerText;
538
+
539
+ this.desirabilityProfileRoots.set(descrName, children[2] as HTMLElement);
540
+ });
541
+ } // updateDesirabilityProfileData
542
+
543
+ /** Updates the ROC curve viewer with the given desirability (labels) and prediction columns
544
+ * @return Best threshold according to Youden's J statistic
545
+ */
546
+ private updateRocCurve(desirability: DG.Column, prediction: DG.Column): number {
547
+ const evaluation = getPmpoEvaluation(desirability, prediction);
548
+
549
+ const rocDf = DG.DataFrame.fromColumns([
550
+ DG.Column.fromFloat32Array(THRESHOLD, ROC_TRESHOLDS),
551
+ DG.Column.fromFloat32Array(FPR_TITLE, evaluation.fpr),
552
+ DG.Column.fromFloat32Array(TPR_TITLE, evaluation.tpr),
553
+ ]);
554
+
555
+ // Add baseline
556
+ rocDf.meta.formulaLines.addLine({
557
+ title: 'Non-informative baseline',
558
+ formula: `\${${TPR_TITLE}} = \${${FPR_TITLE}}`,
559
+ width: 1,
560
+ style: 'dashed',
561
+ min: 0,
562
+ max: 1,
563
+ });
564
+
565
+ this.rocCurve.dataFrame = rocDf;
566
+ this.rocCurve.setOptions({
567
+ xColumnName: FPR_TITLE,
568
+ yColumnName: TPR_TITLE,
569
+ linesOrderColumnName: FPR_TITLE,
570
+ linesWidth: 5,
571
+ markerType: 'dot',
572
+ title: `ROC Curve (AUC = ${evaluation.auc.toFixed(3)})`,
573
+ });
574
+
575
+ return evaluation.threshold;
576
+ } // updateRocCurve
577
+
578
+ /** Updates the confusion matrix viewer with the given data frame, desirability column name, and best threshold */
579
+ private updateConfusionMatrix(df: DG.DataFrame, desColName: string, bestThreshold: number): void {
580
+ this.confusionMatrix.dataFrame = df;
581
+ this.confusionMatrix.setOptions({
582
+ xColumnName: desColName,
583
+ yColumnName: this.boolPredictionName,
584
+ description: `Threshold: ${bestThreshold.toFixed(3)} (optimized via Youden's J)`,
585
+ title: desColName + ' Confusion Matrix',
586
+ });
587
+ } // updateConfusionMatrix
588
+
589
+ /** Fits the pMPO model to the given data and updates the viewers accordingly */
590
+ private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
591
+ pValTresh: number, r2Tresh: number, qCutoff: number, useSigmoid: boolean): void {
592
+ const trainResult = Pmpo.fit(df, descriptors, desirability, pValTresh, r2Tresh, qCutoff);
593
+ this.params = trainResult.params;
594
+ const descrStatsTable = trainResult.descrStatsTable;
595
+ const selectedByPvalue = trainResult.selectedByPvalue;
596
+ const selectedByCorr = trainResult.selectedByCorr;
597
+
598
+ const descriptorNames = descriptors.names();
599
+
600
+ const prediction = Pmpo.predict(df, this.params, useSigmoid, this.predictionName);
601
+
602
+ // Mark predictions with a color
603
+ prediction.colors.setLinear(getOutputPalette(OPT_TYPE.MAX), {min: prediction.stats.min, max: prediction.stats.max});
604
+
605
+ // Remove existing prediction column and add the new one
606
+ df.columns.remove(this.predictionName);
607
+ df.columns.add(prediction);
608
+
609
+ // Update viewers
610
+ this.updateGrid();
611
+
612
+ // Update desirability profile roots map
613
+ this.updateDesirabilityProfileData(descrStatsTable, useSigmoid);
614
+
615
+ // Update statistics grid
616
+ this.updateStatisticsGrid(descrStatsTable, descriptorNames, selectedByPvalue, selectedByCorr);
617
+
618
+ // Update ROC curve
619
+ const bestThreshold = this.updateRocCurve(desirability, prediction);
620
+
621
+ // Update desirability prediction column
622
+ const desColName = desirability.name;
623
+ df.columns.remove(this.boolPredictionName);
624
+ this.boolPredictionName = df.columns.getUnusedName(desColName + '(predicted)');
625
+ const boolPrediction = getBoolPredictionColumn(prediction, bestThreshold, this.boolPredictionName);
626
+ df.columns.add(boolPrediction);
627
+
628
+ // Update confusion matrix
629
+ this.updateConfusionMatrix(df, desColName, bestThreshold);
630
+
631
+ this.view.dataFrame.selection.setAll(false, true);
632
+ } // fitAndUpdateViewers
633
+
634
+ /** Runs the pMPO model training application */
635
+ public runTrainingApp(): void {
636
+ const dockMng = this.view.dockManager;
637
+
638
+ // Inputs form
639
+ dockMng.dock(this.getInputForm(true).form, DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
640
+
641
+ // Dock viewers
642
+ const gridNode = dockMng.findNode(this.view.grid.root);
643
+
644
+ if (gridNode == null)
645
+ throw new Error('Failed to train pMPO: missing a grid in the table view.');
646
+
647
+ // Dock statistics grid
648
+ const statGridNode = dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
649
+
650
+ // Dock ROC curve
651
+ const rocNode = dockMng.dock(this.rocCurve, DG.DOCK_TYPE.RIGHT, statGridNode, undefined, 0.3);
652
+
653
+ // Dock confusion matrix
654
+ dockMng.dock(this.confusionMatrix, DG.DOCK_TYPE.RIGHT, rocNode, undefined, 0.2);
655
+
656
+ this.setRibbons();
657
+ } // runTrainingApp
658
+
659
+ /** Runs the pMPO model training application */
660
+ public getPmpoAppItems(): PmpoAppItems {
661
+ return {
662
+ statsGrid: this.statGrid,
663
+ rocCurve: this.rocCurve,
664
+ confusionMatrix: this.confusionMatrix,
665
+ controls: this.getInputForm(false),
666
+ };
667
+ } // getViewers
668
+
669
+ /** Creates and returns the input form for pMPO model training */
670
+ private getInputForm(addBtn: boolean): Controls {
671
+ const form = ui.form([]);
672
+ form.append(ui.h2('Training data'));
673
+ const numericColNames = this.numericCols.map((col) => col.name);
674
+
675
+ // Function to run computations on input changes
676
+ const runComputations = () => {
677
+ try {
678
+ //grok.shell.info('Running...');
679
+
680
+ this.fitAndUpdateViewers(
681
+ this.table,
682
+ DG.DataFrame.fromColumns(descrInput.value).columns,
683
+ this.table.col(desInput.value!)!,
684
+ pInput.value!,
685
+ rInput.value!,
686
+ qInput.value!,
687
+ useSigmoidInput.value,
688
+ );
689
+ } catch (err) {
690
+ err instanceof PmpoError ?
691
+ grok.shell.warning(err.message) :
692
+ grok.shell.error(err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.');
693
+ }
694
+ };
695
+
696
+ // Descriptor columns input
697
+ const descrInput = ui.input.columns('Descriptors', {
698
+ table: this.table,
699
+ nullable: false,
700
+ available: numericColNames,
701
+ checked: numericColNames,
702
+ tooltipText: 'Descriptor columns used for model construction.',
703
+ onValueChanged: (value) => {
704
+ if (value != null) {
705
+ areTunedSettingsUsed = false;
706
+ checkAutoTuneAndRun();
707
+ }
708
+ },
709
+ });
710
+ form.append(descrInput.root);
711
+
712
+ // Desirability column input
713
+ const desInput = ui.input.choice('Desirability', {
714
+ nullable: false,
715
+ value: this.boolCols[0].name,
716
+ items: this.boolCols.map((col) => col.name),
717
+ tooltipText: 'Desirability column.',
718
+ onValueChanged: (value) => {
719
+ if (value != null) {
720
+ areTunedSettingsUsed = false;
721
+ checkAutoTuneAndRun();
722
+ }
723
+ },
724
+ });
725
+ form.append(desInput.root);
726
+
727
+ const header = ui.h2('Settings');
728
+ form.append(header);
729
+ ui.tooltip.bind(header, 'Settings of the pMPO model.');
730
+
731
+ // use sigmoid correction
732
+ const useSigmoidInput = ui.input.bool('\u03C3 correction', {
733
+ value: USE_SIGMOID_DEFAULT,
734
+ tooltipText: 'Use the sigmoidal correction to the weighted Gaussian scores.',
735
+ onValueChanged: (_value) => {
736
+ areTunedSettingsUsed = false;
737
+ checkAutoTuneAndRun();
738
+ },
739
+ });
740
+ form.append(useSigmoidInput.root);
741
+
742
+ const toUseAutoTune = (this.table.rowCount <= AUTO_TUNE_MAX_APPLICABLE_ROWS);
743
+
744
+ // Flag indicating whether optimal parameters from auto-tuning are currently used
745
+ let areTunedSettingsUsed = false;
746
+
747
+ const setOptimalParametersAndRun = async () => {
748
+ if (!areTunedSettingsUsed) {
749
+ const optimalSettings = await this.getOptimalSettings(
750
+ DG.DataFrame.fromColumns(descrInput.value).columns,
751
+ this.table.col(desInput.value!)!,
752
+ useSigmoidInput.value,
753
+ );
754
+
755
+ if (optimalSettings.success) {
756
+ pInput.value = Math.max(optimalSettings.pValTresh, P_VAL_TRES_MIN);
757
+ rInput.value = Math.max(optimalSettings.r2Tresh, R2_MIN);
758
+ qInput.value = Math.max(optimalSettings.qCutoff, Q_CUTOFF_MIN);
759
+ areTunedSettingsUsed = true;
760
+ } else
761
+ autoTuneInput.value = false; // revert to manual mode if optimization failed
762
+ }
763
+
764
+ runComputations();
765
+ };
766
+
767
+ const checkAutoTuneAndRun = () => {
768
+ if (autoTuneInput.value)
769
+ setOptimalParametersAndRun();
770
+ else
771
+ runComputations();
772
+ };
773
+
774
+ // autotuning input
775
+ const autoTuneInput = ui.input.bool('Auto-tuning', {
776
+ value: false,
777
+ tooltipText: 'Automatically select optimal p-value, R², and q-cutoff by maximizing AUC.',
778
+ onValueChanged: async (value) => {
779
+ setEnability(!value);
780
+
781
+ if (areTunedSettingsUsed)
782
+ return;
783
+
784
+ // If auto-tuning is turned on, set optimal parameters and run computations
785
+ if (value)
786
+ await setOptimalParametersAndRun();
787
+ },
788
+ });
789
+ form.append(autoTuneInput.root);
790
+
791
+ // p-value threshold input
792
+ const pInput = ui.input.float('p-value', {
793
+ nullable: false,
794
+ min: P_VAL_TRES_MIN,
795
+ max: P_VAL_TRES_MAX,
796
+ step: 0.001,
797
+ value: P_VAL_TRES_DEFAULT,
798
+ // @ts-ignore
799
+ format: FORMAT,
800
+ tooltipText: 'P-value threshold. Descriptors with p-values above this threshold are excluded.',
801
+ onValueChanged: (value) => {
802
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
803
+ if (autoTuneInput.value)
804
+ return;
805
+
806
+ areTunedSettingsUsed = false;
807
+ if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= P_VAL_TRES_MAX))
808
+ runComputations();
809
+ },
810
+ });
811
+ form.append(pInput.root);
812
+
813
+ // R² threshold input
814
+ const rInput = ui.input.float('R²', {
815
+ // @ts-ignore
816
+ format: FORMAT,
817
+ nullable: false,
818
+ min: R2_MIN,
819
+ value: R2_DEFAULT,
820
+ max: R2_MAX,
821
+ step: 0.01,
822
+ // eslint-disable-next-line max-len
823
+ tooltipText: 'Squared correlation threshold. Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
824
+ onValueChanged: (value) => {
825
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
826
+ if (autoTuneInput.value)
827
+ return;
828
+
829
+ areTunedSettingsUsed = false;
830
+
831
+ if ((value != null) && (value >= R2_MIN) && (value <= R2_MAX))
832
+ runComputations();
833
+ },
834
+ });
835
+ form.append(rInput.root);
836
+
837
+ // q-cutoff input
838
+ const qInput = ui.input.float('q-cutoff', {
839
+ // @ts-ignore
840
+ format: FORMAT,
841
+ nullable: false,
842
+ min: Q_CUTOFF_MIN,
843
+ value: Q_CUTOFF_DEFAULT,
844
+ max: Q_CUTOFF_MAX,
845
+ step: 0.01,
846
+ tooltipText: 'Q-cutoff for the pMPO model computation.',
847
+ onValueChanged: (value) => {
848
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
849
+ if (autoTuneInput.value)
850
+ return;
851
+
852
+ areTunedSettingsUsed = false;
853
+
854
+ if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= Q_CUTOFF_MAX))
855
+ runComputations();
856
+ },
857
+ });
858
+ form.append(qInput.root);
859
+
860
+ const setEnability = (toEnable: boolean) => {
861
+ pInput.enabled = toEnable;
862
+ rInput.enabled = toEnable;
863
+ qInput.enabled = toEnable;
864
+ };
865
+
866
+ setTimeout(() => {
867
+ runComputations();
868
+
869
+ if (toUseAutoTune)
870
+ autoTuneInput.value = true; // this will trigger setting optimal parameters and running computations
871
+ else
872
+ runComputations();
873
+ }, 10);
874
+
875
+ // Save model button
876
+ const saveBtn = ui.button('Save', async () => {
877
+ if (this.params == null) {
878
+ grok.shell.warning('Failed to save pMPO model: null parameters.');
879
+ return;
880
+ }
881
+
882
+ saveModel(this.params, this.table.name, useSigmoidInput.value);
883
+ }, 'Save model as platform file.');
884
+
885
+ if (addBtn)
886
+ form.append(saveBtn);
887
+
888
+ const div = ui.div([form]);
889
+ div.classList.add('eda-pmpo-input-form');
890
+
891
+ return {
892
+ form: div,
893
+ saveBtn: saveBtn,
894
+ };
895
+ } // getInputForm
896
+
897
+ /** Retrieves boolean columns from the data frame */
898
+ private getBoolCols(): DG.Column[] {
899
+ const res: DG.Column[] = [];
900
+
901
+ for (const col of this.table.columns) {
902
+ if ((col.type === DG.COLUMN_TYPE.BOOL) && (col.stats.stdev > 0))
903
+ res.push(col);
904
+ }
905
+
906
+ return res;
907
+ } // getBoolCols
908
+
909
+ /** Retrieves valid (numerical, no missing values, non-zero standard deviation) numeric columns from the data frame */
910
+ private getValidNumericCols(): DG.Column[] {
911
+ const res: DG.Column[] = [];
912
+
913
+ for (const col of this.table.columns) {
914
+ if ((col.isNumerical) && (col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
915
+ res.push(col);
916
+ }
917
+
918
+ return res;
919
+ } // getValidNumericCols
920
+
921
+ /** Fits the pMPO model to the given data and updates the viewers accordingly */
922
+ private async getOptimalSettings(descriptors: DG.ColumnList, desirability: DG.Column, useSigmoid: boolean): Promise<OptimalPoint> {
923
+ const failedResult: OptimalPoint = {
924
+ pValTresh: 0,
925
+ r2Tresh: 0,
926
+ qCutoff: 0,
927
+ success: false,
928
+ };
929
+
930
+ const descriptorNames = descriptors.names();
931
+ const {desired, nonDesired} = getDesiredTables(this.table, desirability);
932
+
933
+ // Compute descriptors' statistics
934
+ const descrStats = new Map<string, DescriptorStatistics>();
935
+ descriptorNames.forEach((name) => {
936
+ descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
937
+ });
938
+ const descrStatsTable = getDescriptorStatisticsTable(descrStats);
939
+
940
+ // Filter by p-value
941
+ const selectedByPvalue = getFilteredByPvalue(descrStatsTable, P_VAL_TRES_DEFAULT);
942
+ if (selectedByPvalue.length < 1)
943
+ return failedResult;
944
+
945
+ const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
946
+
947
+ const funcToBeMinimized = (point: Float32Array) => {
948
+ // Filter by correlations
949
+ const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, point[0], correlationTriples);
950
+
951
+ // Compute pMPO parameters - training
952
+ const params = getModelParams(desired, nonDesired, selectedByCorr, point[1]);
953
+
954
+ // Get predictions
955
+ const prediction = Pmpo.predict(this.table, params, useSigmoid, this.predictionName);
956
+
957
+ // Evaluate predictions and return 1 - AUC (since optimization minimizes the function, but we want to maximize AUC)
958
+ return 1 - getPmpoEvaluation(desirability, prediction).auc;
959
+ }; // funcToBeMinimized
960
+
961
+ const pi = DG.TaskBarProgressIndicator.create('Optimizing... ', {cancelable: true});
962
+
963
+ try {
964
+ const optimalResult = await optimizeNM(
965
+ pi,
966
+ funcToBeMinimized,
967
+ new Float32Array([R2_DEFAULT, Q_CUTOFF_DEFAULT]),
968
+ DEFAULT_OPTIMIZATION_SETTINGS,
969
+ LOW_PARAMS_BOUNDS,
970
+ HIGH_PARAMS_BOUNDS,
971
+ );
972
+
973
+ const success = !pi.canceled;
974
+ pi.close();
975
+
976
+ if (success) {
977
+ return {
978
+ pValTresh: P_VAL_TRES_DEFAULT,
979
+ r2Tresh: optimalResult.optimalPoint[0],
980
+ qCutoff: optimalResult.optimalPoint[1],
981
+ success: true,
982
+ };
983
+ } else
984
+ return failedResult;
985
+ } catch (err) {
986
+ pi.close();
987
+
988
+ return failedResult;
989
+ }
990
+ } // getOptimalSettings
991
+ }; // Pmpo