@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,100 +1,72 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
|
|
3
2
|
import { TransformerDecoder } from '@/layers/transformer_decoder';
|
|
4
|
-
|
|
5
3
|
// disables warning for using the faster node backend,
|
|
6
4
|
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
7
5
|
tf.env().set('IS_NODE', false);
|
|
8
|
-
|
|
9
|
-
|
|
10
6
|
describe("TransformerDecoder tests", () => {
|
|
11
7
|
it("should return an output with the same shape as the input", () => {
|
|
12
8
|
const input = tf.randomUniform([2, 3, 12]);
|
|
13
|
-
|
|
14
9
|
const decoder = new TransformerDecoder({
|
|
15
|
-
numHeads: 2, embedDim: input.shape.at(-1)
|
|
10
|
+
numHeads: 2, embedDim: input.shape.at(-1),
|
|
16
11
|
dropout: 0.5, activation: "gelu", dimsFeedForward: 321, useBias: false
|
|
17
12
|
});
|
|
18
|
-
|
|
19
|
-
const output = decoder.apply(input) as tf.Tensor;
|
|
20
|
-
|
|
13
|
+
const output = decoder.apply(input);
|
|
21
14
|
expect(output.shape.length).toBe(input.shape.length);
|
|
22
|
-
})
|
|
23
|
-
|
|
24
|
-
|
|
15
|
+
});
|
|
25
16
|
test("forward calls", () => {
|
|
26
17
|
const input = tf.randomUniform([2, 3, 12]);
|
|
27
|
-
const mask = tf.randomUniform([input.shape[0]
|
|
18
|
+
const mask = tf.randomUniform([input.shape[0], input.shape[1]], -1, 2, "bool");
|
|
28
19
|
const incorrect_mask = tf.randomUniform([2, 5, 12], -1, 2, "bool");
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1)! });
|
|
20
|
+
const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
32
21
|
expect(() => decoder.apply(input)).not.toThrow();
|
|
33
22
|
expect(() => decoder.apply([input])).not.toThrow();
|
|
34
|
-
|
|
35
23
|
// causal masking
|
|
36
|
-
const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1)
|
|
24
|
+
const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
|
|
37
25
|
expect(() => causal.apply(input)).not.toThrow();
|
|
38
26
|
expect(() => causal.apply([input])).not.toThrow();
|
|
39
|
-
})
|
|
40
|
-
|
|
41
|
-
|
|
27
|
+
});
|
|
42
28
|
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
43
29
|
const input = tf.randomUniform([2, 3, 12]);
|
|
44
|
-
|
|
45
|
-
expect(() => new TransformerDecoder({ numHeads:
|
|
46
|
-
|
|
47
|
-
})
|
|
48
|
-
|
|
49
|
-
|
|
30
|
+
expect(() => new TransformerDecoder({ numHeads: 3, embedDim: input.shape.at(-1) })).not.toThrow();
|
|
31
|
+
expect(() => new TransformerDecoder({ numHeads: 5, embedDim: input.shape.at(-1) })).toThrow();
|
|
32
|
+
});
|
|
50
33
|
it("should not accept non-rank 3 tensor inputs", () => {
|
|
51
34
|
const embed_dim = 12;
|
|
52
|
-
|
|
53
35
|
const BAD_RANK4 = tf.randomUniform([2, 3, 12, embed_dim]);
|
|
54
36
|
const BAD_RANK2 = tf.randomUniform([2, embed_dim]);
|
|
55
37
|
const GOOD = tf.randomUniform([2, 3, embed_dim]);
|
|
56
|
-
const mask = tf.randomUniform([GOOD.shape[0]
|
|
57
|
-
|
|
38
|
+
const mask = tf.randomUniform([GOOD.shape[0], GOOD.shape[1]], -1, 2, "bool");
|
|
58
39
|
let decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
|
|
59
|
-
|
|
60
40
|
// BAD
|
|
61
41
|
expect(() => decoder.apply(BAD_RANK4)).toThrow();
|
|
62
42
|
expect(() => decoder.apply(BAD_RANK2)).toThrow();
|
|
63
|
-
|
|
64
43
|
// OK
|
|
65
44
|
decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
|
|
66
45
|
expect(() => decoder.apply(GOOD)).not.toThrow();
|
|
67
46
|
expect(() => decoder.apply([GOOD])).not.toThrow();
|
|
68
47
|
expect(() => decoder.apply([GOOD, mask])).not.toThrow();
|
|
69
|
-
})
|
|
70
|
-
|
|
71
|
-
|
|
48
|
+
});
|
|
72
49
|
it("should not accept inputs that are less or more than 1 and 2 tensors", () => {
|
|
73
50
|
const input = tf.randomUniform([2, 3, 12]);
|
|
74
|
-
|
|
75
|
-
let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
|
|
51
|
+
let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
76
52
|
// OK
|
|
77
53
|
expect(() => decoder.apply(input)).not.toThrow();
|
|
78
54
|
expect(() => decoder.apply([input])).not.toThrow();
|
|
79
|
-
|
|
80
55
|
// BAD
|
|
81
|
-
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)
|
|
56
|
+
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
82
57
|
expect(() => decoder.apply([])).toThrow(); // stops at build()
|
|
83
58
|
decoder.apply(input); // get past the initial build
|
|
84
59
|
expect(() => decoder.apply([input, input, input])).toThrow();
|
|
85
60
|
expect(() => decoder.apply([input, input, input, input])).toThrow();
|
|
86
|
-
|
|
87
61
|
// BAD (tests build())
|
|
88
|
-
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)
|
|
62
|
+
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
89
63
|
expect(() => decoder.apply([input, input, input])).toThrow();
|
|
90
64
|
expect(() => decoder.apply([input, input, input, input])).toThrow();
|
|
91
|
-
})
|
|
92
|
-
|
|
93
|
-
|
|
65
|
+
});
|
|
94
66
|
it("should return a non-empty config dict", () => {
|
|
95
67
|
const input = tf.randomUniform([2, 3, 12]);
|
|
96
|
-
|
|
97
|
-
const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
|
|
68
|
+
const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
98
69
|
expect(Object.keys(decoder.getConfig())).not.toBe(0);
|
|
99
|
-
})
|
|
100
|
-
})
|
|
70
|
+
});
|
|
71
|
+
});
|
|
72
|
+
//# sourceMappingURL=transformer_decoder.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.test.js","sourceRoot":"","sources":["../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,KAAK;SACzE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,GAAG,EAAE;QACvB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QACjF,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAGnE,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,iBAAiB;QACjB,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACnG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC;QAErB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,SAAS,CAAC,CAAC,CAAC;QAC1D,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACnD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACjD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAE/E,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEjD,KAAK;QACL,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QACvE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC5D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,qEAAqE,EAAE,GAAG,EAAE;QAC3E,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACrF,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM;QACN,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,mBAAmB;QAC9D,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,6BAA6B;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEpE,sBAAsB;QACtB,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACxE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
+
import { type MultiHeadAttentionArgs } from "../layers/multihead_attention";
|
|
4
|
+
export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
|
|
5
|
+
activation?: "relu" | "gelu";
|
|
6
|
+
dimsFeedForward?: number;
|
|
7
|
+
}
|
|
8
|
+
/**
|
|
9
|
+
* This class implements the transformer encoder architecture from the 2017 paper
|
|
10
|
+
* Attention Is All You Need.
|
|
11
|
+
*
|
|
12
|
+
* This layer accepts exactly one tensor input with the shape
|
|
13
|
+
* `[ batch, sequences, embedding dims ]`.
|
|
14
|
+
*
|
|
15
|
+
* @param numHeads number of attention heads to use
|
|
16
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
17
|
+
* @param causal use causal masking, default `false` for encoders
|
|
18
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
19
|
+
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
20
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
21
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
22
|
+
*/
|
|
23
|
+
export declare class TransformerEncoder extends tf.layers.Layer {
|
|
24
|
+
static className: string;
|
|
25
|
+
private readonly selfAttention;
|
|
26
|
+
private readonly selfAttentionDropout;
|
|
27
|
+
private readonly selfAttentionNorm;
|
|
28
|
+
private readonly reluLayer;
|
|
29
|
+
private readonly linearLayer;
|
|
30
|
+
private readonly feedForwardDropout;
|
|
31
|
+
private readonly feedFowardNorm;
|
|
32
|
+
private readonly numHeads;
|
|
33
|
+
private readonly embedDim;
|
|
34
|
+
private readonly causal;
|
|
35
|
+
private readonly useBias;
|
|
36
|
+
private readonly dropout;
|
|
37
|
+
private readonly activation;
|
|
38
|
+
private readonly dimsFeedForward;
|
|
39
|
+
constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs);
|
|
40
|
+
/**
|
|
41
|
+
* Forward propagation
|
|
42
|
+
*/
|
|
43
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
44
|
+
private selfAttentionBlock;
|
|
45
|
+
private feedForwardBlock;
|
|
46
|
+
/**
|
|
47
|
+
* Initialize the sublayers' weights and track them to enable backpropagation.
|
|
48
|
+
*/
|
|
49
|
+
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
50
|
+
/**
|
|
51
|
+
* Save the layer's hyperparameters for serialization
|
|
52
|
+
*/
|
|
53
|
+
getConfig(): tf.serialization.ConfigDict;
|
|
54
|
+
}
|
|
55
|
+
//# sourceMappingURL=transformer_encoder.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_encoder.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,+BAA+B,CAAC;AAGhG,MAAM,WAAW,sBAAuB,SAAQ,sBAAsB;IAClE,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;;;;GAcG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,OAAO,CAAC,QAAQ,CAAC,aAAa,CAAkB;IAChD,OAAO,CAAC,QAAQ,CAAC,oBAAoB,CAAkB;IACvD,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAkB;IAEpD,OAAO,CAAC,QAAQ,CAAC,SAAS,CAAkB;IAC5C,OAAO,CAAC,QAAQ,CAAC,WAAW,CAAkB;IAC9C,OAAO,CAAC,QAAQ,CAAC,kBAAkB,CAAkB;IACrD,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAkB;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,MAAM,CAAU;IACjC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAU;IAClC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IACjC,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAuB;IAClD,OAAO,CAAC,QAAQ,CAAC,eAAe,CAAS;gBAG7B,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAqC1H;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF,OAAO,CAAC,kBAAkB;IAc1B,OAAO,CAAC,gBAAgB;IAexB;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAsDvD;;OAEG;IACM,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAiBpD"}
|
|
@@ -1,23 +1,12 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import {
|
|
3
|
-
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
4
|
-
|
|
5
|
-
import { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
|
|
9
|
-
activation?: "relu" | "gelu";
|
|
10
|
-
dimsFeedForward?: number;
|
|
11
|
-
}
|
|
12
|
-
|
|
13
|
-
|
|
2
|
+
import { MultiHeadAttention } from "../layers/multihead_attention";
|
|
14
3
|
/**
|
|
15
4
|
* This class implements the transformer encoder architecture from the 2017 paper
|
|
16
5
|
* Attention Is All You Need.
|
|
17
|
-
*
|
|
6
|
+
*
|
|
18
7
|
* This layer accepts exactly one tensor input with the shape
|
|
19
8
|
* `[ batch, sequences, embedding dims ]`.
|
|
20
|
-
*
|
|
9
|
+
*
|
|
21
10
|
* @param numHeads number of attention heads to use
|
|
22
11
|
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
23
12
|
* @param causal use causal masking, default `false` for encoders
|
|
@@ -28,28 +17,22 @@ export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
|
|
|
28
17
|
*/
|
|
29
18
|
export class TransformerEncoder extends tf.layers.Layer {
|
|
30
19
|
static className = "TransformerEncoder";
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
private readonly activation: ActivationIdentifier;
|
|
47
|
-
private readonly dimsFeedForward: number;
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs) {
|
|
20
|
+
selfAttention;
|
|
21
|
+
selfAttentionDropout;
|
|
22
|
+
selfAttentionNorm;
|
|
23
|
+
reluLayer;
|
|
24
|
+
linearLayer;
|
|
25
|
+
feedForwardDropout;
|
|
26
|
+
feedFowardNorm;
|
|
27
|
+
numHeads;
|
|
28
|
+
embedDim;
|
|
29
|
+
causal;
|
|
30
|
+
useBias;
|
|
31
|
+
dropout;
|
|
32
|
+
activation;
|
|
33
|
+
dimsFeedForward;
|
|
34
|
+
constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }) {
|
|
51
35
|
super(args);
|
|
52
|
-
|
|
53
36
|
this.numHeads = numHeads;
|
|
54
37
|
this.embedDim = embedDim;
|
|
55
38
|
this.causal = causal ?? false;
|
|
@@ -57,19 +40,16 @@ export class TransformerEncoder extends tf.layers.Layer {
|
|
|
57
40
|
this.dropout = dropout ?? 0.1;
|
|
58
41
|
this.activation = activation ?? "relu";
|
|
59
42
|
this.dimsFeedForward = dimsFeedForward ?? 2048;
|
|
60
|
-
|
|
61
43
|
if (this.dropout >= 1) {
|
|
62
44
|
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
63
45
|
}
|
|
64
|
-
|
|
65
46
|
// self attention sub-block
|
|
66
47
|
this.selfAttention = new MultiHeadAttention({
|
|
67
48
|
numHeads: this.numHeads, embedDim: this.embedDim, useBias: this.useBias,
|
|
68
49
|
dropout: this.dropout, causal: this.causal
|
|
69
50
|
});
|
|
70
|
-
this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
|
|
51
|
+
this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
|
|
71
52
|
this.selfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
72
|
-
|
|
73
53
|
// feed forward sub-block
|
|
74
54
|
this.reluLayer = tf.layers.dense({
|
|
75
55
|
units: this.dimsFeedForward, activation: this.activation,
|
|
@@ -82,96 +62,76 @@ export class TransformerEncoder extends tf.layers.Layer {
|
|
|
82
62
|
this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
|
|
83
63
|
this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
84
64
|
}
|
|
85
|
-
|
|
86
|
-
|
|
87
65
|
/**
|
|
88
66
|
* Forward propagation
|
|
89
67
|
*/
|
|
90
|
-
|
|
68
|
+
call(inputs, kwargs) {
|
|
91
69
|
// validate the input tensors
|
|
92
|
-
let input
|
|
93
|
-
|
|
70
|
+
let input;
|
|
94
71
|
if (Array.isArray(inputs)) {
|
|
95
72
|
if (inputs.length != 1) {
|
|
96
73
|
throw Error(`${this.getClassName}::call ${this.name} expects exactly 1 tensor` +
|
|
97
74
|
` input, got ${inputs.length} inputs instead.`);
|
|
98
75
|
}
|
|
99
|
-
|
|
100
76
|
input = inputs[0];
|
|
101
|
-
}
|
|
77
|
+
}
|
|
78
|
+
else {
|
|
102
79
|
input = inputs;
|
|
103
80
|
}
|
|
104
|
-
|
|
105
81
|
// perform forward propagation
|
|
106
82
|
return tf.tidy(() => {
|
|
107
83
|
const attention = this.selfAttentionBlock(input, kwargs);
|
|
108
84
|
const feedforward = this.feedForwardBlock(attention, kwargs);
|
|
109
|
-
|
|
110
85
|
return feedforward;
|
|
111
86
|
});
|
|
112
87
|
}
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
private selfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
88
|
+
selfAttentionBlock(x, kwargs) {
|
|
116
89
|
return tf.tidy(() => {
|
|
117
90
|
const residual = x;
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
attention = this.selfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
|
|
91
|
+
let attention = this.selfAttention.apply(x, kwargs);
|
|
92
|
+
attention = this.selfAttentionDropout.apply(attention, kwargs);
|
|
121
93
|
attention = tf.add(attention, residual);
|
|
122
|
-
attention = this.selfAttentionNorm.apply(attention)
|
|
123
|
-
|
|
94
|
+
attention = this.selfAttentionNorm.apply(attention);
|
|
124
95
|
return attention;
|
|
125
96
|
});
|
|
126
97
|
}
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
private feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
98
|
+
feedForwardBlock(x, kwargs) {
|
|
130
99
|
return tf.tidy(() => {
|
|
131
100
|
const residual = x;
|
|
132
|
-
|
|
133
101
|
let feedForward = this.reluLayer.apply(x);
|
|
134
102
|
feedForward = this.linearLayer.apply(feedForward);
|
|
135
|
-
feedForward = this.feedForwardDropout.apply(feedForward, kwargs)
|
|
103
|
+
feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
|
|
136
104
|
feedForward = tf.add(feedForward, residual);
|
|
137
|
-
feedForward = this.feedFowardNorm.apply(feedForward)
|
|
138
|
-
|
|
105
|
+
feedForward = this.feedFowardNorm.apply(feedForward);
|
|
139
106
|
return feedForward;
|
|
140
107
|
});
|
|
141
108
|
}
|
|
142
|
-
|
|
143
|
-
|
|
144
109
|
/**
|
|
145
110
|
* Initialize the sublayers' weights and track them to enable backpropagation.
|
|
146
111
|
*/
|
|
147
|
-
|
|
148
|
-
let input_shapes
|
|
149
|
-
|
|
112
|
+
build(inputShape) {
|
|
113
|
+
let input_shapes = [];
|
|
150
114
|
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
151
115
|
// input is an array of shapes
|
|
152
|
-
input_shapes = inputShape
|
|
153
|
-
}
|
|
116
|
+
input_shapes = inputShape;
|
|
117
|
+
}
|
|
118
|
+
else if (inputShape.length != 0) {
|
|
154
119
|
// input is a single shape
|
|
155
|
-
input_shapes = [inputShape
|
|
120
|
+
input_shapes = [inputShape];
|
|
156
121
|
}
|
|
157
|
-
|
|
158
122
|
// expects only 1 rank 3 tensor input
|
|
159
123
|
if (input_shapes.length != 1 || input_shapes[0].length != 3) {
|
|
160
|
-
throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
|
|
124
|
+
throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
161
125
|
}
|
|
162
|
-
|
|
163
126
|
// initialize self attention sub-block's weights
|
|
164
127
|
this.selfAttention.build(inputShape);
|
|
165
128
|
this.selfAttentionNorm.build(inputShape);
|
|
166
|
-
|
|
167
129
|
// inintialize feedforward sub-block's weights
|
|
168
130
|
const reluLayerOutputShape = this.reluLayer.computeOutputShape(inputShape);
|
|
169
131
|
const linearLayerOutputShape = this.linearLayer.computeOutputShape(reluLayerOutputShape);
|
|
170
|
-
|
|
171
132
|
this.reluLayer.build(inputShape);
|
|
172
133
|
this.linearLayer.build(reluLayerOutputShape);
|
|
173
134
|
this.feedFowardNorm.build(linearLayerOutputShape);
|
|
174
|
-
|
|
175
135
|
// track sublayers' weights
|
|
176
136
|
this.trainableWeights = [
|
|
177
137
|
...this.selfAttention.trainableWeights,
|
|
@@ -182,28 +142,22 @@ export class TransformerEncoder extends tf.layers.Layer {
|
|
|
182
142
|
...this.feedForwardDropout.trainableWeights,
|
|
183
143
|
...this.feedFowardNorm.trainableWeights
|
|
184
144
|
];
|
|
185
|
-
|
|
186
145
|
// rename the weights otherwise they'll take on the default naming and overlap
|
|
187
146
|
// each other which breaks model loading due to duplicate weight names
|
|
188
147
|
let indexing = 0;
|
|
189
|
-
|
|
190
148
|
for (const weight of this.trainableWeights) {
|
|
191
149
|
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
192
|
-
|
|
193
|
-
|
|
150
|
+
weight.name += unique_name;
|
|
151
|
+
weight.originalName += unique_name;
|
|
194
152
|
indexing++;
|
|
195
153
|
}
|
|
196
|
-
|
|
197
154
|
super.build(inputShape);
|
|
198
155
|
}
|
|
199
|
-
|
|
200
|
-
|
|
201
156
|
/**
|
|
202
157
|
* Save the layer's hyperparameters for serialization
|
|
203
158
|
*/
|
|
204
|
-
|
|
159
|
+
getConfig() {
|
|
205
160
|
const base_config = super.getConfig();
|
|
206
|
-
|
|
207
161
|
const config = {
|
|
208
162
|
numHeads: this.numHeads,
|
|
209
163
|
embedDim: this.embedDim,
|
|
@@ -213,12 +167,9 @@ export class TransformerEncoder extends tf.layers.Layer {
|
|
|
213
167
|
activation: this.activation,
|
|
214
168
|
dimsFeedForward: this.dimsFeedForward
|
|
215
169
|
};
|
|
216
|
-
|
|
217
170
|
Object.assign(config, base_config);
|
|
218
|
-
|
|
219
171
|
return config;
|
|
220
172
|
}
|
|
221
173
|
}
|
|
222
|
-
|
|
223
|
-
|
|
224
174
|
tf.serialization.registerClass(TransformerEncoder);
|
|
175
|
+
//# sourceMappingURL=transformer_encoder.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_encoder.js","sourceRoot":"","sources":["../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,+BAA+B,CAAC;AAShG;;;;;;;;;;;;;;GAcG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAEvB,aAAa,CAAkB;IAC/B,oBAAoB,CAAkB;IACtC,iBAAiB,CAAkB;IAEnC,SAAS,CAAkB;IAC3B,WAAW,CAAkB;IAC7B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,MAAM,CAAU;IAChB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAGzC,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QACtH,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,MAAM,GAAG,MAAM,IAAI,KAAK,CAAC;QAC9B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QACvC,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,IAAI,CAAC;QAE/C,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2BAA2B;QAC3B,IAAI,CAAC,aAAa,GAAG,IAAI,kBAAkB,CAAC;YACxC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YACvE,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,MAAM,EAAE,IAAI,CAAC,MAAM;SAC7C,CAAC,CAAC;QACH,IAAI,CAAC,oBAAoB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QACrE,IAAI,CAAC,iBAAiB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAEzE,yBAAyB;QACzB,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC7B,KAAK,EAAE,IAAI,CAAC,eAAe,EAAE,UAAU,EAAE,IAAI,CAAC,UAAU;YACxD,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC/B,KAAK,EAAE,IAAI,CAAC,QAAQ,EAAE,UAAU,EAAE,QAAQ;YAC1C,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAgB,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBACrB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,UAAU,IAAI,CAAC,IAAI,2BAA2B;oBAC1E,eAAe,MAAM,CAAC,MAAM,kBAAkB,CAAC,CAAC;YACxD,CAAC;YAED,KAAK,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACtB,CAAC;aAAM,CAAC;YACJ,KAAK,GAAG,MAAM,CAAC;QACnB,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,SAAS,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC;YACzD,MAAM,WAAW,GAAG,IAAI,CAAC,gBAAgB,CAAC,SAAS,EAAE,MAAM,CAAC,CAAC;YAE7D,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,kBAAkB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACjE,SAAS,GAAG,IAAI,CAAC,oBAAoB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC5E,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,SAAS,CAAc,CAAC;YAEjE,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACjD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC1C,WAAW,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC;YAClD,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,CAAc,CAAC;YAElE,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,qCAAqC;QACrC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC1D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,iEAAiE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACxJ,CAAC;QAED,gDAAgD;QAChD,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACrC,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QAEzC,8CAA8C;QAC9C,MAAM,oBAAoB,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC3E,MAAM,sBAAsB,GAAG,IAAI,CAAC,WAAW,CAAC,kBAAkB,CAAC,oBAAoB,CAAC,CAAC;QAEzF,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACjC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC;QAC7C,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,sBAAsB,CAAC,CAAC;QAElD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,oBAAoB,CAAC,gBAAgB;YAC7C,GAAG,IAAI,CAAC,iBAAiB,CAAC,gBAAgB;YAC1C,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,WAAW,CAAC,gBAAgB;YACpC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAC;QAEF,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_encoder.test.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":""}
|
|
@@ -1,85 +1,58 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
|
|
3
2
|
import { TransformerEncoder } from "@/layers/transformer_encoder";
|
|
4
|
-
|
|
5
3
|
// disables warning for using the faster node backend,
|
|
6
4
|
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
7
5
|
tf.env().set('IS_NODE', false);
|
|
8
|
-
|
|
9
|
-
|
|
10
6
|
describe("TransformerEncoder tests", () => {
|
|
11
7
|
it("should return an output with the same shape as the input", () => {
|
|
12
8
|
const input = tf.randomUniform([2, 3, 10]);
|
|
13
|
-
|
|
14
9
|
const decoder = new TransformerEncoder({
|
|
15
|
-
numHeads: 2, embedDim: input.shape.at(-1)
|
|
10
|
+
numHeads: 2, embedDim: input.shape.at(-1),
|
|
16
11
|
dropout: 0.5, activation: "gelu", dimsFeedForward: 512, useBias: true
|
|
17
12
|
});
|
|
18
|
-
|
|
19
|
-
const output = decoder.apply(input) as tf.Tensor;
|
|
20
|
-
|
|
13
|
+
const output = decoder.apply(input);
|
|
21
14
|
expect(output.shape.length).toBe(input.shape.length);
|
|
22
|
-
})
|
|
23
|
-
|
|
24
|
-
|
|
15
|
+
});
|
|
25
16
|
test("correct forward calls", () => {
|
|
26
17
|
const input = tf.randomUniform([2, 3, 10]);
|
|
27
|
-
|
|
28
|
-
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)! });
|
|
18
|
+
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
29
19
|
expect(() => encoder.apply(input)).not.toThrow();
|
|
30
20
|
expect(() => encoder.apply([input])).not.toThrow();
|
|
31
|
-
|
|
32
|
-
const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
|
|
21
|
+
const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
|
|
33
22
|
expect(() => causal.apply(input)).not.toThrow();
|
|
34
23
|
expect(() => causal.apply([input])).not.toThrow();
|
|
35
|
-
})
|
|
36
|
-
|
|
37
|
-
|
|
24
|
+
});
|
|
38
25
|
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
39
26
|
const input = tf.randomUniform([2, 3, 10]);
|
|
40
|
-
|
|
41
|
-
expect(() => new TransformerEncoder({ numHeads:
|
|
42
|
-
|
|
43
|
-
})
|
|
44
|
-
|
|
45
|
-
|
|
27
|
+
expect(() => new TransformerEncoder({ numHeads: 3, embedDim: input.shape.at(-1) })).toThrow();
|
|
28
|
+
expect(() => new TransformerEncoder({ numHeads: 5, embedDim: input.shape.at(-1) })).not.toThrow();
|
|
29
|
+
});
|
|
46
30
|
it("should not accept non-rank 3 tensor inputs", () => {
|
|
47
31
|
const incorrect_input = tf.randomUniform([2, 3, 10, 10]);
|
|
48
32
|
const incorrect_input2 = tf.randomUniform([2, 3]);
|
|
49
33
|
const correct_input = tf.randomUniform([2, 3, 10]);
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1)! });
|
|
34
|
+
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1) });
|
|
53
35
|
expect(() => encoder.apply([correct_input, correct_input])).toThrow();
|
|
54
|
-
|
|
55
36
|
expect(() => encoder.apply(incorrect_input)).toThrow();
|
|
56
37
|
expect(() => encoder.apply(incorrect_input2)).toThrow();
|
|
57
|
-
|
|
58
38
|
expect(() => encoder.apply([correct_input, incorrect_input])).toThrow();
|
|
59
39
|
expect(() => encoder.apply([incorrect_input, correct_input])).toThrow();
|
|
60
|
-
|
|
61
40
|
expect(() => encoder.apply([correct_input, incorrect_input2])).toThrow();
|
|
62
41
|
expect(() => encoder.apply([incorrect_input2, correct_input])).toThrow();
|
|
63
|
-
})
|
|
64
|
-
|
|
65
|
-
|
|
42
|
+
});
|
|
66
43
|
it("should accept exactly one input", () => {
|
|
67
44
|
const input = tf.randomUniform([2, 3, 10]);
|
|
68
|
-
|
|
69
|
-
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
|
|
45
|
+
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
70
46
|
expect(() => encoder.apply(input)).not.toThrow();
|
|
71
47
|
expect(() => encoder.apply([input])).not.toThrow();
|
|
72
|
-
|
|
73
48
|
expect(() => encoder.apply([])).toThrow();
|
|
74
49
|
expect(() => encoder.apply([input, input])).toThrow();
|
|
75
|
-
expect(() => encoder.apply([input, input, input])).toThrow()
|
|
76
|
-
})
|
|
77
|
-
|
|
78
|
-
|
|
50
|
+
expect(() => encoder.apply([input, input, input])).toThrow();
|
|
51
|
+
});
|
|
79
52
|
it("should return a non-empty config dict", () => {
|
|
80
53
|
const input = tf.randomUniform([2, 3, 10]);
|
|
81
|
-
|
|
82
|
-
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
|
|
54
|
+
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
83
55
|
expect(Object.keys(encoder.getConfig())).not.toBe(0);
|
|
84
|
-
})
|
|
85
|
-
})
|
|
56
|
+
});
|
|
57
|
+
});
|
|
58
|
+
//# sourceMappingURL=transformer_encoder.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_encoder.test.js","sourceRoot":"","sources":["../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,IAAI;SACxE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,uBAAuB,EAAE,GAAG,EAAE;QAC/B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACvG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,eAAe,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QACzD,MAAM,gBAAgB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,aAAa,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAGnD,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,eAAe,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEtE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,gBAAgB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,eAAe,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,gBAAgB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACzE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,gBAAgB,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC7E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAA;IAChE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
export declare function diceBinaryStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
3
|
+
export declare function diceBinaryGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
4
|
+
export declare function diceCategoricalStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
5
|
+
export declare function diceCategoricalGeneralized(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
6
|
+
export declare function diceCategoricalGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
7
|
+
/**
|
|
8
|
+
* Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
|
|
9
|
+
* Both have equal weight.
|
|
10
|
+
*
|
|
11
|
+
* @param y_true the label tensor
|
|
12
|
+
* @param y_pred the prediction tensor (not sparse)
|
|
13
|
+
* @returns a tensor of shape `[ batch ]`
|
|
14
|
+
*/
|
|
15
|
+
export declare function diceBinaryCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
16
|
+
/**
|
|
17
|
+
* Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
|
|
18
|
+
* Both have equal weight. Expects dense (non-sparse) label tensors.
|
|
19
|
+
*
|
|
20
|
+
* This does not support sparse tensors because TFJS's
|
|
21
|
+
* sparseCategoricalCrossentropy loss onehots the label
|
|
22
|
+
* and calls categoricalCrossentropy. See
|
|
23
|
+
* https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
|
|
24
|
+
*
|
|
25
|
+
* @param y_true the label
|
|
26
|
+
* @param y_pred the prediction tensor (not sparse)
|
|
27
|
+
* @returns a tensor of shape `[ batch ]`
|
|
28
|
+
*/
|
|
29
|
+
export declare function diceCategoricalCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
|
|
30
|
+
//# sourceMappingURL=dice.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"dice.d.ts","sourceRoot":"","sources":["../../src/losses/dice.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAWvC,wBAAgB,kBAAkB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAclF;AAQD,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAahF;AAOD,wBAAgB,uBAAuB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAUvF;AAOD,wBAAgB,0BAA0B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB1F;AAOD,wBAAgB,qBAAqB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAWrF;AAOD;;;;;;;GAOG;AACH,wBAAgB,sBAAsB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAMtF;AAOD;;;;;;;;;;;;GAYG;AACH,wBAAgB,2BAA2B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAM3F"}
|