@datagrok/eda 1.4.12 → 1.4.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,15 +10,37 @@ import {MpoProfileEditor} from '@datagrok-libraries/statistics/src/mpo/mpo-profi
10
10
 
11
11
  import '../../css/pmpo.css';
12
12
 
13
- import {getDesiredTables, getDescriptorStatistics, normalPdf, sigmoidS} from './stat-tools';
13
+ import {getDesiredTables, getDescriptorStatistics, getBoolPredictionColumn, getPmpoEvaluation} from './stat-tools';
14
14
  import {MIN_SAMPLES_COUNT, PMPO_NON_APPLICABLE, DescriptorStatistics, P_VAL_TRES_MIN, DESCR_TITLE,
15
15
  R2_MIN, Q_CUTOFF_MIN, PmpoParams, SCORES_TITLE, DESCR_TABLE_TITLE, PMPO_COMPUTE_FAILED, SELECTED_TITLE,
16
- P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE} from './pmpo-defs';
16
+ P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE,
17
+ P_VAL_TRES_DEFAULT, R2_DEFAULT, Q_CUTOFF_DEFAULT, USE_SIGMOID_DEFAULT, ROC_TRESHOLDS,
18
+ FPR_TITLE, TPR_TITLE, COLORS, THRESHOLD, AUTO_TUNE_MAX_APPLICABLE_ROWS, DEFAULT_OPTIMIZATION_SETTINGS,
19
+ P_VAL_TRES_MAX, R2_MAX, Q_CUTOFF_MAX, OptimalPoint, LOW_PARAMS_BOUNDS, HIGH_PARAMS_BOUNDS, FORMAT} from './pmpo-defs';
17
20
  import {addSelectedDescriptorsCol, getDescriptorStatisticsTable, getFilteredByPvalue, getFilteredByCorrelations,
18
21
  getModelParams, getDescrTooltip, saveModel, getScoreTooltip, getDesirabilityProfileJson, getCorrelationTriples,
19
- addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding} from './pmpo-utils';
22
+ addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding, PmpoError} from './pmpo-utils';
20
23
  import {getOutputPalette} from '../pareto-optimization/utils';
21
24
  import {OPT_TYPE} from '../pareto-optimization/defs';
25
+ import {optimizeNM} from './nelder-mead';
26
+
27
+ export type PmpoTrainingResult = {
28
+ params: Map<string, PmpoParams>,
29
+ descrStatsTable: DG.DataFrame,
30
+ selectedByPvalue: string[],
31
+ selectedByCorr: string[],
32
+ };
33
+
34
+ /** Type for pMPO training controls */
35
+ export type Controls = {form: HTMLElement, saveBtn: HTMLButtonElement};
36
+
37
+ /** Type for pMPO elements */
38
+ export type PmpoAppItems = {
39
+ statsGrid: DG.Viewer;
40
+ rocCurve: DG.Viewer;
41
+ confusionMatrix: DG.Viewer;
42
+ controls: Controls;
43
+ };
22
44
 
23
45
  /** Class implementing probabilistic MPO (pMPO) model training and prediction */
