@epfml/discojs 2.1.2-p20240528164510.0 → 2.1.2-p20240603114517.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/dataset/dataset_builder.d.ts +2 -11
- package/dist/dataset/dataset_builder.js +22 -46
- package/dist/default_tasks/cifar10.d.ts +2 -0
- package/dist/default_tasks/{cifar10/index.js → cifar10.js} +2 -2
- package/dist/default_tasks/index.d.ts +3 -2
- package/dist/default_tasks/index.js +3 -2
- package/dist/default_tasks/lus_covid.js +1 -1
- package/dist/default_tasks/simple_face.d.ts +2 -0
- package/dist/default_tasks/{simple_face/index.js → simple_face.js} +3 -3
- package/dist/default_tasks/skin_condition.d.ts +2 -0
- package/dist/default_tasks/skin_condition.js +79 -0
- package/dist/models/gpt/config.d.ts +32 -0
- package/dist/models/gpt/config.js +42 -0
- package/dist/models/gpt/evaluate.d.ts +7 -0
- package/dist/models/gpt/evaluate.js +44 -0
- package/dist/models/gpt/index.d.ts +35 -0
- package/dist/models/gpt/index.js +104 -0
- package/dist/models/gpt/layers.d.ts +13 -0
- package/dist/models/gpt/layers.js +272 -0
- package/dist/models/gpt/model.d.ts +43 -0
- package/dist/models/gpt/model.js +191 -0
- package/dist/models/gpt/optimizers.d.ts +4 -0
- package/dist/models/gpt/optimizers.js +95 -0
- package/dist/models/index.d.ts +5 -0
- package/dist/models/index.js +4 -0
- package/dist/{default_tasks/simple_face/model.js → models/mobileNetV2_35_alpha_2_classes.js} +2 -0
- package/dist/{default_tasks/cifar10/model.js → models/mobileNet_v1_025_224.js} +1 -0
- package/dist/models/model.d.ts +51 -0
- package/dist/models/model.js +8 -0
- package/dist/models/tfjs.d.ts +24 -0
- package/dist/models/tfjs.js +107 -0
- package/dist/models/tokenizer.d.ts +14 -0
- package/dist/models/tokenizer.js +22 -0
- package/dist/validation/validator.js +8 -7
- package/package.json +1 -1
- package/dist/default_tasks/cifar10/index.d.ts +0 -2
- package/dist/default_tasks/simple_face/index.d.ts +0 -2
- /package/dist/{default_tasks/simple_face/model.d.ts → models/mobileNetV2_35_alpha_2_classes.d.ts} +0 -0
- /package/dist/{default_tasks/cifar10/model.d.ts → models/mobileNet_v1_025_224.d.ts} +0 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
/**
|
|
3
|
+
* Defines a range, from 0 to T, that is used to create positional embeddings
|
|
4
|
+
*/
|
|
5
|
+
class Range extends tf.layers.Layer {
|
|
6
|
+
static className = 'Range';
|
|
7
|
+
computeOutputShape(inputShape) {
|
|
8
|
+
return inputShape;
|
|
9
|
+
}
|
|
10
|
+
call(input, kwargs) {
|
|
11
|
+
return tf.tidy(() => {
|
|
12
|
+
if (Array.isArray(input)) {
|
|
13
|
+
// TODO support multitensor
|
|
14
|
+
input = input[0];
|
|
15
|
+
}
|
|
16
|
+
this.invokeCallHook(input, kwargs);
|
|
17
|
+
const T = input.shape[1];
|
|
18
|
+
if (T === undefined)
|
|
19
|
+
throw new Error('unexpected shape');
|
|
20
|
+
return tf.reshape(tf.range(0, T, 1, 'int32'), [1, T]);
|
|
21
|
+
});
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
tf.serialization.registerClass(Range);
|
|
25
|
+
class LogLayer extends tf.layers.Layer {
|
|
26
|
+
static className = 'LogLayer';
|
|
27
|
+
computeOutputShape(inputShape) {
|
|
28
|
+
return inputShape;
|
|
29
|
+
}
|
|
30
|
+
call(input, kwargs) {
|
|
31
|
+
return tf.tidy(() => {
|
|
32
|
+
if (Array.isArray(input)) {
|
|
33
|
+
input = input[0];
|
|
34
|
+
}
|
|
35
|
+
this.invokeCallHook(input, kwargs);
|
|
36
|
+
return input;
|
|
37
|
+
});
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
tf.serialization.registerClass(LogLayer);
|
|
41
|
+
class CausalSelfAttention extends tf.layers.Layer {
|
|
42
|
+
config;
|
|
43
|
+
peakMemory;
|
|
44
|
+
static className = 'CausalSelfAttention';
|
|
45
|
+
nHead;
|
|
46
|
+
nEmbd;
|
|
47
|
+
dropout;
|
|
48
|
+
bias;
|
|
49
|
+
mask;
|
|
50
|
+
cAttnKernel;
|
|
51
|
+
cAttnBias;
|
|
52
|
+
cProjKernel;
|
|
53
|
+
cProjBias;
|
|
54
|
+
constructor(config, disposalRefs, peakMemory) {
|
|
55
|
+
super(config);
|
|
56
|
+
this.config = config;
|
|
57
|
+
this.peakMemory = peakMemory;
|
|
58
|
+
this.nEmbd = config.nEmbd;
|
|
59
|
+
this.nHead = config.nHead;
|
|
60
|
+
this.dropout = config.dropout;
|
|
61
|
+
this.bias = config.bias;
|
|
62
|
+
// mask is a lower triangular matrix filled with 1
|
|
63
|
+
// calling bandPart zero out the upper triangular part of the all-ones matrix
|
|
64
|
+
// from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
|
|
65
|
+
this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0);
|
|
66
|
+
disposalRefs.push(this.mask); // Push a reference to dispose this matrix later
|
|
67
|
+
}
|
|
68
|
+
build() {
|
|
69
|
+
this.cAttnKernel = this.addWeight('c_attn/kernel', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.glorotNormal({}));
|
|
70
|
+
this.cAttnBias = this.addWeight('c_attn/bias', [3 * this.nEmbd], 'float32', tf.initializers.zeros());
|
|
71
|
+
this.cProjKernel = this.addWeight('c_proj/kernel', [this.nEmbd, this.nEmbd], 'float32', tf.initializers.glorotNormal({}));
|
|
72
|
+
this.cProjBias = this.addWeight('c_proj/bias', [this.nEmbd], 'float32', tf.initializers.zeros());
|
|
73
|
+
}
|
|
74
|
+
computeOutputShape(inputShape) {
|
|
75
|
+
return inputShape;
|
|
76
|
+
}
|
|
77
|
+
getConfig() {
|
|
78
|
+
const config = super.getConfig();
|
|
79
|
+
return Object.assign({}, config, this.config);
|
|
80
|
+
}
|
|
81
|
+
call(input, kwargs) {
|
|
82
|
+
return tf.tidy(() => {
|
|
83
|
+
if (this.cAttnKernel === undefined ||
|
|
84
|
+
this.cAttnBias === undefined ||
|
|
85
|
+
this.cProjKernel === undefined ||
|
|
86
|
+
this.cProjBias === undefined) {
|
|
87
|
+
throw new Error('not built');
|
|
88
|
+
}
|
|
89
|
+
if (Array.isArray(input)) {
|
|
90
|
+
input = input[0];
|
|
91
|
+
}
|
|
92
|
+
this.invokeCallHook(input, kwargs);
|
|
93
|
+
const dense = (x, kernel, bias) => {
|
|
94
|
+
const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]);
|
|
95
|
+
const m = x.matMul(k);
|
|
96
|
+
if (this.bias) {
|
|
97
|
+
return tf.add(m, bias.read());
|
|
98
|
+
}
|
|
99
|
+
else {
|
|
100
|
+
return m;
|
|
101
|
+
}
|
|
102
|
+
};
|
|
103
|
+
// Apply attention weights to inputs as one big matrix which is then split into the
|
|
104
|
+
// query, key and value submatrices
|
|
105
|
+
const cAttn = dense(input, this.cAttnKernel, this.cAttnBias);
|
|
106
|
+
let [q, k, v] = tf.split(cAttn, 3, -1);
|
|
107
|
+
const [B, T, C] = k.shape;
|
|
108
|
+
const splitHeads = (x) => tf.transpose(tf.reshape(x, [B, T, this.nHead, C / this.nHead]), [0, 2, 1, 3]);
|
|
109
|
+
q = splitHeads(q);
|
|
110
|
+
k = splitHeads(k);
|
|
111
|
+
v = splitHeads(v);
|
|
112
|
+
// Scaled self attention: query @ key / sqrt(n_heads)
|
|
113
|
+
let att = tf.mul(tf.matMul(q, k, false, true), tf.div(1, tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32'))));
|
|
114
|
+
// The next operations apply attention to the past tokens, which is
|
|
115
|
+
// essentially a weighted average of the past tokens with complicated weights,
|
|
116
|
+
// and makes sure to not pay any attention to future tokens
|
|
117
|
+
// mask is lower triangular matrix filled with 1
|
|
118
|
+
const mask = this.mask.slice([0, 0], [T, T]);
|
|
119
|
+
// 1 - mask => upper triangular matrix filled with 1
|
|
120
|
+
// (1 - mask) * -10^9 => upper triangular matrix filled with -inf
|
|
121
|
+
// att + ((1 - mask) * -10^9) => lower triangular part is the same as the `att` matrix
|
|
122
|
+
// upper triangular part is -inf
|
|
123
|
+
att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9));
|
|
124
|
+
// applying softmax zeros out the upper triangular part
|
|
125
|
+
//(which are the attention weights of future tokens)
|
|
126
|
+
// and creates a probability distribution for the lower triangular
|
|
127
|
+
// (attention weights of past tokens). The probability distribution ensures
|
|
128
|
+
// that the attention weights of past tokens for a particular token sum to one
|
|
129
|
+
att = tf.softmax(att, -1);
|
|
130
|
+
att = kwargs.training === true ? tf.dropout(att, this.dropout) : att;
|
|
131
|
+
// This is where the (attention-)weighted sum of past values is performed
|
|
132
|
+
let y = tf.matMul(att, v);
|
|
133
|
+
y = tf.transpose(y, [0, 2, 1, 3]);
|
|
134
|
+
y = tf.reshape(y, [B, T, C]);
|
|
135
|
+
y = dense(y, this.cProjKernel, this.cProjBias);
|
|
136
|
+
y = kwargs.training === true ? tf.dropout(y, this.dropout) : y;
|
|
137
|
+
const memoryAllocated = tf.memory().numBytes / 1024 / 1024 / 1024; // GB
|
|
138
|
+
if (memoryAllocated > this.peakMemory.value) {
|
|
139
|
+
this.peakMemory.value = memoryAllocated;
|
|
140
|
+
}
|
|
141
|
+
return y;
|
|
142
|
+
});
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
tf.serialization.registerClass(CausalSelfAttention);
|
|
146
|
+
class GELU extends tf.layers.Layer {
|
|
147
|
+
static className = 'GELU';
|
|
148
|
+
constructor() {
|
|
149
|
+
super({});
|
|
150
|
+
}
|
|
151
|
+
computeOutputShape(inputShape) {
|
|
152
|
+
return inputShape;
|
|
153
|
+
}
|
|
154
|
+
call(input, kwargs) {
|
|
155
|
+
return tf.tidy(() => {
|
|
156
|
+
if (Array.isArray(input)) {
|
|
157
|
+
// TODO support multitensor
|
|
158
|
+
input = input[0];
|
|
159
|
+
}
|
|
160
|
+
this.invokeCallHook(input, kwargs);
|
|
161
|
+
const cdf = tf.mul(0.5, tf.add(1, tf.tanh(tf.mul(tf.sqrt(tf.div(2, Math.PI)), tf.add(input, tf.mul(0.044715, tf.pow(input, 3)))))));
|
|
162
|
+
return tf.mul(input, cdf);
|
|
163
|
+
});
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
tf.serialization.registerClass(GELU);
|
|
167
|
+
function MLP(config) {
|
|
168
|
+
return tf.sequential({ layers: [
|
|
169
|
+
tf.layers.dense({
|
|
170
|
+
name: 'mlp/c_fc',
|
|
171
|
+
units: 4 * config.nEmbd,
|
|
172
|
+
inputDim: config.nEmbd,
|
|
173
|
+
inputShape: [config.blockSize, config.nEmbd]
|
|
174
|
+
}),
|
|
175
|
+
new GELU(),
|
|
176
|
+
tf.layers.dense({
|
|
177
|
+
name: 'mlp/c_proj',
|
|
178
|
+
units: config.nEmbd,
|
|
179
|
+
inputDim: 4 * config.nEmbd,
|
|
180
|
+
inputShape: [config.blockSize, 4 * config.nEmbd]
|
|
181
|
+
}),
|
|
182
|
+
tf.layers.dropout({
|
|
183
|
+
name: 'mlp/drop',
|
|
184
|
+
rate: config.residDrop
|
|
185
|
+
}),
|
|
186
|
+
] });
|
|
187
|
+
}
|
|
188
|
+
function TransformerBlock(conf, disposalRefs, peakMemory) {
|
|
189
|
+
const config = Object.assign({ name: 'h' }, conf);
|
|
190
|
+
const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] });
|
|
191
|
+
let x1, x2;
|
|
192
|
+
// input normalization
|
|
193
|
+
x1 = tf.layers.layerNormalization({ name: config.name + '/ln_1', epsilon: 1e-5 })
|
|
194
|
+
.apply(inputs);
|
|
195
|
+
if (config.debug) {
|
|
196
|
+
x1 = new LogLayer({ name: config.name + '/ln_1_log' }).apply(x1);
|
|
197
|
+
}
|
|
198
|
+
// self attention layer
|
|
199
|
+
x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '/attn' }), disposalRefs, peakMemory).apply(x1);
|
|
200
|
+
// Residual connection
|
|
201
|
+
x1 = tf.layers.add().apply([inputs, x1]);
|
|
202
|
+
// normalization
|
|
203
|
+
x2 = tf.layers
|
|
204
|
+
.layerNormalization({ name: config.name + '/ln_2', epsilon: 1e-5 })
|
|
205
|
+
.apply(x1);
|
|
206
|
+
// MLP
|
|
207
|
+
x2 = MLP(Object.assign({}, config, { name: config.name + '/mlp' })).apply(x2);
|
|
208
|
+
// add attention output to mlp output
|
|
209
|
+
x2 = tf.layers.add().apply([x1, x2]);
|
|
210
|
+
return tf.model({ name: config.name, inputs, outputs: x2 });
|
|
211
|
+
}
|
|
212
|
+
/**
|
|
213
|
+
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
214
|
+
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
215
|
+
* used to create a GPTModel
|
|
216
|
+
*
|
|
217
|
+
* @param conf GPTConfig
|
|
218
|
+
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
|
|
219
|
+
*/
|
|
220
|
+
export function GPTArchitecture(config, disposalRefs, peakMemory) {
|
|
221
|
+
const inputs = tf.input({ shape: [null] });
|
|
222
|
+
//Token embedding
|
|
223
|
+
const tokEmb = config.tokEmb
|
|
224
|
+
? tf.layers.embedding({
|
|
225
|
+
name: config.name + '/wte',
|
|
226
|
+
inputDim: config.vocabSize,
|
|
227
|
+
outputDim: config.nEmbd,
|
|
228
|
+
embeddingsInitializer: 'zeros',
|
|
229
|
+
embeddingsRegularizer: undefined,
|
|
230
|
+
activityRegularizer: undefined
|
|
231
|
+
}).apply(inputs)
|
|
232
|
+
: inputs;
|
|
233
|
+
// Positional embedding
|
|
234
|
+
const range = new Range({}).apply(inputs);
|
|
235
|
+
let posEmb = tf.layers.embedding({
|
|
236
|
+
name: config.name + '/wpe',
|
|
237
|
+
inputDim: config.blockSize,
|
|
238
|
+
outputDim: config.nEmbd,
|
|
239
|
+
embeddingsInitializer: 'zeros'
|
|
240
|
+
}).apply(range);
|
|
241
|
+
if (config.debug) {
|
|
242
|
+
posEmb = new LogLayer({ name: 'posEmb' }).apply(posEmb);
|
|
243
|
+
}
|
|
244
|
+
// token and positional embeddings are added together
|
|
245
|
+
let x = tf.layers.add().apply([tokEmb, posEmb]);
|
|
246
|
+
// dropout
|
|
247
|
+
x = tf.layers.dropout({ name: 'drop', rate: config.embdDrop }).apply(x);
|
|
248
|
+
if (config.debug) {
|
|
249
|
+
x = new LogLayer({ name: 'dropadd' }).apply(x);
|
|
250
|
+
}
|
|
251
|
+
//Apply successively transformer blocks, attention and dense layers
|
|
252
|
+
for (let i = 0; i < config.nLayer; i++) {
|
|
253
|
+
x = TransformerBlock(Object.assign({}, config, { name: config.name + '/h/' + i }), disposalRefs, peakMemory).apply(x);
|
|
254
|
+
}
|
|
255
|
+
// Normalization
|
|
256
|
+
x = tf.layers.layerNormalization({ name: config.name + '/ln_f', epsilon: 1e-5 })
|
|
257
|
+
.apply(x);
|
|
258
|
+
if (config.debug) {
|
|
259
|
+
x = new LogLayer({ name: 'fin/ln' }).apply(x);
|
|
260
|
+
}
|
|
261
|
+
// Append a language modeling head if specified
|
|
262
|
+
if (config.lmHead) {
|
|
263
|
+
x = tf.layers.dense({
|
|
264
|
+
name: 'lm_head',
|
|
265
|
+
units: config.vocabSize,
|
|
266
|
+
inputDim: config.nEmbd,
|
|
267
|
+
inputShape: [config.blockSize, config.nEmbd],
|
|
268
|
+
useBias: false
|
|
269
|
+
}).apply(x);
|
|
270
|
+
}
|
|
271
|
+
return tf.model({ inputs, outputs: x });
|
|
272
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { GPTConfig } from './config.js';
|
|
3
|
+
/**
|
|
4
|
+
* tfjs does not export LazyIterator and Dataset...
|
|
5
|
+
*/
|
|
6
|
+
declare abstract class LazyIterator<T> {
|
|
7
|
+
abstract next(): Promise<IteratorResult<T>>;
|
|
8
|
+
}
|
|
9
|
+
export declare abstract class Dataset<T> {
|
|
10
|
+
abstract iterator(): Promise<LazyIterator<T>>;
|
|
11
|
+
size: number;
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
15
|
+
*
|
|
16
|
+
*/
|
|
17
|
+
declare class GPTModel extends tf.LayersModel {
|
|
18
|
+
protected readonly config: Required<GPTConfig>;
|
|
19
|
+
private readonly disposalRefs;
|
|
20
|
+
protected peakMemory: {
|
|
21
|
+
value: number;
|
|
22
|
+
};
|
|
23
|
+
constructor(partialConfig?: GPTConfig);
|
|
24
|
+
disposeRefs(): void;
|
|
25
|
+
get getGPTConfig(): Required<GPTConfig>;
|
|
26
|
+
compile(): void;
|
|
27
|
+
fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History>;
|
|
28
|
+
}
|
|
29
|
+
interface GenerateConfig {
|
|
30
|
+
maxNewTokens: number;
|
|
31
|
+
temperature: number;
|
|
32
|
+
doSample: boolean;
|
|
33
|
+
}
|
|
34
|
+
/**
|
|
35
|
+
* GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
|
|
36
|
+
* This class extends GPTModel and adds supports for text generation
|
|
37
|
+
*
|
|
38
|
+
*/
|
|
39
|
+
export declare class GPTForCausalLM extends GPTModel {
|
|
40
|
+
generate(idxRaw: tf.TensorLike, conf: GenerateConfig): Promise<number[][]>;
|
|
41
|
+
private generateOnce;
|
|
42
|
+
}
|
|
43
|
+
export {};
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { getModelSizes, DEFAULT_CONFIG } from './config.js';
|
|
3
|
+
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js';
|
|
4
|
+
import evaluate from './evaluate.js';
|
|
5
|
+
import { GPTArchitecture } from './layers.js';
|
|
6
|
+
/**
|
|
7
|
+
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
8
|
+
*
|
|
9
|
+
*/
|
|
10
|
+
class GPTModel extends tf.LayersModel {
|
|
11
|
+
config;
|
|
12
|
+
disposalRefs; // Array to store tensor to dispose manually
|
|
13
|
+
// Object to pass down to layers to store max memory allocated
|
|
14
|
+
// This is an object rather than a primitive to pass the reference
|
|
15
|
+
peakMemory;
|
|
16
|
+
constructor(partialConfig) {
|
|
17
|
+
// Fill missing config parameters with default values
|
|
18
|
+
let completeConfig = { ...DEFAULT_CONFIG, ...partialConfig };
|
|
19
|
+
// Add layer sizes depending on which model has been specified
|
|
20
|
+
completeConfig = { ...completeConfig, ...getModelSizes(completeConfig.modelType) };
|
|
21
|
+
// Init the tf.LayersModel and assign it to this
|
|
22
|
+
const disposalRefs = [];
|
|
23
|
+
const peakMemory = { value: 0 };
|
|
24
|
+
const gpt = GPTArchitecture(completeConfig, disposalRefs, peakMemory);
|
|
25
|
+
const { inputs, outputs, name } = gpt;
|
|
26
|
+
super({ inputs, outputs, name });
|
|
27
|
+
this.config = completeConfig;
|
|
28
|
+
this.disposalRefs = disposalRefs;
|
|
29
|
+
this.peakMemory = peakMemory;
|
|
30
|
+
}
|
|
31
|
+
// Some tensors are not cleaned up when model.dispose is called
|
|
32
|
+
// So we dispose them manually
|
|
33
|
+
disposeRefs() {
|
|
34
|
+
for (const tensorContainer of this.disposalRefs) {
|
|
35
|
+
tf.dispose([tensorContainer]);
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
get getGPTConfig() {
|
|
39
|
+
return this.config;
|
|
40
|
+
}
|
|
41
|
+
compile() {
|
|
42
|
+
this.optimizer = this.config.weightDecay !== 0
|
|
43
|
+
? getCustomAdam(this, this.config.lr, this.config.weightDecay)
|
|
44
|
+
: tf.train.adam(this.config.lr);
|
|
45
|
+
this.peakMemory.value = 0;
|
|
46
|
+
}
|
|
47
|
+
async fitDataset(dataset, trainingArgs) {
|
|
48
|
+
const callbacks = trainingArgs.callbacks;
|
|
49
|
+
const evalDataset = trainingArgs.validationData;
|
|
50
|
+
await callbacks.onTrainBegin?.();
|
|
51
|
+
for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
|
|
52
|
+
let averageLoss = 0;
|
|
53
|
+
let iteration = 1;
|
|
54
|
+
const iterator = await dataset.iterator();
|
|
55
|
+
let preprocessingTime = performance.now();
|
|
56
|
+
let next = await iterator.next();
|
|
57
|
+
preprocessingTime = performance.now() - preprocessingTime;
|
|
58
|
+
while (next.done !== true && iteration <= this.config.maxIter) {
|
|
59
|
+
let weightUpdateTime = performance.now();
|
|
60
|
+
await callbacks.onEpochBegin?.(epoch);
|
|
61
|
+
const { xs, ys } = next.value;
|
|
62
|
+
const lossFn = () => {
|
|
63
|
+
const logits = this.apply(xs);
|
|
64
|
+
if (Array.isArray(logits)) {
|
|
65
|
+
throw new Error('model outputs too many tensor');
|
|
66
|
+
}
|
|
67
|
+
if (logits instanceof tf.SymbolicTensor) {
|
|
68
|
+
throw new Error('model outputs symbolic tensor');
|
|
69
|
+
}
|
|
70
|
+
return tf.losses.softmaxCrossEntropy(ys, logits);
|
|
71
|
+
};
|
|
72
|
+
let backwardPassMemory = 0;
|
|
73
|
+
const lossTensor = tf.tidy(() => {
|
|
74
|
+
const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn);
|
|
75
|
+
const gradsClipped = clipByGlobalNormObj(grads, 1);
|
|
76
|
+
this.optimizer.applyGradients(gradsClipped);
|
|
77
|
+
backwardPassMemory = tf.memory().numBytes / 1024 / 1024 / 1024;
|
|
78
|
+
return lossTensor;
|
|
79
|
+
});
|
|
80
|
+
const loss = await lossTensor.array();
|
|
81
|
+
averageLoss += loss;
|
|
82
|
+
weightUpdateTime = performance.now() - weightUpdateTime;
|
|
83
|
+
// Probably never the case. Empirically the attention mechanism always allocates
|
|
84
|
+
// more memory than the backward pass
|
|
85
|
+
if (backwardPassMemory > this.peakMemory.value) {
|
|
86
|
+
this.peakMemory.value = backwardPassMemory;
|
|
87
|
+
}
|
|
88
|
+
tf.dispose([xs, ys, lossTensor]);
|
|
89
|
+
if (evalDataset !== undefined &&
|
|
90
|
+
this.config.evaluateEvery !== undefined &&
|
|
91
|
+
iteration % this.config.evaluateEvery == 0) {
|
|
92
|
+
const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches);
|
|
93
|
+
console.log(iterationLogs);
|
|
94
|
+
}
|
|
95
|
+
console.log(`Epoch: ${epoch}`, `\tStep: ${iteration} / ${this.config.maxIter}`, `\tLoss: ${loss.toFixed(3)}`, `\tPeak memory: ${this.peakMemory.value.toFixed(2)} GB`, `\tNumber of tensors allocated: ${tf.memory().numTensors}`, `\tPreprocessing time: ${preprocessingTime.toFixed(0)} ms`, `\tWeight update time: ${weightUpdateTime.toFixed(0)} ms`);
|
|
96
|
+
iteration++;
|
|
97
|
+
next = await iterator.next();
|
|
98
|
+
}
|
|
99
|
+
// Memory leak: If we reached the last iteration rather than the end of the dataset, cleanup the tensors
|
|
100
|
+
if (next.done != true && iteration > this.config.maxIter) {
|
|
101
|
+
const { xs, ys } = next.value;
|
|
102
|
+
tf.dispose([xs, ys]);
|
|
103
|
+
}
|
|
104
|
+
let logs = {
|
|
105
|
+
'loss': averageLoss / iteration,
|
|
106
|
+
'peakMemory': this.peakMemory.value
|
|
107
|
+
};
|
|
108
|
+
if (evalDataset !== undefined) {
|
|
109
|
+
logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) };
|
|
110
|
+
console.log(logs);
|
|
111
|
+
}
|
|
112
|
+
await callbacks.onEpochEnd?.(epoch, logs);
|
|
113
|
+
}
|
|
114
|
+
await callbacks.onTrainEnd?.();
|
|
115
|
+
return new tf.History();
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
const defaultGenerateConfig = {
|
|
119
|
+
maxNewTokens: 20,
|
|
120
|
+
temperature: 1.0,
|
|
121
|
+
doSample: false
|
|
122
|
+
};
|
|
123
|
+
function prepareIdx(idx) {
|
|
124
|
+
return tf.tidy(() => {
|
|
125
|
+
let ret;
|
|
126
|
+
if (idx instanceof tf.Tensor) {
|
|
127
|
+
ret = idx.clone();
|
|
128
|
+
}
|
|
129
|
+
else {
|
|
130
|
+
ret = tf.tensor(idx);
|
|
131
|
+
}
|
|
132
|
+
if (ret.dtype !== 'int32') {
|
|
133
|
+
ret = ret.toInt();
|
|
134
|
+
}
|
|
135
|
+
switch (ret.shape.length) {
|
|
136
|
+
case 1:
|
|
137
|
+
return ret.expandDims(0);
|
|
138
|
+
case 2:
|
|
139
|
+
return ret;
|
|
140
|
+
default:
|
|
141
|
+
throw new Error('unexpected shape');
|
|
142
|
+
}
|
|
143
|
+
});
|
|
144
|
+
}
|
|
145
|
+
/**
|
|
146
|
+
* GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
|
|
147
|
+
* This class extends GPTModel and adds supports for text generation
|
|
148
|
+
*
|
|
149
|
+
*/
|
|
150
|
+
export class GPTForCausalLM extends GPTModel {
|
|
151
|
+
async generate(idxRaw, conf) {
|
|
152
|
+
const config = Object.assign({}, defaultGenerateConfig, conf);
|
|
153
|
+
let idx = prepareIdx(idxRaw);
|
|
154
|
+
for (let step = 0; step < config.maxNewTokens; step++) {
|
|
155
|
+
const idxNext = this.generateOnce(this, idx, config);
|
|
156
|
+
const idxNew = idx.concat(idxNext, 1);
|
|
157
|
+
tf.dispose(idx);
|
|
158
|
+
idx = idxNew;
|
|
159
|
+
tf.dispose(idxNext);
|
|
160
|
+
}
|
|
161
|
+
const idxArr = await idx.array();
|
|
162
|
+
tf.dispose(idx);
|
|
163
|
+
return idxArr;
|
|
164
|
+
}
|
|
165
|
+
generateOnce(model, idx, config) {
|
|
166
|
+
const idxNext = tf.tidy(() => {
|
|
167
|
+
// slice input tokens if longer than context length
|
|
168
|
+
const blockSize = this.config.blockSize;
|
|
169
|
+
idx = idx.shape[1] <= blockSize
|
|
170
|
+
? idx : idx.slice([0, idx.shape[1] - blockSize]);
|
|
171
|
+
const output = model.predict(idx);
|
|
172
|
+
if (Array.isArray(output))
|
|
173
|
+
throw new Error('The model outputs too multiple values');
|
|
174
|
+
if (output.shape.length !== 3)
|
|
175
|
+
throw new Error('The model outputs wrong shape');
|
|
176
|
+
const logits = output;
|
|
177
|
+
const logitsScaled = logits
|
|
178
|
+
.slice([0, idx.shape[1] - 1, 0])
|
|
179
|
+
.reshape([logits.shape[0], logits.shape[2]])
|
|
180
|
+
.div(tf.scalar(config.temperature));
|
|
181
|
+
const probs = logitsScaled.softmax(-1);
|
|
182
|
+
if (config.doSample) {
|
|
183
|
+
return tf.multinomial(probs, 1);
|
|
184
|
+
}
|
|
185
|
+
else {
|
|
186
|
+
return probs.argMax(-1).expandDims(1);
|
|
187
|
+
}
|
|
188
|
+
});
|
|
189
|
+
return idxNext;
|
|
190
|
+
}
|
|
191
|
+
}
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
declare function clipByGlobalNormObj(tensorsObj: Record<string, tf.Tensor>, clipNorm: number, useNorm?: tf.Tensor): Record<string, tf.Tensor>;
|
|
3
|
+
declare function getCustomAdam(model: tf.LayersModel, lr: number, weightDecay: number): tf.Optimizer;
|
|
4
|
+
export { getCustomAdam, clipByGlobalNormObj };
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
function l2Loss(tensor) {
|
|
3
|
+
return tf.div(tf.sum(tf.square(tensor)), 2);
|
|
4
|
+
}
|
|
5
|
+
function globalNorm(tensors) {
|
|
6
|
+
const halfSquaredNorms = [];
|
|
7
|
+
tensors.forEach((tensor) => {
|
|
8
|
+
halfSquaredNorms.push(l2Loss(tensor));
|
|
9
|
+
});
|
|
10
|
+
const halfSquaredNorm = tf.sum(tf.stack(halfSquaredNorms));
|
|
11
|
+
const norm = tf.sqrt(tf.mul(halfSquaredNorm, tf.scalar(2.0, halfSquaredNorm.dtype)));
|
|
12
|
+
return norm;
|
|
13
|
+
}
|
|
14
|
+
function clipByGlobalNorm(tensors, clipNorm, useNorm) {
|
|
15
|
+
return tf.tidy(() => {
|
|
16
|
+
useNorm = useNorm ?? globalNorm(tensors);
|
|
17
|
+
const scale = tf.mul(clipNorm, tf.minimum(tf.div(tf.scalar(1.0), useNorm), tf.div(tf.scalar(1.0, useNorm.dtype), clipNorm)));
|
|
18
|
+
const tensorsClipped = [];
|
|
19
|
+
tensors.forEach((tensor) => {
|
|
20
|
+
tensorsClipped.push(tf.clone(tf.mul(tensor, scale)));
|
|
21
|
+
});
|
|
22
|
+
return tensorsClipped;
|
|
23
|
+
});
|
|
24
|
+
}
|
|
25
|
+
function clipByGlobalNormObj(tensorsObj, clipNorm, useNorm) {
|
|
26
|
+
const varNames = Object.keys(tensorsObj);
|
|
27
|
+
const tensorsArr = varNames.map((n) => tensorsObj[n]);
|
|
28
|
+
const tensorsArrClipped = clipByGlobalNorm(tensorsArr, clipNorm, useNorm);
|
|
29
|
+
const tensorsObjClipped = {};
|
|
30
|
+
tensorsArrClipped.forEach((t, ti) => {
|
|
31
|
+
tensorsObjClipped[varNames[ti]] = t;
|
|
32
|
+
});
|
|
33
|
+
return tensorsObjClipped;
|
|
34
|
+
}
|
|
35
|
+
class AdamW extends tf.AdamOptimizer {
|
|
36
|
+
weightDecayRate;
|
|
37
|
+
includeInWeightDecay;
|
|
38
|
+
excludeFromWeightDecay;
|
|
39
|
+
gradientClipNorm;
|
|
40
|
+
constructor(params) {
|
|
41
|
+
console.log('Using custom AdamW optimizer');
|
|
42
|
+
const defaultParams = {
|
|
43
|
+
learningRate: 0.1,
|
|
44
|
+
beta1: 0.9,
|
|
45
|
+
beta2: 0.999,
|
|
46
|
+
epsilon: 1e-7,
|
|
47
|
+
weightDecayRate: 0,
|
|
48
|
+
includeInWeightDecay: [],
|
|
49
|
+
excludeFromWeightDecay: [],
|
|
50
|
+
gradientClipNorm: 1.0
|
|
51
|
+
};
|
|
52
|
+
const p = Object.assign({}, defaultParams, params);
|
|
53
|
+
super(p.learningRate, p.beta1, p.beta2, p.epsilon);
|
|
54
|
+
this.weightDecayRate = p.weightDecayRate;
|
|
55
|
+
this.includeInWeightDecay = p.includeInWeightDecay;
|
|
56
|
+
this.excludeFromWeightDecay = p.excludeFromWeightDecay;
|
|
57
|
+
this.gradientClipNorm = p.gradientClipNorm;
|
|
58
|
+
}
|
|
59
|
+
applyGradients(variableGradients) {
|
|
60
|
+
const varNames = Array.isArray(variableGradients)
|
|
61
|
+
? variableGradients.map((v) => v.name)
|
|
62
|
+
: Object.keys(variableGradients);
|
|
63
|
+
varNames.forEach((name) => {
|
|
64
|
+
if (this.includeInWeightDecay.includes(name)) {
|
|
65
|
+
const value = tf.engine().registeredVariables[name];
|
|
66
|
+
const newValue = tf.sub(value, tf.mul(this.learningRate, tf.mul(value, this.weightDecayRate)));
|
|
67
|
+
value.assign(newValue);
|
|
68
|
+
}
|
|
69
|
+
});
|
|
70
|
+
super.applyGradients(variableGradients);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
function getCustomAdam(model, lr, weightDecay) {
|
|
74
|
+
const includeInWeightDecay = [];
|
|
75
|
+
const excludeFromWeightDecay = [];
|
|
76
|
+
// TODO unsafe cast
|
|
77
|
+
const namedWeights = model.getNamedWeights();
|
|
78
|
+
namedWeights.forEach((v) => {
|
|
79
|
+
if (v.name.includes('bias') ||
|
|
80
|
+
v.name.includes('normalization') ||
|
|
81
|
+
v.name.includes('emb')) {
|
|
82
|
+
excludeFromWeightDecay.push(v.name);
|
|
83
|
+
}
|
|
84
|
+
else {
|
|
85
|
+
includeInWeightDecay.push(v.name);
|
|
86
|
+
}
|
|
87
|
+
});
|
|
88
|
+
return new AdamW({
|
|
89
|
+
learningRate: lr,
|
|
90
|
+
weightDecayRate: weightDecay,
|
|
91
|
+
includeInWeightDecay,
|
|
92
|
+
excludeFromWeightDecay
|
|
93
|
+
});
|
|
94
|
+
}
|
|
95
|
+
export { getCustomAdam, clipByGlobalNormObj };
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
/// <reference types="node" resolution-mode="require"/>
|
|
2
|
+
import type tf from "@tensorflow/tfjs";
|
|
3
|
+
import type { WeightsContainer } from "../index.js";
|
|
4
|
+
import type { Dataset } from "../dataset/index.js";
|
|
5
|
+
export interface EpochLogs {
|
|
6
|
+
epoch: number;
|
|
7
|
+
training: {
|
|
8
|
+
loss: number;
|
|
9
|
+
accuracy?: number;
|
|
10
|
+
};
|
|
11
|
+
validation?: {
|
|
12
|
+
loss: number;
|
|
13
|
+
accuracy: number;
|
|
14
|
+
};
|
|
15
|
+
peakMemory: number;
|
|
16
|
+
}
|
|
17
|
+
export type Prediction = tf.Tensor;
|
|
18
|
+
export type Sample = tf.Tensor;
|
|
19
|
+
/**
|
|
20
|
+
* Trainable predictor
|
|
21
|
+
*
|
|
22
|
+
* Allow for various implementation of models (various train function, tensor-library, ...)
|
|
23
|
+
**/
|
|
24
|
+
export declare abstract class Model implements Disposable {
|
|
25
|
+
/** Return training state */
|
|
26
|
+
abstract get weights(): WeightsContainer;
|
|
27
|
+
/** Set training state */
|
|
28
|
+
abstract set weights(ws: WeightsContainer);
|
|
29
|
+
/**
|
|
30
|
+
* Improve predictor
|
|
31
|
+
*
|
|
32
|
+
* @param trainingData dataset to optimize for
|
|
33
|
+
* @param validationData dataset to measure how well it is training
|
|
34
|
+
* @param epochs number of pass over the training dataset
|
|
35
|
+
* @param tracker watch the various steps
|
|
36
|
+
* @yields on every epoch, training can be stop by `return`ing it
|
|
37
|
+
*/
|
|
38
|
+
abstract train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
|
|
39
|
+
/** Predict likely values */
|
|
40
|
+
abstract predict(input: Sample): Promise<Prediction>;
|
|
41
|
+
/**
|
|
42
|
+
* This method is automatically called to cleanup the memory occupied by the model
|
|
43
|
+
* when leaving the definition scope if the instance has been defined with the `using` keyword.
|
|
44
|
+
* For example:
|
|
45
|
+
* function f() {
|
|
46
|
+
* using model = new Model();
|
|
47
|
+
* }
|
|
48
|
+
* Calling f() will call the model's dispose method when exiting the function.
|
|
49
|
+
*/
|
|
50
|
+
abstract [Symbol.dispose](): void;
|
|
51
|
+
}
|