@oxide-js/spiking 1.1.0 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,42 +2,45 @@ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
2
2
  import { Matrix, mj } from "@oxide-js/core";
3
3
  import {
4
4
  isNativeAvailable,
5
- lifStepNativeWrapper,
6
- maskSurrogateNativeWrapper
5
+ lifStepNativeWrapper,
6
+ maskSurrogateNativeWrapper,
7
+ applyEmbeddingDeltaNativeWrapper
7
8
  } from "../native_backend.js";
8
9
 
9
10
  export interface SpikingEmbeddingConfig extends LayerConfig {
10
- inputDim: number; // Ukuran vocabulary
11
- outputDim: number; // Dimensi embedding (jumlah neuron)
12
- beta?: number; // Decay factor LIF
13
- threshold?: number; // Ambang batas Spike
14
- embeddingsInitializer?: string; // Tipe inisialisasi bobot
11
+ inputDim: number;
12
+ outputDim: number;
13
+ embeddingsInitializer?: string;
14
+ betaRange?: [number, number];
15
+ thresholdRange?: [number, number];
15
16
  }
16
17
 
17
18
  export class SpikingEmbedding extends BaseLayer {
18
19
  public inputDim: number;
19
20
  public outputDim: number;
20
- public beta: number;
21
- public threshold: number;
21
+ public embeddingsInitializer: string;
22
+
23
+ public betaRange: [number, number];
24
+ public thresholdRange: [number, number];
25
+ public beta!: Float32Array;
26
+ public threshold!: Float32Array;
22
27
 
23
28
  public potentials!: Matrix;
24
29
  public lastPotentials?: Matrix;
25
30
  public lastInputs?: Matrix;
26
31
  public lastSpikes?: Matrix;
27
32
 
28
- public embeddingsInitializer: string;
29
-
30
- public get kernel(): Matrix | undefined {
31
- return this.getParameter("kernel");
33
+ public get embeddings(): Matrix | undefined {
34
+ return this.getParameter("embeddings");
32
35
  }
33
36
 
34
37
  constructor(config: SpikingEmbeddingConfig) {
35
38
  super(config);
36
39
  this.inputDim = config.inputDim;
37
40
  this.outputDim = config.outputDim;
38
- this.beta = config.beta ?? 0.9;
39
- this.threshold = config.threshold ?? 1.0;
40
41
  this.embeddingsInitializer = config.embeddingsInitializer || "glorot_normal";
42
+ this.betaRange = config.betaRange || [0.8, 0.99];
43
+ this.thresholdRange = config.thresholdRange || [0.01, 0.1];
41
44
  }
42
45
 
43
46
  public computeOutputShape(inputShape: number[]): number[] {
@@ -47,181 +50,190 @@ export class SpikingEmbedding extends BaseLayer {
47
50
 
48
51
  public build(inputShape: number[]): void {
49
52
  super.build(inputShape);
50
- const kernelVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
51
- this.addParameter("kernel", kernelVal, true, [this.inputDim, this.outputDim]);
53
+
54
+ const embVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
55
+
56
+ // Scale up the embedding values because Glorot Normal makes them too small for large vocabularies (e.g. 32000)
57
+ // which prevents the LIF neurons from ever reaching the threshold.
58
+ const scaleFactor = Math.sqrt(this.inputDim);
59
+ for (let i = 0; i < embVal._data.length; i++) {
60
+ embVal._data[i] *= scaleFactor;
61
+ }
62
+
63
+ this.addParameter("embeddings", embVal, true, [this.inputDim, this.outputDim]);
64
+
65
+ // Inisialisasi beta dan threshold secara acak untuk setiap neuron
66
+ this.beta = new Float32Array(this.outputDim);
67
+ this.threshold = new Float32Array(this.outputDim);
68
+ for (let i = 0; i < this.outputDim; i++) {
69
+ this.beta[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
70
+ this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
71
+ }
72
+
73
+ // Potentials start at 0, shape [batch, outputDim].
74
+ this.potentials = Matrix.fromFlat(new Float32Array(this.outputDim), [1, this.outputDim]);
75
+ }
76
+
77
+ private dotDataBuffer?: Float32Array;
78
+ private outDataBuffer?: Float32Array;
79
+
80
+ private ensurePotentialsShape(batch: number) {
81
+ if (this.potentials._shape[0] !== batch || !this.dotDataBuffer) {
82
+ this.potentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
83
+ this.dotDataBuffer = new Float32Array(batch * this.outputDim);
84
+ this.outDataBuffer = new Float32Array(batch * this.outputDim);
85
+ this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
86
+ }
52
87
  }
53
88
 
54
89
  public resetState() {
55
90
  if (this.potentials) this.potentials._data.fill(0);
56
- this.lastPotentials = undefined;
91
+ if (this.lastPotentials) this.lastPotentials._data.fill(0);
57
92
  this.lastInputs = undefined;
58
93
  this.lastSpikes = undefined;
59
94
  }
60
95
 
61
- private ensurePotentialsShape(batch: number) {
62
- if (!this.potentials || this.potentials._shape[0] !== batch) {
63
- this.potentials = Matrix.fromFlat(
64
- new Float32Array(batch * this.outputDim),
65
- [batch, this.outputDim]
66
- );
67
- }
68
- }
69
-
70
96
  protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
71
- const kernel = this.kernel!._data;
72
97
  const batch = inputs._shape[0];
73
- const inputData = inputs._data;
74
-
75
98
  this.ensurePotentialsShape(batch);
76
-
77
- // 1. Lookup Row (Pengganti dot-product)
78
- const dotData = new Float32Array(batch * this.outputDim);
99
+
100
+ // 1. Embedding lookup
101
+ const emb = this.embeddings!;
102
+ const dotData = this.dotDataBuffer!;
103
+ dotData.fill(0);
104
+
79
105
  for (let b = 0; b < batch; b++) {
80
- const tokenId = Math.round(inputData[b]); // Asumsi input adalah ID token berukuran [batch, 1]
81
-
82
- // Jika token valid, ekstrak barisnya sebagai Arus (Current)
83
- if (tokenId >= 0 && tokenId < this.inputDim) {
84
- const kernelOffset = tokenId * this.outputDim;
85
- const dotOffset = b * this.outputDim;
86
- for (let j = 0; j < this.outputDim; j++) {
87
- dotData[dotOffset + j] = kernel[kernelOffset + j];
88
- }
89
- }
106
+ const tokenIdx = inputs._data[b];
107
+ if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
108
+ const embOffset = tokenIdx * this.outputDim;
109
+ const dotOffset = b * this.outputDim;
110
+ for (let i = 0; i < this.outputDim; i++) {
111
+ dotData[dotOffset + i] = emb._data[embOffset + i];
112
+ }
113
+ }
90
114
  }
91
-
92
- // 2 & 3. Leaky Integrate, Fire & Reset
93
- const outData = new Float32Array(batch * this.outputDim);
115
+
116
+ // 2. Leaky Integrate and Fire (LIF Restore untuk Spiking Murni)
117
+ const outData = this.outDataBuffer!;
118
+ outData.fill(0);
94
119
  const outSpikes = Matrix.fromFlat(outData, [batch, this.outputDim]);
95
- this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
120
+ // lastPotentials is already ensured in shape
96
121
 
97
122
  if (isNativeAvailable()) {
98
123
  lifStepNativeWrapper(
99
124
  this.potentials._data,
100
125
  dotData,
101
- outSpikes._data,
102
- this.lastPotentials._data,
126
+ outData,
127
+ this.lastPotentials!._data,
103
128
  this.beta,
104
129
  this.threshold
105
130
  );
106
131
  } else {
107
132
  const potData = this.potentials._data;
108
- const thresh = this.threshold;
109
- const lpData = this.lastPotentials._data;
110
- for (let i = 0; i < potData.length; i++) {
111
- potData[i] = (potData[i] * this.beta) + dotData[i];
112
- lpData[i] = potData[i];
113
- }
114
- for (let i = 0; i < potData.length; i++) {
115
- if (potData[i] >= thresh) {
116
- outData[i] = 1;
117
- potData[i] -= thresh;
118
- } else {
119
- outData[i] = 0;
120
- }
133
+ const lpData = this.lastPotentials!._data;
134
+
135
+ for (let b = 0; b < batch; b++) {
136
+ const offset = b * this.outputDim;
137
+ for (let i = 0; i < this.outputDim; i++) {
138
+ const idx = offset + i;
139
+ potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0); // Clamp potential max 1.0
140
+ lpData[idx] = potData[idx];
141
+ }
142
+ for (let i = 0; i < this.outputDim; i++) {
143
+ const idx = offset + i;
144
+ if (potData[idx] >= this.threshold[i]) {
145
+ outData[idx] = 1;
146
+ potData[idx] -= this.threshold[i];
147
+ } else {
148
+ outData[idx] = 0;
149
+ }
150
+ }
121
151
  }
122
152
  }
123
-
124
- // Simpan memori untuk update bobot
153
+
125
154
  this.lastInputs = inputs;
126
155
  this.lastSpikes = outSpikes;
127
156
 
128
157
  return outSpikes;
129
158
  }
130
159
 
131
- // Embedding hanya menerima instruksi belajar dari layer atasnya (eHidden yang sudah dikalikan matriks B)
132
- public learnEmbedding(errorFromNext: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
133
- if (!this.lastInputs) {
134
- throw new Error("[SpikingEmbedding] Cannot run learnEmbedding() before forward() is executed. 'lastInputs' is undefined.");
135
- }
136
-
137
- const kernel = this.kernel!._data;
138
- const inputData = this.lastInputs._data;
139
- const batch = this.lastInputs._shape[0];
140
-
141
- // Hitung error yang mampir ke embedding
142
- // E * B (Feedback Alignment)
143
- // Gunakan matmul biasa karena B adalah float, dan errorFromNext mungkin float
144
- const eHidden = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
145
- // Namun karena OxideJS Matrix belum memiliki fungsi dot produk standar terbuka yang stabil,
146
- // kita harus hati-hati di sini. Untuk simplifikasi, eHidden = errorFromNext * B.
147
- // Kita asumsikan ada utilitas dotProduct standar dari core.
148
- // Jika B adalah matriks Dense (dimensi: outUnits x hiddenUnits), maka
149
- // eHidden [batch, hiddenUnits] = errorFromNext [batch, outUnits] dot B [outUnits, hiddenUnits]
160
+ public learnEmbedding(errorSignal: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
161
+ // Broadcast error mundur (Feedback Alignment)
162
+ let eHidden = mj.dotProduct(errorSignal, B, undefined, false, false); // E * B
150
163
 
151
- // Kita panggil dot product standar (bukan Add-Only, karena error dan B sama-sama float)
152
- let eHiddenMatrix = mj.dotProduct(errorFromNext, B, undefined, false, false);
153
-
154
164
  // Surrogate Mask: Boxcar
155
165
  if (this.lastPotentials) {
166
+ const eData = eHidden._data;
167
+ const pData = this.lastPotentials._data;
168
+ const windowSize = 1.0;
169
+
156
170
  if (isNativeAvailable()) {
157
171
  maskSurrogateNativeWrapper(
158
- eHiddenMatrix._data,
159
- this.lastPotentials._data,
160
- this.threshold,
161
- 1.0
172
+ eData,
173
+ pData,
174
+ this.threshold,
175
+ windowSize
162
176
  );
163
177
  } else {
164
- const eData = eHiddenMatrix._data;
165
- const pData = this.lastPotentials._data;
166
- const thresh = this.threshold;
167
- const windowSize = 1.0;
168
-
169
- for (let i = 0; i < eData.length; i++) {
170
- if (Math.abs(pData[i] - thresh) > windowSize) {
171
- eData[i] = 0;
178
+ const batch = eHidden._shape[0];
179
+ for (let b = 0; b < batch; b++) {
180
+ const offset = b * this.outputDim;
181
+ for (let i = 0; i < this.outputDim; i++) {
182
+ const idx = offset + i;
183
+ if (Math.abs(pData[idx] - this.threshold[i]) > windowSize) {
184
+ eData[idx] = 0;
185
+ }
172
186
  }
173
187
  }
174
188
  }
175
189
  }
176
190
 
177
- // Delta Rule Update pada baris Lookup (sangat efisien)
178
- const err = eHiddenMatrix._data;
179
- for (let b = 0; b < batch; b++) {
180
- const tokenId = Math.round(inputData[b]);
181
- if (tokenId >= 0 && tokenId < this.inputDim) {
182
- const kOffset = tokenId * this.outputDim;
183
- const errOffset = b * this.outputDim;
184
- for (let j = 0; j < this.outputDim; j++) {
185
- kernel[kOffset + j] += learningRate * err[errOffset + j];
186
- }
187
- }
188
- }
189
-
190
- return eHiddenMatrix;
191
+ this.applyEmbeddingDelta(eHidden, learningRate);
192
+ return eHidden;
191
193
  }
192
-
193
- /**
194
- * Word2Vec CBOW-style Hebbian Contrastive Learning
195
- * Memungkinkan pembelajaran embedding semantik secara topologis tanpa representation collapse.
196
- */
197
- public learnHebbian(
198
- tokens: number[] | Float32Array,
199
- positiveContext: Float32Array,
200
- negativeContexts: Float32Array[],
201
- learningRate: number = 0.01,
202
- marginPositive: number = 0.1,
203
- marginNegative: number = 0.05
204
- ): void {
205
- const kernel = this.kernel!._data;
206
- const dim = this.outputDim;
207
-
208
- for (let n = 0; n < negativeContexts.length; n++) {
209
- const negMean = negativeContexts[n];
210
- for (let i = 0; i < tokens.length; i++) {
211
- const tokenId = Math.round(tokens[i]);
212
- if (tokenId >= 0 && tokenId < this.inputDim) {
213
- const offset = tokenId * dim;
214
- for (let j = 0; j < dim; j++) {
215
- // Tarik kata ke arah konteks kalimatnya (Positive) - hanya sekali per token
216
- const posGradient = (n === 0) ? (positiveContext[j] - kernel[offset + j]) : 0;
217
- // Tolak kata dari konteks kalimat acak (Negative)
218
- const negGradient = kernel[offset + j] - negMean[j];
219
-
220
- const update = (posGradient * marginPositive) - (negGradient * marginNegative);
221
- kernel[offset + j] += learningRate * update;
194
+
195
+ private applyEmbeddingDelta(errorSignal: Matrix, learningRate: number) {
196
+ if (!this.lastInputs || !this.lastSpikes) {
197
+ throw new Error("[SpikingEmbedding] Cannot run learning before forward() is executed.");
198
+ }
199
+
200
+ const embeddings = this.embeddings!._data;
201
+ const inputs = this.lastInputs._data;
202
+ const err = errorSignal._data;
203
+
204
+ const batch = this.lastInputs._shape[0];
205
+ const outputDim = this.outputDim;
206
+
207
+ if (isNativeAvailable()) {
208
+ applyEmbeddingDeltaNativeWrapper(
209
+ embeddings,
210
+ inputs,
211
+ err,
212
+ learningRate,
213
+ this.inputDim,
214
+ outputDim
215
+ );
216
+ } else {
217
+ for (let b = 0; b < batch; b++) {
218
+ const tokenIdx = inputs[b];
219
+ if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
220
+ const embOffset = tokenIdx * outputDim;
221
+ const errOffset = b * outputDim;
222
+ for (let j = 0; j < outputDim; j++) {
223
+ embeddings[embOffset + j] += learningRate * err[errOffset + j];
224
+ embeddings[embOffset + j] = Math.max(-1.0, Math.min(1.0, embeddings[embOffset + j])); // Clamp weight [-1, 1]
222
225
  }
223
226
  }
224
227
  }
225
228
  }
226
229
  }
230
+
231
+ public getConfig(): Record<string, any> {
232
+ return {
233
+ ...super.getConfig(),
234
+ inputDim: this.inputDim,
235
+ outputDim: this.outputDim,
236
+ embeddingsInitializer: this.embeddingsInitializer
237
+ };
238
+ }
227
239
  }