@oxide-js/spiking 1.1.0 → 1.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,335 @@
1
+ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
2
+ import { Matrix, mj } from "@oxide-js/core";
3
+ import { isNativeAvailable, lifStepNativeWrapper } from "../native_backend.js";
4
+ import dotProductAddOnly from "../math/dotProductAddOnly.js";
5
+
6
+ export interface SpikingSelfAttentionConfig extends LayerConfig {
7
+ d_model: number;
8
+ sequenceLength: number;
9
+ kernelInitializer?: string;
10
+ betaRange?: [number, number];
11
+ thresholdRange?: [number, number];
12
+ }
13
+
14
+ export class SpikingSelfAttention extends BaseLayer {
15
+ public d_model: number;
16
+ public sequenceLength: number;
17
+ public kernelInitializer: string;
18
+ public betaRange: [number, number];
19
+ public thresholdRange: [number, number];
20
+
21
+ // Q, K, V kernels
22
+ public get kernelQ(): Matrix | undefined { return this.getParameter("kernelQ"); }
23
+ public get kernelK(): Matrix | undefined { return this.getParameter("kernelK"); }
24
+ public get kernelV(): Matrix | undefined { return this.getParameter("kernelV"); }
25
+
26
+ // LIF state untuk Q, K, V (opsional, jika ingin akumulasi temporal)
27
+ public betaQKV!: Float32Array;
28
+ public thresholdQKV!: Float32Array;
29
+ public potentialsQ!: Matrix;
30
+ public potentialsK!: Matrix;
31
+ public potentialsV!: Matrix;
32
+
33
+ // LIF state untuk Attention Scores (Pengganti Softmax)
34
+ public betaScores!: Float32Array;
35
+ public thresholdScores!: Float32Array;
36
+ public potentialsScores!: Matrix;
37
+
38
+ // Cache input untuk Local Learning
39
+ public lastInputs?: Matrix;
40
+
41
+ constructor(config: SpikingSelfAttentionConfig) {
42
+ super(config);
43
+ this.d_model = config.d_model;
44
+ this.sequenceLength = config.sequenceLength;
45
+ this.kernelInitializer = config.kernelInitializer || "glorot_normal";
46
+ this.betaRange = config.betaRange || [0.8, 0.99];
47
+ this.thresholdRange = config.thresholdRange || [0.1, 0.3];
48
+ }
49
+
50
+ public computeOutputShape(inputShape: number[]): number[] {
51
+ const batch = inputShape[0] ?? -1;
52
+ // Asumsi input shape: [batch * seqLen, d_model]
53
+ return [batch, this.d_model]; // Actually [batch * seqLen, d_model]
54
+ }
55
+
56
+ public build(inputShape: number[]): void {
57
+ super.build(inputShape);
58
+
59
+ const inFeatures = inputShape[inputShape.length - 1]; // Seharusnya sama dengan d_model
60
+
61
+ // 1. Inisialisasi Bobot Q, K, V
62
+ this.addParameter("kernelQ", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
63
+ this.addParameter("kernelK", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
64
+ this.addParameter("kernelV", this.createInitializer(this.kernelInitializer, [inFeatures, this.d_model]), true, [inFeatures, this.d_model]);
65
+
66
+ // OPTIMIZATION: Scale up initial weights so neurons actually spike (prevent Layer 2 death)
67
+ const scale = Math.sqrt(inFeatures);
68
+ const kQ = this.kernelQ!._data;
69
+ const kK = this.kernelK!._data;
70
+ const kV = this.kernelV!._data;
71
+ for(let i = 0; i < kQ.length; i++) {
72
+ kQ[i] *= scale;
73
+ kK[i] *= scale;
74
+ kV[i] *= scale;
75
+ }
76
+
77
+ // 2. Inisialisasi parameter LIF untuk Q, K, V
78
+ this.betaQKV = new Float32Array(this.d_model);
79
+ this.thresholdQKV = new Float32Array(this.d_model);
80
+ for (let i = 0; i < this.d_model; i++) {
81
+ this.betaQKV[i] = this.betaRange[0] + Math.random() * (this.betaRange[1] - this.betaRange[0]);
82
+ this.thresholdQKV[i] = this.thresholdRange[0] + Math.random() * (this.thresholdRange[1] - this.thresholdRange[0]);
83
+ }
84
+
85
+ // 3. Inisialisasi parameter LIF untuk Attention Scores (Pengganti Softmax)
86
+ this.betaScores = new Float32Array(this.sequenceLength);
87
+ this.thresholdScores = new Float32Array(this.sequenceLength);
88
+ for (let i = 0; i < this.sequenceLength; i++) {
89
+ this.betaScores[i] = 0.9;
90
+ // Ambang batas diturunkan tajam untuk mencegah Dead Neurons
91
+ this.thresholdScores[i] = 1.0;
92
+ }
93
+
94
+ // Inisialisasi Potentials akan dilakukan secara dinamis pada saat forward
95
+ this.potentialsQ = Matrix.fromFlat(new Float32Array(0), [0, 0]);
96
+ this.potentialsK = Matrix.fromFlat(new Float32Array(0), [0, 0]);
97
+ this.potentialsV = Matrix.fromFlat(new Float32Array(0), [0, 0]);
98
+ this.potentialsScores = Matrix.fromFlat(new Float32Array(0), [0, 0]);
99
+ }
100
+ private sqDataBuffer?: Float32Array;
101
+ private skDataBuffer?: Float32Array;
102
+ private svDataBuffer?: Float32Array;
103
+ private dummyLpBuffer?: Float32Array;
104
+ private matchScoresBuffer?: Float32Array;
105
+ private qGatedVBuffer?: Float32Array;
106
+ private outSpikesBuffer?: Float32Array;
107
+ private sScoresDataBuffer?: Float32Array;
108
+ private dummyLpScoresBuffer?: Float32Array;
109
+ private tempMatchesBuffer?: Float32Array;
110
+
111
+ private ensurePotentialsShape(batchSeq: number, seqLen: number) {
112
+ if (this.potentialsQ._shape[0] !== batchSeq || !this.sqDataBuffer) {
113
+ this.potentialsQ = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
114
+ this.potentialsK = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
115
+ this.potentialsV = Matrix.fromFlat(new Float32Array(batchSeq * this.d_model), [batchSeq, this.d_model]);
116
+ this.potentialsScores = Matrix.fromFlat(new Float32Array(batchSeq * seqLen), [batchSeq, seqLen]);
117
+
118
+ this.sqDataBuffer = new Float32Array(batchSeq * this.d_model);
119
+ this.skDataBuffer = new Float32Array(batchSeq * this.d_model);
120
+ this.svDataBuffer = new Float32Array(batchSeq * this.d_model);
121
+ this.dummyLpBuffer = new Float32Array(batchSeq * this.d_model);
122
+ this.matchScoresBuffer = new Float32Array(batchSeq * seqLen);
123
+ this.qGatedVBuffer = new Float32Array(batchSeq * this.d_model);
124
+ this.outSpikesBuffer = new Float32Array(batchSeq * this.d_model);
125
+ this.sScoresDataBuffer = new Float32Array(batchSeq * seqLen);
126
+ this.dummyLpScoresBuffer = new Float32Array(batchSeq * seqLen);
127
+ this.tempMatchesBuffer = new Float32Array(seqLen);
128
+ }
129
+ }
130
+
131
+ public resetState() {
132
+ if (this.potentialsQ) this.potentialsQ._data.fill(0);
133
+ if (this.potentialsK) this.potentialsK._data.fill(0);
134
+ if (this.potentialsV) this.potentialsV._data.fill(0);
135
+ if (this.potentialsScores) this.potentialsScores._data.fill(0);
136
+ }
137
+
138
+ protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
139
+ // Asumsi inputs adalah flat [batch * seqLen, d_model]
140
+ const batchSeq = inputs._shape[0];
141
+ const seqLen = this.sequenceLength;
142
+ const batch = batchSeq / seqLen;
143
+ const d_model = this.d_model;
144
+
145
+ if (!Number.isInteger(batch)) {
146
+ throw new Error(`[SpikingSelfAttention] Jumlah baris input (${batchSeq}) harus merupakan kelipatan dari sequenceLength (${seqLen}).`);
147
+ }
148
+
149
+ this.ensurePotentialsShape(batchSeq, seqLen);
150
+ this.lastInputs = inputs; // Simpan untuk local learning
151
+
152
+ // 1. Proyeksi Q, K, V (Hanya Addisi / Pergeseran Bit karena input spike biner)
153
+ let dotQ = dotProductAddOnly(inputs, this.kernelQ!);
154
+ let dotK = dotProductAddOnly(inputs, this.kernelK!);
155
+ let dotV = dotProductAddOnly(inputs, this.kernelV!);
156
+
157
+ // 2. LIF Step untuk menghasilkan S_Q, S_K, S_V (Matriks Biner)
158
+ const sqData = this.sqDataBuffer!;
159
+ sqData.fill(0);
160
+ const skData = this.skDataBuffer!;
161
+ skData.fill(0);
162
+ const svData = this.svDataBuffer!;
163
+ svData.fill(0);
164
+ const dummyLp = this.dummyLpBuffer!;
165
+ dummyLp.fill(0);
166
+
167
+ // Q
168
+ if (isNativeAvailable()) {
169
+ lifStepNativeWrapper(this.potentialsQ._data, dotQ._data, sqData, dummyLp, this.betaQKV, this.thresholdQKV);
170
+ lifStepNativeWrapper(this.potentialsK._data, dotK._data, skData, dummyLp, this.betaQKV, this.thresholdQKV);
171
+ lifStepNativeWrapper(this.potentialsV._data, dotV._data, svData, dummyLp, this.betaQKV, this.thresholdQKV);
172
+ } else {
173
+ this.runLIF(this.potentialsQ._data, dotQ._data, sqData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
174
+ this.runLIF(this.potentialsK._data, dotK._data, skData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
175
+ this.runLIF(this.potentialsV._data, dotV._data, svData, batchSeq, d_model, this.betaQKV, this.thresholdQKV);
176
+ }
177
+
178
+ // 3. Menghitung Skor Kecocokan (SQ dot SK^T) menggunakan operasi AND / bit-wise addition
179
+ // Hasilnya akan berukuran [batch * seqLen, seqLen]
180
+ const matchScores = this.matchScoresBuffer!;
181
+ matchScores.fill(0);
182
+
183
+ for (let b = 0; b < batch; b++) {
184
+ for (let i = 0; i < seqLen; i++) {
185
+ const qBase = b * seqLen * d_model + i * d_model;
186
+ // Pre-collect non-zero indices for Q to exploit sparsity
187
+ const nonZeroQ: number[] = [];
188
+ for (let d = 0; d < d_model; d++) {
189
+ if (sqData[qBase + d] > 0) nonZeroQ.push(d);
190
+ }
191
+ if (nonZeroQ.length === 0) continue;
192
+
193
+ let maxMatch = 0;
194
+ const tempMatches = this.tempMatchesBuffer!;
195
+ tempMatches.fill(0);
196
+
197
+ for (let j = 0; j < seqLen; j++) {
198
+ let matchCount = 0;
199
+ const kBase = b * seqLen * d_model + j * d_model;
200
+ for (let k = 0; k < nonZeroQ.length; k++) {
201
+ const d = nonZeroQ[k];
202
+ if (skData[kBase + d] > 0) matchCount++;
203
+ }
204
+ tempMatches[j] = matchCount;
205
+ if (matchCount > maxMatch) {
206
+ maxMatch = matchCount;
207
+ }
208
+ }
209
+
210
+ for (let j = 0; j < seqLen; j++) {
211
+ if (maxMatch > 0) {
212
+ matchScores[b * seqLen * seqLen + i * seqLen + j] = tempMatches[j] / maxMatch;
213
+ } else {
214
+ matchScores[b * seqLen * seqLen + i * seqLen + j] = 0;
215
+ }
216
+ }
217
+ }
218
+ }
219
+
220
+ // 4. Pengganti Softmax: Lewatkan skor kecocokan ke lapisan LIF
221
+ const sScoresData = this.sScoresDataBuffer!;
222
+ sScoresData.fill(0);
223
+ const dummyLpScores = this.dummyLpScoresBuffer!;
224
+ dummyLpScores.fill(0);
225
+
226
+ if (isNativeAvailable()) {
227
+ lifStepNativeWrapper(this.potentialsScores._data, matchScores, sScoresData, dummyLpScores, this.betaScores, this.thresholdScores);
228
+ } else {
229
+ this.runLIF(this.potentialsScores._data, matchScores, sScoresData, batchSeq, seqLen, this.betaScores, this.thresholdScores);
230
+ }
231
+
232
+ const outData = this.outSpikesBuffer!;
233
+ outData.fill(0);
234
+
235
+ for (let b = 0; b < batch; b++) {
236
+ for (let j = 0; j < seqLen; j++) {
237
+ const vBase = b * seqLen * d_model + j * d_model;
238
+ // Pre-collect non-zero indices for V to exploit sparsity
239
+ const nonZeroV: number[] = [];
240
+ for (let d = 0; d < d_model; d++) {
241
+ if (svData[vBase + d] > 0) nonZeroV.push(d);
242
+ }
243
+ if (nonZeroV.length === 0) continue;
244
+
245
+ for (let i = 0; i < seqLen; i++) {
246
+ const gradedScore = matchScores[b * seqLen * seqLen + i * seqLen + j];
247
+ if (gradedScore > 0) {
248
+ const outBase = b * seqLen * d_model + i * d_model;
249
+ for (let k = 0; k < nonZeroV.length; k++) {
250
+ const d = nonZeroV[k];
251
+ outData[outBase + d] += gradedScore * svData[vBase + d];
252
+ }
253
+ }
254
+ }
255
+ }
256
+ }
257
+
258
+ // Opsional: Batasi output menjadi biner (spike) jika layer berikutnya menuntut binary matrix
259
+ for (let i = 0; i < outData.length; i++) {
260
+ if (outData[i] > 1.0) outData[i] = 1.0;
261
+ }
262
+
263
+ return Matrix.fromFlat(outData, [batchSeq, d_model]);
264
+ }
265
+
266
+ private runLIF(pot: Float32Array, input: Float32Array, output: Float32Array, batch: number, dim: number, beta: Float32Array, threshold: Float32Array) {
267
+ for (let b = 0; b < batch; b++) {
268
+ const offset = b * dim;
269
+ for (let i = 0; i < dim; i++) {
270
+ const idx = offset + i;
271
+ pot[idx] = Math.min((pot[idx] * beta[i]) + input[idx], 1.0);
272
+ }
273
+ for (let i = 0; i < dim; i++) {
274
+ const idx = offset + i;
275
+ if (pot[idx] >= threshold[i]) {
276
+ output[idx] = 1.0;
277
+ pot[idx] -= threshold[i];
278
+ } else {
279
+ output[idx] = 0.0;
280
+ }
281
+ }
282
+ }
283
+ }
284
+
285
+ public learnAttention(errorSignal: Matrix, learningRate: number = 0.01) {
286
+ if (!this.lastInputs) {
287
+ throw new Error("[SpikingSelfAttention] Cannot run learning before forward() is executed.");
288
+ }
289
+
290
+ const err = errorSignal._data;
291
+ const inputs = this.lastInputs._data;
292
+ const batchSeq = this.lastInputs._shape[0];
293
+ // Karena inputs masuk setelah layer 1, shape-nya [batchSeq, d_model]
294
+ const inFeatures = this.lastInputs._shape[1] || this.d_model;
295
+ const d_model = this.d_model;
296
+
297
+ // Update Local Learning: Karena fungsi non-differentiable rumit,
298
+ // kita mendistribusikan sinyal error secara merata ke kernel Q, K, dan V (Hebbian/Surrogate style)
299
+ const kQ = this.kernelQ!._data;
300
+ const kK = this.kernelK!._data;
301
+ const kV = this.kernelV!._data;
302
+
303
+ for (let b = 0; b < batchSeq; b++) {
304
+ const errOffset = b * d_model;
305
+ const inOffset = b * inFeatures;
306
+ for (let i = 0; i < inFeatures; i++) {
307
+ const inVal = inputs[inOffset + i];
308
+ if (inVal > 0) { // Sparse update
309
+ const kOffset = i * d_model;
310
+ for (let d = 0; d < d_model; d++) {
311
+ // Dopamine drive sangat kecil untuk membangkitkan neuron mati tanpa over-saturate
312
+ const dopamine = 0.00005;
313
+
314
+ let deltaQ = (learningRate * err[errOffset + d] * inVal) + dopamine;
315
+ let deltaK = (learningRate * err[errOffset + d] * inVal) + dopamine;
316
+ let deltaV = (learningRate * err[errOffset + d] * inVal) + dopamine;
317
+
318
+ kQ[kOffset + d] = Math.max(-1.0, Math.min(1.0, kQ[kOffset + d] + deltaQ));
319
+ kK[kOffset + d] = Math.max(-1.0, Math.min(1.0, kK[kOffset + d] + deltaK));
320
+ kV[kOffset + d] = Math.max(-1.0, Math.min(1.0, kV[kOffset + d] + deltaV));
321
+ }
322
+ }
323
+ }
324
+ }
325
+ }
326
+
327
+ public getConfig(): Record<string, any> {
328
+ return {
329
+ ...super.getConfig(),
330
+ d_model: this.d_model,
331
+ sequenceLength: this.sequenceLength,
332
+ kernelInitializer: this.kernelInitializer
333
+ };
334
+ }
335
+ }
@@ -47,8 +47,8 @@ export const lifStepNativeWrapper = (
47
47
  dot: Float32Array,
48
48
  spikes: Float32Array,
49
49
  lastPotentials: Float32Array,
50
- beta: number,
51
- threshold: number
50
+ beta: Float32Array,
51
+ threshold: Float32Array
52
52
  ): void => {
53
53
  if (!native) throw new Error("Spiking Native backend not available");
54
54
  native.lifStepNative(potentials, dot, spikes, lastPotentials, beta, threshold);
@@ -57,7 +57,7 @@ export const lifStepNativeWrapper = (
57
57
  export const maskSurrogateNativeWrapper = (
58
58
  errorSignal: Float32Array,
59
59
  potentials: Float32Array,
60
- threshold: number,
60
+ threshold: Float32Array,
61
61
  windowSize: number
62
62
  ): void => {
63
63
  if (!native) throw new Error("Spiking Native backend not available");
@@ -88,3 +88,39 @@ export const applyAddOnlyDeltaNativeWrapper = (
88
88
  useBias
89
89
  );
90
90
  };
91
+
92
+ export const applyEmbeddingDeltaNativeWrapper = (
93
+ embeddings: Float32Array,
94
+ inputs: Float32Array,
95
+ errorSignal: Float32Array,
96
+ learningRate: number,
97
+ inputDim: number,
98
+ outputDim: number
99
+ ): void => {
100
+ if (!native) throw new Error("Spiking Native backend not available");
101
+ native.applyEmbeddingDeltaNative(
102
+ embeddings,
103
+ inputs,
104
+ errorSignal,
105
+ learningRate,
106
+ inputDim,
107
+ outputDim
108
+ );
109
+ };
110
+
111
+ export const contrastiveHebbianNativeWrapper = (
112
+ spikes: Float32Array,
113
+ errData: Float32Array,
114
+ numPairs: number,
115
+ sequenceLength: number,
116
+ dModel: number
117
+ ): number => {
118
+ if (!native) throw new Error("Spiking Native backend not available");
119
+ return native.contrastiveHebbianNative(
120
+ spikes,
121
+ errData,
122
+ numPairs,
123
+ sequenceLength,
124
+ dModel
125
+ );
126
+ };
@@ -0,0 +1,85 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn contrastive_hebbian_native(
7
+ spikes: Float32Array,
8
+ mut err_data: Float32Array,
9
+ num_pairs: u32,
10
+ sequence_length: u32,
11
+ d_model: u32,
12
+ ) -> f64 {
13
+ let spikes_slice: &[f32] = &spikes;
14
+ let err_slice: &mut [f32] = &mut err_data;
15
+
16
+ let num_pairs = num_pairs as usize;
17
+ let seq_len = sequence_length as usize;
18
+ let d_model = d_model as usize;
19
+ let chunk_size = seq_len * d_model;
20
+
21
+ let total_loss: f32 = err_slice.par_chunks_mut(chunk_size).enumerate().map(|(b, chunk)| {
22
+ let mut local_loss = 0.0f32;
23
+
24
+ if b < num_pairs {
25
+ // Ini adalah vektor Q
26
+ let i = b;
27
+ let p_offset = (num_pairs + i) * chunk_size;
28
+ let n_offset = (num_pairs + ((i + 1) % num_pairs)) * chunk_size;
29
+ let q_offset = i * chunk_size;
30
+
31
+ for rem in 0..chunk_size {
32
+ let q_spike = spikes_slice[q_offset + rem];
33
+ let p_spike = spikes_slice[p_offset + rem];
34
+ let n_spike = spikes_slice[n_offset + rem];
35
+
36
+ let mut pull = p_spike - q_spike;
37
+ if q_spike == 0.0 && p_spike == 0.0 && n_spike == 0.0 {
38
+ pull = 0.05; // Suntik energi
39
+ }
40
+ let push = (q_spike * n_spike) * 0.2;
41
+
42
+ chunk[rem] = pull - push;
43
+
44
+ if pull != 0.0 || push != 0.0 {
45
+ local_loss += pull.abs() + push;
46
+ }
47
+ }
48
+ } else {
49
+ // Ini adalah vektor P atau N
50
+ let p_index = b - num_pairs;
51
+
52
+ // Peran sebagai P untuk i = p_index
53
+ let q_offset_p = p_index * chunk_size;
54
+ let n_offset_p = (num_pairs + ((p_index + 1) % num_pairs)) * chunk_size;
55
+
56
+ // Peran sebagai N untuk i = p_index - 1 (dengan wrap around)
57
+ let i_n = if p_index == 0 { num_pairs - 1 } else { p_index - 1 };
58
+ let q_offset_n = i_n * chunk_size;
59
+
60
+ for rem in 0..chunk_size {
61
+ let q_spike_p = spikes_slice[q_offset_p + rem];
62
+ let p_spike_p = spikes_slice[b * chunk_size + rem];
63
+ let n_spike_p = spikes_slice[n_offset_p + rem];
64
+
65
+ let mut pull_p = p_spike_p - q_spike_p;
66
+ if q_spike_p == 0.0 && p_spike_p == 0.0 && n_spike_p == 0.0 {
67
+ pull_p = 0.05;
68
+ }
69
+ let contrib_p = -pull_p;
70
+
71
+ let q_spike_n = spikes_slice[q_offset_n + rem];
72
+ let n_spike_n = spikes_slice[b * chunk_size + rem];
73
+
74
+ let push_n = (q_spike_n * n_spike_n) * 0.2;
75
+ let contrib_n = -push_n;
76
+
77
+ chunk[rem] = contrib_p + contrib_n;
78
+ }
79
+ }
80
+
81
+ local_loss
82
+ }).sum();
83
+
84
+ total_loss as f64
85
+ }
@@ -0,0 +1,51 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn apply_add_only_delta_native(
7
+ mut kernel: Float32Array,
8
+ mut bias: Float32Array,
9
+ inputs: Float32Array,
10
+ error_signal: Float32Array,
11
+ learning_rate: f64,
12
+ batch: u32,
13
+ in_features: u32,
14
+ units: u32,
15
+ use_bias: bool
16
+ ) {
17
+ let in_feat = in_features as usize;
18
+ let u = units as usize;
19
+ let b_size = batch as usize;
20
+ let lr = learning_rate as f32;
21
+
22
+ let kernel_slice: &mut [f32] = &mut kernel;
23
+ let bias_slice: &mut [f32] = &mut bias;
24
+ let in_slice: &[f32] = &inputs;
25
+ let err_slice: &[f32] = &error_signal;
26
+
27
+ kernel_slice.par_chunks_mut(u).enumerate().for_each(|(k, kernel_row)| {
28
+ for b in 0..b_size {
29
+ if in_slice[b * in_feat + k] > 0.5 {
30
+ let err_offset = b * u;
31
+ for j in 0..u {
32
+ kernel_row[j] += lr * err_slice[err_offset + j];
33
+ }
34
+ }
35
+ }
36
+ for j in 0..u {
37
+ kernel_row[j] = kernel_row[j].clamp(-1.0, 1.0);
38
+ }
39
+ });
40
+
41
+ if use_bias {
42
+ bias_slice.par_iter_mut().enumerate().for_each(|(j, b_val)| {
43
+ let mut sum = 0.0;
44
+ for b in 0..b_size {
45
+ sum += err_slice[b * u + j];
46
+ }
47
+ *b_val += (lr * sum) / (b_size as f32);
48
+ *b_val = b_val.clamp(-1.0, 1.0);
49
+ });
50
+ }
51
+ }
@@ -0,0 +1,47 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+ use rayon::prelude::*;
4
+
5
+ #[napi]
6
+ pub fn dot_product_add_only_native(
7
+ a_data: Float32Array,
8
+ a_rows_orig: u32,
9
+ a_cols_orig: u32,
10
+ b_data: Float32Array,
11
+ b_rows_orig: u32,
12
+ b_cols_orig: u32,
13
+ trans_a: bool,
14
+ trans_b: bool,
15
+ mut out_data: Float32Array
16
+ ) {
17
+ let a_rows = if trans_a { a_cols_orig } else { a_rows_orig } as usize;
18
+ let a_cols = if trans_a { a_rows_orig } else { a_cols_orig } as usize;
19
+ let b_cols = if trans_b { b_rows_orig } else { b_cols_orig } as usize;
20
+
21
+ let a_slice: &[f32] = &a_data;
22
+ let b_slice: &[f32] = &b_data;
23
+ let out_slice: &mut [f32] = &mut out_data;
24
+
25
+ out_slice.par_chunks_mut(b_cols).enumerate().for_each(|(i, out_row)| {
26
+ let a_offset = i * a_cols;
27
+ for k in 0..a_cols {
28
+ let a_val = if trans_a {
29
+ a_slice[k * a_rows + i]
30
+ } else {
31
+ a_slice[a_offset + k]
32
+ };
33
+
34
+ if a_val > 0.5 {
35
+ let b_offset = k * b_cols;
36
+ for j in 0..b_cols {
37
+ let b_val = if trans_b {
38
+ b_slice[j * a_cols + k]
39
+ } else {
40
+ b_slice[b_offset + j]
41
+ };
42
+ out_row[j] += b_val;
43
+ }
44
+ }
45
+ }
46
+ });
47
+ }
@@ -0,0 +1,28 @@
1
+ use napi_derive::napi;
2
+ use napi::bindgen_prelude::Float32Array;
3
+
4
+ #[napi]
5
+ pub fn apply_embedding_delta_native(
6
+ mut embeddings: Float32Array,
7
+ inputs: Float32Array,
8
+ error_signal: Float32Array,
9
+ learning_rate: f64,
10
+ input_dim: u32,
11
+ output_dim: u32
12
+ ) {
13
+ let batch = inputs.len();
14
+ let out_dim = output_dim as usize;
15
+
16
+ for b in 0..batch {
17
+ let token_idx = inputs[b] as i32;
18
+ if token_idx >= 0 && token_idx < input_dim as i32 {
19
+ let token_idx = token_idx as usize;
20
+ let emb_offset = token_idx * out_dim;
21
+ let err_offset = b * out_dim;
22
+ for j in 0..out_dim {
23
+ embeddings[emb_offset + j] += learning_rate as f32 * error_signal[err_offset + j];
24
+ embeddings[emb_offset + j] = embeddings[emb_offset + j].clamp(-1.0, 1.0);
25
+ }
26
+ }
27
+ }
28
+ }