@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,86 +1,58 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
- import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
- import { generateCausalAttentionMask } from "@/utils";
5
-
6
-
7
- export interface MultiHeadAttentionArgs extends LayerArgs {
8
- numHeads: number;
9
- embedDim: number;
10
- useBias?: boolean;
11
- dropout?: number;
12
- causal?: boolean;
13
- }
14
-
15
-
16
- export interface ScaledDotProductionAttentionKwargs {
17
- training?: boolean;
18
- dropout?: number;
19
- causal?: boolean;
20
- scaling_factor?: number;
21
- }
22
-
23
-
2
+ import { causal as generateCausalMask } from "../masks";
24
3
  /**
25
4
  * This MultiHead Attention layer implements the algorithm as described in
26
5
  * the paper "Attention is all you Need" Vaswani et al., 2017.
27
- *
6
+ *
28
7
  * @param numHeads number of attention heads to use
29
8
  * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
30
9
  * @param causal use causal masking, default `false`
31
10
  * @param dropout use dropout during the attention calculations, default `0.0`
32
11
  * @param useBias use bias for the dense sublayers, default `true`
33
- *
12
+ *
34
13
  * The TensorFlow version uses tf.einsum, whose gradient op has not yet been
35
14
  * implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
36
15
  * therefore we follow the PyTorch implementation described in:
37
16
  * https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
38
17
  * https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
39
- *
18
+ *
40
19
  * This implementation is different from TensorFlow's whose attention weights
41
20
  * are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
42
21
  * are shaped [embed dim, embed dim]
43
22
  * https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
44
23
  * https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
45
- *
24
+ *
46
25
  * TODO: implement a fast track for self attention (query = key = value)
47
26
  * where a single dense layer combines and replaces the query, key and projection layers
48
- *
27
+ *
49
28
  * TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
50
29
  */
