@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,24 +1,13 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
3
|
-
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
4
|
-
|
|
5
|
-
import { PositionalEncoding, type PositionalEncodingArgs } from '@/layers/positional_encoding';
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEncodingArgs {
|
|
9
|
-
vocabularySize: number;
|
|
10
|
-
dropout?: number
|
|
11
|
-
}
|
|
12
|
-
|
|
13
|
-
|
|
2
|
+
import { PositionalEncoding } from '../layers/positional_encoding';
|
|
14
3
|
/**
|
|
15
4
|
* This class implements combines sinusoidal positional encoding from the
|
|
16
5
|
* 2017 paper "Attention Is All You Need" with a normal embedding layer to
|
|
17
6
|
* form a simplified single embedding layer.
|
|
18
|
-
*
|
|
7
|
+
*
|
|
19
8
|
* This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
|
|
20
9
|
* it through an embedding layer before adding sinusoidal positional encoding.
|
|
21
|
-
*
|
|
10
|
+
*
|
|
22
11
|
* @param embedDim the size of each token/word's embedding
|
|
23
12
|
* @param vocabularySize the number of tokens to embed
|
|
24
13
|
* @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
|
|
@@ -26,124 +15,95 @@ export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEn
|
|
|
26
15
|
*/
|
|
27
16
|
export class TokenAndPositionalEmbedding extends tf.layers.Layer {
|
|
28
17
|
static className = "TokenAndPositionalEmbedding";
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
private dropoutLayer: tf.layers.Layer;
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }: TokenAndPositionalEmbeddingArgs) {
|
|
18
|
+
embedDim;
|
|
19
|
+
vocabularySize;
|
|
20
|
+
embedding;
|
|
21
|
+
positional;
|
|
22
|
+
maxSequenceLength;
|
|
23
|
+
dropout;
|
|
24
|
+
dropoutLayer;
|
|
25
|
+
constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }) {
|
|
42
26
|
super(args);
|
|
43
|
-
|
|
44
27
|
this.embedDim = embedDim;
|
|
45
28
|
this.vocabularySize = vocabularySize;
|
|
46
29
|
this.maxSequenceLength = maxSequenceLength ?? 5120;
|
|
47
30
|
this.dropout = dropout ?? 0.1;
|
|
48
|
-
|
|
49
31
|
if (this.dropout >= 1) {
|
|
50
32
|
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
51
33
|
}
|
|
52
|
-
|
|
53
34
|
this.embedding = tf.layers.embedding({
|
|
54
35
|
inputDim: this.vocabularySize,
|
|
55
36
|
outputDim: this.embedDim,
|
|
56
37
|
});
|
|
57
|
-
|
|
58
38
|
this.positional = new PositionalEncoding({
|
|
59
39
|
maxSequenceLength: this.maxSequenceLength,
|
|
60
40
|
embedDim: this.embedDim,
|
|
61
41
|
});
|
|
62
|
-
|
|
63
42
|
this.dropoutLayer = tf.layers.dropout({ rate: this.dropout });
|
|
64
43
|
}
|
|
65
|
-
|
|
66
|
-
|
|
67
44
|
/**
|
|
68
|
-
* Forward propagation.
|
|
45
|
+
* Forward propagation.
|
|
69
46
|
*/
|
|
70
|
-
|
|
47
|
+
call(inputs, kwargs) {
|
|
71
48
|
if (Array.isArray(inputs) && inputs.length != 1) {
|
|
72
49
|
throw Error(`${this.getClassName()}::call ${this.name} expects exactly` +
|
|
73
50
|
` 1 tensor input, received ${inputs.length}`);
|
|
74
51
|
}
|
|
75
|
-
|
|
76
52
|
return tf.tidy(() => {
|
|
77
|
-
let output = this.positional.apply(this.embedding.apply(inputs))
|
|
78
|
-
output = this.dropoutLayer.apply(output)
|
|
79
|
-
|
|
53
|
+
let output = this.positional.apply(this.embedding.apply(inputs));
|
|
54
|
+
output = this.dropoutLayer.apply(output);
|
|
80
55
|
return output;
|
|
81
|
-
})
|
|
56
|
+
});
|
|
82
57
|
}
|
|
83
|
-
|
|
84
|
-
|
|
85
58
|
/**
|
|
86
59
|
* Build the sublayers and enable serialization
|
|
87
60
|
*/
|
|
88
|
-
|
|
89
|
-
let input_shapes
|
|
90
|
-
|
|
61
|
+
build(inputShape) {
|
|
62
|
+
let input_shapes = [];
|
|
91
63
|
// only consider the first shape if multiple provided
|
|
92
64
|
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
93
65
|
// input is an array of shapes
|
|
94
|
-
input_shapes = inputShape
|
|
95
|
-
}
|
|
66
|
+
input_shapes = inputShape;
|
|
67
|
+
}
|
|
68
|
+
else if (inputShape.length != 0) {
|
|
96
69
|
// input is a single shape
|
|
97
|
-
input_shapes = [inputShape
|
|
70
|
+
input_shapes = [inputShape];
|
|
98
71
|
}
|
|
99
|
-
|
|
100
|
-
if (input_shapes[0].length != 2 || input_shapes[0][1]! > this.maxSequenceLength) {
|
|
72
|
+
if (input_shapes[0].length != 2 || input_shapes[0][1] > this.maxSequenceLength) {
|
|
101
73
|
throw Error(`${this.getClassName()}::build ${this.name} expected an input of` +
|
|
102
74
|
` shape [batch, tokens] where tokens < ${this.maxSequenceLength},` +
|
|
103
75
|
` received ${JSON.stringify(input_shapes[0])}`);
|
|
104
76
|
}
|
|
105
|
-
|
|
106
77
|
// initialize the sublayers' weights
|
|
107
78
|
this.embedding.build(input_shapes[0]);
|
|
108
79
|
this.positional.build(this.embedding.computeOutputShape(input_shapes[0]));
|
|
109
|
-
|
|
110
80
|
// no need to rename weights, haven't found a case where their names collide
|
|
111
81
|
this.trainableWeights = [
|
|
112
82
|
...this.embedding.trainableWeights,
|
|
113
83
|
...this.positional.trainableWeights
|
|
114
84
|
];
|
|
115
|
-
|
|
116
85
|
super.build(input_shapes[0]);
|
|
117
86
|
}
|
|
118
|
-
|
|
119
|
-
|
|
120
87
|
/**
|
|
121
88
|
* The output shape, for an input shape of [batch, sequences], is
|
|
122
89
|
* [batch, sequences, embedDim]
|
|
123
90
|
*/
|
|
124
|
-
|
|
91
|
+
computeOutputShape(inputShape) {
|
|
125
92
|
const embedding_shape = this.embedding.computeOutputShape(inputShape);
|
|
126
93
|
const positional_shape = this.positional.computeOutputShape(embedding_shape);
|
|
127
|
-
|
|
128
94
|
return positional_shape;
|
|
129
95
|
}
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
override getConfig(): tf.serialization.ConfigDict {
|
|
96
|
+
getConfig() {
|
|
133
97
|
const base_config = super.getConfig();
|
|
134
|
-
|
|
135
98
|
const config = {
|
|
136
99
|
embedDim: this.embedDim,
|
|
137
100
|
vocabularySize: this.vocabularySize,
|
|
138
101
|
maxSequenceLength: this.maxSequenceLength,
|
|
139
102
|
dropout: this.dropout,
|
|
140
|
-
}
|
|
141
|
-
|
|
103
|
+
};
|
|
142
104
|
Object.assign(config, base_config);
|
|
143
|
-
|
|
144
105
|
return config;
|
|
145
106
|
}
|
|
146
107
|
}
|
|
147
|
-
|
|
148
|
-
|
|
149
108
|
tf.serialization.registerClass(TokenAndPositionalEmbedding);
|
|
109
|
+
//# sourceMappingURL=token_and_positional_embedding.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"token_and_positional_embedding.js","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,+BAA+B,CAAC;AAShG;;;;;;;;;;;;GAYG;AACH,MAAM,OAAO,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,GAAG,6BAA6B,CAAC;IAEhC,QAAQ,CAAS;IACjB,cAAc,CAAS;IAChC,SAAS,CAAkB;IAE3B,UAAU,CAAiB;IAClB,iBAAiB,CAAS;IAC1B,OAAO,CAAS;IAEzB,YAAY,CAAkB;IAGtC,YAAY,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAmC;QAC1G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,cAAc,GAAG,cAAc,CAAC;QACrC,IAAI,CAAC,iBAAiB,GAAG,iBAAiB,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAE9B,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,SAAS,CAAC;YACjC,QAAQ,EAAE,IAAI,CAAC,cAAc;YAC7B,SAAS,EAAE,IAAI,CAAC,QAAQ;SAC3B,CAAC,CAAC;QAEH,IAAI,CAAC,UAAU,GAAG,IAAI,kBAAkB,CAAC;YACrC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAC,CAAC;QAEH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;IAClE,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kBAAkB;gBACnE,6BAA6B,MAAM,CAAC,MAAM,EAAE,CAAC,CAAC;QACtD,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,MAAM,CAAC,CAAc,CAAC;YAC9E,MAAM,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,MAAM,CAAc,CAAC;YAEtD,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,qDAAqD;QACrD,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAE,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YAC9E,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,uBAAuB;gBACzE,yCAAyC,IAAI,CAAC,iBAAiB,GAAG;gBAClE,aAAa,IAAI,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;QACxD,CAAC;QAED,oCAAoC;QACpC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAE1E,4EAA4E;QAC5E,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,UAAU,CAAC,gBAAgB;SACtC,CAAC;QAEF,KAAK,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;IACjC,CAAC;IAGD;;;OAGG;IACM,kBAAkB,CAAC,UAAiC;QACzD,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QACtE,MAAM,gBAAgB,GAAG,IAAI,CAAC,UAAU,CAAC,kBAAkB,CAAC,eAAe,CAAC,CAAC;QAE7E,OAAO,gBAAgB,CAAC;IAC5B,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,2BAA2B,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"token_and_positional_embedding.test.d.ts","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":""}
|
|
@@ -1,81 +1,58 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
|
|
3
2
|
import { TokenAndPositionalEmbedding } from '@/layers/token_and_positional_embedding';
|
|
4
|
-
|
|
5
3
|
// disables warning for using the faster node backend,
|
|
6
4
|
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
7
5
|
tf.env().set('IS_NODE', false);
|
|
8
|
-
|
|
9
|
-
|
|
10
6
|
describe("PositionalEncoding tests", () => {
|
|
11
7
|
test("layer initialization", () => {
|
|
12
8
|
expect(() => new TokenAndPositionalEmbedding({ maxSequenceLength: 0, embedDim: 10, vocabularySize: 10_000 })).toThrow();
|
|
13
9
|
expect(() => new TokenAndPositionalEmbedding({ embedDim: 0, vocabularySize: 10_000 })).toThrow();
|
|
14
10
|
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 0 })).toThrow();
|
|
15
|
-
|
|
16
11
|
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
|
|
17
12
|
expect(() => new TokenAndPositionalEmbedding({ embedDim: 10, vocabularySize: 10_000 })).not.toThrow();
|
|
18
|
-
})
|
|
19
|
-
|
|
20
|
-
|
|
13
|
+
});
|
|
21
14
|
test("successfull forward calls", () => {
|
|
22
15
|
const embed_dims = 32;
|
|
23
16
|
const sequences = 4;
|
|
24
17
|
const vocab_size = 10_000;
|
|
25
18
|
const input = tf.randomUniform([2, sequences]);
|
|
26
|
-
|
|
27
19
|
const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, dropout: 0.1, vocabularySize: vocab_size });
|
|
28
20
|
expect(() => embedding.apply(input)).not.toThrow();
|
|
29
21
|
expect(() => embedding.apply([input])).not.toThrow();
|
|
30
|
-
})
|
|
31
|
-
|
|
32
|
-
|
|
22
|
+
});
|
|
33
23
|
test("layer build", () => {
|
|
34
24
|
const input_ok = tf.randomUniform([2, 4]);
|
|
35
25
|
const input_too_many_words = tf.randomUniform([2, 700]);
|
|
36
26
|
const input_is_image = tf.randomUniform([1, 32, 32, 3]);
|
|
37
|
-
|
|
38
27
|
let embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
39
28
|
expect(() => embedding.build(input_ok.shape)).not.toThrow();
|
|
40
|
-
|
|
41
29
|
embedding = new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
42
30
|
expect(() => embedding.build([input_ok.shape, input_ok.shape])).not.toThrow();
|
|
43
|
-
|
|
44
31
|
new TokenAndPositionalEmbedding({ embedDim: 32, maxSequenceLength: 500, vocabularySize: 1_000 });
|
|
45
32
|
expect(() => embedding.build(input_too_many_words.shape)).toThrow();
|
|
46
33
|
expect(() => embedding.build(input_is_image.shape)).toThrow();
|
|
47
|
-
})
|
|
48
|
-
|
|
49
|
-
|
|
34
|
+
});
|
|
50
35
|
it("should throw when more than one input provided, input sequences are too large, or incorrect input rank", () => {
|
|
51
36
|
const sequences_too_long = tf.randomUniform([10, 1000]);
|
|
52
37
|
const multiple_correct_inputs = [tf.randomUniform([2, 3]), tf.randomUniform([2, 3])];
|
|
53
38
|
const wrong_rank = tf.randomUniform([10, 32, 32]);
|
|
54
|
-
|
|
55
39
|
const positional = new TokenAndPositionalEmbedding({ maxSequenceLength: 10, embedDim: 32, vocabularySize: 10_000 });
|
|
56
40
|
positional.build([2, 3]); // get past the initial build call to test forward prop
|
|
57
|
-
|
|
58
41
|
expect(() => positional.apply(sequences_too_long)).toThrow();
|
|
59
42
|
expect(() => positional.apply(multiple_correct_inputs)).toThrow();
|
|
60
43
|
expect(() => positional.apply(wrong_rank)).toThrow();
|
|
61
|
-
})
|
|
62
|
-
|
|
63
|
-
|
|
44
|
+
});
|
|
64
45
|
it("should return a non-empty config dict", () => {
|
|
65
46
|
const embedding = new TokenAndPositionalEmbedding({ embedDim: 32, vocabularySize: 10_000 });
|
|
66
47
|
expect(Object.keys(embedding.getConfig())).not.toBe(0);
|
|
67
|
-
})
|
|
68
|
-
|
|
69
|
-
|
|
48
|
+
});
|
|
70
49
|
it("should return an output shape of [batch, sequences, embed dims]", () => {
|
|
71
50
|
const words = 100;
|
|
72
51
|
const batch = 2;
|
|
73
52
|
const embed_dims = 64;
|
|
74
|
-
|
|
75
53
|
const input = tf.randomUniform([batch, words]);
|
|
76
|
-
|
|
77
54
|
const embedding = new TokenAndPositionalEmbedding({ embedDim: embed_dims, vocabularySize: 10_000 });
|
|
78
|
-
|
|
79
55
|
expect(embedding.computeOutputShape(input.shape)).toEqual([batch, words, embed_dims]);
|
|
80
|
-
})
|
|
56
|
+
});
|
|
81
57
|
});
|
|
58
|
+
//# sourceMappingURL=token_and_positional_embedding.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"token_and_positional_embedding.test.js","sourceRoot":"","sources":["../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,2BAA2B,EAAE,MAAM,yCAAyC,CAAC;AAEtF,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,IAAI,CAAC,sBAAsB,EAAE,GAAG,EAAE;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxH,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE7F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtG,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC1G,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,UAAU,GAAG,MAAM,CAAC;QAC1B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,cAAc,EAAE,UAAU,EAAE,CAAC,CAAC;QACtH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,aAAa,EAAE,GAAG,EAAE;QACrB,MAAM,QAAQ,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC;QACxD,MAAM,cAAc,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;QAExD,IAAI,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE5D,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,QAAQ,CAAC,KAAK,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAE9E,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,iBAAiB,EAAE,GAAG,EAAE,cAAc,EAAE,KAAK,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,oBAAoB,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAClE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,wGAAwG,EAAE,GAAG,EAAE;QAC9G,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAC,CAAC;QACxD,MAAM,uBAAuB,GAAG,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACrF,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,2BAA2B,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QACpH,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,uDAAuD;QAEjF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAClE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAC5F,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iEAAiE,EAAE,GAAG,EAAE;QACvE,MAAM,KAAK,GAAG,GAAG,CAAC;QAClB,MAAM,KAAK,GAAG,CAAC,CAAC;QAChB,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC;QAE/C,MAAM,SAAS,GAAG,IAAI,2BAA2B,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,cAAc,EAAE,MAAM,EAAE,CAAC,CAAC;QAEpG,MAAM,CAAC,SAAS,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,UAAU,CAAC,CAAC,CAAC;IAC1F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
+
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
4
|
+
import { type MultiHeadAttentionArgs } from "../layers/multihead_attention";
|
|
5
|
+
export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
|
|
6
|
+
activation?: "relu" | "gelu";
|
|
7
|
+
dimsFeedForward?: number;
|
|
8
|
+
causal?: boolean;
|
|
9
|
+
}
|
|
10
|
+
/**
|
|
11
|
+
* This class implements the transformer decoder architecture from
|
|
12
|
+
* the 2017 paper "Attention Is All You Need".
|
|
13
|
+
*
|
|
14
|
+
* This decoder-only transformer layer accepts one tensor input.
|
|
15
|
+
* The input tensor should have the shape
|
|
16
|
+
* `[ batch, sequences, embedding dims ]`.
|
|
17
|
+
*
|
|
18
|
+
* Causal masking is enabled by default for the initial attention sub-layer.
|
|
19
|
+
*
|
|
20
|
+
* @param numHeads number of attention heads to use
|
|
21
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
22
|
+
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
23
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
24
|
+
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
25
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
26
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
27
|
+
*/
|
|
28
|
+
export declare class TransformerDecoder extends tf.layers.Layer {
|
|
29
|
+
static className: string;
|
|
30
|
+
protected readonly causalSelfAttention: tf.layers.Layer;
|
|
31
|
+
protected readonly causalSelfAttentionDropout: tf.layers.Layer;
|
|
32
|
+
protected readonly causalSelfAttentionNorm: tf.layers.Layer;
|
|
33
|
+
protected readonly feedforward1: tf.layers.Layer;
|
|
34
|
+
protected readonly feedforward2: tf.layers.Layer;
|
|
35
|
+
protected readonly feedForwardDropout: tf.layers.Layer;
|
|
36
|
+
protected readonly feedFowardNorm: tf.layers.Layer;
|
|
37
|
+
protected readonly numHeads: number;
|
|
38
|
+
protected readonly embedDim: number;
|
|
39
|
+
protected readonly useBias: boolean;
|
|
40
|
+
protected readonly dropout: number;
|
|
41
|
+
protected readonly activation: ActivationIdentifier;
|
|
42
|
+
protected readonly dimsFeedForward: number;
|
|
43
|
+
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs);
|
|
44
|
+
/**
|
|
45
|
+
* Forward propagation
|
|
46
|
+
*
|
|
47
|
+
* @param inputs input tensor
|
|
48
|
+
* @return the output tensor
|
|
49
|
+
*/
|
|
50
|
+
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
51
|
+
protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
52
|
+
protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor;
|
|
53
|
+
/**
|
|
54
|
+
* Initialize the sublayers' weights and track them to enable serialization
|
|
55
|
+
*/
|
|
56
|
+
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
57
|
+
/**
|
|
58
|
+
* Save the layer's hyperparameters for serialization
|
|
59
|
+
*/
|
|
60
|
+
getConfig(): {
|
|
61
|
+
numHeads: number;
|
|
62
|
+
embedDim: number;
|
|
63
|
+
useBias: boolean;
|
|
64
|
+
dropout: number;
|
|
65
|
+
activation: ActivationIdentifier;
|
|
66
|
+
dimsFeedForward: number;
|
|
67
|
+
};
|
|
68
|
+
}
|
|
69
|
+
//# sourceMappingURL=transformer_decoder.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AACjE,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAExG,OAAO,EAAE,KAAK,sBAAsB,EAAE,MAAM,+BAA+B,CAAC;AAI5E,MAAM,WAAW,sBAAuB,SAAQ,IAAI,CAAC,sBAAsB,EAAE,QAAQ,CAAC;IAClF,UAAU,CAAC,EAAE,MAAM,GAAG,MAAM,CAAC;IAC7B,eAAe,CAAC,EAAE,MAAM,CAAC;IACzB,MAAM,CAAC,EAAE,OAAO,CAAC;CACpB;AAGD;;;;;;;;;;;;;;;;;GAiBG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IAExC,SAAS,CAAC,QAAQ,CAAC,mBAAmB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACxD,SAAS,CAAC,QAAQ,CAAC,0BAA0B,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAC/D,SAAS,CAAC,QAAQ,CAAC,uBAAuB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAE5D,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,YAAY,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACjD,SAAS,CAAC,QAAQ,CAAC,kBAAkB,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IACvD,SAAS,CAAC,QAAQ,CAAC,cAAc,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;IAEnD,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,OAAO,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,OAAO,EAAE,MAAM,CAAC;IACnC,SAAS,CAAC,QAAQ,CAAC,UAAU,EAAE,oBAAoB,CAAC;IACpD,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,MAAM,CAAC;gBAE/B,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAAE,EAAE,sBAAsB;IAyClH;;;;;OAKG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAoBvF,SAAS,CAAC,wBAAwB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAc3E,SAAS,CAAC,gBAAgB,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAenE;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA6DvD;;OAEG;IACM,SAAS;;;;;;;;CAiBrB"}
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { CachedRoPEMultiHeadAttention } from "../layers/cached_rope_multihead_attention";
|
|
3
|
+
/**
|
|
4
|
+
* This class implements the transformer decoder architecture from
|
|
5
|
+
* the 2017 paper "Attention Is All You Need".
|
|
6
|
+
*
|
|
7
|
+
* This decoder-only transformer layer accepts one tensor input.
|
|
8
|
+
* The input tensor should have the shape
|
|
9
|
+
* `[ batch, sequences, embedding dims ]`.
|
|
10
|
+
*
|
|
11
|
+
* Causal masking is enabled by default for the initial attention sub-layer.
|
|
12
|
+
*
|
|
13
|
+
* @param numHeads number of attention heads to use
|
|
14
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
15
|
+
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
16
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
17
|
+
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
18
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
19
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
20
|
+
*/
|
|
21
|
+
export class TransformerDecoder extends tf.layers.Layer {
|
|
22
|
+
static className = "TransformerDecoder";
|
|
23
|
+
causalSelfAttention;
|
|
24
|
+
causalSelfAttentionDropout;
|
|
25
|
+
causalSelfAttentionNorm;
|
|
26
|
+
feedforward1;
|
|
27
|
+
feedforward2;
|
|
28
|
+
feedForwardDropout;
|
|
29
|
+
feedFowardNorm;
|
|
30
|
+
numHeads;
|
|
31
|
+
embedDim;
|
|
32
|
+
useBias;
|
|
33
|
+
dropout;
|
|
34
|
+
activation;
|
|
35
|
+
dimsFeedForward;
|
|
36
|
+
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }) {
|
|
37
|
+
super(args);
|
|
38
|
+
this.numHeads = numHeads;
|
|
39
|
+
this.embedDim = embedDim;
|
|
40
|
+
this.useBias = useBias ?? true;
|
|
41
|
+
this.dropout = dropout ?? 0.1;
|
|
42
|
+
this.activation = activation ?? "relu";
|
|
43
|
+
if (this.dropout >= 1) {
|
|
44
|
+
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
45
|
+
}
|
|
46
|
+
// in the paper section 3.3, d_model=512 (embedDim) and first dense layer outputs d_ff=2048
|
|
47
|
+
this.dimsFeedForward = dimsFeedForward ?? embedDim * 4;
|
|
48
|
+
// self attention sub-block
|
|
49
|
+
this.causalSelfAttention = new CachedRoPEMultiHeadAttention({
|
|
50
|
+
numHeads: this.numHeads, embedDim: this.embedDim,
|
|
51
|
+
useBias: this.useBias, dropout: this.dropout,
|
|
52
|
+
causal: true
|
|
53
|
+
});
|
|
54
|
+
this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
|
|
55
|
+
this.causalSelfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
56
|
+
// feed forward sub-block
|
|
57
|
+
this.feedforward1 = tf.layers.dense({
|
|
58
|
+
units: this.dimsFeedForward,
|
|
59
|
+
activation: this.activation,
|
|
60
|
+
useBias: this.useBias,
|
|
61
|
+
});
|
|
62
|
+
this.feedforward2 = tf.layers.dense({
|
|
63
|
+
units: this.embedDim,
|
|
64
|
+
activation: "linear",
|
|
65
|
+
useBias: this.useBias
|
|
66
|
+
});
|
|
67
|
+
this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
|
|
68
|
+
this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
69
|
+
}
|
|
70
|
+
/**
|
|
71
|
+
* Forward propagation
|
|
72
|
+
*
|
|
73
|
+
* @param inputs input tensor
|
|
74
|
+
* @return the output tensor
|
|
75
|
+
*/
|
|
76
|
+
call(inputs, kwargs) {
|
|
77
|
+
// validate the input tensors
|
|
78
|
+
if (Array.isArray(inputs) && inputs.length != 1 && inputs.length != 2) {
|
|
79
|
+
throw Error(`${this.getClassName()}::call ${this.name} expects one input tensor, got ${inputs.length} inputs.`);
|
|
80
|
+
}
|
|
81
|
+
if (Array.isArray(inputs)) {
|
|
82
|
+
inputs = inputs[0];
|
|
83
|
+
}
|
|
84
|
+
// perform forward propagation
|
|
85
|
+
return tf.tidy(() => {
|
|
86
|
+
let output = this.causalSelfAttentionBlock(inputs, kwargs);
|
|
87
|
+
output = this.feedForwardBlock(output, kwargs);
|
|
88
|
+
return output;
|
|
89
|
+
});
|
|
90
|
+
}
|
|
91
|
+
causalSelfAttentionBlock(x, kwargs) {
|
|
92
|
+
return tf.tidy(() => {
|
|
93
|
+
const residual = x;
|
|
94
|
+
let attention = this.causalSelfAttention.apply(x, kwargs);
|
|
95
|
+
attention = this.causalSelfAttentionDropout.apply(attention, kwargs);
|
|
96
|
+
attention = tf.add(attention, residual);
|
|
97
|
+
attention = this.causalSelfAttentionNorm.apply(attention, kwargs);
|
|
98
|
+
return attention;
|
|
99
|
+
});
|
|
100
|
+
}
|
|
101
|
+
feedForwardBlock(x, kwargs) {
|
|
102
|
+
return tf.tidy(() => {
|
|
103
|
+
const residual = x;
|
|
104
|
+
let feedForward = this.feedforward1.apply(x, kwargs);
|
|
105
|
+
feedForward = this.feedforward2.apply(feedForward, kwargs);
|
|
106
|
+
feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
|
|
107
|
+
feedForward = tf.add(feedForward, residual);
|
|
108
|
+
feedForward = this.feedFowardNorm.apply(feedForward, kwargs);
|
|
109
|
+
return feedForward;
|
|
110
|
+
});
|
|
111
|
+
}
|
|
112
|
+
/**
|
|
113
|
+
* Initialize the sublayers' weights and track them to enable serialization
|
|
114
|
+
*/
|
|
115
|
+
build(inputShape) {
|
|
116
|
+
let input_shapes = [];
|
|
117
|
+
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
118
|
+
// input is an array of shapes
|
|
119
|
+
input_shapes = inputShape;
|
|
120
|
+
}
|
|
121
|
+
else if (inputShape.length != 0) {
|
|
122
|
+
// input is a single shape
|
|
123
|
+
input_shapes = [inputShape];
|
|
124
|
+
}
|
|
125
|
+
if (input_shapes.length != 1 && input_shapes.length != 2) {
|
|
126
|
+
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
127
|
+
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
128
|
+
}
|
|
129
|
+
const [decoderInputShape] = input_shapes;
|
|
130
|
+
if (decoderInputShape?.length != 3) {
|
|
131
|
+
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
132
|
+
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
|
|
133
|
+
}
|
|
134
|
+
// initialize causal self attention sub-block's weights
|
|
135
|
+
this.causalSelfAttention.build(decoderInputShape);
|
|
136
|
+
this.causalSelfAttentionNorm.build(this.causalSelfAttention.computeOutputShape(decoderInputShape));
|
|
137
|
+
// initialize feedforward sub-block's weights
|
|
138
|
+
const feedforward1OutputShape = this.feedforward1.computeOutputShape(decoderInputShape);
|
|
139
|
+
const feedforward2OutputShape = this.feedforward2.computeOutputShape(feedforward1OutputShape);
|
|
140
|
+
this.feedforward1.build(decoderInputShape);
|
|
141
|
+
this.feedforward2.build(feedforward1OutputShape);
|
|
142
|
+
this.feedFowardNorm.build(feedforward2OutputShape);
|
|
143
|
+
// track sublayers' weights
|
|
144
|
+
this.trainableWeights = [
|
|
145
|
+
...this.causalSelfAttention.trainableWeights,
|
|
146
|
+
...this.causalSelfAttentionDropout.trainableWeights,
|
|
147
|
+
...this.causalSelfAttentionNorm.trainableWeights,
|
|
148
|
+
...this.feedforward1.trainableWeights,
|
|
149
|
+
...this.feedforward2.trainableWeights,
|
|
150
|
+
...this.feedForwardDropout.trainableWeights,
|
|
151
|
+
...this.feedFowardNorm.trainableWeights
|
|
152
|
+
];
|
|
153
|
+
// rename the weights otherwise they'll take on the default naming and overlap
|
|
154
|
+
// each other which breaks model loading due to duplicate weight names
|
|
155
|
+
let indexing = 0;
|
|
156
|
+
for (const weight of this.trainableWeights) {
|
|
157
|
+
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
158
|
+
weight.name += unique_name;
|
|
159
|
+
weight.originalName += unique_name;
|
|
160
|
+
indexing++;
|
|
161
|
+
}
|
|
162
|
+
super.build(inputShape);
|
|
163
|
+
}
|
|
164
|
+
/**
|
|
165
|
+
* Save the layer's hyperparameters for serialization
|
|
166
|
+
*/
|
|
167
|
+
getConfig() {
|
|
168
|
+
const base_config = super.getConfig();
|
|
169
|
+
const config = {
|
|
170
|
+
numHeads: this.numHeads,
|
|
171
|
+
embedDim: this.embedDim,
|
|
172
|
+
useBias: this.useBias,
|
|
173
|
+
dropout: this.dropout,
|
|
174
|
+
activation: this.activation,
|
|
175
|
+
dimsFeedForward: this.dimsFeedForward
|
|
176
|
+
};
|
|
177
|
+
Object.assign(config, base_config);
|
|
178
|
+
return config;
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
tf.serialization.registerClass(TransformerDecoder);
|
|
182
|
+
//# sourceMappingURL=transformer_decoder.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.js","sourceRoot":"","sources":["../../src/layers/transformer_decoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAKvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,2CAA2C,CAAC;AAUzF;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAErB,mBAAmB,CAAkB;IACrC,0BAA0B,CAAkB;IAC5C,uBAAuB,CAAkB;IAEzC,YAAY,CAAkB;IAC9B,YAAY,CAAkB;IAC9B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAE3C,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QAEvC,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2FAA2F;QAC3F,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEvD,2BAA2B;QAC3B,IAAI,CAAC,mBAAmB,GAAG,IAAI,4BAA4B,CAAC;YACxD,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ;YAChD,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YAC5C,MAAM,EAAE,IAAI;SACf,CAAC,CAAC;QACH,IAAI,CAAC,0BAA0B,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QAC3E,IAAI,CAAC,uBAAuB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAE/E,yBAAyB;QACzB,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,eAAe;YAC3B,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAChC,KAAK,EAAE,IAAI,CAAC,QAAQ;YACpB,UAAU,EAAE,QAAQ;YACpB,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACpE,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kCAAkC,MAAM,CAAC,MAAM,UAAU,CAAC,CAAC;QACpH,CAAC;QAED,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,MAAM,GAAG,MAAM,CAAC,CAAC,CAAc,CAAC;QACpC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAC3D,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAE/C,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,wBAAwB,CAAC,CAAY,EAAE,MAAc;QAC3D,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACvE,SAAS,GAAG,IAAI,CAAC,0BAA0B,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAClF,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAE/E,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGS,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YACrD,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAC;YAC3D,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAE1E,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACvD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,MAAM,CAAC,iBAAiB,CAAC,GAAG,YAAY,CAAC;QAEzC,IAAI,iBAAiB,EAAE,MAAM,IAAI,CAAC,EAAE,CAAC;YACjC,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yBAAyB;gBAC3E,oCAAoC,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACzE,CAAC;QAED,uDAAuD;QACvD,IAAI,CAAC,mBAAmB,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAClD,IAAI,CAAC,uBAAuB,CAAC,KAAK,CAAC,IAAI,CAAC,mBAAmB,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC;QAEnG,6CAA6C;QAC7C,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,iBAAiB,CAAC,CAAC;QACxF,MAAM,uBAAuB,GAAG,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,uBAAuB,CAAC,CAAC;QAE9F,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QAC3C,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QACjD,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,uBAAuB,CAAC,CAAC;QAEnD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,mBAAmB,CAAC,gBAAgB;YAC5C,GAAG,IAAI,CAAC,0BAA0B,CAAC,gBAAgB;YACnD,GAAG,IAAI,CAAC,uBAAuB,CAAC,gBAAgB;YAChD,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,YAAY,CAAC,gBAAgB;YACrC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"transformer_decoder.test.d.ts","sourceRoot":"","sources":["../../src/layers/transformer_decoder.test.ts"],"names":[],"mappings":""}
|