@datagrok/eda 1.4.13 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. package/CHANGELOG.md +7 -5
  2. package/dist/111.js +1 -1
  3. package/dist/111.js.map +1 -1
  4. package/dist/128.js +1 -1
  5. package/dist/128.js.map +1 -1
  6. package/dist/153.js +1 -1
  7. package/dist/153.js.map +1 -1
  8. package/dist/23.js +1 -1
  9. package/dist/23.js.map +1 -1
  10. package/dist/234.js +1 -1
  11. package/dist/234.js.map +1 -1
  12. package/dist/242.js +1 -1
  13. package/dist/242.js.map +1 -1
  14. package/dist/260.js +1 -1
  15. package/dist/260.js.map +1 -1
  16. package/dist/33.js +1 -1
  17. package/dist/33.js.map +1 -1
  18. package/dist/348.js +1 -1
  19. package/dist/348.js.map +1 -1
  20. package/dist/377.js +1 -1
  21. package/dist/377.js.map +1 -1
  22. package/dist/397.js +2 -0
  23. package/dist/397.js.map +1 -0
  24. package/dist/412.js +1 -1
  25. package/dist/412.js.map +1 -1
  26. package/dist/415.js +1 -1
  27. package/dist/415.js.map +1 -1
  28. package/dist/501.js +1 -1
  29. package/dist/501.js.map +1 -1
  30. package/dist/531.js +1 -1
  31. package/dist/531.js.map +1 -1
  32. package/dist/583.js +1 -1
  33. package/dist/583.js.map +1 -1
  34. package/dist/589.js +1 -1
  35. package/dist/589.js.map +1 -1
  36. package/dist/603.js +1 -1
  37. package/dist/603.js.map +1 -1
  38. package/dist/656.js +1 -1
  39. package/dist/656.js.map +1 -1
  40. package/dist/682.js +1 -1
  41. package/dist/682.js.map +1 -1
  42. package/dist/705.js +1 -1
  43. package/dist/705.js.map +1 -1
  44. package/dist/727.js +1 -1
  45. package/dist/727.js.map +1 -1
  46. package/dist/731.js +1 -1
  47. package/dist/731.js.map +1 -1
  48. package/dist/738.js +1 -1
  49. package/dist/738.js.map +1 -1
  50. package/dist/763.js +1 -1
  51. package/dist/763.js.map +1 -1
  52. package/dist/778.js +1 -1
  53. package/dist/778.js.map +1 -1
  54. package/dist/783.js +1 -1
  55. package/dist/783.js.map +1 -1
  56. package/dist/793.js +1 -1
  57. package/dist/793.js.map +1 -1
  58. package/dist/810.js +1 -1
  59. package/dist/810.js.map +1 -1
  60. package/dist/860.js +1 -1
  61. package/dist/860.js.map +1 -1
  62. package/dist/907.js +1 -1
  63. package/dist/907.js.map +1 -1
  64. package/dist/950.js +1 -1
  65. package/dist/950.js.map +1 -1
  66. package/dist/980.js +1 -1
  67. package/dist/980.js.map +1 -1
  68. package/dist/990.js +1 -1
  69. package/dist/990.js.map +1 -1
  70. package/dist/package-test.js +1 -1
  71. package/dist/package-test.js.map +1 -1
  72. package/dist/package.js +1 -1
  73. package/dist/package.js.map +1 -1
  74. package/package.json +5 -5
  75. package/src/package.ts +2 -1
  76. package/src/pareto-optimization/pareto-optimizer.ts +1 -1
  77. package/src/pls/pls-constants.ts +3 -1
  78. package/src/pls/pls-tools.ts +73 -69
  79. package/src/probabilistic-scoring/data-generator.ts +48 -3
  80. package/src/probabilistic-scoring/pmpo-defs.ts +30 -2
  81. package/src/probabilistic-scoring/pmpo-utils.ts +143 -52
  82. package/src/probabilistic-scoring/prob-scoring.ts +475 -102
  83. package/src/probabilistic-scoring/stat-tools.ts +1 -1
  84. package/src/tests/pareto-tests.ts +13 -15
  85. package/src/tests/pmpo-tests.ts +643 -3
  86. package/test-console-output-1.log +221 -93
  87. package/test-record-1.mp4 +0 -0
