@datagrok/eda 1.1.28 → 1.1.29

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,412 @@
1
+ // Softmax classifier (multinomial logistic regression): https://en.wikipedia.org/wiki/Multinomial_logistic_regression
2
+
3
+ import * as grok from 'datagrok-api/grok';
4
+ import * as ui from 'datagrok-api/ui';
5
+ import * as DG from 'datagrok-api/dg';
6
+
7
+ import {_fitSoftmax} from '../wasm/EDAAPI';
8
+
9
+ const ROWS_EXTRA = 1;
10
+ const COLS_EXTRA = 2;
11
+ const MIN_COLS_COUNT = 1 + COLS_EXTRA;
12
+ const AVGS_NAME = 'Avg-s';
13
+ const STDEVS_NAME = 'Stddev-s';
14
+ const PRED_NAME = 'predicted';
15
+ const DEFAULT_LEARNING_RATE = 1;
16
+ const DEFAULT_ITER_COUNT = 100;
17
+ const DEFAULT_PENALTY = 0.1;
18
+ const DEFAULT_TOLERANCE = 0.001;
19
+ const BYTES_PER_MODEL_SIZE = 4;
20
+
21
+ /** Train data sizes */
22
+ type DataSpecification = {
23
+ classesCount: number,
24
+ featuresCount: number,
25
+ };
26
+
27
+ /** Target labels specification */
28
+ type TargetLabelsData = {
29
+ oneHot: Array<Uint8Array>,
30
+ weights: Uint32Array,
31
+ };
32
+
33
+ /** Interactivity tresholds */
34
+ enum INTERACTIVITY {
35
+ MAX_SAMLPES = 50000,
36
+ MAX_FEATURES = 100,
37
+ };
38
+
39
+ /** Softmax classifier */
40
+ export class SoftmaxClassifier {
41
+ /** Check applicability */
42
+ static isApplicable(features: DG.ColumnList, predictColumn: DG.Column): boolean {
43
+ for (const col of features) {
44
+ if (!col.matches('numerical'))
45
+ return false;
46
+ }
47
+
48
+ return (predictColumn.type === DG.COLUMN_TYPE.STRING);
49
+ }
50
+
51
+ /** Check interactivity */
52
+ static isInteractive(features: DG.ColumnList, predictColumn: DG.Column): boolean {
53
+ return (features.length <= INTERACTIVITY.MAX_FEATURES) &&
54
+ (predictColumn.length <= INTERACTIVITY.MAX_SAMLPES);
55
+ }
56
+
57
+ private avgs: Float32Array;
58
+ private stdevs: Float32Array;
59
+ private categories: string[];
60
+ private params: Float32Array[] | undefined = undefined;
61
+ private classesCount = 1;
62
+ private featuresCount = 1;
63
+
64
+ constructor(specification?: DataSpecification, packedModel?: Uint8Array) {
65
+ if (specification !== undefined) { // Create empty model
66
+ /** features count */
67
+ const n = specification.featuresCount;
68
+
69
+ /** classes count */
70
+ const c = specification.classesCount;
71
+
72
+ if (n < 1)
73
+ throw new Error('Incorrect features count');
74
+
75
+ if (c < 1)
76
+ throw new Error('Incorrect classes count');
77
+
78
+ /** length of arrays */
79
+ const len = n + ROWS_EXTRA;
80
+
81
+ // Init model routine
82
+ this.avgs = new Float32Array(len);
83
+ this.stdevs = new Float32Array(len);
84
+ this.categories = new Array<string>(len);
85
+ this.featuresCount = n;
86
+ this.classesCount = c;
87
+ } else if (packedModel !== undefined) { // Get classifier from packed model (bytes)
88
+ try {
89
+ // Extract model's bytes count
90
+ const sizeArr = new Uint32Array(packedModel.buffer, 0, 1);
91
+ const bytesCount = sizeArr[0];
92
+
93
+ // Model's bytes
94
+ const modelBytes = new Uint8Array(packedModel.buffer, BYTES_PER_MODEL_SIZE, bytesCount);
95
+
96
+ const modelDf = DG.DataFrame.fromByteArray(modelBytes);
97
+ const columns = modelDf.columns;
98
+ const colsCount = columns.length;
99
+
100
+ if (colsCount < MIN_COLS_COUNT)
101
+ throw new Error('incorrect columns count');
102
+
103
+ this.classesCount = colsCount - COLS_EXTRA;
104
+ this.featuresCount = modelDf.rowCount - ROWS_EXTRA;
105
+
106
+ const c = this.classesCount;
107
+
108
+ // extract params & categories
109
+ this.params = new Array<Float32Array>(c);
110
+ this.categories = new Array<string>(modelDf.rowCount);
111
+
112
+ for (let i = 0; i < c; ++i) {
113
+ const col = columns.byIndex(i);
114
+ this.categories[i] = col.name;
115
+
116
+ if (col.type !== DG.COLUMN_TYPE.FLOAT)
117
+ throw new Error(`Incorrect input column type. Expected: float, passed: ${col.type}`);
118
+
119
+ this.params[i] = col.getRawData() as Float32Array;
120
+ }
121
+
122
+ // extract averages
123
+ const avgsCol = columns.byName(AVGS_NAME);
124
+ if (avgsCol.type !== DG.COLUMN_TYPE.FLOAT)
125
+ throw new Error('incorrect average values column type');
126
+ this.avgs = avgsCol.getRawData() as Float32Array;
127
+
128
+ // extract stdevs
129
+ const stdevsCol = columns.byName(STDEVS_NAME);
130
+ if (stdevsCol.type !== DG.COLUMN_TYPE.FLOAT)
131
+ throw new Error('incorrect standard deviations column type');
132
+ this.stdevs = stdevsCol.getRawData() as Float32Array;
133
+ } catch (e) {
134
+ throw new Error(`Failed to load model: ${(e instanceof Error ? e.message : 'the platform issue')}`);
135
+ }
136
+ } else
137
+ throw new Error('Softmax classifier not initialized');
138
+ }; // constructor
139
+
140
+ /** Return packed softmax classifier */
141
+ public toBytes(): Uint8Array {
142
+ if (this.params === undefined)
143
+ throw new Error('Non-trained model');
144
+
145
+ const c = this.classesCount;
146
+ const columns = new Array<DG.Column>(c + COLS_EXTRA);
147
+
148
+ // params columns
149
+ for (let i = 0; i < c; ++i)
150
+ columns[i] = DG.Column.fromFloat32Array(this.categories[i], this.params[i]);
151
+
152
+ // averages
153
+ columns[c] = DG.Column.fromFloat32Array(AVGS_NAME, this.avgs);
154
+
155
+ // stdevs
156
+ columns[c + 1] = DG.Column.fromFloat32Array(STDEVS_NAME, this.stdevs);
157
+
158
+ const modelDf = DG.DataFrame.fromColumns(columns);
159
+
160
+ const modelBytes = modelDf.toByteArray();
161
+ const bytesCount = modelBytes.length;
162
+
163
+ // Packed model bytes, including bytes count
164
+ const packedModel = new Uint8Array(bytesCount + BYTES_PER_MODEL_SIZE);
165
+
166
+ // 4 bytes for storing model's bytes count
167
+ const sizeArr = new Uint32Array(packedModel.buffer, 0, 1);
168
+ sizeArr[0] = bytesCount;
169
+
170
+ // Store model's bytes
171
+ packedModel.set(modelBytes, BYTES_PER_MODEL_SIZE);
172
+
173
+ return packedModel;
174
+ } // toBytes
175
+
176
+ /** Train classifier */
177
+ public async fit(features: DG.ColumnList, target: DG.Column, rate: number = DEFAULT_LEARNING_RATE,
178
+ iterations: number = DEFAULT_ITER_COUNT, penalty: number = DEFAULT_PENALTY, tolerance: number = DEFAULT_TOLERANCE) {
179
+ if (features.length !== this.featuresCount)
180
+ throw new Error('Training failes - incorrect features count');
181
+
182
+ if ((rate <= 0) || (iterations < 1) || (penalty <= 0) || (tolerance <= 0))
183
+ throw new Error('Training failes - incorrect fitting hyperparameters');
184
+
185
+ // Extract statistics & categories
186
+ this.extractStats(features);
187
+ const rowsCount = target.length;
188
+ const classesCount = target.categories.length;
189
+ const cats = target.categories;
190
+ for (let i = 0; i < classesCount; ++i)
191
+ this.categories[i] = cats[i];
192
+
193
+ try {
194
+ // call wasm-computations
195
+ const paramCols = _fitSoftmax(
196
+ features,
197
+ DG.Column.fromFloat32Array('avgs', this.avgs, this.featuresCount),
198
+ DG.Column.fromFloat32Array('stdevs', this.stdevs, this.featuresCount),
199
+ DG.Column.fromInt32Array('targets', target.getRawData() as Int32Array, rowsCount),
200
+ classesCount,
201
+ iterations, rate, penalty, tolerance,
202
+ this.featuresCount + 1, classesCount,
203
+ ).columns as DG.ColumnList;
204
+
205
+ this.params = new Array<Float32Array>(classesCount);
206
+ for (let i = 0; i < classesCount; ++i)
207
+ this.params[i] = paramCols.byIndex(i).getRawData() as Float32Array;
208
+ } catch (error) {
209
+ try { // call fitting TS-computations (if wasm failed)
210
+ this.params = await this.fitSoftmaxParams(
211
+ features,
212
+ target,
213
+ iterations,
214
+ rate,
215
+ penalty,
216
+ tolerance,
217
+ ) as Float32Array[];
218
+ } catch (error) {
219
+ throw new Error('Training failes');
220
+ }
221
+ }
222
+
223
+ if (this.params === undefined)
224
+ throw new Error('Training failes');
225
+ }; // fit
226
+
227
+ /** Extract features' stats */
228
+ private extractStats(features: DG.ColumnList): void {
229
+ let j = 0;
230
+
231
+ for (const col of features) {
232
+ if ((col.type !== DG.COLUMN_TYPE.INT) && (col.type !== DG.COLUMN_TYPE.FLOAT))
233
+ throw new Error('Training failes - incorrect features type');
234
+
235
+ this.avgs[j] = col.stats.avg;
236
+ this.stdevs[j] = col.stats.stdev;
237
+
238
+ ++j;
239
+ }
240
+ } // extractStats
241
+
242
+ /** Retrun normalized features */
243
+ private normalized(features: DG.ColumnList): Array<Float32Array> {
244
+ const m = features.byIndex(0).length;
245
+
246
+ const X = new Array<Float32Array>(m);
247
+
248
+ for (let i = 0; i < m; ++i)
249
+ X[i] = new Float32Array(this.featuresCount);
250
+
251
+ let j = 0;
252
+ for (const col of features) {
253
+ if ((col.type !== DG.COLUMN_TYPE.INT) && (col.type !== DG.COLUMN_TYPE.FLOAT))
254
+ throw new Error('Training failes - incorrect features type');
255
+
256
+ const raw = col.getRawData();
257
+ const avg = this.avgs[j];
258
+ const stdev = this.stdevs[j];
259
+
260
+ if (stdev > 0) {
261
+ for (let i = 0; i < m; ++i)
262
+ X[i][j] = (raw[i] - avg) / stdev;
263
+ } else {
264
+ for (let i = 0; i < m; ++i)
265
+ X[i][j] = 0;
266
+ }
267
+
268
+ ++j;
269
+ }
270
+
271
+ return X;
272
+ } // normalized
273
+
274
+ /** Retrun normalized & transposed features */
275
+ private transposed(features: DG.ColumnList): Array<Float32Array> {
276
+ const m = features.byIndex(0).length;
277
+ const n = this.featuresCount;
278
+
279
+ const X = new Array<Float32Array>(n);
280
+
281
+ for (let i = 0; i < n; ++i)
282
+ X[i] = new Float32Array(m);
283
+
284
+ let j = 0;
285
+ for (const col of features) {
286
+ if ((col.type !== DG.COLUMN_TYPE.INT) && (col.type !== DG.COLUMN_TYPE.FLOAT))
287
+ throw new Error('Training failes - incorrect features type');
288
+
289
+ const raw = col.getRawData();
290
+ const avg = this.avgs[j];
291
+ const stdev = this.stdevs[j];
292
+
293
+ if (stdev > 0) {
294
+ for (let i = 0; i < m; ++i)
295
+ X[j][i] = (raw[i] - avg) / stdev;
296
+ } else {
297
+ for (let i = 0; i < m; ++i)
298
+ X[j][i] = 0;
299
+ }
300
+
301
+ ++j;
302
+ }
303
+
304
+ return X;
305
+ } // transposed
306
+
307
+ /** Return one-hot vectors and classes weights */
308
+ private preprocessedTargets(target: DG.Column): TargetLabelsData {
309
+ if (target.type !== DG.COLUMN_TYPE.STRING)
310
+ throw new Error('Training failes - incorrect target type');
311
+
312
+ const c = this.classesCount;
313
+ const m = target.length;
314
+ const raw = target.getRawData();
315
+
316
+ const Y = new Array<Uint8Array>(m);
317
+ const weights = new Uint32Array(c).fill(0);
318
+
319
+ for (let i = 0; i < m; ++i)
320
+ Y[i] = new Uint8Array(c).fill(0);
321
+
322
+ for (let i = 0; i < m; ++i) {
323
+ Y[i][raw[i]] = 1;
324
+ ++weights[raw[i]];
325
+ }
326
+
327
+ return {
328
+ oneHot: Y,
329
+ weights: weights,
330
+ };
331
+ } // getOneHot
332
+
333
+ /** Return prediction column */
334
+ public predict(features: DG.ColumnList): DG.Column {
335
+ if (this.params === undefined)
336
+ throw new Error('Non-trained model');
337
+
338
+ if (features.length !== this.featuresCount)
339
+ throw new Error('Predcition fails: incorrect features count');
340
+
341
+ // Normalize features
342
+ const X = this.normalized(features);
343
+
344
+ // Routine items
345
+ const m = X.length;
346
+ const n = this.featuresCount;
347
+ const c = this.classesCount;
348
+ let xBuf: Float32Array;
349
+ let wBuf: Float32Array;
350
+ const Z = new Float32Array(c);
351
+ let sum: number;
352
+ let max: number;
353
+ let argMax: number;
354
+ const predClass = new Array<string>(m);
355
+
356
+ // get prediction for each sample
357
+ for (let j = 0; j < m; ++j) {
358
+ xBuf = X[j];
359
+ sum = 0;
360
+
361
+ for (let i = 0; i < c; ++i) {
362
+ wBuf = this.params[i];
363
+ sum = wBuf[n];
364
+
365
+ for (let k = 0; k < n; ++k)
366
+ sum += wBuf[k] * xBuf[k];
367
+
368
+ Z[i] = Math.exp(sum);
369
+ }
370
+
371
+ max = Z[0];
372
+ argMax = 0;
373
+
374
+ for (let k = 1; k < c; ++k) {
375
+ if (max < Z[k]) {
376
+ max = Z[k];
377
+ argMax = k;
378
+ }
379
+ }
380
+
381
+ predClass[j] = this.categories[argMax];
382
+ }
383
+
384
+ return DG.Column.fromStrings(PRED_NAME, predClass);
385
+ }
386
+
387
+ /** Fit params in the webworker */
388
+ private async fitSoftmaxParams(features: DG.ColumnList, target: DG.Column,
389
+ iterations: number, rate: number, penalty: number, tolerance: number) {
390
+ const targetData = this.preprocessedTargets(target);
391
+
392
+ return new Promise((resolve, reject) => {
393
+ const worker = new Worker(new URL('./workers/softmax-worker.ts', import.meta.url));
394
+ worker.postMessage({
395
+ features: this.normalized(features),
396
+ transposed: this.transposed(features),
397
+ oneHot: targetData.oneHot,
398
+ classesWeights: targetData.weights,
399
+ targetRaw: target.getRawData(),
400
+ iterations: iterations,
401
+ rate: rate,
402
+ penalty: penalty,
403
+ tolerance: tolerance,
404
+ });
405
+ worker.onmessage = function(e) {
406
+ worker.terminate();
407
+ resolve(e.data.params);
408
+ console.log(`Loss: ${e.data.loss}`);
409
+ };
410
+ });
411
+ }
412
+ }; // SoftmaxClassifier
package/src/svm.ts CHANGED
@@ -75,19 +75,11 @@ const FEATURES_COUNT_NAME = 'Features count';
75
75
  const TRAIN_SAMPLES_COUNT_NAME = 'Train samples count';
