@stellarapp/tfjs-stellar 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/jest.config.ts +203 -0
- package/package.json +24 -0
- package/src/index.ts +93 -0
- package/src/kv_cache.ts +205 -0
- package/src/layers/cached_rope_multihead_attention.test.ts +59 -0
- package/src/layers/cached_rope_multihead_attention.ts +113 -0
- package/src/layers/gpt_decoder_block.ts +77 -0
- package/src/layers/multihead_attention.test.ts +212 -0
- package/src/layers/multihead_attention.ts +371 -0
- package/src/layers/positional_encoding.test.ts +113 -0
- package/src/layers/positional_encoding.ts +158 -0
- package/src/layers/rotary_position_embedding.test.ts +107 -0
- package/src/layers/rotary_position_embedding.ts +163 -0
- package/src/layers/token_and_positional_embedding.test.ts +81 -0
- package/src/layers/token_and_positional_embedding.ts +149 -0
- package/src/layers/transformer_decoder.test.ts +100 -0
- package/src/layers/transformer_decoder.ts +236 -0
- package/src/layers/transformer_encoder.test.ts +85 -0
- package/src/layers/transformer_encoder.ts +224 -0
- package/src/losses/dice.ts +156 -0
- package/src/losses/index.ts +1 -0
- package/src/metrics.ts +32 -0
- package/src/models/gpt_model.ts +232 -0
- package/src/models/index.ts +2 -0
- package/src/models/llm_model.ts +355 -0
- package/src/models/u_net.ts +240 -0
- package/src/packing_mask.ts +28 -0
- package/src/testing.ts +1 -0
- package/src/tfjs_types.ts +15 -0
- package/src/utils.test.ts +101 -0
- package/src/utils.ts +86 -0
- package/tsconfig.json +49 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
+
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
4
|
+
import { generateCausalAttentionMask } from "@/utils";
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
export interface MultiHeadAttentionArgs extends LayerArgs {
|
|
8
|
+
numHeads: number;
|
|
9
|
+
embedDim: number;
|
|
10
|
+
useBias?: boolean;
|
|
11
|
+
dropout?: number;
|
|
12
|
+
causal?: boolean;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
export interface ScaledDotProductionAttentionKwargs {
|
|
17
|
+
training?: boolean;
|
|
18
|
+
dropout?: number;
|
|
19
|
+
causal?: boolean;
|
|
20
|
+
scaling_factor?: number;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
* This MultiHead Attention layer implements the algorithm as described in
|
|
26
|
+
* the paper "Attention is all you Need" Vaswani et al., 2017.
|
|
27
|
+
*
|
|
28
|
+
* @param numHeads number of attention heads to use
|
|
29
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
30
|
+
* @param causal use causal masking, default `false`
|
|
31
|
+
* @param dropout use dropout during the attention calculations, default `0.0`
|
|
32
|
+
* @param useBias use bias for the dense sublayers, default `true`
|
|
33
|
+
*
|
|
34
|
+
* The TensorFlow version uses tf.einsum, whose gradient op has not yet been
|
|
35
|
+
* implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
|
|
36
|
+
* therefore we follow the PyTorch implementation described in:
|
|
37
|
+
* https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
|
|
38
|
+
* https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
39
|
+
*
|
|
40
|
+
* This implementation is different from TensorFlow's whose attention weights
|
|
41
|
+
* are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
|
|
42
|
+
* are shaped [embed dim, embed dim]
|
|
43
|
+
* https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
|
|
44
|
+
* https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
|
|
45
|
+
*
|
|
46
|
+
* TODO: implement a fast track for self attention (query = key = value)
|
|
47
|
+
* where a single dense layer combines and replaces the query, key and projection layers
|
|
48
|
+
*
|
|
49
|
+
* TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
|
|
50
|
+
*/
|
|
51
|
+
export class MultiHeadAttention extends tf.layers.Layer {
|
|
52
|
+
static className = "MultiHeadAttention";
|
|
53
|
+
protected readonly numHeads: number;
|
|
54
|
+
protected readonly embedDim: number; // size of embedding dim of inputs, also per attention head
|
|
55
|
+
protected readonly useBias: boolean;
|
|
56
|
+
protected readonly dropout: number;
|
|
57
|
+
protected readonly causal: boolean; // use causal attention to mask future words
|
|
58
|
+
|
|
59
|
+
// projection simply means matrix multiplying query, key, and value
|
|
60
|
+
// with weights to create a representation of the inputs
|
|
61
|
+
protected readonly queryProjection: tf.layers.Layer;
|
|
62
|
+
protected readonly keyProjection: tf.layers.Layer;
|
|
63
|
+
protected readonly valueProjection: tf.layers.Layer;
|
|
64
|
+
protected readonly outputProjection: tf.layers.Layer;
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }: MultiHeadAttentionArgs) {
|
|
68
|
+
super(args);
|
|
69
|
+
|
|
70
|
+
if (embedDim % numHeads != 0) {
|
|
71
|
+
throw Error(`${this.getClassName()}::constructor ${this.name} embedDim (${embedDim}) is not divisible by numHeads (${numHeads})`);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
this.numHeads = numHeads;
|
|
75
|
+
this.embedDim = embedDim;
|
|
76
|
+
this.useBias = useBias;
|
|
77
|
+
this.dropout = dropout;
|
|
78
|
+
this.causal = causal;
|
|
79
|
+
|
|
80
|
+
if (this.dropout >= 1) {
|
|
81
|
+
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// intialize the projection weights, this should be in the
|
|
85
|
+
// build() function but is done here to avoid linting complaints
|
|
86
|
+
this.queryProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
87
|
+
this.keyProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
88
|
+
this.valueProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
89
|
+
this.outputProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
/**
|
|
94
|
+
* Forward propagation. Provide one input tensor or three identical tensors to self-attention.
|
|
95
|
+
* @param inputs a single tensor for self-attention or an array of exactly three
|
|
96
|
+
* tensors that are either identical (self-attention) or different (cross-attention)
|
|
97
|
+
* @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
|
|
98
|
+
*/
|
|
99
|
+
override call(
|
|
100
|
+
inputs: tf.Tensor | tf.Tensor[],
|
|
101
|
+
kwargs: Kwargs & {
|
|
102
|
+
packingMask?: tf.Tensor,
|
|
103
|
+
causalMask?: tf.Tensor,
|
|
104
|
+
}
|
|
105
|
+
): tf.Tensor | tf.Tensor[] {
|
|
106
|
+
// validate the input tensors
|
|
107
|
+
if (!Array.isArray(inputs)) {
|
|
108
|
+
inputs = [inputs];
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// accept only 1 input (self attention) or 3 inputs (self or cross attention)
|
|
112
|
+
if (inputs.length != 1 && inputs.length != 3) {
|
|
113
|
+
throw Error(`${this.getClassName()}::call ${this.name} expects exactly one or three input tensors, ${inputs.length} were provided`);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
for (const input of inputs) {
|
|
117
|
+
if (input.shape.length != 3) {
|
|
118
|
+
throw Error(`${this.getClassName()}::call ${this.name} expected input shapes of [batch, seq, embed_dim], got ${JSON.stringify(input.shape)}`);
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
const [query, key, value] = inputs;
|
|
123
|
+
const packingMask = kwargs.packingMask ?? null;
|
|
124
|
+
const causalMask = kwargs.causalMask ?? null;
|
|
125
|
+
|
|
126
|
+
return inputs.length == 3
|
|
127
|
+
// cross-attention
|
|
128
|
+
? this.forward(query!, key!, value!, packingMask, causalMask, kwargs)
|
|
129
|
+
// self-attention
|
|
130
|
+
: this.forward(query!, query!, query!, packingMask, causalMask, kwargs);
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
/**
|
|
135
|
+
* Forward propagation
|
|
136
|
+
*/
|
|
137
|
+
protected forward(
|
|
138
|
+
query_input: tf.Tensor,
|
|
139
|
+
key_input: tf.Tensor,
|
|
140
|
+
value_input: tf.Tensor,
|
|
141
|
+
packing_mask: tf.Tensor | null,
|
|
142
|
+
causal_mask: tf.Tensor | null,
|
|
143
|
+
kwargs: Kwargs): tf.Tensor {
|
|
144
|
+
|
|
145
|
+
// dimensions abbreviations
|
|
146
|
+
// batch = the number of sequences in the input
|
|
147
|
+
// seq = the length of each sequence in the input
|
|
148
|
+
// dims = the size of each token's embedding
|
|
149
|
+
return tf.tidy(() => {
|
|
150
|
+
const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
|
|
151
|
+
|
|
152
|
+
// swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
|
|
153
|
+
const move_head_dim_forward = [0, 2, 1, 3];
|
|
154
|
+
|
|
155
|
+
const {
|
|
156
|
+
query_split, key_split, value_split
|
|
157
|
+
} = this.splitHeads(query, key, value, move_head_dim_forward);
|
|
158
|
+
|
|
159
|
+
// apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
|
|
160
|
+
const spda = MultiHeadAttention.scaledDotProductionAttention(
|
|
161
|
+
query_split, key_split, value_split,
|
|
162
|
+
kwargs.attentionMask ?? null, packing_mask, causal_mask,
|
|
163
|
+
this.dropout, this.causal, kwargs);
|
|
164
|
+
|
|
165
|
+
// concat heads and apply the output projection
|
|
166
|
+
const output = this.outputProjection.apply(
|
|
167
|
+
spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], -1, this.embedDim]));
|
|
168
|
+
|
|
169
|
+
return output as tf.Tensor;
|
|
170
|
+
})
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor) {
|
|
175
|
+
// apply input projections, this is a batched matrix multiplication operated on the last
|
|
176
|
+
// dimension of query_input and first dimension of the dense layer weights,
|
|
177
|
+
// [batch, seq, dims] x [dims, dims] = [batch x seq, dims] x [dims, dims] = [batch x seq, dims] = [batch, seq, dims]
|
|
178
|
+
return tf.tidy(() => {
|
|
179
|
+
return {
|
|
180
|
+
query: this.queryProjection.apply(query_input) as tf.Tensor,
|
|
181
|
+
key: this.keyProjection.apply(key_input) as tf.Tensor,
|
|
182
|
+
value: this.valueProjection.apply(value_input) as tf.Tensor
|
|
183
|
+
}
|
|
184
|
+
})
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]) {
|
|
189
|
+
// split heads and prepare for scaled dot product attention by splitting the
|
|
190
|
+
// last dimension to get the heads, bring the heads forward
|
|
191
|
+
// [batch, seq, dims] -> [batch, seq, heads, dims / heads] -> [batch, heads, seq, head_dim]
|
|
192
|
+
const batch_size = query.shape[0];
|
|
193
|
+
const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
|
|
194
|
+
|
|
195
|
+
return tf.tidy(() => {
|
|
196
|
+
return {
|
|
197
|
+
query_split: query.reshape(split_heads).transpose(shuffle) as tf.Tensor4D,
|
|
198
|
+
key_split: key.reshape(split_heads).transpose(shuffle) as tf.Tensor4D,
|
|
199
|
+
value_split: value.reshape(split_heads).transpose(shuffle) as tf.Tensor4D
|
|
200
|
+
}
|
|
201
|
+
})
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
/**
|
|
206
|
+
* Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
|
|
207
|
+
* formula (1) of the 2017 paper Attention Is All You Need
|
|
208
|
+
*
|
|
209
|
+
* @param attentionMask a mask to prevent tokens from being
|
|
210
|
+
* attended to (usually for padding tokens). It should have the shape
|
|
211
|
+
* [batch, head, query_sequence_len, key_sequence_len]. To use in
|
|
212
|
+
* conjunction with causal masking, the tensor should be a boolean type
|
|
213
|
+
* where false indicates a masked token.
|
|
214
|
+
* @param packingMask a mask to prevent tokens from attending across document boundaries
|
|
215
|
+
*/
|
|
216
|
+
static scaledDotProductionAttention(
|
|
217
|
+
query: tf.Tensor,
|
|
218
|
+
key: tf.Tensor,
|
|
219
|
+
value: tf.Tensor,
|
|
220
|
+
attentionMask: tf.Tensor | null,
|
|
221
|
+
packingMask: tf.Tensor | null,
|
|
222
|
+
causalMask: tf.Tensor | null,
|
|
223
|
+
dropout: number,
|
|
224
|
+
causal: boolean,
|
|
225
|
+
kwargs: ScaledDotProductionAttentionKwargs = {}
|
|
226
|
+
): tf.Tensor {
|
|
227
|
+
return tf.tidy(() => {
|
|
228
|
+
const { training = false, scaling_factor } = kwargs;
|
|
229
|
+
|
|
230
|
+
key.shape.forEach((val, index) => {
|
|
231
|
+
if (key.shape[index] != value.shape[index]) {
|
|
232
|
+
throw Error(`scaledDotProductionAttention: expected key and value` +
|
|
233
|
+
` to have the same shape, got ${JSON.stringify(key.shape)} (key) and` +
|
|
234
|
+
` ${JSON.stringify(value.shape)} (value)`);
|
|
235
|
+
}
|
|
236
|
+
})
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
// mask's shape is [..., seq, seq] where seq is the number of words/tokens in the input,
|
|
240
|
+
// not adding the batch dimension yet to lessen the calculations
|
|
241
|
+
const causal_mask_shape = [
|
|
242
|
+
query.shape[query.shape.length - 2],
|
|
243
|
+
key.shape[key.shape.length - 2]];
|
|
244
|
+
|
|
245
|
+
let mask = tf.zeros(causal_mask_shape);
|
|
246
|
+
|
|
247
|
+
if (causal && causal_mask_shape[0] > 1) {
|
|
248
|
+
if (attentionMask && attentionMask.dtype != "bool") {
|
|
249
|
+
throw Error(`scaledDotProductionAttention: the attention mask must be undefined or a boolean type if used with causal attention`);
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
// apply a causal attention mask so that tokens can only attend to preceding tokens,
|
|
253
|
+
// prevents looking at head
|
|
254
|
+
if (causalMask) {
|
|
255
|
+
mask = causalMask;
|
|
256
|
+
} else {
|
|
257
|
+
mask = generateCausalAttentionMask(causal_mask_shape[0], causal_mask_shape[1]);
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
if (attentionMask) {
|
|
262
|
+
if (attentionMask.dtype == "bool") {
|
|
263
|
+
// convert the boolean mask to float
|
|
264
|
+
// warning: do not use 1e9, it will overflow, use something smaller like 1e7
|
|
265
|
+
mask = mask.add(attentionMask.cast("float32").sub(1).mul(1e7));
|
|
266
|
+
} else {
|
|
267
|
+
// this will occur only when not using causal masking,
|
|
268
|
+
// if the attention mask is not boolean, it's assumed the masking is already calculated,
|
|
269
|
+
mask = attentionMask;
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
// 1. matrix multiply query and transposed key
|
|
274
|
+
// 2. divide by scaling factor
|
|
275
|
+
// 3. apply softmax to the result
|
|
276
|
+
// 4. apply attention and/or causal mask
|
|
277
|
+
// 5. apply dropout
|
|
278
|
+
// 6. matrix multiply softmax result with value
|
|
279
|
+
let pre_softmax = query
|
|
280
|
+
.matMul(key, false, true)
|
|
281
|
+
.div(Math.sqrt(scaling_factor ?? key.shape[key.shape.length - 1]))
|
|
282
|
+
.add(mask);
|
|
283
|
+
|
|
284
|
+
if (packingMask) {
|
|
285
|
+
// packing mask is added separately because each mask within a batch may be different,
|
|
286
|
+
// so it cannot be broadcasted
|
|
287
|
+
pre_softmax = pre_softmax.add(packingMask);
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
const spda = tf.softmax(pre_softmax);
|
|
291
|
+
|
|
292
|
+
const spda_dropout = tf.dropout(spda, training ? dropout : 0);
|
|
293
|
+
const attention = spda_dropout.matMul(value);
|
|
294
|
+
|
|
295
|
+
return attention;
|
|
296
|
+
});
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
override build(inputShape: tf.Shape | tf.Shape[]): void {
|
|
301
|
+
let input_shape: tf.Shape[] = [];
|
|
302
|
+
|
|
303
|
+
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
304
|
+
input_shape = inputShape as tf.Shape[];
|
|
305
|
+
} else {
|
|
306
|
+
input_shape = [inputShape as tf.Shape, inputShape as tf.Shape, inputShape as tf.Shape];
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
if (input_shape.length != 1 && input_shape.length != 3) {
|
|
310
|
+
throw Error(`${this.getClassName()}::build ${this.name} accepts either exactly one or three inputs, received ${JSON.stringify(inputShape)}`);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
// initialize the sublayer weights
|
|
314
|
+
this.queryProjection.build(input_shape[0]);
|
|
315
|
+
this.keyProjection.build(input_shape[1]);
|
|
316
|
+
this.valueProjection.build(input_shape[2]);
|
|
317
|
+
this.outputProjection.build(input_shape[0]);
|
|
318
|
+
|
|
319
|
+
// the sublayer weights need to be tracked by this layer otherwise
|
|
320
|
+
// backpropagation will complain about no trainable parameters found,
|
|
321
|
+
// this is an extra step that TF's Python version does not need
|
|
322
|
+
this.trainableWeights = [
|
|
323
|
+
...this.queryProjection.trainableWeights,
|
|
324
|
+
...this.keyProjection.trainableWeights,
|
|
325
|
+
...this.valueProjection.trainableWeights,
|
|
326
|
+
...this.outputProjection.trainableWeights
|
|
327
|
+
];
|
|
328
|
+
|
|
329
|
+
// rename the weights otherwise they'll take on the default naming and overlap
|
|
330
|
+
// each other which breaks model loading due to duplicate weight names
|
|
331
|
+
let indexing = 0;
|
|
332
|
+
|
|
333
|
+
for (const weight of this.trainableWeights) {
|
|
334
|
+
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
335
|
+
(weight as any).name += unique_name;
|
|
336
|
+
(weight as any).originalName += unique_name;
|
|
337
|
+
indexing++;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
super.build(inputShape);
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
/**
|
|
345
|
+
* MultiHead attention's output is the same shape the query's.
|
|
346
|
+
*/
|
|
347
|
+
override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
|
|
348
|
+
return Array.isArray(inputShape) && Array.isArray(inputShape[0]) ? inputShape[0] : inputShape;
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
override getConfig() {
|
|
353
|
+
const base_config = super.getConfig();
|
|
354
|
+
|
|
355
|
+
const config = {
|
|
356
|
+
numHeads: this.numHeads,
|
|
357
|
+
embedDim: this.embedDim,
|
|
358
|
+
useBias: this.useBias,
|
|
359
|
+
causal: this.causal,
|
|
360
|
+
dropout: this.dropout,
|
|
361
|
+
name: this.name,
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
Object.assign(config, base_config);
|
|
365
|
+
|
|
366
|
+
return config;
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
tf.serialization.registerClass(MultiHeadAttention);
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
|
|
3
|
+
import { PositionalEncoding } from '@/layers/positional_encoding';
|
|
4
|
+
|
|
5
|
+
// disables warning for using the faster node backend,
|
|
6
|
+
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
7
|
+
tf.env().set('IS_NODE', false);
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
describe("PositionalEncoding tests", () => {
|
|
11
|
+
it("should fail to instantiate a layer", () => {
|
|
12
|
+
expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: 0 })).toThrow();
|
|
13
|
+
expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: -1 })).toThrow();
|
|
14
|
+
expect(() => new PositionalEncoding({ maxSequenceLength: 0, embedDim: 32 })).toThrow();
|
|
15
|
+
expect(() => new PositionalEncoding({ maxSequenceLength: -1, embedDim: 32 })).toThrow();
|
|
16
|
+
})
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
test("successfull forward calls", () => {
|
|
20
|
+
const embed_dims = 32;
|
|
21
|
+
const sequences = 4;
|
|
22
|
+
const input = tf.randomUniform([2, sequences, embed_dims]);
|
|
23
|
+
|
|
24
|
+
const positional = new PositionalEncoding({ embedDim: embed_dims });
|
|
25
|
+
expect(() => positional.apply(input)).not.toThrow();
|
|
26
|
+
expect(() => positional.apply([input])).not.toThrow();
|
|
27
|
+
expect(positional.computeOutputShape(input.shape)).toEqual(input.shape);
|
|
28
|
+
})
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
it("should throw when input sequences are too large, embedding dims don't match, input aren't rank 3", () => {
|
|
32
|
+
const sequences_too_long = tf.randomUniform([100, 32]);
|
|
33
|
+
const embeddings_too_large = tf.randomUniform([32, 100]);
|
|
34
|
+
const wrong_rank = tf.randomUniform([10, 32, 32]);
|
|
35
|
+
|
|
36
|
+
const positional = new PositionalEncoding({ maxSequenceLength: 10, embedDim: 32 });
|
|
37
|
+
|
|
38
|
+
expect(() => positional.apply(sequences_too_long)).toThrow();
|
|
39
|
+
expect(() => positional.apply(embeddings_too_large)).toThrow();
|
|
40
|
+
expect(() => positional.apply(wrong_rank)).toThrow();
|
|
41
|
+
})
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
it("should return a non-empty config dict", () => {
|
|
45
|
+
const attention = new PositionalEncoding({ embedDim: 32 });
|
|
46
|
+
expect(Object.keys(attention.getConfig())).not.toBe(0);
|
|
47
|
+
})
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
// PyTorch implementation at found at
|
|
51
|
+
// https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
|
|
52
|
+
it("should be within 1e-6 of PyTorch's implementation", () => {
|
|
53
|
+
const pytorch_embed4 = tf.tensor([
|
|
54
|
+
[[0.0000000, 1.0000000, 0.0000000, 1.0000000],
|
|
55
|
+
[0.8414710, 0.5403023, 0.0099998, 0.9999500],
|
|
56
|
+
[0.9092974, -0.4161468, 0.0199987, 0.9998000],
|
|
57
|
+
[0.1411200, -0.9899925, 0.0299955, 0.9995500],
|
|
58
|
+
[-0.7568025, -0.6536436, 0.0399893, 0.9992001],
|
|
59
|
+
[-0.9589243, 0.2836622, 0.0499792, 0.9987503],
|
|
60
|
+
[-0.2794155, 0.9601703, 0.0599640, 0.9982005],
|
|
61
|
+
[0.6569866, 0.7539023, 0.0699428, 0.9975510],
|
|
62
|
+
[0.9893582, -0.1455000, 0.0799147, 0.9968017],
|
|
63
|
+
[0.4121185, -0.9111302, 0.0898785, 0.9959527]]]);
|
|
64
|
+
|
|
65
|
+
const pytorch_embed8 = tf.tensor([
|
|
66
|
+
[[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00,
|
|
67
|
+
0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00],
|
|
68
|
+
[8.4147096e-01, 5.4030234e-01, 9.9833414e-02, 9.9500418e-01,
|
|
69
|
+
9.9998331e-03, 9.9994999e-01, 9.9999981e-04, 9.9999952e-01],
|
|
70
|
+
[9.0929741e-01, -4.1614684e-01, 1.9866931e-01, 9.8006660e-01,
|
|
71
|
+
1.9998666e-02, 9.9980003e-01, 1.9999985e-03, 9.9999803e-01],
|
|
72
|
+
[1.4112000e-01, -9.8999250e-01, 2.9552019e-01, 9.5533651e-01,
|
|
73
|
+
2.9995499e-02, 9.9955004e-01, 2.9999954e-03, 9.9999553e-01],
|
|
74
|
+
[-7.5680250e-01, -6.5364361e-01, 3.8941833e-01, 9.2106098e-01,
|
|
75
|
+
3.9989334e-02, 9.9920011e-01, 3.9999890e-03, 9.9999201e-01],
|
|
76
|
+
[-9.5892429e-01, 2.8366220e-01, 4.7942552e-01, 8.7758255e-01,
|
|
77
|
+
4.9979165e-02, 9.9875027e-01, 4.9999789e-03, 9.9998754e-01],
|
|
78
|
+
[-2.7941549e-01, 9.6017027e-01, 5.6464243e-01, 8.2533562e-01,
|
|
79
|
+
5.9964005e-02, 9.9820054e-01, 5.9999637e-03, 9.9998200e-01],
|
|
80
|
+
[6.5698659e-01, 7.5390226e-01, 6.4421761e-01, 7.6484221e-01,
|
|
81
|
+
6.9942847e-02, 9.9755102e-01, 6.9999420e-03, 9.9997550e-01],
|
|
82
|
+
[9.8935825e-01, -1.4550003e-01, 7.1735609e-01, 6.9670677e-01,
|
|
83
|
+
7.9914689e-02, 9.9680167e-01, 7.9999138e-03, 9.9996799e-01],
|
|
84
|
+
[4.1211849e-01, -9.1113025e-01, 7.8332686e-01, 6.2160999e-01,
|
|
85
|
+
8.9878544e-02, 9.9595273e-01, 8.9998785e-03, 9.9995953e-01]]]);
|
|
86
|
+
|
|
87
|
+
const positional4 = new PositionalEncoding({ embedDim: 4, maxSequenceLength: 10 });
|
|
88
|
+
positional4.build([]);
|
|
89
|
+
|
|
90
|
+
const positional8 = new PositionalEncoding({ embedDim: 8, maxSequenceLength: 10 });
|
|
91
|
+
positional8.build([]);
|
|
92
|
+
|
|
93
|
+
const margin_of_error = 1e-6;
|
|
94
|
+
|
|
95
|
+
// the difference between this and PyTorch's implementation
|
|
96
|
+
//should be insignificantly small
|
|
97
|
+
expect((positional4.getWeights()[0]
|
|
98
|
+
.sub(pytorch_embed4)
|
|
99
|
+
.abs()
|
|
100
|
+
.arraySync() as [])
|
|
101
|
+
.flat(2)
|
|
102
|
+
.filter(i => i > margin_of_error)
|
|
103
|
+
.length).toBe(0);
|
|
104
|
+
|
|
105
|
+
expect((positional8.getWeights()[0]
|
|
106
|
+
.sub(pytorch_embed8)
|
|
107
|
+
.abs()
|
|
108
|
+
.arraySync() as [])
|
|
109
|
+
.flat(2)
|
|
110
|
+
.filter(i => i > margin_of_error)
|
|
111
|
+
.length).toBe(0);
|
|
112
|
+
});
|
|
113
|
+
});
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
export interface PositionalEncodingArgs extends LayerArgs {
|
|
7
|
+
// embedding size of each word/token, aka d_model from the paper
|
|
8
|
+
embedDim: number;
|
|
9
|
+
// the max length of each sentence, any more or less are truncated or padded
|
|
10
|
+
maxSequenceLength?: number;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* This class implements the position encoding logic described in the
|
|
16
|
+
* 2017 paper "Attention Is All You Need".
|
|
17
|
+
*
|
|
18
|
+
* This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
|
|
19
|
+
* and adds positional encoding to return an output tensor of the same shape.
|
|
20
|
+
*
|
|
21
|
+
* @param embedDim the size of each token/word's embedding
|
|
22
|
+
* @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
|
|
23
|
+
*/
|
|
24
|
+
export class PositionalEncoding extends tf.layers.Layer {
|
|
25
|
+
static className = "PositionalEncoding";
|
|
26
|
+
private readonly maxSequenceLength: number;
|
|
27
|
+
private readonly embedDim: number;
|
|
28
|
+
private positionalEncodings: tf.LayerVariable;
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
constructor(args: PositionalEncodingArgs) {
|
|
32
|
+
super(args);
|
|
33
|
+
|
|
34
|
+
this.maxSequenceLength = args.maxSequenceLength ?? 5120;
|
|
35
|
+
this.embedDim = args.embedDim;
|
|
36
|
+
|
|
37
|
+
if (this.maxSequenceLength < 1) {
|
|
38
|
+
throw Error(`${this.getClassName()}::constructor ${this.name} maxSequenceLength` +
|
|
39
|
+
` (${args.maxSequenceLength}) must be greater than 0`);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if (this.embedDim < 1) {
|
|
43
|
+
throw Error(`${this.getClassName()}::constructor ${this.name} embedDim` +
|
|
44
|
+
` (${args.embedDim}) must be greater than 0`);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// positional encodings are not trainable
|
|
48
|
+
this.positionalEncodings = this.addWeight('positional_encodings',
|
|
49
|
+
[this.maxSequenceLength, this.embedDim], "float32",
|
|
50
|
+
tf.initializers.zeros(), undefined, false);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Forward propagation. Injects positional encoding to the input embeddings
|
|
56
|
+
*/
|
|
57
|
+
override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
|
|
58
|
+
// validate the input tensors
|
|
59
|
+
const input = Array.isArray(inputs) ? inputs[0] : inputs;
|
|
60
|
+
const sequences = input.shape[1]!;
|
|
61
|
+
|
|
62
|
+
if (input.shape.length != 3 || input.shape[2] != this.embedDim) {
|
|
63
|
+
throw Error(`${this.getClassName()}::call ${this.name} expected an input shape of` +
|
|
64
|
+
` [batch, (up to ${this.maxSequenceLength}), ${this.embedDim}], instead got ${input.shape}`);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
if (sequences > this.maxSequenceLength) {
|
|
68
|
+
// unexpected sequence length
|
|
69
|
+
throw Error(`${this.getClassName()}::call ${this.name} received an input with` +
|
|
70
|
+
` sequence length (${sequences}) which is greater than the max sequence length` +
|
|
71
|
+
` ${this.maxSequenceLength}`);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// perform forward propagation
|
|
75
|
+
return tf.tidy(() => {
|
|
76
|
+
return input.add(this.positionalEncodings.read()
|
|
77
|
+
.slice([0, 0], [sequences, this.embedDim]) // gets the first "sequences" rows
|
|
78
|
+
.expandDims(0)); // introduce the batch dimension and let add() broadcast it
|
|
79
|
+
})
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Generate the positional encoding from the paper Attention Is All You Need.
|
|
84
|
+
* Note that because the inner term of the position formula is the same for both even
|
|
85
|
+
* and odd indices, we only create half of it and apply sine and cosine individually.
|
|
86
|
+
*/
|
|
87
|
+
override build(inputShape: tf.Shape | tf.Shape[]): void {
|
|
88
|
+
tf.tidy(() => {
|
|
89
|
+
const embedDimHalved = Math.ceil(this.embedDim / 2);
|
|
90
|
+
|
|
91
|
+
// create the position matrix as [ 0, 1, 2, 3, etc ],
|
|
92
|
+
// and broadcast it horizontally to match the number of embeddings,
|
|
93
|
+
const numerator = tf.range(0, this.maxSequenceLength, 1)
|
|
94
|
+
.reshape([this.maxSequenceLength, 1])
|
|
95
|
+
// this creates an extra, unsued positional encoding column later on for odd embedding sizes
|
|
96
|
+
.broadcastTo([this.maxSequenceLength, embedDimHalved]);
|
|
97
|
+
|
|
98
|
+
// the inner term's denominator's exponent's numerator is created as
|
|
99
|
+
// [ 0, 0, 2, 2, 4, 4, etc ] ( technically [0, 2, 4] as explained above ) and not
|
|
100
|
+
// [ 0, 2, 4, 6, 8, 10, etc ] because the even and odd indices are counted as pairs
|
|
101
|
+
// when incrementing "i",
|
|
102
|
+
// the denominator formula is 10_000^(2i/d_model) where each "i" is a sine cosine pair
|
|
103
|
+
const denominator = tf.pow(10_000, tf.range(0, this.embedDim, 2).div(this.embedDim));
|
|
104
|
+
|
|
105
|
+
const inner_term = numerator.div(denominator);
|
|
106
|
+
|
|
107
|
+
const sine = tf.sin(inner_term);
|
|
108
|
+
const cosine = tf.cos(inner_term);
|
|
109
|
+
|
|
110
|
+
// horizontally interweave the sine and cosine columns together to form
|
|
111
|
+
// [sin, cos, sin, cos, etc]
|
|
112
|
+
// [sin, cos, sin, cos, etc]
|
|
113
|
+
// etc
|
|
114
|
+
const interweaved = [];
|
|
115
|
+
const ALL_ROWS = -1;
|
|
116
|
+
const ONE_COL = 1;
|
|
117
|
+
const FIRST_ROW = 0;
|
|
118
|
+
|
|
119
|
+
for (let targetCol = 0; targetCol < this.embedDim / 2; targetCol++) {
|
|
120
|
+
interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
|
|
121
|
+
|
|
122
|
+
if (targetCol != Math.floor(this.embedDim / 2)) {
|
|
123
|
+
// for odd numbered embedDim sizes skip the last cosine column
|
|
124
|
+
// e.g. if embedDim = 5, create [ i=0 (sin), i=0 (cos), i=1 (sin), i=1 (cos), i=2 (sin) ]
|
|
125
|
+
// and the final i=2 (cos) is ignored
|
|
126
|
+
interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// add the positional encoding
|
|
131
|
+
this.setWeights([tf.concat(interweaved, 1)]);
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
super.build(inputShape);
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
|
|
139
|
+
return inputShape;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
override getConfig(): tf.serialization.ConfigDict {
|
|
144
|
+
const base_config = super.getConfig();
|
|
145
|
+
|
|
146
|
+
const config = {
|
|
147
|
+
maxSequenceLength: this.maxSequenceLength,
|
|
148
|
+
embedDim: this.embedDim,
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
Object.assign(config, base_config);
|
|
152
|
+
|
|
153
|
+
return config;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
tf.serialization.registerClass(PositionalEncoding);
|