package/package.json CHANGED
@@ -1,18 +1,18 @@
1
1
  {
2
2
  "name": "@datagrok/eda",
3
3
  "friendlyName": "EDA",
4
- "version": "1.4.13",
4
+ "version": "1.5.0",
5
5
  "description": "Exploratory Data Analysis Tools",
6
6
  "dependencies": {
7
7
  "@datagrok-libraries/math": "^1.2.6",
8
- "@datagrok-libraries/ml": "^6.10.8",
9
- "@datagrok-libraries/statistics": "^1.10.0",
8
+ "@datagrok-libraries/ml": "^6.10.10",
9
+ "@datagrok-libraries/statistics": "^1.12.0",
10
10
  "@datagrok-libraries/tutorials": "^1.7.4",
11
- "@datagrok-libraries/utils": "^4.6.5",
11
+ "@datagrok-libraries/utils": "^4.7.0",
12
12
  "@keckelt/tsne": "^1.0.2",
13
13
  "@webgpu/types": "^0.1.40",
14
14
  "cash-dom": "^8.1.1",
15
- "datagrok-api": "^1.26.3",
15
+ "datagrok-api": "^1.27.0",
16
16
  "dayjs": "^1.11.9",
17
17
  "jstat": "^1.9.6",
18
18
  "mathjs": "^15.1.0",
package/src/package.ts CHANGED
@@ -1027,8 +1027,9 @@ export class PackageFunctions {
1027
1027
  'outputs': [{name: 'Synthetic', type: 'dataframe'}],
1028
1028
  })
1029
1029
  static async generatePmpoDataset(@grok.decorators.param({'type': 'int'}) samples: number): Promise<DG.DataFrame> {
1030
- const df = await getSynteticPmpoData(samples);
1030
+ const df = await getSynteticPmpoData(samples, false);
1031
1031
  df.name = 'Synthetic';
1032
1032
  return df;
1033
1033
  }
1034
+
1034
1035
  }
@@ -19,7 +19,7 @@ export class ParetoOptimizer {
19
19
  private toUpdatePcCols = false;
20
20
  private paretoFrontViewer: DG.Viewer;
21
21
  private resultColName: string;
22
- private intervalId: NodeJS.Timeout | null = null;
22
+ private intervalId: ReturnType<typeof setInterval> | null = null;
23
23
  private inputsMap = new Map<string, DG.InputBase>();
24
24
  private pcPlotNode: DG.DockNode | null = null;
25
25
  private inputFormNode: DG.DockNode | null = null;
@@ -17,8 +17,10 @@ export enum ERROR_MSG {
17
17
  ENOUGH = 'Not enough of features',
18
18
  COMP_LIN_PLS = 'Components count must be less than the number of features',
19
19
  COMP_QUA_PLS = 'Too large components count for the quadratic PLS regression',
20
- COMPONENTS = 'Components count must be greater than 1',
20
+ COMP_ROWS = 'Components count must not exceed the number of rows',
21
+ COMPONENTS = 'Components count must be at least 1',
21
22
  INV_INP = 'Invalid inputs',
23
+ NULL_COMPS = 'Components count is not specified',
22
24
  }
23
25
 
24
26
  /** Widget titles */
@@ -36,6 +36,24 @@ export type PlsInput = {
36
36
 
37
37
  type TypedArray = Int32Array | Float32Array | Uint32Array | Float64Array;
38
38
 
39
+ /** Set style for input element depending on the validity of the value */
40
+ function setStyle(valid: boolean, element: HTMLElement, tooltip: string, errorMsg: string) {
41
+ if (valid) {
42
+ element.style.color = COLOR.VALID_TEXT;
43
+ element.style.borderBottomColor = COLOR.VALID_LINE;
44
+ ui.tooltip.bind(element, tooltip);
45
+ } else {
46
+ element.style.color = COLOR.INVALID;
47
+ element.style.borderBottomColor = COLOR.INVALID;
48
+ ui.tooltip.bind(element, () => {
49
+ const hint = ui.label(tooltip);
50
+ const err = ui.label(errorMsg);
51
+ err.style.color = COLOR.INVALID;
52
+ return ui.divV([hint, err]);
53
+ });
54
+ }
55
+ };
56
+
39
57
  /** Return lines */
40
58
  export function getLines(names: string[]): DG.FormulaLine[] {
41
59
  const lines: DG.FormulaLine[] = [];
@@ -137,7 +155,7 @@ function getQuadraticPlsInput(input: PlsInput): PlsInput {
137
155
 
138
156
  for (let j = i; j < colsCount; ++j) {
139
157
  col2 = cols[j];
140
- raw2 = col2.getRawData();
158
+ raw2 = col2.getRawData();
141
159
  qaudrRaw = new Float32Array(rowsCount);
142
160
 
143
161
  for (let k = 0; k < rowsCount; ++k)
@@ -268,7 +286,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
268
286
  });
269
287
 
270
288
 
271
- // 4.3) create lines & circles
289
+ // 4.3) create lines & circles
272
290
  view.addViewer(scoresScatter);
273
291
  scoresScatter.meta.formulaLines.addAll(getLines(scoreNames));
274
292
 
@@ -321,7 +339,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
321
339
  }));
