@datagrok/eda 1.4.1 → 1.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,13 +1,13 @@
1
1
  {
2
2
  "name": "@datagrok/eda",
3
3
  "friendlyName": "EDA",
4
- "version": "1.4.1",
4
+ "version": "1.4.2",
5
5
  "description": "Exploratory Data Analysis Tools",
6
6
  "dependencies": {
7
7
  "@datagrok-libraries/math": "^1.2.6",
8
8
  "@datagrok-libraries/ml": "^6.8.3",
9
9
  "@datagrok-libraries/tutorials": "^1.6.1",
10
- "@datagrok-libraries/utils": "^4.5.0",
10
+ "@datagrok-libraries/utils": "^4.5.7",
11
11
  "@keckelt/tsne": "^1.0.2",
12
12
  "@webgpu/types": "^0.1.40",
13
13
  "cash-dom": "^8.1.1",
package/src/package.ts CHANGED
@@ -278,6 +278,7 @@ export async function PLS(table: DG.DataFrame, features: DG.ColumnList, predict:
278
278
  features: features,
279
279
  predict: predict,
280
280
  components: components,
281
+ isQuadratic: false,
281
282
  names: names,
282
283
  });
283
284
  }
@@ -13,6 +13,12 @@ export enum ERROR_MSG {
13
13
  NO_COLS = 'No numeric columns without missing values',
14
14
  ONE_COL = 'No columns to be used as features (just one numeric columns without missing values)',
15
15
  EMPTY_DF = 'Dataframe is empty',
16
+ PREDICT = 'Predictors must not contain a response variable',
17
+ ENOUGH = 'Not enough of features',
18
+ COMP_LIN_PLS = 'Components count must be less than the number of features',
19
+ COMP_QUA_PLS = 'Too large components count for the quadratic PLS regression',
20
+ COMPONENTS = 'Components count must be greater than 1',
21
+ INV_INP = 'Invalid inputs',
16
22
  }
17
23
 
18
24
  /** Widget titles */
@@ -37,6 +43,7 @@ export enum TITLE {
37
43
  FEATURES = 'Feature names',
38
44
  BROWSE = 'Browse',
39
45
  ANALYSIS = 'Features Analysis',
46
+ QUADRATIC = 'Quadratic',
40
47
  }
41
48
 
42
49
  /** Tooltips */
@@ -47,6 +54,7 @@ export enum HINT {
47
54
  PLS = 'Compute PLS components',
48
55
  MVA = 'Perform multivariate analysis',
49
56
  NAMES = 'Names of data samples',
57
+ QUADRATIC = 'Specifies whether to include squared terms as additional predictors in the PLS model',
50
58
  }
51
59
 
52
60
  /** Links to help */
@@ -96,6 +104,9 @@ export const DELAY = 2000;
96
104
  export enum COLOR {
97
105
  AXIS = '#838383',
98
106
  CIRCLE = '#0000FF',
107
+ INVALID = '#EB6767',
108
+ VALID_TEXT = '#4d5261',
109
+ VALID_LINE = '#dbdcdf',
99
110
  };
100
111
 
101
112
  /** Intro markdown for demo app */
@@ -131,7 +142,7 @@ export const DEMO_RESULTS = [
131
142
  },
132
143
  {
133
144
  caption: TITLE.REGR_COEFS,
134
- text: 'Parameters of the obtained linear model: features make different contribution to the prediction.',
145
+ text: 'Parameters of the obtained model: features make different contribution to the prediction.',
135
146
  },
136
147
  {
137
148
  caption: TITLE.EXPL_VAR,
package/src/pls/pls-ml.ts CHANGED
@@ -121,6 +121,7 @@ export class PlsModel {
121
121
  predict: target,
122
122
  components: components,
123
123
  names: undefined,
124
+ isQuadratic: false,
124
125
  });
125
126
 
126
127
  // 1. Names of features
@@ -30,9 +30,12 @@ export type PlsInput = {
30
30
  features: DG.ColumnList,
31
31
  predict: DG.Column,
32
32
  components: number,
33
+ isQuadratic: boolean,
33
34
  names : DG.Column | undefined,
34
35
  };
35
36
 
37
+ type TypedArray = Int32Array | Float32Array | Uint32Array | Float64Array;
38
+
36
39
  /** Return lines */
