@stellarapp/tfjs-stellar 1.0.4 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. package/README.md +17 -0
  2. package/dist/index.d.ts +2 -1
  3. package/dist/index.d.ts.map +1 -1
  4. package/dist/index.js +2 -1
  5. package/dist/index.js.map +1 -1
  6. package/dist/kv_cache.d.ts +2 -0
  7. package/dist/kv_cache.d.ts.map +1 -1
  8. package/dist/kv_cache.js +6 -0
  9. package/dist/kv_cache.js.map +1 -1
  10. package/dist/masks.test.d.ts +2 -0
  11. package/dist/masks.test.d.ts.map +1 -0
  12. package/dist/masks.test.js +55 -0
  13. package/dist/masks.test.js.map +1 -0
  14. package/dist/models/index.d.ts +2 -1
  15. package/dist/models/index.d.ts.map +1 -1
  16. package/dist/models/index.js +2 -1
  17. package/dist/models/index.js.map +1 -1
  18. package/dist/utils.test.js +0 -15
  19. package/dist/utils.test.js.map +1 -1
  20. package/package.json +1 -1
  21. package/dist/jest.config.d.ts +0 -8
  22. package/dist/jest.config.d.ts.map +0 -1
  23. package/dist/jest.config.js +0 -147
  24. package/dist/jest.config.js.map +0 -1
  25. package/dist/src/index.d.ts +0 -6
  26. package/dist/src/index.d.ts.map +0 -1
  27. package/dist/src/index.js +0 -6
  28. package/dist/src/index.js.map +0 -1
  29. package/dist/src/kv_cache.d.ts +0 -53
  30. package/dist/src/kv_cache.d.ts.map +0 -1
  31. package/dist/src/kv_cache.js +0 -135
  32. package/dist/src/kv_cache.js.map +0 -1
  33. package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
  34. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
  35. package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
  36. package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
  37. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
  38. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
  39. package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
  40. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
  41. package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
  42. package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
  43. package/dist/src/layers/gpt_decoder_block.js +0 -51
  44. package/dist/src/layers/gpt_decoder_block.js.map +0 -1
  45. package/dist/src/layers/index.d.ts +0 -17
  46. package/dist/src/layers/index.d.ts.map +0 -1
  47. package/dist/src/layers/index.js +0 -33
  48. package/dist/src/layers/index.js.map +0 -1
  49. package/dist/src/layers/multihead_attention.d.ts +0 -106
  50. package/dist/src/layers/multihead_attention.d.ts.map +0 -1
  51. package/dist/src/layers/multihead_attention.js +0 -269
  52. package/dist/src/layers/multihead_attention.js.map +0 -1
  53. package/dist/src/layers/multihead_attention.test.d.ts +0 -2
  54. package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
  55. package/dist/src/layers/multihead_attention.test.js +0 -160
  56. package/dist/src/layers/multihead_attention.test.js.map +0 -1
  57. package/dist/src/layers/positional_encoding.d.ts +0 -37
  58. package/dist/src/layers/positional_encoding.d.ts.map +0 -1
  59. package/dist/src/layers/positional_encoding.js +0 -115
  60. package/dist/src/layers/positional_encoding.js.map +0 -1
  61. package/dist/src/layers/positional_encoding.test.d.ts +0 -2
  62. package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
  63. package/dist/src/layers/positional_encoding.test.js +0 -95
  64. package/dist/src/layers/positional_encoding.test.js.map +0 -1
  65. package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
  66. package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
  67. package/dist/src/layers/rotary_position_embedding.js +0 -99
  68. package/dist/src/layers/rotary_position_embedding.js.map +0 -1
  69. package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
  70. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
  71. package/dist/src/layers/rotary_position_embedding.test.js +0 -88
  72. package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
  73. package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
  74. package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
  75. package/dist/src/layers/token_and_positional_embedding.js +0 -109
  76. package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
  77. package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
  78. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
  79. package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
  80. package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
  81. package/dist/src/layers/transformer_decoder.d.ts +0 -69
  82. package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
  83. package/dist/src/layers/transformer_decoder.js +0 -182
  84. package/dist/src/layers/transformer_decoder.js.map +0 -1
  85. package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
  86. package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
  87. package/dist/src/layers/transformer_decoder.test.js +0 -72
  88. package/dist/src/layers/transformer_decoder.test.js.map +0 -1
  89. package/dist/src/layers/transformer_encoder.d.ts +0 -55
  90. package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
  91. package/dist/src/layers/transformer_encoder.js +0 -175
  92. package/dist/src/layers/transformer_encoder.js.map +0 -1
  93. package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
  94. package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
  95. package/dist/src/layers/transformer_encoder.test.js +0 -58
  96. package/dist/src/layers/transformer_encoder.test.js.map +0 -1
  97. package/dist/src/losses/dice.d.ts +0 -30
  98. package/dist/src/losses/dice.d.ts.map +0 -1
  99. package/dist/src/losses/dice.js +0 -93
  100. package/dist/src/losses/dice.js.map +0 -1
  101. package/dist/src/losses/index.d.ts +0 -2
  102. package/dist/src/losses/index.d.ts.map +0 -1
  103. package/dist/src/losses/index.js +0 -2
  104. package/dist/src/losses/index.js.map +0 -1
  105. package/dist/src/masks.d.ts +0 -20
  106. package/dist/src/masks.d.ts.map +0 -1
  107. package/dist/src/masks.js +0 -37
  108. package/dist/src/masks.js.map +0 -1
  109. package/dist/src/metrics.d.ts +0 -20
  110. package/dist/src/metrics.d.ts.map +0 -1
  111. package/dist/src/metrics.js +0 -28
  112. package/dist/src/metrics.js.map +0 -1
  113. package/dist/src/models/gpt_model.d.ts +0 -94
  114. package/dist/src/models/gpt_model.d.ts.map +0 -1
  115. package/dist/src/models/gpt_model.js +0 -154
  116. package/dist/src/models/gpt_model.js.map +0 -1
  117. package/dist/src/models/index.d.ts +0 -3
  118. package/dist/src/models/index.d.ts.map +0 -1
  119. package/dist/src/models/index.js +0 -3
  120. package/dist/src/models/index.js.map +0 -1
  121. package/dist/src/models/llm_model.d.ts +0 -87
  122. package/dist/src/models/llm_model.d.ts.map +0 -1
  123. package/dist/src/models/llm_model.js +0 -245
  124. package/dist/src/models/llm_model.js.map +0 -1
  125. package/dist/src/models/u_net.d.ts +0 -40
  126. package/dist/src/models/u_net.d.ts.map +0 -1
  127. package/dist/src/models/u_net.js +0 -151
  128. package/dist/src/models/u_net.js.map +0 -1
  129. package/dist/src/tfjs_types.d.ts +0 -10
  130. package/dist/src/tfjs_types.d.ts.map +0 -1
  131. package/dist/src/tfjs_types.js +0 -2
  132. package/dist/src/tfjs_types.js.map +0 -1
  133. package/dist/src/utils.d.ts +0 -28
  134. package/dist/src/utils.d.ts.map +0 -1
  135. package/dist/src/utils.js +0 -63
  136. package/dist/src/utils.js.map +0 -1
  137. package/dist/src/utils.test.d.ts +0 -2
  138. package/dist/src/utils.test.d.ts.map +0 -1
  139. package/dist/src/utils.test.js +0 -73
  140. package/dist/src/utils.test.js.map +0 -1
