@epfml/discojs 3.0.1-p20250402145848.0 → 3.0.1-p20250625140656.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/models/gpt/layers.d.ts +67 -0
- package/dist/models/gpt/layers.js +59 -45
- package/dist/models/hellaswag.d.ts +31 -0
- package/dist/models/hellaswag.js +120 -0
- package/dist/models/index.d.ts +3 -0
- package/dist/models/index.js +3 -0
- package/dist/models/onnx.d.ts +19 -0
- package/dist/models/onnx.js +71 -0
- package/package.json +1 -1
|
@@ -1,5 +1,72 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
2
|
import type { GPTConfig } from './config.js';
|
|
3
|
+
import type { ModelSize } from './config.js';
|
|
4
|
+
/**
|
|
5
|
+
* Defines a range, from 0 to T, that is used to create positional embeddings
|
|
6
|
+
*/
|
|
7
|
+
export declare class Range extends tf.layers.Layer {
|
|
8
|
+
static readonly className = "Range";
|
|
9
|
+
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
10
|
+
call(input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[];
|
|
11
|
+
}
|
|
12
|
+
export type CausalSelfAttentionConfig = ConstructorParameters<typeof tf.layers.Layer>[0] & Record<'contextLength' | 'nHead' | 'nEmbd' | 'dropout' | 'nLayer' | 'seed', number>;
|
|
13
|
+
export declare class CausalSelfAttention extends tf.layers.Layer {
|
|
14
|
+
private readonly config;
|
|
15
|
+
static readonly className = "CausalSelfAttention";
|
|
16
|
+
private readonly nHead;
|
|
17
|
+
private readonly nEmbd;
|
|
18
|
+
private readonly nLayer;
|
|
19
|
+
private readonly dropout;
|
|
20
|
+
private readonly seed;
|
|
21
|
+
private readonly mask;
|
|
22
|
+
cAttnKernel?: tf.LayerVariable;
|
|
23
|
+
cAttnBias?: tf.LayerVariable;
|
|
24
|
+
cProjKernel?: tf.LayerVariable;
|
|
25
|
+
cProjBias?: tf.LayerVariable;
|
|
26
|
+
constructor(config: CausalSelfAttentionConfig);
|
|
27
|
+
build(): void;
|
|
28
|
+
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
29
|
+
getConfig(): tf.serialization.ConfigDict;
|
|
30
|
+
call(input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor;
|
|
31
|
+
dense(x: tf.Tensor, kernel: tf.LayerVariable, bias: tf.LayerVariable): tf.Tensor;
|
|
32
|
+
splitHeads(x: tf.Tensor, B: number, T: number, nHead: number): tf.Tensor;
|
|
33
|
+
applyCausalMask(att: tf.Tensor, T: number): tf.Tensor;
|
|
34
|
+
computeAttention(q: tf.Tensor, k: tf.Tensor, training: boolean, T: number): tf.Tensor;
|
|
35
|
+
}
|
|
36
|
+
/**
|
|
37
|
+
* GELU with tanh approximate
|
|
38
|
+
* GELU(x) = x * 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
|
|
39
|
+
*
|
|
40
|
+
* https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
|
41
|
+
*/
|
|
42
|
+
export declare class GELU extends tf.layers.Layer {
|
|
43
|
+
static readonly className = "GELU";
|
|
44
|
+
constructor();
|
|
45
|
+
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
46
|
+
call(input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[];
|
|
47
|
+
}
|
|
48
|
+
export type MLPConfig = ConstructorParameters<typeof tf.layers.Layer>[0] & Required<ModelSize> & Record<'contextLength' | 'residDrop' | 'nLayer' | 'seed', number>;
|
|
49
|
+
export declare function MLP(config: MLPConfig): tf.LayersModel;
|
|
50
|
+
/**
|
|
51
|
+
* LanguageModelEmbedding is a layer that combines the token embeddings and the language modeling head
|
|
52
|
+
* I.e. LMEmbedding is used to translate token indices into token embeddings
|
|
53
|
+
* as well as to project embeddings back into token indices
|
|
54
|
+
* The GPT2 model uses the same embedding matrix for both the token embeddings and the language modeling head
|
|
55
|
+
* Because Tensorflow.js doesn't offer an easy weight sharing mechanism, we need to define a custom layer
|
|
56
|
+
* that can be used for both the token embeddings and the language modeling head.
|
|
57
|
+
* In the GPT2 model definition, this layers corresponds to wte and lm_head (which reuses wte)
|
|
58
|
+
*/
|
|
59
|
+
export declare class LMEmbedding extends tf.layers.Layer {
|
|
60
|
+
private readonly vocabSize;
|
|
61
|
+
private readonly nEmbd;
|
|
62
|
+
private readonly seed;
|
|
63
|
+
static readonly className = "LMEmbedding";
|
|
64
|
+
embeddings?: tf.LayerVariable;
|
|
65
|
+
constructor(vocabSize: number, nEmbd: number, seed: number);
|
|
66
|
+
build(): void;
|
|
67
|
+
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
68
|
+
call(input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[];
|
|
69
|
+
}
|
|
3
70
|
/**
|
|
4
71
|
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
5
72
|
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
@@ -4,7 +4,7 @@ const debug = createDebug("discojs:models:gpt:layers");
|
|
|
4
4
|
/**
|
|
5
5
|
* Defines a range, from 0 to T, that is used to create positional embeddings
|
|
6
6
|
*/
|
|
7
|
-
class Range extends tf.layers.Layer {
|
|
7
|
+
export class Range extends tf.layers.Layer {
|
|
8
8
|
static className = 'Range';
|
|
9
9
|
computeOutputShape(inputShape) {
|
|
10
10
|
return inputShape;
|
|
@@ -56,7 +56,7 @@ class LogLayer extends tf.layers.Layer {
|
|
|
56
56
|
}
|
|
57
57
|
}
|
|
58
58
|
tf.serialization.registerClass(LogLayer);
|
|
59
|
-
class CausalSelfAttention extends tf.layers.Layer {
|
|
59
|
+
export class CausalSelfAttention extends tf.layers.Layer {
|
|
60
60
|
config;
|
|
61
61
|
static className = 'CausalSelfAttention';
|
|
62
62
|
nHead;
|
|
@@ -86,8 +86,7 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
86
86
|
}
|
|
87
87
|
build() {
|
|
88
88
|
// key, query, value projections for all heads, but in a batch
|
|
89
|
-
this.cAttnKernel = this.addWeight('c_attn.weight', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed })
|
|
90
|
-
);
|
|
89
|
+
this.cAttnKernel = this.addWeight('c_attn.weight', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed }));
|
|
91
90
|
this.cAttnBias = this.addWeight('c_attn.bias', [3 * this.nEmbd], 'float32', tf.initializers.zeros());
|
|
92
91
|
// output projection
|
|
93
92
|
this.cProjKernel = this.addWeight('c_proj.kernel', [this.nEmbd, this.nEmbd], 'float32',
|
|
@@ -97,7 +96,9 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
97
96
|
// https://github.com/karpathy/build-nanogpt/blob/6104ab1b53920f6e2159749676073ff7d815c1fa/train_gpt2.py#L103
|
|
98
97
|
// https://youtu.be/l8pRSuU81PU?si=5GcKfi_kPgLgvtg2&t=4640
|
|
99
98
|
tf.initializers.randomNormal({
|
|
100
|
-
mean: 0,
|
|
99
|
+
mean: 0,
|
|
100
|
+
stddev: 0.02 * Math.sqrt(2 * this.nLayer),
|
|
101
|
+
seed: this.seed
|
|
101
102
|
}));
|
|
102
103
|
this.cProjBias = this.addWeight('c_proj.bias', [this.nEmbd], 'float32', tf.initializers.zeros());
|
|
103
104
|
}
|
|
@@ -122,59 +123,72 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
122
123
|
input = input[0];
|
|
123
124
|
}
|
|
124
125
|
this.invokeCallHook(input, kwargs);
|
|
125
|
-
|
|
126
|
-
// TODO: use broadcasting when tfjs will support backpropagating through broadcasting
|
|
127
|
-
const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]);
|
|
128
|
-
const m = x.matMul(k);
|
|
129
|
-
return tf.add(m, bias.read());
|
|
130
|
-
};
|
|
126
|
+
// --- Use helper methods below to build the computation ---
|
|
131
127
|
// Apply attention weights to inputs as one big matrix which is then split into the
|
|
132
128
|
// query, key and value submatrices
|
|
133
129
|
// nHead is "number of heads", hs is "head size", and C (number of channels) = n_embd = nHead * hs
|
|
134
|
-
// e.g. in GPT-2 (124M), nHead = 12, hs = 64, so nHead * hs = C = 768 channels in the Transformer
|
|
135
|
-
const cAttn = dense(input, this.cAttnKernel, this.cAttnBias);
|
|
130
|
+
// e.g. in GPT-2 (124M), nHead = 12, hs = 64, so nHead * hs = C = 768 channels in the Transformer const cAttn = dense(input, this.cAttnKernel, this.cAttnBias);
|
|
131
|
+
const cAttn = this.dense(input, this.cAttnKernel, this.cAttnBias);
|
|
136
132
|
let [q, k, v] = tf.split(cAttn, 3, -1);
|
|
137
133
|
// Follow naming conventions in https://github.com/karpathy/build-nanogpt/
|
|
138
134
|
const [B, T, C] = k.shape; // batch size, sequence length, embedding dimensionality (number of channels)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
);
|
|
142
|
-
|
|
143
|
-
k = splitHeads(k); // (B, nHead, T, hs)
|
|
144
|
-
v = splitHeads(v); // (B, nHead, T, hs)
|
|
135
|
+
// Split into attention heads.
|
|
136
|
+
q = this.splitHeads(q, B, T, this.nHead);
|
|
137
|
+
k = this.splitHeads(k, B, T, this.nHead);
|
|
138
|
+
v = this.splitHeads(v, B, T, this.nHead);
|
|
145
139
|
// Scaled self attention: query @ key / sqrt(hs)
|
|
146
140
|
// Matrix representing the token-to-token attention (B, nHead, T, T)
|
|
147
|
-
|
|
148
|
-
tf.div(1, tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32'))) // 1 / sqrt(hs)
|
|
149
|
-
);
|
|
150
|
-
/**
|
|
151
|
-
* The next operations apply attention only on the past tokens, which is
|
|
152
|
-
* essentially a weighted average of the past tokens with complicated weights,
|
|
153
|
-
* it relies on a mask to not "pay any attention" to future tokens
|
|
154
|
-
*/
|
|
155
|
-
// mask is lower triangular matrix filled with 1
|
|
156
|
-
const mask = this.mask.slice([0, 0], [T, T]); // (T, T)
|
|
157
|
-
// 1 - mask => upper triangular matrix filled with 1
|
|
158
|
-
// (1 - mask) * -10^9 => upper triangular matrix filled with -inf
|
|
159
|
-
// att + ((1 - mask) * -10^9) => lower triangular part is the same as the `att` matrix
|
|
160
|
-
// upper triangular part is -inf
|
|
161
|
-
att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9)); // (B, nHead, T, T)
|
|
162
|
-
// applying softmax zeroes out the upper triangular part (softmax(-inf) = 0)
|
|
163
|
-
// i.e., zeroes out future tokens's attention weights
|
|
164
|
-
// and creates a probability distribution for the lower triangular
|
|
165
|
-
// (attention weights of past tokens). The probability distribution ensures
|
|
166
|
-
// that the attention weights of past tokens for a particular token sum to one
|
|
167
|
-
att = tf.softmax(att, -1);
|
|
168
|
-
att = kwargs.training === true ? tf.dropout(att, this.dropout, undefined, this.seed) : att;
|
|
141
|
+
const att = this.computeAttention(q, k, kwargs.training === true, T);
|
|
169
142
|
// This is where the (attention-)weighted sum of past values is performed
|
|
170
143
|
let y = tf.matMul(att, v); // (B, nHead, T, T) x (B, nHead, T, hs) -> (B, nHead, T, hs)
|
|
171
144
|
y = tf.transpose(y, [0, 2, 1, 3]); // (B, T, nHead, hs)
|
|
172
145
|
y = tf.reshape(y, [B, T, C]); // (B, T, C = nHead * hs)
|
|
173
|
-
y = dense(y, this.cProjKernel, this.cProjBias); // output projection (B, T, C)
|
|
146
|
+
y = this.dense(y, this.cProjKernel, this.cProjBias); // output projection (B, T, C)
|
|
174
147
|
y = kwargs.training === true ? tf.dropout(y, this.dropout, undefined, this.seed) : y;
|
|
175
148
|
return y;
|
|
176
149
|
});
|
|
177
150
|
}
|
|
151
|
+
// --- Helper Methods ---
|
|
152
|
+
dense(x, kernel, bias) {
|
|
153
|
+
const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]);
|
|
154
|
+
const m = x.matMul(k);
|
|
155
|
+
return tf.add(m, bias.read());
|
|
156
|
+
}
|
|
157
|
+
splitHeads(x, B, T, nHead) {
|
|
158
|
+
return tf.transpose(tf.reshape(x, [B, T, nHead, (x.shape[2] ?? 0) / nHead]), [0, 2, 1, 3]);
|
|
159
|
+
}
|
|
160
|
+
applyCausalMask(att, T) {
|
|
161
|
+
// mask is lower triangular matrix filled with 1
|
|
162
|
+
const mask = this.mask.slice([0, 0], [T, T]);
|
|
163
|
+
// 1 - mask => upper triangular matrix filled with 1
|
|
164
|
+
// (1 - mask) * -10^9 => upper triangular matrix filled with -inf
|
|
165
|
+
// att + ((1 - mask) * -10^9) => lower triangular part is the same as the `att` matrix
|
|
166
|
+
// upper triangular part is -inf
|
|
167
|
+
return tf.add(att, tf.mul(tf.sub(1, mask), -1e9)); // (B, nHead, T, T)
|
|
168
|
+
}
|
|
169
|
+
computeAttention(q, k, training, T) {
|
|
170
|
+
/**
|
|
171
|
+
* The next operations apply attention only on the past tokens, which is
|
|
172
|
+
* essentially a weighted average of the past tokens with complicated weights,
|
|
173
|
+
* it relies on a mask to not "pay any attention" to future tokens
|
|
174
|
+
*/
|
|
175
|
+
const headSize = k.shape[k.shape.length - 1];
|
|
176
|
+
// Scaled self attention: query @ key / sqrt(hs)
|
|
177
|
+
// Matrix representing the token-to-token attention (B, nHead, T, T)
|
|
178
|
+
let att = tf.matMul(q, k, false, true); // (B, nHead, T, hs) x (B, nHead, hs, T) -> (B, nHead, T, T)
|
|
179
|
+
att = tf.mul(att, tf.div(1, tf.sqrt(tf.cast(headSize, 'float32')))); // 1 / sqrt(hs)
|
|
180
|
+
att = this.applyCausalMask(att, T);
|
|
181
|
+
// applying softmax zeroes out the upper triangular part (softmax(-inf) = 0)
|
|
182
|
+
// i.e., zeroes out future tokens's attention weights
|
|
183
|
+
// and creates a probability distribution for the lower triangular
|
|
184
|
+
// (attention weights of past tokens). The probability distribution ensures
|
|
185
|
+
// that the attention weights of past tokens for a particular token sum to one
|
|
186
|
+
att = tf.softmax(att, -1);
|
|
187
|
+
if (training) {
|
|
188
|
+
att = tf.dropout(att, this.dropout, undefined, this.seed);
|
|
189
|
+
}
|
|
190
|
+
return att;
|
|
191
|
+
}
|
|
178
192
|
}
|
|
179
193
|
tf.serialization.registerClass(CausalSelfAttention);
|
|
180
194
|
/**
|
|
@@ -183,7 +197,7 @@ tf.serialization.registerClass(CausalSelfAttention);
|
|
|
183
197
|
*
|
|
184
198
|
* https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
|
185
199
|
*/
|
|
186
|
-
class GELU extends tf.layers.Layer {
|
|
200
|
+
export class GELU extends tf.layers.Layer {
|
|
187
201
|
static className = 'GELU';
|
|
188
202
|
constructor() {
|
|
189
203
|
super({});
|
|
@@ -210,7 +224,7 @@ class GELU extends tf.layers.Layer {
|
|
|
210
224
|
}
|
|
211
225
|
}
|
|
212
226
|
tf.serialization.registerClass(GELU);
|
|
213
|
-
function MLP(config) {
|
|
227
|
+
export function MLP(config) {
|
|
214
228
|
return tf.sequential({ layers: [
|
|
215
229
|
tf.layers.dense({
|
|
216
230
|
name: config.name + `.mlp.c_fc`,
|
|
@@ -298,7 +312,7 @@ function TransformerBlock(conf) {
|
|
|
298
312
|
* that can be used for both the token embeddings and the language modeling head.
|
|
299
313
|
* In the GPT2 model definition, this layers corresponds to wte and lm_head (which reuses wte)
|
|
300
314
|
*/
|
|
301
|
-
class LMEmbedding extends tf.layers.Layer {
|
|
315
|
+
export class LMEmbedding extends tf.layers.Layer {
|
|
302
316
|
vocabSize;
|
|
303
317
|
nEmbd;
|
|
304
318
|
seed;
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import { GPT } from './index.js';
|
|
2
|
+
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
3
|
+
import { ONNXModel } from './onnx.js';
|
|
4
|
+
export declare const HELLASWAG_URL = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl";
|
|
5
|
+
/**
|
|
6
|
+
* Represents a single example from the HellaSwag dataset.
|
|
7
|
+
*
|
|
8
|
+
* ctx - The context sentence or paragraph that sets up the situation.
|
|
9
|
+
* endings - An array of four possible continuations of the context.
|
|
10
|
+
* label - The index (0–3) of the correct ending in the `endings` array.
|
|
11
|
+
*/
|
|
12
|
+
export interface HellaSwagExample {
|
|
13
|
+
ctx: string;
|
|
14
|
+
endings: string[];
|
|
15
|
+
label: number;
|
|
16
|
+
}
|
|
17
|
+
export type HellaSwagDataset = HellaSwagExample[];
|
|
18
|
+
type Tokenizer = PreTrainedTokenizer;
|
|
19
|
+
type ModelType = GPT | ONNXModel;
|
|
20
|
+
/**
|
|
21
|
+
* Evaluates the model on a given HellaSwag dataset.
|
|
22
|
+
*
|
|
23
|
+
* @param model - The model to evaluate (GPT or ONNXModel)
|
|
24
|
+
* @param tokenizer - The tokenizer to use
|
|
25
|
+
* @param dataset - An array of HellaSwagExample to evaluate on
|
|
26
|
+
* @param limit - Number of examples to evaluate (default: all)
|
|
27
|
+
* @param print - Whether to print results (default: true)
|
|
28
|
+
* @returns The accuracy of the model on the dataset
|
|
29
|
+
*/
|
|
30
|
+
export declare function evaluate(model: ModelType, tokenizer: Tokenizer, dataset: HellaSwagExample[], print?: boolean): Promise<number>;
|
|
31
|
+
export {};
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { GPT } from './index.js';
|
|
3
|
+
import { tokenize } from '../processing/text.js';
|
|
4
|
+
import { List } from 'immutable';
|
|
5
|
+
export const HELLASWAG_URL = 'https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl';
|
|
6
|
+
// Computes the log likelihood of the input sequence using the tfjs model
|
|
7
|
+
// The input sequence is expected to be a concatenation of the context and the ending
|
|
8
|
+
// The function computes the log likelihood of each ending and returns the one with the loss of each ending
|
|
9
|
+
// Sources:
|
|
10
|
+
// https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
|
|
11
|
+
//https://www.youtube.com/watch?v=l8pRSuU81PU
|
|
12
|
+
async function computeLogLikelihood(gpt, inputIds, ctxLength) {
|
|
13
|
+
const lossTensor = tf.tidy(() => {
|
|
14
|
+
// Convert input sequence to shape [1, seq_len]
|
|
15
|
+
const inputTensor = tf.tensor2d([inputIds], [1, inputIds.length], 'int32');
|
|
16
|
+
// Get model logits: [1, seq_len, vocab_size]
|
|
17
|
+
const logits3D = gpt.extract().predict(inputTensor);
|
|
18
|
+
// Shift logits to align with next-token targets
|
|
19
|
+
const shiftedLogits = logits3D.slice([0, 0, 0], [1, inputIds.length - 1, -1]);
|
|
20
|
+
// Target tokens (next tokens), same length as shifted logits
|
|
21
|
+
const shiftedTargets = inputIds.slice(1);
|
|
22
|
+
const targetTensor = tf.tensor1d(shiftedTargets, 'int32');
|
|
23
|
+
// One-hot encode targets for cross-entropy loss
|
|
24
|
+
const oneHotLabels = tf.oneHot(targetTensor, shiftedLogits.shape[2]);
|
|
25
|
+
// Compute per-token cross-entropy log-probabilities (unnormalized loss)
|
|
26
|
+
const logProbs = tf.losses.softmaxCrossEntropy(oneHotLabels, shiftedLogits.squeeze());
|
|
27
|
+
// Create a mask to only include loss after the context length
|
|
28
|
+
const mask = tf.tensor1d(inputIds.map((_, i) => (i >= ctxLength ? 1 : 0)), 'float32').slice(1);
|
|
29
|
+
// Apply the mask and average over the selected tokens
|
|
30
|
+
const masked = logProbs.mul(mask);
|
|
31
|
+
const loss = masked.sum().div(mask.sum());
|
|
32
|
+
return loss;
|
|
33
|
+
});
|
|
34
|
+
const lossNumber = await lossTensor.array();
|
|
35
|
+
if (typeof lossNumber !== 'number') {
|
|
36
|
+
throw new Error('got multiple loss');
|
|
37
|
+
}
|
|
38
|
+
return lossNumber;
|
|
39
|
+
}
|
|
40
|
+
// Computes the log likelihood of the input sequence using the ONNX model
|
|
41
|
+
// The input sequence is expected to be a concatenation of the context and the ending
|
|
42
|
+
// The function computes the log likelihood of each ending and returns the one with the loss of each ending
|
|
43
|
+
// Sources:
|
|
44
|
+
// https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
|
|
45
|
+
// https://www.youtube.com/watch?v=l8pRSuU81PU
|
|
46
|
+
async function computeONNXLogLikelihood(model, inputIds, ctxLength) {
|
|
47
|
+
const batchInput = List([List(inputIds)]); // [1, seq_len]
|
|
48
|
+
// Run model to get logits: flattened [T * V]
|
|
49
|
+
const logitsTensor = await model.getLogits(batchInput);
|
|
50
|
+
const logits = logitsTensor.data;
|
|
51
|
+
const [_B, T, V] = logitsTensor.dims;
|
|
52
|
+
// Reshape flattened logits into [T][V]
|
|
53
|
+
const reshaped = Array.from({ length: T }, (_, t) => logits.slice(t * V, (t + 1) * V));
|
|
54
|
+
// Shift targets (next-token prediction)
|
|
55
|
+
const targets = inputIds.slice(1); // length = T - 1
|
|
56
|
+
const logitsShifted = reshaped.slice(0, T - 1); // also length = T - 1
|
|
57
|
+
// Compute per-token cross-entropy loss manually
|
|
58
|
+
const losses = logitsShifted.map((logit, i) => {
|
|
59
|
+
const maxLogit = Math.max(...logit); // for numerical stability
|
|
60
|
+
const exp = logit.map(x => Math.exp(x - maxLogit));
|
|
61
|
+
const sumExp = exp.reduce((a, b) => a + b, 0);
|
|
62
|
+
const probs = exp.map(e => e / sumExp); // softmax
|
|
63
|
+
return -Math.log(probs[targets[i]]); // cross-entropy loss
|
|
64
|
+
});
|
|
65
|
+
// Create a binary mask for non-context tokens
|
|
66
|
+
const mask = inputIds.map((_, i) => (i >= ctxLength ? 1 : 0)).slice(1);
|
|
67
|
+
// Apply the mask to the losses
|
|
68
|
+
const maskedLosses = losses.map((l, i) => l * mask[i]);
|
|
69
|
+
// Average the masked losses
|
|
70
|
+
const totalLoss = maskedLosses.reduce((a, b) => a + b, 0);
|
|
71
|
+
const sum = mask.reduce((a, b) => a + b, 0);
|
|
72
|
+
return totalLoss / (sum || 1); // avoid division by 0
|
|
73
|
+
}
|
|
74
|
+
/**
|
|
75
|
+
* Evaluates the model on a given HellaSwag dataset.
|
|
76
|
+
*
|
|
77
|
+
* @param model - The model to evaluate (GPT or ONNXModel)
|
|
78
|
+
* @param tokenizer - The tokenizer to use
|
|
79
|
+
* @param dataset - An array of HellaSwagExample to evaluate on
|
|
80
|
+
* @param limit - Number of examples to evaluate (default: all)
|
|
81
|
+
* @param print - Whether to print results (default: true)
|
|
82
|
+
* @returns The accuracy of the model on the dataset
|
|
83
|
+
*/
|
|
84
|
+
export async function evaluate(model, tokenizer, dataset, print = true) {
|
|
85
|
+
let correct = 0;
|
|
86
|
+
let total = 0;
|
|
87
|
+
for (const example of dataset) {
|
|
88
|
+
const endingTokens = example.endings.map(e => tokenize(tokenizer, example.ctx + ' ' + e, {
|
|
89
|
+
truncation: true,
|
|
90
|
+
max_length: 128
|
|
91
|
+
}).toArray());
|
|
92
|
+
const ctxTokens = tokenize(tokenizer, example.ctx, {
|
|
93
|
+
truncation: true,
|
|
94
|
+
max_length: 128
|
|
95
|
+
}).toArray();
|
|
96
|
+
let losses = [];
|
|
97
|
+
if (model instanceof GPT) {
|
|
98
|
+
losses = await Promise.all(endingTokens.map(e => computeLogLikelihood(model, e, ctxTokens.length)));
|
|
99
|
+
}
|
|
100
|
+
else {
|
|
101
|
+
losses = await Promise.all(endingTokens.map(e => computeONNXLogLikelihood(model, e, ctxTokens.length)));
|
|
102
|
+
}
|
|
103
|
+
const pred = losses.indexOf(Math.min(...losses));
|
|
104
|
+
if (pred === example.label)
|
|
105
|
+
correct++;
|
|
106
|
+
total++;
|
|
107
|
+
if (print) {
|
|
108
|
+
console.log(`\nExample #${total}`);
|
|
109
|
+
console.log(`Context: ${example.ctx}`);
|
|
110
|
+
example.endings.forEach((end, i) => {
|
|
111
|
+
console.log(` ${i}: ${end} (loss: ${losses[i].toFixed(4)})${i === example.label ? ' <-- correct' : ''}${i === pred ? ' <-- picked' : ''}`);
|
|
112
|
+
});
|
|
113
|
+
const accuracy_temp = correct / total;
|
|
114
|
+
console.log(`\n Accuracy on ${total} examples: ${(accuracy_temp * 100).toFixed(2)}%`);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
const accuracy = correct / total;
|
|
118
|
+
console.log(`\nFinal accuracy on ${total} examples: ${(accuracy * 100).toFixed(2)}%`);
|
|
119
|
+
return accuracy;
|
|
120
|
+
}
|
package/dist/models/index.d.ts
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
export { Model } from './model.js';
|
|
2
2
|
export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js";
|
|
3
3
|
export { GPT } from './gpt/index.js';
|
|
4
|
+
export { ONNXModel } from './onnx.js';
|
|
4
5
|
export { GPTConfig } from './gpt/config.js';
|
|
6
|
+
export { evaluate as evaluate_hellaswag } from './hellaswag.js';
|
|
5
7
|
export { TFJS } from './tfjs.js';
|
|
6
8
|
export { getTaskTokenizer } from './tokenizer.js';
|
|
9
|
+
export { evaluate, HellaSwagDataset, HellaSwagExample, HELLASWAG_URL } from './hellaswag.js';
|
package/dist/models/index.js
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
export { Model } from './model.js';
|
|
2
2
|
export { EpochLogs } from "./logs.js";
|
|
3
3
|
export { GPT } from './gpt/index.js';
|
|
4
|
+
export { ONNXModel } from './onnx.js';
|
|
5
|
+
export { evaluate as evaluate_hellaswag } from './hellaswag.js';
|
|
4
6
|
export { TFJS } from './tfjs.js';
|
|
5
7
|
export { getTaskTokenizer } from './tokenizer.js';
|
|
8
|
+
export { evaluate, HELLASWAG_URL } from './hellaswag.js';
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { Tensor } from '@xenova/transformers';
|
|
2
|
+
import { Model } from './index.js';
|
|
3
|
+
import type { WeightsContainer } from '../index.js';
|
|
4
|
+
import { List } from 'immutable';
|
|
5
|
+
import type { GenerationConfig as TFJSGenerationConfig } from './gpt/config.js';
|
|
6
|
+
import type { Batched, DataFormat } from "../index.js";
|
|
7
|
+
export declare class ONNXModel extends Model<'text'> {
|
|
8
|
+
#private;
|
|
9
|
+
private model;
|
|
10
|
+
private constructor();
|
|
11
|
+
static init_pretrained(modelName?: string): Promise<ONNXModel>;
|
|
12
|
+
getConfig(): Record<string, unknown>;
|
|
13
|
+
predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<TFJSGenerationConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
|
|
14
|
+
getLogits(batch: List<List<number>>): Promise<Tensor>;
|
|
15
|
+
train(): AsyncGenerator<never, never>;
|
|
16
|
+
get weights(): WeightsContainer;
|
|
17
|
+
set weights(_: WeightsContainer);
|
|
18
|
+
[Symbol.dispose](): void;
|
|
19
|
+
}
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import { AutoModelForCausalLM, Tensor } from '@xenova/transformers';
|
|
2
|
+
import { Model } from './index.js';
|
|
3
|
+
import { List } from 'immutable';
|
|
4
|
+
import { DefaultGenerationConfig } from './gpt/config.js';
|
|
5
|
+
export class ONNXModel extends Model {
|
|
6
|
+
model;
|
|
7
|
+
constructor(model) {
|
|
8
|
+
super();
|
|
9
|
+
this.model = model;
|
|
10
|
+
}
|
|
11
|
+
static async init_pretrained(modelName = 'Xenova/gpt2') {
|
|
12
|
+
const model = await AutoModelForCausalLM.from_pretrained(modelName);
|
|
13
|
+
return new ONNXModel(model);
|
|
14
|
+
}
|
|
15
|
+
getConfig() {
|
|
16
|
+
return this.model.config;
|
|
17
|
+
}
|
|
18
|
+
async predict(batch, options) {
|
|
19
|
+
const config = Object.assign({}, DefaultGenerationConfig, options);
|
|
20
|
+
return List(await Promise.all(batch.map(tokens => this.#predictSingle(tokens, config))));
|
|
21
|
+
}
|
|
22
|
+
async #predictSingle(tokens, config) {
|
|
23
|
+
const contextLength = this.model.config.max_position_embeddings ?? 1024;
|
|
24
|
+
const truncated = tokens.slice(-contextLength).toArray();
|
|
25
|
+
if (truncated.length === 0) {
|
|
26
|
+
throw new Error('Token list is empty. Cannot run generate().');
|
|
27
|
+
}
|
|
28
|
+
const input_ids = new Tensor('int64', truncated.map(BigInt), [1, truncated.length]);
|
|
29
|
+
const output = await this.model.generate(input_ids, {
|
|
30
|
+
max_new_tokens: 1,
|
|
31
|
+
temperature: config.temperature,
|
|
32
|
+
do_sample: config.doSample,
|
|
33
|
+
top_k: config.topk,
|
|
34
|
+
});
|
|
35
|
+
if (!Array.isArray(output) || output.length === 0 || !Array.isArray(output[0])) {
|
|
36
|
+
throw new Error('ONNX model.generate() did not return valid sequences.');
|
|
37
|
+
}
|
|
38
|
+
const predicted_id = output[0].at(-1);
|
|
39
|
+
return Number(predicted_id);
|
|
40
|
+
}
|
|
41
|
+
async getLogits(batch) {
|
|
42
|
+
const input_ids_array = batch.toArray().map(seq => seq.toArray());
|
|
43
|
+
const attention_mask_array = input_ids_array.map((seq) => new Array(seq.length).fill(1));
|
|
44
|
+
const input_ids_flat = input_ids_array.flat();
|
|
45
|
+
const attention_mask_flat = attention_mask_array.flat();
|
|
46
|
+
const shape = [input_ids_array.length, input_ids_array[0].length];
|
|
47
|
+
// use BigInt for int64 compatibility
|
|
48
|
+
const input_ids = new Tensor('int64', input_ids_flat.map(BigInt), shape);
|
|
49
|
+
const attention_mask = new Tensor('int64', attention_mask_flat.map(BigInt), shape);
|
|
50
|
+
// run model forward
|
|
51
|
+
const outputs = await this.model.forward({ input_ids, attention_mask });
|
|
52
|
+
return outputs.logits;
|
|
53
|
+
}
|
|
54
|
+
async *train() {
|
|
55
|
+
await Promise.resolve(); // dummy await
|
|
56
|
+
const yieldFlag = false;
|
|
57
|
+
if (yieldFlag)
|
|
58
|
+
yield undefined; // satisfy 'require-yield'
|
|
59
|
+
throw new Error('Training not supported for ONNX models');
|
|
60
|
+
}
|
|
61
|
+
get weights() {
|
|
62
|
+
throw new Error('Weights access not supported in ONNX models');
|
|
63
|
+
}
|
|
64
|
+
set weights(_) {
|
|
65
|
+
throw new Error('Weights setting not supported in ONNX models');
|
|
66
|
+
}
|
|
67
|
+
[Symbol.dispose]() {
|
|
68
|
+
// Dispose of the model to free up memory
|
|
69
|
+
void this.model.dispose();
|
|
70
|
+
}
|
|
71
|
+
}
|