@stellarapp/tfjs-stellar 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +14 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,24 +1,13 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
- import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
-
5
- import { PositionalEncoding, type PositionalEncodingArgs } from '@/layers/positional_encoding';
6
-
7
-
8
- export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEncodingArgs {
9
- vocabularySize: number;
10
- dropout?: number
11
- }
12
-
13
-
2
+ import { PositionalEncoding } from '../layers/positional_encoding';
14
3
  /**
15
4
  * This class implements combines sinusoidal positional encoding from the
16
5
  * 2017 paper "Attention Is All You Need" with a normal embedding layer to
17
6
  * form a simplified single embedding layer.
18
- *
7
+ *
19
8
  * This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
20
9
  * it through an embedding layer before adding sinusoidal positional encoding.
21
- *
10
+ *
22
11
  * @param embedDim the size of each token/word's embedding
23
12
  * @param vocabularySize the number of tokens to embed
24
13
  * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
@@ -26,124 +15,95 @@ export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEn
26
15
  */
27
16
  export class TokenAndPositionalEmbedding extends tf.layers.Layer {
28
17
  static className = "TokenAndPositionalEmbedding";
29
-
30
- private readonly embedDim: number;
31
- private readonly vocabularySize: number;
32
- private embedding: tf.layers.Layer;
33
-
34
- private positional: tf.layers.Layer
35
- private readonly maxSequenceLength: number;
36
- private readonly dropout: number;
37
-
38
- private dropoutLayer: tf.layers.Layer;
39
-
40
-
41
- constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }: TokenAndPositionalEmbeddingArgs) {
18
+ embedDim;
19
+ vocabularySize;
20
+ embedding;
21
+ positional;
22
+ maxSequenceLength;
23
+ dropout;
24
+ dropoutLayer;
25
+ constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }) {
42
26
  super(args);
43
-
44
27
  this.embedDim = embedDim;
45
28
  this.vocabularySize = vocabularySize;
46
29
  this.maxSequenceLength = maxSequenceLength ?? 5120;
47
30
  this.dropout = dropout ?? 0.1;
48
-
49
31
  if (this.dropout >= 1) {
50
32
  throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
51
33
  }
52
-
53
34
  this.embedding = tf.layers.embedding({
54
35
  inputDim: this.vocabularySize,
55
36
  outputDim: this.embedDim,
56
37
  });
57
-
58
38
  this.positional = new PositionalEncoding({
59
39
  maxSequenceLength: this.maxSequenceLength,
60
40
  embedDim: this.embedDim,
61
41
  });
62
-
63
42
  this.dropoutLayer = tf.layers.dropout({ rate: this.dropout });
64
43
  }
65
-
66
-
67
44
  /**
68
- * Forward propagation.
45
+ * Forward propagation.
69
46
  */
70
- override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs) {
47
+ call(inputs, kwargs) {
71
48
  if (Array.isArray(inputs) && inputs.length != 1) {
72
49
  throw Error(`${this.getClassName()}::call ${this.name} expects exactly` +
73
50
  ` 1 tensor input, received ${inputs.length}`);
74
51
  }
75
-
76
52
  return tf.tidy(() => {
77
- let output = this.positional.apply(this.embedding.apply(inputs)) as tf.Tensor;
78
- output = this.dropoutLayer.apply(output) as tf.Tensor;
79
-
53
+ let output = this.positional.apply(this.embedding.apply(inputs));
54
+ output = this.dropoutLayer.apply(output);
80
55
  return output;
81
- })
56
+ });
82
57
  }
83
-
84
-
85
58
  /**
86
59
  * Build the sublayers and enable serialization
87
60
  */
