@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -0,0 +1,58 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { TokenAndPositionalEmbedding } from '@/layers/token_and_positional_embedding';
3
+ // disables warning for using the faster node backend,
4
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
+ tf.env().set('IS_NODE', false);
6
+ describe("PositionalEncoding tests", () => {
7
+ test("layer initialization", () => {
8
+ expect(() => new TokenAndPositionalEmbedding({ maxSequenceLength: 0, embedDim: 10, vocabularySize: 10_000 })).toThrow();
9
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 0, vocabularySize: 10_000 })).toThrow();
10
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 0 })).toThrow();
11
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
12
+ expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
13
+ });
14
+ test("successfull forward calls", () => {
15
+ const embed_dims = 32;
16
+ const sequences = 4;
17
+ const vocab_size = 10_000;
18
+ const input = tf.randomUniform([2, sequences]);
19
+ const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, dropout: 0.1, vocabularySize: vocab_size });
20
+ expect(() => embedding.apply(input)).not.toThrow();
21
+ expect(() => embedding.apply([input])).not.toThrow();
22
+ });
23
+ test("layer build", () => {
24
+ const input_ok = tf.randomUniform([2, 4]);
25
+ const input_too_many_words = tf.randomUniform([2, 700]);
26
+ const input_is_image = tf.randomUniform([1, 32, 32, 3]);
27
+ let embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
28
+ expect(() => embedding.build(input_ok.shape)).not.toThrow();
29
+ embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
30
+ expect(() => embedding.build([input_ok.shape, input_ok.shape])).not.toThrow();
31
+ new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
32
+ expect(() => embedding.build(input_too_many_words.shape)).toThrow();
33
+ expect(() => embedding.build(input_is_image.shape)).toThrow();
34
+ });
35
+ it("should throw when more than one input provided, input sequences are too large, or incorrect input rank", () => {
36
+ const sequences_too_long = tf.randomUniform([10, 1000]);
37
+ const multiple_correct_inputs = [tf.randomUniform([2, 3]), tf.randomUniform([2, 3])];
38
+ const wrong_rank = tf.randomUniform([10, 32, 32]);
39
+ const positional = new TokenAndPositionalEmbedding({ maxSequenceLength: 10, embedDim: 32, vocabularySize: 10_000 });
40
+ positional.build([2, 3]); // get past the initial build call to test forward prop
41
+ expect(() => positional.apply(sequences_too_long)).toThrow();
42
+ expect(() => positional.apply(multiple_correct_inputs)).toThrow();
43
+ expect(() => positional.apply(wrong_rank)).toThrow();
44
+ });
45
+ it("should return a non-empty config dict", () => {
46
+ const embedding = new TokenAndPositionalEmbedding({ embedDim: 32, vocabularySize: 10_000 });
47
+ expect(Object.keys(embedding.getConfig())).not.toBe(0);
48
+ });
49
+ it("should return an output shape of [batch, sequences, embed dims]", () => {
50
+ const words = 100;
51
+ const batch = 2;
52
+ const embed_dims = 64;
53
+ const input = tf.randomUniform([batch, words]);
54
+ const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, vocabularySize: 10_000 });
55
+ expect(embedding.computeOutputShape(input.shape)).toEqual([batch, words, embed_dims]);
56
+ });
57
+ });
58
+ //# sourceMappingURL=token_and_positional_embedding.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.test.js","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,2BAA2B,EAAE,MAAM,yCAAyC,CAAC;AAEtF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,IAAI,CAAC,sBAAsB,EAAE,GAAG,EAAE;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxH,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE7F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC1G,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,UAAU,GAAG,MAAM,CAAC;QAC1B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,cAAc,EAAE,UAAU,EAAE,CAAC,CAAC;QACtH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,aAAa,EAAE,GAAG,EAAE;QACrB,MAAM,QAAQ,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC;QACxD,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;QAExD,IAAI,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE5D,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,QAAQ,CAAC,KAAK,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE9E,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,oBAAoB,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAClE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,wGAAwG,EAAE,GAAG,EAAE;QAC9G,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAC,CAAC;QACxD,MAAM,uBAAuB,GAAG,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACrF,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QACpH,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,uDAAuD;QAEjF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAClE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAC5F,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iEAAiE,EAAE,GAAG,EAAE;QACvE,MAAM,KAAK,GAAG,GAAG,CAAC;QAClB,MAAM,KAAK,GAAG,CAAC,CAAC;QAChB,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAEpG,MAAM,CAAC,SAAS,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,UAAU,CAAC,CAAC,CAAC;IAC1F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
@@ -0,0 +1,69 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+ import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
4
+ import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
5
+ export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
6
+ activation?: "relu" | "gelu";
7
+ dimsFeedForward?: number;
8
+ causal?: boolean;
9
+ }
10
+ /**
11
+ * This class implements the transformer decoder architecture from
12
+ * the 2017 paper "Attention Is All You Need".
13
+ *
14
+ * This decoder-only transformer layer accepts one tensor input.
15
+ * The input tensor should have the shape
16
+ * `[ batch, sequences, embedding dims ]`.
17
+ *
18
+ * Causal masking is enabled by default for the initial attention sub-layer.
19
+ *
20
+ * @param numHeads number of attention heads to use
21
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
22
+ * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
23
+ * @param dropout use dropout during the attention calculations, default `0.1`
24
+ * @param activation the activation of the intermediate feed forward layer, default `relu`
25
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
26
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
27
+ */
28
+ export declare class TransformerDecoder extends tf.layers.Layer {
29
+ static className: string;
30
+ protected readonly causalSelfAttention: tf.layers.Layer;
31
+ protected readonly causalSelfAttentionDropout: tf.layers.Layer;
32
+ protected readonly causalSelfAttentionNorm: tf.layers.Layer;
33
+ protected readonly feedforward1: tf.layers.Layer;
34
+ protected readonly feedforward2: tf.layers.Layer;
35
+ protected readonly feedForwardDropout: tf.layers.Layer;
36
+ protected readonly feedFowardNorm: tf.layers.Layer;
37
+ protected readonly numHeads: number;
38
+ protected readonly embedDim: number;
39
+ protected readonly useBias: boolean;
40
+ protected readonly dropout: number;
41
+ protected readonly activation: ActivationIdentifier;
42
+ protected readonly dimsFeedForward: number;
43
+ constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs);
44
+ /**
45
+ * Forward propagation
46
+ *
47
+ * @param inputs input tensor
48
+ * @return the output tensor
49
+ */
50
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
51
+ protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
52
+ protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
53
+ /**
54
+ * Initialize the sublayers' weights and track them to enable serialization
55
+ */
56
+ build(inputShape: tf.Shape | tf.Shape[]): void;
57
+ /**
58
+ * Save the layer's hyperparameters for serialization
59
+ */
60
+ getConfig(): {
61
+ numHeads: number;
62
+ embedDim: number;
63
+ useBias: boolean;
64
+ dropout: number;
65
+ activation: ActivationIdentifier;
66
+ dimsFeedForward: number;
67
+ };
68
+ }
69
+ //# sourceMappingURL=transformer_decoder.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AACjE,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAExG,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAI3E,MAAM,WAAW,sBAAuB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAClF,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;IACzB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD;;;;;;;;;;;;;;;;;GAiBG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,SAAS,CAAC,QAAQ,CAAC,mBAAmB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACxD,SAAS,CAAC,QAAQ,CAAC,0BAA0B,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAC/D,SAAS,CAAC,QAAQ,CAAC,uBAAuB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAE5D,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,kBAAkB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACvD,SAAS,CAAC,QAAQ,CAAC,cAAc,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAEnD,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,UAAU,EAAE,oBAAoB,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,MAAM,CAAC;gBAE/B,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAyClH;;;;;OAKG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAoBvF,SAAS,CAAC,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAc3E,SAAS,CAAC,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAenE;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA6DvD;;OAEG;IACM,SAAS;;;;;;;;CAiBrB"}
@@ -1,28 +1,15 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
- import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
- import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
4
-
5
- import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
6
2
  import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
