@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,113 +1,95 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
-
3
2
  import { PositionalEncoding } from '@/layers/positional_encoding';
4
-
5
3
  // disables warning for using the faster node backend,
6
4
  // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
5
  tf.env().set('IS_NODE', false);
8
-
9
-
10
6
  describe("PositionalEncoding tests", () => {
11
7
  it("should fail to instantiate a layer", () => {
12
8
  expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: 0 })).toThrow();
13
9
  expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: -1 })).toThrow();
14
10
  expect(() => new PositionalEncoding({ maxSequenceLength: 0, embedDim: 32 })).toThrow();
15
11
  expect(() => new PositionalEncoding({ maxSequenceLength: -1, embedDim: 32 })).toThrow();
16
- })
17
-
18
-
12
+ });
19
13
  test("successfull forward calls", () => {
20
14
  const embed_dims = 32;
21
15
  const sequences = 4;
22
16
  const input = tf.randomUniform([2, sequences, embed_dims]);
23
-
24
17
  const positional = new PositionalEncoding({ embedDim: embed_dims });
25
18
  expect(() => positional.apply(input)).not.toThrow();
26
19
  expect(() => positional.apply([input])).not.toThrow();
27
20
  expect(positional.computeOutputShape(input.shape)).toEqual(input.shape);
28
- })
29
-
30
-
21
+ });
31
22
  it("should throw when input sequences are too large, embedding dims don't match, input aren't rank 3", () => {
32
23
  const sequences_too_long = tf.randomUniform([100, 32]);
33
24
  const embeddings_too_large = tf.randomUniform([32, 100]);
34
25
  const wrong_rank = tf.randomUniform([10, 32, 32]);
35
-
36
26
  const positional = new PositionalEncoding({ maxSequenceLength: 10, embedDim: 32 });
37
-
38
27
  expect(() => positional.apply(sequences_too_long)).toThrow();
39
28
  expect(() => positional.apply(embeddings_too_large)).toThrow();
40
29
  expect(() => positional.apply(wrong_rank)).toThrow();
41
- })
42
-
43
-
30
+ });
44
31
  it("should return a non-empty config dict", () => {
45
32
  const attention = new PositionalEncoding({ embedDim: 32 });
46
33
  expect(Object.keys(attention.getConfig())).not.toBe(0);
47
- })
48
-
49
-
34
+ });
50
35
  // PyTorch implementation at found at
51
36
  // https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
