@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,113 +1,95 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
|
|
3
2
|
import { PositionalEncoding } from '@/layers/positional_encoding';
|
|
4
|
-
|
|
5
3
|
// disables warning for using the faster node backend,
|
|
6
4
|
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
7
5
|
tf.env().set('IS_NODE', false);
|
|
8
|
-
|
|
9
|
-
|
|
10
6
|
describe("PositionalEncoding tests", () => {
|
|
11
7
|
it("should fail to instantiate a layer", () => {
|
|
12
8
|
expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: 0 })).toThrow();
|
|
13
9
|
expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: -1 })).toThrow();
|
|
14
10
|
expect(() => new PositionalEncoding({ maxSequenceLength: 0, embedDim: 32 })).toThrow();
|
|
15
11
|
expect(() => new PositionalEncoding({ maxSequenceLength: -1, embedDim: 32 })).toThrow();
|
|
16
|
-
})
|
|
17
|
-
|
|
18
|
-
|
|
12
|
+
});
|
|
19
13
|
test("successfull forward calls", () => {
|
|
20
14
|
const embed_dims = 32;
|
|
21
15
|
const sequences = 4;
|
|
22
16
|
const input = tf.randomUniform([2, sequences, embed_dims]);
|
|
23
|
-
|
|
24
17
|
const positional = new PositionalEncoding({ embedDim: embed_dims });
|
|
25
18
|
expect(() => positional.apply(input)).not.toThrow();
|
|
26
19
|
expect(() => positional.apply([input])).not.toThrow();
|
|
27
20
|
expect(positional.computeOutputShape(input.shape)).toEqual(input.shape);
|
|
28
|
-
})
|
|
29
|
-
|
|
30
|
-
|
|
21
|
+
});
|
|
31
22
|
it("should throw when input sequences are too large, embedding dims don't match, input aren't rank 3", () => {
|
|
32
23
|
const sequences_too_long = tf.randomUniform([100, 32]);
|
|
33
24
|
const embeddings_too_large = tf.randomUniform([32, 100]);
|
|
34
25
|
const wrong_rank = tf.randomUniform([10, 32, 32]);
|
|
35
|
-
|
|
36
26
|
const positional = new PositionalEncoding({ maxSequenceLength: 10, embedDim: 32 });
|
|
37
|
-
|
|
38
27
|
expect(() => positional.apply(sequences_too_long)).toThrow();
|
|
39
28
|
expect(() => positional.apply(embeddings_too_large)).toThrow();
|
|
40
29
|
expect(() => positional.apply(wrong_rank)).toThrow();
|
|
41
|
-
})
|
|
42
|
-
|
|
43
|
-
|
|
30
|
+
});
|
|
44
31
|
it("should return a non-empty config dict", () => {
|
|
45
32
|
const attention = new PositionalEncoding({ embedDim: 32 });
|
|
46
33
|
expect(Object.keys(attention.getConfig())).not.toBe(0);
|
|
47
|
-
})
|
|
48
|
-
|
|
49
|
-
|
|
34
|
+
});
|
|
50
35
|
// PyTorch implementation at found at
|
|
51
36
|
// https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
|
|
52
37
|
it("should be within 1e-6 of PyTorch's implementation", () => {
|
|
53
38
|
const pytorch_embed4 = tf.tensor([
|
|
54
39
|
[[0.0000000, 1.0000000, 0.0000000, 1.0000000],
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
40
|
+
[0.8414710, 0.5403023, 0.0099998, 0.9999500],
|
|
41
|
+
[0.9092974, -0.4161468, 0.0199987, 0.9998000],
|
|
42
|
+
[0.1411200, -0.9899925, 0.0299955, 0.9995500],
|
|
43
|
+
[-0.7568025, -0.6536436, 0.0399893, 0.9992001],
|
|
44
|
+
[-0.9589243, 0.2836622, 0.0499792, 0.9987503],
|
|
45
|
+
[-0.2794155, 0.9601703, 0.0599640, 0.9982005],
|
|
46
|
+
[0.6569866, 0.7539023, 0.0699428, 0.9975510],
|
|
47
|
+
[0.9893582, -0.1455000, 0.0799147, 0.9968017],
|
|
48
|
+
[0.4121185, -0.9111302, 0.0898785, 0.9959527]]
|
|
49
|
+
]);
|
|
65
50
|
const pytorch_embed8 = tf.tensor([
|
|
66
51
|
[[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00,
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
52
|
+
0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00],
|
|
53
|
+
[8.4147096e-01, 5.4030234e-01, 9.9833414e-02, 9.9500418e-01,
|
|
54
|
+
9.9998331e-03, 9.9994999e-01, 9.9999981e-04, 9.9999952e-01],
|
|
55
|
+
[9.0929741e-01, -4.1614684e-01, 1.9866931e-01, 9.8006660e-01,
|
|
56
|
+
1.9998666e-02, 9.9980003e-01, 1.9999985e-03, 9.9999803e-01],
|
|
57
|
+
[1.4112000e-01, -9.8999250e-01, 2.9552019e-01, 9.5533651e-01,
|
|
58
|
+
2.9995499e-02, 9.9955004e-01, 2.9999954e-03, 9.9999553e-01],
|
|
59
|
+
[-7.5680250e-01, -6.5364361e-01, 3.8941833e-01, 9.2106098e-01,
|
|
60
|
+
3.9989334e-02, 9.9920011e-01, 3.9999890e-03, 9.9999201e-01],
|
|
61
|
+
[-9.5892429e-01, 2.8366220e-01, 4.7942552e-01, 8.7758255e-01,
|
|
62
|
+
4.9979165e-02, 9.9875027e-01, 4.9999789e-03, 9.9998754e-01],
|
|
63
|
+
[-2.7941549e-01, 9.6017027e-01, 5.6464243e-01, 8.2533562e-01,
|
|
64
|
+
5.9964005e-02, 9.9820054e-01, 5.9999637e-03, 9.9998200e-01],
|
|
65
|
+
[6.5698659e-01, 7.5390226e-01, 6.4421761e-01, 7.6484221e-01,
|
|
66
|
+
6.9942847e-02, 9.9755102e-01, 6.9999420e-03, 9.9997550e-01],
|
|
67
|
+
[9.8935825e-01, -1.4550003e-01, 7.1735609e-01, 6.9670677e-01,
|
|
68
|
+
7.9914689e-02, 9.9680167e-01, 7.9999138e-03, 9.9996799e-01],
|
|
69
|
+
[4.1211849e-01, -9.1113025e-01, 7.8332686e-01, 6.2160999e-01,
|
|
70
|
+
8.9878544e-02, 9.9595273e-01, 8.9998785e-03, 9.9995953e-01]]
|
|
71
|
+
]);
|
|
87
72
|
const positional4 = new PositionalEncoding({ embedDim: 4, maxSequenceLength: 10 });
|
|
88
73
|
positional4.build([]);
|
|
89
|
-
|
|
90
74
|
const positional8 = new PositionalEncoding({ embedDim: 8, maxSequenceLength: 10 });
|
|
91
75
|
positional8.build([]);
|
|
92
|
-
|
|
93
76
|
const margin_of_error = 1e-6;
|
|
94
|
-
|
|
95
77
|
// the difference between this and PyTorch's implementation
|
|
96
78
|
//should be insignificantly small
|
|
97
|
-
expect(
|
|
79
|
+
expect(positional4.getWeights()[0]
|
|
98
80
|
.sub(pytorch_embed4)
|
|
99
81
|
.abs()
|
|
100
|
-
.arraySync()
|
|
82
|
+
.arraySync()
|
|
101
83
|
.flat(2)
|
|
102
84
|
.filter(i => i > margin_of_error)
|
|
103
85
|
.length).toBe(0);
|
|
104
|
-
|
|
105
|
-
expect((positional8.getWeights()[0]
|
|
86
|
+
expect(positional8.getWeights()[0]
|
|
106
87
|
.sub(pytorch_embed8)
|
|
107
88
|
.abs()
|
|
108
|
-
.arraySync()
|
|
89
|
+
.arraySync()
|
|
109
90
|
.flat(2)
|
|
110
91
|
.filter(i => i > margin_of_error)
|
|
111
92
|
.length).toBe(0);
|
|
112
93
|
});
|
|
113
94
|
});
|
|
95
|
+
//# sourceMappingURL=positional_encoding.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"positional_encoding.test.js","sourceRoot":"","sources":["../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC5F,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,UAAU,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;IAC5E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,kGAAkG,EAAE,GAAG,EAAE;QACxG,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QACvD,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QACzD,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAEnF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAC3D,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,qCAAqC;IACrC,mFAAmF;IACnF,EAAE,CAAC,mDAAmD,EAAE,GAAG,EAAE;QACzD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAAC,CAAC,CAAC;QAErD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACzD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC,CAAC;SAAC,CAAC,CAAC;QAEvE,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,eAAe,GAAG,IAAI,CAAC;QAE7B,2DAA2D;QAC3D,iCAAiC;QACjC,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAErB,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzB,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type LayerArgs } from "@tensorflow/tfjs-layers/dist/engine/topology";
|
|
3
|
+
export declare function applyRope(x: tf.Tensor, dim: number, cosine_cache: tf.Tensor, sine_cache: tf.Tensor): tf.Tensor<tf.Rank>;
|
|
4
|
+
export declare function rotateHalf(x: tf.Tensor, dim: number): tf.Tensor;
|
|
5
|
+
export declare function createRoPECache(dim: number, max_sequence_length: number, theta?: number): tf.Tensor<tf.Rank>[];
|
|
6
|
+
export interface RotaryPositionEmbeddingArgs extends LayerArgs {
|
|
7
|
+
/**
|
|
8
|
+
* The dimension of each head (rounded down), e.g. `Math.floor(embedDim / numHeads)`
|
|
9
|
+
*/
|
|
10
|
+
dim: number;
|
|
11
|
+
/**
|
|
12
|
+
* The RoPE cache will be pre-calculated up to the max sequence length, and re-caculated as needed. Defaults to `4096`.
|
|
13
|
+
*/
|
|
14
|
+
maxSequenceLength?: number;
|
|
15
|
+
/**
|
|
16
|
+
* The base for the geometric progression used to compute the rotation angles. Defaults to `10_000`.
|
|
17
|
+
*/
|
|
18
|
+
theta?: number;
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
|
|
22
|
+
* Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
|
|
23
|
+
*/
|
|
24
|
+
export declare class RotaryPositionEmbedding extends tf.layers.Layer {
|
|
25
|
+
static className: string;
|
|
26
|
+
protected dim: number;
|
|
27
|
+
protected max_sequence_length: number;
|
|
28
|
+
protected theta: number;
|
|
29
|
+
protected cosine_cache: tf.LayerVariable;
|
|
30
|
+
protected sine_cache: tf.LayerVariable;
|
|
31
|
+
constructor({ dim, maxSequenceLength, theta, ...args }: RotaryPositionEmbeddingArgs);
|
|
32
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[];
|
|
33
|
+
build(input_shape: tf.Shape | tf.Shape[]): void;
|
|
34
|
+
/**
|
|
35
|
+
* Output shape: [batch, head, sequence, head_dim]
|
|
36
|
+
*/
|
|
37
|
+
computeOutputShape(input_shape: tf.Shape | tf.Shape[]): tf.Shape;
|
|
38
|
+
}
|
|
39
|
+
//# sourceMappingURL=rotary_position_embedding.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"rotary_position_embedding.d.ts","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAG9E,wBAAgB,SAAS,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,YAAY,EAAE,EAAE,CAAC,MAAM,EAAE,UAAU,EAAE,EAAE,CAAC,MAAM,sBAalG;AAGD,wBAAgB,UAAU,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB/D;AAGD,wBAAgB,eAAe,CAAC,GAAG,EAAE,MAAM,EAAE,mBAAmB,EAAE,MAAM,EAAE,KAAK,GAAE,MAAe,wBAqB/F;AAGD,MAAM,WAAW,2BAA4B,SAAQ,SAAS;IAC1D;;OAEG;IACH,GAAG,EAAE,MAAM,CAAC;IACZ;;OAEG;IACH,iBAAiB,CAAC,EAAE,MAAM,CAAC;IAC3B;;OAEG;IACH,KAAK,CAAC,EAAE,MAAM,CAAC;CAClB;AAGD;;;GAGG;AACH,qBAAa,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,SAA6B;IAE7C,SAAS,CAAC,GAAG,EAAE,MAAM,CAAC;IACtB,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,KAAK,EAAE,MAAM,CAAC;IAGxB,SAAS,CAAC,YAAY,EAAE,EAAE,CAAC,aAAa,CAAC;IACzC,SAAS,CAAC,UAAU,EAAE,EAAE,CAAC,aAAa,CAAC;gBAE3B,EAAE,GAAG,EAAE,iBAAwB,EAAE,KAAc,EAAE,GAAG,IAAI,EAAE,EAAE,2BAA2B;IAqB1F,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,GAAG,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAkB3E,KAAK,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAmBjD;;OAEG;IACI,kBAAkB,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;CAK/D"}
|
package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js}
RENAMED
|
@@ -1,163 +1,99 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
export function applyRope(x: tf.Tensor, dim: number, cosine_cache: tf.Tensor, sine_cache: tf.Tensor) {
|
|
2
|
+
export function applyRope(x, dim, cosine_cache, sine_cache) {
|
|
6
3
|
return tf.tidy(() => {
|
|
7
|
-
const seq_length = x.shape[2]
|
|
8
|
-
|
|
4
|
+
const seq_length = x.shape[2];
|
|
9
5
|
// get a slice of the pre-computed cache, up to the input's sequence length
|
|
10
6
|
const cosine = cosine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
|
|
11
7
|
const sine = sine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
|
|
12
|
-
|
|
13
8
|
// apply RoPE formula (x1 * cosine) + (rotate(-x2) * sine)
|
|
14
9
|
const rotated_x = rotateHalf(x, dim);
|
|
15
|
-
|
|
16
10
|
return tf.add(tf.mul(x, cosine), tf.mul(rotated_x, sine));
|
|
17
11
|
});
|
|
18
12
|
}
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
export function rotateHalf(x: tf.Tensor, dim: number): tf.Tensor {
|
|
13
|
+
export function rotateHalf(x, dim) {
|
|
22
14
|
return tf.tidy(() => {
|
|
23
15
|
// reshape the last dimension such that adjacent coordinates are paired together
|
|
24
16
|
// [x1, x2, x3, x4] -> [[x1, x2], [x3, x4]]
|
|
25
17
|
// the leading dimensions are flattened because TFJS has issues during
|
|
26
18
|
// backpropagation with 5D slicing
|
|
27
19
|
const reshaped = x.reshape([-1, dim / 2, 2]);
|
|
28
|
-
|
|
29
20
|
const x1 = reshaped.slice([0, 0, 0], [-1, -1, 1]);
|
|
30
21
|
const x2 = reshaped.slice([0, 0, 1], [-1, -1, 1]);
|
|
31
|
-
|
|
32
22
|
// [x1, x2] -> [-x2, x1]
|
|
33
23
|
const rotated = tf.concat([tf.neg(x2), x1], -1);
|
|
34
|
-
|
|
35
24
|
return rotated.reshape(x.shape);
|
|
36
25
|
});
|
|
37
26
|
}
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
export function createRoPECache(dim: number, max_sequence_length: number, theta: number = 10_000) {
|
|
27
|
+
export function createRoPECache(dim, max_sequence_length, theta = 10_000) {
|
|
41
28
|
return tf.tidy(() => {
|
|
42
29
|
// [dim]
|
|
43
|
-
const inv_frequencies = tf.div
|
|
44
|
-
theta,
|
|
45
|
-
tf.range(0, Math.floor(dim / 2) * 2, 2, "float32").div(dim)));
|
|
46
|
-
|
|
30
|
+
const inv_frequencies = tf.div(1, tf.pow(theta, tf.range(0, Math.floor(dim / 2) * 2, 2, "float32").div(dim)));
|
|
47
31
|
// [max_sequene_length]
|
|
48
32
|
const sequence_indices = tf.range(0, max_sequence_length);
|
|
49
33
|
//
|
|
50
34
|
const freq = tf.outerProduct(sequence_indices, inv_frequencies);
|
|
51
|
-
|
|
52
35
|
// cache final shape [max_sequence_length, dim]
|
|
53
36
|
const freq_pairs = tf.stack([freq, freq], -1)
|
|
54
37
|
.reshape([max_sequence_length, dim]);
|
|
55
|
-
|
|
56
38
|
return [
|
|
57
39
|
tf.keep(tf.cos(freq_pairs).expandDims(0).expandDims(0)),
|
|
58
40
|
tf.keep(tf.sin(freq_pairs).expandDims(0).expandDims(0))
|
|
59
|
-
]
|
|
41
|
+
];
|
|
60
42
|
});
|
|
61
43
|
}
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
export interface RotaryPositionEmbeddingArgs extends LayerArgs {
|
|
65
|
-
/**
|
|
66
|
-
* The dimension of each head (rounded down), e.g. `Math.floor(embedDim / numHeads)`
|
|
67
|
-
*/
|
|
68
|
-
dim: number,
|
|
69
|
-
/**
|
|
70
|
-
* The RoPE cache will be pre-calculated up to the max sequence length, and re-caculated as needed. Defaults to `4096`.
|
|
71
|
-
*/
|
|
72
|
-
maxSequenceLength?: number,
|
|
73
|
-
/**
|
|
74
|
-
* The base for the geometric progression used to compute the rotation angles. Defaults to `10_000`.
|
|
75
|
-
*/
|
|
76
|
-
theta?: number,
|
|
77
|
-
}
|
|
78
|
-
|
|
79
|
-
|
|
80
44
|
/**
|
|
81
45
|
* Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
|
|
82
46
|
* Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
|
|
83
47
|
*/
|
|
84
48
|
export class RotaryPositionEmbedding extends tf.layers.Layer {
|
|
85
49
|
static className = "RotaryPositionEmbedding";
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
protected theta: number;
|
|
90
|
-
|
|
50
|
+
dim;
|
|
51
|
+
max_sequence_length;
|
|
52
|
+
theta;
|
|
91
53
|
// cached sine and cosine frequencies, untrainable weights
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
constructor({ dim, maxSequenceLength = 4096, theta = 10_000, ...args }: RotaryPositionEmbeddingArgs) {
|
|
54
|
+
cosine_cache;
|
|
55
|
+
sine_cache;
|
|
56
|
+
constructor({ dim, maxSequenceLength = 4096, theta = 10_000, ...args }) {
|
|
96
57
|
super(args);
|
|
97
|
-
|
|
98
58
|
if (dim % 2 !== 0) {
|
|
99
59
|
throw Error(`${this.getClassName()}::constructor ${this.name} expected dim to be even, got ${dim}`);
|
|
100
60
|
}
|
|
101
|
-
|
|
102
61
|
this.dim = dim;
|
|
103
62
|
this.max_sequence_length = maxSequenceLength;
|
|
104
63
|
this.theta = theta;
|
|
105
|
-
|
|
106
|
-
this.
|
|
107
|
-
[1, 1, maxSequenceLength, Math.floor(this.dim)],
|
|
108
|
-
"float32", tf.initializers.zeros(), undefined, false);
|
|
109
|
-
|
|
110
|
-
this.sine_cache = this.addWeight("cosine_cache",
|
|
111
|
-
[1, 1, maxSequenceLength, Math.floor(this.dim)],
|
|
112
|
-
"float32", tf.initializers.zeros(), undefined, false);
|
|
64
|
+
this.cosine_cache = this.addWeight("sine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
|
|
65
|
+
this.sine_cache = this.addWeight("cosine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
|
|
113
66
|
}
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
override call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[] {
|
|
67
|
+
call(inputs, kwargs) {
|
|
117
68
|
const shape = Array.isArray(inputs) ? inputs[0].shape : inputs.shape;
|
|
118
69
|
const seq_length = shape[2];
|
|
119
|
-
|
|
120
70
|
if (seq_length > this.max_sequence_length) {
|
|
121
71
|
// expand cache to the nearest power of 2
|
|
122
72
|
this.max_sequence_length = Math.pow(2, Math.ceil(Math.log2(seq_length)));
|
|
123
73
|
this.build([]);
|
|
124
74
|
}
|
|
125
|
-
|
|
126
|
-
return applyRope(
|
|
127
|
-
Array.isArray(inputs) ? inputs[0] : inputs,
|
|
128
|
-
this.dim,
|
|
129
|
-
this.cosine_cache.read(),
|
|
130
|
-
this.sine_cache.read())
|
|
75
|
+
return applyRope(Array.isArray(inputs) ? inputs[0] : inputs, this.dim, this.cosine_cache.read(), this.sine_cache.read());
|
|
131
76
|
}
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
override build(input_shape: tf.Shape | tf.Shape[]) {
|
|
135
|
-
const [cosine, sine] = createRoPECache(
|
|
136
|
-
this.dim, this.max_sequence_length, this.theta);
|
|
137
|
-
|
|
77
|
+
build(input_shape) {
|
|
78
|
+
const [cosine, sine] = createRoPECache(this.dim, this.max_sequence_length, this.theta);
|
|
138
79
|
this.cosine_cache.dispose();
|
|
139
80
|
this.sine_cache.dispose();
|
|
140
|
-
|
|
141
81
|
this.cosine_cache = new tf.LayerVariable(cosine);
|
|
142
82
|
this.sine_cache = new tf.LayerVariable(sine);
|
|
143
|
-
|
|
144
83
|
this.nonTrainableWeights = [
|
|
145
84
|
new tf.LayerVariable(cosine),
|
|
146
85
|
new tf.LayerVariable(sine)
|
|
147
86
|
];
|
|
148
|
-
|
|
149
87
|
this.setWeights([cosine, sine]);
|
|
150
88
|
}
|
|
151
|
-
|
|
152
|
-
|
|
153
89
|
/**
|
|
154
90
|
* Output shape: [batch, head, sequence, head_dim]
|
|
155
91
|
*/
|
|
156
|
-
|
|
92
|
+
computeOutputShape(input_shape) {
|
|
157
93
|
return Array.isArray(input_shape[0])
|
|
158
|
-
? input_shape[0]
|
|
159
|
-
: input_shape
|
|
94
|
+
? input_shape[0]
|
|
95
|
+
: input_shape;
|
|
160
96
|
}
|
|
161
97
|
}
|
|
162
|
-
|
|
163
98
|
tf.serialization.registerClass(RotaryPositionEmbedding);
|
|
99
|
+
//# sourceMappingURL=rotary_position_embedding.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"rotary_position_embedding.js","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,MAAM,UAAU,SAAS,CAAC,CAAY,EAAE,GAAW,EAAE,YAAuB,EAAE,UAAqB;IAC/F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,MAAM,UAAU,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAE/B,2EAA2E;QAC3E,MAAM,MAAM,GAAG,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QACzE,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QAErE,0DAA0D;QAC1D,MAAM,SAAS,GAAG,UAAU,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAErC,OAAO,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC;IAC9D,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,UAAU,CAAC,CAAY,EAAE,GAAW;IAChD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,gFAAgF;QAChF,2CAA2C;QAC3C,sEAAsE;QACtE,kCAAkC;QAClC,MAAM,QAAQ,GAAG,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,EAAE,GAAG,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE7C,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAElD,wBAAwB;QACxB,MAAM,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAEhD,OAAO,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,eAAe,CAAC,GAAW,EAAE,mBAA2B,EAAE,QAAgB,MAAM;IAC5F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,QAAQ;QACR,MAAM,eAAe,GAAG,EAAE,CAAC,GAAG,CAAc,CAAC,EAAE,EAAE,CAAC,GAAG,CACjD,KAAK,EACL,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,GAAG,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAElE,uBAAuB;QACvB,MAAM,gBAAgB,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,mBAAmB,CAAC,CAAC;QAC1D,GAAG;QACH,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAC,gBAAgB,EAAE,eAAe,CAAC,CAAC;QAEhE,+CAA+C;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC;aACxC,OAAO,CAAC,CAAC,mBAAmB,EAAE,GAAG,CAAC,CAAC,CAAC;QAEzC,OAAO;YACH,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACvD,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;SAC1D,CAAA;IACL,CAAC,CAAC,CAAC;AACP,CAAC;AAmBD;;;GAGG;AACH,MAAM,OAAO,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,GAAG,yBAAyB,CAAC;IAEnC,GAAG,CAAS;IACZ,mBAAmB,CAAS;IAC5B,KAAK,CAAS;IAExB,0DAA0D;IAChD,YAAY,CAAmB;IAC/B,UAAU,CAAmB;IAEvC,YAAY,EAAE,GAAG,EAAE,iBAAiB,GAAG,IAAI,EAAE,KAAK,GAAG,MAAM,EAAE,GAAG,IAAI,EAA+B;QAC/F,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,GAAG,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC;YAChB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,iCAAiC,GAAG,EAAE,CAAC,CAAC;QACxG,CAAC;QAED,IAAI,CAAC,GAAG,GAAG,GAAG,CAAC;QACf,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;QAEnB,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,YAAY,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;QAE1D,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,cAAc,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IAC9D,CAAC;IAGQ,IAAI,CAAC,MAA+B,EAAE,MAAW;QACtD,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC;QACrE,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;QAE5B,IAAI,UAAU,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YACxC,yCAAyC;YACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACzE,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QACnB,CAAC;QAED,OAAO,SAAS,CACZ,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,EAC1C,IAAI,CAAC,GAAG,EACR,IAAI,CAAC,YAAY,CAAC,IAAI,EAAE,EACxB,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,CAAC,CAAA;IAC/B,CAAC;IAGQ,KAAK,CAAC,WAAkC;QAC7C,MAAM,CAAC,MAAM,EAAE,IAAI,CAAC,GAAG,eAAe,CAClC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,mBAAmB,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAEpD,IAAI,CAAC,YAAY,CAAC,OAAO,EAAE,CAAC;QAC5B,IAAI,CAAC,UAAU,CAAC,OAAO,EAAE,CAAC;QAE1B,IAAI,CAAC,YAAY,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;QACjD,IAAI,CAAC,UAAU,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;QAE7C,IAAI,CAAC,mBAAmB,GAAG;YACvB,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC;YAC5B,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC;SAC7B,CAAC;QAEF,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC,CAAC;IACpC,CAAC;IAGD;;OAEG;IACI,kBAAkB,CAAC,WAAkC;QACxD,OAAO,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;YAChC,CAAC,CAAC,WAAW,CAAC,CAAC,CAAa;YAC5B,CAAC,CAAC,WAAuB,CAAC;IAClC,CAAC;;AAGL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,uBAAuB,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"rotary_position_embedding.test.d.ts","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import { RotaryPositionEmbedding } from "@/layers/rotary_position_embedding";
|
|
2
|
+
import * as tf from "@tensorflow/tfjs";
|
|
3
|
+
// disables warning for using the faster node backend,
|
|
4
|
+
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
5
|
+
tf.env().set('IS_NODE', false);
|
|
6
|
+
describe("RotaryPositionEmbedding tests", () => {
|
|
7
|
+
test("create cache", async () => {
|
|
8
|
+
const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
|
|
9
|
+
rope.build([]);
|
|
10
|
+
const expected_cosine_cache = tf.tensor([[[
|
|
11
|
+
[1, 1, 1, 1, 1, 1, 1, 1],
|
|
12
|
+
[0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334, 0.9999995231628418, 0.9999995231628418],
|
|
13
|
+
[-0.416146844625473, -0.416146844625473, 0.9800665974617004, 0.9800665974617004, 0.9998000264167786, 0.9998000264167786, 0.9999979734420776, 0.9999979734420776],
|
|
14
|
+
[-0.9899924993515015, -0.9899924993515015, 0.9553365111351013, 0.9553365111351013, 0.9995500445365906, 0.9995500445365906, 0.9999955296516418, 0.9999955296516418],
|
|
15
|
+
[-0.6536436080932617, -0.6536436080932617, 0.9210609793663025, 0.9210609793663025, 0.9992001056671143, 0.9992001056671143, 0.9999920129776001, 0.9999920129776001],
|
|
16
|
+
[0.28366219997406006, 0.28366219997406006, 0.8775825500488281, 0.8775825500488281, 0.9987502694129944, 0.9987502694129944, 0.9999874830245972, 0.9999874830245972],
|
|
17
|
+
[0.9601702690124512, 0.9601702690124512, 0.8253356218338013, 0.8253356218338013, 0.998200535774231, 0.998200535774231, 0.9999819993972778, 0.9999819993972778],
|
|
18
|
+
[0.7539022564888, 0.7539022564888, 0.7648422122001648, 0.7648422122001648, 0.9975510239601135, 0.9975510239601135, 0.9999755024909973, 0.9999755024909973],
|
|
19
|
+
[-0.1455000340938568, -0.1455000340938568, 0.6967067122459412, 0.6967067122459412, 0.9968017339706421, 0.9968017339706421, 0.9999679923057556, 0.9999679923057556],
|
|
20
|
+
[-0.9111302495002747, -0.9111302495002747, 0.6216099262237549, 0.6216099262237549, 0.9959527254104614, 0.9959527254104614, 0.9999595284461975, 0.9999595284461975],
|
|
21
|
+
[-0.83907151222229, -0.83907151222229, 0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334],
|
|
22
|
+
[0.004425697959959507, 0.004425697959959507, 0.4535960853099823, 0.4535960853099823, 0.9939560890197754, 0.9939560890197754, 0.999939501285553, 0.999939501285553],
|
|
23
|
+
[0.8438539505004883, 0.8438539505004883, 0.3623577058315277, 0.3623577058315277, 0.9928086400032043, 0.9928086400032043, 0.9999279975891113, 0.9999279975891113],
|
|
24
|
+
[0.9074468016624451, 0.9074468016624451, 0.26749876141548157, 0.26749876141548157, 0.9915618896484375, 0.9915618896484375, 0.9999154806137085, 0.9999154806137085],
|
|
25
|
+
[0.13673721253871918, 0.13673721253871918, 0.1699671596288681, 0.1699671596288681, 0.9902160167694092, 0.9902160167694092, 0.9999020099639893, 0.9999020099639893]
|
|
26
|
+
]]]);
|
|
27
|
+
const expected_sine_cache = tf.tensor([[[
|
|
28
|
+
[0, 0, 0, 0, 0, 0, 0, 0],
|
|
29
|
+
[0.8414709568023682, 0.8414709568023682, 0.0998334214091301, 0.0998334214091301, 0.009999833069741726, 0.009999833069741726, 0.0009999999310821295, 0.0009999999310821295],
|
|
30
|
+
[0.9092974066734314, 0.9092974066734314, 0.19866932928562164, 0.19866932928562164, 0.019998665899038315, 0.019998665899038315, 0.0019999986980110407, 0.0019999986980110407],
|
|
31
|
+
[0.14112000167369843, 0.14112000167369843, 0.29552021622657776, 0.29552021622657776, 0.029995499178767204, 0.029995499178767204, 0.0029999956022948027, 0.0029999956022948027],
|
|
32
|
+
[-0.756802499294281, -0.756802499294281, 0.3894183337688446, 0.3894183337688446, 0.03998933359980583, 0.03998933359980583, 0.003999989479780197, 0.003999989479780197],
|
|
33
|
+
[-0.9589242935180664, -0.9589242935180664, 0.4794255495071411, 0.4794255495071411, 0.04997916519641876, 0.04997916519641876, 0.0049999793991446495, 0.0049999793991446495],
|
|
34
|
+
[-0.279415488243103, -0.279415488243103, 0.5646424889564514, 0.5646424889564514, 0.059964004904031754, 0.059964004904031754, 0.0059999641962349415, 0.0059999641962349415],
|
|
35
|
+
[0.6569865942001343, 0.6569865942001343, 0.6442176699638367, 0.6442176699638367, 0.06994284689426422, 0.06994284689426422, 0.0069999429397284985, 0.0069999429397284985],
|
|
36
|
+
[0.9893582463264465, 0.9893582463264465, 0.7173560857772827, 0.7173560857772827, 0.07991468906402588, 0.07991468906402588, 0.007999914698302746, 0.007999914698302746],
|
|
37
|
+
[0.41211849451065063, 0.41211849451065063, 0.7833269238471985, 0.7833269238471985, 0.08987854421138763, 0.08987854421138763, 0.008999879471957684, 0.008999879471957684],
|
|
38
|
+
[-0.5440211296081543, -0.5440211296081543, 0.8414709568023682, 0.8414709568023682, 0.0998334139585495, 0.0998334139585495, 0.0099998340010643, 0.0099998340010643],
|
|
39
|
+
[-0.9999902248382568, -0.9999902248382568, 0.8912073969841003, 0.8912073969841003, 0.10977829992771149, 0.10977829992771149, 0.010999779216945171, 0.010999779216945171],
|
|
40
|
+
[-0.5365729331970215, -0.5365729331970215, 0.9320390820503235, 0.9320390820503235, 0.11971220374107361, 0.11971220374107361, 0.011999712325632572, 0.011999712325632572],
|
|
41
|
+
[0.4201670289039612, 0.4201670289039612, 0.9635581970214844, 0.9635581970214844, 0.12963414192199707, 0.12963414192199707, 0.012999634258449078, 0.012999634258449078],
|
|
42
|
+
[0.9906073808670044, 0.9906073808670044, 0.9854497313499451, 0.9854497313499451, 0.13954311609268188, 0.13954311609268188, 0.013999543152749538, 0.013999543152749538]
|
|
43
|
+
]]]);
|
|
44
|
+
const [cosine_cache, sine_cache] = rope.getWeights();
|
|
45
|
+
expect(await cosine_cache?.sub(expected_cosine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
|
|
46
|
+
expect(await sine_cache?.sub(expected_sine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
|
|
47
|
+
});
|
|
48
|
+
test("rotate inputs", async () => {
|
|
49
|
+
const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
|
|
50
|
+
const x = tf.tensor([[[
|
|
51
|
+
[0.0766048, 0.5706575, 0.6705932, 0.5273118, 0.4794086, 0.9378104, 0.9888024, 0.6926053],
|
|
52
|
+
[0.9064133, 0.5875182, 0.1681865, 0.3833345, 0.9901192, 0.4677338, 0.3353315, 0.02699],
|
|
53
|
+
[0.3033573, 0.4139377, 0.4062586, 0.9705839, 0.3582608, 0.328775, 0.1340587, 0.2193414],
|
|
54
|
+
[0.5565202, 0.4334963, 0.9912352, 0.3388563, 0.7991487, 0.1911893, 0.1140554, 0.6949552]
|
|
55
|
+
]]
|
|
56
|
+
]); // batch=1, seq = 1, heads=4, embedDim=8
|
|
57
|
+
const expected_output = tf.tensor([[[
|
|
58
|
+
[0.07660479843616486, 0.57065749168396, 0.6705932021141052, 0.5273118019104004, 0.4794085919857025, 0.9378104209899902, 0.9888023734092712, 0.6926053166389465],
|
|
59
|
+
[-0.004642367362976074, 1.08015775680542, 0.12907665967941284, 0.39820998907089233, 0.9853923320770264, 0.47761136293411255, 0.33530429005622864, 0.027325313538312912],
|
|
60
|
+
[-0.5026336908340454, 0.10358311235904694, 0.20533521473407745, 1.0319478511810303, 0.3516140580177307, 0.33587393164634705, 0.1336197406053543, 0.21960905194282532],
|
|
61
|
+
[-0.6121258735656738, -0.3506217896938324, 0.8468242287635803, 0.6166517734527588, 0.7930541634559631, 0.2150741070508957, 0.11197001487016678, 0.695294201374054]
|
|
62
|
+
]]]);
|
|
63
|
+
const output = rope.apply(x);
|
|
64
|
+
expect(await expected_output.sub(output).sum().array()).toBeLessThan(1e-6);
|
|
65
|
+
expect(rope.computeOutputShape(x.shape)).toEqual(x.shape);
|
|
66
|
+
expect(rope.computeOutputShape([x.shape])).toEqual(x.shape);
|
|
67
|
+
});
|
|
68
|
+
test("expand cache when input sequences are larger than rope's max sequence length", async () => {
|
|
69
|
+
const dim = 8;
|
|
70
|
+
const rope = new RotaryPositionEmbedding({ dim, maxSequenceLength: 15, theta: 1_000_000 });
|
|
71
|
+
const larger_sequence = 20;
|
|
72
|
+
const even_larger_sequence = 50;
|
|
73
|
+
rope.apply(tf.randomUniform([1, 1, larger_sequence, dim]));
|
|
74
|
+
rope.getWeights().forEach(weight => {
|
|
75
|
+
expect(weight.shape).toEqual([1, 1, 32, dim]);
|
|
76
|
+
});
|
|
77
|
+
rope.apply([tf.randomUniform([1, 1, even_larger_sequence, dim])]);
|
|
78
|
+
rope.getWeights().forEach(weight => {
|
|
79
|
+
expect(weight.shape).toEqual([1, 1, 64, dim]);
|
|
80
|
+
});
|
|
81
|
+
});
|
|
82
|
+
test("create layer", async () => {
|
|
83
|
+
// dim must be even
|
|
84
|
+
expect(() => new RotaryPositionEmbedding({ dim: 7, maxSequenceLength: 15 })).toThrow();
|
|
85
|
+
expect(() => new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 25 })).not.toThrow();
|
|
86
|
+
});
|
|
87
|
+
});
|
|
88
|
+
//# sourceMappingURL=rotary_position_embedding.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"rotary_position_embedding.test.js","sourceRoot":"","sources":["../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAC7E,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,+BAA+B,EAAE,GAAG,EAAE;IAC3C,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAC5E,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEf,MAAM,qBAAqB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACtC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,eAAe,EAAE,eAAe,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC1J,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,gBAAgB,EAAE,CAAC,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,oBAAoB,EAAE,oBAAoB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,mBAAmB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACpC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC5K,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC9K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;iBACzK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,CAAC,YAAY,EAAE,UAAU,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC;QAErD,MAAM,CAAC,MAAM,YAAY,EAAE,GAAG,CAAC,qBAAqB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;QACzG,MAAM,CAAC,MAAM,UAAU,EAAE,GAAG,CAAC,mBAAmB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;IACzG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,KAAK,IAAI,EAAE;QAC7B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAE5E,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAClB,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;oBACxF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,OAAO,CAAC;oBACtF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,CAAC;oBACvF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;iBAAC,CAAC;SAC7F,CAAC,CAAC,CAAC,wCAAwC;QAE5C,MAAM,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAChC,CAAC,mBAAmB,EAAE,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC/J,CAAC,CAAC,oBAAoB,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,CAAC;oBACvK,CAAC,CAAC,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,CAAC;oBACrK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,iBAAiB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAc,CAAC;QAE1C,MAAM,CAAC,MAAM,eAAe,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QACrF,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QAC1D,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IAChE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,8EAA8E,EAAE,KAAK,IAAI,EAAE;QAC5F,MAAM,GAAG,GAAG,CAAC,CAAC;QACd,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,iBAAiB,EAAE,EAAE,EAAE,KAAK,EAAE,SAAS,EAAE,CAAC,CAAC;QAC3F,MAAM,eAAe,GAAG,EAAE,CAAC;QAC3B,MAAM,oBAAoB,GAAG,EAAE,CAAC;QAEhC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,eAAe,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC;QAE3D,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;QAEH,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,oBAAoB,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;QAElE,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;IACP,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,mBAAmB;QACnB,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC/F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
+
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
4
|
+
import { type PositionalEncodingArgs } from '../layers/positional_encoding';
|
|
5
|
+
export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEncodingArgs {
|
|
6
|
+
vocabularySize: number;
|
|
7
|
+
dropout?: number;
|
|
8
|
+
}
|
|
9
|
+
/**
|
|
10
|
+
* This class implements combines sinusoidal positional encoding from the
|
|
11
|
+
* 2017 paper "Attention Is All You Need" with a normal embedding layer to
|
|
12
|
+
* form a simplified single embedding layer.
|
|
13
|
+
*
|
|
14
|
+
* This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
|
|
15
|
+
* it through an embedding layer before adding sinusoidal positional encoding.
|
|
16
|
+
*
|
|
17
|
+
* @param embedDim the size of each token/word's embedding
|
|
18
|
+
* @param vocabularySize the number of tokens to embed
|
|
19
|
+
* @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
|
|
20
|
+
* @param dropout applies dropout to the positionally encoded embeddings, default `0.1`
|
|
21
|
+
*/
|
|
22
|
+
export declare class TokenAndPositionalEmbedding extends tf.layers.Layer {
|
|
23
|
+
static className: string;
|
|
24
|
+
private readonly embedDim;
|
|
25
|
+
private readonly vocabularySize;
|
|
26
|
+
private embedding;
|
|
27
|
+
private positional;
|
|
28
|
+
private readonly maxSequenceLength;
|
|
29
|
+
private readonly dropout;
|
|
30
|
+
private dropoutLayer;
|
|
31
|
+
constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }: TokenAndPositionalEmbeddingArgs);
|
|
32
|
+
/**
|
|
33
|
+
* Forward propagation.
|
|
34
|
+
*/
|
|
35
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor<tf.Rank>;
|
|
36
|
+
/**
|
|
37
|
+
* Build the sublayers and enable serialization
|
|
38
|
+
*/
|
|
39
|
+
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
40
|
+
/**
|
|
41
|
+
* The output shape, for an input shape of [batch, sequences], is
|
|
42
|
+
* [batch, sequences, embedDim]
|
|
43
|
+
*/
|
|
44
|
+
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
45
|
+
getConfig(): tf.serialization.ConfigDict;
|
|
46
|
+
}
|
|
47
|
+
//# sourceMappingURL=token_and_positional_embedding.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"token_and_positional_embedding.d.ts","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,+BAA+B,CAAC;AAGhG,MAAM,WAAW,+BAAgC,SAAQ,SAAS,EAAE,sBAAsB;IACtF,cAAc,EAAE,MAAM,CAAC;IACvB,OAAO,CAAC,EAAE,MAAM,CAAA;CACnB;AAGD;;;;;;;;;;;;GAYG;AACH,qBAAa,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,SAAiC;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAS;IACxC,OAAO,CAAC,SAAS,CAAkB;IAEnC,OAAO,CAAC,UAAU,CAAiB;IACnC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IAEjC,OAAO,CAAC,YAAY,CAAkB;gBAG1B,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAE,EAAE,+BAA+B;IA0B9G;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM;IAe7D;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAgCvD;;;OAGG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAQ5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAcpD"}
|