88
- override build(inputShape: tf.Shape | tf.Shape[]): void {
89
- let input_shapes: tf.Shape[] = [];
90
-
61
+ build(inputShape) {
62
+ let input_shapes = [];
91
63
  // only consider the first shape if multiple provided
92
64
  if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
93
65
  // input is an array of shapes
94
- input_shapes = inputShape as tf.Shape[];
95
- } else if (inputShape.length != 0) {
66
+ input_shapes = inputShape;
67
+ }
68
+ else if (inputShape.length != 0) {
96
69
  // input is a single shape
97
- input_shapes = [inputShape as tf.Shape];
70
+ input_shapes = [inputShape];
98
71
  }
99
-
100
- if (input_shapes[0].length != 2 || input_shapes[0][1]! > this.maxSequenceLength) {
72
+ if (input_shapes[0].length != 2 || input_shapes[0][1] > this.maxSequenceLength) {
101
73
  throw Error(`${this.getClassName()}::build ${this.name} expected an input of` +
102
74
  ` shape [batch, tokens] where tokens < ${this.maxSequenceLength},` +
103
75
  ` received ${JSON.stringify(input_shapes[0])}`);
104
76
  }
105
-
106
77
  // initialize the sublayers' weights
107
78
  this.embedding.build(input_shapes[0]);
108
79
  this.positional.build(this.embedding.computeOutputShape(input_shapes[0]));
109
-
110
80
  // no need to rename weights, haven't found a case where their names collide
111
81
  this.trainableWeights = [
112
82
  ...this.embedding.trainableWeights,
113
83
  ...this.positional.trainableWeights
114
84
  ];
115
-
116
85
  super.build(input_shapes[0]);
117
86
  }
118
-
119
-
120
87
  /**
121
88
  * The output shape, for an input shape of [batch, sequences], is
122
89
  * [batch, sequences, embedDim]
123
90
  */
124
- override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
91
+ computeOutputShape(inputShape) {
125
92
  const embedding_shape = this.embedding.computeOutputShape(inputShape);
126
93
  const positional_shape = this.positional.computeOutputShape(embedding_shape);
127
-
128
94
  return positional_shape;
129
95
  }
130
-
131
-
132
- override getConfig(): tf.serialization.ConfigDict {
96
+ getConfig() {
133
97
  const base_config = super.getConfig();
134
-
135
98
  const config = {
136
99
  embedDim: this.embedDim,
137
100
  vocabularySize: this.vocabularySize,
138
101
  maxSequenceLength: this.maxSequenceLength,
139
102
  dropout: this.dropout,
140
- }
141
-
103
+ };
142
104
  Object.assign(config, base_config);
143
-
144
105
  return config;
145
106
  }
146
107
  }
147
-
148
-
149
108
  tf.serialization.registerClass(TokenAndPositionalEmbedding);
109
+ //# sourceMappingURL=token_and_positional_embedding.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.js","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,+BAA+B,CAAC;AAShG;;;;;;;;;;;;GAYG;AACH,MAAM,OAAO,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,GAAG,6BAA6B,CAAC;IAEhC,QAAQ,CAAS;IACjB,cAAc,CAAS;IAChC,SAAS,CAAkB;IAE3B,UAAU,CAAiB;IAClB,iBAAiB,CAAS;IAC1B,OAAO,CAAS;IAEzB,YAAY,CAAkB;IAGtC,YAAY,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAmC;QAC1G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,cAAc,GAAG,cAAc,CAAC;QACrC,IAAI,CAAC,iBAAiB,GAAG,iBAAiB,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAE9B,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,SAAS,CAAC;YACjC,QAAQ,EAAE,IAAI,CAAC,cAAc;YAC7B,SAAS,EAAE,IAAI,CAAC,QAAQ;SAC3B,CAAC,CAAC;QAEH,IAAI,CAAC,UAAU,GAAG,IAAI,kBAAkB,CAAC;YACrC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAC,CAAC;QAEH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;IAClE,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kBAAkB;gBACnE,6BAA6B,MAAM,CAAC,MAAM,EAAE,CAAC,CAAC;QACtD,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,MAAM,CAAC,CAAc,CAAC;YAC9E,MAAM,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,MAAM,CAAc,CAAC;YAEtD,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,qDAAqD;QACrD,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAE,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YAC9E,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,uBAAuB;gBACzE,yCAAyC,IAAI,CAAC,iBAAiB,GAAG;gBAClE,aAAa,IAAI,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;QACxD,CAAC;QAED,oCAAoC;QACpC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAE1E,4EAA4E;QAC5E,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,UAAU,CAAC,gBAAgB;SACtC,CAAC;QAEF,KAAK,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;IACjC,CAAC;IAGD;;;OAGG;IACM,kBAAkB,CAAC,UAAiC;QACzD,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QACtE,MAAM,gBAAgB,GAAG,IAAI,CAAC,UAAU,CAAC,kBAAkB,CAAC,eAAe,CAAC,CAAC;QAE7E,OAAO,gBAAgB,CAAC;IAC5B,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,2BAA2B,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=token_and_positional_embedding.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.test.d.ts","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":""}
@@ -1,81 +1,58 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
-
3
2
  import { TokenAndPositionalEmbedding } from '@/layers/token_and_positional_embedding';
