@datagrok/eda 1.4.12 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. package/.eslintrc.json +0 -1
  2. package/CHANGELOG.md +10 -0
  3. package/CLAUDE.md +185 -0
  4. package/css/pmpo.css +9 -0
  5. package/dist/111.js +1 -1
  6. package/dist/111.js.map +1 -1
  7. package/dist/128.js +1 -1
  8. package/dist/128.js.map +1 -1
  9. package/dist/153.js +1 -1
  10. package/dist/153.js.map +1 -1
  11. package/dist/23.js +1 -1
  12. package/dist/23.js.map +1 -1
  13. package/dist/234.js +1 -1
  14. package/dist/234.js.map +1 -1
  15. package/dist/242.js +1 -1
  16. package/dist/242.js.map +1 -1
  17. package/dist/260.js +1 -1
  18. package/dist/260.js.map +1 -1
  19. package/dist/33.js +1 -1
  20. package/dist/33.js.map +1 -1
  21. package/dist/348.js +1 -1
  22. package/dist/348.js.map +1 -1
  23. package/dist/377.js +1 -1
  24. package/dist/377.js.map +1 -1
  25. package/dist/397.js +2 -0
  26. package/dist/397.js.map +1 -0
  27. package/dist/412.js +1 -1
  28. package/dist/412.js.map +1 -1
  29. package/dist/415.js +1 -1
  30. package/dist/415.js.map +1 -1
  31. package/dist/501.js +1 -1
  32. package/dist/501.js.map +1 -1
  33. package/dist/531.js +1 -1
  34. package/dist/531.js.map +1 -1
  35. package/dist/583.js +1 -1
  36. package/dist/583.js.map +1 -1
  37. package/dist/589.js +1 -1
  38. package/dist/589.js.map +1 -1
  39. package/dist/603.js +1 -1
  40. package/dist/603.js.map +1 -1
  41. package/dist/656.js +1 -1
  42. package/dist/656.js.map +1 -1
  43. package/dist/682.js +1 -1
  44. package/dist/682.js.map +1 -1
  45. package/dist/705.js +1 -1
  46. package/dist/705.js.map +1 -1
  47. package/dist/727.js +1 -1
  48. package/dist/727.js.map +1 -1
  49. package/dist/731.js +1 -1
  50. package/dist/731.js.map +1 -1
  51. package/dist/738.js +1 -1
  52. package/dist/738.js.map +1 -1
  53. package/dist/763.js +1 -1
  54. package/dist/763.js.map +1 -1
  55. package/dist/778.js +1 -1
  56. package/dist/778.js.map +1 -1
  57. package/dist/783.js +1 -1
  58. package/dist/783.js.map +1 -1
  59. package/dist/793.js +1 -1
  60. package/dist/793.js.map +1 -1
  61. package/dist/810.js +1 -1
  62. package/dist/810.js.map +1 -1
  63. package/dist/860.js +1 -1
  64. package/dist/860.js.map +1 -1
  65. package/dist/907.js +1 -1
  66. package/dist/907.js.map +1 -1
  67. package/dist/950.js +1 -1
  68. package/dist/950.js.map +1 -1
  69. package/dist/980.js +1 -1
  70. package/dist/980.js.map +1 -1
  71. package/dist/990.js +1 -1
  72. package/dist/990.js.map +1 -1
  73. package/dist/package-test.js +1 -1
  74. package/dist/package-test.js.map +1 -1
  75. package/dist/package.js +1 -1
  76. package/dist/package.js.map +1 -1
  77. package/eslintrc.json +0 -1
  78. package/files/drugs-props-train-scores.csv +664 -0
  79. package/package.json +11 -7
  80. package/src/package-api.ts +7 -3
  81. package/src/package-test.ts +4 -1
  82. package/src/package.g.ts +21 -9
  83. package/src/package.ts +33 -23
  84. package/src/pareto-optimization/pareto-computations.ts +6 -0
  85. package/src/pareto-optimization/pareto-optimizer.ts +1 -1
  86. package/src/pls/pls-constants.ts +3 -1
  87. package/src/pls/pls-tools.ts +73 -69
  88. package/src/probabilistic-scoring/data-generator.ts +202 -0
  89. package/src/probabilistic-scoring/nelder-mead.ts +204 -0
  90. package/src/probabilistic-scoring/pmpo-defs.ts +141 -3
  91. package/src/probabilistic-scoring/pmpo-utils.ts +240 -126
  92. package/src/probabilistic-scoring/prob-scoring.ts +862 -135
  93. package/src/probabilistic-scoring/stat-tools.ts +141 -6
  94. package/src/tests/anova-tests.ts +1 -1
  95. package/src/tests/classifiers-tests.ts +1 -1
  96. package/src/tests/dim-reduction-tests.ts +1 -1
  97. package/src/tests/linear-methods-tests.ts +1 -1
  98. package/src/tests/mis-vals-imputation-tests.ts +1 -1
  99. package/src/tests/pareto-tests.ts +251 -0
  100. package/src/tests/pmpo-tests.ts +797 -0
  101. package/test-console-output-1.log +303 -239
  102. package/test-record-1.mp4 +0 -0
  103. package/files/mpo-done.ipynb +0 -2123
@@ -1,6 +1,6 @@
1
1
  /* eslint-disable max-len */
2
2
  // Probabilistic scoring (pMPO) features
3
- // Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
3
+ // Source paper https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
4
4
 
5
5
  import * as grok from 'datagrok-api/grok';
6
6
  import * as ui from 'datagrok-api/ui';
@@ -10,15 +10,45 @@ import {MpoProfileEditor} from '@datagrok-libraries/statistics/src/mpo/mpo-profi
10
10
 
11
11
  import '../../css/pmpo.css';
12
12
 
13
- import {getDesiredTables, getDescriptorStatistics, normalPdf, sigmoidS} from './stat-tools';
13
+ import {getDesiredTables, getDescriptorStatistics, getBoolPredictionColumn, getPmpoEvaluation} from './stat-tools';
14
14
  import {MIN_SAMPLES_COUNT, PMPO_NON_APPLICABLE, DescriptorStatistics, P_VAL_TRES_MIN, DESCR_TITLE,
15
- R2_MIN, Q_CUTOFF_MIN, PmpoParams, SCORES_TITLE, DESCR_TABLE_TITLE, PMPO_COMPUTE_FAILED, SELECTED_TITLE,
16
- P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE} from './pmpo-defs';
15
+ R2_MIN, Q_CUTOFF_MIN, PmpoParams, SCORES_TITLE, DESCR_TABLE_TITLE, SELECTED_TITLE,
16
+ P_VAL, DESIRABILITY_COL_NAME, STAT_GRID_HEIGHT, DESIRABILITY_COLUMN_WIDTH, WEIGHT_TITLE,
17
+ P_VAL_TRES_DEFAULT, R2_DEFAULT, Q_CUTOFF_DEFAULT, USE_SIGMOID_DEFAULT, ROC_TRESHOLDS,
18
+ FPR_TITLE, TPR_TITLE, COLORS, THRESHOLD, AUTO_TUNE_MAX_APPLICABLE_ROWS, DEFAULT_OPTIMIZATION_SETTINGS,
19
+ P_VAL_TRES_MAX, R2_MAX, Q_CUTOFF_MAX, OptimalPoint, LOW_PARAMS_BOUNDS, HIGH_PARAMS_BOUNDS, FORMAT,
20
+ EQUALITY_SIGN, SIGN_OPTIONS, THRESHOLDED_DESIRABILITY_COL_NAME, PMPO_COMPUTE_FAILED,
21
+ PmpoInputId, TooltipContent, PmpoValidationResult} from './pmpo-defs';
17
22
  import {addSelectedDescriptorsCol, getDescriptorStatisticsTable, getFilteredByPvalue, getFilteredByCorrelations,
18
23
  getModelParams, getDescrTooltip, saveModel, getScoreTooltip, getDesirabilityProfileJson, getCorrelationTriples,
19
- addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding} from './pmpo-utils';
24
+ addCorrelationColumns, setPvalColumnColorCoding, setCorrColumnColorCoding, PmpoError, getInitCol,
25
+ getBoolDesirabilityColData, isDesirabilityValid,
26
+ getDesirabilityColumnFromCategories,
27
+ getSelectedCategories} from './pmpo-utils';
20
28
  import {getOutputPalette} from '../pareto-optimization/utils';
