@oxide-js/spiking 1.2.0 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CHANGELOG.md CHANGED
@@ -1,5 +1,13 @@
1
1
  # @oxide-js/spiking
2
2
 
3
+ ## 1.3.0
4
+
5
+ ### Minor Changes
6
+
7
+ - 1f6489c: - Added `SpikingDenseBPTT` temporal pooler with Sequence-as-Time dynamics.
8
+ - Implemented Spike-Count Accumulation with L2 Normalization to solve membrane saturation in long sequences.
9
+ - Updated documentation for SpikingDenseBPTT in API layers.
10
+
3
11
  ## 1.2.0
4
12
 
5
13
  ### Minor Changes
package/index.d.ts CHANGED
@@ -8,3 +8,4 @@ export declare function lifStepNative(potentials: Float32Array, dot: Float32Arra
8
8
  export declare function maskSurrogateNative(errorSignal: Float32Array, potentials: Float32Array, threshold: Float32Array, windowSize: number): void
9
9
  export declare function applyAddOnlyDeltaNative(kernel: Float32Array, bias: Float32Array, inputs: Float32Array, errorSignal: Float32Array, learningRate: number, batch: number, inFeatures: number, units: number, useBias: boolean): void
10
10
  export declare function applyEmbeddingDeltaNative(embeddings: Float32Array, inputs: Float32Array, errorSignal: Float32Array, learningRate: number, inputDim: number, outputDim: number): void
11
+ export declare function contrastiveHebbianNative(spikes: Float32Array, errData: Float32Array, numPairs: number, sequenceLength: number, dModel: number): number
package/index.js CHANGED
@@ -310,10 +310,11 @@ if (!nativeBinding) {
310
310
  throw new Error(`Failed to load native binding`)
311
311
  }
312
312
 
313
- const { dotProductAddOnlyNative, lifStepNative, maskSurrogateNative, applyAddOnlyDeltaNative, applyEmbeddingDeltaNative } = nativeBinding
313
+ const { dotProductAddOnlyNative, lifStepNative, maskSurrogateNative, applyAddOnlyDeltaNative, applyEmbeddingDeltaNative, contrastiveHebbianNative } = nativeBinding
314
314
 
315
315
  module.exports.dotProductAddOnlyNative = dotProductAddOnlyNative
316
316
  module.exports.lifStepNative = lifStepNative
317
317
  module.exports.maskSurrogateNative = maskSurrogateNative
318
318
  module.exports.applyAddOnlyDeltaNative = applyAddOnlyDeltaNative
319
319
  module.exports.applyEmbeddingDeltaNative = applyEmbeddingDeltaNative
320
+ module.exports.contrastiveHebbianNative = contrastiveHebbianNative
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@oxide-js/spiking",
3
- "version": "1.2.0",
3
+ "version": "1.3.0",
4
4
  "type": "module",
5
5
  "description": "Oxide-JS Spiking: Add-Only Event-Driven Spiking Neural Network implementation.",
6
6
  "repository": {
Binary file
package/src/index.ts CHANGED
@@ -1,4 +1,7 @@
1
1
  export { default as dotProductAddOnly } from "./math/dotProductAddOnly.js";
2
2
  export { SpikingDense, type SpikingDenseConfig } from "./layers/SpikingDense.js";
3
+ export { SpikingDenseBPTT, type SpikingDenseBPTTConfig } from "./layers/SpikingDenseBPTT.js";
3
4
  export { SpikingEmbedding, type SpikingEmbeddingConfig } from "./layers/SpikingEmbedding.js";
5
+ export { SpikingSelfAttention, type SpikingSelfAttentionConfig } from "./layers/SpikingSelfAttention.js";
4
6
  // export { SpikingSentenceEmbedder, type SpikingSentenceConfig } from "./models/SpikingSentenceEmbedder.js";
7
+ export { contrastiveHebbianNativeWrapper, isNativeAvailable } from "./native_backend.js";
@@ -12,6 +12,8 @@ export interface SpikingDenseConfig extends LayerConfig {
12
12
  useBias?: boolean;
13
13
  kernelInitializer?: string;
14
14
  biasInitializer?: string;
15
+ betaRange?: [number, number];
16
+ thresholdRange?: [number, number];
15
17
  }
16
18
 
17
19
  export class SpikingDense extends BaseLayer {
@@ -19,6 +21,8 @@ export class SpikingDense extends BaseLayer {
19
21
  public useBias: boolean;
20
22
  public kernelInitializer: string;
21
23
  public biasInitializer: string;
24
+ public betaRange: [number, number];
25
+ public thresholdRange: [number, number];
22
26
  public beta!: Float32Array;
23
27
  public threshold!: Float32Array;
24
28
 
@@ -41,6 +45,8 @@ export class SpikingDense extends BaseLayer {
41
45
  this.useBias = config.useBias ?? true;
42
46
  this.kernelInitializer = config.kernelInitializer || "glorot_normal";
43
47
  this.biasInitializer = config.biasInitializer || "zeros";
48
+ this.betaRange = config.betaRange || [0.8, 0.99];
49
+ this.thresholdRange = config.thresholdRange || [0.5, 1.0];
44
50
  }
45
51
 
46
52
  public computeOutputShape(inputShape: number[]): number[] {
@@ -65,23 +71,27 @@ export class SpikingDense extends BaseLayer {
65
71
  this.beta = new Float32Array(this.units);
66
72
  this.threshold = new Float32Array(this.units);
67
73
  for (let i = 0; i < this.units; i++) {
68
- this.beta[i] = 0.8 + Math.random() * 0.19;
69
- this.threshold[i] = 0.5 + Math.random() * 0.5; // Max 1.0
74
+ this.beta[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
75
+ this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
70
76
  }
71
77
 
72
78
  // Inisialisasi state
73
79
  this.potentials = Matrix.fromFlat(new Float32Array(this.units), [1, this.units]);
74
80
  }
75
81
 
82
+ private outSpikesDataBuffer?: Float32Array;
83
+
76
84
  private ensurePotentialsShape(batch: number) {
77
- if (this.potentials._shape[0] !== batch) {
85
+ if (this.potentials._shape[0] !== batch || !this.outSpikesDataBuffer) {
78
86
  this.potentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
87
+ this.outSpikesDataBuffer = new Float32Array(batch * this.units);
88
+ this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
79
89
  }
80
90
  }
81
91
 
82
92
  public resetState() {
83
93
  if (this.potentials) this.potentials._data.fill(0);
84
- this.lastPotentials = undefined;
94
+ if (this.lastPotentials) this.lastPotentials._data.fill(0);
85
95
  this.lastInputs = undefined;
86
96
  this.lastSpikes = undefined;
87
97
  }
@@ -100,23 +110,23 @@ export class SpikingDense extends BaseLayer {
100
110
  }
101
111
 
102
112
  // 3 & 4. Leaky Integrate, Fire & Reset
103
- const outData = new Float32Array(batch * this.units);
113
+ const outData = this.outSpikesDataBuffer!;
114
+ outData.fill(0);
104
115
  const outSpikes = Matrix.fromFlat(outData, [batch, this.units]);
105
- this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
106
116
 
107
117
  if (isNativeAvailable()) {
108
118
  lifStepNativeWrapper(
109
119
  this.potentials._data,
110
120
  dot._data,
111
121
  outSpikes._data,
112
- this.lastPotentials._data,
122
+ this.lastPotentials!._data,
113
123
  this.beta,
114
124
  this.threshold
115
125
  );
116
126
  } else {
117
127
  const potData = this.potentials._data;
118
128
  const dotData = dot._data;
119
- const lpData = this.lastPotentials._data;
129
+ const lpData = this.lastPotentials!._data;
120
130
 
121
131
  for (let b = 0; b < batch; b++) {
122
132
  const offset = b * this.units;
@@ -0,0 +1,303 @@
1
+ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
2
+ import { Matrix, mj } from "@oxide-js/core";
3
+ import {
4
+ isNativeAvailable,
5
+ lifStepNativeWrapper,
6
+ maskSurrogateNativeWrapper,
7
+ applyAddOnlyDeltaNativeWrapper
8
+ } from "../native_backend.js";
9
+ import dotProductAddOnly from "../math/dotProductAddOnly.js";
10
+
11
+ export interface SpikingDenseBPTTConfig extends LayerConfig {
12
+ units: number;
13
+ useBias?: boolean;
14
+ kernelInitializer?: string;
15
+ biasInitializer?: string;
16
+ betaRange?: [number, number];
17
+ thresholdRange?: [number, number];
18
+ }
19
+
20
+ export class SpikingDenseBPTT extends BaseLayer {
21
+ public units: number;
22
+ public useBias: boolean;
23
+ public kernelInitializer: string;
24
+ public biasInitializer: string;
25
+ public betaRange: [number, number];
26
+ public thresholdRange: [number, number];
27
+ public beta!: Float32Array;
28
+ public threshold!: Float32Array;
29
+
30
+ public potentials!: Matrix;
31
+
32
+ // History buffers for Backpropagation Through Time (BPTT)
33
+ public historyInputs: Matrix[] = [];
34
+ public historyPotentials: Matrix[] = [];
35
+ public historySpikes: Matrix[] = [];
36
+ public maxTimeSteps: number = 0;
37
+
38
+ // Buffer untuk performa komputasi Forward
39
+ private outSpikesDataBuffer?: Float32Array;
40
+
41
+ public get kernel(): Matrix | undefined {
42
+ return this.getParameter("kernel");
43
+ }
44
+
45
+ public get bias(): Matrix | undefined {
46
+ return this.getParameter("bias");
47
+ }
48
+
49
+ constructor(config: SpikingDenseBPTTConfig) {
50
+ super(config);
51
+ this.units = config.units;
52
+ this.useBias = config.useBias ?? true;
53
+ this.kernelInitializer = config.kernelInitializer || "glorot_normal";
54
+ this.biasInitializer = config.biasInitializer || "zeros";
55
+ this.betaRange = config.betaRange || [0.8, 0.99];
56
+ this.thresholdRange = config.thresholdRange || [0.5, 1.0];
57
+ }
58
+
59
+ public computeOutputShape(inputShape: number[]): number[] {
60
+ const batch = inputShape[0] ?? -1;
61
+ return [batch, this.units];
62
+ }
63
+
64
+ public build(inputShape: number[]): void {
65
+ super.build(inputShape);
66
+
67
+ const inFeatures = inputShape[inputShape.length - 1];
68
+
69
+ const kernelVal = this.createInitializer(this.kernelInitializer, [inFeatures, this.units]);
70
+ this.addParameter("kernel", kernelVal, true, [inFeatures, this.units]);
71
+
72
+ if (this.useBias) {
73
+ const biasVal = this.createInitializer(this.biasInitializer, [this.units, 1]);
74
+ this.addParameter("bias", biasVal, true, [this.units, 1]);
75
+ }
76
+
77
+ // Inisialisasi beta (dengan pre-calculated bit-shift logic) dan threshold acak
78
+ this.beta = new Float32Array(this.units);
79
+ this.threshold = new Float32Array(this.units);
80
+ for (let i = 0; i < this.units; i++) {
81
+ // Pilih pangkat bit-shift secara acak (2 hingga 5)
82
+ const shift = Math.floor(2 + Math.random() * 4);
83
+ // Pre-kalkulasi multiplier float agar loop sangat cepat tanpa perkalian ekstra
84
+ this.beta[i] = 1.0 - (1.0 / Math.pow(2, shift));
85
+ this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
86
+ }
87
+
88
+ // Inisialisasi state
89
+ this.potentials = Matrix.fromFlat(new Float32Array(this.units), [1, this.units]);
90
+ }
91
+
92
+ private ensurePotentialsShape(batch: number) {
93
+ if (this.potentials._shape[0] !== batch || !this.outSpikesDataBuffer) {
94
+ this.potentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
95
+ this.outSpikesDataBuffer = new Float32Array(batch * this.units);
96
+ }
97
+ }
98
+
99
+ // Panggil fungsi ini SEBELUM satu kalimat/sequence baru mulai dimasukkan
100
+ public resetSequence(timeSteps: number) {
101
+ this.maxTimeSteps = timeSteps;
102
+ this.historyInputs = new Array(timeSteps);
103
+ this.historyPotentials = new Array(timeSteps);
104
+ this.historySpikes = new Array(timeSteps);
105
+
106
+ if (this.potentials) this.potentials._data.fill(0);
107
+ }
108
+
109
+ protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
110
+ throw new Error("[SpikingDenseBPTT] Harap gunakan computeStep(inputs, t) dan resetSequence(t) untuk model BPTT, jangan gunakan compute().");
111
+ }
112
+
113
+ // Forward Pass untuk satu token di waktu ke-t
114
+ public computeStep(inputs: Matrix, t: number): Matrix {
115
+ if (t >= this.maxTimeSteps) {
116
+ throw new Error(`[SpikingDenseBPTT] Time step ${t} melebihi batas maxTimeSteps ${this.maxTimeSteps}`);
117
+ }
118
+
119
+ const kernel = this.kernel!;
120
+ const batch = inputs._shape[0];
121
+ this.ensurePotentialsShape(batch);
122
+
123
+ // 1. Simpan input ke dalam history buffer di index t
124
+ this.historyInputs[t] = Matrix.fromFlat(new Float32Array(inputs._data), inputs._shape);
125
+
126
+ // 2. Add-Only Spiking Dot Product
127
+ let dot = dotProductAddOnly(inputs, kernel);
128
+
129
+ if (this.useBias && this.bias) {
130
+ mj.addBiasRow(dot, this.bias);
131
+ }
132
+
133
+ const outData = this.outSpikesDataBuffer!;
134
+ outData.fill(0);
135
+ const outSpikes = Matrix.fromFlat(outData, [batch, this.units]);
136
+
137
+ // Buffer untuk menyimpan Potensial Membran di waktu t (SETELAH ditambahkan input, tapi SEBELUM ditembakkan/direset)
138
+ // Ini krusial untuk evaluasi kedekatan threshold pada proses Surrogate Gradient BPTT
139
+ const potAtT = new Float32Array(batch * this.units);
140
+
141
+ if (isNativeAvailable()) {
142
+ lifStepNativeWrapper(
143
+ this.potentials._data,
144
+ dot._data,
145
+ outSpikes._data,
146
+ potAtT, // argumen ke-4 di lifStepNative akan diisi oleh potensial pre-fire
147
+ this.beta,
148
+ this.threshold
149
+ );
150
+ } else {
151
+ const potData = this.potentials._data;
152
+ const dotData = dot._data;
153
+
154
+ for (let b = 0; b < batch; b++) {
155
+ const offset = b * this.units;
156
+ // Tahap Leaky & Integrate
157
+ for (let i = 0; i < this.units; i++) {
158
+ const idx = offset + i;
159
+ potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0);
160
+ potAtT[idx] = potData[idx]; // Catat memori potensial di waktu t
161
+ }
162
+ // Tahap Fire & Reset
163
+ for (let i = 0; i < this.units; i++) {
164
+ const idx = offset + i;
165
+ if (potData[idx] >= this.threshold[i]) {
166
+ outData[idx] = 1;
167
+ potData[idx] -= this.threshold[i]; // Soft Reset
168
+ } else {
169
+ outData[idx] = 0;
170
+ }
171
+ }
172
+ }
173
+ }
174
+
175
+ // 3. Simpan state (Potensial & Output Spikes) ke buffer di index t
176
+ this.historyPotentials[t] = Matrix.fromFlat(potAtT, [batch, this.units]);
177
+ this.historySpikes[t] = Matrix.fromFlat(new Float32Array(outSpikes._data), [batch, this.units]);
178
+
179
+ return outSpikes;
180
+ }
181
+
182
+ // Backward Pass untuk belajar menggunakan BPTT
183
+ // Dipanggil SATU KALI HANYA saat kalimat (sequence) selesai diproses
184
+ public learnThroughTime(errorSequence: Matrix[], B: Matrix | undefined, learningRate: number = 0.01): void {
185
+ if (this.maxTimeSteps === 0 || !this.historyInputs[0]) {
186
+ throw new Error("[SpikingDenseBPTT] Belum ada data di memory. Panggil computeStep(inputs, t) terlebih dahulu.");
187
+ }
188
+
189
+ const batch = errorSequence[0]._shape[0];
190
+ const inFeatures = this.historyInputs[0]._shape[1];
191
+ const units = this.units;
192
+ const kernel = this.kernel!._data;
193
+
194
+ // Sinyal "Penyesalan" yang menjalar mundur dari masa depan ke masa lalu
195
+ let temporalErrorData = new Float32Array(batch * units).fill(0);
196
+ const windowSize = 1.0; // Lebar boxcar surrogate gradient
197
+
198
+ // PRE-ALLOCATE buffers untuk menghilangkan BOTTLENECK Javascript (Zero Garbage Collection dalam loop)
199
+ const totalErrorData = new Float32Array(batch * units);
200
+ const maskedErrorData = new Float32Array(batch * units);
201
+ const biasData = (this.useBias && this.bias) ? this.bias._data : new Float32Array(0);
202
+
203
+ // Loop Mundur (Dari akhir kalimat ke awal kalimat)
204
+ for (let t = this.maxTimeSteps - 1; t >= 0; t--) {
205
+ let currentErrorData: Float32Array;
206
+
207
+ if (B) {
208
+ // Jika ini layer tersembunyi (Hidden): Evaluasi menggunakan matriks broadcast B
209
+ let eHidden = mj.dotProduct(errorSequence[t], B, undefined, false, false);
210
+ currentErrorData = eHidden._data;
211
+ } else {
212
+ // Jika ini layer Output: Error murni dari loss function
213
+ currentErrorData = errorSequence[t]._data;
214
+ }
215
+
216
+ const pData = this.historyPotentials[t]._data;
217
+ const inputData = this.historyInputs[t]._data;
218
+
219
+ // Langkah A: Menyatukan Sinyal Spasial (Atas-Bawah) dan Temporal (Masa Depan-Masa Lalu)
220
+ for (let i = 0; i < totalErrorData.length; i++) {
221
+ totalErrorData[i] = currentErrorData[i] + temporalErrorData[i];
222
+ }
223
+
224
+ // Salin total error ke masked error, karena native wrapper akan menimpanya in-place
225
+ maskedErrorData.set(totalErrorData);
226
+
227
+ if (isNativeAvailable()) {
228
+ // 1. Surrogate Mask Native (Zero Copy pointer passing ke Rust)
229
+ maskSurrogateNativeWrapper(maskedErrorData, pData, this.threshold, windowSize);
230
+
231
+ // 2. Add-Only Delta Rule Native (Zero Copy pointer passing ke Rust)
232
+ applyAddOnlyDeltaNativeWrapper(
233
+ kernel,
234
+ biasData,
235
+ inputData,
236
+ maskedErrorData,
237
+ learningRate,
238
+ batch,
239
+ inFeatures,
240
+ units,
241
+ this.useBias
242
+ );
243
+
244
+ // 3. Hitung temporal error (leaky pathway)
245
+ for (let b = 0; b < batch; b++) {
246
+ const offset = b * units;
247
+ for (let i = 0; i < units; i++) {
248
+ const idx = offset + i;
249
+ temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
250
+ }
251
+ }
252
+ } else {
253
+ // ============ FALLBACK JAVASCRIPT ============
254
+ for (let b = 0; b < batch; b++) {
255
+ const offset = b * units;
256
+ for (let i = 0; i < units; i++) {
257
+ const idx = offset + i;
258
+ // Surrogate Boxcar Masking
259
+ if (Math.abs(pData[idx] - this.threshold[i]) <= windowSize) {
260
+ // maskedErrorData sudah berisi totalErrData karena di-.set() di atas
261
+ } else {
262
+ maskedErrorData[idx] = 0;
263
+ }
264
+
265
+ // Menghitung Sinyal Temporal untuk dilanjutkan ke waktu t-1 (melewati jalur leaky/beta)
266
+ temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
267
+ }
268
+ }
269
+
270
+ // Langkah B: Add-Only Delta Rule untuk update Bobot di waktu t
271
+ for (let b = 0; b < batch; b++) {
272
+ const inOffset = b * inFeatures;
273
+ const errOffset = b * units;
274
+ for (let k = 0; k < inFeatures; k++) {
275
+ if (inputData[inOffset + k] > 0.5) {
276
+ const kOffset = k * units;
277
+ for (let j = 0; j < units; j++) {
278
+ kernel[kOffset + j] += learningRate * maskedErrorData[errOffset + j];
279
+ kernel[kOffset + j] = Math.max(-1.0, Math.min(1.0, kernel[kOffset + j]));
280
+ }
281
+ }
282
+ }
283
+ if (this.useBias && this.bias) {
284
+ for (let j = 0; j < units; j++) {
285
+ biasData[j] += (learningRate * maskedErrorData[errOffset + j]) / batch;
286
+ biasData[j] = Math.max(-1.0, Math.min(1.0, biasData[j]));
287
+ }
288
+ }
289
+ }
290
+ }
291
+ }
292
+ }
293
+
294
+ public getConfig(): Record<string, any> {
295
+ return {
296
+ ...super.getConfig(),
297
+ units: this.units,
298
+ useBias: this.useBias,
299
+ kernelInitializer: this.kernelInitializer,
300
+ biasInitializer: this.biasInitializer
301
+ };
302
+ }
303
+ }
@@ -3,19 +3,25 @@ import { Matrix, mj } from "@oxide-js/core";
3
3
  import {
4
4
  isNativeAvailable,
5
5
  lifStepNativeWrapper,
6
- maskSurrogateNativeWrapper
6
+ maskSurrogateNativeWrapper,
7
+ applyEmbeddingDeltaNativeWrapper
7
8
  } from "../native_backend.js";
8
9
 
9
10
  export interface SpikingEmbeddingConfig extends LayerConfig {
10
11
  inputDim: number;
11
12
  outputDim: number;
12
13
  embeddingsInitializer?: string;
14
+ betaRange?: [number, number];
15
+ thresholdRange?: [number, number];
13
16
  }
14
17
 
15
18
  export class SpikingEmbedding extends BaseLayer {
16
19
  public inputDim: number;
17
20
  public outputDim: number;
18
21
  public embeddingsInitializer: string;
22
+
23
+ public betaRange: [number, number];
24
+ public thresholdRange: [number, number];
19
25
  public beta!: Float32Array;
20
26
  public threshold!: Float32Array;
21
27
 
@@ -33,6 +39,8 @@ export class SpikingEmbedding extends BaseLayer {
33
39
  this.inputDim = config.inputDim;
34
40
  this.outputDim = config.outputDim;
35
41
  this.embeddingsInitializer = config.embeddingsInitializer || "glorot_normal";
42
+ this.betaRange = config.betaRange || [0.8, 0.99];
43
+ this.thresholdRange = config.thresholdRange || [0.01, 0.1];
36
44
  }
37
45
 
38
46
  public computeOutputShape(inputShape: number[]): number[] {
@@ -44,29 +52,43 @@ export class SpikingEmbedding extends BaseLayer {
44
52
  super.build(inputShape);
45
53
 
46
54
  const embVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
55
+
56
+ // Scale up the embedding values because Glorot Normal makes them too small for large vocabularies (e.g. 32000)
57
+ // which prevents the LIF neurons from ever reaching the threshold.
58
+ const scaleFactor = Math.sqrt(this.inputDim);
59
+ for (let i = 0; i < embVal._data.length; i++) {
60
+ embVal._data[i] *= scaleFactor;
61
+ }
62
+
47
63
  this.addParameter("embeddings", embVal, true, [this.inputDim, this.outputDim]);
48
64
 
49
65
  // Inisialisasi beta dan threshold secara acak untuk setiap neuron
50
66
  this.beta = new Float32Array(this.outputDim);
51
67
  this.threshold = new Float32Array(this.outputDim);
52
68
  for (let i = 0; i < this.outputDim; i++) {
53
- this.beta[i] = 0.8 + Math.random() * 0.19; // Random 0.8 - 0.99
54
- this.threshold[i] = 0.5 + Math.random() * 0.5; // Random 0.5 - 1.0 (Max 1.0)
69
+ this.beta[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
70
+ this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
55
71
  }
56
72
 
57
73
  // Potentials start at 0, shape [batch, outputDim].
58
74
  this.potentials = Matrix.fromFlat(new Float32Array(this.outputDim), [1, this.outputDim]);
59
75
  }
60
76
 
77
+ private dotDataBuffer?: Float32Array;
78
+ private outDataBuffer?: Float32Array;
79
+
61
80
  private ensurePotentialsShape(batch: number) {
62
- if (this.potentials._shape[0] !== batch) {
81
+ if (this.potentials._shape[0] !== batch || !this.dotDataBuffer) {
63
82
  this.potentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
83
+ this.dotDataBuffer = new Float32Array(batch * this.outputDim);
84
+ this.outDataBuffer = new Float32Array(batch * this.outputDim);
85
+ this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
64
86
  }
65
87
  }
66
88
 
67
89
  public resetState() {
68
90
  if (this.potentials) this.potentials._data.fill(0);
69
- this.lastPotentials = undefined;
91
+ if (this.lastPotentials) this.lastPotentials._data.fill(0);
70
92
  this.lastInputs = undefined;
71
93
  this.lastSpikes = undefined;
72
94
  }
@@ -77,7 +99,8 @@ export class SpikingEmbedding extends BaseLayer {
77
99
 
78
100
  // 1. Embedding lookup
79
101
  const emb = this.embeddings!;
80
- const dotData = new Float32Array(batch * this.outputDim);
102
+ const dotData = this.dotDataBuffer!;
103
+ dotData.fill(0);
81
104
 
82
105
  for (let b = 0; b < batch; b++) {
83
106
  const tokenIdx = inputs._data[b];
@@ -90,23 +113,24 @@ export class SpikingEmbedding extends BaseLayer {
90
113
  }
91
114
  }
92
115
 
93
- // 2. Leaky Integrate and Fire
94
- const outData = new Float32Array(batch * this.outputDim);
116
+ // 2. Leaky Integrate and Fire (LIF Restore untuk Spiking Murni)
117
+ const outData = this.outDataBuffer!;
118
+ outData.fill(0);
95
119
  const outSpikes = Matrix.fromFlat(outData, [batch, this.outputDim]);
96
- this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
120
+ // lastPotentials is already ensured in shape
97
121
 
98
122
  if (isNativeAvailable()) {
99
123
  lifStepNativeWrapper(
100
124
  this.potentials._data,
101
125
  dotData,
102
126
  outData,
103
- this.lastPotentials._data,
127
+ this.lastPotentials!._data,
104
128
  this.beta,
105
129
  this.threshold
106
130
  );
107
131
  } else {
108
132
  const potData = this.potentials._data;
109
- const lpData = this.lastPotentials._data;
133
+ const lpData = this.lastPotentials!._data;
110
134
 
111
135
  for (let b = 0; b < batch; b++) {
112
136
  const offset = b * this.outputDim;
@@ -126,7 +150,7 @@ export class SpikingEmbedding extends BaseLayer {
126
150
  }
127
151
  }
128
152
  }
129
-
153
+
130
154
  this.lastInputs = inputs;
131
155
  this.lastSpikes = outSpikes;
132
156
 
@@ -0,0 +1,335 @@
1
+ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
2
+ import { Matrix, mj } from "@oxide-js/core";
3
+ import { isNativeAvailable, lifStepNativeWrapper } from "../native_backend.js";
4
+ import dotProductAddOnly from "../math/dotProductAddOnly.js";
5
+
6
+ export interface SpikingSelfAttentionConfig extends LayerConfig {
7
+ d_model: number;
8
+ sequenceLength: number;
9
+ kernelInitializer?: string;
10
+ betaRange?: [number, number];
11
+ thresholdRange?: [number, number];
12
+ }
13
+
14
+ export class SpikingSelfAttention extends BaseLayer {
15
+ public d_model: number;
16
+ public sequenceLength: number;
17
+ public kernelInitializer: string;
18
+ public betaRange: [number, number];
19
+ public thresholdRange: [number, number];
20
+
21
+ // Q, K, V kernels
22
+ public get kernelQ(): Matrix | undefined { return this.getParameter("kernelQ"); }
23
+ public get kernelK(): Matrix | undefined { return this.getParameter("kernelK"); }
24
+ public get kernelV(): Matrix | undefined { return this.getParameter("kernelV"); }
25
+
26
+ // LIF state untuk Q, K, V (opsional, jika ingin akumulasi temporal)
27
+ public betaQKV!: Float32Array;
28
+ public thresholdQKV!: Float32Array;
29
+ public potentialsQ!: Matrix;
30
+ public potentialsK!: Matrix;
31
+ public potentialsV!: Matrix;
32
+
33
+ // LIF state untuk Attention Scores (Pengganti Softmax)
34
+ public betaScores!: Float32Array;
35
+ public thresholdScores!: Float32Array;
36
+ public potentialsScores!: Matrix;
37
+
38
+ // Cache input untuk Local Learning
39
+ public lastInputs?: Matrix;
40
+
41
+ constructor(config: SpikingSelfAttentionConfig) {
42
+ super(config);
43
+ this.d_model = config.d_model;
44
+ this.sequenceLength = config.sequenceLength;
45
+ this.kernelInitializer = config.kernelInitializer || "glorot_normal";
46
+ this.betaRange = config.betaRange || [0.8, 0.99];
47
+ this.thresholdRange = config.thresholdRange || [0.1, 0.3];
48
+ }
49
+
50
+ public computeOutputShape(inputShape: number[]): number[] {
51
+ const batch = inputShape[0] ?? -1;
52
+ // Asumsi input shape: [batch * seqLen, d_model]
53
+ return [batch, this.d_model]; // Actually [batch * seqLen, d_model]
54
+ }
55
+
56
+ public build(inputShape: number[]): void {
57
+ super.build(inputShape);
58
+
59
+ const inFeatures = inputShape[inputShape.length - 1]; // Seharusnya sama dengan d_model
60
+
61
+ // 1. Inisialisasi Bobot Q, K, V
62
+ this.addParameter("kernelQ", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
63
+ this.addParameter("kernelK", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
64
+ this.addParameter("kernelV", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
65
+
66
+ // OPTIMIZATION: Scale up initial weights so neurons actually spike (prevent Layer 2 death)
67
+ const scale = Math.sqrt(inFeatures);
68
+ const kQ = this.kernelQ!._data;
69
+ const kK = this.kernelK!._data;
70
+ const kV = this.kernelV!._data;
71
+ for(let i = 0; i < kQ.length; i++) {
72
+ kQ[i] *= scale;
73
+ kK[i] *= scale;
74
+ kV[i] *= scale;
75
+ }
76
+
77
+ // 2. Inisialisasi parameter LIF untuk Q, K, V
78
+ this.betaQKV = new Float32Array(this.d_model);
79
+ this.thresholdQKV = new Float32Array(this.d_model);
80
+ for (let i = 0; i < this.d_model; i++) {
81
+ this.betaQKV[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
82
+ this.thresholdQKV[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
83
+ }
84
+
85
+ // 3. Inisialisasi parameter LIF untuk Attention Scores (Pengganti Softmax)
86
+ this.betaScores = new Float32Array(this.sequenceLength);
87
+ this.thresholdScores = new Float32Array(this.sequenceLength);
88
+ for (let i = 0; i < this.sequenceLength; i++) {
89
+ this.betaScores[i] = 0.9;
90
+ // Ambang batas diturunkan tajam untuk mencegah Dead Neurons
91
+ this.thresholdScores[i] = 1.0;
92
+ }
93
+
94
+ // Inisialisasi Potentials akan dilakukan secara dinamis pada saat forward
95
+ this.potentialsQ = Matrix.fromFlat(new Float32Array(0), [0, 0]);
96
+ this.potentialsK = Matrix.fromFlat(new Float32Array(0), [0, 0]);
97
+ this.potentialsV = Matrix.fromFlat(new Float32Array(0), [0, 0]);
98
+ this.potentialsScores = Matrix.fromFlat(new Float32Array(0), [0, 0]);
99
+ }
100
+ private sqDataBuffer?: Float32Array;
101
+ private skDataBuffer?: Float32Array;
102
+ private svDataBuffer?: Float32Array;
103
+ private dummyLpBuffer?: Float32Array;
104
+ private matchScoresBuffer?: Float32Array;
105
+ private qGatedVBuffer?: Float32Array;
106
+ private outSpikesBuffer?: Float32Array;
107
+ private sScoresDataBuffer?: Float32Array;
108
+ private dummyLpScoresBuffer?: Float32Array;
109
+ private tempMatchesBuffer?: Float32Array;
110
+
111
+ private ensurePotentialsShape(batchSeq: number, seqLen: number) {
112
+ if (this.potentialsQ._shape[0] !== batchSeq || !this.sqDataBuffer) {
113
+ this.potentialsQ = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
114
+ this.potentialsK = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
115
+ this.potentialsV = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
116
+ this.potentialsScores = Matrix.fromFlat(new Float32Array(batchSeq * seqLen), [batchSeq, seqLen]);
117
+
118
+ this.sqDataBuffer = new Float32Array(batchSeq * this.d_model);
119
+ this.skDataBuffer = new Float32Array(batchSeq * this.d_model);
120
+ this.svDataBuffer = new Float32Array(batchSeq * this.d_model);
121
+ this.dummyLpBuffer = new Float32Array(batchSeq * this.d_model);
122
+ this.matchScoresBuffer = new Float32Array(batchSeq * seqLen);
123
+ this.qGatedVBuffer = new Float32Array(batchSeq * this.d_model);
124
+ this.outSpikesBuffer = new Float32Array(batchSeq * this.d_model);
125
+ this.sScoresDataBuffer = new Float32Array(batchSeq * seqLen);
126
+ this.dummyLpScoresBuffer = new Float32Array(batchSeq * seqLen);
127
+ this.tempMatchesBuffer = new Float32Array(seqLen);
128
+ }
129
+ }
130
+
131
+ public resetState() {
132
+ if (this.potentialsQ) this.potentialsQ._data.fill(0);
133
+ if (this.potentialsK) this.potentialsK._data.fill(0);
134
+ if (this.potentialsV) this.potentialsV._data.fill(0);
135
+ if (this.potentialsScores) this.potentialsScores._data.fill(0);
136
+ }
137
+
138
+ protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
139
+ // Asumsi inputs adalah flat [batch * seqLen, d_model]
140
+ const batchSeq = inputs._shape[0];
141
+ const seqLen = this.sequenceLength;
142
+ const batch = batchSeq / seqLen;
143
+ const d_model = this.d_model;
144
+
145
+ if (!Number.isInteger(batch)) {
146
+ throw new Error(`[SpikingSelfAttention] Jumlah baris input (${batchSeq}) harus merupakan kelipatan dari sequenceLength (${seqLen}).`);
147
+ }
148
+
149
+ this.ensurePotentialsShape(batchSeq, seqLen);
150
+ this.lastInputs = inputs; // Simpan untuk local learning
151
+
152
+ // 1. Proyeksi Q, K, V (Hanya Addisi / Pergeseran Bit karena input spike biner)
153
+ let dotQ = dotProductAddOnly(inputs, this.kernelQ!);
154
+ let dotK = dotProductAddOnly(inputs, this.kernelK!);
155
+ let dotV = dotProductAddOnly(inputs, this.kernelV!);
156
+
157
+ // 2. LIF Step untuk menghasilkan S_Q, S_K, S_V (Matriks Biner)
158
+ const sqData = this.sqDataBuffer!;
159
+ sqData.fill(0);
160
+ const skData = this.skDataBuffer!;
161
+ skData.fill(0);
162
+ const svData = this.svDataBuffer!;
163
+ svData.fill(0);
164
+ const dummyLp = this.dummyLpBuffer!;
165
+ dummyLp.fill(0);
166
+
167
+ // Q
168
+ if (isNativeAvailable()) {
169
+ lifStepNativeWrapper(this.potentialsQ._data, dotQ._data, sqData, dummyLp, this.betaQKV, this.thresholdQKV);
170
+ lifStepNativeWrapper(this.potentialsK._data, dotK._data, skData, dummyLp, this.betaQKV, this.thresholdQKV);
171
+ lifStepNativeWrapper(this.potentialsV._data, dotV._data, svData, dummyLp, this.betaQKV, this.thresholdQKV);
172
+ } else {
173
+ this.runLIF(this.potentialsQ._data, dotQ._data, sqData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
174
+ this.runLIF(this.potentialsK._data, dotK._data, skData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
175
+ this.runLIF(this.potentialsV._data, dotV._data, svData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
176
+ }
177
+
178
+ // 3. Menghitung Skor Kecocokan (SQ dot SK^T) menggunakan operasi AND / bit-wise addition
179
+ // Hasilnya akan berukuran [batch * seqLen, seqLen]
180
+ const matchScores = this.matchScoresBuffer!;
181
+ matchScores.fill(0);
182
+
183
+ for (let b = 0; b < batch; b++) {
184
+ for (let i = 0; i < seqLen; i++) {
185
+ const qBase = b * seqLen * d_model + i * d_model;
186
+ // Pre-collect non-zero indices for Q to exploit sparsity
187
+ const nonZeroQ: number[] = [];
188
+ for (let d = 0; d < d_model; d++) {
189
+ if (sqData[qBase + d] > 0) nonZeroQ.push(d);
190
+ }
191
+ if (nonZeroQ.length === 0) continue;
192
+
193
+ let maxMatch = 0;
194
+ const tempMatches = this.tempMatchesBuffer!;
195
+ tempMatches.fill(0);
196
+
197
+ for (let j = 0; j < seqLen; j++) {
198
+ let matchCount = 0;
199
+ const kBase = b * seqLen * d_model + j * d_model;
200
+ for (let k = 0; k < nonZeroQ.length; k++) {
201
+ const d = nonZeroQ[k];
202
+ if (skData[kBase + d] > 0) matchCount++;
203
+ }
204
+ tempMatches[j] = matchCount;
205
+ if (matchCount > maxMatch) {
206
+ maxMatch = matchCount;
207
+ }
208
+ }
209
+
210
+ for (let j = 0; j < seqLen; j++) {
211
+ if (maxMatch > 0) {
212
+ matchScores[b * seqLen * seqLen + i * seqLen + j] = tempMatches[j] / maxMatch;
213
+ } else {
214
+ matchScores[b * seqLen * seqLen + i * seqLen + j] = 0;
215
+ }
216
+ }
217
+ }
218
+ }
219
+
220
+ // 4. Pengganti Softmax: Lewatkan skor kecocokan ke lapisan LIF
221
+ const sScoresData = this.sScoresDataBuffer!;
222
+ sScoresData.fill(0);
223
+ const dummyLpScores = this.dummyLpScoresBuffer!;
224
+ dummyLpScores.fill(0);
225
+
226
+ if (isNativeAvailable()) {
227
+ lifStepNativeWrapper(this.potentialsScores._data, matchScores, sScoresData, dummyLpScores, this.betaScores, this.thresholdScores);
228
+ } else {
229
+ this.runLIF(this.potentialsScores._data, matchScores, sScoresData, batchSeq, seqLen, this.betaScores, this.thresholdScores);
230
+ }
231
+
232
+ const outData = this.outSpikesBuffer!;
233
+ outData.fill(0);
234
+
235
+ for (let b = 0; b < batch; b++) {
236
+ for (let j = 0; j < seqLen; j++) {
237
+ const vBase = b * seqLen * d_model + j * d_model;
238
+ // Pre-collect non-zero indices for V to exploit sparsity
239
+ const nonZeroV: number[] = [];
240
+ for (let d = 0; d < d_model; d++) {
241
+ if (svData[vBase + d] > 0) nonZeroV.push(d);
242
+ }
243
+ if (nonZeroV.length === 0) continue;
244
+
245
+ for (let i = 0; i < seqLen; i++) {
246
+ const gradedScore = matchScores[b * seqLen * seqLen + i * seqLen + j];
247
+ if (gradedScore > 0) {
248
+ const outBase = b * seqLen * d_model + i * d_model;
249
+ for (let k = 0; k < nonZeroV.length; k++) {
250
+ const d = nonZeroV[k];
251
+ outData[outBase + d] += gradedScore * svData[vBase + d];
252
+ }
253
+ }
254
+ }
255
+ }
256
+ }
257
+
258
+ // Opsional: Batasi output menjadi biner (spike) jika layer berikutnya menuntut binary matrix
259
+ for (let i = 0; i < outData.length; i++) {
260
+ if (outData[i] > 1.0) outData[i] = 1.0;
261
+ }
262
+
263
+ return Matrix.fromFlat(outData, [batchSeq, d_model]);
264
+ }
265
+
266
+ private runLIF(pot: Float32Array, input: Float32Array, output: Float32Array, batch: number, dim: number, beta: Float32Array, threshold: Float32Array) {
267
+ for (let b = 0; b < batch; b++) {
268
+ const offset = b * dim;
269
+ for (let i = 0; i < dim; i++) {
270
+ const idx = offset + i;
271
+ pot[idx] = Math.min((pot[idx] * beta[i]) + input[idx], 1.0);
272
+ }
273
+ for (let i = 0; i < dim; i++) {
274
+ const idx = offset + i;
275
+ if (pot[idx] >= threshold[i]) {
276
+ output[idx] = 1.0;
277
+ pot[idx] -= threshold[i];
278
+ } else {
279
+ output[idx] = 0.0;
280
+ }
281
+ }
282
+ }
283
+ }
284
+
285
+ public learnAttention(errorSignal: Matrix, learningRate: number = 0.01) {
286
+ if (!this.lastInputs) {
287
+ throw new Error("[SpikingSelfAttention] Cannot run learning before forward() is executed.");
288
+ }
289
+
290
+ const err = errorSignal._data;
291
+ const inputs = this.lastInputs._data;
292
+ const batchSeq = this.lastInputs._shape[0];
293
+ // Karena inputs masuk setelah layer 1, shape-nya [batchSeq, d_model]
294
+ const inFeatures = this.lastInputs._shape[1] || this.d_model;
295
+ const d_model = this.d_model;
296
+
297
+ // Update Local Learning: Karena fungsi non-differentiable rumit,
298
+ // kita mendistribusikan sinyal error secara merata ke kernel Q, K, dan V (Hebbian/Surrogate style)
299
+ const kQ = this.kernelQ!._data;
300
+ const kK = this.kernelK!._data;
301
+ const kV = this.kernelV!._data;
302
+
303
+ for (let b = 0; b < batchSeq; b++) {
304
+ const errOffset = b * d_model;
305
+ const inOffset = b * inFeatures;
306
+ for (let i = 0; i < inFeatures; i++) {
307
+ const inVal = inputs[inOffset + i];
308
+ if (inVal > 0) { // Sparse update
309
+ const kOffset = i * d_model;
310
+ for (let d = 0; d < d_model; d++) {
311
+ // Dopamine drive sangat kecil untuk membangkitkan neuron mati tanpa over-saturate
312
+ const dopamine = 0.00005;
313
+
314
+ let deltaQ = (learningRate * err[errOffset + d] * inVal) + dopamine;
315
+ let deltaK = (learningRate * err[errOffset + d] * inVal) + dopamine;
316
+ let deltaV = (learningRate * err[errOffset + d] * inVal) + dopamine;
317
+
318
+ kQ[kOffset + d] = Math.max(-1.0, Math.min(1.0, kQ[kOffset + d] + deltaQ));
319
+ kK[kOffset + d] = Math.max(-1.0, Math.min(1.0, kK[kOffset + d] + deltaK));
320
+ kV[kOffset + d] = Math.max(-1.0, Math.min(1.0, kV[kOffset + d] + deltaV));
321
+ }
322
+ }
323
+ }
324
+ }
325
+ }
326
+
327
+ public getConfig(): Record<string, any> {
328
+ return {
329
+ ...super.getConfig(),
330
+ d_model: this.d_model,
331
+ sequenceLength: this.sequenceLength,
332
+ kernelInitializer: this.kernelInitializer
333
+ };
334
+ }
335
+ }
@@ -107,3 +107,20 @@ export const applyEmbeddingDeltaNativeWrapper = (
107
107
  outputDim
108
108
  );
109
109
  };
110
+
111
+ export const contrastiveHebbianNativeWrapper = (
112
+ spikes: Float32Array,
113
+ errData: Float32Array,
114
+ numPairs: number,
115
+ sequenceLength: number,
116
+ dModel: number
117
+ ): number => {
118
+ if (!native) throw new Error("Spiking Native backend not available");
119
+ return native.contrastiveHebbianNative(
120
+ spikes,
121
+ errData,
122
+ numPairs,
123
+ sequenceLength,
124
+ dModel
125
+ );
126
+ };
@@ -0,0 +1,85 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn contrastive_hebbian_native(
7
+ spikes: Float32Array,
8
+ mut err_data: Float32Array,
9
+ num_pairs: u32,
10
+ sequence_length: u32,
11
+ d_model: u32,
12
+ ) -> f64 {
13
+ let spikes_slice: &[f32] = &spikes;
14
+ let err_slice: &mut [f32] = &mut err_data;
15
+
16
+ let num_pairs = num_pairs as usize;
17
+ let seq_len = sequence_length as usize;
18
+ let d_model = d_model as usize;
19
+ let chunk_size = seq_len * d_model;
20
+
21
+ let total_loss: f32 = err_slice.par_chunks_mut(chunk_size).enumerate().map(|(b, chunk)| {
22
+ let mut local_loss = 0.0f32;
23
+
24
+ if b < num_pairs {
25
+ // Ini adalah vektor Q
26
+ let i = b;
27
+ let p_offset = (num_pairs + i) * chunk_size;
28
+ let n_offset = (num_pairs + ((i + 1) % num_pairs)) * chunk_size;
29
+ let q_offset = i * chunk_size;
30
+
31
+ for rem in 0..chunk_size {
32
+ let q_spike = spikes_slice[q_offset + rem];
33
+ let p_spike = spikes_slice[p_offset + rem];
34
+ let n_spike = spikes_slice[n_offset + rem];
35
+
36
+ let mut pull = p_spike - q_spike;
37
+ if q_spike == 0.0 && p_spike == 0.0 && n_spike == 0.0 {
38
+ pull = 0.05; // Suntik energi
39
+ }
40
+ let push = (q_spike * n_spike) * 0.2;
41
+
42
+ chunk[rem] = pull - push;
43
+
44
+ if pull != 0.0 || push != 0.0 {
45
+ local_loss += pull.abs() + push;
46
+ }
47
+ }
48
+ } else {
49
+ // Ini adalah vektor P atau N
50
+ let p_index = b - num_pairs;
51
+
52
+ // Peran sebagai P untuk i = p_index
53
+ let q_offset_p = p_index * chunk_size;
54
+ let n_offset_p = (num_pairs + ((p_index + 1) % num_pairs)) * chunk_size;
55
+
56
+ // Peran sebagai N untuk i = p_index - 1 (dengan wrap around)
57
+ let i_n = if p_index == 0 { num_pairs - 1 } else { p_index - 1 };
58
+ let q_offset_n = i_n * chunk_size;
59
+
60
+ for rem in 0..chunk_size {
61
+ let q_spike_p = spikes_slice[q_offset_p + rem];
62
+ let p_spike_p = spikes_slice[b * chunk_size + rem];
63
+ let n_spike_p = spikes_slice[n_offset_p + rem];
64
+
65
+ let mut pull_p = p_spike_p - q_spike_p;
66
+ if q_spike_p == 0.0 && p_spike_p == 0.0 && n_spike_p == 0.0 {
67
+ pull_p = 0.05;
68
+ }
69
+ let contrib_p = -pull_p;
70
+
71
+ let q_spike_n = spikes_slice[q_offset_n + rem];
72
+ let n_spike_n = spikes_slice[b * chunk_size + rem];
73
+
74
+ let push_n = (q_spike_n * n_spike_n) * 0.2;
75
+ let contrib_n = -push_n;
76
+
77
+ chunk[rem] = contrib_p + contrib_n;
78
+ }
79
+ }
80
+
81
+ local_loss
82
+ }).sum();
83
+
84
+ total_loss as f64
85
+ }
@@ -8,9 +8,11 @@ mod lif;
8
8
  mod surrogate;
9
9
  mod delta;
10
10
  mod embedding;
11
+ mod contrastive;
11
12
 
12
13
  pub use dot_product::*;
13
14
  pub use lif::*;
14
15
  pub use surrogate::*;
15
16
  pub use delta::*;
16
17
  pub use embedding::*;
18
+ pub use contrastive::*;
@@ -0,0 +1,151 @@
1
+ import { describe, it, expect, beforeEach } from 'vitest';
2
+ import { Matrix } from '@oxide-js/core';
3
+ import { SpikingDenseBPTT } from '../src/layers/SpikingDenseBPTT.js';
4
+
5
+ describe('SpikingDenseBPTT Layer', () => {
6
+ const units = 4;
7
+ const inFeatures = 3;
8
+ const batch = 2;
9
+ const maxTimeSteps = 3;
10
+
11
+ let layer: SpikingDenseBPTT;
12
+
13
+ beforeEach(() => {
14
+ layer = new SpikingDenseBPTT({
15
+ units,
16
+ kernelInitializer: 'ones',
17
+ useBias: false
18
+ });
19
+ layer.build([batch, inFeatures]);
20
+ });
21
+
22
+ it('should initialize parameters correctly', () => {
23
+ expect(layer.units).toBe(units);
24
+ expect(layer.kernel).toBeDefined();
25
+
26
+ // Mengecek apakah inisialisasi dynamic beta (Bit-Shift Decay Float) bekerja
27
+ expect(layer.beta).toBeDefined();
28
+ expect(layer.beta.length).toBe(units);
29
+
30
+ for (let i = 0; i < units; i++) {
31
+ // Karena kita melakukan pre-kalkulasi multiplier: 1.0 - (1.0 / Math.pow(2, shift))
32
+ // dimana shift = 2 hingga 5 (1/4 hingga 1/32)
33
+ // Maka rentang beta yang valid adalah 0.75 hingga 0.96875
34
+ expect(layer.beta[i]).toBeGreaterThanOrEqual(0.75);
35
+ expect(layer.beta[i]).toBeLessThanOrEqual(0.96875);
36
+ }
37
+ });
38
+
39
+ it('should throw error when calling compute() directly', () => {
40
+ const inputs = Matrix.fromFlat(new Float32Array(batch * inFeatures), [batch, inFeatures]);
41
+ expect(() => {
42
+ // @ts-ignore
43
+ layer.compute(inputs);
44
+ }).toThrowError(/Harap gunakan computeStep/);
45
+ });
46
+
47
+ it('should process sequence, enforce BPTT limits, and store history correctly', () => {
48
+ layer.resetSequence(maxTimeSteps);
49
+
50
+ expect(layer.maxTimeSteps).toBe(maxTimeSteps);
51
+ expect(layer.historyInputs.length).toBe(maxTimeSteps);
52
+
53
+ // Dummy binary spike input
54
+ const inputData = new Float32Array(batch * inFeatures).fill(1);
55
+ const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
56
+
57
+ // Time Step 0
58
+ const out0 = layer.computeStep(inputs, 0);
59
+ expect(out0._shape).toEqual([batch, units]);
60
+ expect(layer.historyInputs[0]).toBeDefined();
61
+ expect(layer.historyPotentials[0]).toBeDefined();
62
+ expect(layer.historySpikes[0]).toBeDefined();
63
+
64
+ // Time Step 1
65
+ const out1 = layer.computeStep(inputs, 1);
66
+ expect(out1._shape).toEqual([batch, units]);
67
+ expect(layer.historyInputs[1]).toBeDefined();
68
+
69
+ // Time Step 2
70
+ const out2 = layer.computeStep(inputs, 2);
71
+ expect(out2._shape).toEqual([batch, units]);
72
+
73
+ // Time Step 3 (Exceeds maxTimeSteps -> Harus Error)
74
+ expect(() => {
75
+ layer.computeStep(inputs, 3);
76
+ }).toThrowError(/melebihi batas maxTimeSteps/);
77
+ });
78
+
79
+ it('should run learnThroughTime properly without crashing', () => {
80
+ layer.resetSequence(maxTimeSteps);
81
+
82
+ const inputData = new Float32Array(batch * inFeatures).fill(1);
83
+ const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
84
+
85
+ // Jalankan seluruh sekuens
86
+ for (let t = 0; t < maxTimeSteps; t++) {
87
+ layer.computeStep(inputs, t);
88
+ }
89
+
90
+ // Siapkan urutan error palsu untuk pengujian (error di t=0, t=1, t=2)
91
+ const errors = [];
92
+ for (let t = 0; t < maxTimeSteps; t++) {
93
+ errors.push(Matrix.fromFlat(new Float32Array(batch * units).fill(0.1), [batch, units]));
94
+ }
95
+
96
+ // Uji BPTT untuk Output Layer (parameter B = undefined)
97
+ expect(() => {
98
+ layer.learnThroughTime(errors, undefined, 0.01);
99
+ }).not.toThrow();
100
+
101
+ // Uji BPTT untuk Hidden Layer (parameter B = Identity Matrix Broadcast)
102
+ const B = Matrix.fromFlat(new Float32Array(units * units).fill(1), [units, units]);
103
+ expect(() => {
104
+ layer.learnThroughTime(errors, B, 0.01);
105
+ }).not.toThrow();
106
+ });
107
+ it('should correctly accumulate potentials and trigger spikes deterministically over time', () => {
108
+ // Buat layer deterministik
109
+ const testLayer = new SpikingDenseBPTT({
110
+ units: 2,
111
+ useBias: false,
112
+ kernelInitializer: 'zeros'
113
+ });
114
+ testLayer.build([1, 2]); // batch=1, inFeatures=2
115
+
116
+ // Override kernel manual: [ [0.6, 0.0], [0.0, 0.8] ]
117
+ testLayer.kernel!._data.set([0.6, 0.0, 0.0, 0.8]);
118
+
119
+ // Override konstan beta (0.5 agar mudah dihitung) dan threshold (1.0)
120
+ testLayer.beta.fill(0.5);
121
+ testLayer.threshold.fill(1.0);
122
+
123
+ // Siapkan sequence 3 time steps
124
+ testLayer.resetSequence(3);
125
+
126
+ // Input selalu menyala setiap timestep: [1, 1]
127
+ const inputs = Matrix.fromFlat(new Float32Array([1, 1]), [1, 2]);
128
+
129
+ // --- TIME STEP 0 ---
130
+ const out0 = testLayer.computeStep(inputs, 0);
131
+ expect(out0._data).toEqual(new Float32Array([0, 0]));
132
+
133
+ // History potentials sebelum spike harus mencatat nilai [0.6, 0.8]
134
+ expect(testLayer.historyPotentials[0]._data[0]).toBeCloseTo(0.6, 5);
135
+ expect(testLayer.historyPotentials[0]._data[1]).toBeCloseTo(0.8, 5);
136
+
137
+ // --- TIME STEP 1 ---
138
+ const out1 = testLayer.computeStep(inputs, 1);
139
+ expect(out1._data).toEqual(new Float32Array([0, 1]));
140
+
141
+ expect(testLayer.historyPotentials[1]._data[0]).toBeCloseTo(0.9, 5);
142
+ expect(testLayer.historyPotentials[1]._data[1]).toBeCloseTo(1.0, 5);
143
+
144
+ // --- TIME STEP 2 ---
145
+ const out2 = testLayer.computeStep(inputs, 2);
146
+ expect(out2._data).toEqual(new Float32Array([1, 0]));
147
+
148
+ expect(testLayer.historyPotentials[2]._data[0]).toBeCloseTo(1.0, 5);
149
+ expect(testLayer.historyPotentials[2]._data[1]).toBeCloseTo(0.8, 5);
150
+ });
151
+ });
@@ -0,0 +1,148 @@
1
+ import { describe, it, expect, beforeEach } from 'vitest';
2
+ import { Matrix } from '@oxide-js/core';
3
+ import { SpikingSelfAttention } from '../src/layers/SpikingSelfAttention.js';
4
+
5
+ describe('SpikingSelfAttention Layer', () => {
6
+ const d_model = 4;
7
+ const sequenceLength = 3;
8
+ const batch = 2;
9
+ const batchSeq = batch * sequenceLength;
10
+
11
+ let attentionLayer: SpikingSelfAttention;
12
+
13
+ beforeEach(() => {
14
+ attentionLayer = new SpikingSelfAttention({
15
+ d_model,
16
+ sequenceLength,
17
+ kernelInitializer: 'ones' // Inisialisasi awal statis agar prediktabilitas baik
18
+ });
19
+ attentionLayer.build([batchSeq, d_model]);
20
+ });
21
+
22
+ it('should initialize parameters correctly', () => {
23
+ expect(attentionLayer.d_model).toBe(d_model);
24
+ expect(attentionLayer.sequenceLength).toBe(sequenceLength);
25
+ expect(attentionLayer.kernelQ).toBeDefined();
26
+ expect(attentionLayer.kernelK).toBeDefined();
27
+ expect(attentionLayer.kernelV).toBeDefined();
28
+
29
+ // Potentials harus diawali kosong
30
+ expect(attentionLayer.potentialsQ._data.length).toBe(0);
31
+ expect(attentionLayer.potentialsScores._data.length).toBe(0);
32
+ });
33
+
34
+ it('should compute output with the correct shape and binary spike format', () => {
35
+ // Buat input dummy berupa array biner (0 dan 1)
36
+ const inputData = new Float32Array(batchSeq * d_model);
37
+ for (let i = 0; i < inputData.length; i++) {
38
+ inputData[i] = Math.random() > 0.5 ? 1 : 0;
39
+ }
40
+ const inputs = Matrix.fromFlat(inputData, [batchSeq, d_model]);
41
+
42
+ const output = attentionLayer.forward(inputs) as Matrix;
43
+
44
+ // Cek shape output yang diharapkan: [batch * seqLen, d_model]
45
+ expect(output._shape[0]).toBe(batchSeq);
46
+ expect(output._shape[1]).toBe(d_model);
47
+
48
+ // Pastikan output hanya berisi format spike biner (0.0 atau 1.0)
49
+ const outputData = output._data;
50
+ for (let i = 0; i < outputData.length; i++) {
51
+ expect([0, 1]).toContain(outputData[i]);
52
+ }
53
+
54
+ // State potentials harus terbentuk sesuai shape
55
+ expect(attentionLayer.potentialsQ._data.length).toBe(batchSeq * d_model);
56
+ expect(attentionLayer.potentialsScores._data.length).toBe(batchSeq * sequenceLength);
57
+ });
58
+
59
+ it('should properly accumulate potentials in sequential steps', () => {
60
+ const inputData = new Float32Array(batchSeq * d_model).fill(1);
61
+ const inputs = Matrix.fromFlat(inputData, [batchSeq, d_model]);
62
+
63
+ // Jalankan beberapa time-steps (forward pass)
64
+ attentionLayer.forward(inputs);
65
+
66
+ // Ambil sebagian data dari potentialsQ untuk verifikasi akumulasi
67
+ const firstStepPotentials = new Float32Array(attentionLayer.potentialsQ._data);
68
+
69
+ attentionLayer.forward(inputs);
70
+
71
+ // Karena input konstan, seharusnya potensial naik atau diset ulang setelah spike,
72
+ // tetapi setidaknya harus berjalan normal (tidak crash)
73
+ expect(attentionLayer.potentialsQ._data).toBeDefined();
74
+ // Pada saat reset, potensial harus dikembalikan ke 0
75
+ attentionLayer.resetState();
76
+ const zeros = new Float32Array(attentionLayer.potentialsQ._data.length).fill(0);
77
+ expect(attentionLayer.potentialsQ._data).toEqual(zeros);
78
+ });
79
+
80
+ it('should throw error if input batch length is not multiple of sequence length', () => {
81
+ const invalidBatchSeq = 7; // Bukan kelipatan sequenceLength (3)
82
+ const inputData = new Float32Array(invalidBatchSeq * d_model).fill(1);
83
+ const inputs = Matrix.fromFlat(inputData, [invalidBatchSeq, d_model]);
84
+
85
+ expect(() => {
86
+ attentionLayer.forward(inputs);
87
+ }).toThrowError(/Jumlah baris input/);
88
+ });
89
+
90
+ it('should correctly compute exact Self-Attention math (Deterministic Correctness)', () => {
91
+ // Set up environment yang sepenuhnya deterministik
92
+ const testLayer = new SpikingSelfAttention({
93
+ d_model: 2,
94
+ sequenceLength: 2,
95
+ kernelInitializer: 'zeros' // Kita override manual
96
+ });
97
+
98
+ testLayer.build([2, 2]); // batchSeq=2, d_model=2
99
+
100
+ // Override bobot menjadi Identity Matrix
101
+ const identity = new Float32Array([1, 0, 0, 1]);
102
+ testLayer.kernelQ!._data.set(identity);
103
+ testLayer.kernelK!._data.set(identity);
104
+ testLayer.kernelV!._data.set(identity);
105
+
106
+ // Override LIF properties agar seketika spike jika input >= 1
107
+ testLayer.betaQKV.fill(0.0);
108
+ testLayer.thresholdQKV.fill(0.5);
109
+
110
+ testLayer.betaScores.fill(0.0);
111
+ testLayer.thresholdScores.fill(0.5); // Threshold kecil agar skor >= 1 langsung tembak spike
112
+
113
+ // Input: Token 0 = [1, 0], Token 1 = [0, 1]
114
+ const inputs = Matrix.fromFlat(new Float32Array([1, 0, 0, 1]), [2, 2]);
115
+
116
+ // Lakukan forward pass
117
+ const output = testLayer.forward(inputs) as Matrix;
118
+
119
+ // Analisis ekspektasi:
120
+ // SQ, SK, SV akan identik dengan input (karena dikali Identity dan threshold < 1.0)
121
+ // SQ dot SK^T:
122
+ // - Token 0 dot Token 0 = 1 (match index 0)
123
+ // - Token 0 dot Token 1 = 0
124
+ // - Token 1 dot Token 0 = 0
125
+ // - Token 1 dot Token 1 = 1
126
+ // S_Scores akan menjadi Identity Matrix [1, 0, 0, 1]
127
+ // Hasil akhir: S_Scores dikali SV -> [1, 0, 0, 1]
128
+
129
+ expect(output._shape).toEqual([2, 2]);
130
+ expect(output._data).toEqual(new Float32Array([1, 0, 0, 1]));
131
+
132
+ // Mari tes dengan Token yang identik: Token 0 = [1, 0], Token 1 = [1, 0]
133
+ const inputsSame = Matrix.fromFlat(new Float32Array([1, 0, 1, 0]), [2, 2]);
134
+ const outputSame = testLayer.forward(inputsSame) as Matrix;
135
+
136
+ // SQ dot SK^T:
137
+ // - Token 0 dot Token 0 = 1
138
+ // - Token 0 dot Token 1 = 1
139
+ // - Token 1 dot Token 0 = 1
140
+ // - Token 1 dot Token 1 = 1
141
+ // S_Scores akan menjadi [1, 1, 1, 1] (semuanya spike karena saling cocok)
142
+ // SV = [1, 0, 1, 0]
143
+ // Perkalian akhir:
144
+ // Out Token 0 = S_Scores[0]*SV[0] + S_Scores[1]*SV[1] = 1*1 + 1*1 = 2 (Lalu dikunci/clamped ke 1) = [1, 0]
145
+ // Out Token 1 = [1, 0]
146
+ expect(outputSame._data).toEqual(new Float32Array([1, 0, 1, 0]));
147
+ });
148
+ });