@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,113 +1,76 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { KvCacheContainer } from "@/kv_cache";
3
- import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
2
+ import { MultiHeadAttention } from '@/layers/multihead_attention';
4
3
  import { RotaryPositionEmbedding } from '@/layers/rotary_position_embedding';
5
- import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
6
-
7
-
8
4
  /**
9
5
  * MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
10
6
  * should be used in a custom training loop because it requires the cache to be
11
7
  * passed through the `kwargs.kvCache` argument during the `layer.apply()`
12
8
  * forward propagation.
13
- *
9
+ *
14
10
  * If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
15
11
  */
16
12
  export class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
17
13
  static className = "CachedRoPEMultiHeadAttention";
18
-
19
- protected rope: tf.layers.Layer;
20
-
21
- constructor(args: MultiHeadAttentionArgs) {
14
+ rope;
15
+ constructor(args) {
22
16
  super(args);
23
17
  this.rope = new RotaryPositionEmbedding({ dim: Math.floor(this.embedDim / this.numHeads) });
24
18
  }
25
-
26
-
27
- protected override forward(
28
- query_input: tf.Tensor,
29
- key_input: tf.Tensor,
30
- value_input: tf.Tensor,
31
- packing_mask: tf.Tensor | null,
32
- causal_mask: tf.Tensor | null,
33
- kwargs: Kwargs): tf.Tensor {
34
-
19
+ forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
35
20
  return tf.tidy(() => {
36
21
  const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
37
-
38
22
  // swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
39
23
  const move_head_dim_forward = [0, 2, 1, 3];
40
-
41
24
  const split = this.splitHeads(query, key, value, move_head_dim_forward);
42
-
43
25
  const query_split = split.query_split;
44
26
  let key_split = split.key_split;
45
27
  let value_split = split.value_split;
46
-
47
28
  if (kwargs.training !== true && kwargs.kvCache) {
48
29
  // runs on inference, updates the KV cache and get the historical key and value
49
- const cached_kv = this.getCachedKV(
50
- kwargs.kvCache as KvCacheContainer, key_split, value_split);
51
-
30
+ const cached_kv = this.getCachedKV(kwargs.kvCache, key_split, value_split);
52
31
  key_split = cached_kv.keyCache;
53
32
  value_split = cached_kv.valueCache;
54
33
  }
55
-
56
34
  // apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
57
- const spda = MultiHeadAttention.scaledDotProductionAttention(
58
- query_split, key_split, value_split,
59
- kwargs.attentionMask ?? null, packing_mask, causal_mask,
60
- this.dropout, this.causal, kwargs);
61
-
35
+ const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
62
36
  // concat heads and apply the output projection
63
- const output = this.outputProjection.apply(
64
- spda.transpose(move_head_dim_forward).reshape(
65
- [query_input.shape[0], query_input.shape[1]!, this.embedDim]));
66
-
67
- return output as tf.Tensor;
68
- })
37
+ const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], query_input.shape[1], this.embedDim]));
38
+ return output;
39
+ });
69
40
  }
70
-
71
-
72
- protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D) {
41
+ getCachedKV(kv_container, key_split, value_split) {
73
42
  try {
74
43
  let kv_cache = kv_container.update(this.name, key_split, value_split);
75
-
76
44
  if (!kv_cache) {
77
45
  kv_container.create(this.name, {
78
46
  batchSize: key_split.shape[0],
79
47
  numHeads: this.numHeads,
80
48
  headDim: this.embedDim / this.numHeads,
81
- })
82
-
83
- kv_cache = kv_container.update(this.name, key_split, value_split)!;
49
+ });
50
+ kv_cache = kv_container.update(this.name, key_split, value_split);
84
51
  }
85
-
86
- return kv_cache!;
87
- } catch (error: any) {
52
+ return kv_cache;
53
+ }
54
+ catch (error) {
88
55
  throw Error(`${this.getClassName()}::getCachedKV ${this.name} ${error.toString()}`);
89
56
  }
90
57
  }
91
-
92
-
93
58
  /**
94
59
  * Adds RoPE position encoding right after splitting heads.
95
60
  */
