@datagrok/eda 1.4.13 → 1.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. package/CHANGELOG.md +11 -5
  2. package/dist/111.js +1 -1
  3. package/dist/111.js.map +1 -1
  4. package/dist/128.js +1 -1
  5. package/dist/128.js.map +1 -1
  6. package/dist/153.js +1 -1
  7. package/dist/153.js.map +1 -1
  8. package/dist/23.js +1 -1
  9. package/dist/23.js.map +1 -1
  10. package/dist/234.js +1 -1
  11. package/dist/234.js.map +1 -1
  12. package/dist/242.js +1 -1
  13. package/dist/242.js.map +1 -1
  14. package/dist/260.js +1 -1
  15. package/dist/260.js.map +1 -1
  16. package/dist/33.js +1 -1
  17. package/dist/33.js.map +1 -1
  18. package/dist/348.js +1 -1
  19. package/dist/348.js.map +1 -1
  20. package/dist/377.js +1 -1
  21. package/dist/377.js.map +1 -1
  22. package/dist/397.js +2 -0
  23. package/dist/397.js.map +1 -0
  24. package/dist/412.js +1 -1
  25. package/dist/412.js.map +1 -1
  26. package/dist/415.js +1 -1
  27. package/dist/415.js.map +1 -1
  28. package/dist/501.js +1 -1
  29. package/dist/501.js.map +1 -1
  30. package/dist/531.js +1 -1
  31. package/dist/531.js.map +1 -1
  32. package/dist/583.js +1 -1
  33. package/dist/583.js.map +1 -1
  34. package/dist/589.js +1 -1
  35. package/dist/589.js.map +1 -1
  36. package/dist/603.js +1 -1
  37. package/dist/603.js.map +1 -1
  38. package/dist/656.js +1 -1
  39. package/dist/656.js.map +1 -1
  40. package/dist/682.js +1 -1
  41. package/dist/682.js.map +1 -1
  42. package/dist/705.js +1 -1
  43. package/dist/705.js.map +1 -1
  44. package/dist/727.js +1 -1
  45. package/dist/727.js.map +1 -1
  46. package/dist/731.js +1 -1
  47. package/dist/731.js.map +1 -1
  48. package/dist/738.js +1 -1
  49. package/dist/738.js.map +1 -1
  50. package/dist/763.js +1 -1
  51. package/dist/763.js.map +1 -1
  52. package/dist/778.js +1 -1
  53. package/dist/778.js.map +1 -1
  54. package/dist/783.js +1 -1
  55. package/dist/783.js.map +1 -1
  56. package/dist/793.js +1 -1
  57. package/dist/793.js.map +1 -1
  58. package/dist/810.js +1 -1
  59. package/dist/810.js.map +1 -1
  60. package/dist/860.js +1 -1
  61. package/dist/860.js.map +1 -1
  62. package/dist/907.js +1 -1
  63. package/dist/907.js.map +1 -1
  64. package/dist/950.js +1 -1
  65. package/dist/950.js.map +1 -1
  66. package/dist/980.js +1 -1
  67. package/dist/980.js.map +1 -1
  68. package/dist/990.js +1 -1
  69. package/dist/990.js.map +1 -1
  70. package/dist/package-test.js +1 -1
  71. package/dist/package-test.js.map +1 -1
  72. package/dist/package.js +1 -1
  73. package/dist/package.js.map +1 -1
  74. package/package.json +5 -5
  75. package/src/package.ts +2 -1
  76. package/src/pareto-optimization/pareto-optimizer.ts +1 -1
  77. package/src/pls/pls-constants.ts +8 -1
  78. package/src/pls/pls-tools.ts +176 -74
  79. package/src/probabilistic-scoring/data-generator.ts +48 -3
  80. package/src/probabilistic-scoring/pmpo-defs.ts +30 -2
  81. package/src/probabilistic-scoring/pmpo-utils.ts +143 -52
  82. package/src/probabilistic-scoring/prob-scoring.ts +477 -104
  83. package/src/probabilistic-scoring/stat-tools.ts +1 -1
  84. package/src/tests/pareto-tests.ts +13 -15
  85. package/src/tests/pmpo-tests.ts +643 -3
  86. package/test-console-output-1.log +224 -86
  87. package/test-record-1.mp4 +0 -0
