@stellarapp/tfjs-stellar 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,107 @@
1
+ import { RotaryPositionEmbedding } from "@/layers/rotary_position_embedding";
2
+ import * as tf from "@tensorflow/tfjs";
3
+
4
+ // disables warning for using the faster node backend,
5
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
6
+ tf.env().set('IS_NODE', false);
7
+
8
+
9
+ describe("RotaryPositionEmbedding tests", () => {
10
+ test("create cache", async () => {
11
+ const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
12
+ rope.build([]);
13
+
14
+ const expected_cosine_cache = tf.tensor([[[
15
+ [1, 1, 1, 1, 1, 1, 1, 1],
16
+ [0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334, 0.9999995231628418, 0.9999995231628418],
17
+ [-0.416146844625473, -0.416146844625473, 0.9800665974617004, 0.9800665974617004, 0.9998000264167786, 0.9998000264167786, 0.9999979734420776, 0.9999979734420776],
18
+ [-0.9899924993515015, -0.9899924993515015, 0.9553365111351013, 0.9553365111351013, 0.9995500445365906, 0.9995500445365906, 0.9999955296516418, 0.9999955296516418],
19
+ [-0.6536436080932617, -0.6536436080932617, 0.9210609793663025, 0.9210609793663025, 0.9992001056671143, 0.9992001056671143, 0.9999920129776001, 0.9999920129776001],
20
+ [0.28366219997406006, 0.28366219997406006, 0.8775825500488281, 0.8775825500488281, 0.9987502694129944, 0.9987502694129944, 0.9999874830245972, 0.9999874830245972],
21
+ [0.9601702690124512, 0.9601702690124512, 0.8253356218338013, 0.8253356218338013, 0.998200535774231, 0.998200535774231, 0.9999819993972778, 0.9999819993972778],
22
+ [0.7539022564888, 0.7539022564888, 0.7648422122001648, 0.7648422122001648, 0.9975510239601135, 0.9975510239601135, 0.9999755024909973, 0.9999755024909973],
23
+ [-0.1455000340938568, -0.1455000340938568, 0.6967067122459412, 0.6967067122459412, 0.9968017339706421, 0.9968017339706421, 0.9999679923057556, 0.9999679923057556],
24
+ [-0.9111302495002747, -0.9111302495002747, 0.6216099262237549, 0.6216099262237549, 0.9959527254104614, 0.9959527254104614, 0.9999595284461975, 0.9999595284461975],
25
+ [-0.83907151222229, -0.83907151222229, 0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334],
26
+ [0.004425697959959507, 0.004425697959959507, 0.4535960853099823, 0.4535960853099823, 0.9939560890197754, 0.9939560890197754, 0.999939501285553, 0.999939501285553],
27
+ [0.8438539505004883, 0.8438539505004883, 0.3623577058315277, 0.3623577058315277, 0.9928086400032043, 0.9928086400032043, 0.9999279975891113, 0.9999279975891113],
28
+ [0.9074468016624451, 0.9074468016624451, 0.26749876141548157, 0.26749876141548157, 0.9915618896484375, 0.9915618896484375, 0.9999154806137085, 0.9999154806137085],
29
+ [0.13673721253871918, 0.13673721253871918, 0.1699671596288681, 0.1699671596288681, 0.9902160167694092, 0.9902160167694092, 0.9999020099639893, 0.9999020099639893]
30
+ ]]]);
31
+
32
+ const expected_sine_cache = tf.tensor([[[
33
+ [0, 0, 0, 0, 0, 0, 0, 0],
34
+ [0.8414709568023682, 0.8414709568023682, 0.0998334214091301, 0.0998334214091301, 0.009999833069741726, 0.009999833069741726, 0.0009999999310821295, 0.0009999999310821295],
35
+ [0.9092974066734314, 0.9092974066734314, 0.19866932928562164, 0.19866932928562164, 0.019998665899038315, 0.019998665899038315, 0.0019999986980110407, 0.0019999986980110407],
36
+ [0.14112000167369843, 0.14112000167369843, 0.29552021622657776, 0.29552021622657776, 0.029995499178767204, 0.029995499178767204, 0.0029999956022948027, 0.0029999956022948027],
37
+ [-0.756802499294281, -0.756802499294281, 0.3894183337688446, 0.3894183337688446, 0.03998933359980583, 0.03998933359980583, 0.003999989479780197, 0.003999989479780197],
38
+ [-0.9589242935180664, -0.9589242935180664, 0.4794255495071411, 0.4794255495071411, 0.04997916519641876, 0.04997916519641876, 0.0049999793991446495, 0.0049999793991446495],
39
+ [-0.279415488243103, -0.279415488243103, 0.5646424889564514, 0.5646424889564514, 0.059964004904031754, 0.059964004904031754, 0.0059999641962349415, 0.0059999641962349415],
40
+ [0.6569865942001343, 0.6569865942001343, 0.6442176699638367, 0.6442176699638367, 0.06994284689426422, 0.06994284689426422, 0.0069999429397284985, 0.0069999429397284985],
41
+ [0.9893582463264465, 0.9893582463264465, 0.7173560857772827, 0.7173560857772827, 0.07991468906402588, 0.07991468906402588, 0.007999914698302746, 0.007999914698302746],
42
+ [0.41211849451065063, 0.41211849451065063, 0.7833269238471985, 0.7833269238471985, 0.08987854421138763, 0.08987854421138763, 0.008999879471957684, 0.008999879471957684],
43
+ [-0.5440211296081543, -0.5440211296081543, 0.8414709568023682, 0.8414709568023682, 0.0998334139585495, 0.0998334139585495, 0.0099998340010643, 0.0099998340010643],
44
+ [-0.9999902248382568, -0.9999902248382568, 0.8912073969841003, 0.8912073969841003, 0.10977829992771149, 0.10977829992771149, 0.010999779216945171, 0.010999779216945171],
45
+ [-0.5365729331970215, -0.5365729331970215, 0.9320390820503235, 0.9320390820503235, 0.11971220374107361, 0.11971220374107361, 0.011999712325632572, 0.011999712325632572],
46
+ [0.4201670289039612, 0.4201670289039612, 0.9635581970214844, 0.9635581970214844, 0.12963414192199707, 0.12963414192199707, 0.012999634258449078, 0.012999634258449078],
47
+ [0.9906073808670044, 0.9906073808670044, 0.9854497313499451, 0.9854497313499451, 0.13954311609268188, 0.13954311609268188, 0.013999543152749538, 0.013999543152749538]
48
+ ]]]);
49
+
50
+ const [cosine_cache, sine_cache] = rope.getWeights();
51
+
52
+ expect(await cosine_cache?.sub(expected_cosine_cache).sum().array() as number).toBeLessThanOrEqual(1e-6);
53
+ expect(await sine_cache?.sub(expected_sine_cache).sum().array() as number).toBeLessThanOrEqual(1e-6);
54
+ })
55
+
56
+
57
+ test("rotate inputs", async () => {
58
+ const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
59
+
60
+ const x = tf.tensor([[[
61
+ [0.0766048, 0.5706575, 0.6705932, 0.5273118, 0.4794086, 0.9378104, 0.9888024, 0.6926053],
62
+ [0.9064133, 0.5875182, 0.1681865, 0.3833345, 0.9901192, 0.4677338, 0.3353315, 0.02699],
63
+ [0.3033573, 0.4139377, 0.4062586, 0.9705839, 0.3582608, 0.328775, 0.1340587, 0.2193414],
64
+ [0.5565202, 0.4334963, 0.9912352, 0.3388563, 0.7991487, 0.1911893, 0.1140554, 0.6949552]]]
65
+ ]); // batch=1, seq = 1, heads=4, embedDim=8
66
+
67
+ const expected_output = tf.tensor([[[
68
+ [0.07660479843616486, 0.57065749168396, 0.6705932021141052, 0.5273118019104004, 0.4794085919857025, 0.9378104209899902, 0.9888023734092712, 0.6926053166389465],
69
+ [-0.004642367362976074, 1.08015775680542, 0.12907665967941284, 0.39820998907089233, 0.9853923320770264, 0.47761136293411255, 0.33530429005622864, 0.027325313538312912],
70
+ [-0.5026336908340454, 0.10358311235904694, 0.20533521473407745, 1.0319478511810303, 0.3516140580177307, 0.33587393164634705, 0.1336197406053543, 0.21960905194282532],
71
+ [-0.6121258735656738, -0.3506217896938324, 0.8468242287635803, 0.6166517734527588, 0.7930541634559631, 0.2150741070508957, 0.11197001487016678, 0.695294201374054]
72
+ ]]]);
73
+
74
+ const output = rope.apply(x) as tf.Tensor;
75
+
76
+ expect(await expected_output.sub(output).sum().array() as number).toBeLessThan(1e-6);
77
+ expect(rope.computeOutputShape(x.shape)).toEqual(x.shape);
78
+ expect(rope.computeOutputShape([x.shape])).toEqual(x.shape);
79
+ })
80
+
81
+
82
+ test("expand cache when input sequences are larger than rope's max sequence length", async () => {
83
+ const dim = 8;
84
+ const rope = new RotaryPositionEmbedding({ dim, maxSequenceLength: 15, theta: 1_000_000 });
85
+ const larger_sequence = 20;
86
+ const even_larger_sequence = 50;
87
+
88
+ rope.apply(tf.randomUniform([1, 1, larger_sequence, dim]));
89
+
90
+ rope.getWeights().forEach(weight => {
91
+ expect(weight.shape).toEqual([1, 1, 32, dim]);
92
+ });
93
+
94
+ rope.apply([tf.randomUniform([1, 1, even_larger_sequence, dim])]);
95
+
96
+ rope.getWeights().forEach(weight => {
97
+ expect(weight.shape).toEqual([1, 1, 64, dim]);
98
+ });
99
+ })
100
+
101
+
102
+ test("create layer", async () => {
103
+ // dim must be even
104
+ expect(() => new RotaryPositionEmbedding({ dim: 7, maxSequenceLength: 15 })).toThrow();
105
+ expect(() => new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 25 })).not.toThrow();
106
+ })
107
+ });
@@ -0,0 +1,163 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type LayerArgs } from "@tensorflow/tfjs-layers/dist/engine/topology";
3
+
4
+
5
+ export function applyRope(x: tf.Tensor, dim: number, cosine_cache: tf.Tensor, sine_cache: tf.Tensor) {
6
+ return tf.tidy(() => {
7
+ const seq_length = x.shape[2]!;
8
+
9
+ // get a slice of the pre-computed cache, up to the input's sequence length
10
+ const cosine = cosine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
11
+ const sine = sine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
12
+
13
+ // apply RoPE formula (x1 * cosine) + (rotate(-x2) * sine)
14
+ const rotated_x = rotateHalf(x, dim);
15
+
16
+ return tf.add(tf.mul(x, cosine), tf.mul(rotated_x, sine));
17
+ });
18
+ }
19
+
20
+
21
+ export function rotateHalf(x: tf.Tensor, dim: number): tf.Tensor {
22
+ return tf.tidy(() => {
23
+ // reshape the last dimension such that adjacent coordinates are paired together
24
+ // [x1, x2, x3, x4] -> [[x1, x2], [x3, x4]]
25
+ // the leading dimensions are flattened because TFJS has issues during
26
+ // backpropagation with 5D slicing
27
+ const reshaped = x.reshape([-1, dim / 2, 2]);
28
+
29
+ const x1 = reshaped.slice([0, 0, 0], [-1, -1, 1]);
30
+ const x2 = reshaped.slice([0, 0, 1], [-1, -1, 1]);
31
+
32
+ // [x1, x2] -> [-x2, x1]
33
+ const rotated = tf.concat([tf.neg(x2), x1], -1);
34
+
35
+ return rotated.reshape(x.shape);
36
+ });
37
+ }
38
+
39
+
40
+ export function createRoPECache(dim: number, max_sequence_length: number, theta: number = 10_000) {
41
+ return tf.tidy(() => {
42
+ // [dim]
43
+ const inv_frequencies = tf.div<tf.Tensor1D>(1, tf.pow(
44
+ theta,
45
+ tf.range(0, Math.floor(dim / 2) * 2, 2, "float32").div(dim)));
46
+
47
+ // [max_sequene_length]
48
+ const sequence_indices = tf.range(0, max_sequence_length);
49
+ //
50
+ const freq = tf.outerProduct(sequence_indices, inv_frequencies);
51
+
52
+ // cache final shape [max_sequence_length, dim]
53
+ const freq_pairs = tf.stack([freq, freq], -1)
54
+ .reshape([max_sequence_length, dim]);
55
+
56
+ return [
57
+ tf.keep(tf.cos(freq_pairs).expandDims(0).expandDims(0)),
58
+ tf.keep(tf.sin(freq_pairs).expandDims(0).expandDims(0))
59
+ ]
60
+ });
61
+ }
62
+
63
+
64
+ export interface RotaryPositionEmbeddingArgs extends LayerArgs {
65
+ /**
66
+ * The dimension of each head (rounded down), e.g. `Math.floor(embedDim / numHeads)`
67
+ */
68
+ dim: number,
69
+ /**
70
+ * The RoPE cache will be pre-calculated up to the max sequence length, and re-caculated as needed. Defaults to `4096`.
71
+ */
72
+ maxSequenceLength?: number,
73
+ /**
74
+ * The base for the geometric progression used to compute the rotation angles. Defaults to `10_000`.
75
+ */
76
+ theta?: number,
77
+ }
78
+
79
+
80
+ /**
81
+ * Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
82
+ * Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
83
+ */
84
+ export class RotaryPositionEmbedding extends tf.layers.Layer {
85
+ static className = "RotaryPositionEmbedding";
86
+
87
+ protected dim: number;
88
+ protected max_sequence_length: number;
89
+ protected theta: number;
90
+
91
+ // cached sine and cosine frequencies, untrainable weights
92
+ protected cosine_cache: tf.LayerVariable;
93
+ protected sine_cache: tf.LayerVariable;
94
+
95
+ constructor({ dim, maxSequenceLength = 4096, theta = 10_000, ...args }: RotaryPositionEmbeddingArgs) {
96
+ super(args);
97
+
98
+ if (dim % 2 !== 0) {
99
+ throw Error(`${this.getClassName()}::constructor ${this.name} expected dim to be even, got ${dim}`);
100
+ }
101
+
102
+ this.dim = dim;
103
+ this.max_sequence_length = maxSequenceLength;
104
+ this.theta = theta;
105
+
106
+ this.cosine_cache = this.addWeight("sine_cache",
107
+ [1, 1, maxSequenceLength, Math.floor(this.dim)],
108
+ "float32", tf.initializers.zeros(), undefined, false);
109
+
110
+ this.sine_cache = this.addWeight("cosine_cache",
111
+ [1, 1, maxSequenceLength, Math.floor(this.dim)],
112
+ "float32", tf.initializers.zeros(), undefined, false);
113
+ }
114
+
115
+
116
+ override call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[] {
117
+ const shape = Array.isArray(inputs) ? inputs[0].shape : inputs.shape;
118
+ const seq_length = shape[2];
119
+
120
+ if (seq_length > this.max_sequence_length) {
121
+ // expand cache to the nearest power of 2
122
+ this.max_sequence_length = Math.pow(2, Math.ceil(Math.log2(seq_length)));
123
+ this.build([]);
124
+ }
125
+
126
+ return applyRope(
127
+ Array.isArray(inputs) ? inputs[0] : inputs,
128
+ this.dim,
129
+ this.cosine_cache.read(),
130
+ this.sine_cache.read())
131
+ }
132
+
133
+
134
+ override build(input_shape: tf.Shape | tf.Shape[]) {
135
+ const [cosine, sine] = createRoPECache(
136
+ this.dim, this.max_sequence_length, this.theta);
137
+
138
+ this.cosine_cache.dispose();
139
+ this.sine_cache.dispose();
140
+
141
+ this.cosine_cache = new tf.LayerVariable(cosine);
142
+ this.sine_cache = new tf.LayerVariable(sine);
143
+
144
+ this.nonTrainableWeights = [
145
+ new tf.LayerVariable(cosine),
146
+ new tf.LayerVariable(sine)
147
+ ];
148
+
149
+ this.setWeights([cosine, sine]);
150
+ }
151
+
152
+
153
+ /**
154
+ * Output shape: [batch, head, sequence, head_dim]
155
+ */
156
+ public computeOutputShape(input_shape: tf.Shape | tf.Shape[]) {
157
+ return Array.isArray(input_shape[0])
158
+ ? input_shape[0] as tf.Shape
159
+ : input_shape as tf.Shape;
160
+ }
161
+ }
162
+
163
+ tf.serialization.registerClass(RotaryPositionEmbedding);
@@ -0,0 +1,81 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+
3
+ import { TokenAndPositionalEmbedding } from '@/layers/token_and_positional_embedding';
4
+
5
+ // disables warning for using the faster node backend,
6
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
+ tf.env().set('IS_NODE', false);
8
+
9
+
10
+ describe("PositionalEncoding tests", () => {
11
+ test("layer initialization", () => {
12
+ expect(() => new TokenAndPositionalEmbedding({ maxSequenceLength: 0, embedDim: 10, vocabularySize: 10_000 })).toThrow();
13
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 0, vocabularySize: 10_000 })).toThrow();
14
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 0 })).toThrow();
15
+
16
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
17
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
18
+ })
19
+
20
+
21
+ test("successfull forward calls", () => {
22
+ const embed_dims = 32;
23
+ const sequences = 4;
24
+ const vocab_size = 10_000;
25
+ const input = tf.randomUniform([2, sequences]);
26
+
27
+ const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, dropout: 0.1, vocabularySize: vocab_size });
28
+ expect(() => embedding.apply(input)).not.toThrow();
29
+ expect(() => embedding.apply([input])).not.toThrow();
30
+ })
31
+
32
+
33
+ test("layer build", () => {
34
+ const input_ok = tf.randomUniform([2, 4]);
35
+ const input_too_many_words = tf.randomUniform([2, 700]);
36
+ const input_is_image = tf.randomUniform([1, 32, 32, 3]);
37
+
38
+ let embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
39
+ expect(() => embedding.build(input_ok.shape)).not.toThrow();
40
+
41
+ embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
42
+ expect(() => embedding.build([input_ok.shape, input_ok.shape])).not.toThrow();
43
+
44
+ new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
45
+ expect(() => embedding.build(input_too_many_words.shape)).toThrow();
46
+ expect(() => embedding.build(input_is_image.shape)).toThrow();
47
+ })
48
+
49
+
50
+ it("should throw when more than one input provided, input sequences are too large, or incorrect input rank", () => {
51
+ const sequences_too_long = tf.randomUniform([10, 1000]);
52
+ const multiple_correct_inputs = [tf.randomUniform([2, 3]), tf.randomUniform([2, 3])];
53
+ const wrong_rank = tf.randomUniform([10, 32, 32]);
54
+
55
+ const positional = new TokenAndPositionalEmbedding({ maxSequenceLength: 10, embedDim: 32, vocabularySize: 10_000 });
56
+ positional.build([2, 3]); // get past the initial build call to test forward prop
57
+
58
+ expect(() => positional.apply(sequences_too_long)).toThrow();
59
+ expect(() => positional.apply(multiple_correct_inputs)).toThrow();
60
+ expect(() => positional.apply(wrong_rank)).toThrow();
61
+ })
62
+
63
+
64
+ it("should return a non-empty config dict", () => {
65
+ const embedding = new TokenAndPositionalEmbedding({ embedDim: 32, vocabularySize: 10_000 });
66
+ expect(Object.keys(embedding.getConfig())).not.toBe(0);
67
+ })
68
+
69
+
70
+ it("should return an output shape of [batch, sequences, embed dims]", () => {
71
+ const words = 100;
72
+ const batch = 2;
73
+ const embed_dims = 64;
74
+
75
+ const input = tf.randomUniform([batch, words]);
76
+
77
+ const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, vocabularySize: 10_000 });
78
+
79
+ expect(embedding.computeOutputShape(input.shape)).toEqual([batch, words, embed_dims]);
80
+ })
81
+ });
@@ -0,0 +1,149 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
+ import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
+
5
+ import { PositionalEncoding, type PositionalEncodingArgs } from '@/layers/positional_encoding';
6
+
7
+
8
+ export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEncodingArgs {
9
+ vocabularySize: number;
10
+ dropout?: number
11
+ }
12
+
13
+
14
+ /**
15
+ * This class implements combines sinusoidal positional encoding from the
16
+ * 2017 paper "Attention Is All You Need" with a normal embedding layer to
17
+ * form a simplified single embedding layer.
18
+ *
19
+ * This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
20
+ * it through an embedding layer before adding sinusoidal positional encoding.
21
+ *
22
+ * @param embedDim the size of each token/word's embedding
23
+ * @param vocabularySize the number of tokens to embed
24
+ * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
25
+ * @param dropout applies dropout to the positionally encoded embeddings, default `0.1`
26
+ */
27
+ export class TokenAndPositionalEmbedding extends tf.layers.Layer {
28
+ static className = "TokenAndPositionalEmbedding";
29
+
30
+ private readonly embedDim: number;
31
+ private readonly vocabularySize: number;
32
+ private embedding: tf.layers.Layer;
33
+
34
+ private positional: tf.layers.Layer
35
+ private readonly maxSequenceLength: number;
36
+ private readonly dropout: number;
37
+
38
+ private dropoutLayer: tf.layers.Layer;
39
+
40
+
41
+ constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }: TokenAndPositionalEmbeddingArgs) {
42
+ super(args);
43
+
44
+ this.embedDim = embedDim;
45
+ this.vocabularySize = vocabularySize;
46
+ this.maxSequenceLength = maxSequenceLength ?? 5120;
47
+ this.dropout = dropout ?? 0.1;
48
+
49
+ if (this.dropout >= 1) {
50
+ throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
51
+ }
52
+
53
+ this.embedding = tf.layers.embedding({
54
+ inputDim: this.vocabularySize,
55
+ outputDim: this.embedDim,
56
+ });
57
+
58
+ this.positional = new PositionalEncoding({
59
+ maxSequenceLength: this.maxSequenceLength,
60
+ embedDim: this.embedDim,
61
+ });
62
+
63
+ this.dropoutLayer = tf.layers.dropout({ rate: this.dropout });
64
+ }
65
+
66
+
67
+ /**
68
+ * Forward propagation.
69
+ */
70
+ override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs) {
71
+ if (Array.isArray(inputs) && inputs.length != 1) {
72
+ throw Error(`${this.getClassName()}::call ${this.name} expects exactly` +
73
+ ` 1 tensor input, received ${inputs.length}`);
74
+ }
75
+
76
+ return tf.tidy(() => {
77
+ let output = this.positional.apply(this.embedding.apply(inputs)) as tf.Tensor;
78
+ output = this.dropoutLayer.apply(output) as tf.Tensor;
79
+
80
+ return output;
81
+ })
82
+ }
83
+
84
+
85
+ /**
86
+ * Build the sublayers and enable serialization
87
+ */
88
+ override build(inputShape: tf.Shape | tf.Shape[]): void {
89
+ let input_shapes: tf.Shape[] = [];
90
+
91
+ // only consider the first shape if multiple provided
92
+ if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
93
+ // input is an array of shapes
94
+ input_shapes = inputShape as tf.Shape[];
95
+ } else if (inputShape.length != 0) {
96
+ // input is a single shape
97
+ input_shapes = [inputShape as tf.Shape];
98
+ }
99
+
100
+ if (input_shapes[0].length != 2 || input_shapes[0][1]! > this.maxSequenceLength) {
101
+ throw Error(`${this.getClassName()}::build ${this.name} expected an input of` +
102
+ ` shape [batch, tokens] where tokens < ${this.maxSequenceLength},` +
103
+ ` received ${JSON.stringify(input_shapes[0])}`);
104
+ }
105
+
106
+ // initialize the sublayers' weights
107
+ this.embedding.build(input_shapes[0]);
108
+ this.positional.build(this.embedding.computeOutputShape(input_shapes[0]));
109
+
110
+ // no need to rename weights, haven't found a case where their names collide
111
+ this.trainableWeights = [
112
+ ...this.embedding.trainableWeights,
113
+ ...this.positional.trainableWeights
114
+ ];
115
+
116
+ super.build(input_shapes[0]);
117
+ }
118
+
119
+
120
+ /**
121
+ * The output shape, for an input shape of [batch, sequences], is
122
+ * [batch, sequences, embedDim]
123
+ */
124
+ override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
125
+ const embedding_shape = this.embedding.computeOutputShape(inputShape);
126
+ const positional_shape = this.positional.computeOutputShape(embedding_shape);
127
+
128
+ return positional_shape;
129
+ }
130
+
131
+
132
+ override getConfig(): tf.serialization.ConfigDict {
133
+ const base_config = super.getConfig();
134
+
135
+ const config = {
136
+ embedDim: this.embedDim,
137
+ vocabularySize: this.vocabularySize,
138
+ maxSequenceLength: this.maxSequenceLength,
139
+ dropout: this.dropout,
140
+ }
141
+
142
+ Object.assign(config, base_config);
143
+
144
+ return config;
145
+ }
146
+ }
147
+
148
+
149
+ tf.serialization.registerClass(TokenAndPositionalEmbedding);
@@ -0,0 +1,100 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+
3
+ import { TransformerDecoder } from '@/layers/transformer_decoder';
4
+
5
+ // disables warning for using the faster node backend,
6
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
+ tf.env().set('IS_NODE', false);
8
+
9
+
10
+ describe("TransformerDecoder tests", () => {
11
+ it("should return an output with the same shape as the input", () => {
12
+ const input = tf.randomUniform([2, 3, 12]);
13
+
14
+ const decoder = new TransformerDecoder({
15
+ numHeads: 2, embedDim: input.shape.at(-1)!,
16
+ dropout: 0.5, activation: "gelu", dimsFeedForward: 321, useBias: false
17
+ });
18
+
19
+ const output = decoder.apply(input) as tf.Tensor;
20
+
21
+ expect(output.shape.length).toBe(input.shape.length);
22
+ })
23
+
24
+
25
+ test("forward calls", () => {
26
+ const input = tf.randomUniform([2, 3, 12]);
27
+ const mask = tf.randomUniform([input.shape[0]!, input.shape[1]!], -1, 2, "bool");
28
+ const incorrect_mask = tf.randomUniform([2, 5, 12], -1, 2, "bool");
29
+
30
+
31
+ const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1)! });
32
+ expect(() => decoder.apply(input)).not.toThrow();
33
+ expect(() => decoder.apply([input])).not.toThrow();
34
+
35
+ // causal masking
36
+ const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
37
+ expect(() => causal.apply(input)).not.toThrow();
38
+ expect(() => causal.apply([input])).not.toThrow();
39
+ })
40
+
41
+
42
+ it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
43
+ const input = tf.randomUniform([2, 3, 12]);
44
+
45
+ expect(() => new TransformerDecoder({ numHeads: 3, embedDim: input.shape.at(-1)! })).not.toThrow();
46
+ expect(() => new TransformerDecoder({ numHeads: 5, embedDim: input.shape.at(-1)! })).toThrow();
47
+ })
48
+
49
+
50
+ it("should not accept non-rank 3 tensor inputs", () => {
51
+ const embed_dim = 12;
52
+
53
+ const BAD_RANK4 = tf.randomUniform([2, 3, 12, embed_dim]);
54
+ const BAD_RANK2 = tf.randomUniform([2, embed_dim]);
55
+ const GOOD = tf.randomUniform([2, 3, embed_dim]);
56
+ const mask = tf.randomUniform([GOOD.shape[0]!, GOOD.shape[1]!], -1, 2, "bool");
57
+
58
+ let decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
59
+
60
+ // BAD
61
+ expect(() => decoder.apply(BAD_RANK4)).toThrow();
62
+ expect(() => decoder.apply(BAD_RANK2)).toThrow();
63
+
64
+ // OK
65
+ decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
66
+ expect(() => decoder.apply(GOOD)).not.toThrow();
67
+ expect(() => decoder.apply([GOOD])).not.toThrow();
68
+ expect(() => decoder.apply([GOOD, mask])).not.toThrow();
69
+ })
70
+
71
+
72
+ it("should not accept inputs that are less or more than 1 and 2 tensors", () => {
73
+ const input = tf.randomUniform([2, 3, 12]);
74
+
75
+ let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
76
+ // OK
77
+ expect(() => decoder.apply(input)).not.toThrow();
78
+ expect(() => decoder.apply([input])).not.toThrow();
79
+
80
+ // BAD
81
+ decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
82
+ expect(() => decoder.apply([])).toThrow(); // stops at build()
83
+ decoder.apply(input); // get past the initial build
84
+ expect(() => decoder.apply([input, input, input])).toThrow();
85
+ expect(() => decoder.apply([input, input, input, input])).toThrow();
86
+
87
+ // BAD (tests build())
88
+ decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
89
+ expect(() => decoder.apply([input, input, input])).toThrow();
90
+ expect(() => decoder.apply([input, input, input, input])).toThrow();
91
+ })
92
+
93
+
94
+ it("should return a non-empty config dict", () => {
95
+ const input = tf.randomUniform([2, 3, 12]);
96
+
97
+ const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
98
+ expect(Object.keys(decoder.getConfig())).not.toBe(0);
99
+ })
100
+ })