@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,100 +1,72 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
-
3
2
  import { TransformerDecoder } from '@/layers/transformer_decoder';
4
-
5
3
  // disables warning for using the faster node backend,
6
4
  // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
5
  tf.env().set('IS_NODE', false);
8
-
9
-
10
6
  describe("TransformerDecoder tests", () => {
11
7
  it("should return an output with the same shape as the input", () => {
12
8
  const input = tf.randomUniform([2, 3, 12]);
13
-
14
9
  const decoder = new TransformerDecoder({
15
- numHeads: 2, embedDim: input.shape.at(-1)!,
10
+ numHeads: 2, embedDim: input.shape.at(-1),
16
11
  dropout: 0.5, activation: "gelu", dimsFeedForward: 321, useBias: false
17
12
  });
18
-
19
- const output = decoder.apply(input) as tf.Tensor;
20
-
13
+ const output = decoder.apply(input);
21
14
  expect(output.shape.length).toBe(input.shape.length);
22
- })
23
-
24
-
15
+ });
25
16
  test("forward calls", () => {
26
17
  const input = tf.randomUniform([2, 3, 12]);
27
- const mask = tf.randomUniform([input.shape[0]!, input.shape[1]!], -1, 2, "bool");
18
+ const mask = tf.randomUniform([input.shape[0], input.shape[1]], -1, 2, "bool");
28
19
  const incorrect_mask = tf.randomUniform([2, 5, 12], -1, 2, "bool");
29
-
30
-
31
- const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1)! });
20
+ const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1) });
32
21
  expect(() => decoder.apply(input)).not.toThrow();
33
22
  expect(() => decoder.apply([input])).not.toThrow();
34
-
35
23
  // causal masking
36
- const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
24
+ const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
37
25
  expect(() => causal.apply(input)).not.toThrow();
38
26
  expect(() => causal.apply([input])).not.toThrow();
39
- })
40
-
41
-
27
+ });
42
28
  it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
43
29
  const input = tf.randomUniform([2, 3, 12]);
