@oxide-js/spiking 1.1.0 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,303 @@
1
+ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
2
+ import { Matrix, mj } from "@oxide-js/core";
3
+ import {
4
+ isNativeAvailable,
5
+ lifStepNativeWrapper,
6
+ maskSurrogateNativeWrapper,
7
+ applyAddOnlyDeltaNativeWrapper
8
+ } from "../native_backend.js";
9
+ import dotProductAddOnly from "../math/dotProductAddOnly.js";
10
+
11
+ export interface SpikingDenseBPTTConfig extends LayerConfig {
12
+ units: number;
13
+ useBias?: boolean;
14
+ kernelInitializer?: string;
15
+ biasInitializer?: string;
16
+ betaRange?: [number, number];
17
+ thresholdRange?: [number, number];
18
+ }
19
+
20
+ export class SpikingDenseBPTT extends BaseLayer {
21
+ public units: number;
22
+ public useBias: boolean;
23
+ public kernelInitializer: string;
24
+ public biasInitializer: string;
25
+ public betaRange: [number, number];
26
+ public thresholdRange: [number, number];
27
+ public beta!: Float32Array;
28
+ public threshold!: Float32Array;
29
+
30
+ public potentials!: Matrix;
31
+
32
+ // History buffers for Backpropagation Through Time (BPTT)
33
+ public historyInputs: Matrix[] = [];
34
+ public historyPotentials: Matrix[] = [];
35
+ public historySpikes: Matrix[] = [];
36
+ public maxTimeSteps: number = 0;
37
+
38
+ // Buffer untuk performa komputasi Forward
39
+ private outSpikesDataBuffer?: Float32Array;
40
+
41
+ public get kernel(): Matrix | undefined {
42
+ return this.getParameter("kernel");
43
+ }
44
+
45
+ public get bias(): Matrix | undefined {
46
+ return this.getParameter("bias");
47
+ }
48
+
49
+ constructor(config: SpikingDenseBPTTConfig) {
50
+ super(config);
51
+ this.units = config.units;
52
+ this.useBias = config.useBias ?? true;
53
+ this.kernelInitializer = config.kernelInitializer || "glorot_normal";
54
+ this.biasInitializer = config.biasInitializer || "zeros";
55
+ this.betaRange = config.betaRange || [0.8, 0.99];
56
+ this.thresholdRange = config.thresholdRange || [0.5, 1.0];
57
+ }
58
+
59
+ public computeOutputShape(inputShape: number[]): number[] {
60
+ const batch = inputShape[0] ?? -1;
61
+ return [batch, this.units];
62
+ }
63
+
64
+ public build(inputShape: number[]): void {
65
+ super.build(inputShape);
66
+
67
+ const inFeatures = inputShape[inputShape.length - 1];
68
+
69
+ const kernelVal = this.createInitializer(this.kernelInitializer, [inFeatures, this.units]);
70
+ this.addParameter("kernel", kernelVal, true, [inFeatures, this.units]);
71
+
72
+ if (this.useBias) {
73
+ const biasVal = this.createInitializer(this.biasInitializer, [this.units, 1]);
74
+ this.addParameter("bias", biasVal, true, [this.units, 1]);
75
+ }
76
+
77
+ // Inisialisasi beta (dengan pre-calculated bit-shift logic) dan threshold acak
78
+ this.beta = new Float32Array(this.units);
79
+ this.threshold = new Float32Array(this.units);
80
+ for (let i = 0; i < this.units; i++) {
81
+ // Pilih pangkat bit-shift secara acak (2 hingga 5)
82
+ const shift = Math.floor(2 + Math.random() * 4);
83
+ // Pre-kalkulasi multiplier float agar loop sangat cepat tanpa perkalian ekstra
84
+ this.beta[i] = 1.0 - (1.0 / Math.pow(2, shift));
85
+ this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
86
+ }
87
+
88
+ // Inisialisasi state
89
+ this.potentials = Matrix.fromFlat(new Float32Array(this.units), [1, this.units]);
90
+ }
91
+
92
+ private ensurePotentialsShape(batch: number) {
93
+ if (this.potentials._shape[0] !== batch || !this.outSpikesDataBuffer) {
94
+ this.potentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
95
+ this.outSpikesDataBuffer = new Float32Array(batch * this.units);
96
+ }
97
+ }
98
+
99
+ // Panggil fungsi ini SEBELUM satu kalimat/sequence baru mulai dimasukkan
100
+ public resetSequence(timeSteps: number) {
101
+ this.maxTimeSteps = timeSteps;
102
+ this.historyInputs = new Array(timeSteps);
103
+ this.historyPotentials = new Array(timeSteps);
104
+ this.historySpikes = new Array(timeSteps);
105
+
106
+ if (this.potentials) this.potentials._data.fill(0);
107
+ }
108
+
109
+ protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
110
+ throw new Error("[SpikingDenseBPTT] Harap gunakan computeStep(inputs, t) dan resetSequence(t) untuk model BPTT, jangan gunakan compute().");
111
+ }
112
+
113
+ // Forward Pass untuk satu token di waktu ke-t
114
+ public computeStep(inputs: Matrix, t: number): Matrix {
115
+ if (t >= this.maxTimeSteps) {
116
+ throw new Error(`[SpikingDenseBPTT] Time step ${t} melebihi batas maxTimeSteps ${this.maxTimeSteps}`);
117
+ }
118
+
119
+ const kernel = this.kernel!;
120
+ const batch = inputs._shape[0];
121
+ this.ensurePotentialsShape(batch);
122
+
123
+ // 1. Simpan input ke dalam history buffer di index t
124
+ this.historyInputs[t] = Matrix.fromFlat(new Float32Array(inputs._data), inputs._shape);
125
+
126
+ // 2. Add-Only Spiking Dot Product
127
+ let dot = dotProductAddOnly(inputs, kernel);
128
+
129
+ if (this.useBias && this.bias) {
130
+ mj.addBiasRow(dot, this.bias);
131
+ }
132
+
133
+ const outData = this.outSpikesDataBuffer!;
134
+ outData.fill(0);
135
+ const outSpikes = Matrix.fromFlat(outData, [batch, this.units]);
136
+
137
+ // Buffer untuk menyimpan Potensial Membran di waktu t (SETELAH ditambahkan input, tapi SEBELUM ditembakkan/direset)
138
+ // Ini krusial untuk evaluasi kedekatan threshold pada proses Surrogate Gradient BPTT
139
+ const potAtT = new Float32Array(batch * this.units);
140
+
141
+ if (isNativeAvailable()) {
142
+ lifStepNativeWrapper(
143
+ this.potentials._data,
144
+ dot._data,
145
+ outSpikes._data,
146
+ potAtT, // argumen ke-4 di lifStepNative akan diisi oleh potensial pre-fire
147
+ this.beta,
148
+ this.threshold
149
+ );
150
+ } else {
151
+ const potData = this.potentials._data;
152
+ const dotData = dot._data;
153
+
154
+ for (let b = 0; b < batch; b++) {
155
+ const offset = b * this.units;
156
+ // Tahap Leaky & Integrate
157
+ for (let i = 0; i < this.units; i++) {
158
+ const idx = offset + i;
159
+ potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0);
160
+ potAtT[idx] = potData[idx]; // Catat memori potensial di waktu t
161
+ }
162
+ // Tahap Fire & Reset
163
+ for (let i = 0; i < this.units; i++) {
164
+ const idx = offset + i;
165
+ if (potData[idx] >= this.threshold[i]) {
166
+ outData[idx] = 1;
167
+ potData[idx] -= this.threshold[i]; // Soft Reset
168
+ } else {
169
+ outData[idx] = 0;
170
+ }
171
+ }
172
+ }
173
+ }
174
+
175
+ // 3. Simpan state (Potensial & Output Spikes) ke buffer di index t
176
+ this.historyPotentials[t] = Matrix.fromFlat(potAtT, [batch, this.units]);
177
+ this.historySpikes[t] = Matrix.fromFlat(new Float32Array(outSpikes._data), [batch, this.units]);
178
+
179
+ return outSpikes;
180
+ }
181
+
182
+ // Backward Pass untuk belajar menggunakan BPTT
183
+ // Dipanggil SATU KALI HANYA saat kalimat (sequence) selesai diproses
184
+ public learnThroughTime(errorSequence: Matrix[], B: Matrix | undefined, learningRate: number = 0.01): void {
185
+ if (this.maxTimeSteps === 0 || !this.historyInputs[0]) {
186
+ throw new Error("[SpikingDenseBPTT] Belum ada data di memory. Panggil computeStep(inputs, t) terlebih dahulu.");
187
+ }
188
+
189
+ const batch = errorSequence[0]._shape[0];
190
+ const inFeatures = this.historyInputs[0]._shape[1];
191
+ const units = this.units;
192
+ const kernel = this.kernel!._data;
193
+
194
+ // Sinyal "Penyesalan" yang menjalar mundur dari masa depan ke masa lalu
195
+ let temporalErrorData = new Float32Array(batch * units).fill(0);
196
+ const windowSize = 1.0; // Lebar boxcar surrogate gradient
197
+
198
+ // PRE-ALLOCATE buffers untuk menghilangkan BOTTLENECK Javascript (Zero Garbage Collection dalam loop)
199
+ const totalErrorData = new Float32Array(batch * units);
200
+ const maskedErrorData = new Float32Array(batch * units);
201
+ const biasData = (this.useBias && this.bias) ? this.bias._data : new Float32Array(0);
202
+
203
+ // Loop Mundur (Dari akhir kalimat ke awal kalimat)
204
+ for (let t = this.maxTimeSteps - 1; t >= 0; t--) {
205
+ let currentErrorData: Float32Array;
206
+
207
+ if (B) {
208
+ // Jika ini layer tersembunyi (Hidden): Evaluasi menggunakan matriks broadcast B
209
+ let eHidden = mj.dotProduct(errorSequence[t], B, undefined, false, false);
210
+ currentErrorData = eHidden._data;
211
+ } else {
212
+ // Jika ini layer Output: Error murni dari loss function
213
+ currentErrorData = errorSequence[t]._data;
214
+ }
215
+
216
+ const pData = this.historyPotentials[t]._data;
217
+ const inputData = this.historyInputs[t]._data;
218
+
219
+ // Langkah A: Menyatukan Sinyal Spasial (Atas-Bawah) dan Temporal (Masa Depan-Masa Lalu)
220
+ for (let i = 0; i < totalErrorData.length; i++) {
221
+ totalErrorData[i] = currentErrorData[i] + temporalErrorData[i];
222
+ }
223
+
224
+ // Salin total error ke masked error, karena native wrapper akan menimpanya in-place
225
+ maskedErrorData.set(totalErrorData);
226
+
227
+ if (isNativeAvailable()) {
228
+ // 1. Surrogate Mask Native (Zero Copy pointer passing ke Rust)
229
+ maskSurrogateNativeWrapper(maskedErrorData, pData, this.threshold, windowSize);
230
+
231
+ // 2. Add-Only Delta Rule Native (Zero Copy pointer passing ke Rust)
232
+ applyAddOnlyDeltaNativeWrapper(
233
+ kernel,
234
+ biasData,
235
+ inputData,
236
+ maskedErrorData,
237
+ learningRate,
238
+ batch,
239
+ inFeatures,
240
+ units,
241
+ this.useBias
242
+ );
243
+
244
+ // 3. Hitung temporal error (leaky pathway)
245
+ for (let b = 0; b < batch; b++) {
246
+ const offset = b * units;
247
+ for (let i = 0; i < units; i++) {
248
+ const idx = offset + i;
249
+ temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
250
+ }
251
+ }
252
+ } else {
253
+ // ============ FALLBACK JAVASCRIPT ============
254
+ for (let b = 0; b < batch; b++) {
255
+ const offset = b * units;
256
+ for (let i = 0; i < units; i++) {
257
+ const idx = offset + i;
258
+ // Surrogate Boxcar Masking
259
+ if (Math.abs(pData[idx] - this.threshold[i]) <= windowSize) {
260
+ // maskedErrorData sudah berisi totalErrData karena di-.set() di atas
261
+ } else {
262
+ maskedErrorData[idx] = 0;
263
+ }
264
+
265
+ // Menghitung Sinyal Temporal untuk dilanjutkan ke waktu t-1 (melewati jalur leaky/beta)
266
+ temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
267
+ }
268
+ }
269
+
270
+ // Langkah B: Add-Only Delta Rule untuk update Bobot di waktu t
271
+ for (let b = 0; b < batch; b++) {
272
+ const inOffset = b * inFeatures;
273
+ const errOffset = b * units;
274
+ for (let k = 0; k < inFeatures; k++) {
275
+ if (inputData[inOffset + k] > 0.5) {
276
+ const kOffset = k * units;
277
+ for (let j = 0; j < units; j++) {
278
+ kernel[kOffset + j] += learningRate * maskedErrorData[errOffset + j];
279
+ kernel[kOffset + j] = Math.max(-1.0, Math.min(1.0, kernel[kOffset + j]));
280
+ }
281
+ }
282
+ }
283
+ if (this.useBias && this.bias) {
284
+ for (let j = 0; j < units; j++) {
285
+ biasData[j] += (learningRate * maskedErrorData[errOffset + j]) / batch;
286
+ biasData[j] = Math.max(-1.0, Math.min(1.0, biasData[j]));
287
+ }
288
+ }
289
+ }
290
+ }
291
+ }
292
+ }
293
+
294
+ public getConfig(): Record<string, any> {
295
+ return {
296
+ ...super.getConfig(),
297
+ units: this.units,
298
+ useBias: this.useBias,
299
+ kernelInitializer: this.kernelInitializer,
300
+ biasInitializer: this.biasInitializer
301
+ };
302
+ }
303
+ }