21
29
  import {OPT_TYPE} from '../pareto-optimization/defs';
30
+ import {optimizeNM} from './nelder-mead';
31
+ import {getMissingValsIndices} from '../missing-values-imputation/knn-imputer';
32
+ import {DesirabilityProfile} from '@datagrok-libraries/statistics/src/mpo/mpo';
33
+
34
+ export type PmpoTrainingResult = {
35
+ params: Map<string, PmpoParams>,
36
+ descrStatsTable: DG.DataFrame,
37
+ selectedByPvalue: string[],
38
+ selectedByCorr: string[],
39
+ };
40
+
41
+ /** Type for pMPO training controls */
42
+ export type Controls = {form: HTMLElement, saveBtn: HTMLButtonElement};
43
+
44
+ /** Type for pMPO elements */
45
+ export type PmpoAppItems = {
46
+ statsGrid: DG.Viewer;
47
+ rocCurve: DG.Viewer;
48
+ confusionMatrix: DG.Viewer;
49
+ controls: Controls;
50
+ profile: DesirabilityProfile | null;
51
+ };
22
52
 
23
53
  /** Class implementing probabilistic MPO (pMPO) model training and prediction */
24
54
  export class Pmpo {
@@ -56,12 +86,6 @@ export class Pmpo {
56
86
  return false;
57
87
  }
58
88
 
59
- // Check desirability
60
- if (desirability.type !== DG.COLUMN_TYPE.BOOL) {
61
- showWarning(`: "${desirability.name}" must be boolean column.`);
62
- return false;
63
- }
64
-
65
89
  if (desirability.stats.stdev === 0) { // TRUE & FALSE
66
90
  showWarning(`: "${desirability.name}" has a single category.`);
67
91
  return false;
@@ -76,8 +100,8 @@ export class Pmpo {
76
100
  return false;
77
101
  }
78
102
 
79
- if (col.stats.missingValueCount > 0) {
80
- showWarning(`: "${col.name}" contains missing values.`);
103
+ if (col.stats.missingValueCount === col.length) {
104
+ showWarning(`: "${col.name}" contains only missing values.`);
81
105
  return false;
82
106
  }
83
107
 
@@ -102,52 +126,118 @@ export class Pmpo {
102
126
  return false;
103
127
  }
104
128
 
105
- let boolColsCount = 0;
106
- let validNumericColsCount = 0;
129
+ let validColsCount = 0;
107
130
 
108
131
  // Check numeric columns and boolean columns
109
132
  for (const col of df.columns) {
110
- if (col.isNumerical) {
111
- if ((col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
112
- ++validNumericColsCount;
113
- } else if (col.type == DG.COLUMN_TYPE.BOOL)
114
- ++boolColsCount;
115
- }
116
-
117
- // Check boolean columns count
118
- if (boolColsCount < 1) {
119
- if (toShowMsg)
120
- grok.shell.warning(PMPO_NON_APPLICABLE + ': no boolean columns.');
121
- return false;
133
+ if (col.isNumerical || (col.type === DG.TYPE.BOOL)) {
134
+ if (col.stats.stdev > 0)
135
+ ++validColsCount;
136
+ }
122
137
  }
123
138
 
124
139
  // Check valid numeric columns count
125
- if (validNumericColsCount < 1) {
140
+ if (validColsCount < 2) {
126
141
  if (toShowMsg)
127
- grok.shell.warning(PMPO_NON_APPLICABLE + ': no numeric columns without missing values and non-zero variance.');
142
+ grok.shell.warning(PMPO_NON_APPLICABLE + ': not enough of non-constant columns.');
128
143
  return false;
129
144
  }
130
145
 
131
146
  return true;
132
147
  } // isTableValid
133
148
 
149
+ /** Fits the pMPO model to the given data and returns training results */
150
+ static fit(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
151
+ pValTresh: number, r2Tresh: number, qCutoff: number, toCheckApplicability: boolean = true): PmpoTrainingResult {
152
+ if (toCheckApplicability) {
153
+ if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
154
+ throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
155
+ }
156
+
157
+ const descriptorNames = descriptors.names();
158
+ const {desired, nonDesired} = getDesiredTables(df, desirability);
159
+
160
+ // Compute descriptors' statistics
161
+ const descrStats = new Map<string, DescriptorStatistics>();
162
+ descriptorNames.forEach((name) => {
163
+ descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
164
+ });
165
+ const descrStatsTable = getDescriptorStatisticsTable(descrStats);
166
+
167
+ // Set p-value column color coding
168
+ setPvalColumnColorCoding(descrStatsTable, pValTresh);
169
+
170
+ // Filter by p-value
171
+ const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
172
+
173
+ if (selectedByPvalue.length < 1)
174
+ throw new PmpoError('Cannot train pMPO model: all descriptors have high p-values (not significant).');
175
+
176
+ // Compute correlation triples
177
+ const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
178
+
179
+ // Filter by correlations
180
+ const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
181
+
182
+ // Add the Selected column
183
+ addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
184
+
185
+ // Add correlation columns
186
+ addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
187
+
188
+ // Set correlation columns color coding
189
+ setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
190
+
191
+ // Compute pMPO parameters - training
192
+ const params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
193
+
194
+ return {
195
+ params: params,
196
+ descrStatsTable: descrStatsTable,
197
+ selectedByPvalue: selectedByPvalue,
198
+ selectedByCorr: selectedByCorr,
199
+ };
200
+ } // fitModelParams
201
+
134
202
  /** Predicts pMPO scores for the given data frame using provided pMPO parameters */
135
- static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, predictionName: string): DG.Column {
203
+ static predict(df: DG.DataFrame, params: Map<string, PmpoParams>, useSigmoid: boolean, predictionName: string): DG.Column {
136
204
  const count = df.rowCount;
137
205
  const scores = new Float64Array(count).fill(0);
138
- let x = 0;
206
+ const colsWithMissingVals: DG.Column[] = [];
139
207
 
140
208
  // Compute pMPO scores (see https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
141
209
  params.forEach((param, name) => {
142
210
  const col = df.col(name);
143
211
 
212
+ const b = param.b;
213
+ const c = param.c;
214
+ const x0 = param.cutoff;
215
+ let weight = param.weight;
216
+ const avg = param.desAvg;
217
+ const std = param.desStd;
218
+ const frac = 1.0 / (2 * std**2);
219
+
144
220
  if (col == null)
145
- throw new Error(`Filed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
221
+ throw new Error(`Failed to apply pMPO: inconsistent data, no column "${name}" in the table "${df.name}"`);
222
+
223
+ if (col.stats.missingValueCount > 0)
224
+ colsWithMissingVals.push(col);
146
225
 
147
226
  const vals = col.getRawData();
148
- for (let i = 0; i < count; ++i) {
149
- x = vals[i];
150
- scores[i] += param.weight * normalPdf(x, param.desAvg, param.desStd) * sigmoidS(x, param.x0, param.b, param.c);
227
+
228
+ if (useSigmoid) {
229
+ if (c > 0) {
230
+ for (let i = 0; i < count; ++i)
231
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac) / (1.0 + b * (c ** (-(vals[i] - x0))));
232
+ } else {
233
+ weight = weight / (1.0 + b);
234
+
235
+ for (let i = 0; i < count; ++i)
236
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
237
+ }
238
+ } else {
239
+ for (let i = 0; i < count; ++i)
240
+ scores[i] += weight * Math.exp(-((vals[i] - avg)**2) * frac);
151
241
  }
152
242
  });
153
243
 
@@ -155,25 +245,49 @@ export class Pmpo {
155
245
  } // predict
156
246
 
157
247
  private params: Map<string, PmpoParams> | null = null;
248
+ private desirabilityProfile: DesirabilityProfile | null = null;
158
249
 
159
250
  private table: DG.DataFrame;
160
251
  private view: DG.TableView;
161
- private boolCols: DG.Column[];
252
+ private desirabilityColumns: DG.Column[];
162
253
  private numericCols: DG.Column[];
254
+ private missingValsIndeces: Map<string, number[]>;
163
255
 
164
- private initTable = grok.data.demo.demog(10);
256
+ private initTable = DG.DataFrame.create();
165
257
 
166
258
  private statGrid = DG.Viewer.grid(this.initTable, {showTitle: true, title: DESCR_TABLE_TITLE});
167
259
 
168
260
  private predictionName = SCORES_TITLE;
261
+ private boolPredictionName = '';
169
262
 
170
263
  private desirabilityProfileRoots = new Map<string, HTMLElement>();
171
- constructor(df: DG.DataFrame) {
264
+
265
+ private tresholdedColumn: DG.Column | null = null;
266
+ private threshColTooltip: string | null = null;
267
+
268
+ private rocCurve = DG.Viewer.scatterPlot(this.initTable, {
269
+ showTitle: true,
270
+ showSizeSelector: false,
271
+ showColorSelector: false,
272
+ });
273
+
274
+ private confusionMatrix = DG.Viewer.fromType('Confusion matrix', this.initTable, {
275
+ xColumnName: 'control',
276
+ yColumnName: 'control',
277
+ showTitle: true,
278
+ title: 'Confusion Matrix',
279
+ descriptionPosition: 'Bottom',
280
+ description: 'Confusion matrix for the predicted vs actual desirability labels.',
281
+ descriptionVisibilityMode: 'Always',
282
+ });
283
+
284
+ constructor(df: DG.DataFrame, view?: DG.TableView) {
172
285
  this.table = df;
173
- this.view = grok.shell.tableView(df.name) ?? grok.shell.addTableView(df);
174
- this.boolCols = this.getBoolCols();
286
+ this.view = view ?? (grok.shell.tableView(df.name) ?? grok.shell.addTableView(df));
287
+ this.desirabilityColumns = this.getDesirabilityColumns();
175
288
  this.numericCols = this.getValidNumericCols();
176
289
  this.predictionName = df.columns.getUnusedName(SCORES_TITLE);
290
+ this.missingValsIndeces = getMissingValsIndices(this.numericCols);
177
291
  };
178
292
 
179
293
  /** Sets the ribbon panels in the table view (removes the first panel) */
@@ -294,14 +408,12 @@ export class Pmpo {
294
408
  } else {
295
409
  const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
296
410
 
297
- if ((colName === DESIRABILITY_COL_NAME) || (colName === WEIGHT_TITLE)) {
298
- const startText = (colName === WEIGHT_TITLE) ? 'No weight' : 'No chart shown';
299
-
411
+ if (colName === WEIGHT_TITLE) {
300
412
  if (!this.desirabilityProfileRoots.has(descriptor)) {
301
413
  if (selectedByPvalue.includes(descriptor))
302
- ui.tooltip.show(`${startText}: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
414
+ ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.`, x, y);
303
415
  else
304
- ui.tooltip.show(`${startText}: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
416
+ ui.tooltip.show(`No weight: <b>${descriptor}</b> is excluded due to a high p-value.`, x, y);
305
417
 
306
418
  return true;
307
419
  }
@@ -341,7 +453,23 @@ export class Pmpo {
341
453
  return;
342
454
 
343
455
  const descriptor = grid.cell(DESCR_TITLE, cell.gridRow).value;
344
- cell.element = this.desirabilityProfileRoots.get(descriptor) ?? ui.div();
456
+ const element = this.desirabilityProfileRoots.get(descriptor);
457
+
458
+ if (element != null)
459
+ cell.element = element;
460
+ else {
461
+ const selected = selectedByPvalue.includes(descriptor);
462
+ const text = selected ? 'highly correlated with other descriptors' : 'statistically insignificant';
463
+ const tooltipMsg = selected ?
464
+ `No chart shown: <b>${descriptor}</b> is excluded due to a high correlation with other descriptors.` :
465
+ `No chart shown: <b>${descriptor}</b> is excluded due to a high p-value.`;
466
+
467
+ const divWithDescription = ui.divText(text);
468
+ divWithDescription.style.color = COLORS.SKIPPED;
469
+ divWithDescription.classList.add('eda-pmpo-centered-text');
470
+ ui.tooltip.bind(divWithDescription, tooltipMsg);
471
+ cell.element = divWithDescription;
472
+ }
345
473
  }); // grid.onCellPrepare
346
474
  } // updateGrid
347
475
 
@@ -352,7 +480,9 @@ export class Pmpo {
352
480
 
353
481
  grid.sort([this.predictionName], [false]);
354
482
 
355
- grid.col(name)!.format = '0.0000';
483
+ const scoresCol = grid.col(name);
484
+ scoresCol!.format = '0.0000';
485
+ scoresCol!.isTextColorCoded = true;
356
486
 
357
487
  // set tooltips
358
488
  grid.onCellTooltip((cell, x, y) => {
@@ -363,6 +493,12 @@ export class Pmpo {
363
493
  ui.tooltip.show(getScoreTooltip(), x, y);
364
494
 
365
495
  return true;
496
+ } else {
497
+ if (this.tresholdedColumn != null && cell.tableColumn.name === this.tresholdedColumn.name) {
498
+ ui.tooltip.show(ui.markdown(this.threshColTooltip ?? ''), x, y);
499
+
500
+ return true;
501
+ }
366
502
  }
367
503
 
368
504
  return false;
@@ -372,7 +508,7 @@ export class Pmpo {
372
508
  } // updateGrid
373
509
 
374
510
  /** Updates the desirability profile data */
375
- private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame): void {
511
+ private updateDesirabilityProfileData(descrStatsTable: DG.DataFrame, useSigmoidalCorrection: boolean): void {
376
512
  if (this.params == null)
377
513
  return;
378
514
 
@@ -380,14 +516,16 @@ export class Pmpo {
380
516
  this.desirabilityProfileRoots.forEach((root) => root.remove());
381
517
  this.desirabilityProfileRoots.clear();
382
518
 
383
- const desirabilityProfile = getDesirabilityProfileJson(this.params, '', '');
519
+ const desirabilityProfile = getDesirabilityProfileJson(this.params, useSigmoidalCorrection, '', '', true);
520
+ this.desirabilityProfile = getDesirabilityProfileJson(this.params, useSigmoidalCorrection, '', '', false);
384
521
 
385
522
  // Set weights
386
523
  const descrNames = descrStatsTable.col(DESCR_TITLE)!.toList();
387
524
  const weightsRaw = descrStatsTable.col(WEIGHT_TITLE)!.getRawData();
388
525
  const props = desirabilityProfile.properties;
526
+ const names: string[] = Object.keys(props);
389
527
 
390
- for (const name of Object.keys(props))
528
+ for (const name of names)
391
529
  weightsRaw[descrNames.indexOf(name)] = props[name].weight;
392
530
 
393
531
  // Set HTML elements
@@ -399,69 +537,106 @@ export class Pmpo {
399
537
  if (rootsCol == null)
400
538
  return;
401
539
 
402
- const rows = rootsCol.querySelectorAll('div.d4-flex-row.ui-div');
540
+ const rows = rootsCol.querySelectorAll('div.d4-flex-row.ui-div.statistics-mpo-row');
403
541
 
404
- rows.forEach((row) => {
542
+ rows.forEach((row, idx) => {
405
543
  const children = row.children;
406
544
  if (children.length < 2) // expecting descriptor name, weight & profile
407
545
  return;
408
546
 
409
- const descrDivChildren = (children[0] as HTMLElement).children;
410
- if (descrDivChildren.length < 1) // expecting 1 div with descriptor name
411
- return;
412
-
413
- const descrName = (descrDivChildren[0] as HTMLElement).innerText;
414
-
415
- this.desirabilityProfileRoots.set(descrName, children[2] as HTMLElement);
547
+ const profileRoot = children[2] as HTMLElement;
548
+ profileRoot.style.width = '100%';
549
+ this.desirabilityProfileRoots.set(names[idx], profileRoot);
416
550
  });
417
551
  } // updateDesirabilityProfileData
418
552
 
419
- /** Fits the pMPO model to the given data and updates the viewers accordingly */
420
- private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
421
- pValTresh: number, r2Tresh: number, qCutoff: number): void {
422
- if (!Pmpo.isApplicable(descriptors, desirability, pValTresh, r2Tresh, qCutoff))
423
- throw new Error('Failed to train pMPO model: the method is not applicable to the inputs');
553
+ /** Updates the ROC curve viewer with the given desirability (labels) and prediction columns
554
+ * @return Best threshold according to Youden's J statistic
555
+ */
556
+ private updateRocCurve(desirability: DG.Column, prediction: DG.Column): number {
557
+ const evaluation = getPmpoEvaluation(desirability, prediction);
558
+
559
+ const rocDf = DG.DataFrame.fromColumns([
560
+ DG.Column.fromFloat32Array(THRESHOLD, ROC_TRESHOLDS),
561
+ DG.Column.fromFloat32Array(FPR_TITLE, evaluation.fpr),
562
+ DG.Column.fromFloat32Array(TPR_TITLE, evaluation.tpr),
563
+ ]);
564
+
565
+ // Add baseline
566
+ rocDf.meta.formulaLines.addLine({
567
+ title: 'Non-informative baseline',
568
+ formula: `\${${TPR_TITLE}} = \${${FPR_TITLE}}`,
569
+ width: 1,
570
+ style: 'dashed',
571
+ min: 0,
572
+ max: 1,
573
+ });
424
574
 
425
- const descriptorNames = descriptors.names();
426
- const {desired, nonDesired} = getDesiredTables(df, desirability);
575
+ this.rocCurve.dataFrame = rocDf;
576
+ this.rocCurve.setOptions({
577
+ xColumnName: FPR_TITLE,
578
+ yColumnName: TPR_TITLE,
579
+ linesOrderColumnName: FPR_TITLE,
580
+ linesWidth: 5,
581
+ markerType: 'dot',
582
+ title: `ROC Curve (AUC = ${evaluation.auc.toFixed(3)})`,
583
+ });
427
584
 
428
- // Compute descriptors' statistics
429
- const descrStats = new Map<string, DescriptorStatistics>();
430
- descriptorNames.forEach((name) => {
431
- descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
585
+ return evaluation.threshold;
586
+ } // updateRocCurve
587
+
588
+ /** Updates the confusion matrix viewer with the given data frame, desirability column name, and best threshold */
589
+ private updateConfusionMatrix(df: DG.DataFrame, desColName: string, bestThreshold: number): void {
590
+ this.confusionMatrix.dataFrame = df;
591
+ this.confusionMatrix.setOptions({
592
+ xColumnName: desColName,
593
+ yColumnName: this.boolPredictionName,
594
+ description: `Threshold: ${bestThreshold.toFixed(3)} (optimized via Youden's J)`,
595
+ title: desColName + ' Confusion Matrix',
432
596
  });
433
- const descrStatsTable = getDescriptorStatisticsTable(descrStats);
597
+ } // updateConfusionMatrix
434
598
 
435
- // Set p-value column color coding
436
- setPvalColumnColorCoding(descrStatsTable, pValTresh);
599
+ /** Sets null values for the predicted scores in rows with missing values in any of the descriptors */
600
+ private getIndecesOfMissingValues(colNames: string[]): number[] {
601
+ const indeces: number[] = [];
437
602
 
438
- // Filter by p-value
439
- const selectedByPvalue = getFilteredByPvalue(descrStatsTable, pValTresh);
603
+ colNames.forEach((name) => {
604
+ const inds = this.missingValsIndeces.get(name);
440
605
 
441
- // Compute correlation triples
442
- const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
606
+ if (inds != null)
607
+ indeces.push(...inds);
608
+ });
443
609
 
444
- // Filter by correlations
445
- const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, r2Tresh, correlationTriples);
610
+ return indeces;
611
+ }
446
612
 
447
- // Add the Selected column
448
- addSelectedDescriptorsCol(descrStatsTable, selectedByCorr);
613
+ /** Sets null values for the predicted scores in rows with missing values in any of the descriptors */
614
+ private setNulls(scores: DG.Column, indeces: number[]): void {
615
+ const raw = scores.getRawData();
616
+ indeces.forEach((ind) => raw[ind] = DG.FLOAT_NULL);
617
+ }
449
618
 
450
- // Add correlation columns
451
- addCorrelationColumns(descrStatsTable, descriptorNames, correlationTriples, selectedByCorr);
619
+ /** Fits the pMPO model to the given data and updates the viewers accordingly */
620
+ private fitAndUpdateViewers(df: DG.DataFrame, descriptors: DG.ColumnList, desirability: DG.Column,
621
+ pValTresh: number, r2Tresh: number, qCutoff: number, useSigmoid: boolean): void {
622
+ const trainResult = Pmpo.fit(df, descriptors, desirability, pValTresh, r2Tresh, qCutoff);
623
+ this.params = trainResult.params;
624
+ const descrStatsTable = trainResult.descrStatsTable;
625
+ const selectedByPvalue = trainResult.selectedByPvalue;
626
+ const selectedByCorr = trainResult.selectedByCorr;
452
627
 
453
- // Set correlation columns color coding
454
- setCorrColumnColorCoding(descrStatsTable, descriptorNames, r2Tresh);
628
+ const descriptorNames = descriptors.names();
455
629
 
456
- // Compute pMPO parameters - training
457
- this.params = getModelParams(desired, nonDesired, selectedByCorr, qCutoff);
630
+ const prediction = Pmpo.predict(df, this.params, useSigmoid, this.predictionName);
458
631
 
459
- //const weightsTable = getWeightsTable(this.params);
460
- const prediction = Pmpo.predict(df, this.params, this.predictionName);
632
+ // Set nulls for rows with missing values in any of the selected descriptors
633
+ const indecesOfMissingVals = this.getIndecesOfMissingValues(selectedByCorr);
634
+ this.setNulls(prediction, indecesOfMissingVals);
461
635
 
462
636
  // Mark predictions with a color
463
637
  prediction.colors.setLinear(getOutputPalette(OPT_TYPE.MAX), {min: prediction.stats.min, max: prediction.stats.max});
464
638
 
639
+ // Remove existing prediction column and add the new one
465
640
  df.columns.remove(this.predictionName);
466
641
  df.columns.add(prediction);
467
642
 
@@ -469,10 +644,25 @@ export class Pmpo {
469
644
  this.updateGrid();
470
645
 
471
646
  // Update desirability profile roots map
472
- this.updateDesirabilityProfileData(descrStatsTable);
647
+ this.updateDesirabilityProfileData(descrStatsTable, useSigmoid);
473
648
 
474
649
  // Update statistics grid
475
650
  this.updateStatisticsGrid(descrStatsTable, descriptorNames, selectedByPvalue, selectedByCorr);
651
+
652
+ // Update ROC curve
653
+ const bestThreshold = this.updateRocCurve(desirability, prediction);
654
+
655
+ // Update desirability prediction column
656
+ const desColName = desirability.name;
657
+ df.columns.remove(this.boolPredictionName);
658
+ this.boolPredictionName = df.columns.getUnusedName(desColName + '(predicted)');
659
+ const boolPrediction = getBoolPredictionColumn(prediction, bestThreshold, this.boolPredictionName);
660
+ df.columns.add(boolPrediction);
661
+
662
+ // Update confusion matrix
663
+ this.updateConfusionMatrix(df, desColName, bestThreshold);
664
+
665
+ this.view.dataFrame.selection.setAll(false, true);
476
666
  } // fitAndUpdateViewers
477
667
 
478
668
  /** Runs the pMPO model training application */
@@ -480,7 +670,7 @@ export class Pmpo {
480
670
  const dockMng = this.view.dockManager;
481
671
 
482
672
  // Inputs form
483
- dockMng.dock(this.getInputForm(), DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
673
+ dockMng.dock(this.getInputForm(true).form, DG.DOCK_TYPE.LEFT, null, undefined, 0.1);
484
674
 
485
675
  // Dock viewers
486
676
  const gridNode = dockMng.findNode(this.view.grid.root);
@@ -488,150 +678,687 @@ export class Pmpo {
488
678
  if (gridNode == null)
489
679
  throw new Error('Failed to train pMPO: missing a grid in the table view.');
490
680
 
491
- dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
681
+ // Dock statistics grid
682
+ const statGridNode = dockMng.dock(this.statGrid, DG.DOCK_TYPE.DOWN, gridNode, undefined, 0.5);
683
+
684
+ // Dock ROC curve
685
+ const rocNode = dockMng.dock(this.rocCurve, DG.DOCK_TYPE.RIGHT, statGridNode, undefined, 0.3);
686
+
687
+ // Dock confusion matrix
688
+ dockMng.dock(this.confusionMatrix, DG.DOCK_TYPE.RIGHT, rocNode, undefined, 0.2);
492
689
 
493
690
  this.setRibbons();
494
691
  } // runTrainingApp
495
692
 
693
+ /** Runs the pMPO model training application */
694
+ public getPmpoAppItems(): PmpoAppItems {
695
+ return {
696
+ statsGrid: this.statGrid,
697
+ rocCurve: this.rocCurve,
698
+ confusionMatrix: this.confusionMatrix,
699
+ controls: this.getInputForm(false),
700
+ profile: this.desirabilityProfile,
701
+ };
702
+ } // getViewers
703
+
496
704
  /** Creates and returns the input form for pMPO model training */
497
- private getInputForm(): HTMLElement {
705
+ private getInputForm(addBtn: boolean): Controls {
498
706
  const form = ui.form([]);
499
707
  form.append(ui.h2('Training data'));
500
- const numericColNames = this.numericCols.map((col) => col.name);
708
+ const initDesirability = getInitCol(this.desirabilityColumns);
709
+
710
+ // returns the desirability column to be used for computations, based on the input desirability column and threshold settings
711
+ const getDesirabilityColumn = (): DG.Column => {
712
+ // remove existing thresholded column if exists
713
+ if (this.tresholdedColumn != null) {
714
+ this.table.columns.remove(this.tresholdedColumn.name);
715
+ this.tresholdedColumn = null;
716
+ }
717
+
718
+ if (desInput.value!.type === DG.COLUMN_TYPE.BOOL)
719
+ return desInput.value!;
720
+
721
+ const boolDesirabilityData = (desInput.value!.type === DG.COLUMN_TYPE.STRING) ?
722
+ getDesirabilityColumnFromCategories(desInput.value!, desirableCategoriesInput!.value!) :
723
+ getBoolDesirabilityColData(
724
+ desInput.value!,
725
+ desirabilityThresholdInput.value!,
726
+ signInput.value as EQUALITY_SIGN,
727
+ );
728
+
729
+ this.tresholdedColumn = boolDesirabilityData.column;
730
+ this.threshColTooltip = boolDesirabilityData.tooltip;
731
+
732
+ this.tresholdedColumn.name = this.table.columns.getUnusedName(THRESHOLDED_DESIRABILITY_COL_NAME);
733
+ this.table.columns.add(this.tresholdedColumn);
734
+
735
+ return this.tresholdedColumn;
736
+ }; // getDesirabilityColumn
501
737
 
502
738
  // Function to run computations on input changes
503
739
  const runComputations = () => {
740
+ if (!areInputsValid())
741
+ return;
742
+
504
743
  try {
505
744
  this.fitAndUpdateViewers(
506
745
  this.table,
507
746
  DG.DataFrame.fromColumns(descrInput.value).columns,
508
- this.table.col(desInput.value!)!,
509
- pInput.value!,
510
- rInput.value!,
511
- qInput.value!,
747
+ getDesirabilityColumn(),
748
+ pInput.value!,
749
+ rInput.value!,
750
+ qInput.value!,
751
+ useSigmoidInput.value,
512
752
  );
513
753
  } catch (err) {
514
- grok.shell.error(err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.');
754
+ if (err instanceof PmpoError) {
755
+ grok.shell.warning(err.message);
756
+ ui.tooltip.bind(desInput.input, err.message);
757
+ ui.tooltip.bind(descrInput.input, err.message);
758
+ } else {
759
+ const msg = err instanceof Error ? err.message : PMPO_COMPUTE_FAILED + ': the platform issue.';
760
+ grok.shell.error(msg);
761
+ ui.tooltip.bind(desInput.input, msg);
762
+ ui.tooltip.bind(descrInput.input, msg);
763
+ };
764
+
765
+ desInput.input.classList.add('d4-invalid');
766
+ descrInput.input.classList.add('d4-invalid');
515
767
  }
516
- };
768
+ }; // runComputations
517
769
 
518
770
  // Descriptor columns input
519
771
  const descrInput = ui.input.columns('Descriptors', {
520
772
  table: this.table,
521
773
  nullable: false,
522
- available: numericColNames,
523
- checked: numericColNames,
774
+ available: this.numericCols.map((col) => col.name),
775
+ checked: this.numericCols.filter((col) => {
776
+ return (col.name !== initDesirability.name) && (col.stats.stdev > 0) && (col.stats.missingValueCount < col.length);
777
+ }).map((col) => col.name),
524
778
  tooltipText: 'Descriptor columns used for model construction.',
525
779
  onValueChanged: (value) => {
526
- if (value != null)
527
- runComputations();
780
+ if (value != null) {
781
+ areTunedSettingsUsed = false;
782
+ checkAutoTuneAndRun();
783
+ }
528
784
  },
529
785
  });
530
786
  form.append(descrInput.root);
531
787
 
532
- // Desirability column input
533
- const desInput = ui.input.choice('Desirability', {
788
+ descrInput.addValidator(() => {
789
+ if (descrInput.value == null || descrInput.value.length < 1)
790
+ return 'Select at least one descriptor column.';
791
+ if (desInput.value != null && descrInput.value.includes(desInput.value))
792
+ return 'Desirability column cannot be used as a descriptor.';
793
+ const zeroStdevCols = descrInput.value.filter((col) => col.stats.stdev === 0).map((col) => col.name);
794
+ if (zeroStdevCols.length > 0)
795
+ return `Descriptor columns with zero variance: ${zeroStdevCols.join(', ')}`;
796
+ const nullCols = descrInput.value.filter((col) => col.stats.missingValueCount === col.length).map((col) => col.name);
797
+ if (nullCols.length > 0)
798
+ return `Descriptor columns with only missing values: ${nullCols.join(', ')}`;
799
+ return null;
800
+ });
801
+
802
+ // Desirability column input and related controls
803
+ const setVisibilityOfDesirabilityAuxInputs = (value: DG.Column) => {
804
+ if (value.type === DG.COLUMN_TYPE.BOOL)
805
+ desOptionsInputDiv.hidden = true;
806
+ else {
807
+ desOptionsInputDiv.hidden = false;
808
+ const isString = (value.type === DG.COLUMN_TYPE.STRING);
809
+ desirabilityThresholdInput.root.hidden = isString;
810
+ signInput.root.hidden = isString;
811
+ }
812
+ }; // setVisibilityOfDesirabilityAuxInputs
813
+
814
+ const desInput = ui.input.column('Desirability', {
534
815
  nullable: false,
535
- value: this.boolCols[0].name,
536
- items: this.boolCols.map((col) => col.name),
816
+ value: initDesirability,
817
+ table: this.table,
818
+ filter: (col) => this.desirabilityColumns.includes(col),
537
819
  tooltipText: 'Desirability column.',
538
820
  onValueChanged: (value) => {
539
- if (value != null)
540
- runComputations();
541
- },
821
+ if (value != null) {
822
+ updateDesirableCategoriesInput();
823
+ setVisibilityOfDesirabilityAuxInputs(value);
824
+ areComputationsBlocked = true;
825
+ desirabilityThresholdInput.value = Math.round(value.stats.avg * 100) / 100;
826
+ areComputationsBlocked = false;
827
+ areTunedSettingsUsed = false;
828
+ checkAutoTuneAndRun();
829
+ }
830
+ }, // onValueChanged
542
831
  });
543
832
  form.append(desInput.root);
544
833
 
545
- const header = ui.h2('Thresholds');
546
- ui.tooltip.bind(header, 'Settings of the pMPO model training.');
834
+ desInput.addValidator(() => {
835
+ if (desInput.value == null)
836
+ return 'Select a desirability column.';
837
+ if (descrInput.value != null && descrInput.value.includes(desInput.value))
838
+ return 'Desirability column cannot be used as a descriptor.';
839
+ if (desInput.value.type === DG.COLUMN_TYPE.BOOL) {
840
+ if (desInput.value.stats.stdev === 0)
841
+ return 'All desirability values are the same - scoring is not feasible.';
842
+ } else if (desInput.value.type === DG.COLUMN_TYPE.STRING) {
843
+ if (desInput.value.categories.length < 2)
844
+ return 'String desirability column must have at least 2 categories.';
845
+ } else {
846
+ if (desInput.value.stats.stdev === 0) {
847
+ return desInput.value.stats.missingValueCount < desInput.value.length ?
848
+ 'All desirability values are the same - scoring is not feasible.' :
849
+ 'Empty column cannot be used as desirability column.';
850
+ }
851
+ if (desirabilityThresholdInput.value == null)
852
+ return 'Specify non-null desirability threshold.';
853
+ if (!isDesirabilityValid(desInput.value, desirabilityThresholdInput.value, signInput.value as EQUALITY_SIGN)) {
854
+ return `All compounds are either desired or non-desired for ${desInput.value.name} ` +
855
+ `${signInput.value} ${desirabilityThresholdInput.value}. Adjust the threshold or condition.`;
856
+ }
857
+ }
858
+ return null;
859
+ });
860
+
861
+ let areComputationsBlocked = false;
862
+
863
+ const signInput = ui.input.choice('Condition', {
864
+ value: EQUALITY_SIGN.DEFAULT,
865
+ items: SIGN_OPTIONS,
866
+ nullable: false,
867
+ tooltipText: 'How to compare numeric Desirability column values against the threshold.',
868
+ onValueChanged: (_value) => {
869
+ areTunedSettingsUsed = false;
870
+ checkAutoTuneAndRun();
871
+ },
872
+ });
873
+
874
+ const desirabilityThresholdInput = ui.input.float('Threshold', {
875
+ value: Math.round(initDesirability.stats.avg * 100) / 100,
876
+ nullable: false,
877
+ tooltipText: 'Boundary value that separates desired from non-desired compounds.',
878
+ format: '0.00',
879
+ onValueChanged: (value) => {
880
+ if (value != null) {
881
+ if (areComputationsBlocked)
882
+ return;
883
+ areTunedSettingsUsed = false;
884
+ checkAutoTuneAndRun();
885
+ }
886
+ },
887
+ });
888
+
889
+ desirabilityThresholdInput.addValidator(() => {
890
+ if (desInput.value == null || desInput.value.type === DG.COLUMN_TYPE.BOOL ||
891
+ desInput.value.type === DG.COLUMN_TYPE.STRING)
892
+ return null;
893
+ if (desirabilityThresholdInput.value == null)
894
+ return 'Specify non-null desirability threshold.';
895
+ if (!isDesirabilityValid(desInput.value, desirabilityThresholdInput.value, signInput.value as EQUALITY_SIGN))
896
+ return 'Adjust the threshold to get both desired and non-desired groups.';
897
+ return null;
898
+ });
899
+
900
+ const desOptionsInputDiv = ui.divV([signInput.root, desirabilityThresholdInput.root]);
901
+
902
+ form.append(desOptionsInputDiv);
903
+
904
+ let desirableCategoriesInput: DG.InputBase<string[] | null> | null = null;
905
+
906
+ // For string columns - input for selecting which categories are considered desirable
907
+ const updateDesirableCategoriesInput = () => {
908
+ if (desirableCategoriesInput != null) {
909
+ desirableCategoriesInput.root.remove();
910
+ desirableCategoriesInput = null;
911
+ }
912
+
913
+ if (desInput.value?.type === DG.COLUMN_TYPE.STRING) {
914
+ desirableCategoriesInput = ui.input.multiChoice('Preferred', {
915
+ value: getSelectedCategories(desInput.value!.categories),
916
+ items: desInput.value!.categories,
917
+ nullable: false,
918
+ tooltipText: 'Select which categories should be treated as desirable.',
919
+ onValueChanged: (value) => {
920
+ if (value != null) {
921
+ if (areComputationsBlocked)
922
+ return;
923
+ areTunedSettingsUsed = false;
924
+ checkAutoTuneAndRun();
925
+ }
926
+ },
927
+ });
928
+
929
+ desirableCategoriesInput.addValidator(() => {
930
+ if (desirableCategoriesInput!.value == null || desirableCategoriesInput!.value.length === 0)
931
+ return 'Select at least one preferable category.';
932
+ if (desInput.value != null && desirableCategoriesInput!.value.length === desInput.value.categories.length)
933
+ return 'At least one category must be non-preferable.';
934
+ return null;
935
+ });
936
+
937
+ desOptionsInputDiv.append(desirableCategoriesInput.root);
938
+ }
939
+ }; // updateDesirableCategoriesInput
940
+
941
+ setVisibilityOfDesirabilityAuxInputs(desInput.value!);
942
+
943
+ // Settings inputs
944
+
945
+ const header = ui.h2('Settings');
547
946
  form.append(header);
947
+ ui.tooltip.bind(header, 'Settings of the pMPO model.');
948
+
949
+ // use sigmoid correction
950
+ const useSigmoidInput = ui.input.bool('\u03C3 correction', {
951
+ value: USE_SIGMOID_DEFAULT,
952
+ tooltipText: 'Use the sigmoidal correction to the weighted Gaussian scores.',
953
+ onValueChanged: (_value) => {
954
+ areTunedSettingsUsed = false;
955
+ checkAutoTuneAndRun();
956
+ },
957
+ });
958
+ form.append(useSigmoidInput.root);
959
+
960
+ const toUseAutoTune = (this.table.rowCount <= AUTO_TUNE_MAX_APPLICABLE_ROWS);
961
+
962
+ // Flag indicating whether optimal parameters from auto-tuning are currently used
963
+ let areTunedSettingsUsed = false;
964
+
965
+ // Auto-tune parameters and run computations; if auto-tune is not applicable, just run computations with current settings
966
+ const setOptimalParametersAndRun = async () => {
967
+ await new Promise((resolve) => setTimeout(resolve, 50));
968
+
969
+ if (!areInputsValid())
970
+ return;
971
+
972
+ if (!areTunedSettingsUsed) {
973
+ const optimalSettings = await this.getOptimalSettings(
974
+ DG.DataFrame.fromColumns(descrInput.value).columns,
975
+ getDesirabilityColumn(),
976
+ useSigmoidInput.value,
977
+ );
978
+
979
+ if (optimalSettings.state === 'success') {
980
+ pInput.value = Math.max(optimalSettings.pValTresh, P_VAL_TRES_MIN);
981
+ rInput.value = Math.max(optimalSettings.r2Tresh, R2_MIN);
982
+ qInput.value = Math.max(optimalSettings.qCutoff, Q_CUTOFF_MIN);
983
+ areTunedSettingsUsed = true;
984
+ runComputations();
985
+ } else
986
+ grok.shell.warning(optimalSettings.msg);
987
+ /*descrInput.input.classList.add('d4-invalid');
988
+ desInput.input.classList.add('d4-invalid');
989
+ ui.tooltip.bind(descrInput.input, optimalSettings.msg);
990
+ ui.tooltip.bind(desInput.input, optimalSettings.msg);*/
991
+ } else
992
+ runComputations();
993
+ }; // setOptimalParametersAndRun
994
+
995
+ // Validates all inputs before running computations using registered validators
996
+ const areInputsValid = (): boolean => {
997
+ const results = [
998
+ descrInput.validate(),
999
+ desInput.validate(),
1000
+ desirabilityThresholdInput.validate(),
1001
+ pInput.validate(),
1002
+ rInput.validate(),
1003
+ qInput.validate(),
1004
+ ];
1005
+
1006
+ if (desirableCategoriesInput != null)
1007
+ results.push(desirableCategoriesInput.validate());
1008
+
1009
+ return results.every((r) => r);
1010
+ }; // areInputsValid
1011
+
1012
+ const checkAutoTuneAndRun = () => {
1013
+ if (autoTuneInput.value)
1014
+ setOptimalParametersAndRun();
1015
+ else
1016
+ runComputations();
1017
+ };
1018
+
1019
+ // autotuning input
1020
+ const autoTuneInput = ui.input.bool('Auto-tuning', {
1021
+ value: false,
1022
+ tooltipText: 'Automatically select optimal p-value, R², and q-cutoff by maximizing AUC.',
1023
+ onValueChanged: async (value) => {
1024
+ setEnability(!value);
1025
+
1026
+ if (areTunedSettingsUsed)
1027
+ return;
1028
+
1029
+ // If auto-tuning is turned on, set optimal parameters and run computations
1030
+ if (value)
1031
+ await setOptimalParametersAndRun();
1032
+ else
1033
+ runComputations();
1034
+ },
1035
+ });
1036
+ form.append(autoTuneInput.root);
548
1037
 
549
1038
  // p-value threshold input
550
1039
  const pInput = ui.input.float('p-value', {
551
1040
  nullable: false,
552
1041
  min: P_VAL_TRES_MIN,
553
- max: 1,
554
- step: 0.01,
555
- value: 0.05,
556
- tooltipText: 'Descriptors with p-values above this threshold are excluded.',
1042
+ max: P_VAL_TRES_MAX,
1043
+ step: 0.001,
1044
+ value: P_VAL_TRES_DEFAULT,
1045
+ // @ts-ignore
1046
+ format: FORMAT,
1047
+ tooltipText: 'P-value threshold. Descriptors with p-values above this threshold are excluded.',
557
1048
  onValueChanged: (value) => {
558
- if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= 1))
1049
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
1050
+ if (autoTuneInput.value)
1051
+ return;
1052
+
1053
+ areTunedSettingsUsed = false;
1054
+ if ((value != null) && (value >= P_VAL_TRES_MIN) && (value <= P_VAL_TRES_MAX))
559
1055
  runComputations();
560
1056
  },
561
1057
  });
562
1058
  form.append(pInput.root);
563
1059
 
1060
+ pInput.addValidator(() => {
1061
+ if (pInput.value == null)
1062
+ return 'P-value is required.';
1063
+ if (pInput.value < P_VAL_TRES_MIN || pInput.value > P_VAL_TRES_MAX)
1064
+ return `P-value must be between ${P_VAL_TRES_MIN} and ${P_VAL_TRES_MAX}.`;
1065
+ return null;
1066
+ });
1067
+
564
1068
  // R² threshold input
565
1069
  const rInput = ui.input.float('R²', {
1070
+ // @ts-ignore
1071
+ format: FORMAT,
566
1072
  nullable: false,
567
1073
  min: R2_MIN,
568
- value: 0.5,
569
- max: 1,
1074
+ value: R2_DEFAULT,
1075
+ max: R2_MAX,
570
1076
  step: 0.01,
571
1077
  // eslint-disable-next-line max-len
572
- tooltipText: 'Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
1078
+ tooltipText: 'Squared correlation threshold. Descriptors with squared correlation above this threshold are considered highly correlated. Among them, the descriptor with the lower p-value is retained.',
573
1079
  onValueChanged: (value) => {
574
- if ((value != null) && (value >= R2_MIN) && (value <= 1))
1080
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
1081
+ if (autoTuneInput.value)
1082
+ return;
1083
+
1084
+ areTunedSettingsUsed = false;
1085
+
1086
+ if ((value != null) && (value >= R2_MIN) && (value <= R2_MAX))
575
1087
  runComputations();
576
1088
  },
577
1089
  });
578
1090
  form.append(rInput.root);
579
1091
 
1092
+ rInput.addValidator(() => {
1093
+ if (rInput.value == null)
1094
+ return 'R² is required.';
1095
+ if (rInput.value < R2_MIN || rInput.value > R2_MAX)
1096
+ return `R² must be between ${R2_MIN} and ${R2_MAX}.`;
1097
+ return null;
1098
+ });
1099
+
580
1100
  // q-cutoff input
581
1101
  const qInput = ui.input.float('q-cutoff', {
1102
+ // @ts-ignore
1103
+ format: FORMAT,
582
1104
  nullable: false,
583
1105
  min: Q_CUTOFF_MIN,
584
- value: 0.05,
585
- max: 1,
1106
+ value: Q_CUTOFF_DEFAULT,
1107
+ max: Q_CUTOFF_MAX,
586
1108
  step: 0.01,
587
1109
  tooltipText: 'Q-cutoff for the pMPO model computation.',
588
1110
  onValueChanged: (value) => {
589
- if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= 1))
1111
+ // Prevent running computations when auto-tuning is on, since parameters will be set automatically
1112
+ if (autoTuneInput.value)
1113
+ return;
1114
+
1115
+ areTunedSettingsUsed = false;
1116
+
1117
+ if ((value != null) && (value >= Q_CUTOFF_MIN) && (value <= Q_CUTOFF_MAX))
590
1118
  runComputations();
591
1119
  },
592
1120
  });
593
1121
  form.append(qInput.root);
594
1122
 
595
- setTimeout(() => runComputations(), 10);
1123
+ qInput.addValidator(() => {
1124
+ if (qInput.value == null)
1125
+ return 'Q-cutoff is required.';
1126
+ if (qInput.value < Q_CUTOFF_MIN || qInput.value > Q_CUTOFF_MAX)
1127
+ return `Q-cutoff must be between ${Q_CUTOFF_MIN} and ${Q_CUTOFF_MAX}.`;
1128
+ return null;
1129
+ });
1130
+
1131
+ const setEnability = (toEnable: boolean) => {
1132
+ pInput.enabled = toEnable;
1133
+ rInput.enabled = toEnable;
1134
+ qInput.enabled = toEnable;
1135
+ };
1136
+
1137
+ setTimeout(() => {
1138
+ runComputations();
1139
+
1140
+ if (toUseAutoTune)
1141
+ autoTuneInput.value = true; // this will trigger setting optimal parameters and running computations
1142
+ else
1143
+ runComputations();
1144
+ }, 10);
596
1145
 
597
1146
  // Save model button
598
- const saveBtn = ui.button('Save model', async () => {
1147
+ const saveBtn = ui.button('Save', async () => {
599
1148
  if (this.params == null) {
600
1149
  grok.shell.warning('Failed to save pMPO model: null parameters.');
601
1150
  return;
602
1151
  }
603
1152
 
604
- saveModel(this.params, this.table.name);
1153
+ saveModel(this.params, this.table.name, useSigmoidInput.value);
605
1154
  }, 'Save model as platform file.');
606
- form.append(saveBtn);
1155
+
1156
+ if (addBtn)
1157
+ form.append(saveBtn);
607
1158
 
608
1159
  const div = ui.div([form]);
609
1160
  div.classList.add('eda-pmpo-input-form');
610
1161
 
611
- return div;
1162
+ return {
1163
+ form: div,
1164
+ saveBtn: saveBtn,
1165
+ };
612
1166
  } // getInputForm
613
1167
 
614
- /** Retrieves boolean columns from the data frame */
615
- private getBoolCols(): DG.Column[] {
1168
+ /** Validates all pMPO inputs and returns structured errors without mutating the DOM */
1169
+ static validateInputs(params: {
1170
+ descriptors: DG.Column[] | null,
1171
+ desirability: DG.Column | null,
1172
+ threshold: number | null,
1173
+ sign: EQUALITY_SIGN,
1174
+ desirableCategories: string[] | null,
1175
+ pValue: number | null,
1176
+ r2: number | null,
1177
+ qCutoff: number | null,
1178
+ }): PmpoValidationResult {
1179
+ const errors = new Map<PmpoInputId, TooltipContent>();
1180
+ const {descriptors, desirability, threshold, sign, desirableCategories, pValue, r2, qCutoff} = params;
1181
+
1182
+ // Settings null or out of range
1183
+ if (pValue == null || r2 == null || qCutoff == null)
1184
+ return {valid: false, errors};
1185
+
1186
+ if ((pValue <= 0) || (pValue > 1) || (r2 < 0) || (r2 > 1) || (qCutoff <= 0) || (qCutoff > 1))
1187
+ return {valid: false, errors};
1188
+
1189
+ // Column inputs null
1190
+ if (descriptors == null || desirability == null)
1191
+ return {valid: false, errors};
1192
+
1193
+ // At least one descriptor
1194
+ if (descriptors.length < 1) {
1195
+ errors.set('descriptors', 'Select at least one descriptor column.');
1196
+ return {valid: false, errors};
1197
+ }
1198
+
1199
+ // Desirability column must not be among descriptors
1200
+ if (descriptors.includes(desirability)) {
1201
+ const msg = 'Desirability column cannot be used as a descriptor.';
1202
+ errors.set('descriptors', msg);
1203
+ errors.set('desirability', msg);
1204
+ return {valid: false, errors};
1205
+ }
1206
+
1207
+ // No zero-variance descriptor columns
1208
+ const zeroStdevCols = descriptors.filter((col) => col.stats.stdev === 0).map((col) => col.name);
1209
+ if (zeroStdevCols.length > 0)
1210
+ errors.set('descriptors', () => ui.markdown(`Descriptor columns with zero variance cannot be used: **${zeroStdevCols.join(', ')}**`));
1211
+
1212
+ // No all-null descriptor columns
1213
+ const nullCols = descriptors.filter((col) => col.stats.missingValueCount === col.length).map((col) => col.name);
1214
+ if (nullCols.length > 0)
1215
+ errors.set('descriptors', () => ui.markdown(`Descriptor columns with only missing values cannot be used: **${nullCols.join(', ')}**`));
1216
+
1217
+ // Validate desirability column based on its type
1218
+ if (desirability.type === DG.COLUMN_TYPE.BOOL) {
1219
+ if (desirability.stats.stdev === 0)
1220
+ errors.set('desirability', 'All desirability values are the same - scoring is not feasible.');
1221
+ } else if (desirability.type === DG.COLUMN_TYPE.STRING) {
1222
+ const catsCount = desirability.categories.length;
1223
+ const selectedCatsCount = desirableCategories?.length ?? 0;
1224
+
1225
+ if (catsCount < 2)
1226
+ errors.set('desirability', 'String desirability column must have at least 2 categories.');
1227
+ else if (selectedCatsCount === 0)
1228
+ errors.set('desirability', 'Select at least one preferable category.');
1229
+ else if (selectedCatsCount === catsCount)
1230
+ errors.set('desirability', 'At least one category must be non-preferable.');
1231
+ } else {
1232
+ // Numeric desirability
1233
+ if (desirability.stats.stdev === 0) {
1234
+ errors.set('desirability',
1235
+ desirability.stats.missingValueCount < desirability.length ?
1236
+ 'All desirability values are the same - scoring is not feasible.' :
1237
+ 'Empty column cannot be used as desirability column.',
1238
+ );
1239
+ } else if (threshold == null)
1240
+ errors.set('desirability', 'Specify non-null desirability threshold.');
1241
+ else if (!isDesirabilityValid(desirability, threshold, sign)) {
1242
+ errors.set('desirability', () => ui.markdown(`All compounds are either desired or non-desired for
1243
+ <div align="center">
1244
+ **${desirability.name} ${sign} ${threshold}.**
1245
+ </div>
1246
+ Adjust the threshold or condition to get both groups.`));
1247
+ errors.set('threshold', 'Adjust the threshold to get both desired and non-desired groups.');
1248
+ }
1249
+ }
1250
+
1251
+ return {valid: !errors.size, errors};
1252
+ } // validateInputs
1253
+
1254
+ /** Retrieves acceptable desirability columns (boolean or numerical with non-zero standard deviation) from the data frame */
1255
+ private getDesirabilityColumns(): DG.Column[] {
616
1256
  const res: DG.Column[] = [];
617
1257
 
618
1258
  for (const col of this.table.columns) {
619
- if ((col.type === DG.COLUMN_TYPE.BOOL) && (col.stats.stdev > 0))
1259
+ if (((col.type === DG.COLUMN_TYPE.BOOL) || (col.isNumerical) || (col.type === DG.COLUMN_TYPE.STRING)))
620
1260
  res.push(col);
621
1261
  }
622
1262
 
623
1263
  return res;
624
- } // getBoolCols
1264
+ } // getDesirabilityColumns
625
1265
 
626
1266
  /** Retrieves valid (numerical, no missing values, non-zero standard deviation) numeric columns from the data frame */
627
1267
  private getValidNumericCols(): DG.Column[] {
628
1268
  const res: DG.Column[] = [];
629
1269
 
630
1270
  for (const col of this.table.columns) {
631
- if ((col.isNumerical) && (col.stats.missingValueCount < 1) && (col.stats.stdev > 0))
1271
+ if (col.isNumerical)
632
1272
  res.push(col);
633
1273
  }
634
1274
 
635
1275
  return res;
636
1276
  } // getValidNumericCols
1277
+
1278
+ /** Fits the pMPO model to the given data and updates the viewers accordingly */
1279
+ private async getOptimalSettings(descriptors: DG.ColumnList, desirability: DG.Column, useSigmoid: boolean): Promise<OptimalPoint> {
1280
+ const pi = DG.TaskBarProgressIndicator.create('Optimizing... ', {cancelable: true});
1281
+
1282
+ try {
1283
+ const descriptorNames = descriptors.names();
1284
+ const {desired, nonDesired} = getDesiredTables(this.table, desirability);
1285
+
1286
+ // Compute descriptors' statistics
1287
+ const descrStats = new Map<string, DescriptorStatistics>();
1288
+ descriptorNames.forEach((name) => {
1289
+ descrStats.set(name, getDescriptorStatistics(desired.col(name)!, nonDesired.col(name)!));
1290
+ });
1291
+ const descrStatsTable = getDescriptorStatisticsTable(descrStats);
1292
+
1293
+ // Filter by p-value
1294
+ const selectedByPvalue = getFilteredByPvalue(descrStatsTable, P_VAL_TRES_DEFAULT);
1295
+ if (selectedByPvalue.length < 1) {
1296
+ pi.close();
1297
+
1298
+ return {
1299
+ pValTresh: 0,
1300
+ r2Tresh: 0,
1301
+ qCutoff: 0,
1302
+ state: 'failed',
1303
+ msg: 'No descriptors passed the p-value threshold filter.',
1304
+ };
1305
+ }
1306
+
1307
+ const correlationTriples = getCorrelationTriples(descriptors, selectedByPvalue);
1308
+
1309
+ const funcToBeMinimized = (point: Float32Array) => {
1310
+ // Filter by correlations
1311
+ const selectedByCorr = getFilteredByCorrelations(descriptors, selectedByPvalue, descrStats, point[0], correlationTriples);
1312
+
1313
+ // Compute pMPO parameters - training
1314
+ const params = getModelParams(desired, nonDesired, selectedByCorr, point[1]);
1315
+
1316
+ // Get predictions
1317
+ const prediction = Pmpo.predict(this.table, params, useSigmoid, this.predictionName);
1318
+
1319
+ // Evaluate predictions and return 1 - AUC (since optimization minimizes the function, but we want to maximize AUC)
1320
+ return 1 - getPmpoEvaluation(desirability, prediction).auc;
1321
+ }; // funcToBeMinimized
1322
+
1323
+ const optimalResult = await optimizeNM(
1324
+ pi,
1325
+ funcToBeMinimized,
1326
+ new Float32Array([R2_DEFAULT, Q_CUTOFF_DEFAULT]),
1327
+ DEFAULT_OPTIMIZATION_SETTINGS,
1328
+ LOW_PARAMS_BOUNDS,
1329
+ HIGH_PARAMS_BOUNDS,
1330
+ );
1331
+
1332
+ const success = !pi.canceled;
1333
+ pi.close();
1334
+
1335
+ if (success) {
1336
+ return {
1337
+ pValTresh: P_VAL_TRES_DEFAULT,
1338
+ r2Tresh: optimalResult.optimalPoint[0],
1339
+ qCutoff: optimalResult.optimalPoint[1],
1340
+ state: 'success',
1341
+ msg: 'Optimization completed successfully.',
1342
+ };
1343
+ } else {
1344
+ return {
1345
+ pValTresh: 0,
1346
+ r2Tresh: 0,
1347
+ qCutoff: 0,
1348
+ state: 'canceled',
1349
+ msg: 'Auto-tuning was canceled by the user.',
1350
+ };
1351
+ }
1352
+ } catch (err) {
1353
+ pi.close();
1354
+
1355
+ return {
1356
+ pValTresh: 0,
1357
+ r2Tresh: 0,
1358
+ qCutoff: 0,
1359
+ state: 'failed',
1360
+ msg: err instanceof Error ? err.message : 'Optimization failed due to an unexpected error.',
1361
+ };
1362
+ }
1363
+ } // getOptimalSettings
637
1364
  }; // Pmpo