44
-
45
- expect(() => new TransformerDecoder({ numHeads: 3, embedDim: input.shape.at(-1)! })).not.toThrow();
46
- expect(() => new TransformerDecoder({ numHeads: 5, embedDim: input.shape.at(-1)! })).toThrow();
47
- })
48
-
49
-
30
+ expect(() => new TransformerDecoder({ numHeads: 3, embedDim: input.shape.at(-1) })).not.toThrow();
31
+ expect(() => new TransformerDecoder({ numHeads: 5, embedDim: input.shape.at(-1) })).toThrow();
32
+ });
50
33
  it("should not accept non-rank 3 tensor inputs", () => {
51
34
  const embed_dim = 12;
52
-
53
35
  const BAD_RANK4 = tf.randomUniform([2, 3, 12, embed_dim]);
54
36
  const BAD_RANK2 = tf.randomUniform([2, embed_dim]);
55
37
  const GOOD = tf.randomUniform([2, 3, embed_dim]);
56
- const mask = tf.randomUniform([GOOD.shape[0]!, GOOD.shape[1]!], -1, 2, "bool");
57
-
38
+ const mask = tf.randomUniform([GOOD.shape[0], GOOD.shape[1]], -1, 2, "bool");
58
39
  let decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
59
-
60
40
  // BAD
61
41
  expect(() => decoder.apply(BAD_RANK4)).toThrow();
62
42
  expect(() => decoder.apply(BAD_RANK2)).toThrow();
63
-
64
43
  // OK
65
44
  decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
66
45
  expect(() => decoder.apply(GOOD)).not.toThrow();
67
46
  expect(() => decoder.apply([GOOD])).not.toThrow();
68
47
  expect(() => decoder.apply([GOOD, mask])).not.toThrow();
69
- })
70
-
71
-
48
+ });
72
49
  it("should not accept inputs that are less or more than 1 and 2 tensors", () => {
73
50
  const input = tf.randomUniform([2, 3, 12]);
74
-
75
- let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
51
+ let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
76
52
  // OK
77
53
  expect(() => decoder.apply(input)).not.toThrow();
78
54
  expect(() => decoder.apply([input])).not.toThrow();
79
-
80
55
  // BAD
81
- decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
56
+ decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
82
57
  expect(() => decoder.apply([])).toThrow(); // stops at build()
83
58
  decoder.apply(input); // get past the initial build
84
59
  expect(() => decoder.apply([input, input, input])).toThrow();
85
60
  expect(() => decoder.apply([input, input, input, input])).toThrow();
86
-
87
61
  // BAD (tests build())
88
- decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
62
+ decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
89
63
  expect(() => decoder.apply([input, input, input])).toThrow();
90
64
  expect(() => decoder.apply([input, input, input, input])).toThrow();
91
- })
92
-
93
-
65
+ });
94
66
  it("should return a non-empty config dict", () => {
95
67
  const input = tf.randomUniform([2, 3, 12]);
96
-
97
- const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
68
+ const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
98
69
  expect(Object.keys(decoder.getConfig())).not.toBe(0);
99
- })
100
- })
70
+ });
71
+ });
72
+ //# sourceMappingURL=transformer_decoder.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.test.js","sourceRoot":"","sources":["../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,KAAK;SACzE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,GAAG,EAAE;QACvB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QACjF,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAGnE,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,iBAAiB;QACjB,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACnG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC;QAErB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,SAAS,CAAC,CAAC,CAAC;QAC1D,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACnD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACjD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAE/E,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEjD,KAAK;QACL,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QACvE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC5D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,qEAAqE,EAAE,GAAG,EAAE;QAC3E,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACrF,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM;QACN,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,mBAAmB;QAC9D,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,6BAA6B;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEpE,sBAAsB;QACtB,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACxE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
@@ -0,0 +1,55 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+ import { type MultiHeadAttentionArgs } from "../layers/multihead_attention";
4
+ export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
5
+ activation?: "relu" | "gelu";
6
+ dimsFeedForward?: number;
7
+ }
8
+ /**
9
+ * This class implements the transformer encoder architecture from the 2017 paper
10
+ * Attention Is All You Need.
11
+ *
12
+ * This layer accepts exactly one tensor input with the shape
13
+ * `[ batch, sequences, embedding dims ]`.
14
+ *
15
+ * @param numHeads number of attention heads to use
16
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
17
+ * @param causal use causal masking, default `false` for encoders
18
+ * @param dropout use dropout during the attention calculations, default `0.1`
19
+ * @param activation the activation of the intermediate feed forward layer, default `relu`
20
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
21
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
22
+ */
23
+ export declare class TransformerEncoder extends tf.layers.Layer {
24
+ static className: string;
25
+ private readonly selfAttention;
26
+ private readonly selfAttentionDropout;
27
+ private readonly selfAttentionNorm;
28
+ private readonly reluLayer;
29
+ private readonly linearLayer;
30
+ private readonly feedForwardDropout;
31
+ private readonly feedFowardNorm;
32
+ private readonly numHeads;
33
+ private readonly embedDim;
34
+ private readonly causal;
35
+ private readonly useBias;
36
+ private readonly dropout;
37
+ private readonly activation;
38
+ private readonly dimsFeedForward;
39
+ constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs);
40
+ /**
41
+ * Forward propagation
42
+ */
43
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
44
+ private selfAttentionBlock;
45
+ private feedForwardBlock;
46
+ /**
47
+ * Initialize the sublayers' weights and track them to enable backpropagation.
48
+ */
49
+ build(inputShape: tf.Shape | tf.Shape[]): void;
50
+ /**
51
+ * Save the layer's hyperparameters for serialization
52
+ */
53
+ getConfig(): tf.serialization.ConfigDict;
54
+ }
55
+ //# sourceMappingURL=transformer_encoder.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_encoder.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,+BAA+B,CAAC;AAGhG,MAAM,WAAW,sBAAuB,SAAQ,sBAAsB;IAClE,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;;;;GAcG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,OAAO,CAAC,QAAQ,CAAC,aAAa,CAAkB;IAChD,OAAO,CAAC,QAAQ,CAAC,oBAAoB,CAAkB;IACvD,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAkB;IAEpD,OAAO,CAAC,QAAQ,CAAC,SAAS,CAAkB;IAC5C,OAAO,CAAC,QAAQ,CAAC,WAAW,CAAkB;IAC9C,OAAO,CAAC,QAAQ,CAAC,kBAAkB,CAAkB;IACrD,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAkB;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,MAAM,CAAU;IACjC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAU;IAClC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IACjC,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAuB;IAClD,OAAO,CAAC,QAAQ,CAAC,eAAe,CAAS;gBAG7B,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAqC1H;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF,OAAO,CAAC,kBAAkB;IAc1B,OAAO,CAAC,gBAAgB;IAexB;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAsDvD;;OAEG;IACM,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAiBpD"}
@@ -1,23 +1,12 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
- import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
- import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
4
-
5
- import { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
6
-
7
-
8
- export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
9
- activation?: "relu" | "gelu";
10
- dimsFeedForward?: number;
11
- }
12
-
13
-
2
+ import { MultiHeadAttention } from "../layers/multihead_attention";
14
3
  /**
15
4
  * This class implements the transformer encoder architecture from the 2017 paper
16
5
  * Attention Is All You Need.
17
- *
6
+ *
18
7
  * This layer accepts exactly one tensor input with the shape
19
8
  * `[ batch, sequences, embedding dims ]`.
20
- *
9
+ *
21
10
  * @param numHeads number of attention heads to use
22
11
  * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
23
12
  * @param causal use causal masking, default `false` for encoders
@@ -28,28 +17,22 @@ export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
28
17
  */