@@ -1,135 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- /**
3
- * A container for KV caches. A model should initialize one KV cache
4
- */
5
- export class KvCacheContainer {
6
- caches = new Map();
7
- max_sequence_length;
8
- constructor(maxSequenceLength) {
9
- if (!maxSequenceLength) {
10
- throw Error(`KvCacheContainer: expected KV cache maximum sequence length to be greater than 0, got: ${String(maxSequenceLength)}`);
11
- }
12
- this.max_sequence_length = maxSequenceLength;
13
- }
14
- create(id, args) {
15
- const new_cache = new KvCache({
16
- ...args,
17
- maxSequenceLength: this.max_sequence_length
18
- });
19
- this.caches.set(id, new_cache);
20
- }
21
- /**
22
- * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
23
- */
24
- update(id, key, value) {
25
- const kv_cache = this.caches.get(id);
26
- if (!kv_cache) {
27
- return undefined;
28
- }
29
- const { keyCache, valueCache } = kv_cache.update(key, value);
30
- // slicing to get only the past key and value projections, but normally
31
- // in TensorFlow and PyTorch the full cache is returned and masked for
32
- // graph purposes
33
- return tf.tidy(() => {
34
- const k_cache = keyCache.slice([0, 0, 0, 0], [keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
35
- const v_cache = valueCache.slice([0, 0, 0, 0], [valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
36
- return {
37
- keyCache: k_cache,
38
- valueCache: v_cache
39
- };
40
- });
41
- }
42
- reset() {
43
- this.caches.forEach(cache => {
44
- cache.reset();
45
- });
46
- }
47
- dispose() {
48
- this.caches.forEach(cache => {
49
- cache.dispose();
50
- });
51
- }
52
- get size() {
53
- // the size of all KV caches are expected to be the same, just use the first one
54
- return this.caches.entries().next().value?.[1].size ?? 0;
55
- }
56
- get maxSequenceLength() {
57
- return this.max_sequence_length;
58
- }
59
- }
60
- export class KvCache {
61
- key_cache;
62
- value_cache;
63
- // the size of the KV cache, represents the number of tokens since the first chat token
64
- current_position = 0;
65
- batch_size;
66
- max_sequence_length;
67
- num_kv_heads;
68
- head_dim;
69
- constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }) {
70
- const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim];
71
- this.key_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
72
- this.value_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
73
- this.batch_size = batchSize;
74
- this.max_sequence_length = maxSequenceLength;
75
- this.num_kv_heads = numHeads;
76
- this.head_dim = headDim;
77
- }
78
- /**
79
- * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
80
- */
81
- update(key, value) {
82
- const batch_size = key.shape[0];
83
- const seq_len = key.shape[2];
84
- if (batch_size > this.key_cache.shape[0]) {
85
- throw Error(`The current KV cache has been set up with a batch size of` +
86
- ` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`);
87
- }
88
- if (this.current_position + seq_len > this.max_sequence_length) {
89
- throw Error(`The KV cache has exceeded its maximum sequence length of ${this.max_sequence_length}. Use a larger value.`);
90
- }
91
- const new_key_cache = this.mergeIntoCache(key, this.key_cache);
92
- const new_value_cache = this.mergeIntoCache(value, this.value_cache);
93
- this.key_cache.assign(new_key_cache);
94
- this.value_cache.assign(new_value_cache);
95
- new_key_cache.dispose();
96
- new_value_cache.dispose();
97
- // advance the pointer to reflect the updated cache's current
98
- this.current_position += seq_len;
99
- return {
100
- keyCache: this.key_cache,
101
- valueCache: this.value_cache,
102
- };
103
- }
104
- mergeIntoCache(new_value, current_cache) {
105
- const seq_len = new_value.shape[2];
106
- return tf.tidy(() => {
107
- const historical = current_cache.slice([0, 0, 0, 0], [this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
108
- const future = current_cache.slice([0, 0, this.current_position + seq_len, 0], [this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
109
- // merge the new tensor into the current cache to create a new, larger, cache,
110
- // this is different from Python immplementations because TFJS tensors are immutable,
111
- // because we cannot update a slice, we must slice and concat
112
- return tf.concat([historical, new_value, future], 2);
113
- });
114
- }
115
- reset() {
116
- this.current_position = 0;
117
- tf.tidy(() => {
118
- const key_cache_shape = this.key_cache.shape;
119
- const value_cache_shape = this.value_cache.shape;
120
- this.key_cache.assign(tf.zeros(key_cache_shape));
121
- this.value_cache.assign(tf.zeros(value_cache_shape));
122
- });
123
- }
124
- dispose() {
125
- this.key_cache.dispose();
126
- this.value_cache.dispose();
127
- }
128
- /**
129
- * The size of the KV cache, also the number of tokens since the first one.
130
- */
131
- get size() {
132
- return this.current_position;
133
- }
134
- }
135
- //# sourceMappingURL=kv_cache.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
@@ -1,31 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { KvCacheContainer } from "@/kv_cache";
3
- import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
4
- import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
5
- /**
6
- * MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
7
- * should be used in a custom training loop because it requires the cache to be
8
- * passed through the `kwargs.kvCache` argument during the `layer.apply()`
9
- * forward propagation.
10
- *
11
- * If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
12
- */
13
- export declare class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
14
- static className: string;
15
- protected rope: tf.layers.Layer;
16
- constructor(args: MultiHeadAttentionArgs);
17
- protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
18
- protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D): {
19
- keyCache: tf.Variable<tf.Rank.R4>;
20
- valueCache: tf.Variable<tf.Rank.R4>;
21
- };
22
- /**
23
- * Adds RoPE position encoding right after splitting heads.
24
- */
25
- protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
26
- query_split: tf.Tensor4D;
27
- key_split: tf.Tensor4D;
28
- value_split: tf.Tensor4D;
29
- };
30
- }
31
- //# sourceMappingURL=cached_rope_multihead_attention.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"cached_rope_multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAE/F,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE;;;;;;;GAOG;AACH,qBAAa,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,SAAkC;IAElD,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAEpB,IAAI,EAAE,sBAAsB;cAMrB,OAAO,CACtB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAuC9B,SAAS,CAAC,WAAW,CAAC,YAAY,EAAE,gBAAgB,EAAE,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,WAAW,EAAE,EAAE,CAAC,QAAQ;;;;IAqBtG;;OAEG;cACgB,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAO5D,EAAE,CAAC,QAAQ;mBAEX,EAAE,CAAC,QAAQ;qBACwB,EAAE,CAAC,QAAQ;;CAIxF"}
@@ -1,76 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { MultiHeadAttention } from '@/layers/multihead_attention';
3
- import { RotaryPositionEmbedding } from '@/layers/rotary_position_embedding';
4
- /**
5
- * MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
6
- * should be used in a custom training loop because it requires the cache to be
7
- * passed through the `kwargs.kvCache` argument during the `layer.apply()`
8
- * forward propagation.
9
- *
10
- * If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
11
- */
12
- export class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
13
- static className = "CachedRoPEMultiHeadAttention";
14
- rope;
15
- constructor(args) {
16
- super(args);
17
- this.rope = new RotaryPositionEmbedding({ dim: Math.floor(this.embedDim / this.numHeads) });
18
- }
19
- forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
20
- return tf.tidy(() => {
21
- const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
22
- // swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
23
- const move_head_dim_forward = [0, 2, 1, 3];
24
- const split = this.splitHeads(query, key, value, move_head_dim_forward);
25
- const query_split = split.query_split;
26
- let key_split = split.key_split;
27
- let value_split = split.value_split;
28
- if (kwargs.training !== true && kwargs.kvCache) {
29
- // runs on inference, updates the KV cache and get the historical key and value
30
- const cached_kv = this.getCachedKV(kwargs.kvCache, key_split, value_split);
31
- key_split = cached_kv.keyCache;
32
- value_split = cached_kv.valueCache;
33
- }
34
- // apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
35
- const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
36
- // concat heads and apply the output projection
37
- const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], query_input.shape[1], this.embedDim]));
38
- return output;
39
- });
40
- }
41
- getCachedKV(kv_container, key_split, value_split) {
42
- try {
43
- let kv_cache = kv_container.update(this.name, key_split, value_split);
44
- if (!kv_cache) {
45
- kv_container.create(this.name, {
46
- batchSize: key_split.shape[0],
47
- numHeads: this.numHeads,
48
- headDim: this.embedDim / this.numHeads,
49
- });
50
- kv_cache = kv_container.update(this.name, key_split, value_split);
51
- }
52
- return kv_cache;
53
- }
54
- catch (error) {
55
- throw Error(`${this.getClassName()}::getCachedKV ${this.name} ${error.toString()}`);
56
- }
57
- }
58
- /**
59
- * Adds RoPE position encoding right after splitting heads.
60
- */
61
- splitHeads(query, key, value, shuffle) {
62
- const batch_size = query.shape[0];
63
- const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
64
- return tf.tidy(() => {
65
- return {
66
- query_split: this.rope.apply(query.reshape(split_heads))
67
- .transpose(shuffle),
68
- key_split: this.rope.apply(key.reshape(split_heads))
69
- .transpose(shuffle),
70
- value_split: value.reshape(split_heads).transpose(shuffle)
71
- };
72
- });
73
- }
74
- }
75
- tf.serialization.registerClass(CachedRoPEMultiHeadAttention);
76
- //# sourceMappingURL=cached_rope_multihead_attention.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"cached_rope_multihead_attention.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAC/F,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAI7E;;;;;;;GAOG;AACH,MAAM,OAAO,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,GAAG,8BAA8B,CAAC;IAExC,IAAI,CAAkB;IAEhC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;IAChG,CAAC;IAGkB,OAAO,CACtB,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAExE,MAAM,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YACtC,IAAI,SAAS,GAAG,KAAK,CAAC,SAAS,CAAC;YAChC,IAAI,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YAEpC,IAAI,MAAM,CAAC,QAAQ,KAAK,IAAI,IAAI,MAAM,CAAC,OAAO,EAAE,CAAC;gBAC7C,+EAA+E;gBAC/E,MAAM,SAAS,GAAG,IAAI,CAAC,WAAW,CAC9B,MAAM,CAAC,OAA2B,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;gBAEhE,SAAS,GAAG,SAAS,CAAC,QAAQ,CAAC;gBAC/B,WAAW,GAAG,SAAS,CAAC,UAAU,CAAC;YACvC,CAAC;YAED,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CACzC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,WAAW,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAEvE,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,WAAW,CAAC,YAA8B,EAAE,SAAsB,EAAE,WAAwB;QAClG,IAAI,CAAC;YACD,IAAI,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAEtE,IAAI,CAAC,QAAQ,EAAE,CAAC;gBACZ,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE;oBAC3B,SAAS,EAAE,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC;oBAC7B,QAAQ,EAAE,IAAI,CAAC,QAAQ;oBACvB,OAAO,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ;iBACzC,CAAC,CAAA;gBAEF,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAE,CAAC;YACvE,CAAC;YAED,OAAO,QAAS,CAAC;QACrB,CAAC;QAAC,OAAO,KAAU,EAAE,CAAC;YAClB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,IAAI,KAAK,CAAC,QAAQ,EAAE,EAAE,CAAC,CAAC;QACxF,CAAC;IACL,CAAC;IAGD;;OAEG;IACgB,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QAC/F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAClE,SAAS,CAAC,OAAO,CAAgB;gBACtC,SAAS,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAC9D,SAAS,CAAC,OAAO,CAAgB;gBACtC,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,4BAA4B,CAAC,CAAC"}
@@ -1,2 +0,0 @@
1
- export {};
2
- //# sourceMappingURL=cached_rope_multihead_attention.test.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"cached_rope_multihead_attention.test.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":""}
@@ -1,43 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { KvCacheContainer } from '@/kv_cache';
3
- import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
4
- // disables warning for using the faster node backend,
5
- // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
6
- tf.env().set('IS_NODE', false);
7
- describe("CachedRoPEMultiHeadAttention tests", () => {
8
- test("aggregate forward passes output are identical normal multihead attention", () => {
9
- compareNormalWithCachedAttention(tf.randomUniform([2, 10, 16]), 123);
10
- compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 123);
11
- compareNormalWithCachedAttention(tf.randomUniform([1, 1, 16]), 123);
12
- compareNormalWithCachedAttention(tf.randomUniform([3, 2, 16]), 123);
13
- // input exceeds KV cach size
14
- expect(() => compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 5)).toThrow();
15
- function compareNormalWithCachedAttention(input, max_sequence_length) {
16
- const embed_dim = input.shape[2];
17
- const batch = input.shape[0];
18
- const heads = 2;
19
- const kv_cache = new KvCacheContainer(max_sequence_length);
20
- const normal_mha = new CachedRoPEMultiHeadAttention({ numHeads: heads, embedDim: embed_dim, causal: true });
21
- const normal_mha_output = normal_mha.apply(input);
22
- // initialize cached attention with identical configuration and weights
23
- const cached_mha1 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test1" });
24
- cached_mha1.build(input.shape);
25
- cached_mha1.setWeights(normal_mha.getWeights());
26
- const cached_mha2 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test2" });
27
- cached_mha2.build(input.shape);
28
- cached_mha2.setWeights(normal_mha.getWeights());
29
- const cached_mha_outputs1 = [];
30
- const cached_mha_outputs2 = [];
31
- for (let i = 0; i < input.shape[1]; i++) {
32
- const current_token = input.slice([0, i, 0], [batch, 1, embed_dim]);
33
- cached_mha_outputs1.push(cached_mha1.apply(current_token, { kvCache: kv_cache }));
34
- cached_mha_outputs2.push(cached_mha2.apply(current_token, { kvCache: kv_cache }));
35
- }
36
- expect(kv_cache.size == input.shape[1]);
37
- expect(kv_cache.size == input.shape[1]);
38
- expect(normal_mha_output.sub(tf.concat(cached_mha_outputs1, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
39
- expect(normal_mha_output.sub(tf.concat(cached_mha_outputs2, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
40
- }
41
- });
42
- });
43
- //# sourceMappingURL=cached_rope_multihead_attention.test.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"cached_rope_multihead_attention.test.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AAGxF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,0EAA0E,EAAE,GAAG,EAAE;QAClF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAChF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAEhF,6BAA6B;QAC7B,MAAM,CAAC,GAAG,EAAE,CAAC,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEvG,SAAS,gCAAgC,CAAC,KAAkB,EAAE,mBAA2B;YACrF,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,KAAK,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC7B,MAAM,KAAK,GAAG,CAAC,CAAC;YAEhB,MAAM,QAAQ,GAAG,IAAI,gBAAgB,CAAC,mBAAmB,CAAC,CAAC;YAE3D,MAAM,UAAU,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;YAC5G,MAAM,iBAAiB,GAAG,UAAU,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;YAE/D,uEAAuE;YACvE,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAC5C,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAE5C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBACtC,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;gBAEpE,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;gBAC/F,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;YACnG,CAAC;YAED,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACvC,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YAExC,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;YACxG,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QAC5G,CAAC;IACL,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
@@ -1,34 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
3
- import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
4
- import { TransformerDecoder, type TransformerDecoderArgs } from "@/layers/transformer_decoder";
5
- export interface GPTDecoderBlockArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
6
- dimsFeedForward?: number;
7
- }
8
- /**
9
- * This implements the GPT-2 transformer block by modifying the transformer
10
- * decoder block to use pre-layer-normalization and replacing ReLU activation
11
- * with GELU.
12
- *
13
- * @param numHeads number of attention heads to use
14
- * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
15
- * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
16
- * @param dropout use dropout during the attention calculations, default `0.1`
17
- * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
18
- * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
19
- */
20
- export declare class GPT2DecoderBlock extends TransformerDecoder {
21
- static className: string;
22
- constructor(args: TransformerDecoderArgs);
23
- /**
24
- * Attention sub-block which is similar to the original transformer except
25
- * layer normalization is applied beginning
26
- */
27
- protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
28
- /**
29
- * Feedforward sub-block which is similar to the original transformer except
30
- * layer normalization is applied at the beginning and gelu activation is used
31
- */
32
- protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
33
- }
34
- //# sourceMappingURL=gpt_decoder_block.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"gpt_decoder_block.d.ts","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAC3E,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,mBAAoB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAC/E,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;GAWG;AACH,qBAAa,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,SAAsB;gBAG1B,IAAI,EAAE,sBAAsB;IAKxC;;;OAGG;cACgB,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAcpF;;;OAGG;cACgB,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;CAkB/E"}
@@ -1,51 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { TransformerDecoder } from "@/layers/transformer_decoder";
3
- /**
4
- * This implements the GPT-2 transformer block by modifying the transformer
5
- * decoder block to use pre-layer-normalization and replacing ReLU activation
6
- * with GELU.
7
- *
8
- * @param numHeads number of attention heads to use
9
- * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
10
- * @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
11
- * @param dropout use dropout during the attention calculations, default `0.1`
12
- * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
13
- * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
14
- */
15
- export class GPT2DecoderBlock extends TransformerDecoder {
16
- static className = "GPT2DecoderBlock";
17
- constructor(args) {
18
- super(args);
19
- }
20
- /**
21
- * Attention sub-block which is similar to the original transformer except
22
- * layer normalization is applied beginning
23
- */
24
- causalSelfAttentionBlock(x, kwargs) {
25
- return tf.tidy(() => {
26
- const residual = x;
27
- let attention = this.causalSelfAttentionNorm.apply(x, kwargs);
28
- attention = this.causalSelfAttention.apply(attention, kwargs);
29
- attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
30
- attention = tf.add(attention, residual);
31
- return attention;
32
- });
33
- }
34
- /**
35
- * Feedforward sub-block which is similar to the original transformer except
36
- * layer normalization is applied at the beginning and gelu activation is used
37
- */
38
- feedForwardBlock(x, kwargs) {
39
- return tf.tidy(() => {
40
- const residual = x;
41
- let feedForward = this.feedFowardNorm.apply(x, kwargs);
42
- feedForward = this.feedforward1.apply(feedForward, kwargs);
43
- feedForward = this.feedforward2.apply(feedForward, kwargs);
44
- feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
45
- feedForward = tf.add(feedForward, residual);
46
- return feedForward;
47
- });
48
- }
49
- }
50
- tf.serialization.registerClass(GPT2DecoderBlock);
51
- //# sourceMappingURL=gpt_decoder_block.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"gpt_decoder_block.js","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAQ/F;;;;;;;;;;;GAWG;AACH,MAAM,OAAO,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,GAAG,kBAAkB,CAAC;IAGtC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;IAChB,CAAC;IAGD;;;OAGG;IACgB,wBAAwB,CAAC,CAAY,EAAE,MAAc;QACpE,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YAExC,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;;OAGG;IACgB,gBAAgB,CAAC,CAAY,EAAE,MAAc;QAC5D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACvD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAE5C,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;;AASL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC"}
@@ -1,17 +0,0 @@
1
- import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
2
- import { GPT2DecoderBlock, GPTDecoderBlockArgs } from "./gpt_decoder_block";
3
- import { MultiHeadAttention, MultiHeadAttentionArgs } from "./multihead_attention";
4
- import { PositionalEncoding, PositionalEncodingArgs } from "./positional_encoding";
5
- import { RotaryPositionEmbedding, RotaryPositionEmbeddingArgs } from "./rotary_position_embedding";
6
- import { TokenAndPositionalEmbedding, TokenAndPositionalEmbeddingArgs } from "./token_and_positional_embedding";
7
- import { TransformerDecoder, TransformerDecoderArgs } from "./transformer_decoder";
8
- import { TransformerEncoder, TransformerEncoderArgs } from "./transformer_encoder";
9
- export declare function tokenAndPositionalEmbedding(args: TokenAndPositionalEmbeddingArgs): TokenAndPositionalEmbedding;
10
- export declare function transformerEncoder(args: TransformerEncoderArgs): TransformerEncoder;
11
- export declare function transformerDecoder(args: TransformerDecoderArgs): TransformerDecoder;
12
- export declare function multiheadAttention(args: MultiHeadAttentionArgs): MultiHeadAttention;
13
- export declare function cachedRopeMultiheadAttention(args: MultiHeadAttentionArgs): CachedRoPEMultiHeadAttention;
14
- export declare function positionalEncoding(args: PositionalEncodingArgs): PositionalEncoding;
15
- export declare function gpt2DecoderBlock(args: GPTDecoderBlockArgs): GPT2DecoderBlock;
16
- export declare function rotaryPositionEmbedding(args: RotaryPositionEmbeddingArgs): RotaryPositionEmbedding;
17
- //# sourceMappingURL=index.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAAE,2BAA2B,EAAE,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAE,+BAA+B,EAAE,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AAGnF,wBAAgB,2BAA2B,CAAC,IAAI,EAAE,+BAA+B,+BAEhF;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,4BAA4B,CAAC,IAAI,EAAE,sBAAsB,gCAExE;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,gBAAgB,CAAC,IAAI,EAAE,mBAAmB,oBAEzD;AAGD,wBAAgB,uBAAuB,CAAC,IAAI,EAAE,2BAA2B,2BAExE"}
@@ -1,33 +0,0 @@
1
- import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
2
- import { GPT2DecoderBlock } from "./gpt_decoder_block";
3
- import { MultiHeadAttention } from "./multihead_attention";
4
- import { PositionalEncoding } from "./positional_encoding";
5
- import { RotaryPositionEmbedding } from "./rotary_position_embedding";
6
- import { TokenAndPositionalEmbedding } from "./token_and_positional_embedding";
7
- import { TransformerDecoder } from "./transformer_decoder";
8
- import { TransformerEncoder } from "./transformer_encoder";
9
- export function tokenAndPositionalEmbedding(args) {
10
- return new TokenAndPositionalEmbedding(args);
11
- }
12
- export function transformerEncoder(args) {
13
- return new TransformerEncoder(args);
14
- }
15
- export function transformerDecoder(args) {
16
- return new TransformerDecoder(args);
17
- }
18
- export function multiheadAttention(args) {
19
- return new MultiHeadAttention(args);
20
- }
21
- export function cachedRopeMultiheadAttention(args) {
22
- return new CachedRoPEMultiHeadAttention(args);
23
- }
24
- export function positionalEncoding(args) {
25
- return new PositionalEncoding(args);
26
- }
27
- export function gpt2DecoderBlock(args) {
28
- return new GPT2DecoderBlock(args);
29
- }
30
- export function rotaryPositionEmbedding(args) {
31
- return new RotaryPositionEmbedding(args);
32
- }
33
- //# sourceMappingURL=index.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAuB,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAA+B,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAmC,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AAGnF,MAAM,UAAU,2BAA2B,CAAC,IAAqC;IAC7E,OAAO,IAAI,2BAA2B,CAAC,IAAI,CAAC,CAAC;AACjD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,4BAA4B,CAAC,IAA4B;IACrE,OAAO,IAAI,4BAA4B,CAAC,IAAI,CAAC,CAAC;AAClD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,gBAAgB,CAAC,IAAyB;IACtD,OAAO,IAAI,gBAAgB,CAAC,IAAI,CAAC,CAAC;AACtC,CAAC;AAGD,MAAM,UAAU,uBAAuB,CAAC,IAAiC;IACrE,OAAO,IAAI,uBAAuB,CAAC,IAAI,CAAC,CAAC;AAC7C,CAAC"}
@@ -1,106 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
- import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
- export interface MultiHeadAttentionArgs extends LayerArgs {
5
- numHeads: number;
6
- embedDim: number;
7
- useBias?: boolean;
8
- dropout?: number;
9
- causal?: boolean;
10
- }
11
- export interface ScaledDotProductionAttentionKwargs {
12
- training?: boolean;
13
- dropout?: number;
14
- causal?: boolean;
15
- scaling_factor?: number;
16
- }
17
- /**
18
- * This MultiHead Attention layer implements the algorithm as described in
19
- * the paper "Attention is all you Need" Vaswani et al., 2017.
20
- *
21
- * @param numHeads number of attention heads to use
22
- * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
23
- * @param causal use causal masking, default `false`
24
- * @param dropout use dropout during the attention calculations, default `0.0`
25
- * @param useBias use bias for the dense sublayers, default `true`
26
- *
27
- * The TensorFlow version uses tf.einsum, whose gradient op has not yet been
28
- * implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
29
- * therefore we follow the PyTorch implementation described in:
30
- * https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
31
- * https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
32
- *
33
- * This implementation is different from TensorFlow's whose attention weights
34
- * are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
35
- * are shaped [embed dim, embed dim]
36
- * https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
37
- * https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
38
- *
39
- * TODO: implement a fast track for self attention (query = key = value)
40
- * where a single dense layer combines and replaces the query, key and projection layers
41
- *
42
- * TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
43
- */
44
- export declare class MultiHeadAttention extends tf.layers.Layer {
45
- static className: string;
46
- protected readonly numHeads: number;
47
- protected readonly embedDim: number;
48
- protected readonly useBias: boolean;
49
- protected readonly dropout: number;
50
- protected readonly causal: boolean;
51
- protected readonly queryProjection: tf.layers.Layer;
52
- protected readonly keyProjection: tf.layers.Layer;
53
- protected readonly valueProjection: tf.layers.Layer;
54
- protected readonly outputProjection: tf.layers.Layer;
55
- constructor({ numHeads, embedDim, useBias, dropout, causal, ...args }: MultiHeadAttentionArgs);
56
- /**
57
- * Forward propagation. Provide one input tensor or three identical tensors to self-attention.
58
- * @param inputs a single tensor for self-attention or an array of exactly three
59
- * tensors that are either identical (self-attention) or different (cross-attention)
60
- * @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
61
- */
62
- call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs & {
63
- packingMask?: tf.Tensor;
64
- causalMask?: tf.Tensor;
65
- }): tf.Tensor | tf.Tensor[];
66
- /**
67
- * Forward propagation
68
- */
69
- protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
70
- protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor): {
71
- query: tf.Tensor;
72
- key: tf.Tensor;
73
- value: tf.Tensor;
74
- };
75
- protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
76
- query_split: tf.Tensor4D;
77
- key_split: tf.Tensor4D;
78
- value_split: tf.Tensor4D;
79
- };
80
- /**
81
- * Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
82
- * formula (1) of the 2017 paper Attention Is All You Need
83
- *
84
- * @param attentionMask a mask to prevent tokens from being
85
- * attended to (usually for padding tokens). It should have the shape
86
- * [batch, head, query_sequence_len, key_sequence_len]. To use in
87
- * conjunction with causal masking, the tensor should be a boolean type
88
- * where false indicates a masked token.
89
- * @param packingMask a mask to prevent tokens from attending across document boundaries
90
- */
91
- static scaledDotProductionAttention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, attentionMask: tf.Tensor | null, packingMask: tf.Tensor | null, causalMask: tf.Tensor | null, dropout: number, causal: boolean, kwargs?: ScaledDotProductionAttentionKwargs): tf.Tensor;
92
- build(inputShape: tf.Shape | tf.Shape[]): void;
93
- /**
94
- * MultiHead attention's output is the same shape the query's.
95
- */
96
- computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
97
- getConfig(): {
98
- numHeads: number;
99
- embedDim: number;
100
- useBias: boolean;
101
- causal: boolean;
102
- dropout: number;
103
- name: string;
104
- };
105
- }
106
- //# sourceMappingURL=multihead_attention.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAIjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IACrD,QAAQ,EAAE,MAAM,CAAC;IACjB,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,CAAC,EAAE,OAAO,CAAC;IAClB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD,MAAM,WAAW,kCAAkC;IAC/C,QAAQ,CAAC,EAAE,OAAO,CAAC;IACnB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;IACjB,cAAc,CAAC,EAAE,MAAM,CAAC;CAC3B;AAGD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,MAAM,EAAE,OAAO,CAAC;IAInC,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,aAAa,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAClD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,gBAAgB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAGzC,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAc,EAAE,OAAa,EAAE,MAAc,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IA0BlH;;;;;OAKG;IACM,IAAI,CACT,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAC/B,MAAM,EAAE,MAAM,GAAG;QACb,WAAW,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;QACxB,UAAU,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;KAC1B,GACF,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IA6B1B;;OAEG;IACH,SAAS,CAAC,OAAO,CACb,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IA+B9B,SAAS,CAAC,qBAAqB,CAAC,WAAW,EAAE,EAAE,CAAC,MAAM,EAAE,SAAS,EAAE,EAAE,CAAC,MAAM,EAAE,WAAW,EAAE,EAAE,CAAC,MAAM;eAMtC,EAAE,CAAC,MAAM;aACf,EAAE,CAAC,MAAM;eACH,EAAE,CAAC,MAAM;;IAMvE,SAAS,CAAC,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAShB,EAAE,CAAC,QAAQ;mBACf,EAAE,CAAC,QAAQ;qBACP,EAAE,CAAC,QAAQ;;IAMrF;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,GAAG,EAAE,EAAE,CAAC,MAAM,EACd,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,aAAa,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC/B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,UAAU,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC5B,OAAO,EAAE,MAAM,EACf,MAAM,EAAE,OAAO,EACf,MAAM,GAAE,kCAAuC,GAChD,EAAE,CAAC,MAAM;IA0EH,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA4CvD;;OAEG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS;;;;;;;;CAgBrB"}