51
30
  export class MultiHeadAttention extends tf.layers.Layer {
52
31
  static className = "MultiHeadAttention";
53
- protected readonly numHeads: number;
54
- protected readonly embedDim: number; // size of embedding dim of inputs, also per attention head
55
- protected readonly useBias: boolean;
56
- protected readonly dropout: number;
57
- protected readonly causal: boolean; // use causal attention to mask future words
58
-
32
+ numHeads;
33
+ embedDim; // size of embedding dim of inputs, also per attention head
34
+ useBias;
35
+ dropout;
36
+ causal; // use causal attention to mask future words
59
37
  // projection simply means matrix multiplying query, key, and value
60
38
  // with weights to create a representation of the inputs
61
- protected readonly queryProjection: tf.layers.Layer;
62
- protected readonly keyProjection: tf.layers.Layer;
63
- protected readonly valueProjection: tf.layers.Layer;
64
- protected readonly outputProjection: tf.layers.Layer;
65
-
66
-
67
- constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }: MultiHeadAttentionArgs) {
39
+ queryProjection;
40
+ keyProjection;
41
+ valueProjection;
42
+ outputProjection;
43
+ constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }) {
68
44
  super(args);
69
-
70
45
  if (embedDim % numHeads != 0) {
71
46
  throw Error(`${this.getClassName()}::constructor ${this.name} embedDim (${embedDim}) is not divisible by numHeads (${numHeads})`);
72
47
  }
73
-
74
48
  this.numHeads = numHeads;
75
49
  this.embedDim = embedDim;
76
50
  this.useBias = useBias;
77
51
  this.dropout = dropout;
78
52
  this.causal = causal;
79
-
80
53
  if (this.dropout >= 1) {
81
54
  throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
82
55
  }
83
-
84
56
  // intialize the projection weights, this should be in the
85
57
  // build() function but is done here to avoid linting complaints
86
58
  this.queryProjection = tf.layers.dense({ useBias, units: embedDim });
@@ -88,188 +60,134 @@ export class MultiHeadAttention extends tf.layers.Layer {
88
60
  this.valueProjection = tf.layers.dense({ useBias, units: embedDim });
89
61
  this.outputProjection = tf.layers.dense({ useBias, units: embedDim });
90
62
  }
91
-
92
-
93
63
  /**
94
64
  * Forward propagation. Provide one input tensor or three identical tensors to self-attention.
95
65
  * @param inputs a single tensor for self-attention or an array of exactly three
96
66
  * tensors that are either identical (self-attention) or different (cross-attention)
97
67
  * @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
98
68
  */
99
- override call(
100
- inputs: tf.Tensor | tf.Tensor[],
101
- kwargs: Kwargs & {
102
- packingMask?: tf.Tensor,
103
- causalMask?: tf.Tensor,
104
- }
105
- ): tf.Tensor | tf.Tensor[] {
69
+ call(inputs, kwargs) {
106
70
  // validate the input tensors
107
71
  if (!Array.isArray(inputs)) {
108
72
  inputs = [inputs];
109
73
  }
110
-
111
74
  // accept only 1 input (self attention) or 3 inputs (self or cross attention)
112
75
  if (inputs.length != 1 && inputs.length != 3) {
113
76
  throw Error(`${this.getClassName()}::call ${this.name} expects exactly one or three input tensors, ${inputs.length} were provided`);
114
77
  }
115
-
116
78
  for (const input of inputs) {
117
79
  if (input.shape.length != 3) {
118
80
  throw Error(`${this.getClassName()}::call ${this.name} expected input shapes of [batch, seq, embed_dim], got ${JSON.stringify(input.shape)}`);
119
81
  }
120
82
  }
121
-
122
83
  const [query, key, value] = inputs;
123
84
  const packingMask = kwargs.packingMask ?? null;
124
85
  const causalMask = kwargs.causalMask ?? null;
125
-
126
86
  return inputs.length == 3
127
87
  // cross-attention
128
- ? this.forward(query!, key!, value!, packingMask, causalMask, kwargs)
88
+ ? this.forward(query, key, value, packingMask, causalMask, kwargs)
129
89
  // self-attention
130
- : this.forward(query!, query!, query!, packingMask, causalMask, kwargs);
90
+ : this.forward(query, query, query, packingMask, causalMask, kwargs);
131
91
  }
132
-
133
-
134
92
  /**
135
93
  * Forward propagation
136
94
  */
137
- protected forward(
138
- query_input: tf.Tensor,
139
- key_input: tf.Tensor,
140
- value_input: tf.Tensor,
141
- packing_mask: tf.Tensor | null,
142
- causal_mask: tf.Tensor | null,
143
- kwargs: Kwargs): tf.Tensor {
144
-
95
+ forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
145
96
  // dimensions abbreviations
146
97
  // batch = the number of sequences in the input
147
98
  // seq = the length of each sequence in the input
148
99
  // dims = the size of each token's embedding
149
100
  return tf.tidy(() => {
150
101
  const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
151
-
152
102
  // swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
153
103
  const move_head_dim_forward = [0, 2, 1, 3];
154
-
155
- const {
156
- query_split, key_split, value_split
157
- } = this.splitHeads(query, key, value, move_head_dim_forward);
158
-
104
+ const { query_split, key_split, value_split } = this.splitHeads(query, key, value, move_head_dim_forward);
159
105
  // apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
160
- const spda = MultiHeadAttention.scaledDotProductionAttention(
161
- query_split, key_split, value_split,
162
- kwargs.attentionMask ?? null, packing_mask, causal_mask,
163
- this.dropout, this.causal, kwargs);
164
-
106
+ const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
165
107
  // concat heads and apply the output projection
166
- const output = this.outputProjection.apply(
167
- spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], -1, this.embedDim]));
168
-
169
- return output as tf.Tensor;
170
- })
108
+ const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], -1, this.embedDim]));
109
+ return output;
110
+ });
171
111
  }
