@stellarapp/tfjs-stellar 1.0.4 → 1.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -0
- package/dist/index.d.ts +2 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -1
- package/dist/index.js.map +1 -1
- package/dist/kv_cache.d.ts +2 -0
- package/dist/kv_cache.d.ts.map +1 -1
- package/dist/kv_cache.js +6 -0
- package/dist/kv_cache.js.map +1 -1
- package/dist/models/index.d.ts +2 -1
- package/dist/models/index.d.ts.map +1 -1
- package/dist/models/index.js +2 -1
- package/dist/models/index.js.map +1 -1
- package/package.json +1 -1
- package/dist/jest.config.d.ts +0 -8
- package/dist/jest.config.d.ts.map +0 -1
- package/dist/jest.config.js +0 -147
- package/dist/jest.config.js.map +0 -1
- package/dist/src/index.d.ts +0 -6
- package/dist/src/index.d.ts.map +0 -1
- package/dist/src/index.js +0 -6
- package/dist/src/index.js.map +0 -1
- package/dist/src/kv_cache.d.ts +0 -53
- package/dist/src/kv_cache.d.ts.map +0 -1
- package/dist/src/kv_cache.js +0 -135
- package/dist/src/kv_cache.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
- package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
- package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
- package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
- package/dist/src/layers/gpt_decoder_block.js +0 -51
- package/dist/src/layers/gpt_decoder_block.js.map +0 -1
- package/dist/src/layers/index.d.ts +0 -17
- package/dist/src/layers/index.d.ts.map +0 -1
- package/dist/src/layers/index.js +0 -33
- package/dist/src/layers/index.js.map +0 -1
- package/dist/src/layers/multihead_attention.d.ts +0 -106
- package/dist/src/layers/multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.js +0 -269
- package/dist/src/layers/multihead_attention.js.map +0 -1
- package/dist/src/layers/multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.test.js +0 -160
- package/dist/src/layers/multihead_attention.test.js.map +0 -1
- package/dist/src/layers/positional_encoding.d.ts +0 -37
- package/dist/src/layers/positional_encoding.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.js +0 -115
- package/dist/src/layers/positional_encoding.js.map +0 -1
- package/dist/src/layers/positional_encoding.test.d.ts +0 -2
- package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.test.js +0 -95
- package/dist/src/layers/positional_encoding.test.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
- package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.js +0 -99
- package/dist/src/layers/rotary_position_embedding.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.js +0 -88
- package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.js +0 -109
- package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
- package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
- package/dist/src/layers/transformer_decoder.d.ts +0 -69
- package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.js +0 -182
- package/dist/src/layers/transformer_decoder.js.map +0 -1
- package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.test.js +0 -72
- package/dist/src/layers/transformer_decoder.test.js.map +0 -1
- package/dist/src/layers/transformer_encoder.d.ts +0 -55
- package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.js +0 -175
- package/dist/src/layers/transformer_encoder.js.map +0 -1
- package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.test.js +0 -58
- package/dist/src/layers/transformer_encoder.test.js.map +0 -1
- package/dist/src/losses/dice.d.ts +0 -30
- package/dist/src/losses/dice.d.ts.map +0 -1
- package/dist/src/losses/dice.js +0 -93
- package/dist/src/losses/dice.js.map +0 -1
- package/dist/src/losses/index.d.ts +0 -2
- package/dist/src/losses/index.d.ts.map +0 -1
- package/dist/src/losses/index.js +0 -2
- package/dist/src/losses/index.js.map +0 -1
- package/dist/src/masks.d.ts +0 -20
- package/dist/src/masks.d.ts.map +0 -1
- package/dist/src/masks.js +0 -37
- package/dist/src/masks.js.map +0 -1
- package/dist/src/metrics.d.ts +0 -20
- package/dist/src/metrics.d.ts.map +0 -1
- package/dist/src/metrics.js +0 -28
- package/dist/src/metrics.js.map +0 -1
- package/dist/src/models/gpt_model.d.ts +0 -94
- package/dist/src/models/gpt_model.d.ts.map +0 -1
- package/dist/src/models/gpt_model.js +0 -154
- package/dist/src/models/gpt_model.js.map +0 -1
- package/dist/src/models/index.d.ts +0 -3
- package/dist/src/models/index.d.ts.map +0 -1
- package/dist/src/models/index.js +0 -3
- package/dist/src/models/index.js.map +0 -1
- package/dist/src/models/llm_model.d.ts +0 -87
- package/dist/src/models/llm_model.d.ts.map +0 -1
- package/dist/src/models/llm_model.js +0 -245
- package/dist/src/models/llm_model.js.map +0 -1
- package/dist/src/models/u_net.d.ts +0 -40
- package/dist/src/models/u_net.d.ts.map +0 -1
- package/dist/src/models/u_net.js +0 -151
- package/dist/src/models/u_net.js.map +0 -1
- package/dist/src/tfjs_types.d.ts +0 -10
- package/dist/src/tfjs_types.d.ts.map +0 -1
- package/dist/src/tfjs_types.js +0 -2
- package/dist/src/tfjs_types.js.map +0 -1
- package/dist/src/utils.d.ts +0 -28
- package/dist/src/utils.d.ts.map +0 -1
- package/dist/src/utils.js +0 -63
- package/dist/src/utils.js.map +0 -1
- package/dist/src/utils.test.d.ts +0 -2
- package/dist/src/utils.test.d.ts.map +0 -1
- package/dist/src/utils.test.js +0 -73
- package/dist/src/utils.test.js.map +0 -1
|
@@ -1,269 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { causal as generateCausalMask } from "@/masks";
|
|
3
|
-
/**
|
|
4
|
-
* This MultiHead Attention layer implements the algorithm as described in
|
|
5
|
-
* the paper "Attention is all you Need" Vaswani et al., 2017.
|
|
6
|
-
*
|
|
7
|
-
* @param numHeads number of attention heads to use
|
|
8
|
-
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
9
|
-
* @param causal use causal masking, default `false`
|
|
10
|
-
* @param dropout use dropout during the attention calculations, default `0.0`
|
|
11
|
-
* @param useBias use bias for the dense sublayers, default `true`
|
|
12
|
-
*
|
|
13
|
-
* The TensorFlow version uses tf.einsum, whose gradient op has not yet been
|
|
14
|
-
* implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
|
|
15
|
-
* therefore we follow the PyTorch implementation described in:
|
|
16
|
-
* https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
|
|
17
|
-
* https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
18
|
-
*
|
|
19
|
-
* This implementation is different from TensorFlow's whose attention weights
|
|
20
|
-
* are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
|
|
21
|
-
* are shaped [embed dim, embed dim]
|
|
22
|
-
* https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
|
|
23
|
-
* https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
|
|
24
|
-
*
|
|
25
|
-
* TODO: implement a fast track for self attention (query = key = value)
|
|
26
|
-
* where a single dense layer combines and replaces the query, key and projection layers
|
|
27
|
-
*
|
|
28
|
-
* TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
|
|
29
|
-
*/
|
|
30
|
-
export class MultiHeadAttention extends tf.layers.Layer {
|
|
31
|
-
static className = "MultiHeadAttention";
|
|
32
|
-
numHeads;
|
|
33
|
-
embedDim; // size of embedding dim of inputs, also per attention head
|
|
34
|
-
useBias;
|
|
35
|
-
dropout;
|
|
36
|
-
causal; // use causal attention to mask future words
|
|
37
|
-
// projection simply means matrix multiplying query, key, and value
|
|
38
|
-
// with weights to create a representation of the inputs
|
|
39
|
-
queryProjection;
|
|
40
|
-
keyProjection;
|
|
41
|
-
valueProjection;
|
|
42
|
-
outputProjection;
|
|
43
|
-
constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }) {
|
|
44
|
-
super(args);
|
|
45
|
-
if (embedDim % numHeads != 0) {
|
|
46
|
-
throw Error(`${this.getClassName()}::constructor ${this.name} embedDim (${embedDim}) is not divisible by numHeads (${numHeads})`);
|
|
47
|
-
}
|
|
48
|
-
this.numHeads = numHeads;
|
|
49
|
-
this.embedDim = embedDim;
|
|
50
|
-
this.useBias = useBias;
|
|
51
|
-
this.dropout = dropout;
|
|
52
|
-
this.causal = causal;
|
|
53
|
-
if (this.dropout >= 1) {
|
|
54
|
-
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
55
|
-
}
|
|
56
|
-
// intialize the projection weights, this should be in the
|
|
57
|
-
// build() function but is done here to avoid linting complaints
|
|
58
|
-
this.queryProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
59
|
-
this.keyProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
60
|
-
this.valueProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
61
|
-
this.outputProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
62
|
-
}
|
|
63
|
-
/**
|
|
64
|
-
* Forward propagation. Provide one input tensor or three identical tensors to self-attention.
|
|
65
|
-
* @param inputs a single tensor for self-attention or an array of exactly three
|
|
66
|
-
* tensors that are either identical (self-attention) or different (cross-attention)
|
|
67
|
-
* @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
|
|
68
|
-
*/
|
|
69
|
-
call(inputs, kwargs) {
|
|
70
|
-
// validate the input tensors
|
|
71
|
-
if (!Array.isArray(inputs)) {
|
|
72
|
-
inputs = [inputs];
|
|
73
|
-
}
|
|
74
|
-
// accept only 1 input (self attention) or 3 inputs (self or cross attention)
|
|
75
|
-
if (inputs.length != 1 && inputs.length != 3) {
|
|
76
|
-
throw Error(`${this.getClassName()}::call ${this.name} expects exactly one or three input tensors, ${inputs.length} were provided`);
|
|
77
|
-
}
|
|
78
|
-
for (const input of inputs) {
|
|
79
|
-
if (input.shape.length != 3) {
|
|
80
|
-
throw Error(`${this.getClassName()}::call ${this.name} expected input shapes of [batch, seq, embed_dim], got ${JSON.stringify(input.shape)}`);
|
|
81
|
-
}
|
|
82
|
-
}
|
|
83
|
-
const [query, key, value] = inputs;
|
|
84
|
-
const packingMask = kwargs.packingMask ?? null;
|
|
85
|
-
const causalMask = kwargs.causalMask ?? null;
|
|
86
|
-
return inputs.length == 3
|
|
87
|
-
// cross-attention
|
|
88
|
-
? this.forward(query, key, value, packingMask, causalMask, kwargs)
|
|
89
|
-
// self-attention
|
|
90
|
-
: this.forward(query, query, query, packingMask, causalMask, kwargs);
|
|
91
|
-
}
|
|
92
|
-
/**
|
|
93
|
-
* Forward propagation
|
|
94
|
-
*/
|
|
95
|
-
forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
|
|
96
|
-
// dimensions abbreviations
|
|
97
|
-
// batch = the number of sequences in the input
|
|
98
|
-
// seq = the length of each sequence in the input
|
|
99
|
-
// dims = the size of each token's embedding
|
|
100
|
-
return tf.tidy(() => {
|
|
101
|
-
const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
|
|
102
|
-
// swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
|
|
103
|
-
const move_head_dim_forward = [0, 2, 1, 3];
|
|
104
|
-
const { query_split, key_split, value_split } = this.splitHeads(query, key, value, move_head_dim_forward);
|
|
105
|
-
// apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
|
|
106
|
-
const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
|
|
107
|
-
// concat heads and apply the output projection
|
|
108
|
-
const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], -1, this.embedDim]));
|
|
109
|
-
return output;
|
|
110
|
-
});
|
|
111
|
-
}
|
|
112
|
-
applyInputProjections(query_input, key_input, value_input) {
|
|
113
|
-
// apply input projections, this is a batched matrix multiplication operated on the last
|
|
114
|
-
// dimension of query_input and first dimension of the dense layer weights,
|
|
115
|
-
// [batch, seq, dims] x [dims, dims] = [batch x seq, dims] x [dims, dims] = [batch x seq, dims] = [batch, seq, dims]
|
|
116
|
-
return tf.tidy(() => {
|
|
117
|
-
return {
|
|
118
|
-
query: this.queryProjection.apply(query_input),
|
|
119
|
-
key: this.keyProjection.apply(key_input),
|
|
120
|
-
value: this.valueProjection.apply(value_input)
|
|
121
|
-
};
|
|
122
|
-
});
|
|
123
|
-
}
|
|
124
|
-
splitHeads(query, key, value, shuffle) {
|
|
125
|
-
// split heads and prepare for scaled dot product attention by splitting the
|
|
126
|
-
// last dimension to get the heads, bring the heads forward
|
|
127
|
-
// [batch, seq, dims] -> [batch, seq, heads, dims / heads] -> [batch, heads, seq, head_dim]
|
|
128
|
-
const batch_size = query.shape[0];
|
|
129
|
-
const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
|
|
130
|
-
return tf.tidy(() => {
|
|
131
|
-
return {
|
|
132
|
-
query_split: query.reshape(split_heads).transpose(shuffle),
|
|
133
|
-
key_split: key.reshape(split_heads).transpose(shuffle),
|
|
134
|
-
value_split: value.reshape(split_heads).transpose(shuffle)
|
|
135
|
-
};
|
|
136
|
-
});
|
|
137
|
-
}
|
|
138
|
-
/**
|
|
139
|
-
* Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
|
|
140
|
-
* formula (1) of the 2017 paper Attention Is All You Need
|
|
141
|
-
*
|
|
142
|
-
* @param attentionMask a mask to prevent tokens from being
|
|
143
|
-
* attended to (usually for padding tokens). It should have the shape
|
|
144
|
-
* [batch, head, query_sequence_len, key_sequence_len]. To use in
|
|
145
|
-
* conjunction with causal masking, the tensor should be a boolean type
|
|
146
|
-
* where false indicates a masked token.
|
|
147
|
-
* @param packingMask a mask to prevent tokens from attending across document boundaries
|
|
148
|
-
*/
|
|
149
|
-
static scaledDotProductionAttention(query, key, value, attentionMask, packingMask, causalMask, dropout, causal, kwargs = {}) {
|
|
150
|
-
return tf.tidy(() => {
|
|
151
|
-
const { training = false, scaling_factor } = kwargs;
|
|
152
|
-
key.shape.forEach((val, index) => {
|
|
153
|
-
if (key.shape[index] != value.shape[index]) {
|
|
154
|
-
throw Error(`scaledDotProductionAttention: expected key and value` +
|
|
155
|
-
` to have the same shape, got ${JSON.stringify(key.shape)} (key) and` +
|
|
156
|
-
` ${JSON.stringify(value.shape)} (value)`);
|
|
157
|
-
}
|
|
158
|
-
});
|
|
159
|
-
// mask's shape is [..., seq, seq] where seq is the number of words/tokens in the input,
|
|
160
|
-
// not adding the batch dimension yet to lessen the calculations
|
|
161
|
-
const causal_mask_shape = [
|
|
162
|
-
query.shape[query.shape.length - 2],
|
|
163
|
-
key.shape[key.shape.length - 2]
|
|
164
|
-
];
|
|
165
|
-
let mask = tf.zeros(causal_mask_shape);
|
|
166
|
-
if (causal && causal_mask_shape[0] > 1) {
|
|
167
|
-
if (attentionMask && attentionMask.dtype != "bool") {
|
|
168
|
-
throw Error(`scaledDotProductionAttention: the attention mask must be undefined or a boolean type if used with causal attention`);
|
|
169
|
-
}
|
|
170
|
-
// apply a causal attention mask so that tokens can only attend to preceding tokens,
|
|
171
|
-
// prevents looking at head
|
|
172
|
-
if (causalMask) {
|
|
173
|
-
mask = causalMask;
|
|
174
|
-
}
|
|
175
|
-
else {
|
|
176
|
-
mask = generateCausalMask(causal_mask_shape[0], causal_mask_shape[1]);
|
|
177
|
-
}
|
|
178
|
-
}
|
|
179
|
-
if (attentionMask) {
|
|
180
|
-
if (attentionMask.dtype == "bool") {
|
|
181
|
-
// convert the boolean mask to float
|
|
182
|
-
// warning: do not use 1e9, it will overflow, use something smaller like 1e7
|
|
183
|
-
mask = mask.add(attentionMask.cast("float32").sub(1).mul(1e7));
|
|
184
|
-
}
|
|
185
|
-
else {
|
|
186
|
-
// this will occur only when not using causal masking,
|
|
187
|
-
// if the attention mask is not boolean, it's assumed the masking is already calculated,
|
|
188
|
-
mask = attentionMask;
|
|
189
|
-
}
|
|
190
|
-
}
|
|
191
|
-
// 1. matrix multiply query and transposed key
|
|
192
|
-
// 2. divide by scaling factor
|
|
193
|
-
// 3. apply softmax to the result
|
|
194
|
-
// 4. apply attention and/or causal mask
|
|
195
|
-
// 5. apply dropout
|
|
196
|
-
// 6. matrix multiply softmax result with value
|
|
197
|
-
let pre_softmax = query
|
|
198
|
-
.matMul(key, false, true)
|
|
199
|
-
.div(Math.sqrt(scaling_factor ?? key.shape[key.shape.length - 1]))
|
|
200
|
-
.add(mask);
|
|
201
|
-
if (packingMask) {
|
|
202
|
-
// packing mask is added separately because each mask within a batch may be different,
|
|
203
|
-
// so it cannot be broadcasted
|
|
204
|
-
pre_softmax = pre_softmax.add(packingMask);
|
|
205
|
-
}
|
|
206
|
-
const spda = tf.softmax(pre_softmax);
|
|
207
|
-
const spda_dropout = tf.dropout(spda, training ? dropout : 0);
|
|
208
|
-
const attention = spda_dropout.matMul(value);
|
|
209
|
-
return attention;
|
|
210
|
-
});
|
|
211
|
-
}
|
|
212
|
-
build(inputShape) {
|
|
213
|
-
let input_shape = [];
|
|
214
|
-
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
215
|
-
input_shape = inputShape;
|
|
216
|
-
}
|
|
217
|
-
else {
|
|
218
|
-
input_shape = [inputShape, inputShape, inputShape];
|
|
219
|
-
}
|
|
220
|
-
if (input_shape.length != 1 && input_shape.length != 3) {
|
|
221
|
-
throw Error(`${this.getClassName()}::build ${this.name} accepts either exactly one or three inputs, received ${JSON.stringify(inputShape)}`);
|
|
222
|
-
}
|
|
223
|
-
// initialize the sublayer weights
|
|
224
|
-
this.queryProjection.build(input_shape[0]);
|
|
225
|
-
this.keyProjection.build(input_shape[1]);
|
|
226
|
-
this.valueProjection.build(input_shape[2]);
|
|
227
|
-
this.outputProjection.build(input_shape[0]);
|
|
228
|
-
// the sublayer weights need to be tracked by this layer otherwise
|
|
229
|
-
// backpropagation will complain about no trainable parameters found,
|
|
230
|
-
// this is an extra step that TF's Python version does not need
|
|
231
|
-
this.trainableWeights = [
|
|
232
|
-
...this.queryProjection.trainableWeights,
|
|
233
|
-
...this.keyProjection.trainableWeights,
|
|
234
|
-
...this.valueProjection.trainableWeights,
|
|
235
|
-
...this.outputProjection.trainableWeights
|
|
236
|
-
];
|
|
237
|
-
// rename the weights otherwise they'll take on the default naming and overlap
|
|
238
|
-
// each other which breaks model loading due to duplicate weight names
|
|
239
|
-
let indexing = 0;
|
|
240
|
-
for (const weight of this.trainableWeights) {
|
|
241
|
-
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
242
|
-
weight.name += unique_name;
|
|
243
|
-
weight.originalName += unique_name;
|
|
244
|
-
indexing++;
|
|
245
|
-
}
|
|
246
|
-
super.build(inputShape);
|
|
247
|
-
}
|
|
248
|
-
/**
|
|
249
|
-
* MultiHead attention's output is the same shape the query's.
|
|
250
|
-
*/
|
|
251
|
-
computeOutputShape(inputShape) {
|
|
252
|
-
return Array.isArray(inputShape) && Array.isArray(inputShape[0]) ? inputShape[0] : inputShape;
|
|
253
|
-
}
|
|
254
|
-
getConfig() {
|
|
255
|
-
const base_config = super.getConfig();
|
|
256
|
-
const config = {
|
|
257
|
-
numHeads: this.numHeads,
|
|
258
|
-
embedDim: this.embedDim,
|
|
259
|
-
useBias: this.useBias,
|
|
260
|
-
causal: this.causal,
|
|
261
|
-
dropout: this.dropout,
|
|
262
|
-
name: this.name,
|
|
263
|
-
};
|
|
264
|
-
Object.assign(config, base_config);
|
|
265
|
-
return config;
|
|
266
|
-
}
|
|
267
|
-
}
|
|
268
|
-
tf.serialization.registerClass(MultiHeadAttention);
|
|
269
|
-
//# sourceMappingURL=multihead_attention.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"multihead_attention.js","sourceRoot":"","sources":["../../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAoBvD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACrB,QAAQ,CAAS;IACjB,QAAQ,CAAS,CAAC,2DAA2D;IAC7E,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,MAAM,CAAU,CAAC,4CAA4C;IAEhF,mEAAmE;IACnE,wDAAwD;IACrC,eAAe,CAAkB;IACjC,aAAa,CAAkB;IAC/B,eAAe,CAAkB;IACjC,gBAAgB,CAAkB;IAGrD,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,GAAG,IAAI,EAAE,OAAO,GAAG,GAAG,EAAE,MAAM,GAAG,KAAK,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,QAAQ,GAAG,QAAQ,IAAI,CAAC,EAAE,CAAC;YAC3B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,cAAc,QAAQ,mCAAmC,QAAQ,GAAG,CAAC,CAAC;QACtI,CAAC;QAED,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QAErB,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,0DAA0D;QAC1D,gEAAgE;QAChE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACnE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,gBAAgB,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CACT,MAA+B,EAC/B,MAGC;QAED,6BAA6B;QAC7B,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACzB,MAAM,GAAG,CAAC,MAAM,CAAC,CAAC;QACtB,CAAC;QAED,6EAA6E;QAC7E,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC3C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,gDAAgD,MAAM,CAAC,MAAM,gBAAgB,CAAC,CAAC;QACxI,CAAC;QAED,KAAK,MAAM,KAAK,IAAI,MAAM,EAAE,CAAC;YACzB,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBAC1B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,0DAA0D,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;YAClJ,CAAC;QACL,CAAC;QAED,MAAM,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,GAAG,MAAM,CAAC;QACnC,MAAM,WAAW,GAAG,MAAM,CAAC,WAAW,IAAI,IAAI,CAAC;QAC/C,MAAM,UAAU,GAAG,MAAM,CAAC,UAAU,IAAI,IAAI,CAAC;QAE7C,OAAO,MAAM,CAAC,MAAM,IAAI,CAAC;YACrB,kBAAkB;YAClB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,GAAI,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC;YACrE,iBAAiB;YACjB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,KAAM,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC,CAAC;IAChF,CAAC;IAGD;;OAEG;IACO,OAAO,CACb,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,2BAA2B;QAC3B,+CAA+C;QAC/C,iDAAiD;QACjD,4CAA4C;QAC5C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,EACF,WAAW,EAAE,SAAS,EAAE,WAAW,EACtC,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAE9D,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAE9F,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,qBAAqB,CAAC,WAAsB,EAAE,SAAoB,EAAE,WAAsB;QAChG,wFAAwF;QACxF,2EAA2E;QAC3E,oHAAoH;QACpH,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;gBAC3D,GAAG,EAAE,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,SAAS,CAAc;gBACrD,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;aAC9D,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QACtF,4EAA4E;QAC5E,2DAA2D;QAC3D,2FAA2F;QAC3F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACzE,SAAS,EAAE,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACrE,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAgB,EAChB,GAAc,EACd,KAAgB,EAChB,aAA+B,EAC/B,WAA6B,EAC7B,UAA4B,EAC5B,OAAe,EACf,MAAe,EACf,SAA6C,EAAE;QAE/C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,QAAQ,GAAG,KAAK,EAAE,cAAc,EAAE,GAAG,MAAM,CAAC;YAEpD,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,KAAK,EAAE,EAAE;gBAC7B,IAAI,GAAG,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC;oBACzC,MAAM,KAAK,CAAC,sDAAsD;wBAC9D,gCAAgC,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,YAAY;wBACrE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;gBACnD,CAAC;YACL,CAAC,CAAC,CAAA;YAGF,wFAAwF;YACxF,gEAAgE;YAChE,MAAM,iBAAiB,GAAG;gBACtB,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;gBACnC,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;aAAC,CAAC;YAErC,IAAI,IAAI,GAAG,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;YAEvC,IAAI,MAAM,IAAI,iBAAiB,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC;gBACrC,IAAI,aAAa,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBACjD,MAAM,KAAK,CAAC,oHAAoH,CAAC,CAAC;gBACtI,CAAC;gBAED,oFAAoF;gBACpF,2BAA2B;gBAC3B,IAAI,UAAU,EAAE,CAAC;oBACb,IAAI,GAAG,UAAU,CAAC;gBACtB,CAAC;qBAAM,CAAC;oBACJ,IAAI,GAAG,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC1E,CAAC;YACL,CAAC;YAED,IAAI,aAAa,EAAE,CAAC;gBAChB,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBAChC,oCAAoC;oBACpC,4EAA4E;oBAC5E,IAAI,GAAG,IAAI,CAAC,GAAG,CAAC,aAAa,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;gBACnE,CAAC;qBAAM,CAAC;oBACJ,sDAAsD;oBACtD,wFAAwF;oBACxF,IAAI,GAAG,aAAa,CAAC;gBACzB,CAAC;YACL,CAAC;YAED,8CAA8C;YAC9C,8BAA8B;YAC9B,iCAAiC;YACjC,wCAAwC;YACxC,mBAAmB;YACnB,+CAA+C;YAC/C,IAAI,WAAW,GAAG,KAAK;iBAClB,MAAM,CAAC,GAAG,EAAE,KAAK,EAAE,IAAI,CAAC;iBACxB,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,IAAI,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC;iBACjE,GAAG,CAAC,IAAI,CAAC,CAAC;YAEf,IAAI,WAAW,EAAE,CAAC;gBACd,sFAAsF;gBACtF,8BAA8B;gBAC9B,WAAW,GAAG,WAAW,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAC/C,CAAC;YAED,MAAM,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC;YAErC,MAAM,YAAY,GAAG,EAAE,CAAC,OAAO,CAAC,IAAI,EAAE,QAAQ,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9D,MAAM,SAAS,GAAG,YAAY,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;YAE7C,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGQ,KAAK,CAAC,UAAiC;QAC5C,IAAI,WAAW,GAAe,EAAE,CAAC;QAEjC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,WAAW,GAAG,UAAwB,CAAC;QAC3C,CAAC;aAAM,CAAC;YACJ,WAAW,GAAG,CAAC,UAAsB,EAAE,UAAsB,EAAE,UAAsB,CAAC,CAAC;QAC3F,CAAC;QAED,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACrD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yDAAyD,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;QACjJ,CAAC;QAED,kCAAkC;QAClC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QACzC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,gBAAgB,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAE5C,kEAAkE;QAClE,qEAAqE;QACrE,+DAA+D;QAC/D,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,gBAAgB,CAAC,gBAAgB;SAC5C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,kBAAkB,CAAC,UAAiC;QACzD,OAAO,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;IAClG,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,IAAI,EAAE,IAAI,CAAC,IAAI;SAClB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"multihead_attention.test.d.ts","sourceRoot":"","sources":["../../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":""}
|
|
@@ -1,160 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { CachedRoPEMultiHeadAttention } from '@/layers/cached_rope_multihead_attention';
|
|
3
|
-
import { causal as generateCausalMask } from "@/masks";
|
|
4
|
-
import { MultiHeadAttention } from '@/layers/multihead_attention';
|
|
5
|
-
// disables warning for using the faster node backend,
|
|
6
|
-
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
7
|
-
tf.env().set('IS_NODE', false);
|
|
8
|
-
describe("MultiHeadAttention tests", () => {
|
|
9
|
-
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
10
|
-
expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 3, embedDim: 10 })).toThrow();
|
|
11
|
-
expect(() => new CachedRoPEMultiHeadAttention({ numHeads: 15, embedDim: 60 })).not.toThrow();
|
|
12
|
-
});
|
|
13
|
-
test("successfull forward calls", () => {
|
|
14
|
-
const input = tf.randomUniform([2, 3, 12]);
|
|
15
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
16
|
-
expect(() => attention.apply(input)).not.toThrow();
|
|
17
|
-
expect(() => attention.apply([input])).not.toThrow();
|
|
18
|
-
const causal = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
|
|
19
|
-
expect(() => causal.apply(input)).not.toThrow();
|
|
20
|
-
expect(() => causal.apply([input])).not.toThrow();
|
|
21
|
-
});
|
|
22
|
-
test("query and value must have the same shape for scaled dot product attention to succeed", () => {
|
|
23
|
-
const query = tf.randomUniform([2, 3, 12]);
|
|
24
|
-
const key = tf.randomUniform([2, 3, 12]);
|
|
25
|
-
const value = tf.randomUniform([2, 3, 12]);
|
|
26
|
-
const value_thats_too_long = tf.randomUniform([2, 100, 12]);
|
|
27
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1) });
|
|
28
|
-
expect(() => attention.apply([query, key, value])).not.toThrow();
|
|
29
|
-
expect(() => attention.apply([query, key, value_thats_too_long])).toThrow();
|
|
30
|
-
});
|
|
31
|
-
it("should only accept rank 3 tensors", () => {
|
|
32
|
-
const embed_dims = 12;
|
|
33
|
-
const BAD_RANK2 = tf.randomUniform([2, embed_dims]);
|
|
34
|
-
const GOOD = tf.randomUniform([2, 3, embed_dims]);
|
|
35
|
-
const BAD_RANK4 = tf.randomUniform([2, 3, 10, embed_dims]);
|
|
36
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: embed_dims });
|
|
37
|
-
// BAD
|
|
38
|
-
expect(() => attention.apply(BAD_RANK2)).toThrow();
|
|
39
|
-
expect(() => attention.apply([BAD_RANK2])).toThrow();
|
|
40
|
-
expect(() => attention.apply([BAD_RANK2, BAD_RANK2, BAD_RANK2])).toThrow();
|
|
41
|
-
// OK
|
|
42
|
-
expect(() => attention.apply(GOOD)).not.toThrow();
|
|
43
|
-
expect(() => attention.apply([GOOD])).not.toThrow();
|
|
44
|
-
expect(() => attention.apply([GOOD, GOOD, GOOD])).not.toThrow();
|
|
45
|
-
// BAD
|
|
46
|
-
expect(() => attention.apply(BAD_RANK4)).toThrow();
|
|
47
|
-
expect(() => attention.apply([BAD_RANK4])).toThrow();
|
|
48
|
-
expect(() => attention.apply([BAD_RANK4, BAD_RANK4, BAD_RANK4])).toThrow();
|
|
49
|
-
// BAD
|
|
50
|
-
expect(() => attention.apply([GOOD, BAD_RANK2, BAD_RANK4])).toThrow();
|
|
51
|
-
expect(() => attention.apply([BAD_RANK2, GOOD, BAD_RANK4])).toThrow();
|
|
52
|
-
expect(() => attention.apply([BAD_RANK2, BAD_RANK4, GOOD])).toThrow();
|
|
53
|
-
expect(() => attention.apply([BAD_RANK2, GOOD, GOOD])).toThrow();
|
|
54
|
-
expect(() => attention.apply([GOOD, GOOD, BAD_RANK4])).toThrow();
|
|
55
|
-
});
|
|
56
|
-
it("should only 1 or 3 inputs total", () => {
|
|
57
|
-
const input = tf.randomUniform([2, 3, 12]);
|
|
58
|
-
let attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
59
|
-
// OK
|
|
60
|
-
expect(() => attention.apply(input, { packingMask: undefined })).not.toThrow();
|
|
61
|
-
expect(() => attention.apply([input])).not.toThrow();
|
|
62
|
-
// reinitialize to rerun build()
|
|
63
|
-
attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
64
|
-
expect(() => attention.apply([input, input, input])).not.toThrow();
|
|
65
|
-
// BAD
|
|
66
|
-
expect(() => attention.apply([])).toThrow();
|
|
67
|
-
expect(() => attention.apply([input, input])).toThrow();
|
|
68
|
-
// reinitialize to rerun build()
|
|
69
|
-
attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: input.shape.at(-1) });
|
|
70
|
-
expect(() => attention.apply([input, input, input, input])).toThrow();
|
|
71
|
-
});
|
|
72
|
-
test("attention masking", () => {
|
|
73
|
-
const query = tf.randomUniform([2, 3, 12]);
|
|
74
|
-
const key = tf.randomUniform([2, 3, 12]);
|
|
75
|
-
const value = tf.randomUniform([2, 3, 12]);
|
|
76
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 2, embedDim: query.shape.at(-1), causal: true });
|
|
77
|
-
expect(() => attention.call(query, {})).not.toThrow();
|
|
78
|
-
// cross attention
|
|
79
|
-
expect(() => attention.call([query, key, value], {})).not.toThrow();
|
|
80
|
-
const query5 = tf.randomUniform([2, 5, 10]);
|
|
81
|
-
const key4 = tf.randomUniform([2, 4, 10]);
|
|
82
|
-
const value5 = tf.randomUniform([2, 4, 10]);
|
|
83
|
-
const expected_mask = tf.tensor([[
|
|
84
|
-
// vertical represents query, false means that token cannot attend to the keys
|
|
85
|
-
// horizontal represents key, false means that token cannot attend to the queries
|
|
86
|
-
[false, false, false, false],
|
|
87
|
-
[true, true, true, false,],
|
|
88
|
-
[true, true, true, false,],
|
|
89
|
-
[false, false, false, false],
|
|
90
|
-
[true, true, true, false,],
|
|
91
|
-
]]);
|
|
92
|
-
const packing_mask = tf.tensor([
|
|
93
|
-
[0, 0, 0, -1e7, -1e7],
|
|
94
|
-
[0, 0, 0, -1e7, -1e7],
|
|
95
|
-
[0, 0, 0, -1e7, -1e7],
|
|
96
|
-
[-1e7, -1e7, -1e7, 0, 0],
|
|
97
|
-
[-1e7, -1e7, -1e7, 0, 0]
|
|
98
|
-
]);
|
|
99
|
-
// for causal attention, the attention mask must be boolean
|
|
100
|
-
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0.1, true, { scaling_factor: 10 })).toThrow();
|
|
101
|
-
// for causal attention, using pre-calculated causal mask
|
|
102
|
-
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, generateCausalMask(query5.shape[1], key4.shape[1]), 0.2, true, { scaling_factor: 10 })).toThrow();
|
|
103
|
-
// when not using causal attention, the attention mask can be a float32 tensor
|
|
104
|
-
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, key4, value5, expected_mask.asType("float32"), null, null, 0, false)).not.toThrow();
|
|
105
|
-
// packing mask for self attention
|
|
106
|
-
expect(() => MultiHeadAttention.scaledDotProductionAttention(query5, query5, query5, null, packing_mask, null, 0.9, true)).not.toThrow();
|
|
107
|
-
});
|
|
108
|
-
it("should return a non-empty config dict", () => {
|
|
109
|
-
const input = tf.randomUniform([2, 3, 10]);
|
|
110
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1) });
|
|
111
|
-
expect(Object.keys(attention.getConfig())).not.toBe(0);
|
|
112
|
-
});
|
|
113
|
-
test("causal attention hard coded values", () => {
|
|
114
|
-
// input and output shapes: [2, 3, 10]
|
|
115
|
-
const input = tf.tensor([
|
|
116
|
-
[[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
|
|
117
|
-
[0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
|
|
118
|
-
[0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
|
|
119
|
-
[[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
|
|
120
|
-
[0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
|
|
121
|
-
[0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
|
|
122
|
-
]);
|
|
123
|
-
const expected = tf.tensor([
|
|
124
|
-
[[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
|
|
125
|
-
[0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
|
|
126
|
-
[0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
|
|
127
|
-
[[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
|
|
128
|
-
[0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
|
|
129
|
-
[0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
|
|
130
|
-
]);
|
|
131
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: true });
|
|
132
|
-
attention.build(input.shape);
|
|
133
|
-
attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
|
|
134
|
-
expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
135
|
-
});
|
|
136
|
-
test("non-causal attention hard coded values", () => {
|
|
137
|
-
// input and output shapes: [2, 3, 10]
|
|
138
|
-
const input = tf.tensor([
|
|
139
|
-
[[0.2109915, 0.6158954, 0.6012088, 0.9867562, 0.8728716, 0.7496274, 0.8173883, 0.2958342, 0.9650571, 0.2075207],
|
|
140
|
-
[0.2946285, 0.9779906, 0.3203818, 0.4037617, 0.3762881, 0.9863171, 0.6655593, 0.7707329, 0.3216831, 0.7984023],
|
|
141
|
-
[0.9080769, 0.0026282, 0.379492, 0.0162054, 0.1939302, 0.2201049, 0.8190675, 0.0203963, 0.0114392, 0.5015539]],
|
|
142
|
-
[[0.6241482, 0.7631097, 0.6687831, 0.7259795, 0.0457698, 0.6889264, 0.0853676, 0.8697655, 0.3637198, 0.2105307],
|
|
143
|
-
[0.5221761, 0.4476321, 0.1244729, 0.8863543, 0.7319002, 0.2954829, 0.3200496, 0.0905503, 0.607977, 0.1309131],
|
|
144
|
-
[0.4693873, 0.4609751, 0.9170766, 0.7065565, 0.4795104, 0.3225758, 0.1353116, 0.7083887, 0.1928891, 0.967386]]
|
|
145
|
-
]);
|
|
146
|
-
const expected = tf.tensor([
|
|
147
|
-
[[0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344, 0.2055344],
|
|
148
|
-
[0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376, 0.205376],
|
|
149
|
-
[0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539, 0.2042539]],
|
|
150
|
-
[[0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718, 0.1966718],
|
|
151
|
-
[0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268, 0.1966268],
|
|
152
|
-
[0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877, 0.1966877]]
|
|
153
|
-
]);
|
|
154
|
-
const attention = new CachedRoPEMultiHeadAttention({ numHeads: 1, embedDim: input.shape.at(-1), causal: false });
|
|
155
|
-
attention.build(input.shape);
|
|
156
|
-
attention.setWeights(attention.getWeights().map(weight => tf.onesLike(weight).mul(0.05)));
|
|
157
|
-
expect(expected.sub(attention.apply(input)).sum().dataSync()[0]).toBeLessThan(1e-6);
|
|
158
|
-
});
|
|
159
|
-
});
|
|
160
|
-
//# sourceMappingURL=multihead_attention.test.js.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"multihead_attention.test.js","sourceRoot":"","sources":["../../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,4BAA4B,EAAE,MAAM,0CAA0C,CAAC;AACxF,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,SAAS,CAAC;AACvD,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACjG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAErD,MAAM,MAAM,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAC9G,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,sFAAsF,EAAE,GAAG,EAAE;QAC9F,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,oBAAoB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAChF,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,mCAAmC,EAAE,GAAG,EAAE;QACzC,MAAM,UAAU,GAAG,EAAE,CAAC;QAEtB,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QACpD,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;QAClD,MAAM,SAAS,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QAE1F,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEhE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACnD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAE3E,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACjE,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACrE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,IAAI,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAEjG,KAAK;QACL,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,WAAW,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAC/E,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACrD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnE,MAAM;QACN,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC5C,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxD,gCAAgC;QAChC,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QAC7F,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC1E,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,mBAAmB,EAAE,GAAG,EAAE;QAC3B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3C,MAAM,GAAG,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACzC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QAEjH,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,KAAK,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEtD,kBAAkB;QAClB,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,IAAI,CAAC,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAGpE,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC5C,MAAM,IAAI,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC1C,MAAM,MAAM,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE5C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC;gBAC7B,8EAA8E;gBAC9E,iFAAiF;gBACjF,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;gBAC1B,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC;gBAC5B,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,EAAE,KAAK,EAAE;aAC7B,CAAC,CAAC,CAAC;QAEJ,MAAM,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC;YAC3B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC;YACrB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;YACxB,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC;SAC3B,CAAC,CAAA;QAEF,2DAA2D;QAC3D,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9K,yDAAyD;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,kBAAkB,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAE,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC,EAAE,GAAG,EAAE,IAAI,EAAE,EAAE,cAAc,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC9N,8EAA8E;QAC9E,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,CAAC,EAAE,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACzJ,kCAAkC;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,4BAA4B,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC7I,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACnG,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC5C,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAEH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAGH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACjH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,wCAAwC,EAAE,GAAG,EAAE;QAChD,sCAAsC;QACtC,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE9G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,CAAC;gBAC7G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,CAAC,CAAC;SACjH,CAAC,CAAC;QAGH,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;YACvB,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,EAAE,QAAQ,CAAC;gBACpG,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;YAE/G,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC/G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9G,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAClH,CAAC,CAAC;QAEH,MAAM,SAAS,GAAG,IAAI,4BAA4B,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,KAAK,EAAE,CAAC,CAAC;QAClH,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;QAC7B,SAAS,CAAC,UAAU,CAAC,SAAS,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1F,MAAM,CAAC,QAAQ,CAAC,GAAG,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;IACrG,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
|
|
3
|
-
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
4
|
-
export interface PositionalEncodingArgs extends LayerArgs {
|
|
5
|
-
embedDim: number;
|
|
6
|
-
maxSequenceLength?: number;
|
|
7
|
-
}
|
|
8
|
-
/**
|
|
9
|
-
* This class implements the position encoding logic described in the
|
|
10
|
-
* 2017 paper "Attention Is All You Need".
|
|
11
|
-
*
|
|
12
|
-
* This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
|
|
13
|
-
* and adds positional encoding to return an output tensor of the same shape.
|
|
14
|
-
*
|
|
15
|
-
* @param embedDim the size of each token/word's embedding
|
|
16
|
-
* @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
|
|
17
|
-
*/
|
|
18
|
-
export declare class PositionalEncoding extends tf.layers.Layer {
|
|
19
|
-
static className: string;
|
|
20
|
-
private readonly maxSequenceLength;
|
|
21
|
-
private readonly embedDim;
|
|
22
|
-
private positionalEncodings;
|
|
23
|
-
constructor(args: PositionalEncodingArgs);
|
|
24
|
-
/**
|
|
25
|
-
* Forward propagation. Injects positional encoding to the input embeddings
|
|
26
|
-
*/
|
|
27
|
-
call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[];
|
|
28
|
-
/**
|
|
29
|
-
* Generate the positional encoding from the paper Attention Is All You Need.
|
|
30
|
-
* Note that because the inner term of the position formula is the same for both even
|
|
31
|
-
* and odd indices, we only create half of it and apply sine and cosine individually.
|
|
32
|
-
*/
|
|
33
|
-
build(inputShape: tf.Shape | tf.Shape[]): void;
|
|
34
|
-
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
|
|
35
|
-
getConfig(): tf.serialization.ConfigDict;
|
|
36
|
-
}
|
|
37
|
-
//# sourceMappingURL=positional_encoding.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"positional_encoding.d.ts","sourceRoot":"","sources":["../../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE,MAAM,WAAW,sBAAuB,SAAQ,SAAS;IAErD,QAAQ,EAAE,MAAM,CAAC;IAEjB,iBAAiB,CAAC,EAAE,MAAM,CAAC;CAC9B;AAGD;;;;;;;;;GASG;AACH,qBAAa,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,SAAwB;IACxC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,mBAAmB,CAAmB;gBAGlC,IAAI,EAAE,sBAAsB;IAuBxC;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAyBvF;;;;OAIG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAmD9C,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAK5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAYpD"}
|
|
@@ -1,115 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
/**
|
|
3
|
-
* This class implements the position encoding logic described in the
|
|
4
|
-
* 2017 paper "Attention Is All You Need".
|
|
5
|
-
*
|
|
6
|
-
* This layer is untrainable and accepts inputs of shape `[ batch, sequences, embedding dims ]`
|
|
7
|
-
* and adds positional encoding to return an output tensor of the same shape.
|
|
8
|
-
*
|
|
9
|
-
* @param embedDim the size of each token/word's embedding
|
|
10
|
-
* @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
|
|
11
|
-
*/
|
|
12
|
-
export class PositionalEncoding extends tf.layers.Layer {
|
|
13
|
-
static className = "PositionalEncoding";
|
|
14
|
-
maxSequenceLength;
|
|
15
|
-
embedDim;
|
|
16
|
-
positionalEncodings;
|
|
17
|
-
constructor(args) {
|
|
18
|
-
super(args);
|
|
19
|
-
this.maxSequenceLength = args.maxSequenceLength ?? 5120;
|
|
20
|
-
this.embedDim = args.embedDim;
|
|
21
|
-
if (this.maxSequenceLength < 1) {
|
|
22
|
-
throw Error(`${this.getClassName()}::constructor ${this.name} maxSequenceLength` +
|
|
23
|
-
` (${args.maxSequenceLength}) must be greater than 0`);
|
|
24
|
-
}
|
|
25
|
-
if (this.embedDim < 1) {
|
|
26
|
-
throw Error(`${this.getClassName()}::constructor ${this.name} embedDim` +
|
|
27
|
-
` (${args.embedDim}) must be greater than 0`);
|
|
28
|
-
}
|
|
29
|
-
// positional encodings are not trainable
|
|
30
|
-
this.positionalEncodings = this.addWeight('positional_encodings', [this.maxSequenceLength, this.embedDim], "float32", tf.initializers.zeros(), undefined, false);
|
|
31
|
-
}
|
|
32
|
-
/**
|
|
33
|
-
* Forward propagation. Injects positional encoding to the input embeddings
|
|
34
|
-
*/
|
|
35
|
-
call(inputs, kwargs) {
|
|
36
|
-
// validate the input tensors
|
|
37
|
-
const input = Array.isArray(inputs) ? inputs[0] : inputs;
|
|
38
|
-
const sequences = input.shape[1];
|
|
39
|
-
if (input.shape.length != 3 || input.shape[2] != this.embedDim) {
|
|
40
|
-
throw Error(`${this.getClassName()}::call ${this.name} expected an input shape of` +
|
|
41
|
-
` [batch, (up to ${this.maxSequenceLength}), ${this.embedDim}], instead got ${input.shape}`);
|
|
42
|
-
}
|
|
43
|
-
if (sequences > this.maxSequenceLength) {
|
|
44
|
-
// unexpected sequence length
|
|
45
|
-
throw Error(`${this.getClassName()}::call ${this.name} received an input with` +
|
|
46
|
-
` sequence length (${sequences}) which is greater than the max sequence length` +
|
|
47
|
-
` ${this.maxSequenceLength}`);
|
|
48
|
-
}
|
|
49
|
-
// perform forward propagation
|
|
50
|
-
return tf.tidy(() => {
|
|
51
|
-
return input.add(this.positionalEncodings.read()
|
|
52
|
-
.slice([0, 0], [sequences, this.embedDim]) // gets the first "sequences" rows
|
|
53
|
-
.expandDims(0)); // introduce the batch dimension and let add() broadcast it
|
|
54
|
-
});
|
|
55
|
-
}
|
|
56
|
-
/**
|
|
57
|
-
* Generate the positional encoding from the paper Attention Is All You Need.
|
|
58
|
-
* Note that because the inner term of the position formula is the same for both even
|
|
59
|
-
* and odd indices, we only create half of it and apply sine and cosine individually.
|
|
60
|
-
*/
|
|
61
|
-
build(inputShape) {
|
|
62
|
-
tf.tidy(() => {
|
|
63
|
-
const embedDimHalved = Math.ceil(this.embedDim / 2);
|
|
64
|
-
// create the position matrix as [ 0, 1, 2, 3, etc ],
|
|
65
|
-
// and broadcast it horizontally to match the number of embeddings,
|
|
66
|
-
const numerator = tf.range(0, this.maxSequenceLength, 1)
|
|
67
|
-
.reshape([this.maxSequenceLength, 1])
|
|
68
|
-
// this creates an extra, unsued positional encoding column later on for odd embedding sizes
|
|
69
|
-
.broadcastTo([this.maxSequenceLength, embedDimHalved]);
|
|
70
|
-
// the inner term's denominator's exponent's numerator is created as
|
|
71
|
-
// [ 0, 0, 2, 2, 4, 4, etc ] ( technically [0, 2, 4] as explained above ) and not
|
|
72
|
-
// [ 0, 2, 4, 6, 8, 10, etc ] because the even and odd indices are counted as pairs
|
|
73
|
-
// when incrementing "i",
|
|
74
|
-
// the denominator formula is 10_000^(2i/d_model) where each "i" is a sine cosine pair
|
|
75
|
-
const denominator = tf.pow(10_000, tf.range(0, this.embedDim, 2).div(this.embedDim));
|
|
76
|
-
const inner_term = numerator.div(denominator);
|
|
77
|
-
const sine = tf.sin(inner_term);
|
|
78
|
-
const cosine = tf.cos(inner_term);
|
|
79
|
-
// horizontally interweave the sine and cosine columns together to form
|
|
80
|
-
// [sin, cos, sin, cos, etc]
|
|
81
|
-
// [sin, cos, sin, cos, etc]
|
|
82
|
-
// etc
|
|
83
|
-
const interweaved = [];
|
|
84
|
-
const ALL_ROWS = -1;
|
|
85
|
-
const ONE_COL = 1;
|
|
86
|
-
const FIRST_ROW = 0;
|
|
87
|
-
for (let targetCol = 0; targetCol < this.embedDim / 2; targetCol++) {
|
|
88
|
-
interweaved.push(sine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
|
|
89
|
-
if (targetCol != Math.floor(this.embedDim / 2)) {
|
|
90
|
-
// for odd numbered embedDim sizes skip the last cosine column
|
|
91
|
-
// e.g. if embedDim = 5, create [ i=0 (sin), i=0 (cos), i=1 (sin), i=1 (cos), i=2 (sin) ]
|
|
92
|
-
// and the final i=2 (cos) is ignored
|
|
93
|
-
interweaved.push(cosine.slice([FIRST_ROW, targetCol], [ALL_ROWS, ONE_COL]));
|
|
94
|
-
}
|
|
95
|
-
}
|
|
96
|
-
// add the positional encoding
|
|
97
|
-
this.setWeights([tf.concat(interweaved, 1)]);
|
|
98
|
-
});
|
|
99
|
-
super.build(inputShape);
|
|
100
|
-
}
|
|
101
|
-
computeOutputShape(inputShape) {
|
|
102
|
-
return inputShape;
|
|
103
|
-
}
|
|
104
|
-
getConfig() {
|
|
105
|
-
const base_config = super.getConfig();
|
|
106
|
-
const config = {
|
|
107
|
-
maxSequenceLength: this.maxSequenceLength,
|
|
108
|
-
embedDim: this.embedDim,
|
|
109
|
-
};
|
|
110
|
-
Object.assign(config, base_config);
|
|
111
|
-
return config;
|
|
112
|
-
}
|
|
113
|
-
}
|
|
114
|
-
tf.serialization.registerClass(PositionalEncoding);
|
|
115
|
-
//# sourceMappingURL=positional_encoding.js.map
|