@oxide-js/spiking 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/examples/demo.ts +101 -0
- package/index.d.ts +19 -0
- package/index.js +316 -0
- package/package.json +47 -0
- package/src/index.ts +5 -0
- package/src/layers/SpikingDense.ts +237 -0
- package/src/layers/SpikingEmbedding.ts +227 -0
- package/src/math/dotProductAddOnly.ts +229 -0
- package/src/models/SpikingSentenceEmbedder.ts +135 -0
- package/src/native_backend.ts +90 -0
- package/src-rust/Cargo.lock +324 -0
- package/src-rust/Cargo.toml +17 -0
- package/src-rust/build.rs +5 -0
- package/src-rust/src/lib.rs +462 -0
- package/test/test_embedding.ts +126 -0
- package/test/test_xor.ts +122 -0
- package/tsconfig.json +9 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
#![deny(clippy::all)]
|
|
2
|
+
|
|
3
|
+
use napi_derive::napi;
|
|
4
|
+
use rayon::prelude::*;
|
|
5
|
+
use serde::{Deserialize, Serialize};
|
|
6
|
+
use std::fs::File;
|
|
7
|
+
use std::io::{BufReader, BufWriter};
|
|
8
|
+
|
|
9
|
+
#[derive(Serialize, Deserialize)]
|
|
10
|
+
pub struct NetworkState {
|
|
11
|
+
pub num_neurons: u32,
|
|
12
|
+
pub beta: f32,
|
|
13
|
+
pub thresholds: Vec<f32>,
|
|
14
|
+
pub post_synaptic_indices: Vec<Vec<u32>>,
|
|
15
|
+
pub weights: Vec<Vec<f32>>,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
#[napi]
|
|
19
|
+
pub struct NativeSpikingNetwork {
|
|
20
|
+
num_neurons: u32,
|
|
21
|
+
potentials: Vec<f32>,
|
|
22
|
+
thresholds: Vec<f32>,
|
|
23
|
+
spikes: Vec<u8>,
|
|
24
|
+
|
|
25
|
+
// Adjacency lists
|
|
26
|
+
post_synaptic_indices: Vec<Vec<u32>>,
|
|
27
|
+
weights: Vec<Vec<f32>>,
|
|
28
|
+
|
|
29
|
+
// STDP Traces
|
|
30
|
+
pre_traces: Vec<f32>,
|
|
31
|
+
post_traces: Vec<f32>,
|
|
32
|
+
|
|
33
|
+
beta: f32,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
#[napi]
|
|
37
|
+
impl NativeSpikingNetwork {
|
|
38
|
+
#[napi(constructor)]
|
|
39
|
+
pub fn new(num_neurons: u32, beta: f64, default_threshold: f64) -> Self {
|
|
40
|
+
Self {
|
|
41
|
+
num_neurons,
|
|
42
|
+
potentials: vec![0.0; num_neurons as usize],
|
|
43
|
+
thresholds: vec![default_threshold as f32; num_neurons as usize],
|
|
44
|
+
spikes: vec![0; num_neurons as usize],
|
|
45
|
+
post_synaptic_indices: vec![Vec::new(); num_neurons as usize],
|
|
46
|
+
weights: vec![Vec::new(); num_neurons as usize],
|
|
47
|
+
pre_traces: vec![0.0; num_neurons as usize],
|
|
48
|
+
post_traces: vec![0.0; num_neurons as usize],
|
|
49
|
+
beta: beta as f32,
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
#[napi]
|
|
54
|
+
pub fn connect(&mut self, pre: u32, post: u32, weight: f64) {
|
|
55
|
+
let pre_idx = pre as usize;
|
|
56
|
+
if pre_idx >= self.num_neurons as usize || post as usize >= self.num_neurons as usize {
|
|
57
|
+
return;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
if let Some(pos) = self.post_synaptic_indices[pre_idx].iter().position(|&p| p == post) {
|
|
61
|
+
self.weights[pre_idx][pos] = weight as f32;
|
|
62
|
+
} else {
|
|
63
|
+
self.post_synaptic_indices[pre_idx].push(post);
|
|
64
|
+
self.weights[pre_idx].push(weight as f32);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
#[napi]
|
|
69
|
+
pub fn inject_current(&mut self, neuron_idx: u32, current: f64) {
|
|
70
|
+
let idx = neuron_idx as usize;
|
|
71
|
+
if idx < self.potentials.len() {
|
|
72
|
+
self.potentials[idx] += current as f32;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
#[napi]
|
|
77
|
+
pub fn inhibit_range(&mut self, start_idx: u32, end_idx: u32, except_idx: u32, current: f64) {
|
|
78
|
+
let start = start_idx as usize;
|
|
79
|
+
let end = end_idx as usize;
|
|
80
|
+
let ex = except_idx as usize;
|
|
81
|
+
let c = current as f32;
|
|
82
|
+
|
|
83
|
+
for i in start..end {
|
|
84
|
+
if i != ex && i < self.potentials.len() {
|
|
85
|
+
self.potentials[i] += c;
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
#[napi]
|
|
91
|
+
pub fn step(&mut self) {
|
|
92
|
+
// 1. Accumulate pre-synaptic spikes (Sequential for now to avoid data races)
|
|
93
|
+
// We only process neurons that spiked in the previous step
|
|
94
|
+
for i in 0..self.num_neurons as usize {
|
|
95
|
+
if self.spikes[i] == 1 {
|
|
96
|
+
let targets = &self.post_synaptic_indices[i];
|
|
97
|
+
let w = &self.weights[i];
|
|
98
|
+
for k in 0..targets.len() {
|
|
99
|
+
let j = targets[k] as usize;
|
|
100
|
+
self.potentials[j] += w[k];
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
// Reset spikes safely (Paralel)
|
|
106
|
+
self.spikes.par_iter_mut().for_each(|s| *s = 0);
|
|
107
|
+
|
|
108
|
+
// 2. Membrane decay, threshold check, and fire (Paralel via Rayon)
|
|
109
|
+
let beta = self.beta;
|
|
110
|
+
self.potentials
|
|
111
|
+
.par_iter_mut()
|
|
112
|
+
.zip(self.thresholds.par_iter())
|
|
113
|
+
.zip(self.spikes.par_iter_mut())
|
|
114
|
+
.for_each(|((pot, &thresh), spike)| {
|
|
115
|
+
*pot *= beta; // Leaky integration
|
|
116
|
+
if *pot < 1e-5 { *pot = 0.0; } // Prevent denormals
|
|
117
|
+
if *pot >= thresh {
|
|
118
|
+
*spike = 1;
|
|
119
|
+
*pot -= thresh; // Soft reset
|
|
120
|
+
}
|
|
121
|
+
});
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
#[napi]
|
|
125
|
+
pub fn reset_state(&mut self) {
|
|
126
|
+
self.potentials.par_iter_mut().for_each(|p| *p = 0.0);
|
|
127
|
+
self.spikes.par_iter_mut().for_each(|s| *s = 0);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// Helpers to get data to JS if needed
|
|
131
|
+
#[napi]
|
|
132
|
+
pub fn get_spikes(&self) -> Vec<u8> {
|
|
133
|
+
self.spikes.clone()
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
#[napi]
|
|
137
|
+
pub fn get_potentials(&self) -> Vec<f32> {
|
|
138
|
+
self.potentials.clone()
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
#[napi]
|
|
142
|
+
pub fn update_stdp(
|
|
143
|
+
&mut self,
|
|
144
|
+
learning_rate: f64,
|
|
145
|
+
tau_plus: f64,
|
|
146
|
+
tau_minus: f64,
|
|
147
|
+
a_plus: f64,
|
|
148
|
+
a_minus: f64,
|
|
149
|
+
w_max: f64,
|
|
150
|
+
w_min: f64,
|
|
151
|
+
) {
|
|
152
|
+
let lr = learning_rate as f32;
|
|
153
|
+
let t_plus = tau_plus as f32;
|
|
154
|
+
let t_minus = tau_minus as f32;
|
|
155
|
+
let ap = a_plus as f32;
|
|
156
|
+
let am = a_minus as f32;
|
|
157
|
+
let wmax = w_max as f32;
|
|
158
|
+
let wmin = w_min as f32;
|
|
159
|
+
|
|
160
|
+
// 1. Update trace decay and trace spikes
|
|
161
|
+
self.pre_traces
|
|
162
|
+
.par_iter_mut()
|
|
163
|
+
.zip(self.post_traces.par_iter_mut())
|
|
164
|
+
.zip(self.spikes.par_iter())
|
|
165
|
+
.for_each(|((pre, post), &spike)| {
|
|
166
|
+
*pre *= t_plus;
|
|
167
|
+
*post *= t_minus;
|
|
168
|
+
|
|
169
|
+
if *pre < 1e-5 { *pre = 0.0; }
|
|
170
|
+
if *post < 1e-5 { *post = 0.0; }
|
|
171
|
+
|
|
172
|
+
if spike == 1 {
|
|
173
|
+
*pre = 1.0;
|
|
174
|
+
*post = 1.0;
|
|
175
|
+
}
|
|
176
|
+
});
|
|
177
|
+
|
|
178
|
+
// 2. Event-driven weight updates
|
|
179
|
+
// Since we mutate self.weights, we can parallelize over the outer slice
|
|
180
|
+
// because each sub-vector weights[i] belongs to exclusively one thread.
|
|
181
|
+
// However, we need immutable access to self.post_traces and self.spikes.
|
|
182
|
+
// To do this cleanly in Rust with Rayon without violating borrow rules:
|
|
183
|
+
|
|
184
|
+
let spikes = &self.spikes;
|
|
185
|
+
let pre_traces = &self.pre_traces;
|
|
186
|
+
let post_traces = &self.post_traces;
|
|
187
|
+
let post_synaptic_indices = &self.post_synaptic_indices;
|
|
188
|
+
|
|
189
|
+
self.weights
|
|
190
|
+
.par_iter_mut()
|
|
191
|
+
.enumerate()
|
|
192
|
+
.for_each(|(i, w_list)| {
|
|
193
|
+
let pre_spiked = spikes[i] == 1;
|
|
194
|
+
let pre_trace = pre_traces[i];
|
|
195
|
+
|
|
196
|
+
if pre_trace == 0.0 && !pre_spiked {
|
|
197
|
+
return;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
let targets = &post_synaptic_indices[i];
|
|
201
|
+
|
|
202
|
+
for k in 0..targets.len() {
|
|
203
|
+
let j = targets[k] as usize;
|
|
204
|
+
let post_spiked = spikes[j] == 1;
|
|
205
|
+
|
|
206
|
+
let mut dw = 0.0;
|
|
207
|
+
|
|
208
|
+
// LTP
|
|
209
|
+
if post_spiked {
|
|
210
|
+
dw += lr * ap * pre_trace;
|
|
211
|
+
}
|
|
212
|
+
// LTD
|
|
213
|
+
if pre_spiked {
|
|
214
|
+
dw -= lr * am * post_traces[j];
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
if dw != 0.0 {
|
|
218
|
+
w_list[k] += dw;
|
|
219
|
+
if w_list[k] > wmax {
|
|
220
|
+
w_list[k] = wmax;
|
|
221
|
+
} else if w_list[k] < wmin {
|
|
222
|
+
w_list[k] = wmin;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
});
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
#[napi]
|
|
230
|
+
pub fn save_to_file(&self, filepath: String) {
|
|
231
|
+
let state = NetworkState {
|
|
232
|
+
num_neurons: self.num_neurons,
|
|
233
|
+
beta: self.beta,
|
|
234
|
+
thresholds: self.thresholds.clone(),
|
|
235
|
+
post_synaptic_indices: self.post_synaptic_indices.clone(),
|
|
236
|
+
weights: self.weights.clone(),
|
|
237
|
+
};
|
|
238
|
+
let file = File::create(filepath).expect("Gagal membuat file penyimpanan");
|
|
239
|
+
let writer = BufWriter::new(file);
|
|
240
|
+
serde_json::to_writer(writer, &state).expect("Gagal menyimpan NetworkState");
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
#[napi]
|
|
244
|
+
pub fn load_from_file(&mut self, filepath: String) {
|
|
245
|
+
let file = File::open(filepath).expect("Gagal membuka file penyimpanan");
|
|
246
|
+
let reader = BufReader::new(file);
|
|
247
|
+
let state: NetworkState = serde_json::from_reader(reader).expect("Gagal memuat NetworkState");
|
|
248
|
+
|
|
249
|
+
self.num_neurons = state.num_neurons;
|
|
250
|
+
self.beta = state.beta;
|
|
251
|
+
self.thresholds = state.thresholds;
|
|
252
|
+
self.post_synaptic_indices = state.post_synaptic_indices;
|
|
253
|
+
self.weights = state.weights;
|
|
254
|
+
|
|
255
|
+
// Sesuaikan memori untuk traces & potentials
|
|
256
|
+
let len = self.num_neurons as usize;
|
|
257
|
+
self.potentials.resize(len, 0.0);
|
|
258
|
+
self.spikes.resize(len, 0);
|
|
259
|
+
self.pre_traces.resize(len, 0.0);
|
|
260
|
+
self.post_traces.resize(len, 0.0);
|
|
261
|
+
|
|
262
|
+
// Reset state internal
|
|
263
|
+
self.potentials.fill(0.0);
|
|
264
|
+
self.spikes.fill(0);
|
|
265
|
+
self.pre_traces.fill(0.0);
|
|
266
|
+
self.post_traces.fill(0.0);
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
#[napi]
|
|
271
|
+
pub fn dot_product_add_only_native(
|
|
272
|
+
a_data: napi::bindgen_prelude::Float32Array,
|
|
273
|
+
a_rows_orig: u32,
|
|
274
|
+
a_cols_orig: u32,
|
|
275
|
+
b_data: napi::bindgen_prelude::Float32Array,
|
|
276
|
+
b_rows_orig: u32,
|
|
277
|
+
b_cols_orig: u32,
|
|
278
|
+
trans_a: bool,
|
|
279
|
+
trans_b: bool,
|
|
280
|
+
mut out_data: napi::bindgen_prelude::Float32Array,
|
|
281
|
+
) {
|
|
282
|
+
let a_rows = if trans_a { a_cols_orig } else { a_rows_orig } as usize;
|
|
283
|
+
let a_cols = if trans_a { a_rows_orig } else { a_cols_orig } as usize;
|
|
284
|
+
let b_rows = if trans_b { b_cols_orig } else { b_rows_orig } as usize;
|
|
285
|
+
let b_cols = if trans_b { b_rows_orig } else { b_cols_orig } as usize;
|
|
286
|
+
|
|
287
|
+
let a_slice = &*a_data;
|
|
288
|
+
let b_slice = &*b_data;
|
|
289
|
+
let out_slice = &mut *out_data;
|
|
290
|
+
|
|
291
|
+
let mut a_is_binary = true;
|
|
292
|
+
for &val in a_slice {
|
|
293
|
+
if val != 0.0 && val != 1.0 {
|
|
294
|
+
a_is_binary = false;
|
|
295
|
+
break;
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
let mut b_is_binary = true;
|
|
300
|
+
if !a_is_binary {
|
|
301
|
+
for &val in b_slice {
|
|
302
|
+
if val != 0.0 && val != 1.0 {
|
|
303
|
+
b_is_binary = false;
|
|
304
|
+
break;
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
if !a_is_binary && !b_is_binary {
|
|
310
|
+
panic!("SNN Error: Kedua matriks adalah floating-point. Setidaknya salah satu matriks harus hanya berisi 0 dan 1.");
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
out_slice.par_iter_mut().for_each(|x| *x = 0.0);
|
|
314
|
+
|
|
315
|
+
let a_rows_orig = a_rows_orig as usize;
|
|
316
|
+
let a_cols_orig = a_cols_orig as usize;
|
|
317
|
+
let b_cols_orig = b_cols_orig as usize;
|
|
318
|
+
|
|
319
|
+
out_slice.par_chunks_mut(b_cols).enumerate().for_each(|(i, out_row)| {
|
|
320
|
+
if !trans_b {
|
|
321
|
+
for k in 0..a_cols {
|
|
322
|
+
let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
|
|
323
|
+
if aik == 0.0 { continue; }
|
|
324
|
+
let k_offset = k * b_cols;
|
|
325
|
+
|
|
326
|
+
if a_is_binary {
|
|
327
|
+
// aik must be 1.0
|
|
328
|
+
for j in 0..b_cols {
|
|
329
|
+
out_row[j] += b_slice[k_offset + j];
|
|
330
|
+
}
|
|
331
|
+
} else {
|
|
332
|
+
// b_is_binary
|
|
333
|
+
for j in 0..b_cols {
|
|
334
|
+
if b_slice[k_offset + j] == 1.0 {
|
|
335
|
+
out_row[j] += aik;
|
|
336
|
+
}
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
}
|
|
340
|
+
} else {
|
|
341
|
+
// trans_b == true
|
|
342
|
+
for j in 0..b_cols {
|
|
343
|
+
let mut sum = 0.0;
|
|
344
|
+
for k in 0..a_cols {
|
|
345
|
+
let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
|
|
346
|
+
let bjk = b_slice[j * b_cols_orig + k];
|
|
347
|
+
if a_is_binary {
|
|
348
|
+
if aik == 1.0 { sum += bjk; }
|
|
349
|
+
} else {
|
|
350
|
+
if bjk == 1.0 { sum += aik; }
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
out_row[j] = sum;
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
});
|
|
357
|
+
}
|
|
358
|
+
});
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
#[napi]
|
|
362
|
+
pub fn lif_step_native(
|
|
363
|
+
mut potentials: napi::bindgen_prelude::Float32Array,
|
|
364
|
+
dot: napi::bindgen_prelude::Float32Array,
|
|
365
|
+
mut spikes: napi::bindgen_prelude::Float32Array,
|
|
366
|
+
mut last_potentials: napi::bindgen_prelude::Float32Array,
|
|
367
|
+
beta: f64,
|
|
368
|
+
threshold: f64,
|
|
369
|
+
) {
|
|
370
|
+
let pot_slice = &mut *potentials;
|
|
371
|
+
let dot_slice = &*dot;
|
|
372
|
+
let spike_slice = &mut *spikes;
|
|
373
|
+
let lp_slice = &mut *last_potentials;
|
|
374
|
+
let b = beta as f32;
|
|
375
|
+
let th = threshold as f32;
|
|
376
|
+
|
|
377
|
+
pot_slice.par_iter_mut()
|
|
378
|
+
.zip(dot_slice.par_iter())
|
|
379
|
+
.zip(spike_slice.par_iter_mut())
|
|
380
|
+
.zip(lp_slice.par_iter_mut())
|
|
381
|
+
.for_each(|(((p, d), s), lp)| {
|
|
382
|
+
*p = (*p * b) + d;
|
|
383
|
+
*lp = *p;
|
|
384
|
+
if *p >= th {
|
|
385
|
+
*s = 1.0;
|
|
386
|
+
*p -= th;
|
|
387
|
+
} else {
|
|
388
|
+
*s = 0.0;
|
|
389
|
+
}
|
|
390
|
+
});
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
#[napi]
|
|
394
|
+
pub fn mask_surrogate_native(
|
|
395
|
+
mut error_signal: napi::bindgen_prelude::Float32Array,
|
|
396
|
+
potentials: napi::bindgen_prelude::Float32Array,
|
|
397
|
+
threshold: f64,
|
|
398
|
+
window_size: f64,
|
|
399
|
+
) {
|
|
400
|
+
let err_slice = &mut *error_signal;
|
|
401
|
+
let pot_slice = &*potentials;
|
|
402
|
+
let th = threshold as f32;
|
|
403
|
+
let win = window_size as f32;
|
|
404
|
+
|
|
405
|
+
err_slice.par_iter_mut()
|
|
406
|
+
.zip(pot_slice.par_iter())
|
|
407
|
+
.for_each(|(e, p)| {
|
|
408
|
+
if (*p - th).abs() > win {
|
|
409
|
+
*e = 0.0;
|
|
410
|
+
}
|
|
411
|
+
});
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
#[napi]
|
|
415
|
+
pub fn apply_add_only_delta_native(
|
|
416
|
+
mut kernel: napi::bindgen_prelude::Float32Array,
|
|
417
|
+
mut bias: napi::bindgen_prelude::Float32Array,
|
|
418
|
+
inputs: napi::bindgen_prelude::Float32Array,
|
|
419
|
+
error_signal: napi::bindgen_prelude::Float32Array,
|
|
420
|
+
learning_rate: f64,
|
|
421
|
+
batch: u32,
|
|
422
|
+
in_features: u32,
|
|
423
|
+
units: u32,
|
|
424
|
+
use_bias: bool,
|
|
425
|
+
) {
|
|
426
|
+
let k_slice = &mut *kernel;
|
|
427
|
+
let b_slice = &mut *bias;
|
|
428
|
+
let in_slice = &*inputs;
|
|
429
|
+
let err_slice = &*error_signal;
|
|
430
|
+
let lr = learning_rate as f32;
|
|
431
|
+
|
|
432
|
+
let batch = batch as usize;
|
|
433
|
+
let in_f = in_features as usize;
|
|
434
|
+
let u = units as usize;
|
|
435
|
+
|
|
436
|
+
k_slice.par_chunks_mut(u).enumerate().for_each(|(k, k_row)| {
|
|
437
|
+
let mut row_update = vec![0.0; u];
|
|
438
|
+
for b in 0..batch {
|
|
439
|
+
let in_offset = b * in_f;
|
|
440
|
+
if in_slice[in_offset + k] == 1.0 {
|
|
441
|
+
let err_offset = b * u;
|
|
442
|
+
for j in 0..u {
|
|
443
|
+
row_update[j] += err_slice[err_offset + j];
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
}
|
|
447
|
+
for j in 0..u {
|
|
448
|
+
k_row[j] += lr * row_update[j];
|
|
449
|
+
}
|
|
450
|
+
});
|
|
451
|
+
|
|
452
|
+
if use_bias && b_slice.len() >= u {
|
|
453
|
+
b_slice.par_iter_mut().enumerate().for_each(|(j, b_val)| {
|
|
454
|
+
let mut b_update = 0.0;
|
|
455
|
+
for b in 0..batch {
|
|
456
|
+
let err_offset = b * u;
|
|
457
|
+
b_update += err_slice[err_offset + j];
|
|
458
|
+
}
|
|
459
|
+
*b_val += lr * b_update;
|
|
460
|
+
});
|
|
461
|
+
}
|
|
462
|
+
}
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import { Matrix } from "@oxide-js/core";
|
|
2
|
+
import { SpikingEmbedding } from "../src/layers/SpikingEmbedding.js";
|
|
3
|
+
import { SpikingDense } from "../src/layers/SpikingDense.js";
|
|
4
|
+
|
|
5
|
+
// Dataset: Mengajarkan SNN untuk mengenali 4 Token Kosakata
|
|
6
|
+
// Token 0 -> Target: Kelas A [1, 0, 0, 0]
|
|
7
|
+
// Token 1 -> Target: Kelas B [0, 1, 0, 0]
|
|
8
|
+
// Token 2 -> Target: Kelas C [0, 0, 1, 0]
|
|
9
|
+
// Token 3 -> Target: Kelas D [0, 0, 0, 1]
|
|
10
|
+
const xData = [[0], [1], [2], [3]];
|
|
11
|
+
const yData = [
|
|
12
|
+
[1, 0, 0, 0],
|
|
13
|
+
[0, 1, 0, 0],
|
|
14
|
+
[0, 0, 1, 0],
|
|
15
|
+
[0, 0, 0, 1]
|
|
16
|
+
];
|
|
17
|
+
|
|
18
|
+
const vocabSize = 4;
|
|
19
|
+
const embedDim = 16; // Ukuran otak embedding
|
|
20
|
+
const numClasses = 4;
|
|
21
|
+
|
|
22
|
+
console.log("Inisialisasi SpikingEmbedding & SpikingDense...");
|
|
23
|
+
|
|
24
|
+
const embedding = new SpikingEmbedding({
|
|
25
|
+
inputDim: vocabSize,
|
|
26
|
+
outputDim: embedDim,
|
|
27
|
+
beta: 0.9,
|
|
28
|
+
threshold: 1.0,
|
|
29
|
+
embeddingsInitializer: "glorot_normal"
|
|
30
|
+
});
|
|
31
|
+
|
|
32
|
+
const outputLayer = new SpikingDense({
|
|
33
|
+
units: numClasses,
|
|
34
|
+
beta: 0.9,
|
|
35
|
+
threshold: 1.0,
|
|
36
|
+
useBias: true,
|
|
37
|
+
kernelInitializer: "glorot_normal"
|
|
38
|
+
});
|
|
39
|
+
|
|
40
|
+
embedding.build([1, 1]); // input shape [batch=1, num_tokens=1]
|
|
41
|
+
outputLayer.build([1, embedDim]);
|
|
42
|
+
|
|
43
|
+
// Matriks Acak (B) untuk Feedback Alignment dari Output ke Embedding
|
|
44
|
+
const bData = new Float32Array(numClasses * embedDim);
|
|
45
|
+
for (let i = 0; i < bData.length; i++) bData[i] = (Math.random() * 2) - 1;
|
|
46
|
+
const B = Matrix.fromFlat(bData, [numClasses, embedDim]);
|
|
47
|
+
|
|
48
|
+
const epochs = 200;
|
|
49
|
+
const learningRate = 0.05;
|
|
50
|
+
|
|
51
|
+
console.log("Mulai training SNN Word-to-Class (Feedback Alignment)...");
|
|
52
|
+
|
|
53
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
54
|
+
let totalError = 0;
|
|
55
|
+
|
|
56
|
+
for (let i = 0; i < xData.length; i++) {
|
|
57
|
+
const x = Matrix.fromFlat(new Float32Array(xData[i]), [1, 1]);
|
|
58
|
+
const y = Matrix.fromFlat(new Float32Array(yData[i]), [1, numClasses]);
|
|
59
|
+
|
|
60
|
+
let outSpikes = Matrix.fromFlat(new Float32Array(numClasses), [1, numClasses]);
|
|
61
|
+
let sudahSpike = new Array(numClasses).fill(false);
|
|
62
|
+
|
|
63
|
+
embedding.resetState();
|
|
64
|
+
outputLayer.resetState();
|
|
65
|
+
|
|
66
|
+
// Berikan SNN waktu 5 timesteps untuk merenung dan menembak
|
|
67
|
+
for (let t = 0; t < 5; t++) {
|
|
68
|
+
const eSpikes = embedding.forward(x) as Matrix;
|
|
69
|
+
outSpikes = outputLayer.forward(eSpikes) as Matrix;
|
|
70
|
+
|
|
71
|
+
const actual = outSpikes._data;
|
|
72
|
+
const target = y._data;
|
|
73
|
+
|
|
74
|
+
const errData = new Float32Array(numClasses);
|
|
75
|
+
let stepError = 0;
|
|
76
|
+
|
|
77
|
+
for (let j = 0; j < numClasses; j++) {
|
|
78
|
+
if (actual[j] === 1) sudahSpike[j] = true;
|
|
79
|
+
|
|
80
|
+
if (target[j] === 1) {
|
|
81
|
+
if (!sudahSpike[j]) errData[j] = 1; // Dorong sampai spike
|
|
82
|
+
else errData[j] = 0; // Sudah spike, diam
|
|
83
|
+
} else {
|
|
84
|
+
errData[j] = 0 - actual[j]; // Kalau salah spike, hukum!
|
|
85
|
+
}
|
|
86
|
+
stepError += Math.abs(errData[j]);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
totalError += stepError;
|
|
90
|
+
|
|
91
|
+
if (stepError !== 0) {
|
|
92
|
+
const errorSignal = Matrix.fromFlat(errData, [1, numClasses]);
|
|
93
|
+
|
|
94
|
+
// 1. Output Layer Learn
|
|
95
|
+
outputLayer.learnOutput(errorSignal, learningRate);
|
|
96
|
+
|
|
97
|
+
// 2. Embedding Layer Learn (lewat Feedback Alignment B)
|
|
98
|
+
embedding.learnEmbedding(errorSignal, B, learningRate);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
if (epoch % 50 === 0 || epoch === epochs - 1) {
|
|
104
|
+
console.log(`Epoch ${epoch} | Total Spiking Error: ${totalError}`);
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Uji coba tebak-tebakan kata
|
|
109
|
+
console.log("\n--- HASIL PENGUJIAN ---");
|
|
110
|
+
for (let i = 0; i < xData.length; i++) {
|
|
111
|
+
const x = Matrix.fromFlat(new Float32Array(xData[i]), [1, 1]);
|
|
112
|
+
embedding.resetState();
|
|
113
|
+
outputLayer.resetState();
|
|
114
|
+
|
|
115
|
+
let totalTembakan = new Float32Array(numClasses);
|
|
116
|
+
|
|
117
|
+
for (let t = 0; t < 5; t++) {
|
|
118
|
+
const eSpikes = embedding.forward(x) as Matrix;
|
|
119
|
+
const outSpikes = outputLayer.forward(eSpikes) as Matrix;
|
|
120
|
+
for(let j=0; j<numClasses; j++) {
|
|
121
|
+
totalTembakan[j] += outSpikes._data[j];
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
console.log(`Token Input: [${xData[i][0]}] -> Prediksi Spike Kelas: [${totalTembakan.join(", ")}] | Target Seharusnya: [${yData[i].join(", ")}]`);
|
|
126
|
+
}
|
package/test/test_xor.ts
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import { Matrix } from "@oxide-js/core";
|
|
2
|
+
import { SpikingDense } from "../src/layers/SpikingDense.js";
|
|
3
|
+
|
|
4
|
+
// XOR Dataset
|
|
5
|
+
const xData = [
|
|
6
|
+
[0, 0],
|
|
7
|
+
[0, 1],
|
|
8
|
+
[1, 0],
|
|
9
|
+
[1, 1]
|
|
10
|
+
];
|
|
11
|
+
|
|
12
|
+
const yData = [
|
|
13
|
+
[0],
|
|
14
|
+
[1],
|
|
15
|
+
[1],
|
|
16
|
+
[0]
|
|
17
|
+
];
|
|
18
|
+
|
|
19
|
+
// Initialize layers
|
|
20
|
+
const hiddenUnits = 8;
|
|
21
|
+
const hiddenLayer = new SpikingDense({
|
|
22
|
+
units: hiddenUnits,
|
|
23
|
+
beta: 0.9,
|
|
24
|
+
threshold: 1.0,
|
|
25
|
+
useBias: true,
|
|
26
|
+
kernelInitializer: "glorot_normal"
|
|
27
|
+
});
|
|
28
|
+
|
|
29
|
+
const outputLayer = new SpikingDense({
|
|
30
|
+
units: 1,
|
|
31
|
+
beta: 0.9,
|
|
32
|
+
threshold: 1.0,
|
|
33
|
+
useBias: true,
|
|
34
|
+
kernelInitializer: "glorot_normal"
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
// Build layers
|
|
38
|
+
hiddenLayer.build([1, 2]);
|
|
39
|
+
outputLayer.build([1, hiddenUnits]);
|
|
40
|
+
|
|
41
|
+
// Random Matrix B for Feedback Alignment
|
|
42
|
+
// The shape should be [outputUnits, hiddenUnits] -> [1, 8]
|
|
43
|
+
const bData = new Float32Array(hiddenUnits);
|
|
44
|
+
const B = Matrix.fromFlat(bData, [1, hiddenUnits]);
|
|
45
|
+
for (let i = 0; i < bData.length; i++) {
|
|
46
|
+
// Random -1 to 1
|
|
47
|
+
bData[i] = (Math.random() * 2) - 1;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
const epochs = 500;
|
|
51
|
+
const learningRate = 0.01;
|
|
52
|
+
|
|
53
|
+
console.log("Mulai training SNN XOR dengan Feedback Alignment...");
|
|
54
|
+
|
|
55
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
56
|
+
let totalError = 0;
|
|
57
|
+
|
|
58
|
+
for (let i = 0; i < xData.length; i++) {
|
|
59
|
+
const x = Matrix.fromFlat(new Float32Array(xData[i]), [1, 2]);
|
|
60
|
+
const y = Matrix.fromFlat(new Float32Array(yData[i]), [1, 1]);
|
|
61
|
+
|
|
62
|
+
// Berikan waktu (misalnya 5 time-steps) untuk setiap input agar SNN bisa accumulate
|
|
63
|
+
let outSpikes = Matrix.fromFlat(new Float32Array(1), [1, 1]);
|
|
64
|
+
let sudahSpike = false;
|
|
65
|
+
|
|
66
|
+
for (let t = 0; t < 5; t++) {
|
|
67
|
+
const hSpikes = hiddenLayer.forward(x) as Matrix;
|
|
68
|
+
outSpikes = outputLayer.forward(hSpikes) as Matrix;
|
|
69
|
+
|
|
70
|
+
const actual = outSpikes._data[0];
|
|
71
|
+
const target = y._data[0];
|
|
72
|
+
|
|
73
|
+
if (actual === 1) sudahSpike = true;
|
|
74
|
+
|
|
75
|
+
let err = 0;
|
|
76
|
+
if (target === 1) {
|
|
77
|
+
if (!sudahSpike) err = 1; // Dorong terus sampai dia spike
|
|
78
|
+
else err = 0; // Udah spike, biarkan dia istirahat
|
|
79
|
+
} else { // Target === 0
|
|
80
|
+
err = 0 - actual; // Kalau target 0, dia HARUS 0 terus. Kalau spike, hukum!
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
totalError += Math.abs(err);
|
|
84
|
+
|
|
85
|
+
if (err !== 0) {
|
|
86
|
+
const errorSignal = Matrix.fromFlat(new Float32Array([err]), [1, 1]);
|
|
87
|
+
|
|
88
|
+
// 1. Output Layer Learn (Delta Rule standard)
|
|
89
|
+
outputLayer.learnOutput(errorSignal, learningRate);
|
|
90
|
+
|
|
91
|
+
// 2. Hidden Layer Learn (Feedback Alignment broadcast)
|
|
92
|
+
hiddenLayer.learnHidden(errorSignal, B, learningRate);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Reset state antar data point
|
|
97
|
+
hiddenLayer.resetState();
|
|
98
|
+
outputLayer.resetState();
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if (epoch % 100 === 0) {
|
|
102
|
+
console.log(`Epoch ${epoch} | Total Spiking Error: ${totalError}`);
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
// Uji coba hasil
|
|
107
|
+
console.log("\nHasil Pengujian:");
|
|
108
|
+
for (let i = 0; i < xData.length; i++) {
|
|
109
|
+
const x = Matrix.fromFlat(new Float32Array(xData[i]), [1, 2]);
|
|
110
|
+
hiddenLayer.resetState();
|
|
111
|
+
outputLayer.resetState();
|
|
112
|
+
|
|
113
|
+
let sumSpikes = 0;
|
|
114
|
+
for (let t = 0; t < 5; t++) {
|
|
115
|
+
const hSpikes = hiddenLayer.forward(x) as Matrix;
|
|
116
|
+
const outSpikes = outputLayer.forward(hSpikes) as Matrix;
|
|
117
|
+
sumSpikes += outSpikes._data[0];
|
|
118
|
+
}
|
|
119
|
+
// Jika dalam 5 timestep dia spike setidaknya 1 kali, kita anggap Prediksi 1
|
|
120
|
+
const pred = sumSpikes >= 1 ? 1 : 0;
|
|
121
|
+
console.log(`Input [${xData[i]}] -> Target: ${yData[i][0]} | Prediksi Spike: ${pred} (Total tembakan: ${sumSpikes})`);
|
|
122
|
+
}
|