76
76
  const TRAIN_ERROR = 'Train error, %';
77
77
  const KERNEL_TYPE_TO_NAME_MAP = ['linear', 'polynomial', 'RBF', 'sigmoid'];
78
- const POSITIVE_NAME = 'positive (P)';
79
- const NEGATIVE_NAME = 'negative (N)';
80
- const PREDICTED_POSITIVE_NAME = 'predicted positive (PP)';
81
- const PREDICTED_NEGATIVE_NAME = 'predicted negative (PN)';
82
78
  const SENSITIVITY = 'Sensitivity';
83
79
  const SPECIFICITY = 'Specificity';
84
80
  const BALANCED_ACCURACY = 'Balanced accuracy';
85
81
  const POSITIVE_PREDICTIVE_VALUE = 'Positive predicitve value';
86
82
  const NEGATIVE_PREDICTIVE_VALUE = 'Negative predicitve value';
87
- const ML_REPORT = 'Model report';
88
- const ML_REPORT_PREDICTED_LABELS = 'Predicted labels';
89
- const ML_REPORT_TRAIN_LABELS = 'Train labels';
90
- const ML_REPORT_CORRECTNESS = 'Prediction correctness';
91
83
  const PREDICTION = 'prediction';
92
84
 
93
85
  // Pack/unpack constants
