@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,212 +1,160 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
-
3
2
  import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
4
- import { generateCausalAttentionMask } from '@/utils';
3
+ import { causal as generateCausalMask } from "@/masks";
5
4
  import { MultiHeadAttention } from '@/layers/multihead_attention';
6
-
7
5
  // disables warning for using the faster node backend,
8
6
  // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
9
7
  tf.env().set('IS_NODE', false);
10
-
11
-
12
8
  describe("MultiHeadAttention tests", () => {
13
9
  it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
14
10
  expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 3, embedDim: 10 })).toThrow();
15
11
  expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 15, embedDim: 60 })).not.toThrow();
16
- })
17
-
18
-
12
+ });
19
13
  test("successfull forward calls", () => {
20
14
  const input = tf.randomUniform([2, 3, 12]);
21
-
22
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
15
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
23
16
  expect(() => attention.apply(input)).not.toThrow();
24
17
  expect(() => attention.apply([input])).not.toThrow();
25
-
26
- const causal = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
18
+ const causal = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
27
19
  expect(() => causal.apply(input)).not.toThrow();
28
20
  expect(() => causal.apply([input])).not.toThrow();
29
- })
30
-
31
-
21
+ });
32
22
  test("query and value must have the same shape for scaled dot product attention to succeed", () => {
33
23
  const query = tf.randomUniform([2, 3, 12]);
34
24
  const key = tf.randomUniform([2, 3, 12]);
35
25
  const value = tf.randomUniform([2, 3, 12]);
36
26
  const value_thats_too_long = tf.randomUniform([2, 100, 12]);
37
-
38
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1)! });
27
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1) });
39
28
  expect(() => attention.apply([query, key, value])).not.toThrow();
40
29
  expect(() => attention.apply([query, key, value_thats_too_long])).toThrow();
41
- })
42
-
43
-
30
+ });
44
31
  it("should only accept rank 3 tensors", () => {
45
32
  const embed_dims = 12;
46
-
47
33
  const BAD_RANK2 = tf.randomUniform([2, embed_dims]);
48
34
  const GOOD = tf.randomUniform([2, 3, embed_dims]);
49
35
  const BAD_RANK4 = tf.randomUniform([2, 3, 10, embed_dims]);
50
-
51
36
  const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: embed_dims });
52
-
53
37
  // BAD
54
38
  expect(() => attention.apply(BAD_RANK2)).toThrow();
55
39
  expect(() => attention.apply([BAD_RANK2])).toThrow();
56
40
  expect(() => attention.apply([BAD_RANK2, BAD_RANK2, BAD_RANK2])).toThrow();
57
-
58
41
  // OK
59
42
  expect(() => attention.apply(GOOD)).not.toThrow();
60
43
  expect(() => attention.apply([GOOD])).not.toThrow();
61
44
  expect(() => attention.apply([GOOD, GOOD, GOOD])).not.toThrow();
62
-
63
45
  // BAD
64
46
  expect(() => attention.apply(BAD_RANK4)).toThrow();
65
47
  expect(() => attention.apply([BAD_RANK4])).toThrow();
66
48
  expect(() => attention.apply([BAD_RANK4, BAD_RANK4, BAD_RANK4])).toThrow();
67
-
68
49
  // BAD
69
50
  expect(() => attention.apply([GOOD, BAD_RANK2, BAD_RANK4])).toThrow();
70
51
  expect(() => attention.apply([BAD_RANK2, GOOD, BAD_RANK4])).toThrow();
71
52
  expect(() => attention.apply([BAD_RANK2, BAD_RANK4, GOOD])).toThrow();
72
53
  expect(() => attention.apply([BAD_RANK2, GOOD, GOOD])).toThrow();
73
54
  expect(() => attention.apply([GOOD, GOOD, BAD_RANK4])).toThrow();