29
18
  export class TransformerEncoder extends tf.layers.Layer {
30
19
  static className = "TransformerEncoder";
31
-
32
- private readonly selfAttention: tf.layers.Layer;
33
- private readonly selfAttentionDropout: tf.layers.Layer;
34
- private readonly selfAttentionNorm: tf.layers.Layer;
35
-
36
- private readonly reluLayer: tf.layers.Layer;
37
- private readonly linearLayer: tf.layers.Layer;
38
- private readonly feedForwardDropout: tf.layers.Layer;
39
- private readonly feedFowardNorm: tf.layers.Layer;
40
-
41
- private readonly numHeads: number;
42
- private readonly embedDim: number;
43
- private readonly causal: boolean;
44
- private readonly useBias: boolean;
45
- private readonly dropout: number;
46
- private readonly activation: ActivationIdentifier;
47
- private readonly dimsFeedForward: number;
48
-
49
-
50
- constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs) {
20
+ selfAttention;
21
+ selfAttentionDropout;
22
+ selfAttentionNorm;
23
+ reluLayer;
24
+ linearLayer;
25
+ feedForwardDropout;
26
+ feedFowardNorm;
27
+ numHeads;
28
+ embedDim;
29
+ causal;
30
+ useBias;
31
+ dropout;
32
+ activation;
33
+ dimsFeedForward;
34
+ constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }) {
51
35
  super(args);
52
-
53
36
  this.numHeads = numHeads;
54
37
  this.embedDim = embedDim;
55
38
  this.causal = causal ?? false;
@@ -57,19 +40,16 @@ export class TransformerEncoder extends tf.layers.Layer {
57
40
  this.dropout = dropout ?? 0.1;
58
41
  this.activation = activation ?? "relu";
59
42
  this.dimsFeedForward = dimsFeedForward ?? 2048;
60
-
61
43
  if (this.dropout >= 1) {
62
44
  throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
63
45
  }
64
-
65
46
  // self attention sub-block
66
47
  this.selfAttention = new MultiHeadAttention({
67
48
  numHeads: this.numHeads, embedDim: this.embedDim, useBias: this.useBias,
68
49
  dropout: this.dropout, causal: this.causal
69
50
  });
70
- this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
51
+ this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
71
52
  this.selfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
72
-
73
53
  // feed forward sub-block
74
54
  this.reluLayer = tf.layers.dense({
75
55
  units: this.dimsFeedForward, activation: this.activation,
@@ -82,96 +62,76 @@ export class TransformerEncoder extends tf.layers.Layer {
82
62
  this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
83
63
  this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
84
64
  }
85
-
86
-
87
65
  /**
88
66
  * Forward propagation
89
67
  */
90
- override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
68
+ call(inputs, kwargs) {
91
69
  // validate the input tensors
92
- let input: tf.Tensor;
93
-
70
+ let input;
94
71
  if (Array.isArray(inputs)) {
95
72
  if (inputs.length != 1) {
96
73
  throw Error(`${this.getClassName}::call ${this.name} expects exactly 1 tensor` +
97
74
  ` input, got ${inputs.length} inputs instead.`);
98
75
  }
99
-
100
76
  input = inputs[0];
101
- } else {
77
+ }
78
+ else {
102
79
  input = inputs;
103
80
  }
104
-
105
81
  // perform forward propagation
106
82
  return tf.tidy(() => {
107
83
  const attention = this.selfAttentionBlock(input, kwargs);
108
84
  const feedforward = this.feedForwardBlock(attention, kwargs);
109
-
110
85
  return feedforward;
111
86
  });
112
87
  }
113
-
114
-
115
- private selfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
88
+ selfAttentionBlock(x, kwargs) {
116
89
  return tf.tidy(() => {
117
90
  const residual = x;
118
-
119
- let attention = this.selfAttention.apply(x, kwargs) as tf.Tensor;
120
- attention = this.selfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
91
+ let attention = this.selfAttention.apply(x, kwargs);
92
+ attention = this.selfAttentionDropout.apply(attention, kwargs);
121
93
  attention = tf.add(attention, residual);
122
- attention = this.selfAttentionNorm.apply(attention) as tf.Tensor;
123
-
94
+ attention = this.selfAttentionNorm.apply(attention);
124
95
  return attention;
125
96
  });
126
97
  }
127
-
128
-
129
- private feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
98
+ feedForwardBlock(x, kwargs) {
130
99
  return tf.tidy(() => {
131
100
  const residual = x;
132
-
133
101
  let feedForward = this.reluLayer.apply(x);
134
102
  feedForward = this.linearLayer.apply(feedForward);
135
- feedForward = this.feedForwardDropout.apply(feedForward, kwargs) as tf.Tensor;
103
+ feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
136
104
  feedForward = tf.add(feedForward, residual);
137
- feedForward = this.feedFowardNorm.apply(feedForward) as tf.Tensor;
138
-
105
+ feedForward = this.feedFowardNorm.apply(feedForward);
139
106
  return feedForward;
140
107
  });
141
108
  }
142
-
143
-
144
109
  /**
145
110
  * Initialize the sublayers' weights and track them to enable backpropagation.
146
111
  */
147
- override build(inputShape: tf.Shape | tf.Shape[]): void {
148
- let input_shapes: tf.Shape[] = [];
149
-
112
+ build(inputShape) {
113
+ let input_shapes = [];
150
114
  if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
151
115
  // input is an array of shapes
152
- input_shapes = inputShape as tf.Shape[];
153
- } else if (inputShape.length != 0) {
116
+ input_shapes = inputShape;
117
+ }
118
+ else if (inputShape.length != 0) {
154
119
  // input is a single shape
155
- input_shapes = [inputShape as tf.Shape];
120
+ input_shapes = [inputShape];
156
121
  }
157
-
158
122
  // expects only 1 rank 3 tensor input
159
123
  if (input_shapes.length != 1 || input_shapes[0].length != 3) {
160
- throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
124
+ throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
161
125
  }
162
-
163
126
  // initialize self attention sub-block's weights
164
127
  this.selfAttention.build(inputShape);
165
128
  this.selfAttentionNorm.build(inputShape);
166
-
167
129
  // inintialize feedforward sub-block's weights
168
130
  const reluLayerOutputShape = this.reluLayer.computeOutputShape(inputShape);
169
131
  const linearLayerOutputShape = this.linearLayer.computeOutputShape(reluLayerOutputShape);
170
-
171
132
  this.reluLayer.build(inputShape);
172
133
  this.linearLayer.build(reluLayerOutputShape);
173
134
  this.feedFowardNorm.build(linearLayerOutputShape);
174
-
175
135
  // track sublayers' weights
176
136
  this.trainableWeights = [
177
137
  ...this.selfAttention.trainableWeights,
@@ -182,28 +142,22 @@ export class TransformerEncoder extends tf.layers.Layer {
182
142
  ...this.feedForwardDropout.trainableWeights,
183
143
  ...this.feedFowardNorm.trainableWeights
184
144
  ];
185
-
186
145
  // rename the weights otherwise they'll take on the default naming and overlap
187
146
  // each other which breaks model loading due to duplicate weight names
188
147
  let indexing = 0;
189
-
190
148
  for (const weight of this.trainableWeights) {
191
149
  const unique_name = `${this.getClassName()}_${indexing}`;
192
- (weight as any).name += unique_name;
193
- (weight as any).originalName += unique_name;
150
+ weight.name += unique_name;
151
+ weight.originalName += unique_name;
194
152
  indexing++;
195
153
  }
196
-
197
154
  super.build(inputShape);
198
155
  }
199
-
200
-
201
156
  /**
202
157
  * Save the layer's hyperparameters for serialization
203
158
  */
204
- override getConfig(): tf.serialization.ConfigDict {
159
+ getConfig() {
205
160
  const base_config = super.getConfig();
206
-
207
161
  const config = {
208
162
  numHeads: this.numHeads,
209
163
  embedDim: this.embedDim,
@@ -213,12 +167,9 @@ export class TransformerEncoder extends tf.layers.Layer {
213
167
  activation: this.activation,
214
168
  dimsFeedForward: this.dimsFeedForward
215
169
  };
216
-
217
170
  Object.assign(config, base_config);
218
-
219
171
  return config;
220
172
  }
221
173
  }
222
-
223
-
224
174
  tf.serialization.registerClass(TransformerEncoder);
175
+ //# sourceMappingURL=transformer_encoder.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_encoder.js","sourceRoot":"","sources":["../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,+BAA+B,CAAC;AAShG;;;;;;;;;;;;;;GAcG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAEvB,aAAa,CAAkB;IAC/B,oBAAoB,CAAkB;IACtC,iBAAiB,CAAkB;IAEnC,SAAS,CAAkB;IAC3B,WAAW,CAAkB;IAC7B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,MAAM,CAAU;IAChB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAGzC,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QACtH,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,MAAM,GAAG,MAAM,IAAI,KAAK,CAAC;QAC9B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QACvC,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,IAAI,CAAC;QAE/C,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2BAA2B;QAC3B,IAAI,CAAC,aAAa,GAAG,IAAI,kBAAkB,CAAC;YACxC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YACvE,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,MAAM,EAAE,IAAI,CAAC,MAAM;SAC7C,CAAC,CAAC;QACH,IAAI,CAAC,oBAAoB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QACrE,IAAI,CAAC,iBAAiB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAEzE,yBAAyB;QACzB,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC7B,KAAK,EAAE,IAAI,CAAC,eAAe,EAAE,UAAU,EAAE,IAAI,CAAC,UAAU;YACxD,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC/B,KAAK,EAAE,IAAI,CAAC,QAAQ,EAAE,UAAU,EAAE,QAAQ;YAC1C,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAgB,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBACrB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,UAAU,IAAI,CAAC,IAAI,2BAA2B;oBAC1E,eAAe,MAAM,CAAC,MAAM,kBAAkB,CAAC,CAAC;YACxD,CAAC;YAED,KAAK,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACtB,CAAC;aAAM,CAAC;YACJ,KAAK,GAAG,MAAM,CAAC;QACnB,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,SAAS,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC;YACzD,MAAM,WAAW,GAAG,IAAI,CAAC,gBAAgB,CAAC,SAAS,EAAE,MAAM,CAAC,CAAC;YAE7D,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,kBAAkB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACjE,SAAS,GAAG,IAAI,CAAC,oBAAoB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC5E,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,SAAS,CAAc,CAAC;YAEjE,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACjD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC1C,WAAW,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC;YAClD,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,CAAc,CAAC;YAElE,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,qCAAqC;QACrC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC1D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,iEAAiE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACxJ,CAAC;QAED,gDAAgD;QAChD,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACrC,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QAEzC,8CAA8C;QAC9C,MAAM,oBAAoB,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC3E,MAAM,sBAAsB,GAAG,IAAI,CAAC,WAAW,CAAC,kBAAkB,CAAC,oBAAoB,CAAC,CAAC;QAEzF,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACjC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC;QAC7C,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,sBAAsB,CAAC,CAAC;QAElD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,oBAAoB,CAAC,gBAAgB;YAC7C,GAAG,IAAI,CAAC,iBAAiB,CAAC,gBAAgB;YAC1C,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,WAAW,CAAC,gBAAgB;YACpC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAC;QAEF,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=transformer_encoder.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_encoder.test.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":""}
@@ -1,85 +1,58 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
-
3
2
  import { TransformerEncoder } from "@/layers/transformer_encoder";
4
-
5
3
  // disables warning for using the faster node backend,
6
4
  // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
5
  tf.env().set('IS_NODE', false);
8
-
9
-
10
6
  describe("TransformerEncoder tests", () => {
11
7
  it("should return an output with the same shape as the input", () => {
12
8
  const input = tf.randomUniform([2, 3, 10]);
13
-
14
9
  const decoder = new TransformerEncoder({
15
- numHeads: 2, embedDim: input.shape.at(-1)!,
10
+ numHeads: 2, embedDim: input.shape.at(-1),
16
11
  dropout: 0.5, activation: "gelu", dimsFeedForward: 512, useBias: true
17
12
  });
18
-
19
- const output = decoder.apply(input) as tf.Tensor;
20
-
13
+ const output = decoder.apply(input);
21
14
  expect(output.shape.length).toBe(input.shape.length);
22
- })
23
-
24
-
15
+ });
25
16
  test("correct forward calls", () => {
26
17
  const input = tf.randomUniform([2, 3, 10]);
27
-
28
- const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)! });
18
+ const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1) });
29
19
  expect(() => encoder.apply(input)).not.toThrow();
30
20
  expect(() => encoder.apply([input])).not.toThrow();
31
-
32
- const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
21
+ const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
33
22
  expect(() => causal.apply(input)).not.toThrow();
34
23
  expect(() => causal.apply([input])).not.toThrow();
35
- })
36
-
37
-
24
+ });
38
25
  it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
