@stellarapp/tfjs-stellar 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,371 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
+ import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
+ import { generateCausalAttentionMask } from "@/utils";
5
+
6
+
7
+ export interface MultiHeadAttentionArgs extends LayerArgs {
8
+ numHeads: number;
9
+ embedDim: number;
10
+ useBias?: boolean;
11
+ dropout?: number;
12
+ causal?: boolean;
13
+ }
14
+
15
+
16
+ export interface ScaledDotProductionAttentionKwargs {
17
+ training?: boolean;
18
+ dropout?: number;
19
+ causal?: boolean;
20
+ scaling_factor?: number;
21
+ }
22
+
23
+
24
+ /**
25
+ * This MultiHead Attention layer implements the algorithm as described in
26
+ * the paper "Attention is all you Need" Vaswani et al., 2017.
27
+ *
28
+ * @param numHeads number of attention heads to use
29
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
30
+ * @param causal use causal masking, default `false`
31
+ * @param dropout use dropout during the attention calculations, default `0.0`
32
+ * @param useBias use bias for the dense sublayers, default `true`
33
+ *
34
+ * The TensorFlow version uses tf.einsum, whose gradient op has not yet been
35
+ * implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
36
+ * therefore we follow the PyTorch implementation described in:
37
+ * https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
38
+ * https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
39
+ *
40
+ * This implementation is different from TensorFlow's whose attention weights
41
+ * are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
42
+ * are shaped [embed dim, embed dim]
43
+ * https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
44
+ * https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
45
+ *
46
+ * TODO: implement a fast track for self attention (query = key = value)
47
+ * where a single dense layer combines and replaces the query, key and projection layers
48
+ *
49
+ * TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
50
+ */
51
+ export class MultiHeadAttention extends tf.layers.Layer {
52
+ static className = "MultiHeadAttention";
53
+ protected readonly numHeads: number;
54
+ protected readonly embedDim: number; // size of embedding dim of inputs, also per attention head
55
+ protected readonly useBias: boolean;
56
+ protected readonly dropout: number;
57
+ protected readonly causal: boolean; // use causal attention to mask future words
58
+
59
+ // projection simply means matrix multiplying query, key, and value
60
+ // with weights to create a representation of the inputs
61
+ protected readonly queryProjection: tf.layers.Layer;
62
+ protected readonly keyProjection: tf.layers.Layer;
63
+ protected readonly valueProjection: tf.layers.Layer;
64
+ protected readonly outputProjection: tf.layers.Layer;
65
+
66
+
67
+ constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }: MultiHeadAttentionArgs) {
68
+ super(args);
69
+
70
+ if (embedDim % numHeads != 0) {
71
+ throw Error(`${this.getClassName()}::constructor ${this.name} embedDim (${embedDim}) is not divisible by numHeads (${numHeads})`);
72
+ }
73
+
74
+ this.numHeads = numHeads;
75
+ this.embedDim = embedDim;
76
+ this.useBias = useBias;
77
+ this.dropout = dropout;
78
+ this.causal = causal;
79
+
80
+ if (this.dropout >= 1) {
81
+ throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
82
+ }
83
+
84
+ // intialize the projection weights, this should be in the
85
+ // build() function but is done here to avoid linting complaints
86
+ this.queryProjection = tf.layers.dense({ useBias, units: embedDim });
87
+ this.keyProjection = tf.layers.dense({ useBias, units: embedDim });
88
+ this.valueProjection = tf.layers.dense({ useBias, units: embedDim });
89
+ this.outputProjection = tf.layers.dense({ useBias, units: embedDim });
90
+ }
91
+
92
+
93
+ /**
94
+ * Forward propagation. Provide one input tensor or three identical tensors to self-attention.
95
+ * @param inputs a single tensor for self-attention or an array of exactly three
96
+ * tensors that are either identical (self-attention) or different (cross-attention)
97
+ * @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
98
+ */
99
+ override call(
100
+ inputs: tf.Tensor | tf.Tensor[],
101
+ kwargs: Kwargs & {
102
+ packingMask?: tf.Tensor,
103
+ causalMask?: tf.Tensor,
104
+ }
105
+ ): tf.Tensor | tf.Tensor[] {
106
+ // validate the input tensors
107
+ if (!Array.isArray(inputs)) {
108
+ inputs = [inputs];
109
+ }
110
+
111
+ // accept only 1 input (self attention) or 3 inputs (self or cross attention)
112
+ if (inputs.length != 1 && inputs.length != 3) {
113
+ throw Error(`${this.getClassName()}::call ${this.name} expects exactly one or three input tensors, ${inputs.length} were provided`);
114
+ }
115
+
116
+ for (const input of inputs) {
117
+ if (input.shape.length != 3) {
118
+ throw Error(`${this.getClassName()}::call ${this.name} expected input shapes of [batch, seq, embed_dim], got ${JSON.stringify(input.shape)}`);
119
+ }
120
+ }
121
+
122
+ const [query, key, value] = inputs;
123
+ const packingMask = kwargs.packingMask ?? null;
124
+ const causalMask = kwargs.causalMask ?? null;
125
+
126
+ return inputs.length == 3
127
+ // cross-attention
128
+ ? this.forward(query!, key!, value!, packingMask, causalMask, kwargs)
129
+ // self-attention
130
+ : this.forward(query!, query!, query!, packingMask, causalMask, kwargs);
131
+ }
132
+
133
+
134
+ /**
135
+ * Forward propagation
136
+ */
137
+ protected forward(
138
+ query_input: tf.Tensor,
139
+ key_input: tf.Tensor,
140
+ value_input: tf.Tensor,
141
+ packing_mask: tf.Tensor | null,
142
+ causal_mask: tf.Tensor | null,
143
+ kwargs: Kwargs): tf.Tensor {
144
+
145
+ // dimensions abbreviations
146
+ // batch = the number of sequences in the input
147
+ // seq = the length of each sequence in the input
148
+ // dims = the size of each token's embedding
149
+ return tf.tidy(() => {
150
+ const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
151
+
152
+ // swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
153
+ const move_head_dim_forward = [0, 2, 1, 3];
154
+
155
+ const {
156
+ query_split, key_split, value_split
157
+ } = this.splitHeads(query, key, value, move_head_dim_forward);
158
+
159
+ // apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
160
+ const spda = MultiHeadAttention.scaledDotProductionAttention(
161
+ query_split, key_split, value_split,
162
+ kwargs.attentionMask ?? null, packing_mask, causal_mask,
163
+ this.dropout, this.causal, kwargs);
164
+
165
+ // concat heads and apply the output projection
166
+ const output = this.outputProjection.apply(
167
+ spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], -1, this.embedDim]));
168
+
169
+ return output as tf.Tensor;
170
+ })
171
+ }
172
+
173
+
174
+ protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor) {
175
+ // apply input projections, this is a batched matrix multiplication operated on the last
176
+ // dimension of query_input and first dimension of the dense layer weights,
177
+ // [batch, seq, dims] x [dims, dims] = [batch x seq, dims] x [dims, dims] = [batch x seq, dims] = [batch, seq, dims]
178
+ return tf.tidy(() => {
179
+ return {
180
+ query: this.queryProjection.apply(query_input) as tf.Tensor,
181
+ key: this.keyProjection.apply(key_input) as tf.Tensor,
182
+ value: this.valueProjection.apply(value_input) as tf.Tensor
183
+ }
184
+ })
185
+ }
186
+
187
+
188
+ protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]) {
189
+ // split heads and prepare for scaled dot product attention by splitting the
190
+ // last dimension to get the heads, bring the heads forward
191
+ // [batch, seq, dims] -> [batch, seq, heads, dims / heads] -> [batch, heads, seq, head_dim]
192
+ const batch_size = query.shape[0];
193
+ const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
194
+
195
+ return tf.tidy(() => {
196
+ return {
197
+ query_split: query.reshape(split_heads).transpose(shuffle) as tf.Tensor4D,
198
+ key_split: key.reshape(split_heads).transpose(shuffle) as tf.Tensor4D,
199
+ value_split: value.reshape(split_heads).transpose(shuffle) as tf.Tensor4D
200
+ }
201
+ })
202
+ }
203
+
204
+
205
+ /**
206
+ * Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
207
+ * formula (1) of the 2017 paper Attention Is All You Need
208
+ *
209
+ * @param attentionMask a mask to prevent tokens from being
210
+ * attended to (usually for padding tokens). It should have the shape
211
+ * [batch, head, query_sequence_len, key_sequence_len]. To use in
212
+ * conjunction with causal masking, the tensor should be a boolean type
213
+ * where false indicates a masked token.
214
+ * @param packingMask a mask to prevent tokens from attending across document boundaries
215
+ */
216
+ static scaledDotProductionAttention(
217
+ query: tf.Tensor,
218
+ key: tf.Tensor,
219
+ value: tf.Tensor,
220
+ attentionMask: tf.Tensor | null,
221
+ packingMask: tf.Tensor | null,
222
+ causalMask: tf.Tensor | null,
223
+ dropout: number,
224
+ causal: boolean,
225
+ kwargs: ScaledDotProductionAttentionKwargs = {}
226
+ ): tf.Tensor {
227
+ return tf.tidy(() => {
228
+ const { training = false, scaling_factor } = kwargs;
229
+
230
+ key.shape.forEach((val, index) => {
231
+ if (key.shape[index] != value.shape[index]) {
232
+ throw Error(`scaledDotProductionAttention: expected key and value` +
233
+ ` to have the same shape, got ${JSON.stringify(key.shape)} (key) and` +
234
+ ` ${JSON.stringify(value.shape)} (value)`);
235
+ }
236
+ })
237
+
238
+
239
+ // mask's shape is [..., seq, seq] where seq is the number of words/tokens in the input,
240
+ // not adding the batch dimension yet to lessen the calculations
241
+ const causal_mask_shape = [
242
+ query.shape[query.shape.length - 2],
243
+ key.shape[key.shape.length - 2]];
244
+
245
+ let mask = tf.zeros(causal_mask_shape);
246
+
247
+ if (causal && causal_mask_shape[0] > 1) {
248
+ if (attentionMask && attentionMask.dtype != "bool") {
249
+ throw Error(`scaledDotProductionAttention: the attention mask must be undefined or a boolean type if used with causal attention`);
250
+ }
251
+
252
+ // apply a causal attention mask so that tokens can only attend to preceding tokens,
253
+ // prevents looking at head
254
+ if (causalMask) {
255
+ mask = causalMask;
256
+ } else {
257
+ mask = generateCausalAttentionMask(causal_mask_shape[0], causal_mask_shape[1]);
258
+ }
259
+ }
260
+
261
+ if (attentionMask) {
262
+ if (attentionMask.dtype == "bool") {
263
+ // convert the boolean mask to float
264
+ // warning: do not use 1e9, it will overflow, use something smaller like 1e7
265
+ mask = mask.add(attentionMask.cast("float32").sub(1).mul(1e7));
266
+ } else {
267
+ // this will occur only when not using causal masking,
268
+ // if the attention mask is not boolean, it's assumed the masking is already calculated,
269
+ mask = attentionMask;
270
+ }
271
+ }
272
+
273
+ // 1. matrix multiply query and transposed key
274
+ // 2. divide by scaling factor
275
+ // 3. apply softmax to the result
276
+ // 4. apply attention and/or causal mask
277
+ // 5. apply dropout
278
+ // 6. matrix multiply softmax result with value
279
+ let pre_softmax = query
280
+ .matMul(key, false, true)
281
+ .div(Math.sqrt(scaling_factor ?? key.shape[key.shape.length - 1]))
282
+ .add(mask);
283
+
284
+ if (packingMask) {
285
+ // packing mask is added separately because each mask within a batch may be different,
286
+ // so it cannot be broadcasted
287
+ pre_softmax = pre_softmax.add(packingMask);
288
+ }
289
+
290
+ const spda = tf.softmax(pre_softmax);
291
+
292
+ const spda_dropout = tf.dropout(spda, training ? dropout : 0);
293
+ const attention = spda_dropout.matMul(value);
294
+
295
+ return attention;
296
+ });
297
+ }
298
+
299
+
300
+ override build(inputShape: tf.Shape | tf.Shape[]): void {
301
+ let input_shape: tf.Shape[] = [];
302
+
303
+ if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
304
+ input_shape = inputShape as tf.Shape[];
305
+ } else {
306
+ input_shape = [inputShape as tf.Shape, inputShape as tf.Shape, inputShape as tf.Shape];
307
+ }
308
+
309
+ if (input_shape.length != 1 && input_shape.length != 3) {
310
+ throw Error(`${this.getClassName()}::build ${this.name} accepts either exactly one or three inputs, received ${JSON.stringify(inputShape)}`);
311
+ }
312
+
313
+ // initialize the sublayer weights
314
+ this.queryProjection.build(input_shape[0]);
315
+ this.keyProjection.build(input_shape[1]);
316
+ this.valueProjection.build(input_shape[2]);
317
+ this.outputProjection.build(input_shape[0]);
318
+
319
+ // the sublayer weights need to be tracked by this layer otherwise
320
+ // backpropagation will complain about no trainable parameters found,
321
+ // this is an extra step that TF's Python version does not need
322
+ this.trainableWeights = [
323
+ ...this.queryProjection.trainableWeights,
324
+ ...this.keyProjection.trainableWeights,
325
+ ...this.valueProjection.trainableWeights,
326
+ ...this.outputProjection.trainableWeights
327
+ ];
328
+
329
+ // rename the weights otherwise they'll take on the default naming and overlap
330
+ // each other which breaks model loading due to duplicate weight names
331
+ let indexing = 0;
332
+
333
+ for (const weight of this.trainableWeights) {
334
+ const unique_name = `${this.getClassName()}_${indexing}`;
335
+ (weight as any).name += unique_name;
336
+ (weight as any).originalName += unique_name;
337
+ indexing++;
338
+ }
339
+
340
+ super.build(inputShape);
341
+ }
342
+
343
+
344
+ /**
345
+ * MultiHead attention's output is the same shape the query's.
346
+ */
347
+ override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
348
+ return Array.isArray(inputShape) && Array.isArray(inputShape[0]) ? inputShape[0] : inputShape;
349
+ }
350
+
351
+
352
+ override getConfig() {
353
+ const base_config = super.getConfig();
354
+
355
+ const config = {
356
+ numHeads: this.numHeads,
357
+ embedDim: this.embedDim,
358
+ useBias: this.useBias,
359
+ causal: this.causal,
360
+ dropout: this.dropout,
361
+ name: this.name,
362
+ }
363
+
364
+ Object.assign(config, base_config);
365
+
366
+ return config;
367
+ }
368
+ }
369
+
370
+
371
+ tf.serialization.registerClass(MultiHeadAttention);
@@ -0,0 +1,113 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+
3
+ import { PositionalEncoding } from '@/layers/positional_encoding';
4
+
5
+ // disables warning for using the faster node backend,
6
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
+ tf.env().set('IS_NODE', false);
8
+
9
+
10
+ describe("PositionalEncoding tests", () => {
11
+ it("should fail to instantiate a layer", () => {
12
+ expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: 0 })).toThrow();
13
+ expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: -1 })).toThrow();
14
+ expect(() => new PositionalEncoding({ maxSequenceLength: 0, embedDim: 32 })).toThrow();
15
+ expect(() => new PositionalEncoding({ maxSequenceLength: -1, embedDim: 32 })).toThrow();
16
+ })
17
+
18
+
19
+ test("successfull forward calls", () => {
20
+ const embed_dims = 32;
21
+ const sequences = 4;
22
+ const input = tf.randomUniform([2, sequences, embed_dims]);
23
+
24
+ const positional = new PositionalEncoding({ embedDim: embed_dims });
25
+ expect(() => positional.apply(input)).not.toThrow();
26
+ expect(() => positional.apply([input])).not.toThrow();
27
+ expect(positional.computeOutputShape(input.shape)).toEqual(input.shape);
28
+ })
29
+
30
+
31
+ it("should throw when input sequences are too large, embedding dims don't match, input aren't rank 3", () => {
32
+ const sequences_too_long = tf.randomUniform([100, 32]);
33
+ const embeddings_too_large = tf.randomUniform([32, 100]);
34
+ const wrong_rank = tf.randomUniform([10, 32, 32]);
35
+
36
+ const positional = new PositionalEncoding({ maxSequenceLength: 10, embedDim: 32 });
37
+
38
+ expect(() => positional.apply(sequences_too_long)).toThrow();
39
+ expect(() => positional.apply(embeddings_too_large)).toThrow();
40
+ expect(() => positional.apply(wrong_rank)).toThrow();
41
+ })
42
+
43
+
44
+ it("should return a non-empty config dict", () => {
45
+ const attention = new PositionalEncoding({ embedDim: 32 });
46
+ expect(Object.keys(attention.getConfig())).not.toBe(0);
47
+ })
48
+
49
+
50
+ // PyTorch implementation at found at
51
+ // https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
52
+ it("should be within 1e-6 of PyTorch's implementation", () => {
53
+ const pytorch_embed4 = tf.tensor([
54
+ [[0.0000000, 1.0000000, 0.0000000, 1.0000000],
55
+ [0.8414710, 0.5403023, 0.0099998, 0.9999500],
56
+ [0.9092974, -0.4161468, 0.0199987, 0.9998000],
57
+ [0.1411200, -0.9899925, 0.0299955, 0.9995500],
58
+ [-0.7568025, -0.6536436, 0.0399893, 0.9992001],
59
+ [-0.9589243, 0.2836622, 0.0499792, 0.9987503],
60
+ [-0.2794155, 0.9601703, 0.0599640, 0.9982005],
61
+ [0.6569866, 0.7539023, 0.0699428, 0.9975510],
62
+ [0.9893582, -0.1455000, 0.0799147, 0.9968017],
63
+ [0.4121185, -0.9111302, 0.0898785, 0.9959527]]]);
64
+
65
+ const pytorch_embed8 = tf.tensor([
66
+ [[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00,
67
+ 0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00],
68
+ [8.4147096e-01, 5.4030234e-01, 9.9833414e-02, 9.9500418e-01,
69
+ 9.9998331e-03, 9.9994999e-01, 9.9999981e-04, 9.9999952e-01],
70
+ [9.0929741e-01, -4.1614684e-01, 1.9866931e-01, 9.8006660e-01,
71
+ 1.9998666e-02, 9.9980003e-01, 1.9999985e-03, 9.9999803e-01],
72
+ [1.4112000e-01, -9.8999250e-01, 2.9552019e-01, 9.5533651e-01,
73
+ 2.9995499e-02, 9.9955004e-01, 2.9999954e-03, 9.9999553e-01],
74
+ [-7.5680250e-01, -6.5364361e-01, 3.8941833e-01, 9.2106098e-01,
75
+ 3.9989334e-02, 9.9920011e-01, 3.9999890e-03, 9.9999201e-01],
76
+ [-9.5892429e-01, 2.8366220e-01, 4.7942552e-01, 8.7758255e-01,
77
+ 4.9979165e-02, 9.9875027e-01, 4.9999789e-03, 9.9998754e-01],
78
+ [-2.7941549e-01, 9.6017027e-01, 5.6464243e-01, 8.2533562e-01,
79
+ 5.9964005e-02, 9.9820054e-01, 5.9999637e-03, 9.9998200e-01],
80
+ [6.5698659e-01, 7.5390226e-01, 6.4421761e-01, 7.6484221e-01,
81
+ 6.9942847e-02, 9.9755102e-01, 6.9999420e-03, 9.9997550e-01],
82
+ [9.8935825e-01, -1.4550003e-01, 7.1735609e-01, 6.9670677e-01,
83
+ 7.9914689e-02, 9.9680167e-01, 7.9999138e-03, 9.9996799e-01],
84
+ [4.1211849e-01, -9.1113025e-01, 7.8332686e-01, 6.2160999e-01,
85
+ 8.9878544e-02, 9.9595273e-01, 8.9998785e-03, 9.9995953e-01]]]);
86
+
87
+ const positional4 = new PositionalEncoding({ embedDim: 4, maxSequenceLength: 10 });
88
+ positional4.build([]);
89
+
90
+ const positional8 = new PositionalEncoding({ embedDim: 8, maxSequenceLength: 10 });
91
+ positional8.build([]);
92
+
93
+ const margin_of_error = 1e-6;
94
+
95
+ // the difference between this and PyTorch's implementation
96
+ //should be insignificantly small
97
+ expect((positional4.getWeights()[0]
98
+ .sub(pytorch_embed4)
99
+ .abs()
100
+ .arraySync() as [])
101
+ .flat(2)
102
+ .filter(i => i > margin_of_error)
103
+ .length).toBe(0);
104
+
105
+ expect((positional8.getWeights()[0]
106
+ .sub(pytorch_embed8)
107
+ .abs()
108
+ .arraySync() as [])
109
+ .flat(2)
110
+ .filter(i => i > margin_of_error)
111
+ .length).toBe(0);
112
+ });
113
+ });
@@ -0,0 +1,158 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
4
+
5
+
6
+ export interface PositionalEncodingArgs extends LayerArgs {
7
+ // embedding size of each word/token, aka d_model from the paper
8
+ embedDim: number;
9
+ // the max length of each sentence, any more or less are truncated or padded
10
+ maxSequenceLength?: number;
11
+ }
12
+
13
+
14
+ /**
15
+ * This class implements the position encoding logic described in the
16
+ * 2017 paper "Attention Is All You Need".
17
+ *
18
+ * This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
19
+ * and adds positional encoding to return an output tensor of the same shape.
20
+ *
21
+ * @param embedDim the size of each token/word's embedding
22
+ * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
23
+ */
24
+ export class PositionalEncoding extends tf.layers.Layer {
25
+ static className = "PositionalEncoding";
26
+ private readonly maxSequenceLength: number;
27
+ private readonly embedDim: number;
28
+ private positionalEncodings: tf.LayerVariable;
29
+
30
+
31
+ constructor(args: PositionalEncodingArgs) {
32
+ super(args);
33
+
34
+ this.maxSequenceLength = args.maxSequenceLength ?? 5120;
35
+ this.embedDim = args.embedDim;
36
+
37
+ if (this.maxSequenceLength < 1) {
38
+ throw Error(`${this.getClassName()}::constructor ${this.name} maxSequenceLength` +
39
+ ` (${args.maxSequenceLength}) must be greater than 0`);
40
+ }
41
+
42
+ if (this.embedDim < 1) {
43
+ throw Error(`${this.getClassName()}::constructor ${this.name} embedDim` +
44
+ ` (${args.embedDim}) must be greater than 0`);
45
+ }
46
+
47
+ // positional encodings are not trainable
48
+ this.positionalEncodings = this.addWeight('positional_encodings',
49
+ [this.maxSequenceLength, this.embedDim], "float32",
50
+ tf.initializers.zeros(), undefined, false);
51
+ }
52
+
53
+
54
+ /**
55
+ * Forward propagation. Injects positional encoding to the input embeddings
56
+ */
57
+ override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
58
+ // validate the input tensors
59
+ const input = Array.isArray(inputs) ? inputs[0] : inputs;
60
+ const sequences = input.shape[1]!;
61
+
62
+ if (input.shape.length != 3 || input.shape[2] != this.embedDim) {
63
+ throw Error(`${this.getClassName()}::call ${this.name} expected an input shape of` +
64
+ ` [batch, (up to ${this.maxSequenceLength}), ${this.embedDim}], instead got ${input.shape}`);
65
+ }
66
+
67
+ if (sequences > this.maxSequenceLength) {
68
+ // unexpected sequence length
69
+ throw Error(`${this.getClassName()}::call ${this.name} received an input with` +
70
+ ` sequence length (${sequences}) which is greater than the max sequence length` +
71
+ ` ${this.maxSequenceLength}`);
72
+ }
73
+
74
+ // perform forward propagation
75
+ return tf.tidy(() => {
76
+ return input.add(this.positionalEncodings.read()
77
+ .slice([0, 0], [sequences, this.embedDim]) // gets the first "sequences" rows
78
+ .expandDims(0)); // introduce the batch dimension and let add() broadcast it
79
+ })
80
+ }
81
+
82
+ /**
83
+ * Generate the positional encoding from the paper Attention Is All You Need.
84
+ * Note that because the inner term of the position formula is the same for both even
85
+ * and odd indices, we only create half of it and apply sine and cosine individually.
86
+ */
87
+ override build(inputShape: tf.Shape | tf.Shape[]): void {
88
+ tf.tidy(() => {
89
+ const embedDimHalved = Math.ceil(this.embedDim / 2);
90
+
91
+ // create the position matrix as [ 0, 1, 2, 3, etc ],
92
+ // and broadcast it horizontally to match the number of embeddings,
93
+ const numerator = tf.range(0, this.maxSequenceLength, 1)
94
+ .reshape([this.maxSequenceLength, 1])
95
+ // this creates an extra, unsued positional encoding column later on for odd embedding sizes
96
+ .broadcastTo([this.maxSequenceLength, embedDimHalved]);
97
+
98
+ // the inner term's denominator's exponent's numerator is created as
99
+ // [ 0, 0, 2, 2, 4, 4, etc ] ( technically [0, 2, 4] as explained above ) and not
100
+ // [ 0, 2, 4, 6, 8, 10, etc ] because the even and odd indices are counted as pairs
101
+ // when incrementing "i",
102
+ // the denominator formula is 10_000^(2i/d_model) where each "i" is a sine cosine pair
103
+ const denominator = tf.pow(10_000, tf.range(0, this.embedDim, 2).div(this.embedDim));
104
+
105
+ const inner_term = numerator.div(denominator);
106
+
107
+ const sine = tf.sin(inner_term);
108
+ const cosine = tf.cos(inner_term);
109
+
110
+ // horizontally interweave the sine and cosine columns together to form
111
+ // [sin, cos, sin, cos, etc]
112
+ // [sin, cos, sin, cos, etc]
113
+ // etc
114
+ const interweaved = [];
115
+ const ALL_ROWS = -1;
116
+ const ONE_COL = 1;
117
+ const FIRST_ROW = 0;
118
+
119
+ for (let targetCol = 0; targetCol < this.embedDim / 2; targetCol++) {
120
+ interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
121
+
122
+ if (targetCol != Math.floor(this.embedDim / 2)) {
123
+ // for odd numbered embedDim sizes skip the last cosine column
124
+ // e.g. if embedDim = 5, create [ i=0 (sin), i=0 (cos), i=1 (sin), i=1 (cos), i=2 (sin) ]
125
+ // and the final i=2 (cos) is ignored
126
+ interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
127
+ }
128
+ }
129
+
130
+ // add the positional encoding
131
+ this.setWeights([tf.concat(interweaved, 1)]);
132
+ });
133
+
134
+ super.build(inputShape);
135
+ }
136
+
137
+
138
+ override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
139
+ return inputShape;
140
+ }
141
+
142
+
143
+ override getConfig(): tf.serialization.ConfigDict {
144
+ const base_config = super.getConfig();
145
+
146
+ const config = {
147
+ maxSequenceLength: this.maxSequenceLength,
148
+ embedDim: this.embedDim,
149
+ }
150
+
151
+ Object.assign(config, base_config);
152
+
153
+ return config;
154
+ }
155
+ }
156
+
157
+
158
+ tf.serialization.registerClass(PositionalEncoding);