@epfml/discojs 3.0.1-p20241119093954.0 → 3.0.1-p20241206133538.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/client/client.js +2 -0
- package/dist/client/federated/federated_client.js +2 -2
- package/dist/dataset/dataset.d.ts +18 -5
- package/dist/dataset/dataset.js +58 -23
- package/dist/dataset/types.d.ts +1 -0
- package/dist/default_tasks/index.d.ts +1 -0
- package/dist/default_tasks/index.js +1 -0
- package/dist/default_tasks/tinder_dog.d.ts +2 -0
- package/dist/default_tasks/tinder_dog.js +72 -0
- package/dist/default_tasks/wikitext.js +5 -3
- package/dist/models/gpt/config.d.ts +11 -6
- package/dist/models/gpt/config.js +11 -7
- package/dist/models/gpt/index.d.ts +5 -9
- package/dist/models/gpt/index.js +36 -15
- package/dist/models/gpt/layers.js +260 -82
- package/dist/models/gpt/model.d.ts +1 -1
- package/dist/models/gpt/model.js +4 -4
- package/dist/processing/index.js +8 -9
- package/dist/processing/text.d.ts +16 -6
- package/dist/processing/text.js +29 -26
- package/dist/task/task_handler.js +5 -1
- package/dist/task/training_information.d.ts +1 -1
- package/dist/task/training_information.js +3 -4
- package/dist/training/disco.js +6 -3
- package/dist/types/data_format.d.ts +2 -2
- package/dist/validator.js +2 -2
- package/package.json +1 -1
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
const debug = createDebug("discojs:models:gpt:layers");
|
|
2
4
|
/**
|
|
3
5
|
* Defines a range, from 0 to T, that is used to create positional embeddings
|
|
4
6
|
*/
|
|
@@ -10,7 +12,8 @@ class Range extends tf.layers.Layer {
|
|
|
10
12
|
call(input, kwargs) {
|
|
11
13
|
return tf.tidy(() => {
|
|
12
14
|
if (Array.isArray(input)) {
|
|
13
|
-
|
|
15
|
+
if (input.length !== 1)
|
|
16
|
+
throw new Error('expected exactly one tensor');
|
|
14
17
|
input = input[0];
|
|
15
18
|
}
|
|
16
19
|
this.invokeCallHook(input, kwargs);
|
|
@@ -22,6 +25,11 @@ class Range extends tf.layers.Layer {
|
|
|
22
25
|
}
|
|
23
26
|
}
|
|
24
27
|
tf.serialization.registerClass(Range);
|
|
28
|
+
/**
|
|
29
|
+
* LogLayer is a layer that allows debugging the input that is fed to this layer
|
|
30
|
+
* This layer allows to inspect the input tensor at a specific point
|
|
31
|
+
* in the model by adding a log layer in the model definition
|
|
32
|
+
*/
|
|
25
33
|
class LogLayer extends tf.layers.Layer {
|
|
26
34
|
static className = 'LogLayer';
|
|
27
35
|
computeOutputShape(inputShape) {
|
|
@@ -30,9 +38,19 @@ class LogLayer extends tf.layers.Layer {
|
|
|
30
38
|
call(input, kwargs) {
|
|
31
39
|
return tf.tidy(() => {
|
|
32
40
|
if (Array.isArray(input)) {
|
|
41
|
+
if (input.length !== 1)
|
|
42
|
+
throw new Error('expected exactly one tensor');
|
|
33
43
|
input = input[0];
|
|
34
44
|
}
|
|
35
45
|
this.invokeCallHook(input, kwargs);
|
|
46
|
+
const logs = {
|
|
47
|
+
'shape': input.shape,
|
|
48
|
+
'is_only_zero': !!input.equal(tf.tensor(0)).all().dataSync()[0],
|
|
49
|
+
'has_some_NaN': !!input.isNaN().any().dataSync()[0],
|
|
50
|
+
'min': +input.min().dataSync()[0].toPrecision(3),
|
|
51
|
+
'max': +input.max().dataSync()[0].toPrecision(3),
|
|
52
|
+
};
|
|
53
|
+
debug("%s logged: %o", this.name, logs);
|
|
36
54
|
return input;
|
|
37
55
|
});
|
|
38
56
|
}
|
|
@@ -43,8 +61,9 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
43
61
|
static className = 'CausalSelfAttention';
|
|
44
62
|
nHead;
|
|
45
63
|
nEmbd;
|
|
64
|
+
nLayer;
|
|
46
65
|
dropout;
|
|
47
|
-
|
|
66
|
+
seed;
|
|
48
67
|
mask;
|
|
49
68
|
cAttnKernel;
|
|
50
69
|
cAttnBias;
|
|
@@ -53,20 +72,34 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
53
72
|
constructor(config) {
|
|
54
73
|
super(config);
|
|
55
74
|
this.config = config;
|
|
75
|
+
if (config.nEmbd % config.nHead !== 0)
|
|
76
|
+
throw new Error('The embedding dimension `nEmbd` must be divisible by the number of attention heads `nHead`');
|
|
56
77
|
this.nEmbd = config.nEmbd;
|
|
57
78
|
this.nHead = config.nHead;
|
|
79
|
+
this.nLayer = config.nLayer;
|
|
58
80
|
this.dropout = config.dropout;
|
|
59
|
-
this.
|
|
81
|
+
this.seed = config.seed;
|
|
60
82
|
// mask is a lower triangular matrix filled with 1
|
|
61
83
|
// calling bandPart zero out the upper triangular part of the all-ones matrix
|
|
62
84
|
// from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
|
|
63
|
-
this.mask = tf.linalg.bandPart(tf.ones([config.
|
|
85
|
+
this.mask = tf.linalg.bandPart(tf.ones([config.contextLength, config.contextLength]), -1, 0);
|
|
64
86
|
}
|
|
65
87
|
build() {
|
|
66
|
-
|
|
67
|
-
this.
|
|
68
|
-
|
|
69
|
-
this.
|
|
88
|
+
// key, query, value projections for all heads, but in a batch
|
|
89
|
+
this.cAttnKernel = this.addWeight('c_attn.weight', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed }) // use same init as GPT2
|
|
90
|
+
);
|
|
91
|
+
this.cAttnBias = this.addWeight('c_attn.bias', [3 * this.nEmbd], 'float32', tf.initializers.zeros());
|
|
92
|
+
// output projection
|
|
93
|
+
this.cProjKernel = this.addWeight('c_proj.kernel', [this.nEmbd, this.nEmbd], 'float32',
|
|
94
|
+
// the input keeps accumulating through the residual stream so we
|
|
95
|
+
// scale the initialization with the nb of layers to keep a unit std
|
|
96
|
+
// Sources:
|
|
97
|
+
// https://github.com/karpathy/build-nanogpt/blob/6104ab1b53920f6e2159749676073ff7d815c1fa/train_gpt2.py#L103
|
|
98
|
+
// https://youtu.be/l8pRSuU81PU?si=5GcKfi_kPgLgvtg2&t=4640
|
|
99
|
+
tf.initializers.randomNormal({
|
|
100
|
+
mean: 0, stddev: 0.02 * Math.sqrt(2 * this.nLayer), seed: this.seed
|
|
101
|
+
}));
|
|
102
|
+
this.cProjBias = this.addWeight('c_proj.bias', [this.nEmbd], 'float32', tf.initializers.zeros());
|
|
70
103
|
}
|
|
71
104
|
computeOutputShape(inputShape) {
|
|
72
105
|
return inputShape;
|
|
@@ -84,58 +117,72 @@ class CausalSelfAttention extends tf.layers.Layer {
|
|
|
84
117
|
throw new Error('not built');
|
|
85
118
|
}
|
|
86
119
|
if (Array.isArray(input)) {
|
|
120
|
+
if (input.length !== 1)
|
|
121
|
+
throw new Error('expected exactly one tensor');
|
|
87
122
|
input = input[0];
|
|
88
123
|
}
|
|
89
124
|
this.invokeCallHook(input, kwargs);
|
|
90
125
|
const dense = (x, kernel, bias) => {
|
|
126
|
+
// TODO: use broadcasting when tfjs will support backpropagating through broadcasting
|
|
91
127
|
const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]);
|
|
92
128
|
const m = x.matMul(k);
|
|
93
|
-
|
|
94
|
-
return tf.add(m, bias.read());
|
|
95
|
-
}
|
|
96
|
-
else {
|
|
97
|
-
return m;
|
|
98
|
-
}
|
|
129
|
+
return tf.add(m, bias.read());
|
|
99
130
|
};
|
|
100
131
|
// Apply attention weights to inputs as one big matrix which is then split into the
|
|
101
132
|
// query, key and value submatrices
|
|
133
|
+
// nHead is "number of heads", hs is "head size", and C (number of channels) = n_embd = nHead * hs
|
|
134
|
+
// e.g. in GPT-2 (124M), nHead = 12, hs = 64, so nHead * hs = C = 768 channels in the Transformer
|
|
102
135
|
const cAttn = dense(input, this.cAttnKernel, this.cAttnBias);
|
|
103
136
|
let [q, k, v] = tf.split(cAttn, 3, -1);
|
|
104
|
-
|
|
105
|
-
const
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
//
|
|
113
|
-
//
|
|
137
|
+
// Follow naming conventions in https://github.com/karpathy/build-nanogpt/
|
|
138
|
+
const [B, T, C] = k.shape; // batch size, sequence length, embedding dimensionality (number of channels)
|
|
139
|
+
const splitHeads = (x) => tf.transpose(tf.reshape(x, [B, T, this.nHead, C / this.nHead]), // (B, T, nHead, head size)
|
|
140
|
+
[0, 2, 1, 3] // (B, nHead, T, hs)
|
|
141
|
+
);
|
|
142
|
+
q = splitHeads(q); // (B, nHead, T, hs)
|
|
143
|
+
k = splitHeads(k); // (B, nHead, T, hs)
|
|
144
|
+
v = splitHeads(v); // (B, nHead, T, hs)
|
|
145
|
+
// Scaled self attention: query @ key / sqrt(hs)
|
|
146
|
+
// Matrix representing the token-to-token attention (B, nHead, T, T)
|
|
147
|
+
let att = tf.mul(tf.matMul(q, k, false, true), // (B, nHead, T, hs) x (B, nHead, hs, T) -> (B, nHead, T, T)
|
|
148
|
+
tf.div(1, tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32'))) // 1 / sqrt(hs)
|
|
149
|
+
);
|
|
150
|
+
/**
|
|
151
|
+
* The next operations apply attention only on the past tokens, which is
|
|
152
|
+
* essentially a weighted average of the past tokens with complicated weights,
|
|
153
|
+
* it relies on a mask to not "pay any attention" to future tokens
|
|
154
|
+
*/
|
|
114
155
|
// mask is lower triangular matrix filled with 1
|
|
115
|
-
const mask = this.mask.slice([0, 0], [T, T]);
|
|
156
|
+
const mask = this.mask.slice([0, 0], [T, T]); // (T, T)
|
|
116
157
|
// 1 - mask => upper triangular matrix filled with 1
|
|
117
158
|
// (1 - mask) * -10^9 => upper triangular matrix filled with -inf
|
|
118
159
|
// att + ((1 - mask) * -10^9) => lower triangular part is the same as the `att` matrix
|
|
119
160
|
// upper triangular part is -inf
|
|
120
|
-
att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9));
|
|
121
|
-
// applying softmax
|
|
122
|
-
//
|
|
161
|
+
att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9)); // (B, nHead, T, T)
|
|
162
|
+
// applying softmax zeroes out the upper triangular part (softmax(-inf) = 0)
|
|
163
|
+
// i.e., zeroes out future tokens's attention weights
|
|
123
164
|
// and creates a probability distribution for the lower triangular
|
|
124
165
|
// (attention weights of past tokens). The probability distribution ensures
|
|
125
166
|
// that the attention weights of past tokens for a particular token sum to one
|
|
126
167
|
att = tf.softmax(att, -1);
|
|
127
|
-
att = kwargs.training === true ? tf.dropout(att, this.dropout) : att;
|
|
168
|
+
att = kwargs.training === true ? tf.dropout(att, this.dropout, undefined, this.seed) : att;
|
|
128
169
|
// This is where the (attention-)weighted sum of past values is performed
|
|
129
|
-
let y = tf.matMul(att, v);
|
|
130
|
-
y = tf.transpose(y, [0, 2, 1, 3]);
|
|
131
|
-
y = tf.reshape(y, [B, T, C]);
|
|
132
|
-
y = dense(y, this.cProjKernel, this.cProjBias);
|
|
133
|
-
y = kwargs.training === true ? tf.dropout(y, this.dropout) : y;
|
|
170
|
+
let y = tf.matMul(att, v); // (B, nHead, T, T) x (B, nHead, T, hs) -> (B, nHead, T, hs)
|
|
171
|
+
y = tf.transpose(y, [0, 2, 1, 3]); // (B, T, nHead, hs)
|
|
172
|
+
y = tf.reshape(y, [B, T, C]); // (B, T, C = nHead * hs)
|
|
173
|
+
y = dense(y, this.cProjKernel, this.cProjBias); // output projection (B, T, C)
|
|
174
|
+
y = kwargs.training === true ? tf.dropout(y, this.dropout, undefined, this.seed) : y;
|
|
134
175
|
return y;
|
|
135
176
|
});
|
|
136
177
|
}
|
|
137
178
|
}
|
|
138
179
|
tf.serialization.registerClass(CausalSelfAttention);
|
|
180
|
+
/**
|
|
181
|
+
* GELU with tanh approximate
|
|
182
|
+
* GELU(x) = x * 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
|
|
183
|
+
*
|
|
184
|
+
* https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
|
185
|
+
*/
|
|
139
186
|
class GELU extends tf.layers.Layer {
|
|
140
187
|
static className = 'GELU';
|
|
141
188
|
constructor() {
|
|
@@ -148,11 +195,17 @@ class GELU extends tf.layers.Layer {
|
|
|
148
195
|
return tf.tidy(() => {
|
|
149
196
|
if (Array.isArray(input)) {
|
|
150
197
|
// TODO support multitensor
|
|
198
|
+
if (input.length !== 1)
|
|
199
|
+
throw new Error('expected exactly one tensor');
|
|
151
200
|
input = input[0];
|
|
152
201
|
}
|
|
153
202
|
this.invokeCallHook(input, kwargs);
|
|
154
|
-
const cdf = tf.mul(0.5
|
|
155
|
-
|
|
203
|
+
const cdf = tf.mul(// 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
|
|
204
|
+
0.5, tf.add(1, tf.tanh(// Tanh[sqrt(2/π) * (x + 0.044715 * x^3)]
|
|
205
|
+
tf.mul(tf.sqrt(tf.div(2, Math.PI)), // (sqrt(2/π)
|
|
206
|
+
tf.add(input, tf.mul(0.044715, tf.pow(input, 3))) // (x + 0.044715 * x^3)
|
|
207
|
+
))));
|
|
208
|
+
return tf.mul(input, cdf); // x * 0.5 * (1 + Tanh[sqrt(2/π) * (x + 0.044715 * x^3)])
|
|
156
209
|
});
|
|
157
210
|
}
|
|
158
211
|
}
|
|
@@ -160,48 +213,173 @@ tf.serialization.registerClass(GELU);
|
|
|
160
213
|
function MLP(config) {
|
|
161
214
|
return tf.sequential({ layers: [
|
|
162
215
|
tf.layers.dense({
|
|
163
|
-
name: config.name +
|
|
216
|
+
name: config.name + `.mlp.c_fc`,
|
|
164
217
|
units: 4 * config.nEmbd,
|
|
165
218
|
inputDim: config.nEmbd,
|
|
166
|
-
inputShape: [config.
|
|
219
|
+
inputShape: [config.contextLength, config.nEmbd],
|
|
220
|
+
kernelInitializer: tf.initializers.randomNormal({
|
|
221
|
+
mean: 0, stddev: 0.02, seed: config.seed
|
|
222
|
+
}),
|
|
167
223
|
}),
|
|
168
224
|
new GELU(),
|
|
169
225
|
tf.layers.dense({
|
|
170
|
-
name: config.name + '
|
|
226
|
+
name: config.name + '.mlp.c_proj',
|
|
171
227
|
units: config.nEmbd,
|
|
172
228
|
inputDim: 4 * config.nEmbd,
|
|
173
|
-
inputShape: [config.
|
|
229
|
+
inputShape: [config.contextLength, 4 * config.nEmbd],
|
|
230
|
+
kernelInitializer: tf.initializers.randomNormal({
|
|
231
|
+
mean: 0, stddev: 0.02 * Math.sqrt(2 * config.nLayer), seed: config.seed
|
|
232
|
+
}),
|
|
174
233
|
}),
|
|
175
234
|
tf.layers.dropout({
|
|
176
|
-
name: config.name + '
|
|
177
|
-
rate: config.residDrop
|
|
235
|
+
name: config.name + '.mlp.drop',
|
|
236
|
+
rate: config.residDrop,
|
|
237
|
+
seed: config.seed
|
|
178
238
|
}),
|
|
179
239
|
] });
|
|
180
240
|
}
|
|
241
|
+
/**
|
|
242
|
+
* Performs the following operations:
|
|
243
|
+
* x1 = input + mlp(layernorm_1(input))
|
|
244
|
+
* output = x1 + mlp(layernorm_2(x1))
|
|
245
|
+
*/
|
|
181
246
|
function TransformerBlock(conf) {
|
|
182
|
-
const config = Object.assign({ name: 'h' }, conf);
|
|
183
|
-
const inputs = tf.input({ shape: [config.
|
|
247
|
+
const config = Object.assign({ name: '.h' }, conf);
|
|
248
|
+
const inputs = tf.input({ shape: [config.contextLength, config.nEmbd] });
|
|
184
249
|
let x1, x2;
|
|
185
250
|
// input normalization
|
|
186
|
-
x1 = tf.layers.layerNormalization({
|
|
187
|
-
.
|
|
251
|
+
x1 = tf.layers.layerNormalization({
|
|
252
|
+
name: config.name + '.ln_1',
|
|
253
|
+
epsilon: 1e-5,
|
|
254
|
+
gammaInitializer: 'ones', // already the default but make it explicit
|
|
255
|
+
betaInitializer: 'zeros',
|
|
256
|
+
}).apply(inputs);
|
|
188
257
|
if (config.debug) {
|
|
189
|
-
x1 = new LogLayer({ name: config.name + '
|
|
258
|
+
x1 = new LogLayer({ name: config.name + '.ln_1_log' }).apply(x1);
|
|
190
259
|
}
|
|
191
260
|
// self attention layer
|
|
192
|
-
x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '
|
|
261
|
+
x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '.attn' })).apply(x1);
|
|
262
|
+
if (config.debug) {
|
|
263
|
+
x1 = new LogLayer({ name: config.name + '.attn_log' }).apply(x1);
|
|
264
|
+
}
|
|
193
265
|
// Residual connection
|
|
194
266
|
x1 = tf.layers.add().apply([inputs, x1]);
|
|
267
|
+
if (config.debug) {
|
|
268
|
+
x1 = new LogLayer({ name: config.name + '.residual_log' }).apply(x1);
|
|
269
|
+
}
|
|
195
270
|
// normalization
|
|
196
|
-
x2 = tf.layers
|
|
197
|
-
|
|
198
|
-
|
|
271
|
+
x2 = tf.layers.layerNormalization({
|
|
272
|
+
name: config.name + '.ln_2',
|
|
273
|
+
epsilon: 1e-5,
|
|
274
|
+
gammaInitializer: 'ones',
|
|
275
|
+
betaInitializer: 'zeros',
|
|
276
|
+
}).apply(x1);
|
|
277
|
+
if (config.debug) {
|
|
278
|
+
x2 = new LogLayer({ name: config.name + '.ln_2_log' }).apply(x2);
|
|
279
|
+
}
|
|
199
280
|
// MLP
|
|
200
|
-
x2 = MLP(Object.assign({}, config, { name: config.name })).apply(x2);
|
|
281
|
+
x2 = MLP(Object.assign({}, config, { name: config.name + '.mlp' })).apply(x2);
|
|
282
|
+
if (config.debug) {
|
|
283
|
+
x2 = new LogLayer({ name: config.name + '.mlp_log' }).apply(x2);
|
|
284
|
+
}
|
|
201
285
|
// add attention output to mlp output
|
|
202
286
|
x2 = tf.layers.add().apply([x1, x2]);
|
|
287
|
+
if (config.debug) {
|
|
288
|
+
x2 = new LogLayer({ name: config.name + '.add_log' }).apply(x2);
|
|
289
|
+
}
|
|
203
290
|
return tf.model({ name: config.name, inputs, outputs: x2 });
|
|
204
291
|
}
|
|
292
|
+
/**
|
|
293
|
+
* LanguageModelEmbedding is a layer that combines the token embeddings and the language modeling head
|
|
294
|
+
* I.e. LMEmbedding is used to translate token indices into token embeddings
|
|
295
|
+
* as well as to project embeddings back into token indices
|
|
296
|
+
* The GPT2 model uses the same embedding matrix for both the token embeddings and the language modeling head
|
|
297
|
+
* Because Tensorflow.js doesn't offer an easy weight sharing mechanism, we need to define a custom layer
|
|
298
|
+
* that can be used for both the token embeddings and the language modeling head.
|
|
299
|
+
* In the GPT2 model definition, this layers corresponds to wte and lm_head (which reuses wte)
|
|
300
|
+
*/
|
|
301
|
+
class LMEmbedding extends tf.layers.Layer {
|
|
302
|
+
vocabSize;
|
|
303
|
+
nEmbd;
|
|
304
|
+
seed;
|
|
305
|
+
static className = 'LMEmbedding';
|
|
306
|
+
embeddings;
|
|
307
|
+
constructor(vocabSize, nEmbd, seed) {
|
|
308
|
+
super({});
|
|
309
|
+
this.vocabSize = vocabSize;
|
|
310
|
+
this.nEmbd = nEmbd;
|
|
311
|
+
this.seed = seed;
|
|
312
|
+
}
|
|
313
|
+
build() {
|
|
314
|
+
this.embeddings = this.addWeight('wte', //use same name as GPT2
|
|
315
|
+
[this.vocabSize, this.nEmbd], 'float32', tf.initializers.randomNormal({ mean: 0, stddev: 0.02, seed: this.seed }));
|
|
316
|
+
}
|
|
317
|
+
computeOutputShape(inputShape) {
|
|
318
|
+
let shape;
|
|
319
|
+
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
320
|
+
if (inputShape.length !== 1)
|
|
321
|
+
throw new Error('Expected exactly one Shape');
|
|
322
|
+
shape = inputShape[0];
|
|
323
|
+
}
|
|
324
|
+
else
|
|
325
|
+
shape = inputShape;
|
|
326
|
+
// input shape for the token embedding
|
|
327
|
+
if (shape.length === 2) {
|
|
328
|
+
// https://github.com/tensorflow/tfjs/blob/3daf152cb794f4da58fce5e21e09e8a4f89c8f80/tfjs-layers/src/layers/embeddings.ts#L155
|
|
329
|
+
// batch size and sequence length are undetermined
|
|
330
|
+
// so the output shape is [null, null, nEmbd]
|
|
331
|
+
if (shape[0] !== null || shape[1] !== null)
|
|
332
|
+
throw new Error('expected shape [null, null, ...]');
|
|
333
|
+
return [null, null, this.nEmbd];
|
|
334
|
+
}
|
|
335
|
+
// input shape for the language modeling head
|
|
336
|
+
// https://github.com/tensorflow/tfjs/blob/3daf152cb794f4da58fce5e21e09e8a4f89c8f80/tfjs-layers/src/layers/core.ts#L258
|
|
337
|
+
else if (shape.length === 3) {
|
|
338
|
+
// batch size and sequence length are undetermined
|
|
339
|
+
// so the output shape is [null, null, nEmbd]
|
|
340
|
+
if (shape[0] !== null || shape[1] !== null)
|
|
341
|
+
throw new Error('expected shape [null, null, ...]');
|
|
342
|
+
return [null, null, this.vocabSize];
|
|
343
|
+
}
|
|
344
|
+
else
|
|
345
|
+
throw new Error('unexpected input shape');
|
|
346
|
+
}
|
|
347
|
+
call(input, kwargs) {
|
|
348
|
+
return tf.tidy(() => {
|
|
349
|
+
if (this.embeddings === undefined)
|
|
350
|
+
throw new Error('not built');
|
|
351
|
+
if (Array.isArray(input)) {
|
|
352
|
+
if (input.length !== 1)
|
|
353
|
+
throw new Error('expected exactly one tensor');
|
|
354
|
+
input = input[0];
|
|
355
|
+
}
|
|
356
|
+
this.invokeCallHook(input, kwargs);
|
|
357
|
+
// If the input is a 2D tensor, it is a batch of sequences of tokens
|
|
358
|
+
// so we translate the tokens into embeddings
|
|
359
|
+
// using `this.embeddings` as a lookup table
|
|
360
|
+
if (input.shape.length === 2) {
|
|
361
|
+
// (batch_size, sequence_length) => (batch_size, sequence_length, nEmbd)
|
|
362
|
+
return tf.gather(this.embeddings.read(), tf.cast(input, 'int32'), 0);
|
|
363
|
+
}
|
|
364
|
+
// If the input is a 3D tensor, it is a sequence of embeddings
|
|
365
|
+
// so we apply a dense layer to project the embeddings back into the vocabulary space
|
|
366
|
+
else if (input.shape.length === 3 && input.shape[2] === this.nEmbd) {
|
|
367
|
+
// Replicate the kernel for each batch element
|
|
368
|
+
const kernel = this.embeddings.read().expandDims(0).tile([input.shape[0], 1, 1]);
|
|
369
|
+
// TODO: rely on broadcasting when tfjs will support backpropagating through broadcasting
|
|
370
|
+
// Remove the tile, or use tf.einsum('BTE,VE->BTV', input, this.embeddings.read())
|
|
371
|
+
// to prevent tensor duplication but tensorflow.js fails to backpropagate einsum
|
|
372
|
+
// https://github.com/tensorflow/tfjs/issues/5690
|
|
373
|
+
// (batch_size, sequence_length, nEmbd) x (vocabSize, nEmbd)^T -> (batch_size, sequence_length, vocabSize)
|
|
374
|
+
return tf.matMul(input, kernel, false, true);
|
|
375
|
+
}
|
|
376
|
+
else {
|
|
377
|
+
throw new Error('unexpected input shape for token embeddings');
|
|
378
|
+
}
|
|
379
|
+
});
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
tf.serialization.registerClass(LMEmbedding);
|
|
205
383
|
/**
|
|
206
384
|
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
207
385
|
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
@@ -212,54 +390,54 @@ function TransformerBlock(conf) {
|
|
|
212
390
|
*/
|
|
213
391
|
export function GPTArchitecture(config) {
|
|
214
392
|
const inputs = tf.input({ shape: [null] });
|
|
215
|
-
//
|
|
216
|
-
const
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
embeddingsInitializer: 'zeros',
|
|
222
|
-
embeddingsRegularizer: undefined,
|
|
223
|
-
activityRegularizer: undefined
|
|
224
|
-
}).apply(inputs)
|
|
225
|
-
: inputs;
|
|
393
|
+
// token embedding
|
|
394
|
+
const wte = new LMEmbedding(config.vocabSize, config.nEmbd, config.seed);
|
|
395
|
+
let tokEmb = wte.apply(inputs); // (batch_size, input length T, nEmbd)
|
|
396
|
+
if (config.debug) {
|
|
397
|
+
tokEmb = new LogLayer({ name: 'tokEmb_log' }).apply(tokEmb);
|
|
398
|
+
}
|
|
226
399
|
// Positional embedding
|
|
227
400
|
const range = new Range({}).apply(inputs);
|
|
228
401
|
let posEmb = tf.layers.embedding({
|
|
229
|
-
name: config.name + '
|
|
230
|
-
inputDim: config.
|
|
402
|
+
name: config.name + '.wpe',
|
|
403
|
+
inputDim: config.contextLength,
|
|
231
404
|
outputDim: config.nEmbd,
|
|
232
|
-
embeddingsInitializer:
|
|
405
|
+
embeddingsInitializer: tf.initializers.randomNormal({
|
|
406
|
+
mean: 0, stddev: 0.02, seed: config.seed
|
|
407
|
+
}),
|
|
233
408
|
}).apply(range);
|
|
234
409
|
if (config.debug) {
|
|
235
|
-
posEmb = new LogLayer({ name: '
|
|
410
|
+
posEmb = new LogLayer({ name: 'posEmb_log' }).apply(posEmb);
|
|
236
411
|
}
|
|
237
412
|
// token and positional embeddings are added together
|
|
238
413
|
let x = tf.layers.add().apply([tokEmb, posEmb]);
|
|
239
414
|
// dropout
|
|
240
|
-
x = tf.layers.dropout({
|
|
415
|
+
x = tf.layers.dropout({
|
|
416
|
+
name: 'drop', rate: config.embdDrop, seed: config.seed
|
|
417
|
+
}).apply(x);
|
|
241
418
|
if (config.debug) {
|
|
242
|
-
x = new LogLayer({ name: '
|
|
419
|
+
x = new LogLayer({ name: 'drop_log' }).apply(x);
|
|
243
420
|
}
|
|
244
|
-
//
|
|
421
|
+
// apply successively transformer blocks, attention and dense layers
|
|
245
422
|
for (let i = 0; i < config.nLayer; i++) {
|
|
246
|
-
x = TransformerBlock(Object.assign({}, config, { name: config.name + '
|
|
423
|
+
x = TransformerBlock(Object.assign({}, config, { name: config.name + '.h' + i })).apply(x);
|
|
247
424
|
}
|
|
248
425
|
// Normalization
|
|
249
|
-
x = tf.layers.layerNormalization({
|
|
426
|
+
x = tf.layers.layerNormalization({
|
|
427
|
+
name: config.name + '.ln_f',
|
|
428
|
+
epsilon: 1e-5,
|
|
429
|
+
gammaInitializer: 'ones',
|
|
430
|
+
betaInitializer: 'zeros',
|
|
431
|
+
})
|
|
250
432
|
.apply(x);
|
|
251
433
|
if (config.debug) {
|
|
252
|
-
x = new LogLayer({ name: '
|
|
434
|
+
x = new LogLayer({ name: 'ln_f_log' }).apply(x);
|
|
253
435
|
}
|
|
254
|
-
//
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
inputDim: config.nEmbd,
|
|
260
|
-
inputShape: [config.blockSize, config.nEmbd],
|
|
261
|
-
useBias: false
|
|
262
|
-
}).apply(x);
|
|
436
|
+
// language modeling head
|
|
437
|
+
// GPT2 uses the same matrix for the token embedding and the modeling head
|
|
438
|
+
x = wte.apply(x);
|
|
439
|
+
if (config.debug) {
|
|
440
|
+
x = new LogLayer({ name: 'lm_head_log' }).apply(x);
|
|
263
441
|
}
|
|
264
442
|
return tf.model({ inputs, outputs: x });
|
|
265
443
|
}
|
|
@@ -16,7 +16,7 @@ export declare abstract class Dataset<T> {
|
|
|
16
16
|
*/
|
|
17
17
|
export declare class GPTModel extends tf.LayersModel {
|
|
18
18
|
protected readonly config: Required<GPTConfig>;
|
|
19
|
-
constructor(partialConfig?: GPTConfig
|
|
19
|
+
constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel);
|
|
20
20
|
get getGPTConfig(): Required<GPTConfig>;
|
|
21
21
|
compile(): void;
|
|
22
22
|
fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History>;
|
package/dist/models/gpt/model.js
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import * as tf from '@tensorflow/tfjs';
|
|
3
|
-
import { getModelSizes,
|
|
3
|
+
import { getModelSizes, DefaultGPTConfig } from './config.js';
|
|
4
4
|
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js';
|
|
5
5
|
import evaluate from './evaluate.js';
|
|
6
6
|
import { GPTArchitecture } from './layers.js';
|
|
7
|
-
const debug = createDebug("discojs:models:gpt");
|
|
7
|
+
const debug = createDebug("discojs:models:gpt:model");
|
|
8
8
|
/**
|
|
9
9
|
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
10
10
|
*
|
|
@@ -13,7 +13,7 @@ export class GPTModel extends tf.LayersModel {
|
|
|
13
13
|
config;
|
|
14
14
|
constructor(partialConfig, layersModel) {
|
|
15
15
|
// Fill missing config parameters with default values
|
|
16
|
-
let completeConfig = { ...
|
|
16
|
+
let completeConfig = { ...DefaultGPTConfig, ...partialConfig };
|
|
17
17
|
// Add layer sizes depending on which model has been specified
|
|
18
18
|
completeConfig = { ...completeConfig, ...getModelSizes(completeConfig.modelType) };
|
|
19
19
|
if (layersModel !== undefined) {
|
|
@@ -112,7 +112,7 @@ export class GPTModel extends tf.LayersModel {
|
|
|
112
112
|
tf.dispose([xs, ys]);
|
|
113
113
|
}
|
|
114
114
|
let logs = {
|
|
115
|
-
'loss': averageLoss / iteration,
|
|
115
|
+
'loss': averageLoss / (iteration - 1), // -1 because iteration got incremented at the end of the loop
|
|
116
116
|
'acc': accuracyFraction[0] / accuracyFraction[1],
|
|
117
117
|
};
|
|
118
118
|
if (evalDataset !== undefined) {
|
package/dist/processing/index.js
CHANGED
|
@@ -33,11 +33,11 @@ export async function preprocess(task, dataset) {
|
|
|
33
33
|
// cast as typescript doesn't reduce generic type
|
|
34
34
|
const d = dataset;
|
|
35
35
|
const t = task;
|
|
36
|
+
const contextLength = task.trainingInformation.contextLength;
|
|
36
37
|
const tokenizer = await models.getTaskTokenizer(t);
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
|
|
38
|
+
return d.map(text => processing.tokenize(tokenizer, text))
|
|
39
|
+
.flatten()
|
|
40
|
+
.batch(contextLength + 1, 1)
|
|
41
41
|
.map((tokens) => [tokens.pop(), tokens.last()]);
|
|
42
42
|
}
|
|
43
43
|
}
|
|
@@ -60,12 +60,11 @@ export async function preprocessWithoutLabel(task, dataset) {
|
|
|
60
60
|
// cast as typescript doesn't reduce generic type
|
|
61
61
|
const d = dataset;
|
|
62
62
|
const t = task;
|
|
63
|
+
const contextLength = task.trainingInformation.contextLength;
|
|
63
64
|
const tokenizer = await models.getTaskTokenizer(t);
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
|
|
68
|
-
.map((tokens) => tokens.pop());
|
|
65
|
+
return d.map(text => processing.tokenize(tokenizer, text))
|
|
66
|
+
.flatten()
|
|
67
|
+
.batch(contextLength);
|
|
69
68
|
}
|
|
70
69
|
}
|
|
71
70
|
}
|
|
@@ -1,11 +1,21 @@
|
|
|
1
|
-
import { List } from "immutable";
|
|
2
1
|
import { PreTrainedTokenizer } from "@xenova/transformers";
|
|
3
|
-
type
|
|
2
|
+
import type { Text, TokenizedText } from '../index.js';
|
|
3
|
+
interface TokenizingConfig {
|
|
4
|
+
padding?: boolean;
|
|
5
|
+
padding_side?: 'left' | 'right';
|
|
6
|
+
truncation?: boolean;
|
|
7
|
+
max_length?: number;
|
|
8
|
+
}
|
|
4
9
|
/**
|
|
5
|
-
* Tokenize
|
|
10
|
+
* Tokenize one line of text.
|
|
11
|
+
* Wrapper around Transformers.js tokenizer to handle type checking and format the output.
|
|
12
|
+
* Note that Transformers.js's tokenizer can tokenize multiple lines of text at once
|
|
13
|
+
* but we are currently not making use of it. Can be useful when padding a batch
|
|
6
14
|
*
|
|
7
|
-
* @param
|
|
8
|
-
* @
|
|
15
|
+
* @param tokenizer the tokenizer object
|
|
16
|
+
* @param text the text to tokenize
|
|
17
|
+
* @param config TokenizingConfig, the tokenizing parameters when using `tokenizer`
|
|
18
|
+
* @returns List<number> the tokenized text
|
|
9
19
|
*/
|
|
10
|
-
export declare function
|
|
20
|
+
export declare function tokenize(tokenizer: PreTrainedTokenizer, text: Text, config?: TokenizingConfig): TokenizedText;
|
|
11
21
|
export {};
|