7
-
8
-
9
- export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
10
- activation?: "relu" | "gelu";
11
- dimsFeedForward?: number;
12
- causal?: boolean; // use causal mask for attention on inputs
13
- }
14
-
15
-
16
3
  /**
17
4
  * This class implements the transformer decoder architecture from
18
5
  * the 2017 paper "Attention Is All You Need".
19
- *
6
+ *
20
7
  * This decoder-only transformer layer accepts one tensor input.
21
8
  * The input tensor should have the shape
22
9
  * `[ batch, sequences, embedding dims ]`.
23
- *
10
+ *
24
11
  * Causal masking is enabled by default for the initial attention sub-layer.
25
- *
12
+ *
26
13
  * @param numHeads number of attention heads to use
27
14
  * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
28
15
  * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
@@ -33,48 +20,39 @@ export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "ca
33
20
  */
34
21
  export class TransformerDecoder extends tf.layers.Layer {
35
22
  static className = "TransformerDecoder";
36
-
37
- protected readonly causalSelfAttention: tf.layers.Layer;
38
- protected readonly causalSelfAttentionDropout: tf.layers.Layer;
39
- protected readonly causalSelfAttentionNorm: tf.layers.Layer;
40
-
41
- protected readonly feedforward1: tf.layers.Layer;
42
- protected readonly feedforward2: tf.layers.Layer;
43
- protected readonly feedForwardDropout: tf.layers.Layer;
44
- protected readonly feedFowardNorm: tf.layers.Layer;
45
-
46
- protected readonly numHeads: number;
47
- protected readonly embedDim: number;
48
- protected readonly useBias: boolean;
49
- protected readonly dropout: number;
50
- protected readonly activation: ActivationIdentifier;
51
- protected readonly dimsFeedForward: number;
52
-
53
- constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs) {
23
+ causalSelfAttention;
24
+ causalSelfAttentionDropout;
25
+ causalSelfAttentionNorm;
26
+ feedforward1;
27
+ feedforward2;
28
+ feedForwardDropout;
29
+ feedFowardNorm;
30
+ numHeads;
31
+ embedDim;
32
+ useBias;
33
+ dropout;
34
+ activation;
35
+ dimsFeedForward;
36
+ constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }) {
54
37
  super(args);
55
-
56
38
  this.numHeads = numHeads;
57
39
  this.embedDim = embedDim;
58
40
  this.useBias = useBias ?? true;
59
41
  this.dropout = dropout ?? 0.1;
60
42
  this.activation = activation ?? "relu";
61
-
62
43
  if (this.dropout >= 1) {
63
44
  throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
64
45
  }
65
-
66
46
  // in the paper section 3.3, d_model=512 (embedDim) and first dense layer outputs d_ff=2048
67
47
  this.dimsFeedForward = dimsFeedForward ?? embedDim * 4;
68
-
69
48
  // self attention sub-block
70
49
  this.causalSelfAttention = new CachedRoPEMultiHeadAttention({
71
50
  numHeads: this.numHeads, embedDim: this.embedDim,
72
51
  useBias: this.useBias, dropout: this.dropout,
73
52
  causal: true
74
53
  });
75
- this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
54
+ this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
76
55
  this.causalSelfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
77
-
78
56
  // feed forward sub-block
79
57
  this.feedforward1 = tf.layers.dense({
80
58
  units: this.dimsFeedForward,
@@ -89,101 +67,79 @@ export class TransformerDecoder extends tf.layers.Layer {
89
67
  this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
90
68
  this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
91
69
  }
92
-
93
-
94
70
  /**
95
71
  * Forward propagation
96
- *
72
+ *
97
73
  * @param inputs input tensor
98
74
  * @return the output tensor
99
75
  */
100
- override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
76
+ call(inputs, kwargs) {
101
77
  // validate the input tensors
102
78
  if (Array.isArray(inputs) && inputs.length != 1 && inputs.length != 2) {
103
79
  throw Error(`${this.getClassName()}::call ${this.name} expects one input tensor, got ${inputs.length} inputs.`);
104
80
  }
105
-
106
81
  if (Array.isArray(inputs)) {
107
- inputs = inputs[0] as tf.Tensor;
82
+ inputs = inputs[0];
108
83
  }
109
-
110
84
  // perform forward propagation
111
85
  return tf.tidy(() => {
112
86
  let output = this.causalSelfAttentionBlock(inputs, kwargs);
113
87
  output = this.feedForwardBlock(output, kwargs);
114
-
115
88
  return output;
116
89
  });
117
90
  }
118
-
119
-
120
- protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
91
+ causalSelfAttentionBlock(x, kwargs) {
121
92
  return tf.tidy(() => {
122
93
  const residual = x;
123
-
124
- let attention = this.causalSelfAttention.apply(x, kwargs) as tf.Tensor;
125
- attention = this.causalSelfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
94
+ let attention = this.causalSelfAttention.apply(x, kwargs);
95
+ attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
126
96
  attention = tf.add(attention, residual);
127
- attention = this.causalSelfAttentionNorm.apply(attention, kwargs) as tf.Tensor;
128
-
97
+ attention = this.causalSelfAttentionNorm.apply(attention, kwargs);
129
98
  return attention;
130
99
  });
131
100
  }
132
-
133
-
134
- protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
101
+ feedForwardBlock(x, kwargs) {
135
102
  return tf.tidy(() => {
136
103
  const residual = x;
137
-
138
104
  let feedForward = this.feedforward1.apply(x, kwargs);
139
105
  feedForward = this.feedforward2.apply(feedForward, kwargs);
140
- feedForward = this.feedForwardDropout.apply(feedForward, kwargs) as tf.Tensor;
106
+ feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
141
107
  feedForward = tf.add(feedForward, residual);
142
- feedForward = this.feedFowardNorm.apply(feedForward, kwargs) as tf.Tensor;
143
-
108
+ feedForward = this.feedFowardNorm.apply(feedForward, kwargs);
144
109
  return feedForward;
145
110
  });
146
111
  }
147
-
148
-
149
112
  /**
150
113
  * Initialize the sublayers' weights and track them to enable serialization
151
114
  */
152
- override build(inputShape: tf.Shape | tf.Shape[]): void {
153
- let input_shapes: tf.Shape[] = [];
154
-
115
+ build(inputShape) {
116
+ let input_shapes = [];
155
117
  if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
156
118
  // input is an array of shapes
157
- input_shapes = inputShape as tf.Shape[];
158
- } else if (inputShape.length != 0) {
119
+ input_shapes = inputShape;
120
+ }
121
+ else if (inputShape.length != 0) {
159
122
  // input is a single shape
160
- input_shapes = [inputShape as tf.Shape];
123
+ input_shapes = [inputShape];
161
124
  }
162
-
163
125
  if (input_shapes.length != 1 && input_shapes.length != 2) {
164
126
  throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
165
- ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
127
+ ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
166
128
  }
167
-
168
129
  const [decoderInputShape] = input_shapes;
169
-
170
130
  if (decoderInputShape?.length != 3) {
171
131
  throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
172
- ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
132
+ ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
173
133
  }
174
-
175
134
  // initialize causal self attention sub-block's weights
176
135
  this.causalSelfAttention.build(decoderInputShape);
177
136
  this.causalSelfAttentionNorm.build(this.causalSelfAttention.computeOutputShape(decoderInputShape));
178
-
179
137
  // initialize feedforward sub-block's weights
180
138
  const feedforward1OutputShape = this.feedforward1.computeOutputShape(decoderInputShape);
181
139
  const feedforward2OutputShape = this.feedforward2.computeOutputShape(feedforward1OutputShape);
182
-
183
140
  this.feedforward1.build(decoderInputShape);
184
141
  this.feedforward2.build(feedforward1OutputShape);
185
142
  this.feedFowardNorm.build(feedforward2OutputShape);
186
-
187
143
  // track sublayers' weights
188
144
  this.trainableWeights = [
189
145
  ...this.causalSelfAttention.trainableWeights,
@@ -194,28 +150,22 @@ export class TransformerDecoder extends tf.layers.Layer {
194
150
  ...this.feedForwardDropout.trainableWeights,
195
151
  ...this.feedFowardNorm.trainableWeights
196
152
  ];
197
-
198
153
  // rename the weights otherwise they'll take on the default naming and overlap
199
154
  // each other which breaks model loading due to duplicate weight names
200
155
  let indexing = 0;
201
-
202
156
  for (const weight of this.trainableWeights) {
203
157
  const unique_name = `${this.getClassName()}_${indexing}`;
204
- (weight as any).name += unique_name;
205
- (weight as any).originalName += unique_name;
158
+ weight.name += unique_name;
159
+ weight.originalName += unique_name;
206
160
  indexing++;
207
161
  }
208
-
209
162
  super.build(inputShape);
210
163
  }
211
-
212
-
213
164
  /**
214
165
  * Save the layer's hyperparameters for serialization
215
166
  */
216
- override getConfig() {
167
+ getConfig() {
217
168
  const base_config = super.getConfig();
218
-
219
169
  const config = {
220
170
  numHeads: this.numHeads,
221
171
  embedDim: this.embedDim,
@@ -223,14 +173,10 @@ export class TransformerDecoder extends tf.layers.Layer {
223
173
  dropout: this.dropout,
224
174
  activation: this.activation,
225
175
  dimsFeedForward: this.dimsFeedForward
226
- }
227
-
176
+ };
228
177
  Object.assign(config, base_config);
229
-
230
178
  return config;
231
179
  }
232
-
233
180
  }
234
-
235
-
236
181
  tf.serialization.registerClass(TransformerDecoder);
182
+ //# sourceMappingURL=transformer_decoder.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.js","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAKvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AAUxF;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAErB,mBAAmB,CAAkB;IACrC,0BAA0B,CAAkB;IAC5C,uBAAuB,CAAkB;IAEzC,YAAY,CAAkB;IAC9B,YAAY,CAAkB;IAC9B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAE3C,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QAEvC,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2FAA2F;QAC3F,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEvD,2BAA2B;QAC3B,IAAI,CAAC,mBAAmB,GAAG,IAAI,4BAA4B,CAAC;YACxD,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ;YAChD,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YAC5C,MAAM,EAAE,IAAI;SACf,CAAC,CAAC;QACH,IAAI,CAAC,0BAA0B,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QAC3E,IAAI,CAAC,uBAAuB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAE/E,yBAAyB;QACzB,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,eAAe;YAC3B,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,QAAQ;YACpB,UAAU,EAAE,QAAQ;YACpB,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACpE,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kCAAkC,MAAM,CAAC,MAAM,UAAU,CAAC,CAAC;QACpH,CAAC;QAED,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAc,CAAC;QACpC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAC3D,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAE/C,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,wBAAwB,CAAC,CAAY,EAAE,MAAc;QAC3D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACvE,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAE/E,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACrD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAE1E,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACvD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,MAAM,CAAC,iBAAiB,CAAC,GAAG,YAAY,CAAC;QAEzC,IAAI,iBAAiB,EAAE,MAAM,IAAI,CAAC,EAAE,CAAC;YACjC,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,uDAAuD;QACvD,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAClD,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,IAAI,CAAC,mBAAmB,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC;QAEnG,6CAA6C;QAC7C,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC;QACxF,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,uBAAuB,CAAC,CAAC;QAE9F,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAC3C,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QACjD,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QAEnD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,mBAAmB,CAAC,gBAAgB;YAC5C,GAAG,IAAI,CAAC,0BAA0B,CAAC,gBAAgB;YACnD,GAAG,IAAI,CAAC,uBAAuB,CAAC,gBAAgB;YAChD,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=transformer_decoder.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.test.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":""}
@@ -0,0 +1,72 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { TransformerDecoder } from '@/layers/transformer_decoder';
3
+ // disables warning for using the faster node backend,
4
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
+ tf.env().set('IS_NODE', false);
6
+ describe("TransformerDecoder tests", () => {
7
+ it("should return an output with the same shape as the input", () => {
8
+ const input = tf.randomUniform([2, 3, 12]);
9
+ const decoder = new TransformerDecoder({
10
+ numHeads: 2, embedDim: input.shape.at(-1),
11
+ dropout: 0.5, activation: "gelu", dimsFeedForward: 321, useBias: false
12
+ });
13
+ const output = decoder.apply(input);
14
+ expect(output.shape.length).toBe(input.shape.length);
15
+ });
16
+ test("forward calls", () => {
17
+ const input = tf.randomUniform([2, 3, 12]);
18
+ const mask = tf.randomUniform([input.shape[0], input.shape[1]], -1, 2, "bool");
19
+ const incorrect_mask = tf.randomUniform([2, 5, 12], -1, 2, "bool");
20
+ const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1) });
21
+ expect(() => decoder.apply(input)).not.toThrow();
22
+ expect(() => decoder.apply([input])).not.toThrow();
23
+ // causal masking
24
+ const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
25
+ expect(() => causal.apply(input)).not.toThrow();
26
+ expect(() => causal.apply([input])).not.toThrow();
27
+ });
28
+ it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
29
+ const input = tf.randomUniform([2, 3, 12]);
30
+ expect(() => new TransformerDecoder({ numHeads: 3, embedDim: input.shape.at(-1) })).not.toThrow();
31
+ expect(() => new TransformerDecoder({ numHeads: 5, embedDim: input.shape.at(-1) })).toThrow();
32
+ });
33
+ it("should not accept non-rank 3 tensor inputs", () => {
34
+ const embed_dim = 12;
35
+ const BAD_RANK4 = tf.randomUniform([2, 3, 12, embed_dim]);
36
+ const BAD_RANK2 = tf.randomUniform([2, embed_dim]);
37
+ const GOOD = tf.randomUniform([2, 3, embed_dim]);
38
+ const mask = tf.randomUniform([GOOD.shape[0], GOOD.shape[1]], -1, 2, "bool");
39
+ let decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
40
+ // BAD
41
+ expect(() => decoder.apply(BAD_RANK4)).toThrow();
42
+ expect(() => decoder.apply(BAD_RANK2)).toThrow();
43
+ // OK
44
+ decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
45
+ expect(() => decoder.apply(GOOD)).not.toThrow();
46
+ expect(() => decoder.apply([GOOD])).not.toThrow();
47
+ expect(() => decoder.apply([GOOD, mask])).not.toThrow();
48
+ });
49
+ it("should not accept inputs that are less or more than 1 and 2 tensors", () => {
50
+ const input = tf.randomUniform([2, 3, 12]);
51
+ let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
52
+ // OK
53
+ expect(() => decoder.apply(input)).not.toThrow();
54
+ expect(() => decoder.apply([input])).not.toThrow();
55
+ // BAD
56
+ decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
57
+ expect(() => decoder.apply([])).toThrow(); // stops at build()
58
+ decoder.apply(input); // get past the initial build
59
+ expect(() => decoder.apply([input, input, input])).toThrow();
60
+ expect(() => decoder.apply([input, input, input, input])).toThrow();
61
+ // BAD (tests build())
62
+ decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
63
+ expect(() => decoder.apply([input, input, input])).toThrow();
64
+ expect(() => decoder.apply([input, input, input, input])).toThrow();
65
+ });
66
+ it("should return a non-empty config dict", () => {
67
+ const input = tf.randomUniform([2, 3, 12]);
68
+ const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
69
+ expect(Object.keys(decoder.getConfig())).not.toBe(0);
70
+ });
71
+ });
72
+ //# sourceMappingURL=transformer_decoder.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.test.js","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,KAAK;SACzE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,GAAG,EAAE;QACvB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QACjF,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAGnE,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,iBAAiB;QACjB,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACnG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC;QAErB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,SAAS,CAAC,CAAC,CAAC;QAC1D,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACnD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACjD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAE/E,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEjD,KAAK;QACL,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QACvE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC5D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,qEAAqE,EAAE,GAAG,EAAE;QAC3E,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACrF,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM;QACN,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,mBAAmB;QAC9D,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,6BAA6B;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEpE,sBAAsB;QACtB,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACxE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
@@ -0,0 +1,55 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+ import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
4
+ export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
5
+ activation?: "relu" | "gelu";
6
+ dimsFeedForward?: number;
7
+ }
8
+ /**
9
+ * This class implements the transformer encoder architecture from the 2017 paper
10
+ * Attention Is All You Need.
11
+ *
12
+ * This layer accepts exactly one tensor input with the shape
13
+ * `[ batch, sequences, embedding dims ]`.
14
+ *
15
+ * @param numHeads number of attention heads to use
16
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
17
+ * @param causal use causal masking, default `false` for encoders
18
+ * @param dropout use dropout during the attention calculations, default `0.1`
19
+ * @param activation the activation of the intermediate feed forward layer, default `relu`
20
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
21
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
22
+ */
23
+ export declare class TransformerEncoder extends tf.layers.Layer {
24
+ static className: string;
25
+ private readonly selfAttention;
26
+ private readonly selfAttentionDropout;
27
+ private readonly selfAttentionNorm;
28
+ private readonly reluLayer;
29
+ private readonly linearLayer;
30
+ private readonly feedForwardDropout;
31
+ private readonly feedFowardNorm;
32
+ private readonly numHeads;
33
+ private readonly embedDim;
34
+ private readonly causal;
35
+ private readonly useBias;
36
+ private readonly dropout;
37
+ private readonly activation;
38
+ private readonly dimsFeedForward;
39
+ constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs);
40
+ /**
41
+ * Forward propagation
42
+ */
43
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
44
+ private selfAttentionBlock;
45
+ private feedForwardBlock;
46
+ /**
47
+ * Initialize the sublayers' weights and track them to enable backpropagation.
48
+ */
49
+ build(inputShape: tf.Shape | tf.Shape[]): void;
50
+ /**
51
+ * Save the layer's hyperparameters for serialization
52
+ */
53
+ getConfig(): tf.serialization.ConfigDict;
54
+ }
55
+ //# sourceMappingURL=transformer_encoder.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_encoder.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,sBAAuB,SAAQ,sBAAsB;IAClE,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;;;;GAcG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,OAAO,CAAC,QAAQ,CAAC,aAAa,CAAkB;IAChD,OAAO,CAAC,QAAQ,CAAC,oBAAoB,CAAkB;IACvD,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAkB;IAEpD,OAAO,CAAC,QAAQ,CAAC,SAAS,CAAkB;IAC5C,OAAO,CAAC,QAAQ,CAAC,WAAW,CAAkB;IAC9C,OAAO,CAAC,QAAQ,CAAC,kBAAkB,CAAkB;IACrD,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAkB;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,MAAM,CAAU;IACjC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAU;IAClC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IACjC,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAuB;IAClD,OAAO,CAAC,QAAQ,CAAC,eAAe,CAAS;gBAG7B,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAqC1H;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF,OAAO,CAAC,kBAAkB;IAc1B,OAAO,CAAC,gBAAgB;IAexB;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAsDvD;;OAEG;IACM,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAiBpD"}