172
-
173
-
174
- protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor) {
112
+ applyInputProjections(query_input, key_input, value_input) {
175
113
  // apply input projections, this is a batched matrix multiplication operated on the last
176
114
  // dimension of query_input and first dimension of the dense layer weights,
177
115
  // [batch, seq, dims] x [dims, dims] = [batch x seq, dims] x [dims, dims] = [batch x seq, dims] = [batch, seq, dims]
178
116
  return tf.tidy(() => {
179
117
  return {
180
- query: this.queryProjection.apply(query_input) as tf.Tensor,
181
- key: this.keyProjection.apply(key_input) as tf.Tensor,
182
- value: this.valueProjection.apply(value_input) as tf.Tensor
183
- }
184
- })
118
+ query: this.queryProjection.apply(query_input),
119
+ key: this.keyProjection.apply(key_input),
120
+ value: this.valueProjection.apply(value_input)
121
+ };
122
+ });
185
123
  }
186
-
187
-
188
- protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]) {
124
+ splitHeads(query, key, value, shuffle) {
189
125
  // split heads and prepare for scaled dot product attention by splitting the
190
126
  // last dimension to get the heads, bring the heads forward
191
127
  // [batch, seq, dims] -> [batch, seq, heads, dims / heads] -> [batch, heads, seq, head_dim]
192
128
  const batch_size = query.shape[0];
193
129
  const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
194
-
195
130
  return tf.tidy(() => {
196
131
  return {
197
- query_split: query.reshape(split_heads).transpose(shuffle) as tf.Tensor4D,
198
- key_split: key.reshape(split_heads).transpose(shuffle) as tf.Tensor4D,
199
- value_split: value.reshape(split_heads).transpose(shuffle) as tf.Tensor4D
200
- }
201
- })
132
+ query_split: query.reshape(split_heads).transpose(shuffle),
133
+ key_split: key.reshape(split_heads).transpose(shuffle),
134
+ value_split: value.reshape(split_heads).transpose(shuffle)
135
+ };
136
+ });
202
137
  }
203
-
204
-
205
138
  /**
206
139
  * Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
207
140
  * formula (1) of the 2017 paper Attention Is All You Need
208
- *
209
- * @param attentionMask a mask to prevent tokens from being
141
+ *
142
+ * @param attentionMask a mask to prevent tokens from being
210
143
  * attended to (usually for padding tokens). It should have the shape
211
144
  * [batch, head, query_sequence_len, key_sequence_len]. To use in
212
145
  * conjunction with causal masking, the tensor should be a boolean type
213
146
  * where false indicates a masked token.
214
147
  * @param packingMask a mask to prevent tokens from attending across document boundaries
215
148
  */
