@oxide-js/spiking 1.1.0 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,462 +1,18 @@
1
1
  #![deny(clippy::all)]
2
2
 
3
- use napi_derive::napi;
4
- use rayon::prelude::*;
5
- use serde::{Deserialize, Serialize};
6
- use std::fs::File;
7
- use std::io::{BufReader, BufWriter};
8
-
9
- #[derive(Serialize, Deserialize)]
10
- pub struct NetworkState {
11
- pub num_neurons: u32,
12
- pub beta: f32,
13
- pub thresholds: Vec<f32>,
14
- pub post_synaptic_indices: Vec<Vec<u32>>,
15
- pub weights: Vec<Vec<f32>>,
16
- }
17
-
18
- #[napi]
19
- pub struct NativeSpikingNetwork {
20
- num_neurons: u32,
21
- potentials: Vec<f32>,
22
- thresholds: Vec<f32>,
23
- spikes: Vec<u8>,
24
-
25
- // Adjacency lists
26
- post_synaptic_indices: Vec<Vec<u32>>,
27
- weights: Vec<Vec<f32>>,
28
-
29
- // STDP Traces
30
- pre_traces: Vec<f32>,
31
- post_traces: Vec<f32>,
32
-
33
- beta: f32,
34
- }
35
-
36
- #[napi]
37
- impl NativeSpikingNetwork {
38
- #[napi(constructor)]
39
- pub fn new(num_neurons: u32, beta: f64, default_threshold: f64) -> Self {
40
- Self {
41
- num_neurons,
42
- potentials: vec![0.0; num_neurons as usize],
43
- thresholds: vec![default_threshold as f32; num_neurons as usize],
44
- spikes: vec![0; num_neurons as usize],
45
- post_synaptic_indices: vec![Vec::new(); num_neurons as usize],
46
- weights: vec![Vec::new(); num_neurons as usize],
47
- pre_traces: vec![0.0; num_neurons as usize],
48
- post_traces: vec![0.0; num_neurons as usize],
49
- beta: beta as f32,
50
- }
51
- }
52
-
53
- #[napi]
54
- pub fn connect(&mut self, pre: u32, post: u32, weight: f64) {
55
- let pre_idx = pre as usize;
56
- if pre_idx >= self.num_neurons as usize || post as usize >= self.num_neurons as usize {
57
- return;
58
- }
59
-
60
- if let Some(pos) = self.post_synaptic_indices[pre_idx].iter().position(|&p| p == post) {
61
- self.weights[pre_idx][pos] = weight as f32;
62
- } else {
63
- self.post_synaptic_indices[pre_idx].push(post);
64
- self.weights[pre_idx].push(weight as f32);
65
- }
66
- }
67
-
68
- #[napi]
69
- pub fn inject_current(&mut self, neuron_idx: u32, current: f64) {
70
- let idx = neuron_idx as usize;
71
- if idx < self.potentials.len() {
72
- self.potentials[idx] += current as f32;
73
- }
74
- }
75
-
76
- #[napi]
77
- pub fn inhibit_range(&mut self, start_idx: u32, end_idx: u32, except_idx: u32, current: f64) {
78
- let start = start_idx as usize;
79
- let end = end_idx as usize;
80
- let ex = except_idx as usize;
81
- let c = current as f32;
82
-
83
- for i in start..end {
84
- if i != ex && i < self.potentials.len() {
85
- self.potentials[i] += c;
86
- }
87
- }
88
- }
89
-
90
- #[napi]
91
- pub fn step(&mut self) {
92
- // 1. Accumulate pre-synaptic spikes (Sequential for now to avoid data races)
93
- // We only process neurons that spiked in the previous step
94
- for i in 0..self.num_neurons as usize {
95
- if self.spikes[i] == 1 {
96
- let targets = &self.post_synaptic_indices[i];
97
- let w = &self.weights[i];
98
- for k in 0..targets.len() {
99
- let j = targets[k] as usize;
100
- self.potentials[j] += w[k];
101
- }
102
- }
103
- }
104
-
105
- // Reset spikes safely (Paralel)
106
- self.spikes.par_iter_mut().for_each(|s| *s = 0);
107
-
108
- // 2. Membrane decay, threshold check, and fire (Paralel via Rayon)
109
- let beta = self.beta;
110
- self.potentials
111
- .par_iter_mut()
112
- .zip(self.thresholds.par_iter())
113
- .zip(self.spikes.par_iter_mut())
114
- .for_each(|((pot, &thresh), spike)| {
115
- *pot *= beta; // Leaky integration
116
- if *pot < 1e-5 { *pot = 0.0; } // Prevent denormals
117
- if *pot >= thresh {
118
- *spike = 1;
119
- *pot -= thresh; // Soft reset
120
- }
121
- });
122
- }
123
-
124
- #[napi]
125
- pub fn reset_state(&mut self) {
126
- self.potentials.par_iter_mut().for_each(|p| *p = 0.0);
127
- self.spikes.par_iter_mut().for_each(|s| *s = 0);
128
- }
129
-
130
- // Helpers to get data to JS if needed
131
- #[napi]
132
- pub fn get_spikes(&self) -> Vec<u8> {
133
- self.spikes.clone()
134
- }
135
-
136
- #[napi]
137
- pub fn get_potentials(&self) -> Vec<f32> {
138
- self.potentials.clone()
139
- }
140
-
141
- #[napi]
142
- pub fn update_stdp(
143
- &mut self,
144
- learning_rate: f64,
145
- tau_plus: f64,
146
- tau_minus: f64,
147
- a_plus: f64,
148
- a_minus: f64,
149
- w_max: f64,
150
- w_min: f64,
151
- ) {
152
- let lr = learning_rate as f32;
153
- let t_plus = tau_plus as f32;
154
- let t_minus = tau_minus as f32;
155
- let ap = a_plus as f32;
156
- let am = a_minus as f32;
157
- let wmax = w_max as f32;
158
- let wmin = w_min as f32;
159
-
160
- // 1. Update trace decay and trace spikes
161
- self.pre_traces
162
- .par_iter_mut()
163
- .zip(self.post_traces.par_iter_mut())
164
- .zip(self.spikes.par_iter())
165
- .for_each(|((pre, post), &spike)| {
166
- *pre *= t_plus;
167
- *post *= t_minus;
168
-
169
- if *pre < 1e-5 { *pre = 0.0; }
170
- if *post < 1e-5 { *post = 0.0; }
171
-
172
- if spike == 1 {
173
- *pre = 1.0;
174
- *post = 1.0;
175
- }
176
- });
177
-
178
- // 2. Event-driven weight updates
179
- // Since we mutate self.weights, we can parallelize over the outer slice
180
- // because each sub-vector weights[i] belongs to exclusively one thread.
181
- // However, we need immutable access to self.post_traces and self.spikes.
182
- // To do this cleanly in Rust with Rayon without violating borrow rules:
183
-
184
- let spikes = &self.spikes;
185
- let pre_traces = &self.pre_traces;
186
- let post_traces = &self.post_traces;
187
- let post_synaptic_indices = &self.post_synaptic_indices;
188
-
189
- self.weights
190
- .par_iter_mut()
191
- .enumerate()
192
- .for_each(|(i, w_list)| {
193
- let pre_spiked = spikes[i] == 1;
194
- let pre_trace = pre_traces[i];
195
-
196
- if pre_trace == 0.0 && !pre_spiked {
197
- return;
198
- }
199
-
200
- let targets = &post_synaptic_indices[i];
201
-
202
- for k in 0..targets.len() {
203
- let j = targets[k] as usize;
204
- let post_spiked = spikes[j] == 1;
205
-
206
- let mut dw = 0.0;
207
-
208
- // LTP
209
- if post_spiked {
210
- dw += lr * ap * pre_trace;
211
- }
212
- // LTD
213
- if pre_spiked {
214
- dw -= lr * am * post_traces[j];
215
- }
216
-
217
- if dw != 0.0 {
218
- w_list[k] += dw;
219
- if w_list[k] > wmax {
220
- w_list[k] = wmax;
221
- } else if w_list[k] < wmin {
222
- w_list[k] = wmin;
223
- }
224
- }
225
- }
226
- });
227
- }
228
-
229
- #[napi]
230
- pub fn save_to_file(&self, filepath: String) {
231
- let state = NetworkState {
232
- num_neurons: self.num_neurons,
233
- beta: self.beta,
234
- thresholds: self.thresholds.clone(),
235
- post_synaptic_indices: self.post_synaptic_indices.clone(),
236
- weights: self.weights.clone(),
237
- };
238
- let file = File::create(filepath).expect("Gagal membuat file penyimpanan");
239
- let writer = BufWriter::new(file);
240
- serde_json::to_writer(writer, &state).expect("Gagal menyimpan NetworkState");
241
- }
242
-
243
- #[napi]
244
- pub fn load_from_file(&mut self, filepath: String) {
245
- let file = File::open(filepath).expect("Gagal membuka file penyimpanan");
246
- let reader = BufReader::new(file);
247
- let state: NetworkState = serde_json::from_reader(reader).expect("Gagal memuat NetworkState");
248
-
249
- self.num_neurons = state.num_neurons;
250
- self.beta = state.beta;
251
- self.thresholds = state.thresholds;
252
- self.post_synaptic_indices = state.post_synaptic_indices;
253
- self.weights = state.weights;
254
-
255
- // Sesuaikan memori untuk traces & potentials
256
- let len = self.num_neurons as usize;
257
- self.potentials.resize(len, 0.0);
258
- self.spikes.resize(len, 0);
259
- self.pre_traces.resize(len, 0.0);
260
- self.post_traces.resize(len, 0.0);
261
-
262
- // Reset state internal
263
- self.potentials.fill(0.0);
264
- self.spikes.fill(0);
265
- self.pre_traces.fill(0.0);
266
- self.post_traces.fill(0.0);
267
- }
268
- }
269
-
270
- #[napi]
271
- pub fn dot_product_add_only_native(
272
- a_data: napi::bindgen_prelude::Float32Array,
273
- a_rows_orig: u32,
274
- a_cols_orig: u32,
275
- b_data: napi::bindgen_prelude::Float32Array,
276
- b_rows_orig: u32,
277
- b_cols_orig: u32,
278
- trans_a: bool,
279
- trans_b: bool,
280
- mut out_data: napi::bindgen_prelude::Float32Array,
281
- ) {
282
- let a_rows = if trans_a { a_cols_orig } else { a_rows_orig } as usize;
283
- let a_cols = if trans_a { a_rows_orig } else { a_cols_orig } as usize;
284
- let b_rows = if trans_b { b_cols_orig } else { b_rows_orig } as usize;
285
- let b_cols = if trans_b { b_rows_orig } else { b_cols_orig } as usize;
286
-
287
- let a_slice = &*a_data;
288
- let b_slice = &*b_data;
289
- let out_slice = &mut *out_data;
290
-
291
- let mut a_is_binary = true;
292
- for &val in a_slice {
293
- if val != 0.0 && val != 1.0 {
294
- a_is_binary = false;
295
- break;
296
- }
297
- }
298
-
299
- let mut b_is_binary = true;
300
- if !a_is_binary {
301
- for &val in b_slice {
302
- if val != 0.0 && val != 1.0 {
303
- b_is_binary = false;
304
- break;
305
- }
306
- }
307
- }
308
-
309
- if !a_is_binary && !b_is_binary {
310
- panic!("SNN Error: Kedua matriks adalah floating-point. Setidaknya salah satu matriks harus hanya berisi 0 dan 1.");
311
- }
312
-
313
- out_slice.par_iter_mut().for_each(|x| *x = 0.0);
314
-
315
- let a_rows_orig = a_rows_orig as usize;
316
- let a_cols_orig = a_cols_orig as usize;
317
- let b_cols_orig = b_cols_orig as usize;
318
-
319
- out_slice.par_chunks_mut(b_cols).enumerate().for_each(|(i, out_row)| {
320
- if !trans_b {
321
- for k in 0..a_cols {
322
- let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
323
- if aik == 0.0 { continue; }
324
- let k_offset = k * b_cols;
325
-
326
- if a_is_binary {
327
- // aik must be 1.0
328
- for j in 0..b_cols {
329
- out_row[j] += b_slice[k_offset + j];
330
- }
331
- } else {
332
- // b_is_binary
333
- for j in 0..b_cols {
334
- if b_slice[k_offset + j] == 1.0 {
335
- out_row[j] += aik;
336
- }
337
- }
338
- }
339
- }
340
- } else {
341
- // trans_b == true
342
- for j in 0..b_cols {
343
- let mut sum = 0.0;
344
- for k in 0..a_cols {
345
- let aik = if trans_a { a_slice[k * a_rows_orig + i] } else { a_slice[i * a_cols_orig + k] };
346
- let bjk = b_slice[j * b_cols_orig + k];
347
- if a_is_binary {
348
- if aik == 1.0 { sum += bjk; }
349
- } else {
350
- if bjk == 1.0 { sum += aik; }
351
- }
352
- }
353
- out_row[j] = sum;
354
- }
355
- }
356
- });
357
- }
358
- });
359
- }
360
-
361
- #[napi]
362
- pub fn lif_step_native(
363
- mut potentials: napi::bindgen_prelude::Float32Array,
364
- dot: napi::bindgen_prelude::Float32Array,
365
- mut spikes: napi::bindgen_prelude::Float32Array,
366
- mut last_potentials: napi::bindgen_prelude::Float32Array,
367
- beta: f64,
368
- threshold: f64,
369
- ) {
370
- let pot_slice = &mut *potentials;
371
- let dot_slice = &*dot;
372
- let spike_slice = &mut *spikes;
373
- let lp_slice = &mut *last_potentials;
374
- let b = beta as f32;
375
- let th = threshold as f32;
376
-
377
- pot_slice.par_iter_mut()
378
- .zip(dot_slice.par_iter())
379
- .zip(spike_slice.par_iter_mut())
380
- .zip(lp_slice.par_iter_mut())
381
- .for_each(|(((p, d), s), lp)| {
382
- *p = (*p * b) + d;
383
- *lp = *p;
384
- if *p >= th {
385
- *s = 1.0;
386
- *p -= th;
387
- } else {
388
- *s = 0.0;
389
- }
390
- });
391
- }
392
-
393
- #[napi]
394
- pub fn mask_surrogate_native(
395
- mut error_signal: napi::bindgen_prelude::Float32Array,
396
- potentials: napi::bindgen_prelude::Float32Array,
397
- threshold: f64,
398
- window_size: f64,
399
- ) {
400
- let err_slice = &mut *error_signal;
401
- let pot_slice = &*potentials;
402
- let th = threshold as f32;
403
- let win = window_size as f32;
404
-
405
- err_slice.par_iter_mut()
406
- .zip(pot_slice.par_iter())
407
- .for_each(|(e, p)| {
408
- if (*p - th).abs() > win {
409
- *e = 0.0;
410
- }
411
- });
412
- }
413
-
414
- #[napi]
415
- pub fn apply_add_only_delta_native(
416
- mut kernel: napi::bindgen_prelude::Float32Array,
417
- mut bias: napi::bindgen_prelude::Float32Array,
418
- inputs: napi::bindgen_prelude::Float32Array,
419
- error_signal: napi::bindgen_prelude::Float32Array,
420
- learning_rate: f64,
421
- batch: u32,
422
- in_features: u32,
423
- units: u32,
424
- use_bias: bool,
425
- ) {
426
- let k_slice = &mut *kernel;
427
- let b_slice = &mut *bias;
428
- let in_slice = &*inputs;
429
- let err_slice = &*error_signal;
430
- let lr = learning_rate as f32;
431
-
432
- let batch = batch as usize;
433
- let in_f = in_features as usize;
434
- let u = units as usize;
435
-
436
- k_slice.par_chunks_mut(u).enumerate().for_each(|(k, k_row)| {
437
- let mut row_update = vec![0.0; u];
438
- for b in 0..batch {
439
- let in_offset = b * in_f;
440
- if in_slice[in_offset + k] == 1.0 {
441
- let err_offset = b * u;
442
- for j in 0..u {
443
- row_update[j] += err_slice[err_offset + j];
444
- }
445
- }
446
- }
447
- for j in 0..u {
448
- k_row[j] += lr * row_update[j];
449
- }
450
- });
451
-
452
- if use_bias && b_slice.len() >= u {
453
- b_slice.par_iter_mut().enumerate().for_each(|(j, b_val)| {
454
- let mut b_update = 0.0;
455
- for b in 0..batch {
456
- let err_offset = b * u;
457
- b_update += err_slice[err_offset + j];
458
- }
459
- *b_val += lr * b_update;
460
- });
461
- }
462
- }
3
+ #[macro_use]
4
+ extern crate napi_derive;
5
+
6
+ mod dot_product;
7
+ mod lif;
8
+ mod surrogate;
9
+ mod delta;
10
+ mod embedding;
11
+ mod contrastive;
12
+
13
+ pub use dot_product::*;
14
+ pub use lif::*;
15
+ pub use surrogate::*;
16
+ pub use delta::*;
17
+ pub use embedding::*;
18
+ pub use contrastive::*;
@@ -0,0 +1,44 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn lif_step_native(
7
+ mut potentials: Float32Array,
8
+ dot: Float32Array,
9
+ mut spikes: Float32Array,
10
+ mut last_potentials: Float32Array,
11
+ beta: Float32Array,
12
+ threshold: Float32Array
13
+ ) {
14
+ let units = beta.len();
15
+ if units == 0 { return; }
16
+ let batch = potentials.len() / units;
17
+
18
+ let pot_slice: &mut [f32] = &mut potentials;
19
+ let dot_slice: &[f32] = &dot;
20
+ let spikes_slice: &mut [f32] = &mut spikes;
21
+ let last_pot_slice: &mut [f32] = &mut last_potentials;
22
+ let beta_slice: &[f32] = &beta;
23
+ let thresh_slice: &[f32] = &threshold;
24
+
25
+ pot_slice.par_chunks_mut(units)
26
+ .zip(dot_slice.par_chunks(units))
27
+ .zip(spikes_slice.par_chunks_mut(units))
28
+ .zip(last_pot_slice.par_chunks_mut(units))
29
+ .for_each(|(((pot_chunk, dot_chunk), spike_chunk), last_pot_chunk)| {
30
+ for i in 0..units {
31
+ let mut pot = (pot_chunk[i] * beta_slice[i]) + dot_chunk[i];
32
+ pot = pot.min(1.0); // Clamp potential max 1.0
33
+ last_pot_chunk[i] = pot;
34
+
35
+ if pot >= thresh_slice[i] {
36
+ spike_chunk[i] = 1.0;
37
+ pot -= thresh_slice[i];
38
+ } else {
39
+ spike_chunk[i] = 0.0;
40
+ }
41
+ pot_chunk[i] = pot;
42
+ }
43
+ });
44
+ }
@@ -0,0 +1,28 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn mask_surrogate_native(
7
+ mut error_signal: Float32Array,
8
+ potentials: Float32Array,
9
+ threshold: Float32Array,
10
+ window_size: f64
11
+ ) {
12
+ let units = threshold.len();
13
+ if units == 0 { return; }
14
+
15
+ let err_slice: &mut [f32] = &mut error_signal;
16
+ let pot_slice: &[f32] = &potentials;
17
+ let thresh_slice: &[f32] = &threshold;
18
+
19
+ err_slice.par_chunks_mut(units)
20
+ .zip(pot_slice.par_chunks(units))
21
+ .for_each(|(err_chunk, pot_chunk)| {
22
+ for i in 0..units {
23
+ if (pot_chunk[i] - thresh_slice[i]).abs() > window_size as f32 {
24
+ err_chunk[i] = 0.0;
25
+ }
26
+ }
27
+ });
28
+ }
@@ -0,0 +1,151 @@
1
+ import { describe, it, expect, beforeEach } from 'vitest';
2
+ import { Matrix } from '@oxide-js/core';
3
+ import { SpikingDenseBPTT } from '../src/layers/SpikingDenseBPTT.js';
4
+
5
+ describe('SpikingDenseBPTT Layer', () => {
6
+ const units = 4;
7
+ const inFeatures = 3;
8
+ const batch = 2;
9
+ const maxTimeSteps = 3;
10
+
11
+ let layer: SpikingDenseBPTT;
12
+
13
+ beforeEach(() => {
14
+ layer = new SpikingDenseBPTT({
15
+ units,
16
+ kernelInitializer: 'ones',
17
+ useBias: false
18
+ });
19
+ layer.build([batch, inFeatures]);
20
+ });
21
+
22
+ it('should initialize parameters correctly', () => {
23
+ expect(layer.units).toBe(units);
24
+ expect(layer.kernel).toBeDefined();
25
+
26
+ // Mengecek apakah inisialisasi dynamic beta (Bit-Shift Decay Float) bekerja
27
+ expect(layer.beta).toBeDefined();
28
+ expect(layer.beta.length).toBe(units);
29
+
30
+ for (let i = 0; i < units; i++) {
31
+ // Karena kita melakukan pre-kalkulasi multiplier: 1.0 - (1.0 / Math.pow(2, shift))
32
+ // dimana shift = 2 hingga 5 (1/4 hingga 1/32)
33
+ // Maka rentang beta yang valid adalah 0.75 hingga 0.96875
34
+ expect(layer.beta[i]).toBeGreaterThanOrEqual(0.75);
35
+ expect(layer.beta[i]).toBeLessThanOrEqual(0.96875);
36
+ }
37
+ });
38
+
39
+ it('should throw error when calling compute() directly', () => {
40
+ const inputs = Matrix.fromFlat(new Float32Array(batch * inFeatures), [batch, inFeatures]);
41
+ expect(() => {
42
+ // @ts-ignore
43
+ layer.compute(inputs);
44
+ }).toThrowError(/Harap gunakan computeStep/);
45
+ });
46
+
47
+ it('should process sequence, enforce BPTT limits, and store history correctly', () => {
48
+ layer.resetSequence(maxTimeSteps);
49
+
50
+ expect(layer.maxTimeSteps).toBe(maxTimeSteps);
51
+ expect(layer.historyInputs.length).toBe(maxTimeSteps);
52
+
53
+ // Dummy binary spike input
54
+ const inputData = new Float32Array(batch * inFeatures).fill(1);
55
+ const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
56
+
57
+ // Time Step 0
58
+ const out0 = layer.computeStep(inputs, 0);
59
+ expect(out0._shape).toEqual([batch, units]);
60
+ expect(layer.historyInputs[0]).toBeDefined();
61
+ expect(layer.historyPotentials[0]).toBeDefined();
62
+ expect(layer.historySpikes[0]).toBeDefined();
63
+
64
+ // Time Step 1
65
+ const out1 = layer.computeStep(inputs, 1);
66
+ expect(out1._shape).toEqual([batch, units]);
67
+ expect(layer.historyInputs[1]).toBeDefined();
68
+
69
+ // Time Step 2
70
+ const out2 = layer.computeStep(inputs, 2);
71
+ expect(out2._shape).toEqual([batch, units]);
72
+
73
+ // Time Step 3 (Exceeds maxTimeSteps -> Harus Error)
74
+ expect(() => {
75
+ layer.computeStep(inputs, 3);
76
+ }).toThrowError(/melebihi batas maxTimeSteps/);
77
+ });
78
+
79
+ it('should run learnThroughTime properly without crashing', () => {
80
+ layer.resetSequence(maxTimeSteps);
81
+
82
+ const inputData = new Float32Array(batch * inFeatures).fill(1);
83
+ const inputs = Matrix.fromFlat(inputData, [batch, inFeatures]);
84
+
85
+ // Jalankan seluruh sekuens
86
+ for (let t = 0; t < maxTimeSteps; t++) {
87
+ layer.computeStep(inputs, t);
88
+ }
89
+
90
+ // Siapkan urutan error palsu untuk pengujian (error di t=0, t=1, t=2)
91
+ const errors = [];
92
+ for (let t = 0; t < maxTimeSteps; t++) {
93
+ errors.push(Matrix.fromFlat(new Float32Array(batch * units).fill(0.1), [batch, units]));
94
+ }
95
+
96
+ // Uji BPTT untuk Output Layer (parameter B = undefined)
97
+ expect(() => {
98
+ layer.learnThroughTime(errors, undefined, 0.01);
99
+ }).not.toThrow();
100
+
101
+ // Uji BPTT untuk Hidden Layer (parameter B = Identity Matrix Broadcast)
102
+ const B = Matrix.fromFlat(new Float32Array(units * units).fill(1), [units, units]);
103
+ expect(() => {
104
+ layer.learnThroughTime(errors, B, 0.01);
105
+ }).not.toThrow();
106
+ });
107
+ it('should correctly accumulate potentials and trigger spikes deterministically over time', () => {
108
+ // Buat layer deterministik
109
+ const testLayer = new SpikingDenseBPTT({
110
+ units: 2,
111
+ useBias: false,
112
+ kernelInitializer: 'zeros'
113
+ });
114
+ testLayer.build([1, 2]); // batch=1, inFeatures=2
115
+
116
+ // Override kernel manual: [ [0.6, 0.0], [0.0, 0.8] ]
117
+ testLayer.kernel!._data.set([0.6, 0.0, 0.0, 0.8]);
118
+
119
+ // Override konstan beta (0.5 agar mudah dihitung) dan threshold (1.0)
120
+ testLayer.beta.fill(0.5);
121
+ testLayer.threshold.fill(1.0);
122
+
123
+ // Siapkan sequence 3 time steps
124
+ testLayer.resetSequence(3);
125
+
126
+ // Input selalu menyala setiap timestep: [1, 1]
127
+ const inputs = Matrix.fromFlat(new Float32Array([1, 1]), [1, 2]);
128
+
129
+ // --- TIME STEP 0 ---
130
+ const out0 = testLayer.computeStep(inputs, 0);
131
+ expect(out0._data).toEqual(new Float32Array([0, 0]));
132
+
133
+ // History potentials sebelum spike harus mencatat nilai [0.6, 0.8]
134
+ expect(testLayer.historyPotentials[0]._data[0]).toBeCloseTo(0.6, 5);
135
+ expect(testLayer.historyPotentials[0]._data[1]).toBeCloseTo(0.8, 5);
136
+
137
+ // --- TIME STEP 1 ---
138
+ const out1 = testLayer.computeStep(inputs, 1);
139
+ expect(out1._data).toEqual(new Float32Array([0, 1]));
140
+
141
+ expect(testLayer.historyPotentials[1]._data[0]).toBeCloseTo(0.9, 5);
142
+ expect(testLayer.historyPotentials[1]._data[1]).toBeCloseTo(1.0, 5);
143
+
144
+ // --- TIME STEP 2 ---
145
+ const out2 = testLayer.computeStep(inputs, 2);
146
+ expect(out2._data).toEqual(new Float32Array([1, 0]));
147
+
148
+ expect(testLayer.historyPotentials[2]._data[0]).toBeCloseTo(1.0, 5);
149
+ expect(testLayer.historyPotentials[2]._data[1]).toBeCloseTo(0.8, 5);
150
+ });
151
+ });