@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,86 +1,58 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
3
|
-
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
4
|
-
import { generateCausalAttentionMask } from "@/utils";
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
export interface MultiHeadAttentionArgs extends LayerArgs {
|
|
8
|
-
numHeads: number;
|
|
9
|
-
embedDim: number;
|
|
10
|
-
useBias?: boolean;
|
|
11
|
-
dropout?: number;
|
|
12
|
-
causal?: boolean;
|
|
13
|
-
}
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
export interface ScaledDotProductionAttentionKwargs {
|
|
17
|
-
training?: boolean;
|
|
18
|
-
dropout?: number;
|
|
19
|
-
causal?: boolean;
|
|
20
|
-
scaling_factor?: number;
|
|
21
|
-
}
|
|
22
|
-
|
|
23
|
-
|
|
2
|
+
import { causal as generateCausalMask } from "../masks";
|
|
24
3
|
/**
|
|
25
4
|
* This MultiHead Attention layer implements the algorithm as described in
|
|
26
5
|
* the paper "Attention is all you Need" Vaswani et al., 2017.
|
|
27
|
-
*
|
|
6
|
+
*
|
|
28
7
|
* @param numHeads number of attention heads to use
|
|
29
8
|
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
30
9
|
* @param causal use causal masking, default `false`
|
|
31
10
|
* @param dropout use dropout during the attention calculations, default `0.0`
|
|
32
11
|
* @param useBias use bias for the dense sublayers, default `true`
|
|
33
|
-
*
|
|
12
|
+
*
|
|
34
13
|
* The TensorFlow version uses tf.einsum, whose gradient op has not yet been
|
|
35
14
|
* implemented (https://github.com/tensorflow/tfjs/pull/4955#discussion_r619219334),
|
|
36
15
|
* therefore we follow the PyTorch implementation described in:
|
|
37
16
|
* https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention
|
|
38
17
|
* https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
39
|
-
*
|
|
18
|
+
*
|
|
40
19
|
* This implementation is different from TensorFlow's whose attention weights
|
|
41
20
|
* are shaped [embed dim, heads, embed dim] where as PyTorch and OpenAI's attention weights
|
|
42
21
|
* are shaped [embed dim, embed dim]
|
|
43
22
|
* https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/activation.py#L1080
|
|
44
23
|
* https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L53
|
|
45
|
-
*
|
|
24
|
+
*
|
|
46
25
|
* TODO: implement a fast track for self attention (query = key = value)
|
|
47
26
|
* where a single dense layer combines and replaces the query, key and projection layers
|
|
48
|
-
*
|
|
27
|
+
*
|
|
49
28
|
* TODO: add kDim and vDim to accept key and values whose embedding dimensions differ from query's.
|
|
50
29
|
*/
|
|
51
30
|
export class MultiHeadAttention extends tf.layers.Layer {
|
|
52
31
|
static className = "MultiHeadAttention";
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
32
|
+
numHeads;
|
|
33
|
+
embedDim; // size of embedding dim of inputs, also per attention head
|
|
34
|
+
useBias;
|
|
35
|
+
dropout;
|
|
36
|
+
causal; // use causal attention to mask future words
|
|
59
37
|
// projection simply means matrix multiplying query, key, and value
|
|
60
38
|
// with weights to create a representation of the inputs
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }: MultiHeadAttentionArgs) {
|
|
39
|
+
queryProjection;
|
|
40
|
+
keyProjection;
|
|
41
|
+
valueProjection;
|
|
42
|
+
outputProjection;
|
|
43
|
+
constructor({ numHeads, embedDim, useBias = true, dropout = 0.0, causal = false, ...args }) {
|
|
68
44
|
super(args);
|
|
69
|
-
|
|
70
45
|
if (embedDim % numHeads != 0) {
|
|
71
46
|
throw Error(`${this.getClassName()}::constructor ${this.name} embedDim (${embedDim}) is not divisible by numHeads (${numHeads})`);
|
|
72
47
|
}
|
|
73
|
-
|
|
74
48
|
this.numHeads = numHeads;
|
|
75
49
|
this.embedDim = embedDim;
|
|
76
50
|
this.useBias = useBias;
|
|
77
51
|
this.dropout = dropout;
|
|
78
52
|
this.causal = causal;
|
|
79
|
-
|
|
80
53
|
if (this.dropout >= 1) {
|
|
81
54
|
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
82
55
|
}
|
|
83
|
-
|
|
84
56
|
// intialize the projection weights, this should be in the
|
|
85
57
|
// build() function but is done here to avoid linting complaints
|
|
86
58
|
this.queryProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
@@ -88,188 +60,134 @@ export class MultiHeadAttention extends tf.layers.Layer {
|
|
|
88
60
|
this.valueProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
89
61
|
this.outputProjection = tf.layers.dense({ useBias, units: embedDim });
|
|
90
62
|
}
|
|
91
|
-
|
|
92
|
-
|
|
93
63
|
/**
|
|
94
64
|
* Forward propagation. Provide one input tensor or three identical tensors to self-attention.
|
|
95
65
|
* @param inputs a single tensor for self-attention or an array of exactly three
|
|
96
66
|
* tensors that are either identical (self-attention) or different (cross-attention)
|
|
97
67
|
* @param kwargs.packingMask a mask to prevent tokens from attending across document boundaries
|
|
98
68
|
*/
|
|
99
|
-
|
|
100
|
-
inputs: tf.Tensor | tf.Tensor[],
|
|
101
|
-
kwargs: Kwargs & {
|
|
102
|
-
packingMask?: tf.Tensor,
|
|
103
|
-
causalMask?: tf.Tensor,
|
|
104
|
-
}
|
|
105
|
-
): tf.Tensor | tf.Tensor[] {
|
|
69
|
+
call(inputs, kwargs) {
|
|
106
70
|
// validate the input tensors
|
|
107
71
|
if (!Array.isArray(inputs)) {
|
|
108
72
|
inputs = [inputs];
|
|
109
73
|
}
|
|
110
|
-
|
|
111
74
|
// accept only 1 input (self attention) or 3 inputs (self or cross attention)
|
|
112
75
|
if (inputs.length != 1 && inputs.length != 3) {
|
|
113
76
|
throw Error(`${this.getClassName()}::call ${this.name} expects exactly one or three input tensors, ${inputs.length} were provided`);
|
|
114
77
|
}
|
|
115
|
-
|
|
116
78
|
for (const input of inputs) {
|
|
117
79
|
if (input.shape.length != 3) {
|
|
118
80
|
throw Error(`${this.getClassName()}::call ${this.name} expected input shapes of [batch, seq, embed_dim], got ${JSON.stringify(input.shape)}`);
|
|
119
81
|
}
|
|
120
82
|
}
|
|
121
|
-
|
|
122
83
|
const [query, key, value] = inputs;
|
|
123
84
|
const packingMask = kwargs.packingMask ?? null;
|
|
124
85
|
const causalMask = kwargs.causalMask ?? null;
|
|
125
|
-
|
|
126
86
|
return inputs.length == 3
|
|
127
87
|
// cross-attention
|
|
128
|
-
? this.forward(query
|
|
88
|
+
? this.forward(query, key, value, packingMask, causalMask, kwargs)
|
|
129
89
|
// self-attention
|
|
130
|
-
: this.forward(query
|
|
90
|
+
: this.forward(query, query, query, packingMask, causalMask, kwargs);
|
|
131
91
|
}
|
|
132
|
-
|
|
133
|
-
|
|
134
92
|
/**
|
|
135
93
|
* Forward propagation
|
|
136
94
|
*/
|
|
137
|
-
|
|
138
|
-
query_input: tf.Tensor,
|
|
139
|
-
key_input: tf.Tensor,
|
|
140
|
-
value_input: tf.Tensor,
|
|
141
|
-
packing_mask: tf.Tensor | null,
|
|
142
|
-
causal_mask: tf.Tensor | null,
|
|
143
|
-
kwargs: Kwargs): tf.Tensor {
|
|
144
|
-
|
|
95
|
+
forward(query_input, key_input, value_input, packing_mask, causal_mask, kwargs) {
|
|
145
96
|
// dimensions abbreviations
|
|
146
97
|
// batch = the number of sequences in the input
|
|
147
98
|
// seq = the length of each sequence in the input
|
|
148
99
|
// dims = the size of each token's embedding
|
|
149
100
|
return tf.tidy(() => {
|
|
150
101
|
const { query, key, value } = this.applyInputProjections(query_input, key_input, value_input);
|
|
151
|
-
|
|
152
102
|
// swap the seq and heads dimensions: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
|
|
153
103
|
const move_head_dim_forward = [0, 2, 1, 3];
|
|
154
|
-
|
|
155
|
-
const {
|
|
156
|
-
query_split, key_split, value_split
|
|
157
|
-
} = this.splitHeads(query, key, value, move_head_dim_forward);
|
|
158
|
-
|
|
104
|
+
const { query_split, key_split, value_split } = this.splitHeads(query, key, value, move_head_dim_forward);
|
|
159
105
|
// apply scaled dot production attention to get [batch, seq, numHeads, embedDim]
|
|
160
|
-
const spda = MultiHeadAttention.scaledDotProductionAttention(
|
|
161
|
-
query_split, key_split, value_split,
|
|
162
|
-
kwargs.attentionMask ?? null, packing_mask, causal_mask,
|
|
163
|
-
this.dropout, this.causal, kwargs);
|
|
164
|
-
|
|
106
|
+
const spda = MultiHeadAttention.scaledDotProductionAttention(query_split, key_split, value_split, kwargs.attentionMask ?? null, packing_mask, causal_mask, this.dropout, this.causal, kwargs);
|
|
165
107
|
// concat heads and apply the output projection
|
|
166
|
-
const output = this.outputProjection.apply(
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
return output as tf.Tensor;
|
|
170
|
-
})
|
|
108
|
+
const output = this.outputProjection.apply(spda.transpose(move_head_dim_forward).reshape([query_input.shape[0], -1, this.embedDim]));
|
|
109
|
+
return output;
|
|
110
|
+
});
|
|
171
111
|
}
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
protected applyInputProjections(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor) {
|
|
112
|
+
applyInputProjections(query_input, key_input, value_input) {
|
|
175
113
|
// apply input projections, this is a batched matrix multiplication operated on the last
|
|
176
114
|
// dimension of query_input and first dimension of the dense layer weights,
|
|
177
115
|
// [batch, seq, dims] x [dims, dims] = [batch x seq, dims] x [dims, dims] = [batch x seq, dims] = [batch, seq, dims]
|
|
178
116
|
return tf.tidy(() => {
|
|
179
117
|
return {
|
|
180
|
-
query: this.queryProjection.apply(query_input)
|
|
181
|
-
key: this.keyProjection.apply(key_input)
|
|
182
|
-
value: this.valueProjection.apply(value_input)
|
|
183
|
-
}
|
|
184
|
-
})
|
|
118
|
+
query: this.queryProjection.apply(query_input),
|
|
119
|
+
key: this.keyProjection.apply(key_input),
|
|
120
|
+
value: this.valueProjection.apply(value_input)
|
|
121
|
+
};
|
|
122
|
+
});
|
|
185
123
|
}
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]) {
|
|
124
|
+
splitHeads(query, key, value, shuffle) {
|
|
189
125
|
// split heads and prepare for scaled dot product attention by splitting the
|
|
190
126
|
// last dimension to get the heads, bring the heads forward
|
|
191
127
|
// [batch, seq, dims] -> [batch, seq, heads, dims / heads] -> [batch, heads, seq, head_dim]
|
|
192
128
|
const batch_size = query.shape[0];
|
|
193
129
|
const split_heads = [batch_size, -1, this.numHeads, this.embedDim / this.numHeads];
|
|
194
|
-
|
|
195
130
|
return tf.tidy(() => {
|
|
196
131
|
return {
|
|
197
|
-
query_split: query.reshape(split_heads).transpose(shuffle)
|
|
198
|
-
key_split: key.reshape(split_heads).transpose(shuffle)
|
|
199
|
-
value_split: value.reshape(split_heads).transpose(shuffle)
|
|
200
|
-
}
|
|
201
|
-
})
|
|
132
|
+
query_split: query.reshape(split_heads).transpose(shuffle),
|
|
133
|
+
key_split: key.reshape(split_heads).transpose(shuffle),
|
|
134
|
+
value_split: value.reshape(split_heads).transpose(shuffle)
|
|
135
|
+
};
|
|
136
|
+
});
|
|
202
137
|
}
|
|
203
|
-
|
|
204
|
-
|
|
205
138
|
/**
|
|
206
139
|
* Applies the scaled dot-product formula: softmax(QK_t / sqrt(d_k))V,
|
|
207
140
|
* formula (1) of the 2017 paper Attention Is All You Need
|
|
208
|
-
*
|
|
209
|
-
* @param attentionMask a mask to prevent tokens from being
|
|
141
|
+
*
|
|
142
|
+
* @param attentionMask a mask to prevent tokens from being
|
|
210
143
|
* attended to (usually for padding tokens). It should have the shape
|
|
211
144
|
* [batch, head, query_sequence_len, key_sequence_len]. To use in
|
|
212
145
|
* conjunction with causal masking, the tensor should be a boolean type
|
|
213
146
|
* where false indicates a masked token.
|
|
214
147
|
* @param packingMask a mask to prevent tokens from attending across document boundaries
|
|
215
148
|
*/
|
|
216
|
-
static scaledDotProductionAttention(
|
|
217
|
-
query: tf.Tensor,
|
|
218
|
-
key: tf.Tensor,
|
|
219
|
-
value: tf.Tensor,
|
|
220
|
-
attentionMask: tf.Tensor | null,
|
|
221
|
-
packingMask: tf.Tensor | null,
|
|
222
|
-
causalMask: tf.Tensor | null,
|
|
223
|
-
dropout: number,
|
|
224
|
-
causal: boolean,
|
|
225
|
-
kwargs: ScaledDotProductionAttentionKwargs = {}
|
|
226
|
-
): tf.Tensor {
|
|
149
|
+
static scaledDotProductionAttention(query, key, value, attentionMask, packingMask, causalMask, dropout, causal, kwargs = {}) {
|
|
227
150
|
return tf.tidy(() => {
|
|
228
151
|
const { training = false, scaling_factor } = kwargs;
|
|
229
|
-
|
|
230
152
|
key.shape.forEach((val, index) => {
|
|
231
153
|
if (key.shape[index] != value.shape[index]) {
|
|
232
154
|
throw Error(`scaledDotProductionAttention: expected key and value` +
|
|
233
155
|
` to have the same shape, got ${JSON.stringify(key.shape)} (key) and` +
|
|
234
156
|
` ${JSON.stringify(value.shape)} (value)`);
|
|
235
157
|
}
|
|
236
|
-
})
|
|
237
|
-
|
|
238
|
-
|
|
158
|
+
});
|
|
239
159
|
// mask's shape is [..., seq, seq] where seq is the number of words/tokens in the input,
|
|
240
160
|
// not adding the batch dimension yet to lessen the calculations
|
|
241
161
|
const causal_mask_shape = [
|
|
242
162
|
query.shape[query.shape.length - 2],
|
|
243
|
-
key.shape[key.shape.length - 2]
|
|
244
|
-
|
|
163
|
+
key.shape[key.shape.length - 2]
|
|
164
|
+
];
|
|
245
165
|
let mask = tf.zeros(causal_mask_shape);
|
|
246
|
-
|
|
247
166
|
if (causal && causal_mask_shape[0] > 1) {
|
|
248
167
|
if (attentionMask && attentionMask.dtype != "bool") {
|
|
249
168
|
throw Error(`scaledDotProductionAttention: the attention mask must be undefined or a boolean type if used with causal attention`);
|
|
250
169
|
}
|
|
251
|
-
|
|
252
170
|
// apply a causal attention mask so that tokens can only attend to preceding tokens,
|
|
253
171
|
// prevents looking at head
|
|
254
172
|
if (causalMask) {
|
|
255
173
|
mask = causalMask;
|
|
256
|
-
}
|
|
257
|
-
|
|
174
|
+
}
|
|
175
|
+
else {
|
|
176
|
+
mask = generateCausalMask(causal_mask_shape[0], causal_mask_shape[1]);
|
|
258
177
|
}
|
|
259
178
|
}
|
|
260
|
-
|
|
261
179
|
if (attentionMask) {
|
|
262
180
|
if (attentionMask.dtype == "bool") {
|
|
263
181
|
// convert the boolean mask to float
|
|
264
182
|
// warning: do not use 1e9, it will overflow, use something smaller like 1e7
|
|
265
183
|
mask = mask.add(attentionMask.cast("float32").sub(1).mul(1e7));
|
|
266
|
-
}
|
|
184
|
+
}
|
|
185
|
+
else {
|
|
267
186
|
// this will occur only when not using causal masking,
|
|
268
187
|
// if the attention mask is not boolean, it's assumed the masking is already calculated,
|
|
269
188
|
mask = attentionMask;
|
|
270
189
|
}
|
|
271
190
|
}
|
|
272
|
-
|
|
273
191
|
// 1. matrix multiply query and transposed key
|
|
274
192
|
// 2. divide by scaling factor
|
|
275
193
|
// 3. apply softmax to the result
|
|
@@ -280,42 +198,33 @@ export class MultiHeadAttention extends tf.layers.Layer {
|
|
|
280
198
|
.matMul(key, false, true)
|
|
281
199
|
.div(Math.sqrt(scaling_factor ?? key.shape[key.shape.length - 1]))
|
|
282
200
|
.add(mask);
|
|
283
|
-
|
|
284
201
|
if (packingMask) {
|
|
285
202
|
// packing mask is added separately because each mask within a batch may be different,
|
|
286
203
|
// so it cannot be broadcasted
|
|
287
204
|
pre_softmax = pre_softmax.add(packingMask);
|
|
288
205
|
}
|
|
289
|
-
|
|
290
206
|
const spda = tf.softmax(pre_softmax);
|
|
291
|
-
|
|
292
207
|
const spda_dropout = tf.dropout(spda, training ? dropout : 0);
|
|
293
208
|
const attention = spda_dropout.matMul(value);
|
|
294
|
-
|
|
295
209
|
return attention;
|
|
296
210
|
});
|
|
297
211
|
}
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
override build(inputShape: tf.Shape | tf.Shape[]): void {
|
|
301
|
-
let input_shape: tf.Shape[] = [];
|
|
302
|
-
|
|
212
|
+
build(inputShape) {
|
|
213
|
+
let input_shape = [];
|
|
303
214
|
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
304
|
-
input_shape = inputShape
|
|
305
|
-
}
|
|
306
|
-
|
|
215
|
+
input_shape = inputShape;
|
|
216
|
+
}
|
|
217
|
+
else {
|
|
218
|
+
input_shape = [inputShape, inputShape, inputShape];
|
|
307
219
|
}
|
|
308
|
-
|
|
309
220
|
if (input_shape.length != 1 && input_shape.length != 3) {
|
|
310
221
|
throw Error(`${this.getClassName()}::build ${this.name} accepts either exactly one or three inputs, received ${JSON.stringify(inputShape)}`);
|
|
311
222
|
}
|
|
312
|
-
|
|
313
223
|
// initialize the sublayer weights
|
|
314
224
|
this.queryProjection.build(input_shape[0]);
|
|
315
225
|
this.keyProjection.build(input_shape[1]);
|
|
316
226
|
this.valueProjection.build(input_shape[2]);
|
|
317
227
|
this.outputProjection.build(input_shape[0]);
|
|
318
|
-
|
|
319
228
|
// the sublayer weights need to be tracked by this layer otherwise
|
|
320
229
|
// backpropagation will complain about no trainable parameters found,
|
|
321
230
|
// this is an extra step that TF's Python version does not need
|
|
@@ -325,33 +234,25 @@ export class MultiHeadAttention extends tf.layers.Layer {
|
|
|
325
234
|
...this.valueProjection.trainableWeights,
|
|
326
235
|
...this.outputProjection.trainableWeights
|
|
327
236
|
];
|
|
328
|
-
|
|
329
237
|
// rename the weights otherwise they'll take on the default naming and overlap
|
|
330
238
|
// each other which breaks model loading due to duplicate weight names
|
|
331
239
|
let indexing = 0;
|
|
332
|
-
|
|
333
240
|
for (const weight of this.trainableWeights) {
|
|
334
241
|
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
335
|
-
|
|
336
|
-
|
|
242
|
+
weight.name += unique_name;
|
|
243
|
+
weight.originalName += unique_name;
|
|
337
244
|
indexing++;
|
|
338
245
|
}
|
|
339
|
-
|
|
340
246
|
super.build(inputShape);
|
|
341
247
|
}
|
|
342
|
-
|
|
343
|
-
|
|
344
248
|
/**
|
|
345
249
|
* MultiHead attention's output is the same shape the query's.
|
|
346
250
|
*/
|
|
347
|
-
|
|
251
|
+
computeOutputShape(inputShape) {
|
|
348
252
|
return Array.isArray(inputShape) && Array.isArray(inputShape[0]) ? inputShape[0] : inputShape;
|
|
349
253
|
}
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
override getConfig() {
|
|
254
|
+
getConfig() {
|
|
353
255
|
const base_config = super.getConfig();
|
|
354
|
-
|
|
355
256
|
const config = {
|
|
356
257
|
numHeads: this.numHeads,
|
|
357
258
|
embedDim: this.embedDim,
|
|
@@ -359,13 +260,10 @@ export class MultiHeadAttention extends tf.layers.Layer {
|
|
|
359
260
|
causal: this.causal,
|
|
360
261
|
dropout: this.dropout,
|
|
361
262
|
name: this.name,
|
|
362
|
-
}
|
|
363
|
-
|
|
263
|
+
};
|
|
364
264
|
Object.assign(config, base_config);
|
|
365
|
-
|
|
366
265
|
return config;
|
|
367
266
|
}
|
|
368
267
|
}
|
|
369
|
-
|
|
370
|
-
|
|
371
268
|
tf.serialization.registerClass(MultiHeadAttention);
|
|
269
|
+
//# sourceMappingURL=multihead_attention.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"multihead_attention.js","sourceRoot":"","sources":["../../src/layers/multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,UAAU,CAAC;AAoBxD;;;;;;;;;;;;;;;;;;;;;;;;;;GA0BG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACrB,QAAQ,CAAS;IACjB,QAAQ,CAAS,CAAC,2DAA2D;IAC7E,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,MAAM,CAAU,CAAC,4CAA4C;IAEhF,mEAAmE;IACnE,wDAAwD;IACrC,eAAe,CAAkB;IACjC,aAAa,CAAkB;IAC/B,eAAe,CAAkB;IACjC,gBAAgB,CAAkB;IAGrD,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,OAAO,GAAG,IAAI,EAAE,OAAO,GAAG,GAAG,EAAE,MAAM,GAAG,KAAK,EAAE,GAAG,IAAI,EAA0B;QAC9G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,QAAQ,GAAG,QAAQ,IAAI,CAAC,EAAE,CAAC;YAC3B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,cAAc,QAAQ,mCAAmC,QAAQ,GAAG,CAAC,CAAC;QACtI,CAAC;QAED,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,OAAO,GAAG,OAAO,CAAC;QACvB,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QAErB,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,0DAA0D;QAC1D,gEAAgE;QAChE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACnE,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;QACrE,IAAI,CAAC,gBAAgB,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,QAAQ,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;;;;OAKG;IACM,IAAI,CACT,MAA+B,EAC/B,MAGC;QAED,6BAA6B;QAC7B,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACzB,MAAM,GAAG,CAAC,MAAM,CAAC,CAAC;QACtB,CAAC;QAED,6EAA6E;QAC7E,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC3C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,gDAAgD,MAAM,CAAC,MAAM,gBAAgB,CAAC,CAAC;QACxI,CAAC;QAED,KAAK,MAAM,KAAK,IAAI,MAAM,EAAE,CAAC;YACzB,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBAC1B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,0DAA0D,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;YAClJ,CAAC;QACL,CAAC;QAED,MAAM,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,CAAC,GAAG,MAAM,CAAC;QACnC,MAAM,WAAW,GAAG,MAAM,CAAC,WAAW,IAAI,IAAI,CAAC;QAC/C,MAAM,UAAU,GAAG,MAAM,CAAC,UAAU,IAAI,IAAI,CAAC;QAE7C,OAAO,MAAM,CAAC,MAAM,IAAI,CAAC;YACrB,kBAAkB;YAClB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,GAAI,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC;YACrE,iBAAiB;YACjB,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,KAAM,EAAE,KAAM,EAAE,KAAM,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,CAAC,CAAC;IAChF,CAAC;IAGD;;OAEG;IACO,OAAO,CACb,WAAsB,EACtB,SAAoB,EACpB,WAAsB,EACtB,YAA8B,EAC9B,WAA6B,EAC7B,MAAc;QAEd,2BAA2B;QAC3B,+CAA+C;QAC/C,iDAAiD;QACjD,4CAA4C;QAC5C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,qBAAqB,CAAC,WAAW,EAAE,SAAS,EAAE,WAAW,CAAC,CAAC;YAE9F,oGAAoG;YACpG,MAAM,qBAAqB,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;YAE3C,MAAM,EACF,WAAW,EAAE,SAAS,EAAE,WAAW,EACtC,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,GAAG,EAAE,KAAK,EAAE,qBAAqB,CAAC,CAAC;YAE9D,gFAAgF;YAChF,MAAM,IAAI,GAAG,kBAAkB,CAAC,4BAA4B,CACxD,WAAW,EAAE,SAAS,EAAE,WAAW,EACnC,MAAM,CAAC,aAAa,IAAI,IAAI,EAAE,YAAY,EAAE,WAAW,EACvD,IAAI,CAAC,OAAO,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAEvC,+CAA+C;YAC/C,MAAM,MAAM,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CACtC,IAAI,CAAC,SAAS,CAAC,qBAAqB,CAAC,CAAC,OAAO,CAAC,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;YAE9F,OAAO,MAAmB,CAAC;QAC/B,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,qBAAqB,CAAC,WAAsB,EAAE,SAAoB,EAAE,WAAsB;QAChG,wFAAwF;QACxF,2EAA2E;QAC3E,oHAAoH;QACpH,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;gBAC3D,GAAG,EAAE,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,SAAS,CAAc;gBACrD,KAAK,EAAE,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAc;aAC9D,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGS,UAAU,CAAC,KAAgB,EAAE,GAAc,EAAE,KAAgB,EAAE,OAAiB;QACtF,4EAA4E;QAC5E,2DAA2D;QAC3D,2FAA2F;QAC3F,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,WAAW,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEnF,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO;gBACH,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACzE,SAAS,EAAE,GAAG,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;gBACrE,WAAW,EAAE,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,SAAS,CAAC,OAAO,CAAgB;aAC5E,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;;;;;;;;;OAUG;IACH,MAAM,CAAC,4BAA4B,CAC/B,KAAgB,EAChB,GAAc,EACd,KAAgB,EAChB,aAA+B,EAC/B,WAA6B,EAC7B,UAA4B,EAC5B,OAAe,EACf,MAAe,EACf,SAA6C,EAAE;QAE/C,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,EAAE,QAAQ,GAAG,KAAK,EAAE,cAAc,EAAE,GAAG,MAAM,CAAC;YAEpD,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,KAAK,EAAE,EAAE;gBAC7B,IAAI,GAAG,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC;oBACzC,MAAM,KAAK,CAAC,sDAAsD;wBAC9D,gCAAgC,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,YAAY;wBACrE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;gBACnD,CAAC;YACL,CAAC,CAAC,CAAA;YAGF,wFAAwF;YACxF,gEAAgE;YAChE,MAAM,iBAAiB,GAAG;gBACtB,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;gBACnC,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;aAAC,CAAC;YAErC,IAAI,IAAI,GAAG,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;YAEvC,IAAI,MAAM,IAAI,iBAAiB,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC;gBACrC,IAAI,aAAa,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBACjD,MAAM,KAAK,CAAC,oHAAoH,CAAC,CAAC;gBACtI,CAAC;gBAED,oFAAoF;gBACpF,2BAA2B;gBAC3B,IAAI,UAAU,EAAE,CAAC;oBACb,IAAI,GAAG,UAAU,CAAC;gBACtB,CAAC;qBAAM,CAAC;oBACJ,IAAI,GAAG,kBAAkB,CAAC,iBAAiB,CAAC,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC1E,CAAC;YACL,CAAC;YAED,IAAI,aAAa,EAAE,CAAC;gBAChB,IAAI,aAAa,CAAC,KAAK,IAAI,MAAM,EAAE,CAAC;oBAChC,oCAAoC;oBACpC,4EAA4E;oBAC5E,IAAI,GAAG,IAAI,CAAC,GAAG,CAAC,aAAa,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;gBACnE,CAAC;qBAAM,CAAC;oBACJ,sDAAsD;oBACtD,wFAAwF;oBACxF,IAAI,GAAG,aAAa,CAAC;gBACzB,CAAC;YACL,CAAC;YAED,8CAA8C;YAC9C,8BAA8B;YAC9B,iCAAiC;YACjC,wCAAwC;YACxC,mBAAmB;YACnB,+CAA+C;YAC/C,IAAI,WAAW,GAAG,KAAK;iBAClB,MAAM,CAAC,GAAG,EAAE,KAAK,EAAE,IAAI,CAAC;iBACxB,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,IAAI,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC;iBACjE,GAAG,CAAC,IAAI,CAAC,CAAC;YAEf,IAAI,WAAW,EAAE,CAAC;gBACd,sFAAsF;gBACtF,8BAA8B;gBAC9B,WAAW,GAAG,WAAW,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAC/C,CAAC;YAED,MAAM,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC;YAErC,MAAM,YAAY,GAAG,EAAE,CAAC,OAAO,CAAC,IAAI,EAAE,QAAQ,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9D,MAAM,SAAS,GAAG,YAAY,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;YAE7C,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGQ,KAAK,CAAC,UAAiC;QAC5C,IAAI,WAAW,GAAe,EAAE,CAAC;QAEjC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,WAAW,GAAG,UAAwB,CAAC;QAC3C,CAAC;aAAM,CAAC;YACJ,WAAW,GAAG,CAAC,UAAsB,EAAE,UAAsB,EAAE,UAAsB,CAAC,CAAC;QAC3F,CAAC;QAED,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,IAAI,WAAW,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YACrD,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,yDAAyD,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;QACjJ,CAAC;QAED,kCAAkC;QAClC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QACzC,IAAI,CAAC,eAAe,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,IAAI,CAAC,gBAAgB,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAE5C,kEAAkE;QAClE,qEAAqE;QACrE,+DAA+D;QAC/D,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,eAAe,CAAC,gBAAgB;YACxC,GAAG,IAAI,CAAC,gBAAgB,CAAC,gBAAgB;SAC5C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,kBAAkB,CAAC,UAAiC;QACzD,OAAO,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;IAClG,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,IAAI,EAAE,IAAI,CAAC,IAAI;SAClB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"multihead_attention.test.d.ts","sourceRoot":"","sources":["../../src/layers/multihead_attention.test.ts"],"names":[],"mappings":""}
|