@oxide-js/spiking 1.2.0 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +8 -0
- package/index.d.ts +1 -0
- package/index.js +2 -1
- package/package.json +1 -1
- package/spiking-native.linux-x64-gnu.node +0 -0
- package/src/index.ts +3 -0
- package/src/layers/SpikingDense.ts +18 -8
- package/src/layers/SpikingDenseBPTT.ts +303 -0
- package/src/layers/SpikingEmbedding.ts +36 -12
- package/src/layers/SpikingSelfAttention.ts +335 -0
- package/src/native_backend.ts +17 -0
- package/src-rust/src/contrastive.rs +85 -0
- package/src-rust/src/lib.rs +2 -0
- package/test/SpikingDenseBPTT.test.ts +151 -0
- package/test/SpikingSelfAttention.test.ts +148 -0
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,13 @@
|
|
|
1
1
|
# @oxide-js/spiking
|
|
2
2
|
|
|
3
|
+
## 1.3.0
|
|
4
|
+
|
|
5
|
+
### Minor Changes
|
|
6
|
+
|
|
7
|
+
- 1f6489c: - Added `SpikingDenseBPTT` temporal pooler with Sequence-as-Time dynamics.
|
|
8
|
+
- Implemented Spike-Count Accumulation with L2 Normalization to solve membrane saturation in long sequences.
|
|
9
|
+
- Updated documentation for SpikingDenseBPTT in API layers.
|
|
10
|
+
|
|
3
11
|
## 1.2.0
|
|
4
12
|
|
|
5
13
|
### Minor Changes
|
package/index.d.ts
CHANGED
|
@@ -8,3 +8,4 @@ export declare function lifStepNative(potentials: Float32Array, dot: Float32Arra
|
|
|
8
8
|
export declare function maskSurrogateNative(errorSignal: Float32Array, potentials: Float32Array, threshold: Float32Array, windowSize: number): void
|
|
9
9
|
export declare function applyAddOnlyDeltaNative(kernel: Float32Array, bias: Float32Array, inputs: Float32Array, errorSignal: Float32Array, learningRate: number, batch: number, inFeatures: number, units: number, useBias: boolean): void
|
|
10
10
|
export declare function applyEmbeddingDeltaNative(embeddings: Float32Array, inputs: Float32Array, errorSignal: Float32Array, learningRate: number, inputDim: number, outputDim: number): void
|
|
11
|
+
export declare function contrastiveHebbianNative(spikes: Float32Array, errData: Float32Array, numPairs: number, sequenceLength: number, dModel: number): number
|
package/index.js
CHANGED
|
@@ -310,10 +310,11 @@ if (!nativeBinding) {
|
|
|
310
310
|
throw new Error(`Failed to load native binding`)
|
|
311
311
|
}
|
|
312
312
|
|
|
313
|
-
const { dotProductAddOnlyNative, lifStepNative, maskSurrogateNative, applyAddOnlyDeltaNative, applyEmbeddingDeltaNative } = nativeBinding
|
|
313
|
+
const { dotProductAddOnlyNative, lifStepNative, maskSurrogateNative, applyAddOnlyDeltaNative, applyEmbeddingDeltaNative, contrastiveHebbianNative } = nativeBinding
|
|
314
314
|
|
|
315
315
|
module.exports.dotProductAddOnlyNative = dotProductAddOnlyNative
|
|
316
316
|
module.exports.lifStepNative = lifStepNative
|
|
317
317
|
module.exports.maskSurrogateNative = maskSurrogateNative
|
|
318
318
|
module.exports.applyAddOnlyDeltaNative = applyAddOnlyDeltaNative
|
|
319
319
|
module.exports.applyEmbeddingDeltaNative = applyEmbeddingDeltaNative
|
|
320
|
+
module.exports.contrastiveHebbianNative = contrastiveHebbianNative
|
package/package.json
CHANGED
|
Binary file
|
package/src/index.ts
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
export { default as dotProductAddOnly } from "./math/dotProductAddOnly.js";
|
|
2
2
|
export { SpikingDense, type SpikingDenseConfig } from "./layers/SpikingDense.js";
|
|
3
|
+
export { SpikingDenseBPTT, type SpikingDenseBPTTConfig } from "./layers/SpikingDenseBPTT.js";
|
|
3
4
|
export { SpikingEmbedding, type SpikingEmbeddingConfig } from "./layers/SpikingEmbedding.js";
|
|
5
|
+
export { SpikingSelfAttention, type SpikingSelfAttentionConfig } from "./layers/SpikingSelfAttention.js";
|
|
4
6
|
// export { SpikingSentenceEmbedder, type SpikingSentenceConfig } from "./models/SpikingSentenceEmbedder.js";
|
|
7
|
+
export { contrastiveHebbianNativeWrapper, isNativeAvailable } from "./native_backend.js";
|
|
@@ -12,6 +12,8 @@ export interface SpikingDenseConfig extends LayerConfig {
|
|
|
12
12
|
useBias?: boolean;
|
|
13
13
|
kernelInitializer?: string;
|
|
14
14
|
biasInitializer?: string;
|
|
15
|
+
betaRange?: [number, number];
|
|
16
|
+
thresholdRange?: [number, number];
|
|
15
17
|
}
|
|
16
18
|
|
|
17
19
|
export class SpikingDense extends BaseLayer {
|
|
@@ -19,6 +21,8 @@ export class SpikingDense extends BaseLayer {
|
|
|
19
21
|
public useBias: boolean;
|
|
20
22
|
public kernelInitializer: string;
|
|
21
23
|
public biasInitializer: string;
|
|
24
|
+
public betaRange: [number, number];
|
|
25
|
+
public thresholdRange: [number, number];
|
|
22
26
|
public beta!: Float32Array;
|
|
23
27
|
public threshold!: Float32Array;
|
|
24
28
|
|
|
@@ -41,6 +45,8 @@ export class SpikingDense extends BaseLayer {
|
|
|
41
45
|
this.useBias = config.useBias ?? true;
|
|
42
46
|
this.kernelInitializer = config.kernelInitializer || "glorot_normal";
|
|
43
47
|
this.biasInitializer = config.biasInitializer || "zeros";
|
|
48
|
+
this.betaRange = config.betaRange || [0.8, 0.99];
|
|
49
|
+
this.thresholdRange = config.thresholdRange || [0.5, 1.0];
|
|
44
50
|
}
|
|
45
51
|
|
|
46
52
|
public computeOutputShape(inputShape: number[]): number[] {
|
|
@@ -65,23 +71,27 @@ export class SpikingDense extends BaseLayer {
|
|
|
65
71
|
this.beta = new Float32Array(this.units);
|
|
66
72
|
this.threshold = new Float32Array(this.units);
|
|
67
73
|
for (let i = 0; i < this.units; i++) {
|
|
68
|
-
this.beta[i] = 0
|
|
69
|
-
this.threshold[i] = 0
|
|
74
|
+
this.beta[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
|
|
75
|
+
this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
|
|
70
76
|
}
|
|
71
77
|
|
|
72
78
|
// Inisialisasi state
|
|
73
79
|
this.potentials = Matrix.fromFlat(new Float32Array(this.units), [1, this.units]);
|
|
74
80
|
}
|
|
75
81
|
|
|
82
|
+
private outSpikesDataBuffer?: Float32Array;
|
|
83
|
+
|
|
76
84
|
private ensurePotentialsShape(batch: number) {
|
|
77
|
-
if (this.potentials._shape[0] !== batch) {
|
|
85
|
+
if (this.potentials._shape[0] !== batch || !this.outSpikesDataBuffer) {
|
|
78
86
|
this.potentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
|
|
87
|
+
this.outSpikesDataBuffer = new Float32Array(batch * this.units);
|
|
88
|
+
this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
|
|
79
89
|
}
|
|
80
90
|
}
|
|
81
91
|
|
|
82
92
|
public resetState() {
|
|
83
93
|
if (this.potentials) this.potentials._data.fill(0);
|
|
84
|
-
this.lastPotentials
|
|
94
|
+
if (this.lastPotentials) this.lastPotentials._data.fill(0);
|
|
85
95
|
this.lastInputs = undefined;
|
|
86
96
|
this.lastSpikes = undefined;
|
|
87
97
|
}
|
|
@@ -100,23 +110,23 @@ export class SpikingDense extends BaseLayer {
|
|
|
100
110
|
}
|
|
101
111
|
|
|
102
112
|
// 3 & 4. Leaky Integrate, Fire & Reset
|
|
103
|
-
const outData =
|
|
113
|
+
const outData = this.outSpikesDataBuffer!;
|
|
114
|
+
outData.fill(0);
|
|
104
115
|
const outSpikes = Matrix.fromFlat(outData, [batch, this.units]);
|
|
105
|
-
this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
|
|
106
116
|
|
|
107
117
|
if (isNativeAvailable()) {
|
|
108
118
|
lifStepNativeWrapper(
|
|
109
119
|
this.potentials._data,
|
|
110
120
|
dot._data,
|
|
111
121
|
outSpikes._data,
|
|
112
|
-
this.lastPotentials
|
|
122
|
+
this.lastPotentials!._data,
|
|
113
123
|
this.beta,
|
|
114
124
|
this.threshold
|
|
115
125
|
);
|
|
116
126
|
} else {
|
|
117
127
|
const potData = this.potentials._data;
|
|
118
128
|
const dotData = dot._data;
|
|
119
|
-
const lpData = this.lastPotentials
|
|
129
|
+
const lpData = this.lastPotentials!._data;
|
|
120
130
|
|
|
121
131
|
for (let b = 0; b < batch; b++) {
|
|
122
132
|
const offset = b * this.units;
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
|
|
2
|
+
import { Matrix, mj } from "@oxide-js/core";
|
|
3
|
+
import {
|
|
4
|
+
isNativeAvailable,
|
|
5
|
+
lifStepNativeWrapper,
|
|
6
|
+
maskSurrogateNativeWrapper,
|
|
7
|
+
applyAddOnlyDeltaNativeWrapper
|
|
8
|
+
} from "../native_backend.js";
|
|
9
|
+
import dotProductAddOnly from "../math/dotProductAddOnly.js";
|
|
10
|
+
|
|
11
|
+
export interface SpikingDenseBPTTConfig extends LayerConfig {
|
|
12
|
+
units: number;
|
|
13
|
+
useBias?: boolean;
|
|
14
|
+
kernelInitializer?: string;
|
|
15
|
+
biasInitializer?: string;
|
|
16
|
+
betaRange?: [number, number];
|
|
17
|
+
thresholdRange?: [number, number];
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
export class SpikingDenseBPTT extends BaseLayer {
|
|
21
|
+
public units: number;
|
|
22
|
+
public useBias: boolean;
|
|
23
|
+
public kernelInitializer: string;
|
|
24
|
+
public biasInitializer: string;
|
|
25
|
+
public betaRange: [number, number];
|
|
26
|
+
public thresholdRange: [number, number];
|
|
27
|
+
public beta!: Float32Array;
|
|
28
|
+
public threshold!: Float32Array;
|
|
29
|
+
|
|
30
|
+
public potentials!: Matrix;
|
|
31
|
+
|
|
32
|
+
// History buffers for Backpropagation Through Time (BPTT)
|
|
33
|
+
public historyInputs: Matrix[] = [];
|
|
34
|
+
public historyPotentials: Matrix[] = [];
|
|
35
|
+
public historySpikes: Matrix[] = [];
|
|
36
|
+
public maxTimeSteps: number = 0;
|
|
37
|
+
|
|
38
|
+
// Buffer untuk performa komputasi Forward
|
|
39
|
+
private outSpikesDataBuffer?: Float32Array;
|
|
40
|
+
|
|
41
|
+
public get kernel(): Matrix | undefined {
|
|
42
|
+
return this.getParameter("kernel");
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
public get bias(): Matrix | undefined {
|
|
46
|
+
return this.getParameter("bias");
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
constructor(config: SpikingDenseBPTTConfig) {
|
|
50
|
+
super(config);
|
|
51
|
+
this.units = config.units;
|
|
52
|
+
this.useBias = config.useBias ?? true;
|
|
53
|
+
this.kernelInitializer = config.kernelInitializer || "glorot_normal";
|
|
54
|
+
this.biasInitializer = config.biasInitializer || "zeros";
|
|
55
|
+
this.betaRange = config.betaRange || [0.8, 0.99];
|
|
56
|
+
this.thresholdRange = config.thresholdRange || [0.5, 1.0];
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
public computeOutputShape(inputShape: number[]): number[] {
|
|
60
|
+
const batch = inputShape[0] ?? -1;
|
|
61
|
+
return [batch, this.units];
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
public build(inputShape: number[]): void {
|
|
65
|
+
super.build(inputShape);
|
|
66
|
+
|
|
67
|
+
const inFeatures = inputShape[inputShape.length - 1];
|
|
68
|
+
|
|
69
|
+
const kernelVal = this.createInitializer(this.kernelInitializer, [inFeatures, this.units]);
|
|
70
|
+
this.addParameter("kernel", kernelVal, true, [inFeatures, this.units]);
|
|
71
|
+
|
|
72
|
+
if (this.useBias) {
|
|
73
|
+
const biasVal = this.createInitializer(this.biasInitializer, [this.units, 1]);
|
|
74
|
+
this.addParameter("bias", biasVal, true, [this.units, 1]);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Inisialisasi beta (dengan pre-calculated bit-shift logic) dan threshold acak
|
|
78
|
+
this.beta = new Float32Array(this.units);
|
|
79
|
+
this.threshold = new Float32Array(this.units);
|
|
80
|
+
for (let i = 0; i < this.units; i++) {
|
|
81
|
+
// Pilih pangkat bit-shift secara acak (2 hingga 5)
|
|
82
|
+
const shift = Math.floor(2 + Math.random() * 4);
|
|
83
|
+
// Pre-kalkulasi multiplier float agar loop sangat cepat tanpa perkalian ekstra
|
|
84
|
+
this.beta[i] = 1.0 - (1.0 / Math.pow(2, shift));
|
|
85
|
+
this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Inisialisasi state
|
|
89
|
+
this.potentials = Matrix.fromFlat(new Float32Array(this.units), [1, this.units]);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
private ensurePotentialsShape(batch: number) {
|
|
93
|
+
if (this.potentials._shape[0] !== batch || !this.outSpikesDataBuffer) {
|
|
94
|
+
this.potentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
|
|
95
|
+
this.outSpikesDataBuffer = new Float32Array(batch * this.units);
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// Panggil fungsi ini SEBELUM satu kalimat/sequence baru mulai dimasukkan
|
|
100
|
+
public resetSequence(timeSteps: number) {
|
|
101
|
+
this.maxTimeSteps = timeSteps;
|
|
102
|
+
this.historyInputs = new Array(timeSteps);
|
|
103
|
+
this.historyPotentials = new Array(timeSteps);
|
|
104
|
+
this.historySpikes = new Array(timeSteps);
|
|
105
|
+
|
|
106
|
+
if (this.potentials) this.potentials._data.fill(0);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
|
|
110
|
+
throw new Error("[SpikingDenseBPTT] Harap gunakan computeStep(inputs, t) dan resetSequence(t) untuk model BPTT, jangan gunakan compute().");
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// Forward Pass untuk satu token di waktu ke-t
|
|
114
|
+
public computeStep(inputs: Matrix, t: number): Matrix {
|
|
115
|
+
if (t >= this.maxTimeSteps) {
|
|
116
|
+
throw new Error(`[SpikingDenseBPTT] Time step ${t} melebihi batas maxTimeSteps ${this.maxTimeSteps}`);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
const kernel = this.kernel!;
|
|
120
|
+
const batch = inputs._shape[0];
|
|
121
|
+
this.ensurePotentialsShape(batch);
|
|
122
|
+
|
|
123
|
+
// 1. Simpan input ke dalam history buffer di index t
|
|
124
|
+
this.historyInputs[t] = Matrix.fromFlat(new Float32Array(inputs._data), inputs._shape);
|
|
125
|
+
|
|
126
|
+
// 2. Add-Only Spiking Dot Product
|
|
127
|
+
let dot = dotProductAddOnly(inputs, kernel);
|
|
128
|
+
|
|
129
|
+
if (this.useBias && this.bias) {
|
|
130
|
+
mj.addBiasRow(dot, this.bias);
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
const outData = this.outSpikesDataBuffer!;
|
|
134
|
+
outData.fill(0);
|
|
135
|
+
const outSpikes = Matrix.fromFlat(outData, [batch, this.units]);
|
|
136
|
+
|
|
137
|
+
// Buffer untuk menyimpan Potensial Membran di waktu t (SETELAH ditambahkan input, tapi SEBELUM ditembakkan/direset)
|
|
138
|
+
// Ini krusial untuk evaluasi kedekatan threshold pada proses Surrogate Gradient BPTT
|
|
139
|
+
const potAtT = new Float32Array(batch * this.units);
|
|
140
|
+
|
|
141
|
+
if (isNativeAvailable()) {
|
|
142
|
+
lifStepNativeWrapper(
|
|
143
|
+
this.potentials._data,
|
|
144
|
+
dot._data,
|
|
145
|
+
outSpikes._data,
|
|
146
|
+
potAtT, // argumen ke-4 di lifStepNative akan diisi oleh potensial pre-fire
|
|
147
|
+
this.beta,
|
|
148
|
+
this.threshold
|
|
149
|
+
);
|
|
150
|
+
} else {
|
|
151
|
+
const potData = this.potentials._data;
|
|
152
|
+
const dotData = dot._data;
|
|
153
|
+
|
|
154
|
+
for (let b = 0; b < batch; b++) {
|
|
155
|
+
const offset = b * this.units;
|
|
156
|
+
// Tahap Leaky & Integrate
|
|
157
|
+
for (let i = 0; i < this.units; i++) {
|
|
158
|
+
const idx = offset + i;
|
|
159
|
+
potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0);
|
|
160
|
+
potAtT[idx] = potData[idx]; // Catat memori potensial di waktu t
|
|
161
|
+
}
|
|
162
|
+
// Tahap Fire & Reset
|
|
163
|
+
for (let i = 0; i < this.units; i++) {
|
|
164
|
+
const idx = offset + i;
|
|
165
|
+
if (potData[idx] >= this.threshold[i]) {
|
|
166
|
+
outData[idx] = 1;
|
|
167
|
+
potData[idx] -= this.threshold[i]; // Soft Reset
|
|
168
|
+
} else {
|
|
169
|
+
outData[idx] = 0;
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// 3. Simpan state (Potensial & Output Spikes) ke buffer di index t
|
|
176
|
+
this.historyPotentials[t] = Matrix.fromFlat(potAtT, [batch, this.units]);
|
|
177
|
+
this.historySpikes[t] = Matrix.fromFlat(new Float32Array(outSpikes._data), [batch, this.units]);
|
|
178
|
+
|
|
179
|
+
return outSpikes;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// Backward Pass untuk belajar menggunakan BPTT
|
|
183
|
+
// Dipanggil SATU KALI HANYA saat kalimat (sequence) selesai diproses
|
|
184
|
+
public learnThroughTime(errorSequence: Matrix[], B: Matrix | undefined, learningRate: number = 0.01): void {
|
|
185
|
+
if (this.maxTimeSteps === 0 || !this.historyInputs[0]) {
|
|
186
|
+
throw new Error("[SpikingDenseBPTT] Belum ada data di memory. Panggil computeStep(inputs, t) terlebih dahulu.");
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
const batch = errorSequence[0]._shape[0];
|
|
190
|
+
const inFeatures = this.historyInputs[0]._shape[1];
|
|
191
|
+
const units = this.units;
|
|
192
|
+
const kernel = this.kernel!._data;
|
|
193
|
+
|
|
194
|
+
// Sinyal "Penyesalan" yang menjalar mundur dari masa depan ke masa lalu
|
|
195
|
+
let temporalErrorData = new Float32Array(batch * units).fill(0);
|
|
196
|
+
const windowSize = 1.0; // Lebar boxcar surrogate gradient
|
|
197
|
+
|
|
198
|
+
// PRE-ALLOCATE buffers untuk menghilangkan BOTTLENECK Javascript (Zero Garbage Collection dalam loop)
|
|
199
|
+
const totalErrorData = new Float32Array(batch * units);
|
|
200
|
+
const maskedErrorData = new Float32Array(batch * units);
|
|
201
|
+
const biasData = (this.useBias && this.bias) ? this.bias._data : new Float32Array(0);
|
|
202
|
+
|
|
203
|
+
// Loop Mundur (Dari akhir kalimat ke awal kalimat)
|
|
204
|
+
for (let t = this.maxTimeSteps - 1; t >= 0; t--) {
|
|
205
|
+
let currentErrorData: Float32Array;
|
|
206
|
+
|
|
207
|
+
if (B) {
|
|
208
|
+
// Jika ini layer tersembunyi (Hidden): Evaluasi menggunakan matriks broadcast B
|
|
209
|
+
let eHidden = mj.dotProduct(errorSequence[t], B, undefined, false, false);
|
|
210
|
+
currentErrorData = eHidden._data;
|
|
211
|
+
} else {
|
|
212
|
+
// Jika ini layer Output: Error murni dari loss function
|
|
213
|
+
currentErrorData = errorSequence[t]._data;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
const pData = this.historyPotentials[t]._data;
|
|
217
|
+
const inputData = this.historyInputs[t]._data;
|
|
218
|
+
|
|
219
|
+
// Langkah A: Menyatukan Sinyal Spasial (Atas-Bawah) dan Temporal (Masa Depan-Masa Lalu)
|
|
220
|
+
for (let i = 0; i < totalErrorData.length; i++) {
|
|
221
|
+
totalErrorData[i] = currentErrorData[i] + temporalErrorData[i];
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
// Salin total error ke masked error, karena native wrapper akan menimpanya in-place
|
|
225
|
+
maskedErrorData.set(totalErrorData);
|
|
226
|
+
|
|
227
|
+
if (isNativeAvailable()) {
|
|
228
|
+
// 1. Surrogate Mask Native (Zero Copy pointer passing ke Rust)
|
|
229
|
+
maskSurrogateNativeWrapper(maskedErrorData, pData, this.threshold, windowSize);
|
|
230
|
+
|
|
231
|
+
// 2. Add-Only Delta Rule Native (Zero Copy pointer passing ke Rust)
|
|
232
|
+
applyAddOnlyDeltaNativeWrapper(
|
|
233
|
+
kernel,
|
|
234
|
+
biasData,
|
|
235
|
+
inputData,
|
|
236
|
+
maskedErrorData,
|
|
237
|
+
learningRate,
|
|
238
|
+
batch,
|
|
239
|
+
inFeatures,
|
|
240
|
+
units,
|
|
241
|
+
this.useBias
|
|
242
|
+
);
|
|
243
|
+
|
|
244
|
+
// 3. Hitung temporal error (leaky pathway)
|
|
245
|
+
for (let b = 0; b < batch; b++) {
|
|
246
|
+
const offset = b * units;
|
|
247
|
+
for (let i = 0; i < units; i++) {
|
|
248
|
+
const idx = offset + i;
|
|
249
|
+
temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
} else {
|
|
253
|
+
// ============ FALLBACK JAVASCRIPT ============
|
|
254
|
+
for (let b = 0; b < batch; b++) {
|
|
255
|
+
const offset = b * units;
|
|
256
|
+
for (let i = 0; i < units; i++) {
|
|
257
|
+
const idx = offset + i;
|
|
258
|
+
// Surrogate Boxcar Masking
|
|
259
|
+
if (Math.abs(pData[idx] - this.threshold[i]) <= windowSize) {
|
|
260
|
+
// maskedErrorData sudah berisi totalErrData karena di-.set() di atas
|
|
261
|
+
} else {
|
|
262
|
+
maskedErrorData[idx] = 0;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
// Menghitung Sinyal Temporal untuk dilanjutkan ke waktu t-1 (melewati jalur leaky/beta)
|
|
266
|
+
temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Langkah B: Add-Only Delta Rule untuk update Bobot di waktu t
|
|
271
|
+
for (let b = 0; b < batch; b++) {
|
|
272
|
+
const inOffset = b * inFeatures;
|
|
273
|
+
const errOffset = b * units;
|
|
274
|
+
for (let k = 0; k < inFeatures; k++) {
|
|
275
|
+
if (inputData[inOffset + k] > 0.5) {
|
|
276
|
+
const kOffset = k * units;
|
|
277
|
+
for (let j = 0; j < units; j++) {
|
|
278
|
+
kernel[kOffset + j] += learningRate * maskedErrorData[errOffset + j];
|
|
279
|
+
kernel[kOffset + j] = Math.max(-1.0, Math.min(1.0, kernel[kOffset + j]));
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
if (this.useBias && this.bias) {
|
|
284
|
+
for (let j = 0; j < units; j++) {
|
|
285
|
+
biasData[j] += (learningRate * maskedErrorData[errOffset + j]) / batch;
|
|
286
|
+
biasData[j] = Math.max(-1.0, Math.min(1.0, biasData[j]));
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
public getConfig(): Record<string, any> {
|
|
295
|
+
return {
|
|
296
|
+
...super.getConfig(),
|
|
297
|
+
units: this.units,
|
|
298
|
+
useBias: this.useBias,
|
|
299
|
+
kernelInitializer: this.kernelInitializer,
|
|
300
|
+
biasInitializer: this.biasInitializer
|
|
301
|
+
};
|
|
302
|
+
}
|
|
303
|
+
}
|
|
@@ -3,19 +3,25 @@ import { Matrix, mj } from "@oxide-js/core";
|
|
|
3
3
|
import {
|
|
4
4
|
isNativeAvailable,
|
|
5
5
|
lifStepNativeWrapper,
|
|
6
|
-
maskSurrogateNativeWrapper
|
|
6
|
+
maskSurrogateNativeWrapper,
|
|
7
|
+
applyEmbeddingDeltaNativeWrapper
|
|
7
8
|
} from "../native_backend.js";
|
|
8
9
|
|
|
9
10
|
export interface SpikingEmbeddingConfig extends LayerConfig {
|
|
10
11
|
inputDim: number;
|
|
11
12
|
outputDim: number;
|
|
12
13
|
embeddingsInitializer?: string;
|
|
14
|
+
betaRange?: [number, number];
|
|
15
|
+
thresholdRange?: [number, number];
|
|
13
16
|
}
|
|
14
17
|
|
|
15
18
|
export class SpikingEmbedding extends BaseLayer {
|
|
16
19
|
public inputDim: number;
|
|
17
20
|
public outputDim: number;
|
|
18
21
|
public embeddingsInitializer: string;
|
|
22
|
+
|
|
23
|
+
public betaRange: [number, number];
|
|
24
|
+
public thresholdRange: [number, number];
|
|
19
25
|
public beta!: Float32Array;
|
|
20
26
|
public threshold!: Float32Array;
|
|
21
27
|
|
|
@@ -33,6 +39,8 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
33
39
|
this.inputDim = config.inputDim;
|
|
34
40
|
this.outputDim = config.outputDim;
|
|
35
41
|
this.embeddingsInitializer = config.embeddingsInitializer || "glorot_normal";
|
|
42
|
+
this.betaRange = config.betaRange || [0.8, 0.99];
|
|
43
|
+
this.thresholdRange = config.thresholdRange || [0.01, 0.1];
|
|
36
44
|
}
|
|
37
45
|
|
|
38
46
|
public computeOutputShape(inputShape: number[]): number[] {
|
|
@@ -44,29 +52,43 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
44
52
|
super.build(inputShape);
|
|
45
53
|
|
|
46
54
|
const embVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
|
|
55
|
+
|
|
56
|
+
// Scale up the embedding values because Glorot Normal makes them too small for large vocabularies (e.g. 32000)
|
|
57
|
+
// which prevents the LIF neurons from ever reaching the threshold.
|
|
58
|
+
const scaleFactor = Math.sqrt(this.inputDim);
|
|
59
|
+
for (let i = 0; i < embVal._data.length; i++) {
|
|
60
|
+
embVal._data[i] *= scaleFactor;
|
|
61
|
+
}
|
|
62
|
+
|
|
47
63
|
this.addParameter("embeddings", embVal, true, [this.inputDim, this.outputDim]);
|
|
48
64
|
|
|
49
65
|
// Inisialisasi beta dan threshold secara acak untuk setiap neuron
|
|
50
66
|
this.beta = new Float32Array(this.outputDim);
|
|
51
67
|
this.threshold = new Float32Array(this.outputDim);
|
|
52
68
|
for (let i = 0; i < this.outputDim; i++) {
|
|
53
|
-
this.beta[i] = 0
|
|
54
|
-
this.threshold[i] = 0
|
|
69
|
+
this.beta[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
|
|
70
|
+
this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
|
|
55
71
|
}
|
|
56
72
|
|
|
57
73
|
// Potentials start at 0, shape [batch, outputDim].
|
|
58
74
|
this.potentials = Matrix.fromFlat(new Float32Array(this.outputDim), [1, this.outputDim]);
|
|
59
75
|
}
|
|
60
76
|
|
|
77
|
+
private dotDataBuffer?: Float32Array;
|
|
78
|
+
private outDataBuffer?: Float32Array;
|
|
79
|
+
|
|
61
80
|
private ensurePotentialsShape(batch: number) {
|
|
62
|
-
if (this.potentials._shape[0] !== batch) {
|
|
81
|
+
if (this.potentials._shape[0] !== batch || !this.dotDataBuffer) {
|
|
63
82
|
this.potentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
83
|
+
this.dotDataBuffer = new Float32Array(batch * this.outputDim);
|
|
84
|
+
this.outDataBuffer = new Float32Array(batch * this.outputDim);
|
|
85
|
+
this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
64
86
|
}
|
|
65
87
|
}
|
|
66
88
|
|
|
67
89
|
public resetState() {
|
|
68
90
|
if (this.potentials) this.potentials._data.fill(0);
|
|
69
|
-
this.lastPotentials
|
|
91
|
+
if (this.lastPotentials) this.lastPotentials._data.fill(0);
|
|
70
92
|
this.lastInputs = undefined;
|
|
71
93
|
this.lastSpikes = undefined;
|
|
72
94
|
}
|
|
@@ -77,7 +99,8 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
77
99
|
|
|
78
100
|
// 1. Embedding lookup
|
|
79
101
|
const emb = this.embeddings!;
|
|
80
|
-
const dotData =
|
|
102
|
+
const dotData = this.dotDataBuffer!;
|
|
103
|
+
dotData.fill(0);
|
|
81
104
|
|
|
82
105
|
for (let b = 0; b < batch; b++) {
|
|
83
106
|
const tokenIdx = inputs._data[b];
|
|
@@ -90,23 +113,24 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
90
113
|
}
|
|
91
114
|
}
|
|
92
115
|
|
|
93
|
-
// 2. Leaky Integrate and Fire
|
|
94
|
-
const outData =
|
|
116
|
+
// 2. Leaky Integrate and Fire (LIF Restore untuk Spiking Murni)
|
|
117
|
+
const outData = this.outDataBuffer!;
|
|
118
|
+
outData.fill(0);
|
|
95
119
|
const outSpikes = Matrix.fromFlat(outData, [batch, this.outputDim]);
|
|
96
|
-
|
|
120
|
+
// lastPotentials is already ensured in shape
|
|
97
121
|
|
|
98
122
|
if (isNativeAvailable()) {
|
|
99
123
|
lifStepNativeWrapper(
|
|
100
124
|
this.potentials._data,
|
|
101
125
|
dotData,
|
|
102
126
|
outData,
|
|
103
|
-
this.lastPotentials
|
|
127
|
+
this.lastPotentials!._data,
|
|
104
128
|
this.beta,
|
|
105
129
|
this.threshold
|
|
106
130
|
);
|
|
107
131
|
} else {
|
|
108
132
|
const potData = this.potentials._data;
|
|
109
|
-
const lpData = this.lastPotentials
|
|
133
|
+
const lpData = this.lastPotentials!._data;
|
|
110
134
|
|
|
111
135
|
for (let b = 0; b < batch; b++) {
|
|
112
136
|
const offset = b * this.outputDim;
|
|
@@ -126,7 +150,7 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
126
150
|
}
|
|
127
151
|
}
|
|
128
152
|
}
|
|
129
|
-
|
|
153
|
+
|
|
130
154
|
this.lastInputs = inputs;
|
|
131
155
|
this.lastSpikes = outSpikes;
|
|
132
156
|
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
|
|
2
|
+
import { Matrix, mj } from "@oxide-js/core";
|
|
3
|
+
import { isNativeAvailable, lifStepNativeWrapper } from "../native_backend.js";
|
|
4
|
+
import dotProductAddOnly from "../math/dotProductAddOnly.js";
|
|
5
|
+
|
|
6
|
+
export interface SpikingSelfAttentionConfig extends LayerConfig {
|
|
7
|
+
d_model: number;
|
|
8
|
+
sequenceLength: number;
|
|
9
|
+
kernelInitializer?: string;
|
|
10
|
+
betaRange?: [number, number];
|
|
11
|
+
thresholdRange?: [number, number];
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export class SpikingSelfAttention extends BaseLayer {
|
|
15
|
+
public d_model: number;
|
|
16
|
+
public sequenceLength: number;
|
|
17
|
+
public kernelInitializer: string;
|
|
18
|
+
public betaRange: [number, number];
|
|
19
|
+
public thresholdRange: [number, number];
|
|
20
|
+
|
|
21
|
+
// Q, K, V kernels
|
|
22
|
+
public get kernelQ(): Matrix | undefined { return this.getParameter("kernelQ"); }
|
|
23
|
+
public get kernelK(): Matrix | undefined { return this.getParameter("kernelK"); }
|
|
24
|
+
public get kernelV(): Matrix | undefined { return this.getParameter("kernelV"); }
|
|
25
|
+
|
|
26
|
+
// LIF state untuk Q, K, V (opsional, jika ingin akumulasi temporal)
|
|
27
|
+
public betaQKV!: Float32Array;
|
|
28
|
+
public thresholdQKV!: Float32Array;
|
|
29
|
+
public potentialsQ!: Matrix;
|
|
30
|
+
public potentialsK!: Matrix;
|
|
31
|
+
public potentialsV!: Matrix;
|
|
32
|
+
|
|
33
|
+
// LIF state untuk Attention Scores (Pengganti Softmax)
|
|
34
|
+
public betaScores!: Float32Array;
|
|
35
|
+
public thresholdScores!: Float32Array;
|
|
36
|
+
public potentialsScores!: Matrix;
|
|
37
|
+
|
|
38
|
+
// Cache input untuk Local Learning
|
|
39
|
+
public lastInputs?: Matrix;
|
|
40
|
+
|
|
41
|
+
constructor(config: SpikingSelfAttentionConfig) {
|
|
42
|
+
super(config);
|
|
43
|
+
this.d_model = config.d_model;
|
|
44
|
+
this.sequenceLength = config.sequenceLength;
|
|
45
|
+
this.kernelInitializer = config.kernelInitializer || "glorot_normal";
|
|
46
|
+
this.betaRange = config.betaRange || [0.8, 0.99];
|
|
47
|
+
this.thresholdRange = config.thresholdRange || [0.1, 0.3];
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
public computeOutputShape(inputShape: number[]): number[] {
|
|
51
|
+
const batch = inputShape[0] ?? -1;
|
|
52
|
+
// Asumsi input shape: [batch * seqLen, d_model]
|
|
53
|
+
return [batch, this.d_model]; // Actually [batch * seqLen, d_model]
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
public build(inputShape: number[]): void {
|
|
57
|
+
super.build(inputShape);
|
|
58
|
+
|
|
59
|
+
const inFeatures = inputShape[inputShape.length - 1]; // Seharusnya sama dengan d_model
|
|
60
|
+
|
|
61
|
+
// 1. Inisialisasi Bobot Q, K, V
|
|
62
|
+
this.addParameter("kernelQ", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
|
|
63
|
+
this.addParameter("kernelK", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
|
|
64
|
+
this.addParameter("kernelV", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
|
|
65
|
+
|
|
66
|
+
// OPTIMIZATION: Scale up initial weights so neurons actually spike (prevent Layer 2 death)
|
|
67
|
+
const scale = Math.sqrt(inFeatures);
|
|
68
|
+
const kQ = this.kernelQ!._data;
|
|
69
|
+
const kK = this.kernelK!._data;
|
|
70
|
+
const kV = this.kernelV!._data;
|
|
71
|
+
for(let i = 0; i < kQ.length; i++) {
|
|
72
|
+
kQ[i] *= scale;
|
|
73
|
+
kK[i] *= scale;
|
|
74
|
+
kV[i] *= scale;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// 2. Inisialisasi parameter LIF untuk Q, K, V
|
|
78
|
+
this.betaQKV = new Float32Array(this.d_model);
|
|
79
|
+
this.thresholdQKV = new Float32Array(this.d_model);
|
|
80
|
+
for (let i = 0; i < this.d_model; i++) {
|
|
81
|
+
this.betaQKV[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
|
|
82
|
+
this.thresholdQKV[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// 3. Inisialisasi parameter LIF untuk Attention Scores (Pengganti Softmax)
|
|
86
|
+
this.betaScores = new Float32Array(this.sequenceLength);
|
|
87
|
+
this.thresholdScores = new Float32Array(this.sequenceLength);
|
|
88
|
+
for (let i = 0; i < this.sequenceLength; i++) {
|
|
89
|
+
this.betaScores[i] = 0.9;
|
|
90
|
+
// Ambang batas diturunkan tajam untuk mencegah Dead Neurons
|
|
91
|
+
this.thresholdScores[i] = 1.0;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// Inisialisasi Potentials akan dilakukan secara dinamis pada saat forward
|
|
95
|
+
this.potentialsQ = Matrix.fromFlat(new Float32Array(0), [0, 0]);
|
|
96
|
+
this.potentialsK = Matrix.fromFlat(new Float32Array(0), [0, 0]);
|
|
97
|
+
this.potentialsV = Matrix.fromFlat(new Float32Array(0), [0, 0]);
|
|
98
|
+
this.potentialsScores = Matrix.fromFlat(new Float32Array(0), [0, 0]);
|
|
99
|
+
}
|
|
100
|
+
private sqDataBuffer?: Float32Array;
|
|
101
|
+
private skDataBuffer?: Float32Array;
|
|
102
|
+
private svDataBuffer?: Float32Array;
|
|
103
|
+
private dummyLpBuffer?: Float32Array;
|
|
104
|
+
private matchScoresBuffer?: Float32Array;
|
|
105
|
+
private qGatedVBuffer?: Float32Array;
|
|
106
|
+
private outSpikesBuffer?: Float32Array;
|
|
107
|
+
private sScoresDataBuffer?: Float32Array;
|
|
108
|
+
private dummyLpScoresBuffer?: Float32Array;
|
|
109
|
+
private tempMatchesBuffer?: Float32Array;
|
|
110
|
+
|
|
111
|
+
private ensurePotentialsShape(batchSeq: number, seqLen: number) {
|
|
112
|
+
if (this.potentialsQ._shape[0] !== batchSeq || !this.sqDataBuffer) {
|
|
113
|
+
this.potentialsQ = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
|
|
114
|
+
this.potentialsK = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
|
|
115
|
+
this.potentialsV = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
|
|
116
|
+
this.potentialsScores = Matrix.fromFlat(new Float32Array(batchSeq * seqLen), [batchSeq, seqLen]);
|
|
117
|
+
|
|
118
|
+
this.sqDataBuffer = new Float32Array(batchSeq * this.d_model);
|
|
119
|
+
this.skDataBuffer = new Float32Array(batchSeq * this.d_model);
|
|
120
|
+
this.svDataBuffer = new Float32Array(batchSeq * this.d_model);
|
|
121
|
+
this.dummyLpBuffer = new Float32Array(batchSeq * this.d_model);
|
|
122
|
+
this.matchScoresBuffer = new Float32Array(batchSeq * seqLen);
|
|
123
|
+
this.qGatedVBuffer = new Float32Array(batchSeq * this.d_model);
|
|
124
|
+
this.outSpikesBuffer = new Float32Array(batchSeq * this.d_model);
|
|
125
|
+
this.sScoresDataBuffer = new Float32Array(batchSeq * seqLen);
|
|
126
|
+
this.dummyLpScoresBuffer = new Float32Array(batchSeq * seqLen);
|
|
127
|
+
this.tempMatchesBuffer = new Float32Array(seqLen);
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
public resetState() {
|
|
132
|
+
if (this.potentialsQ) this.potentialsQ._data.fill(0);
|
|
133
|
+
if (this.potentialsK) this.potentialsK._data.fill(0);
|
|
134
|
+
if (this.potentialsV) this.potentialsV._data.fill(0);
|
|
135
|
+
if (this.potentialsScores) this.potentialsScores._data.fill(0);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
|
|
139
|
+
// Asumsi inputs adalah flat [batch * seqLen, d_model]
|
|
140
|
+
const batchSeq = inputs._shape[0];
|
|
141
|
+
const seqLen = this.sequenceLength;
|
|
142
|
+
const batch = batchSeq / seqLen;
|
|
143
|
+
const d_model = this.d_model;
|
|
144
|
+
|
|
145
|
+
if (!Number.isInteger(batch)) {
|
|
146
|
+
throw new Error(`[SpikingSelfAttention] Jumlah baris input (${batchSeq}) harus merupakan kelipatan dari sequenceLength (${seqLen}).`);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
this.ensurePotentialsShape(batchSeq, seqLen);
|
|
150
|
+
this.lastInputs = inputs; // Simpan untuk local learning
|
|
151
|
+
|
|
152
|
+
// 1. Proyeksi Q, K, V (Hanya Addisi / Pergeseran Bit karena input spike biner)
|
|
153
|
+
let dotQ = dotProductAddOnly(inputs, this.kernelQ!);
|
|
154
|
+
let dotK = dotProductAddOnly(inputs, this.kernelK!);
|
|
155
|
+
let dotV = dotProductAddOnly(inputs, this.kernelV!);
|
|
156
|
+
|
|
157
|
+
// 2. LIF Step untuk menghasilkan S_Q, S_K, S_V (Matriks Biner)
|
|
158
|
+
const sqData = this.sqDataBuffer!;
|
|
159
|
+
sqData.fill(0);
|
|
160
|
+
const skData = this.skDataBuffer!;
|
|
161
|
+
skData.fill(0);
|
|
162
|
+
const svData = this.svDataBuffer!;
|
|
163
|
+
svData.fill(0);
|
|
164
|
+
const dummyLp = this.dummyLpBuffer!;
|
|
165
|
+
dummyLp.fill(0);
|
|
166
|
+
|
|
167
|
+
// Q
|
|
168
|
+
if (isNativeAvailable()) {
|
|
169
|
+
lifStepNativeWrapper(this.potentialsQ._data, dotQ._data, sqData, dummyLp, this.betaQKV, this.thresholdQKV);
|
|
170
|
+
lifStepNativeWrapper(this.potentialsK._data, dotK._data, skData, dummyLp, this.betaQKV, this.thresholdQKV);
|
|
171
|
+
lifStepNativeWrapper(this.potentialsV._data, dotV._data, svData, dummyLp, this.betaQKV, this.thresholdQKV);
|
|
172
|
+
} else {
|
|
173
|
+
this.runLIF(this.potentialsQ._data, dotQ._data, sqData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
|
|
174
|
+
this.runLIF(this.potentialsK._data, dotK._data, skData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
|
|
175
|
+
this.runLIF(this.potentialsV._data, dotV._data, svData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// 3. Menghitung Skor Kecocokan (SQ dot SK^T) menggunakan operasi AND / bit-wise addition
|
|
179
|
+
// Hasilnya akan berukuran [batch * seqLen, seqLen]
|
|
180
|
+
const matchScores = this.matchScoresBuffer!;
|
|
181
|
+
matchScores.fill(0);
|
|
182
|
+
|
|
183
|
+
for (let b = 0; b < batch; b++) {
|
|
184
|
+
for (let i = 0; i < seqLen; i++) {
|
|
185
|
+
const qBase = b * seqLen * d_model + i * d_model;
|
|
186
|
+
// Pre-collect non-zero indices for Q to exploit sparsity
|
|
187
|
+
const nonZeroQ: number[] = [];
|
|
188
|
+
for (let d = 0; d < d_model; d++) {
|
|
189
|
+
if (sqData[qBase + d] > 0) nonZeroQ.push(d);
|
|
190
|
+
}
|
|
191
|
+
if (nonZeroQ.length === 0) continue;
|
|
192
|
+
|
|
193
|
+
let maxMatch = 0;
|
|
194
|
+
const tempMatches = this.tempMatchesBuffer!;
|
|
195
|
+
tempMatches.fill(0);
|
|
196
|
+
|
|
197
|
+
for (let j = 0; j < seqLen; j++) {
|
|
198
|
+
let matchCount = 0;
|
|
199
|
+
const kBase = b * seqLen * d_model + j * d_model;
|
|
200
|
+
for (let k = 0; k < nonZeroQ.length; k++) {
|
|
201
|
+
const d = nonZeroQ[k];
|
|
202
|
+
if (skData[kBase + d] > 0) matchCount++;
|
|
203
|
+
}
|
|
204
|
+
tempMatches[j] = matchCount;
|
|
205
|
+
if (matchCount > maxMatch) {
|
|
206
|
+
maxMatch = matchCount;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
for (let j = 0; j < seqLen; j++) {
|
|
211
|
+
if (maxMatch > 0) {
|
|
212
|
+
matchScores[b * seqLen * seqLen + i * seqLen + j] = tempMatches[j] / maxMatch;
|
|
213
|
+
} else {
|
|
214
|
+
matchScores[b * seqLen * seqLen + i * seqLen + j] = 0;
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
// 4. Pengganti Softmax: Lewatkan skor kecocokan ke lapisan LIF
|
|
221
|
+
const sScoresData = this.sScoresDataBuffer!;
|
|
222
|
+
sScoresData.fill(0);
|
|
223
|
+
const dummyLpScores = this.dummyLpScoresBuffer!;
|
|
224
|
+
dummyLpScores.fill(0);
|
|
225
|
+
|
|
226
|
+
if (isNativeAvailable()) {
|
|
227
|
+
lifStepNativeWrapper(this.potentialsScores._data, matchScores, sScoresData, dummyLpScores, this.betaScores, this.thresholdScores);
|
|
228
|
+
} else {
|
|
229
|
+
this.runLIF(this.potentialsScores._data, matchScores, sScoresData, batchSeq, seqLen, this.betaScores, this.thresholdScores);
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
const outData = this.outSpikesBuffer!;
|
|
233
|
+
outData.fill(0);
|
|
234
|
+
|
|
235
|
+
for (let b = 0; b < batch; b++) {
|
|
236
|
+
for (let j = 0; j < seqLen; j++) {
|
|
237
|
+
const vBase = b * seqLen * d_model + j * d_model;
|
|
238
|
+
// Pre-collect non-zero indices for V to exploit sparsity
|
|
239
|
+
const nonZeroV: number[] = [];
|
|
240
|
+
for (let d = 0; d < d_model; d++) {
|
|
241
|
+
if (svData[vBase + d] > 0) nonZeroV.push(d);
|
|
242
|
+
}
|
|
243
|
+
if (nonZeroV.length === 0) continue;
|
|
244
|
+
|
|
245
|
+
for (let i = 0; i < seqLen; i++) {
|
|
246
|
+
const gradedScore = matchScores[b * seqLen * seqLen + i * seqLen + j];
|
|
247
|
+
if (gradedScore > 0) {
|
|
248
|
+
const outBase = b * seqLen * d_model + i * d_model;
|
|
249
|
+
for (let k = 0; k < nonZeroV.length; k++) {
|
|
250
|
+
const d = nonZeroV[k];
|
|
251
|
+
outData[outBase + d] += gradedScore * svData[vBase + d];
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
// Opsional: Batasi output menjadi biner (spike) jika layer berikutnya menuntut binary matrix
|
|
259
|
+
for (let i = 0; i < outData.length; i++) {
|
|
260
|
+
if (outData[i] > 1.0) outData[i] = 1.0;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
return Matrix.fromFlat(outData, [batchSeq, d_model]);
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
private runLIF(pot: Float32Array, input: Float32Array, output: Float32Array, batch: number, dim: number, beta: Float32Array, threshold: Float32Array) {
|
|
267
|
+
for (let b = 0; b < batch; b++) {
|
|
268
|
+
const offset = b * dim;
|
|
269
|
+
for (let i = 0; i < dim; i++) {
|
|
270
|
+
const idx = offset + i;
|
|
271
|
+
pot[idx] = Math.min((pot[idx] * beta[i]) + input[idx], 1.0);
|
|
272
|
+
}
|
|
273
|
+
for (let i = 0; i < dim; i++) {
|
|
274
|
+
const idx = offset + i;
|
|
275
|
+
if (pot[idx] >= threshold[i]) {
|
|
276
|
+
output[idx] = 1.0;
|
|
277
|
+
pot[idx] -= threshold[i];
|
|
278
|
+
} else {
|
|
279
|
+
output[idx] = 0.0;
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
public learnAttention(errorSignal: Matrix, learningRate: number = 0.01) {
|
|
286
|
+
if (!this.lastInputs) {
|
|
287
|
+
throw new Error("[SpikingSelfAttention] Cannot run learning before forward() is executed.");
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
const err = errorSignal._data;
|
|
291
|
+
const inputs = this.lastInputs._data;
|
|
292
|
+
const batchSeq = this.lastInputs._shape[0];
|
|
293
|
+
// Karena inputs masuk setelah layer 1, shape-nya [batchSeq, d_model]
|
|
294
|
+
const inFeatures = this.lastInputs._shape[1] || this.d_model;
|
|
295
|
+
const d_model = this.d_model;
|
|
296
|
+
|
|
297
|
+
// Update Local Learning: Karena fungsi non-differentiable rumit,
|
|
298
|
+
// kita mendistribusikan sinyal error secara merata ke kernel Q, K, dan V (Hebbian/Surrogate style)
|
|
299
|
+
const kQ = this.kernelQ!._data;
|
|
300
|
+
const kK = this.kernelK!._data;
|
|
301
|
+
const kV = this.kernelV!._data;
|
|
302
|
+
|
|
303
|
+
for (let b = 0; b < batchSeq; b++) {
|
|
304
|
+
const errOffset = b * d_model;
|
|
305
|
+
const inOffset = b * inFeatures;
|
|
306
|
+
for (let i = 0; i < inFeatures; i++) {
|
|
307
|
+
const inVal = inputs[inOffset + i];
|
|
308
|
+
if (inVal > 0) { // Sparse update
|
|
309
|
+
const kOffset = i * d_model;
|
|
310
|
+
for (let d = 0; d < d_model; d++) {
|
|
311
|
+
// Dopamine drive sangat kecil untuk membangkitkan neuron mati tanpa over-saturate
|
|
312
|
+
const dopamine = 0.00005;
|
|
313
|
+
|
|
314
|
+
let deltaQ = (learningRate * err[errOffset + d] * inVal) + dopamine;
|
|
315
|
+
let deltaK = (learningRate * err[errOffset + d] * inVal) + dopamine;
|
|
316
|
+
let deltaV = (learningRate * err[errOffset + d] * inVal) + dopamine;
|
|
317
|
+
|
|
318
|
+
kQ[kOffset + d] = Math.max(-1.0, Math.min(1.0, kQ[kOffset + d] + deltaQ));
|
|
319
|
+
kK[kOffset + d] = Math.max(-1.0, Math.min(1.0, kK[kOffset + d] + deltaK));
|
|
320
|
+
kV[kOffset + d] = Math.max(-1.0, Math.min(1.0, kV[kOffset + d] + deltaV));
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
public getConfig(): Record<string, any> {
|
|
328
|
+
return {
|
|
329
|
+
...super.getConfig(),
|
|
330
|
+
d_model: this.d_model,
|
|
331
|
+
sequenceLength: this.sequenceLength,
|
|
332
|
+
kernelInitializer: this.kernelInitializer
|
|
333
|
+
};
|
|
334
|
+
}
|
|
335
|
+
}
|
package/src/native_backend.ts
CHANGED
|
@@ -107,3 +107,20 @@ export const applyEmbeddingDeltaNativeWrapper = (
|
|
|
107
107
|
outputDim
|
|
108
108
|
);
|
|
109
109
|
};
|
|
110
|
+
|
|
111
|
+
export const contrastiveHebbianNativeWrapper = (
|
|
112
|
+
spikes: Float32Array,
|
|
113
|
+
errData: Float32Array,
|
|
114
|
+
numPairs: number,
|
|
115
|
+
sequenceLength: number,
|
|
116
|
+
dModel: number
|
|
117
|
+
): number => {
|
|
118
|
+
if (!native) throw new Error("Spiking Native backend not available");
|
|
119
|
+
return native.contrastiveHebbianNative(
|
|
120
|
+
spikes,
|
|
121
|
+
errData,
|
|
122
|
+
numPairs,
|
|
123
|
+
sequenceLength,
|
|
124
|
+
dModel
|
|
125
|
+
);
|
|
126
|
+
};
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
use napi_derive::napi;
|
|
2
|
+
use napi::bindgen_prelude::Float32Array;
|
|
3
|
+
use rayon::prelude::*;
|
|
4
|
+
|
|
5
|
+
#[napi]
|
|
6
|
+
pub fn contrastive_hebbian_native(
|
|
7
|
+
spikes: Float32Array,
|
|
8
|
+
mut err_data: Float32Array,
|
|
9
|
+
num_pairs: u32,
|
|
10
|
+
sequence_length: u32,
|
|
11
|
+
d_model: u32,
|
|
12
|
+
) -> f64 {
|
|
13
|
+
let spikes_slice: &[f32] = &spikes;
|
|
14
|
+
let err_slice: &mut [f32] = &mut err_data;
|
|
15
|
+
|
|
16
|
+
let num_pairs = num_pairs as usize;
|
|
17
|
+
let seq_len = sequence_length as usize;
|
|
18
|
+
let d_model = d_model as usize;
|
|
19
|
+
let chunk_size = seq_len * d_model;
|
|
20
|
+
|
|
21
|
+
let total_loss: f32 = err_slice.par_chunks_mut(chunk_size).enumerate().map(|(b, chunk)| {
|
|
22
|
+
let mut local_loss = 0.0f32;
|
|
23
|
+
|
|
24
|
+
if b < num_pairs {
|
|
25
|
+
// Ini adalah vektor Q
|
|
26
|
+
let i = b;
|
|
27
|
+
let p_offset = (num_pairs + i) * chunk_size;
|
|
28
|
+
let n_offset = (num_pairs + ((i + 1) % num_pairs)) * chunk_size;
|
|
29
|
+
let q_offset = i * chunk_size;
|
|
30
|
+
|
|
31
|
+
for rem in 0..chunk_size {
|
|
32
|
+
let q_spike = spikes_slice[q_offset + rem];
|
|
33
|
+
let p_spike = spikes_slice[p_offset + rem];
|
|
34
|
+
let n_spike = spikes_slice[n_offset + rem];
|
|
35
|
+
|
|
36
|
+
let mut pull = p_spike - q_spike;
|
|
37
|
+
if q_spike == 0.0 && p_spike == 0.0 && n_spike == 0.0 {
|
|
38
|
+
pull = 0.05; // Suntik energi
|
|
39
|
+
}
|
|
40
|
+
let push = (q_spike * n_spike) * 0.2;
|
|
41
|
+
|
|
42
|
+
chunk[rem] = pull - push;
|
|
43
|
+
|
|
44
|
+
if pull != 0.0 || push != 0.0 {
|
|
45
|
+
local_loss += pull.abs() + push;
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
} else {
|
|
49
|
+
// Ini adalah vektor P atau N
|
|
50
|
+
let p_index = b - num_pairs;
|
|
51
|
+
|
|
52
|
+
// Peran sebagai P untuk i = p_index
|
|
53
|
+
let q_offset_p = p_index * chunk_size;
|
|
54
|
+
let n_offset_p = (num_pairs + ((p_index + 1) % num_pairs)) * chunk_size;
|
|
55
|
+
|
|
56
|
+
// Peran sebagai N untuk i = p_index - 1 (dengan wrap around)
|
|
57
|
+
let i_n = if p_index == 0 { num_pairs - 1 } else { p_index - 1 };
|
|
58
|
+
let q_offset_n = i_n * chunk_size;
|
|
59
|
+
|
|
60
|
+
for rem in 0..chunk_size {
|
|
61
|
+
let q_spike_p = spikes_slice[q_offset_p + rem];
|
|
62
|
+
let p_spike_p = spikes_slice[b * chunk_size + rem];
|
|
63
|
+
let n_spike_p = spikes_slice[n_offset_p + rem];
|
|
64
|
+
|
|
65
|
+
let mut pull_p = p_spike_p - q_spike_p;
|
|
66
|
+
if q_spike_p == 0.0 && p_spike_p == 0.0 && n_spike_p == 0.0 {
|
|
67
|
+
pull_p = 0.05;
|
|
68
|
+
}
|
|
69
|
+
let contrib_p = -pull_p;
|
|
70
|
+
|
|
71
|
+
let q_spike_n = spikes_slice[q_offset_n + rem];
|
|
72
|
+
let n_spike_n = spikes_slice[b * chunk_size + rem];
|
|
73
|
+
|
|
74
|
+
let push_n = (q_spike_n * n_spike_n) * 0.2;
|
|
75
|
+
let contrib_n = -push_n;
|
|
76
|
+
|
|
77
|
+
chunk[rem] = contrib_p + contrib_n;
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
local_loss
|
|
82
|
+
}).sum();
|
|
83
|
+
|
|
84
|
+
total_loss as f64
|
|
85
|
+
}
|
package/src-rust/src/lib.rs
CHANGED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import { describe, it, expect, beforeEach } from 'vitest';
|
|
2
|
+
import { Matrix } from '@oxide-js/core';
|
|
3
|
+
import { SpikingDenseBPTT } from '../src/layers/SpikingDenseBPTT.js';
|
|
4
|
+
|
|
5
|
+
describe('SpikingDenseBPTT Layer', () => {
|
|
6
|
+
const units = 4;
|
|
7
|
+
const inFeatures = 3;
|
|
8
|
+
const batch = 2;
|
|
9
|
+
const maxTimeSteps = 3;
|
|
10
|
+
|
|
11
|
+
let layer: SpikingDenseBPTT;
|
|
12
|
+
|
|
13
|
+
beforeEach(() => {
|
|
14
|
+
layer = new SpikingDenseBPTT({
|
|
15
|
+
units,
|
|
16
|
+
kernelInitializer: 'ones',
|
|
17
|
+
useBias: false
|
|
18
|
+
});
|
|
19
|
+
layer.build([batch, inFeatures]);
|
|
20
|
+
});
|
|
21
|
+
|
|
22
|
+
it('should initialize parameters correctly', () => {
|
|
23
|
+
expect(layer.units).toBe(units);
|
|
24
|
+
expect(layer.kernel).toBeDefined();
|
|
25
|
+
|
|
26
|
+
// Mengecek apakah inisialisasi dynamic beta (Bit-Shift Decay Float) bekerja
|
|
27
|
+
expect(layer.beta).toBeDefined();
|
|
28
|
+
expect(layer.beta.length).toBe(units);
|
|
29
|
+
|
|
30
|
+
for (let i = 0; i < units; i++) {
|
|
31
|
+
// Karena kita melakukan pre-kalkulasi multiplier: 1.0 - (1.0 / Math.pow(2, shift))
|
|
32
|
+
// dimana shift = 2 hingga 5 (1/4 hingga 1/32)
|
|
33
|
+
// Maka rentang beta yang valid adalah 0.75 hingga 0.96875
|
|
34
|
+
expect(layer.beta[i]).toBeGreaterThanOrEqual(0.75);
|
|
35
|
+
expect(layer.beta[i]).toBeLessThanOrEqual(0.96875);
|
|
36
|
+
}
|
|
37
|
+
});
|
|
38
|
+
|
|
39
|
+
it('should throw error when calling compute() directly', () => {
|
|
40
|
+
const inputs = Matrix.fromFlat(new Float32Array(batch * inFeatures), [batch, inFeatures]);
|
|
41
|
+
expect(() => {
|
|
42
|
+
// @ts-ignore
|
|
43
|
+
layer.compute(inputs);
|
|
44
|
+
}).toThrowError(/Harap gunakan computeStep/);
|
|
45
|
+
});
|
|
46
|
+
|
|
47
|
+
it('should process sequence, enforce BPTT limits, and store history correctly', () => {
|
|
48
|
+
layer.resetSequence(maxTimeSteps);
|
|
49
|
+
|
|
50
|
+
expect(layer.maxTimeSteps).toBe(maxTimeSteps);
|
|
51
|
+
expect(layer.historyInputs.length).toBe(maxTimeSteps);
|
|
52
|
+
|
|
53
|
+
// Dummy binary spike input
|
|
54
|
+
const inputData = new Float32Array(batch * inFeatures).fill(1);
|
|
55
|
+
const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
|
|
56
|
+
|
|
57
|
+
// Time Step 0
|
|
58
|
+
const out0 = layer.computeStep(inputs, 0);
|
|
59
|
+
expect(out0._shape).toEqual([batch, units]);
|
|
60
|
+
expect(layer.historyInputs[0]).toBeDefined();
|
|
61
|
+
expect(layer.historyPotentials[0]).toBeDefined();
|
|
62
|
+
expect(layer.historySpikes[0]).toBeDefined();
|
|
63
|
+
|
|
64
|
+
// Time Step 1
|
|
65
|
+
const out1 = layer.computeStep(inputs, 1);
|
|
66
|
+
expect(out1._shape).toEqual([batch, units]);
|
|
67
|
+
expect(layer.historyInputs[1]).toBeDefined();
|
|
68
|
+
|
|
69
|
+
// Time Step 2
|
|
70
|
+
const out2 = layer.computeStep(inputs, 2);
|
|
71
|
+
expect(out2._shape).toEqual([batch, units]);
|
|
72
|
+
|
|
73
|
+
// Time Step 3 (Exceeds maxTimeSteps -> Harus Error)
|
|
74
|
+
expect(() => {
|
|
75
|
+
layer.computeStep(inputs, 3);
|
|
76
|
+
}).toThrowError(/melebihi batas maxTimeSteps/);
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
it('should run learnThroughTime properly without crashing', () => {
|
|
80
|
+
layer.resetSequence(maxTimeSteps);
|
|
81
|
+
|
|
82
|
+
const inputData = new Float32Array(batch * inFeatures).fill(1);
|
|
83
|
+
const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
|
|
84
|
+
|
|
85
|
+
// Jalankan seluruh sekuens
|
|
86
|
+
for (let t = 0; t < maxTimeSteps; t++) {
|
|
87
|
+
layer.computeStep(inputs, t);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Siapkan urutan error palsu untuk pengujian (error di t=0, t=1, t=2)
|
|
91
|
+
const errors = [];
|
|
92
|
+
for (let t = 0; t < maxTimeSteps; t++) {
|
|
93
|
+
errors.push(Matrix.fromFlat(new Float32Array(batch * units).fill(0.1), [batch, units]));
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Uji BPTT untuk Output Layer (parameter B = undefined)
|
|
97
|
+
expect(() => {
|
|
98
|
+
layer.learnThroughTime(errors, undefined, 0.01);
|
|
99
|
+
}).not.toThrow();
|
|
100
|
+
|
|
101
|
+
// Uji BPTT untuk Hidden Layer (parameter B = Identity Matrix Broadcast)
|
|
102
|
+
const B = Matrix.fromFlat(new Float32Array(units * units).fill(1), [units, units]);
|
|
103
|
+
expect(() => {
|
|
104
|
+
layer.learnThroughTime(errors, B, 0.01);
|
|
105
|
+
}).not.toThrow();
|
|
106
|
+
});
|
|
107
|
+
it('should correctly accumulate potentials and trigger spikes deterministically over time', () => {
|
|
108
|
+
// Buat layer deterministik
|
|
109
|
+
const testLayer = new SpikingDenseBPTT({
|
|
110
|
+
units: 2,
|
|
111
|
+
useBias: false,
|
|
112
|
+
kernelInitializer: 'zeros'
|
|
113
|
+
});
|
|
114
|
+
testLayer.build([1, 2]); // batch=1, inFeatures=2
|
|
115
|
+
|
|
116
|
+
// Override kernel manual: [ [0.6, 0.0], [0.0, 0.8] ]
|
|
117
|
+
testLayer.kernel!._data.set([0.6, 0.0, 0.0, 0.8]);
|
|
118
|
+
|
|
119
|
+
// Override konstan beta (0.5 agar mudah dihitung) dan threshold (1.0)
|
|
120
|
+
testLayer.beta.fill(0.5);
|
|
121
|
+
testLayer.threshold.fill(1.0);
|
|
122
|
+
|
|
123
|
+
// Siapkan sequence 3 time steps
|
|
124
|
+
testLayer.resetSequence(3);
|
|
125
|
+
|
|
126
|
+
// Input selalu menyala setiap timestep: [1, 1]
|
|
127
|
+
const inputs = Matrix.fromFlat(new Float32Array([1, 1]), [1, 2]);
|
|
128
|
+
|
|
129
|
+
// --- TIME STEP 0 ---
|
|
130
|
+
const out0 = testLayer.computeStep(inputs, 0);
|
|
131
|
+
expect(out0._data).toEqual(new Float32Array([0, 0]));
|
|
132
|
+
|
|
133
|
+
// History potentials sebelum spike harus mencatat nilai [0.6, 0.8]
|
|
134
|
+
expect(testLayer.historyPotentials[0]._data[0]).toBeCloseTo(0.6, 5);
|
|
135
|
+
expect(testLayer.historyPotentials[0]._data[1]).toBeCloseTo(0.8, 5);
|
|
136
|
+
|
|
137
|
+
// --- TIME STEP 1 ---
|
|
138
|
+
const out1 = testLayer.computeStep(inputs, 1);
|
|
139
|
+
expect(out1._data).toEqual(new Float32Array([0, 1]));
|
|
140
|
+
|
|
141
|
+
expect(testLayer.historyPotentials[1]._data[0]).toBeCloseTo(0.9, 5);
|
|
142
|
+
expect(testLayer.historyPotentials[1]._data[1]).toBeCloseTo(1.0, 5);
|
|
143
|
+
|
|
144
|
+
// --- TIME STEP 2 ---
|
|
145
|
+
const out2 = testLayer.computeStep(inputs, 2);
|
|
146
|
+
expect(out2._data).toEqual(new Float32Array([1, 0]));
|
|
147
|
+
|
|
148
|
+
expect(testLayer.historyPotentials[2]._data[0]).toBeCloseTo(1.0, 5);
|
|
149
|
+
expect(testLayer.historyPotentials[2]._data[1]).toBeCloseTo(0.8, 5);
|
|
150
|
+
});
|
|
151
|
+
});
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import { describe, it, expect, beforeEach } from 'vitest';
|
|
2
|
+
import { Matrix } from '@oxide-js/core';
|
|
3
|
+
import { SpikingSelfAttention } from '../src/layers/SpikingSelfAttention.js';
|
|
4
|
+
|
|
5
|
+
describe('SpikingSelfAttention Layer', () => {
|
|
6
|
+
const d_model = 4;
|
|
7
|
+
const sequenceLength = 3;
|
|
8
|
+
const batch = 2;
|
|
9
|
+
const batchSeq = batch * sequenceLength;
|
|
10
|
+
|
|
11
|
+
let attentionLayer: SpikingSelfAttention;
|
|
12
|
+
|
|
13
|
+
beforeEach(() => {
|
|
14
|
+
attentionLayer = new SpikingSelfAttention({
|
|
15
|
+
d_model,
|
|
16
|
+
sequenceLength,
|
|
17
|
+
kernelInitializer: 'ones' // Inisialisasi awal statis agar prediktabilitas baik
|
|
18
|
+
});
|
|
19
|
+
attentionLayer.build([batchSeq, d_model]);
|
|
20
|
+
});
|
|
21
|
+
|
|
22
|
+
it('should initialize parameters correctly', () => {
|
|
23
|
+
expect(attentionLayer.d_model).toBe(d_model);
|
|
24
|
+
expect(attentionLayer.sequenceLength).toBe(sequenceLength);
|
|
25
|
+
expect(attentionLayer.kernelQ).toBeDefined();
|
|
26
|
+
expect(attentionLayer.kernelK).toBeDefined();
|
|
27
|
+
expect(attentionLayer.kernelV).toBeDefined();
|
|
28
|
+
|
|
29
|
+
// Potentials harus diawali kosong
|
|
30
|
+
expect(attentionLayer.potentialsQ._data.length).toBe(0);
|
|
31
|
+
expect(attentionLayer.potentialsScores._data.length).toBe(0);
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
it('should compute output with the correct shape and binary spike format', () => {
|
|
35
|
+
// Buat input dummy berupa array biner (0 dan 1)
|
|
36
|
+
const inputData = new Float32Array(batchSeq * d_model);
|
|
37
|
+
for (let i = 0; i < inputData.length; i++) {
|
|
38
|
+
inputData[i] = Math.random() > 0.5 ? 1 : 0;
|
|
39
|
+
}
|
|
40
|
+
const inputs = Matrix.fromFlat(inputData, [batchSeq, d_model]);
|
|
41
|
+
|
|
42
|
+
const output = attentionLayer.forward(inputs) as Matrix;
|
|
43
|
+
|
|
44
|
+
// Cek shape output yang diharapkan: [batch * seqLen, d_model]
|
|
45
|
+
expect(output._shape[0]).toBe(batchSeq);
|
|
46
|
+
expect(output._shape[1]).toBe(d_model);
|
|
47
|
+
|
|
48
|
+
// Pastikan output hanya berisi format spike biner (0.0 atau 1.0)
|
|
49
|
+
const outputData = output._data;
|
|
50
|
+
for (let i = 0; i < outputData.length; i++) {
|
|
51
|
+
expect([0, 1]).toContain(outputData[i]);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
// State potentials harus terbentuk sesuai shape
|
|
55
|
+
expect(attentionLayer.potentialsQ._data.length).toBe(batchSeq * d_model);
|
|
56
|
+
expect(attentionLayer.potentialsScores._data.length).toBe(batchSeq * sequenceLength);
|
|
57
|
+
});
|
|
58
|
+
|
|
59
|
+
it('should properly accumulate potentials in sequential steps', () => {
|
|
60
|
+
const inputData = new Float32Array(batchSeq * d_model).fill(1);
|
|
61
|
+
const inputs = Matrix.fromFlat(inputData, [batchSeq, d_model]);
|
|
62
|
+
|
|
63
|
+
// Jalankan beberapa time-steps (forward pass)
|
|
64
|
+
attentionLayer.forward(inputs);
|
|
65
|
+
|
|
66
|
+
// Ambil sebagian data dari potentialsQ untuk verifikasi akumulasi
|
|
67
|
+
const firstStepPotentials = new Float32Array(attentionLayer.potentialsQ._data);
|
|
68
|
+
|
|
69
|
+
attentionLayer.forward(inputs);
|
|
70
|
+
|
|
71
|
+
// Karena input konstan, seharusnya potensial naik atau diset ulang setelah spike,
|
|
72
|
+
// tetapi setidaknya harus berjalan normal (tidak crash)
|
|
73
|
+
expect(attentionLayer.potentialsQ._data).toBeDefined();
|
|
74
|
+
// Pada saat reset, potensial harus dikembalikan ke 0
|
|
75
|
+
attentionLayer.resetState();
|
|
76
|
+
const zeros = new Float32Array(attentionLayer.potentialsQ._data.length).fill(0);
|
|
77
|
+
expect(attentionLayer.potentialsQ._data).toEqual(zeros);
|
|
78
|
+
});
|
|
79
|
+
|
|
80
|
+
it('should throw error if input batch length is not multiple of sequence length', () => {
|
|
81
|
+
const invalidBatchSeq = 7; // Bukan kelipatan sequenceLength (3)
|
|
82
|
+
const inputData = new Float32Array(invalidBatchSeq * d_model).fill(1);
|
|
83
|
+
const inputs = Matrix.fromFlat(inputData, [invalidBatchSeq, d_model]);
|
|
84
|
+
|
|
85
|
+
expect(() => {
|
|
86
|
+
attentionLayer.forward(inputs);
|
|
87
|
+
}).toThrowError(/Jumlah baris input/);
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
it('should correctly compute exact Self-Attention math (Deterministic Correctness)', () => {
|
|
91
|
+
// Set up environment yang sepenuhnya deterministik
|
|
92
|
+
const testLayer = new SpikingSelfAttention({
|
|
93
|
+
d_model: 2,
|
|
94
|
+
sequenceLength: 2,
|
|
95
|
+
kernelInitializer: 'zeros' // Kita override manual
|
|
96
|
+
});
|
|
97
|
+
|
|
98
|
+
testLayer.build([2, 2]); // batchSeq=2, d_model=2
|
|
99
|
+
|
|
100
|
+
// Override bobot menjadi Identity Matrix
|
|
101
|
+
const identity = new Float32Array([1, 0, 0, 1]);
|
|
102
|
+
testLayer.kernelQ!._data.set(identity);
|
|
103
|
+
testLayer.kernelK!._data.set(identity);
|
|
104
|
+
testLayer.kernelV!._data.set(identity);
|
|
105
|
+
|
|
106
|
+
// Override LIF properties agar seketika spike jika input >= 1
|
|
107
|
+
testLayer.betaQKV.fill(0.0);
|
|
108
|
+
testLayer.thresholdQKV.fill(0.5);
|
|
109
|
+
|
|
110
|
+
testLayer.betaScores.fill(0.0);
|
|
111
|
+
testLayer.thresholdScores.fill(0.5); // Threshold kecil agar skor >= 1 langsung tembak spike
|
|
112
|
+
|
|
113
|
+
// Input: Token 0 = [1, 0], Token 1 = [0, 1]
|
|
114
|
+
const inputs = Matrix.fromFlat(new Float32Array([1, 0, 0, 1]), [2, 2]);
|
|
115
|
+
|
|
116
|
+
// Lakukan forward pass
|
|
117
|
+
const output = testLayer.forward(inputs) as Matrix;
|
|
118
|
+
|
|
119
|
+
// Analisis ekspektasi:
|
|
120
|
+
// SQ, SK, SV akan identik dengan input (karena dikali Identity dan threshold < 1.0)
|
|
121
|
+
// SQ dot SK^T:
|
|
122
|
+
// - Token 0 dot Token 0 = 1 (match index 0)
|
|
123
|
+
// - Token 0 dot Token 1 = 0
|
|
124
|
+
// - Token 1 dot Token 0 = 0
|
|
125
|
+
// - Token 1 dot Token 1 = 1
|
|
126
|
+
// S_Scores akan menjadi Identity Matrix [1, 0, 0, 1]
|
|
127
|
+
// Hasil akhir: S_Scores dikali SV -> [1, 0, 0, 1]
|
|
128
|
+
|
|
129
|
+
expect(output._shape).toEqual([2, 2]);
|
|
130
|
+
expect(output._data).toEqual(new Float32Array([1, 0, 0, 1]));
|
|
131
|
+
|
|
132
|
+
// Mari tes dengan Token yang identik: Token 0 = [1, 0], Token 1 = [1, 0]
|
|
133
|
+
const inputsSame = Matrix.fromFlat(new Float32Array([1, 0, 1, 0]), [2, 2]);
|
|
134
|
+
const outputSame = testLayer.forward(inputsSame) as Matrix;
|
|
135
|
+
|
|
136
|
+
// SQ dot SK^T:
|
|
137
|
+
// - Token 0 dot Token 0 = 1
|
|
138
|
+
// - Token 0 dot Token 1 = 1
|
|
139
|
+
// - Token 1 dot Token 0 = 1
|
|
140
|
+
// - Token 1 dot Token 1 = 1
|
|
141
|
+
// S_Scores akan menjadi [1, 1, 1, 1] (semuanya spike karena saling cocok)
|
|
142
|
+
// SV = [1, 0, 1, 0]
|
|
143
|
+
// Perkalian akhir:
|
|
144
|
+
// Out Token 0 = S_Scores[0]*SV[0] + S_Scores[1]*SV[1] = 1*1 + 1*1 = 2 (Lalu dikunci/clamped ke 1) = [1, 0]
|
|
145
|
+
// Out Token 1 = [1, 0]
|
|
146
|
+
expect(outputSame._data).toEqual(new Float32Array([1, 0, 1, 0]));
|
|
147
|
+
});
|
|
148
|
+
});
|