74
- })
75
-
76
-
55
+ });
77
56
  it("should only 1 or 3 inputs total", () => {
78
57
  const input = tf.randomUniform([2, 3, 12]);
79
-
80
- let attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
81
-
58
+ let attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
82
59
  // OK
83
60
  expect(() => attention.apply(input, { packingMask: undefined })).not.toThrow();
84
61
  expect(() => attention.apply([input])).not.toThrow();
85
62
  // reinitialize to rerun build()
86
- attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
63
+ attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
87
64
  expect(() => attention.apply([input, input, input])).not.toThrow();
88
-
89
65
  // BAD
90
66
  expect(() => attention.apply([])).toThrow();
91
67
  expect(() => attention.apply([input, input])).toThrow();
92
68
  // reinitialize to rerun build()
93
- attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
69
+ attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
94
70
  expect(() => attention.apply([input, input, input, input])).toThrow();
95
- })
96
-
97
-
71
+ });
98
72
  test("attention masking", () => {
99
73
  const query = tf.randomUniform([2, 3, 12]);
100
74
  const key = tf.randomUniform([2, 3, 12]);
101
75
  const value = tf.randomUniform([2, 3, 12]);
102
-
103
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1)!, causal: true });
104
-
76
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1), causal: true });
105
77
  expect(() => attention.call(query, {})).not.toThrow();
106
-
107
78
  // cross attention
108
79
  expect(() => attention.call([query, key, value], {})).not.toThrow();
109
-
110
-
111
80
  const query5 = tf.randomUniform([2, 5, 10]);
112
81
  const key4 = tf.randomUniform([2, 4, 10]);
113
82
  const value5 = tf.randomUniform([2, 4, 10]);
114
-
115
83
  const expected_mask = tf.tensor([[
116
- // vertical represents query, false means that token cannot attend to the keys
117
- // horizontal represents key, false means that token cannot attend to the queries
118
- [false, false, false, false],
119
- [true, true, true, false,],
120
- [true, true, true, false,],
121
- [false, false, false, false],
122
- [true, true, true, false,],
123
- ]]);
124
-
84
+ // vertical represents query, false means that token cannot attend to the keys
85
+ // horizontal represents key, false means that token cannot attend to the queries
86
+ [false, false, false, false],
87
+ [true, true, true, false,],
88
+ [true, true, true, false,],
89
+ [false, false, false, false],
90
+ [true, true, true, false,],
91
+ ]]);
125
92
  const packing_mask = tf.tensor([
126
93
  [0, 0, 0, -1e7, -1e7],
127
94
  [0, 0, 0, -1e7, -1e7],
128
95
  [0, 0, 0, -1e7, -1e7],
129
96
  [-1e7, -1e7, -1e7, 0, 0],
130
97
  [-1e7, -1e7, -1e7, 0, 0]
131
- ])
132
-
98
+ ]);
133
99
  // for causal attention, the attention mask must be boolean
134
100
  expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0.1, true, { scaling_factor: 10 })).toThrow();
135
101
  // for causal attention, using pre-calculated causal mask
136
- expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, generateCausalAttentionMask(query5.shape[1]!, key4.shape[1]!), 0.2, true, { scaling_factor: 10 })).toThrow();
102
+ expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, generateCausalMask(query5.shape[1], key4.shape[1]), 0.2, true, { scaling_factor: 10 })).toThrow();
137
103
  // when not using causal attention, the attention mask can be a float32 tensor
138
104
  expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0, false)).not.toThrow();
139
105
  // packing mask for self attention
140
106
  expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, query5, query5, null, packing_mask, null, 0.9, true)).not.toThrow();
141
- })
142
-
143
-
107
+ });
144
108
  it("should return a non-empty config dict", () => {
145
109
  const input = tf.randomUniform([2, 3, 10]);
146
-
147
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)! });
110
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1) });
148
111
  expect(Object.keys(attention.getConfig())).not.toBe(0);
