@epfml/discojs 3.0.1-p20250429140233.0 → 3.0.1-p20250625140656.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ import { GPT } from './index.js';
2
+ import { PreTrainedTokenizer } from '@xenova/transformers';
3
+ import { ONNXModel } from './onnx.js';
4
+ export declare const HELLASWAG_URL = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl";
5
+ /**
6
+ * Represents a single example from the HellaSwag dataset.
7
+ *
8
+ * ctx - The context sentence or paragraph that sets up the situation.
9
+ * endings - An array of four possible continuations of the context.
10
+ * label - The index (0–3) of the correct ending in the `endings` array.
11
+ */
12
+ export interface HellaSwagExample {
13
+ ctx: string;
14
+ endings: string[];
15
+ label: number;
16
+ }
17
+ export type HellaSwagDataset = HellaSwagExample[];
18
+ type Tokenizer = PreTrainedTokenizer;
19
+ type ModelType = GPT | ONNXModel;
20
+ /**
21
+ * Evaluates the model on a given HellaSwag dataset.
22
+ *
23
+ * @param model - The model to evaluate (GPT or ONNXModel)
24
+ * @param tokenizer - The tokenizer to use
25
+ * @param dataset - An array of HellaSwagExample to evaluate on
26
+ * @param limit - Number of examples to evaluate (default: all)
27
+ * @param print - Whether to print results (default: true)
28
+ * @returns The accuracy of the model on the dataset
29
+ */
30
+ export declare function evaluate(model: ModelType, tokenizer: Tokenizer, dataset: HellaSwagExample[], print?: boolean): Promise<number>;
31
+ export {};
@@ -0,0 +1,120 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { GPT } from './index.js';
3
+ import { tokenize } from '../processing/text.js';
4
+ import { List } from 'immutable';
5
+ export const HELLASWAG_URL = 'https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl';
6
+ // Computes the log likelihood of the input sequence using the tfjs model
7
+ // The input sequence is expected to be a concatenation of the context and the ending
8
+ // The function computes the log likelihood of each ending and returns the one with the loss of each ending
9
+ // Sources:
10
+ // https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
11
+ //https://www.youtube.com/watch?v=l8pRSuU81PU
12
+ async function computeLogLikelihood(gpt, inputIds, ctxLength) {
13
+ const lossTensor = tf.tidy(() => {
14
+ // Convert input sequence to shape [1, seq_len]
15
+ const inputTensor = tf.tensor2d([inputIds], [1, inputIds.length], 'int32');
16
+ // Get model logits: [1, seq_len, vocab_size]
17
+ const logits3D = gpt.extract().predict(inputTensor);
18
+ // Shift logits to align with next-token targets
19
+ const shiftedLogits = logits3D.slice([0, 0, 0], [1, inputIds.length - 1, -1]);
20
+ // Target tokens (next tokens), same length as shifted logits
21
+ const shiftedTargets = inputIds.slice(1);
22
+ const targetTensor = tf.tensor1d(shiftedTargets, 'int32');
23
+ // One-hot encode targets for cross-entropy loss
24
+ const oneHotLabels = tf.oneHot(targetTensor, shiftedLogits.shape[2]);
25
+ // Compute per-token cross-entropy log-probabilities (unnormalized loss)
26
+ const logProbs = tf.losses.softmaxCrossEntropy(oneHotLabels, shiftedLogits.squeeze());
27
+ // Create a mask to only include loss after the context length
28
+ const mask = tf.tensor1d(inputIds.map((_, i) => (i >= ctxLength ? 1 : 0)), 'float32').slice(1);
29
+ // Apply the mask and average over the selected tokens
30
+ const masked = logProbs.mul(mask);
31
+ const loss = masked.sum().div(mask.sum());
32
+ return loss;
33
+ });
34
+ const lossNumber = await lossTensor.array();
35
+ if (typeof lossNumber !== 'number') {
36
+ throw new Error('got multiple loss');
37
+ }
38
+ return lossNumber;
39
+ }
40
+ // Computes the log likelihood of the input sequence using the ONNX model
41
+ // The input sequence is expected to be a concatenation of the context and the ending
42
+ // The function computes the log likelihood of each ending and returns the one with the loss of each ending
43
+ // Sources:
44
+ // https://github.com/karpathy/build-nanogpt/blob/master/hellaswag.py
45
+ // https://www.youtube.com/watch?v=l8pRSuU81PU
46
+ async function computeONNXLogLikelihood(model, inputIds, ctxLength) {
47
+ const batchInput = List([List(inputIds)]); // [1, seq_len]
48
+ // Run model to get logits: flattened [T * V]
49
+ const logitsTensor = await model.getLogits(batchInput);
50
+ const logits = logitsTensor.data;
51
+ const [_B, T, V] = logitsTensor.dims;
52
+ // Reshape flattened logits into [T][V]
53
+ const reshaped = Array.from({ length: T }, (_, t) => logits.slice(t * V, (t + 1) * V));
54
+ // Shift targets (next-token prediction)
55
+ const targets = inputIds.slice(1); // length = T - 1
56
+ const logitsShifted = reshaped.slice(0, T - 1); // also length = T - 1
57
+ // Compute per-token cross-entropy loss manually
58
+ const losses = logitsShifted.map((logit, i) => {
59
+ const maxLogit = Math.max(...logit); // for numerical stability
60
+ const exp = logit.map(x => Math.exp(x - maxLogit));
61
+ const sumExp = exp.reduce((a, b) => a + b, 0);
62
+ const probs = exp.map(e => e / sumExp); // softmax
63
+ return -Math.log(probs[targets[i]]); // cross-entropy loss
64
+ });
65
+ // Create a binary mask for non-context tokens
66
+ const mask = inputIds.map((_, i) => (i >= ctxLength ? 1 : 0)).slice(1);
67
+ // Apply the mask to the losses
68
+ const maskedLosses = losses.map((l, i) => l * mask[i]);
69
+ // Average the masked losses
70
+ const totalLoss = maskedLosses.reduce((a, b) => a + b, 0);
71
+ const sum = mask.reduce((a, b) => a + b, 0);
72
+ return totalLoss / (sum || 1); // avoid division by 0
73
+ }
74
+ /**
75
+ * Evaluates the model on a given HellaSwag dataset.
76
+ *
77
+ * @param model - The model to evaluate (GPT or ONNXModel)
78
+ * @param tokenizer - The tokenizer to use
79
+ * @param dataset - An array of HellaSwagExample to evaluate on
80
+ * @param limit - Number of examples to evaluate (default: all)
81
+ * @param print - Whether to print results (default: true)
82
+ * @returns The accuracy of the model on the dataset
83
+ */
84
+ export async function evaluate(model, tokenizer, dataset, print = true) {
85
+ let correct = 0;
86
+ let total = 0;
87
+ for (const example of dataset) {
88
+ const endingTokens = example.endings.map(e => tokenize(tokenizer, example.ctx + ' ' + e, {
89
+ truncation: true,
90
+ max_length: 128
91
+ }).toArray());
92
+ const ctxTokens = tokenize(tokenizer, example.ctx, {
93
+ truncation: true,
94
+ max_length: 128
95
+ }).toArray();
96
+ let losses = [];
97
+ if (model instanceof GPT) {
98
+ losses = await Promise.all(endingTokens.map(e => computeLogLikelihood(model, e, ctxTokens.length)));
99
+ }
100
+ else {
101
+ losses = await Promise.all(endingTokens.map(e => computeONNXLogLikelihood(model, e, ctxTokens.length)));
102
+ }
103
+ const pred = losses.indexOf(Math.min(...losses));
104
+ if (pred === example.label)
105
+ correct++;
106
+ total++;
107
+ if (print) {
108
+ console.log(`\nExample #${total}`);
109
+ console.log(`Context: ${example.ctx}`);
110
+ example.endings.forEach((end, i) => {
111
+ console.log(` ${i}: ${end} (loss: ${losses[i].toFixed(4)})${i === example.label ? ' <-- correct' : ''}${i === pred ? ' <-- picked' : ''}`);
112
+ });
113
+ const accuracy_temp = correct / total;
114
+ console.log(`\n Accuracy on ${total} examples: ${(accuracy_temp * 100).toFixed(2)}%`);
115
+ }
116
+ }
117
+ const accuracy = correct / total;
118
+ console.log(`\nFinal accuracy on ${total} examples: ${(accuracy * 100).toFixed(2)}%`);
119
+ return accuracy;
120
+ }
@@ -1,6 +1,9 @@
1
1
  export { Model } from './model.js';