package/package.json CHANGED
@@ -1,18 +1,18 @@
1
1
  {
2
2
  "name": "@datagrok/eda",
3
3
  "friendlyName": "EDA",
4
- "version": "1.4.13",
4
+ "version": "1.5.1",
5
5
  "description": "Exploratory Data Analysis Tools",
6
6
  "dependencies": {
7
7
  "@datagrok-libraries/math": "^1.2.6",
8
- "@datagrok-libraries/ml": "^6.10.8",
9
- "@datagrok-libraries/statistics": "^1.10.0",
8
+ "@datagrok-libraries/ml": "^6.10.10",
9
+ "@datagrok-libraries/statistics": "^1.12.1",
10
10
  "@datagrok-libraries/tutorials": "^1.7.4",
11
- "@datagrok-libraries/utils": "^4.6.5",
11
+ "@datagrok-libraries/utils": "^4.7.0",
12
12
  "@keckelt/tsne": "^1.0.2",
13
13
  "@webgpu/types": "^0.1.40",
14
14
  "cash-dom": "^8.1.1",
15
- "datagrok-api": "^1.26.3",
15
+ "datagrok-api": "^1.27.0",
16
16
  "dayjs": "^1.11.9",
17
17
  "jstat": "^1.9.6",
18
18
  "mathjs": "^15.1.0",
package/src/package.ts CHANGED
@@ -1027,8 +1027,9 @@ export class PackageFunctions {
1027
1027
  'outputs': [{name: 'Synthetic', type: 'dataframe'}],
1028
1028
  })
1029
1029
  static async generatePmpoDataset(@grok.decorators.param({'type': 'int'}) samples: number): Promise<DG.DataFrame> {
1030
- const df = await getSynteticPmpoData(samples);
1030
+ const df = await getSynteticPmpoData(samples, false);
1031
1031
  df.name = 'Synthetic';
1032
1032
  return df;
1033
1033
  }
1034
+
1034
1035
  }