149
- })
150
-
151
-
112
+ });
152
113
  test("causal attention hard coded values", () => {
153
114
  // input and output shapes: [2, 3, 10]
154
115
  const input = tf.tensor([
155
116
  [[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
156
- [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
157
- [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
158
-
117
+ [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
118
+ [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
159
119
  [[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
160
- [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
161
- [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
120
+ [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
121
+ [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
162
122
  ]);
163
-
164
123
  const expected = tf.tensor([
165
124
  [[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
166
- [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
167
- [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
168
-
125
+ [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
126
+ [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
169
127
  [[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
170
- [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
171
- [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
128
+ [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
129
+ [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
172
130
  ]);
173
-
174
-
175
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)!, causal: true });
131
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: true });
176
132
  attention.build(input.shape);
177
133
  attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
178
-
179
- expect(expected.sub(attention.apply(input) as tf.Tensor).sum().dataSync()[0]).toBeLessThan(1e-6);
180
- })
181
-
182
-
134
+ expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
135
+ });
183
136
  test("non-causal attention hard coded values", () => {
184
137
  // input and output shapes: [2, 3, 10]
185
138
  const input = tf.tensor([
186
139
  [[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
187
- [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
188
- [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
189
-
140
+ [0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
141
+ [0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
190
142
  [[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
191
- [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
192
- [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
143
+ [0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
144
+ [0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
193
145
  ]);
194
-
195
-
196
146
  const expected = tf.tensor([
197
147
  [[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
198
- [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
199
- [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
200
-
148
+ [0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
149
+ [0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
201
150
  [[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
202
- [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
203
- [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
151
+ [0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
152
+ [0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
204
153
  ]);
205
-
206
- const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)!, causal: false });
154
+ const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: false });
207
155
  attention.build(input.shape);
208
156
  attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
209
-
210
- expect(expected.sub(attention.apply(input) as tf.Tensor).sum().dataSync()[0]).toBeLessThan(1e-6);
157
+ expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
211
158
  });
212
159
  });
160
+ //# sourceMappingURL=multihead_attention.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"multihead_attention.test.js","sourceRoot":"","sources":["../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AACxF,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,SAAS,CAAC;AACvD,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACjG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAErD,MAAM,MAAM,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAC9G,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,sFAAsF,EAAE,GAAG,EAAE;QAC9F,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,oBAAoB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAChF,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,mCAAmC,EAAE,GAAG,EAAE;QACzC,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QACpD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QAE1F,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEhE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACrE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAEjG,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,WAAW,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAC/E,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACrD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC5C,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC1E,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,mBAAmB,EAAE,GAAG,EAAE;QAC3B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAEjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,KAAK,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEtD,kBAAkB;QAClB,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAGpE,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC5C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC1C,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC;gBAC7B,8EAA8E;gBAC9E,iFAAiF;gBACjF,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;aAC7B,CAAC,CAAC,CAAC;QAEJ,MAAM,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC;YAC3B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;YACxB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;SAC3B,CAAC,CAAA;QAEF,2DAA2D;QAC3D,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9K,yDAAyD;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,kBAAkB,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9N,8EAA8E;QAC9E,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,CAAC,EAAE,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACzJ,kCAAkC;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC7I,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC5C,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAEH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAGH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACjH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,wCAAwC,EAAE,GAAG,EAAE;QAChD,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAGH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAEH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,KAAK,EAAE,CAAC,CAAC;QAClH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
@@ -0,0 +1,37 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
4
+ export interface PositionalEncodingArgs extends LayerArgs {
5
+ embedDim: number;
6
+ maxSequenceLength?: number;
7
+ }
8
+ /**
9
+ * This class implements the position encoding logic described in the
10
+ * 2017 paper "Attention Is All You Need".
11
+ *
12
+ * This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
13
+ * and adds positional encoding to return an output tensor of the same shape.
14
+ *
15
+ * @param embedDim the size of each token/word's embedding
16
+ * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
17
+ */
18
+ export declare class PositionalEncoding extends tf.layers.Layer {
19
+ static className: string;
20
+ private readonly maxSequenceLength;
21
+ private readonly embedDim;
22
+ private positionalEncodings;
23
+ constructor(args: PositionalEncodingArgs);
24
+ /**
25
+ * Forward propagation. Injects positional encoding to the input embeddings
26
+ */
27
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
28
+ /**
29
+ * Generate the positional encoding from the paper Attention Is All You Need.
30
+ * Note that because the inner term of the position formula is the same for both even
31
+ * and odd indices, we only create half of it and apply sine and cosine individually.
32
+ */
33
+ build(inputShape: tf.Shape | tf.Shape[]): void;
34
+ computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
35
+ getConfig(): tf.serialization.ConfigDict;
36
+ }
37
+ //# sourceMappingURL=positional_encoding.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"positional_encoding.d.ts","sourceRoot":"","sources":["../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IAErD,QAAQ,EAAE,MAAM,CAAC;IAEjB,iBAAiB,CAAC,EAAE,MAAM,CAAC;CAC9B;AAGD;;;;;;;;;GASG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,mBAAmB,CAAmB;gBAGlC,IAAI,EAAE,sBAAsB;IAuBxC;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF;;;;OAIG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAmD9C,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAYpD"}
@@ -1,112 +1,81 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
- import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
- import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
4
-
5
-
6
- export interface PositionalEncodingArgs extends LayerArgs {
7
- // embedding size of each word/token, aka d_model from the paper
8
- embedDim: number;
9
- // the max length of each sentence, any more or less are truncated or padded
10
- maxSequenceLength?: number;
11
- }
12
-
13
-
14
2
  /**
15
3
  * This class implements the position encoding logic described in the
16
4
  * 2017 paper "Attention Is All You Need".
17
- *
5
+ *
18
6
  * This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
19
7
  * and adds positional encoding to return an output tensor of the same shape.
20
- *
8
+ *
21
9
  * @param embedDim the size of each token/word's embedding
22
10
  * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
23
11
  */
24
12
  export class PositionalEncoding extends tf.layers.Layer {
25
13
  static className = "PositionalEncoding";
26
- private readonly maxSequenceLength: number;
27
- private readonly embedDim: number;
28
- private positionalEncodings: tf.LayerVariable;
29
-
30
-
31
- constructor(args: PositionalEncodingArgs) {
14
+ maxSequenceLength;
15
+ embedDim;
16
+ positionalEncodings;
17
+ constructor(args) {
32
18
  super(args);
33
-
34
19
  this.maxSequenceLength = args.maxSequenceLength ?? 5120;
35
20
  this.embedDim = args.embedDim;
36
-
37
21
  if (this.maxSequenceLength < 1) {
38
22
  throw Error(`${this.getClassName()}::constructor ${this.name} maxSequenceLength` +
39
23
  ` (${args.maxSequenceLength}) must be greater than 0`);
40
24
  }
41
-
42
25
  if (this.embedDim < 1) {
43
26
  throw Error(`${this.getClassName()}::constructor ${this.name} embedDim` +
44
27
  ` (${args.embedDim}) must be greater than 0`);
45
28
  }
46
-
47
29
  // positional encodings are not trainable
48
- this.positionalEncodings = this.addWeight('positional_encodings',
49
- [this.maxSequenceLength, this.embedDim], "float32",
50
- tf.initializers.zeros(), undefined, false);
30
+ this.positionalEncodings = this.addWeight('positional_encodings', [this.maxSequenceLength, this.embedDim], "float32", tf.initializers.zeros(), undefined, false);
51
31
  }
52
-
53
-
54
32
  /**
55
33
  * Forward propagation. Injects positional encoding to the input embeddings
56
34
  */
57
- override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
35
+ call(inputs, kwargs) {
58
36
  // validate the input tensors
59
37
  const input = Array.isArray(inputs) ? inputs[0] : inputs;
60
- const sequences = input.shape[1]!;
61
-
38
+ const sequences = input.shape[1];
62
39
  if (input.shape.length != 3 || input.shape[2] != this.embedDim) {
63
40
  throw Error(`${this.getClassName()}::call ${this.name} expected an input shape of` +
64
41
  ` [batch, (up to ${this.maxSequenceLength}), ${this.embedDim}], instead got ${input.shape}`);
65
42
  }
66
-
67
43
  if (sequences > this.maxSequenceLength) {
68
44
  // unexpected sequence length
69
45
  throw Error(`${this.getClassName()}::call ${this.name} received an input with` +
70
46
  ` sequence length (${sequences}) which is greater than the max sequence length` +
71
47
  ` ${this.maxSequenceLength}`);
72
48
  }
73
-
74
49
  // perform forward propagation
75
50
  return tf.tidy(() => {
76
51
  return input.add(this.positionalEncodings.read()
77
52
  .slice([0, 0], [sequences, this.embedDim]) // gets the first "sequences" rows
78
53
  .expandDims(0)); // introduce the batch dimension and let add() broadcast it
79
- })
54
+ });
80
55
  }
81
-
82
56
  /**
83
57
  * Generate the positional encoding from the paper Attention Is All You Need.
84
58
  * Note that because the inner term of the position formula is the same for both even
85
59
  * and odd indices, we only create half of it and apply sine and cosine individually.
86
60
  */
87
- override build(inputShape: tf.Shape | tf.Shape[]): void {
61
+ build(inputShape) {
88
62
  tf.tidy(() => {
89
63
  const embedDimHalved = Math.ceil(this.embedDim / 2);
90
-
91
64
  // create the position matrix as [ 0, 1, 2, 3, etc ],
92
65
  // and broadcast it horizontally to match the number of embeddings,
93
66
  const numerator = tf.range(0, this.maxSequenceLength, 1)
94
67
  .reshape([this.maxSequenceLength, 1])
95
68
  // this creates an extra, unsued positional encoding column later on for odd embedding sizes
96
69
  .broadcastTo([this.maxSequenceLength, embedDimHalved]);
97
-
98
70
  // the inner term's denominator's exponent's numerator is created as
99
71
  // [ 0, 0, 2, 2, 4, 4, etc ] ( technically [0, 2, 4] as explained above ) and not
100
72
  // [ 0, 2, 4, 6, 8, 10, etc ] because the even and odd indices are counted as pairs
101
73
  // when incrementing "i",
102
74
  // the denominator formula is 10_000^(2i/d_model) where each "i" is a sine cosine pair
103
75
  const denominator = tf.pow(10_000, tf.range(0, this.embedDim, 2).div(this.embedDim));
104
-
105
76
  const inner_term = numerator.div(denominator);
106
-
107
77
  const sine = tf.sin(inner_term);
108
78
  const cosine = tf.cos(inner_term);
109
-
110
79
  // horizontally interweave the sine and cosine columns together to form
111
80
  // [sin, cos, sin, cos, etc]
112
81
  // [sin, cos, sin, cos, etc]
@@ -115,44 +84,32 @@ export class PositionalEncoding extends tf.layers.Layer {
115
84
  const ALL_ROWS = -1;
116
85
  const ONE_COL = 1;
117
86
  const FIRST_ROW = 0;
118
-
119
87
  for (let targetCol = 0; targetCol < this.embedDim / 2; targetCol++) {
120
- interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
121
-
88
+ interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
122
89
  if (targetCol != Math.floor(this.embedDim / 2)) {
123
90
  // for odd numbered embedDim sizes skip the last cosine column
124
91
  // e.g. if embedDim = 5, create [ i=0 (sin), i=0 (cos), i=1 (sin), i=1 (cos), i=2 (sin) ]
125
92
  // and the final i=2 (cos) is ignored
126
- interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
93
+ interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
127
94
  }
128
95
  }
129
-
130
96
  // add the positional encoding
131
97
  this.setWeights([tf.concat(interweaved, 1)]);
132
98
  });
133
-
134
99
  super.build(inputShape);
135
100
  }
136
-
137
-
138
- override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
101
+ computeOutputShape(inputShape) {
139
102
  return inputShape;
140
103
  }
141
-
142
-
143
- override getConfig(): tf.serialization.ConfigDict {
104
+ getConfig() {
144
105
  const base_config = super.getConfig();
145
-
146
106
  const config = {
147
107
  maxSequenceLength: this.maxSequenceLength,
148
108
  embedDim: this.embedDim,
149
- }
150
-
109
+ };
151
110
  Object.assign(config, base_config);
152
-
153
111
  return config;
154
112
  }
155
113
  }
156
-
157
-
158
114
  tf.serialization.registerClass(PositionalEncoding);
115
+ //# sourceMappingURL=positional_encoding.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"positional_encoding.js","sourceRoot":"","sources":["../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAavC;;;;;;;;;GASG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACvB,iBAAiB,CAAS;IAC1B,QAAQ,CAAS;IAC1B,mBAAmB,CAAmB;IAG9C,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,iBAAiB,GAAG,IAAI,CAAC,iBAAiB,IAAI,IAAI,CAAC;QACxD,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAE9B,IAAI,IAAI,CAAC,iBAAiB,GAAG,CAAC,EAAE,CAAC;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,oBAAoB;gBAC5E,KAAK,IAAI,CAAC,iBAAiB,0BAA0B,CAAC,CAAC;QAC/D,CAAC;QAED,IAAI,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,WAAW;gBACnE,KAAK,IAAI,CAAC,QAAQ,0BAA0B,CAAC,CAAC;QACtD,CAAC;QAED,yCAAyC;QACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,SAAS,CAAC,sBAAsB,EAC5D,CAAC,IAAI,CAAC,iBAAiB,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,SAAS,EAClD,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IACnD,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;QACzD,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAElC,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,QAAQ,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,6BAA6B;gBAC9E,mBAAmB,IAAI,CAAC,iBAAiB,MAAM,IAAI,CAAC,QAAQ,kBAAkB,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC;QACrG,CAAC;QAED,IAAI,SAAS,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrC,6BAA6B;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,yBAAyB;gBAC1E,qBAAqB,SAAS,iDAAiD;gBAC/E,IAAI,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;QACtC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,mBAAmB,CAAC,IAAI,EAAE;iBAC3C,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,SAAS,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,kCAAkC;iBAC5E,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,2DAA2D;QACpF,CAAC,CAAC,CAAA;IACN,CAAC;IAED;;;;OAIG;IACM,KAAK,CAAC,UAAiC;QAC5C,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC;YAEpD,qDAAqD;YACrD,mEAAmE;YACnE,MAAM,SAAS,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;iBACnD,OAAO,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC,CAAC;gBACrC,4FAA4F;iBAC3F,WAAW,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,cAAc,CAAC,CAAC,CAAC;YAE3D,oEAAoE;YACpE,iFAAiF;YACjF,mFAAmF;YACnF,yBAAyB;YACzB,sFAAsF;YACtF,MAAM,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErF,MAAM,UAAU,GAAG,SAAS,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAE9C,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAChC,MAAM,MAAM,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAElC,uEAAuE;YACvE,4BAA4B;YAC5B,4BAA4B;YAC5B,MAAM;YACN,MAAM,WAAW,GAAG,EAAE,CAAC;YACvB,MAAM,QAAQ,GAAG,CAAC,CAAC,CAAC;YACpB,MAAM,OAAO,GAAG,CAAC,CAAC;YAClB,MAAM,SAAS,GAAG,CAAC,CAAC;YAEpB,KAAK,IAAI,SAAS,GAAG,CAAC,EAAE,SAAS,GAAG,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,SAAS,EAAE,EAAE,CAAC;gBACjE,WAAW,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAEzE,IAAI,SAAS,IAAI,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,EAAE,CAAC;oBAC7C,8DAA8D;oBAC9D,yFAAyF;oBACzF,qCAAqC;oBACrC,WAAW,CAAC,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAC/E,CAAC;YACL,CAAC;YAED,8BAA8B;YAC9B,IAAI,CAAC,UAAU,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACjD,CAAC,CAAC,CAAC;QAEH,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGQ,kBAAkB,CAAC,UAAiC;QACzD,OAAO,UAAU,CAAC;IACtB,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=positional_encoding.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"positional_encoding.test.d.ts","sourceRoot":"","sources":["../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":""}