@datagrok/eda 1.1.28 → 1.1.29
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +5 -0
- package/README.md +1 -0
- package/dist/{05e5e0770f54f07e9474.wasm → 12a82b8001995d426ed2.wasm} +0 -0
- package/dist/23.js +1 -1
- package/dist/23.js.map +1 -1
- package/dist/501.js +2 -0
- package/dist/501.js.map +1 -0
- package/dist/727.js +2 -0
- package/dist/727.js.map +1 -0
- package/dist/package.js +1 -1
- package/dist/package.js.map +1 -1
- package/package.json +5 -5
- package/scripts/command.txt +1 -1
- package/scripts/func.json +664 -1
- package/scripts/module.json +1 -1
- package/src/data-generators.ts +1 -44
- package/src/missing-values-imputation/ui.ts +16 -6
- package/src/package.ts +60 -78
- package/src/regression.ts +1 -1
- package/src/softmax-classifier.ts +412 -0
- package/src/svm.ts +11 -33
- package/src/workers/softmax-worker.ts +146 -0
- package/wasm/EDA.js +55 -1
- package/wasm/EDA.wasm +0 -0
- package/wasm/EDAAPI.js +15 -0
- package/wasm/EDAForWebWorker.js +1 -1
- package/wasm/regression.h +2 -5
- package/wasm/softmax-api.cpp +49 -0
- package/wasm/softmax.h +156 -0
- package/wasm/workers/fitSoftmaxWorker.js +13 -0
- package/webpack.config.js +3 -2
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
// Softmax classifier (multinomial logistic regression): https://en.wikipedia.org/wiki/Multinomial_logistic_regression
|
|
2
|
+
|
|
3
|
+
import * as grok from 'datagrok-api/grok';
|
|
4
|
+
import * as ui from 'datagrok-api/ui';
|
|
5
|
+
import * as DG from 'datagrok-api/dg';
|
|
6
|
+
|
|
7
|
+
import {_fitSoftmax} from '../wasm/EDAAPI';
|
|
8
|
+
|
|
9
|
+
const ROWS_EXTRA = 1;
|
|
10
|
+
const COLS_EXTRA = 2;
|
|
11
|
+
const MIN_COLS_COUNT = 1 + COLS_EXTRA;
|
|
12
|
+
const AVGS_NAME = 'Avg-s';
|
|
13
|
+
const STDEVS_NAME = 'Stddev-s';
|
|
14
|
+
const PRED_NAME = 'predicted';
|
|
15
|
+
const DEFAULT_LEARNING_RATE = 1;
|
|
16
|
+
const DEFAULT_ITER_COUNT = 100;
|
|
17
|
+
const DEFAULT_PENALTY = 0.1;
|
|
18
|
+
const DEFAULT_TOLERANCE = 0.001;
|
|
19
|
+
const BYTES_PER_MODEL_SIZE = 4;
|
|
20
|
+
|
|
21
|
+
/** Train data sizes */
|
|
22
|
+
type DataSpecification = {
|
|
23
|
+
classesCount: number,
|
|
24
|
+
featuresCount: number,
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
/** Target labels specification */
|
|
28
|
+
type TargetLabelsData = {
|
|
29
|
+
oneHot: Array<Uint8Array>,
|
|
30
|
+
weights: Uint32Array,
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
/** Interactivity tresholds */
|
|
34
|
+
enum INTERACTIVITY {
|
|
35
|
+
MAX_SAMLPES = 50000,
|
|
36
|
+
MAX_FEATURES = 100,
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
/** Softmax classifier */
|
|
40
|
+
export class SoftmaxClassifier {
|
|
41
|
+
/** Check applicability */
|
|
42
|
+
static isApplicable(features: DG.ColumnList, predictColumn: DG.Column): boolean {
|
|
43
|
+
for (const col of features) {
|
|
44
|
+
if (!col.matches('numerical'))
|
|
45
|
+
return false;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return (predictColumn.type === DG.COLUMN_TYPE.STRING);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
/** Check interactivity */
|
|
52
|
+
static isInteractive(features: DG.ColumnList, predictColumn: DG.Column): boolean {
|
|
53
|
+
return (features.length <= INTERACTIVITY.MAX_FEATURES) &&
|
|
54
|
+
(predictColumn.length <= INTERACTIVITY.MAX_SAMLPES);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
private avgs: Float32Array;
|
|
58
|
+
private stdevs: Float32Array;
|
|
59
|
+
private categories: string[];
|
|
60
|
+
private params: Float32Array[] | undefined = undefined;
|
|
61
|
+
private classesCount = 1;
|
|
62
|
+
private featuresCount = 1;
|
|
63
|
+
|
|
64
|
+
constructor(specification?: DataSpecification, packedModel?: Uint8Array) {
|
|
65
|
+
if (specification !== undefined) { // Create empty model
|
|
66
|
+
/** features count */
|
|
67
|
+
const n = specification.featuresCount;
|
|
68
|
+
|
|
69
|
+
/** classes count */
|
|
70
|
+
const c = specification.classesCount;
|
|
71
|
+
|
|
72
|
+
if (n < 1)
|
|
73
|
+
throw new Error('Incorrect features count');
|
|
74
|
+
|
|
75
|
+
if (c < 1)
|
|
76
|
+
throw new Error('Incorrect classes count');
|
|
77
|
+
|
|
78
|
+
/** length of arrays */
|
|
79
|
+
const len = n + ROWS_EXTRA;
|
|
80
|
+
|
|
81
|
+
// Init model routine
|
|
82
|
+
this.avgs = new Float32Array(len);
|
|
83
|
+
this.stdevs = new Float32Array(len);
|
|
84
|
+
this.categories = new Array<string>(len);
|
|
85
|
+
this.featuresCount = n;
|
|
86
|
+
this.classesCount = c;
|
|
87
|
+
} else if (packedModel !== undefined) { // Get classifier from packed model (bytes)
|
|
88
|
+
try {
|
|
89
|
+
// Extract model's bytes count
|
|
90
|
+
const sizeArr = new Uint32Array(packedModel.buffer, 0, 1);
|
|
91
|
+
const bytesCount = sizeArr[0];
|
|
92
|
+
|
|
93
|
+
// Model's bytes
|
|
94
|
+
const modelBytes = new Uint8Array(packedModel.buffer, BYTES_PER_MODEL_SIZE, bytesCount);
|
|
95
|
+
|
|
96
|
+
const modelDf = DG.DataFrame.fromByteArray(modelBytes);
|
|
97
|
+
const columns = modelDf.columns;
|
|
98
|
+
const colsCount = columns.length;
|
|
99
|
+
|
|
100
|
+
if (colsCount < MIN_COLS_COUNT)
|
|
101
|
+
throw new Error('incorrect columns count');
|
|
102
|
+
|
|
103
|
+
this.classesCount = colsCount - COLS_EXTRA;
|
|
104
|
+
this.featuresCount = modelDf.rowCount - ROWS_EXTRA;
|
|
105
|
+
|
|
106
|
+
const c = this.classesCount;
|
|
107
|
+
|
|
108
|
+
// extract params & categories
|
|
109
|
+
this.params = new Array<Float32Array>(c);
|
|
110
|
+
this.categories = new Array<string>(modelDf.rowCount);
|
|
111
|
+
|
|
112
|
+
for (let i = 0; i < c; ++i) {
|
|
113
|
+
const col = columns.byIndex(i);
|
|
114
|
+
this.categories[i] = col.name;
|
|
115
|
+
|
|
116
|
+
if (col.type !== DG.COLUMN_TYPE.FLOAT)
|
|
117
|
+
throw new Error(`Incorrect input column type. Expected: float, passed: ${col.type}`);
|
|
118
|
+
|
|
119
|
+
this.params[i] = col.getRawData() as Float32Array;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// extract averages
|
|
123
|
+
const avgsCol = columns.byName(AVGS_NAME);
|
|
124
|
+
if (avgsCol.type !== DG.COLUMN_TYPE.FLOAT)
|
|
125
|
+
throw new Error('incorrect average values column type');
|
|
126
|
+
this.avgs = avgsCol.getRawData() as Float32Array;
|
|
127
|
+
|
|
128
|
+
// extract stdevs
|
|
129
|
+
const stdevsCol = columns.byName(STDEVS_NAME);
|
|
130
|
+
if (stdevsCol.type !== DG.COLUMN_TYPE.FLOAT)
|
|
131
|
+
throw new Error('incorrect standard deviations column type');
|
|
132
|
+
this.stdevs = stdevsCol.getRawData() as Float32Array;
|
|
133
|
+
} catch (e) {
|
|
134
|
+
throw new Error(`Failed to load model: ${(e instanceof Error ? e.message : 'the platform issue')}`);
|
|
135
|
+
}
|
|
136
|
+
} else
|
|
137
|
+
throw new Error('Softmax classifier not initialized');
|
|
138
|
+
}; // constructor
|
|
139
|
+
|
|
140
|
+
/** Return packed softmax classifier */
|
|
141
|
+
public toBytes(): Uint8Array {
|
|
142
|
+
if (this.params === undefined)
|
|
143
|
+
throw new Error('Non-trained model');
|
|
144
|
+
|
|
145
|
+
const c = this.classesCount;
|
|
146
|
+
const columns = new Array<DG.Column>(c + COLS_EXTRA);
|
|
147
|
+
|
|
148
|
+
// params columns
|
|
149
|
+
for (let i = 0; i < c; ++i)
|
|
150
|
+
columns[i] = DG.Column.fromFloat32Array(this.categories[i], this.params[i]);
|
|
151
|
+
|
|
152
|
+
// averages
|
|
153
|
+
columns[c] = DG.Column.fromFloat32Array(AVGS_NAME, this.avgs);
|
|
154
|
+
|
|
155
|
+
// stdevs
|
|
156
|
+
columns[c + 1] = DG.Column.fromFloat32Array(STDEVS_NAME, this.stdevs);
|
|
157
|
+
|
|
158
|
+
const modelDf = DG.DataFrame.fromColumns(columns);
|
|
159
|
+
|
|
160
|
+
const modelBytes = modelDf.toByteArray();
|
|
161
|
+
const bytesCount = modelBytes.length;
|
|
162
|
+
|
|
163
|
+
// Packed model bytes, including bytes count
|
|
164
|
+
const packedModel = new Uint8Array(bytesCount + BYTES_PER_MODEL_SIZE);
|
|
165
|
+
|
|
166
|
+
// 4 bytes for storing model's bytes count
|
|
167
|
+
const sizeArr = new Uint32Array(packedModel.buffer, 0, 1);
|
|
168
|
+
sizeArr[0] = bytesCount;
|
|
169
|
+
|
|
170
|
+
// Store model's bytes
|
|
171
|
+
packedModel.set(modelBytes, BYTES_PER_MODEL_SIZE);
|
|
172
|
+
|
|
173
|
+
return packedModel;
|
|
174
|
+
} // toBytes
|
|
175
|
+
|
|
176
|
+
/** Train classifier */
|
|
177
|
+
public async fit(features: DG.ColumnList, target: DG.Column, rate: number = DEFAULT_LEARNING_RATE,
|
|
178
|
+
iterations: number = DEFAULT_ITER_COUNT, penalty: number = DEFAULT_PENALTY, tolerance: number = DEFAULT_TOLERANCE) {
|
|
179
|
+
if (features.length !== this.featuresCount)
|
|
180
|
+
throw new Error('Training failes - incorrect features count');
|
|
181
|
+
|
|
182
|
+
if ((rate <= 0) || (iterations < 1) || (penalty <= 0) || (tolerance <= 0))
|
|
183
|
+
throw new Error('Training failes - incorrect fitting hyperparameters');
|
|
184
|
+
|
|
185
|
+
// Extract statistics & categories
|
|
186
|
+
this.extractStats(features);
|
|
187
|
+
const rowsCount = target.length;
|
|
188
|
+
const classesCount = target.categories.length;
|
|
189
|
+
const cats = target.categories;
|
|
190
|
+
for (let i = 0; i < classesCount; ++i)
|
|
191
|
+
this.categories[i] = cats[i];
|
|
192
|
+
|
|
193
|
+
try {
|
|
194
|
+
// call wasm-computations
|
|
195
|
+
const paramCols = _fitSoftmax(
|
|
196
|
+
features,
|
|
197
|
+
DG.Column.fromFloat32Array('avgs', this.avgs, this.featuresCount),
|
|
198
|
+
DG.Column.fromFloat32Array('stdevs', this.stdevs, this.featuresCount),
|
|
199
|
+
DG.Column.fromInt32Array('targets', target.getRawData() as Int32Array, rowsCount),
|
|
200
|
+
classesCount,
|
|
201
|
+
iterations, rate, penalty, tolerance,
|
|
202
|
+
this.featuresCount + 1, classesCount,
|
|
203
|
+
).columns as DG.ColumnList;
|
|
204
|
+
|
|
205
|
+
this.params = new Array<Float32Array>(classesCount);
|
|
206
|
+
for (let i = 0; i < classesCount; ++i)
|
|
207
|
+
this.params[i] = paramCols.byIndex(i).getRawData() as Float32Array;
|
|
208
|
+
} catch (error) {
|
|
209
|
+
try { // call fitting TS-computations (if wasm failed)
|
|
210
|
+
this.params = await this.fitSoftmaxParams(
|
|
211
|
+
features,
|
|
212
|
+
target,
|
|
213
|
+
iterations,
|
|
214
|
+
rate,
|
|
215
|
+
penalty,
|
|
216
|
+
tolerance,
|
|
217
|
+
) as Float32Array[];
|
|
218
|
+
} catch (error) {
|
|
219
|
+
throw new Error('Training failes');
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
if (this.params === undefined)
|
|
224
|
+
throw new Error('Training failes');
|
|
225
|
+
}; // fit
|
|
226
|
+
|
|
227
|
+
/** Extract features' stats */
|
|
228
|
+
private extractStats(features: DG.ColumnList): void {
|
|
229
|
+
let j = 0;
|
|
230
|
+
|
|
231
|
+
for (const col of features) {
|
|
232
|
+
if ((col.type !== DG.COLUMN_TYPE.INT) && (col.type !== DG.COLUMN_TYPE.FLOAT))
|
|
233
|
+
throw new Error('Training failes - incorrect features type');
|
|
234
|
+
|
|
235
|
+
this.avgs[j] = col.stats.avg;
|
|
236
|
+
this.stdevs[j] = col.stats.stdev;
|
|
237
|
+
|
|
238
|
+
++j;
|
|
239
|
+
}
|
|
240
|
+
} // extractStats
|
|
241
|
+
|
|
242
|
+
/** Retrun normalized features */
|
|
243
|
+
private normalized(features: DG.ColumnList): Array<Float32Array> {
|
|
244
|
+
const m = features.byIndex(0).length;
|
|
245
|
+
|
|
246
|
+
const X = new Array<Float32Array>(m);
|
|
247
|
+
|
|
248
|
+
for (let i = 0; i < m; ++i)
|
|
249
|
+
X[i] = new Float32Array(this.featuresCount);
|
|
250
|
+
|
|
251
|
+
let j = 0;
|
|
252
|
+
for (const col of features) {
|
|
253
|
+
if ((col.type !== DG.COLUMN_TYPE.INT) && (col.type !== DG.COLUMN_TYPE.FLOAT))
|
|
254
|
+
throw new Error('Training failes - incorrect features type');
|
|
255
|
+
|
|
256
|
+
const raw = col.getRawData();
|
|
257
|
+
const avg = this.avgs[j];
|
|
258
|
+
const stdev = this.stdevs[j];
|
|
259
|
+
|
|
260
|
+
if (stdev > 0) {
|
|
261
|
+
for (let i = 0; i < m; ++i)
|
|
262
|
+
X[i][j] = (raw[i] - avg) / stdev;
|
|
263
|
+
} else {
|
|
264
|
+
for (let i = 0; i < m; ++i)
|
|
265
|
+
X[i][j] = 0;
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
++j;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
return X;
|
|
272
|
+
} // normalized
|
|
273
|
+
|
|
274
|
+
/** Retrun normalized & transposed features */
|
|
275
|
+
private transposed(features: DG.ColumnList): Array<Float32Array> {
|
|
276
|
+
const m = features.byIndex(0).length;
|
|
277
|
+
const n = this.featuresCount;
|
|
278
|
+
|
|
279
|
+
const X = new Array<Float32Array>(n);
|
|
280
|
+
|
|
281
|
+
for (let i = 0; i < n; ++i)
|
|
282
|
+
X[i] = new Float32Array(m);
|
|
283
|
+
|
|
284
|
+
let j = 0;
|
|
285
|
+
for (const col of features) {
|
|
286
|
+
if ((col.type !== DG.COLUMN_TYPE.INT) && (col.type !== DG.COLUMN_TYPE.FLOAT))
|
|
287
|
+
throw new Error('Training failes - incorrect features type');
|
|
288
|
+
|
|
289
|
+
const raw = col.getRawData();
|
|
290
|
+
const avg = this.avgs[j];
|
|
291
|
+
const stdev = this.stdevs[j];
|
|
292
|
+
|
|
293
|
+
if (stdev > 0) {
|
|
294
|
+
for (let i = 0; i < m; ++i)
|
|
295
|
+
X[j][i] = (raw[i] - avg) / stdev;
|
|
296
|
+
} else {
|
|
297
|
+
for (let i = 0; i < m; ++i)
|
|
298
|
+
X[j][i] = 0;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
++j;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
return X;
|
|
305
|
+
} // transposed
|
|
306
|
+
|
|
307
|
+
/** Return one-hot vectors and classes weights */
|
|
308
|
+
private preprocessedTargets(target: DG.Column): TargetLabelsData {
|
|
309
|
+
if (target.type !== DG.COLUMN_TYPE.STRING)
|
|
310
|
+
throw new Error('Training failes - incorrect target type');
|
|
311
|
+
|
|
312
|
+
const c = this.classesCount;
|
|
313
|
+
const m = target.length;
|
|
314
|
+
const raw = target.getRawData();
|
|
315
|
+
|
|
316
|
+
const Y = new Array<Uint8Array>(m);
|
|
317
|
+
const weights = new Uint32Array(c).fill(0);
|
|
318
|
+
|
|
319
|
+
for (let i = 0; i < m; ++i)
|
|
320
|
+
Y[i] = new Uint8Array(c).fill(0);
|
|
321
|
+
|
|
322
|
+
for (let i = 0; i < m; ++i) {
|
|
323
|
+
Y[i][raw[i]] = 1;
|
|
324
|
+
++weights[raw[i]];
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
return {
|
|
328
|
+
oneHot: Y,
|
|
329
|
+
weights: weights,
|
|
330
|
+
};
|
|
331
|
+
} // getOneHot
|
|
332
|
+
|
|
333
|
+
/** Return prediction column */
|
|
334
|
+
public predict(features: DG.ColumnList): DG.Column {
|
|
335
|
+
if (this.params === undefined)
|
|
336
|
+
throw new Error('Non-trained model');
|
|
337
|
+
|
|
338
|
+
if (features.length !== this.featuresCount)
|
|
339
|
+
throw new Error('Predcition fails: incorrect features count');
|
|
340
|
+
|
|
341
|
+
// Normalize features
|
|
342
|
+
const X = this.normalized(features);
|
|
343
|
+
|
|
344
|
+
// Routine items
|
|
345
|
+
const m = X.length;
|
|
346
|
+
const n = this.featuresCount;
|
|
347
|
+
const c = this.classesCount;
|
|
348
|
+
let xBuf: Float32Array;
|
|
349
|
+
let wBuf: Float32Array;
|
|
350
|
+
const Z = new Float32Array(c);
|
|
351
|
+
let sum: number;
|
|
352
|
+
let max: number;
|
|
353
|
+
let argMax: number;
|
|
354
|
+
const predClass = new Array<string>(m);
|
|
355
|
+
|
|
356
|
+
// get prediction for each sample
|
|
357
|
+
for (let j = 0; j < m; ++j) {
|
|
358
|
+
xBuf = X[j];
|
|
359
|
+
sum = 0;
|
|
360
|
+
|
|
361
|
+
for (let i = 0; i < c; ++i) {
|
|
362
|
+
wBuf = this.params[i];
|
|
363
|
+
sum = wBuf[n];
|
|
364
|
+
|
|
365
|
+
for (let k = 0; k < n; ++k)
|
|
366
|
+
sum += wBuf[k] * xBuf[k];
|
|
367
|
+
|
|
368
|
+
Z[i] = Math.exp(sum);
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
max = Z[0];
|
|
372
|
+
argMax = 0;
|
|
373
|
+
|
|
374
|
+
for (let k = 1; k < c; ++k) {
|
|
375
|
+
if (max < Z[k]) {
|
|
376
|
+
max = Z[k];
|
|
377
|
+
argMax = k;
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
predClass[j] = this.categories[argMax];
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
return DG.Column.fromStrings(PRED_NAME, predClass);
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
/** Fit params in the webworker */
|
|
388
|
+
private async fitSoftmaxParams(features: DG.ColumnList, target: DG.Column,
|
|
389
|
+
iterations: number, rate: number, penalty: number, tolerance: number) {
|
|
390
|
+
const targetData = this.preprocessedTargets(target);
|
|
391
|
+
|
|
392
|
+
return new Promise((resolve, reject) => {
|
|
393
|
+
const worker = new Worker(new URL('./workers/softmax-worker.ts', import.meta.url));
|
|
394
|
+
worker.postMessage({
|
|
395
|
+
features: this.normalized(features),
|
|
396
|
+
transposed: this.transposed(features),
|
|
397
|
+
oneHot: targetData.oneHot,
|
|
398
|
+
classesWeights: targetData.weights,
|
|
399
|
+
targetRaw: target.getRawData(),
|
|
400
|
+
iterations: iterations,
|
|
401
|
+
rate: rate,
|
|
402
|
+
penalty: penalty,
|
|
403
|
+
tolerance: tolerance,
|
|
404
|
+
});
|
|
405
|
+
worker.onmessage = function(e) {
|
|
406
|
+
worker.terminate();
|
|
407
|
+
resolve(e.data.params);
|
|
408
|
+
console.log(`Loss: ${e.data.loss}`);
|
|
409
|
+
};
|
|
410
|
+
});
|
|
411
|
+
}
|
|
412
|
+
}; // SoftmaxClassifier
|
package/src/svm.ts
CHANGED
|
@@ -75,19 +75,11 @@ const FEATURES_COUNT_NAME = 'Features count';
|
|
|
75
75
|
const TRAIN_SAMPLES_COUNT_NAME = 'Train samples count';
|
|
76
76
|
const TRAIN_ERROR = 'Train error, %';
|
|
77
77
|
const KERNEL_TYPE_TO_NAME_MAP = ['linear', 'polynomial', 'RBF', 'sigmoid'];
|
|
78
|
-
const POSITIVE_NAME = 'positive (P)';
|
|
79
|
-
const NEGATIVE_NAME = 'negative (N)';
|
|
80
|
-
const PREDICTED_POSITIVE_NAME = 'predicted positive (PP)';
|
|
81
|
-
const PREDICTED_NEGATIVE_NAME = 'predicted negative (PN)';
|
|
82
78
|
const SENSITIVITY = 'Sensitivity';
|
|
83
79
|
const SPECIFICITY = 'Specificity';
|
|
84
80
|
const BALANCED_ACCURACY = 'Balanced accuracy';
|
|
85
81
|
const POSITIVE_PREDICTIVE_VALUE = 'Positive predicitve value';
|
|
86
82
|
const NEGATIVE_PREDICTIVE_VALUE = 'Negative predicitve value';
|
|
87
|
-
const ML_REPORT = 'Model report';
|
|
88
|
-
const ML_REPORT_PREDICTED_LABELS = 'Predicted labels';
|
|
89
|
-
const ML_REPORT_TRAIN_LABELS = 'Train labels';
|
|
90
|
-
const ML_REPORT_CORRECTNESS = 'Prediction correctness';
|
|
91
83
|
const PREDICTION = 'prediction';
|
|
92
84
|
|
|
93
85
|
// Pack/unpack constants
|
|
@@ -277,11 +269,11 @@ export async function getTrainedModel(hyperparameters: any, df: DG.DataFrame, la
|
|
|
277
269
|
|
|
278
270
|
if (labels.categories.length != 2)
|
|
279
271
|
throw new Error(WRONG_LABELS_MESSAGE);
|
|
280
|
-
|
|
281
|
-
for (
|
|
272
|
+
const labelNumeric : DG.Column = DG.Column.float(labels.name, labels.length);
|
|
273
|
+
for (let i = 0; i < labels.length; i++)
|
|
282
274
|
labelNumeric.set(i, labels.get(i) == labels.categories[0] ? -1.0 : 1.0, false);
|
|
283
275
|
|
|
284
|
-
|
|
276
|
+
const model = await trainAndAnalyzeModel(hyperparameters, columns, labelNumeric);
|
|
285
277
|
model.realLabels = labels;
|
|
286
278
|
return model;
|
|
287
279
|
}
|
|
@@ -306,19 +298,6 @@ function getModelInfo(model: any): DG.DataFrame {
|
|
|
306
298
|
]);
|
|
307
299
|
}
|
|
308
300
|
|
|
309
|
-
// Get dataframe with confusion matrix
|
|
310
|
-
function getConfusionMatrixDF(model: any): DG.DataFrame {
|
|
311
|
-
const data = model.confusionMatrix.getRawData();
|
|
312
|
-
|
|
313
|
-
return DG.DataFrame.fromColumns([
|
|
314
|
-
DG.Column.fromStrings('', [POSITIVE_NAME, NEGATIVE_NAME]),
|
|
315
|
-
DG.Column.fromList('int', PREDICTED_POSITIVE_NAME,
|
|
316
|
-
[data[TRUE_POSITIVE_INDEX], data[FALSE_POSITIVE_INDEX]]),
|
|
317
|
-
DG.Column.fromList('int', PREDICTED_NEGATIVE_NAME,
|
|
318
|
-
[data[FALSE_NEGATIVE_INDEX], data[TRUE_NEGATIVE_INDEX]]),
|
|
319
|
-
]);
|
|
320
|
-
}
|
|
321
|
-
|
|
322
301
|
// Show training report
|
|
323
302
|
export function showTrainReport(df: DG.DataFrame, packedModel: any): HTMLElement {
|
|
324
303
|
const model = getUnpackedModel(packedModel);
|
|
@@ -335,7 +314,7 @@ export function getPackedModel(model: any): any {
|
|
|
335
314
|
const realLabelsSize = BYTES + realLabelsBuffer.length + 4 - realLabelsBuffer.length % 4;
|
|
336
315
|
const modelInfoBuffer = getModelInfo(model).toByteArray();
|
|
337
316
|
const modelInfoSize = BYTES + modelInfoBuffer.length + 4 - modelInfoBuffer.length % 4;
|
|
338
|
-
|
|
317
|
+
|
|
339
318
|
/*let bufferSize = BYTES * (7 + featuresCount * samplesCount
|
|
340
319
|
+ 3 * featuresCount + 2 * samplesCount);*/
|
|
341
320
|
|
|
@@ -472,7 +451,7 @@ function getUnpackedModel(packedModel: any): any {
|
|
|
472
451
|
offset += BYTES;
|
|
473
452
|
const modelInfo = DG.DataFrame.fromByteArray(new Uint8Array(modelBytes, offset, modelInfoSize));
|
|
474
453
|
offset += modelInfoBytesSize;
|
|
475
|
-
|
|
454
|
+
|
|
476
455
|
const model = {kernelType: header[MODEL_KERNEL_INDEX],
|
|
477
456
|
kernelParams: kernelParams,
|
|
478
457
|
trainLabels: trainLabels,
|
|
@@ -482,7 +461,7 @@ function getUnpackedModel(packedModel: any): any {
|
|
|
482
461
|
modelParams: modelParams,
|
|
483
462
|
modelWeights: modelWeights,
|
|
484
463
|
normalizedTrainData: normalizedTrainData,
|
|
485
|
-
modelInfo: modelInfo
|
|
464
|
+
modelInfo: modelInfo,
|
|
486
465
|
};
|
|
487
466
|
|
|
488
467
|
return model;
|
|
@@ -491,13 +470,12 @@ function getUnpackedModel(packedModel: any): any {
|
|
|
491
470
|
// Wrapper for combining the function "predict" with Datagrok predicitve tools
|
|
492
471
|
export async function getPrediction(df: DG.DataFrame, packedModel: any): Promise<DG.DataFrame> {
|
|
493
472
|
const model = getUnpackedModel(new Uint8Array(packedModel));
|
|
494
|
-
|
|
495
473
|
const resNumeric = await predict(model, df.columns);
|
|
496
474
|
const res = DG.Column.string(PREDICTION, resNumeric.length);
|
|
497
475
|
const categories = model.realLabels.categories;
|
|
498
|
-
for (
|
|
476
|
+
for (let i = 0; i < res.length; i++)
|
|
499
477
|
res.set(i, resNumeric.get(i) == -1 ? categories[0] : categories[1]);
|
|
500
|
-
|
|
478
|
+
|
|
501
479
|
|
|
502
480
|
return DG.DataFrame.fromColumns([res]);
|
|
503
481
|
} // getPrediction
|
|
@@ -507,12 +485,12 @@ export function isApplicableSVM(df: DG.DataFrame, labels: DG.Column): boolean {
|
|
|
507
485
|
const columns = df.columns;
|
|
508
486
|
if (!labels.matches('categorical') || labels.categories.length > 2)
|
|
509
487
|
return false;
|
|
510
|
-
|
|
511
|
-
for (
|
|
488
|
+
let res: boolean = true;
|
|
489
|
+
for (let i = 0; i < columns.length; i++)
|
|
512
490
|
res = res && (columns.byIndex(i).matches('numerical'));
|
|
513
491
|
return res;
|
|
514
492
|
}
|
|
515
493
|
|
|
516
494
|
export function isInteractiveSVM(df: DG.DataFrame, labels: DG.Column): boolean {
|
|
517
495
|
return df.rowCount <= 1000;
|
|
518
|
-
}
|
|
496
|
+
}
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
// Worker for softmax training
|
|
2
|
+
|
|
3
|
+
onmessage = async function(evt) {
|
|
4
|
+
const X = evt.data.features;
|
|
5
|
+
const transposedX = evt.data.transposed;
|
|
6
|
+
const Y = evt.data.oneHot;
|
|
7
|
+
const classesWeights = evt.data.classesWeights;
|
|
8
|
+
const iterations = evt.data.iterations;
|
|
9
|
+
const rate = evt.data.rate;
|
|
10
|
+
const penalty = evt.data.penalty;
|
|
11
|
+
const tolerance = evt.data.tolerance;
|
|
12
|
+
const targetRaw = evt.data.targetRaw;
|
|
13
|
+
let loss = 0;
|
|
14
|
+
let lossPrev = 0;
|
|
15
|
+
|
|
16
|
+
const n = transposedX.length;
|
|
17
|
+
const m = X.length;
|
|
18
|
+
const c = classesWeights.length;
|
|
19
|
+
|
|
20
|
+
// 1. Init params
|
|
21
|
+
// Xavier initialization scale value
|
|
22
|
+
const xavierScale = 2 * Math.sqrt(6.0 / (c + n));
|
|
23
|
+
const params = new Array<Float32Array>(c);
|
|
24
|
+
for (let i = 0; i < c; ++i) {
|
|
25
|
+
const current = new Float32Array(n + 1);
|
|
26
|
+
|
|
27
|
+
// initialize bias, b
|
|
28
|
+
current[n] = 0;
|
|
29
|
+
|
|
30
|
+
//Xavier initialization of weights, w
|
|
31
|
+
for (let j = 0; j < n; ++j)
|
|
32
|
+
current[j] = (Math.random() - 0.5) * xavierScale;
|
|
33
|
+
|
|
34
|
+
params[i] = current;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// 2. Fitting
|
|
38
|
+
|
|
39
|
+
// Routine
|
|
40
|
+
let xBuf: Float32Array;
|
|
41
|
+
let wBuf: Float32Array;
|
|
42
|
+
let zBuf: Float32Array;
|
|
43
|
+
let sum: number;
|
|
44
|
+
let sumExp: number;
|
|
45
|
+
let yTrue: Uint8Array;
|
|
46
|
+
let yPred: Float32Array;
|
|
47
|
+
let dWbuf: Float32Array;
|
|
48
|
+
const Z = new Array<Float32Array>(m);
|
|
49
|
+
for (let i = 0; i < m; ++i)
|
|
50
|
+
Z[i] = new Float32Array(c);
|
|
51
|
+
const dZ = new Array<Float32Array>(c);
|
|
52
|
+
for (let i = 0; i < c; ++i)
|
|
53
|
+
dZ[i] = new Float32Array(m);
|
|
54
|
+
const dW = new Array<Float32Array>(c);
|
|
55
|
+
for (let i = 0; i < c; ++i)
|
|
56
|
+
dW[i] = new Float32Array(n + 1);
|
|
57
|
+
|
|
58
|
+
// Fitting
|
|
59
|
+
for (let iter = 0; iter < iterations; ++iter) {
|
|
60
|
+
// 2.1) Forward propagation
|
|
61
|
+
for (let j = 0; j < m; ++j) {
|
|
62
|
+
xBuf = X[j];
|
|
63
|
+
zBuf = Z[j];
|
|
64
|
+
sum = 0;
|
|
65
|
+
sumExp = 0;
|
|
66
|
+
|
|
67
|
+
for (let i = 0; i < c; ++i) {
|
|
68
|
+
wBuf = params[i];
|
|
69
|
+
sum = wBuf[n];
|
|
70
|
+
|
|
71
|
+
for (let k = 0; k < n; ++k)
|
|
72
|
+
sum += wBuf[k] * xBuf[k];
|
|
73
|
+
|
|
74
|
+
zBuf[i] = Math.exp(sum) * classesWeights[i];
|
|
75
|
+
sumExp += zBuf[i];
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
for (let i = 0; i < c; ++i)
|
|
79
|
+
zBuf[i] /= sumExp;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
// 2.2) Loss
|
|
83
|
+
loss = 0;
|
|
84
|
+
|
|
85
|
+
for (let i = 0; i < m; ++i)
|
|
86
|
+
loss += -Math.log(Z[i][targetRaw[i]]);
|
|
87
|
+
|
|
88
|
+
loss /= m;
|
|
89
|
+
|
|
90
|
+
if (Math.abs(loss - lossPrev) < tolerance)
|
|
91
|
+
break;
|
|
92
|
+
|
|
93
|
+
lossPrev = loss;
|
|
94
|
+
|
|
95
|
+
// 2.3) Backward propagation
|
|
96
|
+
|
|
97
|
+
// 2.3.1) dZ
|
|
98
|
+
for (let j = 0; j < m; ++j) {
|
|
99
|
+
yPred = Z[j];
|
|
100
|
+
yTrue = Y[j];
|
|
101
|
+
|
|
102
|
+
for (let i = 0; i < c; ++i)
|
|
103
|
+
dZ[i][j] = yPred[i] - yTrue[i];
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
// 2.3.2) dB
|
|
107
|
+
for (let i = 0; i < c; ++i) {
|
|
108
|
+
sum = 0;
|
|
109
|
+
zBuf = dZ[i];
|
|
110
|
+
|
|
111
|
+
for (let j = 0; j < m; ++j)
|
|
112
|
+
sum += zBuf[j];
|
|
113
|
+
|
|
114
|
+
dW[i][n] = sum / m;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// 2.3.3) dW
|
|
118
|
+
for (let i = 0; i < c; ++i) {
|
|
119
|
+
zBuf = dZ[i];
|
|
120
|
+
wBuf = dW[i];
|
|
121
|
+
|
|
122
|
+
for (let j = 0; j < n; ++j) {
|
|
123
|
+
xBuf = transposedX[j];
|
|
124
|
+
|
|
125
|
+
sum = 0;
|
|
126
|
+
for (let k = 0; k < m; ++k)
|
|
127
|
+
sum += zBuf[k] * xBuf[k];
|
|
128
|
+
|
|
129
|
+
wBuf[j] = sum / m;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
// 2.4) Update weights
|
|
134
|
+
for (let i = 0; i < c; ++i) {
|
|
135
|
+
wBuf = params[i];
|
|
136
|
+
dWbuf = dW[i];
|
|
137
|
+
|
|
138
|
+
for (let j = 0; j < n; ++j)
|
|
139
|
+
wBuf[j] = (1 - rate * penalty / m) * wBuf[j] - rate * dWbuf[j];
|
|
140
|
+
|
|
141
|
+
wBuf[n] -= rate * dWbuf[n];
|
|
142
|
+
}
|
|
143
|
+
} // for iter
|
|
144
|
+
|
|
145
|
+
postMessage({'params': params, 'loss': loss});
|
|
146
|
+
};
|