@stellarapp/tfjs-stellar 1.0.0 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +14 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,113 +1,76 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
3
|
-
import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
|
|
2
|
+
import { MultiHeadAttention } from '@/layers/multihead_attention';
|
|
4
3
|
import { RotaryPositionEmbedding } from '@/layers/rotary_position_embedding';
|
|
5
|
-
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
6
|
-
|
|
7
|
-
|
|
8
4
|
/**
|
|
9
5
|
* MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
|
|
10
6
|
* should be used in a custom training loop because it requires the cache to be
|
|
11
7
|
* passed through the `kwargs.kvCache` argument during the `layer.apply()`
|
|
12
8
|
* forward propagation.
|
|
13
|
-
*
|
|
9
|
+
*
|
|
14
10
|
* If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
|
|
15
11
|
*/
|
|
16
12
|
export class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
|
|
17
13
|
static className = "CachedRoPEMultiHeadAttention";
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
constructor(args: MultiHeadAttentionArgs) {
|
|
14
|
+
rope;
|
|
15
|
+
constructor(args) {
|
|
22
16
|
super(args);
|
|
23
17
|
this.rope = new RotaryPositionEmbedding({ dim: Math.floor(this.embedDim / this.numHeads) });
|
|
24
18
|
}
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
protected override forward(
|
|
28
|
-
query_input: tf.Tensor,
|
|
29
|
-
key_input: tf.Tensor,
|
|
30
|
-
value_input: tf.Tensor,
|
|
31
|
-
packing_mask: tf.Tensor | null,
|
|
32
|
-
causal_mask: tf.Tensor | null,
|
|
33
|
-
kwargs: Kwargs): tf.Tensor {
|
|
34
|
-
|
|
19
|
+
forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
|
|
35
20
|
return tf.tidy(() => {
|
|
36
21
|
const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
|
|
37
|
-
|
|
38
22
|
// swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
|
|
39
23
|
const move_head_dim_forward = [0, 2, 1, 3];
|
|
40
|
-
|
|
41
24
|
const split = this.splitHeads(query, key, value, move_head_dim_forward);
|
|
42
|
-
|
|
43
25
|
const query_split = split.query_split;
|
|
44
26
|
let key_split = split.key_split;
|
|
45
27
|
let value_split = split.value_split;
|
|
46
|
-
|
|
47
28
|
if (kwargs.training !== true && kwargs.kvCache) {
|
|
48
29
|
// runs on inference, updates the KV cache and get the historical key and value
|
|
49
|
-
const cached_kv = this.getCachedKV(
|
|
50
|
-
kwargs.kvCache as KvCacheContainer, key_split, value_split);
|
|
51
|
-
|
|
30
|
+
const cached_kv = this.getCachedKV(kwargs.kvCache, key_split, value_split);
|
|
52
31
|
key_split = cached_kv.keyCache;
|
|
53
32
|
value_split = cached_kv.valueCache;
|
|
54
33
|
}
|
|
55
|
-
|
|
56
34
|
// apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
|
|
57
|
-
const spda = MultiHeadAttention.scaledDotProductionAttention(
|
|
58
|
-
query_split, key_split, value_split,
|
|
59
|
-
kwargs.attentionMask ?? null, packing_mask, causal_mask,
|
|
60
|
-
this.dropout, this.causal, kwargs);
|
|
61
|
-
|
|
35
|
+
const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
|
|
62
36
|
// concat heads and apply the output projection
|
|
63
|
-
const output = this.outputProjection.apply(
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
return output as tf.Tensor;
|
|
68
|
-
})
|
|
37
|
+
const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], query_input.shape[1], this.embedDim]));
|
|
38
|
+
return output;
|
|
39
|
+
});
|
|
69
40
|
}
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D) {
|
|
41
|
+
getCachedKV(kv_container, key_split, value_split) {
|
|
73
42
|
try {
|
|
74
43
|
let kv_cache = kv_container.update(this.name, key_split, value_split);
|
|
75
|
-
|
|
76
44
|
if (!kv_cache) {
|
|
77
45
|
kv_container.create(this.name, {
|
|
78
46
|
batchSize: key_split.shape[0],
|
|
79
47
|
numHeads: this.numHeads,
|
|
80
48
|
headDim: this.embedDim / this.numHeads,
|
|
81
|
-
})
|
|
82
|
-
|
|
83
|
-
kv_cache = kv_container.update(this.name, key_split, value_split)!;
|
|
49
|
+
});
|
|
50
|
+
kv_cache = kv_container.update(this.name, key_split, value_split);
|
|
84
51
|
}
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
52
|
+
return kv_cache;
|
|
53
|
+
}
|
|
54
|
+
catch (error) {
|
|
88
55
|
throw Error(`${this.getClassName()}::getCachedKV ${this.name} ${error.toString()}`);
|
|
89
56
|
}
|
|
90
57
|
}
|
|
91
|
-
|
|
92
|
-
|
|
93
58
|
/**
|
|
94
59
|
* Adds RoPE position encoding right after splitting heads.
|
|
95
60
|
*/
|
|
96
|
-
|
|
61
|
+
splitHeads(query, key, value, shuffle) {
|
|
97
62
|
const batch_size = query.shape[0];
|
|
98
63
|
const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
|
|
99
|
-
|
|
100
64
|
return tf.tidy(() => {
|
|
101
65
|
return {
|
|
102
|
-
query_split:
|
|
103
|
-
.transpose(shuffle)
|
|
104
|
-
key_split:
|
|
105
|
-
.transpose(shuffle)
|
|
106
|
-
value_split: value.reshape(split_heads).transpose(shuffle)
|
|
107
|
-
}
|
|
108
|
-
})
|
|
66
|
+
query_split: this.rope.apply(query.reshape(split_heads))
|
|
67
|
+
.transpose(shuffle),
|
|
68
|
+
key_split: this.rope.apply(key.reshape(split_heads))
|
|
69
|
+
.transpose(shuffle),
|
|
70
|
+
value_split: value.reshape(split_heads).transpose(shuffle)
|
|
71
|
+
};
|
|
72
|
+
});
|
|
109
73
|
}
|
|
110
74
|
}
|
|
111
|
-
|
|
112
|
-
|
|
113
75
|
tf.serialization.registerClass(CachedRoPEMultiHeadAttention);
|
|
76
|
+
//# sourceMappingURL=cached_rope_multihead_attention.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"cached_rope_multihead_attention.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAC/F,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAI7E;;;;;;;GAOG;AACH,MAAM,OAAO,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,GAAG,8BAA8B,CAAC;IAExC,IAAI,CAAkB;IAEhC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;IAChG,CAAC;IAGkB,OAAO,CACtB,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAExE,MAAM,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YACtC,IAAI,SAAS,GAAG,KAAK,CAAC,SAAS,CAAC;YAChC,IAAI,WAAW,GAAG,KAAK,CAAC,WAAW,CAAC;YAEpC,IAAI,MAAM,CAAC,QAAQ,KAAK,IAAI,IAAI,MAAM,CAAC,OAAO,EAAE,CAAC;gBAC7C,+EAA+E;gBAC/E,MAAM,SAAS,GAAG,IAAI,CAAC,WAAW,CAC9B,MAAM,CAAC,OAA2B,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;gBAEhE,SAAS,GAAG,SAAS,CAAC,QAAQ,CAAC;gBAC/B,WAAW,GAAG,SAAS,CAAC,UAAU,CAAC;YACvC,CAAC;YAED,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CACzC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,WAAW,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAEvE,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,WAAW,CAAC,YAA8B,EAAE,SAAsB,EAAE,WAAwB;QAClG,IAAI,CAAC;YACD,IAAI,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAEtE,IAAI,CAAC,QAAQ,EAAE,CAAC;gBACZ,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE;oBAC3B,SAAS,EAAE,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC;oBAC7B,QAAQ,EAAE,IAAI,CAAC,QAAQ;oBACvB,OAAO,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ;iBACzC,CAAC,CAAA;gBAEF,QAAQ,GAAG,YAAY,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,WAAW,CAAE,CAAC;YACvE,CAAC;YAED,OAAO,QAAS,CAAC;QACrB,CAAC;QAAC,OAAO,KAAU,EAAE,CAAC;YAClB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,IAAI,KAAK,CAAC,QAAQ,EAAE,EAAE,CAAC,CAAC;QACxF,CAAC;IACL,CAAC;IAGD;;OAEG;IACgB,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QAC/F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAClE,SAAS,CAAC,OAAO,CAAgB;gBACtC,SAAS,EAAG,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAe;qBAC9D,SAAS,CAAC,OAAO,CAAgB;gBACtC,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,4BAA4B,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"cached_rope_multihead_attention.test.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { KvCacheContainer } from '@/kv_cache';
|
|
3
|
+
import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
|
|
4
|
+
// disables warning for using the faster node backend,
|
|
5
|
+
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
6
|
+
tf.env().set('IS_NODE', false);
|
|
7
|
+
describe("CachedRoPEMultiHeadAttention tests", () => {
|
|
8
|
+
test("aggregate forward passes output are identical normal multihead attention", () => {
|
|
9
|
+
compareNormalWithCachedAttention(tf.randomUniform([2, 10, 16]), 123);
|
|
10
|
+
compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 123);
|
|
11
|
+
compareNormalWithCachedAttention(tf.randomUniform([1, 1, 16]), 123);
|
|
12
|
+
compareNormalWithCachedAttention(tf.randomUniform([3, 2, 16]), 123);
|
|
13
|
+
// input exceeds KV cach size
|
|
14
|
+
expect(() => compareNormalWithCachedAttention(tf.randomUniform([1, 10, 16]), 5)).toThrow();
|
|
15
|
+
function compareNormalWithCachedAttention(input, max_sequence_length) {
|
|
16
|
+
const embed_dim = input.shape[2];
|
|
17
|
+
const batch = input.shape[0];
|
|
18
|
+
const heads = 2;
|
|
19
|
+
const kv_cache = new KvCacheContainer(max_sequence_length);
|
|
20
|
+
const normal_mha = new CachedRoPEMultiHeadAttention({ numHeads: heads, embedDim: embed_dim, causal: true });
|
|
21
|
+
const normal_mha_output = normal_mha.apply(input);
|
|
22
|
+
// initialize cached attention with identical configuration and weights
|
|
23
|
+
const cached_mha1 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test1" });
|
|
24
|
+
cached_mha1.build(input.shape);
|
|
25
|
+
cached_mha1.setWeights(normal_mha.getWeights());
|
|
26
|
+
const cached_mha2 = new CachedRoPEMultiHeadAttention({ ...normal_mha.getConfig(), name: "cache_test2" });
|
|
27
|
+
cached_mha2.build(input.shape);
|
|
28
|
+
cached_mha2.setWeights(normal_mha.getWeights());
|
|
29
|
+
const cached_mha_outputs1 = [];
|
|
30
|
+
const cached_mha_outputs2 = [];
|
|
31
|
+
for (let i = 0; i < input.shape[1]; i++) {
|
|
32
|
+
const current_token = input.slice([0, i, 0], [batch, 1, embed_dim]);
|
|
33
|
+
cached_mha_outputs1.push(cached_mha1.apply(current_token, { kvCache: kv_cache }));
|
|
34
|
+
cached_mha_outputs2.push(cached_mha2.apply(current_token, { kvCache: kv_cache }));
|
|
35
|
+
}
|
|
36
|
+
expect(kv_cache.size == input.shape[1]);
|
|
37
|
+
expect(kv_cache.size == input.shape[1]);
|
|
38
|
+
expect(normal_mha_output.sub(tf.concat(cached_mha_outputs1, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
39
|
+
expect(normal_mha_output.sub(tf.concat(cached_mha_outputs2, 1)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
40
|
+
}
|
|
41
|
+
});
|
|
42
|
+
});
|
|
43
|
+
//# sourceMappingURL=cached_rope_multihead_attention.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"cached_rope_multihead_attention.test.js","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AAGxF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,0EAA0E,EAAE,GAAG,EAAE;QAClF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QACjF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAChF,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAEhF,6BAA6B;QAC7B,MAAM,CAAC,GAAG,EAAE,CAAC,gCAAgC,CAAC,EAAE,CAAC,aAAa,CAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEvG,SAAS,gCAAgC,CAAC,KAAkB,EAAE,mBAA2B;YACrF,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM,KAAK,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC7B,MAAM,KAAK,GAAG,CAAC,CAAC;YAEhB,MAAM,QAAQ,GAAG,IAAI,gBAAgB,CAAC,mBAAmB,CAAC,CAAC;YAE3D,MAAM,UAAU,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;YAC5G,MAAM,iBAAiB,GAAG,UAAU,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;YAE/D,uEAAuE;YACvE,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,WAAW,GAAG,IAAI,4BAA4B,CAAC,EAAE,GAAG,UAAU,CAAC,SAAS,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,CAAC,CAAC;YACzG,WAAW,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,WAAW,CAAC,UAAU,CAAC,UAAU,CAAC,UAAU,EAAE,CAAC,CAAC;YAEhD,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAC5C,MAAM,mBAAmB,GAAgB,EAAE,CAAC;YAE5C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBACtC,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,KAAK,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;gBAEpE,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;gBAC/F,mBAAmB,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC,CAAC;YACnG,CAAC;YAED,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACvC,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YAExC,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;YACxG,MAAM,CAAC,iBAAiB,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,mBAAmB,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QAC5G,CAAC;IACL,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
+
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
4
|
+
import { TransformerDecoder, type TransformerDecoderArgs } from "@/layers/transformer_decoder";
|
|
5
|
+
export interface GPTDecoderBlockArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
|
|
6
|
+
dimsFeedForward?: number;
|
|
7
|
+
}
|
|
8
|
+
/**
|
|
9
|
+
* This implements the GPT-2 transformer block by modifying the transformer
|
|
10
|
+
* decoder block to use pre-layer-normalization and replacing ReLU activation
|
|
11
|
+
* with GELU.
|
|
12
|
+
*
|
|
13
|
+
* @param numHeads number of attention heads to use
|
|
14
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
15
|
+
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
16
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
17
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
18
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
19
|
+
*/
|
|
20
|
+
export declare class GPT2DecoderBlock extends TransformerDecoder {
|
|
21
|
+
static className: string;
|
|
22
|
+
constructor(args: TransformerDecoderArgs);
|
|
23
|
+
/**
|
|
24
|
+
* Attention sub-block which is similar to the original transformer except
|
|
25
|
+
* layer normalization is applied beginning
|
|
26
|
+
*/
|
|
27
|
+
protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
28
|
+
/**
|
|
29
|
+
* Feedforward sub-block which is similar to the original transformer except
|
|
30
|
+
* layer normalization is applied at the beginning and gelu activation is used
|
|
31
|
+
*/
|
|
32
|
+
protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
33
|
+
}
|
|
34
|
+
//# sourceMappingURL=gpt_decoder_block.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"gpt_decoder_block.d.ts","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAC3E,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,mBAAoB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAC/E,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;GAWG;AACH,qBAAa,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,SAAsB;gBAG1B,IAAI,EAAE,sBAAsB;IAKxC;;;OAGG;cACgB,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAcpF;;;OAGG;cACgB,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;CAkB/E"}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { TransformerDecoder } from "@/layers/transformer_decoder";
|
|
3
|
+
/**
|
|
4
|
+
* This implements the GPT-2 transformer block by modifying the transformer
|
|
5
|
+
* decoder block to use pre-layer-normalization and replacing ReLU activation
|
|
6
|
+
* with GELU.
|
|
7
|
+
*
|
|
8
|
+
* @param numHeads number of attention heads to use
|
|
9
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
10
|
+
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
11
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
12
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
13
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
14
|
+
*/
|
|
15
|
+
export class GPT2DecoderBlock extends TransformerDecoder {
|
|
16
|
+
static className = "GPT2DecoderBlock";
|
|
17
|
+
constructor(args) {
|
|
18
|
+
super(args);
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* Attention sub-block which is similar to the original transformer except
|
|
22
|
+
* layer normalization is applied beginning
|
|
23
|
+
*/
|
|
24
|
+
causalSelfAttentionBlock(x, kwargs) {
|
|
25
|
+
return tf.tidy(() => {
|
|
26
|
+
const residual = x;
|
|
27
|
+
let attention = this.causalSelfAttentionNorm.apply(x, kwargs);
|
|
28
|
+
attention = this.causalSelfAttention.apply(attention, kwargs);
|
|
29
|
+
attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
|
|
30
|
+
attention = tf.add(attention, residual);
|
|
31
|
+
return attention;
|
|
32
|
+
});
|
|
33
|
+
}
|
|
34
|
+
/**
|
|
35
|
+
* Feedforward sub-block which is similar to the original transformer except
|
|
36
|
+
* layer normalization is applied at the beginning and gelu activation is used
|
|
37
|
+
*/
|
|
38
|
+
feedForwardBlock(x, kwargs) {
|
|
39
|
+
return tf.tidy(() => {
|
|
40
|
+
const residual = x;
|
|
41
|
+
let feedForward = this.feedFowardNorm.apply(x, kwargs);
|
|
42
|
+
feedForward = this.feedforward1.apply(feedForward, kwargs);
|
|
43
|
+
feedForward = this.feedforward2.apply(feedForward, kwargs);
|
|
44
|
+
feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
|
|
45
|
+
feedForward = tf.add(feedForward, residual);
|
|
46
|
+
return feedForward;
|
|
47
|
+
});
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
tf.serialization.registerClass(GPT2DecoderBlock);
|
|
51
|
+
//# sourceMappingURL=gpt_decoder_block.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"gpt_decoder_block.js","sourceRoot":"","sources":["../../../src/layers/gpt_decoder_block.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAQ/F;;;;;;;;;;;GAWG;AACH,MAAM,OAAO,gBAAiB,SAAQ,kBAAkB;IACpD,MAAM,CAAC,SAAS,GAAG,kBAAkB,CAAC;IAGtC,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;IAChB,CAAC;IAGD;;;OAGG;IACgB,wBAAwB,CAAC,CAAY,EAAE,MAAc;QACpE,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC3E,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YAExC,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;;OAGG;IACgB,gBAAgB,CAAC,CAAY,EAAE,MAAc;QAC5D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACvD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAE5C,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;;AASL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC"}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
|
|
2
|
+
import { GPT2DecoderBlock, GPTDecoderBlockArgs } from "./gpt_decoder_block";
|
|
3
|
+
import { MultiHeadAttention, MultiHeadAttentionArgs } from "./multihead_attention";
|
|
4
|
+
import { PositionalEncoding, PositionalEncodingArgs } from "./positional_encoding";
|
|
5
|
+
import { RotaryPositionEmbedding, RotaryPositionEmbeddingArgs } from "./rotary_position_embedding";
|
|
6
|
+
import { TokenAndPositionalEmbedding, TokenAndPositionalEmbeddingArgs } from "./token_and_positional_embedding";
|
|
7
|
+
import { TransformerDecoder, TransformerDecoderArgs } from "./transformer_decoder";
|
|
8
|
+
import { TransformerEncoder, TransformerEncoderArgs } from "./transformer_encoder";
|
|
9
|
+
export declare function tokenAndPositionalEmbedding(args: TokenAndPositionalEmbeddingArgs): TokenAndPositionalEmbedding;
|
|
10
|
+
export declare function transformerEncoder(args: TransformerEncoderArgs): TransformerEncoder;
|
|
11
|
+
export declare function transformerDecoder(args: TransformerDecoderArgs): TransformerDecoder;
|
|
12
|
+
export declare function multiheadAttention(args: MultiHeadAttentionArgs): MultiHeadAttention;
|
|
13
|
+
export declare function cachedRopeMultiheadAttention(args: MultiHeadAttentionArgs): CachedRoPEMultiHeadAttention;
|
|
14
|
+
export declare function positionalEncoding(args: PositionalEncodingArgs): PositionalEncoding;
|
|
15
|
+
export declare function gpt2DecoderBlock(args: GPTDecoderBlockArgs): GPT2DecoderBlock;
|
|
16
|
+
export declare function rotaryPositionEmbedding(args: RotaryPositionEmbeddingArgs): RotaryPositionEmbedding;
|
|
17
|
+
//# sourceMappingURL=index.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAAE,2BAA2B,EAAE,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAE,+BAA+B,EAAE,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAAE,sBAAsB,EAAE,MAAM,uBAAuB,CAAC;AAGnF,wBAAgB,2BAA2B,CAAC,IAAI,EAAE,+BAA+B,+BAEhF;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,4BAA4B,CAAC,IAAI,EAAE,sBAAsB,gCAExE;AAGD,wBAAgB,kBAAkB,CAAC,IAAI,EAAE,sBAAsB,sBAE9D;AAGD,wBAAgB,gBAAgB,CAAC,IAAI,EAAE,mBAAmB,oBAEzD;AAGD,wBAAgB,uBAAuB,CAAC,IAAI,EAAE,2BAA2B,2BAExE"}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import { CachedRoPEMultiHeadAttention } from "./cached_rope_multihead_attention";
|
|
2
|
+
import { GPT2DecoderBlock } from "./gpt_decoder_block";
|
|
3
|
+
import { MultiHeadAttention } from "./multihead_attention";
|
|
4
|
+
import { PositionalEncoding } from "./positional_encoding";
|
|
5
|
+
import { RotaryPositionEmbedding } from "./rotary_position_embedding";
|
|
6
|
+
import { TokenAndPositionalEmbedding } from "./token_and_positional_embedding";
|
|
7
|
+
import { TransformerDecoder } from "./transformer_decoder";
|
|
8
|
+
import { TransformerEncoder } from "./transformer_encoder";
|
|
9
|
+
export function tokenAndPositionalEmbedding(args) {
|
|
10
|
+
return new TokenAndPositionalEmbedding(args);
|
|
11
|
+
}
|
|
12
|
+
export function transformerEncoder(args) {
|
|
13
|
+
return new TransformerEncoder(args);
|
|
14
|
+
}
|
|
15
|
+
export function transformerDecoder(args) {
|
|
16
|
+
return new TransformerDecoder(args);
|
|
17
|
+
}
|
|
18
|
+
export function multiheadAttention(args) {
|
|
19
|
+
return new MultiHeadAttention(args);
|
|
20
|
+
}
|
|
21
|
+
export function cachedRopeMultiheadAttention(args) {
|
|
22
|
+
return new CachedRoPEMultiHeadAttention(args);
|
|
23
|
+
}
|
|
24
|
+
export function positionalEncoding(args) {
|
|
25
|
+
return new PositionalEncoding(args);
|
|
26
|
+
}
|
|
27
|
+
export function gpt2DecoderBlock(args) {
|
|
28
|
+
return new GPT2DecoderBlock(args);
|
|
29
|
+
}
|
|
30
|
+
export function rotaryPositionEmbedding(args) {
|
|
31
|
+
return new RotaryPositionEmbedding(args);
|
|
32
|
+
}
|
|
33
|
+
//# sourceMappingURL=index.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../../src/layers/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,4BAA4B,EAAE,MAAM,mCAAmC,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAuB,MAAM,qBAAqB,CAAC;AAC5E,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,uBAAuB,EAA+B,MAAM,6BAA6B,CAAC;AACnG,OAAO,EAAE,2BAA2B,EAAmC,MAAM,kCAAkC,CAAC;AAChH,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AACnF,OAAO,EAAE,kBAAkB,EAA0B,MAAM,uBAAuB,CAAC;AAGnF,MAAM,UAAU,2BAA2B,CAAC,IAAqC;IAC7E,OAAO,IAAI,2BAA2B,CAAC,IAAI,CAAC,CAAC;AACjD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,4BAA4B,CAAC,IAA4B;IACrE,OAAO,IAAI,4BAA4B,CAAC,IAAI,CAAC,CAAC;AAClD,CAAC;AAGD,MAAM,UAAU,kBAAkB,CAAC,IAA4B;IAC3D,OAAO,IAAI,kBAAkB,CAAC,IAAI,CAAC,CAAC;AACxC,CAAC;AAGD,MAAM,UAAU,gBAAgB,CAAC,IAAyB;IACtD,OAAO,IAAI,gBAAgB,CAAC,IAAI,CAAC,CAAC;AACtC,CAAC;AAGD,MAAM,UAAU,uBAAuB,CAAC,IAAiC;IACrE,OAAO,IAAI,uBAAuB,CAAC,IAAI,CAAC,CAAC;AAC7C,CAAC"}
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
+
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
4
|
+
export interface MultiHeadAttentionArgs extends LayerArgs {
|
|
5
|
+
numHeads: number;
|
|
6
|
+
embedDim: number;
|
|
7
|
+
useBias?: boolean;
|
|
8
|
+
dropout?: number;
|
|
9
|
+
causal?: boolean;
|
|
10
|
+
}
|
|
11
|
+
export interface ScaledDotProductionAttentionKwargs {
|
|
12
|
+
training?: boolean;
|
|
13
|
+
dropout?: number;
|
|
14
|
+
causal?: boolean;
|
|
15
|
+
scaling_factor?: number;
|
|
16
|
+
}
|
|
17
|
+
/**
|
|
18
|
+
* This MultiHead Attention layer implements the algorithm as described in
|
|
19
|
+
* the paper "Attention is all you Need" Vaswani et al., 2017.
|
|
20
|
+
*
|
|
21
|
+
* @param numHeads number of attention heads to use
|
|
22
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
23
|
+
* @param causal use causal masking, default `false`
|
|
24
|
+
* @param dropout use dropout during the attention calculations, default `0.0`
|
|
25
|
+
* @param useBias use bias for the dense sublayers, default `true`
|
|
26
|
+
*
|
|
27
|
+
* The TensorFlow version uses tf.einsum, whose gradient op has not yet been
|
|
28
|
+
* implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
|
|
29
|
+
* therefore we follow the PyTorch implementation described in:
|
|
30
|
+
* https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
|
|
31
|
+
* https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
32
|
+
*
|
|
33
|
+
* This implementation is different from TensorFlow's whose attention weights
|
|
34
|
+
* are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
|
|
35
|
+
* are shaped [embed dim, embed dim]
|
|
36
|
+
* https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
|
|
37
|
+
* https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
|
|
38
|
+
*
|
|
39
|
+
* TODO: implement a fast track for self attention (query = key = value)
|
|
40
|
+
* where a single dense layer combines and replaces the query, key and projection layers
|
|
41
|
+
*
|
|
42
|
+
* TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
|
|
43
|
+
*/
|
|
44
|
+
export declare class MultiHeadAttention extends tf.layers.Layer {
|
|
45
|
+
static className: string;
|
|
46
|
+
protected readonly numHeads: number;
|
|
47
|
+
protected readonly embedDim: number;
|
|
48
|
+
protected readonly useBias: boolean;
|
|
49
|
+
protected readonly dropout: number;
|
|
50
|
+
protected readonly causal: boolean;
|
|
51
|
+
protected readonly queryProjection: tf.layers.Layer;
|
|
52
|
+
protected readonly keyProjection: tf.layers.Layer;
|
|
53
|
+
protected readonly valueProjection: tf.layers.Layer;
|
|
54
|
+
protected readonly outputProjection: tf.layers.Layer;
|
|
55
|
+
constructor({ numHeads, embedDim, useBias, dropout, causal, ...args }: MultiHeadAttentionArgs);
|
|
56
|
+
/**
|
|
57
|
+
* Forward propagation. Provide one input tensor or three identical tensors to self-attention.
|
|
58
|
+
* @param inputs a single tensor for self-attention or an array of exactly three
|
|
59
|
+
* tensors that are either identical (self-attention) or different (cross-attention)
|
|
60
|
+
* @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
|
|
61
|
+
*/
|
|
62
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs & {
|
|
63
|
+
packingMask?: tf.Tensor;
|
|
64
|
+
causalMask?: tf.Tensor;
|
|
65
|
+
}): tf.Tensor | tf.Tensor[];
|
|
66
|
+
/**
|
|
67
|
+
* Forward propagation
|
|
68
|
+
*/
|
|
69
|
+
protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
|
|
70
|
+
protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor): {
|
|
71
|
+
query: tf.Tensor;
|
|
72
|
+
key: tf.Tensor;
|
|
73
|
+
value: tf.Tensor;
|
|
74
|
+
};
|
|
75
|
+
protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
|
|
76
|
+
query_split: tf.Tensor4D;
|
|
77
|
+
key_split: tf.Tensor4D;
|
|
78
|
+
value_split: tf.Tensor4D;
|
|
79
|
+
};
|
|
80
|
+
/**
|
|
81
|
+
* Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
|
|
82
|
+
* formula (1) of the 2017 paper Attention Is All You Need
|
|
83
|
+
*
|
|
84
|
+
* @param attentionMask a mask to prevent tokens from being
|
|
85
|
+
* attended to (usually for padding tokens). It should have the shape
|
|
86
|
+
* [batch, head, query_sequence_len, key_sequence_len]. To use in
|
|
87
|
+
* conjunction with causal masking, the tensor should be a boolean type
|
|
88
|
+
* where false indicates a masked token.
|
|
89
|
+
* @param packingMask a mask to prevent tokens from attending across document boundaries
|
|
90
|
+
*/
|
|
91
|
+
static scaledDotProductionAttention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, attentionMask: tf.Tensor | null, packingMask: tf.Tensor | null, causalMask: tf.Tensor | null, dropout: number, causal: boolean, kwargs?: ScaledDotProductionAttentionKwargs): tf.Tensor;
|
|
92
|
+
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
93
|
+
/**
|
|
94
|
+
* MultiHead attention's output is the same shape the query's.
|
|
95
|
+
*/
|
|
96
|
+
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
97
|
+
getConfig(): {
|
|
98
|
+
numHeads: number;
|
|
99
|
+
embedDim: number;
|
|
100
|
+
useBias: boolean;
|
|
101
|
+
causal: boolean;
|
|
102
|
+
dropout: number;
|
|
103
|
+
name: string;
|
|
104
|
+
};
|
|
105
|
+
}
|
|
106
|
+
//# sourceMappingURL=multihead_attention.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAIjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IACrD,QAAQ,EAAE,MAAM,CAAC;IACjB,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,CAAC,EAAE,OAAO,CAAC;IAClB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD,MAAM,WAAW,kCAAkC;IAC/C,QAAQ,CAAC,EAAE,OAAO,CAAC;IACnB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,MAAM,CAAC,EAAE,OAAO,CAAC;IACjB,cAAc,CAAC,EAAE,MAAM,CAAC;CAC3B;AAGD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,MAAM,EAAE,OAAO,CAAC;IAInC,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,aAAa,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAClD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,gBAAgB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAGzC,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAc,EAAE,OAAa,EAAE,MAAc,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IA0BlH;;;;;OAKG;IACM,IAAI,CACT,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAC/B,MAAM,EAAE,MAAM,GAAG;QACb,WAAW,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;QACxB,UAAU,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;KAC1B,GACF,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IA6B1B;;OAEG;IACH,SAAS,CAAC,OAAO,CACb,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IA+B9B,SAAS,CAAC,qBAAqB,CAAC,WAAW,EAAE,EAAE,CAAC,MAAM,EAAE,SAAS,EAAE,EAAE,CAAC,MAAM,EAAE,WAAW,EAAE,EAAE,CAAC,MAAM;eAMtC,EAAE,CAAC,MAAM;aACf,EAAE,CAAC,MAAM;eACH,EAAE,CAAC,MAAM;;IAMvE,SAAS,CAAC,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAShB,EAAE,CAAC,QAAQ;mBACf,EAAE,CAAC,QAAQ;qBACP,EAAE,CAAC,QAAQ;;IAMrF;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,GAAG,EAAE,EAAE,CAAC,MAAM,EACd,KAAK,EAAE,EAAE,CAAC,MAAM,EAChB,aAAa,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC/B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,UAAU,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC5B,OAAO,EAAE,MAAM,EACf,MAAM,EAAE,OAAO,EACf,MAAM,GAAE,kCAAuC,GAChD,EAAE,CAAC,MAAM;IA0EH,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA4CvD;;OAEG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS;;;;;;;;CAgBrB"}
|