@oxide-js/spiking 1.1.0 → 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +11 -0
- package/index.cjs +322 -0
- package/index.d.ts +4 -13
- package/index.js +5 -2
- package/package.json +1 -1
- package/spiking-native.linux-x64-gnu.node +0 -0
- package/src/index.ts +1 -2
- package/src/layers/SpikingDense.ts +55 -36
- package/src/layers/SpikingEmbedding.ts +123 -135
- package/src/native_backend.ts +22 -3
- package/src-rust/src/delta.rs +51 -0
- package/src-rust/src/dot_product.rs +47 -0
- package/src-rust/src/embedding.rs +28 -0
- package/src-rust/src/lib.rs +14 -460
- package/src-rust/src/lif.rs +44 -0
- package/src-rust/src/surrogate.rs +28 -0
- package/test/test_embedding_overlap.ts +181 -0
- package/examples/demo.ts +0 -101
- package/src/models/SpikingSentenceEmbedder.ts +0 -135
|
@@ -2,41 +2,36 @@ import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
|
|
|
2
2
|
import { Matrix, mj } from "@oxide-js/core";
|
|
3
3
|
import {
|
|
4
4
|
isNativeAvailable,
|
|
5
|
-
lifStepNativeWrapper,
|
|
6
|
-
maskSurrogateNativeWrapper
|
|
5
|
+
lifStepNativeWrapper,
|
|
6
|
+
maskSurrogateNativeWrapper
|
|
7
7
|
} from "../native_backend.js";
|
|
8
8
|
|
|
9
9
|
export interface SpikingEmbeddingConfig extends LayerConfig {
|
|
10
|
-
inputDim: number;
|
|
11
|
-
outputDim: number;
|
|
12
|
-
|
|
13
|
-
threshold?: number; // Ambang batas Spike
|
|
14
|
-
embeddingsInitializer?: string; // Tipe inisialisasi bobot
|
|
10
|
+
inputDim: number;
|
|
11
|
+
outputDim: number;
|
|
12
|
+
embeddingsInitializer?: string;
|
|
15
13
|
}
|
|
16
14
|
|
|
17
15
|
export class SpikingEmbedding extends BaseLayer {
|
|
18
16
|
public inputDim: number;
|
|
19
17
|
public outputDim: number;
|
|
20
|
-
public
|
|
21
|
-
public
|
|
18
|
+
public embeddingsInitializer: string;
|
|
19
|
+
public beta!: Float32Array;
|
|
20
|
+
public threshold!: Float32Array;
|
|
22
21
|
|
|
23
22
|
public potentials!: Matrix;
|
|
24
23
|
public lastPotentials?: Matrix;
|
|
25
24
|
public lastInputs?: Matrix;
|
|
26
25
|
public lastSpikes?: Matrix;
|
|
27
26
|
|
|
28
|
-
public
|
|
29
|
-
|
|
30
|
-
public get kernel(): Matrix | undefined {
|
|
31
|
-
return this.getParameter("kernel");
|
|
27
|
+
public get embeddings(): Matrix | undefined {
|
|
28
|
+
return this.getParameter("embeddings");
|
|
32
29
|
}
|
|
33
30
|
|
|
34
31
|
constructor(config: SpikingEmbeddingConfig) {
|
|
35
32
|
super(config);
|
|
36
33
|
this.inputDim = config.inputDim;
|
|
37
34
|
this.outputDim = config.outputDim;
|
|
38
|
-
this.beta = config.beta ?? 0.9;
|
|
39
|
-
this.threshold = config.threshold ?? 1.0;
|
|
40
35
|
this.embeddingsInitializer = config.embeddingsInitializer || "glorot_normal";
|
|
41
36
|
}
|
|
42
37
|
|
|
@@ -47,8 +42,26 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
47
42
|
|
|
48
43
|
public build(inputShape: number[]): void {
|
|
49
44
|
super.build(inputShape);
|
|
50
|
-
|
|
51
|
-
this.
|
|
45
|
+
|
|
46
|
+
const embVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
|
|
47
|
+
this.addParameter("embeddings", embVal, true, [this.inputDim, this.outputDim]);
|
|
48
|
+
|
|
49
|
+
// Inisialisasi beta dan threshold secara acak untuk setiap neuron
|
|
50
|
+
this.beta = new Float32Array(this.outputDim);
|
|
51
|
+
this.threshold = new Float32Array(this.outputDim);
|
|
52
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
53
|
+
this.beta[i] = 0.8 + Math.random() * 0.19; // Random 0.8 - 0.99
|
|
54
|
+
this.threshold[i] = 0.5 + Math.random() * 0.5; // Random 0.5 - 1.0 (Max 1.0)
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
// Potentials start at 0, shape [batch, outputDim].
|
|
58
|
+
this.potentials = Matrix.fromFlat(new Float32Array(this.outputDim), [1, this.outputDim]);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
private ensurePotentialsShape(batch: number) {
|
|
62
|
+
if (this.potentials._shape[0] !== batch) {
|
|
63
|
+
this.potentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
64
|
+
}
|
|
52
65
|
}
|
|
53
66
|
|
|
54
67
|
public resetState() {
|
|
@@ -58,38 +71,26 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
58
71
|
this.lastSpikes = undefined;
|
|
59
72
|
}
|
|
60
73
|
|
|
61
|
-
private ensurePotentialsShape(batch: number) {
|
|
62
|
-
if (!this.potentials || this.potentials._shape[0] !== batch) {
|
|
63
|
-
this.potentials = Matrix.fromFlat(
|
|
64
|
-
new Float32Array(batch * this.outputDim),
|
|
65
|
-
[batch, this.outputDim]
|
|
66
|
-
);
|
|
67
|
-
}
|
|
68
|
-
}
|
|
69
|
-
|
|
70
74
|
protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
|
|
71
|
-
const kernel = this.kernel!._data;
|
|
72
75
|
const batch = inputs._shape[0];
|
|
73
|
-
const inputData = inputs._data;
|
|
74
|
-
|
|
75
76
|
this.ensurePotentialsShape(batch);
|
|
76
|
-
|
|
77
|
-
// 1.
|
|
77
|
+
|
|
78
|
+
// 1. Embedding lookup
|
|
79
|
+
const emb = this.embeddings!;
|
|
78
80
|
const dotData = new Float32Array(batch * this.outputDim);
|
|
81
|
+
|
|
79
82
|
for (let b = 0; b < batch; b++) {
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
}
|
|
89
|
-
}
|
|
83
|
+
const tokenIdx = inputs._data[b];
|
|
84
|
+
if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
|
|
85
|
+
const embOffset = tokenIdx * this.outputDim;
|
|
86
|
+
const dotOffset = b * this.outputDim;
|
|
87
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
88
|
+
dotData[dotOffset + i] = emb._data[embOffset + i];
|
|
89
|
+
}
|
|
90
|
+
}
|
|
90
91
|
}
|
|
91
|
-
|
|
92
|
-
// 2
|
|
92
|
+
|
|
93
|
+
// 2. Leaky Integrate and Fire
|
|
93
94
|
const outData = new Float32Array(batch * this.outputDim);
|
|
94
95
|
const outSpikes = Matrix.fromFlat(outData, [batch, this.outputDim]);
|
|
95
96
|
this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
@@ -98,130 +99,117 @@ export class SpikingEmbedding extends BaseLayer {
|
|
|
98
99
|
lifStepNativeWrapper(
|
|
99
100
|
this.potentials._data,
|
|
100
101
|
dotData,
|
|
101
|
-
|
|
102
|
+
outData,
|
|
102
103
|
this.lastPotentials._data,
|
|
103
104
|
this.beta,
|
|
104
105
|
this.threshold
|
|
105
106
|
);
|
|
106
107
|
} else {
|
|
107
108
|
const potData = this.potentials._data;
|
|
108
|
-
const thresh = this.threshold;
|
|
109
109
|
const lpData = this.lastPotentials._data;
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
110
|
+
|
|
111
|
+
for (let b = 0; b < batch; b++) {
|
|
112
|
+
const offset = b * this.outputDim;
|
|
113
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
114
|
+
const idx = offset + i;
|
|
115
|
+
potData[idx] = Math.min((potData[idx] * this.beta[i]) + dotData[idx], 1.0); // Clamp potential max 1.0
|
|
116
|
+
lpData[idx] = potData[idx];
|
|
117
|
+
}
|
|
118
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
119
|
+
const idx = offset + i;
|
|
120
|
+
if (potData[idx] >= this.threshold[i]) {
|
|
121
|
+
outData[idx] = 1;
|
|
122
|
+
potData[idx] -= this.threshold[i];
|
|
123
|
+
} else {
|
|
124
|
+
outData[idx] = 0;
|
|
125
|
+
}
|
|
126
|
+
}
|
|
121
127
|
}
|
|
122
128
|
}
|
|
123
129
|
|
|
124
|
-
// Simpan memori untuk update bobot
|
|
125
130
|
this.lastInputs = inputs;
|
|
126
131
|
this.lastSpikes = outSpikes;
|
|
127
132
|
|
|
128
133
|
return outSpikes;
|
|
129
134
|
}
|
|
130
135
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
throw new Error("[SpikingEmbedding] Cannot run learnEmbedding() before forward() is executed. 'lastInputs' is undefined.");
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
const kernel = this.kernel!._data;
|
|
138
|
-
const inputData = this.lastInputs._data;
|
|
139
|
-
const batch = this.lastInputs._shape[0];
|
|
140
|
-
|
|
141
|
-
// Hitung error yang mampir ke embedding
|
|
142
|
-
// E * B (Feedback Alignment)
|
|
143
|
-
// Gunakan matmul biasa karena B adalah float, dan errorFromNext mungkin float
|
|
144
|
-
const eHidden = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
145
|
-
// Namun karena OxideJS Matrix belum memiliki fungsi dot produk standar terbuka yang stabil,
|
|
146
|
-
// kita harus hati-hati di sini. Untuk simplifikasi, eHidden = errorFromNext * B.
|
|
147
|
-
// Kita asumsikan ada utilitas dotProduct standar dari core.
|
|
148
|
-
// Jika B adalah matriks Dense (dimensi: outUnits x hiddenUnits), maka
|
|
149
|
-
// eHidden [batch, hiddenUnits] = errorFromNext [batch, outUnits] dot B [outUnits, hiddenUnits]
|
|
136
|
+
public learnEmbedding(errorSignal: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
|
|
137
|
+
// Broadcast error mundur (Feedback Alignment)
|
|
138
|
+
let eHidden = mj.dotProduct(errorSignal, B, undefined, false, false); // E * B
|
|
150
139
|
|
|
151
|
-
// Kita panggil dot product standar (bukan Add-Only, karena error dan B sama-sama float)
|
|
152
|
-
let eHiddenMatrix = mj.dotProduct(errorFromNext, B, undefined, false, false);
|
|
153
|
-
|
|
154
140
|
// Surrogate Mask: Boxcar
|
|
155
141
|
if (this.lastPotentials) {
|
|
142
|
+
const eData = eHidden._data;
|
|
143
|
+
const pData = this.lastPotentials._data;
|
|
144
|
+
const windowSize = 1.0;
|
|
145
|
+
|
|
156
146
|
if (isNativeAvailable()) {
|
|
157
147
|
maskSurrogateNativeWrapper(
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
this.threshold,
|
|
161
|
-
|
|
148
|
+
eData,
|
|
149
|
+
pData,
|
|
150
|
+
this.threshold,
|
|
151
|
+
windowSize
|
|
162
152
|
);
|
|
163
153
|
} else {
|
|
164
|
-
const
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
154
|
+
const batch = eHidden._shape[0];
|
|
155
|
+
for (let b = 0; b < batch; b++) {
|
|
156
|
+
const offset = b * this.outputDim;
|
|
157
|
+
for (let i = 0; i < this.outputDim; i++) {
|
|
158
|
+
const idx = offset + i;
|
|
159
|
+
if (Math.abs(pData[idx] - this.threshold[i]) > windowSize) {
|
|
160
|
+
eData[idx] = 0;
|
|
161
|
+
}
|
|
172
162
|
}
|
|
173
163
|
}
|
|
174
164
|
}
|
|
175
165
|
}
|
|
176
166
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
for (let b = 0; b < batch; b++) {
|
|
180
|
-
const tokenId = Math.round(inputData[b]);
|
|
181
|
-
if (tokenId >= 0 && tokenId < this.inputDim) {
|
|
182
|
-
const kOffset = tokenId * this.outputDim;
|
|
183
|
-
const errOffset = b * this.outputDim;
|
|
184
|
-
for (let j = 0; j < this.outputDim; j++) {
|
|
185
|
-
kernel[kOffset + j] += learningRate * err[errOffset + j];
|
|
186
|
-
}
|
|
187
|
-
}
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
return eHiddenMatrix;
|
|
167
|
+
this.applyEmbeddingDelta(eHidden, learningRate);
|
|
168
|
+
return eHidden;
|
|
191
169
|
}
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
170
|
+
|
|
171
|
+
private applyEmbeddingDelta(errorSignal: Matrix, learningRate: number) {
|
|
172
|
+
if (!this.lastInputs || !this.lastSpikes) {
|
|
173
|
+
throw new Error("[SpikingEmbedding] Cannot run learning before forward() is executed.");
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
const embeddings = this.embeddings!._data;
|
|
177
|
+
const inputs = this.lastInputs._data;
|
|
178
|
+
const err = errorSignal._data;
|
|
179
|
+
|
|
180
|
+
const batch = this.lastInputs._shape[0];
|
|
181
|
+
const outputDim = this.outputDim;
|
|
182
|
+
|
|
183
|
+
if (isNativeAvailable()) {
|
|
184
|
+
applyEmbeddingDeltaNativeWrapper(
|
|
185
|
+
embeddings,
|
|
186
|
+
inputs,
|
|
187
|
+
err,
|
|
188
|
+
learningRate,
|
|
189
|
+
this.inputDim,
|
|
190
|
+
outputDim
|
|
191
|
+
);
|
|
192
|
+
} else {
|
|
193
|
+
for (let b = 0; b < batch; b++) {
|
|
194
|
+
const tokenIdx = inputs[b];
|
|
195
|
+
if (tokenIdx >= 0 && tokenIdx < this.inputDim) {
|
|
196
|
+
const embOffset = tokenIdx * outputDim;
|
|
197
|
+
const errOffset = b * outputDim;
|
|
198
|
+
for (let j = 0; j < outputDim; j++) {
|
|
199
|
+
embeddings[embOffset + j] += learningRate * err[errOffset + j];
|
|
200
|
+
embeddings[embOffset + j] = Math.max(-1.0, Math.min(1.0, embeddings[embOffset + j])); // Clamp weight [-1, 1]
|
|
222
201
|
}
|
|
223
202
|
}
|
|
224
203
|
}
|
|
225
204
|
}
|
|
226
205
|
}
|
|
206
|
+
|
|
207
|
+
public getConfig(): Record<string, any> {
|
|
208
|
+
return {
|
|
209
|
+
...super.getConfig(),
|
|
210
|
+
inputDim: this.inputDim,
|
|
211
|
+
outputDim: this.outputDim,
|
|
212
|
+
embeddingsInitializer: this.embeddingsInitializer
|
|
213
|
+
};
|
|
214
|
+
}
|
|
227
215
|
}
|
package/src/native_backend.ts
CHANGED
|
@@ -47,8 +47,8 @@ export const lifStepNativeWrapper = (
|
|
|
47
47
|
dot: Float32Array,
|
|
48
48
|
spikes: Float32Array,
|
|
49
49
|
lastPotentials: Float32Array,
|
|
50
|
-
beta:
|
|
51
|
-
threshold:
|
|
50
|
+
beta: Float32Array,
|
|
51
|
+
threshold: Float32Array
|
|
52
52
|
): void => {
|
|
53
53
|
if (!native) throw new Error("Spiking Native backend not available");
|
|
54
54
|
native.lifStepNative(potentials, dot, spikes, lastPotentials, beta, threshold);
|
|
@@ -57,7 +57,7 @@ export const lifStepNativeWrapper = (
|
|
|
57
57
|
export const maskSurrogateNativeWrapper = (
|
|
58
58
|
errorSignal: Float32Array,
|
|
59
59
|
potentials: Float32Array,
|
|
60
|
-
threshold:
|
|
60
|
+
threshold: Float32Array,
|
|
61
61
|
windowSize: number
|
|
62
62
|
): void => {
|
|
63
63
|
if (!native) throw new Error("Spiking Native backend not available");
|
|
@@ -88,3 +88,22 @@ export const applyAddOnlyDeltaNativeWrapper = (
|
|
|
88
88
|
useBias
|
|
89
89
|
);
|
|
90
90
|
};
|
|
91
|
+
|
|
92
|
+
export const applyEmbeddingDeltaNativeWrapper = (
|
|
93
|
+
embeddings: Float32Array,
|
|
94
|
+
inputs: Float32Array,
|
|
95
|
+
errorSignal: Float32Array,
|
|
96
|
+
learningRate: number,
|
|
97
|
+
inputDim: number,
|
|
98
|
+
outputDim: number
|
|
99
|
+
): void => {
|
|
100
|
+
if (!native) throw new Error("Spiking Native backend not available");
|
|
101
|
+
native.applyEmbeddingDeltaNative(
|
|
102
|
+
embeddings,
|
|
103
|
+
inputs,
|
|
104
|
+
errorSignal,
|
|
105
|
+
learningRate,
|
|
106
|
+
inputDim,
|
|
107
|
+
outputDim
|
|
108
|
+
);
|
|
109
|
+
};
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
use napi_derive::napi;
|
|
2
|
+
use napi::bindgen_prelude::Float32Array;
|
|
3
|
+
use rayon::prelude::*;
|
|
4
|
+
|
|
5
|
+
#[napi]
|
|
6
|
+
pub fn apply_add_only_delta_native(
|
|
7
|
+
mut kernel: Float32Array,
|
|
8
|
+
mut bias: Float32Array,
|
|
9
|
+
inputs: Float32Array,
|
|
10
|
+
error_signal: Float32Array,
|
|
11
|
+
learning_rate: f64,
|
|
12
|
+
batch: u32,
|
|
13
|
+
in_features: u32,
|
|
14
|
+
units: u32,
|
|
15
|
+
use_bias: bool
|
|
16
|
+
) {
|
|
17
|
+
let in_feat = in_features as usize;
|
|
18
|
+
let u = units as usize;
|
|
19
|
+
let b_size = batch as usize;
|
|
20
|
+
let lr = learning_rate as f32;
|
|
21
|
+
|
|
22
|
+
let kernel_slice: &mut [f32] = &mut kernel;
|
|
23
|
+
let bias_slice: &mut [f32] = &mut bias;
|
|
24
|
+
let in_slice: &[f32] = &inputs;
|
|
25
|
+
let err_slice: &[f32] = &error_signal;
|
|
26
|
+
|
|
27
|
+
kernel_slice.par_chunks_mut(u).enumerate().for_each(|(k, kernel_row)| {
|
|
28
|
+
for b in 0..b_size {
|
|
29
|
+
if in_slice[b * in_feat + k] > 0.5 {
|
|
30
|
+
let err_offset = b * u;
|
|
31
|
+
for j in 0..u {
|
|
32
|
+
kernel_row[j] += lr * err_slice[err_offset + j];
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
for j in 0..u {
|
|
37
|
+
kernel_row[j] = kernel_row[j].clamp(-1.0, 1.0);
|
|
38
|
+
}
|
|
39
|
+
});
|
|
40
|
+
|
|
41
|
+
if use_bias {
|
|
42
|
+
bias_slice.par_iter_mut().enumerate().for_each(|(j, b_val)| {
|
|
43
|
+
let mut sum = 0.0;
|
|
44
|
+
for b in 0..b_size {
|
|
45
|
+
sum += err_slice[b * u + j];
|
|
46
|
+
}
|
|
47
|
+
*b_val += (lr * sum) / (b_size as f32);
|
|
48
|
+
*b_val = b_val.clamp(-1.0, 1.0);
|
|
49
|
+
});
|
|
50
|
+
}
|
|
51
|
+
}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
use napi_derive::napi;
|
|
2
|
+
use napi::bindgen_prelude::Float32Array;
|
|
3
|
+
use rayon::prelude::*;
|
|
4
|
+
|
|
5
|
+
#[napi]
|
|
6
|
+
pub fn dot_product_add_only_native(
|
|
7
|
+
a_data: Float32Array,
|
|
8
|
+
a_rows_orig: u32,
|
|
9
|
+
a_cols_orig: u32,
|
|
10
|
+
b_data: Float32Array,
|
|
11
|
+
b_rows_orig: u32,
|
|
12
|
+
b_cols_orig: u32,
|
|
13
|
+
trans_a: bool,
|
|
14
|
+
trans_b: bool,
|
|
15
|
+
mut out_data: Float32Array
|
|
16
|
+
) {
|
|
17
|
+
let a_rows = if trans_a { a_cols_orig } else { a_rows_orig } as usize;
|
|
18
|
+
let a_cols = if trans_a { a_rows_orig } else { a_cols_orig } as usize;
|
|
19
|
+
let b_cols = if trans_b { b_rows_orig } else { b_cols_orig } as usize;
|
|
20
|
+
|
|
21
|
+
let a_slice: &[f32] = &a_data;
|
|
22
|
+
let b_slice: &[f32] = &b_data;
|
|
23
|
+
let out_slice: &mut [f32] = &mut out_data;
|
|
24
|
+
|
|
25
|
+
out_slice.par_chunks_mut(b_cols).enumerate().for_each(|(i, out_row)| {
|
|
26
|
+
let a_offset = i * a_cols;
|
|
27
|
+
for k in 0..a_cols {
|
|
28
|
+
let a_val = if trans_a {
|
|
29
|
+
a_slice[k * a_rows + i]
|
|
30
|
+
} else {
|
|
31
|
+
a_slice[a_offset + k]
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
if a_val > 0.5 {
|
|
35
|
+
let b_offset = k * b_cols;
|
|
36
|
+
for j in 0..b_cols {
|
|
37
|
+
let b_val = if trans_b {
|
|
38
|
+
b_slice[j * a_cols + k]
|
|
39
|
+
} else {
|
|
40
|
+
b_slice[b_offset + j]
|
|
41
|
+
};
|
|
42
|
+
out_row[j] += b_val;
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
});
|
|
47
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
use napi_derive::napi;
|
|
2
|
+
use napi::bindgen_prelude::Float32Array;
|
|
3
|
+
|
|
4
|
+
#[napi]
|
|
5
|
+
pub fn apply_embedding_delta_native(
|
|
6
|
+
mut embeddings: Float32Array,
|
|
7
|
+
inputs: Float32Array,
|
|
8
|
+
error_signal: Float32Array,
|
|
9
|
+
learning_rate: f64,
|
|
10
|
+
input_dim: u32,
|
|
11
|
+
output_dim: u32
|
|
12
|
+
) {
|
|
13
|
+
let batch = inputs.len();
|
|
14
|
+
let out_dim = output_dim as usize;
|
|
15
|
+
|
|
16
|
+
for b in 0..batch {
|
|
17
|
+
let token_idx = inputs[b] as i32;
|
|
18
|
+
if token_idx >= 0 && token_idx < input_dim as i32 {
|
|
19
|
+
let token_idx = token_idx as usize;
|
|
20
|
+
let emb_offset = token_idx * out_dim;
|
|
21
|
+
let err_offset = b * out_dim;
|
|
22
|
+
for j in 0..out_dim {
|
|
23
|
+
embeddings[emb_offset + j] += learning_rate as f32 * error_signal[err_offset + j];
|
|
24
|
+
embeddings[emb_offset + j] = embeddings[emb_offset + j].clamp(-1.0, 1.0);
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
}
|