@stellarapp/tfjs-stellar 1.0.4 → 1.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -0
- package/dist/index.d.ts +2 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -1
- package/dist/index.js.map +1 -1
- package/dist/kv_cache.d.ts +2 -0
- package/dist/kv_cache.d.ts.map +1 -1
- package/dist/kv_cache.js +6 -0
- package/dist/kv_cache.js.map +1 -1
- package/dist/masks.test.d.ts +2 -0
- package/dist/masks.test.d.ts.map +1 -0
- package/dist/masks.test.js +55 -0
- package/dist/masks.test.js.map +1 -0
- package/dist/models/index.d.ts +2 -1
- package/dist/models/index.d.ts.map +1 -1
- package/dist/models/index.js +2 -1
- package/dist/models/index.js.map +1 -1
- package/dist/utils.test.js +0 -15
- package/dist/utils.test.js.map +1 -1
- package/package.json +1 -1
- package/dist/jest.config.d.ts +0 -8
- package/dist/jest.config.d.ts.map +0 -1
- package/dist/jest.config.js +0 -147
- package/dist/jest.config.js.map +0 -1
- package/dist/src/index.d.ts +0 -6
- package/dist/src/index.d.ts.map +0 -1
- package/dist/src/index.js +0 -6
- package/dist/src/index.js.map +0 -1
- package/dist/src/kv_cache.d.ts +0 -53
- package/dist/src/kv_cache.d.ts.map +0 -1
- package/dist/src/kv_cache.js +0 -135
- package/dist/src/kv_cache.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
- package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
- package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
- package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
- package/dist/src/layers/gpt_decoder_block.js +0 -51
- package/dist/src/layers/gpt_decoder_block.js.map +0 -1
- package/dist/src/layers/index.d.ts +0 -17
- package/dist/src/layers/index.d.ts.map +0 -1
- package/dist/src/layers/index.js +0 -33
- package/dist/src/layers/index.js.map +0 -1
- package/dist/src/layers/multihead_attention.d.ts +0 -106
- package/dist/src/layers/multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.js +0 -269
- package/dist/src/layers/multihead_attention.js.map +0 -1
- package/dist/src/layers/multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.test.js +0 -160
- package/dist/src/layers/multihead_attention.test.js.map +0 -1
- package/dist/src/layers/positional_encoding.d.ts +0 -37
- package/dist/src/layers/positional_encoding.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.js +0 -115
- package/dist/src/layers/positional_encoding.js.map +0 -1
- package/dist/src/layers/positional_encoding.test.d.ts +0 -2
- package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.test.js +0 -95
- package/dist/src/layers/positional_encoding.test.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
- package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.js +0 -99
- package/dist/src/layers/rotary_position_embedding.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.js +0 -88
- package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.js +0 -109
- package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
- package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
- package/dist/src/layers/transformer_decoder.d.ts +0 -69
- package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.js +0 -182
- package/dist/src/layers/transformer_decoder.js.map +0 -1
- package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.test.js +0 -72
- package/dist/src/layers/transformer_decoder.test.js.map +0 -1
- package/dist/src/layers/transformer_encoder.d.ts +0 -55
- package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.js +0 -175
- package/dist/src/layers/transformer_encoder.js.map +0 -1
- package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.test.js +0 -58
- package/dist/src/layers/transformer_encoder.test.js.map +0 -1
- package/dist/src/losses/dice.d.ts +0 -30
- package/dist/src/losses/dice.d.ts.map +0 -1
- package/dist/src/losses/dice.js +0 -93
- package/dist/src/losses/dice.js.map +0 -1
- package/dist/src/losses/index.d.ts +0 -2
- package/dist/src/losses/index.d.ts.map +0 -1
- package/dist/src/losses/index.js +0 -2
- package/dist/src/losses/index.js.map +0 -1
- package/dist/src/masks.d.ts +0 -20
- package/dist/src/masks.d.ts.map +0 -1
- package/dist/src/masks.js +0 -37
- package/dist/src/masks.js.map +0 -1
- package/dist/src/metrics.d.ts +0 -20
- package/dist/src/metrics.d.ts.map +0 -1
- package/dist/src/metrics.js +0 -28
- package/dist/src/metrics.js.map +0 -1
- package/dist/src/models/gpt_model.d.ts +0 -94
- package/dist/src/models/gpt_model.d.ts.map +0 -1
- package/dist/src/models/gpt_model.js +0 -154
- package/dist/src/models/gpt_model.js.map +0 -1
- package/dist/src/models/index.d.ts +0 -3
- package/dist/src/models/index.d.ts.map +0 -1
- package/dist/src/models/index.js +0 -3
- package/dist/src/models/index.js.map +0 -1
- package/dist/src/models/llm_model.d.ts +0 -87
- package/dist/src/models/llm_model.d.ts.map +0 -1
- package/dist/src/models/llm_model.js +0 -245
- package/dist/src/models/llm_model.js.map +0 -1
- package/dist/src/models/u_net.d.ts +0 -40
- package/dist/src/models/u_net.d.ts.map +0 -1
- package/dist/src/models/u_net.js +0 -151
- package/dist/src/models/u_net.js.map +0 -1
- package/dist/src/tfjs_types.d.ts +0 -10
- package/dist/src/tfjs_types.d.ts.map +0 -1
- package/dist/src/tfjs_types.js +0 -2
- package/dist/src/tfjs_types.js.map +0 -1
- package/dist/src/utils.d.ts +0 -28
- package/dist/src/utils.d.ts.map +0 -1
- package/dist/src/utils.js +0 -63
- package/dist/src/utils.js.map +0 -1
- package/dist/src/utils.test.d.ts +0 -2
- package/dist/src/utils.test.d.ts.map +0 -1
- package/dist/src/utils.test.js +0 -73
- package/dist/src/utils.test.js.map +0 -1
package/dist/src/kv_cache.js
DELETED
|
@@ -1,135 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
/**
|
|
3
|
-
* A container for KV caches. A model should initialize one KV cache
|
|
4
|
-
*/
|
|
5
|
-
export class KvCacheContainer {
|
|
6
|
-
caches = new Map();
|
|
7
|
-
max_sequence_length;
|
|
8
|
-
constructor(maxSequenceLength) {
|
|
9
|
-
if (!maxSequenceLength) {
|
|
10
|
-
throw Error(`KvCacheContainer: expected KV cache maximum sequence length to be greater than 0, got: ${String(maxSequenceLength)}`);
|
|
11
|
-
}
|
|
12
|
-
this.max_sequence_length = maxSequenceLength;
|
|
13
|
-
}
|
|
14
|
-
create(id, args) {
|
|
15
|
-
const new_cache = new KvCache({
|
|
16
|
-
...args,
|
|
17
|
-
maxSequenceLength: this.max_sequence_length
|
|
18
|
-
});
|
|
19
|
-
this.caches.set(id, new_cache);
|
|
20
|
-
}
|
|
21
|
-
/**
|
|
22
|
-
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
23
|
-
*/
|
|
24
|
-
update(id, key, value) {
|
|
25
|
-
const kv_cache = this.caches.get(id);
|
|
26
|
-
if (!kv_cache) {
|
|
27
|
-
return undefined;
|
|
28
|
-
}
|
|
29
|
-
const { keyCache, valueCache } = kv_cache.update(key, value);
|
|
30
|
-
// slicing to get only the past key and value projections, but normally
|
|
31
|
-
// in TensorFlow and PyTorch the full cache is returned and masked for
|
|
32
|
-
// graph purposes
|
|
33
|
-
return tf.tidy(() => {
|
|
34
|
-
const k_cache = keyCache.slice([0, 0, 0, 0], [keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
|
|
35
|
-
const v_cache = valueCache.slice([0, 0, 0, 0], [valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
|
|
36
|
-
return {
|
|
37
|
-
keyCache: k_cache,
|
|
38
|
-
valueCache: v_cache
|
|
39
|
-
};
|
|
40
|
-
});
|
|
41
|
-
}
|
|
42
|
-
reset() {
|
|
43
|
-
this.caches.forEach(cache => {
|
|
44
|
-
cache.reset();
|
|
45
|
-
});
|
|
46
|
-
}
|
|
47
|
-
dispose() {
|
|
48
|
-
this.caches.forEach(cache => {
|
|
49
|
-
cache.dispose();
|
|
50
|
-
});
|
|
51
|
-
}
|
|
52
|
-
get size() {
|
|
53
|
-
// the size of all KV caches are expected to be the same, just use the first one
|
|
54
|
-
return this.caches.entries().next().value?.[1].size ?? 0;
|
|
55
|
-
}
|
|
56
|
-
get maxSequenceLength() {
|
|
57
|
-
return this.max_sequence_length;
|
|
58
|
-
}
|
|
59
|
-
}
|
|
60
|
-
export class KvCache {
|
|
61
|
-
key_cache;
|
|
62
|
-
value_cache;
|
|
63
|
-
// the size of the KV cache, represents the number of tokens since the first chat token
|
|
64
|
-
current_position = 0;
|
|
65
|
-
batch_size;
|
|
66
|
-
max_sequence_length;
|
|
67
|
-
num_kv_heads;
|
|
68
|
-
head_dim;
|
|
69
|
-
constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }) {
|
|
70
|
-
const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim];
|
|
71
|
-
this.key_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
|
|
72
|
-
this.value_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
|
|
73
|
-
this.batch_size = batchSize;
|
|
74
|
-
this.max_sequence_length = maxSequenceLength;
|
|
75
|
-
this.num_kv_heads = numHeads;
|
|
76
|
-
this.head_dim = headDim;
|
|
77
|
-
}
|
|
78
|
-
/**
|
|
79
|
-
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
80
|
-
*/
|
|
81
|
-
update(key, value) {
|
|
82
|
-
const batch_size = key.shape[0];
|
|
83
|
-
const seq_len = key.shape[2];
|
|
84
|
-
if (batch_size > this.key_cache.shape[0]) {
|
|
85
|
-
throw Error(`The current KV cache has been set up with a batch size of` +
|
|
86
|
-
` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`);
|
|
87
|
-
}
|
|
88
|
-
if (this.current_position + seq_len > this.max_sequence_length) {
|
|
89
|
-
throw Error(`The KV cache has exceeded its maximum sequence length of ${this.max_sequence_length}. Use a larger value.`);
|
|
90
|
-
}
|
|
91
|
-
const new_key_cache = this.mergeIntoCache(key, this.key_cache);
|
|
92
|
-
const new_value_cache = this.mergeIntoCache(value, this.value_cache);
|
|
93
|
-
this.key_cache.assign(new_key_cache);
|
|
94
|
-
this.value_cache.assign(new_value_cache);
|
|
95
|
-
new_key_cache.dispose();
|
|
96
|
-
new_value_cache.dispose();
|
|
97
|
-
// advance the pointer to reflect the updated cache's current
|
|
98
|
-
this.current_position += seq_len;
|
|
99
|
-
return {
|
|
100
|
-
keyCache: this.key_cache,
|
|
101
|
-
valueCache: this.value_cache,
|
|
102
|
-
};
|
|
103
|
-
}
|
|
104
|
-
mergeIntoCache(new_value, current_cache) {
|
|
105
|
-
const seq_len = new_value.shape[2];
|
|
106
|
-
return tf.tidy(() => {
|
|
107
|
-
const historical = current_cache.slice([0, 0, 0, 0], [this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
|
|
108
|
-
const future = current_cache.slice([0, 0, this.current_position + seq_len, 0], [this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
|
|
109
|
-
// merge the new tensor into the current cache to create a new, larger, cache,
|
|
110
|
-
// this is different from Python immplementations because TFJS tensors are immutable,
|
|
111
|
-
// because we cannot update a slice, we must slice and concat
|
|
112
|
-
return tf.concat([historical, new_value, future], 2);
|
|
113
|
-
});
|
|
114
|
-
}
|
|
115
|
-
reset() {
|
|
116
|
-
this.current_position = 0;
|
|
117
|
-
tf.tidy(() => {
|
|
118
|
-
const key_cache_shape = this.key_cache.shape;
|
|
119
|
-
const value_cache_shape = this.value_cache.shape;
|
|
120
|
-
this.key_cache.assign(tf.zeros(key_cache_shape));
|
|
121
|
-
this.value_cache.assign(tf.zeros(value_cache_shape));
|
|
122
|
-
});
|
|
123
|
-
}
|
|
124
|
-
dispose() {
|
|
125
|
-
this.key_cache.dispose();
|
|
126
|
-
this.value_cache.dispose();
|
|
127
|
-
}
|
|
128
|
-
/**
|
|
129
|
-
* The size of the KV cache, also the number of tokens since the first one.
|
|
130
|
-
*/
|
|
131
|
-
get size() {
|
|
132
|
-
return this.current_position;
|
|
133
|
-
}
|
|
134
|
-
}
|
|
135
|
-
//# sourceMappingURL=kv_cache.js.map
|
package/dist/src/kv_cache.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { KvCacheContainer } from "@/kv_cache";
|
|
3
|
-
import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
|
|
4
|
-
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
5
|
-
/**
|
|
6
|
-
* MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
|
|
7
|
-
* should be used in a custom training loop because it requires the cache to be
|
|
8
|
-
* passed through the `kwargs.kvCache` argument during the `layer.apply()`
|
|
9
|
-
* forward propagation.
|
|
10
|
-
*
|
|
11
|
-
* If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
|
|
12
|
-
*/
|
|
13
|
-
export declare class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
|
|
14
|
-
static className: string;
|
|
15
|
-
protected rope: tf.layers.Layer;
|
|
16
|
-
constructor(args: MultiHeadAttentionArgs);
|
|
17
|
-
protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
|
|
18
|
-
protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D): {
|
|
19
|
-
keyCache: tf.Variable<tf.Rank.R4>;
|
|
20
|
-
valueCache: tf.Variable<tf.Rank.R4>;
|
|
21
|
-
};
|
|
22
|
-
/**
|
|
23
|
-
* Adds RoPE position encoding right after splitting heads.
|
|
24
|
-
*/
|
|
25
|
-
protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
|
|
26
|
-
query_split: tf.Tensor4D;
|
|
27
|
-
key_split: tf.Tensor4D;
|
|
28
|
-
value_split: tf.Tensor4D;
|
|
29
|
-
};
|
|
30
|
-
}
|
|
31
|
-
//# sourceMappingURL=cached_rope_multihead_attention.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"cached_rope_multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAE/F,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE;;;;;;;GAOG;AACH,qBAAa,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,SAAkC;IAElD,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAEpB,IAAI,EAAE,sBAAsB;cAMrB,OAAO,CACtB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAuC9B,SAAS,CAAC,WAAW,CAAC,YAAY,EAAE,gBAAgB,EAAE,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,WAAW,EAAE,EAAE,CAAC,QAAQ;;;;IAqBtG;;OAEG;cACgB,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAO5D,EAAE,CAAC,QAAQ;mBAEX,EAAE,CAAC,QAAQ;qBACwB,EAAE,CAAC,QAAQ;;CAIxF"}
|
|
@@ -1,76 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { MultiHeadAttention } from '@/layers/multihead_attention';
|
|
3
|
-
import { RotaryPositionEmbedding } from '@/layers/rotary_position_embedding';
|
|
4
|
-
/**
|
|
5
|
-
* MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
|
|
6
|
-
* should be used in a custom training loop because it requires the cache to be
|
|
7
|
-
* passed through the `kwargs.kvCache` argument during the `layer.apply()`
|
|
8
|
-
* forward propagation.
|
|
9
|
-
*
|
|
10
|
-
* If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
|
|
11
|
-
*/
|
|
12
|
-
export class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
|
|
13
|
-
static className = "CachedRoPEMultiHeadAttention";
|
|
14
|
-
rope;
|
|
15
|
-
constructor(args) {
|
|
16
|
-
super(args);
|
|
17
|
-
this.rope = new RotaryPositionEmbedding({ dim: Math.floor(this.embedDim / this.numHeads) });
|
|
18
|
-
}
|
|
19
|
-
forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
|
|
20
|
-
return tf.tidy(() => {
|
|
21
|
-
const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
|
|
22
|
-
// swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
|
|
23
|
-
const move_head_dim_forward = [0, 2, 1, 3];
|
|
24
|
-
const split = this.splitHeads(query, key, value, move_head_dim_forward);
|
|
25
|
-
const query_split = split.query_split;
|
|
26
|
-
let key_split = split.key_split;
|
|
27
|
-
let value_split = split.value_split;
|
|
28
|
-
if (kwargs.training !== true && kwargs.kvCache) {
|
|
29
|
-
// runs on inference, updates the KV cache and get the historical key and value
|
|
30
|
-
const cached_kv = this.getCachedKV(kwargs.kvCache, key_split, value_split);
|
|
31
|
-
key_split = cached_kv.keyCache;
|
|
32
|
-
value_split = cached_kv.valueCache;
|
|
33
|
-
}
|
|
34
|
-
// apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
|
|
35
|
-
const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
|
|
36
|
-
// concat heads and apply the output projection
|
|
37
|
-
const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], query_input.shape[1], this.embedDim]));
|
|
38
|
-
return output;
|
|
39
|
-
});
|
|
40
|
-
}
|
|
41
|
-
getCachedKV(kv_container, key_split, value_split) {
|
|
42
|
-
try {
|
|
43
|
-
let kv_cache = kv_container.update(this.name, key_split, value_split);
|
|
44
|
-
if (!kv_cache) {
|
|
45
|
-
kv_container.create(this.name, {
|
|
46
|
-
batchSize: key_split.shape[0],
|
|
47
|
-
numHeads: this.numHeads,
|
|
48
|
-
headDim: this.embedDim / this.numHeads,
|
|
49
|
-
});
|
|
50
|
-
kv_cache = kv_container.update(this.name, key_split, value_split);
|
|
51
|
-
}
|
|
52
|
-
return kv_cache;
|
|
53
|
-
}
|
|
54
|
-
catch (error) {
|
|
55
|
-
throw Error(`${this.getClassName()}::getCachedKV ${this.name} ${error.toString()}`);
|
|
56
|
-
}
|
|
57
|
-
}
|
|
58
|
-
/**
|
|
59
|
-
* Adds RoPE position encoding right after splitting heads.
|
|
60
|
-
*/
|
|
61
|
-
splitHeads(query, key, value, shuffle) {
|
|
62
|
-
const batch_size = query.shape[0];
|
|
63
|
-
const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
|
|
64
|
-
return tf.tidy(() => {
|
|
65
|
-
return {
|
|
66
|
-
query_split: this.rope.apply(query.reshape(split_heads))
|
|
67
|
-
.transpose(shuffle),
|
|
68
|
-
key_split: this.rope.apply(key.reshape(split_heads))
|
|
69
|
-
.transpose(shuffle),
|
|
70
|
-
value_split: value.reshape(split_heads).transpose(shuffle)
|
|
71
|
-
};
|
|
72
|
-
});
|
|
73
|
-
}
|
|
74
|
-
}
|
|
75
|
-
tf.serialization.registerClass(CachedRoPEMultiHeadAttention);
|
|
76
|
-
//# sourceMappingURL=cached_rope_multihead_attention.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"cached_rope_multihead_attention.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAC/F,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAI7E;;;;;;;GAOG;AACH,MAAM,OAAO,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,GAAG,8BAA8B,CAAC;IAExC,IAAI,CAAkB;IAEhC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;IAChG,CAAC;IAGkB,OAAO,CACtB,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAExE,MAAM,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YACtC,IAAI,SAAS,GAAG,KAAK,CAAC,SAAS,CAAC;YAChC,IAAI,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YAEpC,IAAI,MAAM,CAAC,QAAQ,KAAK,IAAI,IAAI,MAAM,CAAC,OAAO,EAAE,CAAC;gBAC7C,+EAA+E;gBAC/E,MAAM,SAAS,GAAG,IAAI,CAAC,WAAW,CAC9B,MAAM,CAAC,OAA2B,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;gBAEhE,SAAS,GAAG,SAAS,CAAC,QAAQ,CAAC;gBAC/B,WAAW,GAAG,SAAS,CAAC,UAAU,CAAC;YACvC,CAAC;YAED,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CACzC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,WAAW,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAEvE,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,WAAW,CAAC,YAA8B,EAAE,SAAsB,EAAE,WAAwB;QAClG,IAAI,CAAC;YACD,IAAI,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAEtE,IAAI,CAAC,QAAQ,EAAE,CAAC;gBACZ,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE;oBAC3B,SAAS,EAAE,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC;oBAC7B,QAAQ,EAAE,IAAI,CAAC,QAAQ;oBACvB,OAAO,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ;iBACzC,CAAC,CAAA;gBAEF,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAE,CAAC;YACvE,CAAC;YAED,OAAO,QAAS,CAAC;QACrB,CAAC;QAAC,OAAO,KAAU,EAAE,CAAC;YAClB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,IAAI,KAAK,CAAC,QAAQ,EAAE,EAAE,CAAC,CAAC;QACxF,CAAC;IACL,CAAC;IAGD;;OAEG;IACgB,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QAC/F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAClE,SAAS,CAAC,OAAO,CAAgB;gBACtC,SAAS,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAC9D,SAAS,CAAC,OAAO,CAAgB;gBACtC,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,4BAA4B,CAAC,CAAC"}
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"cached_rope_multihead_attention.test.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":""}
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { KvCacheContainer } from '@/kv_cache';
|
|
3
|
-
import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
|
|
4
|
-
// disables warning for using the faster node backend,
|
|
5
|
-
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
6
|
-
tf.env().set('IS_NODE', false);
|
|
7
|
-
describe("CachedRoPEMultiHeadAttention tests", () => {
|
|
8
|
-
test("aggregate forward passes output are identical normal multihead attention", () => {
|
|
9
|
-
compareNormalWithCachedAttention(tf.randomUniform([2, 10, 16]), 123);
|
|
10
|
-
compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 123);
|
|
11
|
-
compareNormalWithCachedAttention(tf.randomUniform([1, 1, 16]), 123);
|
|
12
|
-
compareNormalWithCachedAttention(tf.randomUniform([3, 2, 16]), 123);
|
|
13
|
-
// input exceeds KV cach size
|
|
14
|
-
expect(() => compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 5)).toThrow();
|
|
15
|
-
function compareNormalWithCachedAttention(input, max_sequence_length) {
|
|
16
|
-
const embed_dim = input.shape[2];
|
|
17
|
-
const batch = input.shape[0];
|
|
18
|
-
const heads = 2;
|
|
19
|
-
const kv_cache = new KvCacheContainer(max_sequence_length);
|
|
20
|
-
const normal_mha = new CachedRoPEMultiHeadAttention({ numHeads: heads, embedDim: embed_dim, causal: true });
|
|
21
|
-
const normal_mha_output = normal_mha.apply(input);
|
|
22
|
-
// initialize cached attention with identical configuration and weights
|
|
23
|
-
const cached_mha1 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test1" });
|
|
24
|
-
cached_mha1.build(input.shape);
|
|
25
|
-
cached_mha1.setWeights(normal_mha.getWeights());
|
|
26
|
-
const cached_mha2 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test2" });
|
|
27
|
-
cached_mha2.build(input.shape);
|
|
28
|
-
cached_mha2.setWeights(normal_mha.getWeights());
|
|
29
|
-
const cached_mha_outputs1 = [];
|
|
30
|
-
const cached_mha_outputs2 = [];
|
|
31
|
-
for (let i = 0; i < input.shape[1]; i++) {
|
|
32
|
-
const current_token = input.slice([0, i, 0], [batch, 1, embed_dim]);
|
|
33
|
-
cached_mha_outputs1.push(cached_mha1.apply(current_token, { kvCache: kv_cache }));
|
|
34
|
-
cached_mha_outputs2.push(cached_mha2.apply(current_token, { kvCache: kv_cache }));
|
|
35
|
-
}
|
|
36
|
-
expect(kv_cache.size == input.shape[1]);
|
|
37
|
-
expect(kv_cache.size == input.shape[1]);
|
|
38
|
-
expect(normal_mha_output.sub(tf.concat(cached_mha_outputs1, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
39
|
-
expect(normal_mha_output.sub(tf.concat(cached_mha_outputs2, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
40
|
-
}
|
|
41
|
-
});
|
|
42
|
-
});
|
|
43
|
-
//# sourceMappingURL=cached_rope_multihead_attention.test.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"cached_rope_multihead_attention.test.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AAGxF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,0EAA0E,EAAE,GAAG,EAAE;QAClF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAChF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAEhF,6BAA6B;QAC7B,MAAM,CAAC,GAAG,EAAE,CAAC,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEvG,SAAS,gCAAgC,CAAC,KAAkB,EAAE,mBAA2B;YACrF,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,KAAK,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC7B,MAAM,KAAK,GAAG,CAAC,CAAC;YAEhB,MAAM,QAAQ,GAAG,IAAI,gBAAgB,CAAC,mBAAmB,CAAC,CAAC;YAE3D,MAAM,UAAU,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;YAC5G,MAAM,iBAAiB,GAAG,UAAU,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;YAE/D,uEAAuE;YACvE,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAC5C,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAE5C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBACtC,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;gBAEpE,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;gBAC/F,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;YACnG,CAAC;YAED,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACvC,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YAExC,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;YACxG,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QAC5G,CAAC;IACL,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
-
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
4
|
-
import { TransformerDecoder, type TransformerDecoderArgs } from "@/layers/transformer_decoder";
|
|
5
|
-
export interface GPTDecoderBlockArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
|
|
6
|
-
dimsFeedForward?: number;
|
|
7
|
-
}
|
|
8
|
-
/**
|
|
9
|
-
* This implements the GPT-2 transformer block by modifying the transformer
|
|
10
|
-
* decoder block to use pre-layer-normalization and replacing ReLU activation
|
|
11
|
-
* with GELU.
|
|
12
|
-
*
|
|
13
|
-
* @param numHeads number of attention heads to use
|
|
14
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
15
|
-
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
16
|
-
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
17
|
-
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
18
|
-
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
19
|
-
*/
|
|
20
|
-
export declare class GPT2DecoderBlock extends TransformerDecoder {
|
|
21
|
-
static className: string;
|
|
22
|
-
constructor(args: TransformerDecoderArgs);
|
|
23
|
-
/**
|
|
24
|
-
* Attention sub-block which is similar to the original transformer except
|
|
25
|
-
* layer normalization is applied beginning
|
|
26
|
-
*/
|
|
27
|
-
protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
28
|
-
/**
|
|
29
|
-
* Feedforward sub-block which is similar to the original transformer except
|
|
30
|
-
* layer normalization is applied at the beginning and gelu activation is used
|
|
31
|
-
*/
|
|
32
|
-
protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
33
|
-
}
|
|
34
|
-
//# sourceMappingURL=gpt_decoder_block.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"gpt_decoder_block.d.ts","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAC3E,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,mBAAoB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAC/E,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;GAWG;AACH,qBAAa,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,SAAsB;gBAG1B,IAAI,EAAE,sBAAsB;IAKxC;;;OAGG;cACgB,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAcpF;;;OAGG;cACgB,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;CAkB/E"}
|
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { TransformerDecoder } from "@/layers/transformer_decoder";
|
|
3
|
-
/**
|
|
4
|
-
* This implements the GPT-2 transformer block by modifying the transformer
|
|
5
|
-
* decoder block to use pre-layer-normalization and replacing ReLU activation
|
|
6
|
-
* with GELU.
|
|
7
|
-
*
|
|
8
|
-
* @param numHeads number of attention heads to use
|
|
9
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
10
|
-
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
11
|
-
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
12
|
-
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
13
|
-
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
14
|
-
*/
|
|
15
|
-
export class GPT2DecoderBlock extends TransformerDecoder {
|
|
16
|
-
static className = "GPT2DecoderBlock";
|
|
17
|
-
constructor(args) {
|
|
18
|
-
super(args);
|
|
19
|
-
}
|
|
20
|
-
/**
|
|
21
|
-
* Attention sub-block which is similar to the original transformer except
|
|
22
|
-
* layer normalization is applied beginning
|
|
23
|
-
*/
|
|
24
|
-
causalSelfAttentionBlock(x, kwargs) {
|
|
25
|
-
return tf.tidy(() => {
|
|
26
|
-
const residual = x;
|
|
27
|
-
let attention = this.causalSelfAttentionNorm.apply(x, kwargs);
|
|
28
|
-
attention = this.causalSelfAttention.apply(attention, kwargs);
|
|
29
|
-
attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
|
|
30
|
-
attention = tf.add(attention, residual);
|
|
31
|
-
return attention;
|
|
32
|
-
});
|
|
33
|
-
}
|
|
34
|
-
/**
|
|
35
|
-
* Feedforward sub-block which is similar to the original transformer except
|
|
36
|
-
* layer normalization is applied at the beginning and gelu activation is used
|
|
37
|
-
*/
|
|
38
|
-
feedForwardBlock(x, kwargs) {
|
|
39
|
-
return tf.tidy(() => {
|
|
40
|
-
const residual = x;
|
|
41
|
-
let feedForward = this.feedFowardNorm.apply(x, kwargs);
|
|
42
|
-
feedForward = this.feedforward1.apply(feedForward, kwargs);
|
|
43
|
-
feedForward = this.feedforward2.apply(feedForward, kwargs);
|
|
44
|
-
feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
|
|
45
|
-
feedForward = tf.add(feedForward, residual);
|
|
46
|
-
return feedForward;
|
|
47
|
-
});
|
|
48
|
-
}
|
|
49
|
-
}
|
|
50
|
-
tf.serialization.registerClass(GPT2DecoderBlock);
|
|
51
|
-
//# sourceMappingURL=gpt_decoder_block.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"gpt_decoder_block.js","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAQ/F;;;;;;;;;;;GAWG;AACH,MAAM,OAAO,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,GAAG,kBAAkB,CAAC;IAGtC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;IAChB,CAAC;IAGD;;;OAGG;IACgB,wBAAwB,CAAC,CAAY,EAAE,MAAc;QACpE,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YAExC,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;;OAGG;IACgB,gBAAgB,CAAC,CAAY,EAAE,MAAc;QAC5D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACvD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAE5C,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;;AASL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC"}
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
|
|
2
|
-
import { GPT2DecoderBlock, GPTDecoderBlockArgs } from "./gpt_decoder_block";
|
|
3
|
-
import { MultiHeadAttention, MultiHeadAttentionArgs } from "./multihead_attention";
|
|
4
|
-
import { PositionalEncoding, PositionalEncodingArgs } from "./positional_encoding";
|
|
5
|
-
import { RotaryPositionEmbedding, RotaryPositionEmbeddingArgs } from "./rotary_position_embedding";
|
|
6
|
-
import { TokenAndPositionalEmbedding, TokenAndPositionalEmbeddingArgs } from "./token_and_positional_embedding";
|
|
7
|
-
import { TransformerDecoder, TransformerDecoderArgs } from "./transformer_decoder";
|
|
8
|
-
import { TransformerEncoder, TransformerEncoderArgs } from "./transformer_encoder";
|
|
9
|
-
export declare function tokenAndPositionalEmbedding(args: TokenAndPositionalEmbeddingArgs): TokenAndPositionalEmbedding;
|
|
10
|
-
export declare function transformerEncoder(args: TransformerEncoderArgs): TransformerEncoder;
|
|
11
|
-
export declare function transformerDecoder(args: TransformerDecoderArgs): TransformerDecoder;
|
|
12
|
-
export declare function multiheadAttention(args: MultiHeadAttentionArgs): MultiHeadAttention;
|
|
13
|
-
export declare function cachedRopeMultiheadAttention(args: MultiHeadAttentionArgs): CachedRoPEMultiHeadAttention;
|
|
14
|
-
export declare function positionalEncoding(args: PositionalEncodingArgs): PositionalEncoding;
|
|
15
|
-
export declare function gpt2DecoderBlock(args: GPTDecoderBlockArgs): GPT2DecoderBlock;
|
|
16
|
-
export declare function rotaryPositionEmbedding(args: RotaryPositionEmbeddingArgs): RotaryPositionEmbedding;
|
|
17
|
-
//# sourceMappingURL=index.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAAE,2BAA2B,EAAE,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAE,+BAA+B,EAAE,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AAGnF,wBAAgB,2BAA2B,CAAC,IAAI,EAAE,+BAA+B,+BAEhF;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,4BAA4B,CAAC,IAAI,EAAE,sBAAsB,gCAExE;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,gBAAgB,CAAC,IAAI,EAAE,mBAAmB,oBAEzD;AAGD,wBAAgB,uBAAuB,CAAC,IAAI,EAAE,2BAA2B,2BAExE"}
|
package/dist/src/layers/index.js
DELETED
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
|
|
2
|
-
import { GPT2DecoderBlock } from "./gpt_decoder_block";
|
|
3
|
-
import { MultiHeadAttention } from "./multihead_attention";
|
|
4
|
-
import { PositionalEncoding } from "./positional_encoding";
|
|
5
|
-
import { RotaryPositionEmbedding } from "./rotary_position_embedding";
|
|
6
|
-
import { TokenAndPositionalEmbedding } from "./token_and_positional_embedding";
|
|
7
|
-
import { TransformerDecoder } from "./transformer_decoder";
|
|
8
|
-
import { TransformerEncoder } from "./transformer_encoder";
|
|
9
|
-
export function tokenAndPositionalEmbedding(args) {
|
|
10
|
-
return new TokenAndPositionalEmbedding(args);
|
|
11
|
-
}
|
|
12
|
-
export function transformerEncoder(args) {
|
|
13
|
-
return new TransformerEncoder(args);
|
|
14
|
-
}
|
|
15
|
-
export function transformerDecoder(args) {
|
|
16
|
-
return new TransformerDecoder(args);
|
|
17
|
-
}
|
|
18
|
-
export function multiheadAttention(args) {
|
|
19
|
-
return new MultiHeadAttention(args);
|
|
20
|
-
}
|
|
21
|
-
export function cachedRopeMultiheadAttention(args) {
|
|
22
|
-
return new CachedRoPEMultiHeadAttention(args);
|
|
23
|
-
}
|
|
24
|
-
export function positionalEncoding(args) {
|
|
25
|
-
return new PositionalEncoding(args);
|
|
26
|
-
}
|
|
27
|
-
export function gpt2DecoderBlock(args) {
|
|
28
|
-
return new GPT2DecoderBlock(args);
|
|
29
|
-
}
|
|
30
|
-
export function rotaryPositionEmbedding(args) {
|
|
31
|
-
return new RotaryPositionEmbedding(args);
|
|
32
|
-
}
|
|
33
|
-
//# sourceMappingURL=index.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAuB,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAA+B,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAmC,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AAGnF,MAAM,UAAU,2BAA2B,CAAC,IAAqC;IAC7E,OAAO,IAAI,2BAA2B,CAAC,IAAI,CAAC,CAAC;AACjD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,4BAA4B,CAAC,IAA4B;IACrE,OAAO,IAAI,4BAA4B,CAAC,IAAI,CAAC,CAAC;AAClD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,gBAAgB,CAAC,IAAyB;IACtD,OAAO,IAAI,gBAAgB,CAAC,IAAI,CAAC,CAAC;AACtC,CAAC;AAGD,MAAM,UAAU,uBAAuB,CAAC,IAAiC;IACrE,OAAO,IAAI,uBAAuB,CAAC,IAAI,CAAC,CAAC;AAC7C,CAAC"}
|
|
@@ -1,106 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
-
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
4
|
-
export interface MultiHeadAttentionArgs extends LayerArgs {
|
|
5
|
-
numHeads: number;
|
|
6
|
-
embedDim: number;
|
|
7
|
-
useBias?: boolean;
|
|
8
|
-
dropout?: number;
|
|
9
|
-
causal?: boolean;
|
|
10
|
-
}
|
|
11
|
-
export interface ScaledDotProductionAttentionKwargs {
|
|
12
|
-
training?: boolean;
|
|
13
|
-
dropout?: number;
|
|
14
|
-
causal?: boolean;
|
|
15
|
-
scaling_factor?: number;
|
|
16
|
-
}
|
|
17
|
-
/**
|
|
18
|
-
* This MultiHead Attention layer implements the algorithm as described in
|
|
19
|
-
* the paper "Attention is all you Need" Vaswani et al., 2017.
|
|
20
|
-
*
|
|
21
|
-
* @param numHeads number of attention heads to use
|
|
22
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
23
|
-
* @param causal use causal masking, default `false`
|
|
24
|
-
* @param dropout use dropout during the attention calculations, default `0.0`
|
|
25
|
-
* @param useBias use bias for the dense sublayers, default `true`
|
|
26
|
-
*
|
|
27
|
-
* The TensorFlow version uses tf.einsum, whose gradient op has not yet been
|
|
28
|
-
* implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
|
|
29
|
-
* therefore we follow the PyTorch implementation described in:
|
|
30
|
-
* https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
|
|
31
|
-
* https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
32
|
-
*
|
|
33
|
-
* This implementation is different from TensorFlow's whose attention weights
|
|
34
|
-
* are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
|
|
35
|
-
* are shaped [embed dim, embed dim]
|
|
36
|
-
* https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
|
|
37
|
-
* https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
|
|
38
|
-
*
|
|
39
|
-
* TODO: implement a fast track for self attention (query = key = value)
|
|
40
|
-
* where a single dense layer combines and replaces the query, key and projection layers
|
|
41
|
-
*
|
|
42
|
-
* TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
|
|
43
|
-
*/
|
|
44
|
-
export declare class MultiHeadAttention extends tf.layers.Layer {
|
|
45
|
-
static className: string;
|
|
46
|
-
protected readonly numHeads: number;
|
|
47
|
-
protected readonly embedDim: number;
|
|
48
|
-
protected readonly useBias: boolean;
|
|
49
|
-
protected readonly dropout: number;
|
|
50
|
-
protected readonly causal: boolean;
|
|
51
|
-
protected readonly queryProjection: tf.layers.Layer;
|
|
52
|
-
protected readonly keyProjection: tf.layers.Layer;
|
|
53
|
-
protected readonly valueProjection: tf.layers.Layer;
|
|
54
|
-
protected readonly outputProjection: tf.layers.Layer;
|
|
55
|
-
constructor({ numHeads, embedDim, useBias, dropout, causal, ...args }: MultiHeadAttentionArgs);
|
|
56
|
-
/**
|
|
57
|
-
* Forward propagation. Provide one input tensor or three identical tensors to self-attention.
|
|
58
|
-
* @param inputs a single tensor for self-attention or an array of exactly three
|
|
59
|
-
* tensors that are either identical (self-attention) or different (cross-attention)
|
|
60
|
-
* @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
|
|
61
|
-
*/
|
|
62
|
-
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs & {
|
|
63
|
-
packingMask?: tf.Tensor;
|
|
64
|
-
causalMask?: tf.Tensor;
|
|
65
|
-
}): tf.Tensor | tf.Tensor[];
|
|
66
|
-
/**
|
|
67
|
-
* Forward propagation
|
|
68
|
-
*/
|
|
69
|
-
protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
|
|
70
|
-
protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor): {
|
|
71
|
-
query: tf.Tensor;
|
|
72
|
-
key: tf.Tensor;
|
|
73
|
-
value: tf.Tensor;
|
|
74
|
-
};
|
|
75
|
-
protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
|
|
76
|
-
query_split: tf.Tensor4D;
|
|
77
|
-
key_split: tf.Tensor4D;
|
|
78
|
-
value_split: tf.Tensor4D;
|
|
79
|
-
};
|
|
80
|
-
/**
|
|
81
|
-
* Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
|
|
82
|
-
* formula (1) of the 2017 paper Attention Is All You Need
|
|
83
|
-
*
|
|
84
|
-
* @param attentionMask a mask to prevent tokens from being
|
|
85
|
-
* attended to (usually for padding tokens). It should have the shape
|
|
86
|
-
* [batch, head, query_sequence_len, key_sequence_len]. To use in
|
|
87
|
-
* conjunction with causal masking, the tensor should be a boolean type
|
|
88
|
-
* where false indicates a masked token.
|
|
89
|
-
* @param packingMask a mask to prevent tokens from attending across document boundaries
|
|
90
|
-
*/
|
|
91
|
-
static scaledDotProductionAttention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, attentionMask: tf.Tensor | null, packingMask: tf.Tensor | null, causalMask: tf.Tensor | null, dropout: number, causal: boolean, kwargs?: ScaledDotProductionAttentionKwargs): tf.Tensor;
|
|
92
|
-
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
93
|
-
/**
|
|
94
|
-
* MultiHead attention's output is the same shape the query's.
|
|
95
|
-
*/
|
|
96
|
-
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
97
|
-
getConfig(): {
|
|
98
|
-
numHeads: number;
|
|
99
|
-
embedDim: number;
|
|
100
|
-
useBias: boolean;
|
|
101
|
-
causal: boolean;
|
|
102
|
-
dropout: number;
|
|
103
|
-
name: string;
|
|
104
|
-
};
|
|
105
|
-
}
|
|
106
|
-
//# sourceMappingURL=multihead_attention.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAIjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IACrD,QAAQ,EAAE,MAAM,CAAC;IACjB,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,CAAC,EAAE,OAAO,CAAC;IAClB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD,MAAM,WAAW,kCAAkC;IAC/C,QAAQ,CAAC,EAAE,OAAO,CAAC;IACnB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;IACjB,cAAc,CAAC,EAAE,MAAM,CAAC;CAC3B;AAGD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,MAAM,EAAE,OAAO,CAAC;IAInC,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,aAAa,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAClD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,gBAAgB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAGzC,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAc,EAAE,OAAa,EAAE,MAAc,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IA0BlH;;;;;OAKG;IACM,IAAI,CACT,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAC/B,MAAM,EAAE,MAAM,GAAG;QACb,WAAW,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;QACxB,UAAU,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;KAC1B,GACF,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IA6B1B;;OAEG;IACH,SAAS,CAAC,OAAO,CACb,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IA+B9B,SAAS,CAAC,qBAAqB,CAAC,WAAW,EAAE,EAAE,CAAC,MAAM,EAAE,SAAS,EAAE,EAAE,CAAC,MAAM,EAAE,WAAW,EAAE,EAAE,CAAC,MAAM;eAMtC,EAAE,CAAC,MAAM;aACf,EAAE,CAAC,MAAM;eACH,EAAE,CAAC,MAAM;;IAMvE,SAAS,CAAC,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAShB,EAAE,CAAC,QAAQ;mBACf,EAAE,CAAC,QAAQ;qBACP,EAAE,CAAC,QAAQ;;IAMrF;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,GAAG,EAAE,EAAE,CAAC,MAAM,EACd,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,aAAa,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC/B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,UAAU,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC5B,OAAO,EAAE,MAAM,EACf,MAAM,EAAE,OAAO,EACf,MAAM,GAAE,kCAAuC,GAChD,EAAE,CAAC,MAAM;IA0EH,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA4CvD;;OAEG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS;;;;;;;;CAgBrB"}
|