@oxide-js/spiking 1.1.0 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +19 -0
- package/index.cjs +322 -0
- package/index.d.ts +5 -13
- package/index.js +6 -2
- package/package.json +1 -1
- package/spiking-native.linux-x64-gnu.node +0 -0
- package/src/index.ts +4 -2
- package/src/layers/SpikingDense.ts +71 -42
- package/src/layers/SpikingDenseBPTT.ts +303 -0
- package/src/layers/SpikingEmbedding.ts +154 -142
- package/src/layers/SpikingSelfAttention.ts +335 -0
- package/src/native_backend.ts +39 -3
- package/src-rust/src/contrastive.rs +85 -0
- package/src-rust/src/delta.rs +51 -0
- package/src-rust/src/dot_product.rs +47 -0
- package/src-rust/src/embedding.rs +28 -0
- package/src-rust/src/lib.rs +16 -460
- package/src-rust/src/lif.rs +44 -0
- package/src-rust/src/surrogate.rs +28 -0
- package/test/SpikingDenseBPTT.test.ts +151 -0
- package/test/SpikingSelfAttention.test.ts +148 -0
- package/test/test_embedding_overlap.ts +181 -0
- package/examples/demo.ts +0 -101
- package/src/models/SpikingSentenceEmbedder.ts +0 -135
package/src-rust/src/lib.rs
CHANGED
|
@@ -1,462 +1,18 @@
|
|
|
1
1
|
#![deny(clippy::all)]
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
pub struct NativeSpikingNetwork {
|
|
20
|
-
num_neurons: u32,
|
|
21
|
-
potentials: Vec<f32>,
|
|
22
|
-
thresholds: Vec<f32>,
|
|
23
|
-
spikes: Vec<u8>,
|
|
24
|
-
|
|
25
|
-
// Adjacency lists
|
|
26
|
-
post_synaptic_indices: Vec<Vec<u32>>,
|
|
27
|
-
weights: Vec<Vec<f32>>,
|
|
28
|
-
|
|
29
|
-
// STDP Traces
|
|
30
|
-
pre_traces: Vec<f32>,
|
|
31
|
-
post_traces: Vec<f32>,
|
|
32
|
-
|
|
33
|
-
beta: f32,
|
|
34
|
-
}
|
|
35
|
-
|
|
36
|
-
#[napi]
|
|
37
|
-
impl NativeSpikingNetwork {
|
|
38
|
-
#[napi(constructor)]
|
|
39
|
-
pub fn new(num_neurons: u32, beta: f64, default_threshold: f64) -> Self {
|
|
40
|
-
Self {
|
|
41
|
-
num_neurons,
|
|
42
|
-
potentials: vec![0.0; num_neurons as usize],
|
|
43
|
-
thresholds: vec![default_threshold as f32; num_neurons as usize],
|
|
44
|
-
spikes: vec![0; num_neurons as usize],
|
|
45
|
-
post_synaptic_indices: vec![Vec::new(); num_neurons as usize],
|
|
46
|
-
weights: vec![Vec::new(); num_neurons as usize],
|
|
47
|
-
pre_traces: vec![0.0; num_neurons as usize],
|
|
48
|
-
post_traces: vec![0.0; num_neurons as usize],
|
|
49
|
-
beta: beta as f32,
|
|
50
|
-
}
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
#[napi]
|
|
54
|
-
pub fn connect(&mut self, pre: u32, post: u32, weight: f64) {
|
|
55
|
-
let pre_idx = pre as usize;
|
|
56
|
-
if pre_idx >= self.num_neurons as usize || post as usize >= self.num_neurons as usize {
|
|
57
|
-
return;
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
if let Some(pos) = self.post_synaptic_indices[pre_idx].iter().position(|&p| p == post) {
|
|
61
|
-
self.weights[pre_idx][pos] = weight as f32;
|
|
62
|
-
} else {
|
|
63
|
-
self.post_synaptic_indices[pre_idx].push(post);
|
|
64
|
-
self.weights[pre_idx].push(weight as f32);
|
|
65
|
-
}
|
|
66
|
-
}
|
|
67
|
-
|
|
68
|
-
#[napi]
|
|
69
|
-
pub fn inject_current(&mut self, neuron_idx: u32, current: f64) {
|
|
70
|
-
let idx = neuron_idx as usize;
|
|
71
|
-
if idx < self.potentials.len() {
|
|
72
|
-
self.potentials[idx] += current as f32;
|
|
73
|
-
}
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
#[napi]
|
|
77
|
-
pub fn inhibit_range(&mut self, start_idx: u32, end_idx: u32, except_idx: u32, current: f64) {
|
|
78
|
-
let start = start_idx as usize;
|
|
79
|
-
let end = end_idx as usize;
|
|
80
|
-
let ex = except_idx as usize;
|
|
81
|
-
let c = current as f32;
|
|
82
|
-
|
|
83
|
-
for i in start..end {
|
|
84
|
-
if i != ex && i < self.potentials.len() {
|
|
85
|
-
self.potentials[i] += c;
|
|
86
|
-
}
|
|
87
|
-
}
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
#[napi]
|
|
91
|
-
pub fn step(&mut self) {
|
|
92
|
-
// 1. Accumulate pre-synaptic spikes (Sequential for now to avoid data races)
|
|
93
|
-
// We only process neurons that spiked in the previous step
|
|
94
|
-
for i in 0..self.num_neurons as usize {
|
|
95
|
-
if self.spikes[i] == 1 {
|
|
96
|
-
let targets = &self.post_synaptic_indices[i];
|
|
97
|
-
let w = &self.weights[i];
|
|
98
|
-
for k in 0..targets.len() {
|
|
99
|
-
let j = targets[k] as usize;
|
|
100
|
-
self.potentials[j] += w[k];
|
|
101
|
-
}
|
|
102
|
-
}
|
|
103
|
-
}
|
|
104
|
-
|
|
105
|
-
// Reset spikes safely (Paralel)
|
|
106
|
-
self.spikes.par_iter_mut().for_each(|s| *s = 0);
|
|
107
|
-
|
|
108
|
-
// 2. Membrane decay, threshold check, and fire (Paralel via Rayon)
|
|
109
|
-
let beta = self.beta;
|
|
110
|
-
self.potentials
|
|
111
|
-
.par_iter_mut()
|
|
112
|
-
.zip(self.thresholds.par_iter())
|
|
113
|
-
.zip(self.spikes.par_iter_mut())
|
|
114
|
-
.for_each(|((pot, &thresh), spike)| {
|
|
115
|
-
*pot *= beta; // Leaky integration
|
|
116
|
-
if *pot < 1e-5 { *pot = 0.0; } // Prevent denormals
|
|
117
|
-
if *pot >= thresh {
|
|
118
|
-
*spike = 1;
|
|
119
|
-
*pot -= thresh; // Soft reset
|
|
120
|
-
}
|
|
121
|
-
});
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
#[napi]
|
|
125
|
-
pub fn reset_state(&mut self) {
|
|
126
|
-
self.potentials.par_iter_mut().for_each(|p| *p = 0.0);
|
|
127
|
-
self.spikes.par_iter_mut().for_each(|s| *s = 0);
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
// Helpers to get data to JS if needed
|
|
131
|
-
#[napi]
|
|
132
|
-
pub fn get_spikes(&self) -> Vec<u8> {
|
|
133
|
-
self.spikes.clone()
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
#[napi]
|
|
137
|
-
pub fn get_potentials(&self) -> Vec<f32> {
|
|
138
|
-
self.potentials.clone()
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
#[napi]
|
|
142
|
-
pub fn update_stdp(
|
|
143
|
-
&mut self,
|
|
144
|
-
learning_rate: f64,
|
|
145
|
-
tau_plus: f64,
|
|
146
|
-
tau_minus: f64,
|
|
147
|
-
a_plus: f64,
|
|
148
|
-
a_minus: f64,
|
|
149
|
-
w_max: f64,
|
|
150
|
-
w_min: f64,
|
|
151
|
-
) {
|
|
152
|
-
let lr = learning_rate as f32;
|
|
153
|
-
let t_plus = tau_plus as f32;
|
|
154
|
-
let t_minus = tau_minus as f32;
|
|
155
|
-
let ap = a_plus as f32;
|
|
156
|
-
let am = a_minus as f32;
|
|
157
|
-
let wmax = w_max as f32;
|
|
158
|
-
let wmin = w_min as f32;
|
|
159
|
-
|
|
160
|
-
// 1. Update trace decay and trace spikes
|
|
161
|
-
self.pre_traces
|
|
162
|
-
.par_iter_mut()
|
|
163
|
-
.zip(self.post_traces.par_iter_mut())
|
|
164
|
-
.zip(self.spikes.par_iter())
|
|
165
|
-
.for_each(|((pre, post), &spike)| {
|
|
166
|
-
*pre *= t_plus;
|
|
167
|
-
*post *= t_minus;
|
|
168
|
-
|
|
169
|
-
if *pre < 1e-5 { *pre = 0.0; }
|
|
170
|
-
if *post < 1e-5 { *post = 0.0; }
|
|
171
|
-
|
|
172
|
-
if spike == 1 {
|
|
173
|
-
*pre = 1.0;
|
|
174
|
-
*post = 1.0;
|
|
175
|
-
}
|
|
176
|
-
});
|
|
177
|
-
|
|
178
|
-
// 2. Event-driven weight updates
|
|
179
|
-
// Since we mutate self.weights, we can parallelize over the outer slice
|
|
180
|
-
// because each sub-vector weights[i] belongs to exclusively one thread.
|
|
181
|
-
// However, we need immutable access to self.post_traces and self.spikes.
|
|
182
|
-
// To do this cleanly in Rust with Rayon without violating borrow rules:
|
|
183
|
-
|
|
184
|
-
let spikes = &self.spikes;
|
|
185
|
-
let pre_traces = &self.pre_traces;
|
|
186
|
-
let post_traces = &self.post_traces;
|
|
187
|
-
let post_synaptic_indices = &self.post_synaptic_indices;
|
|
188
|
-
|
|
189
|
-
self.weights
|
|
190
|
-
.par_iter_mut()
|
|
191
|
-
.enumerate()
|
|
192
|
-
.for_each(|(i, w_list)| {
|
|
193
|
-
let pre_spiked = spikes[i] == 1;
|
|
194
|
-
let pre_trace = pre_traces[i];
|
|
195
|
-
|
|
196
|
-
if pre_trace == 0.0 && !pre_spiked {
|
|
197
|
-
return;
|
|
198
|
-
}
|
|
199
|
-
|
|
200
|
-
let targets = &post_synaptic_indices[i];
|
|
201
|
-
|
|
202
|
-
for k in 0..targets.len() {
|
|
203
|
-
let j = targets[k] as usize;
|
|
204
|
-
let post_spiked = spikes[j] == 1;
|
|
205
|
-
|
|
206
|
-
let mut dw = 0.0;
|
|
207
|
-
|
|
208
|
-
// LTP
|
|
209
|
-
if post_spiked {
|
|
210
|
-
dw += lr * ap * pre_trace;
|
|
211
|
-
}
|
|
212
|
-
// LTD
|
|
213
|
-
if pre_spiked {
|
|
214
|
-
dw -= lr * am * post_traces[j];
|
|
215
|
-
}
|
|
216
|
-
|
|
217
|
-
if dw != 0.0 {
|
|
218
|
-
w_list[k] += dw;
|
|
219
|
-
if w_list[k] > wmax {
|
|
220
|
-
w_list[k] = wmax;
|
|
221
|
-
} else if w_list[k] < wmin {
|
|
222
|
-
w_list[k] = wmin;
|
|
223
|
-
}
|
|
224
|
-
}
|
|
225
|
-
}
|
|
226
|
-
});
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
#[napi]
|
|
230
|
-
pub fn save_to_file(&self, filepath: String) {
|
|
231
|
-
let state = NetworkState {
|
|
232
|
-
num_neurons: self.num_neurons,
|
|
233
|
-
beta: self.beta,
|
|
234
|
-
thresholds: self.thresholds.clone(),
|
|
235
|
-
post_synaptic_indices: self.post_synaptic_indices.clone(),
|
|
236
|
-
weights: self.weights.clone(),
|
|
237
|
-
};
|
|
238
|
-
let file = File::create(filepath).expect("Gagal membuat file penyimpanan");
|
|
239
|
-
let writer = BufWriter::new(file);
|
|
240
|
-
serde_json::to_writer(writer, &state).expect("Gagal menyimpan NetworkState");
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
#[napi]
|
|
244
|
-
pub fn load_from_file(&mut self, filepath: String) {
|
|
245
|
-
let file = File::open(filepath).expect("Gagal membuka file penyimpanan");
|
|
246
|
-
let reader = BufReader::new(file);
|
|
247
|
-
let state: NetworkState = serde_json::from_reader(reader).expect("Gagal memuat NetworkState");
|
|
248
|
-
|
|
249
|
-
self.num_neurons = state.num_neurons;
|
|
250
|
-
self.beta = state.beta;
|
|
251
|
-
self.thresholds = state.thresholds;
|
|
252
|
-
self.post_synaptic_indices = state.post_synaptic_indices;
|
|
253
|
-
self.weights = state.weights;
|
|
254
|
-
|
|
255
|
-
// Sesuaikan memori untuk traces & potentials
|
|
256
|
-
let len = self.num_neurons as usize;
|
|
257
|
-
self.potentials.resize(len, 0.0);
|
|
258
|
-
self.spikes.resize(len, 0);
|
|
259
|
-
self.pre_traces.resize(len, 0.0);
|
|
260
|
-
self.post_traces.resize(len, 0.0);
|
|
261
|
-
|
|
262
|
-
// Reset state internal
|
|
263
|
-
self.potentials.fill(0.0);
|
|
264
|
-
self.spikes.fill(0);
|
|
265
|
-
self.pre_traces.fill(0.0);
|
|
266
|
-
self.post_traces.fill(0.0);
|
|
267
|
-
}
|
|
268
|
-
}
|
|
269
|
-
|
|
270
|
-
#[napi]
|
|
271
|
-
pub fn dot_product_add_only_native(
|
|
272
|
-
a_data: napi::bindgen_prelude::Float32Array,
|
|
273
|
-
a_rows_orig: u32,
|
|
274
|
-
a_cols_orig: u32,
|
|
275
|
-
b_data: napi::bindgen_prelude::Float32Array,
|
|
276
|
-
b_rows_orig: u32,
|
|
277
|
-
b_cols_orig: u32,
|
|
278
|
-
trans_a: bool,
|
|
279
|
-
trans_b: bool,
|
|
280
|
-
mut out_data: napi::bindgen_prelude::Float32Array,
|
|
281
|
-
) {
|
|
282
|
-
let a_rows = if trans_a { a_cols_orig } else { a_rows_orig } as usize;
|
|
283
|
-
let a_cols = if trans_a { a_rows_orig } else { a_cols_orig } as usize;
|
|
284
|
-
let b_rows = if trans_b { b_cols_orig } else { b_rows_orig } as usize;
|
|
285
|
-
let b_cols = if trans_b { b_rows_orig } else { b_cols_orig } as usize;
|
|
286
|
-
|
|
287
|
-
let a_slice = &*a_data;
|
|
288
|
-
let b_slice = &*b_data;
|
|
289
|
-
let out_slice = &mut *out_data;
|
|
290
|
-
|
|
291
|
-
let mut a_is_binary = true;
|
|
292
|
-
for &val in a_slice {
|
|
293
|
-
if val != 0.0 && val != 1.0 {
|
|
294
|
-
a_is_binary = false;
|
|
295
|
-
break;
|
|
296
|
-
}
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
let mut b_is_binary = true;
|
|
300
|
-
if !a_is_binary {
|
|
301
|
-
for &val in b_slice {
|
|
302
|
-
if val != 0.0 && val != 1.0 {
|
|
303
|
-
b_is_binary = false;
|
|
304
|
-
break;
|
|
305
|
-
}
|
|
306
|
-
}
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
if !a_is_binary && !b_is_binary {
|
|
310
|
-
panic!("SNN Error: Kedua matriks adalah floating-point. Setidaknya salah satu matriks harus hanya berisi 0 dan 1.");
|
|
311
|
-
}
|
|
312
|
-
|
|
313
|
-
out_slice.par_iter_mut().for_each(|x| *x = 0.0);
|
|
314
|
-
|
|
315
|
-
let a_rows_orig = a_rows_orig as usize;
|
|
316
|
-
let a_cols_orig = a_cols_orig as usize;
|
|
317
|
-
let b_cols_orig = b_cols_orig as usize;
|
|
318
|
-
|
|
319
|
-
out_slice.par_chunks_mut(b_cols).enumerate().for_each(|(i, out_row)| {
|
|
320
|
-
if !trans_b {
|
|
321
|
-
for k in 0..a_cols {
|
|
322
|
-
let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
|
|
323
|
-
if aik == 0.0 { continue; }
|
|
324
|
-
let k_offset = k * b_cols;
|
|
325
|
-
|
|
326
|
-
if a_is_binary {
|
|
327
|
-
// aik must be 1.0
|
|
328
|
-
for j in 0..b_cols {
|
|
329
|
-
out_row[j] += b_slice[k_offset + j];
|
|
330
|
-
}
|
|
331
|
-
} else {
|
|
332
|
-
// b_is_binary
|
|
333
|
-
for j in 0..b_cols {
|
|
334
|
-
if b_slice[k_offset + j] == 1.0 {
|
|
335
|
-
out_row[j] += aik;
|
|
336
|
-
}
|
|
337
|
-
}
|
|
338
|
-
}
|
|
339
|
-
}
|
|
340
|
-
} else {
|
|
341
|
-
// trans_b == true
|
|
342
|
-
for j in 0..b_cols {
|
|
343
|
-
let mut sum = 0.0;
|
|
344
|
-
for k in 0..a_cols {
|
|
345
|
-
let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
|
|
346
|
-
let bjk = b_slice[j * b_cols_orig + k];
|
|
347
|
-
if a_is_binary {
|
|
348
|
-
if aik == 1.0 { sum += bjk; }
|
|
349
|
-
} else {
|
|
350
|
-
if bjk == 1.0 { sum += aik; }
|
|
351
|
-
}
|
|
352
|
-
}
|
|
353
|
-
out_row[j] = sum;
|
|
354
|
-
}
|
|
355
|
-
}
|
|
356
|
-
});
|
|
357
|
-
}
|
|
358
|
-
});
|
|
359
|
-
}
|
|
360
|
-
|
|
361
|
-
#[napi]
|
|
362
|
-
pub fn lif_step_native(
|
|
363
|
-
mut potentials: napi::bindgen_prelude::Float32Array,
|
|
364
|
-
dot: napi::bindgen_prelude::Float32Array,
|
|
365
|
-
mut spikes: napi::bindgen_prelude::Float32Array,
|
|
366
|
-
mut last_potentials: napi::bindgen_prelude::Float32Array,
|
|
367
|
-
beta: f64,
|
|
368
|
-
threshold: f64,
|
|
369
|
-
) {
|
|
370
|
-
let pot_slice = &mut *potentials;
|
|
371
|
-
let dot_slice = &*dot;
|
|
372
|
-
let spike_slice = &mut *spikes;
|
|
373
|
-
let lp_slice = &mut *last_potentials;
|
|
374
|
-
let b = beta as f32;
|
|
375
|
-
let th = threshold as f32;
|
|
376
|
-
|
|
377
|
-
pot_slice.par_iter_mut()
|
|
378
|
-
.zip(dot_slice.par_iter())
|
|
379
|
-
.zip(spike_slice.par_iter_mut())
|
|
380
|
-
.zip(lp_slice.par_iter_mut())
|
|
381
|
-
.for_each(|(((p, d), s), lp)| {
|
|
382
|
-
*p = (*p * b) + d;
|
|
383
|
-
*lp = *p;
|
|
384
|
-
if *p >= th {
|
|
385
|
-
*s = 1.0;
|
|
386
|
-
*p -= th;
|
|
387
|
-
} else {
|
|
388
|
-
*s = 0.0;
|
|
389
|
-
}
|
|
390
|
-
});
|
|
391
|
-
}
|
|
392
|
-
|
|
393
|
-
#[napi]
|
|
394
|
-
pub fn mask_surrogate_native(
|
|
395
|
-
mut error_signal: napi::bindgen_prelude::Float32Array,
|
|
396
|
-
potentials: napi::bindgen_prelude::Float32Array,
|
|
397
|
-
threshold: f64,
|
|
398
|
-
window_size: f64,
|
|
399
|
-
) {
|
|
400
|
-
let err_slice = &mut *error_signal;
|
|
401
|
-
let pot_slice = &*potentials;
|
|
402
|
-
let th = threshold as f32;
|
|
403
|
-
let win = window_size as f32;
|
|
404
|
-
|
|
405
|
-
err_slice.par_iter_mut()
|
|
406
|
-
.zip(pot_slice.par_iter())
|
|
407
|
-
.for_each(|(e, p)| {
|
|
408
|
-
if (*p - th).abs() > win {
|
|
409
|
-
*e = 0.0;
|
|
410
|
-
}
|
|
411
|
-
});
|
|
412
|
-
}
|
|
413
|
-
|
|
414
|
-
#[napi]
|
|
415
|
-
pub fn apply_add_only_delta_native(
|
|
416
|
-
mut kernel: napi::bindgen_prelude::Float32Array,
|
|
417
|
-
mut bias: napi::bindgen_prelude::Float32Array,
|
|
418
|
-
inputs: napi::bindgen_prelude::Float32Array,
|
|
419
|
-
error_signal: napi::bindgen_prelude::Float32Array,
|
|
420
|
-
learning_rate: f64,
|
|
421
|
-
batch: u32,
|
|
422
|
-
in_features: u32,
|
|
423
|
-
units: u32,
|
|
424
|
-
use_bias: bool,
|
|
425
|
-
) {
|
|
426
|
-
let k_slice = &mut *kernel;
|
|
427
|
-
let b_slice = &mut *bias;
|
|
428
|
-
let in_slice = &*inputs;
|
|
429
|
-
let err_slice = &*error_signal;
|
|
430
|
-
let lr = learning_rate as f32;
|
|
431
|
-
|
|
432
|
-
let batch = batch as usize;
|
|
433
|
-
let in_f = in_features as usize;
|
|
434
|
-
let u = units as usize;
|
|
435
|
-
|
|
436
|
-
k_slice.par_chunks_mut(u).enumerate().for_each(|(k, k_row)| {
|
|
437
|
-
let mut row_update = vec![0.0; u];
|
|
438
|
-
for b in 0..batch {
|
|
439
|
-
let in_offset = b * in_f;
|
|
440
|
-
if in_slice[in_offset + k] == 1.0 {
|
|
441
|
-
let err_offset = b * u;
|
|
442
|
-
for j in 0..u {
|
|
443
|
-
row_update[j] += err_slice[err_offset + j];
|
|
444
|
-
}
|
|
445
|
-
}
|
|
446
|
-
}
|
|
447
|
-
for j in 0..u {
|
|
448
|
-
k_row[j] += lr * row_update[j];
|
|
449
|
-
}
|
|
450
|
-
});
|
|
451
|
-
|
|
452
|
-
if use_bias && b_slice.len() >= u {
|
|
453
|
-
b_slice.par_iter_mut().enumerate().for_each(|(j, b_val)| {
|
|
454
|
-
let mut b_update = 0.0;
|
|
455
|
-
for b in 0..batch {
|
|
456
|
-
let err_offset = b * u;
|
|
457
|
-
b_update += err_slice[err_offset + j];
|
|
458
|
-
}
|
|
459
|
-
*b_val += lr * b_update;
|
|
460
|
-
});
|
|
461
|
-
}
|
|
462
|
-
}
|
|
3
|
+
#[macro_use]
|
|
4
|
+
extern crate napi_derive;
|
|
5
|
+
|
|
6
|
+
mod dot_product;
|
|
7
|
+
mod lif;
|
|
8
|
+
mod surrogate;
|
|
9
|
+
mod delta;
|
|
10
|
+
mod embedding;
|
|
11
|
+
mod contrastive;
|
|
12
|
+
|
|
13
|
+
pub use dot_product::*;
|
|
14
|
+
pub use lif::*;
|
|
15
|
+
pub use surrogate::*;
|
|
16
|
+
pub use delta::*;
|
|
17
|
+
pub use embedding::*;
|
|
18
|
+
pub use contrastive::*;
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
use napi_derive::napi;
|
|
2
|
+
use napi::bindgen_prelude::Float32Array;
|
|
3
|
+
use rayon::prelude::*;
|
|
4
|
+
|
|
5
|
+
#[napi]
|
|
6
|
+
pub fn lif_step_native(
|
|
7
|
+
mut potentials: Float32Array,
|
|
8
|
+
dot: Float32Array,
|
|
9
|
+
mut spikes: Float32Array,
|
|
10
|
+
mut last_potentials: Float32Array,
|
|
11
|
+
beta: Float32Array,
|
|
12
|
+
threshold: Float32Array
|
|
13
|
+
) {
|
|
14
|
+
let units = beta.len();
|
|
15
|
+
if units == 0 { return; }
|
|
16
|
+
let batch = potentials.len() / units;
|
|
17
|
+
|
|
18
|
+
let pot_slice: &mut [f32] = &mut potentials;
|
|
19
|
+
let dot_slice: &[f32] = ˙
|
|
20
|
+
let spikes_slice: &mut [f32] = &mut spikes;
|
|
21
|
+
let last_pot_slice: &mut [f32] = &mut last_potentials;
|
|
22
|
+
let beta_slice: &[f32] = β
|
|
23
|
+
let thresh_slice: &[f32] = &threshold;
|
|
24
|
+
|
|
25
|
+
pot_slice.par_chunks_mut(units)
|
|
26
|
+
.zip(dot_slice.par_chunks(units))
|
|
27
|
+
.zip(spikes_slice.par_chunks_mut(units))
|
|
28
|
+
.zip(last_pot_slice.par_chunks_mut(units))
|
|
29
|
+
.for_each(|(((pot_chunk, dot_chunk), spike_chunk), last_pot_chunk)| {
|
|
30
|
+
for i in 0..units {
|
|
31
|
+
let mut pot = (pot_chunk[i] * beta_slice[i]) + dot_chunk[i];
|
|
32
|
+
pot = pot.min(1.0); // Clamp potential max 1.0
|
|
33
|
+
last_pot_chunk[i] = pot;
|
|
34
|
+
|
|
35
|
+
if pot >= thresh_slice[i] {
|
|
36
|
+
spike_chunk[i] = 1.0;
|
|
37
|
+
pot -= thresh_slice[i];
|
|
38
|
+
} else {
|
|
39
|
+
spike_chunk[i] = 0.0;
|
|
40
|
+
}
|
|
41
|
+
pot_chunk[i] = pot;
|
|
42
|
+
}
|
|
43
|
+
});
|
|
44
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
use napi_derive::napi;
|
|
2
|
+
use napi::bindgen_prelude::Float32Array;
|
|
3
|
+
use rayon::prelude::*;
|
|
4
|
+
|
|
5
|
+
#[napi]
|
|
6
|
+
pub fn mask_surrogate_native(
|
|
7
|
+
mut error_signal: Float32Array,
|
|
8
|
+
potentials: Float32Array,
|
|
9
|
+
threshold: Float32Array,
|
|
10
|
+
window_size: f64
|
|
11
|
+
) {
|
|
12
|
+
let units = threshold.len();
|
|
13
|
+
if units == 0 { return; }
|
|
14
|
+
|
|
15
|
+
let err_slice: &mut [f32] = &mut error_signal;
|
|
16
|
+
let pot_slice: &[f32] = &potentials;
|
|
17
|
+
let thresh_slice: &[f32] = &threshold;
|
|
18
|
+
|
|
19
|
+
err_slice.par_chunks_mut(units)
|
|
20
|
+
.zip(pot_slice.par_chunks(units))
|
|
21
|
+
.for_each(|(err_chunk, pot_chunk)| {
|
|
22
|
+
for i in 0..units {
|
|
23
|
+
if (pot_chunk[i] - thresh_slice[i]).abs() > window_size as f32 {
|
|
24
|
+
err_chunk[i] = 0.0;
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
});
|
|
28
|
+
}
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import { describe, it, expect, beforeEach } from 'vitest';
|
|
2
|
+
import { Matrix } from '@oxide-js/core';
|
|
3
|
+
import { SpikingDenseBPTT } from '../src/layers/SpikingDenseBPTT.js';
|
|
4
|
+
|
|
5
|
+
describe('SpikingDenseBPTT Layer', () => {
|
|
6
|
+
const units = 4;
|
|
7
|
+
const inFeatures = 3;
|
|
8
|
+
const batch = 2;
|
|
9
|
+
const maxTimeSteps = 3;
|
|
10
|
+
|
|
11
|
+
let layer: SpikingDenseBPTT;
|
|
12
|
+
|
|
13
|
+
beforeEach(() => {
|
|
14
|
+
layer = new SpikingDenseBPTT({
|
|
15
|
+
units,
|
|
16
|
+
kernelInitializer: 'ones',
|
|
17
|
+
useBias: false
|
|
18
|
+
});
|
|
19
|
+
layer.build([batch, inFeatures]);
|
|
20
|
+
});
|
|
21
|
+
|
|
22
|
+
it('should initialize parameters correctly', () => {
|
|
23
|
+
expect(layer.units).toBe(units);
|
|
24
|
+
expect(layer.kernel).toBeDefined();
|
|
25
|
+
|
|
26
|
+
// Mengecek apakah inisialisasi dynamic beta (Bit-Shift Decay Float) bekerja
|
|
27
|
+
expect(layer.beta).toBeDefined();
|
|
28
|
+
expect(layer.beta.length).toBe(units);
|
|
29
|
+
|
|
30
|
+
for (let i = 0; i < units; i++) {
|
|
31
|
+
// Karena kita melakukan pre-kalkulasi multiplier: 1.0 - (1.0 / Math.pow(2, shift))
|
|
32
|
+
// dimana shift = 2 hingga 5 (1/4 hingga 1/32)
|
|
33
|
+
// Maka rentang beta yang valid adalah 0.75 hingga 0.96875
|
|
34
|
+
expect(layer.beta[i]).toBeGreaterThanOrEqual(0.75);
|
|
35
|
+
expect(layer.beta[i]).toBeLessThanOrEqual(0.96875);
|
|
36
|
+
}
|
|
37
|
+
});
|
|
38
|
+
|
|
39
|
+
it('should throw error when calling compute() directly', () => {
|
|
40
|
+
const inputs = Matrix.fromFlat(new Float32Array(batch * inFeatures), [batch, inFeatures]);
|
|
41
|
+
expect(() => {
|
|
42
|
+
// @ts-ignore
|
|
43
|
+
layer.compute(inputs);
|
|
44
|
+
}).toThrowError(/Harap gunakan computeStep/);
|
|
45
|
+
});
|
|
46
|
+
|
|
47
|
+
it('should process sequence, enforce BPTT limits, and store history correctly', () => {
|
|
48
|
+
layer.resetSequence(maxTimeSteps);
|
|
49
|
+
|
|
50
|
+
expect(layer.maxTimeSteps).toBe(maxTimeSteps);
|
|
51
|
+
expect(layer.historyInputs.length).toBe(maxTimeSteps);
|
|
52
|
+
|
|
53
|
+
// Dummy binary spike input
|
|
54
|
+
const inputData = new Float32Array(batch * inFeatures).fill(1);
|
|
55
|
+
const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
|
|
56
|
+
|
|
57
|
+
// Time Step 0
|
|
58
|
+
const out0 = layer.computeStep(inputs, 0);
|
|
59
|
+
expect(out0._shape).toEqual([batch, units]);
|
|
60
|
+
expect(layer.historyInputs[0]).toBeDefined();
|
|
61
|
+
expect(layer.historyPotentials[0]).toBeDefined();
|
|
62
|
+
expect(layer.historySpikes[0]).toBeDefined();
|
|
63
|
+
|
|
64
|
+
// Time Step 1
|
|
65
|
+
const out1 = layer.computeStep(inputs, 1);
|
|
66
|
+
expect(out1._shape).toEqual([batch, units]);
|
|
67
|
+
expect(layer.historyInputs[1]).toBeDefined();
|
|
68
|
+
|
|
69
|
+
// Time Step 2
|
|
70
|
+
const out2 = layer.computeStep(inputs, 2);
|
|
71
|
+
expect(out2._shape).toEqual([batch, units]);
|
|
72
|
+
|
|
73
|
+
// Time Step 3 (Exceeds maxTimeSteps -> Harus Error)
|
|
74
|
+
expect(() => {
|
|
75
|
+
layer.computeStep(inputs, 3);
|
|
76
|
+
}).toThrowError(/melebihi batas maxTimeSteps/);
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
it('should run learnThroughTime properly without crashing', () => {
|
|
80
|
+
layer.resetSequence(maxTimeSteps);
|
|
81
|
+
|
|
82
|
+
const inputData = new Float32Array(batch * inFeatures).fill(1);
|
|
83
|
+
const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
|
|
84
|
+
|
|
85
|
+
// Jalankan seluruh sekuens
|
|
86
|
+
for (let t = 0; t < maxTimeSteps; t++) {
|
|
87
|
+
layer.computeStep(inputs, t);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Siapkan urutan error palsu untuk pengujian (error di t=0, t=1, t=2)
|
|
91
|
+
const errors = [];
|
|
92
|
+
for (let t = 0; t < maxTimeSteps; t++) {
|
|
93
|
+
errors.push(Matrix.fromFlat(new Float32Array(batch * units).fill(0.1), [batch, units]));
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Uji BPTT untuk Output Layer (parameter B = undefined)
|
|
97
|
+
expect(() => {
|
|
98
|
+
layer.learnThroughTime(errors, undefined, 0.01);
|
|
99
|
+
}).not.toThrow();
|
|
100
|
+
|
|
101
|
+
// Uji BPTT untuk Hidden Layer (parameter B = Identity Matrix Broadcast)
|
|
102
|
+
const B = Matrix.fromFlat(new Float32Array(units * units).fill(1), [units, units]);
|
|
103
|
+
expect(() => {
|
|
104
|
+
layer.learnThroughTime(errors, B, 0.01);
|
|
105
|
+
}).not.toThrow();
|
|
106
|
+
});
|
|
107
|
+
it('should correctly accumulate potentials and trigger spikes deterministically over time', () => {
|
|
108
|
+
// Buat layer deterministik
|
|
109
|
+
const testLayer = new SpikingDenseBPTT({
|
|
110
|
+
units: 2,
|
|
111
|
+
useBias: false,
|
|
112
|
+
kernelInitializer: 'zeros'
|
|
113
|
+
});
|
|
114
|
+
testLayer.build([1, 2]); // batch=1, inFeatures=2
|
|
115
|
+
|
|
116
|
+
// Override kernel manual: [ [0.6, 0.0], [0.0, 0.8] ]
|
|
117
|
+
testLayer.kernel!._data.set([0.6, 0.0, 0.0, 0.8]);
|
|
118
|
+
|
|
119
|
+
// Override konstan beta (0.5 agar mudah dihitung) dan threshold (1.0)
|
|
120
|
+
testLayer.beta.fill(0.5);
|
|
121
|
+
testLayer.threshold.fill(1.0);
|
|
122
|
+
|
|
123
|
+
// Siapkan sequence 3 time steps
|
|
124
|
+
testLayer.resetSequence(3);
|
|
125
|
+
|
|
126
|
+
// Input selalu menyala setiap timestep: [1, 1]
|
|
127
|
+
const inputs = Matrix.fromFlat(new Float32Array([1, 1]), [1, 2]);
|
|
128
|
+
|
|
129
|
+
// --- TIME STEP 0 ---
|
|
130
|
+
const out0 = testLayer.computeStep(inputs, 0);
|
|
131
|
+
expect(out0._data).toEqual(new Float32Array([0, 0]));
|
|
132
|
+
|
|
133
|
+
// History potentials sebelum spike harus mencatat nilai [0.6, 0.8]
|
|
134
|
+
expect(testLayer.historyPotentials[0]._data[0]).toBeCloseTo(0.6, 5);
|
|
135
|
+
expect(testLayer.historyPotentials[0]._data[1]).toBeCloseTo(0.8, 5);
|
|
136
|
+
|
|
137
|
+
// --- TIME STEP 1 ---
|
|
138
|
+
const out1 = testLayer.computeStep(inputs, 1);
|
|
139
|
+
expect(out1._data).toEqual(new Float32Array([0, 1]));
|
|
140
|
+
|
|
141
|
+
expect(testLayer.historyPotentials[1]._data[0]).toBeCloseTo(0.9, 5);
|
|
142
|
+
expect(testLayer.historyPotentials[1]._data[1]).toBeCloseTo(1.0, 5);
|
|
143
|
+
|
|
144
|
+
// --- TIME STEP 2 ---
|
|
145
|
+
const out2 = testLayer.computeStep(inputs, 2);
|
|
146
|
+
expect(out2._data).toEqual(new Float32Array([1, 0]));
|
|
147
|
+
|
|
148
|
+
expect(testLayer.historyPotentials[2]._data[0]).toBeCloseTo(1.0, 5);
|
|
149
|
+
expect(testLayer.historyPotentials[2]._data[1]).toBeCloseTo(0.8, 5);
|
|
150
|
+
});
|
|
151
|
+
});
|