37
40
  export function getLines(names: string[]): DG.FormulaLine[] {
38
41
  const lines: DG.FormulaLine[] = [];
@@ -113,12 +116,63 @@ function debiasedPrediction(features: DG.ColumnList, params: DG.Column,
113
116
  return DG.Column.fromFloat32Array('Debiased', debiased, samples);
114
117
  }
115
118
 
119
+ /** Return an input for the quadratic PLS regression */
120
+ function getQuadraticPlsInput(input: PlsInput): PlsInput {
121
+ if (!input.isQuadratic)
122
+ return input;
123
+
124
+ const cols: DG.Column[] = input.features.toList();
125
+ const colsCount = cols.length;
126
+ const rowsCount = input.table.rowCount;
127
+ const quadrCols: DG.Column[] = [];
128
+ let col1: DG.Column;
129
+ let raw1: TypedArray;
130
+ let col2: DG.Column;
131
+ let raw2: TypedArray;
132
+ let qaudrRaw: Float32Array;
133
+
134
+ for (let i = 0; i < colsCount; ++i) {
135
+ col1 = cols[i];
136
+ raw1 = col1.getRawData();
137
+
138
+ for (let j = i; j < colsCount; ++j) {
139
+ col2 = cols[j];
140
+ raw2 = col2.getRawData();
141
+ qaudrRaw = new Float32Array(rowsCount);
142
+
143
+ for (let k = 0; k < rowsCount; ++k)
144
+ qaudrRaw[k] = raw1[k] * raw2[k];
145
+
146
+ const quadrCol = DG.Column.fromFloat32Array(`${col1.name} x ${col2.name}`, qaudrRaw);
147
+
148
+ if (quadrCol.stats.stdev > 0)
149
+ quadrCols.push(quadrCol);
150
+ }
151
+ }
152
+
153
+ const extendedTable = DG.DataFrame.fromColumns(cols.concat(quadrCols));
154
+
155
+ return {
156
+ table: extendedTable,
157
+ features: extendedTable.columns,
158
+ isQuadratic: true,
159
+ names: input.names,
160
+ predict: input.predict,
161
+ components: input.components,
162
+ };
163
+ }
164
+
116
165
  /** Perform multivariate analysis using the PLS regression */
117
166
  async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<void> {
167
+ const sourceTable = input.table;
168
+
169
+ if (input.isQuadratic)
170
+ input = getQuadraticPlsInput(input);
171
+
118
172
  const result = await getPlsAnalysis(input);
119
173
 
120
174
  const plsCols = result.tScores;
121
- const cols = input.table.columns;
175
+ const cols = sourceTable.columns;
122
176
  const features = input.features;
123
177
  const featuresNames = features.names();
124
178
  const prefix = (analysisType === PLS_ANALYSIS.COMPUTE_COMPONENTS) ? RESULT_NAMES.PREFIX : TITLE.XSCORE;
@@ -132,7 +186,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
132
186
  if (analysisType === PLS_ANALYSIS.COMPUTE_COMPONENTS)
133
187
  return;
134
188
 
135
- const view = grok.shell.tableView(input.table.name);
189
+ const view = grok.shell.tableView(sourceTable.name);
136
190
 
137
191
  // 0.1 Buffer table
138
192
  const loadingsRegrCoefsTable = DG.DataFrame.fromColumns([
@@ -140,7 +194,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
140
194
  result.regressionCoefficients,
141
195
  ]);
142
196
 
143
- loadingsRegrCoefsTable.name = `${input.table.name}(${TITLE.ANALYSIS})`;
197
+ loadingsRegrCoefsTable.name = `${sourceTable.name}(${TITLE.ANALYSIS})`;
144
198
  grok.shell.addTable(loadingsRegrCoefsTable);
145
199
 
146
200
  // 0.2. Add X-Loadings
@@ -154,7 +208,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
154
208
  const pred = debiasedPrediction(features, result.regressionCoefficients, input.predict, result.prediction);
155
209
  pred.name = cols.getUnusedName(`${input.predict.name} ${RESULT_NAMES.SUFFIX}`);
156
210
  cols.add(pred);
157
- const predictVsReferScatter = view.addViewer(DG.Viewer.scatterPlot(input.table, {
211
+ const predictVsReferScatter = view.addViewer(DG.Viewer.scatterPlot(sourceTable, {
158
212
  title: TITLE.MODEL,
159
213
  xColumnName: input.predict.name,
160
214
  yColumnName: pred.name,
@@ -203,7 +257,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
203
257
  });
204
258
 
205
259
  // 4.2) create scatter
206
- const scoresScatter = DG.Viewer.scatterPlot(input.table, {
260
+ const scoresScatter = DG.Viewer.scatterPlot(sourceTable, {
207
261
  title: TITLE.SCORES,
208
262
  xColumnName: plsCols[0].name,
209
263
  yColumnName: (plsCols.length > 1) ? plsCols[1].name : result.uScores[0].name,
@@ -224,7 +278,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
224
278
  // here, we use notations from this paper
225
279
  const q = result.yLoadings.getRawData();
226
280
  const p = result.xLoadings.map((col) => col.getRawData());
227
- const n = input.table.rowCount;
281
+ const n = sourceTable.rowCount;
228
282
  const m = featuresNames.length;
229
283
  const A = input.components;
230
284
  const yExplVars = new Float32Array(A);
@@ -249,7 +303,7 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
249
303
  DG.Column.fromFloat32Array(input.predict.name, yExplVars),
250
304
  ]);
251
305
 
252
- explVarsDF.name = `${input.table.name}(${TITLE.EXPL_VAR})`;
306
+ explVarsDF.name = `${sourceTable.name}(${TITLE.EXPL_VAR})`;
253
307
  grok.shell.addTable(explVarsDF);
254
308
 
255
309
  xExplVars.forEach((arr, idx) => explVarsDF.columns.add(DG.Column.fromFloat32Array(featuresNames[idx], arr)));
@@ -318,33 +372,65 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
318
372
  return;
319
373
  }
320
374
 
321
- // responce (to predict)
375
+ let features: DG.Column[] = numCols.slice(0, numCols.length - 1);
322
376
  let predict = numCols[numCols.length - 1];
323
- const predictInput = ui.input.column(TITLE.PREDICT, {table: table, value: predict, onValueChanged: (value) => {
324
- predict = value;
325
- updateIputs();
326
- }, filter: (col: DG.Column) => isValidNumeric(col)},
327
- );
328
- predictInput.setTooltip(HINT.PREDICT);
377
+ let components = min(numColNames.length - 1, COMPONENTS.DEFAULT as number);
378
+ let isQuadratic = false;
379
+
380
+ const isPredictValid = () => {
381
+ for (const col of features)
382
+ if (col.name === predict.name)
383
+ return false;
384
+ return true;
385
+ };
386
+
387
+ const isCompConsistent = () => {
388
+ if (components < 1)
389
+ return false;
390
+
391
+ const n = features.length;
392
+
393
+ if (isQuadratic)
394
+ return components <= (n + 1) * n / 2 + n;
395
+
396
+ return components <= n;
397
+ }
398
+
399
+ // response (to predict)
400
+ const predictInput = ui.input.column(TITLE.PREDICT, {
401
+ table: table,
402
+ value: predict,
403
+ nullable: false,
404
+ onValueChanged: (value) => {
405
+ predict = value;
406
+ updateIputs();
407
+ },
408
+ filter: (col: DG.Column) => isValidNumeric(col),
409
+ tooltipText: HINT.PREDICT,
410
+ });
329
411
 
330
412
  // predictors (features)
331
- let features: DG.Column[];
332
- const featuresInput = ui.input.columns(TITLE.USING, {table: table, available: numColNames});
333
- featuresInput.onInput.subscribe(() => updateIputs());
334
- featuresInput.setTooltip(HINT.FEATURES);
413
+ const featuresInput = ui.input.columns(TITLE.USING, {
414
+ table: table,
415
+ available: numColNames,
416
+ value: features,
417
+ onValueChanged: (val) => {
418
+ features = val;
419
+ updateIputs();
420
+ },
421
+ tooltipText: HINT.FEATURES,
422
+ });
335
423
 
336
424
  // components count
337
- let components = min(numColNames.length - 1, COMPONENTS.DEFAULT as number);
338
- const componentsInput = ui.input.forProperty(DG.Property.fromOptions({
339
- name: TITLE.COMPONENTS,
340
- inputType: INT,
341
- defaultValue: components,
342
- //@ts-ignore
425
+ const componentsInput = ui.input.int(TITLE.COMPONENTS, {
426
+ value: components,
343
427
  showPlusMinus: true,
344
- min: COMPONENTS.MIN,
345
- }));
346
- componentsInput.onInput.subscribe(() => updateIputs());
347
- componentsInput.setTooltip(HINT.COMPONENTS);
428
+ onValueChanged: (val) => {
429
+ components = val;
430
+ updateIputs();
431
+ },
432
+ tooltipText: HINT.COMPONENTS,
433
+ });
348
434
 
349
435
  let dlgTitle: string;
350
436
  let dlgHelpUrl: string;
@@ -360,14 +446,57 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
360
446
  dlgRunBtnTooltip = HINT.MVA;
361
447
  }
362
448
 
363
- const updateIputs = () => {
364
- featuresInput.value = featuresInput.value.filter((col) => col !== predict);
365
- features = featuresInput.value;
366
-
367
- componentsInput.value = min(max(componentsInput.value ?? components, COMPONENTS.MIN), features.length);
368
- components = componentsInput.value;
449
+ const setStyle = (valid: boolean, element: HTMLElement, tooltip: string, errorMsg: string) => {
450
+ if (valid) {
451
+ element.style.color = COLOR.VALID_TEXT;
452
+ element.style.borderBottomColor = COLOR.VALID_LINE;
453
+ ui.tooltip.bind(element, tooltip);
454
+ } else {
455
+ element.style.color = COLOR.INVALID;
456
+ element.style.borderBottomColor = COLOR.INVALID;
457
+ ui.tooltip.bind(element, () => {
458
+ const hint = ui.label(tooltip);
459
+ const err = ui.label(errorMsg);
460
+ err.style.color = COLOR.INVALID;
461
+ return ui.divV([hint, err]);
462
+ });
463
+ }
464
+ };
369
465
 
370
- dlg.getButton(TITLE.RUN).disabled = (features.length === 0) || (components <= 0);
466
+ const updateIputs = () => {
467
+ const predValid = isPredictValid();
468
+ let compValid: boolean;
469
+
470
+ if (predValid) {
471
+ setStyle(true, predictInput.input, HINT.PREDICT, '');
472
+ setStyle(true, featuresInput.input, HINT.FEATURES, '');
473
+ } else {
474
+ setStyle(false, predictInput.input, HINT.PREDICT, ERROR_MSG.PREDICT);
475
+ setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.PREDICT);
476
+ }
477
+
478
+ if (components < 1) {
479
+ setStyle(false, componentsInput.input, HINT.COMPONENTS, ERROR_MSG.COMPONENTS);
480
+ compValid = false;
481
+ } else {
482
+ compValid = isCompConsistent();
483
+
484
+ if (compValid) {
485
+ setStyle(true, componentsInput.input, HINT.COMPONENTS, '');
486
+ if (predValid)
487
+ setStyle(true, featuresInput.input, HINT.FEATURES, '');
488
+ } else {
489
+ const errMsg = isQuadratic ? ERROR_MSG.COMP_QUA_PLS : ERROR_MSG.COMP_LIN_PLS;
490
+ setStyle(false, componentsInput.input, HINT.COMPONENTS, errMsg);
491
+ setStyle(false, featuresInput.input, HINT.FEATURES, ERROR_MSG.ENOUGH);
492
+ }
493
+ }
494
+
495
+ const isValid = predValid && compValid;
496
+
497
+ dlg.getButton(TITLE.RUN).disabled = !isValid;
498
+
499
+ return isValid;
371
500
  };
372
501
 
373
502
  // names of samples
@@ -381,8 +510,18 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
381
510
  namesInputs.setTooltip(HINT.NAMES);
382
511
  namesInputs.root.hidden = (strCols.length === 0) || (analysisType === PLS_ANALYSIS.COMPUTE_COMPONENTS);
383
512
 
513
+ // quadratic/linear model
514
+ const isQuadraticInput = ui.input.bool(TITLE.QUADRATIC, {
515
+ value: isQuadratic,
516
+ tooltipText: HINT.QUADRATIC,
517
+ onValueChanged: (val) => {
518
+ isQuadratic = val;
519
+ updateIputs();
520
+ },
521
+ });
522
+
384
523
  const dlg = ui.dialog({title: dlgTitle, helpUrl: dlgHelpUrl})
385
- .add(ui.form([predictInput, featuresInput, componentsInput, namesInputs]))
524
+ .add(ui.form([predictInput, featuresInput, componentsInput, isQuadraticInput, namesInputs]))
386
525
  .addButton(TITLE.RUN, async () => {
387
526
  dlg.close();
388
527
 
@@ -391,6 +530,7 @@ export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
391
530
  features: DG.DataFrame.fromColumns(features).columns,
392
531
  predict: predict,
393
532
  components: components,
533
+ isQuadratic: isQuadratic,
394
534
  names: names,
395
535
  }, analysisType);
396
536
  }, undefined, dlgRunBtnTooltip)
package/src/regression.ts CHANGED
@@ -179,6 +179,7 @@ async function getLinearRegressionParamsUsingPLS(features: DG.ColumnList,
179
179
  predict: targets,
180
180
  components: components,
181
181
  names: undefined,
182
+ isQuadratic: false,
182
183
  });
183
184
 
184
185
  return plsAnalysis.regressionCoefficients.getRawData() as Float32Array;
@@ -61,6 +61,7 @@ category('Partial least squares regression', () => {
61
61
  predict: cols.byIndex(COLS - 1),
62
62
  components: COMPONENTS,
63
63
  names: undefined,
64
+ isQuadratic: false,
64
65
  });
65
66
  }, {timeout: TIMEOUT, benchmark: true});
66
67
 
@@ -76,6 +77,7 @@ category('Partial least squares regression', () => {
76
77
  features: cols,
77
78
  predict: target,
78
79
  components: COMPONENTS,
80
+ isQuadratic: false,
79
81
  names: undefined,
80
82
  });
81
83