@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { TokenAndPositionalEmbedding } from '@/layers/token_and_positional_embedding';
|
|
3
|
+
// disables warning for using the faster node backend,
|
|
4
|
+
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
5
|
+
tf.env().set('IS_NODE', false);
|
|
6
|
+
describe("PositionalEncoding tests", () => {
|
|
7
|
+
test("layer initialization", () => {
|
|
8
|
+
expect(() => new TokenAndPositionalEmbedding({ maxSequenceLength: 0, embedDim: 10, vocabularySize: 10_000 })).toThrow();
|
|
9
|
+
expect(() => new TokenAndPositionalEmbedding({ embedDim: 0, vocabularySize: 10_000 })).toThrow();
|
|
10
|
+
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 0 })).toThrow();
|
|
11
|
+
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
|
|
12
|
+
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
|
|
13
|
+
});
|
|
14
|
+
test("successfull forward calls", () => {
|
|
15
|
+
const embed_dims = 32;
|
|
16
|
+
const sequences = 4;
|
|
17
|
+
const vocab_size = 10_000;
|
|
18
|
+
const input = tf.randomUniform([2, sequences]);
|
|
19
|
+
const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, dropout: 0.1, vocabularySize: vocab_size });
|
|
20
|
+
expect(() => embedding.apply(input)).not.toThrow();
|
|
21
|
+
expect(() => embedding.apply([input])).not.toThrow();
|
|
22
|
+
});
|
|
23
|
+
test("layer build", () => {
|
|
24
|
+
const input_ok = tf.randomUniform([2, 4]);
|
|
25
|
+
const input_too_many_words = tf.randomUniform([2, 700]);
|
|
26
|
+
const input_is_image = tf.randomUniform([1, 32, 32, 3]);
|
|
27
|
+
let embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
28
|
+
expect(() => embedding.build(input_ok.shape)).not.toThrow();
|
|
29
|
+
embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
30
|
+
expect(() => embedding.build([input_ok.shape, input_ok.shape])).not.toThrow();
|
|
31
|
+
new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
32
|
+
expect(() => embedding.build(input_too_many_words.shape)).toThrow();
|
|
33
|
+
expect(() => embedding.build(input_is_image.shape)).toThrow();
|
|
34
|
+
});
|
|
35
|
+
it("should throw when more than one input provided, input sequences are too large, or incorrect input rank", () => {
|
|
36
|
+
const sequences_too_long = tf.randomUniform([10, 1000]);
|
|
37
|
+
const multiple_correct_inputs = [tf.randomUniform([2, 3]), tf.randomUniform([2, 3])];
|
|
38
|
+
const wrong_rank = tf.randomUniform([10, 32, 32]);
|
|
39
|
+
const positional = new TokenAndPositionalEmbedding({ maxSequenceLength: 10, embedDim: 32, vocabularySize: 10_000 });
|
|
40
|
+
positional.build([2, 3]); // get past the initial build call to test forward prop
|
|
41
|
+
expect(() => positional.apply(sequences_too_long)).toThrow();
|
|
42
|
+
expect(() => positional.apply(multiple_correct_inputs)).toThrow();
|
|
43
|
+
expect(() => positional.apply(wrong_rank)).toThrow();
|
|
44
|
+
});
|
|
45
|
+
it("should return a non-empty config dict", () => {
|
|
46
|
+
const embedding = new TokenAndPositionalEmbedding({ embedDim: 32, vocabularySize: 10_000 });
|
|
47
|
+
expect(Object.keys(embedding.getConfig())).not.toBe(0);
|
|
48
|
+
});
|
|
49
|
+
it("should return an output shape of [batch, sequences, embed dims]", () => {
|
|
50
|
+
const words = 100;
|
|
51
|
+
const batch = 2;
|
|
52
|
+
const embed_dims = 64;
|
|
53
|
+
const input = tf.randomUniform([batch, words]);
|
|
54
|
+
const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, vocabularySize: 10_000 });
|
|
55
|
+
expect(embedding.computeOutputShape(input.shape)).toEqual([batch, words, embed_dims]);
|
|
56
|
+
});
|
|
57
|
+
});
|
|
58
|
+
//# sourceMappingURL=token_and_positional_embedding.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"token_and_positional_embedding.test.js","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,2BAA2B,EAAE,MAAM,yCAAyC,CAAC;AAEtF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,IAAI,CAAC,sBAAsB,EAAE,GAAG,EAAE;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxH,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE7F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC1G,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,UAAU,GAAG,MAAM,CAAC;QAC1B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,cAAc,EAAE,UAAU,EAAE,CAAC,CAAC;QACtH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,aAAa,EAAE,GAAG,EAAE;QACrB,MAAM,QAAQ,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC;QACxD,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;QAExD,IAAI,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE5D,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,QAAQ,CAAC,KAAK,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE9E,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,oBAAoB,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAClE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,wGAAwG,EAAE,GAAG,EAAE;QAC9G,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAC,CAAC;QACxD,MAAM,uBAAuB,GAAG,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACrF,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QACpH,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,uDAAuD;QAEjF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAClE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAC5F,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iEAAiE,EAAE,GAAG,EAAE;QACvE,MAAM,KAAK,GAAG,GAAG,CAAC;QAClB,MAAM,KAAK,GAAG,CAAC,CAAC;QAChB,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAEpG,MAAM,CAAC,SAAS,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,UAAU,CAAC,CAAC,CAAC;IAC1F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
+
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
4
|
+
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
5
|
+
export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
|
|
6
|
+
activation?: "relu" | "gelu";
|
|
7
|
+
dimsFeedForward?: number;
|
|
8
|
+
causal?: boolean;
|
|
9
|
+
}
|
|
10
|
+
/**
|
|
11
|
+
* This class implements the transformer decoder architecture from
|
|
12
|
+
* the 2017 paper "Attention Is All You Need".
|
|
13
|
+
*
|
|
14
|
+
* This decoder-only transformer layer accepts one tensor input.
|
|
15
|
+
* The input tensor should have the shape
|
|
16
|
+
* `[ batch, sequences, embedding dims ]`.
|
|
17
|
+
*
|
|
18
|
+
* Causal masking is enabled by default for the initial attention sub-layer.
|
|
19
|
+
*
|
|
20
|
+
* @param numHeads number of attention heads to use
|
|
21
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
22
|
+
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
23
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
24
|
+
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
25
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
26
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
27
|
+
*/
|
|
28
|
+
export declare class TransformerDecoder extends tf.layers.Layer {
|
|
29
|
+
static className: string;
|
|
30
|
+
protected readonly causalSelfAttention: tf.layers.Layer;
|
|
31
|
+
protected readonly causalSelfAttentionDropout: tf.layers.Layer;
|
|
32
|
+
protected readonly causalSelfAttentionNorm: tf.layers.Layer;
|
|
33
|
+
protected readonly feedforward1: tf.layers.Layer;
|
|
34
|
+
protected readonly feedforward2: tf.layers.Layer;
|
|
35
|
+
protected readonly feedForwardDropout: tf.layers.Layer;
|
|
36
|
+
protected readonly feedFowardNorm: tf.layers.Layer;
|
|
37
|
+
protected readonly numHeads: number;
|
|
38
|
+
protected readonly embedDim: number;
|
|
39
|
+
protected readonly useBias: boolean;
|
|
40
|
+
protected readonly dropout: number;
|
|
41
|
+
protected readonly activation: ActivationIdentifier;
|
|
42
|
+
protected readonly dimsFeedForward: number;
|
|
43
|
+
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs);
|
|
44
|
+
/**
|
|
45
|
+
* Forward propagation
|
|
46
|
+
*
|
|
47
|
+
* @param inputs input tensor
|
|
48
|
+
* @return the output tensor
|
|
49
|
+
*/
|
|
50
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
51
|
+
protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
52
|
+
protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
53
|
+
/**
|
|
54
|
+
* Initialize the sublayers' weights and track them to enable serialization
|
|
55
|
+
*/
|
|
56
|
+
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
57
|
+
/**
|
|
58
|
+
* Save the layer's hyperparameters for serialization
|
|
59
|
+
*/
|
|
60
|
+
getConfig(): {
|
|
61
|
+
numHeads: number;
|
|
62
|
+
embedDim: number;
|
|
63
|
+
useBias: boolean;
|
|
64
|
+
dropout: number;
|
|
65
|
+
activation: ActivationIdentifier;
|
|
66
|
+
dimsFeedForward: number;
|
|
67
|
+
};
|
|
68
|
+
}
|
|
69
|
+
//# sourceMappingURL=transformer_decoder.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AACjE,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAExG,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAI3E,MAAM,WAAW,sBAAuB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAClF,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;IACzB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD;;;;;;;;;;;;;;;;;GAiBG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,SAAS,CAAC,QAAQ,CAAC,mBAAmB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACxD,SAAS,CAAC,QAAQ,CAAC,0BAA0B,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAC/D,SAAS,CAAC,QAAQ,CAAC,uBAAuB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAE5D,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,kBAAkB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACvD,SAAS,CAAC,QAAQ,CAAC,cAAc,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAEnD,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,UAAU,EAAE,oBAAoB,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,MAAM,CAAC;gBAE/B,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAyClH;;;;;OAKG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAoBvF,SAAS,CAAC,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAc3E,SAAS,CAAC,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAenE;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA6DvD;;OAEG;IACM,SAAS;;;;;;;;CAiBrB"}
|
|
@@ -1,28 +1,15 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
-
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
4
|
-
|
|
5
|
-
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
6
2
|
import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
|
|
10
|
-
activation?: "relu" | "gelu";
|
|
11
|
-
dimsFeedForward?: number;
|
|
12
|
-
causal?: boolean; // use causal mask for attention on inputs
|
|
13
|
-
}
|
|
14
|
-
|
|
15
|
-
|
|
16
3
|
/**
|
|
17
4
|
* This class implements the transformer decoder architecture from
|
|
18
5
|
* the 2017 paper "Attention Is All You Need".
|
|
19
|
-
*
|
|
6
|
+
*
|
|
20
7
|
* This decoder-only transformer layer accepts one tensor input.
|
|
21
8
|
* The input tensor should have the shape
|
|
22
9
|
* `[ batch, sequences, embedding dims ]`.
|
|
23
|
-
*
|
|
10
|
+
*
|
|
24
11
|
* Causal masking is enabled by default for the initial attention sub-layer.
|
|
25
|
-
*
|
|
12
|
+
*
|
|
26
13
|
* @param numHeads number of attention heads to use
|
|
27
14
|
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
28
15
|
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
@@ -33,48 +20,39 @@ export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "ca
|
|
|
33
20
|
*/
|
|
34
21
|
export class TransformerDecoder extends tf.layers.Layer {
|
|
35
22
|
static className = "TransformerDecoder";
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
protected readonly activation: ActivationIdentifier;
|
|
51
|
-
protected readonly dimsFeedForward: number;
|
|
52
|
-
|
|
53
|
-
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs) {
|
|
23
|
+
causalSelfAttention;
|
|
24
|
+
causalSelfAttentionDropout;
|
|
25
|
+
causalSelfAttentionNorm;
|
|
26
|
+
feedforward1;
|
|
27
|
+
feedforward2;
|
|
28
|
+
feedForwardDropout;
|
|
29
|
+
feedFowardNorm;
|
|
30
|
+
numHeads;
|
|
31
|
+
embedDim;
|
|
32
|
+
useBias;
|
|
33
|
+
dropout;
|
|
34
|
+
activation;
|
|
35
|
+
dimsFeedForward;
|
|
36
|
+
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }) {
|
|
54
37
|
super(args);
|
|
55
|
-
|
|
56
38
|
this.numHeads = numHeads;
|
|
57
39
|
this.embedDim = embedDim;
|
|
58
40
|
this.useBias = useBias ?? true;
|
|
59
41
|
this.dropout = dropout ?? 0.1;
|
|
60
42
|
this.activation = activation ?? "relu";
|
|
61
|
-
|
|
62
43
|
if (this.dropout >= 1) {
|
|
63
44
|
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
64
45
|
}
|
|
65
|
-
|
|
66
46
|
// in the paper section 3.3, d_model=512 (embedDim) and first dense layer outputs d_ff=2048
|
|
67
47
|
this.dimsFeedForward = dimsFeedForward ?? embedDim * 4;
|
|
68
|
-
|
|
69
48
|
// self attention sub-block
|
|
70
49
|
this.causalSelfAttention = new CachedRoPEMultiHeadAttention({
|
|
71
50
|
numHeads: this.numHeads, embedDim: this.embedDim,
|
|
72
51
|
useBias: this.useBias, dropout: this.dropout,
|
|
73
52
|
causal: true
|
|
74
53
|
});
|
|
75
|
-
this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
|
|
54
|
+
this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
|
|
76
55
|
this.causalSelfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
77
|
-
|
|
78
56
|
// feed forward sub-block
|
|
79
57
|
this.feedforward1 = tf.layers.dense({
|
|
80
58
|
units: this.dimsFeedForward,
|
|
@@ -89,101 +67,79 @@ export class TransformerDecoder extends tf.layers.Layer {
|
|
|
89
67
|
this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
|
|
90
68
|
this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
91
69
|
}
|
|
92
|
-
|
|
93
|
-
|
|
94
70
|
/**
|
|
95
71
|
* Forward propagation
|
|
96
|
-
*
|
|
72
|
+
*
|
|
97
73
|
* @param inputs input tensor
|
|
98
74
|
* @return the output tensor
|
|
99
75
|
*/
|
|
100
|
-
|
|
76
|
+
call(inputs, kwargs) {
|
|
101
77
|
// validate the input tensors
|
|
102
78
|
if (Array.isArray(inputs) && inputs.length != 1 && inputs.length != 2) {
|
|
103
79
|
throw Error(`${this.getClassName()}::call ${this.name} expects one input tensor, got ${inputs.length} inputs.`);
|
|
104
80
|
}
|
|
105
|
-
|
|
106
81
|
if (Array.isArray(inputs)) {
|
|
107
|
-
inputs = inputs[0]
|
|
82
|
+
inputs = inputs[0];
|
|
108
83
|
}
|
|
109
|
-
|
|
110
84
|
// perform forward propagation
|
|
111
85
|
return tf.tidy(() => {
|
|
112
86
|
let output = this.causalSelfAttentionBlock(inputs, kwargs);
|
|
113
87
|
output = this.feedForwardBlock(output, kwargs);
|
|
114
|
-
|
|
115
88
|
return output;
|
|
116
89
|
});
|
|
117
90
|
}
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
91
|
+
causalSelfAttentionBlock(x, kwargs) {
|
|
121
92
|
return tf.tidy(() => {
|
|
122
93
|
const residual = x;
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
attention = this.causalSelfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
|
|
94
|
+
let attention = this.causalSelfAttention.apply(x, kwargs);
|
|
95
|
+
attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
|
|
126
96
|
attention = tf.add(attention, residual);
|
|
127
|
-
attention = this.causalSelfAttentionNorm.apply(attention, kwargs)
|
|
128
|
-
|
|
97
|
+
attention = this.causalSelfAttentionNorm.apply(attention, kwargs);
|
|
129
98
|
return attention;
|
|
130
99
|
});
|
|
131
100
|
}
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
101
|
+
feedForwardBlock(x, kwargs) {
|
|
135
102
|
return tf.tidy(() => {
|
|
136
103
|
const residual = x;
|
|
137
|
-
|
|
138
104
|
let feedForward = this.feedforward1.apply(x, kwargs);
|
|
139
105
|
feedForward = this.feedforward2.apply(feedForward, kwargs);
|
|
140
|
-
feedForward = this.feedForwardDropout.apply(feedForward, kwargs)
|
|
106
|
+
feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
|
|
141
107
|
feedForward = tf.add(feedForward, residual);
|
|
142
|
-
feedForward = this.feedFowardNorm.apply(feedForward, kwargs)
|
|
143
|
-
|
|
108
|
+
feedForward = this.feedFowardNorm.apply(feedForward, kwargs);
|
|
144
109
|
return feedForward;
|
|
145
110
|
});
|
|
146
111
|
}
|
|
147
|
-
|
|
148
|
-
|
|
149
112
|
/**
|
|
150
113
|
* Initialize the sublayers' weights and track them to enable serialization
|
|
151
114
|
*/
|
|
152
|
-
|
|
153
|
-
let input_shapes
|
|
154
|
-
|
|
115
|
+
build(inputShape) {
|
|
116
|
+
let input_shapes = [];
|
|
155
117
|
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
156
118
|
// input is an array of shapes
|
|
157
|
-
input_shapes = inputShape
|
|
158
|
-
}
|
|
119
|
+
input_shapes = inputShape;
|
|
120
|
+
}
|
|
121
|
+
else if (inputShape.length != 0) {
|
|
159
122
|
// input is a single shape
|
|
160
|
-
input_shapes = [inputShape
|
|
123
|
+
input_shapes = [inputShape];
|
|
161
124
|
}
|
|
162
|
-
|
|
163
125
|
if (input_shapes.length != 1 && input_shapes.length != 2) {
|
|
164
126
|
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
165
|
-
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
|
|
127
|
+
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
166
128
|
}
|
|
167
|
-
|
|
168
129
|
const [decoderInputShape] = input_shapes;
|
|
169
|
-
|
|
170
130
|
if (decoderInputShape?.length != 3) {
|
|
171
131
|
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
172
|
-
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
|
|
132
|
+
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
173
133
|
}
|
|
174
|
-
|
|
175
134
|
// initialize causal self attention sub-block's weights
|
|
176
135
|
this.causalSelfAttention.build(decoderInputShape);
|
|
177
136
|
this.causalSelfAttentionNorm.build(this.causalSelfAttention.computeOutputShape(decoderInputShape));
|
|
178
|
-
|
|
179
137
|
// initialize feedforward sub-block's weights
|
|
180
138
|
const feedforward1OutputShape = this.feedforward1.computeOutputShape(decoderInputShape);
|
|
181
139
|
const feedforward2OutputShape = this.feedforward2.computeOutputShape(feedforward1OutputShape);
|
|
182
|
-
|
|
183
140
|
this.feedforward1.build(decoderInputShape);
|
|
184
141
|
this.feedforward2.build(feedforward1OutputShape);
|
|
185
142
|
this.feedFowardNorm.build(feedforward2OutputShape);
|
|
186
|
-
|
|
187
143
|
// track sublayers' weights
|
|
188
144
|
this.trainableWeights = [
|
|
189
145
|
...this.causalSelfAttention.trainableWeights,
|
|
@@ -194,28 +150,22 @@ export class TransformerDecoder extends tf.layers.Layer {
|
|
|
194
150
|
...this.feedForwardDropout.trainableWeights,
|
|
195
151
|
...this.feedFowardNorm.trainableWeights
|
|
196
152
|
];
|
|
197
|
-
|
|
198
153
|
// rename the weights otherwise they'll take on the default naming and overlap
|
|
199
154
|
// each other which breaks model loading due to duplicate weight names
|
|
200
155
|
let indexing = 0;
|
|
201
|
-
|
|
202
156
|
for (const weight of this.trainableWeights) {
|
|
203
157
|
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
204
|
-
|
|
205
|
-
|
|
158
|
+
weight.name += unique_name;
|
|
159
|
+
weight.originalName += unique_name;
|
|
206
160
|
indexing++;
|
|
207
161
|
}
|
|
208
|
-
|
|
209
162
|
super.build(inputShape);
|
|
210
163
|
}
|
|
211
|
-
|
|
212
|
-
|
|
213
164
|
/**
|
|
214
165
|
* Save the layer's hyperparameters for serialization
|
|
215
166
|
*/
|
|
216
|
-
|
|
167
|
+
getConfig() {
|
|
217
168
|
const base_config = super.getConfig();
|
|
218
|
-
|
|
219
169
|
const config = {
|
|
220
170
|
numHeads: this.numHeads,
|
|
221
171
|
embedDim: this.embedDim,
|
|
@@ -223,14 +173,10 @@ export class TransformerDecoder extends tf.layers.Layer {
|
|
|
223
173
|
dropout: this.dropout,
|
|
224
174
|
activation: this.activation,
|
|
225
175
|
dimsFeedForward: this.dimsFeedForward
|
|
226
|
-
}
|
|
227
|
-
|
|
176
|
+
};
|
|
228
177
|
Object.assign(config, base_config);
|
|
229
|
-
|
|
230
178
|
return config;
|
|
231
179
|
}
|
|
232
|
-
|
|
233
180
|
}
|
|
234
|
-
|
|
235
|
-
|
|
236
181
|
tf.serialization.registerClass(TransformerDecoder);
|
|
182
|
+
//# sourceMappingURL=transformer_decoder.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.js","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAKvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AAUxF;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAErB,mBAAmB,CAAkB;IACrC,0BAA0B,CAAkB;IAC5C,uBAAuB,CAAkB;IAEzC,YAAY,CAAkB;IAC9B,YAAY,CAAkB;IAC9B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAE3C,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QAEvC,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2FAA2F;QAC3F,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEvD,2BAA2B;QAC3B,IAAI,CAAC,mBAAmB,GAAG,IAAI,4BAA4B,CAAC;YACxD,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ;YAChD,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YAC5C,MAAM,EAAE,IAAI;SACf,CAAC,CAAC;QACH,IAAI,CAAC,0BAA0B,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QAC3E,IAAI,CAAC,uBAAuB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAE/E,yBAAyB;QACzB,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,eAAe;YAC3B,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,QAAQ;YACpB,UAAU,EAAE,QAAQ;YACpB,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACpE,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kCAAkC,MAAM,CAAC,MAAM,UAAU,CAAC,CAAC;QACpH,CAAC;QAED,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAc,CAAC;QACpC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAC3D,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAE/C,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,wBAAwB,CAAC,CAAY,EAAE,MAAc;QAC3D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACvE,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAE/E,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACrD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAE1E,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACvD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,MAAM,CAAC,iBAAiB,CAAC,GAAG,YAAY,CAAC;QAEzC,IAAI,iBAAiB,EAAE,MAAM,IAAI,CAAC,EAAE,CAAC;YACjC,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,uDAAuD;QACvD,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAClD,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,IAAI,CAAC,mBAAmB,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC;QAEnG,6CAA6C;QAC7C,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC;QACxF,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,uBAAuB,CAAC,CAAC;QAE9F,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAC3C,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QACjD,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QAEnD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,mBAAmB,CAAC,gBAAgB;YAC5C,GAAG,IAAI,CAAC,0BAA0B,CAAC,gBAAgB;YACnD,GAAG,IAAI,CAAC,uBAAuB,CAAC,gBAAgB;YAChD,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.test.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { TransformerDecoder } from '@/layers/transformer_decoder';
|
|
3
|
+
// disables warning for using the faster node backend,
|
|
4
|
+
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
5
|
+
tf.env().set('IS_NODE', false);
|
|
6
|
+
describe("TransformerDecoder tests", () => {
|
|
7
|
+
it("should return an output with the same shape as the input", () => {
|
|
8
|
+
const input = tf.randomUniform([2, 3, 12]);
|
|
9
|
+
const decoder = new TransformerDecoder({
|
|
10
|
+
numHeads: 2, embedDim: input.shape.at(-1),
|
|
11
|
+
dropout: 0.5, activation: "gelu", dimsFeedForward: 321, useBias: false
|
|
12
|
+
});
|
|
13
|
+
const output = decoder.apply(input);
|
|
14
|
+
expect(output.shape.length).toBe(input.shape.length);
|
|
15
|
+
});
|
|
16
|
+
test("forward calls", () => {
|
|
17
|
+
const input = tf.randomUniform([2, 3, 12]);
|
|
18
|
+
const mask = tf.randomUniform([input.shape[0], input.shape[1]], -1, 2, "bool");
|
|
19
|
+
const incorrect_mask = tf.randomUniform([2, 5, 12], -1, 2, "bool");
|
|
20
|
+
const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
21
|
+
expect(() => decoder.apply(input)).not.toThrow();
|
|
22
|
+
expect(() => decoder.apply([input])).not.toThrow();
|
|
23
|
+
// causal masking
|
|
24
|
+
const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
|
|
25
|
+
expect(() => causal.apply(input)).not.toThrow();
|
|
26
|
+
expect(() => causal.apply([input])).not.toThrow();
|
|
27
|
+
});
|
|
28
|
+
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
29
|
+
const input = tf.randomUniform([2, 3, 12]);
|
|
30
|
+
expect(() => new TransformerDecoder({ numHeads: 3, embedDim: input.shape.at(-1) })).not.toThrow();
|
|
31
|
+
expect(() => new TransformerDecoder({ numHeads: 5, embedDim: input.shape.at(-1) })).toThrow();
|
|
32
|
+
});
|
|
33
|
+
it("should not accept non-rank 3 tensor inputs", () => {
|
|
34
|
+
const embed_dim = 12;
|
|
35
|
+
const BAD_RANK4 = tf.randomUniform([2, 3, 12, embed_dim]);
|
|
36
|
+
const BAD_RANK2 = tf.randomUniform([2, embed_dim]);
|
|
37
|
+
const GOOD = tf.randomUniform([2, 3, embed_dim]);
|
|
38
|
+
const mask = tf.randomUniform([GOOD.shape[0], GOOD.shape[1]], -1, 2, "bool");
|
|
39
|
+
let decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
|
|
40
|
+
// BAD
|
|
41
|
+
expect(() => decoder.apply(BAD_RANK4)).toThrow();
|
|
42
|
+
expect(() => decoder.apply(BAD_RANK2)).toThrow();
|
|
43
|
+
// OK
|
|
44
|
+
decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
|
|
45
|
+
expect(() => decoder.apply(GOOD)).not.toThrow();
|
|
46
|
+
expect(() => decoder.apply([GOOD])).not.toThrow();
|
|
47
|
+
expect(() => decoder.apply([GOOD, mask])).not.toThrow();
|
|
48
|
+
});
|
|
49
|
+
it("should not accept inputs that are less or more than 1 and 2 tensors", () => {
|
|
50
|
+
const input = tf.randomUniform([2, 3, 12]);
|
|
51
|
+
let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
52
|
+
// OK
|
|
53
|
+
expect(() => decoder.apply(input)).not.toThrow();
|
|
54
|
+
expect(() => decoder.apply([input])).not.toThrow();
|
|
55
|
+
// BAD
|
|
56
|
+
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
57
|
+
expect(() => decoder.apply([])).toThrow(); // stops at build()
|
|
58
|
+
decoder.apply(input); // get past the initial build
|
|
59
|
+
expect(() => decoder.apply([input, input, input])).toThrow();
|
|
60
|
+
expect(() => decoder.apply([input, input, input, input])).toThrow();
|
|
61
|
+
// BAD (tests build())
|
|
62
|
+
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
63
|
+
expect(() => decoder.apply([input, input, input])).toThrow();
|
|
64
|
+
expect(() => decoder.apply([input, input, input, input])).toThrow();
|
|
65
|
+
});
|
|
66
|
+
it("should return a non-empty config dict", () => {
|
|
67
|
+
const input = tf.randomUniform([2, 3, 12]);
|
|
68
|
+
const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
69
|
+
expect(Object.keys(decoder.getConfig())).not.toBe(0);
|
|
70
|
+
});
|
|
71
|
+
});
|
|
72
|
+
//# sourceMappingURL=transformer_decoder.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.test.js","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,KAAK;SACzE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,GAAG,EAAE;QACvB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QACjF,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAGnE,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,iBAAiB;QACjB,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACnG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC;QAErB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,SAAS,CAAC,CAAC,CAAC;QAC1D,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACnD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACjD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAE/E,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEjD,KAAK;QACL,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QACvE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC5D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,qEAAqE,EAAE,GAAG,EAAE;QAC3E,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACrF,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM;QACN,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,mBAAmB;QAC9D,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,6BAA6B;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEpE,sBAAsB;QACtB,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACxE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
+
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
4
|
+
export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
|
|
5
|
+
activation?: "relu" | "gelu";
|
|
6
|
+
dimsFeedForward?: number;
|
|
7
|
+
}
|
|
8
|
+
/**
|
|
9
|
+
* This class implements the transformer encoder architecture from the 2017 paper
|
|
10
|
+
* Attention Is All You Need.
|
|
11
|
+
*
|
|
12
|
+
* This layer accepts exactly one tensor input with the shape
|
|
13
|
+
* `[ batch, sequences, embedding dims ]`.
|
|
14
|
+
*
|
|
15
|
+
* @param numHeads number of attention heads to use
|
|
16
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
17
|
+
* @param causal use causal masking, default `false` for encoders
|
|
18
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
19
|
+
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
20
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
21
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
22
|
+
*/
|
|
23
|
+
export declare class TransformerEncoder extends tf.layers.Layer {
|
|
24
|
+
static className: string;
|
|
25
|
+
private readonly selfAttention;
|
|
26
|
+
private readonly selfAttentionDropout;
|
|
27
|
+
private readonly selfAttentionNorm;
|
|
28
|
+
private readonly reluLayer;
|
|
29
|
+
private readonly linearLayer;
|
|
30
|
+
private readonly feedForwardDropout;
|
|
31
|
+
private readonly feedFowardNorm;
|
|
32
|
+
private readonly numHeads;
|
|
33
|
+
private readonly embedDim;
|
|
34
|
+
private readonly causal;
|
|
35
|
+
private readonly useBias;
|
|
36
|
+
private readonly dropout;
|
|
37
|
+
private readonly activation;
|
|
38
|
+
private readonly dimsFeedForward;
|
|
39
|
+
constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs);
|
|
40
|
+
/**
|
|
41
|
+
* Forward propagation
|
|
42
|
+
*/
|
|
43
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
44
|
+
private selfAttentionBlock;
|
|
45
|
+
private feedForwardBlock;
|
|
46
|
+
/**
|
|
47
|
+
* Initialize the sublayers' weights and track them to enable backpropagation.
|
|
48
|
+
*/
|
|
49
|
+
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
50
|
+
/**
|
|
51
|
+
* Save the layer's hyperparameters for serialization
|
|
52
|
+
*/
|
|
53
|
+
getConfig(): tf.serialization.ConfigDict;
|
|
54
|
+
}
|
|
55
|
+
//# sourceMappingURL=transformer_encoder.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_encoder.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,sBAAuB,SAAQ,sBAAsB;IAClE,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;;;;GAcG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,OAAO,CAAC,QAAQ,CAAC,aAAa,CAAkB;IAChD,OAAO,CAAC,QAAQ,CAAC,oBAAoB,CAAkB;IACvD,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAkB;IAEpD,OAAO,CAAC,QAAQ,CAAC,SAAS,CAAkB;IAC5C,OAAO,CAAC,QAAQ,CAAC,WAAW,CAAkB;IAC9C,OAAO,CAAC,QAAQ,CAAC,kBAAkB,CAAkB;IACrD,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAkB;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,MAAM,CAAU;IACjC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAU;IAClC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IACjC,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAuB;IAClD,OAAO,CAAC,QAAQ,CAAC,eAAe,CAAS;gBAG7B,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAqC1H;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF,OAAO,CAAC,kBAAkB;IAc1B,OAAO,CAAC,gBAAgB;IAexB;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAsDvD;;OAEG;IACM,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAiBpD"}
|