96
- protected override splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]) {
61
+ splitHeads(query, key, value, shuffle) {
97
62
  const batch_size = query.shape[0];
98
63
  const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
99
-
100
64
  return tf.tidy(() => {
101
65
  return {
102
- query_split: (this.rope.apply(query.reshape(split_heads)) as tf.Tensor)
103
- .transpose(shuffle) as tf.Tensor4D,
104
- key_split: (this.rope.apply(key.reshape(split_heads)) as tf.Tensor)
105
- .transpose(shuffle) as tf.Tensor4D,
106
- value_split: value.reshape(split_heads).transpose(shuffle) as tf.Tensor4D
107
- }
108
- })
66
+ query_split: this.rope.apply(query.reshape(split_heads))
67
+ .transpose(shuffle),
68
+ key_split: this.rope.apply(key.reshape(split_heads))
69
+ .transpose(shuffle),
70
+ value_split: value.reshape(split_heads).transpose(shuffle)
71
+ };
72
+ });
109
73
  }
110
74
  }
111
-
112
-
113
75
  tf.serialization.registerClass(CachedRoPEMultiHeadAttention);
76
+ //# sourceMappingURL=cached_rope_multihead_attention.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"cached_rope_multihead_attention.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAC/F,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAI7E;;;;;;;GAOG;AACH,MAAM,OAAO,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,GAAG,8BAA8B,CAAC;IAExC,IAAI,CAAkB;IAEhC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;IAChG,CAAC;IAGkB,OAAO,CACtB,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAExE,MAAM,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YACtC,IAAI,SAAS,GAAG,KAAK,CAAC,SAAS,CAAC;YAChC,IAAI,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YAEpC,IAAI,MAAM,CAAC,QAAQ,KAAK,IAAI,IAAI,MAAM,CAAC,OAAO,EAAE,CAAC;gBAC7C,+EAA+E;gBAC/E,MAAM,SAAS,GAAG,IAAI,CAAC,WAAW,CAC9B,MAAM,CAAC,OAA2B,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;gBAEhE,SAAS,GAAG,SAAS,CAAC,QAAQ,CAAC;gBAC/B,WAAW,GAAG,SAAS,CAAC,UAAU,CAAC;YACvC,CAAC;YAED,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CACzC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,WAAW,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAEvE,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,WAAW,CAAC,YAA8B,EAAE,SAAsB,EAAE,WAAwB;QAClG,IAAI,CAAC;YACD,IAAI,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAEtE,IAAI,CAAC,QAAQ,EAAE,CAAC;gBACZ,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE;oBAC3B,SAAS,EAAE,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC;oBAC7B,QAAQ,EAAE,IAAI,CAAC,QAAQ;oBACvB,OAAO,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ;iBACzC,CAAC,CAAA;gBAEF,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAE,CAAC;YACvE,CAAC;YAED,OAAO,QAAS,CAAC;QACrB,CAAC;QAAC,OAAO,KAAU,EAAE,CAAC;YAClB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,IAAI,KAAK,CAAC,QAAQ,EAAE,EAAE,CAAC,CAAC;QACxF,CAAC;IACL,CAAC;IAGD;;OAEG;IACgB,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QAC/F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAClE,SAAS,CAAC,OAAO,CAAgB;gBACtC,SAAS,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAC9D,SAAS,CAAC,OAAO,CAAgB;gBACtC,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,4BAA4B,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=cached_rope_multihead_attention.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"cached_rope_multihead_attention.test.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":""}
@@ -0,0 +1,43 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { KvCacheContainer } from '@/kv_cache';
3
+ import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
4
+ // disables warning for using the faster node backend,
5
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
6
+ tf.env().set('IS_NODE', false);
7
+ describe("CachedRoPEMultiHeadAttention tests", () => {
8
+ test("aggregate forward passes output are identical normal multihead attention", () => {
9
+ compareNormalWithCachedAttention(tf.randomUniform([2, 10, 16]), 123);
10
+ compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 123);
11
+ compareNormalWithCachedAttention(tf.randomUniform([1, 1, 16]), 123);
12
+ compareNormalWithCachedAttention(tf.randomUniform([3, 2, 16]), 123);
13
+ // input exceeds KV cach size
14
+ expect(() => compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 5)).toThrow();
15
+ function compareNormalWithCachedAttention(input, max_sequence_length) {
16
+ const embed_dim = input.shape[2];
17
+ const batch = input.shape[0];
18
+ const heads = 2;
19
+ const kv_cache = new KvCacheContainer(max_sequence_length);
20
+ const normal_mha = new CachedRoPEMultiHeadAttention({ numHeads: heads, embedDim: embed_dim, causal: true });
21
+ const normal_mha_output = normal_mha.apply(input);
22
+ // initialize cached attention with identical configuration and weights
23
+ const cached_mha1 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test1" });
24
+ cached_mha1.build(input.shape);
25
+ cached_mha1.setWeights(normal_mha.getWeights());
26
+ const cached_mha2 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test2" });
27
+ cached_mha2.build(input.shape);
28
+ cached_mha2.setWeights(normal_mha.getWeights());
29
+ const cached_mha_outputs1 = [];
30
+ const cached_mha_outputs2 = [];
31
+ for (let i = 0; i < input.shape[1]; i++) {
32
+ const current_token = input.slice([0, i, 0], [batch, 1, embed_dim]);
33
+ cached_mha_outputs1.push(cached_mha1.apply(current_token, { kvCache: kv_cache }));
34
+ cached_mha_outputs2.push(cached_mha2.apply(current_token, { kvCache: kv_cache }));
35
+ }
36
+ expect(kv_cache.size == input.shape[1]);
37
+ expect(kv_cache.size == input.shape[1]);
38
+ expect(normal_mha_output.sub(tf.concat(cached_mha_outputs1, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
39
+ expect(normal_mha_output.sub(tf.concat(cached_mha_outputs2, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
40
+ }
41
+ });
42
+ });
43
+ //# sourceMappingURL=cached_rope_multihead_attention.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"cached_rope_multihead_attention.test.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AAGxF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,0EAA0E,EAAE,GAAG,EAAE;QAClF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAChF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAEhF,6BAA6B;QAC7B,MAAM,CAAC,GAAG,EAAE,CAAC,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEvG,SAAS,gCAAgC,CAAC,KAAkB,EAAE,mBAA2B;YACrF,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,KAAK,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC7B,MAAM,KAAK,GAAG,CAAC,CAAC;YAEhB,MAAM,QAAQ,GAAG,IAAI,gBAAgB,CAAC,mBAAmB,CAAC,CAAC;YAE3D,MAAM,UAAU,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;YAC5G,MAAM,iBAAiB,GAAG,UAAU,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;YAE/D,uEAAuE;YACvE,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAC5C,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAE5C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBACtC,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;gBAEpE,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;gBAC/F,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;YACnG,CAAC;YAED,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACvC,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YAExC,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;YACxG,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QAC5G,CAAC;IACL,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
@@ -0,0 +1,34 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+ import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
4
+ import { TransformerDecoder, type TransformerDecoderArgs } from "@/layers/transformer_decoder";
5
+ export interface GPTDecoderBlockArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
6
+ dimsFeedForward?: number;
7
+ }
8
+ /**
9
+ * This implements the GPT-2 transformer block by modifying the transformer
10
+ * decoder block to use pre-layer-normalization and replacing ReLU activation
11
+ * with GELU.
12
+ *
13
+ * @param numHeads number of attention heads to use
14
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
15
+ * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
16
+ * @param dropout use dropout during the attention calculations, default `0.1`
17
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
18
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
19
+ */
20
+ export declare class GPT2DecoderBlock extends TransformerDecoder {
21
+ static className: string;
22
+ constructor(args: TransformerDecoderArgs);
23
+ /**
24
+ * Attention sub-block which is similar to the original transformer except
25
+ * layer normalization is applied beginning
26
+ */
27
+ protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
28
+ /**
29
+ * Feedforward sub-block which is similar to the original transformer except
30
+ * layer normalization is applied at the beginning and gelu activation is used
31
+ */
32
+ protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
33
+ }
34
+ //# sourceMappingURL=gpt_decoder_block.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"gpt_decoder_block.d.ts","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAC3E,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,mBAAoB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAC/E,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;GAWG;AACH,qBAAa,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,SAAsB;gBAG1B,IAAI,EAAE,sBAAsB;IAKxC;;;OAGG;cACgB,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAcpF;;;OAGG;cACgB,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;CAkB/E"}
@@ -0,0 +1,51 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { TransformerDecoder } from "@/layers/transformer_decoder";
3
+ /**
4
+ * This implements the GPT-2 transformer block by modifying the transformer
5
+ * decoder block to use pre-layer-normalization and replacing ReLU activation
6
+ * with GELU.
7
+ *
8
+ * @param numHeads number of attention heads to use
9
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
10
+ * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
11
+ * @param dropout use dropout during the attention calculations, default `0.1`
12
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
13
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
14
+ */
15
+ export class GPT2DecoderBlock extends TransformerDecoder {
16
+ static className = "GPT2DecoderBlock";
17
+ constructor(args) {
18
+ super(args);
19
+ }
20
+ /**
21
+ * Attention sub-block which is similar to the original transformer except
22
+ * layer normalization is applied beginning
23
+ */
24
+ causalSelfAttentionBlock(x, kwargs) {
25
+ return tf.tidy(() => {
26
+ const residual = x;
27
+ let attention = this.causalSelfAttentionNorm.apply(x, kwargs);
28
+ attention = this.causalSelfAttention.apply(attention, kwargs);
29
+ attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
30
+ attention = tf.add(attention, residual);
31
+ return attention;
32
+ });
33
+ }
34
+ /**
35
+ * Feedforward sub-block which is similar to the original transformer except
36
+ * layer normalization is applied at the beginning and gelu activation is used
37
+ */
38
+ feedForwardBlock(x, kwargs) {
39
+ return tf.tidy(() => {
40
+ const residual = x;
41
+ let feedForward = this.feedFowardNorm.apply(x, kwargs);
42
+ feedForward = this.feedforward1.apply(feedForward, kwargs);
43
+ feedForward = this.feedforward2.apply(feedForward, kwargs);
44
+ feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
45
+ feedForward = tf.add(feedForward, residual);
46
+ return feedForward;
47
+ });
48
+ }
49
+ }
50
+ tf.serialization.registerClass(GPT2DecoderBlock);
51
+ //# sourceMappingURL=gpt_decoder_block.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"gpt_decoder_block.js","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAQ/F;;;;;;;;;;;GAWG;AACH,MAAM,OAAO,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,GAAG,kBAAkB,CAAC;IAGtC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;IAChB,CAAC;IAGD;;;OAGG;IACgB,wBAAwB,CAAC,CAAY,EAAE,MAAc;QACpE,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YAExC,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;;OAGG;IACgB,gBAAgB,CAAC,CAAY,EAAE,MAAc;QAC5D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACvD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAE5C,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;;AASL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC"}
@@ -0,0 +1,17 @@
1
+ import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
2
+ import { GPT2DecoderBlock, GPTDecoderBlockArgs } from "./gpt_decoder_block";
3
+ import { MultiHeadAttention, MultiHeadAttentionArgs } from "./multihead_attention";
4
+ import { PositionalEncoding, PositionalEncodingArgs } from "./positional_encoding";
5
+ import { RotaryPositionEmbedding, RotaryPositionEmbeddingArgs } from "./rotary_position_embedding";
6
+ import { TokenAndPositionalEmbedding, TokenAndPositionalEmbeddingArgs } from "./token_and_positional_embedding";
7
+ import { TransformerDecoder, TransformerDecoderArgs } from "./transformer_decoder";
8
+ import { TransformerEncoder, TransformerEncoderArgs } from "./transformer_encoder";
9
+ export declare function tokenAndPositionalEmbedding(args: TokenAndPositionalEmbeddingArgs): TokenAndPositionalEmbedding;
10
+ export declare function transformerEncoder(args: TransformerEncoderArgs): TransformerEncoder;
11
+ export declare function transformerDecoder(args: TransformerDecoderArgs): TransformerDecoder;
12
+ export declare function multiheadAttention(args: MultiHeadAttentionArgs): MultiHeadAttention;
13
+ export declare function cachedRopeMultiheadAttention(args: MultiHeadAttentionArgs): CachedRoPEMultiHeadAttention;
14
+ export declare function positionalEncoding(args: PositionalEncodingArgs): PositionalEncoding;
15
+ export declare function gpt2DecoderBlock(args: GPTDecoderBlockArgs): GPT2DecoderBlock;
16
+ export declare function rotaryPositionEmbedding(args: RotaryPositionEmbeddingArgs): RotaryPositionEmbedding;
17
+ //# sourceMappingURL=index.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAAE,2BAA2B,EAAE,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAE,+BAA+B,EAAE,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AAGnF,wBAAgB,2BAA2B,CAAC,IAAI,EAAE,+BAA+B,+BAEhF;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,4BAA4B,CAAC,IAAI,EAAE,sBAAsB,gCAExE;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,gBAAgB,CAAC,IAAI,EAAE,mBAAmB,oBAEzD;AAGD,wBAAgB,uBAAuB,CAAC,IAAI,EAAE,2BAA2B,2BAExE"}
@@ -0,0 +1,33 @@
1
+ import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
2
+ import { GPT2DecoderBlock } from "./gpt_decoder_block";
3
+ import { MultiHeadAttention } from "./multihead_attention";
4
+ import { PositionalEncoding } from "./positional_encoding";
5
+ import { RotaryPositionEmbedding } from "./rotary_position_embedding";
6
+ import { TokenAndPositionalEmbedding } from "./token_and_positional_embedding";
7
+ import { TransformerDecoder } from "./transformer_decoder";
8
+ import { TransformerEncoder } from "./transformer_encoder";
9
+ export function tokenAndPositionalEmbedding(args) {
10
+ return new TokenAndPositionalEmbedding(args);
11
+ }
12
+ export function transformerEncoder(args) {
13
+ return new TransformerEncoder(args);
14
+ }
15
+ export function transformerDecoder(args) {
16
+ return new TransformerDecoder(args);
17
+ }
18
+ export function multiheadAttention(args) {
19
+ return new MultiHeadAttention(args);
20
+ }
21
+ export function cachedRopeMultiheadAttention(args) {
22
+ return new CachedRoPEMultiHeadAttention(args);
23
+ }
24
+ export function positionalEncoding(args) {
25
+ return new PositionalEncoding(args);
26
+ }
27
+ export function gpt2DecoderBlock(args) {
28
+ return new GPT2DecoderBlock(args);
29
+ }
30
+ export function rotaryPositionEmbedding(args) {
31
+ return new RotaryPositionEmbedding(args);
32
+ }
33
+ //# sourceMappingURL=index.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAuB,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAA+B,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAmC,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AAGnF,MAAM,UAAU,2BAA2B,CAAC,IAAqC;IAC7E,OAAO,IAAI,2BAA2B,CAAC,IAAI,CAAC,CAAC;AACjD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,4BAA4B,CAAC,IAA4B;IACrE,OAAO,IAAI,4BAA4B,CAAC,IAAI,CAAC,CAAC;AAClD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,gBAAgB,CAAC,IAAyB;IACtD,OAAO,IAAI,gBAAgB,CAAC,IAAI,CAAC,CAAC;AACtC,CAAC;AAGD,MAAM,UAAU,uBAAuB,CAAC,IAAiC;IACrE,OAAO,IAAI,uBAAuB,CAAC,IAAI,CAAC,CAAC;AAC7C,CAAC"}
@@ -0,0 +1,106 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
+ import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
+ export interface MultiHeadAttentionArgs extends LayerArgs {
5
+ numHeads: number;
6
+ embedDim: number;
7
+ useBias?: boolean;
8
+ dropout?: number;
9
+ causal?: boolean;
10
+ }
11
+ export interface ScaledDotProductionAttentionKwargs {
12
+ training?: boolean;
13
+ dropout?: number;
14
+ causal?: boolean;
15
+ scaling_factor?: number;
16
+ }
17
+ /**
18
+ * This MultiHead Attention layer implements the algorithm as described in
19
+ * the paper "Attention is all you Need" Vaswani et al., 2017.
20
+ *
21
+ * @param numHeads number of attention heads to use
22
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
23
+ * @param causal use causal masking, default `false`
24
+ * @param dropout use dropout during the attention calculations, default `0.0`
25
+ * @param useBias use bias for the dense sublayers, default `true`
26
+ *
27
+ * The TensorFlow version uses tf.einsum, whose gradient op has not yet been
28
+ * implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
29
+ * therefore we follow the PyTorch implementation described in:
30
+ * https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
31
+ * https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
32
+ *
33
+ * This implementation is different from TensorFlow's whose attention weights
34
+ * are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
35
+ * are shaped [embed dim, embed dim]
36
+ * https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
37
+ * https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
38
+ *
39
+ * TODO: implement a fast track for self attention (query = key = value)
40
+ * where a single dense layer combines and replaces the query, key and projection layers
41
+ *
42
+ * TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
43
+ */
44
+ export declare class MultiHeadAttention extends tf.layers.Layer {
45
+ static className: string;
46
+ protected readonly numHeads: number;
47
+ protected readonly embedDim: number;
48
+ protected readonly useBias: boolean;
49
+ protected readonly dropout: number;
50
+ protected readonly causal: boolean;
51
+ protected readonly queryProjection: tf.layers.Layer;
52
+ protected readonly keyProjection: tf.layers.Layer;
53
+ protected readonly valueProjection: tf.layers.Layer;
54
+ protected readonly outputProjection: tf.layers.Layer;
55
+ constructor({ numHeads, embedDim, useBias, dropout, causal, ...args }: MultiHeadAttentionArgs);
56
+ /**
57
+ * Forward propagation. Provide one input tensor or three identical tensors to self-attention.
58
+ * @param inputs a single tensor for self-attention or an array of exactly three
59
+ * tensors that are either identical (self-attention) or different (cross-attention)
60
+ * @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
61
+ */
62
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs & {
63
+ packingMask?: tf.Tensor;
64
+ causalMask?: tf.Tensor;
65
+ }): tf.Tensor | tf.Tensor[];
66
+ /**
67
+ * Forward propagation
68
+ */
69
+ protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
70
+ protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor): {
71
+ query: tf.Tensor;
72
+ key: tf.Tensor;
73
+ value: tf.Tensor;
74
+ };
75
+ protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
76
+ query_split: tf.Tensor4D;
77
+ key_split: tf.Tensor4D;
78
+ value_split: tf.Tensor4D;
79
+ };
80
+ /**
81
+ * Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
82
+ * formula (1) of the 2017 paper Attention Is All You Need
83
+ *
84
+ * @param attentionMask a mask to prevent tokens from being
85
+ * attended to (usually for padding tokens). It should have the shape
86
+ * [batch, head, query_sequence_len, key_sequence_len]. To use in
87
+ * conjunction with causal masking, the tensor should be a boolean type
88
+ * where false indicates a masked token.
89
+ * @param packingMask a mask to prevent tokens from attending across document boundaries
90
+ */
91
+ static scaledDotProductionAttention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, attentionMask: tf.Tensor | null, packingMask: tf.Tensor | null, causalMask: tf.Tensor | null, dropout: number, causal: boolean, kwargs?: ScaledDotProductionAttentionKwargs): tf.Tensor;
92
+ build(inputShape: tf.Shape | tf.Shape[]): void;
93
+ /**
94
+ * MultiHead attention's output is the same shape the query's.
95
+ */
96
+ computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
97
+ getConfig(): {
98
+ numHeads: number;
99
+ embedDim: number;
100
+ useBias: boolean;
101
+ causal: boolean;
102
+ dropout: number;
103
+ name: string;
104
+ };
105
+ }
106
+ //# sourceMappingURL=multihead_attention.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAIjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IACrD,QAAQ,EAAE,MAAM,CAAC;IACjB,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,CAAC,EAAE,OAAO,CAAC;IAClB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD,MAAM,WAAW,kCAAkC;IAC/C,QAAQ,CAAC,EAAE,OAAO,CAAC;IACnB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;IACjB,cAAc,CAAC,EAAE,MAAM,CAAC;CAC3B;AAGD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,MAAM,EAAE,OAAO,CAAC;IAInC,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,aAAa,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAClD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,gBAAgB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAGzC,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAc,EAAE,OAAa,EAAE,MAAc,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IA0BlH;;;;;OAKG;IACM,IAAI,CACT,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAC/B,MAAM,EAAE,MAAM,GAAG;QACb,WAAW,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;QACxB,UAAU,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;KAC1B,GACF,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IA6B1B;;OAEG;IACH,SAAS,CAAC,OAAO,CACb,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IA+B9B,SAAS,CAAC,qBAAqB,CAAC,WAAW,EAAE,EAAE,CAAC,MAAM,EAAE,SAAS,EAAE,EAAE,CAAC,MAAM,EAAE,WAAW,EAAE,EAAE,CAAC,MAAM;eAMtC,EAAE,CAAC,MAAM;aACf,EAAE,CAAC,MAAM;eACH,EAAE,CAAC,MAAM;;IAMvE,SAAS,CAAC,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAShB,EAAE,CAAC,QAAQ;mBACf,EAAE,CAAC,QAAQ;qBACP,EAAE,CAAC,QAAQ;;IAMrF;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,GAAG,EAAE,EAAE,CAAC,MAAM,EACd,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,aAAa,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC/B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,UAAU,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC5B,OAAO,EAAE,MAAM,EACf,MAAM,EAAE,OAAO,EACf,MAAM,GAAE,kCAAuC,GAChD,EAAE,CAAC,MAAM;IA0EH,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA4CvD;;OAEG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS;;;;;;;;CAgBrB"}