@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,212 +1,160 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
|
|
3
2
|
import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
|
|
4
|
-
import {
|
|
3
|
+
import { causal as generateCausalMask } from "@/masks";
|
|
5
4
|
import { MultiHeadAttention } from '@/layers/multihead_attention';
|
|
6
|
-
|
|
7
5
|
// disables warning for using the faster node backend,
|
|
8
6
|
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
9
7
|
tf.env().set('IS_NODE', false);
|
|
10
|
-
|
|
11
|
-
|
|
12
8
|
describe("MultiHeadAttention tests", () => {
|
|
13
9
|
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
14
10
|
expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 3, embedDim: 10 })).toThrow();
|
|
15
11
|
expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 15, embedDim: 60 })).not.toThrow();
|
|
16
|
-
})
|
|
17
|
-
|
|
18
|
-
|
|
12
|
+
});
|
|
19
13
|
test("successfull forward calls", () => {
|
|
20
14
|
const input = tf.randomUniform([2, 3, 12]);
|
|
21
|
-
|
|
22
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
|
|
15
|
+
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
23
16
|
expect(() => attention.apply(input)).not.toThrow();
|
|
24
17
|
expect(() => attention.apply([input])).not.toThrow();
|
|
25
|
-
|
|
26
|
-
const causal = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
|
|
18
|
+
const causal = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
|
|
27
19
|
expect(() => causal.apply(input)).not.toThrow();
|
|
28
20
|
expect(() => causal.apply([input])).not.toThrow();
|
|
29
|
-
})
|
|
30
|
-
|
|
31
|
-
|
|
21
|
+
});
|
|
32
22
|
test("query and value must have the same shape for scaled dot product attention to succeed", () => {
|
|
33
23
|
const query = tf.randomUniform([2, 3, 12]);
|
|
34
24
|
const key = tf.randomUniform([2, 3, 12]);
|
|
35
25
|
const value = tf.randomUniform([2, 3, 12]);
|
|
36
26
|
const value_thats_too_long = tf.randomUniform([2, 100, 12]);
|
|
37
|
-
|
|
38
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1)! });
|
|
27
|
+
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1) });
|
|
39
28
|
expect(() => attention.apply([query, key, value])).not.toThrow();
|
|
40
29
|
expect(() => attention.apply([query, key, value_thats_too_long])).toThrow();
|
|
41
|
-
})
|
|
42
|
-
|
|
43
|
-
|
|
30
|
+
});
|
|
44
31
|
it("should only accept rank 3 tensors", () => {
|
|
45
32
|
const embed_dims = 12;
|
|
46
|
-
|
|
47
33
|
const BAD_RANK2 = tf.randomUniform([2, embed_dims]);
|
|
48
34
|
const GOOD = tf.randomUniform([2, 3, embed_dims]);
|
|
49
35
|
const BAD_RANK4 = tf.randomUniform([2, 3, 10, embed_dims]);
|
|
50
|
-
|
|
51
36
|
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: embed_dims });
|
|
52
|
-
|
|
53
37
|
// BAD
|
|
54
38
|
expect(() => attention.apply(BAD_RANK2)).toThrow();
|
|
55
39
|
expect(() => attention.apply([BAD_RANK2])).toThrow();
|
|
56
40
|
expect(() => attention.apply([BAD_RANK2, BAD_RANK2, BAD_RANK2])).toThrow();
|
|
57
|
-
|
|
58
41
|
// OK
|
|
59
42
|
expect(() => attention.apply(GOOD)).not.toThrow();
|
|
60
43
|
expect(() => attention.apply([GOOD])).not.toThrow();
|
|
61
44
|
expect(() => attention.apply([GOOD, GOOD, GOOD])).not.toThrow();
|
|
62
|
-
|
|
63
45
|
// BAD
|
|
64
46
|
expect(() => attention.apply(BAD_RANK4)).toThrow();
|
|
65
47
|
expect(() => attention.apply([BAD_RANK4])).toThrow();
|
|
66
48
|
expect(() => attention.apply([BAD_RANK4, BAD_RANK4, BAD_RANK4])).toThrow();
|
|
67
|
-
|
|
68
49
|
// BAD
|
|
69
50
|
expect(() => attention.apply([GOOD, BAD_RANK2, BAD_RANK4])).toThrow();
|
|
70
51
|
expect(() => attention.apply([BAD_RANK2, GOOD, BAD_RANK4])).toThrow();
|
|
71
52
|
expect(() => attention.apply([BAD_RANK2, BAD_RANK4, GOOD])).toThrow();
|
|
72
53
|
expect(() => attention.apply([BAD_RANK2, GOOD, GOOD])).toThrow();
|
|
73
54
|
expect(() => attention.apply([GOOD, GOOD, BAD_RANK4])).toThrow();
|
|
74
|
-
})
|
|
75
|
-
|
|
76
|
-
|
|
55
|
+
});
|
|
77
56
|
it("should only 1 or 3 inputs total", () => {
|
|
78
57
|
const input = tf.randomUniform([2, 3, 12]);
|
|
79
|
-
|
|
80
|
-
let attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)! });
|
|
81
|
-
|
|
58
|
+
let attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
82
59
|
// OK
|
|
83
60
|
expect(() => attention.apply(input, { packingMask: undefined })).not.toThrow();
|
|
84
61
|
expect(() => attention.apply([input])).not.toThrow();
|
|
85
62
|
// reinitialize to rerun build()
|
|
86
|
-
attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)
|
|
63
|
+
attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
87
64
|
expect(() => attention.apply([input, input, input])).not.toThrow();
|
|
88
|
-
|
|
89
65
|
// BAD
|
|
90
66
|
expect(() => attention.apply([])).toThrow();
|
|
91
67
|
expect(() => attention.apply([input, input])).toThrow();
|
|
92
68
|
// reinitialize to rerun build()
|
|
93
|
-
attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1)
|
|
69
|
+
attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
94
70
|
expect(() => attention.apply([input, input, input, input])).toThrow();
|
|
95
|
-
})
|
|
96
|
-
|
|
97
|
-
|
|
71
|
+
});
|
|
98
72
|
test("attention masking", () => {
|
|
99
73
|
const query = tf.randomUniform([2, 3, 12]);
|
|
100
74
|
const key = tf.randomUniform([2, 3, 12]);
|
|
101
75
|
const value = tf.randomUniform([2, 3, 12]);
|
|
102
|
-
|
|
103
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1)!, causal: true });
|
|
104
|
-
|
|
76
|
+
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1), causal: true });
|
|
105
77
|
expect(() => attention.call(query, {})).not.toThrow();
|
|
106
|
-
|
|
107
78
|
// cross attention
|
|
108
79
|
expect(() => attention.call([query, key, value], {})).not.toThrow();
|
|
109
|
-
|
|
110
|
-
|
|
111
80
|
const query5 = tf.randomUniform([2, 5, 10]);
|
|
112
81
|
const key4 = tf.randomUniform([2, 4, 10]);
|
|
113
82
|
const value5 = tf.randomUniform([2, 4, 10]);
|
|
114
|
-
|
|
115
83
|
const expected_mask = tf.tensor([[
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
84
|
+
// vertical represents query, false means that token cannot attend to the keys
|
|
85
|
+
// horizontal represents key, false means that token cannot attend to the queries
|
|
86
|
+
[false, false, false, false],
|
|
87
|
+
[true, true, true, false,],
|
|
88
|
+
[true, true, true, false,],
|
|
89
|
+
[false, false, false, false],
|
|
90
|
+
[true, true, true, false,],
|
|
91
|
+
]]);
|
|
125
92
|
const packing_mask = tf.tensor([
|
|
126
93
|
[0, 0, 0, -1e7, -1e7],
|
|
127
94
|
[0, 0, 0, -1e7, -1e7],
|
|
128
95
|
[0, 0, 0, -1e7, -1e7],
|
|
129
96
|
[-1e7, -1e7, -1e7, 0, 0],
|
|
130
97
|
[-1e7, -1e7, -1e7, 0, 0]
|
|
131
|
-
])
|
|
132
|
-
|
|
98
|
+
]);
|
|
133
99
|
// for causal attention, the attention mask must be boolean
|
|
134
100
|
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0.1, true, { scaling_factor: 10 })).toThrow();
|
|
135
101
|
// for causal attention, using pre-calculated causal mask
|
|
136
|
-
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null,
|
|
102
|
+
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, generateCausalMask(query5.shape[1], key4.shape[1]), 0.2, true, { scaling_factor: 10 })).toThrow();
|
|
137
103
|
// when not using causal attention, the attention mask can be a float32 tensor
|
|
138
104
|
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0, false)).not.toThrow();
|
|
139
105
|
// packing mask for self attention
|
|
140
106
|
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, query5, query5, null, packing_mask, null, 0.9, true)).not.toThrow();
|
|
141
|
-
})
|
|
142
|
-
|
|
143
|
-
|
|
107
|
+
});
|
|
144
108
|
it("should return a non-empty config dict", () => {
|
|
145
109
|
const input = tf.randomUniform([2, 3, 10]);
|
|
146
|
-
|
|
147
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)! });
|
|
110
|
+
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
148
111
|
expect(Object.keys(attention.getConfig())).not.toBe(0);
|
|
149
|
-
})
|
|
150
|
-
|
|
151
|
-
|
|
112
|
+
});
|
|
152
113
|
test("causal attention hard coded values", () => {
|
|
153
114
|
// input and output shapes: [2, 3, 10]
|
|
154
115
|
const input = tf.tensor([
|
|
155
116
|
[[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
117
|
+
[0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
|
|
118
|
+
[0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
|
|
159
119
|
[[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
|
|
160
|
-
|
|
161
|
-
|
|
120
|
+
[0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
|
|
121
|
+
[0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
|
|
162
122
|
]);
|
|
163
|
-
|
|
164
123
|
const expected = tf.tensor([
|
|
165
124
|
[[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
125
|
+
[0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
|
|
126
|
+
[0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
|
|
169
127
|
[[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
|
|
170
|
-
|
|
171
|
-
|
|
128
|
+
[0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
|
|
129
|
+
[0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
|
|
172
130
|
]);
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)!, causal: true });
|
|
131
|
+
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: true });
|
|
176
132
|
attention.build(input.shape);
|
|
177
133
|
attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
})
|
|
181
|
-
|
|
182
|
-
|
|
134
|
+
expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
135
|
+
});
|
|
183
136
|
test("non-causal attention hard coded values", () => {
|
|
184
137
|
// input and output shapes: [2, 3, 10]
|
|
185
138
|
const input = tf.tensor([
|
|
186
139
|
[[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
140
|
+
[0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
|
|
141
|
+
[0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
|
|
190
142
|
[[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
|
|
191
|
-
|
|
192
|
-
|
|
143
|
+
[0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
|
|
144
|
+
[0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
|
|
193
145
|
]);
|
|
194
|
-
|
|
195
|
-
|
|
196
146
|
const expected = tf.tensor([
|
|
197
147
|
[[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
148
|
+
[0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
|
|
149
|
+
[0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
|
|
201
150
|
[[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
|
|
202
|
-
|
|
203
|
-
|
|
151
|
+
[0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
|
|
152
|
+
[0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
|
|
204
153
|
]);
|
|
205
|
-
|
|
206
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1)!, causal: false });
|
|
154
|
+
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: false });
|
|
207
155
|
attention.build(input.shape);
|
|
208
156
|
attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
|
|
209
|
-
|
|
210
|
-
expect(expected.sub(attention.apply(input) as tf.Tensor).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
157
|
+
expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
211
158
|
});
|
|
212
159
|
});
|
|
160
|
+
//# sourceMappingURL=multihead_attention.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"multihead_attention.test.js","sourceRoot":"","sources":["../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AACxF,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,SAAS,CAAC;AACvD,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACjG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAErD,MAAM,MAAM,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAC9G,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,sFAAsF,EAAE,GAAG,EAAE;QAC9F,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,oBAAoB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAChF,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,mCAAmC,EAAE,GAAG,EAAE;QACzC,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QACpD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QAE1F,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEhE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACrE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAEjG,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,WAAW,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAC/E,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACrD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC5C,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC1E,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,mBAAmB,EAAE,GAAG,EAAE;QAC3B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAEjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,KAAK,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEtD,kBAAkB;QAClB,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAGpE,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC5C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC1C,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC;gBAC7B,8EAA8E;gBAC9E,iFAAiF;gBACjF,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;aAC7B,CAAC,CAAC,CAAC;QAEJ,MAAM,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC;YAC3B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;YACxB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;SAC3B,CAAC,CAAA;QAEF,2DAA2D;QAC3D,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9K,yDAAyD;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,kBAAkB,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9N,8EAA8E;QAC9E,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,CAAC,EAAE,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACzJ,kCAAkC;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC7I,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC5C,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAEH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAGH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACjH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,wCAAwC,EAAE,GAAG,EAAE;QAChD,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAGH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAEH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,KAAK,EAAE,CAAC,CAAC;QAClH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
4
|
+
export interface PositionalEncodingArgs extends LayerArgs {
|
|
5
|
+
embedDim: number;
|
|
6
|
+
maxSequenceLength?: number;
|
|
7
|
+
}
|
|
8
|
+
/**
|
|
9
|
+
* This class implements the position encoding logic described in the
|
|
10
|
+
* 2017 paper "Attention Is All You Need".
|
|
11
|
+
*
|
|
12
|
+
* This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
|
|
13
|
+
* and adds positional encoding to return an output tensor of the same shape.
|
|
14
|
+
*
|
|
15
|
+
* @param embedDim the size of each token/word's embedding
|
|
16
|
+
* @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
|
|
17
|
+
*/
|
|
18
|
+
export declare class PositionalEncoding extends tf.layers.Layer {
|
|
19
|
+
static className: string;
|
|
20
|
+
private readonly maxSequenceLength;
|
|
21
|
+
private readonly embedDim;
|
|
22
|
+
private positionalEncodings;
|
|
23
|
+
constructor(args: PositionalEncodingArgs);
|
|
24
|
+
/**
|
|
25
|
+
* Forward propagation. Injects positional encoding to the input embeddings
|
|
26
|
+
*/
|
|
27
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
28
|
+
/**
|
|
29
|
+
* Generate the positional encoding from the paper Attention Is All You Need.
|
|
30
|
+
* Note that because the inner term of the position formula is the same for both even
|
|
31
|
+
* and odd indices, we only create half of it and apply sine and cosine individually.
|
|
32
|
+
*/
|
|
33
|
+
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
34
|
+
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
35
|
+
getConfig(): tf.serialization.ConfigDict;
|
|
36
|
+
}
|
|
37
|
+
//# sourceMappingURL=positional_encoding.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"positional_encoding.d.ts","sourceRoot":"","sources":["../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IAErD,QAAQ,EAAE,MAAM,CAAC;IAEjB,iBAAiB,CAAC,EAAE,MAAM,CAAC;CAC9B;AAGD;;;;;;;;;GASG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,mBAAmB,CAAmB;gBAGlC,IAAI,EAAE,sBAAsB;IAuBxC;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF;;;;OAIG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAmD9C,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAYpD"}
|
|
@@ -1,112 +1,81 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
-
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
export interface PositionalEncodingArgs extends LayerArgs {
|
|
7
|
-
// embedding size of each word/token, aka d_model from the paper
|
|
8
|
-
embedDim: number;
|
|
9
|
-
// the max length of each sentence, any more or less are truncated or padded
|
|
10
|
-
maxSequenceLength?: number;
|
|
11
|
-
}
|
|
12
|
-
|
|
13
|
-
|
|
14
2
|
/**
|
|
15
3
|
* This class implements the position encoding logic described in the
|
|
16
4
|
* 2017 paper "Attention Is All You Need".
|
|
17
|
-
*
|
|
5
|
+
*
|
|
18
6
|
* This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
|
|
19
7
|
* and adds positional encoding to return an output tensor of the same shape.
|
|
20
|
-
*
|
|
8
|
+
*
|
|
21
9
|
* @param embedDim the size of each token/word's embedding
|
|
22
10
|
* @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
|
|
23
11
|
*/
|
|
24
12
|
export class PositionalEncoding extends tf.layers.Layer {
|
|
25
13
|
static className = "PositionalEncoding";
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
constructor(args: PositionalEncodingArgs) {
|
|
14
|
+
maxSequenceLength;
|
|
15
|
+
embedDim;
|
|
16
|
+
positionalEncodings;
|
|
17
|
+
constructor(args) {
|
|
32
18
|
super(args);
|
|
33
|
-
|
|
34
19
|
this.maxSequenceLength = args.maxSequenceLength ?? 5120;
|
|
35
20
|
this.embedDim = args.embedDim;
|
|
36
|
-
|
|
37
21
|
if (this.maxSequenceLength < 1) {
|
|
38
22
|
throw Error(`${this.getClassName()}::constructor ${this.name} maxSequenceLength` +
|
|
39
23
|
` (${args.maxSequenceLength}) must be greater than 0`);
|
|
40
24
|
}
|
|
41
|
-
|
|
42
25
|
if (this.embedDim < 1) {
|
|
43
26
|
throw Error(`${this.getClassName()}::constructor ${this.name} embedDim` +
|
|
44
27
|
` (${args.embedDim}) must be greater than 0`);
|
|
45
28
|
}
|
|
46
|
-
|
|
47
29
|
// positional encodings are not trainable
|
|
48
|
-
this.positionalEncodings = this.addWeight('positional_encodings',
|
|
49
|
-
[this.maxSequenceLength, this.embedDim], "float32",
|
|
50
|
-
tf.initializers.zeros(), undefined, false);
|
|
30
|
+
this.positionalEncodings = this.addWeight('positional_encodings', [this.maxSequenceLength, this.embedDim], "float32", tf.initializers.zeros(), undefined, false);
|
|
51
31
|
}
|
|
52
|
-
|
|
53
|
-
|
|
54
32
|
/**
|
|
55
33
|
* Forward propagation. Injects positional encoding to the input embeddings
|
|
56
34
|
*/
|
|
57
|
-
|
|
35
|
+
call(inputs, kwargs) {
|
|
58
36
|
// validate the input tensors
|
|
59
37
|
const input = Array.isArray(inputs) ? inputs[0] : inputs;
|
|
60
|
-
const sequences = input.shape[1]
|
|
61
|
-
|
|
38
|
+
const sequences = input.shape[1];
|
|
62
39
|
if (input.shape.length != 3 || input.shape[2] != this.embedDim) {
|
|
63
40
|
throw Error(`${this.getClassName()}::call ${this.name} expected an input shape of` +
|
|
64
41
|
` [batch, (up to ${this.maxSequenceLength}), ${this.embedDim}], instead got ${input.shape}`);
|
|
65
42
|
}
|
|
66
|
-
|
|
67
43
|
if (sequences > this.maxSequenceLength) {
|
|
68
44
|
// unexpected sequence length
|
|
69
45
|
throw Error(`${this.getClassName()}::call ${this.name} received an input with` +
|
|
70
46
|
` sequence length (${sequences}) which is greater than the max sequence length` +
|
|
71
47
|
` ${this.maxSequenceLength}`);
|
|
72
48
|
}
|
|
73
|
-
|
|
74
49
|
// perform forward propagation
|
|
75
50
|
return tf.tidy(() => {
|
|
76
51
|
return input.add(this.positionalEncodings.read()
|
|
77
52
|
.slice([0, 0], [sequences, this.embedDim]) // gets the first "sequences" rows
|
|
78
53
|
.expandDims(0)); // introduce the batch dimension and let add() broadcast it
|
|
79
|
-
})
|
|
54
|
+
});
|
|
80
55
|
}
|
|
81
|
-
|
|
82
56
|
/**
|
|
83
57
|
* Generate the positional encoding from the paper Attention Is All You Need.
|
|
84
58
|
* Note that because the inner term of the position formula is the same for both even
|
|
85
59
|
* and odd indices, we only create half of it and apply sine and cosine individually.
|
|
86
60
|
*/
|
|
87
|
-
|
|
61
|
+
build(inputShape) {
|
|
88
62
|
tf.tidy(() => {
|
|
89
63
|
const embedDimHalved = Math.ceil(this.embedDim / 2);
|
|
90
|
-
|
|
91
64
|
// create the position matrix as [ 0, 1, 2, 3, etc ],
|
|
92
65
|
// and broadcast it horizontally to match the number of embeddings,
|
|
93
66
|
const numerator = tf.range(0, this.maxSequenceLength, 1)
|
|
94
67
|
.reshape([this.maxSequenceLength, 1])
|
|
95
68
|
// this creates an extra, unsued positional encoding column later on for odd embedding sizes
|
|
96
69
|
.broadcastTo([this.maxSequenceLength, embedDimHalved]);
|
|
97
|
-
|
|
98
70
|
// the inner term's denominator's exponent's numerator is created as
|
|
99
71
|
// [ 0, 0, 2, 2, 4, 4, etc ] ( technically [0, 2, 4] as explained above ) and not
|
|
100
72
|
// [ 0, 2, 4, 6, 8, 10, etc ] because the even and odd indices are counted as pairs
|
|
101
73
|
// when incrementing "i",
|
|
102
74
|
// the denominator formula is 10_000^(2i/d_model) where each "i" is a sine cosine pair
|
|
103
75
|
const denominator = tf.pow(10_000, tf.range(0, this.embedDim, 2).div(this.embedDim));
|
|
104
|
-
|
|
105
76
|
const inner_term = numerator.div(denominator);
|
|
106
|
-
|
|
107
77
|
const sine = tf.sin(inner_term);
|
|
108
78
|
const cosine = tf.cos(inner_term);
|
|
109
|
-
|
|
110
79
|
// horizontally interweave the sine and cosine columns together to form
|
|
111
80
|
// [sin, cos, sin, cos, etc]
|
|
112
81
|
// [sin, cos, sin, cos, etc]
|
|
@@ -115,44 +84,32 @@ export class PositionalEncoding extends tf.layers.Layer {
|
|
|
115
84
|
const ALL_ROWS = -1;
|
|
116
85
|
const ONE_COL = 1;
|
|
117
86
|
const FIRST_ROW = 0;
|
|
118
|
-
|
|
119
87
|
for (let targetCol = 0; targetCol < this.embedDim / 2; targetCol++) {
|
|
120
|
-
interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
|
|
121
|
-
|
|
88
|
+
interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
|
|
122
89
|
if (targetCol != Math.floor(this.embedDim / 2)) {
|
|
123
90
|
// for odd numbered embedDim sizes skip the last cosine column
|
|
124
91
|
// e.g. if embedDim = 5, create [ i=0 (sin), i=0 (cos), i=1 (sin), i=1 (cos), i=2 (sin) ]
|
|
125
92
|
// and the final i=2 (cos) is ignored
|
|
126
|
-
interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]))
|
|
93
|
+
interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
|
|
127
94
|
}
|
|
128
95
|
}
|
|
129
|
-
|
|
130
96
|
// add the positional encoding
|
|
131
97
|
this.setWeights([tf.concat(interweaved, 1)]);
|
|
132
98
|
});
|
|
133
|
-
|
|
134
99
|
super.build(inputShape);
|
|
135
100
|
}
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
override computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
|
|
101
|
+
computeOutputShape(inputShape) {
|
|
139
102
|
return inputShape;
|
|
140
103
|
}
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
override getConfig(): tf.serialization.ConfigDict {
|
|
104
|
+
getConfig() {
|
|
144
105
|
const base_config = super.getConfig();
|
|
145
|
-
|
|
146
106
|
const config = {
|
|
147
107
|
maxSequenceLength: this.maxSequenceLength,
|
|
148
108
|
embedDim: this.embedDim,
|
|
149
|
-
}
|
|
150
|
-
|
|
109
|
+
};
|
|
151
110
|
Object.assign(config, base_config);
|
|
152
|
-
|
|
153
111
|
return config;
|
|
154
112
|
}
|
|
155
113
|
}
|
|
156
|
-
|
|
157
|
-
|
|
158
114
|
tf.serialization.registerClass(PositionalEncoding);
|
|
115
|
+
//# sourceMappingURL=positional_encoding.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"positional_encoding.js","sourceRoot":"","sources":["../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAavC;;;;;;;;;GASG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACvB,iBAAiB,CAAS;IAC1B,QAAQ,CAAS;IAC1B,mBAAmB,CAAmB;IAG9C,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,iBAAiB,GAAG,IAAI,CAAC,iBAAiB,IAAI,IAAI,CAAC;QACxD,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAE9B,IAAI,IAAI,CAAC,iBAAiB,GAAG,CAAC,EAAE,CAAC;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,oBAAoB;gBAC5E,KAAK,IAAI,CAAC,iBAAiB,0BAA0B,CAAC,CAAC;QAC/D,CAAC;QAED,IAAI,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,WAAW;gBACnE,KAAK,IAAI,CAAC,QAAQ,0BAA0B,CAAC,CAAC;QACtD,CAAC;QAED,yCAAyC;QACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,SAAS,CAAC,sBAAsB,EAC5D,CAAC,IAAI,CAAC,iBAAiB,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,SAAS,EAClD,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IACnD,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;QACzD,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAElC,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,QAAQ,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,6BAA6B;gBAC9E,mBAAmB,IAAI,CAAC,iBAAiB,MAAM,IAAI,CAAC,QAAQ,kBAAkB,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC;QACrG,CAAC;QAED,IAAI,SAAS,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrC,6BAA6B;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,yBAAyB;gBAC1E,qBAAqB,SAAS,iDAAiD;gBAC/E,IAAI,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;QACtC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,mBAAmB,CAAC,IAAI,EAAE;iBAC3C,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,SAAS,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,kCAAkC;iBAC5E,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,2DAA2D;QACpF,CAAC,CAAC,CAAA;IACN,CAAC;IAED;;;;OAIG;IACM,KAAK,CAAC,UAAiC;QAC5C,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC;YAEpD,qDAAqD;YACrD,mEAAmE;YACnE,MAAM,SAAS,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;iBACnD,OAAO,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC,CAAC;gBACrC,4FAA4F;iBAC3F,WAAW,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,cAAc,CAAC,CAAC,CAAC;YAE3D,oEAAoE;YACpE,iFAAiF;YACjF,mFAAmF;YACnF,yBAAyB;YACzB,sFAAsF;YACtF,MAAM,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErF,MAAM,UAAU,GAAG,SAAS,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAE9C,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAChC,MAAM,MAAM,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAElC,uEAAuE;YACvE,4BAA4B;YAC5B,4BAA4B;YAC5B,MAAM;YACN,MAAM,WAAW,GAAG,EAAE,CAAC;YACvB,MAAM,QAAQ,GAAG,CAAC,CAAC,CAAC;YACpB,MAAM,OAAO,GAAG,CAAC,CAAC;YAClB,MAAM,SAAS,GAAG,CAAC,CAAC;YAEpB,KAAK,IAAI,SAAS,GAAG,CAAC,EAAE,SAAS,GAAG,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,SAAS,EAAE,EAAE,CAAC;gBACjE,WAAW,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAEzE,IAAI,SAAS,IAAI,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,EAAE,CAAC;oBAC7C,8DAA8D;oBAC9D,yFAAyF;oBACzF,qCAAqC;oBACrC,WAAW,CAAC,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAC/E,CAAC;YACL,CAAC;YAED,8BAA8B;YAC9B,IAAI,CAAC,UAAU,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACjD,CAAC,CAAC,CAAC;QAEH,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGQ,kBAAkB,CAAC,UAAiC;QACzD,OAAO,UAAU,CAAC;IACtB,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"positional_encoding.test.d.ts","sourceRoot":"","sources":["../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":""}
|