2
2
  export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js";
3
3
  export { GPT } from './gpt/index.js';
4
+ export { ONNXModel } from './onnx.js';
4
5
  export { GPTConfig } from './gpt/config.js';
6
+ export { evaluate as evaluate_hellaswag } from './hellaswag.js';
5
7
  export { TFJS } from './tfjs.js';
6
8
  export { getTaskTokenizer } from './tokenizer.js';
9
+ export { evaluate, HellaSwagDataset, HellaSwagExample, HELLASWAG_URL } from './hellaswag.js';
@@ -1,5 +1,8 @@
1
1
  export { Model } from './model.js';
2
2
  export { EpochLogs } from "./logs.js";
3
3
  export { GPT } from './gpt/index.js';
4
+ export { ONNXModel } from './onnx.js';
5
+ export { evaluate as evaluate_hellaswag } from './hellaswag.js';
4
6
  export { TFJS } from './tfjs.js';
5
7
  export { getTaskTokenizer } from './tokenizer.js';
8
+ export { evaluate, HELLASWAG_URL } from './hellaswag.js';
@@ -0,0 +1,19 @@
1
+ import { Tensor } from '@xenova/transformers';
2
+ import { Model } from './index.js';
3
+ import type { WeightsContainer } from '../index.js';
4
+ import { List } from 'immutable';
5
+ import type { GenerationConfig as TFJSGenerationConfig } from './gpt/config.js';
6
+ import type { Batched, DataFormat } from "../index.js";
7
+ export declare class ONNXModel extends Model<'text'> {
8
+ #private;
9
+ private model;
10
+ private constructor();
11
+ static init_pretrained(modelName?: string): Promise<ONNXModel>;
12
+ getConfig(): Record<string, unknown>;
13
+ predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<TFJSGenerationConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
14
+ getLogits(batch: List<List<number>>): Promise<Tensor>;
15
+ train(): AsyncGenerator<never, never>;
16
+ get weights(): WeightsContainer;
17
+ set weights(_: WeightsContainer);
18
+ [Symbol.dispose](): void;
19
+ }
@@ -0,0 +1,71 @@
1
+ import { AutoModelForCausalLM, Tensor } from '@xenova/transformers';
2
+ import { Model } from './index.js';
3
+ import { List } from 'immutable';
4
+ import { DefaultGenerationConfig } from './gpt/config.js';
5
+ export class ONNXModel extends Model {
6
+ model;
7
+ constructor(model) {
8
+ super();
9
+ this.model = model;
10
+ }
11
+ static async init_pretrained(modelName = 'Xenova/gpt2') {
12
+ const model = await AutoModelForCausalLM.from_pretrained(modelName);
13
+ return new ONNXModel(model);
14
+ }
15
+ getConfig() {
16
+ return this.model.config;
17
+ }
18
+ async predict(batch, options) {
19
+ const config = Object.assign({}, DefaultGenerationConfig, options);
20
+ return List(await Promise.all(batch.map(tokens => this.#predictSingle(tokens, config))));
21
+ }
22
+ async #predictSingle(tokens, config) {
23
+ const contextLength = this.model.config.max_position_embeddings ?? 1024;
24
+ const truncated = tokens.slice(-contextLength).toArray();
25
+ if (truncated.length === 0) {
26
+ throw new Error('Token list is empty. Cannot run generate().');
27
+ }
28
+ const input_ids = new Tensor('int64', truncated.map(BigInt), [1, truncated.length]);
29
+ const output = await this.model.generate(input_ids, {
30
+ max_new_tokens: 1,
31
+ temperature: config.temperature,
32
+ do_sample: config.doSample,
33
+ top_k: config.topk,
34
+ });
35
+ if (!Array.isArray(output) || output.length === 0 || !Array.isArray(output[0])) {
36
+ throw new Error('ONNX model.generate() did not return valid sequences.');
37
+ }
38
+ const predicted_id = output[0].at(-1);
39
+ return Number(predicted_id);
40
+ }
41
+ async getLogits(batch) {
42
+ const input_ids_array = batch.toArray().map(seq => seq.toArray());
43
+ const attention_mask_array = input_ids_array.map((seq) => new Array(seq.length).fill(1));
44
+ const input_ids_flat = input_ids_array.flat();
45
+ const attention_mask_flat = attention_mask_array.flat();
46
+ const shape = [input_ids_array.length, input_ids_array[0].length];
47
+ // use BigInt for int64 compatibility
48
+ const input_ids = new Tensor('int64', input_ids_flat.map(BigInt), shape);
49
+ const attention_mask = new Tensor('int64', attention_mask_flat.map(BigInt), shape);
50
+ // run model forward
51
+ const outputs = await this.model.forward({ input_ids, attention_mask });
52
+ return outputs.logits;
53
+ }
54
+ async *train() {
55
+ await Promise.resolve(); // dummy await
56
+ const yieldFlag = false;
57
+ if (yieldFlag)
58
+ yield undefined; // satisfy 'require-yield'
59
+ throw new Error('Training not supported for ONNX models');
60
+ }
61
+ get weights() {
62
+ throw new Error('Weights access not supported in ONNX models');
63
+ }
64
+ set weights(_) {
65
+ throw new Error('Weights setting not supported in ONNX models');
66
+ }
67
+ [Symbol.dispose]() {
68
+ // Dispose of the model to free up memory
69
+ void this.model.dispose();
70
+ }
71
+ }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20250429140233.0",
3
+ "version": "3.0.1-p20250625140656.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",