@oxide-js/spiking 1.1.0 → 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +11 -0
- package/index.cjs +322 -0
- package/index.d.ts +4 -13
- package/index.js +5 -2
- package/package.json +1 -1
- package/spiking-native.linux-x64-gnu.node +0 -0
- package/src/index.ts +1 -2
- package/src/layers/SpikingDense.ts +55 -36
- package/src/layers/SpikingEmbedding.ts +123 -135
- package/src/native_backend.ts +22 -3
- package/src-rust/src/delta.rs +51 -0
- package/src-rust/src/dot_product.rs +47 -0
- package/src-rust/src/embedding.rs +28 -0
- package/src-rust/src/lib.rs +14 -460
- package/src-rust/src/lif.rs +44 -0
- package/src-rust/src/surrogate.rs +28 -0
- package/test/test_embedding_overlap.ts +181 -0
- package/examples/demo.ts +0 -101
- package/src/models/SpikingSentenceEmbedder.ts +0 -135
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import { Matrix } from "@oxide-js/core";
|
|
2
|
+
import { SpikingEmbedding } from "../src/layers/SpikingEmbedding.js";
|
|
3
|
+
import { SpikingDense } from "../src/layers/SpikingDense.js";
|
|
4
|
+
|
|
5
|
+
// Kamus Kosakata (Vocabulary)
|
|
6
|
+
// 0: Kucing
|
|
7
|
+
// 1: Manusia
|
|
8
|
+
// 2: Makan
|
|
9
|
+
// 3: Ikan
|
|
10
|
+
// 4: Tidur
|
|
11
|
+
// 5: Kasur
|
|
12
|
+
const vocabSize = 6;
|
|
13
|
+
const embedDim = 16;
|
|
14
|
+
const numClasses = 2; // Kelas 0: Hewan, Kelas 1: Manusia
|
|
15
|
+
|
|
16
|
+
// Dataset Kalimat (Urutan indeks kata)
|
|
17
|
+
// Kalimat 1: Kucing(0) makan(2) ikan(3) -> Target: Hewan [1, 0]
|
|
18
|
+
// Kalimat 2: Manusia(1) makan(2) ikan(3) -> Target: Manusia [0, 1]
|
|
19
|
+
// Kalimat 3: Kucing(0) tidur(4) kasur(5) -> Target: Hewan [1, 0]
|
|
20
|
+
// Kalimat 4: Manusia(1) tidur(4) kasur(5) -> Target: Manusia [0, 1]
|
|
21
|
+
|
|
22
|
+
const sentences = [
|
|
23
|
+
[0, 2, 3], // Kucing makan ikan
|
|
24
|
+
[1, 2, 3], // Manusia makan ikan
|
|
25
|
+
[0, 4, 5], // Kucing tidur kasur
|
|
26
|
+
[1, 4, 5] // Manusia tidur kasur
|
|
27
|
+
];
|
|
28
|
+
|
|
29
|
+
const targets = [
|
|
30
|
+
[1, 0], // Hewan
|
|
31
|
+
[0, 1], // Manusia
|
|
32
|
+
[1, 0], // Hewan
|
|
33
|
+
[0, 1] // Manusia
|
|
34
|
+
];
|
|
35
|
+
|
|
36
|
+
console.log("Inisialisasi SpikingEmbedding & SpikingDense (Tes Overlap)...");
|
|
37
|
+
|
|
38
|
+
const embedding = new SpikingEmbedding({
|
|
39
|
+
inputDim: vocabSize,
|
|
40
|
+
outputDim: embedDim,
|
|
41
|
+
beta: 0.9,
|
|
42
|
+
threshold: 1.0,
|
|
43
|
+
embeddingsInitializer: "glorot_normal"
|
|
44
|
+
});
|
|
45
|
+
|
|
46
|
+
const outputLayer = new SpikingDense({
|
|
47
|
+
units: numClasses,
|
|
48
|
+
beta: 0.9,
|
|
49
|
+
threshold: 1.0,
|
|
50
|
+
useBias: true,
|
|
51
|
+
kernelInitializer: "glorot_normal"
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
// Batch size 1, num_tokens 1 (Konteks diumpankan secara berurutan dalam timesteps)
|
|
55
|
+
embedding.build([1, 1]);
|
|
56
|
+
outputLayer.build([1, embedDim]);
|
|
57
|
+
|
|
58
|
+
// Matriks Feedback Alignment (B) yang bernilai acak namun tetap
|
|
59
|
+
const bData = new Float32Array(numClasses * embedDim);
|
|
60
|
+
for (let i = 0; i < bData.length; i++) bData[i] = (Math.random() * 2) - 1;
|
|
61
|
+
const B = Matrix.fromFlat(bData, [numClasses, embedDim]);
|
|
62
|
+
|
|
63
|
+
const epochs = 300; // Epoch butuh lebih banyak karena ada tarik-menarik gradien
|
|
64
|
+
const learningRate = 0.005;
|
|
65
|
+
|
|
66
|
+
console.log("Mulai training SNN Word-to-Class dengan overlapping context...");
|
|
67
|
+
|
|
68
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
69
|
+
let totalError = 0;
|
|
70
|
+
|
|
71
|
+
for (let i = 0; i < sentences.length; i++) {
|
|
72
|
+
const sentence = sentences[i];
|
|
73
|
+
const y = Matrix.fromFlat(new Float32Array(targets[i]), [1, numClasses]);
|
|
74
|
+
|
|
75
|
+
let outSpikes = Matrix.fromFlat(new Float32Array(numClasses), [1, numClasses]);
|
|
76
|
+
let sudahSpike = new Array(numClasses).fill(false);
|
|
77
|
+
|
|
78
|
+
embedding.resetState();
|
|
79
|
+
outputLayer.resetState();
|
|
80
|
+
|
|
81
|
+
// Loop melewati kata-kata di kalimat berulang kali (simulasi durasi waktu SNN membaca)
|
|
82
|
+
// 3 kata x 4 siklus = 12 timestep
|
|
83
|
+
for (let t = 0; t < 12; t++) {
|
|
84
|
+
const tokenIdx = sentence[t % sentence.length]; // Ambil token secara round-robin
|
|
85
|
+
const x = Matrix.fromFlat(new Float32Array([tokenIdx]), [1, 1]);
|
|
86
|
+
|
|
87
|
+
const eSpikes = embedding.forward(x) as Matrix;
|
|
88
|
+
outSpikes = outputLayer.forward(eSpikes) as Matrix;
|
|
89
|
+
|
|
90
|
+
const actual = outSpikes._data;
|
|
91
|
+
const target = y._data;
|
|
92
|
+
|
|
93
|
+
const errData = new Float32Array(numClasses);
|
|
94
|
+
let stepError = 0;
|
|
95
|
+
|
|
96
|
+
for (let j = 0; j < numClasses; j++) {
|
|
97
|
+
if (actual[j] === 1) sudahSpike[j] = true;
|
|
98
|
+
|
|
99
|
+
if (target[j] === 1) {
|
|
100
|
+
if (!sudahSpike[j]) errData[j] = 1; // Dorong sampai spike
|
|
101
|
+
else errData[j] = 0; // Sudah spike, biarkan santai
|
|
102
|
+
} else {
|
|
103
|
+
errData[j] = 0 - actual[j]; // Penalti jika salah spike
|
|
104
|
+
}
|
|
105
|
+
stepError += Math.abs(errData[j]);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
totalError += stepError;
|
|
109
|
+
|
|
110
|
+
if (stepError !== 0) {
|
|
111
|
+
const errorSignal = Matrix.fromFlat(errData, [1, numClasses]);
|
|
112
|
+
outputLayer.learnOutput(errorSignal, learningRate);
|
|
113
|
+
embedding.learnEmbedding(errorSignal, B, learningRate);
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if (epoch % 50 === 0 || epoch === epochs - 1) {
|
|
119
|
+
console.log(`Epoch ${epoch} | Total Spiking Error: ${totalError}`);
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Uji coba inferensi kalimat
|
|
124
|
+
console.log("\n--- HASIL PENGUJIAN KALIMAT ---");
|
|
125
|
+
for (let i = 0; i < sentences.length; i++) {
|
|
126
|
+
const sentence = sentences[i];
|
|
127
|
+
embedding.resetState();
|
|
128
|
+
outputLayer.resetState();
|
|
129
|
+
|
|
130
|
+
let totalTembakan = new Float32Array(numClasses);
|
|
131
|
+
|
|
132
|
+
// Baca kalimat selama 12 timestep
|
|
133
|
+
for (let t = 0; t < 12; t++) {
|
|
134
|
+
const tokenIdx = sentence[t % sentence.length];
|
|
135
|
+
const x = Matrix.fromFlat(new Float32Array([tokenIdx]), [1, 1]);
|
|
136
|
+
|
|
137
|
+
const eSpikes = embedding.forward(x) as Matrix;
|
|
138
|
+
const outSpikes = outputLayer.forward(eSpikes) as Matrix;
|
|
139
|
+
|
|
140
|
+
for (let j = 0; j < numClasses; j++) {
|
|
141
|
+
totalTembakan[j] += outSpikes._data[j];
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
const teks = sentence.map(s => {
|
|
146
|
+
switch (s) {
|
|
147
|
+
case 0: return "Kucing";
|
|
148
|
+
case 1: return "Manusia";
|
|
149
|
+
case 2: return "Makan";
|
|
150
|
+
case 3: return "Ikan";
|
|
151
|
+
case 4: return "Tidur";
|
|
152
|
+
case 5: return "Kasur";
|
|
153
|
+
default: return "";
|
|
154
|
+
}
|
|
155
|
+
}).join(" ");
|
|
156
|
+
|
|
157
|
+
console.log(`Kalimat: "${teks}"`);
|
|
158
|
+
console.log(` -> Prediksi Spike [Hewan, Manusia]: [${totalTembakan.join(", ")}] | Target Seharusnya: [${targets[i].join(", ")}]`);
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
// Uji coba tebak kata individu (untuk melihat sentimen yang dipelajari embedding layer)
|
|
162
|
+
console.log("\n--- ANALISIS SENTIMEN KATA INDIVIDU ---");
|
|
163
|
+
const words = ["Kucing (0)", "Manusia (1)", "Makan (2)", "Ikan (3)", "Tidur (4)", "Kasur (5)"];
|
|
164
|
+
for (let i = 0; i < vocabSize; i++) {
|
|
165
|
+
embedding.resetState();
|
|
166
|
+
outputLayer.resetState();
|
|
167
|
+
let totalTembakan = new Float32Array(numClasses);
|
|
168
|
+
|
|
169
|
+
// Diberi input kata yang sama selama 10 timestep (memaksa SNN memikirkan kata ini saja)
|
|
170
|
+
for (let t = 0; t < 10; t++) {
|
|
171
|
+
const x = Matrix.fromFlat(new Float32Array([i]), [1, 1]);
|
|
172
|
+
const eSpikes = embedding.forward(x) as Matrix;
|
|
173
|
+
const outSpikes = outputLayer.forward(eSpikes) as Matrix;
|
|
174
|
+
|
|
175
|
+
for (let j = 0; j < numClasses; j++) {
|
|
176
|
+
totalTembakan[j] += outSpikes._data[j];
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
console.log(`Kata ${words[i]} -> Pola Spike [Hewan, Manusia]: [${totalTembakan.join(", ")}]`);
|
|
181
|
+
}
|
package/examples/demo.ts
DELETED
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
import { SpikingNetwork } from "../src/core/SpikingNetwork.js";
|
|
2
|
-
import { STDP } from "../src/learning/STDP.js";
|
|
3
|
-
import { PoissonEncoder } from "../src/encoding/PoissonEncoder.js";
|
|
4
|
-
|
|
5
|
-
/**
|
|
6
|
-
* Simulasi Pengenalan Pola Sederhana menggunakan SNN
|
|
7
|
-
* --------------------------------------------------
|
|
8
|
-
* 1. Kita memiliki 3 neuron Input dan 1 neuron Output (Total 4 Neuron)
|
|
9
|
-
* 2. Input merepresentasikan intensitas sinyal (misalnya: [0.9, 0.1, 0.8])
|
|
10
|
-
* 3. Neuron Input terhubung ke Neuron Output
|
|
11
|
-
* 4. STDP akan mengamati dan memperkuat bobot dari input yang sering "menyala" (0.9 dan 0.8)
|
|
12
|
-
*/
|
|
13
|
-
async function runDemo() {
|
|
14
|
-
console.log("=== Memulai Simulasi SNN Add-Only ===\n");
|
|
15
|
-
|
|
16
|
-
// Konfigurasi
|
|
17
|
-
const NUM_INPUTS = 3;
|
|
18
|
-
const NUM_OUTPUTS = 1;
|
|
19
|
-
const TOTAL_NEURONS = NUM_INPUTS + NUM_OUTPUTS;
|
|
20
|
-
const TIME_STEPS = 100;
|
|
21
|
-
|
|
22
|
-
// 1. Inisialisasi Jaringan & Komponen
|
|
23
|
-
// Threshold di-set rendah agar output mudah menembakkan spike
|
|
24
|
-
const net = new SpikingNetwork(TOTAL_NEURONS, 0.8, 2.0);
|
|
25
|
-
const stdp = new STDP(net, {
|
|
26
|
-
learningRate: 0.05,
|
|
27
|
-
aPlus: 1.5,
|
|
28
|
-
aMinus: 0.5, // LTD diturunkan agar sinyal yang sering aktif (berkorelasi kuat) membesar
|
|
29
|
-
wMax: 10.0,
|
|
30
|
-
wMin: 0.0 // Bobot tidak boleh negatif dalam contoh ini
|
|
31
|
-
});
|
|
32
|
-
const encoder = new PoissonEncoder(1.0);
|
|
33
|
-
|
|
34
|
-
// Neuron Output berada di indeks ke-3 (karena 0,1,2 adalah input)
|
|
35
|
-
const OUTPUT_NEURON = 3;
|
|
36
|
-
|
|
37
|
-
// 2. Hubungkan Input ke Output dengan bobot awal yang seragam (kecil)
|
|
38
|
-
console.log("Membangun koneksi awal:");
|
|
39
|
-
for (let i = 0; i < NUM_INPUTS; i++) {
|
|
40
|
-
net.connect(i, OUTPUT_NEURON, 1.0);
|
|
41
|
-
console.log(`- Input Neuron ${i} -> Output Neuron ${OUTPUT_NEURON} (Bobot Awal: 1.0)`);
|
|
42
|
-
}
|
|
43
|
-
console.log("\n");
|
|
44
|
-
|
|
45
|
-
// 3. Menyiapkan Pola Input
|
|
46
|
-
// Neuron 0 (Sinyal Kuat), Neuron 1 (Sinyal Lemah), Neuron 2 (Sinyal Kuat)
|
|
47
|
-
const inputPattern = [0.9, 0.1, 0.8];
|
|
48
|
-
|
|
49
|
-
let totalOutputSpikes = 0;
|
|
50
|
-
|
|
51
|
-
console.log(`Menjalankan simulasi selama ${TIME_STEPS} time-steps...`);
|
|
52
|
-
|
|
53
|
-
for (let t = 0; t < TIME_STEPS; t++) {
|
|
54
|
-
// A. Encode sinyal input kontinu menjadi Spike
|
|
55
|
-
const currentSpikes = encoder.encodeArray(inputPattern);
|
|
56
|
-
|
|
57
|
-
// B. Injeksi spike input secara langsung ke neuron input
|
|
58
|
-
// (Beri potensial yang cukup agar mereka seketika menembakkan spike di step ini)
|
|
59
|
-
for (let i = 0; i < NUM_INPUTS; i++) {
|
|
60
|
-
if (currentSpikes[i] === 1) {
|
|
61
|
-
// Injeksi arus yang jauh di atas threshold (2.0) agar pasti spike
|
|
62
|
-
net.injectCurrent(i, 10.0);
|
|
63
|
-
}
|
|
64
|
-
}
|
|
65
|
-
|
|
66
|
-
// C. Evaluasi Jaringan (Add-Only Propagation)
|
|
67
|
-
net.step();
|
|
68
|
-
|
|
69
|
-
// D. Hitung Spike Output
|
|
70
|
-
if (net.spikes[OUTPUT_NEURON] === 1) {
|
|
71
|
-
totalOutputSpikes++;
|
|
72
|
-
}
|
|
73
|
-
|
|
74
|
-
// E. Lakukan Proses Pembelajaran (Plasticity)
|
|
75
|
-
stdp.updateWeights();
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
// 4. Hasil Simulasi
|
|
79
|
-
console.log("\n=== Hasil Simulasi ===");
|
|
80
|
-
console.log(`Total tembakan (spikes) dari Neuron Output: ${totalOutputSpikes}`);
|
|
81
|
-
console.log("\nPerubahan Bobot Akhir (Setelah proses Belajar STDP):");
|
|
82
|
-
|
|
83
|
-
for (let i = 0; i < NUM_INPUTS; i++) {
|
|
84
|
-
// Bobot ke-0 dari array koneksi neuron input 'i'
|
|
85
|
-
const finalWeight = net.weights[i][0];
|
|
86
|
-
const initialInput = inputPattern[i];
|
|
87
|
-
|
|
88
|
-
let status = "";
|
|
89
|
-
if (finalWeight > 1.0) status = "📈 Diperkuat (LTP)";
|
|
90
|
-
else if (finalWeight < 1.0) status = "📉 Diperlemah (LTD)";
|
|
91
|
-
else status = "➖ Tetap";
|
|
92
|
-
|
|
93
|
-
console.log(`Neuron ${i} (Intensitas Sinyal: ${initialInput}) -> Bobot Akhir: ${finalWeight.toFixed(4)} ${status}`);
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
console.log("\nKesimpulan:");
|
|
97
|
-
console.log("Seperti yang terlihat, SNN secara otomatis mengenali dan memperkuat koneksi dari neuron input yang aktif (Intensitas 0.9 & 0.8),");
|
|
98
|
-
console.log("sementara neuron yang jarang aktif (Intensitas 0.1) bobotnya melemah secara natural melalui aturan STDP.");
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
runDemo().catch(console.error);
|
|
@@ -1,135 +0,0 @@
|
|
|
1
|
-
import { Matrix } from "@oxide-js/core";
|
|
2
|
-
import { BaseModel } from "@oxide-js/models";
|
|
3
|
-
import { SpikingEmbedding } from "../layers/SpikingEmbedding.js";
|
|
4
|
-
import { SpikingDense } from "../layers/SpikingDense.js";
|
|
5
|
-
|
|
6
|
-
export interface SpikingSentenceConfig {
|
|
7
|
-
vocabSize: number;
|
|
8
|
-
embedDim: number;
|
|
9
|
-
beta?: number;
|
|
10
|
-
threshold?: number;
|
|
11
|
-
}
|
|
12
|
-
|
|
13
|
-
export class SpikingSentenceEmbedder extends BaseModel {
|
|
14
|
-
public vocabSize: number;
|
|
15
|
-
public embedDim: number;
|
|
16
|
-
|
|
17
|
-
public embedding: SpikingEmbedding;
|
|
18
|
-
public contextLayer: SpikingDense;
|
|
19
|
-
|
|
20
|
-
constructor(config: SpikingSentenceConfig) {
|
|
21
|
-
super();
|
|
22
|
-
this.vocabSize = config.vocabSize;
|
|
23
|
-
this.embedDim = config.embedDim;
|
|
24
|
-
const beta = config.beta ?? 0.9;
|
|
25
|
-
const threshold = config.threshold ?? 1.0;
|
|
26
|
-
|
|
27
|
-
this.embedding = new SpikingEmbedding({
|
|
28
|
-
inputDim: this.vocabSize,
|
|
29
|
-
outputDim: this.embedDim,
|
|
30
|
-
beta: beta,
|
|
31
|
-
threshold: threshold,
|
|
32
|
-
embeddingsInitializer: "glorot_normal"
|
|
33
|
-
});
|
|
34
|
-
|
|
35
|
-
this.contextLayer = new SpikingDense({
|
|
36
|
-
units: this.embedDim,
|
|
37
|
-
beta: beta,
|
|
38
|
-
threshold: threshold,
|
|
39
|
-
useBias: true,
|
|
40
|
-
kernelInitializer: "glorot_normal"
|
|
41
|
-
});
|
|
42
|
-
|
|
43
|
-
this.add(this.embedding);
|
|
44
|
-
this.add(this.contextLayer);
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
public resetState() {
|
|
48
|
-
this.embedding.resetState();
|
|
49
|
-
this.contextLayer.resetState();
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
/**
|
|
53
|
-
* Membaca sebuah kalimat utuh dan mengubahnya menjadi Vektor Semantik tunggal (Spike Count)
|
|
54
|
-
* @param inputs Matrix Token ID dari kalimat (shape: [batch=1, seq_len])
|
|
55
|
-
* @param optionsOrTraining Opsi forward (tidak dipakai di SNN ini, tapi dibutuhkan oleh abstract method)
|
|
56
|
-
* @returns Matrix Vektor berukuran `[1, embedDim]` yang berisi total Spike (Representasi Makna Kalimat)
|
|
57
|
-
*/
|
|
58
|
-
public forward(inputs: Matrix, optionsOrTraining?: any): Matrix {
|
|
59
|
-
if (!this.isBuilt) {
|
|
60
|
-
this.build([1, 1]); // SNN layer selalu memproses kata per kata
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
this.resetState();
|
|
64
|
-
const semanticVector = new Float32Array(this.embedDim);
|
|
65
|
-
|
|
66
|
-
const seqLen = inputs._shape.length > 1 ? inputs._shape[1] : inputs._shape[0];
|
|
67
|
-
const inputData = inputs._data;
|
|
68
|
-
|
|
69
|
-
for (let i = 0; i < seqLen; i++) {
|
|
70
|
-
const x = Matrix.fromFlat(new Float32Array([inputData[i]]), [1, 1]);
|
|
71
|
-
|
|
72
|
-
// Default readingTime = 3 timestep
|
|
73
|
-
for (let t = 0; t < 3; t++) {
|
|
74
|
-
const wordSpikes = this.embedding.forward(x) as Matrix;
|
|
75
|
-
const contextSpikes = this.contextLayer.forward(wordSpikes) as Matrix;
|
|
76
|
-
|
|
77
|
-
const outData = contextSpikes._data;
|
|
78
|
-
for(let j=0; j<this.embedDim; j++) {
|
|
79
|
-
semanticVector[j] += outData[j];
|
|
80
|
-
}
|
|
81
|
-
}
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
return Matrix.fromFlat(semanticVector, [1, this.embedDim]);
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
/**
|
|
88
|
-
* Melatih model sentence embedder menggunakan prinsip Word2Vec CBOW-style Hebbian Contrastive Learning.
|
|
89
|
-
* Secara otomatis mengambil rata-rata vektor konteks kalimat saat ini (Positive),
|
|
90
|
-
* dan menolaknya dari list mean vector negatif (NegativeContexts).
|
|
91
|
-
* Metode ini bypass context layer dan langsung melatih embedding untuk mencegah representation collapse.
|
|
92
|
-
*
|
|
93
|
-
* @returns Float32Array rata-rata vektor (meanVec) dari kalimat saat ini yang dapat disimpan ke historyBuffer.
|
|
94
|
-
*/
|
|
95
|
-
public learnContrastive(
|
|
96
|
-
tokens: number[] | Float32Array,
|
|
97
|
-
negativeContexts: Float32Array[],
|
|
98
|
-
learningRate: number = 0.01,
|
|
99
|
-
marginPositive: number = 0.1,
|
|
100
|
-
marginNegative: number = 0.05
|
|
101
|
-
): Float32Array {
|
|
102
|
-
const wordEmbeddings: { tokenId: number, vec: Float32Array }[] = [];
|
|
103
|
-
const kernel = this.embedding.getParameter('kernel')!._data;
|
|
104
|
-
const dim = this.embedDim;
|
|
105
|
-
|
|
106
|
-
for (let i = 0; i < tokens.length; i++) {
|
|
107
|
-
const tokenId = Math.round(tokens[i]);
|
|
108
|
-
if (tokenId >= 0 && tokenId < this.vocabSize) {
|
|
109
|
-
const offset = tokenId * dim;
|
|
110
|
-
const vec = new Float32Array(dim);
|
|
111
|
-
for (let j = 0; j < dim; j++) vec[j] = kernel[offset + j];
|
|
112
|
-
wordEmbeddings.push({ tokenId, vec });
|
|
113
|
-
}
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
const meanVec = new Float32Array(dim);
|
|
117
|
-
if (wordEmbeddings.length > 0) {
|
|
118
|
-
for (const w of wordEmbeddings) {
|
|
119
|
-
for (let j = 0; j < dim; j++) meanVec[j] += w.vec[j];
|
|
120
|
-
}
|
|
121
|
-
for (let j = 0; j < dim; j++) meanVec[j] /= wordEmbeddings.length;
|
|
122
|
-
|
|
123
|
-
this.embedding.learnHebbian(
|
|
124
|
-
tokens,
|
|
125
|
-
meanVec,
|
|
126
|
-
negativeContexts,
|
|
127
|
-
learningRate,
|
|
128
|
-
marginPositive,
|
|
129
|
-
marginNegative
|
|
130
|
-
);
|
|
131
|
-
}
|
|
132
|
-
|
|
133
|
-
return meanVec;
|
|
134
|
-
}
|
|
135
|
-
}
|