@@ -277,11 +269,11 @@ export async function getTrainedModel(hyperparameters: any, df: DG.DataFrame, la
277
269
 
278
270
  if (labels.categories.length != 2)
279
271
  throw new Error(WRONG_LABELS_MESSAGE);
280
- let labelNumeric : DG.Column = DG.Column.float(labels.name, labels.length);
281
- for (var i = 0; i < labels.length; i++)
272
+ const labelNumeric : DG.Column = DG.Column.float(labels.name, labels.length);
273
+ for (let i = 0; i < labels.length; i++)
282
274
  labelNumeric.set(i, labels.get(i) == labels.categories[0] ? -1.0 : 1.0, false);
283
275
 
284
- let model = await trainAndAnalyzeModel(hyperparameters, columns, labelNumeric);
276
+ const model = await trainAndAnalyzeModel(hyperparameters, columns, labelNumeric);
285
277
  model.realLabels = labels;
286
278
  return model;
287
279
  }
@@ -306,19 +298,6 @@ function getModelInfo(model: any): DG.DataFrame {
306
298
  ]);
307
299
  }
308
300
 
309
- // Get dataframe with confusion matrix
310
- function getConfusionMatrixDF(model: any): DG.DataFrame {
311
- const data = model.confusionMatrix.getRawData();
312
-
313
- return DG.DataFrame.fromColumns([
314
- DG.Column.fromStrings('', [POSITIVE_NAME, NEGATIVE_NAME]),
315
- DG.Column.fromList('int', PREDICTED_POSITIVE_NAME,
316
- [data[TRUE_POSITIVE_INDEX], data[FALSE_POSITIVE_INDEX]]),
317
- DG.Column.fromList('int', PREDICTED_NEGATIVE_NAME,
318
- [data[FALSE_NEGATIVE_INDEX], data[TRUE_NEGATIVE_INDEX]]),
319
- ]);
320
- }
321
-
322
301
  // Show training report