39
26
  const input = tf.randomUniform([2, 3, 10]);
40
-
41
- expect(() => new TransformerEncoder({ numHeads: 3, embedDim: input.shape.at(-1)! })).toThrow();
42
- expect(() => new TransformerEncoder({ numHeads: 5, embedDim: input.shape.at(-1)! })).not.toThrow();
43
- })
44
-
45
-
27
+ expect(() => new TransformerEncoder({ numHeads: 3, embedDim: input.shape.at(-1) })).toThrow();
28
+ expect(() => new TransformerEncoder({ numHeads: 5, embedDim: input.shape.at(-1) })).not.toThrow();
29
+ });
46
30
  it("should not accept non-rank 3 tensor inputs", () => {
47
31
  const incorrect_input = tf.randomUniform([2, 3, 10, 10]);
48
32
  const incorrect_input2 = tf.randomUniform([2, 3]);
49
33
  const correct_input = tf.randomUniform([2, 3, 10]);
50
-
51
-
52
- const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1)! });
34
+ const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1) });
53
35
  expect(() => encoder.apply([correct_input, correct_input])).toThrow();
54
-
55
36
  expect(() => encoder.apply(incorrect_input)).toThrow();
56
37
  expect(() => encoder.apply(incorrect_input2)).toThrow();