52
37
  it("should be within 1e-6 of PyTorch's implementation", () => {
53
38
  const pytorch_embed4 = tf.tensor([
54
39
  [[0.0000000, 1.0000000, 0.0000000, 1.0000000],
55
- [0.8414710, 0.5403023, 0.0099998, 0.9999500],
56
- [0.9092974, -0.4161468, 0.0199987, 0.9998000],
57
- [0.1411200, -0.9899925, 0.0299955, 0.9995500],
58
- [-0.7568025, -0.6536436, 0.0399893, 0.9992001],
59
- [-0.9589243, 0.2836622, 0.0499792, 0.9987503],
60
- [-0.2794155, 0.9601703, 0.0599640, 0.9982005],
61
- [0.6569866, 0.7539023, 0.0699428, 0.9975510],
62
- [0.9893582, -0.1455000, 0.0799147, 0.9968017],
63
- [0.4121185, -0.9111302, 0.0898785, 0.9959527]]]);
64
-
40
+ [0.8414710, 0.5403023, 0.0099998, 0.9999500],
41
+ [0.9092974, -0.4161468, 0.0199987, 0.9998000],
42
+ [0.1411200, -0.9899925, 0.0299955, 0.9995500],
43
+ [-0.7568025, -0.6536436, 0.0399893, 0.9992001],
44
+ [-0.9589243, 0.2836622, 0.0499792, 0.9987503],
45
+ [-0.2794155, 0.9601703, 0.0599640, 0.9982005],
46
+ [0.6569866, 0.7539023, 0.0699428, 0.9975510],
47
+ [0.9893582, -0.1455000, 0.0799147, 0.9968017],
48
+ [0.4121185, -0.9111302, 0.0898785, 0.9959527]]
49
+ ]);
65
50
  const pytorch_embed8 = tf.tensor([
66
51
  [[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00,
67
- 0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00],
68
- [8.4147096e-01, 5.4030234e-01, 9.9833414e-02, 9.9500418e-01,
69
- 9.9998331e-03, 9.9994999e-01, 9.9999981e-04, 9.9999952e-01],
70
- [9.0929741e-01, -4.1614684e-01, 1.9866931e-01, 9.8006660e-01,
71
- 1.9998666e-02, 9.9980003e-01, 1.9999985e-03, 9.9999803e-01],
72
- [1.4112000e-01, -9.8999250e-01, 2.9552019e-01, 9.5533651e-01,
73
- 2.9995499e-02, 9.9955004e-01, 2.9999954e-03, 9.9999553e-01],
74
- [-7.5680250e-01, -6.5364361e-01, 3.8941833e-01, 9.2106098e-01,
75
- 3.9989334e-02, 9.9920011e-01, 3.9999890e-03, 9.9999201e-01],
76
- [-9.5892429e-01, 2.8366220e-01, 4.7942552e-01, 8.7758255e-01,
77
- 4.9979165e-02, 9.9875027e-01, 4.9999789e-03, 9.9998754e-01],
78
- [-2.7941549e-01, 9.6017027e-01, 5.6464243e-01, 8.2533562e-01,
79
- 5.9964005e-02, 9.9820054e-01, 5.9999637e-03, 9.9998200e-01],
80
- [6.5698659e-01, 7.5390226e-01, 6.4421761e-01, 7.6484221e-01,
81
- 6.9942847e-02, 9.9755102e-01, 6.9999420e-03, 9.9997550e-01],
82
- [9.8935825e-01, -1.4550003e-01, 7.1735609e-01, 6.9670677e-01,
83
- 7.9914689e-02, 9.9680167e-01, 7.9999138e-03, 9.9996799e-01],
84
- [4.1211849e-01, -9.1113025e-01, 7.8332686e-01, 6.2160999e-01,
85
- 8.9878544e-02, 9.9595273e-01, 8.9998785e-03, 9.9995953e-01]]]);
86
-
52
+ 0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00],
53
+ [8.4147096e-01, 5.4030234e-01, 9.9833414e-02, 9.9500418e-01,
54
+ 9.9998331e-03, 9.9994999e-01, 9.9999981e-04, 9.9999952e-01],
55
+ [9.0929741e-01, -4.1614684e-01, 1.9866931e-01, 9.8006660e-01,
56
+ 1.9998666e-02, 9.9980003e-01, 1.9999985e-03, 9.9999803e-01],
57
+ [1.4112000e-01, -9.8999250e-01, 2.9552019e-01, 9.5533651e-01,
58
+ 2.9995499e-02, 9.9955004e-01, 2.9999954e-03, 9.9999553e-01],
59
+ [-7.5680250e-01, -6.5364361e-01, 3.8941833e-01, 9.2106098e-01,
60
+ 3.9989334e-02, 9.9920011e-01, 3.9999890e-03, 9.9999201e-01],
61
+ [-9.5892429e-01, 2.8366220e-01, 4.7942552e-01, 8.7758255e-01,
62
+ 4.9979165e-02, 9.9875027e-01, 4.9999789e-03, 9.9998754e-01],
63
+ [-2.7941549e-01, 9.6017027e-01, 5.6464243e-01, 8.2533562e-01,
64
+ 5.9964005e-02, 9.9820054e-01, 5.9999637e-03, 9.9998200e-01],
65
+ [6.5698659e-01, 7.5390226e-01, 6.4421761e-01, 7.6484221e-01,
66
+ 6.9942847e-02, 9.9755102e-01, 6.9999420e-03, 9.9997550e-01],
67
+ [9.8935825e-01, -1.4550003e-01, 7.1735609e-01, 6.9670677e-01,
68
+ 7.9914689e-02, 9.9680167e-01, 7.9999138e-03, 9.9996799e-01],
69
+ [4.1211849e-01, -9.1113025e-01, 7.8332686e-01, 6.2160999e-01,
70
+ 8.9878544e-02, 9.9595273e-01, 8.9998785e-03, 9.9995953e-01]]
71
+ ]);
87
72
  const positional4 = new PositionalEncoding({ embedDim: 4, maxSequenceLength: 10 });
88
73
  positional4.build([]);
89
-
90
74
  const positional8 = new PositionalEncoding({ embedDim: 8, maxSequenceLength: 10 });
91
75
  positional8.build([]);
92
-
93
76
  const margin_of_error = 1e-6;
94
-
95
77
  // the difference between this and PyTorch's implementation
96
78
  //should be insignificantly small
97
- expect((positional4.getWeights()[0]
79
+ expect(positional4.getWeights()[0]
98
80
  .sub(pytorch_embed4)
99
81
  .abs()
100
- .arraySync() as [])
82
+ .arraySync()
101
83
  .flat(2)
102
84
  .filter(i => i > margin_of_error)
103
85
  .length).toBe(0);
104
-
105
- expect((positional8.getWeights()[0]
86
+ expect(positional8.getWeights()[0]
106
87
  .sub(pytorch_embed8)
107
88
  .abs()
108
- .arraySync() as [])
89
+ .arraySync()
109
90
  .flat(2)
110
91
  .filter(i => i > margin_of_error)
111
92
  .length).toBe(0);
112
93
  });
113
94
  });
