@stellarapp/tfjs-stellar 1.0.4 → 1.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -0
- package/dist/index.d.ts +2 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -1
- package/dist/index.js.map +1 -1
- package/dist/kv_cache.d.ts +2 -0
- package/dist/kv_cache.d.ts.map +1 -1
- package/dist/kv_cache.js +6 -0
- package/dist/kv_cache.js.map +1 -1
- package/dist/masks.test.d.ts +2 -0
- package/dist/masks.test.d.ts.map +1 -0
- package/dist/masks.test.js +55 -0
- package/dist/masks.test.js.map +1 -0
- package/dist/models/index.d.ts +2 -1
- package/dist/models/index.d.ts.map +1 -1
- package/dist/models/index.js +2 -1
- package/dist/models/index.js.map +1 -1
- package/dist/utils.test.js +0 -15
- package/dist/utils.test.js.map +1 -1
- package/package.json +1 -1
- package/dist/jest.config.d.ts +0 -8
- package/dist/jest.config.d.ts.map +0 -1
- package/dist/jest.config.js +0 -147
- package/dist/jest.config.js.map +0 -1
- package/dist/src/index.d.ts +0 -6
- package/dist/src/index.d.ts.map +0 -1
- package/dist/src/index.js +0 -6
- package/dist/src/index.js.map +0 -1
- package/dist/src/kv_cache.d.ts +0 -53
- package/dist/src/kv_cache.d.ts.map +0 -1
- package/dist/src/kv_cache.js +0 -135
- package/dist/src/kv_cache.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
- package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
- package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
- package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
- package/dist/src/layers/gpt_decoder_block.js +0 -51
- package/dist/src/layers/gpt_decoder_block.js.map +0 -1
- package/dist/src/layers/index.d.ts +0 -17
- package/dist/src/layers/index.d.ts.map +0 -1
- package/dist/src/layers/index.js +0 -33
- package/dist/src/layers/index.js.map +0 -1
- package/dist/src/layers/multihead_attention.d.ts +0 -106
- package/dist/src/layers/multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.js +0 -269
- package/dist/src/layers/multihead_attention.js.map +0 -1
- package/dist/src/layers/multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.test.js +0 -160
- package/dist/src/layers/multihead_attention.test.js.map +0 -1
- package/dist/src/layers/positional_encoding.d.ts +0 -37
- package/dist/src/layers/positional_encoding.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.js +0 -115
- package/dist/src/layers/positional_encoding.js.map +0 -1
- package/dist/src/layers/positional_encoding.test.d.ts +0 -2
- package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.test.js +0 -95
- package/dist/src/layers/positional_encoding.test.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
- package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.js +0 -99
- package/dist/src/layers/rotary_position_embedding.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.js +0 -88
- package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.js +0 -109
- package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
- package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
- package/dist/src/layers/transformer_decoder.d.ts +0 -69
- package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.js +0 -182
- package/dist/src/layers/transformer_decoder.js.map +0 -1
- package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.test.js +0 -72
- package/dist/src/layers/transformer_decoder.test.js.map +0 -1
- package/dist/src/layers/transformer_encoder.d.ts +0 -55
- package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.js +0 -175
- package/dist/src/layers/transformer_encoder.js.map +0 -1
- package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.test.js +0 -58
- package/dist/src/layers/transformer_encoder.test.js.map +0 -1
- package/dist/src/losses/dice.d.ts +0 -30
- package/dist/src/losses/dice.d.ts.map +0 -1
- package/dist/src/losses/dice.js +0 -93
- package/dist/src/losses/dice.js.map +0 -1
- package/dist/src/losses/index.d.ts +0 -2
- package/dist/src/losses/index.d.ts.map +0 -1
- package/dist/src/losses/index.js +0 -2
- package/dist/src/losses/index.js.map +0 -1
- package/dist/src/masks.d.ts +0 -20
- package/dist/src/masks.d.ts.map +0 -1
- package/dist/src/masks.js +0 -37
- package/dist/src/masks.js.map +0 -1
- package/dist/src/metrics.d.ts +0 -20
- package/dist/src/metrics.d.ts.map +0 -1
- package/dist/src/metrics.js +0 -28
- package/dist/src/metrics.js.map +0 -1
- package/dist/src/models/gpt_model.d.ts +0 -94
- package/dist/src/models/gpt_model.d.ts.map +0 -1
- package/dist/src/models/gpt_model.js +0 -154
- package/dist/src/models/gpt_model.js.map +0 -1
- package/dist/src/models/index.d.ts +0 -3
- package/dist/src/models/index.d.ts.map +0 -1
- package/dist/src/models/index.js +0 -3
- package/dist/src/models/index.js.map +0 -1
- package/dist/src/models/llm_model.d.ts +0 -87
- package/dist/src/models/llm_model.d.ts.map +0 -1
- package/dist/src/models/llm_model.js +0 -245
- package/dist/src/models/llm_model.js.map +0 -1
- package/dist/src/models/u_net.d.ts +0 -40
- package/dist/src/models/u_net.d.ts.map +0 -1
- package/dist/src/models/u_net.js +0 -151
- package/dist/src/models/u_net.js.map +0 -1
- package/dist/src/tfjs_types.d.ts +0 -10
- package/dist/src/tfjs_types.d.ts.map +0 -1
- package/dist/src/tfjs_types.js +0 -2
- package/dist/src/tfjs_types.js.map +0 -1
- package/dist/src/utils.d.ts +0 -28
- package/dist/src/utils.d.ts.map +0 -1
- package/dist/src/utils.js +0 -63
- package/dist/src/utils.js.map +0 -1
- package/dist/src/utils.test.d.ts +0 -2
- package/dist/src/utils.test.d.ts.map +0 -1
- package/dist/src/utils.test.js +0 -73
- package/dist/src/utils.test.js.map +0 -1
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { TokenAndPositionalEmbedding } from '@/layers/token_and_positional_embedding';
|
|
3
|
-
// disables warning for using the faster node backend,
|
|
4
|
-
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
5
|
-
tf.env().set('IS_NODE', false);
|
|
6
|
-
describe("PositionalEncoding tests", () => {
|
|
7
|
-
test("layer initialization", () => {
|
|
8
|
-
expect(() => new TokenAndPositionalEmbedding({ maxSequenceLength: 0, embedDim: 10, vocabularySize: 10_000 })).toThrow();
|
|
9
|
-
expect(() => new TokenAndPositionalEmbedding({ embedDim: 0, vocabularySize: 10_000 })).toThrow();
|
|
10
|
-
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 0 })).toThrow();
|
|
11
|
-
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
|
|
12
|
-
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
|
|
13
|
-
});
|
|
14
|
-
test("successfull forward calls", () => {
|
|
15
|
-
const embed_dims = 32;
|
|
16
|
-
const sequences = 4;
|
|
17
|
-
const vocab_size = 10_000;
|
|
18
|
-
const input = tf.randomUniform([2, sequences]);
|
|
19
|
-
const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, dropout: 0.1, vocabularySize: vocab_size });
|
|
20
|
-
expect(() => embedding.apply(input)).not.toThrow();
|
|
21
|
-
expect(() => embedding.apply([input])).not.toThrow();
|
|
22
|
-
});
|
|
23
|
-
test("layer build", () => {
|
|
24
|
-
const input_ok = tf.randomUniform([2, 4]);
|
|
25
|
-
const input_too_many_words = tf.randomUniform([2, 700]);
|
|
26
|
-
const input_is_image = tf.randomUniform([1, 32, 32, 3]);
|
|
27
|
-
let embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
28
|
-
expect(() => embedding.build(input_ok.shape)).not.toThrow();
|
|
29
|
-
embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
30
|
-
expect(() => embedding.build([input_ok.shape, input_ok.shape])).not.toThrow();
|
|
31
|
-
new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
32
|
-
expect(() => embedding.build(input_too_many_words.shape)).toThrow();
|
|
33
|
-
expect(() => embedding.build(input_is_image.shape)).toThrow();
|
|
34
|
-
});
|
|
35
|
-
it("should throw when more than one input provided, input sequences are too large, or incorrect input rank", () => {
|
|
36
|
-
const sequences_too_long = tf.randomUniform([10, 1000]);
|
|
37
|
-
const multiple_correct_inputs = [tf.randomUniform([2, 3]), tf.randomUniform([2, 3])];
|
|
38
|
-
const wrong_rank = tf.randomUniform([10, 32, 32]);
|
|
39
|
-
const positional = new TokenAndPositionalEmbedding({ maxSequenceLength: 10, embedDim: 32, vocabularySize: 10_000 });
|
|
40
|
-
positional.build([2, 3]); // get past the initial build call to test forward prop
|
|
41
|
-
expect(() => positional.apply(sequences_too_long)).toThrow();
|
|
42
|
-
expect(() => positional.apply(multiple_correct_inputs)).toThrow();
|
|
43
|
-
expect(() => positional.apply(wrong_rank)).toThrow();
|
|
44
|
-
});
|
|
45
|
-
it("should return a non-empty config dict", () => {
|
|
46
|
-
const embedding = new TokenAndPositionalEmbedding({ embedDim: 32, vocabularySize: 10_000 });
|
|
47
|
-
expect(Object.keys(embedding.getConfig())).not.toBe(0);
|
|
48
|
-
});
|
|
49
|
-
it("should return an output shape of [batch, sequences, embed dims]", () => {
|
|
50
|
-
const words = 100;
|
|
51
|
-
const batch = 2;
|
|
52
|
-
const embed_dims = 64;
|
|
53
|
-
const input = tf.randomUniform([batch, words]);
|
|
54
|
-
const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, vocabularySize: 10_000 });
|
|
55
|
-
expect(embedding.computeOutputShape(input.shape)).toEqual([batch, words, embed_dims]);
|
|
56
|
-
});
|
|
57
|
-
});
|
|
58
|
-
//# sourceMappingURL=token_and_positional_embedding.test.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"token_and_positional_embedding.test.js","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,2BAA2B,EAAE,MAAM,yCAAyC,CAAC;AAEtF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,IAAI,CAAC,sBAAsB,EAAE,GAAG,EAAE;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxH,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE7F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC1G,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,UAAU,GAAG,MAAM,CAAC;QAC1B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,cAAc,EAAE,UAAU,EAAE,CAAC,CAAC;QACtH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,aAAa,EAAE,GAAG,EAAE;QACrB,MAAM,QAAQ,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC;QACxD,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;QAExD,IAAI,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE5D,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,QAAQ,CAAC,KAAK,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE9E,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,oBAAoB,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAClE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,wGAAwG,EAAE,GAAG,EAAE;QAC9G,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAC,CAAC;QACxD,MAAM,uBAAuB,GAAG,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACrF,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QACpH,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,uDAAuD;QAEjF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAClE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAC5F,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iEAAiE,EAAE,GAAG,EAAE;QACvE,MAAM,KAAK,GAAG,GAAG,CAAC;QAClB,MAAM,KAAK,GAAG,CAAC,CAAC;QAChB,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAEpG,MAAM,CAAC,SAAS,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,UAAU,CAAC,CAAC,CAAC;IAC1F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
-
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
4
|
-
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
5
|
-
export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
|
|
6
|
-
activation?: "relu" | "gelu";
|
|
7
|
-
dimsFeedForward?: number;
|
|
8
|
-
causal?: boolean;
|
|
9
|
-
}
|
|
10
|
-
/**
|
|
11
|
-
* This class implements the transformer decoder architecture from
|
|
12
|
-
* the 2017 paper "Attention Is All You Need".
|
|
13
|
-
*
|
|
14
|
-
* This decoder-only transformer layer accepts one tensor input.
|
|
15
|
-
* The input tensor should have the shape
|
|
16
|
-
* `[ batch, sequences, embedding dims ]`.
|
|
17
|
-
*
|
|
18
|
-
* Causal masking is enabled by default for the initial attention sub-layer.
|
|
19
|
-
*
|
|
20
|
-
* @param numHeads number of attention heads to use
|
|
21
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
22
|
-
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
23
|
-
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
24
|
-
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
25
|
-
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
26
|
-
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
27
|
-
*/
|
|
28
|
-
export declare class TransformerDecoder extends tf.layers.Layer {
|
|
29
|
-
static className: string;
|
|
30
|
-
protected readonly causalSelfAttention: tf.layers.Layer;
|
|
31
|
-
protected readonly causalSelfAttentionDropout: tf.layers.Layer;
|
|
32
|
-
protected readonly causalSelfAttentionNorm: tf.layers.Layer;
|
|
33
|
-
protected readonly feedforward1: tf.layers.Layer;
|
|
34
|
-
protected readonly feedforward2: tf.layers.Layer;
|
|
35
|
-
protected readonly feedForwardDropout: tf.layers.Layer;
|
|
36
|
-
protected readonly feedFowardNorm: tf.layers.Layer;
|
|
37
|
-
protected readonly numHeads: number;
|
|
38
|
-
protected readonly embedDim: number;
|
|
39
|
-
protected readonly useBias: boolean;
|
|
40
|
-
protected readonly dropout: number;
|
|
41
|
-
protected readonly activation: ActivationIdentifier;
|
|
42
|
-
protected readonly dimsFeedForward: number;
|
|
43
|
-
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs);
|
|
44
|
-
/**
|
|
45
|
-
* Forward propagation
|
|
46
|
-
*
|
|
47
|
-
* @param inputs input tensor
|
|
48
|
-
* @return the output tensor
|
|
49
|
-
*/
|
|
50
|
-
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
51
|
-
protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
52
|
-
protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
53
|
-
/**
|
|
54
|
-
* Initialize the sublayers' weights and track them to enable serialization
|
|
55
|
-
*/
|
|
56
|
-
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
57
|
-
/**
|
|
58
|
-
* Save the layer's hyperparameters for serialization
|
|
59
|
-
*/
|
|
60
|
-
getConfig(): {
|
|
61
|
-
numHeads: number;
|
|
62
|
-
embedDim: number;
|
|
63
|
-
useBias: boolean;
|
|
64
|
-
dropout: number;
|
|
65
|
-
activation: ActivationIdentifier;
|
|
66
|
-
dimsFeedForward: number;
|
|
67
|
-
};
|
|
68
|
-
}
|
|
69
|
-
//# sourceMappingURL=transformer_decoder.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_decoder.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AACjE,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAExG,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAI3E,MAAM,WAAW,sBAAuB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAClF,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;IACzB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD;;;;;;;;;;;;;;;;;GAiBG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,SAAS,CAAC,QAAQ,CAAC,mBAAmB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACxD,SAAS,CAAC,QAAQ,CAAC,0BAA0B,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAC/D,SAAS,CAAC,QAAQ,CAAC,uBAAuB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAE5D,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,kBAAkB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACvD,SAAS,CAAC,QAAQ,CAAC,cAAc,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAEnD,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,UAAU,EAAE,oBAAoB,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,MAAM,CAAC;gBAE/B,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAyClH;;;;;OAKG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAoBvF,SAAS,CAAC,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAc3E,SAAS,CAAC,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAenE;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA6DvD;;OAEG;IACM,SAAS;;;;;;;;CAiBrB"}
|
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
|
|
3
|
-
/**
|
|
4
|
-
* This class implements the transformer decoder architecture from
|
|
5
|
-
* the 2017 paper "Attention Is All You Need".
|
|
6
|
-
*
|
|
7
|
-
* This decoder-only transformer layer accepts one tensor input.
|
|
8
|
-
* The input tensor should have the shape
|
|
9
|
-
* `[ batch, sequences, embedding dims ]`.
|
|
10
|
-
*
|
|
11
|
-
* Causal masking is enabled by default for the initial attention sub-layer.
|
|
12
|
-
*
|
|
13
|
-
* @param numHeads number of attention heads to use
|
|
14
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
15
|
-
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
16
|
-
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
17
|
-
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
18
|
-
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
19
|
-
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
20
|
-
*/
|
|
21
|
-
export class TransformerDecoder extends tf.layers.Layer {
|
|
22
|
-
static className = "TransformerDecoder";
|
|
23
|
-
causalSelfAttention;
|
|
24
|
-
causalSelfAttentionDropout;
|
|
25
|
-
causalSelfAttentionNorm;
|
|
26
|
-
feedforward1;
|
|
27
|
-
feedforward2;
|
|
28
|
-
feedForwardDropout;
|
|
29
|
-
feedFowardNorm;
|
|
30
|
-
numHeads;
|
|
31
|
-
embedDim;
|
|
32
|
-
useBias;
|
|
33
|
-
dropout;
|
|
34
|
-
activation;
|
|
35
|
-
dimsFeedForward;
|
|
36
|
-
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }) {
|
|
37
|
-
super(args);
|
|
38
|
-
this.numHeads = numHeads;
|
|
39
|
-
this.embedDim = embedDim;
|
|
40
|
-
this.useBias = useBias ?? true;
|
|
41
|
-
this.dropout = dropout ?? 0.1;
|
|
42
|
-
this.activation = activation ?? "relu";
|
|
43
|
-
if (this.dropout >= 1) {
|
|
44
|
-
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
45
|
-
}
|
|
46
|
-
// in the paper section 3.3, d_model=512 (embedDim) and first dense layer outputs d_ff=2048
|
|
47
|
-
this.dimsFeedForward = dimsFeedForward ?? embedDim * 4;
|
|
48
|
-
// self attention sub-block
|
|
49
|
-
this.causalSelfAttention = new CachedRoPEMultiHeadAttention({
|
|
50
|
-
numHeads: this.numHeads, embedDim: this.embedDim,
|
|
51
|
-
useBias: this.useBias, dropout: this.dropout,
|
|
52
|
-
causal: true
|
|
53
|
-
});
|
|
54
|
-
this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
|
|
55
|
-
this.causalSelfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
56
|
-
// feed forward sub-block
|
|
57
|
-
this.feedforward1 = tf.layers.dense({
|
|
58
|
-
units: this.dimsFeedForward,
|
|
59
|
-
activation: this.activation,
|
|
60
|
-
useBias: this.useBias,
|
|
61
|
-
});
|
|
62
|
-
this.feedforward2 = tf.layers.dense({
|
|
63
|
-
units: this.embedDim,
|
|
64
|
-
activation: "linear",
|
|
65
|
-
useBias: this.useBias
|
|
66
|
-
});
|
|
67
|
-
this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
|
|
68
|
-
this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
69
|
-
}
|
|
70
|
-
/**
|
|
71
|
-
* Forward propagation
|
|
72
|
-
*
|
|
73
|
-
* @param inputs input tensor
|
|
74
|
-
* @return the output tensor
|
|
75
|
-
*/
|
|
76
|
-
call(inputs, kwargs) {
|
|
77
|
-
// validate the input tensors
|
|
78
|
-
if (Array.isArray(inputs) && inputs.length != 1 && inputs.length != 2) {
|
|
79
|
-
throw Error(`${this.getClassName()}::call ${this.name} expects one input tensor, got ${inputs.length} inputs.`);
|
|
80
|
-
}
|
|
81
|
-
if (Array.isArray(inputs)) {
|
|
82
|
-
inputs = inputs[0];
|
|
83
|
-
}
|
|
84
|
-
// perform forward propagation
|
|
85
|
-
return tf.tidy(() => {
|
|
86
|
-
let output = this.causalSelfAttentionBlock(inputs, kwargs);
|
|
87
|
-
output = this.feedForwardBlock(output, kwargs);
|
|
88
|
-
return output;
|
|
89
|
-
});
|
|
90
|
-
}
|
|
91
|
-
causalSelfAttentionBlock(x, kwargs) {
|
|
92
|
-
return tf.tidy(() => {
|
|
93
|
-
const residual = x;
|
|
94
|
-
let attention = this.causalSelfAttention.apply(x, kwargs);
|
|
95
|
-
attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
|
|
96
|
-
attention = tf.add(attention, residual);
|
|
97
|
-
attention = this.causalSelfAttentionNorm.apply(attention, kwargs);
|
|
98
|
-
return attention;
|
|
99
|
-
});
|
|
100
|
-
}
|
|
101
|
-
feedForwardBlock(x, kwargs) {
|
|
102
|
-
return tf.tidy(() => {
|
|
103
|
-
const residual = x;
|
|
104
|
-
let feedForward = this.feedforward1.apply(x, kwargs);
|
|
105
|
-
feedForward = this.feedforward2.apply(feedForward, kwargs);
|
|
106
|
-
feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
|
|
107
|
-
feedForward = tf.add(feedForward, residual);
|
|
108
|
-
feedForward = this.feedFowardNorm.apply(feedForward, kwargs);
|
|
109
|
-
return feedForward;
|
|
110
|
-
});
|
|
111
|
-
}
|
|
112
|
-
/**
|
|
113
|
-
* Initialize the sublayers' weights and track them to enable serialization
|
|
114
|
-
*/
|
|
115
|
-
build(inputShape) {
|
|
116
|
-
let input_shapes = [];
|
|
117
|
-
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
118
|
-
// input is an array of shapes
|
|
119
|
-
input_shapes = inputShape;
|
|
120
|
-
}
|
|
121
|
-
else if (inputShape.length != 0) {
|
|
122
|
-
// input is a single shape
|
|
123
|
-
input_shapes = [inputShape];
|
|
124
|
-
}
|
|
125
|
-
if (input_shapes.length != 1 && input_shapes.length != 2) {
|
|
126
|
-
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
127
|
-
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
128
|
-
}
|
|
129
|
-
const [decoderInputShape] = input_shapes;
|
|
130
|
-
if (decoderInputShape?.length != 3) {
|
|
131
|
-
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
132
|
-
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
133
|
-
}
|
|
134
|
-
// initialize causal self attention sub-block's weights
|
|
135
|
-
this.causalSelfAttention.build(decoderInputShape);
|
|
136
|
-
this.causalSelfAttentionNorm.build(this.causalSelfAttention.computeOutputShape(decoderInputShape));
|
|
137
|
-
// initialize feedforward sub-block's weights
|
|
138
|
-
const feedforward1OutputShape = this.feedforward1.computeOutputShape(decoderInputShape);
|
|
139
|
-
const feedforward2OutputShape = this.feedforward2.computeOutputShape(feedforward1OutputShape);
|
|
140
|
-
this.feedforward1.build(decoderInputShape);
|
|
141
|
-
this.feedforward2.build(feedforward1OutputShape);
|
|
142
|
-
this.feedFowardNorm.build(feedforward2OutputShape);
|
|
143
|
-
// track sublayers' weights
|
|
144
|
-
this.trainableWeights = [
|
|
145
|
-
...this.causalSelfAttention.trainableWeights,
|
|
146
|
-
...this.causalSelfAttentionDropout.trainableWeights,
|
|
147
|
-
...this.causalSelfAttentionNorm.trainableWeights,
|
|
148
|
-
...this.feedforward1.trainableWeights,
|
|
149
|
-
...this.feedforward2.trainableWeights,
|
|
150
|
-
...this.feedForwardDropout.trainableWeights,
|
|
151
|
-
...this.feedFowardNorm.trainableWeights
|
|
152
|
-
];
|
|
153
|
-
// rename the weights otherwise they'll take on the default naming and overlap
|
|
154
|
-
// each other which breaks model loading due to duplicate weight names
|
|
155
|
-
let indexing = 0;
|
|
156
|
-
for (const weight of this.trainableWeights) {
|
|
157
|
-
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
158
|
-
weight.name += unique_name;
|
|
159
|
-
weight.originalName += unique_name;
|
|
160
|
-
indexing++;
|
|
161
|
-
}
|
|
162
|
-
super.build(inputShape);
|
|
163
|
-
}
|
|
164
|
-
/**
|
|
165
|
-
* Save the layer's hyperparameters for serialization
|
|
166
|
-
*/
|
|
167
|
-
getConfig() {
|
|
168
|
-
const base_config = super.getConfig();
|
|
169
|
-
const config = {
|
|
170
|
-
numHeads: this.numHeads,
|
|
171
|
-
embedDim: this.embedDim,
|
|
172
|
-
useBias: this.useBias,
|
|
173
|
-
dropout: this.dropout,
|
|
174
|
-
activation: this.activation,
|
|
175
|
-
dimsFeedForward: this.dimsFeedForward
|
|
176
|
-
};
|
|
177
|
-
Object.assign(config, base_config);
|
|
178
|
-
return config;
|
|
179
|
-
}
|
|
180
|
-
}
|
|
181
|
-
tf.serialization.registerClass(TransformerDecoder);
|
|
182
|
-
//# sourceMappingURL=transformer_decoder.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_decoder.js","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAKvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AAUxF;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAErB,mBAAmB,CAAkB;IACrC,0BAA0B,CAAkB;IAC5C,uBAAuB,CAAkB;IAEzC,YAAY,CAAkB;IAC9B,YAAY,CAAkB;IAC9B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAE3C,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QAEvC,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2FAA2F;QAC3F,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEvD,2BAA2B;QAC3B,IAAI,CAAC,mBAAmB,GAAG,IAAI,4BAA4B,CAAC;YACxD,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ;YAChD,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YAC5C,MAAM,EAAE,IAAI;SACf,CAAC,CAAC;QACH,IAAI,CAAC,0BAA0B,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QAC3E,IAAI,CAAC,uBAAuB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAE/E,yBAAyB;QACzB,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,eAAe;YAC3B,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,QAAQ;YACpB,UAAU,EAAE,QAAQ;YACpB,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACpE,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kCAAkC,MAAM,CAAC,MAAM,UAAU,CAAC,CAAC;QACpH,CAAC;QAED,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAc,CAAC;QACpC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAC3D,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAE/C,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,wBAAwB,CAAC,CAAY,EAAE,MAAc;QAC3D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACvE,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAE/E,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACrD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAE1E,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACvD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,MAAM,CAAC,iBAAiB,CAAC,GAAG,YAAY,CAAC;QAEzC,IAAI,iBAAiB,EAAE,MAAM,IAAI,CAAC,EAAE,CAAC;YACjC,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,uDAAuD;QACvD,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAClD,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,IAAI,CAAC,mBAAmB,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC;QAEnG,6CAA6C;QAC7C,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC;QACxF,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,uBAAuB,CAAC,CAAC;QAE9F,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAC3C,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QACjD,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QAEnD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,mBAAmB,CAAC,gBAAgB;YAC5C,GAAG,IAAI,CAAC,0BAA0B,CAAC,gBAAgB;YACnD,GAAG,IAAI,CAAC,uBAAuB,CAAC,gBAAgB;YAChD,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_decoder.test.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":""}
|
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { TransformerDecoder } from '@/layers/transformer_decoder';
|
|
3
|
-
// disables warning for using the faster node backend,
|
|
4
|
-
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
5
|
-
tf.env().set('IS_NODE', false);
|
|
6
|
-
describe("TransformerDecoder tests", () => {
|
|
7
|
-
it("should return an output with the same shape as the input", () => {
|
|
8
|
-
const input = tf.randomUniform([2, 3, 12]);
|
|
9
|
-
const decoder = new TransformerDecoder({
|
|
10
|
-
numHeads: 2, embedDim: input.shape.at(-1),
|
|
11
|
-
dropout: 0.5, activation: "gelu", dimsFeedForward: 321, useBias: false
|
|
12
|
-
});
|
|
13
|
-
const output = decoder.apply(input);
|
|
14
|
-
expect(output.shape.length).toBe(input.shape.length);
|
|
15
|
-
});
|
|
16
|
-
test("forward calls", () => {
|
|
17
|
-
const input = tf.randomUniform([2, 3, 12]);
|
|
18
|
-
const mask = tf.randomUniform([input.shape[0], input.shape[1]], -1, 2, "bool");
|
|
19
|
-
const incorrect_mask = tf.randomUniform([2, 5, 12], -1, 2, "bool");
|
|
20
|
-
const decoder = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
21
|
-
expect(() => decoder.apply(input)).not.toThrow();
|
|
22
|
-
expect(() => decoder.apply([input])).not.toThrow();
|
|
23
|
-
// causal masking
|
|
24
|
-
const causal = new TransformerDecoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
|
|
25
|
-
expect(() => causal.apply(input)).not.toThrow();
|
|
26
|
-
expect(() => causal.apply([input])).not.toThrow();
|
|
27
|
-
});
|
|
28
|
-
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
29
|
-
const input = tf.randomUniform([2, 3, 12]);
|
|
30
|
-
expect(() => new TransformerDecoder({ numHeads: 3, embedDim: input.shape.at(-1) })).not.toThrow();
|
|
31
|
-
expect(() => new TransformerDecoder({ numHeads: 5, embedDim: input.shape.at(-1) })).toThrow();
|
|
32
|
-
});
|
|
33
|
-
it("should not accept non-rank 3 tensor inputs", () => {
|
|
34
|
-
const embed_dim = 12;
|
|
35
|
-
const BAD_RANK4 = tf.randomUniform([2, 3, 12, embed_dim]);
|
|
36
|
-
const BAD_RANK2 = tf.randomUniform([2, embed_dim]);
|
|
37
|
-
const GOOD = tf.randomUniform([2, 3, embed_dim]);
|
|
38
|
-
const mask = tf.randomUniform([GOOD.shape[0], GOOD.shape[1]], -1, 2, "bool");
|
|
39
|
-
let decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
|
|
40
|
-
// BAD
|
|
41
|
-
expect(() => decoder.apply(BAD_RANK4)).toThrow();
|
|
42
|
-
expect(() => decoder.apply(BAD_RANK2)).toThrow();
|
|
43
|
-
// OK
|
|
44
|
-
decoder = new TransformerDecoder({ numHeads: 2, embedDim: embed_dim });
|
|
45
|
-
expect(() => decoder.apply(GOOD)).not.toThrow();
|
|
46
|
-
expect(() => decoder.apply([GOOD])).not.toThrow();
|
|
47
|
-
expect(() => decoder.apply([GOOD, mask])).not.toThrow();
|
|
48
|
-
});
|
|
49
|
-
it("should not accept inputs that are less or more than 1 and 2 tensors", () => {
|
|
50
|
-
const input = tf.randomUniform([2, 3, 12]);
|
|
51
|
-
let decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
52
|
-
// OK
|
|
53
|
-
expect(() => decoder.apply(input)).not.toThrow();
|
|
54
|
-
expect(() => decoder.apply([input])).not.toThrow();
|
|
55
|
-
// BAD
|
|
56
|
-
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
57
|
-
expect(() => decoder.apply([])).toThrow(); // stops at build()
|
|
58
|
-
decoder.apply(input); // get past the initial build
|
|
59
|
-
expect(() => decoder.apply([input, input, input])).toThrow();
|
|
60
|
-
expect(() => decoder.apply([input, input, input, input])).toThrow();
|
|
61
|
-
// BAD (tests build())
|
|
62
|
-
decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
63
|
-
expect(() => decoder.apply([input, input, input])).toThrow();
|
|
64
|
-
expect(() => decoder.apply([input, input, input, input])).toThrow();
|
|
65
|
-
});
|
|
66
|
-
it("should return a non-empty config dict", () => {
|
|
67
|
-
const input = tf.randomUniform([2, 3, 12]);
|
|
68
|
-
const decoder = new TransformerDecoder({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
69
|
-
expect(Object.keys(decoder.getConfig())).not.toBe(0);
|
|
70
|
-
});
|
|
71
|
-
});
|
|
72
|
-
//# sourceMappingURL=transformer_decoder.test.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_decoder.test.js","sourceRoot":"","sources":["../../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,KAAK;SACzE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,GAAG,EAAE;QACvB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QACjF,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAGnE,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,iBAAiB;QACjB,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACnG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC;QAErB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,SAAS,CAAC,CAAC,CAAC;QAC1D,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACnD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QACjD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAE/E,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEjD,KAAK;QACL,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,CAAC,CAAC;QACvE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC5D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,qEAAqE,EAAE,GAAG,EAAE;QAC3E,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACrF,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM;QACN,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,mBAAmB;QAC9D,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,6BAA6B;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEpE,sBAAsB;QACtB,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACxE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
|
|
@@ -1,55 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
-
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
4
|
-
export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
|
|
5
|
-
activation?: "relu" | "gelu";
|
|
6
|
-
dimsFeedForward?: number;
|
|
7
|
-
}
|
|
8
|
-
/**
|
|
9
|
-
* This class implements the transformer encoder architecture from the 2017 paper
|
|
10
|
-
* Attention Is All You Need.
|
|
11
|
-
*
|
|
12
|
-
* This layer accepts exactly one tensor input with the shape
|
|
13
|
-
* `[ batch, sequences, embedding dims ]`.
|
|
14
|
-
*
|
|
15
|
-
* @param numHeads number of attention heads to use
|
|
16
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
17
|
-
* @param causal use causal masking, default `false` for encoders
|
|
18
|
-
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
19
|
-
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
20
|
-
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
21
|
-
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
22
|
-
*/
|
|
23
|
-
export declare class TransformerEncoder extends tf.layers.Layer {
|
|
24
|
-
static className: string;
|
|
25
|
-
private readonly selfAttention;
|
|
26
|
-
private readonly selfAttentionDropout;
|
|
27
|
-
private readonly selfAttentionNorm;
|
|
28
|
-
private readonly reluLayer;
|
|
29
|
-
private readonly linearLayer;
|
|
30
|
-
private readonly feedForwardDropout;
|
|
31
|
-
private readonly feedFowardNorm;
|
|
32
|
-
private readonly numHeads;
|
|
33
|
-
private readonly embedDim;
|
|
34
|
-
private readonly causal;
|
|
35
|
-
private readonly useBias;
|
|
36
|
-
private readonly dropout;
|
|
37
|
-
private readonly activation;
|
|
38
|
-
private readonly dimsFeedForward;
|
|
39
|
-
constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs);
|
|
40
|
-
/**
|
|
41
|
-
* Forward propagation
|
|
42
|
-
*/
|
|
43
|
-
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
44
|
-
private selfAttentionBlock;
|
|
45
|
-
private feedForwardBlock;
|
|
46
|
-
/**
|
|
47
|
-
* Initialize the sublayers' weights and track them to enable backpropagation.
|
|
48
|
-
*/
|
|
49
|
-
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
50
|
-
/**
|
|
51
|
-
* Save the layer's hyperparameters for serialization
|
|
52
|
-
*/
|
|
53
|
-
getConfig(): tf.serialization.ConfigDict;
|
|
54
|
-
}
|
|
55
|
-
//# sourceMappingURL=transformer_encoder.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"transformer_encoder.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,sBAAuB,SAAQ,sBAAsB;IAClE,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;CAC5B;AAGD;;;;;;;;;;;;;;GAcG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,OAAO,CAAC,QAAQ,CAAC,aAAa,CAAkB;IAChD,OAAO,CAAC,QAAQ,CAAC,oBAAoB,CAAkB;IACvD,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAkB;IAEpD,OAAO,CAAC,QAAQ,CAAC,SAAS,CAAkB;IAC5C,OAAO,CAAC,QAAQ,CAAC,WAAW,CAAkB;IAC9C,OAAO,CAAC,QAAQ,CAAC,kBAAkB,CAAkB;IACrD,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAkB;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,MAAM,CAAU;IACjC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAU;IAClC,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IACjC,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAuB;IAClD,OAAO,CAAC,QAAQ,CAAC,eAAe,CAAS;gBAG7B,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAqC1H;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF,OAAO,CAAC,kBAAkB;IAc1B,OAAO,CAAC,gBAAgB;IAexB;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAsDvD;;OAEG;IACM,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAiBpD"}
|