@oxide-js/spiking 1.1.0 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +19 -0
- package/index.cjs +322 -0
- package/index.d.ts +5 -13
- package/index.js +6 -2
- package/package.json +1 -1
- package/spiking-native.linux-x64-gnu.node +0 -0
- package/src/index.ts +4 -2
- package/src/layers/SpikingDense.ts +71 -42
- package/src/layers/SpikingDenseBPTT.ts +303 -0
- package/src/layers/SpikingEmbedding.ts +154 -142
- package/src/layers/SpikingSelfAttention.ts +335 -0
- package/src/native_backend.ts +39 -3
- package/src-rust/src/contrastive.rs +85 -0
- package/src-rust/src/delta.rs +51 -0
- package/src-rust/src/dot_product.rs +47 -0
- package/src-rust/src/embedding.rs +28 -0
- package/src-rust/src/lib.rs +16 -460
- package/src-rust/src/lif.rs +44 -0
- package/src-rust/src/surrogate.rs +28 -0
- package/test/SpikingDenseBPTT.test.ts +151 -0
- package/test/SpikingSelfAttention.test.ts +148 -0
- package/test/test_embedding_overlap.ts +181 -0
- package/examples/demo.ts +0 -101
- package/src/models/SpikingSentenceEmbedder.ts +0 -135
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
|
|
2
|
+
import { Matrix, mj } from "@oxide-js/core";
|
|
3
|
+
import {
|
|
4
|
+
isNativeAvailable,
|
|
5
|
+
lifStepNativeWrapper,
|
|
6
|
+
maskSurrogateNativeWrapper,
|
|
7
|
+
applyAddOnlyDeltaNativeWrapper
|
|
8
|
+
} from "../native_backend.js";
|
|
9
|
+
import dotProductAddOnly from "../math/dotProductAddOnly.js";
|
|
10
|
+
|
|
11
|
+
export interface SpikingDenseBPTTConfig extends LayerConfig {
|
|
12
|
+
units: number;
|
|
13
|
+
useBias?: boolean;
|
|
14
|
+
kernelInitializer?: string;
|
|
15
|
+
biasInitializer?: string;
|
|
16
|
+
betaRange?: [number, number];
|
|
17
|
+
thresholdRange?: [number, number];
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
export class SpikingDenseBPTT extends BaseLayer {
|
|
21
|
+
public units: number;
|
|
22
|
+
public useBias: boolean;
|
|
23
|
+
public kernelInitializer: string;
|
|
24
|
+
public biasInitializer: string;
|
|
25
|
+
public betaRange: [number, number];
|
|
26
|
+
public thresholdRange: [number, number];
|
|
27
|
+
public beta!: Float32Array;
|
|
28
|
+
public threshold!: Float32Array;
|
|
29
|
+
|
|
30
|
+
public potentials!: Matrix;
|
|
31
|
+
|
|
32
|
+
// History buffers for Backpropagation Through Time (BPTT)
|
|
33
|
+
public historyInputs: Matrix[] = [];
|
|
34
|
+
public historyPotentials: Matrix[] = [];
|
|
35
|
+
public historySpikes: Matrix[] = [];
|
|
36
|
+
public maxTimeSteps: number = 0;
|
|
37
|
+
|
|
38
|
+
// Buffer untuk performa komputasi Forward
|
|
39
|
+
private outSpikesDataBuffer?: Float32Array;
|
|
40
|
+
|
|
41
|
+
public get kernel(): Matrix | undefined {
|
|
42
|
+
return this.getParameter("kernel");
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
public get bias(): Matrix | undefined {
|
|
46
|
+
return this.getParameter("bias");
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
constructor(config: SpikingDenseBPTTConfig) {
|
|
50
|
+
super(config);
|
|
51
|
+
this.units = config.units;
|
|
52
|
+
this.useBias = config.useBias ?? true;
|
|
53
|
+
this.kernelInitializer = config.kernelInitializer || "glorot_normal";
|
|
54
|
+
this.biasInitializer = config.biasInitializer || "zeros";
|
|
55
|
+
this.betaRange = config.betaRange || [0.8, 0.99];
|
|
56
|
+
this.thresholdRange = config.thresholdRange || [0.5, 1.0];
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
public computeOutputShape(inputShape: number[]): number[] {
|
|
60
|
+
const batch = inputShape[0] ?? -1;
|
|
61
|
+
return [batch, this.units];
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
public build(inputShape: number[]): void {
|
|
65
|
+
super.build(inputShape);
|
|
66
|
+
|
|
67
|
+
const inFeatures = inputShape[inputShape.length - 1];
|
|
68
|
+
|
|
69
|
+
const kernelVal = this.createInitializer(this.kernelInitializer, [inFeatures, this.units]);
|
|
70
|
+
this.addParameter("kernel", kernelVal, true, [inFeatures, this.units]);
|
|
71
|
+
|
|
72
|
+
if (this.useBias) {
|
|
73
|
+
const biasVal = this.createInitializer(this.biasInitializer, [this.units, 1]);
|
|
74
|
+
this.addParameter("bias", biasVal, true, [this.units, 1]);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Inisialisasi beta (dengan pre-calculated bit-shift logic) dan threshold acak
|
|
78
|
+
this.beta = new Float32Array(this.units);
|
|
79
|
+
this.threshold = new Float32Array(this.units);
|
|
80
|
+
for (let i = 0; i < this.units; i++) {
|
|
81
|
+
// Pilih pangkat bit-shift secara acak (2 hingga 5)
|
|
82
|
+
const shift = Math.floor(2 + Math.random() * 4);
|
|
83
|
+
// Pre-kalkulasi multiplier float agar loop sangat cepat tanpa perkalian ekstra
|
|
84
|
+
this.beta[i] = 1.0 - (1.0 / Math.pow(2, shift));
|
|
85
|
+
this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Inisialisasi state
|
|
89
|
+
this.potentials = Matrix.fromFlat(new Float32Array(this.units), [1, this.units]);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
private ensurePotentialsShape(batch: number) {
|
|
93
|
+
if (this.potentials._shape[0] !== batch || !this.outSpikesDataBuffer) {
|
|
94
|
+
this.potentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
|
|
95
|
+
this.outSpikesDataBuffer = new Float32Array(batch * this.units);
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// Panggil fungsi ini SEBELUM satu kalimat/sequence baru mulai dimasukkan
|
|
100
|
+
public resetSequence(timeSteps: number) {
|
|
101
|
+
this.maxTimeSteps = timeSteps;
|
|
102
|
+
this.historyInputs = new Array(timeSteps);
|
|
103
|
+
this.historyPotentials = new Array(timeSteps);
|
|
104
|
+
this.historySpikes = new Array(timeSteps);
|
|
105
|
+
|
|
106
|
+
if (this.potentials) this.potentials._data.fill(0);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
|
|
110
|
+
throw new Error("[SpikingDenseBPTT] Harap gunakan computeStep(inputs, t) dan resetSequence(t) untuk model BPTT, jangan gunakan compute().");
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// Forward Pass untuk satu token di waktu ke-t
|
|
114
|
+
public computeStep(inputs: Matrix, t: number): Matrix {
|
|
115
|
+
if (t >= this.maxTimeSteps) {
|
|
116
|
+
throw new Error(`[SpikingDenseBPTT] Time step ${t} melebihi batas maxTimeSteps ${this.maxTimeSteps}`);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
const kernel = this.kernel!;
|
|
120
|
+
const batch = inputs._shape[0];
|
|
121
|
+
this.ensurePotentialsShape(batch);
|
|
122
|
+
|
|
123
|
+
// 1. Simpan input ke dalam history buffer di index t
|
|
124
|
+
this.historyInputs[t] = Matrix.fromFlat(new Float32Array(inputs._data), inputs._shape);
|
|
125
|
+
|
|
126
|
+
// 2. Add-Only Spiking Dot Product
|
|
127
|
+
let dot = dotProductAddOnly(inputs, kernel);
|
|
128
|
+
|
|
129
|
+
if (this.useBias && this.bias) {
|
|
130
|
+
mj.addBiasRow(dot, this.bias);
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
const outData = this.outSpikesDataBuffer!;
|
|
134
|
+
outData.fill(0);
|
|
135
|
+
const outSpikes = Matrix.fromFlat(outData, [batch, this.units]);
|
|
136
|
+
|
|
137
|
+
// Buffer untuk menyimpan Potensial Membran di waktu t (SETELAH ditambahkan input, tapi SEBELUM ditembakkan/direset)
|
|
138
|
+
// Ini krusial untuk evaluasi kedekatan threshold pada proses Surrogate Gradient BPTT
|
|
139
|
+
const potAtT = new Float32Array(batch * this.units);
|
|
140
|
+
|
|
141
|
+
if (isNativeAvailable()) {
|
|
142
|
+
lifStepNativeWrapper(
|
|
143
|
+
this.potentials._data,
|
|
144
|
+
dot._data,
|
|
145
|
+
outSpikes._data,
|
|
146
|
+
potAtT, // argumen ke-4 di lifStepNative akan diisi oleh potensial pre-fire
|
|
147
|
+
this.beta,
|
|
148
|
+
this.threshold
|
|
149
|
+
);
|
|
150
|
+
} else {
|
|
151
|
+
const potData = this.potentials._data;
|
|
152
|
+
const dotData = dot._data;
|
|
153
|
+
|
|
154
|
+
for (let b = 0; b < batch; b++) {
|
|
155
|
+
const offset = b * this.units;
|
|
156
|
+
// Tahap Leaky & Integrate
|
|
157
|
+
for (let i = 0; i < this.units; i++) {
|
|
158
|
+
const idx = offset + i;
|
|
159
|
+
potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0);
|
|
160
|
+
potAtT[idx] = potData[idx]; // Catat memori potensial di waktu t
|
|
161
|
+
}
|
|
162
|
+
// Tahap Fire & Reset
|
|
163
|
+
for (let i = 0; i < this.units; i++) {
|
|
164
|
+
const idx = offset + i;
|
|
165
|
+
if (potData[idx] >= this.threshold[i]) {
|
|
166
|
+
outData[idx] = 1;
|
|
167
|
+
potData[idx] -= this.threshold[i]; // Soft Reset
|
|
168
|
+
} else {
|
|
169
|
+
outData[idx] = 0;
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// 3. Simpan state (Potensial & Output Spikes) ke buffer di index t
|
|
176
|
+
this.historyPotentials[t] = Matrix.fromFlat(potAtT, [batch, this.units]);
|
|
177
|
+
this.historySpikes[t] = Matrix.fromFlat(new Float32Array(outSpikes._data), [batch, this.units]);
|
|
178
|
+
|
|
179
|
+
return outSpikes;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// Backward Pass untuk belajar menggunakan BPTT
|
|
183
|
+
// Dipanggil SATU KALI HANYA saat kalimat (sequence) selesai diproses
|
|
184
|
+
public learnThroughTime(errorSequence: Matrix[], B: Matrix | undefined, learningRate: number = 0.01): void {
|
|
185
|
+
if (this.maxTimeSteps === 0 || !this.historyInputs[0]) {
|
|
186
|
+
throw new Error("[SpikingDenseBPTT] Belum ada data di memory. Panggil computeStep(inputs, t) terlebih dahulu.");
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
const batch = errorSequence[0]._shape[0];
|
|
190
|
+
const inFeatures = this.historyInputs[0]._shape[1];
|
|
191
|
+
const units = this.units;
|
|
192
|
+
const kernel = this.kernel!._data;
|
|
193
|
+
|
|
194
|
+
// Sinyal "Penyesalan" yang menjalar mundur dari masa depan ke masa lalu
|
|
195
|
+
let temporalErrorData = new Float32Array(batch * units).fill(0);
|
|
196
|
+
const windowSize = 1.0; // Lebar boxcar surrogate gradient
|
|
197
|
+
|
|
198
|
+
// PRE-ALLOCATE buffers untuk menghilangkan BOTTLENECK Javascript (Zero Garbage Collection dalam loop)
|
|
199
|
+
const totalErrorData = new Float32Array(batch * units);
|
|
200
|
+
const maskedErrorData = new Float32Array(batch * units);
|
|
201
|
+
const biasData = (this.useBias && this.bias) ? this.bias._data : new Float32Array(0);
|
|
202
|
+
|
|
203
|
+
// Loop Mundur (Dari akhir kalimat ke awal kalimat)
|
|
204
|
+
for (let t = this.maxTimeSteps - 1; t >= 0; t--) {
|
|
205
|
+
let currentErrorData: Float32Array;
|
|
206
|
+
|
|
207
|
+
if (B) {
|
|
208
|
+
// Jika ini layer tersembunyi (Hidden): Evaluasi menggunakan matriks broadcast B
|
|
209
|
+
let eHidden = mj.dotProduct(errorSequence[t], B, undefined, false, false);
|
|
210
|
+
currentErrorData = eHidden._data;
|
|
211
|
+
} else {
|
|
212
|
+
// Jika ini layer Output: Error murni dari loss function
|
|
213
|
+
currentErrorData = errorSequence[t]._data;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
const pData = this.historyPotentials[t]._data;
|
|
217
|
+
const inputData = this.historyInputs[t]._data;
|
|
218
|
+
|
|
219
|
+
// Langkah A: Menyatukan Sinyal Spasial (Atas-Bawah) dan Temporal (Masa Depan-Masa Lalu)
|
|
220
|
+
for (let i = 0; i < totalErrorData.length; i++) {
|
|
221
|
+
totalErrorData[i] = currentErrorData[i] + temporalErrorData[i];
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
// Salin total error ke masked error, karena native wrapper akan menimpanya in-place
|
|
225
|
+
maskedErrorData.set(totalErrorData);
|
|
226
|
+
|
|
227
|
+
if (isNativeAvailable()) {
|
|
228
|
+
// 1. Surrogate Mask Native (Zero Copy pointer passing ke Rust)
|
|
229
|
+
maskSurrogateNativeWrapper(maskedErrorData, pData, this.threshold, windowSize);
|
|
230
|
+
|
|
231
|
+
// 2. Add-Only Delta Rule Native (Zero Copy pointer passing ke Rust)
|
|
232
|
+
applyAddOnlyDeltaNativeWrapper(
|
|
233
|
+
kernel,
|
|
234
|
+
biasData,
|
|
235
|
+
inputData,
|
|
236
|
+
maskedErrorData,
|
|
237
|
+
learningRate,
|
|
238
|
+
batch,
|
|
239
|
+
inFeatures,
|
|
240
|
+
units,
|
|
241
|
+
this.useBias
|
|
242
|
+
);
|
|
243
|
+
|
|
244
|
+
// 3. Hitung temporal error (leaky pathway)
|
|
245
|
+
for (let b = 0; b < batch; b++) {
|
|
246
|
+
const offset = b * units;
|
|
247
|
+
for (let i = 0; i < units; i++) {
|
|
248
|
+
const idx = offset + i;
|
|
249
|
+
temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
} else {
|
|
253
|
+
// ============ FALLBACK JAVASCRIPT ============
|
|
254
|
+
for (let b = 0; b < batch; b++) {
|
|
255
|
+
const offset = b * units;
|
|
256
|
+
for (let i = 0; i < units; i++) {
|
|
257
|
+
const idx = offset + i;
|
|
258
|
+
// Surrogate Boxcar Masking
|
|
259
|
+
if (Math.abs(pData[idx] - this.threshold[i]) <= windowSize) {
|
|
260
|
+
// maskedErrorData sudah berisi totalErrData karena di-.set() di atas
|
|
261
|
+
} else {
|
|
262
|
+
maskedErrorData[idx] = 0;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
// Menghitung Sinyal Temporal untuk dilanjutkan ke waktu t-1 (melewati jalur leaky/beta)
|
|
266
|
+
temporalErrorData[idx] = maskedErrorData[idx] * this.beta[i];
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Langkah B: Add-Only Delta Rule untuk update Bobot di waktu t
|
|
271
|
+
for (let b = 0; b < batch; b++) {
|
|
272
|
+
const inOffset = b * inFeatures;
|
|
273
|
+
const errOffset = b * units;
|
|
274
|
+
for (let k = 0; k < inFeatures; k++) {
|
|
275
|
+
if (inputData[inOffset + k] > 0.5) {
|
|
276
|
+
const kOffset = k * units;
|
|
277
|
+
for (let j = 0; j < units; j++) {
|
|
278
|
+
kernel[kOffset + j] += learningRate * maskedErrorData[errOffset + j];
|
|
279
|
+
kernel[kOffset + j] = Math.max(-1.0, Math.min(1.0, kernel[kOffset + j]));
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
if (this.useBias && this.bias) {
|
|
284
|
+
for (let j = 0; j < units; j++) {
|
|
285
|
+
biasData[j] += (learningRate * maskedErrorData[errOffset + j]) / batch;
|
|
286
|
+
biasData[j] = Math.max(-1.0, Math.min(1.0, biasData[j]));
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
public getConfig(): Record<string, any> {
|
|
295
|
+
return {
|
|
296
|
+
...super.getConfig(),
|
|
297
|
+
units: this.units,
|
|
298
|
+
useBias: this.useBias,
|
|
299
|
+
kernelInitializer: this.kernelInitializer,
|
|
300
|
+
biasInitializer: this.biasInitializer
|
|
301
|
+
};
|
|
302
|
+
}
|
|
303
|
+
}
|