@epfml/discojs 3.0.1-p20241203151748.0 → 3.0.1-p20241206154707.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/dataset/dataset.d.ts +18 -5
- package/dist/dataset/dataset.js +58 -23
- package/dist/dataset/types.d.ts +1 -0
- package/dist/default_tasks/wikitext.js +5 -3
- package/dist/models/gpt/config.d.ts +11 -6
- package/dist/models/gpt/config.js +11 -7
- package/dist/models/gpt/index.d.ts +5 -9
- package/dist/models/gpt/index.js +36 -15
- package/dist/models/gpt/layers.js +260 -82
- package/dist/models/gpt/model.d.ts +1 -1
- package/dist/models/gpt/model.js +4 -4
- package/dist/processing/index.js +8 -9
- package/dist/processing/text.d.ts +16 -6
- package/dist/processing/text.js +29 -26
- package/dist/task/training_information.d.ts +1 -1
- package/dist/task/training_information.js +3 -4
- package/dist/types/data_format.d.ts +2 -2
- package/dist/validator.js +2 -2
- package/package.json +1 -1
|
@@ -25,15 +25,22 @@ export declare class Dataset<T> implements AsyncIterable<T> {
|
|
|
25
25
|
* @param ratio between 0 (all on left) and 1 (all on right)
|
|
26
26
|
*/
|
|
27
27
|
split(ratio: number): [Dataset<T>, Dataset<T>];
|
|
28
|
-
/**
|
|
28
|
+
/** Create batches of `size` elements with potential overlap.
|
|
29
|
+
* Last batch is smaller if dataset isn't perfectly divisible
|
|
29
30
|
*
|
|
30
|
-
*
|
|
31
|
+
* If overlap is set to a positive integer, the last `overlap` elements of a batch
|
|
32
|
+
* are the first `overlap` elements of the next batch.
|
|
33
|
+
*
|
|
34
|
+
* This method is tailored to create text sequences where each token's label is the following token.
|
|
35
|
+
* In order to have a label for the last token of the input sequence, we include the first token
|
|
36
|
+
* of the next sequence (i.e. with an overlap of 1).
|
|
31
37
|
*
|
|
32
38
|
* @param size count of element per chunk
|
|
39
|
+
* @param overlap number of elements overlapping between two consecutive batches
|
|
33
40
|
*/
|
|
34
|
-
batch(size: number): Dataset<Batched<T>>;
|
|
35
|
-
/** Flatten
|
|
36
|
-
|
|
41
|
+
batch(size: number, overlap?: number): Dataset<Batched<T>>;
|
|
42
|
+
/** Flatten batches/arrays of elements */
|
|
43
|
+
flatten<U>(this: Dataset<DatasetLike<U>>): Dataset<U>;
|
|
37
44
|
/** Join side-by-side
|
|
38
45
|
*
|
|
39
46
|
* Stops as soon as one runs out
|
|
@@ -41,6 +48,12 @@ export declare class Dataset<T> implements AsyncIterable<T> {
|
|
|
41
48
|
* @param other right side
|
|
42
49
|
**/
|
|
43
50
|
zip<U>(other: Dataset<U> | DatasetLike<U>): Dataset<[T, U]>;
|
|
51
|
+
/**
|
|
52
|
+
* Repeat the dataset `times` times
|
|
53
|
+
* @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
|
|
54
|
+
* @returns a dataset repeated `times` times
|
|
55
|
+
*/
|
|
56
|
+
repeat(times?: number): Dataset<T>;
|
|
44
57
|
/** Compute size
|
|
45
58
|
*
|
|
46
59
|
* This is a costly operation as we need to go through the whole Dataset.
|
package/dist/dataset/dataset.js
CHANGED
|
@@ -1,6 +1,22 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import { List, Range } from "immutable";
|
|
3
3
|
const debug = createDebug("discojs:dataset");
|
|
4
|
+
/** Convert a DatasetLike object to an async generator */
|
|
5
|
+
async function* datasetLikeToGenerator(content) {
|
|
6
|
+
let iter;
|
|
7
|
+
if (typeof content === "function")
|
|
8
|
+
iter = content();
|
|
9
|
+
else if (Symbol.asyncIterator in content)
|
|
10
|
+
iter = content[Symbol.asyncIterator]();
|
|
11
|
+
else
|
|
12
|
+
iter = content[Symbol.iterator]();
|
|
13
|
+
while (true) {
|
|
14
|
+
const result = await iter.next();
|
|
15
|
+
if (result.done === true)
|
|
16
|
+
break;
|
|
17
|
+
yield result.value;
|
|
18
|
+
}
|
|
19
|
+
}
|
|
4
20
|
/** Immutable series of data */
|
|
5
21
|
export class Dataset {
|
|
6
22
|
#content;
|
|
@@ -11,19 +27,7 @@ export class Dataset {
|
|
|
11
27
|
*/
|
|
12
28
|
constructor(content) {
|
|
13
29
|
this.#content = async function* () {
|
|
14
|
-
|
|
15
|
-
if (typeof content === "function")
|
|
16
|
-
iter = content();
|
|
17
|
-
else if (Symbol.asyncIterator in content)
|
|
18
|
-
iter = content[Symbol.asyncIterator]();
|
|
19
|
-
else
|
|
20
|
-
iter = content[Symbol.iterator]();
|
|
21
|
-
while (true) {
|
|
22
|
-
const result = await iter.next();
|
|
23
|
-
if (result.done === true)
|
|
24
|
-
break;
|
|
25
|
-
yield result.value;
|
|
26
|
-
}
|
|
30
|
+
yield* datasetLikeToGenerator(content);
|
|
27
31
|
};
|
|
28
32
|
}
|
|
29
33
|
[Symbol.asyncIterator]() {
|
|
@@ -87,19 +91,31 @@ export class Dataset {
|
|
|
87
91
|
}.bind(this)),
|
|
88
92
|
];
|
|
89
93
|
}
|
|
90
|
-
/**
|
|
94
|
+
/** Create batches of `size` elements with potential overlap.
|
|
95
|
+
* Last batch is smaller if dataset isn't perfectly divisible
|
|
96
|
+
*
|
|
97
|
+
* If overlap is set to a positive integer, the last `overlap` elements of a batch
|
|
98
|
+
* are the first `overlap` elements of the next batch.
|
|
91
99
|
*
|
|
92
|
-
*
|
|
100
|
+
* This method is tailored to create text sequences where each token's label is the following token.
|
|
101
|
+
* In order to have a label for the last token of the input sequence, we include the first token
|
|
102
|
+
* of the next sequence (i.e. with an overlap of 1).
|
|
93
103
|
*
|
|
94
104
|
* @param size count of element per chunk
|
|
105
|
+
* @param overlap number of elements overlapping between two consecutive batches
|
|
95
106
|
*/
|
|
96
|
-
batch(size) {
|
|
107
|
+
batch(size, overlap = 0) {
|
|
97
108
|
if (size <= 0 || !Number.isInteger(size))
|
|
98
109
|
throw new Error("invalid size");
|
|
110
|
+
if (overlap >= size || !Number.isInteger(overlap))
|
|
111
|
+
throw new Error("invalid overlap");
|
|
99
112
|
return new Dataset(async function* () {
|
|
100
113
|
const iter = this[Symbol.asyncIterator]();
|
|
114
|
+
let overlapped = List();
|
|
101
115
|
for (;;) {
|
|
102
|
-
const batch = List(
|
|
116
|
+
const batch = List(
|
|
117
|
+
// get the first elements of the next batch
|
|
118
|
+
await Promise.all(Range(overlapped.size, size).map(() => iter.next()))).flatMap((res) => {
|
|
103
119
|
if (res.done)
|
|
104
120
|
return [];
|
|
105
121
|
else
|
|
@@ -107,18 +123,21 @@ export class Dataset {
|
|
|
107
123
|
});
|
|
108
124
|
if (batch.isEmpty())
|
|
109
125
|
break;
|
|
110
|
-
yield batch
|
|
126
|
+
// yield the current batch with the first elements of the next batch
|
|
127
|
+
yield overlapped.concat(batch);
|
|
128
|
+
overlapped = batch.takeLast(overlap);
|
|
111
129
|
// iterator couldn't generate more
|
|
112
|
-
if (batch.size < size)
|
|
130
|
+
if (batch.size < size - overlap)
|
|
113
131
|
break;
|
|
114
132
|
}
|
|
115
133
|
}.bind(this));
|
|
116
134
|
}
|
|
117
|
-
/** Flatten
|
|
118
|
-
|
|
135
|
+
/** Flatten batches/arrays of elements */
|
|
136
|
+
flatten() {
|
|
119
137
|
return new Dataset(async function* () {
|
|
120
|
-
for await (const batch of this)
|
|
121
|
-
yield* batch;
|
|
138
|
+
for await (const batch of this) {
|
|
139
|
+
yield* datasetLikeToGenerator(batch);
|
|
140
|
+
}
|
|
122
141
|
}.bind(this));
|
|
123
142
|
}
|
|
124
143
|
/** Join side-by-side
|
|
@@ -141,6 +160,22 @@ export class Dataset {
|
|
|
141
160
|
}
|
|
142
161
|
}.bind(this));
|
|
143
162
|
}
|
|
163
|
+
/**
|
|
164
|
+
* Repeat the dataset `times` times
|
|
165
|
+
* @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
|
|
166
|
+
* @returns a dataset repeated `times` times
|
|
167
|
+
*/
|
|
168
|
+
repeat(times) {
|
|
169
|
+
if (times !== undefined && (!Number.isInteger(times) || times < 1))
|
|
170
|
+
throw new Error("times needs to be a positive integer or undefined");
|
|
171
|
+
return new Dataset(async function* () {
|
|
172
|
+
let loop = 0;
|
|
173
|
+
do {
|
|
174
|
+
yield* this;
|
|
175
|
+
loop++;
|
|
176
|
+
} while (times === undefined || loop < times);
|
|
177
|
+
}.bind(this));
|
|
178
|
+
}
|
|
144
179
|
/** Compute size
|
|
145
180
|
*
|
|
146
181
|
* This is a costly operation as we need to go through the whole Dataset.
|
package/dist/dataset/types.d.ts
CHANGED
|
@@ -31,14 +31,16 @@ export const wikitext = {
|
|
|
31
31
|
// But if set to 0 then the webapp doesn't display the validation metrics
|
|
32
32
|
validationSplit: 0.1,
|
|
33
33
|
roundDuration: 2,
|
|
34
|
-
batchSize:
|
|
34
|
+
batchSize: 8, // If set too high firefox raises a WebGL error
|
|
35
35
|
tokenizer: 'Xenova/gpt2',
|
|
36
|
-
|
|
36
|
+
contextLength: 64,
|
|
37
37
|
tensorBackend: 'gpt'
|
|
38
38
|
}
|
|
39
39
|
};
|
|
40
40
|
},
|
|
41
41
|
getModel() {
|
|
42
|
-
return Promise.resolve(new models.GPT(
|
|
42
|
+
return Promise.resolve(new models.GPT({
|
|
43
|
+
contextLength: this.getTask().trainingInformation.contextLength,
|
|
44
|
+
}));
|
|
43
45
|
}
|
|
44
46
|
};
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
type GPTModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
|
|
2
2
|
export interface GPTConfig {
|
|
3
3
|
lr: number;
|
|
4
|
-
|
|
5
|
-
vocabSize
|
|
4
|
+
contextLength: number;
|
|
5
|
+
vocabSize?: number;
|
|
6
6
|
modelType: GPTModelType;
|
|
7
7
|
name?: string;
|
|
8
8
|
evaluate?: boolean;
|
|
@@ -11,22 +11,27 @@ export interface GPTConfig {
|
|
|
11
11
|
maxIter?: number;
|
|
12
12
|
weightDecay?: number;
|
|
13
13
|
verbose?: 0 | 1;
|
|
14
|
-
bias?: boolean;
|
|
15
14
|
debug?: boolean;
|
|
16
15
|
dropout?: number;
|
|
17
16
|
residDrop?: number;
|
|
18
17
|
embdDrop?: number;
|
|
19
|
-
tokEmb?: boolean;
|
|
20
|
-
lmHead?: boolean;
|
|
21
18
|
nLayer?: number;
|
|
22
19
|
nHead?: number;
|
|
23
20
|
nEmbd?: number;
|
|
21
|
+
seed?: number;
|
|
24
22
|
}
|
|
25
|
-
export declare const
|
|
23
|
+
export declare const DefaultGPTConfig: Required<GPTConfig>;
|
|
26
24
|
export type ModelSize = {
|
|
27
25
|
nLayer: number;
|
|
28
26
|
nHead: number;
|
|
29
27
|
nEmbd: number;
|
|
30
28
|
};
|
|
31
29
|
export declare function getModelSizes(modelType: GPTModelType): Required<ModelSize>;
|
|
30
|
+
export interface GenerationConfig {
|
|
31
|
+
doSample: boolean;
|
|
32
|
+
temperature: number;
|
|
33
|
+
topk: number;
|
|
34
|
+
seed: number;
|
|
35
|
+
}
|
|
36
|
+
export declare const DefaultGenerationConfig: Required<GenerationConfig>;
|
|
32
37
|
export {};
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
|
|
2
|
-
export const
|
|
3
|
-
name: 'transformer',
|
|
2
|
+
export const DefaultGPTConfig = {
|
|
3
|
+
name: 'transformer', // prefix for the model layer names
|
|
4
4
|
lr: 0.001,
|
|
5
5
|
weightDecay: 0,
|
|
6
6
|
maxIter: 10,
|
|
@@ -9,18 +9,16 @@ export const DEFAULT_CONFIG = {
|
|
|
9
9
|
evaluate: true,
|
|
10
10
|
maxEvalBatches: 12,
|
|
11
11
|
evaluateEvery: 100,
|
|
12
|
-
|
|
13
|
-
vocabSize:
|
|
14
|
-
bias: true,
|
|
12
|
+
contextLength: 128,
|
|
13
|
+
vocabSize: 50257,
|
|
15
14
|
debug: false,
|
|
16
15
|
dropout: 0.2,
|
|
17
16
|
residDrop: 0.2,
|
|
18
17
|
embdDrop: 0.2,
|
|
19
|
-
tokEmb: true,
|
|
20
|
-
lmHead: true,
|
|
21
18
|
nLayer: 3,
|
|
22
19
|
nHead: 3,
|
|
23
20
|
nEmbd: 48,
|
|
21
|
+
seed: Math.random(),
|
|
24
22
|
};
|
|
25
23
|
export function getModelSizes(modelType) {
|
|
26
24
|
switch (modelType) {
|
|
@@ -40,3 +38,9 @@ export function getModelSizes(modelType) {
|
|
|
40
38
|
return { nLayer: 3, nHead: 3, nEmbd: 48 };
|
|
41
39
|
}
|
|
42
40
|
}
|
|
41
|
+
export const DefaultGenerationConfig = {
|
|
42
|
+
temperature: 1.0,
|
|
43
|
+
doSample: false,
|
|
44
|
+
seed: Math.random(),
|
|
45
|
+
topk: 50
|
|
46
|
+
};
|
|
@@ -1,23 +1,20 @@
|
|
|
1
1
|
/**
|
|
2
|
-
*
|
|
2
|
+
* Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
|
|
3
|
+
* With modifications from @peacefulotter, @lukemovement and the Disco team
|
|
3
4
|
**/
|
|
4
5
|
import * as tf from "@tensorflow/tfjs";
|
|
5
6
|
import type { Batched, Dataset, DataFormat } from "../../index.js";
|
|
6
7
|
import { WeightsContainer } from "../../index.js";
|
|
7
8
|
import { BatchLogs, Model, EpochLogs } from "../index.js";
|
|
8
|
-
import {
|
|
9
|
+
import type { GPTConfig, GenerationConfig } from './config.js';
|
|
9
10
|
export type GPTSerialization = {
|
|
10
11
|
weights: WeightsContainer;
|
|
11
12
|
config?: GPTConfig;
|
|
12
13
|
};
|
|
13
|
-
interface PredictConfig {
|
|
14
|
-
temperature: number;
|
|
15
|
-
doSample: boolean;
|
|
16
|
-
}
|
|
17
14
|
export declare class GPT extends Model<"text"> {
|
|
18
15
|
#private;
|
|
19
16
|
private readonly model;
|
|
20
|
-
constructor(partialConfig?: GPTConfig
|
|
17
|
+
constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel);
|
|
21
18
|
/**
|
|
22
19
|
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
23
20
|
* This allows for getting logs and stopping training without callbacks.
|
|
@@ -28,7 +25,7 @@ export declare class GPT extends Model<"text"> {
|
|
|
28
25
|
* @param tracker
|
|
29
26
|
*/
|
|
30
27
|
train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded["text"]>>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
31
|
-
predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<
|
|
28
|
+
predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<GenerationConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
|
|
32
29
|
get config(): Required<GPTConfig>;
|
|
33
30
|
get weights(): WeightsContainer;
|
|
34
31
|
set weights(ws: WeightsContainer);
|
|
@@ -37,4 +34,3 @@ export declare class GPT extends Model<"text"> {
|
|
|
37
34
|
extract(): tf.LayersModel;
|
|
38
35
|
[Symbol.dispose](): void;
|
|
39
36
|
}
|
|
40
|
-
export {};
|
package/dist/models/gpt/index.js
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
/**
|
|
2
|
-
*
|
|
2
|
+
* Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
|
|
3
|
+
* With modifications from @peacefulotter, @lukemovement and the Disco team
|
|
3
4
|
**/
|
|
4
5
|
import createDebug from "debug";
|
|
5
6
|
import { List, Range } from "immutable";
|
|
@@ -7,12 +8,12 @@ import * as tf from "@tensorflow/tfjs";
|
|
|
7
8
|
import { WeightsContainer } from "../../index.js";
|
|
8
9
|
import { Model, EpochLogs } from "../index.js";
|
|
9
10
|
import { GPTModel } from "./model.js";
|
|
10
|
-
import { DEFAULT_CONFIG } from "./config.js";
|
|
11
11
|
import evaluate from "./evaluate.js";
|
|
12
|
+
import { DefaultGPTConfig, DefaultGenerationConfig } from './config.js';
|
|
12
13
|
const debug = createDebug("discojs:models:gpt");
|
|
13
14
|
export class GPT extends Model {
|
|
14
15
|
model;
|
|
15
|
-
#
|
|
16
|
+
#contextLength;
|
|
16
17
|
#maxBatchCount;
|
|
17
18
|
#vocabSize;
|
|
18
19
|
constructor(partialConfig, layersModel) {
|
|
@@ -20,9 +21,9 @@ export class GPT extends Model {
|
|
|
20
21
|
const model = new GPTModel(partialConfig, layersModel);
|
|
21
22
|
model.compile();
|
|
22
23
|
this.model = model;
|
|
23
|
-
this.#
|
|
24
|
-
this.#maxBatchCount = partialConfig?.maxIter ??
|
|
25
|
-
this.#vocabSize = partialConfig?.vocabSize ??
|
|
24
|
+
this.#contextLength = partialConfig?.contextLength ?? DefaultGPTConfig.contextLength;
|
|
25
|
+
this.#maxBatchCount = partialConfig?.maxIter ?? DefaultGPTConfig.maxIter;
|
|
26
|
+
this.#vocabSize = partialConfig?.vocabSize ?? DefaultGPTConfig.vocabSize;
|
|
26
27
|
}
|
|
27
28
|
/**
|
|
28
29
|
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
@@ -85,16 +86,21 @@ export class GPT extends Model {
|
|
|
85
86
|
}));
|
|
86
87
|
}
|
|
87
88
|
async predict(batch, options) {
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
doSample: false,
|
|
91
|
-
...options,
|
|
92
|
-
};
|
|
89
|
+
// overwrite default with user config
|
|
90
|
+
const config = Object.assign({}, DefaultGenerationConfig, options);
|
|
93
91
|
return List(await Promise.all(batch.map((tokens) => this.#predictSingle(tokens, config))));
|
|
94
92
|
}
|
|
93
|
+
/**
|
|
94
|
+
* Generate the next token after the input sequence.
|
|
95
|
+
* In other words, takes an input tensor of shape (prompt length T) and returns a tensor of shape (T+1)
|
|
96
|
+
*
|
|
97
|
+
* @param token input tokens of shape (T,). T is truncated to the model's context length
|
|
98
|
+
* @param config generation config: temperature, doSample, topk
|
|
99
|
+
* @returns the next token predicted by the model
|
|
100
|
+
*/
|
|
95
101
|
async #predictSingle(tokens, config) {
|
|
96
102
|
// slice input tokens if longer than context length
|
|
97
|
-
tokens = tokens.slice(-this.#
|
|
103
|
+
tokens = tokens.slice(-this.#contextLength);
|
|
98
104
|
const input = tf.tidy(() => tf.tensor1d(tokens.toArray(), "int32").expandDims(0));
|
|
99
105
|
const logits = tf.tidy(() => {
|
|
100
106
|
const output = this.model.predict(input);
|
|
@@ -111,9 +117,24 @@ export class GPT extends Model {
|
|
|
111
117
|
.div(config.temperature)
|
|
112
118
|
.softmax());
|
|
113
119
|
logits.dispose();
|
|
114
|
-
const next = tf.tidy(() =>
|
|
115
|
-
|
|
116
|
-
|
|
120
|
+
const next = tf.tidy(() => {
|
|
121
|
+
if (config.doSample) {
|
|
122
|
+
// returns topk biggest values among the `vocab_size` probabilities and the corresponding tokens indices
|
|
123
|
+
// both shapes are (config.topk,)
|
|
124
|
+
const { values: topkProbs, indices: topkTokens } = tf.topk(probs, config.topk);
|
|
125
|
+
// sample an index from the top-k probabilities
|
|
126
|
+
// e.g. [[0.1, 0.4, 0.3], [0.1, 0.2, 0.5]] -> [[1], [2]]
|
|
127
|
+
// note: multinomial does not need the input to sum to 1
|
|
128
|
+
const selectedIndices = tf.multinomial(topkProbs, 1, config.seed, false); // (B, )
|
|
129
|
+
// return the corresponding token from the sampled indices (one per sequence in the batch).
|
|
130
|
+
// if for some reason the probabilities are NaN, selectedIndices will be out of bounds
|
|
131
|
+
return topkTokens.gather(selectedIndices).squeeze([0]); // (1)
|
|
132
|
+
}
|
|
133
|
+
else {
|
|
134
|
+
// greedy decoding: return the token with the highest probability
|
|
135
|
+
return probs.argMax();
|
|
136
|
+
}
|
|
137
|
+
});
|
|
117
138
|
probs.dispose();
|
|
118
139
|
const ret = await next.array();
|
|
119
140
|
next.dispose();
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
const debug = createDebug("discojs:models:gpt:layers");
|
|
2
4
|
/**
|
|
3
5
|
* Defines a range, from 0 to T, that is used to create positional embeddings
|
|
4
6
|
*/
|
|
@@ -10,7 +12,8 @@ class Range extends tf.layers.Layer {
|
|
|
10
12
|
call(input, kwargs) {
|
|
11
13
|
return tf.tidy(() => {
|
|
12
14
|
if (Array.isArray(input)) {
|
|
13
|
-
|
|
15
|
+
if (input.length !== 1)
|
|
16
|
+
throw new Error('expected exactly one tensor');
|
|
14
17
|
input = input[0];
|
|
15
18
|
}
|
|
16
19
|
this.invokeCallHook(input, kwargs);
|
|
@@ -22,6 +25,11 @@ class Range extends tf.layers.Layer {
|
|
|
22
25
|
}
|
|
23
26
|
}
|
|
24
27
|
tf.serialization.registerClass(Range);
|
|
28
|
+
/**
|
|
29
|
+
* LogLayer is a layer that allows debugging the input that is fed to this layer
|
|
30
|
+
* This layer allows to inspect the input tensor at a specific point
|
|
31
|
+
* in the model by adding a log layer in the model definition
|
|
32
|
+
*/
|
|
25
33
|
class LogLayer extends tf.layers.Layer {
|
|
26
34
|
static className = 'LogLayer';
|
|
27
35
|
computeOutputShape(inputShape) {
|
|
@@ -30,9 +38,19 @@ class LogLayer extends tf.layers.Layer {
|
|
|
30
38
|
call(input, kwargs) {
|
|
31
39
|
return tf.tidy(() => {
|
|
32
40
|
if (Array.isArray(input)) {
|
|
41
|
+
if (input.length !== 1)
|
|
42
|
+
throw new Error('expected exactly one tensor');
|
|
33
43
|
input = input[0];
|
|
34
44
|
}
|
|
35
45
|
this.invokeCallHook(input, kwargs);
|
|
46
|
+
const logs = {
|
|
47
|
+
'shape': input.shape,
|
|
48
|
+
'is_only_zero': !!input.equal(tf.tensor(0)).all().dataSync()[0],
|
|
49
|
+
'has_some_NaN': !!input.isNaN().any().dataSync()[0],
|
|
50
|
+
'min': +input.min().dataSync()[0].toPrecision(3),
|
|
51
|
+
'max': +input.max().dataSync()[0].toPrecision(3),
|
|
52
|
+
};
|
|
53
|
+
debug("%s logged: %o", this.name, logs);
|
|
36
54
|
return input;
|
|
37
55
|
});
|
|
38
56
|
}
|
|
@@ -43,8 +61,9 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
43
61
|
static className = 'CausalSelfAttention';
|
|
44
62
|
nHead;
|
|
45
63
|
nEmbd;
|
|
64
|
+
nLayer;
|
|
46
65
|
dropout;
|
|
47
|
-
|
|
66
|
+
seed;
|
|
48
67
|
mask;
|
|
49
68
|
cAttnKernel;
|
|
50
69
|
cAttnBias;
|
|
@@ -53,20 +72,34 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
53
72
|
constructor(config) {
|
|
54
73
|
super(config);
|
|
55
74
|
this.config = config;
|
|
75
|
+
if (config.nEmbd % config.nHead !== 0)
|
|
76
|
+
throw new Error('The embedding dimension `nEmbd` must be divisible by the number of attention heads `nHead`');
|
|
56
77
|
this.nEmbd = config.nEmbd;
|
|
57
78
|
this.nHead = config.nHead;
|
|
79
|
+
this.nLayer = config.nLayer;
|
|
58
80
|
this.dropout = config.dropout;
|
|
59
|
-
this.
|
|
81
|
+
this.seed = config.seed;
|
|
60
82
|
// mask is a lower triangular matrix filled with 1
|
|
61
83
|
// calling bandPart zero out the upper triangular part of the all-ones matrix
|
|
62
84
|
// from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
|
|
63
|
-
this.mask = tf.linalg.bandPart(tf.ones([config.
|
|
85
|
+
this.mask = tf.linalg.bandPart(tf.ones([config.contextLength, config.contextLength]), -1, 0);
|
|
64
86
|
}
|
|
65
87
|
build() {
|
|
66
|
-
|
|
67
|
-
this.
|
|
68
|
-
|
|
69
|
-
this.
|
|
88
|
+
// key, query, value projections for all heads, but in a batch
|
|
89
|
+
this.cAttnKernel = this.addWeight('c_attn.weight', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed }) // use same init as GPT2
|
|
90
|
+
);
|
|
91
|
+
this.cAttnBias = this.addWeight('c_attn.bias', [3 * this.nEmbd], 'float32', tf.initializers.zeros());
|
|
92
|
+
// output projection
|
|
93
|
+
this.cProjKernel = this.addWeight('c_proj.kernel', [this.nEmbd, this.nEmbd], 'float32',
|
|
94
|
+
// the input keeps accumulating through the residual stream so we
|
|
95
|
+
// scale the initialization with the nb of layers to keep a unit std
|
|
96
|
+
// Sources:
|
|
97
|
+
// https://github.com/karpathy/build-nanogpt/blob/6104ab1b53920f6e2159749676073ff7d815c1fa/train_gpt2.py#L103
|
|
98
|
+
// https://youtu.be/l8pRSuU81PU?si=5GcKfi_kPgLgvtg2&t=4640
|
|
99
|
+
tf.initializers.randomNormal({
|
|
100
|
+
mean: 0, stddev: 0.02 * Math.sqrt(2 * this.nLayer), seed: this.seed
|
|
101
|
+
}));
|
|
102
|
+
this.cProjBias = this.addWeight('c_proj.bias', [this.nEmbd], 'float32', tf.initializers.zeros());
|
|
70
103
|
}
|
|
71
104
|
computeOutputShape(inputShape) {
|
|
72
105
|
return inputShape;
|
|
@@ -84,58 +117,72 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
84
117
|
throw new Error('not built');
|
|
85
118
|
}
|
|
86
119
|
if (Array.isArray(input)) {
|
|
120
|
+
if (input.length !== 1)
|
|
121
|
+
throw new Error('expected exactly one tensor');
|
|
87
122
|
input = input[0];
|
|
88
123
|
}
|
|
89
124
|
this.invokeCallHook(input, kwargs);
|
|
90
125
|
const dense = (x, kernel, bias) => {
|
|
126
|
+
// TODO: use broadcasting when tfjs will support backpropagating through broadcasting
|
|
91
127
|
const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]);
|
|
92
128
|
const m = x.matMul(k);
|
|
93
|
-
|
|
94
|
-
return tf.add(m, bias.read());
|
|
95
|
-
}
|
|
96
|
-
else {
|
|
97
|
-
return m;
|
|
98
|
-
}
|
|
129
|
+
return tf.add(m, bias.read());
|
|
99
130
|
};
|
|
100
131
|
// Apply attention weights to inputs as one big matrix which is then split into the
|
|
101
132
|
// query, key and value submatrices
|
|
133
|
+
// nHead is "number of heads", hs is "head size", and C (number of channels) = n_embd = nHead * hs
|
|
134
|
+
// e.g. in GPT-2 (124M), nHead = 12, hs = 64, so nHead * hs = C = 768 channels in the Transformer
|
|
102
135
|
const cAttn = dense(input, this.cAttnKernel, this.cAttnBias);
|
|
103
136
|
let [q, k, v] = tf.split(cAttn, 3, -1);
|
|
104
|
-
|
|
105
|
-
const
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
//
|
|
113
|
-
//
|
|
137
|
+
// Follow naming conventions in https://github.com/karpathy/build-nanogpt/
|
|
138
|
+
const [B, T, C] = k.shape; // batch size, sequence length, embedding dimensionality (number of channels)
|
|
139
|
+
const splitHeads = (x) => tf.transpose(tf.reshape(x, [B, T, this.nHead, C / this.nHead]), // (B, T, nHead, head size)
|
|
140
|
+
[0, 2, 1, 3] // (B, nHead, T, hs)
|
|
141
|
+
);
|
|
142
|
+
q = splitHeads(q); // (B, nHead, T, hs)
|
|
143
|
+
k = splitHeads(k); // (B, nHead, T, hs)
|
|
144
|
+
v = splitHeads(v); // (B, nHead, T, hs)
|
|
145
|
+
// Scaled self attention: query @ key / sqrt(hs)
|
|
146
|
+
// Matrix representing the token-to-token attention (B, nHead, T, T)
|
|
147
|
+
let att = tf.mul(tf.matMul(q, k, false, true), // (B, nHead, T, hs) x (B, nHead, hs, T) -> (B, nHead, T, T)
|
|
148
|
+
tf.div(1, tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32'))) // 1 / sqrt(hs)
|
|
149
|
+
);
|
|
150
|
+
/**
|
|
151
|
+
* The next operations apply attention only on the past tokens, which is
|
|
152
|
+
* essentially a weighted average of the past tokens with complicated weights,
|
|
153
|
+
* it relies on a mask to not "pay any attention" to future tokens
|
|
154
|
+
*/
|
|
114
155
|
// mask is lower triangular matrix filled with 1
|
|
115
|
-
const mask = this.mask.slice([0, 0], [T, T]);
|
|
156
|
+
const mask = this.mask.slice([0, 0], [T, T]); // (T, T)
|
|
116
157
|
// 1 - mask => upper triangular matrix filled with 1
|
|
117
158
|
// (1 - mask) * -10^9 => upper triangular matrix filled with -inf
|
|
118
159
|
// att + ((1 - mask) * -10^9) => lower triangular part is the same as the `att` matrix
|
|
119
160
|
// upper triangular part is -inf
|
|
120
|
-
att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9));
|
|
121
|
-
// applying softmax
|
|
122
|
-
//
|
|
161
|
+
att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9)); // (B, nHead, T, T)
|
|
162
|
+
// applying softmax zeroes out the upper triangular part (softmax(-inf) = 0)
|
|
163
|
+
// i.e., zeroes out future tokens's attention weights
|
|
123
164
|
// and creates a probability distribution for the lower triangular
|
|
124
165
|
// (attention weights of past tokens). The probability distribution ensures
|
|
125
166
|
// that the attention weights of past tokens for a particular token sum to one
|
|
126
167
|
att = tf.softmax(att, -1);
|
|
127
|
-
att = kwargs.training === true ? tf.dropout(att, this.dropout) : att;
|
|
168
|
+
att = kwargs.training === true ? tf.dropout(att, this.dropout, undefined, this.seed) : att;
|
|
128
169
|
// This is where the (attention-)weighted sum of past values is performed
|
|
129
|
-
let y = tf.matMul(att, v);
|
|
130
|
-
y = tf.transpose(y, [0, 2, 1, 3]);
|
|
131
|
-
y = tf.reshape(y, [B, T, C]);
|
|
132
|
-
y = dense(y, this.cProjKernel, this.cProjBias);
|
|
133
|
-
y = kwargs.training === true ? tf.dropout(y, this.dropout) : y;
|
|
170
|
+
let y = tf.matMul(att, v); // (B, nHead, T, T) x (B, nHead, T, hs) -> (B, nHead, T, hs)
|
|
171
|
+
y = tf.transpose(y, [0, 2, 1, 3]); // (B, T, nHead, hs)
|
|
172
|
+
y = tf.reshape(y, [B, T, C]); // (B, T, C = nHead * hs)
|
|
173
|
+
y = dense(y, this.cProjKernel, this.cProjBias); // output projection (B, T, C)
|
|
174
|
+
y = kwargs.training === true ? tf.dropout(y, this.dropout, undefined, this.seed) : y;
|
|
134
175
|
return y;
|
|
135
176
|
});
|
|
136
177
|
}
|
|
137
178
|
}
|
|
138
179
|
tf.serialization.registerClass(CausalSelfAttention);
|
|
180
|
+
/**
|
|
181
|
+
* GELU with tanh approximate
|
|
182
|
+
* GELU(x) = x * 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
|
|
183
|
+
*
|
|
184
|
+
* https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
|
185
|
+
*/
|
|
139
186
|
class GELU extends tf.layers.Layer {
|
|
140
187
|
static className = 'GELU';
|
|
141
188
|
constructor() {
|
|
@@ -148,11 +195,17 @@ class GELU extends tf.layers.Layer {
|
|
|
148
195
|
return tf.tidy(() => {
|
|
149
196
|
if (Array.isArray(input)) {
|
|
150
197
|
// TODO support multitensor
|
|
198
|
+
if (input.length !== 1)
|
|
199
|
+
throw new Error('expected exactly one tensor');
|
|
151
200
|
input = input[0];
|
|
152
201
|
}
|
|
153
202
|
this.invokeCallHook(input, kwargs);
|
|
154
|
-
const cdf = tf.mul(0.5
|
|
155
|
-
|
|
203
|
+
const cdf = tf.mul(// 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
|
|
204
|
+
0.5, tf.add(1, tf.tanh(// Tanh[sqrt(2/π) * (x + 0.044715 * x^3)]
|
|
205
|
+
tf.mul(tf.sqrt(tf.div(2, Math.PI)), // (sqrt(2/π)
|
|
206
|
+
tf.add(input, tf.mul(0.044715, tf.pow(input, 3))) // (x + 0.044715 * x^3)
|
|
207
|
+
))));
|
|
208
|
+
return tf.mul(input, cdf); // x * 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
|
|
156
209
|
});
|
|
157
210
|
}
|
|
158
211
|
}
|
|
@@ -160,48 +213,173 @@ tf.serialization.registerClass(GELU);
|
|
|
160
213
|
function MLP(config) {
|
|
161
214
|
return tf.sequential({ layers: [
|
|
162
215
|
tf.layers.dense({
|
|
163
|
-
name: config.name +
|
|
216
|
+
name: config.name + `.mlp.c_fc`,
|
|
164
217
|
units: 4 * config.nEmbd,
|
|
165
218
|
inputDim: config.nEmbd,
|
|
166
|
-
inputShape: [config.
|
|
219
|
+
inputShape: [config.contextLength, config.nEmbd],
|
|
220
|
+
kernelInitializer: tf.initializers.randomNormal({
|
|
221
|
+
mean: 0, stddev: 0.02, seed: config.seed
|
|
222
|
+
}),
|
|
167
223
|
}),
|
|
168
224
|
new GELU(),
|
|
169
225
|
tf.layers.dense({
|
|
170
|
-
name: config.name + '
|
|
226
|
+
name: config.name + '.mlp.c_proj',
|
|
171
227
|
units: config.nEmbd,
|
|
172
228
|
inputDim: 4 * config.nEmbd,
|
|
173
|
-
inputShape: [config.
|
|
229
|
+
inputShape: [config.contextLength, 4 * config.nEmbd],
|
|
230
|
+
kernelInitializer: tf.initializers.randomNormal({
|
|
231
|
+
mean: 0, stddev: 0.02 * Math.sqrt(2 * config.nLayer), seed: config.seed
|
|
232
|
+
}),
|
|
174
233
|
}),
|
|
175
234
|
tf.layers.dropout({
|
|
176
|
-
name: config.name + '
|
|
177
|
-
rate: config.residDrop
|
|
235
|
+
name: config.name + '.mlp.drop',
|
|
236
|
+
rate: config.residDrop,
|
|
237
|
+
seed: config.seed
|
|
178
238
|
}),
|
|
179
239
|
] });
|
|
180
240
|
}
|
|
241
|
+
/**
|
|
242
|
+
* Performs the following operations:
|
|
243
|
+
* x1 = input + mlp(layernorm_1(input))
|
|
244
|
+
* output = x1 + mlp(layernorm_2(x1))
|
|
245
|
+
*/
|
|
181
246
|
function TransformerBlock(conf) {
|
|
182
|
-
const config = Object.assign({ name: 'h' }, conf);
|
|
183
|
-
const inputs = tf.input({ shape: [config.
|
|
247
|
+
const config = Object.assign({ name: '.h' }, conf);
|
|
248
|
+
const inputs = tf.input({ shape: [config.contextLength, config.nEmbd] });
|
|
184
249
|
let x1, x2;
|
|
185
250
|
// input normalization
|
|
186
|
-
x1 = tf.layers.layerNormalization({
|
|
187
|
-
.
|
|
251
|
+
x1 = tf.layers.layerNormalization({
|
|
252
|
+
name: config.name + '.ln_1',
|
|
253
|
+
epsilon: 1e-5,
|
|
254
|
+
gammaInitializer: 'ones', // already the default but make it explicit
|
|
255
|
+
betaInitializer: 'zeros',
|
|
256
|
+
}).apply(inputs);
|
|
188
257
|
if (config.debug) {
|
|
189
|
-
x1 = new LogLayer({ name: config.name + '
|
|
258
|
+
x1 = new LogLayer({ name: config.name + '.ln_1_log' }).apply(x1);
|
|
190
259
|
}
|
|
191
260
|
// self attention layer
|
|
192
|
-
x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '
|
|
261
|
+
x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '.attn' })).apply(x1);
|
|
262
|
+
if (config.debug) {
|
|
263
|
+
x1 = new LogLayer({ name: config.name + '.attn_log' }).apply(x1);
|
|
264
|
+
}
|
|
193
265
|
// Residual connection
|
|
194
266
|
x1 = tf.layers.add().apply([inputs, x1]);
|
|
267
|
+
if (config.debug) {
|
|
268
|
+
x1 = new LogLayer({ name: config.name + '.residual_log' }).apply(x1);
|
|
269
|
+
}
|
|
195
270
|
// normalization
|
|
196
|
-
x2 = tf.layers
|
|
197
|
-
|
|
198
|
-
|
|
271
|
+
x2 = tf.layers.layerNormalization({
|
|
272
|
+
name: config.name + '.ln_2',
|
|
273
|
+
epsilon: 1e-5,
|
|
274
|
+
gammaInitializer: 'ones',
|
|
275
|
+
betaInitializer: 'zeros',
|
|
276
|
+
}).apply(x1);
|
|
277
|
+
if (config.debug) {
|
|
278
|
+
x2 = new LogLayer({ name: config.name + '.ln_2_log' }).apply(x2);
|
|
279
|
+
}
|
|
199
280
|
// MLP
|
|
200
|
-
x2 = MLP(Object.assign({}, config, { name: config.name })).apply(x2);
|
|
281
|
+
x2 = MLP(Object.assign({}, config, { name: config.name + '.mlp' })).apply(x2);
|
|
282
|
+
if (config.debug) {
|
|
283
|
+
x2 = new LogLayer({ name: config.name + '.mlp_log' }).apply(x2);
|
|
284
|
+
}
|
|
201
285
|
// add attention output to mlp output
|
|
202
286
|
x2 = tf.layers.add().apply([x1, x2]);
|
|
287
|
+
if (config.debug) {
|
|
288
|
+
x2 = new LogLayer({ name: config.name + '.add_log' }).apply(x2);
|
|
289
|
+
}
|
|
203
290
|
return tf.model({ name: config.name, inputs, outputs: x2 });
|
|
204
291
|
}
|
|
292
|
+
/**
|
|
293
|
+
* LanguageModelEmbedding is a layer that combines the token embeddings and the language modeling head
|
|
294
|
+
* I.e. LMEmbedding is used to translate token indices into token embeddings
|
|
295
|
+
* as well as to project embeddings back into token indices
|
|
296
|
+
* The GPT2 model uses the same embedding matrix for both the token embeddings and the language modeling head
|
|
297
|
+
* Because Tensorflow.js doesn't offer an easy weight sharing mechanism, we need to define a custom layer
|
|
298
|
+
* that can be used for both the token embeddings and the language modeling head.
|
|
299
|
+
* In the GPT2 model definition, this layers corresponds to wte and lm_head (which reuses wte)
|
|
300
|
+
*/
|
|
301
|
+
class LMEmbedding extends tf.layers.Layer {
|
|
302
|
+
vocabSize;
|
|
303
|
+
nEmbd;
|
|
304
|
+
seed;
|
|
305
|
+
static className = 'LMEmbedding';
|
|
306
|
+
embeddings;
|
|
307
|
+
constructor(vocabSize, nEmbd, seed) {
|
|
308
|
+
super({});
|
|
309
|
+
this.vocabSize = vocabSize;
|
|
310
|
+
this.nEmbd = nEmbd;
|
|
311
|
+
this.seed = seed;
|
|
312
|
+
}
|
|
313
|
+
build() {
|
|
314
|
+
this.embeddings = this.addWeight('wte', //use same name as GPT2
|
|
315
|
+
[this.vocabSize, this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed }));
|
|
316
|
+
}
|
|
317
|
+
computeOutputShape(inputShape) {
|
|
318
|
+
let shape;
|
|
319
|
+
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
320
|
+
if (inputShape.length !== 1)
|
|
321
|
+
throw new Error('Expected exactly one Shape');
|
|
322
|
+
shape = inputShape[0];
|
|
323
|
+
}
|
|
324
|
+
else
|
|
325
|
+
shape = inputShape;
|
|
326
|
+
// input shape for the token embedding
|
|
327
|
+
if (shape.length === 2) {
|
|
328
|
+
// https://github.com/tensorflow/tfjs/blob/3daf152cb794f4da58fce5e21e09e8a4f89c8f80/tfjs-layers/src/layers/embeddings.ts#L155
|
|
329
|
+
// batch size and sequence length are undetermined
|
|
330
|
+
// so the output shape is [null, null, nEmbd]
|
|
331
|
+
if (shape[0] !== null || shape[1] !== null)
|
|
332
|
+
throw new Error('expected shape [null, null, ...]');
|
|
333
|
+
return [null, null, this.nEmbd];
|
|
334
|
+
}
|
|
335
|
+
// input shape for the language modeling head
|
|
336
|
+
// https://github.com/tensorflow/tfjs/blob/3daf152cb794f4da58fce5e21e09e8a4f89c8f80/tfjs-layers/src/layers/core.ts#L258
|
|
337
|
+
else if (shape.length === 3) {
|
|
338
|
+
// batch size and sequence length are undetermined
|
|
339
|
+
// so the output shape is [null, null, nEmbd]
|
|
340
|
+
if (shape[0] !== null || shape[1] !== null)
|
|
341
|
+
throw new Error('expected shape [null, null, ...]');
|
|
342
|
+
return [null, null, this.vocabSize];
|
|
343
|
+
}
|
|
344
|
+
else
|
|
345
|
+
throw new Error('unexpected input shape');
|
|
346
|
+
}
|
|
347
|
+
call(input, kwargs) {
|
|
348
|
+
return tf.tidy(() => {
|
|
349
|
+
if (this.embeddings === undefined)
|
|
350
|
+
throw new Error('not built');
|
|
351
|
+
if (Array.isArray(input)) {
|
|
352
|
+
if (input.length !== 1)
|
|
353
|
+
throw new Error('expected exactly one tensor');
|
|
354
|
+
input = input[0];
|
|
355
|
+
}
|
|
356
|
+
this.invokeCallHook(input, kwargs);
|
|
357
|
+
// If the input is a 2D tensor, it is a batch of sequences of tokens
|
|
358
|
+
// so we translate the tokens into embeddings
|
|
359
|
+
// using `this.embeddings` as a lookup table
|
|
360
|
+
if (input.shape.length === 2) {
|
|
361
|
+
// (batch_size, sequence_length) => (batch_size, sequence_length, nEmbd)
|
|
362
|
+
return tf.gather(this.embeddings.read(), tf.cast(input, 'int32'), 0);
|
|
363
|
+
}
|
|
364
|
+
// If the input is a 3D tensor, it is a sequence of embeddings
|
|
365
|
+
// so we apply a dense layer to project the embeddings back into the vocabulary space
|
|
366
|
+
else if (input.shape.length === 3 && input.shape[2] === this.nEmbd) {
|
|
367
|
+
// Replicate the kernel for each batch element
|
|
368
|
+
const kernel = this.embeddings.read().expandDims(0).tile([input.shape[0], 1, 1]);
|
|
369
|
+
// TODO: rely on broadcasting when tfjs will support backpropagating through broadcasting
|
|
370
|
+
// Remove the tile, or use tf.einsum('BTE,VE->BTV', input, this.embeddings.read())
|
|
371
|
+
// to prevent tensor duplication but tensorflow.js fails to backpropagate einsum
|
|
372
|
+
// https://github.com/tensorflow/tfjs/issues/5690
|
|
373
|
+
// (batch_size, sequence_length, nEmbd) x (vocabSize, nEmbd)^T -> (batch_size, sequence_length, vocabSize)
|
|
374
|
+
return tf.matMul(input, kernel, false, true);
|
|
375
|
+
}
|
|
376
|
+
else {
|
|
377
|
+
throw new Error('unexpected input shape for token embeddings');
|
|
378
|
+
}
|
|
379
|
+
});
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
tf.serialization.registerClass(LMEmbedding);
|
|
205
383
|
/**
|
|
206
384
|
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
207
385
|
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
@@ -212,54 +390,54 @@ function TransformerBlock(conf) {
|
|
|
212
390
|
*/
|
|
213
391
|
export function GPTArchitecture(config) {
|
|
214
392
|
const inputs = tf.input({ shape: [null] });
|
|
215
|
-
//
|
|
216
|
-
const
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
embeddingsInitializer: 'zeros',
|
|
222
|
-
embeddingsRegularizer: undefined,
|
|
223
|
-
activityRegularizer: undefined
|
|
224
|
-
}).apply(inputs)
|
|
225
|
-
: inputs;
|
|
393
|
+
// token embedding
|
|
394
|
+
const wte = new LMEmbedding(config.vocabSize, config.nEmbd, config.seed);
|
|
395
|
+
let tokEmb = wte.apply(inputs); // (batch_size, input length T, nEmbd)
|
|
396
|
+
if (config.debug) {
|
|
397
|
+
tokEmb = new LogLayer({ name: 'tokEmb_log' }).apply(tokEmb);
|
|
398
|
+
}
|
|
226
399
|
// Positional embedding
|
|
227
400
|
const range = new Range({}).apply(inputs);
|
|
228
401
|
let posEmb = tf.layers.embedding({
|
|
229
|
-
name: config.name + '
|
|
230
|
-
inputDim: config.
|
|
402
|
+
name: config.name + '.wpe',
|
|
403
|
+
inputDim: config.contextLength,
|
|
231
404
|
outputDim: config.nEmbd,
|
|
232
|
-
embeddingsInitializer:
|
|
405
|
+
embeddingsInitializer: tf.initializers.randomNormal({
|
|
406
|
+
mean: 0, stddev: 0.02, seed: config.seed
|
|
407
|
+
}),
|
|
233
408
|
}).apply(range);
|
|
234
409
|
if (config.debug) {
|
|
235
|
-
posEmb = new LogLayer({ name: '
|
|
410
|
+
posEmb = new LogLayer({ name: 'posEmb_log' }).apply(posEmb);
|
|
236
411
|
}
|
|
237
412
|
// token and positional embeddings are added together
|
|
238
413
|
let x = tf.layers.add().apply([tokEmb, posEmb]);
|
|
239
414
|
// dropout
|
|
240
|
-
x = tf.layers.dropout({
|
|
415
|
+
x = tf.layers.dropout({
|
|
416
|
+
name: 'drop', rate: config.embdDrop, seed: config.seed
|
|
417
|
+
}).apply(x);
|
|
241
418
|
if (config.debug) {
|
|
242
|
-
x = new LogLayer({ name: '
|
|
419
|
+
x = new LogLayer({ name: 'drop_log' }).apply(x);
|
|
243
420
|
}
|
|
244
|
-
//
|
|
421
|
+
// apply successively transformer blocks, attention and dense layers
|
|
245
422
|
for (let i = 0; i < config.nLayer; i++) {
|
|
246
|
-
x = TransformerBlock(Object.assign({}, config, { name: config.name + '
|
|
423
|
+
x = TransformerBlock(Object.assign({}, config, { name: config.name + '.h' + i })).apply(x);
|
|
247
424
|
}
|
|
248
425
|
// Normalization
|
|
249
|
-
x = tf.layers.layerNormalization({
|
|
426
|
+
x = tf.layers.layerNormalization({
|
|
427
|
+
name: config.name + '.ln_f',
|
|
428
|
+
epsilon: 1e-5,
|
|
429
|
+
gammaInitializer: 'ones',
|
|
430
|
+
betaInitializer: 'zeros',
|
|
431
|
+
})
|
|
250
432
|
.apply(x);
|
|
251
433
|
if (config.debug) {
|
|
252
|
-
x = new LogLayer({ name: '
|
|
434
|
+
x = new LogLayer({ name: 'ln_f_log' }).apply(x);
|
|
253
435
|
}
|
|
254
|
-
//
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
inputDim: config.nEmbd,
|
|
260
|
-
inputShape: [config.blockSize, config.nEmbd],
|
|
261
|
-
useBias: false
|
|
262
|
-
}).apply(x);
|
|
436
|
+
// language modeling head
|
|
437
|
+
// GPT2 uses the same matrix for the token embedding and the modeling head
|
|
438
|
+
x = wte.apply(x);
|
|
439
|
+
if (config.debug) {
|
|
440
|
+
x = new LogLayer({ name: 'lm_head_log' }).apply(x);
|
|
263
441
|
}
|
|
264
442
|
return tf.model({ inputs, outputs: x });
|
|
265
443
|
}
|
|
@@ -16,7 +16,7 @@ export declare abstract class Dataset<T> {
|
|
|
16
16
|
*/
|
|
17
17
|
export declare class GPTModel extends tf.LayersModel {
|
|
18
18
|
protected readonly config: Required<GPTConfig>;
|
|
19
|
-
constructor(partialConfig?: GPTConfig
|
|
19
|
+
constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel);
|
|
20
20
|
get getGPTConfig(): Required<GPTConfig>;
|
|
21
21
|
compile(): void;
|
|
22
22
|
fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History>;
|
package/dist/models/gpt/model.js
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import * as tf from '@tensorflow/tfjs';
|
|
3
|
-
import { getModelSizes,
|
|
3
|
+
import { getModelSizes, DefaultGPTConfig } from './config.js';
|
|
4
4
|
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js';
|
|
5
5
|
import evaluate from './evaluate.js';
|
|
6
6
|
import { GPTArchitecture } from './layers.js';
|
|
7
|
-
const debug = createDebug("discojs:models:gpt");
|
|
7
|
+
const debug = createDebug("discojs:models:gpt:model");
|
|
8
8
|
/**
|
|
9
9
|
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
10
10
|
*
|
|
@@ -13,7 +13,7 @@ export class GPTModel extends tf.LayersModel {
|
|
|
13
13
|
config;
|
|
14
14
|
constructor(partialConfig, layersModel) {
|
|
15
15
|
// Fill missing config parameters with default values
|
|
16
|
-
let completeConfig = { ...
|
|
16
|
+
let completeConfig = { ...DefaultGPTConfig, ...partialConfig };
|
|
17
17
|
// Add layer sizes depending on which model has been specified
|
|
18
18
|
completeConfig = { ...completeConfig, ...getModelSizes(completeConfig.modelType) };
|
|
19
19
|
if (layersModel !== undefined) {
|
|
@@ -112,7 +112,7 @@ export class GPTModel extends tf.LayersModel {
|
|
|
112
112
|
tf.dispose([xs, ys]);
|
|
113
113
|
}
|
|
114
114
|
let logs = {
|
|
115
|
-
'loss': averageLoss / iteration,
|
|
115
|
+
'loss': averageLoss / (iteration - 1), // -1 because iteration got incremented at the end of the loop
|
|
116
116
|
'acc': accuracyFraction[0] / accuracyFraction[1],
|
|
117
117
|
};
|
|
118
118
|
if (evalDataset !== undefined) {
|
package/dist/processing/index.js
CHANGED
|
@@ -33,11 +33,11 @@ export async function preprocess(task, dataset) {
|
|
|
33
33
|
// cast as typescript doesn't reduce generic type
|
|
34
34
|
const d = dataset;
|
|
35
35
|
const t = task;
|
|
36
|
+
const contextLength = task.trainingInformation.contextLength;
|
|
36
37
|
const tokenizer = await models.getTaskTokenizer(t);
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
|
|
38
|
+
return d.map(text => processing.tokenize(tokenizer, text))
|
|
39
|
+
.flatten()
|
|
40
|
+
.batch(contextLength + 1, 1)
|
|
41
41
|
.map((tokens) => [tokens.pop(), tokens.last()]);
|
|
42
42
|
}
|
|
43
43
|
}
|
|
@@ -60,12 +60,11 @@ export async function preprocessWithoutLabel(task, dataset) {
|
|
|
60
60
|
// cast as typescript doesn't reduce generic type
|
|
61
61
|
const d = dataset;
|
|
62
62
|
const t = task;
|
|
63
|
+
const contextLength = task.trainingInformation.contextLength;
|
|
63
64
|
const tokenizer = await models.getTaskTokenizer(t);
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
|
|
68
|
-
.map((tokens) => tokens.pop());
|
|
65
|
+
return d.map(text => processing.tokenize(tokenizer, text))
|
|
66
|
+
.flatten()
|
|
67
|
+
.batch(contextLength);
|
|
69
68
|
}
|
|
70
69
|
}
|
|
71
70
|
}
|
|
@@ -1,11 +1,21 @@
|
|
|
1
|
-
import { List } from "immutable";
|
|
2
1
|
import { PreTrainedTokenizer } from "@xenova/transformers";
|
|
3
|
-
type
|
|
2
|
+
import type { Text, TokenizedText } from '../index.js';
|
|
3
|
+
interface TokenizingConfig {
|
|
4
|
+
padding?: boolean;
|
|
5
|
+
padding_side?: 'left' | 'right';
|
|
6
|
+
truncation?: boolean;
|
|
7
|
+
max_length?: number;
|
|
8
|
+
}
|
|
4
9
|
/**
|
|
5
|
-
* Tokenize
|
|
10
|
+
* Tokenize one line of text.
|
|
11
|
+
* Wrapper around Transformers.js tokenizer to handle type checking and format the output.
|
|
12
|
+
* Note that Transformers.js's tokenizer can tokenize multiple lines of text at once
|
|
13
|
+
* but we are currently not making use of it. Can be useful when padding a batch
|
|
6
14
|
*
|
|
7
|
-
* @param
|
|
8
|
-
* @
|
|
15
|
+
* @param tokenizer the tokenizer object
|
|
16
|
+
* @param text the text to tokenize
|
|
17
|
+
* @param config TokenizingConfig, the tokenizing parameters when using `tokenizer`
|
|
18
|
+
* @returns List<number> the tokenized text
|
|
9
19
|
*/
|
|
10
|
-
export declare function
|
|
20
|
+
export declare function tokenize(tokenizer: PreTrainedTokenizer, text: Text, config?: TokenizingConfig): TokenizedText;
|
|
11
21
|
export {};
|
package/dist/processing/text.js
CHANGED
|
@@ -1,33 +1,36 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { List } from "immutable";
|
|
2
2
|
function isArrayOfNumber(raw) {
|
|
3
3
|
return Array.isArray(raw) && raw.every((e) => typeof e === "number");
|
|
4
4
|
}
|
|
5
5
|
/**
|
|
6
|
-
* Tokenize
|
|
6
|
+
* Tokenize one line of text.
|
|
7
|
+
* Wrapper around Transformers.js tokenizer to handle type checking and format the output.
|
|
8
|
+
* Note that Transformers.js's tokenizer can tokenize multiple lines of text at once
|
|
9
|
+
* but we are currently not making use of it. Can be useful when padding a batch
|
|
7
10
|
*
|
|
8
|
-
* @param
|
|
9
|
-
* @
|
|
11
|
+
* @param tokenizer the tokenizer object
|
|
12
|
+
* @param text the text to tokenize
|
|
13
|
+
* @param config TokenizingConfig, the tokenizing parameters when using `tokenizer`
|
|
14
|
+
* @returns List<number> the tokenized text
|
|
10
15
|
*/
|
|
11
|
-
export function
|
|
12
|
-
if
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
throw new Error("tokenized returned more token than expected");
|
|
32
|
-
return Repeat(tokenizer.pad_token_id, paddingSize).concat(tokens).toList();
|
|
16
|
+
export function tokenize(tokenizer, text, config) {
|
|
17
|
+
config = { ...config }; // create a config if undefined
|
|
18
|
+
if (config.padding || config.truncation) {
|
|
19
|
+
if (config.max_length === undefined)
|
|
20
|
+
throw new Error("max_length needs to be specified to use padding or truncation");
|
|
21
|
+
if (!Number.isInteger(config.max_length))
|
|
22
|
+
throw new Error("max_length should be an integer");
|
|
23
|
+
}
|
|
24
|
+
if (config.padding) {
|
|
25
|
+
// The padding side is set as an attribute, not in the config
|
|
26
|
+
tokenizer.padding_side = config.padding_side ?? 'left';
|
|
27
|
+
config.truncation = true; // for a single sequence, padding implies truncation to max_length
|
|
28
|
+
}
|
|
29
|
+
const tokenizerResult = tokenizer(text, { ...config, return_tensor: false });
|
|
30
|
+
if (typeof tokenizerResult !== "object" ||
|
|
31
|
+
tokenizerResult === null ||
|
|
32
|
+
!("input_ids" in tokenizerResult) ||
|
|
33
|
+
!isArrayOfNumber(tokenizerResult.input_ids))
|
|
34
|
+
throw new Error("tokenizer returned unexpected type");
|
|
35
|
+
return List(tokenizerResult.input_ids);
|
|
33
36
|
}
|
|
@@ -31,7 +31,7 @@ interface DataTypeToTrainingInformation {
|
|
|
31
31
|
text: {
|
|
32
32
|
dataType: "text";
|
|
33
33
|
tokenizer: string | PreTrainedTokenizer;
|
|
34
|
-
|
|
34
|
+
contextLength: number;
|
|
35
35
|
};
|
|
36
36
|
}
|
|
37
37
|
export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation<DataType>;
|
|
@@ -94,16 +94,15 @@ export function isTrainingInformation(raw) {
|
|
|
94
94
|
return true;
|
|
95
95
|
}
|
|
96
96
|
case "text": {
|
|
97
|
-
const {
|
|
97
|
+
const { contextLength, tokenizer, } = raw;
|
|
98
98
|
if ((typeof tokenizer !== "string" &&
|
|
99
99
|
!(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
100
|
-
(
|
|
101
|
-
typeof maxSequenceLength !== "number"))
|
|
100
|
+
(typeof contextLength !== "number"))
|
|
102
101
|
return false;
|
|
103
102
|
const _ = {
|
|
104
103
|
...repack,
|
|
105
104
|
dataType,
|
|
106
|
-
|
|
105
|
+
contextLength,
|
|
107
106
|
tokenizer,
|
|
108
107
|
};
|
|
109
108
|
return true;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { List } from "immutable";
|
|
2
|
-
import type { Image, processing, Tabular, Text } from "../index.js";
|
|
2
|
+
import type { Image, processing, Tabular, Text, TokenizedText } from "../index.js";
|
|
3
3
|
/**
|
|
4
4
|
* The data & label format goes through various stages.
|
|
5
5
|
* Raw* is preprocessed into ModelEncoded.
|
|
@@ -29,7 +29,7 @@ type Token = number;
|
|
|
29
29
|
export interface ModelEncoded {
|
|
30
30
|
image: [image: processing.NormalizedImage<3>, label: number];
|
|
31
31
|
tabular: [row: List<number>, number];
|
|
32
|
-
text: [line:
|
|
32
|
+
text: [line: TokenizedText, next: Token];
|
|
33
33
|
}
|
|
34
34
|
/** what gets outputted by the Validator, for humans */
|
|
35
35
|
export interface Inferred {
|
package/dist/validator.js
CHANGED
|
@@ -13,7 +13,7 @@ export class Validator {
|
|
|
13
13
|
.map(async (batch) => (await this.#model.predict(batch.map(([inputs, _]) => inputs)))
|
|
14
14
|
.zip(batch.map(([_, outputs]) => outputs))
|
|
15
15
|
.map(([inferred, truth]) => inferred === truth))
|
|
16
|
-
.
|
|
16
|
+
.flatten();
|
|
17
17
|
for await (const e of results)
|
|
18
18
|
yield e;
|
|
19
19
|
}
|
|
@@ -22,7 +22,7 @@ export class Validator {
|
|
|
22
22
|
const modelPredictions = (await processing.preprocessWithoutLabel(this.task, dataset))
|
|
23
23
|
.batch(this.task.trainingInformation.batchSize)
|
|
24
24
|
.map((batch) => this.#model.predict(batch))
|
|
25
|
-
.
|
|
25
|
+
.flatten();
|
|
26
26
|
const predictions = await processing.postprocess(this.task, modelPredictions);
|
|
27
27
|
for await (const e of predictions)
|
|
28
28
|
yield e;
|