@oxide-js/spiking 1.1.0 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,41 +2,36 @@ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
2
2
  import { Matrix, mj } from "@oxide-js/core";
3
3
  import {
4
4
  isNativeAvailable,
5
- lifStepNativeWrapper,
6
- maskSurrogateNativeWrapper
5
+ lifStepNativeWrapper,
6
+ maskSurrogateNativeWrapper
7
7
  } from "../native_backend.js";
8
8
 
9
9
  export interface SpikingEmbeddingConfig extends LayerConfig {
10
- inputDim: number; // Ukuran vocabulary
11
- outputDim: number; // Dimensi embedding (jumlah neuron)
12
- beta?: number; // Decay factor LIF
13
- threshold?: number; // Ambang batas Spike
14
- embeddingsInitializer?: string; // Tipe inisialisasi bobot
10
+ inputDim: number;
11
+ outputDim: number;
12
+ embeddingsInitializer?: string;
15
13
  }
16
14
 
17
15
  export class SpikingEmbedding extends BaseLayer {
18
16
  public inputDim: number;
19
17
  public outputDim: number;
20
- public beta: number;
21
- public threshold: number;
18
+ public embeddingsInitializer: string;
19
+ public beta!: Float32Array;
20
+ public threshold!: Float32Array;
22
21
 
23
22
  public potentials!: Matrix;
24
23
  public lastPotentials?: Matrix;
25
24
  public lastInputs?: Matrix;
26
25
  public lastSpikes?: Matrix;
27
26
 
28
- public embeddingsInitializer: string;
29
-
30
- public get kernel(): Matrix | undefined {
31
- return this.getParameter("kernel");
27
+ public get embeddings(): Matrix | undefined {
28
+ return this.getParameter("embeddings");
32
29
  }
33
30
 
34
31
  constructor(config: SpikingEmbeddingConfig) {
35
32
  super(config);
36
33
  this.inputDim = config.inputDim;
37
34
  this.outputDim = config.outputDim;
38
- this.beta = config.beta ?? 0.9;
39
- this.threshold = config.threshold ?? 1.0;
40
35
  this.embeddingsInitializer = config.embeddingsInitializer || "glorot_normal";
41
36
  }
42
37
 
@@ -47,8 +42,26 @@ export class SpikingEmbedding extends BaseLayer {
47
42
 
48
43
  public build(inputShape: number[]): void {
49
44
  super.build(inputShape);
50
- const kernelVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
51
- this.addParameter("kernel", kernelVal, true, [this.inputDim, this.outputDim]);
45
+
46
+ const embVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
47
+ this.addParameter("embeddings", embVal, true, [this.inputDim, this.outputDim]);
48
+
49
+ // Inisialisasi beta dan threshold secara acak untuk setiap neuron
50
+ this.beta = new Float32Array(this.outputDim);
51
+ this.threshold = new Float32Array(this.outputDim);
52
+ for (let i = 0; i < this.outputDim; i++) {
53
+ this.beta[i] = 0.8 + Math.random() * 0.19; // Random 0.8 - 0.99
54
+ this.threshold[i] = 0.5 + Math.random() * 0.5; // Random 0.5 - 1.0 (Max 1.0)
55
+ }
56
+
57
+ // Potentials start at 0, shape [batch, outputDim].
58
+ this.potentials = Matrix.fromFlat(new Float32Array(this.outputDim), [1, this.outputDim]);
59
+ }
60
+
61
+ private ensurePotentialsShape(batch: number) {
62
+ if (this.potentials._shape[0] !== batch) {
63
+ this.potentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
64
+ }
52
65
  }
53
66
 
54
67
  public resetState() {
@@ -58,38 +71,26 @@ export class SpikingEmbedding extends BaseLayer {
58
71
  this.lastSpikes = undefined;
59
72
  }
60
73
 
61
- private ensurePotentialsShape(batch: number) {
62
- if (!this.potentials || this.potentials._shape[0] !== batch) {
63
- this.potentials = Matrix.fromFlat(
64
- new Float32Array(batch * this.outputDim),
65
- [batch, this.outputDim]
66
- );
67
- }
68
- }
69
-
70
74
  protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
71
- const kernel = this.kernel!._data;
72
75
  const batch = inputs._shape[0];
73
- const inputData = inputs._data;
74
-
75
76
  this.ensurePotentialsShape(batch);
76
-
77
- // 1. Lookup Row (Pengganti dot-product)
77
+
78
+ // 1. Embedding lookup
79
+ const emb = this.embeddings!;
78
80
  const dotData = new Float32Array(batch * this.outputDim);
81
+
79
82
  for (let b = 0; b < batch; b++) {
80
- const tokenId = Math.round(inputData[b]); // Asumsi input adalah ID token berukuran [batch, 1]
81
-
82
- // Jika token valid, ekstrak barisnya sebagai Arus (Current)
83
- if (tokenId >= 0 && tokenId < this.inputDim) {
84
- const kernelOffset = tokenId * this.outputDim;
85
- const dotOffset = b * this.outputDim;
86
- for (let j = 0; j < this.outputDim; j++) {
87
- dotData[dotOffset + j] = kernel[kernelOffset + j];
88
- }
89
- }
83
+ const tokenIdx = inputs._data[b];
84
+ if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
85
+ const embOffset = tokenIdx * this.outputDim;
86
+ const dotOffset = b * this.outputDim;
87
+ for (let i = 0; i < this.outputDim; i++) {
88
+ dotData[dotOffset + i] = emb._data[embOffset + i];
89
+ }
90
+ }
90
91
  }
91
-
92
- // 2 & 3. Leaky Integrate, Fire & Reset
92
+
93
+ // 2. Leaky Integrate and Fire
93
94
  const outData = new Float32Array(batch * this.outputDim);
94
95
  const outSpikes = Matrix.fromFlat(outData, [batch, this.outputDim]);
95
96
  this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
@@ -98,130 +99,117 @@ export class SpikingEmbedding extends BaseLayer {
98
99
  lifStepNativeWrapper(
99
100
  this.potentials._data,
100
101
  dotData,
101
- outSpikes._data,
102
+ outData,
102
103
  this.lastPotentials._data,
103
104
  this.beta,
104
105
  this.threshold
105
106
  );
106
107
  } else {
107
108
  const potData = this.potentials._data;
108
- const thresh = this.threshold;
109
109
  const lpData = this.lastPotentials._data;
110
- for (let i = 0; i < potData.length; i++) {
111
- potData[i] = (potData[i] * this.beta) + dotData[i];
112
- lpData[i] = potData[i];
113
- }
114
- for (let i = 0; i < potData.length; i++) {
115
- if (potData[i] >= thresh) {
116
- outData[i] = 1;
117
- potData[i] -= thresh;
118
- } else {
119
- outData[i] = 0;
120
- }
110
+
111
+ for (let b = 0; b < batch; b++) {
112
+ const offset = b * this.outputDim;
113
+ for (let i = 0; i < this.outputDim; i++) {
114
+ const idx = offset + i;
115
+ potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0); // Clamp potential max 1.0
116
+ lpData[idx] = potData[idx];
117
+ }
118
+ for (let i = 0; i < this.outputDim; i++) {
119
+ const idx = offset + i;
120
+ if (potData[idx] >= this.threshold[i]) {
121
+ outData[idx] = 1;
122
+ potData[idx] -= this.threshold[i];
123
+ } else {
124
+ outData[idx] = 0;
125
+ }
126
+ }
121
127
  }
122
128
  }
123
129
 
124
- // Simpan memori untuk update bobot
125
130
  this.lastInputs = inputs;
126
131
  this.lastSpikes = outSpikes;
127
132
 
128
133
  return outSpikes;
129
134
  }
130
135
 
131
- // Embedding hanya menerima instruksi belajar dari layer atasnya (eHidden yang sudah dikalikan matriks B)
132
- public learnEmbedding(errorFromNext: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
133
- if (!this.lastInputs) {
134
- throw new Error("[SpikingEmbedding] Cannot run learnEmbedding() before forward() is executed. 'lastInputs' is undefined.");
135
- }
136
-
137
- const kernel = this.kernel!._data;
138
- const inputData = this.lastInputs._data;
139
- const batch = this.lastInputs._shape[0];
140
-
141
- // Hitung error yang mampir ke embedding
142
- // E * B (Feedback Alignment)
143
- // Gunakan matmul biasa karena B adalah float, dan errorFromNext mungkin float
144
- const eHidden = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
145
- // Namun karena OxideJS Matrix belum memiliki fungsi dot produk standar terbuka yang stabil,
146
- // kita harus hati-hati di sini. Untuk simplifikasi, eHidden = errorFromNext * B.
147
- // Kita asumsikan ada utilitas dotProduct standar dari core.
148
- // Jika B adalah matriks Dense (dimensi: outUnits x hiddenUnits), maka
149
- // eHidden [batch, hiddenUnits] = errorFromNext [batch, outUnits] dot B [outUnits, hiddenUnits]
136
+ public learnEmbedding(errorSignal: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
137
+ // Broadcast error mundur (Feedback Alignment)
138
+ let eHidden = mj.dotProduct(errorSignal, B, undefined, false, false); // E * B
150
139
 
151
- // Kita panggil dot product standar (bukan Add-Only, karena error dan B sama-sama float)
152
- let eHiddenMatrix = mj.dotProduct(errorFromNext, B, undefined, false, false);
153
-
154
140
  // Surrogate Mask: Boxcar
155
141
  if (this.lastPotentials) {
142
+ const eData = eHidden._data;
143
+ const pData = this.lastPotentials._data;
144
+ const windowSize = 1.0;
145
+
156
146
  if (isNativeAvailable()) {
157
147
  maskSurrogateNativeWrapper(
158
- eHiddenMatrix._data,
159
- this.lastPotentials._data,
160
- this.threshold,
161
- 1.0
148
+ eData,
149
+ pData,
150
+ this.threshold,
151
+ windowSize
162
152
  );
163
153
  } else {
164
- const eData = eHiddenMatrix._data;
165
- const pData = this.lastPotentials._data;
166
- const thresh = this.threshold;
167
- const windowSize = 1.0;
168
-
169
- for (let i = 0; i < eData.length; i++) {
170
- if (Math.abs(pData[i] - thresh) > windowSize) {
171
- eData[i] = 0;
154
+ const batch = eHidden._shape[0];
155
+ for (let b = 0; b < batch; b++) {
156
+ const offset = b * this.outputDim;
157
+ for (let i = 0; i < this.outputDim; i++) {
158
+ const idx = offset + i;
159
+ if (Math.abs(pData[idx] - this.threshold[i]) > windowSize) {
160
+ eData[idx] = 0;
161
+ }
172
162
  }
173
163
  }
174
164
  }
175
165
  }
176
166
 
177
- // Delta Rule Update pada baris Lookup (sangat efisien)
178
- const err = eHiddenMatrix._data;
179
- for (let b = 0; b < batch; b++) {
180
- const tokenId = Math.round(inputData[b]);
181
- if (tokenId >= 0 && tokenId < this.inputDim) {
182
- const kOffset = tokenId * this.outputDim;
183
- const errOffset = b * this.outputDim;
184
- for (let j = 0; j < this.outputDim; j++) {
185
- kernel[kOffset + j] += learningRate * err[errOffset + j];
186
- }
187
- }
188
- }
189
-
190
- return eHiddenMatrix;
167
+ this.applyEmbeddingDelta(eHidden, learningRate);
168
+ return eHidden;
191
169
  }
192
-
193
- /**
194
- * Word2Vec CBOW-style Hebbian Contrastive Learning
195
- * Memungkinkan pembelajaran embedding semantik secara topologis tanpa representation collapse.
196
- */
197
- public learnHebbian(
198
- tokens: number[] | Float32Array,
199
- positiveContext: Float32Array,
200
- negativeContexts: Float32Array[],
201
- learningRate: number = 0.01,
202
- marginPositive: number = 0.1,
203
- marginNegative: number = 0.05
204
- ): void {
205
- const kernel = this.kernel!._data;
206
- const dim = this.outputDim;
207
-
208
- for (let n = 0; n < negativeContexts.length; n++) {
209
- const negMean = negativeContexts[n];
210
- for (let i = 0; i < tokens.length; i++) {
211
- const tokenId = Math.round(tokens[i]);
212
- if (tokenId >= 0 && tokenId < this.inputDim) {
213
- const offset = tokenId * dim;
214
- for (let j = 0; j < dim; j++) {
215
- // Tarik kata ke arah konteks kalimatnya (Positive) - hanya sekali per token
216
- const posGradient = (n === 0) ? (positiveContext[j] - kernel[offset + j]) : 0;
217
- // Tolak kata dari konteks kalimat acak (Negative)
218
- const negGradient = kernel[offset + j] - negMean[j];
219
-
220
- const update = (posGradient * marginPositive) - (negGradient * marginNegative);
221
- kernel[offset + j] += learningRate * update;
170
+
171
+ private applyEmbeddingDelta(errorSignal: Matrix, learningRate: number) {
172
+ if (!this.lastInputs || !this.lastSpikes) {
173
+ throw new Error("[SpikingEmbedding] Cannot run learning before forward() is executed.");
174
+ }
175
+
176
+ const embeddings = this.embeddings!._data;
177
+ const inputs = this.lastInputs._data;
178
+ const err = errorSignal._data;
179
+
180
+ const batch = this.lastInputs._shape[0];
181
+ const outputDim = this.outputDim;
182
+
183
+ if (isNativeAvailable()) {
184
+ applyEmbeddingDeltaNativeWrapper(
185
+ embeddings,
186
+ inputs,
187
+ err,
188
+ learningRate,
189
+ this.inputDim,
190
+ outputDim
191
+ );
192
+ } else {
193
+ for (let b = 0; b < batch; b++) {
194
+ const tokenIdx = inputs[b];
195
+ if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
196
+ const embOffset = tokenIdx * outputDim;
197
+ const errOffset = b * outputDim;
198
+ for (let j = 0; j < outputDim; j++) {
199
+ embeddings[embOffset + j] += learningRate * err[errOffset + j];
200
+ embeddings[embOffset + j] = Math.max(-1.0, Math.min(1.0, embeddings[embOffset + j])); // Clamp weight [-1, 1]
222
201
  }
223
202
  }
224
203
  }
225
204
  }
226
205
  }
206
+
207
+ public getConfig(): Record<string, any> {
208
+ return {
209
+ ...super.getConfig(),
210
+ inputDim: this.inputDim,
211
+ outputDim: this.outputDim,
212
+ embeddingsInitializer: this.embeddingsInitializer
213
+ };
214
+ }
227
215
  }
@@ -47,8 +47,8 @@ export const lifStepNativeWrapper = (
47
47
  dot: Float32Array,
48
48
  spikes: Float32Array,
49
49
  lastPotentials: Float32Array,
50
- beta: number,
51
- threshold: number
50
+ beta: Float32Array,
51
+ threshold: Float32Array
52
52
  ): void => {
53
53
  if (!native) throw new Error("Spiking Native backend not available");
54
54
  native.lifStepNative(potentials, dot, spikes, lastPotentials, beta, threshold);
@@ -57,7 +57,7 @@ export const lifStepNativeWrapper = (
57
57
  export const maskSurrogateNativeWrapper = (
58
58
  errorSignal: Float32Array,
59
59
  potentials: Float32Array,
60
- threshold: number,
60
+ threshold: Float32Array,
61
61
  windowSize: number
62
62
  ): void => {
63
63
  if (!native) throw new Error("Spiking Native backend not available");
@@ -88,3 +88,22 @@ export const applyAddOnlyDeltaNativeWrapper = (
88
88
  useBias
89
89
  );
90
90
  };
91
+
92
+ export const applyEmbeddingDeltaNativeWrapper = (
93
+ embeddings: Float32Array,
94
+ inputs: Float32Array,
95
+ errorSignal: Float32Array,
96
+ learningRate: number,
97
+ inputDim: number,
98
+ outputDim: number
99
+ ): void => {
100
+ if (!native) throw new Error("Spiking Native backend not available");
101
+ native.applyEmbeddingDeltaNative(
102
+ embeddings,
103
+ inputs,
104
+ errorSignal,
105
+ learningRate,
106
+ inputDim,
107
+ outputDim
108
+ );
109
+ };
@@ -0,0 +1,51 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn apply_add_only_delta_native(
7
+ mut kernel: Float32Array,
8
+ mut bias: Float32Array,
9
+ inputs: Float32Array,
10
+ error_signal: Float32Array,
11
+ learning_rate: f64,
12
+ batch: u32,
13
+ in_features: u32,
14
+ units: u32,
15
+ use_bias: bool
16
+ ) {
17
+ let in_feat = in_features as usize;
18
+ let u = units as usize;
19
+ let b_size = batch as usize;
20
+ let lr = learning_rate as f32;
21
+
22
+ let kernel_slice: &mut [f32] = &mut kernel;
23
+ let bias_slice: &mut [f32] = &mut bias;
24
+ let in_slice: &[f32] = &inputs;
25
+ let err_slice: &[f32] = &error_signal;
26
+
27
+ kernel_slice.par_chunks_mut(u).enumerate().for_each(|(k, kernel_row)| {
28
+ for b in 0..b_size {
29
+ if in_slice[b * in_feat + k] > 0.5 {
30
+ let err_offset = b * u;
31
+ for j in 0..u {
32
+ kernel_row[j] += lr * err_slice[err_offset + j];
33
+ }
34
+ }
35
+ }
36
+ for j in 0..u {
37
+ kernel_row[j] = kernel_row[j].clamp(-1.0, 1.0);
38
+ }
39
+ });
40
+
41
+ if use_bias {
42
+ bias_slice.par_iter_mut().enumerate().for_each(|(j, b_val)| {
43
+ let mut sum = 0.0;
44
+ for b in 0..b_size {
45
+ sum += err_slice[b * u + j];
46
+ }
47
+ *b_val += (lr * sum) / (b_size as f32);
48
+ *b_val = b_val.clamp(-1.0, 1.0);
49
+ });
50
+ }
51
+ }
@@ -0,0 +1,47 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn dot_product_add_only_native(
7
+ a_data: Float32Array,
8
+ a_rows_orig: u32,
9
+ a_cols_orig: u32,
10
+ b_data: Float32Array,
11
+ b_rows_orig: u32,
12
+ b_cols_orig: u32,
13
+ trans_a: bool,
14
+ trans_b: bool,
15
+ mut out_data: Float32Array
16
+ ) {
17
+ let a_rows = if trans_a { a_cols_orig } else { a_rows_orig } as usize;
18
+ let a_cols = if trans_a { a_rows_orig } else { a_cols_orig } as usize;
19
+ let b_cols = if trans_b { b_rows_orig } else { b_cols_orig } as usize;
20
+
21
+ let a_slice: &[f32] = &a_data;
22
+ let b_slice: &[f32] = &b_data;
23
+ let out_slice: &mut [f32] = &mut out_data;
24
+
25
+ out_slice.par_chunks_mut(b_cols).enumerate().for_each(|(i, out_row)| {
26
+ let a_offset = i * a_cols;
27
+ for k in 0..a_cols {
28
+ let a_val = if trans_a {
29
+ a_slice[k * a_rows + i]
30
+ } else {
31
+ a_slice[a_offset + k]
32
+ };
33
+
34
+ if a_val > 0.5 {
35
+ let b_offset = k * b_cols;
36
+ for j in 0..b_cols {
37
+ let b_val = if trans_b {
38
+ b_slice[j * a_cols + k]
39
+ } else {
40
+ b_slice[b_offset + j]
41
+ };
42
+ out_row[j] += b_val;
43
+ }
44
+ }
45
+ }
46
+ });
47
+ }
@@ -0,0 +1,28 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+
4
+ #[napi]
5
+ pub fn apply_embedding_delta_native(
6
+ mut embeddings: Float32Array,
7
+ inputs: Float32Array,
8
+ error_signal: Float32Array,
9
+ learning_rate: f64,
10
+ input_dim: u32,
11
+ output_dim: u32
12
+ ) {
13
+ let batch = inputs.len();
14
+ let out_dim = output_dim as usize;
15
+
16
+ for b in 0..batch {
17
+ let token_idx = inputs[b] as i32;
18
+ if token_idx >= 0 && token_idx < input_dim as i32 {
19
+ let token_idx = token_idx as usize;
20
+ let emb_offset = token_idx * out_dim;
21
+ let err_offset = b * out_dim;
22
+ for j in 0..out_dim {
23
+ embeddings[emb_offset + j] += learning_rate as f32 * error_signal[err_offset + j];
24
+ embeddings[emb_offset + j] = embeddings[emb_offset + j].clamp(-1.0, 1.0);
25
+ }
26
+ }
27
+ }
28
+ }