@datagrok/eda 1.1.33 → 1.1.35

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,13 +1,13 @@
1
1
  {
2
2
  "name": "@datagrok/eda",
3
3
  "friendlyName": "EDA",
4
- "version": "1.1.33",
4
+ "version": "1.1.35",
5
5
  "description": "Exploratory Data Analysis Tools",
6
6
  "dependencies": {
7
7
  "@datagrok-libraries/math": "^1.1.11",
8
- "@datagrok-libraries/ml": "^6.6.21",
8
+ "@datagrok-libraries/ml": "^6.6.23",
9
9
  "@datagrok-libraries/tutorials": "^1.3.13",
10
- "@datagrok-libraries/utils": "^4.2.20",
10
+ "@datagrok-libraries/utils": "^4.2.29",
11
11
  "@keckelt/tsne": "^1.0.2",
12
12
  "@webgpu/types": "^0.1.40",
13
13
  "cash-dom": "^8.1.1",
package/src/package.ts CHANGED
@@ -173,6 +173,7 @@ export async function reduceDimensionality(): Promise<void> {
173
173
  else
174
174
  okButton.classList.remove('disabled');
175
175
  };
176
+ dialog.history(() => ({editorSettings: editor.getStringInput()}), (x: any) => editor.applyStringInput(x['editorSettings']));
176
177
  editor.onColumnsChanged.subscribe(() => {
177
178
  try {
178
179
  validate();
@@ -688,11 +689,12 @@ export function isInteractiveSoftmax(df: DG.DataFrame, predictColumn: DG.Column)
688
689
  export async function trainPLSRegression(df: DG.DataFrame, predictColumn: DG.Column, components: number): Promise<Uint8Array> {
689
690
  const features = df.columns;
690
691
 
691
- if (components > features.length)
692
- throw new Error('Number of components is greater than features count');
693
-
694
692
  const model = new PlsModel();
695
- await model.fit(features, predictColumn, components);
693
+ await model.fit(
694
+ features,
695
+ predictColumn,
696
+ Math.min(components, features.length),
697
+ );
696
698
 
697
699
  return model.toBytes();
698
700
  }
@@ -92,13 +92,35 @@ export async function getPlsAnalysis(input: PlsInput): Promise<PlsOutput> {
92
92
  };
93
93
  }
94
94
 
95
+ /** Return debiased predction by PLS regression */
96
+ function debiasedPrediction(features: DG.ColumnList, params: DG.Column,
97
+ target: DG.Column, biasedPrediction: DG.Column): DG.Column {
98
+ const samples = target.length;
99
+ const dim = features.length;
100
+ const rawParams = params.getRawData();
101
+ const debiased = new Float32Array(samples);
102
+ const biased = biasedPrediction.getRawData();
103
+
104
+ // Compute bias
105
+ let bias = target.stats.avg;
106
+ for (let i = 0; i < dim; ++i)
107
+ bias -= rawParams[i] * features.byIndex(i).stats.avg;
108
+
109
+ // Compute debiased prediction
110
+ for (let i = 0; i < samples; ++i)
111
+ debiased[i] = bias + biased[i];
112
+
113
+ return DG.Column.fromFloat32Array('Debiased', debiased, samples);
114
+ }
115
+
95
116
  /** Perform multivariate analysis using the PLS regression */
96
117
  async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<void> {
97
118
  const result = await getPlsAnalysis(input);
98
119
 
99
120
  const plsCols = result.tScores;
100
121
  const cols = input.table.columns;
101
- const featuresNames = input.features.names();
122
+ const features = input.features;
123
+ const featuresNames = features.names();
102
124
  const prefix = (analysisType === PLS_ANALYSIS.COMPUTE_COMPONENTS) ? RESULT_NAMES.PREFIX : TITLE.XSCORE;
103
125
 
104
126
  // add PLS components to the table
@@ -129,7 +151,8 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
129
151
  });
130
152
 
131
153
  // 1. Predicted vs Reference scatter plot
132
- const pred = result.prediction;
154
+ // Debias prediction (since PLS center data)
155
+ const pred = debiasedPrediction(features, result.regressionCoefficients, input.predict, result.prediction);
133
156
  pred.name = cols.getUnusedName(`${input.predict.name} ${RESULT_NAMES.SUFFIX}`);
134
157
  cols.add(pred);
135
158
  const predictVsReferScatter = view.addViewer(DG.Viewer.scatterPlot(input.table, {
package/src/regression.ts CHANGED
@@ -7,19 +7,9 @@ import * as DG from 'datagrok-api/dg';
7
7
  import {_fitLinearRegressionParamsWithDataNormalizing} from '../wasm/EDAAPI';
8
8
  import {getPlsAnalysis} from './pls/pls-tools';
9
9
 
10
- // Linear regression computations limits
11
- const FATURES_COUNT_LIMIT = 1000;
12
- const SAMPLES_COUNT_LIMIT = 1000000;
13
-
14
10
  // Default PLS components count
15
11
  const PLS_COMPONENTS_COUNT = 10;
16
12
 
17
- // Wasm computations specific constants (see https://eigen.tuxfamily.org/dox/classEigen_1_1LDLT.html)
18
- const BYTES_PER_VALUE = 4; // wasm computations operates 4-byte floats
19
- const MEMORY_SCALE = 2; // due to the features of the Eigen lib decomposition
20
- const BUFFERS_COUNT = 1; // due to the features of the Eigen lib decomposition
21
- const WASM_MEMORY = 268435456; // wasm buffer size specified in '../scripts/module.json'
22
-
23
13
  /** Compute coefficients of linear regression */
24
14
  export async function getLinearRegressionParams(features: DG.ColumnList, targets: DG.Column): Promise<Float32Array> {
25
15
  const featuresCount = features.length;
@@ -37,24 +27,6 @@ export async function getLinearRegressionParams(features: DG.ColumnList, targets
37
27
 
38
28
  try {
39
29
  // Analyze inputs sizes
40
- const inputsAnalysis = getInputsAnalysis(featuresCount, samplesCount);
41
-
42
- if (inputsAnalysis.toApplyPLS) {
43
- // Apply the PLS method
44
- const paramsByPLS = await getLinearRegressionParamsUsingPLS(features, targets, inputsAnalysis.components);
45
-
46
- let tmpSum = 0;
47
-
48
- // Compute bias (due to the centering feature of PLS)
49
- for (let i = 0; i < featuresCount; ++i) {
50
- params[i] = paramsByPLS[i];
51
- tmpSum += paramsByPLS[i] * features.byIndex(i).stats.avg;
52
- }
53
-
54
- params[featuresCount] -= tmpSum;
55
-
56
- return params;
57
- }
58
30
 
59
31
  // Non-constant columns data
60
32
  const nonConstFeatureColsIndeces: number[] = [];
@@ -101,7 +73,22 @@ export async function getLinearRegressionParams(features: DG.ColumnList, targets
101
73
 
102
74
  params[featuresCount] = tempParams[nonConstFeaturesCount];
103
75
  } catch (e) {
104
- grok.shell.error(`Fitted the trivial model: ${e instanceof Error ? e.message : 'due to the platform issue'}`);
76
+ // Apply PLS regression if regular linear regression failed
77
+ const paramsByPLS = await getLinearRegressionParamsUsingPLS(
78
+ features,
79
+ targets,
80
+ componentsCount(features.length, targets.length),
81
+ );
82
+
83
+ let tmpSum = 0;
84
+
85
+ // Compute bias (due to the centering feature of PLS)
86
+ for (let i = 0; i < featuresCount; ++i) {
87
+ params[i] = paramsByPLS[i];
88
+ tmpSum += paramsByPLS[i] * features.byIndex(i).stats.avg;
89
+ }
90
+
91
+ params[featuresCount] -= tmpSum;
105
92
  }
106
93
 
107
94
  return params;
@@ -197,36 +184,10 @@ async function getLinearRegressionParamsUsingPLS(features: DG.ColumnList,
197
184
  return plsAnalysis.regressionCoefficients.getRawData() as Float32Array;
198
185
  }
199
186
 
200
- /** Check wasm-buffer overflow */
201
- const wasmBufferOverflow = (featuresCount: number, samplesCount: number) => {
202
- return MEMORY_SCALE * BYTES_PER_VALUE * samplesCount * (featuresCount + BUFFERS_COUNT) >= WASM_MEMORY;
203
- };
204
-
205
- /** Check whether to apply the PLS method & how many components to use */
206
- const getInputsAnalysis = (featuresCount: number, samplesCount: number) => {
207
- if (wasmBufferOverflow(featuresCount, samplesCount) || (featuresCount >= FATURES_COUNT_LIMIT)) {
208
- return {
209
- toApplyPLS: true,
210
- components: PLS_COMPONENTS_COUNT,
211
- };
212
- }
187
+ /** Return number of PLS components to be used */
188
+ const componentsCount = (featuresCount: number, samplesCount: number) => {
189
+ if (samplesCount <= featuresCount)
190
+ return Math.min(PLS_COMPONENTS_COUNT, samplesCount);
213
191
 
214
- if (samplesCount >= SAMPLES_COUNT_LIMIT) {
215
- return {
216
- toApplyPLS: true,
217
- components: Math.min(PLS_COMPONENTS_COUNT, featuresCount),
218
- };
219
- }
220
-
221
- if (samplesCount <= featuresCount) {
222
- return {
223
- toApplyPLS: true,
224
- components: Math.min(PLS_COMPONENTS_COUNT, samplesCount),
225
- };
226
- }
227
-
228
- return {
229
- toApplyPLS: false,
230
- components: PLS_COMPONENTS_COUNT,
231
- };
232
- }; // getInputsAnalysis
192
+ return Math.min(PLS_COMPONENTS_COUNT, featuresCount);
193
+ };