@oxide-js/spiking 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/examples/demo.ts +101 -0
- package/index.d.ts +19 -0
- package/index.js +316 -0
- package/package.json +47 -0
- package/src/index.ts +5 -0
- package/src/layers/SpikingDense.ts +237 -0
- package/src/layers/SpikingEmbedding.ts +227 -0
- package/src/math/dotProductAddOnly.ts +229 -0
- package/src/models/SpikingSentenceEmbedder.ts +135 -0
- package/src/native_backend.ts +90 -0
- package/src-rust/Cargo.lock +324 -0
- package/src-rust/Cargo.toml +17 -0
- package/src-rust/build.rs +5 -0
- package/src-rust/src/lib.rs +462 -0
- package/test/test_embedding.ts +126 -0
- package/test/test_xor.ts +122 -0
- package/tsconfig.json +9 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
|
|
2
|
+
import { Matrix, mj } from "@oxide-js/core";
|
|
3
|
+
import {
|
|
4
|
+
isNativeAvailable,
|
|
5
|
+
lifStepNativeWrapper,
|
|
6
|
+
maskSurrogateNativeWrapper,
|
|
7
|
+
applyAddOnlyDeltaNativeWrapper
|
|
8
|
+
} from "../native_backend.js";
|
|
9
|
+
import dotProductAddOnly from "../math/dotProductAddOnly.js";
|
|
10
|
+
export interface SpikingDenseConfig extends LayerConfig {
|
|
11
|
+
units: number;
|
|
12
|
+
useBias?: boolean;
|
|
13
|
+
kernelInitializer?: string;
|
|
14
|
+
biasInitializer?: string;
|
|
15
|
+
beta?: number;
|
|
16
|
+
threshold?: number;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
export class SpikingDense extends BaseLayer {
|
|
20
|
+
public units: number;
|
|
21
|
+
public useBias: boolean;
|
|
22
|
+
public kernelInitializer: string;
|
|
23
|
+
public biasInitializer: string;
|
|
24
|
+
public beta: number;
|
|
25
|
+
public threshold: number;
|
|
26
|
+
|
|
27
|
+
public potentials!: Matrix;
|
|
28
|
+
public lastPotentials?: Matrix;
|
|
29
|
+
public lastInputs?: Matrix;
|
|
30
|
+
public lastSpikes?: Matrix;
|
|
31
|
+
|
|
32
|
+
public get kernel(): Matrix | undefined {
|
|
33
|
+
return this.getParameter("kernel");
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
public get bias(): Matrix | undefined {
|
|
37
|
+
return this.getParameter("bias");
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
constructor(config: SpikingDenseConfig) {
|
|
41
|
+
super(config);
|
|
42
|
+
this.units = config.units;
|
|
43
|
+
this.useBias = config.useBias ?? true;
|
|
44
|
+
this.kernelInitializer = config.kernelInitializer || "glorot_normal";
|
|
45
|
+
this.biasInitializer = config.biasInitializer || "zeros";
|
|
46
|
+
this.beta = config.beta ?? 0.9;
|
|
47
|
+
this.threshold = config.threshold ?? 1.0;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
public computeOutputShape(inputShape: number[]): number[] {
|
|
51
|
+
const batch = inputShape[0] ?? -1;
|
|
52
|
+
return [batch, this.units];
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
public build(inputShape: number[]): void {
|
|
56
|
+
super.build(inputShape);
|
|
57
|
+
|
|
58
|
+
const inFeatures = inputShape[inputShape.length - 1];
|
|
59
|
+
|
|
60
|
+
const kernelVal = this.createInitializer(this.kernelInitializer, [inFeatures, this.units]);
|
|
61
|
+
this.addParameter("kernel", kernelVal, true, [inFeatures, this.units]);
|
|
62
|
+
|
|
63
|
+
if (this.useBias) {
|
|
64
|
+
const biasVal = this.createInitializer(this.biasInitializer, [this.units, 1]);
|
|
65
|
+
this.addParameter("bias", biasVal, true, [this.units, 1]);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
// Inisialisasi state
|
|
69
|
+
this.potentials = Matrix.fromFlat(new Float32Array(this.units), [1, this.units]);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
private ensurePotentialsShape(batch: number) {
|
|
73
|
+
if (this.potentials._shape[0] !== batch) {
|
|
74
|
+
this.potentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
public resetState() {
|
|
79
|
+
if (this.potentials) this.potentials._data.fill(0);
|
|
80
|
+
this.lastPotentials = undefined;
|
|
81
|
+
this.lastInputs = undefined;
|
|
82
|
+
this.lastSpikes = undefined;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
|
|
86
|
+
const kernel = this.kernel!;
|
|
87
|
+
const batch = inputs._shape[0];
|
|
88
|
+
this.ensurePotentialsShape(batch);
|
|
89
|
+
|
|
90
|
+
// 1. Spiking-optimized matrix multiplication (Add-Only)
|
|
91
|
+
let dot = dotProductAddOnly(inputs, kernel);
|
|
92
|
+
|
|
93
|
+
// 2. Add bias
|
|
94
|
+
if (this.useBias && this.bias) {
|
|
95
|
+
mj.addBiasRow(dot, this.bias);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// 3 & 4. Leaky Integrate, Fire & Reset
|
|
99
|
+
const outData = new Float32Array(batch * this.units);
|
|
100
|
+
const outSpikes = Matrix.fromFlat(outData, [batch, this.units]);
|
|
101
|
+
this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.units), [batch, this.units]);
|
|
102
|
+
|
|
103
|
+
if (isNativeAvailable()) {
|
|
104
|
+
lifStepNativeWrapper(
|
|
105
|
+
this.potentials._data,
|
|
106
|
+
dot._data,
|
|
107
|
+
outSpikes._data,
|
|
108
|
+
this.lastPotentials._data,
|
|
109
|
+
this.beta,
|
|
110
|
+
this.threshold
|
|
111
|
+
);
|
|
112
|
+
} else {
|
|
113
|
+
const potData = this.potentials._data;
|
|
114
|
+
const dotData = dot._data;
|
|
115
|
+
const thresh = this.threshold;
|
|
116
|
+
const lpData = this.lastPotentials._data;
|
|
117
|
+
for (let i = 0; i < potData.length; i++) {
|
|
118
|
+
potData[i] = (potData[i] * this.beta) + dotData[i];
|
|
119
|
+
lpData[i] = potData[i];
|
|
120
|
+
}
|
|
121
|
+
for (let i = 0; i < potData.length; i++) {
|
|
122
|
+
if (potData[i] >= thresh) {
|
|
123
|
+
outData[i] = 1;
|
|
124
|
+
potData[i] -= thresh;
|
|
125
|
+
} else {
|
|
126
|
+
outData[i] = 0;
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// Simpan memori untuk belajar
|
|
132
|
+
this.lastInputs = inputs;
|
|
133
|
+
this.lastSpikes = outSpikes;
|
|
134
|
+
|
|
135
|
+
return outSpikes;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
public learnOutput(errorSignal: Matrix, learningRate: number = 0.01): Matrix {
|
|
139
|
+
this.applyAddOnlyDelta(errorSignal, learningRate);
|
|
140
|
+
return errorSignal;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
public learnHidden(errorFromNext: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
|
|
144
|
+
// Broadcast error mundur
|
|
145
|
+
let eHidden = mj.dotProduct(errorFromNext, B, undefined, false, false); // E * B
|
|
146
|
+
|
|
147
|
+
// Surrogate Mask: Boxcar (Murni Add-Only mask, tanpa perkalian float!)
|
|
148
|
+
if (this.lastPotentials) {
|
|
149
|
+
if (isNativeAvailable()) {
|
|
150
|
+
maskSurrogateNativeWrapper(
|
|
151
|
+
eHidden._data,
|
|
152
|
+
this.lastPotentials._data,
|
|
153
|
+
this.threshold,
|
|
154
|
+
1.0
|
|
155
|
+
);
|
|
156
|
+
} else {
|
|
157
|
+
const eData = eHidden._data;
|
|
158
|
+
const pData = this.lastPotentials._data;
|
|
159
|
+
const thresh = this.threshold;
|
|
160
|
+
const windowSize = 1.0;
|
|
161
|
+
|
|
162
|
+
for (let i = 0; i < eData.length; i++) {
|
|
163
|
+
if (Math.abs(pData[i] - thresh) > windowSize) {
|
|
164
|
+
eData[i] = 0;
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
this.applyAddOnlyDelta(eHidden, learningRate);
|
|
171
|
+
return eHidden;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
private applyAddOnlyDelta(errorSignal: Matrix, learningRate: number) {
|
|
175
|
+
if (!this.lastInputs || !this.lastSpikes) {
|
|
176
|
+
throw new Error("[SpikingDense] Cannot run learning before forward() is executed. 'lastInputs' or 'lastSpikes' is undefined.");
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
const kernel = this.kernel!._data;
|
|
180
|
+
const inputs = this.lastInputs._data;
|
|
181
|
+
const err = errorSignal._data;
|
|
182
|
+
|
|
183
|
+
const batch = this.lastInputs._shape[0];
|
|
184
|
+
const inFeatures = this.lastInputs._shape[1];
|
|
185
|
+
const units = this.units;
|
|
186
|
+
|
|
187
|
+
if (isNativeAvailable()) {
|
|
188
|
+
const dummyBias = this.useBias && this.bias ? this.bias._data : new Float32Array(0);
|
|
189
|
+
applyAddOnlyDeltaNativeWrapper(
|
|
190
|
+
kernel,
|
|
191
|
+
dummyBias,
|
|
192
|
+
inputs,
|
|
193
|
+
err,
|
|
194
|
+
learningRate,
|
|
195
|
+
batch,
|
|
196
|
+
inFeatures,
|
|
197
|
+
units,
|
|
198
|
+
this.useBias
|
|
199
|
+
);
|
|
200
|
+
} else {
|
|
201
|
+
// Delta rule add-only
|
|
202
|
+
for (let b = 0; b < batch; b++) {
|
|
203
|
+
const inOffset = b * inFeatures;
|
|
204
|
+
const errOffset = b * units;
|
|
205
|
+
|
|
206
|
+
for (let k = 0; k < inFeatures; k++) {
|
|
207
|
+
// HANYA update jika input menyala (Spike = 1) -> Add Only Update!
|
|
208
|
+
if (inputs[inOffset + k] === 1) {
|
|
209
|
+
const kOffset = k * units;
|
|
210
|
+
for (let j = 0; j < units; j++) {
|
|
211
|
+
kernel[kOffset + j] += learningRate * err[errOffset + j];
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
if (this.useBias && this.bias) {
|
|
217
|
+
const biasData = this.bias._data;
|
|
218
|
+
for (let j = 0; j < units; j++) {
|
|
219
|
+
biasData[j] += learningRate * err[errOffset + j];
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
public getConfig(): Record<string, any> {
|
|
227
|
+
return {
|
|
228
|
+
...super.getConfig(),
|
|
229
|
+
units: this.units,
|
|
230
|
+
useBias: this.useBias,
|
|
231
|
+
kernelInitializer: this.kernelInitializer,
|
|
232
|
+
biasInitializer: this.biasInitializer,
|
|
233
|
+
beta: this.beta,
|
|
234
|
+
threshold: this.threshold
|
|
235
|
+
};
|
|
236
|
+
}
|
|
237
|
+
}
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import { BaseLayer, LayerConfig, ForwardOptions } from "@oxide-js/layers";
|
|
2
|
+
import { Matrix, mj } from "@oxide-js/core";
|
|
3
|
+
import {
|
|
4
|
+
isNativeAvailable,
|
|
5
|
+
lifStepNativeWrapper,
|
|
6
|
+
maskSurrogateNativeWrapper
|
|
7
|
+
} from "../native_backend.js";
|
|
8
|
+
|
|
9
|
+
export interface SpikingEmbeddingConfig extends LayerConfig {
|
|
10
|
+
inputDim: number; // Ukuran vocabulary
|
|
11
|
+
outputDim: number; // Dimensi embedding (jumlah neuron)
|
|
12
|
+
beta?: number; // Decay factor LIF
|
|
13
|
+
threshold?: number; // Ambang batas Spike
|
|
14
|
+
embeddingsInitializer?: string; // Tipe inisialisasi bobot
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
export class SpikingEmbedding extends BaseLayer {
|
|
18
|
+
public inputDim: number;
|
|
19
|
+
public outputDim: number;
|
|
20
|
+
public beta: number;
|
|
21
|
+
public threshold: number;
|
|
22
|
+
|
|
23
|
+
public potentials!: Matrix;
|
|
24
|
+
public lastPotentials?: Matrix;
|
|
25
|
+
public lastInputs?: Matrix;
|
|
26
|
+
public lastSpikes?: Matrix;
|
|
27
|
+
|
|
28
|
+
public embeddingsInitializer: string;
|
|
29
|
+
|
|
30
|
+
public get kernel(): Matrix | undefined {
|
|
31
|
+
return this.getParameter("kernel");
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
constructor(config: SpikingEmbeddingConfig) {
|
|
35
|
+
super(config);
|
|
36
|
+
this.inputDim = config.inputDim;
|
|
37
|
+
this.outputDim = config.outputDim;
|
|
38
|
+
this.beta = config.beta ?? 0.9;
|
|
39
|
+
this.threshold = config.threshold ?? 1.0;
|
|
40
|
+
this.embeddingsInitializer = config.embeddingsInitializer || "glorot_normal";
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
public computeOutputShape(inputShape: number[]): number[] {
|
|
44
|
+
const batch = inputShape[0] ?? -1;
|
|
45
|
+
return [batch, this.outputDim];
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
public build(inputShape: number[]): void {
|
|
49
|
+
super.build(inputShape);
|
|
50
|
+
const kernelVal = this.createInitializer(this.embeddingsInitializer, [this.inputDim, this.outputDim]);
|
|
51
|
+
this.addParameter("kernel", kernelVal, true, [this.inputDim, this.outputDim]);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
public resetState() {
|
|
55
|
+
if (this.potentials) this.potentials._data.fill(0);
|
|
56
|
+
this.lastPotentials = undefined;
|
|
57
|
+
this.lastInputs = undefined;
|
|
58
|
+
this.lastSpikes = undefined;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
private ensurePotentialsShape(batch: number) {
|
|
62
|
+
if (!this.potentials || this.potentials._shape[0] !== batch) {
|
|
63
|
+
this.potentials = Matrix.fromFlat(
|
|
64
|
+
new Float32Array(batch * this.outputDim),
|
|
65
|
+
[batch, this.outputDim]
|
|
66
|
+
);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
protected compute(inputs: Matrix, options?: ForwardOptions): Matrix {
|
|
71
|
+
const kernel = this.kernel!._data;
|
|
72
|
+
const batch = inputs._shape[0];
|
|
73
|
+
const inputData = inputs._data;
|
|
74
|
+
|
|
75
|
+
this.ensurePotentialsShape(batch);
|
|
76
|
+
|
|
77
|
+
// 1. Lookup Row (Pengganti dot-product)
|
|
78
|
+
const dotData = new Float32Array(batch * this.outputDim);
|
|
79
|
+
for (let b = 0; b < batch; b++) {
|
|
80
|
+
const tokenId = Math.round(inputData[b]); // Asumsi input adalah ID token berukuran [batch, 1]
|
|
81
|
+
|
|
82
|
+
// Jika token valid, ekstrak barisnya sebagai Arus (Current)
|
|
83
|
+
if (tokenId >= 0 && tokenId < this.inputDim) {
|
|
84
|
+
const kernelOffset = tokenId * this.outputDim;
|
|
85
|
+
const dotOffset = b * this.outputDim;
|
|
86
|
+
for (let j = 0; j < this.outputDim; j++) {
|
|
87
|
+
dotData[dotOffset + j] = kernel[kernelOffset + j];
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// 2 & 3. Leaky Integrate, Fire & Reset
|
|
93
|
+
const outData = new Float32Array(batch * this.outputDim);
|
|
94
|
+
const outSpikes = Matrix.fromFlat(outData, [batch, this.outputDim]);
|
|
95
|
+
this.lastPotentials = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
96
|
+
|
|
97
|
+
if (isNativeAvailable()) {
|
|
98
|
+
lifStepNativeWrapper(
|
|
99
|
+
this.potentials._data,
|
|
100
|
+
dotData,
|
|
101
|
+
outSpikes._data,
|
|
102
|
+
this.lastPotentials._data,
|
|
103
|
+
this.beta,
|
|
104
|
+
this.threshold
|
|
105
|
+
);
|
|
106
|
+
} else {
|
|
107
|
+
const potData = this.potentials._data;
|
|
108
|
+
const thresh = this.threshold;
|
|
109
|
+
const lpData = this.lastPotentials._data;
|
|
110
|
+
for (let i = 0; i < potData.length; i++) {
|
|
111
|
+
potData[i] = (potData[i] * this.beta) + dotData[i];
|
|
112
|
+
lpData[i] = potData[i];
|
|
113
|
+
}
|
|
114
|
+
for (let i = 0; i < potData.length; i++) {
|
|
115
|
+
if (potData[i] >= thresh) {
|
|
116
|
+
outData[i] = 1;
|
|
117
|
+
potData[i] -= thresh;
|
|
118
|
+
} else {
|
|
119
|
+
outData[i] = 0;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// Simpan memori untuk update bobot
|
|
125
|
+
this.lastInputs = inputs;
|
|
126
|
+
this.lastSpikes = outSpikes;
|
|
127
|
+
|
|
128
|
+
return outSpikes;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// Embedding hanya menerima instruksi belajar dari layer atasnya (eHidden yang sudah dikalikan matriks B)
|
|
132
|
+
public learnEmbedding(errorFromNext: Matrix, B: Matrix, learningRate: number = 0.01): Matrix {
|
|
133
|
+
if (!this.lastInputs) {
|
|
134
|
+
throw new Error("[SpikingEmbedding] Cannot run learnEmbedding() before forward() is executed. 'lastInputs' is undefined.");
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
const kernel = this.kernel!._data;
|
|
138
|
+
const inputData = this.lastInputs._data;
|
|
139
|
+
const batch = this.lastInputs._shape[0];
|
|
140
|
+
|
|
141
|
+
// Hitung error yang mampir ke embedding
|
|
142
|
+
// E * B (Feedback Alignment)
|
|
143
|
+
// Gunakan matmul biasa karena B adalah float, dan errorFromNext mungkin float
|
|
144
|
+
const eHidden = Matrix.fromFlat(new Float32Array(batch * this.outputDim), [batch, this.outputDim]);
|
|
145
|
+
// Namun karena OxideJS Matrix belum memiliki fungsi dot produk standar terbuka yang stabil,
|
|
146
|
+
// kita harus hati-hati di sini. Untuk simplifikasi, eHidden = errorFromNext * B.
|
|
147
|
+
// Kita asumsikan ada utilitas dotProduct standar dari core.
|
|
148
|
+
// Jika B adalah matriks Dense (dimensi: outUnits x hiddenUnits), maka
|
|
149
|
+
// eHidden [batch, hiddenUnits] = errorFromNext [batch, outUnits] dot B [outUnits, hiddenUnits]
|
|
150
|
+
|
|
151
|
+
// Kita panggil dot product standar (bukan Add-Only, karena error dan B sama-sama float)
|
|
152
|
+
let eHiddenMatrix = mj.dotProduct(errorFromNext, B, undefined, false, false);
|
|
153
|
+
|
|
154
|
+
// Surrogate Mask: Boxcar
|
|
155
|
+
if (this.lastPotentials) {
|
|
156
|
+
if (isNativeAvailable()) {
|
|
157
|
+
maskSurrogateNativeWrapper(
|
|
158
|
+
eHiddenMatrix._data,
|
|
159
|
+
this.lastPotentials._data,
|
|
160
|
+
this.threshold,
|
|
161
|
+
1.0
|
|
162
|
+
);
|
|
163
|
+
} else {
|
|
164
|
+
const eData = eHiddenMatrix._data;
|
|
165
|
+
const pData = this.lastPotentials._data;
|
|
166
|
+
const thresh = this.threshold;
|
|
167
|
+
const windowSize = 1.0;
|
|
168
|
+
|
|
169
|
+
for (let i = 0; i < eData.length; i++) {
|
|
170
|
+
if (Math.abs(pData[i] - thresh) > windowSize) {
|
|
171
|
+
eData[i] = 0;
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
// Delta Rule Update pada baris Lookup (sangat efisien)
|
|
178
|
+
const err = eHiddenMatrix._data;
|
|
179
|
+
for (let b = 0; b < batch; b++) {
|
|
180
|
+
const tokenId = Math.round(inputData[b]);
|
|
181
|
+
if (tokenId >= 0 && tokenId < this.inputDim) {
|
|
182
|
+
const kOffset = tokenId * this.outputDim;
|
|
183
|
+
const errOffset = b * this.outputDim;
|
|
184
|
+
for (let j = 0; j < this.outputDim; j++) {
|
|
185
|
+
kernel[kOffset + j] += learningRate * err[errOffset + j];
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
return eHiddenMatrix;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/**
|
|
194
|
+
* Word2Vec CBOW-style Hebbian Contrastive Learning
|
|
195
|
+
* Memungkinkan pembelajaran embedding semantik secara topologis tanpa representation collapse.
|
|
196
|
+
*/
|
|
197
|
+
public learnHebbian(
|
|
198
|
+
tokens: number[] | Float32Array,
|
|
199
|
+
positiveContext: Float32Array,
|
|
200
|
+
negativeContexts: Float32Array[],
|
|
201
|
+
learningRate: number = 0.01,
|
|
202
|
+
marginPositive: number = 0.1,
|
|
203
|
+
marginNegative: number = 0.05
|
|
204
|
+
): void {
|
|
205
|
+
const kernel = this.kernel!._data;
|
|
206
|
+
const dim = this.outputDim;
|
|
207
|
+
|
|
208
|
+
for (let n = 0; n < negativeContexts.length; n++) {
|
|
209
|
+
const negMean = negativeContexts[n];
|
|
210
|
+
for (let i = 0; i < tokens.length; i++) {
|
|
211
|
+
const tokenId = Math.round(tokens[i]);
|
|
212
|
+
if (tokenId >= 0 && tokenId < this.inputDim) {
|
|
213
|
+
const offset = tokenId * dim;
|
|
214
|
+
for (let j = 0; j < dim; j++) {
|
|
215
|
+
// Tarik kata ke arah konteks kalimatnya (Positive) - hanya sekali per token
|
|
216
|
+
const posGradient = (n === 0) ? (positiveContext[j] - kernel[offset + j]) : 0;
|
|
217
|
+
// Tolak kata dari konteks kalimat acak (Negative)
|
|
218
|
+
const negGradient = kernel[offset + j] - negMean[j];
|
|
219
|
+
|
|
220
|
+
const update = (posGradient * marginPositive) - (negGradient * marginNegative);
|
|
221
|
+
kernel[offset + j] += learningRate * update;
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
}
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import { Matrix } from "@oxide-js/core";
|
|
2
|
+
import { engine } from "@oxide-js/core";
|
|
3
|
+
import { isNativeAvailable, dotProductAddOnlyNativeWrapper } from "../native_backend.js";
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Perkalian product matrix a dan b KHUSUS UNTUK SNN (Add-Only)
|
|
7
|
+
* Salah satu matriks HARUS berupa matriks biner (hanya berisi 0 dan 1).
|
|
8
|
+
* Jika 0 maka di-skip, jika 1 maka cukup tambahkan nilainya tanpa dikalikan.
|
|
9
|
+
*
|
|
10
|
+
* @param a Matrix
|
|
11
|
+
* @param b Matrix
|
|
12
|
+
* @param out Optional Matrix to store result
|
|
13
|
+
* @param transA Jika true, anggap a adalah a^T
|
|
14
|
+
* @param transB Jika true, anggap b adalah b^T
|
|
15
|
+
* @returns Matrix
|
|
16
|
+
*/
|
|
17
|
+
export default function dotProductAddOnly(
|
|
18
|
+
a: Matrix,
|
|
19
|
+
b: Matrix,
|
|
20
|
+
out?: Matrix,
|
|
21
|
+
transA: boolean = false,
|
|
22
|
+
transB: boolean = false
|
|
23
|
+
): Matrix {
|
|
24
|
+
const aRowsOrig = a._shape[0], aColsOrig = a._shape[1];
|
|
25
|
+
const bRowsOrig = b._shape[0], bColsOrig = b._shape[1];
|
|
26
|
+
|
|
27
|
+
const aRows = transA ? aColsOrig : aRowsOrig;
|
|
28
|
+
const aCols = transA ? aRowsOrig : aColsOrig;
|
|
29
|
+
const bRows = transB ? bColsOrig : bRowsOrig;
|
|
30
|
+
const bCols = transB ? bRowsOrig : bColsOrig;
|
|
31
|
+
|
|
32
|
+
if (aCols !== bRows) {
|
|
33
|
+
throw new Error(`Dimensi matrix tidak cocok untuk dot product: [${aRows}x${aCols}] * [${bRows}x${bCols}]`);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
if (out) {
|
|
37
|
+
if (out._shape[0] !== aRows || out._shape[1] !== bCols) {
|
|
38
|
+
throw new Error(`Output matrix shape mismatch: expected [${aRows}x${bCols}], got [${out._shape[0]}x${out._shape[1]}]`);
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// Verifikasi kondisi biner: salah satu matrix harus berupa 0 dan 1
|
|
43
|
+
let aIsBinary = true;
|
|
44
|
+
for (let i = 0; i < a._data.length; i++) {
|
|
45
|
+
const val = a._data[i];
|
|
46
|
+
if (val !== 0 && val !== 1) {
|
|
47
|
+
aIsBinary = false;
|
|
48
|
+
break;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
let bIsBinary = true;
|
|
53
|
+
if (!aIsBinary) {
|
|
54
|
+
for (let i = 0; i < b._data.length; i++) {
|
|
55
|
+
const val = b._data[i];
|
|
56
|
+
if (val !== 0 && val !== 1) {
|
|
57
|
+
bIsBinary = false;
|
|
58
|
+
break;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
if (!aIsBinary && !bIsBinary) {
|
|
64
|
+
throw new Error("SNN Error: Kedua matriks adalah floating-point. Setidaknya salah satu matriks harus hanya berisi 0 dan 1.");
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
const resultData = out ? out._data : new Float32Array(aRows * bCols);
|
|
68
|
+
const aData = a._data;
|
|
69
|
+
const bData = b._data;
|
|
70
|
+
|
|
71
|
+
if (isNativeAvailable()) {
|
|
72
|
+
dotProductAddOnlyNativeWrapper(
|
|
73
|
+
aData,
|
|
74
|
+
aRowsOrig,
|
|
75
|
+
aColsOrig,
|
|
76
|
+
bData,
|
|
77
|
+
bRowsOrig,
|
|
78
|
+
bColsOrig,
|
|
79
|
+
transA,
|
|
80
|
+
transB,
|
|
81
|
+
resultData
|
|
82
|
+
);
|
|
83
|
+
} else {
|
|
84
|
+
// Standar A * B (atau A^T * B)
|
|
85
|
+
if (!transB) {
|
|
86
|
+
if (out) resultData.fill(0);
|
|
87
|
+
for (let i = 0; i < aRows; i++) {
|
|
88
|
+
const rOffset = i * bCols;
|
|
89
|
+
for (let k = 0; k < aCols; k++) {
|
|
90
|
+
const aik = transA ? aData[k * aRows + i] : aData[i * aCols + k];
|
|
91
|
+
|
|
92
|
+
// Skip awal jika kita tahu aik = 0 (berlaku untuk kedua kasus binary)
|
|
93
|
+
if (aik === 0) continue;
|
|
94
|
+
|
|
95
|
+
const kOffset = k * bCols;
|
|
96
|
+
let j = 0;
|
|
97
|
+
const jBound = bCols - 8;
|
|
98
|
+
|
|
99
|
+
if (aIsBinary) {
|
|
100
|
+
// aik pasti 1 di sini
|
|
101
|
+
for (; j <= jBound; j += 8) {
|
|
102
|
+
resultData[rOffset + j] += bData[kOffset + j];
|
|
103
|
+
resultData[rOffset + j + 1] += bData[kOffset + j + 1];
|
|
104
|
+
resultData[rOffset + j + 2] += bData[kOffset + j + 2];
|
|
105
|
+
resultData[rOffset + j + 3] += bData[kOffset + j + 3];
|
|
106
|
+
resultData[rOffset + j + 4] += bData[kOffset + j + 4];
|
|
107
|
+
resultData[rOffset + j + 5] += bData[kOffset + j + 5];
|
|
108
|
+
resultData[rOffset + j + 6] += bData[kOffset + j + 6];
|
|
109
|
+
resultData[rOffset + j + 7] += bData[kOffset + j + 7];
|
|
110
|
+
}
|
|
111
|
+
for (; j < bCols; j++) {
|
|
112
|
+
resultData[rOffset + j] += bData[kOffset + j];
|
|
113
|
+
}
|
|
114
|
+
} else {
|
|
115
|
+
// bIsBinary = true, aik adalah float biasa
|
|
116
|
+
for (; j <= jBound; j += 8) {
|
|
117
|
+
if (bData[kOffset + j] === 1) resultData[rOffset + j] += aik;
|
|
118
|
+
if (bData[kOffset + j + 1] === 1) resultData[rOffset + j + 1] += aik;
|
|
119
|
+
if (bData[kOffset + j + 2] === 1) resultData[rOffset + j + 2] += aik;
|
|
120
|
+
if (bData[kOffset + j + 3] === 1) resultData[rOffset + j + 3] += aik;
|
|
121
|
+
if (bData[kOffset + j + 4] === 1) resultData[rOffset + j + 4] += aik;
|
|
122
|
+
if (bData[kOffset + j + 5] === 1) resultData[rOffset + j + 5] += aik;
|
|
123
|
+
if (bData[kOffset + j + 6] === 1) resultData[rOffset + j + 6] += aik;
|
|
124
|
+
if (bData[kOffset + j + 7] === 1) resultData[rOffset + j + 7] += aik;
|
|
125
|
+
}
|
|
126
|
+
for (; j < bCols; j++) {
|
|
127
|
+
if (bData[kOffset + j] === 1) resultData[rOffset + j] += aik;
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
// A * B^T (atau A^T * B^T)
|
|
134
|
+
else {
|
|
135
|
+
for (let i = 0; i < aRows; i++) {
|
|
136
|
+
const rOffset = i * bCols;
|
|
137
|
+
for (let j = 0; j < bCols; j++) {
|
|
138
|
+
let sum = 0;
|
|
139
|
+
let k = 0;
|
|
140
|
+
const kBound = aCols - 8;
|
|
141
|
+
|
|
142
|
+
if (aIsBinary) {
|
|
143
|
+
for (; k <= kBound; k += 8) {
|
|
144
|
+
const aik0 = transA ? aData[k * aRows + i] : aData[i * aCols + k];
|
|
145
|
+
if (aik0 === 1) sum += bData[j * aCols + k];
|
|
146
|
+
|
|
147
|
+
const aik1 = transA ? aData[(k + 1) * aRows + i] : aData[i * aCols + (k + 1)];
|
|
148
|
+
if (aik1 === 1) sum += bData[j * aCols + (k + 1)];
|
|
149
|
+
|
|
150
|
+
const aik2 = transA ? aData[(k + 2) * aRows + i] : aData[i * aCols + (k + 2)];
|
|
151
|
+
if (aik2 === 1) sum += bData[j * aCols + (k + 2)];
|
|
152
|
+
|
|
153
|
+
const aik3 = transA ? aData[(k + 3) * aRows + i] : aData[i * aCols + (k + 3)];
|
|
154
|
+
if (aik3 === 1) sum += bData[j * aCols + (k + 3)];
|
|
155
|
+
|
|
156
|
+
const aik4 = transA ? aData[(k + 4) * aRows + i] : aData[i * aCols + (k + 4)];
|
|
157
|
+
if (aik4 === 1) sum += bData[j * aCols + (k + 4)];
|
|
158
|
+
|
|
159
|
+
const aik5 = transA ? aData[(k + 5) * aRows + i] : aData[i * aCols + (k + 5)];
|
|
160
|
+
if (aik5 === 1) sum += bData[j * aCols + (k + 5)];
|
|
161
|
+
|
|
162
|
+
const aik6 = transA ? aData[(k + 6) * aRows + i] : aData[i * aCols + (k + 6)];
|
|
163
|
+
if (aik6 === 1) sum += bData[j * aCols + (k + 6)];
|
|
164
|
+
|
|
165
|
+
const aik7 = transA ? aData[(k + 7) * aRows + i] : aData[i * aCols + (k + 7)];
|
|
166
|
+
if (aik7 === 1) sum += bData[j * aCols + (k + 7)];
|
|
167
|
+
}
|
|
168
|
+
for (; k < aCols; k++) {
|
|
169
|
+
const aik = transA ? aData[k * aRows + i] : aData[i * aCols + k];
|
|
170
|
+
if (aik === 1) sum += bData[j * aCols + k];
|
|
171
|
+
}
|
|
172
|
+
} else {
|
|
173
|
+
// bIsBinary = true
|
|
174
|
+
for (; k <= kBound; k += 8) {
|
|
175
|
+
if (bData[j * aCols + k] === 1) {
|
|
176
|
+
sum += transA ? aData[k * aRows + i] : aData[i * aCols + k];
|
|
177
|
+
}
|
|
178
|
+
if (bData[j * aCols + (k + 1)] === 1) {
|
|
179
|
+
sum += transA ? aData[(k + 1) * aRows + i] : aData[i * aCols + (k + 1)];
|
|
180
|
+
}
|
|
181
|
+
if (bData[j * aCols + (k + 2)] === 1) {
|
|
182
|
+
sum += transA ? aData[(k + 2) * aRows + i] : aData[i * aCols + (k + 2)];
|
|
183
|
+
}
|
|
184
|
+
if (bData[j * aCols + (k + 3)] === 1) {
|
|
185
|
+
sum += transA ? aData[(k + 3) * aRows + i] : aData[i * aCols + (k + 3)];
|
|
186
|
+
}
|
|
187
|
+
if (bData[j * aCols + (k + 4)] === 1) {
|
|
188
|
+
sum += transA ? aData[(k + 4) * aRows + i] : aData[i * aCols + (k + 4)];
|
|
189
|
+
}
|
|
190
|
+
if (bData[j * aCols + (k + 5)] === 1) {
|
|
191
|
+
sum += transA ? aData[(k + 5) * aRows + i] : aData[i * aCols + (k + 5)];
|
|
192
|
+
}
|
|
193
|
+
if (bData[j * aCols + (k + 6)] === 1) {
|
|
194
|
+
sum += transA ? aData[(k + 6) * aRows + i] : aData[i * aCols + (k + 6)];
|
|
195
|
+
}
|
|
196
|
+
if (bData[j * aCols + (k + 7)] === 1) {
|
|
197
|
+
sum += transA ? aData[(k + 7) * aRows + i] : aData[i * aCols + (k + 7)];
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
for (; k < aCols; k++) {
|
|
201
|
+
if (bData[j * aCols + k] === 1) {
|
|
202
|
+
sum += transA ? aData[k * aRows + i] : aData[i * aCols + k];
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
resultData[rOffset + j] = sum;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
const res = out ? out : Matrix.fromFlat(resultData, [aRows, bCols]);
|
|
213
|
+
|
|
214
|
+
// RECORD FOR AUTO-DIFF
|
|
215
|
+
// Asumsikan engine tersedia dari core
|
|
216
|
+
if (engine && engine.tape) {
|
|
217
|
+
engine.record([a, b], [res], (grad: Matrix) => {
|
|
218
|
+
const gA = !transA
|
|
219
|
+
? dotProductAddOnly(grad, b, undefined, false, !transB)
|
|
220
|
+
: dotProductAddOnly(b, grad, undefined, transB, true);
|
|
221
|
+
const gB = !transB
|
|
222
|
+
? dotProductAddOnly(a, grad, undefined, !transA, false)
|
|
223
|
+
: dotProductAddOnly(grad, a, undefined, true, transA);
|
|
224
|
+
return [gA, gB];
|
|
225
|
+
}, { saveInput: false, saveOutput: false, requireInputStability: true });
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
return res;
|
|
229
|
+
}
|