4
-
5
3
  // disables warning for using the faster node backend,
6
4
  // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
7
5
  tf.env().set('IS_NODE', false);
8
-
9
-
10
6
  describe("PositionalEncoding tests", () => {
11
7
  test("layer initialization", () => {
12
8
  expect(() => new TokenAndPositionalEmbedding({ maxSequenceLength: 0, embedDim: 10, vocabularySize: 10_000 })).toThrow();
13
9
  expect(() => new TokenAndPositionalEmbedding({ embedDim: 0, vocabularySize: 10_000 })).toThrow();
14
10
  expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 0 })).toThrow();
15
-
16
11
  expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
17
12
  expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
18
- })
19
-
20
-
13
+ });
21
14
  test("successfull forward calls", () => {
22
15
  const embed_dims = 32;
23
16
  const sequences = 4;
24
17
  const vocab_size = 10_000;
25
18
  const input = tf.randomUniform([2, sequences]);
26
-
27
19
  const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, dropout: 0.1, vocabularySize: vocab_size });
28
20
  expect(() => embedding.apply(input)).not.toThrow();
29
21
  expect(() => embedding.apply([input])).not.toThrow();
30
- })
31
-
32
-
22
+ });
33
23
  test("layer build", () => {
34
24
  const input_ok = tf.randomUniform([2, 4]);
35
25
  const input_too_many_words = tf.randomUniform([2, 700]);
36
26
  const input_is_image = tf.randomUniform([1, 32, 32, 3]);
37
-
38
27
  let embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
39
28
  expect(() => embedding.build(input_ok.shape)).not.toThrow();
40
-
41
29
  embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
42
30
  expect(() => embedding.build([input_ok.shape, input_ok.shape])).not.toThrow();
43
-
44
31
  new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
45
32
  expect(() => embedding.build(input_too_many_words.shape)).toThrow();
46
33
  expect(() => embedding.build(input_is_image.shape)).toThrow();
47
- })
48
-
49
-
34
+ });
50
35
  it("should throw when more than one input provided, input sequences are too large, or incorrect input rank", () => {
51
36
  const sequences_too_long = tf.randomUniform([10, 1000]);
52
37
  const multiple_correct_inputs = [tf.randomUniform([2, 3]), tf.randomUniform([2, 3])];
53
38
  const wrong_rank = tf.randomUniform([10, 32, 32]);
54
-
55
39
  const positional = new TokenAndPositionalEmbedding({ maxSequenceLength: 10, embedDim: 32, vocabularySize: 10_000 });
56
40
  positional.build([2, 3]); // get past the initial build call to test forward prop
57
-
58
41
  expect(() => positional.apply(sequences_too_long)).toThrow();
59
42
  expect(() => positional.apply(multiple_correct_inputs)).toThrow();
60
43
  expect(() => positional.apply(wrong_rank)).toThrow();
61
- })
62
-
63
-
44
+ });
64
45
  it("should return a non-empty config dict", () => {
65
46
  const embedding = new TokenAndPositionalEmbedding({ embedDim: 32, vocabularySize: 10_000 });
66
47
  expect(Object.keys(embedding.getConfig())).not.toBe(0);
67
- })
68
-
69
-
48
+ });
70
49
  it("should return an output shape of [batch, sequences, embed dims]", () => {
71
50
  const words = 100;
72
51
  const batch = 2;
73
52
  const embed_dims = 64;
74
-
75
53
  const input = tf.randomUniform([batch, words]);
76
-
77
54
  const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, vocabularySize: 10_000 });