24
46
  export class Pmpo {
@@ -131,23 +153,93 @@ export class Pmpo {
131
153
  return true;
132
154
  } // isTableValid
133
155
 
156
+ /** Fits the pMPO model to the given data and returns training results */
157
+ static fit(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
158
+ pValTresh: number, r2Tresh: number, qCutoff: number, toCheckApplicability: boolean = true): PmpoTrainingResult {
159
+ if (toCheckApplicability) {
160
+ if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
161
+ throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
162
+ }
163
+
164
+ const descriptorNames = descriptors.names();
165
+ const {desired, nonDesired} = getDesiredTables(df, desirability);
166
+
167
+ // Compute descriptors' statistics
168
+ const descrStats = new Map<string, DescriptorStatistics>();
169
+ descriptorNames.forEach((name) => {
170
+ descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
171
+ });
172
+ const descrStatsTable = getDescriptorStatisticsTable(descrStats);
173
+
174
+ // Set p-value column color coding
175
+ setPvalColumnColorCoding(descrStatsTable, pValTresh);
176
+
177
+ // Filter by p-value
178
+ const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
179
+
180
+ if (selectedByPvalue.length < 1)
181
+ throw new PmpoError('Cannot train pMPO model: all descriptors have high p-values (not significant).');
182
+
183
+ // Compute correlation triples
184
+ const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
185
+
186
+ // Filter by correlations
187
+ const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
188
+
189
+ // Add the Selected column
190
+ addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
191
+
192
+ // Add correlation columns
193
+ addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
194
+
195
+ // Set correlation columns color coding
196
+ setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
197
+
198
+ // Compute pMPO parameters - training
199
+ const params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
200
+
201
+ return {
202
+ params: params,
203
+ descrStatsTable: descrStatsTable,
204
+ selectedByPvalue: selectedByPvalue,
205
+ selectedByCorr: selectedByCorr,
206
+ };
207
+ } // fitModelParams
208
+
134
209
  /** Predicts pMPO scores for the given data frame using provided pMPO parameters */
135
- static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, predictionName: string): DG.Column {
210
+ static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, useSigmoid: boolean, predictionName: string): DG.Column {
136
211
  const count = df.rowCount;
137
212
  const scores = new Float64Array(count).fill(0);
138
- let x = 0;
139
213
 
140
214
  // Compute pMPO scores (see https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
141
215
  params.forEach((param, name) => {
142
216
  const col = df.col(name);
217
+ const b = param.b;
218
+ const c = param.c;
219
+ const x0 = param.cutoff;
220
+ let weight = param.weight;
221
+ const avg = param.desAvg;
222
+ const std = param.desStd;
223
+ const frac = 1.0 / (2 * std**2);
143
224
 
144
225
  if (col == null)
145
- throw new Error(`Filed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
226
+ throw new Error(`Failed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
146
227
 
147
228
  const vals = col.getRawData();
148
- for (let i = 0; i < count; ++i) {
149
- x = vals[i];
150
- scores[i] += param.weight * normalPdf(x, param.desAvg, param.desStd) * sigmoidS(x, param.x0, param.b, param.c);
229
+
230
+ if (useSigmoid) {
231
+ if (c > 0) {
232
+ for (let i = 0; i < count; ++i)
233
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac) / (1.0 + b * (c ** (-(vals[i] - x0))));
234
+ } else {
235
+ weight = weight / (1.0 + b);
236
+
237
+ for (let i = 0; i < count; ++i)
238
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
239
+ }
240
+ } else {
241
+ for (let i = 0; i < count; ++i)
242
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
151
243
  }
152
244
  });
153
245
 
@@ -161,16 +253,34 @@ export class Pmpo {
161
253
  private boolCols: DG.Column[];
162
254
  private numericCols: DG.Column[];
163
255
 
164
- private initTable = grok.data.demo.demog(10);
256
+ private initTable = DG.DataFrame.create();
165
257
 
166
258
  private statGrid = DG.Viewer.grid(this.initTable, {showTitle: true, title: DESCR_TABLE_TITLE});
167
259
 
168
260
  private predictionName = SCORES_TITLE;
261
+ private boolPredictionName = '';
169
262
 
170
263
  private desirabilityProfileRoots = new Map<string, HTMLElement>();
171
- constructor(df: DG.DataFrame) {
264
+
265
+ private rocCurve = DG.Viewer.scatterPlot(this.initTable, {
266
+ showTitle: true,
267
+ showSizeSelector: false,
268
+ showColorSelector: false,
269
+ });
270
+
271
+ private confusionMatrix = DG.Viewer.fromType('Confusion matrix', this.initTable, {
272
+ xColumnName: 'control',
273
+ yColumnName: 'control',
274
+ showTitle: true,
275
+ title: 'Confusion Matrix',
276
+ descriptionPosition: 'Bottom',
277
+ description: 'Confusion matrix for the predicted vs actual desirability labels.',
278
+ descriptionVisibilityMode: 'Always',
279
+ });
280
+
281
+ constructor(df: DG.DataFrame, view?: DG.TableView) {
172
282
  this.table = df;
173
- this.view = grok.shell.tableView(df.name) ?? grok.shell.addTableView(df);
283
+ this.view = view ?? (grok.shell.tableView(df.name) ?? grok.shell.addTableView(df));
174
284
  this.boolCols = this.getBoolCols();
175
285
  this.numericCols = this.getValidNumericCols();
176
286
  this.predictionName = df.columns.getUnusedName(SCORES_TITLE);
@@ -294,14 +404,12 @@ export class Pmpo {
294
404
  } else {
295
405
  const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
296
406
 
297
- if ((colName === DESIRABILITY_COL_NAME) || (colName === WEIGHT_TITLE)) {
298
- const startText = (colName === WEIGHT_TITLE) ? 'No weight' : 'No chart shown';
299
-
407
+ if (colName === WEIGHT_TITLE) {
300
408
  if (!this.desirabilityProfileRoots.has(descriptor)) {
301
409
  if (selectedByPvalue.includes(descriptor))
302
- ui.tooltip.show(`${startText}: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
410
+ ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
303
411
  else
304
- ui.tooltip.show(`${startText}: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
412
+ ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
305
413
 
306
414
  return true;
307
415
  }
@@ -341,7 +449,23 @@ export class Pmpo {
341
449
  return;
342
450
 
343
451
  const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
344
- cell.element = this.desirabilityProfileRoots.get(descriptor) ?? ui.div();
452
+ const element = this.desirabilityProfileRoots.get(descriptor);
453
+
454
+ if (element != null)
455
+ cell.element = element;
456
+ else {
457
+ const selected = selectedByPvalue.includes(descriptor);
458
+ const text = selected ? 'highly correlated with other descriptors' : 'statistically insignificant';
459
+ const tooltipMsg = selected ?
460
+ `No chart shown: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.` :
461
+ `No chart shown: <b>${descriptor}</b> is excluded due to a high p-value.`;
462
+
463
+ const divWithDescription = ui.divText(text);
464
+ divWithDescription.style.color = COLORS.SKIPPED;
465
+ divWithDescription.classList.add('eda-pmpo-centered-text');
466
+ ui.tooltip.bind(divWithDescription, tooltipMsg);
467
+ cell.element = divWithDescription;
468
+ }
345
469
  }); // grid.onCellPrepare
346
470
  } // updateGrid
347
471
 
@@ -372,7 +496,7 @@ export class Pmpo {
372
496
  } // updateGrid
373
497
 
374
498
  /** Updates the desirability profile data */
375
- private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame): void {
499
+ private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame, useSigmoidalCorrection: boolean): void {
376
500
  if (this.params == null)
377
501
  return;
378
502
 
@@ -380,7 +504,7 @@ export class Pmpo {
380
504
  this.desirabilityProfileRoots.forEach((root) => root.remove());
381
505
  this.desirabilityProfileRoots.clear();
382
506
 
383
- const desirabilityProfile = getDesirabilityProfileJson(this.params, '', '');
507
+ const desirabilityProfile = getDesirabilityProfileJson(this.params, useSigmoidalCorrection, '', '', true);
384
508
 
385
509
  // Set weights
386
510
  const descrNames = descrStatsTable.col(DESCR_TITLE)!.toList();
@@ -416,52 +540,69 @@ export class Pmpo {
416
540
  });
417
541
  } // updateDesirabilityProfileData
418
542
 
419
- /** Fits the pMPO model to the given data and updates the viewers accordingly */
420
- private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
421
- pValTresh: number, r2Tresh: number, qCutoff: number): void {
422
- if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
423
- throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
424
-
425
- const descriptorNames = descriptors.names();
426
- const {desired, nonDesired} = getDesiredTables(df, desirability);
427
-
428
- // Compute descriptors' statistics
429
- const descrStats = new Map<string, DescriptorStatistics>();
430
- descriptorNames.forEach((name) => {
431
- descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
543
+ /** Updates the ROC curve viewer with the given desirability (labels) and prediction columns
544
+ * @return Best threshold according to Youden's J statistic
545
+ */
546
+ private updateRocCurve(desirability: DG.Column, prediction: DG.Column): number {
547
+ const evaluation = getPmpoEvaluation(desirability, prediction);
548
+
549
+ const rocDf = DG.DataFrame.fromColumns([
550
+ DG.Column.fromFloat32Array(THRESHOLD, ROC_TRESHOLDS),
551
+ DG.Column.fromFloat32Array(FPR_TITLE, evaluation.fpr),
552
+ DG.Column.fromFloat32Array(TPR_TITLE, evaluation.tpr),
553
+ ]);
554
+
555
+ // Add baseline
556
+ rocDf.meta.formulaLines.addLine({
557
+ title: 'Non-informative baseline',
558
+ formula: `\${${TPR_TITLE}} = \${${FPR_TITLE}}`,
559
+ width: 1,
560
+ style: 'dashed',
561
+ min: 0,
562
+ max: 1,
432
563
  });
433
- const descrStatsTable = getDescriptorStatisticsTable(descrStats);
434
-
435
- // Set p-value column color coding
436
- setPvalColumnColorCoding(descrStatsTable, pValTresh);
437
-
438
- // Filter by p-value
439
- const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
440
-
441
- // Compute correlation triples
442
- const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
443
-
444
- // Filter by correlations
445
- const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
446
564
 
447
- // Add the Selected column
448
- addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
565
+ this.rocCurve.dataFrame = rocDf;
566
+ this.rocCurve.setOptions({
567
+ xColumnName: FPR_TITLE,
568
+ yColumnName: TPR_TITLE,
569
+ linesOrderColumnName: FPR_TITLE,
570
+ linesWidth: 5,
571
+ markerType: 'dot',
572
+ title: `ROC Curve (AUC = ${evaluation.auc.toFixed(3)})`,
573
+ });
449
574
 
450
- // Add correlation columns
451
- addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
575
+ return evaluation.threshold;
576
+ } // updateRocCurve
577
+
578
+ /** Updates the confusion matrix viewer with the given data frame, desirability column name, and best threshold */
579
+ private updateConfusionMatrix(df: DG.DataFrame, desColName: string, bestThreshold: number): void {
580
+ this.confusionMatrix.dataFrame = df;
581
+ this.confusionMatrix.setOptions({
582
+ xColumnName: desColName,
583
+ yColumnName: this.boolPredictionName,
584
+ description: `Threshold: ${bestThreshold.toFixed(3)} (optimized via Youden's J)`,
585
+ title: desColName + ' Confusion Matrix',
586
+ });
587
+ } // updateConfusionMatrix
452
588
 
453
- // Set correlation columns color coding
454
- setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
589
+ /** Fits the pMPO model to the given data and updates the viewers accordingly */
590
+ private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
591
+ pValTresh: number, r2Tresh: number, qCutoff: number, useSigmoid: boolean): void {
592
+ const trainResult = Pmpo.fit(df, descriptors, desirability, pValTresh, r2Tresh, qCutoff);
593
+ this.params = trainResult.params;
594
+ const descrStatsTable = trainResult.descrStatsTable;
595
+ const selectedByPvalue = trainResult.selectedByPvalue;
596
+ const selectedByCorr = trainResult.selectedByCorr;
455
597
 
456
- // Compute pMPO parameters - training
457
- this.params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
598
+ const descriptorNames = descriptors.names();
458
599
 
459
- //const weightsTable = getWeightsTable(this.params);
460
- const prediction = Pmpo.predict(df, this.params, this.predictionName);
600
+ const prediction = Pmpo.predict(df, this.params, useSigmoid, this.predictionName);
461
601
 
462
602
  // Mark predictions with a color
463
603
  prediction.colors.setLinear(getOutputPalette(OPT_TYPE.MAX), {min: prediction.stats.min, max: prediction.stats.max});
464
604
 
605
+ // Remove existing prediction column and add the new one
465
606
  df.columns.remove(this.predictionName);
466
607
  df.columns.add(prediction);
467
608
 
@@ -469,10 +610,25 @@ export class Pmpo {
469
610
  this.updateGrid();
470
611
 
471
612
  // Update desirability profile roots map
472
- this.updateDesirabilityProfileData(descrStatsTable);
613
+ this.updateDesirabilityProfileData(descrStatsTable, useSigmoid);
473
614
 
474
615
  // Update statistics grid
475
616
  this.updateStatisticsGrid(descrStatsTable, descriptorNames, selectedByPvalue, selectedByCorr);
617
+
618
+ // Update ROC curve
619
+ const bestThreshold = this.updateRocCurve(desirability, prediction);
620
+
621
+ // Update desirability prediction column
622
+ const desColName = desirability.name;
623
+ df.columns.remove(this.boolPredictionName);
624
+ this.boolPredictionName = df.columns.getUnusedName(desColName + '(predicted)');
625
+ const boolPrediction = getBoolPredictionColumn(prediction, bestThreshold, this.boolPredictionName);
626
+ df.columns.add(boolPrediction);
627
+
628
+ // Update confusion matrix
629
+ this.updateConfusionMatrix(df, desColName, bestThreshold);
630
+
631
+ this.view.dataFrame.selection.setAll(false, true);
476
632
  } // fitAndUpdateViewers
477
633
 
478
634
  /** Runs the pMPO model training application */
@@ -480,7 +636,7 @@ export class Pmpo {
480
636
  const dockMng = this.view.dockManager;
481
637
 
482
638
  // Inputs form
483
- dockMng.dock(this.getInputForm(), DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
639
+ dockMng.dock(this.getInputForm(true).form, DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
484
640
 
485
641
  // Dock viewers
486
642
  const gridNode = dockMng.findNode(this.view.grid.root);
@@ -488,13 +644,30 @@ export class Pmpo {
488
644
  if (gridNode == null)
489
645
  throw new Error('Failed to train pMPO: missing a grid in the table view.');
490
646
 
491
- dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
647
+ // Dock statistics grid
648
+ const statGridNode = dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
649
+
650
+ // Dock ROC curve
651
+ const rocNode = dockMng.dock(this.rocCurve, DG.DOCK_TYPE.RIGHT, statGridNode, undefined, 0.3);
652
+
653
+ // Dock confusion matrix
654
+ dockMng.dock(this.confusionMatrix, DG.DOCK_TYPE.RIGHT, rocNode, undefined, 0.2);
492
655
 
493
656
  this.setRibbons();
494
657
  } // runTrainingApp
495
658
 
659
+ /** Runs the pMPO model training application */
660
+ public getPmpoAppItems(): PmpoAppItems {
661
+ return {
662
+ statsGrid: this.statGrid,
663
+ rocCurve: this.rocCurve,
664
+ confusionMatrix: this.confusionMatrix,
665
+ controls: this.getInputForm(false),
666
+ };
667
+ } // getViewers
668
+
496
669
  /** Creates and returns the input form for pMPO model training */
497
- private getInputForm(): HTMLElement {
670
+ private getInputForm(addBtn: boolean): Controls {
498
671
  const form = ui.form([]);
499
672
  form.append(ui.h2('Training data'));
500
673
  const numericColNames = this.numericCols.map((col) => col.name);
@@ -502,16 +675,21 @@ export class Pmpo {
502
675
  // Function to run computations on input changes
503
676
  const runComputations = () => {
504
677
  try {
678
+ //grok.shell.info('Running...');
679
+
505
680
  this.fitAndUpdateViewers(
506
681
  this.table,
507
682
  DG.DataFrame.fromColumns(descrInput.value).columns,
508
- this.table.col(desInput.value!)!,
509
- pInput.value!,
510
- rInput.value!,
511
- qInput.value!,
683
+ this.table.col(desInput.value!)!,
684
+ pInput.value!,
685
+ rInput.value!,
686
+ qInput.value!,
687
+ useSigmoidInput.value,
512
688
  );
513
689
  } catch (err) {
514
- grok.shell.error(err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.');
690
+ err instanceof PmpoError ?
691
+ grok.shell.warning(err.message) :
692
+ grok.shell.error(err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.');
515
693
  }
516
694
  };
517
695
 
@@ -523,8 +701,10 @@ export class Pmpo {
523
701
  checked: numericColNames,
524
702
  tooltipText: 'Descriptor columns used for model construction.',
525
703
  onValueChanged: (value) => {
526
- if (value != null)
527
- runComputations();
704
+ if (value != null) {
705
+ areTunedSettingsUsed = false;
706
+ checkAutoTuneAndRun();
707
+ }
528
708
  },
529
709
  });
530
710
  form.append(descrInput.root);
@@ -536,26 +716,95 @@ export class Pmpo {
536
716
  items: this.boolCols.map((col) => col.name),
537
717
  tooltipText: 'Desirability column.',
538
718
  onValueChanged: (value) => {
539
- if (value != null)
540
- runComputations();
719
+ if (value != null) {
720
+ areTunedSettingsUsed = false;
721
+ checkAutoTuneAndRun();
722
+ }
541
723
  },
542
724
  });
543
725
  form.append(desInput.root);
544
726
 
545
- const header = ui.h2('Thresholds');
546
- ui.tooltip.bind(header, 'Settings of the pMPO model training.');
727
+ const header = ui.h2('Settings');
547
728
  form.append(header);
729
+ ui.tooltip.bind(header, 'Settings of the pMPO model.');
730
+
731
+ // use sigmoid correction
732
+ const useSigmoidInput = ui.input.bool('\u03C3 correction', {
733
+ value: USE_SIGMOID_DEFAULT,
734
+ tooltipText: 'Use the sigmoidal correction to the weighted Gaussian scores.',
735
+ onValueChanged: (_value) => {
736
+ areTunedSettingsUsed = false;
737
+ checkAutoTuneAndRun();
738
+ },
739
+ });
740
+ form.append(useSigmoidInput.root);
741
+
742
+ const toUseAutoTune = (this.table.rowCount <= AUTO_TUNE_MAX_APPLICABLE_ROWS);
743
+
744
+ // Flag indicating whether optimal parameters from auto-tuning are currently used
745
+ let areTunedSettingsUsed = false;
746
+
747
+ const setOptimalParametersAndRun = async () => {
748
+ if (!areTunedSettingsUsed) {
749
+ const optimalSettings = await this.getOptimalSettings(
750
+ DG.DataFrame.fromColumns(descrInput.value).columns,
751
+ this.table.col(desInput.value!)!,
752
+ useSigmoidInput.value,
753
+ );
754
+
755
+ if (optimalSettings.success) {
756
+ pInput.value = Math.max(optimalSettings.pValTresh, P_VAL_TRES_MIN);
757
+ rInput.value = Math.max(optimalSettings.r2Tresh, R2_MIN);
758
+ qInput.value = Math.max(optimalSettings.qCutoff, Q_CUTOFF_MIN);
759
+ areTunedSettingsUsed = true;
760
+ } else
761
+ autoTuneInput.value = false; // revert to manual mode if optimization failed
762
+ }
763
+
764
+ runComputations();
765
+ };
766
+
767
+ const checkAutoTuneAndRun = () => {
768
+ if (autoTuneInput.value)
769
+ setOptimalParametersAndRun();
770
+ else
771
+ runComputations();
772
+ };
773
+
774
+ // autotuning input
775
+ const autoTuneInput = ui.input.bool('Auto-tuning', {
776
+ value: false,
777
+ tooltipText: 'Automatically select optimal p-value, R², and q-cutoff by maximizing AUC.',
778
+ onValueChanged: async (value) => {
779
+ setEnability(!value);
780
+
781
+ if (areTunedSettingsUsed)
782
+ return;
783
+
784
+ // If auto-tuning is turned on, set optimal parameters and run computations
785
+ if (value)
786
+ await setOptimalParametersAndRun();
787
+ },
788
+ });
789
+ form.append(autoTuneInput.root);
548
790
 
549
791
  // p-value threshold input
550
792
  const pInput = ui.input.float('p-value', {
551
793
  nullable: false,
552
794
  min: P_VAL_TRES_MIN,
553
- max: 1,
554
- step: 0.01,
555
- value: 0.05,
556
- tooltipText: 'Descriptors with p-values above this threshold are excluded.',
795
+ max: P_VAL_TRES_MAX,
796
+ step: 0.001,
797
+ value: P_VAL_TRES_DEFAULT,
798
+ // @ts-ignore
799
+ format: FORMAT,
800
+ tooltipText: 'P-value threshold. Descriptors with p-values above this threshold are excluded.',
557
801
  onValueChanged: (value) => {
558
- if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= 1))
802
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
803
+ if (autoTuneInput.value)
804
+ return;
805
+
806
+ areTunedSettingsUsed = false;
807
+ if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= P_VAL_TRES_MAX))
559
808
  runComputations();
560
809
  },
561
810
  });
@@ -563,15 +812,23 @@ export class Pmpo {
563
812
 
564
813
  // R² threshold input
565
814
  const rInput = ui.input.float('R²', {
815
+ // @ts-ignore
816
+ format: FORMAT,
566
817
  nullable: false,
567
818
  min: R2_MIN,
568
- value: 0.5,
569
- max: 1,
819
+ value: R2_DEFAULT,
820
+ max: R2_MAX,
570
821
  step: 0.01,
571
822
  // eslint-disable-next-line max-len
572
- tooltipText: 'Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
823
+ tooltipText: 'Squared correlation threshold. Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
573
824
  onValueChanged: (value) => {
574
- if ((value != null) && (value >= R2_MIN) && (value <= 1))
825
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
826
+ if (autoTuneInput.value)
827
+ return;
828
+
829
+ areTunedSettingsUsed = false;
830
+
831
+ if ((value != null) && (value >= R2_MIN) && (value <= R2_MAX))
575
832
  runComputations();
576
833
  },
577
834
  });
@@ -579,36 +836,62 @@ export class Pmpo {
579
836
 
580
837
  // q-cutoff input
581
838
  const qInput = ui.input.float('q-cutoff', {
839
+ // @ts-ignore
840
+ format: FORMAT,
582
841
  nullable: false,
583
842
  min: Q_CUTOFF_MIN,
584
- value: 0.05,
585
- max: 1,
843
+ value: Q_CUTOFF_DEFAULT,
844
+ max: Q_CUTOFF_MAX,
586
845
  step: 0.01,
587
846
  tooltipText: 'Q-cutoff for the pMPO model computation.',
588
847
  onValueChanged: (value) => {
589
- if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= 1))
848
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
849
+ if (autoTuneInput.value)
850
+ return;
851
+
852
+ areTunedSettingsUsed = false;
853
+
854
+ if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= Q_CUTOFF_MAX))
590
855
  runComputations();
591
856
  },
592
857
  });
593
858
  form.append(qInput.root);
594
859
 
595
- setTimeout(() => runComputations(), 10);
860
+ const setEnability = (toEnable: boolean) => {
861
+ pInput.enabled = toEnable;
862
+ rInput.enabled = toEnable;
863
+ qInput.enabled = toEnable;
864
+ };
865
+
866
+ setTimeout(() => {
867
+ runComputations();
868
+
869
+ if (toUseAutoTune)
870
+ autoTuneInput.value = true; // this will trigger setting optimal parameters and running computations
871
+ else
872
+ runComputations();
873
+ }, 10);
596
874
 
597
875
  // Save model button
598
- const saveBtn = ui.button('Save model', async () => {
876
+ const saveBtn = ui.button('Save', async () => {
599
877
  if (this.params == null) {
600
878
  grok.shell.warning('Failed to save pMPO model: null parameters.');
601
879
  return;
602
880
  }
603
881
 
604
- saveModel(this.params, this.table.name);
882
+ saveModel(this.params, this.table.name, useSigmoidInput.value);
605
883
  }, 'Save model as platform file.');
606
- form.append(saveBtn);
884
+
885
+ if (addBtn)
886
+ form.append(saveBtn);
607
887
 
608
888
  const div = ui.div([form]);
609
889
  div.classList.add('eda-pmpo-input-form');
610
890
 
611
- return div;
891
+ return {
892
+ form: div,
893
+ saveBtn: saveBtn,
894
+ };
612
895
  } // getInputForm
613
896
 
614
897
  /** Retrieves boolean columns from the data frame */
@@ -634,4 +917,75 @@ export class Pmpo {
634
917
 
635
918
  return res;
636
919
  } // getValidNumericCols
920
+
921
+ /** Fits the pMPO model to the given data and updates the viewers accordingly */
922
+ private async getOptimalSettings(descriptors: DG.ColumnList, desirability: DG.Column, useSigmoid: boolean): Promise<OptimalPoint> {
923
+ const failedResult: OptimalPoint = {
924
+ pValTresh: 0,
925
+ r2Tresh: 0,
926
+ qCutoff: 0,
927
+ success: false,
928
+ };
929
+
930
+ const descriptorNames = descriptors.names();
931
+ const {desired, nonDesired} = getDesiredTables(this.table, desirability);
932
+
933
+ // Compute descriptors' statistics
934
+ const descrStats = new Map<string, DescriptorStatistics>();
935
+ descriptorNames.forEach((name) => {
936
+ descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
937
+ });
938
+ const descrStatsTable = getDescriptorStatisticsTable(descrStats);
939
+
940
+ // Filter by p-value
941
+ const selectedByPvalue = getFilteredByPvalue(descrStatsTable, P_VAL_TRES_DEFAULT);
942
+ if (selectedByPvalue.length < 1)
943
+ return failedResult;
944
+
945
+ const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
946
+
947
+ const funcToBeMinimized = (point: Float32Array) => {
948
+ // Filter by correlations
949
+ const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, point[0], correlationTriples);
950
+
951
+ // Compute pMPO parameters - training
952
+ const params = getModelParams(desired, nonDesired, selectedByCorr, point[1]);
953
+
954
+ // Get predictions
955
+ const prediction = Pmpo.predict(this.table, params, useSigmoid, this.predictionName);
956
+
957
+ // Evaluate predictions and return 1 - AUC (since optimization minimizes the function, but we want to maximize AUC)
958
+ return 1 - getPmpoEvaluation(desirability, prediction).auc;
959
+ }; // funcToBeMinimized
960
+
961
+ const pi = DG.TaskBarProgressIndicator.create('Optimizing... ', {cancelable: true});
962
+
963
+ try {
964
+ const optimalResult = await optimizeNM(
965
+ pi,
966
+ funcToBeMinimized,
967
+ new Float32Array([R2_DEFAULT, Q_CUTOFF_DEFAULT]),
968
+ DEFAULT_OPTIMIZATION_SETTINGS,
969
+ LOW_PARAMS_BOUNDS,
970
+ HIGH_PARAMS_BOUNDS,
971
+ );
972
+
973
+ const success = !pi.canceled;
974
+ pi.close();
975
+
976
+ if (success) {
977
+ return {
978
+ pValTresh: P_VAL_TRES_DEFAULT,
979
+ r2Tresh: optimalResult.optimalPoint[0],
980
+ qCutoff: optimalResult.optimalPoint[1],
981
+ success: true,
982
+ };
983
+ } else
984
+ return failedResult;
985
+ } catch (err) {
986
+ pi.close();
987
+
988
+ return failedResult;
989
+ }
990
+ } // getOptimalSettings
637
991
  }; // Pmpo