@stellarapp/tfjs-stellar 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,113 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { KvCacheContainer } from "@/kv_cache";
3
+ import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
4
+ import { RotaryPositionEmbedding } from '@/layers/rotary_position_embedding';
5
+ import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
6
+
7
+
8
+ /**
9
+ * MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
10
+ * should be used in a custom training loop because it requires the cache to be
11
+ * passed through the `kwargs.kvCache` argument during the `layer.apply()`
12
+ * forward propagation.
13
+ *
14
+ * If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
15
+ */
16
+ export class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
17
+ static className = "CachedRoPEMultiHeadAttention";
18
+
19
+ protected rope: tf.layers.Layer;
20
+
21
+ constructor(args: MultiHeadAttentionArgs) {
22
+ super(args);
23
+ this.rope = new RotaryPositionEmbedding({ dim: Math.floor(this.embedDim / this.numHeads) });
24
+ }
25
+
26
+
27
+ protected override forward(
28
+ query_input: tf.Tensor,
29
+ key_input: tf.Tensor,
30
+ value_input: tf.Tensor,
31
+ packing_mask: tf.Tensor | null,
32
+ causal_mask: tf.Tensor | null,
33
+ kwargs: Kwargs): tf.Tensor {
34
+
35
+ return tf.tidy(() => {
36
+ const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
37
+
38
+ // swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
39
+ const move_head_dim_forward = [0, 2, 1, 3];
40
+
41
+ const split = this.splitHeads(query, key, value, move_head_dim_forward);
42
+
43
+ const query_split = split.query_split;
44
+ let key_split = split.key_split;
45
+ let value_split = split.value_split;
46
+
47
+ if (kwargs.training !== true && kwargs.kvCache) {
48
+ // runs on inference, updates the KV cache and get the historical key and value
49
+ const cached_kv = this.getCachedKV(
50
+ kwargs.kvCache as KvCacheContainer, key_split, value_split);
51
+
52
+ key_split = cached_kv.keyCache;
53
+ value_split = cached_kv.valueCache;
54
+ }
55
+
56
+ // apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
57
+ const spda = MultiHeadAttention.scaledDotProductionAttention(
58
+ query_split, key_split, value_split,
59
+ kwargs.attentionMask ?? null, packing_mask, causal_mask,
60
+ this.dropout, this.causal, kwargs);
61
+
62
+ // concat heads and apply the output projection
63
+ const output = this.outputProjection.apply(
64
+ spda.transpose(move_head_dim_forward).reshape(
65
+ [query_input.shape[0], query_input.shape[1]!, this.embedDim]));
66
+
67
+ return output as tf.Tensor;
68
+ })
69
+ }
70
+
71
+
72
+ protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D) {
73
+ try {
74
+ let kv_cache = kv_container.update(this.name, key_split, value_split);
75
+
76
+ if (!kv_cache) {
77
+ kv_container.create(this.name, {
78
+ batchSize: key_split.shape[0],
79
+ numHeads: this.numHeads,
80
+ headDim: this.embedDim / this.numHeads,
81
+ })
82
+
83
+ kv_cache = kv_container.update(this.name, key_split, value_split)!;
84
+ }
85
+
86
+ return kv_cache!;
87
+ } catch (error: any) {
88
+ throw Error(`${this.getClassName()}::getCachedKV ${this.name} ${error.toString()}`);
89
+ }
90
+ }
91
+
92
+
93
+ /**
94
+ * Adds RoPE position encoding right after splitting heads.
95
+ */
96
+ protected override splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]) {
97
+ const batch_size = query.shape[0];
98
+ const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
99
+
100
+ return tf.tidy(() => {
101
+ return {
102
+ query_split: (this.rope.apply(query.reshape(split_heads)) as tf.Tensor)
103
+ .transpose(shuffle) as tf.Tensor4D,
104
+ key_split: (this.rope.apply(key.reshape(split_heads)) as tf.Tensor)
105
+ .transpose(shuffle) as tf.Tensor4D,
106
+ value_split: value.reshape(split_heads).transpose(shuffle) as tf.Tensor4D
107
+ }
108
+ })
109
+ }
110
+ }
111
+
112
+
113
+ tf.serialization.registerClass(CachedRoPEMultiHeadAttention);
@@ -0,0 +1,77 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+
4
+ import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
5
+ import { TransformerDecoder, type TransformerDecoderArgs } from "@/layers/transformer_decoder";
6
+
7
+
8
+ export interface GPTDecoderBlockArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
9
+ dimsFeedForward?: number;
10
+ }
11
+
12
+
13
+ /**
14
+ * This implements the GPT-2 transformer block by modifying the transformer
15
+ * decoder block to use pre-layer-normalization and replacing ReLU activation
16
+ * with GELU.
17
+ *
18
+ * @param numHeads number of attention heads to use
19
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
20
+ * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
21
+ * @param dropout use dropout during the attention calculations, default `0.1`
22
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
23
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
24
+ */
25
+ export class GPT2DecoderBlock extends TransformerDecoder {
26
+ static className = "GPT2DecoderBlock";
27
+
28
+
29
+ constructor(args: TransformerDecoderArgs) {
30
+ super(args);
31
+ }
32
+
33
+
34
+ /**
35
+ * Attention sub-block which is similar to the original transformer except
36
+ * layer normalization is applied beginning
37
+ */
38
+ protected override causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
39
+ return tf.tidy(() => {
40
+ const residual = x;
41
+
42
+ let attention = this.causalSelfAttentionNorm.apply(x, kwargs) as tf.Tensor;
43
+ attention = this.causalSelfAttention.apply(attention, kwargs) as tf.Tensor;
44
+ attention = this.causalSelfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
45
+ attention = tf.add(attention, residual);
46
+
47
+ return attention;
48
+ });
49
+ }
50
+
51
+
52
+ /**
53
+ * Feedforward sub-block which is similar to the original transformer except
54
+ * layer normalization is applied at the beginning and gelu activation is used
55
+ */
56
+ protected override feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
57
+ return tf.tidy(() => {
58
+ const residual = x;
59
+
60
+ let feedForward = this.feedFowardNorm.apply(x, kwargs);
61
+ feedForward = this.feedforward1.apply(feedForward, kwargs);
62
+ feedForward = this.feedforward2.apply(feedForward, kwargs);
63
+ feedForward = this.feedForwardDropout.apply(feedForward, kwargs) as tf.Tensor;
64
+ feedForward = tf.add(feedForward, residual);
65
+
66
+ return feedForward;
67
+ });
68
+ }
69
+
70
+
71
+ // the build() function does not need overriding because the layer normalization
72
+ // outputs the same shape as its input, its position as a sub-layer doesn't affect
73
+ // other sub-layer weight and output shapes
74
+ }
75
+
76
+
77
+ tf.serialization.registerClass(GPT2DecoderBlock);
@@ -0,0 +1,212 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+
3
+ import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
4
+ import { generateCausalAttentionMask } from '@/utils';
5
+ import { MultiHeadAttention } from '@/layers/multihead_attention';
6
+
7
+ // disables warning for using the faster node backend,
8
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
9
+ tf.env().set('IS_NODE', false);
10
+
11
+
12
+ describe("MultiHeadAttention tests", () => {
13
+ it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
14
+ expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 3, embedDim: 10 })).toThrow();
15
+ expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 15, embedDim: 60 })).not.toThrow();
16
+ })
17
+
18
+
19
+ test("successfull forward calls", () => {
20
+ const input = tf.randomUniform([2, 3, 12]);
21
+
22
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
23
+ expect(() => attention.apply(input)).not.toThrow();
24
+ expect(() => attention.apply([input])).not.toThrow();
25
+
26
+ const causal = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
27
+ expect(() => causal.apply(input)).not.toThrow();
28
+ expect(() => causal.apply([input])).not.toThrow();
29
+ })
30
+
31
+
32
+ test("query and value must have the same shape for scaled dot product attention to succeed", () => {
33
+ const query = tf.randomUniform([2, 3, 12]);
34
+ const key = tf.randomUniform([2, 3, 12]);
35
+ const value = tf.randomUniform([2, 3, 12]);
36
+ const value_thats_too_long = tf.randomUniform([2, 100, 12]);
37
+
38
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1)! });
39
+ expect(() => attention.apply([query, key, value])).not.toThrow();
40
+ expect(() => attention.apply([query, key, value_thats_too_long])).toThrow();
41
+ })
42
+
43
+
44
+ it("should only accept rank 3 tensors", () => {
45
+ const embed_dims = 12;
46
+
47
+ const BAD_RANK2 = tf.randomUniform([2, embed_dims]);
48
+ const GOOD = tf.randomUniform([2, 3, embed_dims]);
49
+ const BAD_RANK4 = tf.randomUniform([2, 3, 10, embed_dims]);
50
+
51
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: embed_dims });
52
+
53
+ // BAD
54
+ expect(() => attention.apply(BAD_RANK2)).toThrow();
55
+ expect(() => attention.apply([BAD_RANK2])).toThrow();
56
+ expect(() => attention.apply([BAD_RANK2, BAD_RANK2, BAD_RANK2])).toThrow();
57
+
58
+ // OK
59
+ expect(() => attention.apply(GOOD)).not.toThrow();
60
+ expect(() => attention.apply([GOOD])).not.toThrow();
61
+ expect(() => attention.apply([GOOD, GOOD, GOOD])).not.toThrow();
62
+
63
+ // BAD
64
+ expect(() => attention.apply(BAD_RANK4)).toThrow();
65
+ expect(() => attention.apply([BAD_RANK4])).toThrow();
66
+ expect(() => attention.apply([BAD_RANK4, BAD_RANK4, BAD_RANK4])).toThrow();
67
+
68
+ // BAD
69
+ expect(() => attention.apply([GOOD, BAD_RANK2, BAD_RANK4])).toThrow();
70
+ expect(() => attention.apply([BAD_RANK2, GOOD, BAD_RANK4])).toThrow();
71
+ expect(() => attention.apply([BAD_RANK2, BAD_RANK4, GOOD])).toThrow();
72
+ expect(() => attention.apply([BAD_RANK2, GOOD, GOOD])).toThrow();
73
+ expect(() => attention.apply([GOOD, GOOD, BAD_RANK4])).toThrow();
74
+ })
75
+
76
+
77
+ it("should only 1 or 3 inputs total", () => {
78
+ const input = tf.randomUniform([2, 3, 12]);
79
+
80
+ let attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
81
+
82
+ // OK
83
+ expect(() => attention.apply(input, { packingMask: undefined })).not.toThrow();
84
+ expect(() => attention.apply([input])).not.toThrow();
85
+ // reinitialize to rerun build()
86
+ attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
87
+ expect(() => attention.apply([input, input, input])).not.toThrow();
88
+
89
+ // BAD
90
+ expect(() => attention.apply([])).toThrow();
91
+ expect(() => attention.apply([input, input])).toThrow();
92
+ // reinitialize to rerun build()
93
+ attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
94
+ expect(() => attention.apply([input, input, input, input])).toThrow();
95
+ })
96
+
97
+
98
+ test("attention masking", () => {
99
+ const query = tf.randomUniform([2, 3, 12]);
100
+ const key = tf.randomUniform([2, 3, 12]);
101
+ const value = tf.randomUniform([2, 3, 12]);
102
+
103
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1)!, causal: true });
104
+
105
+ expect(() => attention.call(query, {})).not.toThrow();
106
+
107
+ // cross attention
108
+ expect(() => attention.call([query, key, value], {})).not.toThrow();
109
+
110
+
111
+ const query5 = tf.randomUniform([2, 5, 10]);
112
+ const key4 = tf.randomUniform([2, 4, 10]);
113
+ const value5 = tf.randomUniform([2, 4, 10]);
114
+
115
+ const expected_mask = tf.tensor([[
116
+ // vertical represents query, false means that token cannot attend to the keys
117
+ // horizontal represents key, false means that token cannot attend to the queries
118
+ [false, false, false, false],
119
+ [true, true, true, false,],
120
+ [true, true, true, false,],
121
+ [false, false, false, false],
122
+ [true, true, true, false,],
123
+ ]]);
124
+
125
+ const packing_mask = tf.tensor([
126
+ [0, 0, 0, -1e7, -1e7],
127
+ [0, 0, 0, -1e7, -1e7],
128
+ [0, 0, 0, -1e7, -1e7],
129
+ [-1e7, -1e7, -1e7, 0, 0],
130
+ [-1e7, -1e7, -1e7, 0, 0]
131
+ ])
132
+
133
+ // for causal attention, the attention mask must be boolean
134
+ expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0.1, true, { scaling_factor: 10 })).toThrow();
135
+ // for causal attention, using pre-calculated causal mask
136
+ expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, generateCausalAttentionMask(query5.shape[1]!, key4.shape[1]!), 0.2, true, { scaling_factor: 10 })).toThrow();
137
+ // when not using causal attention, the attention mask can be a float32 tensor
138
+ expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0, false)).not.toThrow();
139
+ // packing mask for self attention
140
+ expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, query5, query5, null, packing_mask, null, 0.9, true)).not.toThrow();
141
+ })
142
+
143
+
144
+ it("should return a non-empty config dict", () => {
145
+ const input = tf.randomUniform([2, 3, 10]);
146
+
147
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)! });
148
+ expect(Object.keys(attention.getConfig())).not.toBe(0);
149
+ })
150
+
151
+
152
+ test("causal attention hard coded values", () => {
153
+ // input and output shapes: [2, 3, 10]
154
+ const input = tf.tensor([
155
+ [[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
156
+ [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
157
+ [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
158
+
159
+ [[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
160
+ [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
161
+ [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
162
+ ]);
163
+
164
+ const expected = tf.tensor([
165
+ [[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
166
+ [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
167
+ [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
168
+
169
+ [[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
170
+ [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
171
+ [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
172
+ ]);
173
+
174
+
175
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)!, causal: true });
176
+ attention.build(input.shape);
177
+ attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
178
+
179
+ expect(expected.sub(attention.apply(input) as tf.Tensor).sum().dataSync()[0]).toBeLessThan(1e-6);
180
+ })
181
+
182
+
183
+ test("non-causal attention hard coded values", () => {
184
+ // input and output shapes: [2, 3, 10]
185
+ const input = tf.tensor([
186
+ [[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
187
+ [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
188
+ [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
189
+
190
+ [[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
191
+ [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
192
+ [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
193
+ ]);
194
+
195
+
196
+ const expected = tf.tensor([
197
+ [[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
198
+ [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
199
+ [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
200
+
201
+ [[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
202
+ [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
203
+ [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
204
+ ]);
205
+
206
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)!, causal: false });
207
+ attention.build(input.shape);
208
+ attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
209
+
210
+ expect(expected.sub(attention.apply(input) as tf.Tensor).sum().dataSync()[0]).toBeLessThan(1e-6);
211
+ });
212
+ });