@oxide-js/spiking 1.1.0 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +19 -0
- package/index.cjs +322 -0
- package/index.d.ts +5 -13
- package/index.js +6 -2
- package/package.json +1 -1
- package/spiking-native.linux-x64-gnu.node +0 -0
- package/src/index.ts +4 -2
- package/src/layers/SpikingDense.ts +71 -42
- package/src/layers/SpikingDenseBPTT.ts +303 -0
- package/src/layers/SpikingEmbedding.ts +154 -142
- package/src/layers/SpikingSelfAttention.ts +335 -0
- package/src/native_backend.ts +39 -3
- package/src-rust/src/contrastive.rs +85 -0
- package/src-rust/src/delta.rs +51 -0
- package/src-rust/src/dot_product.rs +47 -0
- package/src-rust/src/embedding.rs +28 -0
- package/src-rust/src/lib.rs +16 -460
- package/src-rust/src/lif.rs +44 -0
- package/src-rust/src/surrogate.rs +28 -0
- package/test/SpikingDenseBPTT.test.ts +151 -0
- package/test/SpikingSelfAttention.test.ts +148 -0
- package/test/test_embedding_overlap.ts +181 -0
- package/examples/demo.ts +0 -101
- package/src/models/SpikingSentenceEmbedder.ts +0 -135
|
@@ -2,42 +2,45 @@ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
|
|
|
2
2
|
import { Matrix, mj } from "@oxide-js/core";
|
|
3
3
|
import {
|
|
4
4
|
isNativeAvailable,
|
|
5
|
-
lifStepNativeWrapper,
|
|
6
|
-
maskSurrogateNativeWrapper
|
|
5
|
+
lifStepNativeWrapper,
|
|
6
|
+
maskSurrogateNativeWrapper,
|
|
7
|
+
applyEmbeddingDeltaNativeWrapper
|
|
7
8
|
} from "../native_backend.js";
|
|
8
9
|
|
|
9
10
|
export interface SpikingEmbeddingConfig extends LayerConfig {
|
|
10
|
-
inputDim: number;
|
|
11
|
-
outputDim: number;
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
11
|
+
inputDim: number;
|
|
12
|
+
outputDim: number;
|
|
13
|
+
embeddingsInitializer?: string;
|
|
14
|
+
betaRange?: [number, number];
|
|
15
|
+
thresholdRange?: [number, number];
|
|
15
16
|
}
|
|
16
17
|
|
|
17
18
|
export class SpikingEmbedding extends BaseLayer {
|
|
18
19
|
public inputDim: number;
|
|
19
20
|
public outputDim: number;
|
|
20
|
-
public
|
|
21
|
-
|
|
21
|
+
public embeddingsInitializer: string;
|
|
22
|
+
|
|
23
|
+
public betaRange: [number, number];
|
|
24
|
+
public thresholdRange: [number, number];
|
|
25
|
+
public beta!: Float32Array;
|
|
26
|
+
public threshold!: Float32Array;
|
|
22
27
|
|
|
23
28
|
public potentials!: Matrix;
|
|
24
29
|
public lastPotentials?: Matrix;
|
|
25
30
|
public lastInputs?: Matrix;
|
|
26
31
|
public lastSpikes?: Matrix;
|
|
27
32
|
|
|
28
|
-
public
|
|
29
|
-
|
|
30
|
-
public get kernel(): Matrix | undefined {
|
|
31
|
-
return this.getParameter("kernel");
|
|
33
|
+
public get embeddings(): Matrix | undefined {
|
|
34
|
+
return this.getParameter("embeddings");
|
|
32
35
|
}
|
|
33
36
|
|
|
34
37
|
constructor(config: SpikingEmbeddingConfig) {
|
|
35
38
|
super(config);
|
|
36
39
|
this.inputDim = config.inputDim;
|
|
37
40
|
this.outputDim = config.outputDim;
|
|
38
|
-
this.beta = config.beta ?? 0.9;
|
|
39
|
-
this.threshold = config.threshold ?? 1.0;
|
|
40
41
|
this.embeddingsInitializer = config.embeddingsInitializer || "glorot_normal";
|
|
42
|
+
this.betaRange = config.betaRange || [0.8, 0.99];
|
|
43
|
+
this.thresholdRange = config.thresholdRange || [0.01, 0.1];
|
|
41
44
|
}
|
|
42
45
|
|
|
43
46
|
public computeOutputShape(inputShape: number[]): number[] {
|
|
@@ -47,181 +50,190 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
47
50
|
|
|
48
51
|
public build(inputShape: number[]): void {
|
|
49
52
|
super.build(inputShape);
|
|
50
|
-
|
|
51
|
-
this.
|
|
53
|
+
|
|
54
|
+
const embVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
|
|
55
|
+
|
|
56
|
+
// Scale up the embedding values because Glorot Normal makes them too small for large vocabularies (e.g. 32000)
|
|
57
|
+
// which prevents the LIF neurons from ever reaching the threshold.
|
|
58
|
+
const scaleFactor = Math.sqrt(this.inputDim);
|
|
59
|
+
for (let i = 0; i < embVal._data.length; i++) {
|
|
60
|
+
embVal._data[i] *= scaleFactor;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
this.addParameter("embeddings", embVal, true, [this.inputDim, this.outputDim]);
|
|
64
|
+
|
|
65
|
+
// Inisialisasi beta dan threshold secara acak untuk setiap neuron
|
|
66
|
+
this.beta = new Float32Array(this.outputDim);
|
|
67
|
+
this.threshold = new Float32Array(this.outputDim);
|
|
68
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
69
|
+
this.beta[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
|
|
70
|
+
this.threshold[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// Potentials start at 0, shape [batch, outputDim].
|
|
74
|
+
this.potentials = Matrix.fromFlat(new Float32Array(this.outputDim), [1, this.outputDim]);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
private dotDataBuffer?: Float32Array;
|
|
78
|
+
private outDataBuffer?: Float32Array;
|
|
79
|
+
|
|
80
|
+
private ensurePotentialsShape(batch: number) {
|
|
81
|
+
if (this.potentials._shape[0] !== batch || !this.dotDataBuffer) {
|
|
82
|
+
this.potentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
83
|
+
this.dotDataBuffer = new Float32Array(batch * this.outputDim);
|
|
84
|
+
this.outDataBuffer = new Float32Array(batch * this.outputDim);
|
|
85
|
+
this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
86
|
+
}
|
|
52
87
|
}
|
|
53
88
|
|
|
54
89
|
public resetState() {
|
|
55
90
|
if (this.potentials) this.potentials._data.fill(0);
|
|
56
|
-
this.lastPotentials
|
|
91
|
+
if (this.lastPotentials) this.lastPotentials._data.fill(0);
|
|
57
92
|
this.lastInputs = undefined;
|
|
58
93
|
this.lastSpikes = undefined;
|
|
59
94
|
}
|
|
60
95
|
|
|
61
|
-
private ensurePotentialsShape(batch: number) {
|
|
62
|
-
if (!this.potentials || this.potentials._shape[0] !== batch) {
|
|
63
|
-
this.potentials = Matrix.fromFlat(
|
|
64
|
-
new Float32Array(batch * this.outputDim),
|
|
65
|
-
[batch, this.outputDim]
|
|
66
|
-
);
|
|
67
|
-
}
|
|
68
|
-
}
|
|
69
|
-
|
|
70
96
|
protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
|
|
71
|
-
const kernel = this.kernel!._data;
|
|
72
97
|
const batch = inputs._shape[0];
|
|
73
|
-
const inputData = inputs._data;
|
|
74
|
-
|
|
75
98
|
this.ensurePotentialsShape(batch);
|
|
76
|
-
|
|
77
|
-
// 1.
|
|
78
|
-
const
|
|
99
|
+
|
|
100
|
+
// 1. Embedding lookup
|
|
101
|
+
const emb = this.embeddings!;
|
|
102
|
+
const dotData = this.dotDataBuffer!;
|
|
103
|
+
dotData.fill(0);
|
|
104
|
+
|
|
79
105
|
for (let b = 0; b < batch; b++) {
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
}
|
|
89
|
-
}
|
|
106
|
+
const tokenIdx = inputs._data[b];
|
|
107
|
+
if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
|
|
108
|
+
const embOffset = tokenIdx * this.outputDim;
|
|
109
|
+
const dotOffset = b * this.outputDim;
|
|
110
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
111
|
+
dotData[dotOffset + i] = emb._data[embOffset + i];
|
|
112
|
+
}
|
|
113
|
+
}
|
|
90
114
|
}
|
|
91
|
-
|
|
92
|
-
// 2
|
|
93
|
-
const outData =
|
|
115
|
+
|
|
116
|
+
// 2. Leaky Integrate and Fire (LIF Restore untuk Spiking Murni)
|
|
117
|
+
const outData = this.outDataBuffer!;
|
|
118
|
+
outData.fill(0);
|
|
94
119
|
const outSpikes = Matrix.fromFlat(outData, [batch, this.outputDim]);
|
|
95
|
-
|
|
120
|
+
// lastPotentials is already ensured in shape
|
|
96
121
|
|
|
97
122
|
if (isNativeAvailable()) {
|
|
98
123
|
lifStepNativeWrapper(
|
|
99
124
|
this.potentials._data,
|
|
100
125
|
dotData,
|
|
101
|
-
|
|
102
|
-
this.lastPotentials
|
|
126
|
+
outData,
|
|
127
|
+
this.lastPotentials!._data,
|
|
103
128
|
this.beta,
|
|
104
129
|
this.threshold
|
|
105
130
|
);
|
|
106
131
|
} else {
|
|
107
132
|
const potData = this.potentials._data;
|
|
108
|
-
const
|
|
109
|
-
|
|
110
|
-
for (let
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
133
|
+
const lpData = this.lastPotentials!._data;
|
|
134
|
+
|
|
135
|
+
for (let b = 0; b < batch; b++) {
|
|
136
|
+
const offset = b * this.outputDim;
|
|
137
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
138
|
+
const idx = offset + i;
|
|
139
|
+
potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0); // Clamp potential max 1.0
|
|
140
|
+
lpData[idx] = potData[idx];
|
|
141
|
+
}
|
|
142
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
143
|
+
const idx = offset + i;
|
|
144
|
+
if (potData[idx] >= this.threshold[i]) {
|
|
145
|
+
outData[idx] = 1;
|
|
146
|
+
potData[idx] -= this.threshold[i];
|
|
147
|
+
} else {
|
|
148
|
+
outData[idx] = 0;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
121
151
|
}
|
|
122
152
|
}
|
|
123
|
-
|
|
124
|
-
// Simpan memori untuk update bobot
|
|
153
|
+
|
|
125
154
|
this.lastInputs = inputs;
|
|
126
155
|
this.lastSpikes = outSpikes;
|
|
127
156
|
|
|
128
157
|
return outSpikes;
|
|
129
158
|
}
|
|
130
159
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
throw new Error("[SpikingEmbedding] Cannot run learnEmbedding() before forward() is executed. 'lastInputs' is undefined.");
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
const kernel = this.kernel!._data;
|
|
138
|
-
const inputData = this.lastInputs._data;
|
|
139
|
-
const batch = this.lastInputs._shape[0];
|
|
140
|
-
|
|
141
|
-
// Hitung error yang mampir ke embedding
|
|
142
|
-
// E * B (Feedback Alignment)
|
|
143
|
-
// Gunakan matmul biasa karena B adalah float, dan errorFromNext mungkin float
|
|
144
|
-
const eHidden = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
145
|
-
// Namun karena OxideJS Matrix belum memiliki fungsi dot produk standar terbuka yang stabil,
|
|
146
|
-
// kita harus hati-hati di sini. Untuk simplifikasi, eHidden = errorFromNext * B.
|
|
147
|
-
// Kita asumsikan ada utilitas dotProduct standar dari core.
|
|
148
|
-
// Jika B adalah matriks Dense (dimensi: outUnits x hiddenUnits), maka
|
|
149
|
-
// eHidden [batch, hiddenUnits] = errorFromNext [batch, outUnits] dot B [outUnits, hiddenUnits]
|
|
160
|
+
public learnEmbedding(errorSignal: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
|
|
161
|
+
// Broadcast error mundur (Feedback Alignment)
|
|
162
|
+
let eHidden = mj.dotProduct(errorSignal, B, undefined, false, false); // E * B
|
|
150
163
|
|
|
151
|
-
// Kita panggil dot product standar (bukan Add-Only, karena error dan B sama-sama float)
|
|
152
|
-
let eHiddenMatrix = mj.dotProduct(errorFromNext, B, undefined, false, false);
|
|
153
|
-
|
|
154
164
|
// Surrogate Mask: Boxcar
|
|
155
165
|
if (this.lastPotentials) {
|
|
166
|
+
const eData = eHidden._data;
|
|
167
|
+
const pData = this.lastPotentials._data;
|
|
168
|
+
const windowSize = 1.0;
|
|
169
|
+
|
|
156
170
|
if (isNativeAvailable()) {
|
|
157
171
|
maskSurrogateNativeWrapper(
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
this.threshold,
|
|
161
|
-
|
|
172
|
+
eData,
|
|
173
|
+
pData,
|
|
174
|
+
this.threshold,
|
|
175
|
+
windowSize
|
|
162
176
|
);
|
|
163
177
|
} else {
|
|
164
|
-
const
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
178
|
+
const batch = eHidden._shape[0];
|
|
179
|
+
for (let b = 0; b < batch; b++) {
|
|
180
|
+
const offset = b * this.outputDim;
|
|
181
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
182
|
+
const idx = offset + i;
|
|
183
|
+
if (Math.abs(pData[idx] - this.threshold[i]) > windowSize) {
|
|
184
|
+
eData[idx] = 0;
|
|
185
|
+
}
|
|
172
186
|
}
|
|
173
187
|
}
|
|
174
188
|
}
|
|
175
189
|
}
|
|
176
190
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
for (let b = 0; b < batch; b++) {
|
|
180
|
-
const tokenId = Math.round(inputData[b]);
|
|
181
|
-
if (tokenId >= 0 && tokenId < this.inputDim) {
|
|
182
|
-
const kOffset = tokenId * this.outputDim;
|
|
183
|
-
const errOffset = b * this.outputDim;
|
|
184
|
-
for (let j = 0; j < this.outputDim; j++) {
|
|
185
|
-
kernel[kOffset + j] += learningRate * err[errOffset + j];
|
|
186
|
-
}
|
|
187
|
-
}
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
return eHiddenMatrix;
|
|
191
|
+
this.applyEmbeddingDelta(eHidden, learningRate);
|
|
192
|
+
return eHidden;
|
|
191
193
|
}
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
194
|
+
|
|
195
|
+
private applyEmbeddingDelta(errorSignal: Matrix, learningRate: number) {
|
|
196
|
+
if (!this.lastInputs || !this.lastSpikes) {
|
|
197
|
+
throw new Error("[SpikingEmbedding] Cannot run learning before forward() is executed.");
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
const embeddings = this.embeddings!._data;
|
|
201
|
+
const inputs = this.lastInputs._data;
|
|
202
|
+
const err = errorSignal._data;
|
|
203
|
+
|
|
204
|
+
const batch = this.lastInputs._shape[0];
|
|
205
|
+
const outputDim = this.outputDim;
|
|
206
|
+
|
|
207
|
+
if (isNativeAvailable()) {
|
|
208
|
+
applyEmbeddingDeltaNativeWrapper(
|
|
209
|
+
embeddings,
|
|
210
|
+
inputs,
|
|
211
|
+
err,
|
|
212
|
+
learningRate,
|
|
213
|
+
this.inputDim,
|
|
214
|
+
outputDim
|
|
215
|
+
);
|
|
216
|
+
} else {
|
|
217
|
+
for (let b = 0; b < batch; b++) {
|
|
218
|
+
const tokenIdx = inputs[b];
|
|
219
|
+
if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
|
|
220
|
+
const embOffset = tokenIdx * outputDim;
|
|
221
|
+
const errOffset = b * outputDim;
|
|
222
|
+
for (let j = 0; j < outputDim; j++) {
|
|
223
|
+
embeddings[embOffset + j] += learningRate * err[errOffset + j];
|
|
224
|
+
embeddings[embOffset + j] = Math.max(-1.0, Math.min(1.0, embeddings[embOffset + j])); // Clamp weight [-1, 1]
|
|
222
225
|
}
|
|
223
226
|
}
|
|
224
227
|
}
|
|
225
228
|
}
|
|
226
229
|
}
|
|
230
|
+
|
|
231
|
+
public getConfig(): Record<string, any> {
|
|
232
|
+
return {
|
|
233
|
+
...super.getConfig(),
|
|
234
|
+
inputDim: this.inputDim,
|
|
235
|
+
outputDim: this.outputDim,
|
|
236
|
+
embeddingsInitializer: this.embeddingsInitializer
|
|
237
|
+
};
|
|
238
|
+
}
|
|
227
239
|
}
|