@epfml/discojs 3.0.1-p20241203151748.0 → 3.0.1-p20241206133538.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,15 +25,22 @@ export declare class Dataset<T> implements AsyncIterable<T> {
25
25
  * @param ratio between 0 (all on left) and 1 (all on right)
26
26
  */
27
27
  split(ratio: number): [Dataset<T>, Dataset<T>];
28
- /** Slice into chunks
28
+ /** Create batches of `size` elements with potential overlap.
29
+ * Last batch is smaller if dataset isn't perfectly divisible
29
30
  *
30
- * Last slice is smaller if dataset isn't perfectly divisible
31
+ * If overlap is set to a positive integer, the last `overlap` elements of a batch
32
+ * are the first `overlap` elements of the next batch.
33
+ *
34
+ * This method is tailored to create text sequences where each token's label is the following token.
35
+ * In order to have a label for the last token of the input sequence, we include the first token
36
+ * of the next sequence (i.e. with an overlap of 1).
31
37
  *
32
38
  * @param size count of element per chunk
39
+ * @param overlap number of elements overlapping between two consecutive batches
33
40
  */
34
- batch(size: number): Dataset<Batched<T>>;
35
- /** Flatten chunks */
36
- unbatch<U>(this: Dataset<Batched<U>>): Dataset<U>;
41
+ batch(size: number, overlap?: number): Dataset<Batched<T>>;
42
+ /** Flatten batches/arrays of elements */
43
+ flatten<U>(this: Dataset<DatasetLike<U>>): Dataset<U>;
37
44
  /** Join side-by-side
38
45
  *
39
46
  * Stops as soon as one runs out
@@ -41,6 +48,12 @@ export declare class Dataset<T> implements AsyncIterable<T> {
41
48
  * @param other right side
42
49
  **/
43
50
  zip<U>(other: Dataset<U> | DatasetLike<U>): Dataset<[T, U]>;
51
+ /**
52
+ * Repeat the dataset `times` times
53
+ * @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
54
+ * @returns a dataset repeated `times` times
55
+ */
56
+ repeat(times?: number): Dataset<T>;
44
57
  /** Compute size
45
58
  *
46
59
  * This is a costly operation as we need to go through the whole Dataset.
@@ -1,6 +1,22 @@
1
1
  import createDebug from "debug";
2
2
  import { List, Range } from "immutable";
3
3
  const debug = createDebug("discojs:dataset");
4
+ /** Convert a DatasetLike object to an async generator */
5
+ async function* datasetLikeToGenerator(content) {
6
+ let iter;
7
+ if (typeof content === "function")
8
+ iter = content();
9
+ else if (Symbol.asyncIterator in content)
10
+ iter = content[Symbol.asyncIterator]();
11
+ else
12
+ iter = content[Symbol.iterator]();
13
+ while (true) {
14
+ const result = await iter.next();
15
+ if (result.done === true)
16
+ break;
17
+ yield result.value;
18
+ }
19
+ }
4
20
  /** Immutable series of data */
5
21
  export class Dataset {
6
22
  #content;
@@ -11,19 +27,7 @@ export class Dataset {
11
27
  */
12
28
  constructor(content) {
13
29
  this.#content = async function* () {
14
- let iter;
15
- if (typeof content === "function")
16
- iter = content();
17
- else if (Symbol.asyncIterator in content)
18
- iter = content[Symbol.asyncIterator]();
19
- else
20
- iter = content[Symbol.iterator]();
21
- while (true) {
22
- const result = await iter.next();
23
- if (result.done === true)
24
- break;
25
- yield result.value;
26
- }
30
+ yield* datasetLikeToGenerator(content);
27
31
  };
28
32
  }
29
33
  [Symbol.asyncIterator]() {
@@ -87,19 +91,31 @@ export class Dataset {
87
91
  }.bind(this)),
88
92
  ];
89
93
  }
90
- /** Slice into chunks
94
+ /** Create batches of `size` elements with potential overlap.
95
+ * Last batch is smaller if dataset isn't perfectly divisible
96
+ *
97
+ * If overlap is set to a positive integer, the last `overlap` elements of a batch
98
+ * are the first `overlap` elements of the next batch.
91
99
  *
92
- * Last slice is smaller if dataset isn't perfectly divisible
100
+ * This method is tailored to create text sequences where each token's label is the following token.
101
+ * In order to have a label for the last token of the input sequence, we include the first token
102
+ * of the next sequence (i.e. with an overlap of 1).
93
103
  *
94
104
  * @param size count of element per chunk
105
+ * @param overlap number of elements overlapping between two consecutive batches
95
106
  */
