@oxide-js/spiking 1.1.0 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,462 +1,16 @@
1
1
  #![deny(clippy::all)]
2
2
 
3
- use napi_derive::napi;
4
- use rayon::prelude::*;
5
- use serde::{Deserialize, Serialize};
6
- use std::fs::File;
7
- use std::io::{BufReader, BufWriter};
8
-
9
- #[derive(Serialize, Deserialize)]
10
- pub struct NetworkState {
11
- pub num_neurons: u32,
12
- pub beta: f32,
13
- pub thresholds: Vec<f32>,
14
- pub post_synaptic_indices: Vec<Vec<u32>>,
15
- pub weights: Vec<Vec<f32>>,
16
- }
17
-
18
- #[napi]
19
- pub struct NativeSpikingNetwork {
20
- num_neurons: u32,
21
- potentials: Vec<f32>,
22
- thresholds: Vec<f32>,
23
- spikes: Vec<u8>,
24
-
25
- // Adjacency lists
26
- post_synaptic_indices: Vec<Vec<u32>>,
27
- weights: Vec<Vec<f32>>,
28
-
29
- // STDP Traces
30
- pre_traces: Vec<f32>,
31
- post_traces: Vec<f32>,
32
-
33
- beta: f32,
34
- }
35
-
36
- #[napi]
37
- impl NativeSpikingNetwork {
38
- #[napi(constructor)]
39
- pub fn new(num_neurons: u32, beta: f64, default_threshold: f64) -> Self {
40
- Self {
41
- num_neurons,
42
- potentials: vec![0.0; num_neurons as usize],
43
- thresholds: vec![default_threshold as f32; num_neurons as usize],
44
- spikes: vec![0; num_neurons as usize],
45
- post_synaptic_indices: vec![Vec::new(); num_neurons as usize],
46
- weights: vec![Vec::new(); num_neurons as usize],
47
- pre_traces: vec![0.0; num_neurons as usize],
48
- post_traces: vec![0.0; num_neurons as usize],
49
- beta: beta as f32,
50
- }
51
- }
52
-
53
- #[napi]
54
- pub fn connect(&mut self, pre: u32, post: u32, weight: f64) {
55
- let pre_idx = pre as usize;
56
- if pre_idx >= self.num_neurons as usize || post as usize >= self.num_neurons as usize {
57
- return;
58
- }
59
-
60
- if let Some(pos) = self.post_synaptic_indices[pre_idx].iter().position(|&p| p == post) {
61
- self.weights[pre_idx][pos] = weight as f32;
62
- } else {
63
- self.post_synaptic_indices[pre_idx].push(post);
64
- self.weights[pre_idx].push(weight as f32);
65
- }
66
- }
67
-
68
- #[napi]
69
- pub fn inject_current(&mut self, neuron_idx: u32, current: f64) {
70
- let idx = neuron_idx as usize;
71
- if idx < self.potentials.len() {
72
- self.potentials[idx] += current as f32;
73
- }
74
- }
75
-
76
- #[napi]
77
- pub fn inhibit_range(&mut self, start_idx: u32, end_idx: u32, except_idx: u32, current: f64) {
78
- let start = start_idx as usize;
79
- let end = end_idx as usize;
80
- let ex = except_idx as usize;
81
- let c = current as f32;
82
-
83
- for i in start..end {
84
- if i != ex && i < self.potentials.len() {
85
- self.potentials[i] += c;
86
- }
87
- }
88
- }
89
-
90
- #[napi]
91
- pub fn step(&mut self) {
92
- // 1. Accumulate pre-synaptic spikes (Sequential for now to avoid data races)
93
- // We only process neurons that spiked in the previous step
94
- for i in 0..self.num_neurons as usize {
95
- if self.spikes[i] == 1 {
96
- let targets = &self.post_synaptic_indices[i];
97
- let w = &self.weights[i];
98
- for k in 0..targets.len() {
99
- let j = targets[k] as usize;
100
- self.potentials[j] += w[k];
101
- }
102
- }
103
- }
104
-
105
- // Reset spikes safely (Paralel)
106
- self.spikes.par_iter_mut().for_each(|s| *s = 0);
107
-
108
- // 2. Membrane decay, threshold check, and fire (Paralel via Rayon)
109
- let beta = self.beta;
110
- self.potentials
111
- .par_iter_mut()
112
- .zip(self.thresholds.par_iter())
113
- .zip(self.spikes.par_iter_mut())
114
- .for_each(|((pot, &thresh), spike)| {
115
- *pot *= beta; // Leaky integration
116
- if *pot < 1e-5 { *pot = 0.0; } // Prevent denormals
117
- if *pot >= thresh {
118
- *spike = 1;
119
- *pot -= thresh; // Soft reset
120
- }
121
- });
122
- }
123
-
124
- #[napi]
125
- pub fn reset_state(&mut self) {
126
- self.potentials.par_iter_mut().for_each(|p| *p = 0.0);
127
- self.spikes.par_iter_mut().for_each(|s| *s = 0);
128
- }
129
-
130
- // Helpers to get data to JS if needed
131
- #[napi]
132
- pub fn get_spikes(&self) -> Vec<u8> {
133
- self.spikes.clone()
134
- }
135
-
136
- #[napi]
137
- pub fn get_potentials(&self) -> Vec<f32> {
138
- self.potentials.clone()
139
- }
140
-
141
- #[napi]
142
- pub fn update_stdp(
143
- &mut self,
144
- learning_rate: f64,
145
- tau_plus: f64,
146
- tau_minus: f64,
147
- a_plus: f64,
148
- a_minus: f64,
149
- w_max: f64,
150
- w_min: f64,
151
- ) {
152
- let lr = learning_rate as f32;
153
- let t_plus = tau_plus as f32;
154
- let t_minus = tau_minus as f32;
155
- let ap = a_plus as f32;
156
- let am = a_minus as f32;
157
- let wmax = w_max as f32;
158
- let wmin = w_min as f32;
159
-
160
- // 1. Update trace decay and trace spikes
161
- self.pre_traces
162
- .par_iter_mut()
163
- .zip(self.post_traces.par_iter_mut())
164
- .zip(self.spikes.par_iter())
165
- .for_each(|((pre, post), &spike)| {
166
- *pre *= t_plus;
167
- *post *= t_minus;
168
-
169
- if *pre < 1e-5 { *pre = 0.0; }
170
- if *post < 1e-5 { *post = 0.0; }
171
-
172
- if spike == 1 {
173
- *pre = 1.0;
174
- *post = 1.0;
175
- }
176
- });
177
-
178
- // 2. Event-driven weight updates
179
- // Since we mutate self.weights, we can parallelize over the outer slice
180
- // because each sub-vector weights[i] belongs to exclusively one thread.
181
- // However, we need immutable access to self.post_traces and self.spikes.
182
- // To do this cleanly in Rust with Rayon without violating borrow rules:
183
-
184
- let spikes = &self.spikes;
185
- let pre_traces = &self.pre_traces;
186
- let post_traces = &self.post_traces;
187
- let post_synaptic_indices = &self.post_synaptic_indices;
188
-
189
- self.weights
190
- .par_iter_mut()
191
- .enumerate()
192
- .for_each(|(i, w_list)| {
193
- let pre_spiked = spikes[i] == 1;
194
- let pre_trace = pre_traces[i];
195
-
196
- if pre_trace == 0.0 && !pre_spiked {
197
- return;
198
- }
199
-
200
- let targets = &post_synaptic_indices[i];
201
-
202
- for k in 0..targets.len() {
203
- let j = targets[k] as usize;
204
- let post_spiked = spikes[j] == 1;
205
-
206
- let mut dw = 0.0;
207
-
208
- // LTP
209
- if post_spiked {
210
- dw += lr * ap * pre_trace;
211
- }
212
- // LTD
213
- if pre_spiked {
214
- dw -= lr * am * post_traces[j];
215
- }
216
-
217
- if dw != 0.0 {
218
- w_list[k] += dw;
219
- if w_list[k] > wmax {
220
- w_list[k] = wmax;
221
- } else if w_list[k] < wmin {
222
- w_list[k] = wmin;
223
- }
224
- }
225
- }
226
- });
227
- }
228
-
229
- #[napi]
230
- pub fn save_to_file(&self, filepath: String) {
231
- let state = NetworkState {
232
- num_neurons: self.num_neurons,
233
- beta: self.beta,
234
- thresholds: self.thresholds.clone(),
235
- post_synaptic_indices: self.post_synaptic_indices.clone(),
236
- weights: self.weights.clone(),
237
- };
238
- let file = File::create(filepath).expect("Gagal membuat file penyimpanan");
239
- let writer = BufWriter::new(file);
240
- serde_json::to_writer(writer, &state).expect("Gagal menyimpan NetworkState");
241
- }
242
-
243
- #[napi]
244
- pub fn load_from_file(&mut self, filepath: String) {
245
- let file = File::open(filepath).expect("Gagal membuka file penyimpanan");
246
- let reader = BufReader::new(file);
247
- let state: NetworkState = serde_json::from_reader(reader).expect("Gagal memuat NetworkState");
248
-
249
- self.num_neurons = state.num_neurons;
250
- self.beta = state.beta;
251
- self.thresholds = state.thresholds;
252
- self.post_synaptic_indices = state.post_synaptic_indices;
253
- self.weights = state.weights;
254
-
255
- // Sesuaikan memori untuk traces & potentials
256
- let len = self.num_neurons as usize;
257
- self.potentials.resize(len, 0.0);
258
- self.spikes.resize(len, 0);
259
- self.pre_traces.resize(len, 0.0);
260
- self.post_traces.resize(len, 0.0);
261
-
262
- // Reset state internal
263
- self.potentials.fill(0.0);
264
- self.spikes.fill(0);
265
- self.pre_traces.fill(0.0);
266
- self.post_traces.fill(0.0);
267
- }
268
- }
269
-
270
- #[napi]
271
- pub fn dot_product_add_only_native(
272
- a_data: napi::bindgen_prelude::Float32Array,
273
- a_rows_orig: u32,
274
- a_cols_orig: u32,
275
- b_data: napi::bindgen_prelude::Float32Array,
276
- b_rows_orig: u32,
277
- b_cols_orig: u32,
278
- trans_a: bool,
279
- trans_b: bool,
280
- mut out_data: napi::bindgen_prelude::Float32Array,
281
- ) {
282
- let a_rows = if trans_a { a_cols_orig } else { a_rows_orig } as usize;
283
- let a_cols = if trans_a { a_rows_orig } else { a_cols_orig } as usize;
284
- let b_rows = if trans_b { b_cols_orig } else { b_rows_orig } as usize;
285
- let b_cols = if trans_b { b_rows_orig } else { b_cols_orig } as usize;
286
-
287
- let a_slice = &*a_data;
288
- let b_slice = &*b_data;
289
- let out_slice = &mut *out_data;
290
-
291
- let mut a_is_binary = true;
292
- for &val in a_slice {
293
- if val != 0.0 && val != 1.0 {
294
- a_is_binary = false;
295
- break;
296
- }
297
- }
298
-
299
- let mut b_is_binary = true;
300
- if !a_is_binary {
301
- for &val in b_slice {
302
- if val != 0.0 && val != 1.0 {
303
- b_is_binary = false;
304
- break;
305
- }
306
- }
307
- }
308
-
309
- if !a_is_binary && !b_is_binary {
310
- panic!("SNN Error: Kedua matriks adalah floating-point. Setidaknya salah satu matriks harus hanya berisi 0 dan 1.");
311
- }
312
-
313
- out_slice.par_iter_mut().for_each(|x| *x = 0.0);
314
-
315
- let a_rows_orig = a_rows_orig as usize;
316
- let a_cols_orig = a_cols_orig as usize;
317
- let b_cols_orig = b_cols_orig as usize;
318
-
319
- out_slice.par_chunks_mut(b_cols).enumerate().for_each(|(i, out_row)| {
320
- if !trans_b {
321
- for k in 0..a_cols {
322
- let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
323
- if aik == 0.0 { continue; }
324
- let k_offset = k * b_cols;
325
-
326
- if a_is_binary {
327
- // aik must be 1.0
328
- for j in 0..b_cols {
329
- out_row[j] += b_slice[k_offset + j];
330
- }
331
- } else {
332
- // b_is_binary
333
- for j in 0..b_cols {
334
- if b_slice[k_offset + j] == 1.0 {
335
- out_row[j] += aik;
336
- }
337
- }
338
- }
339
- }
340
- } else {
341
- // trans_b == true
342
- for j in 0..b_cols {
343
- let mut sum = 0.0;
344
- for k in 0..a_cols {
345
- let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
346
- let bjk = b_slice[j * b_cols_orig + k];
347
- if a_is_binary {
348
- if aik == 1.0 { sum += bjk; }
349
- } else {
350
- if bjk == 1.0 { sum += aik; }
351
- }
352
- }
353
- out_row[j] = sum;
354
- }
355
- }
356
- });
357
- }
358
- });
359
- }
360
-
361
- #[napi]
362
- pub fn lif_step_native(
363
- mut potentials: napi::bindgen_prelude::Float32Array,
364
- dot: napi::bindgen_prelude::Float32Array,
365
- mut spikes: napi::bindgen_prelude::Float32Array,
366
- mut last_potentials: napi::bindgen_prelude::Float32Array,
367
- beta: f64,
368
- threshold: f64,
369
- ) {
370
- let pot_slice = &mut *potentials;
371
- let dot_slice = &*dot;
372
- let spike_slice = &mut *spikes;
373
- let lp_slice = &mut *last_potentials;
374
- let b = beta as f32;
375
- let th = threshold as f32;
376
-
377
- pot_slice.par_iter_mut()
378
- .zip(dot_slice.par_iter())
379
- .zip(spike_slice.par_iter_mut())
380
- .zip(lp_slice.par_iter_mut())
381
- .for_each(|(((p, d), s), lp)| {
382
- *p = (*p * b) + d;
383
- *lp = *p;
384
- if *p >= th {
385
- *s = 1.0;
386
- *p -= th;
387
- } else {
388
- *s = 0.0;
389
- }
390
- });
391
- }
392
-
393
- #[napi]
394
- pub fn mask_surrogate_native(
395
- mut error_signal: napi::bindgen_prelude::Float32Array,
396
- potentials: napi::bindgen_prelude::Float32Array,
397
- threshold: f64,
398
- window_size: f64,
399
- ) {
400
- let err_slice = &mut *error_signal;
401
- let pot_slice = &*potentials;
402
- let th = threshold as f32;
403
- let win = window_size as f32;
404
-
405
- err_slice.par_iter_mut()
406
- .zip(pot_slice.par_iter())
407
- .for_each(|(e, p)| {
408
- if (*p - th).abs() > win {
409
- *e = 0.0;
410
- }
411
- });
412
- }
413
-
414
- #[napi]
415
- pub fn apply_add_only_delta_native(
416
- mut kernel: napi::bindgen_prelude::Float32Array,
417
- mut bias: napi::bindgen_prelude::Float32Array,
418
- inputs: napi::bindgen_prelude::Float32Array,
419
- error_signal: napi::bindgen_prelude::Float32Array,
420
- learning_rate: f64,
421
- batch: u32,
422
- in_features: u32,
423
- units: u32,
424
- use_bias: bool,
425
- ) {
426
- let k_slice = &mut *kernel;
427
- let b_slice = &mut *bias;
428
- let in_slice = &*inputs;
429
- let err_slice = &*error_signal;
430
- let lr = learning_rate as f32;
431
-
432
- let batch = batch as usize;
433
- let in_f = in_features as usize;
434
- let u = units as usize;
435
-
436
- k_slice.par_chunks_mut(u).enumerate().for_each(|(k, k_row)| {
437
- let mut row_update = vec![0.0; u];
438
- for b in 0..batch {
439
- let in_offset = b * in_f;
440
- if in_slice[in_offset + k] == 1.0 {
441
- let err_offset = b * u;
442
- for j in 0..u {
443
- row_update[j] += err_slice[err_offset + j];
444
- }
445
- }
446
- }
447
- for j in 0..u {
448
- k_row[j] += lr * row_update[j];
449
- }
450
- });
451
-
452
- if use_bias && b_slice.len() >= u {
453
- b_slice.par_iter_mut().enumerate().for_each(|(j, b_val)| {
454
- let mut b_update = 0.0;
455
- for b in 0..batch {
456
- let err_offset = b * u;
457
- b_update += err_slice[err_offset + j];
458
- }
459
- *b_val += lr * b_update;
460
- });
461
- }
462
- }
3
+ #[macro_use]
4
+ extern crate napi_derive;
5
+
6
+ mod dot_product;
7
+ mod lif;
8
+ mod surrogate;
9
+ mod delta;
10
+ mod embedding;
11
+
12
+ pub use dot_product::*;
13
+ pub use lif::*;
14
+ pub use surrogate::*;
15
+ pub use delta::*;
16
+ pub use embedding::*;
@@ -0,0 +1,44 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn lif_step_native(
7
+ mut potentials: Float32Array,
8
+ dot: Float32Array,
9
+ mut spikes: Float32Array,
10
+ mut last_potentials: Float32Array,
11
+ beta: Float32Array,
12
+ threshold: Float32Array
13
+ ) {
14
+ let units = beta.len();
15
+ if units == 0 { return; }
16
+ let batch = potentials.len() / units;
17
+
18
+ let pot_slice: &mut [f32] = &mut potentials;
19
+ let dot_slice: &[f32] = &dot;
20
+ let spikes_slice: &mut [f32] = &mut spikes;
21
+ let last_pot_slice: &mut [f32] = &mut last_potentials;
22
+ let beta_slice: &[f32] = &beta;
23
+ let thresh_slice: &[f32] = &threshold;
24
+
25
+ pot_slice.par_chunks_mut(units)
26
+ .zip(dot_slice.par_chunks(units))
27
+ .zip(spikes_slice.par_chunks_mut(units))
28
+ .zip(last_pot_slice.par_chunks_mut(units))
29
+ .for_each(|(((pot_chunk, dot_chunk), spike_chunk), last_pot_chunk)| {
30
+ for i in 0..units {
31
+ let mut pot = (pot_chunk[i] * beta_slice[i]) + dot_chunk[i];
32
+ pot = pot.min(1.0); // Clamp potential max 1.0
33
+ last_pot_chunk[i] = pot;
34
+
35
+ if pot >= thresh_slice[i] {
36
+ spike_chunk[i] = 1.0;
37
+ pot -= thresh_slice[i];
38
+ } else {
39
+ spike_chunk[i] = 0.0;
40
+ }
41
+ pot_chunk[i] = pot;
42
+ }
43
+ });
44
+ }
@@ -0,0 +1,28 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn mask_surrogate_native(
7
+ mut error_signal: Float32Array,
8
+ potentials: Float32Array,
9
+ threshold: Float32Array,
10
+ window_size: f64
11
+ ) {
12
+ let units = threshold.len();
13
+ if units == 0 { return; }
14
+
15
+ let err_slice: &mut [f32] = &mut error_signal;
16
+ let pot_slice: &[f32] = &potentials;
17
+ let thresh_slice: &[f32] = &threshold;
18
+
19
+ err_slice.par_chunks_mut(units)
20
+ .zip(pot_slice.par_chunks(units))
21
+ .for_each(|(err_chunk, pot_chunk)| {
22
+ for i in 0..units {
23
+ if (pot_chunk[i] - thresh_slice[i]).abs() > window_size as f32 {
24
+ err_chunk[i] = 0.0;
25
+ }
26
+ }
27
+ });
28
+ }