@stellarapp/tfjs-stellar 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +17 -0
  2. package/dist/index.d.ts +3 -1
  3. package/dist/index.d.ts.map +1 -1
  4. package/dist/index.js +3 -1
  5. package/dist/index.js.map +1 -1
  6. package/dist/kv_cache.d.ts +2 -0
  7. package/dist/kv_cache.d.ts.map +1 -1
  8. package/dist/kv_cache.js +6 -0
  9. package/dist/kv_cache.js.map +1 -1
  10. package/dist/models/index.d.ts +2 -1
  11. package/dist/models/index.d.ts.map +1 -1
  12. package/dist/models/index.js +2 -1
  13. package/dist/models/index.js.map +1 -1
  14. package/package.json +1 -1
  15. package/dist/jest.config.d.ts +0 -8
  16. package/dist/jest.config.d.ts.map +0 -1
  17. package/dist/jest.config.js +0 -147
  18. package/dist/jest.config.js.map +0 -1
  19. package/dist/src/index.d.ts +0 -6
  20. package/dist/src/index.d.ts.map +0 -1
  21. package/dist/src/index.js +0 -6
  22. package/dist/src/index.js.map +0 -1
  23. package/dist/src/kv_cache.d.ts +0 -53
  24. package/dist/src/kv_cache.d.ts.map +0 -1
  25. package/dist/src/kv_cache.js +0 -135
  26. package/dist/src/kv_cache.js.map +0 -1
  27. package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
  28. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
  29. package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
  30. package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
  31. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
  32. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
  33. package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
  34. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
  35. package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
  36. package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
  37. package/dist/src/layers/gpt_decoder_block.js +0 -51
  38. package/dist/src/layers/gpt_decoder_block.js.map +0 -1
  39. package/dist/src/layers/index.d.ts +0 -17
  40. package/dist/src/layers/index.d.ts.map +0 -1
  41. package/dist/src/layers/index.js +0 -33
  42. package/dist/src/layers/index.js.map +0 -1
  43. package/dist/src/layers/multihead_attention.d.ts +0 -106
  44. package/dist/src/layers/multihead_attention.d.ts.map +0 -1
  45. package/dist/src/layers/multihead_attention.js +0 -269
  46. package/dist/src/layers/multihead_attention.js.map +0 -1
  47. package/dist/src/layers/multihead_attention.test.d.ts +0 -2
  48. package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
  49. package/dist/src/layers/multihead_attention.test.js +0 -160
  50. package/dist/src/layers/multihead_attention.test.js.map +0 -1
  51. package/dist/src/layers/positional_encoding.d.ts +0 -37
  52. package/dist/src/layers/positional_encoding.d.ts.map +0 -1
  53. package/dist/src/layers/positional_encoding.js +0 -115
  54. package/dist/src/layers/positional_encoding.js.map +0 -1
  55. package/dist/src/layers/positional_encoding.test.d.ts +0 -2
  56. package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
  57. package/dist/src/layers/positional_encoding.test.js +0 -95
  58. package/dist/src/layers/positional_encoding.test.js.map +0 -1
  59. package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
  60. package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
  61. package/dist/src/layers/rotary_position_embedding.js +0 -99
  62. package/dist/src/layers/rotary_position_embedding.js.map +0 -1
  63. package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
  64. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
  65. package/dist/src/layers/rotary_position_embedding.test.js +0 -88
  66. package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
  67. package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
  68. package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
  69. package/dist/src/layers/token_and_positional_embedding.js +0 -109
  70. package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
  71. package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
  72. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
  73. package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
  74. package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
  75. package/dist/src/layers/transformer_decoder.d.ts +0 -69
  76. package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
  77. package/dist/src/layers/transformer_decoder.js +0 -182
  78. package/dist/src/layers/transformer_decoder.js.map +0 -1
  79. package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
  80. package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
  81. package/dist/src/layers/transformer_decoder.test.js +0 -72
  82. package/dist/src/layers/transformer_decoder.test.js.map +0 -1
  83. package/dist/src/layers/transformer_encoder.d.ts +0 -55
  84. package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
  85. package/dist/src/layers/transformer_encoder.js +0 -175
  86. package/dist/src/layers/transformer_encoder.js.map +0 -1
  87. package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
  88. package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
  89. package/dist/src/layers/transformer_encoder.test.js +0 -58
  90. package/dist/src/layers/transformer_encoder.test.js.map +0 -1
  91. package/dist/src/losses/dice.d.ts +0 -30
  92. package/dist/src/losses/dice.d.ts.map +0 -1
  93. package/dist/src/losses/dice.js +0 -93
  94. package/dist/src/losses/dice.js.map +0 -1
  95. package/dist/src/losses/index.d.ts +0 -2
  96. package/dist/src/losses/index.d.ts.map +0 -1
  97. package/dist/src/losses/index.js +0 -2
  98. package/dist/src/losses/index.js.map +0 -1
  99. package/dist/src/masks.d.ts +0 -20
  100. package/dist/src/masks.d.ts.map +0 -1
  101. package/dist/src/masks.js +0 -37
  102. package/dist/src/masks.js.map +0 -1
  103. package/dist/src/metrics.d.ts +0 -20
  104. package/dist/src/metrics.d.ts.map +0 -1
  105. package/dist/src/metrics.js +0 -28
  106. package/dist/src/metrics.js.map +0 -1
  107. package/dist/src/models/gpt_model.d.ts +0 -94
  108. package/dist/src/models/gpt_model.d.ts.map +0 -1
  109. package/dist/src/models/gpt_model.js +0 -154
  110. package/dist/src/models/gpt_model.js.map +0 -1
  111. package/dist/src/models/index.d.ts +0 -3
  112. package/dist/src/models/index.d.ts.map +0 -1
  113. package/dist/src/models/index.js +0 -3
  114. package/dist/src/models/index.js.map +0 -1
  115. package/dist/src/models/llm_model.d.ts +0 -87
  116. package/dist/src/models/llm_model.d.ts.map +0 -1
  117. package/dist/src/models/llm_model.js +0 -245
  118. package/dist/src/models/llm_model.js.map +0 -1
  119. package/dist/src/models/u_net.d.ts +0 -40
  120. package/dist/src/models/u_net.d.ts.map +0 -1
  121. package/dist/src/models/u_net.js +0 -151
  122. package/dist/src/models/u_net.js.map +0 -1
  123. package/dist/src/tfjs_types.d.ts +0 -10
  124. package/dist/src/tfjs_types.d.ts.map +0 -1
  125. package/dist/src/tfjs_types.js +0 -2
  126. package/dist/src/tfjs_types.js.map +0 -1
  127. package/dist/src/utils.d.ts +0 -28
  128. package/dist/src/utils.d.ts.map +0 -1
  129. package/dist/src/utils.js +0 -63
  130. package/dist/src/utils.js.map +0 -1
  131. package/dist/src/utils.test.d.ts +0 -2
  132. package/dist/src/utils.test.d.ts.map +0 -1
  133. package/dist/src/utils.test.js +0 -73
  134. package/dist/src/utils.test.js.map +0 -1