96
- batch(size) {
107
+ batch(size, overlap = 0) {
97
108
  if (size <= 0 || !Number.isInteger(size))
98
109
  throw new Error("invalid size");
110
+ if (overlap >= size || !Number.isInteger(overlap))
111
+ throw new Error("invalid overlap");
99
112
  return new Dataset(async function* () {
100
113
  const iter = this[Symbol.asyncIterator]();
114
+ let overlapped = List();
101
115
  for (;;) {
102
- const batch = List(await Promise.all(Range(0, size).map(() => iter.next()))).flatMap((res) => {
116
+ const batch = List(
117
+ // get the first elements of the next batch
118
+ await Promise.all(Range(overlapped.size, size).map(() => iter.next()))).flatMap((res) => {
103
119
  if (res.done)
104
120
  return [];
105
121
  else
@@ -107,18 +123,21 @@ export class Dataset {
107
123
  });
108
124
  if (batch.isEmpty())
109
125
  break;
110
- yield batch;
126
+ // yield the current batch with the first elements of the next batch
127
+ yield overlapped.concat(batch);
128
+ overlapped = batch.takeLast(overlap);
111
129
  // iterator couldn't generate more
112
- if (batch.size < size)
130
+ if (batch.size < size - overlap)
113
131
  break;
114
132
  }
115
133
  }.bind(this));
116
134
  }
117
- /** Flatten chunks */
118
- unbatch() {
135
+ /** Flatten batches/arrays of elements */
136
+ flatten() {
119
137
  return new Dataset(async function* () {
120
- for await (const batch of this)
121
- yield* batch;
138
+ for await (const batch of this) {
139
+ yield* datasetLikeToGenerator(batch);
140
+ }
122
141
  }.bind(this));
123
142
  }
124
143
  /** Join side-by-side
@@ -141,6 +160,22 @@ export class Dataset {
141
160
  }
142
161
  }.bind(this));
143
162
  }
163
+ /**
164
+ * Repeat the dataset `times` times
165
+ * @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
166
+ * @returns a dataset repeated `times` times
167
+ */
168
+ repeat(times) {
169
+ if (times !== undefined && (!Number.isInteger(times) || times < 1))
170
+ throw new Error("times needs to be a positive integer or undefined");
171
+ return new Dataset(async function* () {
172
+ let loop = 0;
173
+ do {
174
+ yield* this;
175
+ loop++;
176
+ } while (times === undefined || loop < times);
177
+ }.bind(this));
178
+ }
144
179
  /** Compute size
145
180
  *
146
181
  * This is a costly operation as we need to go through the whole Dataset.
@@ -4,3 +4,4 @@ export type Batched<T> = List<T>;
4
4
  export { Image };
5
5
  export type Tabular = Partial<Record<string, string>>;
6
6
  export type Text = string;
7
+ export type TokenizedText = List<number>;
@@ -31,14 +31,16 @@ export const wikitext = {
31
31
  // But if set to 0 then the webapp doesn't display the validation metrics
32
32
  validationSplit: 0.1,
33
33
  roundDuration: 2,
34
- batchSize: 1, // If set too high (e.g. 16) firefox raises a WebGL error
34
+ batchSize: 8, // If set too high firefox raises a WebGL error
35
35
  tokenizer: 'Xenova/gpt2',
36
- maxSequenceLength: 128,
36
+ contextLength: 64,
37
37
  tensorBackend: 'gpt'
38
38
  }
39
39
  };
40
40
  },
41
41
  getModel() {
42
- return Promise.resolve(new models.GPT());
42
+ return Promise.resolve(new models.GPT({
43
+ contextLength: this.getTask().trainingInformation.contextLength,
44
+ }));
43
45
  }
44
46
  };
@@ -1,8 +1,8 @@
1
1
  type GPTModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
2
2
  export interface GPTConfig {
3
3
  lr: number;
4
- blockSize: number;
5
- vocabSize: number;
4
+ contextLength: number;
5
+ vocabSize?: number;
6
6
  modelType: GPTModelType;
7
7
  name?: string;
8
8
  evaluate?: boolean;
@@ -11,22 +11,27 @@ export interface GPTConfig {
11
11
  maxIter?: number;
12
12
  weightDecay?: number;
13
13
  verbose?: 0 | 1;
14
- bias?: boolean;
15
14
  debug?: boolean;
16
15
  dropout?: number;
17
16
  residDrop?: number;
18
17
  embdDrop?: number;
19
- tokEmb?: boolean;
20
- lmHead?: boolean;
21
18
  nLayer?: number;
22
19
  nHead?: number;
23
20
  nEmbd?: number;
21
+ seed?: number;
24
22
  }
25
- export declare const DEFAULT_CONFIG: Required<GPTConfig>;
23
+ export declare const DefaultGPTConfig: Required<GPTConfig>;
26
24
  export type ModelSize = {
27
25
  nLayer: number;
28
26
  nHead: number;
29
27
  nEmbd: number;
30
28
  };
31
29
  export declare function getModelSizes(modelType: GPTModelType): Required<ModelSize>;
30
+ export interface GenerationConfig {
31
+ doSample: boolean;
32
+ temperature: number;
33
+ topk: number;
34
+ seed: number;
35
+ }
36
+ export declare const DefaultGenerationConfig: Required<GenerationConfig>;
32
37
  export {};
@@ -1,6 +1,6 @@
1
1
  // for a benchmark of performance, see https://github.com/epfml/disco/pull/659
2
- export const DEFAULT_CONFIG = {
3
- name: 'transformer',
2
+ export const DefaultGPTConfig = {
3
+ name: 'transformer', // prefix for the model layer names
4
4
  lr: 0.001,
5
5
  weightDecay: 0,
6
6
  maxIter: 10,
@@ -9,18 +9,16 @@ export const DEFAULT_CONFIG = {
9
9
  evaluate: true,
10
10
  maxEvalBatches: 12,
11
11
  evaluateEvery: 100,
12
- blockSize: 128,
13
- vocabSize: 50258,
14
- bias: true,
12
+ contextLength: 128,
13
+ vocabSize: 50257,
15
14
  debug: false,
16
15
  dropout: 0.2,
17
16
  residDrop: 0.2,
18
17
  embdDrop: 0.2,
19
- tokEmb: true,
20
- lmHead: true,
21
18
  nLayer: 3,
22
19
  nHead: 3,
23
20
  nEmbd: 48,
21
+ seed: Math.random(),
24
22
  };
25
23
  export function getModelSizes(modelType) {
26
24
  switch (modelType) {
@@ -40,3 +38,9 @@ export function getModelSizes(modelType) {
40
38
  return { nLayer: 3, nHead: 3, nEmbd: 48 };
41
39
  }
42
40
  }
41
+ export const DefaultGenerationConfig = {
42
+ temperature: 1.0,
43
+ doSample: false,
44
+ seed: Math.random(),
45
+ topk: 50
46
+ };
@@ -1,23 +1,20 @@
1
1
  /**
2
- * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
2
+ * Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
3
+ * With modifications from @peacefulotter, @lukemovement and the Disco team
3
4
  **/
4
5
  import * as tf from "@tensorflow/tfjs";
5
6
  import type { Batched, Dataset, DataFormat } from "../../index.js";
6
7
  import { WeightsContainer } from "../../index.js";
7
8
  import { BatchLogs, Model, EpochLogs } from "../index.js";
8
- import { type GPTConfig } from "./config.js";
9
+ import type { GPTConfig, GenerationConfig } from './config.js';
9
10
  export type GPTSerialization = {
10
11
  weights: WeightsContainer;
11
12
  config?: GPTConfig;
12
13
  };
13
- interface PredictConfig {
14
- temperature: number;
15
- doSample: boolean;
16
- }
17
14
  export declare class GPT extends Model<"text"> {
18
15
  #private;
19
16
  private readonly model;
20
- constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
17
+ constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel);
21
18
  /**
22
19
  * The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
23
20
  * This allows for getting logs and stopping training without callbacks.
@@ -28,7 +25,7 @@ export declare class GPT extends Model<"text"> {
28
25
  * @param tracker
29
26
  */
30
27
  train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded["text"]>>): AsyncGenerator<BatchLogs, EpochLogs>;
31
- predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<PredictConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
28
+ predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<GenerationConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
32
29
  get config(): Required<GPTConfig>;
33
30
  get weights(): WeightsContainer;
34
31
  set weights(ws: WeightsContainer);
@@ -37,4 +34,3 @@ export declare class GPT extends Model<"text"> {
37
34
  extract(): tf.LayersModel;
38
35
  [Symbol.dispose](): void;
39
36
  }
40
- export {};
@@ -1,5 +1,6 @@
1
1
  /**
2
- * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
2
+ * Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
3
+ * With modifications from @peacefulotter, @lukemovement and the Disco team
3
4
  **/
4
5
  import createDebug from "debug";
5
6
  import { List, Range } from "immutable";
@@ -7,12 +8,12 @@ import * as tf from "@tensorflow/tfjs";
7
8
  import { WeightsContainer } from "../../index.js";
8
9
  import { Model, EpochLogs } from "../index.js";
9
10
  import { GPTModel } from "./model.js";
10
- import { DEFAULT_CONFIG } from "./config.js";
11
11
  import evaluate from "./evaluate.js";
12
+ import { DefaultGPTConfig, DefaultGenerationConfig } from './config.js';
12
13
  const debug = createDebug("discojs:models:gpt");
13
14
  export class GPT extends Model {
14
15
  model;
15
- #blockSize;
16
+ #contextLength;
16
17
  #maxBatchCount;
17
18
  #vocabSize;
18
19
  constructor(partialConfig, layersModel) {
@@ -20,9 +21,9 @@ export class GPT extends Model {
20
21
  const model = new GPTModel(partialConfig, layersModel);
21
22
  model.compile();
22
23
  this.model = model;
23
- this.#blockSize = partialConfig?.blockSize ?? DEFAULT_CONFIG.blockSize;
24
- this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter;
25
- this.#vocabSize = partialConfig?.vocabSize ?? DEFAULT_CONFIG.vocabSize;
24
+ this.#contextLength = partialConfig?.contextLength ?? DefaultGPTConfig.contextLength;
25
+ this.#maxBatchCount = partialConfig?.maxIter ?? DefaultGPTConfig.maxIter;
26
+ this.#vocabSize = partialConfig?.vocabSize ?? DefaultGPTConfig.vocabSize;
26
27
  }
27
28
  /**
28
29
  * The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
@@ -85,16 +86,21 @@ export class GPT extends Model {
85
86
  }));
86
87
  }
87
88
  async predict(batch, options) {
88
- const config = {
89
- temperature: 1.0,
90
- doSample: false,
91
- ...options,
92
- };
89
+ // overwrite default with user config
90
+ const config = Object.assign({}, DefaultGenerationConfig, options);
93
91
  return List(await Promise.all(batch.map((tokens) => this.#predictSingle(tokens, config))));
94
92
  }
93
+ /**
94
+ * Generate the next token after the input sequence.
95
+ * In other words, takes an input tensor of shape (prompt length T) and returns a tensor of shape (T+1)
96
+ *
97
+ * @param token input tokens of shape (T,). T is truncated to the model's context length
98
+ * @param config generation config: temperature, doSample, topk
99
+ * @returns the next token predicted by the model
100
+ */
95
101
  async #predictSingle(tokens, config) {
96
102
  // slice input tokens if longer than context length
97
- tokens = tokens.slice(-this.#blockSize);
103
+ tokens = tokens.slice(-this.#contextLength);
98
104
  const input = tf.tidy(() => tf.tensor1d(tokens.toArray(), "int32").expandDims(0));
99
105
  const logits = tf.tidy(() => {
100
106
  const output = this.model.predict(input);
@@ -111,9 +117,24 @@ export class GPT extends Model {
111
117
  .div(config.temperature)
112
118
  .softmax());
113
119
  logits.dispose();
114
- const next = tf.tidy(() => config.doSample
115
- ? tf.multinomial(probs, 1).squeeze([0])
116
- : probs.argMax());
120
+ const next = tf.tidy(() => {
121
+ if (config.doSample) {
122
+ // returns topk biggest values among the `vocab_size` probabilities and the corresponding tokens indices
123
+ // both shapes are (config.topk,)
124
+ const { values: topkProbs, indices: topkTokens } = tf.topk(probs, config.topk);
125
+ // sample an index from the top-k probabilities
126
+ // e.g. [[0.1, 0.4, 0.3], [0.1, 0.2, 0.5]] -> [[1], [2]]
127
+ // note: multinomial does not need the input to sum to 1
128
+ const selectedIndices = tf.multinomial(topkProbs, 1, config.seed, false); // (B, )
129
+ // return the corresponding token from the sampled indices (one per sequence in the batch).
130
+ // if for some reason the probabilities are NaN, selectedIndices will be out of bounds
131
+ return topkTokens.gather(selectedIndices).squeeze([0]); // (1)
132
+ }
133
+ else {
134
+ // greedy decoding: return the token with the highest probability
135
+ return probs.argMax();
136
+ }
137
+ });
117
138
  probs.dispose();
118
139
  const ret = await next.array();
119
140
  next.dispose();
@@ -1,4 +1,6 @@
1
+ import createDebug from "debug";
1
2
  import * as tf from '@tensorflow/tfjs';
3
+ const debug = createDebug("discojs:models:gpt:layers");
2
4
  /**
3
5
  * Defines a range, from 0 to T, that is used to create positional embeddings
4
6
  */
@@ -10,7 +12,8 @@ class Range extends tf.layers.Layer {
10
12
  call(input, kwargs) {
11
13
  return tf.tidy(() => {
12
14
  if (Array.isArray(input)) {
13
- // TODO support multitensor
15
+ if (input.length !== 1)
16
+ throw new Error('expected exactly one tensor');
14
17
  input = input[0];
15
18
  }
16
19
  this.invokeCallHook(input, kwargs);
@@ -22,6 +25,11 @@ class Range extends tf.layers.Layer {
22
25
  }
23
26
  }
24
27
  tf.serialization.registerClass(Range);
28
+ /**
29
+ * LogLayer is a layer that allows debugging the input that is fed to this layer
30
+ * This layer allows to inspect the input tensor at a specific point
31
+ * in the model by adding a log layer in the model definition
32
+ */
25
33
  class LogLayer extends tf.layers.Layer {
26
34
  static className = 'LogLayer';
27
35
  computeOutputShape(inputShape) {
@@ -30,9 +38,19 @@ class LogLayer extends tf.layers.Layer {
30
38
  call(input, kwargs) {
31
39
  return tf.tidy(() => {
32
40
  if (Array.isArray(input)) {
41
+ if (input.length !== 1)
42
+ throw new Error('expected exactly one tensor');
33
43
  input = input[0];
34
44
  }
35
45
  this.invokeCallHook(input, kwargs);
46
+ const logs = {
47
+ 'shape': input.shape,
48
+ 'is_only_zero': !!input.equal(tf.tensor(0)).all().dataSync()[0],
49
+ 'has_some_NaN': !!input.isNaN().any().dataSync()[0],
50
+ 'min': +input.min().dataSync()[0].toPrecision(3),
51
+ 'max': +input.max().dataSync()[0].toPrecision(3),
52
+ };
53
+ debug("%s logged: %o", this.name, logs);
36
54
  return input;
37
55
  });
38
56
  }
@@ -43,8 +61,9 @@ class CausalSelfAttention extends tf.layers.Layer {
43
61
  static className = 'CausalSelfAttention';
44
62
  nHead;
45
63
  nEmbd;
64
+ nLayer;
46
65
  dropout;
47
- bias;
66
+ seed;
48
67
  mask;
49
68
  cAttnKernel;
50
69
  cAttnBias;
@@ -53,20 +72,34 @@ class CausalSelfAttention extends tf.layers.Layer {
53
72
  constructor(config) {
54
73
  super(config);
55
74
  this.config = config;
75
+ if (config.nEmbd % config.nHead !== 0)
76
+ throw new Error('The embedding dimension `nEmbd` must be divisible by the number of attention heads `nHead`');
56
77
  this.nEmbd = config.nEmbd;
57
78
  this.nHead = config.nHead;
79
+ this.nLayer = config.nLayer;
58
80
  this.dropout = config.dropout;
59
- this.bias = config.bias;
81
+ this.seed = config.seed;
60
82
  // mask is a lower triangular matrix filled with 1
61
83
  // calling bandPart zero out the upper triangular part of the all-ones matrix
62
84
  // from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
63
- this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0);
85
+ this.mask = tf.linalg.bandPart(tf.ones([config.contextLength, config.contextLength]), -1, 0);
64
86
  }
65
87
  build() {
66
- this.cAttnKernel = this.addWeight('c_attn/kernel', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.glorotNormal({}));
67
- this.cAttnBias = this.addWeight('c_attn/bias', [3 * this.nEmbd], 'float32', tf.initializers.zeros());
68
- this.cProjKernel = this.addWeight('c_proj/kernel', [this.nEmbd, this.nEmbd], 'float32', tf.initializers.glorotNormal({}));
69
- this.cProjBias = this.addWeight('c_proj/bias', [this.nEmbd], 'float32', tf.initializers.zeros());
88
+ // key, query, value projections for all heads, but in a batch
89
+ this.cAttnKernel = this.addWeight('c_attn.weight', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed }) // use same init as GPT2
90
+ );
91
+ this.cAttnBias = this.addWeight('c_attn.bias', [3 * this.nEmbd], 'float32', tf.initializers.zeros());
92
+ // output projection
93
+ this.cProjKernel = this.addWeight('c_proj.kernel', [this.nEmbd, this.nEmbd], 'float32',
94
+ // the input keeps accumulating through the residual stream so we
95
+ // scale the initialization with the nb of layers to keep a unit std
96
+ // Sources:
97
+ // https://github.com/karpathy/build-nanogpt/blob/6104ab1b53920f6e2159749676073ff7d815c1fa/train_gpt2.py#L103
98
+ // https://youtu.be/l8pRSuU81PU?si=5GcKfi_kPgLgvtg2&t=4640
99
+ tf.initializers.randomNormal({
100
+ mean: 0, stddev: 0.02 * Math.sqrt(2 * this.nLayer), seed: this.seed
101
+ }));
102
+ this.cProjBias = this.addWeight('c_proj.bias', [this.nEmbd], 'float32', tf.initializers.zeros());
70
103
  }
71
104
  computeOutputShape(inputShape) {
72
105
  return inputShape;
@@ -84,58 +117,72 @@ class CausalSelfAttention extends tf.layers.Layer {
84
117
  throw new Error('not built');
85
118
  }
86
119
  if (Array.isArray(input)) {
120
+ if (input.length !== 1)
121
+ throw new Error('expected exactly one tensor');
87
122
  input = input[0];
88
123
  }
89
124
  this.invokeCallHook(input, kwargs);
90
125
  const dense = (x, kernel, bias) => {
126
+ // TODO: use broadcasting when tfjs will support backpropagating through broadcasting
91
127
  const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]);
92
128
  const m = x.matMul(k);
93
- if (this.bias) {
94
- return tf.add(m, bias.read());
95
- }
96
- else {
97
- return m;
98
- }
129
+ return tf.add(m, bias.read());
99
130
  };
100
131
  // Apply attention weights to inputs as one big matrix which is then split into the
101
132
  // query, key and value submatrices
133
+ // nHead is "number of heads", hs is "head size", and C (number of channels) = n_embd = nHead * hs
134
+ // e.g. in GPT-2 (124M), nHead = 12, hs = 64, so nHead * hs = C = 768 channels in the Transformer
102
135
  const cAttn = dense(input, this.cAttnKernel, this.cAttnBias);
103
136
  let [q, k, v] = tf.split(cAttn, 3, -1);
104
- const [B, T, C] = k.shape;
105
- const splitHeads = (x) => tf.transpose(tf.reshape(x, [B, T, this.nHead, C / this.nHead]), [0, 2, 1, 3]);
106
- q = splitHeads(q);
107
- k = splitHeads(k);
108
- v = splitHeads(v);
109
- // Scaled self attention: query @ key / sqrt(n_heads)
110
- let att = tf.mul(tf.matMul(q, k, false, true), tf.div(1, tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32'))));
111
- // The next operations apply attention to the past tokens, which is
112
- // essentially a weighted average of the past tokens with complicated weights,
113
- // and makes sure to not pay any attention to future tokens
137
+ // Follow naming conventions in https://github.com/karpathy/build-nanogpt/
138
+ const [B, T, C] = k.shape; // batch size, sequence length, embedding dimensionality (number of channels)
139
+ const splitHeads = (x) => tf.transpose(tf.reshape(x, [B, T, this.nHead, C / this.nHead]), // (B, T, nHead, head size)
140
+ [0, 2, 1, 3] // (B, nHead, T, hs)
141
+ );
142
+ q = splitHeads(q); // (B, nHead, T, hs)
143
+ k = splitHeads(k); // (B, nHead, T, hs)
144
+ v = splitHeads(v); // (B, nHead, T, hs)
145
+ // Scaled self attention: query @ key / sqrt(hs)
146
+ // Matrix representing the token-to-token attention (B, nHead, T, T)
147
+ let att = tf.mul(tf.matMul(q, k, false, true), // (B, nHead, T, hs) x (B, nHead, hs, T) -> (B, nHead, T, T)
148
+ tf.div(1, tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32'))) // 1 / sqrt(hs)
149
+ );
150
+ /**
151
+ * The next operations apply attention only on the past tokens, which is
152
+ * essentially a weighted average of the past tokens with complicated weights,
153
+ * it relies on a mask to not "pay any attention" to future tokens
154
+ */
114
155
  // mask is lower triangular matrix filled with 1
115
- const mask = this.mask.slice([0, 0], [T, T]);
156
+ const mask = this.mask.slice([0, 0], [T, T]); // (T, T)
116
157
  // 1 - mask => upper triangular matrix filled with 1
117
158
  // (1 - mask) * -10^9 => upper triangular matrix filled with -inf
118
159
  // att + ((1 - mask) * -10^9) => lower triangular part is the same as the `att` matrix
119
160
  // upper triangular part is -inf
120
- att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9));
121
- // applying softmax zeros out the upper triangular part
122
- //(which are the attention weights of future tokens)
161
+ att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9)); // (B, nHead, T, T)
162
+ // applying softmax zeroes out the upper triangular part (softmax(-inf) = 0)
163
+ // i.e., zeroes out future tokens's attention weights
123
164
  // and creates a probability distribution for the lower triangular
124
165
  // (attention weights of past tokens). The probability distribution ensures
125
166
  // that the attention weights of past tokens for a particular token sum to one
126
167
  att = tf.softmax(att, -1);
127
- att = kwargs.training === true ? tf.dropout(att, this.dropout) : att;
168
+ att = kwargs.training === true ? tf.dropout(att, this.dropout, undefined, this.seed) : att;
128
169
  // This is where the (attention-)weighted sum of past values is performed
129
- let y = tf.matMul(att, v);
130
- y = tf.transpose(y, [0, 2, 1, 3]);
131
- y = tf.reshape(y, [B, T, C]);
132
- y = dense(y, this.cProjKernel, this.cProjBias);
133
- y = kwargs.training === true ? tf.dropout(y, this.dropout) : y;
170
+ let y = tf.matMul(att, v); // (B, nHead, T, T) x (B, nHead, T, hs) -> (B, nHead, T, hs)
171
+ y = tf.transpose(y, [0, 2, 1, 3]); // (B, T, nHead, hs)
172
+ y = tf.reshape(y, [B, T, C]); // (B, T, C = nHead * hs)
173
+ y = dense(y, this.cProjKernel, this.cProjBias); // output projection (B, T, C)
174
+ y = kwargs.training === true ? tf.dropout(y, this.dropout, undefined, this.seed) : y;
134
175
  return y;
135
176
  });
136
177
  }
137
178
  }
138
179
  tf.serialization.registerClass(CausalSelfAttention);
180
+ /**
181
+ * GELU with tanh approximate
182
+ * GELU(x) = x * 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
183
+ *
184
+ * https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
185
+ */
139
186
  class GELU extends tf.layers.Layer {
140
187
  static className = 'GELU';
141
188
  constructor() {
@@ -148,11 +195,17 @@ class GELU extends tf.layers.Layer {
148
195
  return tf.tidy(() => {
149
196
  if (Array.isArray(input)) {
150
197
  // TODO support multitensor
198
+ if (input.length !== 1)
199
+ throw new Error('expected exactly one tensor');
151
200
  input = input[0];
152
201
  }
153
202
  this.invokeCallHook(input, kwargs);
154
- const cdf = tf.mul(0.5, tf.add(1, tf.tanh(tf.mul(tf.sqrt(tf.div(2, Math.PI)), tf.add(input, tf.mul(0.044715, tf.pow(input, 3)))))));
155
- return tf.mul(input, cdf);
203
+ const cdf = tf.mul(// 0.5 * (1 + Tanh[sqrt(2) * (x + 0.044715 * x^3)])
204
+ 0.5, tf.add(1, tf.tanh(// Tanh[sqrt(2/π) * (x + 0.044715 * x^3)]
205
+ tf.mul(tf.sqrt(tf.div(2, Math.PI)), // (sqrt(2/π)
206
+ tf.add(input, tf.mul(0.044715, tf.pow(input, 3))) // (x + 0.044715 * x^3)
207
+ ))));
208
+ return tf.mul(input, cdf); // x * 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
156
209
  });
157
210
  }
158
211
  }
@@ -160,48 +213,173 @@ tf.serialization.registerClass(GELU);
160
213
  function MLP(config) {
161
214
  return tf.sequential({ layers: [
162
215
  tf.layers.dense({
163
- name: config.name + `/mlp/c_fc`,
216
+ name: config.name + `.mlp.c_fc`,
164
217
  units: 4 * config.nEmbd,
165
218
  inputDim: config.nEmbd,
166
- inputShape: [config.blockSize, config.nEmbd]
219
+ inputShape: [config.contextLength, config.nEmbd],
220
+ kernelInitializer: tf.initializers.randomNormal({
221
+ mean: 0, stddev: 0.02, seed: config.seed
222
+ }),
167
223
  }),
168
224
  new GELU(),
169
225
  tf.layers.dense({
170
- name: config.name + '/mlp/c_proj',
226
+ name: config.name + '.mlp.c_proj',
171
227
  units: config.nEmbd,
172
228
  inputDim: 4 * config.nEmbd,
173
- inputShape: [config.blockSize, 4 * config.nEmbd]
229
+ inputShape: [config.contextLength, 4 * config.nEmbd],
230
+ kernelInitializer: tf.initializers.randomNormal({
231
+ mean: 0, stddev: 0.02 * Math.sqrt(2 * config.nLayer), seed: config.seed
232
+ }),
174
233
  }),
175
234
  tf.layers.dropout({
176
- name: config.name + '/mlp/drop',
177
- rate: config.residDrop
235
+ name: config.name + '.mlp.drop',
236
+ rate: config.residDrop,
237
+ seed: config.seed
178
238
  }),
179
239
  ] });
180
240
  }
241
+ /**
242
+ * Performs the following operations:
243
+ * x1 = input + mlp(layernorm_1(input))
244
+ * output = x1 + mlp(layernorm_2(x1))
245
+ */
181
246
  function TransformerBlock(conf) {
182
- const config = Object.assign({ name: 'h' }, conf);
183
- const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] });
247
+ const config = Object.assign({ name: '.h' }, conf);
248
+ const inputs = tf.input({ shape: [config.contextLength, config.nEmbd] });
184
249
  let x1, x2;
185
250
  // input normalization
186
- x1 = tf.layers.layerNormalization({ name: config.name + '/ln_1', epsilon: 1e-5 })
187
- .apply(inputs);
251
+ x1 = tf.layers.layerNormalization({
252
+ name: config.name + '.ln_1',
253
+ epsilon: 1e-5,
254
+ gammaInitializer: 'ones', // already the default but make it explicit
255
+ betaInitializer: 'zeros',
256
+ }).apply(inputs);
188
257
  if (config.debug) {
189
- x1 = new LogLayer({ name: config.name + '/ln_1_log' }).apply(x1);
258
+ x1 = new LogLayer({ name: config.name + '.ln_1_log' }).apply(x1);
190
259
  }
191
260
  // self attention layer
192
- x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '/attn' })).apply(x1);
261
+ x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '.attn' })).apply(x1);
262
+ if (config.debug) {
263
+ x1 = new LogLayer({ name: config.name + '.attn_log' }).apply(x1);
264
+ }
193
265
  // Residual connection
194
266
  x1 = tf.layers.add().apply([inputs, x1]);
267
+ if (config.debug) {
268
+ x1 = new LogLayer({ name: config.name + '.residual_log' }).apply(x1);
269
+ }
195
270
  // normalization
196
- x2 = tf.layers
197
- .layerNormalization({ name: config.name + '/ln_2', epsilon: 1e-5 })
198
- .apply(x1);
271
+ x2 = tf.layers.layerNormalization({
272
+ name: config.name + '.ln_2',
273
+ epsilon: 1e-5,
274
+ gammaInitializer: 'ones',
275
+ betaInitializer: 'zeros',
276
+ }).apply(x1);
277
+ if (config.debug) {
278
+ x2 = new LogLayer({ name: config.name + '.ln_2_log' }).apply(x2);
279
+ }
199
280
  // MLP
200
- x2 = MLP(Object.assign({}, config, { name: config.name })).apply(x2);
281
+ x2 = MLP(Object.assign({}, config, { name: config.name + '.mlp' })).apply(x2);
282
+ if (config.debug) {
283
+ x2 = new LogLayer({ name: config.name + '.mlp_log' }).apply(x2);
284
+ }
201
285
  // add attention output to mlp output
202
286
  x2 = tf.layers.add().apply([x1, x2]);
287
+ if (config.debug) {
288
+ x2 = new LogLayer({ name: config.name + '.add_log' }).apply(x2);
289
+ }
203
290
  return tf.model({ name: config.name, inputs, outputs: x2 });
204
291
  }
292
+ /**
293
+ * LanguageModelEmbedding is a layer that combines the token embeddings and the language modeling head
294
+ * I.e. LMEmbedding is used to translate token indices into token embeddings
295
+ * as well as to project embeddings back into token indices
296
+ * The GPT2 model uses the same embedding matrix for both the token embeddings and the language modeling head
297
+ * Because Tensorflow.js doesn't offer an easy weight sharing mechanism, we need to define a custom layer
298
+ * that can be used for both the token embeddings and the language modeling head.
299
+ * In the GPT2 model definition, this layers corresponds to wte and lm_head (which reuses wte)
300
+ */
301
+ class LMEmbedding extends tf.layers.Layer {
302
+ vocabSize;
303
+ nEmbd;
304
+ seed;
305
+ static className = 'LMEmbedding';
306
+ embeddings;
307
+ constructor(vocabSize, nEmbd, seed) {
308
+ super({});
309
+ this.vocabSize = vocabSize;
310
+ this.nEmbd = nEmbd;
311
+ this.seed = seed;
312
+ }
313
+ build() {
314
+ this.embeddings = this.addWeight('wte', //use same name as GPT2
315
+ [this.vocabSize, this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed }));
316
+ }
317
+ computeOutputShape(inputShape) {
318
+ let shape;
319
+ if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
320
+ if (inputShape.length !== 1)
321
+ throw new Error('Expected exactly one Shape');
322
+ shape = inputShape[0];
323
+ }
324
+ else
325
+ shape = inputShape;
326
+ // input shape for the token embedding
327
+ if (shape.length === 2) {
328
+ // https://github.com/tensorflow/tfjs/blob/3daf152cb794f4da58fce5e21e09e8a4f89c8f80/tfjs-layers/src/layers/embeddings.ts#L155
329
+ // batch size and sequence length are undetermined
330
+ // so the output shape is [null, null, nEmbd]
331
+ if (shape[0] !== null || shape[1] !== null)
332
+ throw new Error('expected shape [null, null, ...]');
333
+ return [null, null, this.nEmbd];
334
+ }
335
+ // input shape for the language modeling head
336
+ // https://github.com/tensorflow/tfjs/blob/3daf152cb794f4da58fce5e21e09e8a4f89c8f80/tfjs-layers/src/layers/core.ts#L258
337
+ else if (shape.length === 3) {
338
+ // batch size and sequence length are undetermined
339
+ // so the output shape is [null, null, nEmbd]
340
+ if (shape[0] !== null || shape[1] !== null)
341
+ throw new Error('expected shape [null, null, ...]');
342
+ return [null, null, this.vocabSize];
343
+ }
344
+ else
345
+ throw new Error('unexpected input shape');
346
+ }
347
+ call(input, kwargs) {
348
+ return tf.tidy(() => {
349
+ if (this.embeddings === undefined)
350
+ throw new Error('not built');
351
+ if (Array.isArray(input)) {
352
+ if (input.length !== 1)
353
+ throw new Error('expected exactly one tensor');
354
+ input = input[0];
355
+ }
356
+ this.invokeCallHook(input, kwargs);
357
+ // If the input is a 2D tensor, it is a batch of sequences of tokens
358
+ // so we translate the tokens into embeddings
359
+ // using `this.embeddings` as a lookup table
360
+ if (input.shape.length === 2) {
361
+ // (batch_size, sequence_length) => (batch_size, sequence_length, nEmbd)
362
+ return tf.gather(this.embeddings.read(), tf.cast(input, 'int32'), 0);
363
+ }
364
+ // If the input is a 3D tensor, it is a sequence of embeddings
365
+ // so we apply a dense layer to project the embeddings back into the vocabulary space
366
+ else if (input.shape.length === 3 && input.shape[2] === this.nEmbd) {
367
+ // Replicate the kernel for each batch element
368
+ const kernel = this.embeddings.read().expandDims(0).tile([input.shape[0], 1, 1]);
369
+ // TODO: rely on broadcasting when tfjs will support backpropagating through broadcasting
370
+ // Remove the tile, or use tf.einsum('BTE,VE->BTV', input, this.embeddings.read())
371
+ // to prevent tensor duplication but tensorflow.js fails to backpropagate einsum
372
+ // https://github.com/tensorflow/tfjs/issues/5690
373
+ // (batch_size, sequence_length, nEmbd) x (vocabSize, nEmbd)^T -> (batch_size, sequence_length, vocabSize)
374
+ return tf.matMul(input, kernel, false, true);
375
+ }
376
+ else {
377
+ throw new Error('unexpected input shape for token embeddings');
378
+ }
379
+ });
380
+ }
381
+ }
382
+ tf.serialization.registerClass(LMEmbedding);
205
383
  /**
206
384
  * The GPTArchitecture specifically defines a GPT forward pass, i.e.,
207
385
  * what are the inputs, the successive transformer blocks and the outputs. It is then
@@ -212,54 +390,54 @@ function TransformerBlock(conf) {
212
390
  */