95
+ //# sourceMappingURL=positional_encoding.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"positional_encoding.test.js","sourceRoot":"","sources":["../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC5F,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,UAAU,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;IAC5E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,kGAAkG,EAAE,GAAG,EAAE;QACxG,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QACvD,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QACzD,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAEnF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAC3D,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,qCAAqC;IACrC,mFAAmF;IACnF,EAAE,CAAC,mDAAmD,EAAE,GAAG,EAAE;QACzD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAAC,CAAC,CAAC;QAErD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACzD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC,CAAC;SAAC,CAAC,CAAC;QAEvE,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,eAAe,GAAG,IAAI,CAAC;QAE7B,2DAA2D;QAC3D,iCAAiC;QACjC,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAErB,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzB,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
@@ -0,0 +1,39 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type LayerArgs } from "@tensorflow/tfjs-layers/dist/engine/topology";
3
+ export declare function applyRope(x: tf.Tensor, dim: number, cosine_cache: tf.Tensor, sine_cache: tf.Tensor): tf.Tensor<tf.Rank>;
4
+ export declare function rotateHalf(x: tf.Tensor, dim: number): tf.Tensor;
5
+ export declare function createRoPECache(dim: number, max_sequence_length: number, theta?: number): tf.Tensor<tf.Rank>[];
6
+ export interface RotaryPositionEmbeddingArgs extends LayerArgs {
7
+ /**
8
+ * The dimension of each head (rounded down), e.g. `Math.floor(embedDim / numHeads)`
9
+ */
10
+ dim: number;
11
+ /**
12
+ * The RoPE cache will be pre-calculated up to the max sequence length, and re-caculated as needed. Defaults to `4096`.
13
+ */
14
+ maxSequenceLength?: number;
15
+ /**
16
+ * The base for the geometric progression used to compute the rotation angles. Defaults to `10_000`.
17
+ */
18
+ theta?: number;
19
+ }
20
+ /**
21
+ * Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
22
+ * Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
23
+ */
24
+ export declare class RotaryPositionEmbedding extends tf.layers.Layer {
25
+ static className: string;
26
+ protected dim: number;
27
+ protected max_sequence_length: number;
28
+ protected theta: number;
29
+ protected cosine_cache: tf.LayerVariable;
30
+ protected sine_cache: tf.LayerVariable;
31
+ constructor({ dim, maxSequenceLength, theta, ...args }: RotaryPositionEmbeddingArgs);
32
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[];
33
+ build(input_shape: tf.Shape | tf.Shape[]): void;
34
+ /**
35
+ * Output shape: [batch, head, sequence, head_dim]
36
+ */
37
+ computeOutputShape(input_shape: tf.Shape | tf.Shape[]): tf.Shape;
38
+ }
39
+ //# sourceMappingURL=rotary_position_embedding.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.d.ts","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAG9E,wBAAgB,SAAS,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,YAAY,EAAE,EAAE,CAAC,MAAM,EAAE,UAAU,EAAE,EAAE,CAAC,MAAM,sBAalG;AAGD,wBAAgB,UAAU,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB/D;AAGD,wBAAgB,eAAe,CAAC,GAAG,EAAE,MAAM,EAAE,mBAAmB,EAAE,MAAM,EAAE,KAAK,GAAE,MAAe,wBAqB/F;AAGD,MAAM,WAAW,2BAA4B,SAAQ,SAAS;IAC1D;;OAEG;IACH,GAAG,EAAE,MAAM,CAAC;IACZ;;OAEG;IACH,iBAAiB,CAAC,EAAE,MAAM,CAAC;IAC3B;;OAEG;IACH,KAAK,CAAC,EAAE,MAAM,CAAC;CAClB;AAGD;;;GAGG;AACH,qBAAa,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,SAA6B;IAE7C,SAAS,CAAC,GAAG,EAAE,MAAM,CAAC;IACtB,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,KAAK,EAAE,MAAM,CAAC;IAGxB,SAAS,CAAC,YAAY,EAAE,EAAE,CAAC,aAAa,CAAC;IACzC,SAAS,CAAC,UAAU,EAAE,EAAE,CAAC,aAAa,CAAC;gBAE3B,EAAE,GAAG,EAAE,iBAAwB,EAAE,KAAc,EAAE,GAAG,IAAI,EAAE,EAAE,2BAA2B;IAqB1F,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,GAAG,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAkB3E,KAAK,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAmBjD;;OAEG;IACI,kBAAkB,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;CAK/D"}
@@ -1,163 +1,99 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
- import { type LayerArgs } from "@tensorflow/tfjs-layers/dist/engine/topology";
3
-
4
-
5
- export function applyRope(x: tf.Tensor, dim: number, cosine_cache: tf.Tensor, sine_cache: tf.Tensor) {
2
+ export function applyRope(x, dim, cosine_cache, sine_cache) {
6
3
  return tf.tidy(() => {
7
- const seq_length = x.shape[2]!;
8
-
4
+ const seq_length = x.shape[2];
9
5
  // get a slice of the pre-computed cache, up to the input's sequence length
10
6
  const cosine = cosine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
11
7
  const sine = sine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
12
-
13
8
  // apply RoPE formula (x1 * cosine) + (rotate(-x2) * sine)
14
9
  const rotated_x = rotateHalf(x, dim);
15
-
16
10
  return tf.add(tf.mul(x, cosine), tf.mul(rotated_x, sine));
17
11
  });
18
12
  }
19
-
20
-
21
- export function rotateHalf(x: tf.Tensor, dim: number): tf.Tensor {
13
+ export function rotateHalf(x, dim) {
22
14
  return tf.tidy(() => {
23
15
  // reshape the last dimension such that adjacent coordinates are paired together
24
16
  // [x1, x2, x3, x4] -> [[x1, x2], [x3, x4]]
25
17
  // the leading dimensions are flattened because TFJS has issues during
26
18
  // backpropagation with 5D slicing
27
19
  const reshaped = x.reshape([-1, dim / 2, 2]);
28
-
29
20
  const x1 = reshaped.slice([0, 0, 0], [-1, -1, 1]);
30
21
  const x2 = reshaped.slice([0, 0, 1], [-1, -1, 1]);
31
-
32
22
  // [x1, x2] -> [-x2, x1]
33
23
  const rotated = tf.concat([tf.neg(x2), x1], -1);
34
-
35
24
  return rotated.reshape(x.shape);
36
25
  });
37
26
  }
38
-
39
-
40
- export function createRoPECache(dim: number, max_sequence_length: number, theta: number = 10_000) {
27
+ export function createRoPECache(dim, max_sequence_length, theta = 10_000) {
41
28
  return tf.tidy(() => {
42
29
  // [dim]
43
- const inv_frequencies = tf.div<tf.Tensor1D>(1, tf.pow(
44
- theta,
45
- tf.range(0, Math.floor(dim / 2) * 2, 2, "float32").div(dim)));
46
-
30
+ const inv_frequencies = tf.div(1, tf.pow(theta, tf.range(0, Math.floor(dim / 2) * 2, 2, "float32").div(dim)));
47
31
  // [max_sequene_length]
48
32
  const sequence_indices = tf.range(0, max_sequence_length);
49
33
  //
50
34
  const freq = tf.outerProduct(sequence_indices, inv_frequencies);
51
-
52
35
  // cache final shape [max_sequence_length, dim]
53
36
  const freq_pairs = tf.stack([freq, freq], -1)
54
37
  .reshape([max_sequence_length, dim]);
55
-
56
38
  return [
57
39
  tf.keep(tf.cos(freq_pairs).expandDims(0).expandDims(0)),
58
40
  tf.keep(tf.sin(freq_pairs).expandDims(0).expandDims(0))
59
- ]
41
+ ];
60
42
  });
61
43
  }
62
-
63
-
64
- export interface RotaryPositionEmbeddingArgs extends LayerArgs {
65
- /**
66
- * The dimension of each head (rounded down), e.g. `Math.floor(embedDim / numHeads)`
67
- */
68
- dim: number,
69
- /**
70
- * The RoPE cache will be pre-calculated up to the max sequence length, and re-caculated as needed. Defaults to `4096`.
71
- */
72
- maxSequenceLength?: number,
73
- /**
74
- * The base for the geometric progression used to compute the rotation angles. Defaults to `10_000`.
75
- */
76
- theta?: number,
77
- }
78
-
79
-
80
44
  /**
81
45
  * Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
82
46
  * Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
83
47
  */
84
48
  export class RotaryPositionEmbedding extends tf.layers.Layer {
85
49
  static className = "RotaryPositionEmbedding";
86
-
87
- protected dim: number;
88
- protected max_sequence_length: number;
89
- protected theta: number;
90
-
50
+ dim;
51
+ max_sequence_length;
52
+ theta;
91
53
  // cached sine and cosine frequencies, untrainable weights
92
- protected cosine_cache: tf.LayerVariable;
93
- protected sine_cache: tf.LayerVariable;
94
-
95
- constructor({ dim, maxSequenceLength = 4096, theta = 10_000, ...args }: RotaryPositionEmbeddingArgs) {
54
+ cosine_cache;
55
+ sine_cache;
56
+ constructor({ dim, maxSequenceLength = 4096, theta = 10_000, ...args }) {
96
57
  super(args);
97
-
98
58
  if (dim % 2 !== 0) {
99
59
  throw Error(`${this.getClassName()}::constructor ${this.name} expected dim to be even, got ${dim}`);
100
60
  }
101
-
102
61
  this.dim = dim;
103
62
  this.max_sequence_length = maxSequenceLength;
104
63
  this.theta = theta;
105
-
106
- this.cosine_cache = this.addWeight("sine_cache",
107
- [1, 1, maxSequenceLength, Math.floor(this.dim)],
108
- "float32", tf.initializers.zeros(), undefined, false);
109
-
110
- this.sine_cache = this.addWeight("cosine_cache",
111
- [1, 1, maxSequenceLength, Math.floor(this.dim)],
112
- "float32", tf.initializers.zeros(), undefined, false);
64
+ this.cosine_cache = this.addWeight("sine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
65
+ this.sine_cache = this.addWeight("cosine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
113
66
  }
114
-
115
-
116
- override call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[] {
67
+ call(inputs, kwargs) {
117
68
  const shape = Array.isArray(inputs) ? inputs[0].shape : inputs.shape;
118
69
  const seq_length = shape[2];
119
-
120
70
  if (seq_length > this.max_sequence_length) {
121
71
  // expand cache to the nearest power of 2
122
72
  this.max_sequence_length = Math.pow(2, Math.ceil(Math.log2(seq_length)));
123
73
  this.build([]);
124
74
  }
125
-
126
- return applyRope(
127
- Array.isArray(inputs) ? inputs[0] : inputs,
128
- this.dim,
129
- this.cosine_cache.read(),
130
- this.sine_cache.read())
75
+ return applyRope(Array.isArray(inputs) ? inputs[0] : inputs, this.dim, this.cosine_cache.read(), this.sine_cache.read());
131
76
  }
132
-
133
-
134
- override build(input_shape: tf.Shape | tf.Shape[]) {
135
- const [cosine, sine] = createRoPECache(
136
- this.dim, this.max_sequence_length, this.theta);
137
-
77
+ build(input_shape) {
78
+ const [cosine, sine] = createRoPECache(this.dim, this.max_sequence_length, this.theta);
138
79
  this.cosine_cache.dispose();
139
80
  this.sine_cache.dispose();
140
-
141
81
  this.cosine_cache = new tf.LayerVariable(cosine);
142
82
  this.sine_cache = new tf.LayerVariable(sine);
143
-
144
83
  this.nonTrainableWeights = [
145
84
  new tf.LayerVariable(cosine),
146
85
  new tf.LayerVariable(sine)
147
86
  ];
148
-
149
87
  this.setWeights([cosine, sine]);
150
88
  }
151
-
152
-
153
89
  /**
154
90
  * Output shape: [batch, head, sequence, head_dim]
155
91
  */
156
- public computeOutputShape(input_shape: tf.Shape | tf.Shape[]) {
92
+ computeOutputShape(input_shape) {
157
93
  return Array.isArray(input_shape[0])
158
- ? input_shape[0] as tf.Shape
159
- : input_shape as tf.Shape;
94
+ ? input_shape[0]
95
+ : input_shape;
160
96
  }
161
97
  }
162
-
163
98
  tf.serialization.registerClass(RotaryPositionEmbedding);
99
+ //# sourceMappingURL=rotary_position_embedding.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.js","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,MAAM,UAAU,SAAS,CAAC,CAAY,EAAE,GAAW,EAAE,YAAuB,EAAE,UAAqB;IAC/F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,MAAM,UAAU,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAE/B,2EAA2E;QAC3E,MAAM,MAAM,GAAG,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QACzE,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QAErE,0DAA0D;QAC1D,MAAM,SAAS,GAAG,UAAU,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAErC,OAAO,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC;IAC9D,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,UAAU,CAAC,CAAY,EAAE,GAAW;IAChD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,gFAAgF;QAChF,2CAA2C;QAC3C,sEAAsE;QACtE,kCAAkC;QAClC,MAAM,QAAQ,GAAG,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,EAAE,GAAG,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE7C,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAElD,wBAAwB;QACxB,MAAM,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAEhD,OAAO,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,eAAe,CAAC,GAAW,EAAE,mBAA2B,EAAE,QAAgB,MAAM;IAC5F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,QAAQ;QACR,MAAM,eAAe,GAAG,EAAE,CAAC,GAAG,CAAc,CAAC,EAAE,EAAE,CAAC,GAAG,CACjD,KAAK,EACL,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,GAAG,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAElE,uBAAuB;QACvB,MAAM,gBAAgB,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,mBAAmB,CAAC,CAAC;QAC1D,GAAG;QACH,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAC,gBAAgB,EAAE,eAAe,CAAC,CAAC;QAEhE,+CAA+C;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC;aACxC,OAAO,CAAC,CAAC,mBAAmB,EAAE,GAAG,CAAC,CAAC,CAAC;QAEzC,OAAO;YACH,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACvD,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;SAC1D,CAAA;IACL,CAAC,CAAC,CAAC;AACP,CAAC;AAmBD;;;GAGG;AACH,MAAM,OAAO,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,GAAG,yBAAyB,CAAC;IAEnC,GAAG,CAAS;IACZ,mBAAmB,CAAS;IAC5B,KAAK,CAAS;IAExB,0DAA0D;IAChD,YAAY,CAAmB;IAC/B,UAAU,CAAmB;IAEvC,YAAY,EAAE,GAAG,EAAE,iBAAiB,GAAG,IAAI,EAAE,KAAK,GAAG,MAAM,EAAE,GAAG,IAAI,EAA+B;QAC/F,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,GAAG,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC;YAChB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,iCAAiC,GAAG,EAAE,CAAC,CAAC;QACxG,CAAC;QAED,IAAI,CAAC,GAAG,GAAG,GAAG,CAAC;QACf,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;QAEnB,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,YAAY,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;QAE1D,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,cAAc,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IAC9D,CAAC;IAGQ,IAAI,CAAC,MAA+B,EAAE,MAAW;QACtD,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC;QACrE,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;QAE5B,IAAI,UAAU,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YACxC,yCAAyC;YACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACzE,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QACnB,CAAC;QAED,OAAO,SAAS,CACZ,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,EAC1C,IAAI,CAAC,GAAG,EACR,IAAI,CAAC,YAAY,CAAC,IAAI,EAAE,EACxB,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,CAAC,CAAA;IAC/B,CAAC;IAGQ,KAAK,CAAC,WAAkC;QAC7C,MAAM,CAAC,MAAM,EAAE,IAAI,CAAC,GAAG,eAAe,CAClC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,mBAAmB,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAEpD,IAAI,CAAC,YAAY,CAAC,OAAO,EAAE,CAAC;QAC5B,IAAI,CAAC,UAAU,CAAC,OAAO,EAAE,CAAC;QAE1B,IAAI,CAAC,YAAY,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;QACjD,IAAI,CAAC,UAAU,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;QAE7C,IAAI,CAAC,mBAAmB,GAAG;YACvB,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC;YAC5B,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC;SAC7B,CAAC;QAEF,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC,CAAC;IACpC,CAAC;IAGD;;OAEG;IACI,kBAAkB,CAAC,WAAkC;QACxD,OAAO,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;YAChC,CAAC,CAAC,WAAW,CAAC,CAAC,CAAa;YAC5B,CAAC,CAAC,WAAuB,CAAC;IAClC,CAAC;;AAGL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,uBAAuB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=rotary_position_embedding.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.test.d.ts","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":""}
@@ -0,0 +1,88 @@
1
+ import { RotaryPositionEmbedding } from "@/layers/rotary_position_embedding";
2
+ import * as tf from "@tensorflow/tfjs";
3
+ // disables warning for using the faster node backend,
4
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
+ tf.env().set('IS_NODE', false);
6
+ describe("RotaryPositionEmbedding tests", () => {
7
+ test("create cache", async () => {
8
+ const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
9
+ rope.build([]);
10
+ const expected_cosine_cache = tf.tensor([[[
11
+ [1, 1, 1, 1, 1, 1, 1, 1],
12
+ [0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334, 0.9999995231628418, 0.9999995231628418],
13
+ [-0.416146844625473, -0.416146844625473, 0.9800665974617004, 0.9800665974617004, 0.9998000264167786, 0.9998000264167786, 0.9999979734420776, 0.9999979734420776],
14
+ [-0.9899924993515015, -0.9899924993515015, 0.9553365111351013, 0.9553365111351013, 0.9995500445365906, 0.9995500445365906, 0.9999955296516418, 0.9999955296516418],
15
+ [-0.6536436080932617, -0.6536436080932617, 0.9210609793663025, 0.9210609793663025, 0.9992001056671143, 0.9992001056671143, 0.9999920129776001, 0.9999920129776001],
16
+ [0.28366219997406006, 0.28366219997406006, 0.8775825500488281, 0.8775825500488281, 0.9987502694129944, 0.9987502694129944, 0.9999874830245972, 0.9999874830245972],
17
+ [0.9601702690124512, 0.9601702690124512, 0.8253356218338013, 0.8253356218338013, 0.998200535774231, 0.998200535774231, 0.9999819993972778, 0.9999819993972778],
18
+ [0.7539022564888, 0.7539022564888, 0.7648422122001648, 0.7648422122001648, 0.9975510239601135, 0.9975510239601135, 0.9999755024909973, 0.9999755024909973],
19
+ [-0.1455000340938568, -0.1455000340938568, 0.6967067122459412, 0.6967067122459412, 0.9968017339706421, 0.9968017339706421, 0.9999679923057556, 0.9999679923057556],
20
+ [-0.9111302495002747, -0.9111302495002747, 0.6216099262237549, 0.6216099262237549, 0.9959527254104614, 0.9959527254104614, 0.9999595284461975, 0.9999595284461975],
21
+ [-0.83907151222229, -0.83907151222229, 0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334],
22
+ [0.004425697959959507, 0.004425697959959507, 0.4535960853099823, 0.4535960853099823, 0.9939560890197754, 0.9939560890197754, 0.999939501285553, 0.999939501285553],
23
+ [0.8438539505004883, 0.8438539505004883, 0.3623577058315277, 0.3623577058315277, 0.9928086400032043, 0.9928086400032043, 0.9999279975891113, 0.9999279975891113],
24
+ [0.9074468016624451, 0.9074468016624451, 0.26749876141548157, 0.26749876141548157, 0.9915618896484375, 0.9915618896484375, 0.9999154806137085, 0.9999154806137085],
25
+ [0.13673721253871918, 0.13673721253871918, 0.1699671596288681, 0.1699671596288681, 0.9902160167694092, 0.9902160167694092, 0.9999020099639893, 0.9999020099639893]
26
+ ]]]);
27
+ const expected_sine_cache = tf.tensor([[[
28
+ [0, 0, 0, 0, 0, 0, 0, 0],
29
+ [0.8414709568023682, 0.8414709568023682, 0.0998334214091301, 0.0998334214091301, 0.009999833069741726, 0.009999833069741726, 0.0009999999310821295, 0.0009999999310821295],
30
+ [0.9092974066734314, 0.9092974066734314, 0.19866932928562164, 0.19866932928562164, 0.019998665899038315, 0.019998665899038315, 0.0019999986980110407, 0.0019999986980110407],
31
+ [0.14112000167369843, 0.14112000167369843, 0.29552021622657776, 0.29552021622657776, 0.029995499178767204, 0.029995499178767204, 0.0029999956022948027, 0.0029999956022948027],
32
+ [-0.756802499294281, -0.756802499294281, 0.3894183337688446, 0.3894183337688446, 0.03998933359980583, 0.03998933359980583, 0.003999989479780197, 0.003999989479780197],
33
+ [-0.9589242935180664, -0.9589242935180664, 0.4794255495071411, 0.4794255495071411, 0.04997916519641876, 0.04997916519641876, 0.0049999793991446495, 0.0049999793991446495],
34
+ [-0.279415488243103, -0.279415488243103, 0.5646424889564514, 0.5646424889564514, 0.059964004904031754, 0.059964004904031754, 0.0059999641962349415, 0.0059999641962349415],
35
+ [0.6569865942001343, 0.6569865942001343, 0.6442176699638367, 0.6442176699638367, 0.06994284689426422, 0.06994284689426422, 0.0069999429397284985, 0.0069999429397284985],
36
+ [0.9893582463264465, 0.9893582463264465, 0.7173560857772827, 0.7173560857772827, 0.07991468906402588, 0.07991468906402588, 0.007999914698302746, 0.007999914698302746],
37
+ [0.41211849451065063, 0.41211849451065063, 0.7833269238471985, 0.7833269238471985, 0.08987854421138763, 0.08987854421138763, 0.008999879471957684, 0.008999879471957684],
38
+ [-0.5440211296081543, -0.5440211296081543, 0.8414709568023682, 0.8414709568023682, 0.0998334139585495, 0.0998334139585495, 0.0099998340010643, 0.0099998340010643],
39
+ [-0.9999902248382568, -0.9999902248382568, 0.8912073969841003, 0.8912073969841003, 0.10977829992771149, 0.10977829992771149, 0.010999779216945171, 0.010999779216945171],
40
+ [-0.5365729331970215, -0.5365729331970215, 0.9320390820503235, 0.9320390820503235, 0.11971220374107361, 0.11971220374107361, 0.011999712325632572, 0.011999712325632572],
41
+ [0.4201670289039612, 0.4201670289039612, 0.9635581970214844, 0.9635581970214844, 0.12963414192199707, 0.12963414192199707, 0.012999634258449078, 0.012999634258449078],
42
+ [0.9906073808670044, 0.9906073808670044, 0.9854497313499451, 0.9854497313499451, 0.13954311609268188, 0.13954311609268188, 0.013999543152749538, 0.013999543152749538]
43
+ ]]]);
44
+ const [cosine_cache, sine_cache] = rope.getWeights();
45
+ expect(await cosine_cache?.sub(expected_cosine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
46
+ expect(await sine_cache?.sub(expected_sine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
47
+ });
48
+ test("rotate inputs", async () => {
49
+ const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
50
+ const x = tf.tensor([[[
51
+ [0.0766048, 0.5706575, 0.6705932, 0.5273118, 0.4794086, 0.9378104, 0.9888024, 0.6926053],
52
+ [0.9064133, 0.5875182, 0.1681865, 0.3833345, 0.9901192, 0.4677338, 0.3353315, 0.02699],
53
+ [0.3033573, 0.4139377, 0.4062586, 0.9705839, 0.3582608, 0.328775, 0.1340587, 0.2193414],
54
+ [0.5565202, 0.4334963, 0.9912352, 0.3388563, 0.7991487, 0.1911893, 0.1140554, 0.6949552]
55
+ ]]
56
+ ]); // batch=1, seq = 1, heads=4, embedDim=8
57
+ const expected_output = tf.tensor([[[
58
+ [0.07660479843616486, 0.57065749168396, 0.6705932021141052, 0.5273118019104004, 0.4794085919857025, 0.9378104209899902, 0.9888023734092712, 0.6926053166389465],
59
+ [-0.004642367362976074, 1.08015775680542, 0.12907665967941284, 0.39820998907089233, 0.9853923320770264, 0.47761136293411255, 0.33530429005622864, 0.027325313538312912],
60
+ [-0.5026336908340454, 0.10358311235904694, 0.20533521473407745, 1.0319478511810303, 0.3516140580177307, 0.33587393164634705, 0.1336197406053543, 0.21960905194282532],
61
+ [-0.6121258735656738, -0.3506217896938324, 0.8468242287635803, 0.6166517734527588, 0.7930541634559631, 0.2150741070508957, 0.11197001487016678, 0.695294201374054]
62
+ ]]]);
63
+ const output = rope.apply(x);
64
+ expect(await expected_output.sub(output).sum().array()).toBeLessThan(1e-6);
65
+ expect(rope.computeOutputShape(x.shape)).toEqual(x.shape);
66
+ expect(rope.computeOutputShape([x.shape])).toEqual(x.shape);
67
+ });
68
+ test("expand cache when input sequences are larger than rope's max sequence length", async () => {
69
+ const dim = 8;
70
+ const rope = new RotaryPositionEmbedding({ dim, maxSequenceLength: 15, theta: 1_000_000 });
71
+ const larger_sequence = 20;
72
+ const even_larger_sequence = 50;
73
+ rope.apply(tf.randomUniform([1, 1, larger_sequence, dim]));
74
+ rope.getWeights().forEach(weight => {
75
+ expect(weight.shape).toEqual([1, 1, 32, dim]);
76
+ });
77
+ rope.apply([tf.randomUniform([1, 1, even_larger_sequence, dim])]);
78
+ rope.getWeights().forEach(weight => {
79
+ expect(weight.shape).toEqual([1, 1, 64, dim]);
80
+ });
81
+ });
82
+ test("create layer", async () => {
83
+ // dim must be even
84
+ expect(() => new RotaryPositionEmbedding({ dim: 7, maxSequenceLength: 15 })).toThrow();
85
+ expect(() => new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 25 })).not.toThrow();
86
+ });
87
+ });
88
+ //# sourceMappingURL=rotary_position_embedding.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.test.js","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAC7E,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,+BAA+B,EAAE,GAAG,EAAE;IAC3C,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAC5E,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEf,MAAM,qBAAqB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACtC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,eAAe,EAAE,eAAe,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC1J,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,gBAAgB,EAAE,CAAC,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,oBAAoB,EAAE,oBAAoB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,mBAAmB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACpC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC5K,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC9K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;iBACzK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,CAAC,YAAY,EAAE,UAAU,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC;QAErD,MAAM,CAAC,MAAM,YAAY,EAAE,GAAG,CAAC,qBAAqB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;QACzG,MAAM,CAAC,MAAM,UAAU,EAAE,GAAG,CAAC,mBAAmB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;IACzG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,KAAK,IAAI,EAAE;QAC7B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAE5E,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAClB,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;oBACxF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,OAAO,CAAC;oBACtF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,CAAC;oBACvF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;iBAAC,CAAC;SAC7F,CAAC,CAAC,CAAC,wCAAwC;QAE5C,MAAM,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAChC,CAAC,mBAAmB,EAAE,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC/J,CAAC,CAAC,oBAAoB,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,CAAC;oBACvK,CAAC,CAAC,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,CAAC;oBACrK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,iBAAiB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAc,CAAC;QAE1C,MAAM,CAAC,MAAM,eAAe,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QACrF,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QAC1D,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IAChE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,8EAA8E,EAAE,KAAK,IAAI,EAAE;QAC5F,MAAM,GAAG,GAAG,CAAC,CAAC;QACd,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,iBAAiB,EAAE,EAAE,EAAE,KAAK,EAAE,SAAS,EAAE,CAAC,CAAC;QAC3F,MAAM,eAAe,GAAG,EAAE,CAAC;QAC3B,MAAM,oBAAoB,GAAG,EAAE,CAAC;QAEhC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,eAAe,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC;QAE3D,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;QAEH,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,oBAAoB,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;QAElE,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;IACP,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,mBAAmB;QACnB,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC/F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
@@ -0,0 +1,47 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
+ import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
+ import { type PositionalEncodingArgs } from '../layers/positional_encoding';
5
+ export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEncodingArgs {
6
+ vocabularySize: number;
7
+ dropout?: number;
8
+ }
9
+ /**
10
+ * This class implements combines sinusoidal positional encoding from the
11
+ * 2017 paper "Attention Is All You Need" with a normal embedding layer to
12
+ * form a simplified single embedding layer.
13
+ *
14
+ * This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
15
+ * it through an embedding layer before adding sinusoidal positional encoding.
16
+ *
17
+ * @param embedDim the size of each token/word's embedding
18
+ * @param vocabularySize the number of tokens to embed
19
+ * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
20
+ * @param dropout applies dropout to the positionally encoded embeddings, default `0.1`
21
+ */
22
+ export declare class TokenAndPositionalEmbedding extends tf.layers.Layer {
23
+ static className: string;
24
+ private readonly embedDim;
25
+ private readonly vocabularySize;
26
+ private embedding;
27
+ private positional;
28
+ private readonly maxSequenceLength;
29
+ private readonly dropout;
30
+ private dropoutLayer;
31
+ constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }: TokenAndPositionalEmbeddingArgs);
32
+ /**
33
+ * Forward propagation.
34
+ */
35
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor<tf.Rank>;
36
+ /**
37
+ * Build the sublayers and enable serialization
38
+ */
39
+ build(inputShape: tf.Shape | tf.Shape[]): void;
40
+ /**
41
+ * The output shape, for an input shape of [batch, sequences], is
42
+ * [batch, sequences, embedDim]
43
+ */
44
+ computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
45
+ getConfig(): tf.serialization.ConfigDict;
46
+ }
47
+ //# sourceMappingURL=token_and_positional_embedding.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.d.ts","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,+BAA+B,CAAC;AAGhG,MAAM,WAAW,+BAAgC,SAAQ,SAAS,EAAE,sBAAsB;IACtF,cAAc,EAAE,MAAM,CAAC;IACvB,OAAO,CAAC,EAAE,MAAM,CAAA;CACnB;AAGD;;;;;;;;;;;;GAYG;AACH,qBAAa,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,SAAiC;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAS;IACxC,OAAO,CAAC,SAAS,CAAkB;IAEnC,OAAO,CAAC,UAAU,CAAiB;IACnC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IAEjC,OAAO,CAAC,YAAY,CAAkB;gBAG1B,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAE,EAAE,+BAA+B;IA0B9G;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM;IAe7D;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAgCvD;;;OAGG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAQ5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAcpD"}