323
302
  export function showTrainReport(df: DG.DataFrame, packedModel: any): HTMLElement {
324
303
  const model = getUnpackedModel(packedModel);
@@ -335,7 +314,7 @@ export function getPackedModel(model: any): any {
335
314
  const realLabelsSize = BYTES + realLabelsBuffer.length + 4 - realLabelsBuffer.length % 4;
336
315
  const modelInfoBuffer = getModelInfo(model).toByteArray();
337
316
  const modelInfoSize = BYTES + modelInfoBuffer.length + 4 - modelInfoBuffer.length % 4;
338
-
317
+
339
318
  /*let bufferSize = BYTES * (7 + featuresCount * samplesCount
340
319
  + 3 * featuresCount + 2 * samplesCount);*/
341
320
 
@@ -472,7 +451,7 @@ function getUnpackedModel(packedModel: any): any {
472
451
  offset += BYTES;
473
452
  const modelInfo = DG.DataFrame.fromByteArray(new Uint8Array(modelBytes, offset, modelInfoSize));
474
453
  offset += modelInfoBytesSize;
475
-
454
+
476
455
  const model = {kernelType: header[MODEL_KERNEL_INDEX],
477
456
  kernelParams: kernelParams,
478
457
  trainLabels: trainLabels,
@@ -482,7 +461,7 @@ function getUnpackedModel(packedModel: any): any {
482
461
  modelParams: modelParams,
483
462
  modelWeights: modelWeights,
484
463
  normalizedTrainData: normalizedTrainData,
485
- modelInfo: modelInfo
464
+ modelInfo: modelInfo,
486
465
  };
487
466
 
488
467
  return model;
@@ -491,13 +470,12 @@ function getUnpackedModel(packedModel: any): any {
491
470
  // Wrapper for combining the function "predict" with Datagrok predicitve tools
492
471
  export async function getPrediction(df: DG.DataFrame, packedModel: any): Promise<DG.DataFrame> {
493
472
  const model = getUnpackedModel(new Uint8Array(packedModel));
494
-
495
473
  const resNumeric = await predict(model, df.columns);
496
474
  const res = DG.Column.string(PREDICTION, resNumeric.length);
497
475
  const categories = model.realLabels.categories;
498
- for (var i = 0; i < res.length; i++) {
476
+ for (let i = 0; i < res.length; i++)
499
477
  res.set(i, resNumeric.get(i) == -1 ? categories[0] : categories[1]);
500
- }
478
+
501
479
 
502
480
  return DG.DataFrame.fromColumns([res]);
503
481
  } // getPrediction
@@ -507,12 +485,12 @@ export function isApplicableSVM(df: DG.DataFrame, labels: DG.Column): boolean {
507
485
  const columns = df.columns;
508
486
  if (!labels.matches('categorical') || labels.categories.length > 2)
509
487
  return false;
510
- var res: boolean = true;
511
- for (var i = 0; i < columns.length; i++)
488
+ let res: boolean = true;
489
+ for (let i = 0; i < columns.length; i++)
512
490
  res = res && (columns.byIndex(i).matches('numerical'));
513
491
  return res;
514
492
  }
515
493
 
516
494
  export function isInteractiveSVM(df: DG.DataFrame, labels: DG.Column): boolean {
517
495
  return df.rowCount <= 1000;
518
- }
496
+ }
@@ -0,0 +1,146 @@
1
+ // Worker for softmax training
2
+
3
+ onmessage = async function(evt) {
4
+ const X = evt.data.features;
5
+ const transposedX = evt.data.transposed;
6
+ const Y = evt.data.oneHot;
7
+ const classesWeights = evt.data.classesWeights;
8
+ const iterations = evt.data.iterations;
9
+ const rate = evt.data.rate;
10
+ const penalty = evt.data.penalty;
11
+ const tolerance = evt.data.tolerance;
12
+ const targetRaw = evt.data.targetRaw;
13
+ let loss = 0;
14
+ let lossPrev = 0;
15
+
16
+ const n = transposedX.length;
17
+ const m = X.length;
18
+ const c = classesWeights.length;
19
+
20
+ // 1. Init params
21
+ // Xavier initialization scale value
22
+ const xavierScale = 2 * Math.sqrt(6.0 / (c + n));
23
+ const params = new Array<Float32Array>(c);
24
+ for (let i = 0; i < c; ++i) {
25
+ const current = new Float32Array(n + 1);
26
+
27
+ // initialize bias, b
28
+ current[n] = 0;
29
+
30
+ //Xavier initialization of weights, w
31
+ for (let j = 0; j < n; ++j)
32
+ current[j] = (Math.random() - 0.5) * xavierScale;
33
+
34
+ params[i] = current;
35
+ }
36
+
37
+ // 2. Fitting
38
+
39
+ // Routine
40
+ let xBuf: Float32Array;
41
+ let wBuf: Float32Array;
42
+ let zBuf: Float32Array;
43
+ let sum: number;
44
+ let sumExp: number;
45
+ let yTrue: Uint8Array;
46
+ let yPred: Float32Array;
47
+ let dWbuf: Float32Array;
48
+ const Z = new Array<Float32Array>(m);
49
+ for (let i = 0; i < m; ++i)
50
+ Z[i] = new Float32Array(c);
51
+ const dZ = new Array<Float32Array>(c);
52
+ for (let i = 0; i < c; ++i)
53
+ dZ[i] = new Float32Array(m);
54
+ const dW = new Array<Float32Array>(c);
55
+ for (let i = 0; i < c; ++i)
56
+ dW[i] = new Float32Array(n + 1);
57
+
58
+ // Fitting
59
+ for (let iter = 0; iter < iterations; ++iter) {
60
+ // 2.1) Forward propagation
61
+ for (let j = 0; j < m; ++j) {
62
+ xBuf = X[j];
63
+ zBuf = Z[j];
64
+ sum = 0;
65
+ sumExp = 0;
66
+
67
+ for (let i = 0; i < c; ++i) {
68
+ wBuf = params[i];
69
+ sum = wBuf[n];
70
+
71
+ for (let k = 0; k < n; ++k)
72
+ sum += wBuf[k] * xBuf[k];
73
+
74
+ zBuf[i] = Math.exp(sum) * classesWeights[i];
75
+ sumExp += zBuf[i];
76
+ }
77
+
78
+ for (let i = 0; i < c; ++i)
79
+ zBuf[i] /= sumExp;
80
+ }
81
+
82
+ // 2.2) Loss
83
+ loss = 0;
84
+
85
+ for (let i = 0; i < m; ++i)
86
+ loss += -Math.log(Z[i][targetRaw[i]]);
87
+
88
+ loss /= m;
89
+
90
+ if (Math.abs(loss - lossPrev) < tolerance)
91
+ break;
92
+
93
+ lossPrev = loss;
94
+
95
+ // 2.3) Backward propagation
96
+
97
+ // 2.3.1) dZ
98
+ for (let j = 0; j < m; ++j) {
99
+ yPred = Z[j];
100
+ yTrue = Y[j];
101
+
102
+ for (let i = 0; i < c; ++i)
103
+ dZ[i][j] = yPred[i] - yTrue[i];
104
+ }
105
+
106
+ // 2.3.2) dB
107
+ for (let i = 0; i < c; ++i) {
108
+ sum = 0;
109
+ zBuf = dZ[i];
110
+
111
+ for (let j = 0; j < m; ++j)
112
+ sum += zBuf[j];
113
+
114
+ dW[i][n] = sum / m;
115
+ }
116
+
117
+ // 2.3.3) dW
118
+ for (let i = 0; i < c; ++i) {
119
+ zBuf = dZ[i];
120
+ wBuf = dW[i];
121
+
122
+ for (let j = 0; j < n; ++j) {
123
+ xBuf = transposedX[j];
124
+
125
+ sum = 0;
126
+ for (let k = 0; k < m; ++k)
127
+ sum += zBuf[k] * xBuf[k];
128
+
129
+ wBuf[j] = sum / m;
130
+ }
131
+ }
132
+
133
+ // 2.4) Update weights
134
+ for (let i = 0; i < c; ++i) {
135
+ wBuf = params[i];
136
+ dWbuf = dW[i];
137
+
138
+ for (let j = 0; j < n; ++j)
139
+ wBuf[j] = (1 - rate * penalty / m) * wBuf[j] - rate * dWbuf[j];
140
+
141
+ wBuf[n] -= rate * dWbuf[n];
142
+ }
143
+ } // for iter
144
+
145
+ postMessage({'params': params, 'loss': loss});
146
+ };