213
391
  export function GPTArchitecture(config) {
214
392
  const inputs = tf.input({ shape: [null] });
215
- //Token embedding
216
- const tokEmb = config.tokEmb
217
- ? tf.layers.embedding({
218
- name: config.name + '/wte',
219
- inputDim: config.vocabSize,
220
- outputDim: config.nEmbd,
221
- embeddingsInitializer: 'zeros',
222
- embeddingsRegularizer: undefined,
223
- activityRegularizer: undefined
224
- }).apply(inputs)
225
- : inputs;
393
+ // token embedding
394
+ const wte = new LMEmbedding(config.vocabSize, config.nEmbd, config.seed);
395
+ let tokEmb = wte.apply(inputs); // (batch_size, input length T, nEmbd)
396
+ if (config.debug) {
397
+ tokEmb = new LogLayer({ name: 'tokEmb_log' }).apply(tokEmb);
398
+ }
226
399
  // Positional embedding
227
400
  const range = new Range({}).apply(inputs);
228
401
  let posEmb = tf.layers.embedding({
229
- name: config.name + '/wpe',
230
- inputDim: config.blockSize,
402
+ name: config.name + '.wpe',
403
+ inputDim: config.contextLength,
231
404
  outputDim: config.nEmbd,
232
- embeddingsInitializer: 'zeros'
405
+ embeddingsInitializer: tf.initializers.randomNormal({
406
+ mean: 0, stddev: 0.02, seed: config.seed
407
+ }),
233
408
  }).apply(range);
234
409
  if (config.debug) {
235
- posEmb = new LogLayer({ name: 'posEmb' }).apply(posEmb);
410
+ posEmb = new LogLayer({ name: 'posEmb_log' }).apply(posEmb);
236
411
  }
237
412
  // token and positional embeddings are added together
238
413
  let x = tf.layers.add().apply([tokEmb, posEmb]);
239
414
  // dropout
240
- x = tf.layers.dropout({ name: 'drop', rate: config.embdDrop }).apply(x);
415
+ x = tf.layers.dropout({
416
+ name: 'drop', rate: config.embdDrop, seed: config.seed
417
+ }).apply(x);
241
418
  if (config.debug) {
242
- x = new LogLayer({ name: 'dropadd' }).apply(x);
419
+ x = new LogLayer({ name: 'drop_log' }).apply(x);
243
420
  }
244
- //Apply successively transformer blocks, attention and dense layers
421
+ // apply successively transformer blocks, attention and dense layers
245
422
  for (let i = 0; i < config.nLayer; i++) {
246
- x = TransformerBlock(Object.assign({}, config, { name: config.name + '/h/' + i })).apply(x);
423
+ x = TransformerBlock(Object.assign({}, config, { name: config.name + '.h' + i })).apply(x);
247
424
  }
248
425
  // Normalization
249
- x = tf.layers.layerNormalization({ name: config.name + '/ln_f', epsilon: 1e-5 })
426
+ x = tf.layers.layerNormalization({
427
+ name: config.name + '.ln_f',
428
+ epsilon: 1e-5,
429
+ gammaInitializer: 'ones',
430
+ betaInitializer: 'zeros',
431
+ })
250
432
  .apply(x);
251
433
  if (config.debug) {
252
- x = new LogLayer({ name: 'fin/ln' }).apply(x);
434
+ x = new LogLayer({ name: 'ln_f_log' }).apply(x);
253
435
  }
254
- // Append a language modeling head if specified
255
- if (config.lmHead) {
256
- x = tf.layers.dense({
257
- name: 'lm_head',
258
- units: config.vocabSize,
259
- inputDim: config.nEmbd,
260
- inputShape: [config.blockSize, config.nEmbd],
261
- useBias: false
262
- }).apply(x);
436
+ // language modeling head
437
+ // GPT2 uses the same matrix for the token embedding and the modeling head
438
+ x = wte.apply(x);
439
+ if (config.debug) {
440
+ x = new LogLayer({ name: 'lm_head_log' }).apply(x);
263
441
  }
264
442
  return tf.model({ inputs, outputs: x });
265
443
  }
@@ -16,7 +16,7 @@ export declare abstract class Dataset<T> {
16
16
  */
17
17
  export declare class GPTModel extends tf.LayersModel {
18
18
  protected readonly config: Required<GPTConfig>;
19
- constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
19
+ constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel);
20
20
  get getGPTConfig(): Required<GPTConfig>;
21
21
  compile(): void;
22
22
  fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History>;
@@ -1,10 +1,10 @@
1
1
  import createDebug from "debug";
2
2
  import * as tf from '@tensorflow/tfjs';
3
- import { getModelSizes, DEFAULT_CONFIG } from './config.js';
3
+ import { getModelSizes, DefaultGPTConfig } from './config.js';
4
4
  import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js';
5
5
  import evaluate from './evaluate.js';
6
6
  import { GPTArchitecture } from './layers.js';
7
- const debug = createDebug("discojs:models:gpt");
7
+ const debug = createDebug("discojs:models:gpt:model");
8
8
  /**
9
9
  * GPTModel extends tf.LayersModel and overrides tfjs' default training loop
10
10
  *
@@ -13,7 +13,7 @@ export class GPTModel extends tf.LayersModel {
13
13
  config;
14
14
  constructor(partialConfig, layersModel) {
15
15
  // Fill missing config parameters with default values
16
- let completeConfig = { ...DEFAULT_CONFIG, ...partialConfig };
16
+ let completeConfig = { ...DefaultGPTConfig, ...partialConfig };
17
17
  // Add layer sizes depending on which model has been specified
18
18
  completeConfig = { ...completeConfig, ...getModelSizes(completeConfig.modelType) };
19
19
  if (layersModel !== undefined) {
@@ -112,7 +112,7 @@ export class GPTModel extends tf.LayersModel {
112
112
  tf.dispose([xs, ys]);
113
113
  }
114
114
  let logs = {
115
- 'loss': averageLoss / iteration,
115
+ 'loss': averageLoss / (iteration - 1), // -1 because iteration got incremented at the end of the loop
116
116
  'acc': accuracyFraction[0] / accuracyFraction[1],
117
117
  };
118
118
  if (evalDataset !== undefined) {
@@ -33,11 +33,11 @@ export async function preprocess(task, dataset) {
33
33
  // cast as typescript doesn't reduce generic type
34
34
  const d = dataset;
35
35
  const t = task;
36
+ const contextLength = task.trainingInformation.contextLength;
36
37
  const tokenizer = await models.getTaskTokenizer(t);
37
- const totalTokenCount = task.trainingInformation.maxSequenceLength ??
38
- tokenizer.model_max_length;
39
- return d
40
- .map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
38
+ return d.map(text => processing.tokenize(tokenizer, text))
39
+ .flatten()
40
+ .batch(contextLength + 1, 1)
41
41
  .map((tokens) => [tokens.pop(), tokens.last()]);
42
42
  }
43
43
  }
@@ -60,12 +60,11 @@ export async function preprocessWithoutLabel(task, dataset) {
60
60
  // cast as typescript doesn't reduce generic type
61
61
  const d = dataset;
62
62
  const t = task;
63
+ const contextLength = task.trainingInformation.contextLength;
63
64
  const tokenizer = await models.getTaskTokenizer(t);
64
- const totalTokenCount = t.trainingInformation.maxSequenceLength ??
65
- tokenizer.model_max_length;
66
- return d
67
- .map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
68
- .map((tokens) => tokens.pop());
65
+ return d.map(text => processing.tokenize(tokenizer, text))
66
+ .flatten()
67
+ .batch(contextLength);
69
68
  }
70
69
  }
71
70
  }
@@ -1,11 +1,21 @@
1
- import { List } from "immutable";
2
1
  import { PreTrainedTokenizer } from "@xenova/transformers";
3
- type Token = number;
2
+ import type { Text, TokenizedText } from '../index.js';
3
+ interface TokenizingConfig {
4
+ padding?: boolean;
5
+ padding_side?: 'left' | 'right';
6
+ truncation?: boolean;
7
+ max_length?: number;
8
+ }
4
9
  /**
5
- * Tokenize and truncates input strings
10
+ * Tokenize one line of text.
11
+ * Wrapper around Transformers.js tokenizer to handle type checking and format the output.
12
+ * Note that Transformers.js's tokenizer can tokenize multiple lines of text at once
13
+ * but we are currently not making use of it. Can be useful when padding a batch
6
14
  *
7
- * @param length number of tokens
8
- * @returns encoded string in an array of token, size of max_length
15
+ * @param tokenizer the tokenizer object
16
+ * @param text the text to tokenize
17
+ * @param config TokenizingConfig, the tokenizing parameters when using `tokenizer`
18
+ * @returns List<number> the tokenized text
9
19
  */
10
- export declare function tokenizeAndLeftPad(line: string, tokenizer: PreTrainedTokenizer, length: number): List<Token>;
20
+ export declare function tokenize(tokenizer: PreTrainedTokenizer, text: Text, config?: TokenizingConfig): TokenizedText;
11
21
  export {};
@@ -1,33 +1,36 @@
1
- import { Repeat } from "immutable";
1
+ import { List } from "immutable";
2
2
  function isArrayOfNumber(raw) {
3
3
  return Array.isArray(raw) && raw.every((e) => typeof e === "number");
4
4
  }
5
5
  /**
6
- * Tokenize and truncates input strings
6
+ * Tokenize one line of text.
7
+ * Wrapper around Transformers.js tokenizer to handle type checking and format the output.
8
+ * Note that Transformers.js's tokenizer can tokenize multiple lines of text at once
9
+ * but we are currently not making use of it. Can be useful when padding a batch
7
10
  *
8
- * @param length number of tokens
9
- * @returns encoded string in an array of token, size of max_length
11
+ * @param tokenizer the tokenizer object
12
+ * @param text the text to tokenize
13
+ * @param config TokenizingConfig, the tokenizing parameters when using `tokenizer`
14
+ * @returns List<number> the tokenized text
10
15
  */
11
- export function tokenizeAndLeftPad(line, tokenizer, length) {
12
- if (!Number.isInteger(length))
13
- throw new Error("length should be an integer");
14
- // Transformers.js currently only supports right padding while we need left for text generation
15
- // Right padding should be supported in the future, once it is, we can directly pad while tokenizing
16
- // https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
17
- const tokenized = tokenizer(line, {
18
- padding: false,
19
- truncation: true,
20
- return_tensor: false,
21
- max_length: length,
22
- });
23
- if (typeof tokenized !== "object" ||
24
- tokenized === null ||
25
- !("input_ids" in tokenized) ||
26
- !isArrayOfNumber(tokenized.input_ids))
27
- throw new Error("tokenizer returns unexpected type");
28
- const tokens = tokenized.input_ids;
29
- const paddingSize = length - tokens.length;
30
- if (paddingSize < 0)
31
- throw new Error("tokenized returned more token than expected");
32
- return Repeat(tokenizer.pad_token_id, paddingSize).concat(tokens).toList();
16
+ export function tokenize(tokenizer, text, config) {
17
+ config = { ...config }; // create a config if undefined
18
+ if (config.padding || config.truncation) {
19
+ if (config.max_length === undefined)
20
+ throw new Error("max_length needs to be specified to use padding or truncation");
21
+ if (!Number.isInteger(config.max_length))
22
+ throw new Error("max_length should be an integer");
23
+ }
24
+ if (config.padding) {
25
+ // The padding side is set as an attribute, not in the config
26
+ tokenizer.padding_side = config.padding_side ?? 'left';
27
+ config.truncation = true; // for a single sequence, padding implies truncation to max_length
28
+ }
29
+ const tokenizerResult = tokenizer(text, { ...config, return_tensor: false });
30
+ if (typeof tokenizerResult !== "object" ||
31
+ tokenizerResult === null ||
32
+ !("input_ids" in tokenizerResult) ||
33
+ !isArrayOfNumber(tokenizerResult.input_ids))
34
+ throw new Error("tokenizer returned unexpected type");
35
+ return List(tokenizerResult.input_ids);
33
36
  }
@@ -31,7 +31,7 @@ interface DataTypeToTrainingInformation {
31
31
  text: {
32
32
  dataType: "text";
33
33
  tokenizer: string | PreTrainedTokenizer;
34
- maxSequenceLength?: number;
34
+ contextLength: number;
35
35
  };
36
36
  }
37
37
  export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation<DataType>;
@@ -94,16 +94,15 @@ export function isTrainingInformation(raw) {
94
94
  return true;
95
95
  }
96
96
  case "text": {
97
- const { maxSequenceLength, tokenizer, } = raw;
97
+ const { contextLength, tokenizer, } = raw;
98
98
  if ((typeof tokenizer !== "string" &&
99
99
  !(tokenizer instanceof PreTrainedTokenizer)) ||
100
- (maxSequenceLength !== undefined &&
101
- typeof maxSequenceLength !== "number"))
100
+ (typeof contextLength !== "number"))
102
101
  return false;
103
102
  const _ = {
104
103
  ...repack,
105
104
  dataType,
106
- maxSequenceLength,
105
+ contextLength,
107
106
  tokenizer,
108
107
  };
109
108
  return true;
@@ -1,5 +1,5 @@
1
1
  import { List } from "immutable";
2
- import type { Image, processing, Tabular, Text } from "../index.js";
2
+ import type { Image, processing, Tabular, Text, TokenizedText } from "../index.js";
3
3
  /**
4
4
  * The data & label format goes through various stages.
5
5
  * Raw* is preprocessed into ModelEncoded.
@@ -29,7 +29,7 @@ type Token = number;
29
29
  export interface ModelEncoded {
30
30
  image: [image: processing.NormalizedImage<3>, label: number];
31
31
  tabular: [row: List<number>, number];
32
- text: [line: List<Token>, next: Token];
32
+ text: [line: TokenizedText, next: Token];
33
33
  }
34
34
  /** what gets outputted by the Validator, for humans */
35
35
  export interface Inferred {
package/dist/validator.js CHANGED
@@ -13,7 +13,7 @@ export class Validator {
13
13
  .map(async (batch) => (await this.#model.predict(batch.map(([inputs, _]) => inputs)))
14
14
  .zip(batch.map(([_, outputs]) => outputs))
15
15
  .map(([inferred, truth]) => inferred === truth))
16
- .unbatch();
16
+ .flatten();
17
17
  for await (const e of results)
18
18
  yield e;
19
19
  }
@@ -22,7 +22,7 @@ export class Validator {
22
22
  const modelPredictions = (await processing.preprocessWithoutLabel(this.task, dataset))
23
23
  .batch(this.task.trainingInformation.batchSize)
24
24
  .map((batch) => this.#model.predict(batch))
25
- .unbatch();
25
+ .flatten();
26
26
  const predictions = await processing.postprocess(this.task, modelPredictions);
27
27
  for await (const e of predictions)
28
28
  yield e;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20241203151748.0",
3
+ "version": "3.0.1-p20241206133538.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",