216
- static scaledDotProductionAttention(
217
- query: tf.Tensor,
218
- key: tf.Tensor,
219
- value: tf.Tensor,
220
- attentionMask: tf.Tensor | null,
221
- packingMask: tf.Tensor | null,
222
- causalMask: tf.Tensor | null,
223
- dropout: number,
224
- causal: boolean,
225
- kwargs: ScaledDotProductionAttentionKwargs = {}
226
- ): tf.Tensor {
149
+ static scaledDotProductionAttention(query, key, value, attentionMask, packingMask, causalMask, dropout, causal, kwargs = {}) {
227
150
  return tf.tidy(() => {
228
151
  const { training = false, scaling_factor } = kwargs;
229
-
230
152
  key.shape.forEach((val, index) => {
231
153
  if (key.shape[index] != value.shape[index]) {
232
154
  throw Error(`scaledDotProductionAttention: expected key and value` +
233
155
  ` to have the same shape, got ${JSON.stringify(key.shape)} (key) and` +
234
156
  ` ${JSON.stringify(value.shape)} (value)`);
235
157
  }
236
- })
237
-
238
-
158
+ });
239
159
  // mask's shape is [..., seq, seq] where seq is the number of words/tokens in the input,
240
160
  // not adding the batch dimension yet to lessen the calculations
241
161
  const causal_mask_shape = [
242
162
  query.shape[query.shape.length - 2],
243
- key.shape[key.shape.length - 2]];
244
-
163
+ key.shape[key.shape.length - 2]
164
+ ];
245
165
  let mask = tf.zeros(causal_mask_shape);
246
-
247
166
  if (causal && causal_mask_shape[0] > 1) {
248
167
  if (attentionMask && attentionMask.dtype != "bool") {
249
168
  throw Error(`scaledDotProductionAttention: the attention mask must be undefined or a boolean type if used with causal attention`);
250
169
  }
251
-
252
170
  // apply a causal attention mask so that tokens can only attend to preceding tokens,
253
171
  // prevents looking at head
254
172
  if (causalMask) {
255
173
  mask = causalMask;
256
- } else {
257
- mask = generateCausalAttentionMask(causal_mask_shape[0], causal_mask_shape[1]);
174
+ }
175
+ else {
176
+ mask = generateCausalMask(causal_mask_shape[0], causal_mask_shape[1]);
258
177
  }
259
178
  }
260
-
261
179
  if (attentionMask) {
262
180
  if (attentionMask.dtype == "bool") {
263
181
  // convert the boolean mask to float
264
182
  // warning: do not use 1e9, it will overflow, use something smaller like 1e7
265
183
  mask = mask.add(attentionMask.cast("float32").sub(1).mul(1e7));
266
- } else {
184
+ }
185
+ else {
267
186
  // this will occur only when not using causal masking,
268
187
  // if the attention mask is not boolean, it's assumed the masking is already calculated,
269
188
  mask = attentionMask;
270
189
  }
271
190
  }
272
-
273
191
  // 1. matrix multiply query and transposed key
274
192
  // 2. divide by scaling factor
275
193
  // 3. apply softmax to the result
@@ -280,42 +198,33 @@ export class MultiHeadAttention extends tf.layers.Layer {
280
198
  .matMul(key, false, true)
281
199
  .div(Math.sqrt(scaling_factor ?? key.shape[key.shape.length - 1]))
282
200
  .add(mask);
283
-
284
201
  if (packingMask) {
285
202
  // packing mask is added separately because each mask within a batch may be different,
286
203
  // so it cannot be broadcasted
287
204
  pre_softmax = pre_softmax.add(packingMask);
288
205
  }
289
-
290
206
  const spda = tf.softmax(pre_softmax);
291
-
292
207
  const spda_dropout = tf.dropout(spda, training ? dropout : 0);
293
208
  const attention = spda_dropout.matMul(value);
294
-
295
209
  return attention;
296
210
  });
297
211
  }
298
-
299
-
300
- override build(inputShape: tf.Shape | tf.Shape[]): void {
301
- let input_shape: tf.Shape[] = [];
302
-
212
+ build(inputShape) {
213
+ let input_shape = [];
303
214
  if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
304
- input_shape = inputShape as tf.Shape[];
305
- } else {
306
- input_shape = [inputShape as tf.Shape, inputShape as tf.Shape, inputShape as tf.Shape];
215
+ input_shape = inputShape;
216
+ }
217
+ else {
218
+ input_shape = [inputShape, inputShape, inputShape];
307
219
  }
308
-
309
220
  if (input_shape.length != 1 && input_shape.length != 3) {
310
221
  throw Error(`${this.getClassName()}::build ${this.name} accepts either exactly one or three inputs, received ${JSON.stringify(inputShape)}`);
311
222
  }
312
-
313
223
  // initialize the sublayer weights
314
224
  this.queryProjection.build(input_shape[0]);
315
225
  this.keyProjection.build(input_shape[1]);
316
226
  this.valueProjection.build(input_shape[2]);
317
227
  this.outputProjection.build(input_shape[0]);
318
-
319
228
  // the sublayer weights need to be tracked by this layer otherwise
320
229
  // backpropagation will complain about no trainable parameters found,
321
230
  // this is an extra step that TF's Python version does not need
@@ -325,33 +234,25 @@ export class MultiHeadAttention extends tf.layers.Layer {
325
234
  ...this.valueProjection.trainableWeights,
326
235
  ...this.outputProjection.trainableWeights
327
236
  ];
328
-
329
237
  // rename the weights otherwise they'll take on the default naming and overlap
330
238
  // each other which breaks model loading due to duplicate weight names
331
239
  let indexing = 0;
332
-
333
240
  for (const weight of this.trainableWeights) {
334
241
  const unique_name = `${this.getClassName()}_${indexing}`;
335
- (weight as any).name += unique_name;
336
- (weight as any).originalName += unique_name;
242
+ weight.name += unique_name;
243
+ weight.originalName += unique_name;
337
244
  indexing++;
338
245
  }
339
-
340
246
  super.build(inputShape);
341
247
  }
