@dniskav/neuron 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +122 -1
- package/dist/index.d.mts +290 -1
- package/dist/index.d.ts +290 -1
- package/dist/index.js +1118 -0
- package/dist/index.mjs +1110 -0
- package/package.json +1 -1
package/dist/index.mjs
CHANGED
|
@@ -2531,6 +2531,155 @@ var DataLoader = class _DataLoader {
|
|
|
2531
2531
|
}
|
|
2532
2532
|
};
|
|
2533
2533
|
|
|
2534
|
+
// src/DatasetLoader.ts
|
|
2535
|
+
var DatasetLoader = class _DatasetLoader {
|
|
2536
|
+
// ── CSV ─────────────────────────────────────────────────────────────────────
|
|
2537
|
+
/**
|
|
2538
|
+
* Parse a CSV string into a DataPair.
|
|
2539
|
+
*
|
|
2540
|
+
* - The first non-empty row is treated as a header.
|
|
2541
|
+
* - Numeric values are parsed with parseFloat.
|
|
2542
|
+
* - String values are one-hot encoded (one column → N binary columns).
|
|
2543
|
+
* - Empty rows and comment lines (starting with #) are skipped.
|
|
2544
|
+
*
|
|
2545
|
+
* @param csv - raw CSV text
|
|
2546
|
+
* @param options - which columns to use as features / targets
|
|
2547
|
+
*/
|
|
2548
|
+
static fromCSV(csv, options) {
|
|
2549
|
+
const rows = _DatasetLoader._parseCSV(csv);
|
|
2550
|
+
if (rows.length < 2) throw new Error("DatasetLoader.fromCSV: CSV must have a header row and at least one data row.");
|
|
2551
|
+
const header = rows[0];
|
|
2552
|
+
const dataRows = rows.slice(1);
|
|
2553
|
+
return _DatasetLoader._buildDataPair(header, dataRows, options);
|
|
2554
|
+
}
|
|
2555
|
+
// ── JSON ─────────────────────────────────────────────────────────────────────
|
|
2556
|
+
/**
|
|
2557
|
+
* Parse a JSON string (array of objects) into a DataPair.
|
|
2558
|
+
*
|
|
2559
|
+
* Expected format:
|
|
2560
|
+
* [{ "col1": 1.0, "col2": "cat", "label": "dog" }, ...]
|
|
2561
|
+
*
|
|
2562
|
+
* @param json - raw JSON text or a pre-parsed array of objects
|
|
2563
|
+
* @param options - which columns to use as features / targets
|
|
2564
|
+
*/
|
|
2565
|
+
static fromJSON(json, options) {
|
|
2566
|
+
const records = typeof json === "string" ? JSON.parse(json) : json;
|
|
2567
|
+
if (!Array.isArray(records) || records.length === 0) {
|
|
2568
|
+
throw new Error("DatasetLoader.fromJSON: expected a non-empty JSON array of objects.");
|
|
2569
|
+
}
|
|
2570
|
+
const header = Object.keys(records[0]);
|
|
2571
|
+
const dataRows = records.map((row) => header.map((col) => String(row[col] ?? "")));
|
|
2572
|
+
return _DatasetLoader._buildDataPair(header, dataRows, options);
|
|
2573
|
+
}
|
|
2574
|
+
// ── Private: shared pipeline ──────────────────────────────────────────────
|
|
2575
|
+
static _buildDataPair(header, dataRows, options) {
|
|
2576
|
+
const { featureCols, targetCols, encodeStrings = true } = options;
|
|
2577
|
+
for (const col of [...featureCols, ...targetCols]) {
|
|
2578
|
+
if (!header.includes(col)) {
|
|
2579
|
+
throw new Error(`DatasetLoader: column "${col}" not found in header [${header.join(", ")}].`);
|
|
2580
|
+
}
|
|
2581
|
+
}
|
|
2582
|
+
const catMaps = {};
|
|
2583
|
+
const buildEncoder = (cols) => {
|
|
2584
|
+
for (const col of cols) {
|
|
2585
|
+
const colIdx = header.indexOf(col);
|
|
2586
|
+
const values = dataRows.map((row) => row[colIdx]);
|
|
2587
|
+
const isNumeric = values.every((v) => v === "" || !isNaN(Number(v)));
|
|
2588
|
+
if (!isNumeric) {
|
|
2589
|
+
if (!encodeStrings) {
|
|
2590
|
+
throw new Error(`DatasetLoader: column "${col}" contains non-numeric values. Set encodeStrings: true to one-hot encode them.`);
|
|
2591
|
+
}
|
|
2592
|
+
const unique = [...new Set(values)].sort();
|
|
2593
|
+
catMaps[col] = Object.fromEntries(unique.map((v, i) => [v, i]));
|
|
2594
|
+
}
|
|
2595
|
+
}
|
|
2596
|
+
};
|
|
2597
|
+
buildEncoder(featureCols);
|
|
2598
|
+
buildEncoder(targetCols);
|
|
2599
|
+
const encodeValue = (col, raw) => {
|
|
2600
|
+
if (catMaps[col]) {
|
|
2601
|
+
const categories = catMaps[col];
|
|
2602
|
+
const n = Object.keys(categories).length;
|
|
2603
|
+
const vec = new Array(n).fill(0);
|
|
2604
|
+
const idx = categories[raw];
|
|
2605
|
+
if (idx !== void 0) vec[idx] = 1;
|
|
2606
|
+
return vec;
|
|
2607
|
+
}
|
|
2608
|
+
return [parseFloat(raw)];
|
|
2609
|
+
};
|
|
2610
|
+
const expandNames = (cols) => cols.flatMap((col) => {
|
|
2611
|
+
if (catMaps[col]) {
|
|
2612
|
+
return Object.keys(catMaps[col]).map((cat) => `${col}_${cat}`);
|
|
2613
|
+
}
|
|
2614
|
+
return [col];
|
|
2615
|
+
});
|
|
2616
|
+
const featureNames = expandNames(featureCols);
|
|
2617
|
+
const targetNames = expandNames(targetCols);
|
|
2618
|
+
const inputs = [];
|
|
2619
|
+
const targets = [];
|
|
2620
|
+
for (const row of dataRows) {
|
|
2621
|
+
const input = featureCols.flatMap((col) => {
|
|
2622
|
+
const raw = row[header.indexOf(col)];
|
|
2623
|
+
return encodeValue(col, raw);
|
|
2624
|
+
});
|
|
2625
|
+
const target = targetCols.flatMap((col) => {
|
|
2626
|
+
const raw = row[header.indexOf(col)];
|
|
2627
|
+
return encodeValue(col, raw);
|
|
2628
|
+
});
|
|
2629
|
+
inputs.push(input);
|
|
2630
|
+
targets.push(target);
|
|
2631
|
+
}
|
|
2632
|
+
return {
|
|
2633
|
+
inputs,
|
|
2634
|
+
targets,
|
|
2635
|
+
categoricalMaps: catMaps,
|
|
2636
|
+
featureNames,
|
|
2637
|
+
targetNames,
|
|
2638
|
+
numRows: dataRows.length
|
|
2639
|
+
};
|
|
2640
|
+
}
|
|
2641
|
+
// ── Private: RFC 4180-compatible CSV parser ───────────────────────────────
|
|
2642
|
+
static _parseCSV(csv) {
|
|
2643
|
+
const rows = [];
|
|
2644
|
+
const lines = csv.split(/\r?\n/);
|
|
2645
|
+
for (const line of lines) {
|
|
2646
|
+
const trimmed = line.trim();
|
|
2647
|
+
if (!trimmed || trimmed.startsWith("#")) continue;
|
|
2648
|
+
rows.push(_DatasetLoader._parseCSVRow(trimmed));
|
|
2649
|
+
}
|
|
2650
|
+
return rows;
|
|
2651
|
+
}
|
|
2652
|
+
static _parseCSVRow(line) {
|
|
2653
|
+
const fields = [];
|
|
2654
|
+
let current = "";
|
|
2655
|
+
let inQuotes = false;
|
|
2656
|
+
for (let i = 0; i < line.length; i++) {
|
|
2657
|
+
const ch = line[i];
|
|
2658
|
+
if (inQuotes) {
|
|
2659
|
+
if (ch === '"' && line[i + 1] === '"') {
|
|
2660
|
+
current += '"';
|
|
2661
|
+
i++;
|
|
2662
|
+
} else if (ch === '"') {
|
|
2663
|
+
inQuotes = false;
|
|
2664
|
+
} else {
|
|
2665
|
+
current += ch;
|
|
2666
|
+
}
|
|
2667
|
+
} else {
|
|
2668
|
+
if (ch === '"') {
|
|
2669
|
+
inQuotes = true;
|
|
2670
|
+
} else if (ch === ",") {
|
|
2671
|
+
fields.push(current.trim());
|
|
2672
|
+
current = "";
|
|
2673
|
+
} else {
|
|
2674
|
+
current += ch;
|
|
2675
|
+
}
|
|
2676
|
+
}
|
|
2677
|
+
}
|
|
2678
|
+
fields.push(current.trim());
|
|
2679
|
+
return fields;
|
|
2680
|
+
}
|
|
2681
|
+
};
|
|
2682
|
+
|
|
2534
2683
|
// src/LRScheduler.ts
|
|
2535
2684
|
var LRScheduler = class {
|
|
2536
2685
|
// ── Step Decay ────────────────────────────────────────────────────────────
|
|
@@ -4736,6 +4885,749 @@ var TCN = class {
|
|
|
4736
4885
|
}
|
|
4737
4886
|
};
|
|
4738
4887
|
|
|
4888
|
+
// src/Word2Vec.ts
|
|
4889
|
+
var Word2Vec = class {
|
|
4890
|
+
constructor(embeddingDim = 50, options = {}) {
|
|
4891
|
+
this._trained = false;
|
|
4892
|
+
this.embeddingDim = embeddingDim;
|
|
4893
|
+
this._windowSize = options.windowSize ?? 2;
|
|
4894
|
+
this._model = options.model ?? "skipgram";
|
|
4895
|
+
this._minCount = options.minCount ?? 1;
|
|
4896
|
+
this.embeddings = [];
|
|
4897
|
+
this._W2 = [];
|
|
4898
|
+
this.vocab = /* @__PURE__ */ new Map();
|
|
4899
|
+
this._indexToWord = [];
|
|
4900
|
+
this.vocabSize = 0;
|
|
4901
|
+
}
|
|
4902
|
+
// ── buildVocab ─────────────────────────────────────────────────────────────
|
|
4903
|
+
// Scans the corpus, counts word frequencies, discards rare words (< minCount),
|
|
4904
|
+
// and assigns each remaining word a unique integer index.
|
|
4905
|
+
buildVocab(sentences) {
|
|
4906
|
+
const freq = /* @__PURE__ */ new Map();
|
|
4907
|
+
for (const sentence of sentences) {
|
|
4908
|
+
for (const word of sentence) {
|
|
4909
|
+
freq.set(word, (freq.get(word) ?? 0) + 1);
|
|
4910
|
+
}
|
|
4911
|
+
}
|
|
4912
|
+
this.vocab = /* @__PURE__ */ new Map();
|
|
4913
|
+
this._indexToWord = [];
|
|
4914
|
+
for (const [word, count] of freq) {
|
|
4915
|
+
if (count >= this._minCount) {
|
|
4916
|
+
const idx = this._indexToWord.length;
|
|
4917
|
+
this.vocab.set(word, idx);
|
|
4918
|
+
this._indexToWord.push(word);
|
|
4919
|
+
}
|
|
4920
|
+
}
|
|
4921
|
+
this.vocabSize = this._indexToWord.length;
|
|
4922
|
+
if (this.vocabSize === 0) {
|
|
4923
|
+
throw new Error("Word2Vec.buildVocab: vocabulary is empty after applying minCount filter");
|
|
4924
|
+
}
|
|
4925
|
+
const scale1 = Math.sqrt(1 / this.embeddingDim);
|
|
4926
|
+
const scale2 = Math.sqrt(1 / this.vocabSize);
|
|
4927
|
+
this.embeddings = Array.from(
|
|
4928
|
+
{ length: this.vocabSize },
|
|
4929
|
+
() => Array.from({ length: this.embeddingDim }, () => (Math.random() * 2 - 1) * scale1)
|
|
4930
|
+
);
|
|
4931
|
+
this._W2 = Array.from(
|
|
4932
|
+
{ length: this.embeddingDim },
|
|
4933
|
+
() => Array.from({ length: this.vocabSize }, () => (Math.random() * 2 - 1) * scale2)
|
|
4934
|
+
);
|
|
4935
|
+
this._trained = false;
|
|
4936
|
+
}
|
|
4937
|
+
// ── tokenize ───────────────────────────────────────────────────────────────
|
|
4938
|
+
// Simple tokenizer: lowercase, strip punctuation, split on whitespace.
|
|
4939
|
+
// Returns an array of tokens suitable for buildVocab / train.
|
|
4940
|
+
static tokenize(text) {
|
|
4941
|
+
return text.toLowerCase().replace(/[^a-z0-9\s'-]/g, " ").split(/\s+/).filter((t) => t.length > 0);
|
|
4942
|
+
}
|
|
4943
|
+
// ── train ──────────────────────────────────────────────────────────────────
|
|
4944
|
+
// Runs SGD over all (center, context) pairs in the corpus for `epochs` passes.
|
|
4945
|
+
// Returns the average cross-entropy loss per epoch.
|
|
4946
|
+
//
|
|
4947
|
+
// Note: uses full-vocabulary softmax (not negative sampling) for educational
|
|
4948
|
+
// clarity. This is O(vocabSize) per step — for large vocabularies you would
|
|
4949
|
+
// normally switch to negative sampling or hierarchical softmax.
|
|
4950
|
+
train(sentences, lr = 0.025, epochs = 5) {
|
|
4951
|
+
if (this.vocabSize === 0) this.buildVocab(sentences);
|
|
4952
|
+
const lossHistory = [];
|
|
4953
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
4954
|
+
let totalLoss = 0;
|
|
4955
|
+
let nPairs = 0;
|
|
4956
|
+
for (const sentence of sentences) {
|
|
4957
|
+
const indices = sentence.map((w) => this.vocab.get(w)).filter((idx) => idx !== void 0);
|
|
4958
|
+
for (let t = 0; t < indices.length; t++) {
|
|
4959
|
+
const centerIdx = indices[t];
|
|
4960
|
+
const contextIndices = [];
|
|
4961
|
+
for (let offset = -this._windowSize; offset <= this._windowSize; offset++) {
|
|
4962
|
+
if (offset === 0) continue;
|
|
4963
|
+
const pos = t + offset;
|
|
4964
|
+
if (pos >= 0 && pos < indices.length) {
|
|
4965
|
+
contextIndices.push(indices[pos]);
|
|
4966
|
+
}
|
|
4967
|
+
}
|
|
4968
|
+
if (contextIndices.length === 0) continue;
|
|
4969
|
+
if (this._model === "skipgram") {
|
|
4970
|
+
for (const contextIdx of contextIndices) {
|
|
4971
|
+
totalLoss += this._skipgramStep(centerIdx, contextIdx, lr);
|
|
4972
|
+
nPairs++;
|
|
4973
|
+
}
|
|
4974
|
+
} else {
|
|
4975
|
+
totalLoss += this._cbowStep(centerIdx, contextIndices, lr);
|
|
4976
|
+
nPairs++;
|
|
4977
|
+
}
|
|
4978
|
+
}
|
|
4979
|
+
}
|
|
4980
|
+
lossHistory.push(nPairs > 0 ? totalLoss / nPairs : 0);
|
|
4981
|
+
}
|
|
4982
|
+
this._trained = true;
|
|
4983
|
+
return lossHistory;
|
|
4984
|
+
}
|
|
4985
|
+
// ── getEmbedding ───────────────────────────────────────────────────────────
|
|
4986
|
+
// Returns the learned embedding vector for a word. Throws if unknown.
|
|
4987
|
+
getEmbedding(word) {
|
|
4988
|
+
const idx = this.vocab.get(word);
|
|
4989
|
+
if (idx === void 0) throw new Error(`Word2Vec: unknown word "${word}"`);
|
|
4990
|
+
return this.embeddings[idx];
|
|
4991
|
+
}
|
|
4992
|
+
// ── similarity ─────────────────────────────────────────────────────────────
|
|
4993
|
+
// Cosine similarity between two words.
|
|
4994
|
+
// cos(v1, v2) = (v1 · v2) / (‖v1‖ · ‖v2‖)
|
|
4995
|
+
// Returns a value in [-1, 1]. Higher → more similar context usage.
|
|
4996
|
+
similarity(word1, word2) {
|
|
4997
|
+
const v1 = this.getEmbedding(word1);
|
|
4998
|
+
const v2 = this.getEmbedding(word2);
|
|
4999
|
+
return this._cosine(v1, v2);
|
|
5000
|
+
}
|
|
5001
|
+
// ── mostSimilar ────────────────────────────────────────────────────────────
|
|
5002
|
+
// Returns the topK words (excluding `word` itself) sorted by cosine similarity.
|
|
5003
|
+
mostSimilar(word, topK = 10) {
|
|
5004
|
+
const v = this.getEmbedding(word);
|
|
5005
|
+
return this._nearestByVector(v, topK, /* @__PURE__ */ new Set([word]));
|
|
5006
|
+
}
|
|
5007
|
+
// ── analogy ───────────────────────────────────────────────────────────────
|
|
5008
|
+
// Vector arithmetic analogy: positive1 - negative + positive2 ≈ result
|
|
5009
|
+
//
|
|
5010
|
+
// getAnalogy('king', 'man', 'woman') finds the word closest to
|
|
5011
|
+
// vec('king') - vec('man') + vec('woman') ≈ vec('queen')
|
|
5012
|
+
//
|
|
5013
|
+
// The result is excluded from the input words so they don't pollute the top-K.
|
|
5014
|
+
analogy(positive1, negative, positive2, topK = 5) {
|
|
5015
|
+
const vPos1 = this.getEmbedding(positive1);
|
|
5016
|
+
const vNeg = this.getEmbedding(negative);
|
|
5017
|
+
const vPos2 = this.getEmbedding(positive2);
|
|
5018
|
+
const target = vPos1.map((v, i) => v - vNeg[i] + vPos2[i]);
|
|
5019
|
+
const exclude = /* @__PURE__ */ new Set([positive1, negative, positive2]);
|
|
5020
|
+
return this._nearestByVector(target, topK, exclude);
|
|
5021
|
+
}
|
|
5022
|
+
// ── Private: skip-gram step ───────────────────────────────────────────────
|
|
5023
|
+
// Forward + backward for one (center, target) pair.
|
|
5024
|
+
// Returns the cross-entropy loss for this pair.
|
|
5025
|
+
_skipgramStep(centerIdx, targetIdx, lr) {
|
|
5026
|
+
const h = this.embeddings[centerIdx];
|
|
5027
|
+
const scores = this._hiddenToScores(h);
|
|
5028
|
+
const probs = _softmax(scores);
|
|
5029
|
+
const loss = -Math.log(probs[targetIdx] + 1e-12);
|
|
5030
|
+
const err = probs.map((p, j) => j === targetIdx ? p - 1 : p);
|
|
5031
|
+
const dh = new Array(this.embeddingDim).fill(0);
|
|
5032
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
5033
|
+
for (let j = 0; j < this.vocabSize; j++) {
|
|
5034
|
+
this._W2[d][j] -= lr * h[d] * err[j];
|
|
5035
|
+
dh[d] += this._W2[d][j] * err[j];
|
|
5036
|
+
}
|
|
5037
|
+
}
|
|
5038
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
5039
|
+
this.embeddings[centerIdx][d] -= lr * dh[d];
|
|
5040
|
+
}
|
|
5041
|
+
return loss;
|
|
5042
|
+
}
|
|
5043
|
+
// ── Private: CBOW step ────────────────────────────────────────────────────
|
|
5044
|
+
// Forward + backward for one (contextIndices → centerIdx) pair.
|
|
5045
|
+
// h is the mean of all context embeddings. The gradient is distributed
|
|
5046
|
+
// equally back to each context word's embedding row.
|
|
5047
|
+
_cbowStep(centerIdx, contextIndices, lr) {
|
|
5048
|
+
const k = contextIndices.length;
|
|
5049
|
+
const h = new Array(this.embeddingDim).fill(0);
|
|
5050
|
+
for (const ci of contextIndices) {
|
|
5051
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
5052
|
+
h[d] += this.embeddings[ci][d];
|
|
5053
|
+
}
|
|
5054
|
+
}
|
|
5055
|
+
for (let d = 0; d < this.embeddingDim; d++) h[d] /= k;
|
|
5056
|
+
const scores = this._hiddenToScores(h);
|
|
5057
|
+
const probs = _softmax(scores);
|
|
5058
|
+
const loss = -Math.log(probs[centerIdx] + 1e-12);
|
|
5059
|
+
const err = probs.map((p, j) => j === centerIdx ? p - 1 : p);
|
|
5060
|
+
const dh = new Array(this.embeddingDim).fill(0);
|
|
5061
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
5062
|
+
for (let j = 0; j < this.vocabSize; j++) {
|
|
5063
|
+
this._W2[d][j] -= lr * h[d] * err[j];
|
|
5064
|
+
dh[d] += this._W2[d][j] * err[j];
|
|
5065
|
+
}
|
|
5066
|
+
}
|
|
5067
|
+
for (const ci of contextIndices) {
|
|
5068
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
5069
|
+
this.embeddings[ci][d] -= lr * dh[d] / k;
|
|
5070
|
+
}
|
|
5071
|
+
}
|
|
5072
|
+
return loss;
|
|
5073
|
+
}
|
|
5074
|
+
// Computes scores = h · W2 → [vocabSize]
|
|
5075
|
+
_hiddenToScores(h) {
|
|
5076
|
+
const scores = new Array(this.vocabSize).fill(0);
|
|
5077
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
5078
|
+
for (let j = 0; j < this.vocabSize; j++) {
|
|
5079
|
+
scores[j] += h[d] * this._W2[d][j];
|
|
5080
|
+
}
|
|
5081
|
+
}
|
|
5082
|
+
return scores;
|
|
5083
|
+
}
|
|
5084
|
+
// Returns topK words (from all embeddings) sorted by cosine similarity to v,
|
|
5085
|
+
// skipping any word in the exclude set.
|
|
5086
|
+
_nearestByVector(v, topK, exclude) {
|
|
5087
|
+
const results = [];
|
|
5088
|
+
for (let i = 0; i < this.vocabSize; i++) {
|
|
5089
|
+
const w = this._indexToWord[i];
|
|
5090
|
+
if (exclude.has(w)) continue;
|
|
5091
|
+
results.push({ word: w, score: this._cosine(v, this.embeddings[i]) });
|
|
5092
|
+
}
|
|
5093
|
+
results.sort((a, b) => b.score - a.score);
|
|
5094
|
+
return results.slice(0, topK);
|
|
5095
|
+
}
|
|
5096
|
+
// Cosine similarity: (v1 · v2) / (‖v1‖ · ‖v2‖)
|
|
5097
|
+
_cosine(v1, v2) {
|
|
5098
|
+
let dot = 0, n1 = 0, n2 = 0;
|
|
5099
|
+
for (let i = 0; i < v1.length; i++) {
|
|
5100
|
+
dot += v1[i] * v2[i];
|
|
5101
|
+
n1 += v1[i] * v1[i];
|
|
5102
|
+
n2 += v2[i] * v2[i];
|
|
5103
|
+
}
|
|
5104
|
+
const denom = Math.sqrt(n1) * Math.sqrt(n2);
|
|
5105
|
+
return denom < 1e-12 ? 0 : dot / denom;
|
|
5106
|
+
}
|
|
5107
|
+
};
|
|
5108
|
+
function _softmax(scores) {
|
|
5109
|
+
const max = Math.max(...scores);
|
|
5110
|
+
const exps = scores.map((s) => Math.exp(s - max));
|
|
5111
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
5112
|
+
return exps.map((e) => e / sum);
|
|
5113
|
+
}
|
|
5114
|
+
|
|
5115
|
+
// src/TSNE.ts
|
|
5116
|
+
var TSNE = class {
|
|
5117
|
+
constructor(options = {}) {
|
|
5118
|
+
// KL divergence tracked during the last fit() call.
|
|
5119
|
+
this._klDivergence = 0;
|
|
5120
|
+
// P matrix stored for kl() reporting.
|
|
5121
|
+
this._P = [];
|
|
5122
|
+
this._nComponents = options.nComponents ?? 2;
|
|
5123
|
+
this._perplexity = options.perplexity ?? 30;
|
|
5124
|
+
this._lr = options.lr ?? 200;
|
|
5125
|
+
this._nIter = options.nIter ?? 1e3;
|
|
5126
|
+
this._seed = options.seed;
|
|
5127
|
+
this.embedding = [];
|
|
5128
|
+
}
|
|
5129
|
+
// ── fit ────────────────────────────────────────────────────────────────────
|
|
5130
|
+
// Runs the full t-SNE algorithm on X (shape [n][d]).
|
|
5131
|
+
// Stores the result in this.embedding ([n][nComponents]).
|
|
5132
|
+
fit(X) {
|
|
5133
|
+
const n = X.length;
|
|
5134
|
+
if (n < 2) throw new Error("TSNE.fit: need at least 2 data points");
|
|
5135
|
+
if (this._perplexity >= n) {
|
|
5136
|
+
throw new Error(
|
|
5137
|
+
`TSNE.fit: perplexity (${this._perplexity}) must be less than n (${n})`
|
|
5138
|
+
);
|
|
5139
|
+
}
|
|
5140
|
+
const rng = this._seed !== void 0 ? _mulberry32(this._seed) : Math.random;
|
|
5141
|
+
const distSq = _pairwiseDistSq(X, n);
|
|
5142
|
+
const Pcond = this._computePcond(distSq, n);
|
|
5143
|
+
const P = _symmetrize(Pcond, n);
|
|
5144
|
+
this._P = P;
|
|
5145
|
+
let Y = Array.from({ length: n }, () => {
|
|
5146
|
+
return Array.from({ length: this._nComponents }, () => {
|
|
5147
|
+
const u1 = Math.max(rng(), 1e-12);
|
|
5148
|
+
const u2 = rng();
|
|
5149
|
+
const z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
|
|
5150
|
+
return z * 0.01;
|
|
5151
|
+
});
|
|
5152
|
+
});
|
|
5153
|
+
let Yprev = Y.map((row) => [...row]);
|
|
5154
|
+
const EXAGGERATION_ITERS = 50;
|
|
5155
|
+
const EXAGGERATION_FACTOR = 4;
|
|
5156
|
+
const MOMENTUM_SWITCH = 20;
|
|
5157
|
+
for (let iter = 0; iter < this._nIter; iter++) {
|
|
5158
|
+
const momentum = iter < MOMENTUM_SWITCH ? 0.5 : 0.8;
|
|
5159
|
+
const pScale = iter < EXAGGERATION_ITERS ? EXAGGERATION_FACTOR : 1;
|
|
5160
|
+
const { Q, invDist } = _computeQ(Y, n, this._nComponents);
|
|
5161
|
+
const grad = Array.from(
|
|
5162
|
+
{ length: n },
|
|
5163
|
+
() => new Array(this._nComponents).fill(0)
|
|
5164
|
+
);
|
|
5165
|
+
for (let i = 0; i < n; i++) {
|
|
5166
|
+
for (let j = 0; j < n; j++) {
|
|
5167
|
+
if (i === j) continue;
|
|
5168
|
+
const pq = pScale * P[i][j] - Q[i][j];
|
|
5169
|
+
const c = 4 * pq * invDist[i][j];
|
|
5170
|
+
for (let d = 0; d < this._nComponents; d++) {
|
|
5171
|
+
grad[i][d] += c * (Y[i][d] - Y[j][d]);
|
|
5172
|
+
}
|
|
5173
|
+
}
|
|
5174
|
+
}
|
|
5175
|
+
const Ynext = Array.from(
|
|
5176
|
+
{ length: n },
|
|
5177
|
+
(_, i) => Array.from(
|
|
5178
|
+
{ length: this._nComponents },
|
|
5179
|
+
(_2, d) => Y[i][d] - this._lr * grad[i][d] + momentum * (Y[i][d] - Yprev[i][d])
|
|
5180
|
+
)
|
|
5181
|
+
);
|
|
5182
|
+
Yprev = Y;
|
|
5183
|
+
Y = Ynext;
|
|
5184
|
+
}
|
|
5185
|
+
this.embedding = Y;
|
|
5186
|
+
const { Q: Qfinal } = _computeQ(Y, n, this._nComponents);
|
|
5187
|
+
let kl = 0;
|
|
5188
|
+
for (let i = 0; i < n; i++) {
|
|
5189
|
+
for (let j = 0; j < n; j++) {
|
|
5190
|
+
if (i === j) continue;
|
|
5191
|
+
const p = P[i][j];
|
|
5192
|
+
if (p > 1e-12) {
|
|
5193
|
+
kl += p * Math.log(p / (Qfinal[i][j] + 1e-12));
|
|
5194
|
+
}
|
|
5195
|
+
}
|
|
5196
|
+
}
|
|
5197
|
+
this._klDivergence = kl;
|
|
5198
|
+
}
|
|
5199
|
+
// ── fitTransform ───────────────────────────────────────────────────────────
|
|
5200
|
+
// Convenience: fit() then return this.embedding.
|
|
5201
|
+
fitTransform(X) {
|
|
5202
|
+
this.fit(X);
|
|
5203
|
+
return this.embedding;
|
|
5204
|
+
}
|
|
5205
|
+
// ── kl ─────────────────────────────────────────────────────────────────────
|
|
5206
|
+
// Returns the KL divergence KL(P ‖ Q) from the last fit() call.
|
|
5207
|
+
// Lower is better. Useful for comparing perplexity settings or iteration counts.
|
|
5208
|
+
kl() {
|
|
5209
|
+
return this._klDivergence;
|
|
5210
|
+
}
|
|
5211
|
+
// ── Private: binary search for σi ─────────────────────────────────────────
|
|
5212
|
+
// For each point i, find σi such that the Shannon entropy of P(·|i) equals
|
|
5213
|
+
// log₂(perplexity). We use binary search on σ².
|
|
5214
|
+
_computePcond(distSq, n) {
|
|
5215
|
+
const targetEntropy = Math.log2(this._perplexity);
|
|
5216
|
+
const Pcond = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5217
|
+
for (let i = 0; i < n; i++) {
|
|
5218
|
+
let sigmaLo = 0;
|
|
5219
|
+
let sigmaHi = 1e10;
|
|
5220
|
+
let sigma2 = 1;
|
|
5221
|
+
for (let attempt = 0; attempt < 50; attempt++) {
|
|
5222
|
+
const dists = distSq[i];
|
|
5223
|
+
let sumExp = 0;
|
|
5224
|
+
const exps = new Array(n).fill(0);
|
|
5225
|
+
for (let j = 0; j < n; j++) {
|
|
5226
|
+
if (j === i) continue;
|
|
5227
|
+
const e = Math.exp(-dists[j] / (2 * sigma2));
|
|
5228
|
+
exps[j] = e;
|
|
5229
|
+
sumExp += e;
|
|
5230
|
+
}
|
|
5231
|
+
if (sumExp < 1e-12) break;
|
|
5232
|
+
let H = 0;
|
|
5233
|
+
for (let j = 0; j < n; j++) {
|
|
5234
|
+
if (j === i) continue;
|
|
5235
|
+
const p = exps[j] / sumExp;
|
|
5236
|
+
Pcond[i][j] = p;
|
|
5237
|
+
if (p > 1e-12) H -= p * Math.log2(p);
|
|
5238
|
+
}
|
|
5239
|
+
const delta = H - targetEntropy;
|
|
5240
|
+
if (Math.abs(delta) < 1e-5) break;
|
|
5241
|
+
if (delta > 0) {
|
|
5242
|
+
sigmaHi = sigma2;
|
|
5243
|
+
sigma2 = (sigmaLo + sigma2) / 2;
|
|
5244
|
+
} else {
|
|
5245
|
+
sigmaLo = sigma2;
|
|
5246
|
+
sigma2 = sigmaHi < 1e9 ? (sigma2 + sigmaHi) / 2 : sigma2 * 2;
|
|
5247
|
+
}
|
|
5248
|
+
}
|
|
5249
|
+
}
|
|
5250
|
+
return Pcond;
|
|
5251
|
+
}
|
|
5252
|
+
};
|
|
5253
|
+
function _pairwiseDistSq(X, n) {
|
|
5254
|
+
const D = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5255
|
+
for (let i = 0; i < n; i++) {
|
|
5256
|
+
for (let j = i + 1; j < n; j++) {
|
|
5257
|
+
let d = 0;
|
|
5258
|
+
for (let k = 0; k < X[i].length; k++) {
|
|
5259
|
+
const diff = X[i][k] - X[j][k];
|
|
5260
|
+
d += diff * diff;
|
|
5261
|
+
}
|
|
5262
|
+
D[i][j] = d;
|
|
5263
|
+
D[j][i] = d;
|
|
5264
|
+
}
|
|
5265
|
+
}
|
|
5266
|
+
return D;
|
|
5267
|
+
}
|
|
5268
|
+
function _symmetrize(Pcond, n) {
|
|
5269
|
+
const P = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5270
|
+
for (let i = 0; i < n; i++) {
|
|
5271
|
+
for (let j = 0; j < n; j++) {
|
|
5272
|
+
P[i][j] = (Pcond[i][j] + Pcond[j][i]) / (2 * n);
|
|
5273
|
+
}
|
|
5274
|
+
}
|
|
5275
|
+
return P;
|
|
5276
|
+
}
|
|
5277
|
+
function _computeQ(Y, n, nComponents) {
|
|
5278
|
+
const num = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5279
|
+
let Z = 0;
|
|
5280
|
+
for (let i = 0; i < n; i++) {
|
|
5281
|
+
for (let j = i + 1; j < n; j++) {
|
|
5282
|
+
let d2 = 0;
|
|
5283
|
+
for (let d = 0; d < nComponents; d++) {
|
|
5284
|
+
const diff = Y[i][d] - Y[j][d];
|
|
5285
|
+
d2 += diff * diff;
|
|
5286
|
+
}
|
|
5287
|
+
const inv = 1 / (1 + d2);
|
|
5288
|
+
num[i][j] = inv;
|
|
5289
|
+
num[j][i] = inv;
|
|
5290
|
+
Z += 2 * inv;
|
|
5291
|
+
}
|
|
5292
|
+
}
|
|
5293
|
+
if (Z < 1e-12) Z = 1e-12;
|
|
5294
|
+
const Q = Array.from(
|
|
5295
|
+
{ length: n },
|
|
5296
|
+
(_, i) => num[i].map((v) => v / Z)
|
|
5297
|
+
);
|
|
5298
|
+
return { Q, invDist: num };
|
|
5299
|
+
}
|
|
5300
|
+
function _mulberry32(seed) {
|
|
5301
|
+
let s = seed >>> 0;
|
|
5302
|
+
return function() {
|
|
5303
|
+
s = s + 1831565813 >>> 0;
|
|
5304
|
+
let z = s;
|
|
5305
|
+
z = Math.imul(z ^ z >>> 15, z | 1);
|
|
5306
|
+
z ^= z + Math.imul(z ^ z >>> 7, z | 61);
|
|
5307
|
+
z = (z ^ z >>> 14) >>> 0;
|
|
5308
|
+
return z / 4294967296;
|
|
5309
|
+
};
|
|
5310
|
+
}
|
|
5311
|
+
|
|
5312
|
+
// src/PositionalEncoding.ts
|
|
5313
|
+
var PositionalEncoding = class _PositionalEncoding {
|
|
5314
|
+
// Compute the full PE vector for one token at position `pos`.
|
|
5315
|
+
// Returns an array of length `dModel`.
|
|
5316
|
+
//
|
|
5317
|
+
// Each pair of dimensions (2i, 2i+1) shares the same frequency 1/10000^(2i/dModel)
|
|
5318
|
+
// but is 90° out of phase (sin vs cos), which ensures no two positions produce
|
|
5319
|
+
// the identical vector.
|
|
5320
|
+
static encode(pos, dModel) {
|
|
5321
|
+
const pe = new Array(dModel);
|
|
5322
|
+
for (let i = 0; i < Math.floor(dModel / 2); i++) {
|
|
5323
|
+
const freq = Math.pow(1e4, 2 * i / dModel);
|
|
5324
|
+
pe[2 * i] = Math.sin(pos / freq);
|
|
5325
|
+
pe[2 * i + 1] = Math.cos(pos / freq);
|
|
5326
|
+
}
|
|
5327
|
+
if (dModel % 2 !== 0) {
|
|
5328
|
+
const i = Math.floor(dModel / 2);
|
|
5329
|
+
const freq = Math.pow(1e4, 2 * i / dModel);
|
|
5330
|
+
pe[dModel - 1] = Math.sin(pos / freq);
|
|
5331
|
+
}
|
|
5332
|
+
return pe;
|
|
5333
|
+
}
|
|
5334
|
+
// Build the full positional encoding matrix for a sequence of `seqLen` tokens.
|
|
5335
|
+
// Returns shape [seqLen][dModel].
|
|
5336
|
+
//
|
|
5337
|
+
// In practice this matrix is computed once and cached — it doesn't change
|
|
5338
|
+
// across examples, batches, or epochs.
|
|
5339
|
+
static encodeSequence(seqLen, dModel) {
|
|
5340
|
+
return Array.from(
|
|
5341
|
+
{ length: seqLen },
|
|
5342
|
+
(_, pos) => _PositionalEncoding.encode(pos, dModel)
|
|
5343
|
+
);
|
|
5344
|
+
}
|
|
5345
|
+
// Add positional encoding to an existing embedding matrix (in-place on a copy).
|
|
5346
|
+
//
|
|
5347
|
+
// `embeddings` shape: [seqLen][dModel].
|
|
5348
|
+
// `seqLen` is optional; defaults to embeddings.length.
|
|
5349
|
+
//
|
|
5350
|
+
// The sum e = token_embedding + PE is what actually enters the first
|
|
5351
|
+
// Transformer layer. Summing (rather than concatenating) keeps the model
|
|
5352
|
+
// dimension fixed and lets the network distribute its capacity freely —
|
|
5353
|
+
// it can choose how much of each dimension to allocate to content vs. position.
|
|
5354
|
+
static apply(embeddings, seqLen) {
|
|
5355
|
+
const len = seqLen ?? embeddings.length;
|
|
5356
|
+
const dModel = embeddings[0].length;
|
|
5357
|
+
const pe = _PositionalEncoding.encodeSequence(len, dModel);
|
|
5358
|
+
return embeddings.map(
|
|
5359
|
+
(emb, pos) => emb.map((val, d) => val + pe[pos][d])
|
|
5360
|
+
);
|
|
5361
|
+
}
|
|
5362
|
+
};
|
|
5363
|
+
var LearnedPositionalEncoding = class {
|
|
5364
|
+
constructor(maxSeqLen, dModel) {
|
|
5365
|
+
this.maxSeqLen = maxSeqLen;
|
|
5366
|
+
this.dModel = dModel;
|
|
5367
|
+
const limit = Math.sqrt(1 / dModel);
|
|
5368
|
+
this.weights = Array.from(
|
|
5369
|
+
{ length: maxSeqLen },
|
|
5370
|
+
() => Array.from({ length: dModel }, () => (Math.random() * 2 - 1) * limit)
|
|
5371
|
+
);
|
|
5372
|
+
}
|
|
5373
|
+
// Return the learned encoding for one position.
|
|
5374
|
+
// Returns a copy so callers cannot accidentally mutate the weight table.
|
|
5375
|
+
getEncoding(pos) {
|
|
5376
|
+
if (pos >= this.maxSeqLen) {
|
|
5377
|
+
throw new Error(
|
|
5378
|
+
`Position ${pos} exceeds maxSeqLen=${this.maxSeqLen}. Learned encodings cannot generalize beyond their training length.`
|
|
5379
|
+
);
|
|
5380
|
+
}
|
|
5381
|
+
return [...this.weights[pos]];
|
|
5382
|
+
}
|
|
5383
|
+
// Add learned positional encodings to `embeddings` (returns a new matrix).
|
|
5384
|
+
// Shape: [seqLen][dModel] → [seqLen][dModel].
|
|
5385
|
+
apply(embeddings, seqLen) {
|
|
5386
|
+
const len = seqLen ?? embeddings.length;
|
|
5387
|
+
if (len > this.maxSeqLen) {
|
|
5388
|
+
throw new Error(
|
|
5389
|
+
`Sequence length ${len} exceeds maxSeqLen=${this.maxSeqLen}.`
|
|
5390
|
+
);
|
|
5391
|
+
}
|
|
5392
|
+
return embeddings.map(
|
|
5393
|
+
(emb, pos) => emb.map((val, d) => val + this.weights[pos][d])
|
|
5394
|
+
);
|
|
5395
|
+
}
|
|
5396
|
+
// Apply gradient update to position encoding weights.
|
|
5397
|
+
//
|
|
5398
|
+
// `dWeights` has the same shape as `weights`: [maxSeqLen][dModel].
|
|
5399
|
+
// Each entry is dL/dW_pos[pos][d] — the loss gradient w.r.t. that weight.
|
|
5400
|
+
//
|
|
5401
|
+
// Simple SGD is used here (matching EmbeddingMatrix in MatMul.ts):
|
|
5402
|
+
// position embeddings are updated every step for all positions in the batch,
|
|
5403
|
+
// so the sparse-update problem of token embeddings doesn't apply.
|
|
5404
|
+
update(dWeights, lr) {
|
|
5405
|
+
for (let pos = 0; pos < this.maxSeqLen; pos++) {
|
|
5406
|
+
for (let d = 0; d < this.dModel; d++) {
|
|
5407
|
+
this.weights[pos][d] += lr * dWeights[pos][d];
|
|
5408
|
+
}
|
|
5409
|
+
}
|
|
5410
|
+
}
|
|
5411
|
+
};
|
|
5412
|
+
|
|
5413
|
+
// src/ContrastiveLearning.ts
|
|
5414
|
+
var Augmenter = class _Augmenter {
|
|
5415
|
+
// Add zero-mean Gaussian noise with standard deviation `sigma`.
|
|
5416
|
+
//
|
|
5417
|
+
// Uses the Box-Muller transform to produce normally distributed noise from
|
|
5418
|
+
// two uniform random variables:
|
|
5419
|
+
// z = √(-2·ln(u₁)) · cos(2π·u₂) where u₁, u₂ ~ Uniform(0, 1)
|
|
5420
|
+
//
|
|
5421
|
+
// This keeps us dependency-free while yielding proper Gaussian samples.
|
|
5422
|
+
static addNoise(x, sigma = 0.05) {
|
|
5423
|
+
return x.map((v) => {
|
|
5424
|
+
const u1 = Math.max(1e-10, Math.random());
|
|
5425
|
+
const u2 = Math.random();
|
|
5426
|
+
const z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
|
|
5427
|
+
return v + sigma * z;
|
|
5428
|
+
});
|
|
5429
|
+
}
|
|
5430
|
+
// Randomly zero out features with probability `rate`.
|
|
5431
|
+
//
|
|
5432
|
+
// Analogous to masking in BERT or random crops in vision contrastive learning.
|
|
5433
|
+
// The encoder must learn representations that are robust to missing features —
|
|
5434
|
+
// it cannot simply memorize individual dimensions.
|
|
5435
|
+
static dropoutFeatures(x, rate = 0.1) {
|
|
5436
|
+
return x.map((v) => Math.random() < rate ? 0 : v);
|
|
5437
|
+
}
|
|
5438
|
+
// Apply both noise and feature dropout in sequence.
|
|
5439
|
+
//
|
|
5440
|
+
// Combining augmentations is standard in SimCLR — stronger augmentations
|
|
5441
|
+
// force the encoder to learn more robust, abstract representations.
|
|
5442
|
+
static augment(x, noiseStd = 0.05, dropRate = 0.1) {
|
|
5443
|
+
return _Augmenter.dropoutFeatures(_Augmenter.addNoise(x, noiseStd), dropRate);
|
|
5444
|
+
}
|
|
5445
|
+
// Generate a positive pair: [original, augmented_copy].
|
|
5446
|
+
//
|
|
5447
|
+
// These two views are used as the (i, j) positive pair in NT-Xent.
|
|
5448
|
+
// Everything else in the batch acts as a negative.
|
|
5449
|
+
static makePair(x) {
|
|
5450
|
+
return [x, _Augmenter.augment(x)];
|
|
5451
|
+
}
|
|
5452
|
+
};
|
|
5453
|
+
var ContrastiveLearning = class _ContrastiveLearning {
|
|
5454
|
+
// encoderHidden: hidden layer sizes for the encoder (not counting input/output).
|
|
5455
|
+
// e.g. inputSize=64, encoderHidden=[256, 128] → NetworkN([64, 256, 128])
|
|
5456
|
+
// The encoder output dimension is encoderHidden[last].
|
|
5457
|
+
//
|
|
5458
|
+
// projectionDim: dimension of the projection head output (the z space).
|
|
5459
|
+
// e.g. 64. Typically smaller than the encoder's output.
|
|
5460
|
+
//
|
|
5461
|
+
// The encoder uses ReLU activations throughout — empirically stronger than
|
|
5462
|
+
// sigmoid for representation learning because it doesn't saturate.
|
|
5463
|
+
constructor(inputSize, encoderHidden, projectionDim, options = {}) {
|
|
5464
|
+
if (encoderHidden.length === 0) {
|
|
5465
|
+
throw new Error("encoderHidden must have at least one element.");
|
|
5466
|
+
}
|
|
5467
|
+
this.temperature = options.temperature ?? 0.5;
|
|
5468
|
+
const encoderStructure = [inputSize, ...encoderHidden];
|
|
5469
|
+
const encoderActivations = encoderHidden.map(() => relu);
|
|
5470
|
+
this.encoder = new NetworkN(encoderStructure, {
|
|
5471
|
+
activations: encoderActivations,
|
|
5472
|
+
...options.encoderOptions
|
|
5473
|
+
});
|
|
5474
|
+
const encoderOut = encoderHidden[encoderHidden.length - 1];
|
|
5475
|
+
const projHidden = Math.max(projectionDim, Math.floor(encoderOut / 2));
|
|
5476
|
+
this.projectionHead = new NetworkN(
|
|
5477
|
+
[encoderOut, projHidden, projectionDim],
|
|
5478
|
+
{ activations: [relu, relu] }
|
|
5479
|
+
);
|
|
5480
|
+
}
|
|
5481
|
+
// ── Inference (downstream tasks use this, not project()) ─────────────────
|
|
5482
|
+
//
|
|
5483
|
+
// Returns h — the encoder representation before the projection head.
|
|
5484
|
+
// This is the vector to use for classification, clustering, retrieval, etc.
|
|
5485
|
+
//
|
|
5486
|
+
// The projection head is only active during training.
|
|
5487
|
+
encode(x) {
|
|
5488
|
+
return this.encoder.predict(x);
|
|
5489
|
+
}
|
|
5490
|
+
// ── Training path: encode then project ───────────────────────────────────
|
|
5491
|
+
//
|
|
5492
|
+
// Returns z — the projected representation used to compute NT-Xent.
|
|
5493
|
+
// Do NOT use this for downstream tasks (see encode() above).
|
|
5494
|
+
project(x) {
|
|
5495
|
+
const h = this.encoder.predict(x);
|
|
5496
|
+
return this.projectionHead.predict(h);
|
|
5497
|
+
}
|
|
5498
|
+
// ── Cosine similarity ─────────────────────────────────────────────────────
|
|
5499
|
+
//
|
|
5500
|
+
// sim(u, v) = uᵀv / (||u|| · ||v||)
|
|
5501
|
+
//
|
|
5502
|
+
// Range: [-1, 1]. We use cosine rather than Euclidean distance because it is
|
|
5503
|
+
// scale-invariant — only the direction of the projection matters, not its
|
|
5504
|
+
// magnitude. This prevents the trivial solution of making ||z|| → ∞.
|
|
5505
|
+
static cosineSimilarity(a, b) {
|
|
5506
|
+
let dot = 0, normA = 0, normB = 0;
|
|
5507
|
+
for (let d = 0; d < a.length; d++) {
|
|
5508
|
+
dot += a[d] * b[d];
|
|
5509
|
+
normA += a[d] * a[d];
|
|
5510
|
+
normB += b[d] * b[d];
|
|
5511
|
+
}
|
|
5512
|
+
const denom = Math.sqrt(normA) * Math.sqrt(normB);
|
|
5513
|
+
return denom < 1e-10 ? 0 : dot / denom;
|
|
5514
|
+
}
|
|
5515
|
+
// ── NT-Xent loss (no weight update) ──────────────────────────────────────
|
|
5516
|
+
//
|
|
5517
|
+
// Forward-only pass. Used for validation / monitoring during training.
|
|
5518
|
+
computeLoss(pairs) {
|
|
5519
|
+
const { projections, N } = this._forwardProjections(pairs);
|
|
5520
|
+
return this._ntXentLoss(projections, N);
|
|
5521
|
+
}
|
|
5522
|
+
// ── Training step ─────────────────────────────────────────────────────────
|
|
5523
|
+
//
|
|
5524
|
+
// Given a batch of positive pairs, compute NT-Xent loss and update weights
|
|
5525
|
+
// via finite-difference gradient approximation.
|
|
5526
|
+
//
|
|
5527
|
+
// Full analytical backprop through NT-Xent is complex to implement from
|
|
5528
|
+
// scratch without an autograd engine. Finite differences are slower but
|
|
5529
|
+
// correct and keep the implementation readable for educational purposes.
|
|
5530
|
+
// For production use, couple this with the Tape (autograd) module.
|
|
5531
|
+
//
|
|
5532
|
+
// Step-by-step:
|
|
5533
|
+
// 1. Forward all 2N inputs through encoder + projection head → { z_i }.
|
|
5534
|
+
// 2. Build the 2N×2N cosine similarity matrix (scaled by 1/τ).
|
|
5535
|
+
// 3. For each anchor i, identify its positive pair and all 2N-2 negatives.
|
|
5536
|
+
// 4. Apply softmax over the row; loss = -log(softmax at positive index).
|
|
5537
|
+
// 5. Average over all 2N anchors.
|
|
5538
|
+
// 6. Approximate ∂L/∂w per weight with finite differences and apply update.
|
|
5539
|
+
//
|
|
5540
|
+
// Returns: NT-Xent loss before the weight update.
|
|
5541
|
+
trainStep(pairs, lr) {
|
|
5542
|
+
const loss = this.computeLoss(pairs);
|
|
5543
|
+
const eps = 1e-4;
|
|
5544
|
+
for (const layer of this.encoder.layers) {
|
|
5545
|
+
for (const neuron of layer.neurons) {
|
|
5546
|
+
for (let j = 0; j < neuron.weights.length; j++) {
|
|
5547
|
+
neuron.weights[j] += eps;
|
|
5548
|
+
const lossPlus2 = this.computeLoss(pairs);
|
|
5549
|
+
neuron.weights[j] -= 2 * eps;
|
|
5550
|
+
const lossMinus2 = this.computeLoss(pairs);
|
|
5551
|
+
neuron.weights[j] += eps;
|
|
5552
|
+
const grad2 = (lossPlus2 - lossMinus2) / (2 * eps);
|
|
5553
|
+
neuron.weights[j] += lr * -grad2;
|
|
5554
|
+
}
|
|
5555
|
+
neuron.bias += eps;
|
|
5556
|
+
const lossPlus = this.computeLoss(pairs);
|
|
5557
|
+
neuron.bias -= 2 * eps;
|
|
5558
|
+
const lossMinus = this.computeLoss(pairs);
|
|
5559
|
+
neuron.bias += eps;
|
|
5560
|
+
const grad = (lossPlus - lossMinus) / (2 * eps);
|
|
5561
|
+
neuron.bias += lr * -grad;
|
|
5562
|
+
}
|
|
5563
|
+
}
|
|
5564
|
+
for (const layer of this.projectionHead.layers) {
|
|
5565
|
+
for (const neuron of layer.neurons) {
|
|
5566
|
+
for (let j = 0; j < neuron.weights.length; j++) {
|
|
5567
|
+
neuron.weights[j] += eps;
|
|
5568
|
+
const lossPlus2 = this.computeLoss(pairs);
|
|
5569
|
+
neuron.weights[j] -= 2 * eps;
|
|
5570
|
+
const lossMinus2 = this.computeLoss(pairs);
|
|
5571
|
+
neuron.weights[j] += eps;
|
|
5572
|
+
const grad2 = (lossPlus2 - lossMinus2) / (2 * eps);
|
|
5573
|
+
neuron.weights[j] += lr * -grad2;
|
|
5574
|
+
}
|
|
5575
|
+
neuron.bias += eps;
|
|
5576
|
+
const lossPlus = this.computeLoss(pairs);
|
|
5577
|
+
neuron.bias -= 2 * eps;
|
|
5578
|
+
const lossMinus = this.computeLoss(pairs);
|
|
5579
|
+
neuron.bias += eps;
|
|
5580
|
+
const grad = (lossPlus - lossMinus) / (2 * eps);
|
|
5581
|
+
neuron.bias += lr * -grad;
|
|
5582
|
+
}
|
|
5583
|
+
}
|
|
5584
|
+
return loss;
|
|
5585
|
+
}
|
|
5586
|
+
// ── Private: forward all pairs through the projection head ───────────────
|
|
5587
|
+
//
|
|
5588
|
+
// Returns a flat array of 2N projections.
|
|
5589
|
+
// Layout: [ z_0, z_0', z_1, z_1', ..., z_{N-1}, z_{N-1}' ]
|
|
5590
|
+
// Even indices 2i → original view of pair i
|
|
5591
|
+
// Odd indices 2i+1 → augmented view of pair i (the positive)
|
|
5592
|
+
_forwardProjections(pairs) {
|
|
5593
|
+
const N = pairs.length;
|
|
5594
|
+
const projections = [];
|
|
5595
|
+
for (const [x, xAug] of pairs) {
|
|
5596
|
+
projections.push(this.project(x));
|
|
5597
|
+
projections.push(this.project(xAug));
|
|
5598
|
+
}
|
|
5599
|
+
return { projections, N };
|
|
5600
|
+
}
|
|
5601
|
+
// ── Private: NT-Xent loss over a set of 2N projections ───────────────────
|
|
5602
|
+
//
|
|
5603
|
+
// pairs[2i] and pairs[2i+1] are positives.
|
|
5604
|
+
// All other 2N-2 samples are negatives for each anchor.
|
|
5605
|
+
_ntXentLoss(projections, N) {
|
|
5606
|
+
const total = 2 * N;
|
|
5607
|
+
const tau = this.temperature;
|
|
5608
|
+
const sim = Array.from(
|
|
5609
|
+
{ length: total },
|
|
5610
|
+
(_, i) => Array.from(
|
|
5611
|
+
{ length: total },
|
|
5612
|
+
(_2, j) => _ContrastiveLearning.cosineSimilarity(projections[i], projections[j]) / tau
|
|
5613
|
+
)
|
|
5614
|
+
);
|
|
5615
|
+
let totalLoss = 0;
|
|
5616
|
+
for (let i = 0; i < total; i++) {
|
|
5617
|
+
const posIdx = i % 2 === 0 ? i + 1 : i - 1;
|
|
5618
|
+
const numerator = Math.exp(sim[i][posIdx]);
|
|
5619
|
+
let denominator = 0;
|
|
5620
|
+
for (let k = 0; k < total; k++) {
|
|
5621
|
+
if (k !== i) {
|
|
5622
|
+
denominator += Math.exp(sim[i][k]);
|
|
5623
|
+
}
|
|
5624
|
+
}
|
|
5625
|
+
totalLoss += -Math.log(numerator / (denominator + 1e-10));
|
|
5626
|
+
}
|
|
5627
|
+
return totalLoss / total;
|
|
5628
|
+
}
|
|
5629
|
+
};
|
|
5630
|
+
|
|
4739
5631
|
// src/GAN.ts
|
|
4740
5632
|
var GAN = class {
|
|
4741
5633
|
constructor(latentDim, generatorHidden, outputDim, discriminatorHidden, options) {
|
|
@@ -5272,6 +6164,216 @@ function _binaryRecall(yTrue, yPred, pos) {
|
|
|
5272
6164
|
return tp + fn > 0 ? tp / (tp + fn) : 0;
|
|
5273
6165
|
}
|
|
5274
6166
|
|
|
6167
|
+
// src/Tokenizer.ts
|
|
6168
|
+
var _Tokenizer = class _Tokenizer {
|
|
6169
|
+
constructor(options = {}) {
|
|
6170
|
+
this._token2id = /* @__PURE__ */ new Map();
|
|
6171
|
+
this._id2token = /* @__PURE__ */ new Map();
|
|
6172
|
+
this._fitted = false;
|
|
6173
|
+
this._mode = options.mode ?? "word";
|
|
6174
|
+
this._lowercase = options.lowercase ?? true;
|
|
6175
|
+
this._maxVocab = options.maxVocab ?? 0;
|
|
6176
|
+
this._extraSpecial = options.specialTokens ?? [];
|
|
6177
|
+
}
|
|
6178
|
+
// ── Fit ───────────────────────────────────────────────────────────────────
|
|
6179
|
+
/**
|
|
6180
|
+
* Build vocabulary from an array of text strings.
|
|
6181
|
+
* Calling fit() again resets and rebuilds the vocabulary from scratch.
|
|
6182
|
+
*
|
|
6183
|
+
* @param texts - corpus to build the vocabulary from
|
|
6184
|
+
* @returns this (chainable)
|
|
6185
|
+
*/
|
|
6186
|
+
fit(texts) {
|
|
6187
|
+
this._token2id = /* @__PURE__ */ new Map();
|
|
6188
|
+
this._id2token = /* @__PURE__ */ new Map();
|
|
6189
|
+
const specials = [
|
|
6190
|
+
_Tokenizer.PAD,
|
|
6191
|
+
_Tokenizer.UNK,
|
|
6192
|
+
_Tokenizer.BOS,
|
|
6193
|
+
_Tokenizer.EOS,
|
|
6194
|
+
...this._extraSpecial
|
|
6195
|
+
];
|
|
6196
|
+
for (const s of specials) this._register(s);
|
|
6197
|
+
const freq = /* @__PURE__ */ new Map();
|
|
6198
|
+
for (const text of texts) {
|
|
6199
|
+
for (const token of this.tokenize(text)) {
|
|
6200
|
+
freq.set(token, (freq.get(token) ?? 0) + 1);
|
|
6201
|
+
}
|
|
6202
|
+
}
|
|
6203
|
+
let entries = [...freq.entries()].sort(
|
|
6204
|
+
([a, fa], [b, fb]) => fb - fa || a.localeCompare(b)
|
|
6205
|
+
);
|
|
6206
|
+
if (this._maxVocab > 0) {
|
|
6207
|
+
entries = entries.slice(0, this._maxVocab - specials.length);
|
|
6208
|
+
}
|
|
6209
|
+
for (const [token] of entries) this._register(token);
|
|
6210
|
+
this._fitted = true;
|
|
6211
|
+
return this;
|
|
6212
|
+
}
|
|
6213
|
+
// ── Tokenize ──────────────────────────────────────────────────────────────
|
|
6214
|
+
/**
|
|
6215
|
+
* Split raw text into an array of string tokens (no ID conversion yet).
|
|
6216
|
+
* Useful for inspecting what the tokenizer produces before encoding.
|
|
6217
|
+
*/
|
|
6218
|
+
tokenize(text) {
|
|
6219
|
+
const t = this._lowercase ? text.toLowerCase() : text;
|
|
6220
|
+
switch (this._mode) {
|
|
6221
|
+
case "char":
|
|
6222
|
+
return t.split("");
|
|
6223
|
+
case "whitespace":
|
|
6224
|
+
return t.split(/\s+/).filter(Boolean);
|
|
6225
|
+
case "word":
|
|
6226
|
+
default:
|
|
6227
|
+
return t.match(/[a-z0-9àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ]+|[^\w\s]/gi) ?? [];
|
|
6228
|
+
}
|
|
6229
|
+
}
|
|
6230
|
+
// ── Encode ────────────────────────────────────────────────────────────────
|
|
6231
|
+
/**
|
|
6232
|
+
* Convert a text string to a sequence of token IDs.
|
|
6233
|
+
* Unknown tokens map to <UNK> (id 1).
|
|
6234
|
+
*
|
|
6235
|
+
* @param text - input text
|
|
6236
|
+
* @param options - addBOS / addEOS flags
|
|
6237
|
+
*/
|
|
6238
|
+
encode(text, options = {}) {
|
|
6239
|
+
this._assertFitted();
|
|
6240
|
+
const ids = [];
|
|
6241
|
+
if (options.addBOS) ids.push(this._token2id.get(_Tokenizer.BOS));
|
|
6242
|
+
for (const token of this.tokenize(text)) {
|
|
6243
|
+
ids.push(this._token2id.get(token) ?? this._token2id.get(_Tokenizer.UNK));
|
|
6244
|
+
}
|
|
6245
|
+
if (options.addEOS) ids.push(this._token2id.get(_Tokenizer.EOS));
|
|
6246
|
+
return ids;
|
|
6247
|
+
}
|
|
6248
|
+
// ── Encode batch ──────────────────────────────────────────────────────────
|
|
6249
|
+
/**
|
|
6250
|
+
* Encode an array of texts, optionally padding/truncating to a fixed length.
|
|
6251
|
+
*
|
|
6252
|
+
* @param texts - array of input texts
|
|
6253
|
+
* @param options - addBOS / addEOS / padTo
|
|
6254
|
+
*/
|
|
6255
|
+
encodeBatch(texts, options = {}) {
|
|
6256
|
+
const sequences = texts.map((t) => this.encode(t, options));
|
|
6257
|
+
if (options.padTo !== void 0) {
|
|
6258
|
+
const len = options.padTo;
|
|
6259
|
+
const padId = this._token2id.get(_Tokenizer.PAD);
|
|
6260
|
+
return sequences.map((seq) => {
|
|
6261
|
+
if (seq.length >= len) return seq.slice(0, len);
|
|
6262
|
+
return [...seq, ...Array(len - seq.length).fill(padId)];
|
|
6263
|
+
});
|
|
6264
|
+
}
|
|
6265
|
+
return sequences;
|
|
6266
|
+
}
|
|
6267
|
+
// ── Decode ────────────────────────────────────────────────────────────────
|
|
6268
|
+
/**
|
|
6269
|
+
* Convert a sequence of token IDs back to a human-readable string.
|
|
6270
|
+
*
|
|
6271
|
+
* @param ids - array of token IDs
|
|
6272
|
+
* @param stripSpecial - remove PAD/BOS/EOS tokens from output. Default: true
|
|
6273
|
+
*/
|
|
6274
|
+
decode(ids, stripSpecial = true) {
|
|
6275
|
+
this._assertFitted();
|
|
6276
|
+
const specials = /* @__PURE__ */ new Set([_Tokenizer.PAD, _Tokenizer.BOS, _Tokenizer.EOS]);
|
|
6277
|
+
const tokens = [];
|
|
6278
|
+
for (const id of ids) {
|
|
6279
|
+
const token = this._id2token.get(id) ?? _Tokenizer.UNK;
|
|
6280
|
+
if (stripSpecial && specials.has(token)) continue;
|
|
6281
|
+
tokens.push(token);
|
|
6282
|
+
}
|
|
6283
|
+
return this._mode === "char" ? tokens.join("") : tokens.join(" ");
|
|
6284
|
+
}
|
|
6285
|
+
// ── One-hot encoding ──────────────────────────────────────────────────────
|
|
6286
|
+
/**
|
|
6287
|
+
* Convert a sequence of token IDs to one-hot vectors.
|
|
6288
|
+
* Each vector has length `vocabSize` with a single 1 at the token's position.
|
|
6289
|
+
* Useful when feeding tokens directly into a Network without an embedding layer.
|
|
6290
|
+
*
|
|
6291
|
+
* @param ids - array of token IDs (e.g. from encode())
|
|
6292
|
+
* @returns - 2D array of shape [seqLen, vocabSize]
|
|
6293
|
+
*/
|
|
6294
|
+
oneHot(ids) {
|
|
6295
|
+
this._assertFitted();
|
|
6296
|
+
const V = this.vocabSize;
|
|
6297
|
+
return ids.map((id) => {
|
|
6298
|
+
const vec = new Array(V).fill(0);
|
|
6299
|
+
if (id >= 0 && id < V) vec[id] = 1;
|
|
6300
|
+
return vec;
|
|
6301
|
+
});
|
|
6302
|
+
}
|
|
6303
|
+
// ── Vocabulary helpers ────────────────────────────────────────────────────
|
|
6304
|
+
/** Number of tokens in the vocabulary (including special tokens). */
|
|
6305
|
+
get vocabSize() {
|
|
6306
|
+
return this._token2id.size;
|
|
6307
|
+
}
|
|
6308
|
+
/** True if fit() has been called at least once. */
|
|
6309
|
+
get isFitted() {
|
|
6310
|
+
return this._fitted;
|
|
6311
|
+
}
|
|
6312
|
+
/** Get the integer ID for a token string, or undefined if not in vocabulary. */
|
|
6313
|
+
tokenToId(token) {
|
|
6314
|
+
return this._token2id.get(token);
|
|
6315
|
+
}
|
|
6316
|
+
/** Get the token string for an integer ID, or undefined if out of range. */
|
|
6317
|
+
idToToken(id) {
|
|
6318
|
+
return this._id2token.get(id);
|
|
6319
|
+
}
|
|
6320
|
+
/**
|
|
6321
|
+
* Return the full vocabulary as an array ordered by ID.
|
|
6322
|
+
* Index i of the returned array is the token with ID i.
|
|
6323
|
+
*/
|
|
6324
|
+
getVocabulary() {
|
|
6325
|
+
return Array.from({ length: this.vocabSize }, (_, i) => this._id2token.get(i));
|
|
6326
|
+
}
|
|
6327
|
+
// ── Persistence ───────────────────────────────────────────────────────────
|
|
6328
|
+
/**
|
|
6329
|
+
* Serialize the fitted tokenizer to a plain JSON-compatible object.
|
|
6330
|
+
* Store it with JSON.stringify(); reload with Tokenizer.fromJSON().
|
|
6331
|
+
*/
|
|
6332
|
+
toJSON() {
|
|
6333
|
+
this._assertFitted();
|
|
6334
|
+
return {
|
|
6335
|
+
mode: this._mode,
|
|
6336
|
+
lowercase: this._lowercase,
|
|
6337
|
+
maxVocab: this._maxVocab,
|
|
6338
|
+
token2id: Object.fromEntries(this._token2id)
|
|
6339
|
+
};
|
|
6340
|
+
}
|
|
6341
|
+
/**
|
|
6342
|
+
* Restore a Tokenizer from a snapshot produced by toJSON().
|
|
6343
|
+
*/
|
|
6344
|
+
static fromJSON(snapshot) {
|
|
6345
|
+
const tok = new _Tokenizer({
|
|
6346
|
+
mode: snapshot.mode,
|
|
6347
|
+
lowercase: snapshot.lowercase,
|
|
6348
|
+
maxVocab: snapshot.maxVocab
|
|
6349
|
+
});
|
|
6350
|
+
for (const [token, id] of Object.entries(snapshot.token2id)) {
|
|
6351
|
+
tok._token2id.set(token, id);
|
|
6352
|
+
tok._id2token.set(id, token);
|
|
6353
|
+
}
|
|
6354
|
+
tok._fitted = true;
|
|
6355
|
+
return tok;
|
|
6356
|
+
}
|
|
6357
|
+
// ── Private ───────────────────────────────────────────────────────────────
|
|
6358
|
+
_register(token) {
|
|
6359
|
+
if (this._token2id.has(token)) return;
|
|
6360
|
+
const id = this._token2id.size;
|
|
6361
|
+
this._token2id.set(token, id);
|
|
6362
|
+
this._id2token.set(id, token);
|
|
6363
|
+
}
|
|
6364
|
+
_assertFitted() {
|
|
6365
|
+
if (!this._fitted) {
|
|
6366
|
+
throw new Error("Tokenizer: call fit() before encoding or decoding.");
|
|
6367
|
+
}
|
|
6368
|
+
}
|
|
6369
|
+
};
|
|
6370
|
+
// ── Built-in special tokens ────────────────────────────────────────────────
|
|
6371
|
+
_Tokenizer.PAD = "<PAD>";
|
|
6372
|
+
_Tokenizer.UNK = "<UNK>";
|
|
6373
|
+
_Tokenizer.BOS = "<BOS>";
|
|
6374
|
+
_Tokenizer.EOS = "<EOS>";
|
|
6375
|
+
var Tokenizer = _Tokenizer;
|
|
6376
|
+
|
|
5275
6377
|
// src/EarlyStopping.ts
|
|
5276
6378
|
var EarlyStopping = class {
|
|
5277
6379
|
constructor(options) {
|
|
@@ -5546,16 +6648,19 @@ function _sampleNormal() {
|
|
|
5546
6648
|
export {
|
|
5547
6649
|
Adam,
|
|
5548
6650
|
AttentionHead,
|
|
6651
|
+
Augmenter,
|
|
5549
6652
|
Autoencoder,
|
|
5550
6653
|
BatchNorm,
|
|
5551
6654
|
BiasVector,
|
|
5552
6655
|
CausalConv1D,
|
|
5553
6656
|
ClipOptimizer,
|
|
5554
6657
|
ClippedOptimizerFactory,
|
|
6658
|
+
ContrastiveLearning,
|
|
5555
6659
|
Conv1D,
|
|
5556
6660
|
Conv2D,
|
|
5557
6661
|
DataAugmentation,
|
|
5558
6662
|
DataLoader,
|
|
6663
|
+
DatasetLoader,
|
|
5559
6664
|
DecisionTree,
|
|
5560
6665
|
Dropout,
|
|
5561
6666
|
EarlyStopping,
|
|
@@ -5570,6 +6675,7 @@ export {
|
|
|
5570
6675
|
LSTMLayer,
|
|
5571
6676
|
Layer,
|
|
5572
6677
|
LayerNorm,
|
|
6678
|
+
LearnedPositionalEncoding,
|
|
5573
6679
|
LinearRegression,
|
|
5574
6680
|
LogisticRegression,
|
|
5575
6681
|
LossPlotter,
|
|
@@ -5586,18 +6692,22 @@ export {
|
|
|
5586
6692
|
NeuronN,
|
|
5587
6693
|
PCA,
|
|
5588
6694
|
Perceptron,
|
|
6695
|
+
PositionalEncoding,
|
|
5589
6696
|
RNN,
|
|
5590
6697
|
SGD,
|
|
5591
6698
|
SOM,
|
|
5592
6699
|
Seq2Seq,
|
|
5593
6700
|
SoftmaxRegression,
|
|
5594
6701
|
TCN,
|
|
6702
|
+
TSNE,
|
|
6703
|
+
Tokenizer,
|
|
5595
6704
|
Trainer,
|
|
5596
6705
|
TransformerBlock,
|
|
5597
6706
|
VAE,
|
|
5598
6707
|
Value,
|
|
5599
6708
|
WeightInspector,
|
|
5600
6709
|
WeightMatrix,
|
|
6710
|
+
Word2Vec,
|
|
5601
6711
|
accuracy,
|
|
5602
6712
|
auc,
|
|
5603
6713
|
classificationReport,
|