322
340
 
323
341
  // emphasize viewers in the demo case
324
- if (analysisType === PLS_ANALYSIS.DEMO) {
342
+ if (analysisType === PLS_ANALYSIS.DEMO) {
325
343
  grok.shell.windows.help.showHelp(ui.markdown(DEMO_RESULTS_MD));
326
344
 
327
345
  describeElements(
@@ -372,38 +390,32 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
372
390
  return;
373
391
  }
374
392
 
375
- let features: DG.Column[] = numCols.slice(0, numCols.length - 1);
376
- let predict = numCols[numCols.length - 1];
377
- let components = min(numColNames.length - 1, COMPONENTS.DEFAULT as number);
378
- let isQuadratic = false;
379
-
380
- const isPredictValid = () => {
381
- for (const col of features)
382
- if (col.name === predict.name)
383
- return false;
384
- return true;
393
+ const doFeaturesIncludePredict = () => {
394
+ return featuresInput.value.some((col) => col.name === predictInput.value!.name);
385
395
  };
386
396
 
387
397
  const isCompConsistent = () => {
388
- if (components < 1)
398
+ if (componentsInput.value! < 1)
389
399
  return false;
390
400
 
391
- const n = features.length;
401
+ if (componentsInput.value! > table.rowCount)
402
+ return false;
392
403
 
393
- if (isQuadratic)
394
- return components <= (n + 1) * n / 2 + n;
404
+ const n = featuresInput.value.length;
395
405
 
396
- return components <= n;
397
- }
406
+ if (isQuadraticInput.value)
407
+ return componentsInput.value! <= (n + 1) * n / 2 + n;
408
+
409
+ return componentsInput.value! <= n;
410
+ };
398
411
 
399
412
  // response (to predict)
400
413
  const predictInput = ui.input.column(TITLE.PREDICT, {
401
414
  table: table,
402
- value: predict,
415
+ value: numCols[numCols.length - 1],
403
416
  nullable: false,
404
- onValueChanged: (value) => {
405
- predict = value;
406
- updateIputs();
417
+ onValueChanged: (_) => {
418
+ updateInputStyles();
407
419
  },
408
420
  filter: (col: DG.Column) => isValidNumeric(col),
409
421
  tooltipText: HINT.PREDICT,
@@ -413,21 +425,21 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
413
425
  const featuresInput = ui.input.columns(TITLE.USING, {
414
426
  table: table,
415
427
  available: numColNames,
416
- value: features,
417
- onValueChanged: (val) => {
418
- features = val;
419
- updateIputs();
428
+ value: numCols.slice(0, numCols.length - 1),
429
+ onValueChanged: (_) => {
430
+ updateInputStyles();
420
431
  },
421
432
  tooltipText: HINT.FEATURES,
433
+ nullable: false,
422
434
  });
423
435
 
424
436
  // components count
425
437
  const componentsInput = ui.input.int(TITLE.COMPONENTS, {
426
- value: components,
438
+ value: min(numColNames.length - 1, COMPONENTS.DEFAULT as number),
427
439
  showPlusMinus: true,
428
- onValueChanged: (val) => {
429
- components = val;
430
- updateIputs();
440
+ nullable: false,
441
+ onValueChanged: (_) => {
442
+ updateInputStyles();
431
443
  },
432
444
  tooltipText: HINT.COMPONENTS,
433
445
  });
@@ -446,28 +458,14 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
446
458
  dlgRunBtnTooltip = HINT.MVA;
447
459
  }
448
460
 
449
- const setStyle = (valid: boolean, element: HTMLElement, tooltip: string, errorMsg: string) => {
450
- if (valid) {
451
- element.style.color = COLOR.VALID_TEXT;
452
- element.style.borderBottomColor = COLOR.VALID_LINE;
453
- ui.tooltip.bind(element, tooltip);
454
- } else {
455
- element.style.color = COLOR.INVALID;
456
- element.style.borderBottomColor = COLOR.INVALID;
457
- ui.tooltip.bind(element, () => {
458
- const hint = ui.label(tooltip);
459
- const err = ui.label(errorMsg);
460
- err.style.color = COLOR.INVALID;
461
- return ui.divV([hint, err]);
462
- });
463
- }
464
- };
465
-
466
- const updateIputs = () => {
467
- const predValid = isPredictValid();
461
+ const updateInputStyles = () => {
462
+ const featuresValid = featuresInput.value.length >= 1;
463
+ const predValid = featuresValid && !doFeaturesIncludePredict();
468
464
  let compValid: boolean;
469
465
 
470
- if (predValid) {
466
+ if (!featuresValid)
467
+ setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.ENOUGH);
468
+ else if (predValid) {
471
469
  setStyle(true, predictInput.input, HINT.PREDICT, '');
472
470
  setStyle(true, featuresInput.input, HINT.FEATURES, '');
473
471
  } else {
@@ -475,9 +473,12 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
475
473
  setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.PREDICT);
476
474
  }
477
475
 
478
- if (components < 1) {
476
+ if (componentsInput.value == null) {
477
+ setStyle(false, componentsInput.input, HINT.COMPONENTS, ERROR_MSG.NULL_COMPS);
478
+ compValid = false;
479
+ } else if (componentsInput.value < 1) {
479
480
  setStyle(false, componentsInput.input, HINT.COMPONENTS, ERROR_MSG.COMPONENTS);
480
- compValid = false;
481
+ compValid = false;
481
482
  } else {
482
483
  compValid = isCompConsistent();
483
484
 
@@ -486,7 +487,9 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
486
487
  if (predValid)
487
488
  setStyle(true, featuresInput.input, HINT.FEATURES, '');
488
489
  } else {
489
- const errMsg = isQuadratic ? ERROR_MSG.COMP_QUA_PLS : ERROR_MSG.COMP_LIN_PLS;
490
+ const errMsg = componentsInput.value! > table.rowCount ?
491
+ ERROR_MSG.COMP_ROWS :
492
+ isQuadraticInput.value ? ERROR_MSG.COMP_QUA_PLS : ERROR_MSG.COMP_LIN_PLS;
490
493
  setStyle(false, componentsInput.input, HINT.COMPONENTS, errMsg);
491
494
  setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.ENOUGH);
492
495
  }
@@ -497,10 +500,18 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
497
500
  dlg.getButton(TITLE.RUN).disabled = !isValid;
498
501
 
499
502
  return isValid;
503
+ }; // updateInputStyles
504
+
505
+ const getStrColWithUniqueVals = () => {
506
+ for (const col of strCols) {
507
+ if (col.stats.uniqueCount === table.rowCount)
508
+ return col;
509
+ }
510
+ return undefined;
500
511
  };
501
512
 
502
513
  // names of samples
503
- let names = (strCols.length > 0) ? strCols[0] : undefined;
514
+ let names = getStrColWithUniqueVals();
504
515
  const namesInputs = ui.input.column(TITLE.NAMES, {
505
516
  table: table,
506
517
  value: names,
@@ -512,11 +523,10 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
512
523
 
513
524
  // quadratic/linear model
514
525
  const isQuadraticInput = ui.input.bool(TITLE.QUADRATIC, {
515
- value: isQuadratic,
526
+ value: false,
516
527
  tooltipText: HINT.QUADRATIC,
517
- onValueChanged: (val) => {
518
- isQuadratic = val;
519
- updateIputs();
528
+ onValueChanged: (_) => {
529
+ updateInputStyles();
520
530
  },
521
531
  });
522
532
 
@@ -527,21 +537,15 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
527
537
 
528
538
  await performMVA({
529
539
  table: table,
530
- features: DG.DataFrame.fromColumns(features).columns,
531
- predict: predict,
532
- components: components,
533
- isQuadratic: isQuadratic,
540
+ features: DG.DataFrame.fromColumns(featuresInput.value).columns,
541
+ predict: predictInput.value!,
542
+ components: componentsInput.value!,
543
+ isQuadratic: isQuadraticInput.value,
534
544
  names: names,
535
545
  }, analysisType);
536
546
  }, undefined, dlgRunBtnTooltip)
537
547
  .show({x: X_COORD, y: Y_COORD});
538
548
 
539
- // the following delay provides correct styles (see https://reddata.atlassian.net/browse/GROK-15196)
540
- setTimeout(() => {
541
- featuresInput.value = numCols.filter((col) => col !== predict);
542
- features = featuresInput.value;
543
- }, TIMEOUT);
544
-
545
549
  grok.shell.v.append(dlg.root);
546
550
  } // runMVA
547
551
 
@@ -11,12 +11,57 @@ import * as jStat from 'jstat';
11
11
  /** Generates synthetic data for pMPO model training and testing
12
12
  * @param samplesCount Number of samples to generate
13
13
  * @returns DataFrame with generated data */
14
- export async function getSynteticPmpoData(samplesCount: number): Promise<DG.DataFrame> {
14
+ export async function getSynteticPmpoData(samplesCount: number, isTest: boolean = true): Promise<DG.DataFrame> {
15
15
  const df = await grok.dapi.files.readCsv(SOURCE_PATH);
16
16
  const generator = new PmpoDataGenerator(df, 'Drug', 'CNS', 'Smiles');
17
+ const genTable = generator.getGenerated(samplesCount);
18
+
19
+ if (!isTest) {
20
+ genTable.columns.add(DG.Column.fromList(DG.COLUMN_TYPE.BOOL, 'Const bool', new Array(samplesCount).fill(true)));
21
+ genTable.columns.add(DG.Column.fromInt32Array('Const int', new Int32Array(samplesCount).fill(1)));
22
+
23
+ // Add a copy of the first numeric column with 5 missing values
24
+ const firstNumCol = genTable.columns.toList().find((col) => col.isNumerical);
25
+ if (firstNumCol) {
26
+ const colWithMissing = firstNumCol.clone();
27
+ colWithMissing.name = `${firstNumCol.name} (missing)`;
28
+ for (let i = 0; i < Math.min(5, colWithMissing.length); ++i)
29
+ colWithMissing.set(i, DG.FLOAT_NULL);
30
+ genTable.columns.add(colWithMissing);
31
+ }
32
+
33
+ // Add a column with all null values
34
+ genTable.columns.add(DG.Column.fromFloat32Array('Nulls', new Float32Array(samplesCount).fill(DG.FLOAT_NULL)));
35
+
36
+ // Add categorical columns
37
+ const categoricalCols = getCategoricalColumns(genTable.col('CNS')!, samplesCount);
38
+ for (const col of categoricalCols)
39
+ genTable.columns.add(col);
40
+ }
41
+
42
+ return genTable;
43
+ } // getSynteticPmpoData
44
+
45
+ /** Generates categorical columns based on a boolean source column
46
+ * @param sourceBoolCol Source boolean column to base the categorical columns on
47
+ * @param samplesCount Number of samples to generate
48
+ * @returns Array of generated categorical columns */
49
+ function getCategoricalColumns(sourceBoolCol: DG.Column, samplesCount: number): DG.Column[] {
50
+ const source = sourceBoolCol.toList();
51
+ const stringLabels = new Array<string>(samplesCount);
52
+ const threeCats = new Array<string>(samplesCount);
53
+
54
+ for (let i = 0; i < samplesCount; ++i) {
55
+ stringLabels[i] = source[i] ? 'active' : 'non-active';
56
+ threeCats[i] = source[i] ? (Math.random() < 0.5 ? 'perfect' : 'good') : (Math.random() < 0.5 ? 'bad' : 'worst');
57
+ }
17
58
 
18
- return generator.getGenerated(samplesCount);
19
- }
59
+ return [
60
+ DG.Column.fromList(DG.COLUMN_TYPE.STRING, 'CNS (strings)', stringLabels),
61
+ DG.Column.fromList(DG.COLUMN_TYPE.STRING, 'CNS (4 categories)', threeCats),
62
+ DG.Column.fromList(DG.COLUMN_TYPE.STRING, 'Single category', new Array<string>(samplesCount).fill('Unknown')),
63
+ ];
64
+ } // getCategoricalColumns
20
65
 
21
66
  /** Class for generating synthetic data for pMPO model training and testing */
22
67
  export class PmpoDataGenerator {
@@ -1,5 +1,5 @@
1
1
  // Constants and type definitions for probabilistic scoring (pMPO)
2
- // Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
2
+ // Source paper https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
3
3
 
4
4
  /** Minimum number of samples required to compute pMPO */
5
5
  export const MIN_SAMPLES_COUNT = 10;
@@ -208,7 +208,8 @@ export type OptimalPoint = {
208
208
  pValTresh: number,
209
209
  r2Tresh: number,
210
210
  qCutoff: number,
211
- success: boolean,
211
+ state: 'success' | 'canceled' | 'failed',
212
+ msg: string,
212
213
  };
213
214
 
214
215
  /** Minimum bounds for pMPO parameters during optimization */
@@ -216,3 +217,30 @@ export const LOW_PARAMS_BOUNDS = new Float32Array([0.5, Q_CUTOFF_MIN]);
216
217
 
217
218
  /** Maximum bounds for pMPO parameters during optimization */
218
219
  export const HIGH_PARAMS_BOUNDS = new Float32Array([R2_MAX, Q_CUTOFF_MAX]);
220
+
221
+ export enum EQUALITY_SIGN {
222
+ GREATER = '>',
223
+ LESS = '<',
224
+ GREATER_OR_EQUAL = '≥',
225
+ LESS_OR_EQUAL = '≤',
226
+ DEFAULT = LESS_OR_EQUAL,
227
+ };
228
+
229
+ export const SIGN_OPTIONS = [
230
+ EQUALITY_SIGN.GREATER,
231
+ EQUALITY_SIGN.LESS,
232
+ EQUALITY_SIGN.GREATER_OR_EQUAL,
233
+ EQUALITY_SIGN.LESS_OR_EQUAL,
234
+ ];
235
+
236
+ export const THRESHOLDED_DESIRABILITY_COL_NAME = 'Desirability';
237
+
238
+ export const PREFERABLE_CATEGORIES = ['perfect', 'good', 'true', 't', 'g', 'active', 'a', 'yes', 'y'];
239
+
240
+ export type PmpoInputId = 'descriptors' | 'desirability' | 'threshold' | 'categories';
241
+ export type TooltipContent = string | (() => HTMLElement);
242
+
243
+ export interface PmpoValidationResult {
244
+ valid: boolean;
245
+ errors: Map<PmpoInputId, TooltipContent>;
246
+ }
@@ -1,16 +1,18 @@
1
1
  // Utility functions for probabilistic scoring (pMPO)
2
- // Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
2
+ // Source paper https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
3
3
 
4
4
  import * as grok from 'datagrok-api/grok';
5
5
  import * as ui from 'datagrok-api/ui';
6
6
  import * as DG from 'datagrok-api/dg';
7
7
 
8
+ import {generateMpoFileName, MPO_PROFILE_CHANGED_EVENT} from '@datagrok-libraries/statistics/src/mpo/utils';
8
9
  import '../../css/pmpo.css';
9
10
 
10
11
  import {COLORS, DESCR_TABLE_TITLE, DESCR_TITLE, DescriptorStatistics, DesirabilityProfileProperties,
11
12
  DESIRABILITY_COL_NAME, FOLDER, P_VAL, PMPO_COMPUTE_FAILED, PmpoParams, SCORES_TITLE,
12
13
  SELECTED_TITLE, STAT_TO_TITLE_MAP, TINY, WEIGHT_TITLE, CorrelationTriple,
13
- BASIC_RANGE_SIGMA_COEFFS, EXTENDED_RANGE_SIGMA_COEFFS} from './pmpo-defs';
14
+ BASIC_RANGE_SIGMA_COEFFS, EXTENDED_RANGE_SIGMA_COEFFS, EQUALITY_SIGN,
15
+ PREFERABLE_CATEGORIES} from './pmpo-defs';
14
16
  import {computeSigmoidParamsFromX0, getCutoffs, gaussDesirabilityFunc, sigmoidS,
15
17
  solveNormalIntersection} from './stat-tools';
16
18
  import {getColorScaleDiv} from '../pareto-optimization/utils';
@@ -369,38 +371,13 @@ export function getDesirabilityProfileJson(params: Map<string, PmpoParams>, useS
369
371
  */
370
372
  export async function saveModel(params: Map<string, PmpoParams>, modelName: string,
371
373
  useSigmoidalCorrection: boolean): Promise<void> {
372
- let fileName = modelName;
373
- const nameInput = ui.input.string('File', {
374
- value: fileName,
375
- nullable: false,
376
- onValueChanged: (val) => {
377
- fileName = val;
378
- dlg.getButton('Save').disabled = (fileName.length < 1) || (folderName.length < 1);
379
- },
380
- });
381
-
382
- let folderName = FOLDER;
383
- const folderInput = ui.input.string('Folder', {
384
- value: folderName,
385
- nullable: false,
386
- onValueChanged: (val) => {
387
- folderName = val;
388
- dlg.getButton('Save').disabled = (fileName.length < 1) || (folderName.length < 1);
389
- },
374
+ const nameInput = ui.input.string('Name', {value: modelName, nullable: false});
375
+ const descriptionInput = ui.input.textArea('Description', {value: ' ', nullable: true});
376
+ const typeInput = ui.input.bool('Desirability Profile', {
377
+ value: true,
378
+ tooltipText: 'Save the model as an MPO Desirability Profile. If disabled, the model is saved in the pMPO format.',
390
379
  });
391
380
 
392
- const save = async () => {
393
- const path = `${folderName}/${fileName}.json`;
394
- try {
395
- const jsonString = JSON.stringify(objectToSave(), null, 2);
396
- await grok.dapi.files.writeAsText(path, jsonString);
397
- grok.shell.info(`Saved to ${path}`);
398
- } catch (err) {
399
- grok.shell.error(`Failed to save: ${err instanceof Error ? err.message : 'the platform issue'}.`);
400
- }
401
- dlg.close();
402
- };
403
-
404
381
  const objectToSave = () => {
405
382
  if (typeInput.value) {
406
383
  return getDesirabilityProfileJson(
@@ -420,34 +397,29 @@ export async function saveModel(params: Map<string, PmpoParams>, modelName: stri
420
397
  };
421
398
  };
422
399
 
423
- const modelNameInput = ui.input.string('Name', {value: modelName, nullable: true});
424
- const descriptionInput = ui.input.textArea('Description', {value: ' ', nullable: true});
425
- const typeInput = ui.input.bool('Desirability Profile', {
426
- value: true,
427
- tooltipText: 'Save the model as an MPO Desirability Profile. If disabled, the model is saved in the pMPO format.',
428
- });
429
-
430
400
  const dlg = ui.dialog({title: 'Save model'})
431
- .add(ui.h2('Path'))
432
- .add(folderInput)
433
401
  .add(nameInput)
434
- .add(ui.h2('Model'))
435
- .add(modelNameInput)
436
402
  .add(descriptionInput)
437
403
  .add(typeInput)
438
404
  .addButton('Save', async () => {
439
- const exist = await grok.dapi.files.exists(`${folderName}/${fileName}.json`);
440
- if (!exist)
441
- await save();
442
- else {
443
- // Handle overwrite confirmation
444
- ui.dialog({title: 'Warning'})
445
- .add(ui.label('Overwrite existing file?'))
446
- .onOK(async () => await save())
447
- .show();
405
+ try {
406
+ const files = await grok.dapi.files.list(FOLDER);
407
+ const existingFileNames = new Set(files.map((f) => f.name));
408
+ const fileName = generateMpoFileName(nameInput.value, existingFileNames);
409
+ const path = `${FOLDER}/${fileName}`;
410
+ const jsonString = JSON.stringify(objectToSave(), null, 2);
411
+ await grok.dapi.files.writeAsText(path, jsonString);
412
+ grok.events.fireCustomEvent(MPO_PROFILE_CHANGED_EVENT, {});
413
+ grok.shell.info(`Saved to ${path}`);
414
+ } catch (err) {
415
+ grok.shell.error(`Failed to save: ${err instanceof Error ? err.message : 'the platform issue'}.`);
448
416
  }
417
+ dlg.close();
449
418
  })
450
419
  .show();
420
+
421
+ dlg.getButton('Save').disabled = !nameInput.validate();
422
+ nameInput.onInput.subscribe(() => dlg.getButton('Save').disabled = !nameInput.validate());
451
423
  } // saveModel
452
424
 
453
425
  /** Adds columns with correlation coefficients between descriptors.
@@ -601,3 +573,122 @@ export class PmpoError extends Error {
601
573
  this.name = 'PmpoError';
602
574
  }
603
575
  }
576
+
577
+ /** Returns the initial column for the desirability input, preferring boolean columns.
578
+ * @param cols List of columns to choose from.
579
+ * @return The initial column for the desirability input.
580
+ */
581
+ export function getInitCol(cols: DG.Column[]): DG.Column {
582
+ for (const col of cols) {
583
+ if ((col.type === DG.COLUMN_TYPE.BOOL) && (col.stats.stdev > 0))
584
+ return col;
585
+ }
586
+
587
+ for (const col of cols) {
588
+ if ((col.isNumerical) && (col.stats.stdev > 0))
589
+ return col;
590
+ }
591
+
592
+ return cols[0];
593
+ }
594
+
595
+ /** Returns a comparator function based on the given equality sign.
596
+ * @param sign Equality sign ('<', '<=', '>', '>=').
597
+ * @return Comparator function that takes two numbers and returns a boolean.
598
+ */
599
+ function getComparator(sign: EQUALITY_SIGN): (a: number, b: number) => boolean {
600
+ switch (sign) {
601
+ case EQUALITY_SIGN.LESS: return (a, b) => a < b;
602
+ case EQUALITY_SIGN.LESS_OR_EQUAL: return (a, b) => a <= b;
603
+ case EQUALITY_SIGN.GREATER: return (a, b) => a > b;
604
+ case EQUALITY_SIGN.GREATER_OR_EQUAL: return (a, b) => a >= b;
605
+
606
+ default:
607
+ throw new Error(`Unsupported sign: ${sign}`);
608
+ }
609
+ }
610
+
611
+ /** Converts a numeric column to a boolean column based on the given threshold and equality sign.
612
+ * @param numericCol Numeric column to convert.
613
+ * @param threshold Threshold value for comparison.
614
+ * @param sign Equality sign for comparison ('<', '<=', '>', '>=').
615
+ * @return Boolean column resulting from the comparison.
616
+ */
617
+ export function getBoolDesirabilityColData(numericCol: DG.Column,
618
+ threshold: number, sign: EQUALITY_SIGN): {column: DG.Column, tooltip: string} {
619
+ const boolArr = new Array<boolean>(numericCol.length);
620
+ const numericArr = numericCol.getRawData();
621
+
622
+ const comparator = getComparator(sign);
623
+
624
+ for (let i = 0; i < numericCol.length; ++i)
625
+ boolArr[i] = comparator(numericArr[i], threshold);
626
+
627
+ return {
628
+ column: DG.Column.fromList(DG.COLUMN_TYPE.BOOL, '', boolArr),
629
+ tooltip: `Desirability based on the condition:\n\n **${numericCol.name} ${sign} ${threshold}**`,
630
+ };
631
+ }
632
+
633
+ /** Checks whether the desirability column is valid based on the given threshold and equality sign.
634
+ * @param desCol Desirability column to check.
635
+ * @param threshold Threshold value for comparison.
636
+ * @param sign Equality sign for comparison ('<', '<=', '>', '>=').
637
+ * @return True if the desirability column is valid, false otherwise.
638
+ */
639
+ export function isDesirabilityValid(desCol: DG.Column, threshold: number, sign: EQUALITY_SIGN): boolean {
640
+ const min = desCol.stats.min;
641
+ const max = desCol.stats.max;
642
+
643
+ switch (sign) {
644
+ case EQUALITY_SIGN.LESS:
645
+ return (max >= threshold) && (min < threshold);
646
+
647
+ case EQUALITY_SIGN.LESS_OR_EQUAL:
648
+ return (max > threshold) && (min <= threshold);
649
+
650
+ case EQUALITY_SIGN.GREATER:
651
+ return (min <= threshold) && (max > threshold);
652
+
653
+ default:
654
+ return (min < threshold) && (max >= threshold);
655
+ }
656
+ }
657
+
658
+ /** Converts a string column to a boolean column based on the given desirable categories.
659
+ * @param stringCol String column to convert.
660
+ * @param desirableCategories List of categories that should be considered as desirable.
661
+ * @return Boolean column resulting from the conversion and a tooltip describing the desirability.
662
+ */
663
+ export function getDesirabilityColumnFromCategories(stringCol: DG.Column, desirableCategories: string[]):
664
+ {column: DG.Column, tooltip: string} {
665
+ const boolArr = new Array<boolean>(stringCol.length);
666
+ const raw = stringCol.getRawData();
667
+ const categories = stringCol.categories;
668
+
669
+ for (let i = 0; i < stringCol.length; ++i)
670
+ boolArr[i] = desirableCategories.includes(categories[raw[i]]);
671
+
672
+ const nonDesirableCategories = categories.filter((cat) => !desirableCategories.includes(cat));
673
+
674
+ const c = `\u2705 ${desirableCategories.join(', ')}`;
675
+ const unchecked = `\u274c ${nonDesirableCategories.join(', ')}`;
676
+
677
+ return {
678
+ column: DG.Column.fromList(DG.COLUMN_TYPE.BOOL, '', boolArr),
679
+ tooltip: `Desirability based on the selected categories:\n\n **${c}**\n\n **${unchecked}**`,
680
+ };
681
+ } // getDesirabilityColumnFromCategories
682
+
683
+ /** Returns a list of selected categories based on the given list of categories and preferable categories.
684
+ * @param categories List of categories to select from.
685
+ * @return List of selected categories.
686
+ */
687
+ export function getSelectedCategories(categories: string[]): string[] {
688
+ const selected = categories.filter((cat) => PREFERABLE_CATEGORIES.includes(cat));
689
+
690
+ if (selected.length > 0)
691
+ return selected;
692
+
693
+ return [categories[0]];
694
+ }