@@ -1,269 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { causal as generateCausalMask } from "@/masks";
3
- /**
4
- * This MultiHead Attention layer implements the algorithm as described in
5
- * the paper "Attention is all you Need" Vaswani et al., 2017.
6
- *
7
- * @param numHeads number of attention heads to use
8
- * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
9
- * @param causal use causal masking, default `false`
10
- * @param dropout use dropout during the attention calculations, default `0.0`
11
- * @param useBias use bias for the dense sublayers, default `true`
12
- *
13
- * The TensorFlow version uses tf.einsum, whose gradient op has not yet been
14
- * implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
15
- * therefore we follow the PyTorch implementation described in:
16
- * https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
17
- * https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
18
- *
19
- * This implementation is different from TensorFlow's whose attention weights
20
- * are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
21
- * are shaped [embed dim, embed dim]
22
- * https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
23
- * https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
24
- *
25
- * TODO: implement a fast track for self attention (query = key = value)
26
- * where a single dense layer combines and replaces the query, key and projection layers
27
- *
28
- * TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
29
- */
30
- export class MultiHeadAttention extends tf.layers.Layer {
31
- static className = "MultiHeadAttention";
32
- numHeads;
33
- embedDim; // size of embedding dim of inputs, also per attention head
34
- useBias;
35
- dropout;
36
- causal; // use causal attention to mask future words
37
- // projection simply means matrix multiplying query, key, and value
38
- // with weights to create a representation of the inputs
39
- queryProjection;
40
- keyProjection;
41
- valueProjection;
42
- outputProjection;
43
- constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }) {
44
- super(args);
45
- if (embedDim % numHeads != 0) {
46
- throw Error(`${this.getClassName()}::constructor ${this.name} embedDim (${embedDim}) is not divisible by numHeads (${numHeads})`);
47
- }
48
- this.numHeads = numHeads;
49
- this.embedDim = embedDim;
50
- this.useBias = useBias;
51
- this.dropout = dropout;
52
- this.causal = causal;
53
- if (this.dropout >= 1) {
54
- throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
55
- }
56
- // intialize the projection weights, this should be in the
57
- // build() function but is done here to avoid linting complaints
58
- this.queryProjection = tf.layers.dense({ useBias, units: embedDim });
59
- this.keyProjection = tf.layers.dense({ useBias, units: embedDim });
60
- this.valueProjection = tf.layers.dense({ useBias, units: embedDim });
61
- this.outputProjection = tf.layers.dense({ useBias, units: embedDim });
62
- }
63
- /**
64
- * Forward propagation. Provide one input tensor or three identical tensors to self-attention.
65
- * @param inputs a single tensor for self-attention or an array of exactly three
66
- * tensors that are either identical (self-attention) or different (cross-attention)
67
- * @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
68
- */
69
- call(inputs, kwargs) {
70
- // validate the input tensors
71
- if (!Array.isArray(inputs)) {
72
- inputs = [inputs];
73
- }
74
- // accept only 1 input (self attention) or 3 inputs (self or cross attention)
75
- if (inputs.length != 1 && inputs.length != 3) {
76
- throw Error(`${this.getClassName()}::call ${this.name} expects exactly one or three input tensors, ${inputs.length} were provided`);
77
- }
78
- for (const input of inputs) {
79
- if (input.shape.length != 3) {
80
- throw Error(`${this.getClassName()}::call ${this.name} expected input shapes of [batch, seq, embed_dim], got ${JSON.stringify(input.shape)}`);
81
- }
82
- }
83
- const [query, key, value] = inputs;
84
- const packingMask = kwargs.packingMask ?? null;
85
- const causalMask = kwargs.causalMask ?? null;
86
- return inputs.length == 3
87
- // cross-attention
88
- ? this.forward(query, key, value, packingMask, causalMask, kwargs)
89
- // self-attention
90
- : this.forward(query, query, query, packingMask, causalMask, kwargs);
91
- }
92
- /**
93
- * Forward propagation
94
- */
95
- forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
96
- // dimensions abbreviations
97
- // batch = the number of sequences in the input
98
- // seq = the length of each sequence in the input
99
- // dims = the size of each token's embedding
100
- return tf.tidy(() => {
101
- const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
102
- // swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
103
- const move_head_dim_forward = [0, 2, 1, 3];
104
- const { query_split, key_split, value_split } = this.splitHeads(query, key, value, move_head_dim_forward);
105
- // apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
106
- const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
107
- // concat heads and apply the output projection
108
- const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], -1, this.embedDim]));
109
- return output;
110
- });
111
- }
112
- applyInputProjections(query_input, key_input, value_input) {
113
- // apply input projections, this is a batched matrix multiplication operated on the last
114
- // dimension of query_input and first dimension of the dense layer weights,
115
- // [batch, seq, dims] x [dims, dims] = [batch x seq, dims] x [dims, dims] = [batch x seq, dims] = [batch, seq, dims]
116
- return tf.tidy(() => {
117
- return {
118
- query: this.queryProjection.apply(query_input),
119
- key: this.keyProjection.apply(key_input),
120
- value: this.valueProjection.apply(value_input)
121
- };
122
- });
123
- }
124
- splitHeads(query, key, value, shuffle) {
125
- // split heads and prepare for scaled dot product attention by splitting the
126
- // last dimension to get the heads, bring the heads forward
127
- // [batch, seq, dims] -> [batch, seq, heads, dims / heads] -> [batch, heads, seq, head_dim]
128
- const batch_size = query.shape[0];
129
- const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
130
- return tf.tidy(() => {
131
- return {
132
- query_split: query.reshape(split_heads).transpose(shuffle),
133
- key_split: key.reshape(split_heads).transpose(shuffle),
134
- value_split: value.reshape(split_heads).transpose(shuffle)
135
- };
136
- });
137
- }
138
- /**
139
- * Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
140
- * formula (1) of the 2017 paper Attention Is All You Need
141
- *
142
- * @param attentionMask a mask to prevent tokens from being
143
- * attended to (usually for padding tokens). It should have the shape
144
- * [batch, head, query_sequence_len, key_sequence_len]. To use in
145
- * conjunction with causal masking, the tensor should be a boolean type
146
- * where false indicates a masked token.
147
- * @param packingMask a mask to prevent tokens from attending across document boundaries
148
- */
149
- static scaledDotProductionAttention(query, key, value, attentionMask, packingMask, causalMask, dropout, causal, kwargs = {}) {
150
- return tf.tidy(() => {
151
- const { training = false, scaling_factor } = kwargs;
152
- key.shape.forEach((val, index) => {
153
- if (key.shape[index] != value.shape[index]) {
154
- throw Error(`scaledDotProductionAttention: expected key and value` +
155
- ` to have the same shape, got ${JSON.stringify(key.shape)} (key) and` +
156
- ` ${JSON.stringify(value.shape)} (value)`);
157
- }
158
- });
159
- // mask's shape is [..., seq, seq] where seq is the number of words/tokens in the input,
160
- // not adding the batch dimension yet to lessen the calculations
161
- const causal_mask_shape = [
162
- query.shape[query.shape.length - 2],
163
- key.shape[key.shape.length - 2]
164
- ];
165
- let mask = tf.zeros(causal_mask_shape);
166
- if (causal && causal_mask_shape[0] > 1) {
167
- if (attentionMask && attentionMask.dtype != "bool") {
168
- throw Error(`scaledDotProductionAttention: the attention mask must be undefined or a boolean type if used with causal attention`);
169
- }
170
- // apply a causal attention mask so that tokens can only attend to preceding tokens,
171
- // prevents looking at head
172
- if (causalMask) {
173
- mask = causalMask;
174
- }
175
- else {
176
- mask = generateCausalMask(causal_mask_shape[0], causal_mask_shape[1]);
177
- }
178
- }
179
- if (attentionMask) {
180
- if (attentionMask.dtype == "bool") {
181
- // convert the boolean mask to float
182
- // warning: do not use 1e9, it will overflow, use something smaller like 1e7
183
- mask = mask.add(attentionMask.cast("float32").sub(1).mul(1e7));
184
- }
185
- else {
186
- // this will occur only when not using causal masking,
187
- // if the attention mask is not boolean, it's assumed the masking is already calculated,
188
- mask = attentionMask;
189
- }
190
- }
191
- // 1. matrix multiply query and transposed key
192
- // 2. divide by scaling factor
193
- // 3. apply softmax to the result
194
- // 4. apply attention and/or causal mask
195
- // 5. apply dropout
196
- // 6. matrix multiply softmax result with value
197
- let pre_softmax = query
198
- .matMul(key, false, true)
199
- .div(Math.sqrt(scaling_factor ?? key.shape[key.shape.length - 1]))
200
- .add(mask);
201
- if (packingMask) {
202
- // packing mask is added separately because each mask within a batch may be different,
203
- // so it cannot be broadcasted
204
- pre_softmax = pre_softmax.add(packingMask);
205
- }
206
- const spda = tf.softmax(pre_softmax);
207
- const spda_dropout = tf.dropout(spda, training ? dropout : 0);
208
- const attention = spda_dropout.matMul(value);
209
- return attention;
210
- });
211
- }
212
- build(inputShape) {
213
- let input_shape = [];
214
- if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
215
- input_shape = inputShape;
216
- }
217
- else {
218
- input_shape = [inputShape, inputShape, inputShape];
219
- }
220
- if (input_shape.length != 1 && input_shape.length != 3) {
221
- throw Error(`${this.getClassName()}::build ${this.name} accepts either exactly one or three inputs, received ${JSON.stringify(inputShape)}`);
222
- }
223
- // initialize the sublayer weights
224
- this.queryProjection.build(input_shape[0]);
225
- this.keyProjection.build(input_shape[1]);
226
- this.valueProjection.build(input_shape[2]);
227
- this.outputProjection.build(input_shape[0]);
228
- // the sublayer weights need to be tracked by this layer otherwise
229
- // backpropagation will complain about no trainable parameters found,
230
- // this is an extra step that TF's Python version does not need
231
- this.trainableWeights = [
232
- ...this.queryProjection.trainableWeights,
233
- ...this.keyProjection.trainableWeights,
234
- ...this.valueProjection.trainableWeights,
235
- ...this.outputProjection.trainableWeights
236
- ];
237
- // rename the weights otherwise they'll take on the default naming and overlap
238
- // each other which breaks model loading due to duplicate weight names
239
- let indexing = 0;
240
- for (const weight of this.trainableWeights) {
241
- const unique_name = `${this.getClassName()}_${indexing}`;
242
- weight.name += unique_name;
243
- weight.originalName += unique_name;
244
- indexing++;
245
- }
246
- super.build(inputShape);
247
- }
248
- /**
249
- * MultiHead attention's output is the same shape the query's.
250
- */
251
- computeOutputShape(inputShape) {
252
- return Array.isArray(inputShape) && Array.isArray(inputShape[0]) ? inputShape[0] : inputShape;
253
- }
254
- getConfig() {
255
- const base_config = super.getConfig();
256
- const config = {
257
- numHeads: this.numHeads,
258
- embedDim: this.embedDim,
259
- useBias: this.useBias,
260
- causal: this.causal,
261
- dropout: this.dropout,
262
- name: this.name,
263
- };
264
- Object.assign(config, base_config);
265
- return config;
266
- }
267
- }
268
- tf.serialization.registerClass(MultiHeadAttention);
269
- //# sourceMappingURL=multihead_attention.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"multihead_attention.js","sourceRoot":"","sources":["../../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAoBvD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACrB,QAAQ,CAAS;IACjB,QAAQ,CAAS,CAAC,2DAA2D;IAC7E,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,MAAM,CAAU,CAAC,4CAA4C;IAEhF,mEAAmE;IACnE,wDAAwD;IACrC,eAAe,CAAkB;IACjC,aAAa,CAAkB;IAC/B,eAAe,CAAkB;IACjC,gBAAgB,CAAkB;IAGrD,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,GAAG,IAAI,EAAE,OAAO,GAAG,GAAG,EAAE,MAAM,GAAG,KAAK,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,QAAQ,GAAG,QAAQ,IAAI,CAAC,EAAE,CAAC;YAC3B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,cAAc,QAAQ,mCAAmC,QAAQ,GAAG,CAAC,CAAC;QACtI,CAAC;QAED,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QAErB,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,0DAA0D;QAC1D,gEAAgE;QAChE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACnE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,gBAAgB,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CACT,MAA+B,EAC/B,MAGC;QAED,6BAA6B;QAC7B,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACzB,MAAM,GAAG,CAAC,MAAM,CAAC,CAAC;QACtB,CAAC;QAED,6EAA6E;QAC7E,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC3C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,gDAAgD,MAAM,CAAC,MAAM,gBAAgB,CAAC,CAAC;QACxI,CAAC;QAED,KAAK,MAAM,KAAK,IAAI,MAAM,EAAE,CAAC;YACzB,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBAC1B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,0DAA0D,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;YAClJ,CAAC;QACL,CAAC;QAED,MAAM,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,GAAG,MAAM,CAAC;QACnC,MAAM,WAAW,GAAG,MAAM,CAAC,WAAW,IAAI,IAAI,CAAC;QAC/C,MAAM,UAAU,GAAG,MAAM,CAAC,UAAU,IAAI,IAAI,CAAC;QAE7C,OAAO,MAAM,CAAC,MAAM,IAAI,CAAC;YACrB,kBAAkB;YAClB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,GAAI,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC;YACrE,iBAAiB;YACjB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,KAAM,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC,CAAC;IAChF,CAAC;IAGD;;OAEG;IACO,OAAO,CACb,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,2BAA2B;QAC3B,+CAA+C;QAC/C,iDAAiD;QACjD,4CAA4C;QAC5C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,EACF,WAAW,EAAE,SAAS,EAAE,WAAW,EACtC,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAE9D,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAE9F,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,qBAAqB,CAAC,WAAsB,EAAE,SAAoB,EAAE,WAAsB;QAChG,wFAAwF;QACxF,2EAA2E;QAC3E,oHAAoH;QACpH,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;gBAC3D,GAAG,EAAE,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,SAAS,CAAc;gBACrD,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;aAC9D,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QACtF,4EAA4E;QAC5E,2DAA2D;QAC3D,2FAA2F;QAC3F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACzE,SAAS,EAAE,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACrE,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAgB,EAChB,GAAc,EACd,KAAgB,EAChB,aAA+B,EAC/B,WAA6B,EAC7B,UAA4B,EAC5B,OAAe,EACf,MAAe,EACf,SAA6C,EAAE;QAE/C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,QAAQ,GAAG,KAAK,EAAE,cAAc,EAAE,GAAG,MAAM,CAAC;YAEpD,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,KAAK,EAAE,EAAE;gBAC7B,IAAI,GAAG,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC;oBACzC,MAAM,KAAK,CAAC,sDAAsD;wBAC9D,gCAAgC,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,YAAY;wBACrE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;gBACnD,CAAC;YACL,CAAC,CAAC,CAAA;YAGF,wFAAwF;YACxF,gEAAgE;YAChE,MAAM,iBAAiB,GAAG;gBACtB,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;gBACnC,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;aAAC,CAAC;YAErC,IAAI,IAAI,GAAG,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;YAEvC,IAAI,MAAM,IAAI,iBAAiB,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC;gBACrC,IAAI,aAAa,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBACjD,MAAM,KAAK,CAAC,oHAAoH,CAAC,CAAC;gBACtI,CAAC;gBAED,oFAAoF;gBACpF,2BAA2B;gBAC3B,IAAI,UAAU,EAAE,CAAC;oBACb,IAAI,GAAG,UAAU,CAAC;gBACtB,CAAC;qBAAM,CAAC;oBACJ,IAAI,GAAG,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC1E,CAAC;YACL,CAAC;YAED,IAAI,aAAa,EAAE,CAAC;gBAChB,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBAChC,oCAAoC;oBACpC,4EAA4E;oBAC5E,IAAI,GAAG,IAAI,CAAC,GAAG,CAAC,aAAa,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;gBACnE,CAAC;qBAAM,CAAC;oBACJ,sDAAsD;oBACtD,wFAAwF;oBACxF,IAAI,GAAG,aAAa,CAAC;gBACzB,CAAC;YACL,CAAC;YAED,8CAA8C;YAC9C,8BAA8B;YAC9B,iCAAiC;YACjC,wCAAwC;YACxC,mBAAmB;YACnB,+CAA+C;YAC/C,IAAI,WAAW,GAAG,KAAK;iBAClB,MAAM,CAAC,GAAG,EAAE,KAAK,EAAE,IAAI,CAAC;iBACxB,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,IAAI,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC;iBACjE,GAAG,CAAC,IAAI,CAAC,CAAC;YAEf,IAAI,WAAW,EAAE,CAAC;gBACd,sFAAsF;gBACtF,8BAA8B;gBAC9B,WAAW,GAAG,WAAW,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAC/C,CAAC;YAED,MAAM,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC;YAErC,MAAM,YAAY,GAAG,EAAE,CAAC,OAAO,CAAC,IAAI,EAAE,QAAQ,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9D,MAAM,SAAS,GAAG,YAAY,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;YAE7C,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGQ,KAAK,CAAC,UAAiC;QAC5C,IAAI,WAAW,GAAe,EAAE,CAAC;QAEjC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,WAAW,GAAG,UAAwB,CAAC;QAC3C,CAAC;aAAM,CAAC;YACJ,WAAW,GAAG,CAAC,UAAsB,EAAE,UAAsB,EAAE,UAAsB,CAAC,CAAC;QAC3F,CAAC;QAED,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACrD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yDAAyD,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;QACjJ,CAAC;QAED,kCAAkC;QAClC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QACzC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,gBAAgB,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAE5C,kEAAkE;QAClE,qEAAqE;QACrE,+DAA+D;QAC/D,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,gBAAgB,CAAC,gBAAgB;SAC5C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,kBAAkB,CAAC,UAAiC;QACzD,OAAO,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;IAClG,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,IAAI,EAAE,IAAI,CAAC,IAAI;SAClB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -1,2 +0,0 @@
1
- export {};
2
- //# sourceMappingURL=multihead_attention.test.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"multihead_attention.test.d.ts","sourceRoot":"","sources":["../../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":""}
@@ -1,160 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
3
- import { causal as generateCausalMask } from "@/masks";
4
- import { MultiHeadAttention } from '@/layers/multihead_attention';
5
- // disables warning for using the faster node backend,
6
- // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
- tf.env().set('IS_NODE', false);
8
- describe("MultiHeadAttention tests", () => {
9
- it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
10
- expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 3, embedDim: 10 })).toThrow();
11
- expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 15, embedDim: 60 })).not.toThrow();
12
- });
13
- test("successfull forward calls", () => {
14
- const input = tf.randomUniform([2, 3, 12]);
15
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
16
- expect(() => attention.apply(input)).not.toThrow();
17
- expect(() => attention.apply([input])).not.toThrow();
18
- const causal = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
19
- expect(() => causal.apply(input)).not.toThrow();
20
- expect(() => causal.apply([input])).not.toThrow();
21
- });
22
- test("query and value must have the same shape for scaled dot product attention to succeed", () => {
23
- const query = tf.randomUniform([2, 3, 12]);
24
- const key = tf.randomUniform([2, 3, 12]);
25
- const value = tf.randomUniform([2, 3, 12]);
26
- const value_thats_too_long = tf.randomUniform([2, 100, 12]);
27
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1) });
28
- expect(() => attention.apply([query, key, value])).not.toThrow();
29
- expect(() => attention.apply([query, key, value_thats_too_long])).toThrow();
30
- });
31
- it("should only accept rank 3 tensors", () => {
32
- const embed_dims = 12;
33
- const BAD_RANK2 = tf.randomUniform([2, embed_dims]);
34
- const GOOD = tf.randomUniform([2, 3, embed_dims]);
35
- const BAD_RANK4 = tf.randomUniform([2, 3, 10, embed_dims]);
36
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: embed_dims });
37
- // BAD
38
- expect(() => attention.apply(BAD_RANK2)).toThrow();
39
- expect(() => attention.apply([BAD_RANK2])).toThrow();
40
- expect(() => attention.apply([BAD_RANK2, BAD_RANK2, BAD_RANK2])).toThrow();
41
- // OK
42
- expect(() => attention.apply(GOOD)).not.toThrow();
43
- expect(() => attention.apply([GOOD])).not.toThrow();
44
- expect(() => attention.apply([GOOD, GOOD, GOOD])).not.toThrow();
45
- // BAD
46
- expect(() => attention.apply(BAD_RANK4)).toThrow();
47
- expect(() => attention.apply([BAD_RANK4])).toThrow();
48
- expect(() => attention.apply([BAD_RANK4, BAD_RANK4, BAD_RANK4])).toThrow();
49
- // BAD
50
- expect(() => attention.apply([GOOD, BAD_RANK2, BAD_RANK4])).toThrow();
51
- expect(() => attention.apply([BAD_RANK2, GOOD, BAD_RANK4])).toThrow();
52
- expect(() => attention.apply([BAD_RANK2, BAD_RANK4, GOOD])).toThrow();
53
- expect(() => attention.apply([BAD_RANK2, GOOD, GOOD])).toThrow();
54
- expect(() => attention.apply([GOOD, GOOD, BAD_RANK4])).toThrow();
55
- });
56
- it("should only 1 or 3 inputs total", () => {
57
- const input = tf.randomUniform([2, 3, 12]);
58
- let attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
59
- // OK
60
- expect(() => attention.apply(input, { packingMask: undefined })).not.toThrow();
61
- expect(() => attention.apply([input])).not.toThrow();
62
- // reinitialize to rerun build()
63
- attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
64
- expect(() => attention.apply([input, input, input])).not.toThrow();
65
- // BAD
66
- expect(() => attention.apply([])).toThrow();
67
- expect(() => attention.apply([input, input])).toThrow();
68
- // reinitialize to rerun build()
69
- attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
70
- expect(() => attention.apply([input, input, input, input])).toThrow();
71
- });
72
- test("attention masking", () => {
73
- const query = tf.randomUniform([2, 3, 12]);
74
- const key = tf.randomUniform([2, 3, 12]);
75
- const value = tf.randomUniform([2, 3, 12]);
76
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1), causal: true });
77
- expect(() => attention.call(query, {})).not.toThrow();
78
- // cross attention
79
- expect(() => attention.call([query, key, value], {})).not.toThrow();
80
- const query5 = tf.randomUniform([2, 5, 10]);
81
- const key4 = tf.randomUniform([2, 4, 10]);
82
- const value5 = tf.randomUniform([2, 4, 10]);
83
- const expected_mask = tf.tensor([[
84
- // vertical represents query, false means that token cannot attend to the keys
85
- // horizontal represents key, false means that token cannot attend to the queries
86
- [false, false, false, false],
87
- [true, true, true, false,],
88
- [true, true, true, false,],
89
- [false, false, false, false],
90
- [true, true, true, false,],
91
- ]]);
92
- const packing_mask = tf.tensor([
93
- [0, 0, 0, -1e7, -1e7],
94
- [0, 0, 0, -1e7, -1e7],
95
- [0, 0, 0, -1e7, -1e7],
96
- [-1e7, -1e7, -1e7, 0, 0],
97
- [-1e7, -1e7, -1e7, 0, 0]
98
- ]);
99
- // for causal attention, the attention mask must be boolean
100
- expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0.1, true, { scaling_factor: 10 })).toThrow();
101
- // for causal attention, using pre-calculated causal mask
102
- expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, generateCausalMask(query5.shape[1], key4.shape[1]), 0.2, true, { scaling_factor: 10 })).toThrow();
103
- // when not using causal attention, the attention mask can be a float32 tensor
104
- expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0, false)).not.toThrow();
105
- // packing mask for self attention
106
- expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, query5, query5, null, packing_mask, null, 0.9, true)).not.toThrow();
107
- });
108
- it("should return a non-empty config dict", () => {
109
- const input = tf.randomUniform([2, 3, 10]);
110
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1) });
111
- expect(Object.keys(attention.getConfig())).not.toBe(0);
112
- });
113
- test("causal attention hard coded values", () => {
114
- // input and output shapes: [2, 3, 10]
115
- const input = tf.tensor([
116
- [[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
117
- [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
118
- [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
119
- [[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
120
- [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
121
- [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
122
- ]);
123
- const expected = tf.tensor([
124
- [[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
125
- [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
126
- [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
127
- [[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
128
- [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
129
- [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
130
- ]);
131
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: true });
132
- attention.build(input.shape);
133
- attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
134
- expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
135
- });
136
- test("non-causal attention hard coded values", () => {
137
- // input and output shapes: [2, 3, 10]
138
- const input = tf.tensor([
139
- [[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
140
- [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
141
- [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
142
- [[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
143
- [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
144
- [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
145
- ]);
146
- const expected = tf.tensor([
147
- [[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
148
- [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
149
- [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
150
- [[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
151
- [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
152
- [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
153
- ]);
154
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: false });
155
- attention.build(input.shape);
156
- attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
157
- expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
158
- });
159
- });
160
- //# sourceMappingURL=multihead_attention.test.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"multihead_attention.test.js","sourceRoot":"","sources":["../../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AACxF,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,SAAS,CAAC;AACvD,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACjG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAErD,MAAM,MAAM,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAC9G,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,sFAAsF,EAAE,GAAG,EAAE;QAC9F,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,oBAAoB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAChF,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,mCAAmC,EAAE,GAAG,EAAE;QACzC,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QACpD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QAE1F,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEhE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACrE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAEjG,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,WAAW,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAC/E,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACrD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC5C,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC1E,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,mBAAmB,EAAE,GAAG,EAAE;QAC3B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAEjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,KAAK,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEtD,kBAAkB;QAClB,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAGpE,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC5C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC1C,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC;gBAC7B,8EAA8E;gBAC9E,iFAAiF;gBACjF,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;aAC7B,CAAC,CAAC,CAAC;QAEJ,MAAM,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC;YAC3B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;YACxB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;SAC3B,CAAC,CAAA;QAEF,2DAA2D;QAC3D,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9K,yDAAyD;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,kBAAkB,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9N,8EAA8E;QAC9E,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,CAAC,EAAE,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACzJ,kCAAkC;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC7I,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC5C,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAEH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAGH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACjH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,wCAAwC,EAAE,GAAG,EAAE;QAChD,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAGH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAEH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,KAAK,EAAE,CAAC,CAAC;QAClH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
@@ -1,37 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
- import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
4
- export interface PositionalEncodingArgs extends LayerArgs {
5
- embedDim: number;
6
- maxSequenceLength?: number;
7
- }
8
- /**
9
- * This class implements the position encoding logic described in the
10
- * 2017 paper "Attention Is All You Need".
11
- *
12
- * This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
13
- * and adds positional encoding to return an output tensor of the same shape.
14
- *
15
- * @param embedDim the size of each token/word's embedding
16
- * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
17
- */
18
- export declare class PositionalEncoding extends tf.layers.Layer {
19
- static className: string;
20
- private readonly maxSequenceLength;
21
- private readonly embedDim;
22
- private positionalEncodings;
23
- constructor(args: PositionalEncodingArgs);
24
- /**
25
- * Forward propagation. Injects positional encoding to the input embeddings
26
- */
27
- call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
28
- /**
29
- * Generate the positional encoding from the paper Attention Is All You Need.
30
- * Note that because the inner term of the position formula is the same for both even
31
- * and odd indices, we only create half of it and apply sine and cosine individually.
32
- */
33
- build(inputShape: tf.Shape | tf.Shape[]): void;
34
- computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
35
- getConfig(): tf.serialization.ConfigDict;
36
- }
37
- //# sourceMappingURL=positional_encoding.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"positional_encoding.d.ts","sourceRoot":"","sources":["../../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IAErD,QAAQ,EAAE,MAAM,CAAC;IAEjB,iBAAiB,CAAC,EAAE,MAAM,CAAC;CAC9B;AAGD;;;;;;;;;GASG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,mBAAmB,CAAmB;gBAGlC,IAAI,EAAE,sBAAsB;IAuBxC;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF;;;;OAIG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAmD9C,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAYpD"}
@@ -1,115 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- /**
3
- * This class implements the position encoding logic described in the
4
- * 2017 paper "Attention Is All You Need".
5
- *
6
- * This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
7
- * and adds positional encoding to return an output tensor of the same shape.
8
- *
9
- * @param embedDim the size of each token/word's embedding
10
- * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
11
- */
12
- export class PositionalEncoding extends tf.layers.Layer {
13
- static className = "PositionalEncoding";
14
- maxSequenceLength;
15
- embedDim;
16
- positionalEncodings;
17
- constructor(args) {
18
- super(args);
19
- this.maxSequenceLength = args.maxSequenceLength ?? 5120;
20
- this.embedDim = args.embedDim;
21
- if (this.maxSequenceLength < 1) {
22
- throw Error(`${this.getClassName()}::constructor ${this.name} maxSequenceLength` +
23
- ` (${args.maxSequenceLength}) must be greater than 0`);
24
- }
25
- if (this.embedDim < 1) {
26
- throw Error(`${this.getClassName()}::constructor ${this.name} embedDim` +
27
- ` (${args.embedDim}) must be greater than 0`);
28
- }
29
- // positional encodings are not trainable
30
- this.positionalEncodings = this.addWeight('positional_encodings', [this.maxSequenceLength, this.embedDim], "float32", tf.initializers.zeros(), undefined, false);
31
- }
32
- /**
33
- * Forward propagation. Injects positional encoding to the input embeddings
34
- */
35
- call(inputs, kwargs) {
36
- // validate the input tensors
37
- const input = Array.isArray(inputs) ? inputs[0] : inputs;
38
- const sequences = input.shape[1];
39
- if (input.shape.length != 3 || input.shape[2] != this.embedDim) {
40
- throw Error(`${this.getClassName()}::call ${this.name} expected an input shape of` +
41
- ` [batch, (up to ${this.maxSequenceLength}), ${this.embedDim}], instead got ${input.shape}`);
42
- }
43
- if (sequences > this.maxSequenceLength) {
44
- // unexpected sequence length
45
- throw Error(`${this.getClassName()}::call ${this.name} received an input with` +
46
- ` sequence length (${sequences}) which is greater than the max sequence length` +
47
- ` ${this.maxSequenceLength}`);
48
- }
49
- // perform forward propagation
50
- return tf.tidy(() => {
51
- return input.add(this.positionalEncodings.read()
52
- .slice([0, 0], [sequences, this.embedDim]) // gets the first "sequences" rows
53
- .expandDims(0)); // introduce the batch dimension and let add() broadcast it
54
- });
55
- }
56
- /**
57
- * Generate the positional encoding from the paper Attention Is All You Need.
58
- * Note that because the inner term of the position formula is the same for both even
59
- * and odd indices, we only create half of it and apply sine and cosine individually.
60
- */
61
- build(inputShape) {
62
- tf.tidy(() => {
63
- const embedDimHalved = Math.ceil(this.embedDim / 2);
64
- // create the position matrix as [ 0, 1, 2, 3, etc ],
65
- // and broadcast it horizontally to match the number of embeddings,
66
- const numerator = tf.range(0, this.maxSequenceLength, 1)
67
- .reshape([this.maxSequenceLength, 1])
68
- // this creates an extra, unsued positional encoding column later on for odd embedding sizes
69
- .broadcastTo([this.maxSequenceLength, embedDimHalved]);
70
- // the inner term's denominator's exponent's numerator is created as
71
- // [ 0, 0, 2, 2, 4, 4, etc ] ( technically [0, 2, 4] as explained above ) and not
72
- // [ 0, 2, 4, 6, 8, 10, etc ] because the even and odd indices are counted as pairs
73
- // when incrementing "i",
74
- // the denominator formula is 10_000^(2i/d_model) where each "i" is a sine cosine pair
75
- const denominator = tf.pow(10_000, tf.range(0, this.embedDim, 2).div(this.embedDim));
76
- const inner_term = numerator.div(denominator);
77
- const sine = tf.sin(inner_term);
78
- const cosine = tf.cos(inner_term);
79
- // horizontally interweave the sine and cosine columns together to form
80
- // [sin, cos, sin, cos, etc]
81
- // [sin, cos, sin, cos, etc]
82
- // etc
83
- const interweaved = [];
84
- const ALL_ROWS = -1;
85
- const ONE_COL = 1;
86
- const FIRST_ROW = 0;
87
- for (let targetCol = 0; targetCol < this.embedDim / 2; targetCol++) {
88
- interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
89
- if (targetCol != Math.floor(this.embedDim / 2)) {
90
- // for odd numbered embedDim sizes skip the last cosine column
91
- // e.g. if embedDim = 5, create [ i=0 (sin), i=0 (cos), i=1 (sin), i=1 (cos), i=2 (sin) ]
92
- // and the final i=2 (cos) is ignored
93
- interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
94
- }
95
- }
96
- // add the positional encoding
97
- this.setWeights([tf.concat(interweaved, 1)]);
98
- });
99
- super.build(inputShape);
100
- }
101
- computeOutputShape(inputShape) {
102
- return inputShape;
103
- }
104
- getConfig() {
105
- const base_config = super.getConfig();
106
- const config = {
107
- maxSequenceLength: this.maxSequenceLength,
108
- embedDim: this.embedDim,
109
- };
110
- Object.assign(config, base_config);
111
- return config;
112
- }
113
- }
114
- tf.serialization.registerClass(PositionalEncoding);
115
- //# sourceMappingURL=positional_encoding.js.map