342
-
343
-
344
248
  /**
345
249
  * MultiHead attention's output is the same shape the query's.
346
250
  */
347
- override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
251
+ computeOutputShape(inputShape) {
348
252
  return Array.isArray(inputShape) && Array.isArray(inputShape[0]) ? inputShape[0] : inputShape;
349
253
  }
350
-
351
-
352
- override getConfig() {
254
+ getConfig() {
353
255
  const base_config = super.getConfig();
354
-
355
256
  const config = {
356
257
  numHeads: this.numHeads,
357
258
  embedDim: this.embedDim,
@@ -359,13 +260,10 @@ export class MultiHeadAttention extends tf.layers.Layer {
359
260
  causal: this.causal,
360
261
  dropout: this.dropout,
361
262
  name: this.name,
362
- }
363
-
263
+ };
364
264
  Object.assign(config, base_config);
365
-
366
265
  return config;
367
266
  }
368
267
  }
369
-
370
-
371
268
  tf.serialization.registerClass(MultiHeadAttention);
269
+ //# sourceMappingURL=multihead_attention.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"multihead_attention.js","sourceRoot":"","sources":["../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,UAAU,CAAC;AAoBxD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACrB,QAAQ,CAAS;IACjB,QAAQ,CAAS,CAAC,2DAA2D;IAC7E,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,MAAM,CAAU,CAAC,4CAA4C;IAEhF,mEAAmE;IACnE,wDAAwD;IACrC,eAAe,CAAkB;IACjC,aAAa,CAAkB;IAC/B,eAAe,CAAkB;IACjC,gBAAgB,CAAkB;IAGrD,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,GAAG,IAAI,EAAE,OAAO,GAAG,GAAG,EAAE,MAAM,GAAG,KAAK,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,QAAQ,GAAG,QAAQ,IAAI,CAAC,EAAE,CAAC;YAC3B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,cAAc,QAAQ,mCAAmC,QAAQ,GAAG,CAAC,CAAC;QACtI,CAAC;QAED,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QAErB,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,0DAA0D;QAC1D,gEAAgE;QAChE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACnE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,gBAAgB,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CACT,MAA+B,EAC/B,MAGC;QAED,6BAA6B;QAC7B,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACzB,MAAM,GAAG,CAAC,MAAM,CAAC,CAAC;QACtB,CAAC;QAED,6EAA6E;QAC7E,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC3C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,gDAAgD,MAAM,CAAC,MAAM,gBAAgB,CAAC,CAAC;QACxI,CAAC;QAED,KAAK,MAAM,KAAK,IAAI,MAAM,EAAE,CAAC;YACzB,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBAC1B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,0DAA0D,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;YAClJ,CAAC;QACL,CAAC;QAED,MAAM,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,GAAG,MAAM,CAAC;QACnC,MAAM,WAAW,GAAG,MAAM,CAAC,WAAW,IAAI,IAAI,CAAC;QAC/C,MAAM,UAAU,GAAG,MAAM,CAAC,UAAU,IAAI,IAAI,CAAC;QAE7C,OAAO,MAAM,CAAC,MAAM,IAAI,CAAC;YACrB,kBAAkB;YAClB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,GAAI,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC;YACrE,iBAAiB;YACjB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,KAAM,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC,CAAC;IAChF,CAAC;IAGD;;OAEG;IACO,OAAO,CACb,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,2BAA2B;QAC3B,+CAA+C;QAC/C,iDAAiD;QACjD,4CAA4C;QAC5C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,EACF,WAAW,EAAE,SAAS,EAAE,WAAW,EACtC,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAE9D,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAE9F,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,qBAAqB,CAAC,WAAsB,EAAE,SAAoB,EAAE,WAAsB;QAChG,wFAAwF;QACxF,2EAA2E;QAC3E,oHAAoH;QACpH,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;gBAC3D,GAAG,EAAE,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,SAAS,CAAc;gBACrD,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;aAC9D,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QACtF,4EAA4E;QAC5E,2DAA2D;QAC3D,2FAA2F;QAC3F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACzE,SAAS,EAAE,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACrE,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAgB,EAChB,GAAc,EACd,KAAgB,EAChB,aAA+B,EAC/B,WAA6B,EAC7B,UAA4B,EAC5B,OAAe,EACf,MAAe,EACf,SAA6C,EAAE;QAE/C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,QAAQ,GAAG,KAAK,EAAE,cAAc,EAAE,GAAG,MAAM,CAAC;YAEpD,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,KAAK,EAAE,EAAE;gBAC7B,IAAI,GAAG,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC;oBACzC,MAAM,KAAK,CAAC,sDAAsD;wBAC9D,gCAAgC,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,YAAY;wBACrE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;gBACnD,CAAC;YACL,CAAC,CAAC,CAAA;YAGF,wFAAwF;YACxF,gEAAgE;YAChE,MAAM,iBAAiB,GAAG;gBACtB,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;gBACnC,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;aAAC,CAAC;YAErC,IAAI,IAAI,GAAG,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;YAEvC,IAAI,MAAM,IAAI,iBAAiB,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC;gBACrC,IAAI,aAAa,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBACjD,MAAM,KAAK,CAAC,oHAAoH,CAAC,CAAC;gBACtI,CAAC;gBAED,oFAAoF;gBACpF,2BAA2B;gBAC3B,IAAI,UAAU,EAAE,CAAC;oBACb,IAAI,GAAG,UAAU,CAAC;gBACtB,CAAC;qBAAM,CAAC;oBACJ,IAAI,GAAG,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC1E,CAAC;YACL,CAAC;YAED,IAAI,aAAa,EAAE,CAAC;gBAChB,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBAChC,oCAAoC;oBACpC,4EAA4E;oBAC5E,IAAI,GAAG,IAAI,CAAC,GAAG,CAAC,aAAa,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;gBACnE,CAAC;qBAAM,CAAC;oBACJ,sDAAsD;oBACtD,wFAAwF;oBACxF,IAAI,GAAG,aAAa,CAAC;gBACzB,CAAC;YACL,CAAC;YAED,8CAA8C;YAC9C,8BAA8B;YAC9B,iCAAiC;YACjC,wCAAwC;YACxC,mBAAmB;YACnB,+CAA+C;YAC/C,IAAI,WAAW,GAAG,KAAK;iBAClB,MAAM,CAAC,GAAG,EAAE,KAAK,EAAE,IAAI,CAAC;iBACxB,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,IAAI,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC;iBACjE,GAAG,CAAC,IAAI,CAAC,CAAC;YAEf,IAAI,WAAW,EAAE,CAAC;gBACd,sFAAsF;gBACtF,8BAA8B;gBAC9B,WAAW,GAAG,WAAW,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAC/C,CAAC;YAED,MAAM,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC;YAErC,MAAM,YAAY,GAAG,EAAE,CAAC,OAAO,CAAC,IAAI,EAAE,QAAQ,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9D,MAAM,SAAS,GAAG,YAAY,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;YAE7C,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGQ,KAAK,CAAC,UAAiC;QAC5C,IAAI,WAAW,GAAe,EAAE,CAAC;QAEjC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,WAAW,GAAG,UAAwB,CAAC;QAC3C,CAAC;aAAM,CAAC;YACJ,WAAW,GAAG,CAAC,UAAsB,EAAE,UAAsB,EAAE,UAAsB,CAAC,CAAC;QAC3F,CAAC;QAED,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACrD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yDAAyD,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;QACjJ,CAAC;QAED,kCAAkC;QAClC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QACzC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,gBAAgB,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAE5C,kEAAkE;QAClE,qEAAqE;QACrE,+DAA+D;QAC/D,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,gBAAgB,CAAC,gBAAgB;SAC5C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,kBAAkB,CAAC,UAAiC;QACzD,OAAO,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;IAClG,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,IAAI,EAAE,IAAI,CAAC,IAAI;SAClB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=multihead_attention.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"multihead_attention.test.d.ts","sourceRoot":"","sources":["../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":""}