78
-
79
55
  expect(embedding.computeOutputShape(input.shape)).toEqual([batch, words, embed_dims]);
80
- })
56
+ });
81
57
  });
58
+ //# sourceMappingURL=token_and_positional_embedding.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.test.js","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,2BAA2B,EAAE,MAAM,yCAAyC,CAAC;AAEtF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,IAAI,CAAC,sBAAsB,EAAE,GAAG,EAAE;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxH,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE7F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC1G,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,UAAU,GAAG,MAAM,CAAC;QAC1B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,cAAc,EAAE,UAAU,EAAE,CAAC,CAAC;QACtH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,aAAa,EAAE,GAAG,EAAE;QACrB,MAAM,QAAQ,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC;QACxD,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;QAExD,IAAI,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE5D,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,QAAQ,CAAC,KAAK,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE9E,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,oBAAoB,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAClE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,wGAAwG,EAAE,GAAG,EAAE;QAC9G,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAC,CAAC;QACxD,MAAM,uBAAuB,GAAG,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACrF,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QACpH,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,uDAAuD;QAEjF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAClE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAC5F,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iEAAiE,EAAE,GAAG,EAAE;QACvE,MAAM,KAAK,GAAG,GAAG,CAAC;QAClB,MAAM,KAAK,GAAG,CAAC,CAAC;QAChB,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAEpG,MAAM,CAAC,SAAS,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,UAAU,CAAC,CAAC,CAAC;IAC1F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
@@ -0,0 +1,69 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
+ import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
4
+ import { type MultiHeadAttentionArgs } from "../layers/multihead_attention";
5
+ export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
6
+ activation?: "relu" | "gelu";
7
+ dimsFeedForward?: number;
8
+ causal?: boolean;
9
+ }
10
+ /**
11
+ * This class implements the transformer decoder architecture from
12
+ * the 2017 paper "Attention Is All You Need".
13
+ *
14
+ * This decoder-only transformer layer accepts one tensor input.
15
+ * The input tensor should have the shape
16
+ * `[ batch, sequences, embedding dims ]`.
17
+ *
18
+ * Causal masking is enabled by default for the initial attention sub-layer.
19
+ *
20
+ * @param numHeads number of attention heads to use
21
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
22
+ * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
23
+ * @param dropout use dropout during the attention calculations, default `0.1`
24
+ * @param activation the activation of the intermediate feed forward layer, default `relu`
25
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
26
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
27
+ */
28
+ export declare class TransformerDecoder extends tf.layers.Layer {
29
+ static className: string;
30
+ protected readonly causalSelfAttention: tf.layers.Layer;
31
+ protected readonly causalSelfAttentionDropout: tf.layers.Layer;
32
+ protected readonly causalSelfAttentionNorm: tf.layers.Layer;
33
+ protected readonly feedforward1: tf.layers.Layer;
34
+ protected readonly feedforward2: tf.layers.Layer;
35
+ protected readonly feedForwardDropout: tf.layers.Layer;
36
+ protected readonly feedFowardNorm: tf.layers.Layer;
37
+ protected readonly numHeads: number;
38
+ protected readonly embedDim: number;
39
+ protected readonly useBias: boolean;
40
+ protected readonly dropout: number;
41
+ protected readonly activation: ActivationIdentifier;
42
+ protected readonly dimsFeedForward: number;
43
+ constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs);
44
+ /**
45
+ * Forward propagation
46
+ *
47
+ * @param inputs input tensor
48
+ * @return the output tensor
49
+ */
50
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
51
+ protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
52
+ protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
53
+ /**
54
+ * Initialize the sublayers' weights and track them to enable serialization
55
+ */
56
+ build(inputShape: tf.Shape | tf.Shape[]): void;
57
+ /**
58
+ * Save the layer's hyperparameters for serialization
59
+ */
60
+ getConfig(): {
61
+ numHeads: number;
62
+ embedDim: number;
63
+ useBias: boolean;
64
+ dropout: number;
65
+ activation: ActivationIdentifier;
66
+ dimsFeedForward: number;
67
+ };
68
+ }
69
+ //# sourceMappingURL=transformer_decoder.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AACjE,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAExG,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,+BAA+B,CAAC;AAI5E,MAAM,WAAW,sBAAuB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAClF,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;IACzB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD;;;;;;;;;;;;;;;;;GAiBG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,SAAS,CAAC,QAAQ,CAAC,mBAAmB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACxD,SAAS,CAAC,QAAQ,CAAC,0BAA0B,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAC/D,SAAS,CAAC,QAAQ,CAAC,uBAAuB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAE5D,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,kBAAkB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACvD,SAAS,CAAC,QAAQ,CAAC,cAAc,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAEnD,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,UAAU,EAAE,oBAAoB,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,MAAM,CAAC;gBAE/B,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAyClH;;;;;OAKG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAoBvF,SAAS,CAAC,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAc3E,SAAS,CAAC,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAenE;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA6DvD;;OAEG;IACM,SAAS;;;;;;;;CAiBrB"}
@@ -0,0 +1,182 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { CachedRoPEMultiHeadAttention } from "../layers/cached_rope_multihead_attention";
3
+ /**
4
+ * This class implements the transformer decoder architecture from
5
+ * the 2017 paper "Attention Is All You Need".
6
+ *
7
+ * This decoder-only transformer layer accepts one tensor input.
8
+ * The input tensor should have the shape
9
+ * `[ batch, sequences, embedding dims ]`.
10
+ *
11
+ * Causal masking is enabled by default for the initial attention sub-layer.
12
+ *
13
+ * @param numHeads number of attention heads to use
14
+ * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
15
+ * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
16
+ * @param dropout use dropout during the attention calculations, default `0.1`
17
+ * @param activation the activation of the intermediate feed forward layer, default `relu`
18
+ * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
19
+ * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
20
+ */
21
+ export class TransformerDecoder extends tf.layers.Layer {
22
+ static className = "TransformerDecoder";
23
+ causalSelfAttention;
24
+ causalSelfAttentionDropout;
25
+ causalSelfAttentionNorm;
26
+ feedforward1;
27
+ feedforward2;
28
+ feedForwardDropout;
29
+ feedFowardNorm;
30
+ numHeads;
31
+ embedDim;
32
+ useBias;
33
+ dropout;
34
+ activation;
35
+ dimsFeedForward;
36
+ constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }) {
37
+ super(args);
38
+ this.numHeads = numHeads;
39
+ this.embedDim = embedDim;
40
+ this.useBias = useBias ?? true;
41
+ this.dropout = dropout ?? 0.1;
42
+ this.activation = activation ?? "relu";
43
+ if (this.dropout >= 1) {
44
+ throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
45
+ }
46
+ // in the paper section 3.3, d_model=512 (embedDim) and first dense layer outputs d_ff=2048
47
+ this.dimsFeedForward = dimsFeedForward ?? embedDim * 4;
48
+ // self attention sub-block
49
+ this.causalSelfAttention = new CachedRoPEMultiHeadAttention({
50
+ numHeads: this.numHeads, embedDim: this.embedDim,
51
+ useBias: this.useBias, dropout: this.dropout,
52
+ causal: true
53
+ });
54
+ this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
55
+ this.causalSelfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
56
+ // feed forward sub-block
57
+ this.feedforward1 = tf.layers.dense({
58
+ units: this.dimsFeedForward,
59
+ activation: this.activation,
60
+ useBias: this.useBias,
61
+ });
62
+ this.feedforward2 = tf.layers.dense({
63
+ units: this.embedDim,
64
+ activation: "linear",
65
+ useBias: this.useBias
66
+ });
67
+ this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
68
+ this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
69
+ }
70
+ /**
71
+ * Forward propagation
72
+ *
73
+ * @param inputs input tensor
74
+ * @return the output tensor
75
+ */
76
+ call(inputs, kwargs) {
77
+ // validate the input tensors
78
+ if (Array.isArray(inputs) && inputs.length != 1 && inputs.length != 2) {
79
+ throw Error(`${this.getClassName()}::call ${this.name} expects one input tensor, got ${inputs.length} inputs.`);
80
+ }
81
+ if (Array.isArray(inputs)) {
82
+ inputs = inputs[0];
83
+ }
84
+ // perform forward propagation
85
+ return tf.tidy(() => {
86
+ let output = this.causalSelfAttentionBlock(inputs, kwargs);
87
+ output = this.feedForwardBlock(output, kwargs);
88
+ return output;
89
+ });
90
+ }
91
+ causalSelfAttentionBlock(x, kwargs) {
92
+ return tf.tidy(() => {
93
+ const residual = x;
94
+ let attention = this.causalSelfAttention.apply(x, kwargs);
95
+ attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
96
+ attention = tf.add(attention, residual);
97
+ attention = this.causalSelfAttentionNorm.apply(attention, kwargs);
98
+ return attention;
99
+ });
100
+ }
101
+ feedForwardBlock(x, kwargs) {
102
+ return tf.tidy(() => {
103
+ const residual = x;
104
+ let feedForward = this.feedforward1.apply(x, kwargs);
105
+ feedForward = this.feedforward2.apply(feedForward, kwargs);
106
+ feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
107
+ feedForward = tf.add(feedForward, residual);
108
+ feedForward = this.feedFowardNorm.apply(feedForward, kwargs);
109
+ return feedForward;
110
+ });
111
+ }
112
+ /**
113
+ * Initialize the sublayers' weights and track them to enable serialization
114
+ */
115
+ build(inputShape) {
116
+ let input_shapes = [];
117
+ if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
118
+ // input is an array of shapes
119
+ input_shapes = inputShape;
120
+ }
121
+ else if (inputShape.length != 0) {
122
+ // input is a single shape
123
+ input_shapes = [inputShape];
124
+ }
125
+ if (input_shapes.length != 1 && input_shapes.length != 2) {
126
+ throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
127
+ ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
128
+ }
129
+ const [decoderInputShape] = input_shapes;
130
+ if (decoderInputShape?.length != 3) {
131
+ throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
132
+ ` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
133
+ }
134
+ // initialize causal self attention sub-block's weights
135
+ this.causalSelfAttention.build(decoderInputShape);
136
+ this.causalSelfAttentionNorm.build(this.causalSelfAttention.computeOutputShape(decoderInputShape));
137
+ // initialize feedforward sub-block's weights
138
+ const feedforward1OutputShape = this.feedforward1.computeOutputShape(decoderInputShape);
139
+ const feedforward2OutputShape = this.feedforward2.computeOutputShape(feedforward1OutputShape);
140
+ this.feedforward1.build(decoderInputShape);
141
+ this.feedforward2.build(feedforward1OutputShape);
142
+ this.feedFowardNorm.build(feedforward2OutputShape);
143
+ // track sublayers' weights
144
+ this.trainableWeights = [
145
+ ...this.causalSelfAttention.trainableWeights,
146
+ ...this.causalSelfAttentionDropout.trainableWeights,
147
+ ...this.causalSelfAttentionNorm.trainableWeights,
148
+ ...this.feedforward1.trainableWeights,
149
+ ...this.feedforward2.trainableWeights,
150
+ ...this.feedForwardDropout.trainableWeights,
151
+ ...this.feedFowardNorm.trainableWeights
152
+ ];
153
+ // rename the weights otherwise they'll take on the default naming and overlap
154
+ // each other which breaks model loading due to duplicate weight names
155
+ let indexing = 0;
156
+ for (const weight of this.trainableWeights) {
157
+ const unique_name = `${this.getClassName()}_${indexing}`;
158
+ weight.name += unique_name;
159
+ weight.originalName += unique_name;
160
+ indexing++;
161
+ }
162
+ super.build(inputShape);
163
+ }
164
+ /**
165
+ * Save the layer's hyperparameters for serialization
166
+ */
167
+ getConfig() {
168
+ const base_config = super.getConfig();
169
+ const config = {
170
+ numHeads: this.numHeads,
171
+ embedDim: this.embedDim,
172
+ useBias: this.useBias,
173
+ dropout: this.dropout,
174
+ activation: this.activation,
175
+ dimsFeedForward: this.dimsFeedForward
176
+ };
177
+ Object.assign(config, base_config);
178
+ return config;
179
+ }
180
+ }
181
+ tf.serialization.registerClass(TransformerDecoder);
182
+ //# sourceMappingURL=transformer_decoder.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.js","sourceRoot":"","sources":["../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAKvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,2CAA2C,CAAC;AAUzF;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAErB,mBAAmB,CAAkB;IACrC,0BAA0B,CAAkB;IAC5C,uBAAuB,CAAkB;IAEzC,YAAY,CAAkB;IAC9B,YAAY,CAAkB;IAC9B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAE3C,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QAEvC,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2FAA2F;QAC3F,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEvD,2BAA2B;QAC3B,IAAI,CAAC,mBAAmB,GAAG,IAAI,4BAA4B,CAAC;YACxD,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ;YAChD,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YAC5C,MAAM,EAAE,IAAI;SACf,CAAC,CAAC;QACH,IAAI,CAAC,0BAA0B,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QAC3E,IAAI,CAAC,uBAAuB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAE/E,yBAAyB;QACzB,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,eAAe;YAC3B,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,QAAQ;YACpB,UAAU,EAAE,QAAQ;YACpB,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACpE,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kCAAkC,MAAM,CAAC,MAAM,UAAU,CAAC,CAAC;QACpH,CAAC;QAED,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAc,CAAC;QACpC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAC3D,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAE/C,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,wBAAwB,CAAC,CAAY,EAAE,MAAc;QAC3D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACvE,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAE/E,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACrD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAE1E,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACvD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,MAAM,CAAC,iBAAiB,CAAC,GAAG,YAAY,CAAC;QAEzC,IAAI,iBAAiB,EAAE,MAAM,IAAI,CAAC,EAAE,CAAC;YACjC,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,uDAAuD;QACvD,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAClD,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,IAAI,CAAC,mBAAmB,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC;QAEnG,6CAA6C;QAC7C,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC;QACxF,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,uBAAuB,CAAC,CAAC;QAE9F,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAC3C,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QACjD,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QAEnD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,mBAAmB,CAAC,gBAAgB;YAC5C,GAAG,IAAI,CAAC,0BAA0B,CAAC,gBAAgB;YACnD,GAAG,IAAI,CAAC,uBAAuB,CAAC,gBAAgB;YAChD,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=transformer_decoder.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"transformer_decoder.test.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":""}