57
-
58
38
  expect(() => encoder.apply([correct_input, incorrect_input])).toThrow();
59
39
  expect(() => encoder.apply([incorrect_input, correct_input])).toThrow();
60
-
61
40
  expect(() => encoder.apply([correct_input, incorrect_input2])).toThrow();
62
41
  expect(() => encoder.apply([incorrect_input2, correct_input])).toThrow();
63
- })
64
-
65
-
42
+ });
66
43
  it("should accept exactly one input", () => {
67
44
  const input = tf.randomUniform([2, 3, 10]);
68
-
69
- const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
45
+ const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
70
46
  expect(() => encoder.apply(input)).not.toThrow();
71
47
  expect(() => encoder.apply([input])).not.toThrow();
72
-
73
48
  expect(() => encoder.apply([])).toThrow();
74
49
  expect(() => encoder.apply([input, input])).toThrow();
75
- expect(() => encoder.apply([input, input, input])).toThrow()
76
- })
77
-
78
-
50
+ expect(() => encoder.apply([input, input, input])).toThrow();
51
+ });
79
52
  it("should return a non-empty config dict", () => {
80
53
  const input = tf.randomUniform([2, 3, 10]);
81
-
82
- const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
54
+ const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
83
55
  expect(Object.keys(encoder.getConfig())).not.toBe(0);
84
- })
85
- })
56
+ });
57
+ });
58
+ //# sourceMappingURL=transformer_encoder.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_encoder.test.js","sourceRoot":"","sources":["../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,IAAI;SACxE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,uBAAuB,EAAE,GAAG,EAAE;QAC/B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACvG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,eAAe,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QACzD,MAAM,gBAAgB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,aAAa,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAGnD,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,eAAe,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEtE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,gBAAgB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,eAAe,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,gBAAgB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACzE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,gBAAgB,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC7E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAA;IAChE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
@@ -0,0 +1,30 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ export declare function diceBinaryStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
3
+ export declare function diceBinaryGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
4
+ export declare function diceCategoricalStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
5
+ export declare function diceCategoricalGeneralized(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
6
+ export declare function diceCategoricalGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
7
+ /**
8
+ * Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
9
+ * Both have equal weight.
10
+ *
11
+ * @param y_true the label tensor
12
+ * @param y_pred the prediction tensor (not sparse)
13
+ * @returns a tensor of shape `[ batch ]`
14
+ */
15
+ export declare function diceBinaryCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
16
+ /**
17
+ * Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
18
+ * Both have equal weight. Expects dense (non-sparse) label tensors.
19
+ *
20
+ * This does not support sparse tensors because TFJS's
21
+ * sparseCategoricalCrossentropy loss onehots the label
22
+ * and calls categoricalCrossentropy. See
23
+ * https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
24
+ *
25
+ * @param y_true the label
26
+ * @param y_pred the prediction tensor (not sparse)
27
+ * @returns a tensor of shape `[ batch ]`
28
+ */
29
+ export declare function diceCategoricalCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
30
+ //# sourceMappingURL=dice.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"dice.d.ts","sourceRoot":"","sources":["../../src/losses/dice.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAWvC,wBAAgB,kBAAkB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAclF;AAQD,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAahF;AAOD,wBAAgB,uBAAuB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAUvF;AAOD,wBAAgB,0BAA0B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB1F;AAOD,wBAAgB,qBAAqB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAWrF;AAOD;;;;;;;GAOG;AACH,wBAAgB,sBAAsB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAMtF;AAOD;;;;;;;;;;;;GAYG;AACH,wBAAgB,2BAA2B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAM3F"}