@@ -19,7 +19,7 @@ export class ParetoOptimizer {
19
19
  private toUpdatePcCols = false;
20
20
  private paretoFrontViewer: DG.Viewer;
21
21
  private resultColName: string;
22
- private intervalId: NodeJS.Timeout | null = null;
22
+ private intervalId: ReturnType<typeof setInterval> | null = null;
23
23
  private inputsMap = new Map<string, DG.InputBase>();
24
24
  private pcPlotNode: DG.DockNode | null = null;
25
25
  private inputFormNode: DG.DockNode | null = null;
@@ -17,8 +17,10 @@ export enum ERROR_MSG {
17
17
  ENOUGH = 'Not enough of features',
18
18
  COMP_LIN_PLS = 'Components count must be less than the number of features',
19
19
  COMP_QUA_PLS = 'Too large components count for the quadratic PLS regression',
20
- COMPONENTS = 'Components count must be greater than 1',
20
+ COMP_ROWS = 'Components count must not exceed the number of rows',
21
+ COMPONENTS = 'Components count must be at least 1',
21
22
  INV_INP = 'Invalid inputs',
23
+ NULL_COMPS = 'Components count is not specified',
22
24
  }
23
25
 
24
26
  /** Widget titles */
@@ -44,6 +46,7 @@ export enum TITLE {
44
46
  BROWSE = 'Browse',
45
47
  ANALYSIS = 'Features Analysis',
46
48
  QUADRATIC = 'Quadratic',
49
+ BIAS = 'bias',
47
50
  }
48
51
 
49
52
  /** Tooltips */
@@ -100,6 +103,10 @@ export const X_COORD = 200;
100
103
  export const Y_COORD = 200;
101
104
  export const DELAY = 2000;
102
105
 
106
+ export const MAX_ROWS_IN_PREDICTION_TOOLTIP = 20;
107
+
108
+ export const NUMS_AFTER_COMMA = 3;
109
+
103
110
  /** Curves colors */
104
111
  export enum COLOR {
105
112
  AXIS = '#838383',
@@ -4,9 +4,10 @@ import * as grok from 'datagrok-api/grok';
4
4
  import * as ui from 'datagrok-api/ui';
5
5
  import * as DG from 'datagrok-api/dg';
6
6
 
7
- import {PLS_ANALYSIS, ERROR_MSG, TITLE, HINT, LINK, COMPONENTS, INT, TIMEOUT,
7
+ import {PLS_ANALYSIS, ERROR_MSG, TITLE, HINT, LINK, COMPONENTS,
8
8
  RESULT_NAMES, WASM_OUTPUT_IDX, RADIUS, LINE_WIDTH, COLOR, X_COORD, Y_COORD,
9
- DEMO_INTRO_MD, DEMO_RESULTS_MD, DEMO_RESULTS} from './pls-constants';
9
+ DEMO_INTRO_MD, DEMO_RESULTS_MD, DEMO_RESULTS, NUMS_AFTER_COMMA,
10
+ MAX_ROWS_IN_PREDICTION_TOOLTIP} from './pls-constants';
10
11
  import {checkWasmDimensionReducerInputs, checkColumnType, checkMissingVals, describeElements} from '../utils';
11
12
  import {_partialLeastSquareRegressionInWebWorker} from '../../wasm/EDAAPI';
12
13
  import {carsDataframe} from '../data-generators';
@@ -36,6 +37,37 @@ export type PlsInput = {
36
37
 
37
38
  type TypedArray = Int32Array | Float32Array | Uint32Array | Float64Array;
38
39
 
40
+ /** Set style for input element depending on the validity of the value */
41
+ function setStyle(valid: boolean, element: HTMLElement, tooltip: string, errorMsg: string) {
42
+ if (valid) {
43
+ element.style.color = COLOR.VALID_TEXT;
44
+ element.style.borderBottomColor = COLOR.VALID_LINE;
45
+ ui.tooltip.bind(element, tooltip);
46
+ } else {
47
+ element.style.color = COLOR.INVALID;
48
+ element.style.borderBottomColor = COLOR.INVALID;
49
+ ui.tooltip.bind(element, () => {
50
+ const hint = ui.label(tooltip);
51
+ const err = ui.label(errorMsg);
52
+ err.style.color = COLOR.INVALID;
53
+ return ui.divV([hint, err]);
54
+ });
55
+ }
56
+ };
57
+
58
+ function getModelFormulaTerms(loadingsRegrCoefsTable: DG.DataFrame, bias: number): Map<string, number> {
59
+ const featureNames = loadingsRegrCoefsTable.col(TITLE.FEATURE)!.toList() as string[];
60
+ const regrCoefs = loadingsRegrCoefsTable.col(TITLE.REGR_COEFS)!.getRawData();
61
+
62
+ const terms = new Map([[TITLE.BIAS as string, bias]]);
63
+
64
+ featureNames.forEach((name, idx) => {
65
+ terms.set(name, regrCoefs[idx]);
66
+ });
67
+
68
+ return terms;
69
+ }
70
+
39
71
  /** Return lines */
40
72
  export function getLines(names: string[]): DG.FormulaLine[] {
41
73
  const lines: DG.FormulaLine[] = [];
@@ -97,7 +129,7 @@ export async function getPlsAnalysis(input: PlsInput): Promise<PlsOutput> {
97
129
 
98
130
  /** Return debiased predction by PLS regression */
99
131
  function debiasedPrediction(features: DG.ColumnList, params: DG.Column,
100
- target: DG.Column, biasedPrediction: DG.Column): DG.Column {
132
+ target: DG.Column, biasedPrediction: DG.Column): {debiased: DG.Column, bias: number} {
101
133
  const samples = target.length;
102
134
  const dim = features.length;
103
135
  const rawParams = params.getRawData();
@@ -113,7 +145,7 @@ function debiasedPrediction(features: DG.ColumnList, params: DG.Column,
113
145
  for (let i = 0; i < samples; ++i)
114
146
  debiased[i] = bias + biased[i];
115
147
 
116
- return DG.Column.fromFloat32Array('Debiased', debiased, samples);
148
+ return {debiased: DG.Column.fromFloat32Array('Debiased', debiased, samples), bias: bias};
117
149
  }
118
150
 
119
151
  /** Return an input for the quadratic PLS regression */
@@ -137,7 +169,7 @@ function getQuadraticPlsInput(input: PlsInput): PlsInput {
137
169
 
138
170
  for (let j = i; j < colsCount; ++j) {
139
171
  col2 = cols[j];
140
- raw2 = col2.getRawData();
172
+ raw2 = col2.getRawData();
141
173
  qaudrRaw = new Float32Array(rowsCount);
142
174
 
143
175
  for (let k = 0; k < rowsCount; ++k)
@@ -205,7 +237,8 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
205
237
 
206
238
  // 1. Predicted vs Reference scatter plot
207
239
  // Debias prediction (since PLS center data)
208
- const pred = debiasedPrediction(features, result.regressionCoefficients, input.predict, result.prediction);
240
+ const debiased = debiasedPrediction(features, result.regressionCoefficients, input.predict, result.prediction);
241
+ const pred = debiased.debiased;
209
242
  pred.name = cols.getUnusedName(`${input.predict.name} ${RESULT_NAMES.SUFFIX}`);
210
243
  cols.add(pred);
211
244
  const predictVsReferScatter = view.addViewer(DG.Viewer.scatterPlot(sourceTable, {
@@ -232,6 +265,9 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
232
265
  help: LINK.COEFFS,
233
266
  showValueSelector: false,
234
267
  showStackSelector: false,
268
+ description: `bias = ${debiased.bias.toFixed(NUMS_AFTER_COMMA)}`,
269
+ descriptionVisibilityMode: 'Always',
270
+ descriptionPosition: 'Bottom',
235
271
  }));
236
272
 
237
273
  // 3. Loadings Scatter Plot
@@ -268,7 +304,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
268
304
  });
269
305
 
270
306
 
271
- // 4.3) create lines & circles
307
+ // 4.3) create lines & circles
272
308
  view.addViewer(scoresScatter);
273
309
  scoresScatter.meta.formulaLines.addAll(getLines(scoreNames));
274
310
 
@@ -321,7 +357,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
321
357
  }));
322
358
 
323
359
  // emphasize viewers in the demo case
324
- if (analysisType === PLS_ANALYSIS.DEMO) {
360
+ if (analysisType === PLS_ANALYSIS.DEMO) {
325
361
  grok.shell.windows.help.showHelp(ui.markdown(DEMO_RESULTS_MD));
326
362
 
327
363
  describeElements(
@@ -330,6 +366,10 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
330
366
  ['left', 'left', 'right', 'right', 'left'],
331
367
  );
332
368
  }
369
+
370
+ // Add formula tooltip to the prediction column
371
+ const modelFormulaTerms = getModelFormulaTerms(loadingsRegrCoefsTable, debiased.bias);
372
+ setPredictionTooltip(view, pred, modelFormulaTerms);
333
373
  } // performMVA
334
374
 
335
375
  /** Run multivariate analysis (PLS) */
@@ -372,38 +412,32 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
372
412
  return;
373
413
  }
374
414
 
375
- let features: DG.Column[] = numCols.slice(0, numCols.length - 1);
376
- let predict = numCols[numCols.length - 1];
377
- let components = min(numColNames.length - 1, COMPONENTS.DEFAULT as number);
378
- let isQuadratic = false;
379
-
380
- const isPredictValid = () => {
381
- for (const col of features)
382
- if (col.name === predict.name)
383
- return false;
384
- return true;
415
+ const doFeaturesIncludePredict = () => {
416
+ return featuresInput.value.some((col) => col.name === predictInput.value!.name);
385
417
  };
386
418
 
387
419
  const isCompConsistent = () => {
388
- if (components < 1)
420
+ if (componentsInput.value! < 1)
389
421
  return false;
390
422
 
391
- const n = features.length;
423
+ if (componentsInput.value! > table.rowCount)
424
+ return false;
392
425
 
393
- if (isQuadratic)
394
- return components <= (n + 1) * n / 2 + n;
426
+ const n = featuresInput.value.length;
395
427
 
396
- return components <= n;
397
- }
428
+ if (isQuadraticInput.value)
429
+ return componentsInput.value! <= (n + 1) * n / 2 + n;
430
+
431
+ return componentsInput.value! <= n;
432
+ };
398
433
 
399
434
  // response (to predict)
400
435
  const predictInput = ui.input.column(TITLE.PREDICT, {
401
436
  table: table,
402
- value: predict,
437
+ value: numCols[numCols.length - 1],
403
438
  nullable: false,
404
- onValueChanged: (value) => {
405
- predict = value;
406
- updateIputs();
439
+ onValueChanged: (_) => {
440
+ updateInputStyles();
407
441
  },
408
442
  filter: (col: DG.Column) => isValidNumeric(col),
409
443
  tooltipText: HINT.PREDICT,
@@ -413,21 +447,21 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
413
447
  const featuresInput = ui.input.columns(TITLE.USING, {
414
448
  table: table,
415
449
  available: numColNames,
416
- value: features,
417
- onValueChanged: (val) => {
418
- features = val;
419
- updateIputs();
450
+ value: numCols.slice(0, numCols.length - 1),
451
+ onValueChanged: (_) => {
452
+ updateInputStyles();
420
453
  },
421
454
  tooltipText: HINT.FEATURES,
455
+ nullable: false,
422
456
  });
423
457
 
424
458
  // components count
425
459
  const componentsInput = ui.input.int(TITLE.COMPONENTS, {
426
- value: components,
460
+ value: min(numColNames.length - 1, COMPONENTS.DEFAULT as number),
427
461
  showPlusMinus: true,
428
- onValueChanged: (val) => {
429
- components = val;
430
- updateIputs();
462
+ nullable: false,
463
+ onValueChanged: (_) => {
464
+ updateInputStyles();
431
465
  },
432
466
  tooltipText: HINT.COMPONENTS,
433
467
  });
@@ -446,28 +480,14 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
446
480
  dlgRunBtnTooltip = HINT.MVA;
447
481
  }
448
482
 
449
- const setStyle = (valid: boolean, element: HTMLElement, tooltip: string, errorMsg: string) => {
450
- if (valid) {
451
- element.style.color = COLOR.VALID_TEXT;
452
- element.style.borderBottomColor = COLOR.VALID_LINE;
453
- ui.tooltip.bind(element, tooltip);
454
- } else {
455
- element.style.color = COLOR.INVALID;
456
- element.style.borderBottomColor = COLOR.INVALID;
457
- ui.tooltip.bind(element, () => {
458
- const hint = ui.label(tooltip);
459
- const err = ui.label(errorMsg);
460
- err.style.color = COLOR.INVALID;
461
- return ui.divV([hint, err]);
462
- });
463
- }
464
- };
465
-
466
- const updateIputs = () => {
467
- const predValid = isPredictValid();
483
+ const updateInputStyles = () => {
484
+ const featuresValid = featuresInput.value.length >= 1;
485
+ const predValid = featuresValid && !doFeaturesIncludePredict();
468
486
  let compValid: boolean;
469
487
 
470
- if (predValid) {
488
+ if (!featuresValid)
489
+ setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.ENOUGH);
490
+ else if (predValid) {
471
491
  setStyle(true, predictInput.input, HINT.PREDICT, '');
472
492
  setStyle(true, featuresInput.input, HINT.FEATURES, '');
473
493
  } else {
@@ -475,9 +495,12 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
475
495
  setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.PREDICT);
476
496
  }
477
497
 
478
- if (components < 1) {
498
+ if (componentsInput.value == null) {
499
+ setStyle(false, componentsInput.input, HINT.COMPONENTS, ERROR_MSG.NULL_COMPS);
500
+ compValid = false;
501
+ } else if (componentsInput.value < 1) {
479
502
  setStyle(false, componentsInput.input, HINT.COMPONENTS, ERROR_MSG.COMPONENTS);
480
- compValid = false;
503
+ compValid = false;
481
504
  } else {
482
505
  compValid = isCompConsistent();
483
506
 
@@ -486,7 +509,9 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
486
509
  if (predValid)
487
510
  setStyle(true, featuresInput.input, HINT.FEATURES, '');
488
511
  } else {
489
- const errMsg = isQuadratic ? ERROR_MSG.COMP_QUA_PLS : ERROR_MSG.COMP_LIN_PLS;
512
+ const errMsg = componentsInput.value! > table.rowCount ?
513
+ ERROR_MSG.COMP_ROWS :
514
+ isQuadraticInput.value ? ERROR_MSG.COMP_QUA_PLS : ERROR_MSG.COMP_LIN_PLS;
490
515
  setStyle(false, componentsInput.input, HINT.COMPONENTS, errMsg);
491
516
  setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.ENOUGH);
492
517
  }
@@ -497,10 +522,18 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
497
522
  dlg.getButton(TITLE.RUN).disabled = !isValid;
498
523
 
499
524
  return isValid;
525
+ }; // updateInputStyles
526
+
527
+ const getStrColWithUniqueVals = () => {
528
+ for (const col of strCols) {
529
+ if (col.stats.uniqueCount === table.rowCount)
530
+ return col;
531
+ }
532
+ return undefined;
500
533
  };
501
534
 
502
535
  // names of samples
503
- let names = (strCols.length > 0) ? strCols[0] : undefined;
536
+ let names = getStrColWithUniqueVals();
504
537
  const namesInputs = ui.input.column(TITLE.NAMES, {
505
538
  table: table,
506
539
  value: names,
@@ -512,11 +545,10 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
512
545
 
513
546
  // quadratic/linear model
514
547
  const isQuadraticInput = ui.input.bool(TITLE.QUADRATIC, {
515
- value: isQuadratic,
548
+ value: false,
516
549
  tooltipText: HINT.QUADRATIC,
517
- onValueChanged: (val) => {
518
- isQuadratic = val;
519
- updateIputs();
550
+ onValueChanged: (_) => {
551
+ updateInputStyles();
520
552
  },
521
553
  });
522
554
 
@@ -527,21 +559,15 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
527
559
 
528
560
  await performMVA({
529
561
  table: table,
530
- features: DG.DataFrame.fromColumns(features).columns,
531
- predict: predict,
532
- components: components,
533
- isQuadratic: isQuadratic,
562
+ features: DG.DataFrame.fromColumns(featuresInput.value).columns,
563
+ predict: predictInput.value!,
564
+ components: componentsInput.value!,
565
+ isQuadratic: isQuadraticInput.value,
534
566
  names: names,
535
567
  }, analysisType);
536
568
  }, undefined, dlgRunBtnTooltip)
537
569
  .show({x: X_COORD, y: Y_COORD});
538
570
 
539
- // the following delay provides correct styles (see https://reddata.atlassian.net/browse/GROK-15196)
540
- setTimeout(() => {
541
- featuresInput.value = numCols.filter((col) => col !== predict);
542
- features = featuresInput.value;
543
- }, TIMEOUT);
544
-
545
571
  grok.shell.v.append(dlg.root);
546
572
  } // runMVA
547
573
 
@@ -555,3 +581,79 @@ export async function runDemoMVA(): Promise<void> {
555
581
 
556
582
  await runMVA(PLS_ANALYSIS.DEMO);
557
583
  }
584
+
585
+ function setPredictionTooltip(view: DG.TableView, predCol: DG.Column, modelTerms: Map<string, number>): void {
586
+ view.grid.onCellTooltip((cell, x, y) => {
587
+ if (cell.isColHeader) {
588
+ const cellCol = cell.tableColumn;
589
+
590
+ if (cellCol == null)
591
+ return false;
592
+
593
+ if (cellCol.name === predCol.name) {
594
+ ui.tooltip.show(getPredictionTooltip(modelTerms, predCol), x, y);
595
+ return true;
596
+ }
597
+ }
598
+ return false;
599
+ });
600
+ }
601
+
602
+ function getPredictionTooltip(modelTerms: Map<string, number>, predCol: DG.Column): HTMLElement {
603
+ let idx = 0;
604
+ const bias = modelTerms.get(TITLE.BIAS) ?? 0;
605
+ const elements: HTMLElement[] = [];
606
+ if (Math.abs(bias) > 0) {
607
+ const biasEl = ui.divText(`${bias}`);
608
+ biasEl.style.marginTop = '2px';
609
+ biasEl.style.marginLeft = '4px';
610
+ elements.push(biasEl);
611
+ ++idx;
612
+ }
613
+
614
+ const sortedTerms = [...modelTerms.entries()]
615
+ .filter(([key]) => key !== TITLE.BIAS)
616
+ .sort((a, b) => Math.abs(b[1]) - Math.abs(a[1]));
617
+
618
+ const maxFeatureRows = MAX_ROWS_IN_PREDICTION_TOOLTIP - elements.length;
619
+ const hasOverflow = sortedTerms.length > maxFeatureRows;
620
+ const visibleTerms = hasOverflow ? sortedTerms.slice(0, maxFeatureRows - 1) : sortedTerms;
621
+
622
+ for (const [key, value] of visibleTerms) {
623
+ const signEl = ui.divText(idx > 0 ? '+ ' : '');
624
+ signEl.style.marginRight = '4px';
625
+ signEl.style.marginLeft = '4px';
626
+
627
+ const featureEl = ui.divText(`${key}`);
628
+ featureEl.style.fontWeight = 'bold';
629
+
630
+ const valueEl = ui.divText(` * ${value > 0 ? value : `(${value})`}`);
631
+ valueEl.style.marginLeft = '4px';
632
+
633
+ const rowEl = ui.divH([signEl, featureEl, valueEl]);
634
+ rowEl.style.marginTop = '4px';
635
+ elements.push(rowEl);
636
+
637
+ ++idx;
638
+ }
639
+
640
+ if (hasOverflow) {
641
+ const hidden = sortedTerms.length - visibleTerms.length;
642
+ const ellipsisEl = ui.divText(`(${hidden} more term${hidden > 1 ? 's' : ''})`);
643
+ ellipsisEl.style.marginTop = '4px';
644
+ ellipsisEl.style.marginLeft = '4px';
645
+ ellipsisEl.style.fontStyle = 'italic';
646
+ elements.push(ellipsisEl);
647
+ }
648
+
649
+ const headerEl = ui.divText('Formula:');
650
+
651
+ const leftEl = ui.divText(`${predCol.name} = `);
652
+ leftEl.style.fontWeight = 'bold';
653
+ leftEl.style.marginTop = '4px';
654
+
655
+ const elementsContainer = ui.divV(elements);
656
+ elementsContainer.style.marginTop = '4px';
657
+
658
+ return ui.divV([headerEl, leftEl, elementsContainer]);
659
+ }
@@ -11,12 +11,57 @@ import * as jStat from 'jstat';
11
11
  /** Generates synthetic data for pMPO model training and testing
12
12
  * @param samplesCount Number of samples to generate
13
13
  * @returns DataFrame with generated data */
14
- export async function getSynteticPmpoData(samplesCount: number): Promise<DG.DataFrame> {
14
+ export async function getSynteticPmpoData(samplesCount: number, isTest: boolean = true): Promise<DG.DataFrame> {
15
15
  const df = await grok.dapi.files.readCsv(SOURCE_PATH);
16
16
  const generator = new PmpoDataGenerator(df, 'Drug', 'CNS', 'Smiles');
17
+ const genTable = generator.getGenerated(samplesCount);
18
+
19
+ if (!isTest) {
20
+ genTable.columns.add(DG.Column.fromList(DG.COLUMN_TYPE.BOOL, 'Const bool', new Array(samplesCount).fill(true)));
21
+ genTable.columns.add(DG.Column.fromInt32Array('Const int', new Int32Array(samplesCount).fill(1)));
22
+
23
+ // Add a copy of the first numeric column with 5 missing values
24
+ const firstNumCol = genTable.columns.toList().find((col) => col.isNumerical);
25
+ if (firstNumCol) {
26
+ const colWithMissing = firstNumCol.clone();
27
+ colWithMissing.name = `${firstNumCol.name} (missing)`;
28
+ for (let i = 0; i < Math.min(5, colWithMissing.length); ++i)
29
+ colWithMissing.set(i, DG.FLOAT_NULL);
30
+ genTable.columns.add(colWithMissing);
31
+ }
32
+
33
+ // Add a column with all null values
34
+ genTable.columns.add(DG.Column.fromFloat32Array('Nulls', new Float32Array(samplesCount).fill(DG.FLOAT_NULL)));
35
+
36
+ // Add categorical columns
37
+ const categoricalCols = getCategoricalColumns(genTable.col('CNS')!, samplesCount);
38
+ for (const col of categoricalCols)
39
+ genTable.columns.add(col);
40
+ }
41
+
42
+ return genTable;
43
+ } // getSynteticPmpoData
44
+
45
+ /** Generates categorical columns based on a boolean source column
46
+ * @param sourceBoolCol Source boolean column to base the categorical columns on
47
+ * @param samplesCount Number of samples to generate
48
+ * @returns Array of generated categorical columns */
49
+ function getCategoricalColumns(sourceBoolCol: DG.Column, samplesCount: number): DG.Column[] {
50
+ const source = sourceBoolCol.toList();
51
+ const stringLabels = new Array<string>(samplesCount);
52
+ const threeCats = new Array<string>(samplesCount);
53
+
54
+ for (let i = 0; i < samplesCount; ++i) {
55
+ stringLabels[i] = source[i] ? 'active' : 'non-active';
56
+ threeCats[i] = source[i] ? (Math.random() < 0.5 ? 'perfect' : 'good') : (Math.random() < 0.5 ? 'bad' : 'worst');
57
+ }
17
58
 
18
- return generator.getGenerated(samplesCount);
19
- }
59
+ return [
60
+ DG.Column.fromList(DG.COLUMN_TYPE.STRING, 'CNS (strings)', stringLabels),
61
+ DG.Column.fromList(DG.COLUMN_TYPE.STRING, 'CNS (4 categories)', threeCats),
62
+ DG.Column.fromList(DG.COLUMN_TYPE.STRING, 'Single category', new Array<string>(samplesCount).fill('Unknown')),
63
+ ];
64
+ } // getCategoricalColumns
20
65
 
21
66
  /** Class for generating synthetic data for pMPO model training and testing */
22
67
  export class PmpoDataGenerator {
@@ -1,5 +1,5 @@
1
1
  // Constants and type definitions for probabilistic scoring (pMPO)
2
- // Link: https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
2
+ // Source paper https://pmc.ncbi.nlm.nih.gov/articles/PMC4716604/
3
3
 
4
4
  /** Minimum number of samples required to compute pMPO */
5
5
  export const MIN_SAMPLES_COUNT = 10;
@@ -208,7 +208,8 @@ export type OptimalPoint = {
208
208
  pValTresh: number,
209
209
  r2Tresh: number,
210
210
  qCutoff: number,
211
- success: boolean,
211
+ state: 'success' | 'canceled' | 'failed',
212
+ msg: string,
212
213
  };
213
214
 
214
215
  /** Minimum bounds for pMPO parameters during optimization */
@@ -216,3 +217,30 @@ export const LOW_PARAMS_BOUNDS = new Float32Array([0.5, Q_CUTOFF_MIN]);
216
217
 
217
218
  /** Maximum bounds for pMPO parameters during optimization */
218
219
  export const HIGH_PARAMS_BOUNDS = new Float32Array([R2_MAX, Q_CUTOFF_MAX]);
220
+
221
+ export enum EQUALITY_SIGN {
222
+ GREATER = '>',
223
+ LESS = '<',
224
+ GREATER_OR_EQUAL = '≥',
225
+ LESS_OR_EQUAL = '≤',
226
+ DEFAULT = LESS_OR_EQUAL,
227
+ };
228
+
229
+ export const SIGN_OPTIONS = [
230
+ EQUALITY_SIGN.GREATER,
231
+ EQUALITY_SIGN.LESS,
232
+ EQUALITY_SIGN.GREATER_OR_EQUAL,
233
+ EQUALITY_SIGN.LESS_OR_EQUAL,
234
+ ];
235
+
236
+ export const THRESHOLDED_DESIRABILITY_COL_NAME = 'Desirability';
237
+
238
+ export const PREFERABLE_CATEGORIES = ['perfect', 'good', 'true', 't', 'g', 'active', 'a', 'yes', 'y'];
239
+
240
+ export type PmpoInputId = 'descriptors' | 'desirability' | 'threshold' | 'categories';
241
+ export type TooltipContent = string | (() => HTMLElement);
242
+
243
+ export interface PmpoValidationResult {
244
+ valid: boolean;
245
+ errors: Map<PmpoInputId, TooltipContent>;
246
+ }