@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,145 +1,73 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
export interface UNetArgs {
|
|
6
|
-
/**
|
|
7
|
-
* The starting number of filters.
|
|
8
|
-
*/
|
|
9
|
-
filters: number;
|
|
10
|
-
/**
|
|
11
|
-
* The number of categories. For binary segmentation, `units=1`.
|
|
12
|
-
*/
|
|
13
|
-
units: number;
|
|
14
|
-
/**
|
|
15
|
-
* The activation of the final output convolution layer. Defaults to `sigmoid` if `categories=1`, else `softmax`.
|
|
16
|
-
*/
|
|
17
|
-
activation?: ActivationIdentifier;
|
|
18
|
-
/**
|
|
19
|
-
* The depth of the U-Net or the number of contractions and the number of expansions.
|
|
20
|
-
*/
|
|
21
|
-
depth: number;
|
|
22
|
-
/**
|
|
23
|
-
* Adds residual connections to transform the model into a ResUNet. Defaults to `false`.
|
|
24
|
-
*/
|
|
25
|
-
residual?: boolean;
|
|
26
|
-
/**
|
|
27
|
-
* Adds batch normalization to convolutions. Best used for batched inputs. Defaults to `false`.
|
|
28
|
-
*/
|
|
29
|
-
batchNorm?: boolean;
|
|
30
|
-
/**
|
|
31
|
-
* Set the unbatched input shape of the U-Net in the format `[height, width, channels]`. Defaults to `[null, null, 3]`. If set, only channels is mandatory.
|
|
32
|
-
*/
|
|
33
|
-
inputShape?: [number | null, number | null, number];
|
|
34
|
-
}
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
export type UNetModelArgs = UNetArgs & Omit<tf.SequentialArgs, "layers">;
|
|
38
|
-
|
|
39
|
-
|
|
40
2
|
export class UNetModel extends tf.Sequential {
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
const {
|
|
44
|
-
filters,
|
|
45
|
-
units,
|
|
46
|
-
activation = units == 1 ? "sigmoid" : "softmax",
|
|
47
|
-
depth,
|
|
48
|
-
residual = false,
|
|
49
|
-
batchNorm = false,
|
|
50
|
-
inputShape = [null, null, 3],
|
|
51
|
-
...sequentialArgs
|
|
52
|
-
} = args;
|
|
53
|
-
|
|
3
|
+
constructor(args) {
|
|
4
|
+
const { filters, units, activation = units == 1 ? "sigmoid" : "softmax", depth, residual = false, batchNorm = false, inputShape = [null, null, 3], ...sequentialArgs } = args;
|
|
54
5
|
sequentialArgs.name = sequentialArgs.name ?? "unet_model";
|
|
55
|
-
|
|
56
6
|
super({
|
|
57
7
|
...sequentialArgs,
|
|
58
8
|
// calling user should not modify the layers after instantiation
|
|
59
9
|
layers: [createUNet({ filters, units, activation, depth, residual, batchNorm, inputShape })]
|
|
60
10
|
});
|
|
61
11
|
}
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
override summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void {
|
|
12
|
+
summary(lineLength, positions, printFn) {
|
|
65
13
|
super.summary(lineLength, positions, printFn);
|
|
66
|
-
|
|
14
|
+
this.layers[0].summary(lineLength, positions, printFn);
|
|
67
15
|
}
|
|
68
16
|
}
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }: UNetModelArgs) {
|
|
17
|
+
export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }) {
|
|
72
18
|
if (units < 1) {
|
|
73
19
|
throw Error(`createUNet: units should be >= 1, got ${units}`);
|
|
74
20
|
}
|
|
75
|
-
|
|
76
21
|
const [image_height, image_width] = inputShape;
|
|
77
22
|
const divisble_by = 2 ** depth;
|
|
78
|
-
|
|
79
23
|
if ((image_height != null && image_height % divisble_by != 0) ||
|
|
80
24
|
image_width != null && image_width % divisble_by != 0) {
|
|
81
|
-
throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`)
|
|
25
|
+
throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`);
|
|
82
26
|
}
|
|
83
|
-
|
|
84
27
|
const input = tf.input({ shape: inputShape });
|
|
85
|
-
|
|
86
|
-
const skip_connection: tf.SymbolicTensor[] = [];
|
|
87
|
-
|
|
28
|
+
const skip_connection = [];
|
|
88
29
|
let x = input;
|
|
89
|
-
|
|
90
30
|
// calculate the filter sizes for each level
|
|
91
31
|
const filter_sizes = Array.from({ length: depth }, (_, i) => filters * (2 ** i));
|
|
92
|
-
|
|
93
32
|
for (const filter_size of filter_sizes) {
|
|
94
33
|
const contraction = contractionBlock(x, filter_size, residual, batchNorm, `contraction-f${filter_size}`);
|
|
95
|
-
|
|
96
|
-
x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction) as tf.SymbolicTensor;
|
|
34
|
+
x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction);
|
|
97
35
|
skip_connection.push(contraction);
|
|
98
36
|
}
|
|
99
|
-
|
|
100
|
-
x = contractionBlock(x, filter_sizes.at(-1)! * 2, residual, batchNorm, "bottleneck");
|
|
101
|
-
|
|
37
|
+
x = contractionBlock(x, filter_sizes.at(-1) * 2, residual, batchNorm, "bottleneck");
|
|
102
38
|
for (let i = filter_sizes.length - 1; i >= 0; i--) {
|
|
103
39
|
x = expansionBlock(x, skip_connection[i], filter_sizes[i], residual, batchNorm, `expansion-f${filter_sizes[i]}`);
|
|
104
40
|
}
|
|
105
|
-
|
|
106
41
|
const output = tf.layers.conv2d({
|
|
107
42
|
filters: units,
|
|
108
43
|
kernelSize: 1,
|
|
109
44
|
padding: "same",
|
|
110
45
|
activation: activation ?? (units == 1 ? "sigmoid" : "softmax"),
|
|
111
46
|
name: "output-conv"
|
|
112
|
-
}).apply(x)
|
|
113
|
-
|
|
47
|
+
}).apply(x);
|
|
114
48
|
return tf.model({ inputs: input, outputs: output, name: "u_net" });
|
|
115
49
|
}
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
export async function loadUNetModel(pathOrIOHandler: string | tf.io.IOHandler, options?: tf.io.LoadOptions) {
|
|
50
|
+
export async function loadUNetModel(pathOrIOHandler, options) {
|
|
119
51
|
const model = await tf.loadLayersModel(pathOrIOHandler, options);
|
|
120
52
|
const unet = createUNet({ depth: 1, filters: 4, units: 1 }); // these are dummy args that are overwritten
|
|
121
53
|
const { name, ...rest } = model;
|
|
122
54
|
Object.assign(unet, rest);
|
|
123
|
-
|
|
124
55
|
return unet;
|
|
125
56
|
}
|
|
126
|
-
|
|
127
|
-
|
|
128
57
|
/**
|
|
129
58
|
* The contraction block of a U-Net
|
|
130
|
-
*
|
|
59
|
+
*
|
|
131
60
|
* Conv > BN > ReLU > Conv > BN + residual > ReLU
|
|
132
|
-
*
|
|
61
|
+
*
|
|
133
62
|
* TODO: for residual, change order to (BN > ReLU > Conv)x2 + residual
|
|
134
|
-
*
|
|
63
|
+
*
|
|
135
64
|
* @param x a previous layer's symbolic output
|
|
136
65
|
* @param filters the number of filters, usually half the previous expansion block's
|
|
137
66
|
* @param residual includes a residual connection
|
|
138
67
|
* @param batchNorm applies batch normalization before ReLU activation
|
|
139
68
|
* @param name a unique name for the contraction block
|
|
140
69
|
*/
|
|
141
|
-
function contractionBlock(x
|
|
142
|
-
|
|
70
|
+
function contractionBlock(x, filters, residual, batchNorm, name) {
|
|
143
71
|
const conv1 = tf.layers.conv2d({
|
|
144
72
|
filters,
|
|
145
73
|
kernelSize: 3,
|
|
@@ -149,7 +77,6 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
|
|
|
149
77
|
name: `${name}-1-conv2d`
|
|
150
78
|
});
|
|
151
79
|
const relu1 = tf.layers.reLU({ name: `${name}-1-relu` });
|
|
152
|
-
|
|
153
80
|
const conv2 = tf.layers.conv2d({
|
|
154
81
|
filters,
|
|
155
82
|
kernelSize: 3,
|
|
@@ -159,24 +86,17 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
|
|
|
159
86
|
name: `${name}-2-conv2d`
|
|
160
87
|
});
|
|
161
88
|
const relu2 = tf.layers.reLU({ name: `${name}-2-relu` });
|
|
162
|
-
|
|
163
89
|
let forward = conv1.apply(x);
|
|
164
|
-
|
|
165
90
|
if (batchNorm) {
|
|
166
91
|
forward = tf.layers.batchNormalization({ name: `${name}-1-batchnorm` }).apply(forward);
|
|
167
92
|
}
|
|
168
|
-
|
|
169
93
|
forward = relu1.apply(forward);
|
|
170
|
-
|
|
171
94
|
forward = conv2.apply(forward);
|
|
172
|
-
|
|
173
95
|
if (batchNorm) {
|
|
174
96
|
forward = tf.layers.batchNormalization({ name: `${name}-2-batchnorm` }).apply(forward);
|
|
175
97
|
}
|
|
176
|
-
|
|
177
98
|
if (residual) {
|
|
178
99
|
let residual_skip = x;
|
|
179
|
-
|
|
180
100
|
if (x.shape[x.shape.length - 1] != filters) {
|
|
181
101
|
// a 1x1 convolution on the input to ensure the residual connection's
|
|
182
102
|
// channels/filters dim matches the convolution output
|
|
@@ -187,32 +107,26 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
|
|
|
187
107
|
useBias: !batchNorm,
|
|
188
108
|
kernelInitializer: "heNormal",
|
|
189
109
|
name: `${name}-residual`
|
|
190
|
-
}).apply(x)
|
|
110
|
+
}).apply(x);
|
|
191
111
|
}
|
|
192
|
-
|
|
193
112
|
if (batchNorm) {
|
|
194
113
|
residual_skip = tf.layers.batchNormalization({
|
|
195
114
|
name: `${name}-residual-batchnorm`
|
|
196
|
-
}).apply(residual_skip)
|
|
115
|
+
}).apply(residual_skip);
|
|
197
116
|
}
|
|
198
|
-
|
|
199
117
|
forward = tf.layers.add().apply([
|
|
200
|
-
residual_skip
|
|
201
|
-
forward
|
|
202
|
-
])
|
|
118
|
+
residual_skip,
|
|
119
|
+
forward
|
|
120
|
+
]);
|
|
203
121
|
}
|
|
204
|
-
|
|
205
122
|
forward = relu2.apply(forward);
|
|
206
|
-
|
|
207
|
-
return forward as tf.SymbolicTensor;
|
|
123
|
+
return forward;
|
|
208
124
|
}
|
|
209
|
-
|
|
210
|
-
|
|
211
125
|
/**
|
|
212
126
|
* The expansion block of a U-Net
|
|
213
|
-
*
|
|
127
|
+
*
|
|
214
128
|
* Upconv + skip > contraction block
|
|
215
|
-
*
|
|
129
|
+
*
|
|
216
130
|
* @param x a previous layer's symbolic output
|
|
217
131
|
* @param skip the corresponding contraction block's output (before pool), shape matches `x`
|
|
218
132
|
* @param filters the number of filters, usually half the previous expansion block's
|
|
@@ -220,8 +134,7 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
|
|
|
220
134
|
* @param batchNorm apply batch normalization, should be `false` when batch size is `1`
|
|
221
135
|
* @param name a unique name for the contraction block
|
|
222
136
|
*/
|
|
223
|
-
function expansionBlock(x
|
|
224
|
-
|
|
137
|
+
function expansionBlock(x, skip, filters, residual, batchNorm, name) {
|
|
225
138
|
const upconv = tf.layers.conv2dTranspose({
|
|
226
139
|
filters,
|
|
227
140
|
padding: "same",
|
|
@@ -230,11 +143,9 @@ function expansionBlock(x: tf.SymbolicTensor, skip: tf.SymbolicTensor, filters:
|
|
|
230
143
|
kernelInitializer: "heNormal",
|
|
231
144
|
name: `${name}-upconv`
|
|
232
145
|
});
|
|
233
|
-
|
|
234
146
|
const concat = tf.layers.concatenate({ axis: -1, name: `${name}-concat-upconv-skip` });
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
forward = concat.apply([forward, skip]) as tf.SymbolicTensor;
|
|
238
|
-
|
|
147
|
+
let forward = upconv.apply(x);
|
|
148
|
+
forward = concat.apply([forward, skip]);
|
|
239
149
|
return contractionBlock(forward, filters, residual, batchNorm, name);
|
|
240
150
|
}
|
|
151
|
+
//# sourceMappingURL=u_net.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"u_net.js","sourceRoot":"","sources":["../../src/models/u_net.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAuCvC,MAAM,OAAO,SAAU,SAAQ,EAAE,CAAC,UAAU;IAExC,YAAY,IAAmB;QAC3B,MAAM,EACF,OAAO,EACP,KAAK,EACL,UAAU,GAAG,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,EAC/C,KAAK,EACL,QAAQ,GAAG,KAAK,EAChB,SAAS,GAAG,KAAK,EACjB,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAC5B,GAAG,cAAc,EACpB,GAAG,IAAI,CAAC;QAET,cAAc,CAAC,IAAI,GAAG,cAAc,CAAC,IAAI,IAAI,YAAY,CAAC;QAE1D,KAAK,CAAC;YACF,GAAG,cAAc;YACjB,gEAAgE;YAChE,MAAM,EAAE,CAAC,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,UAAU,EAAE,CAAC,CAAC;SAC/F,CAAC,CAAC;IACP,CAAC;IAGQ,OAAO,CAAC,UAAmB,EAAE,SAAoB,EAAE,OAA2D;QACnH,KAAK,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;QAC7C,IAAI,CAAC,MAAM,CAAC,CAAC,CAAoB,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;IAC/E,CAAC;CACJ;AAGD,MAAM,UAAU,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,QAAQ,GAAG,KAAK,EAAE,SAAS,GAAG,KAAK,EAAE,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAAiB;IAC9I,IAAI,KAAK,GAAG,CAAC,EAAE,CAAC;QACZ,MAAM,KAAK,CAAC,yCAAyC,KAAK,EAAE,CAAC,CAAC;IAClE,CAAC;IAED,MAAM,CAAC,YAAY,EAAE,WAAW,CAAC,GAAG,UAAU,CAAC;IAC/C,MAAM,WAAW,GAAG,CAAC,IAAI,KAAK,CAAC;IAE/B,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,YAAY,GAAG,WAAW,IAAI,CAAC,CAAC;QACzD,WAAW,IAAI,IAAI,IAAI,WAAW,GAAG,WAAW,IAAI,CAAC,EAAE,CAAC;QACxD,MAAM,KAAK,CAAC,wEAAwE,WAAW,GAAG,CAAC,CAAA;IACvG,CAAC;IAED,MAAM,KAAK,GAAG,EAAE,CAAC,KAAK,CAAC,EAAE,KAAK,EAAE,UAAU,EAAE,CAAC,CAAC;IAE9C,MAAM,eAAe,GAAwB,EAAE,CAAC;IAEhD,IAAI,CAAC,GAAG,KAAK,CAAC;IAEd,4CAA4C;IAC5C,MAAM,YAAY,GAAG,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,OAAO,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEjF,KAAK,MAAM,WAAW,IAAI,YAAY,EAAE,CAAC;QACrC,MAAM,WAAW,GAAG,gBAAgB,CAAC,CAAC,EAAE,WAAW,EAAE,QAAQ,EAAE,SAAS,EAAE,gBAAgB,WAAW,EAAE,CAAC,CAAC;QAEzG,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,YAAY,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,IAAI,EAAE,SAAS,WAAW,EAAE,EAAE,CAAC,CAAC,KAAK,CAAC,WAAW,CAAsB,CAAC;QAC9H,eAAe,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;IACtC,CAAC;IAED,CAAC,GAAG,gBAAgB,CAAC,CAAC,EAAE,YAAY,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,GAAG,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,YAAY,CAAC,CAAC;IAErF,KAAK,IAAI,CAAC,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QAChD,CAAC,GAAG,cAAc,CAAC,CAAC,EAAE,eAAe,CAAC,CAAC,CAAC,EAAE,YAAY,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,cAAc,YAAY,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;IACrH,CAAC;IAED,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC5B,OAAO,EAAE,KAAK;QACd,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,UAAU,IAAI,CAAC,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,CAAC;QAC9D,IAAI,EAAE,aAAa;KACtB,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IAEjC,OAAO,EAAE,CAAC,KAAK,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC;AACvE,CAAC;AAGD,MAAM,CAAC,KAAK,UAAU,aAAa,CAAC,eAAyC,EAAE,OAA2B;IACtG,MAAM,KAAK,GAAG,MAAM,EAAE,CAAC,eAAe,CAAC,eAAe,EAAE,OAAO,CAAC,CAAC;IACjE,MAAM,IAAI,GAAG,UAAU,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,4CAA4C;IACzG,MAAM,EAAE,IAAI,EAAE,GAAG,IAAI,EAAE,GAAG,KAAK,CAAC;IAChC,MAAM,CAAC,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;IAE1B,OAAO,IAAI,CAAC;AAChB,CAAC;AAGD;;;;;;;;;;;;GAYG;AACH,SAAS,gBAAgB,CAAC,CAAoB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEhH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,IAAI,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAE7B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,IAAI,QAAQ,EAAE,CAAC;QACX,IAAI,aAAa,GAAG,CAAC,CAAC;QAEtB,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,IAAI,OAAO,EAAE,CAAC;YACzC,qEAAqE;YACrE,sDAAsD;YACtD,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;gBAC7B,OAAO;gBACP,UAAU,EAAE,CAAC;gBACb,OAAO,EAAE,MAAM;gBACf,OAAO,EAAE,CAAC,SAAS;gBACnB,iBAAiB,EAAE,UAAU;gBAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;aAC3B,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;QACrC,CAAC;QAED,IAAI,SAAS,EAAE,CAAC;YACZ,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC;gBACzC,IAAI,EAAE,GAAG,IAAI,qBAAqB;aACrC,CAAC,CAAC,KAAK,CAAC,aAAa,CAAsB,CAAC;QACjD,CAAC;QAED,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC;YAC5B,aAAkC;YAClC,OAA4B;SAC/B,CAAC,CAAA;IACN,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,OAA4B,CAAC;AACxC,CAAC;AAGD;;;;;;;;;;;GAWG;AACH,SAAS,cAAc,CAAC,CAAoB,EAAE,IAAuB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEvI,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,eAAe,CAAC;QACrC,OAAO;QACP,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,CAAC;QACV,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,SAAS;KACzB,CAAC,CAAC;IAEH,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,WAAW,CAAC,EAAE,IAAI,EAAE,CAAC,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,qBAAqB,EAAE,CAAC,CAAC;IAEvF,IAAI,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IACnD,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,EAAE,IAAI,CAAC,CAAsB,CAAC;IAE7D,OAAO,gBAAgB,CAAC,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC;AACzE,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
export interface KvCacheArgs {
|
|
3
|
+
batchSize: number;
|
|
4
|
+
maxSequenceLength: number;
|
|
5
|
+
numHeads: number;
|
|
6
|
+
headDim: number;
|
|
7
|
+
dtype?: tf.DataType;
|
|
8
|
+
}
|
|
9
|
+
/**
|
|
10
|
+
* A container for KV caches. A model should initialize one KV cache
|
|
11
|
+
*/
|
|
12
|
+
export declare class KvCacheContainer {
|
|
13
|
+
protected caches: Map<string, KvCache>;
|
|
14
|
+
protected max_sequence_length: number;
|
|
15
|
+
constructor(maxSequenceLength: number);
|
|
16
|
+
create(id: string, args: Omit<KvCacheArgs, "maxSequenceLength">): void;
|
|
17
|
+
/**
|
|
18
|
+
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
19
|
+
*/
|
|
20
|
+
update(id: string, key: tf.Tensor4D, value: tf.Tensor4D): {
|
|
21
|
+
keyCache: tf.Variable<tf.Rank.R4>;
|
|
22
|
+
valueCache: tf.Variable<tf.Rank.R4>;
|
|
23
|
+
} | undefined;
|
|
24
|
+
reset(): void;
|
|
25
|
+
dispose(): void;
|
|
26
|
+
get size(): number;
|
|
27
|
+
get maxSequenceLength(): number;
|
|
28
|
+
}
|
|
29
|
+
export declare class KvCache {
|
|
30
|
+
protected key_cache: tf.Variable<tf.Rank.R4>;
|
|
31
|
+
protected value_cache: tf.Variable<tf.Rank.R4>;
|
|
32
|
+
protected current_position: number;
|
|
33
|
+
protected batch_size: number;
|
|
34
|
+
protected max_sequence_length: number;
|
|
35
|
+
protected num_kv_heads: number;
|
|
36
|
+
protected head_dim: number;
|
|
37
|
+
constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype }: KvCacheArgs);
|
|
38
|
+
/**
|
|
39
|
+
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
40
|
+
*/
|
|
41
|
+
update(key: tf.Tensor4D, value: tf.Tensor4D): {
|
|
42
|
+
keyCache: tf.Variable<tf.Rank.R4>;
|
|
43
|
+
valueCache: tf.Variable<tf.Rank.R4>;
|
|
44
|
+
};
|
|
45
|
+
protected mergeIntoCache(new_value: tf.Tensor4D, current_cache: tf.Tensor4D): tf.Tensor4D;
|
|
46
|
+
reset(): void;
|
|
47
|
+
dispose(): void;
|
|
48
|
+
/**
|
|
49
|
+
* The size of the KV cache, also the number of tokens since the first one.
|
|
50
|
+
*/
|
|
51
|
+
get size(): number;
|
|
52
|
+
}
|
|
53
|
+
//# sourceMappingURL=kv_cache.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
/**
|
|
3
|
+
* A container for KV caches. A model should initialize one KV cache
|
|
4
|
+
*/
|
|
5
|
+
export class KvCacheContainer {
|
|
6
|
+
caches = new Map();
|
|
7
|
+
max_sequence_length;
|
|
8
|
+
constructor(maxSequenceLength) {
|
|
9
|
+
if (!maxSequenceLength) {
|
|
10
|
+
throw Error(`KvCacheContainer: expected KV cache maximum sequence length to be greater than 0, got: ${String(maxSequenceLength)}`);
|
|
11
|
+
}
|
|
12
|
+
this.max_sequence_length = maxSequenceLength;
|
|
13
|
+
}
|
|
14
|
+
create(id, args) {
|
|
15
|
+
const new_cache = new KvCache({
|
|
16
|
+
...args,
|
|
17
|
+
maxSequenceLength: this.max_sequence_length
|
|
18
|
+
});
|
|
19
|
+
this.caches.set(id, new_cache);
|
|
20
|
+
}
|
|
21
|
+
/**
|
|
22
|
+
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
23
|
+
*/
|
|
24
|
+
update(id, key, value) {
|
|
25
|
+
const kv_cache = this.caches.get(id);
|
|
26
|
+
if (!kv_cache) {
|
|
27
|
+
return undefined;
|
|
28
|
+
}
|
|
29
|
+
const { keyCache, valueCache } = kv_cache.update(key, value);
|
|
30
|
+
// slicing to get only the past key and value projections, but normally
|
|
31
|
+
// in TensorFlow and PyTorch the full cache is returned and masked for
|
|
32
|
+
// graph purposes
|
|
33
|
+
return tf.tidy(() => {
|
|
34
|
+
const k_cache = keyCache.slice([0, 0, 0, 0], [keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
|
|
35
|
+
const v_cache = valueCache.slice([0, 0, 0, 0], [valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
|
|
36
|
+
return {
|
|
37
|
+
keyCache: k_cache,
|
|
38
|
+
valueCache: v_cache
|
|
39
|
+
};
|
|
40
|
+
});
|
|
41
|
+
}
|
|
42
|
+
reset() {
|
|
43
|
+
this.caches.forEach(cache => {
|
|
44
|
+
cache.reset();
|
|
45
|
+
});
|
|
46
|
+
}
|
|
47
|
+
dispose() {
|
|
48
|
+
this.caches.forEach(cache => {
|
|
49
|
+
cache.dispose();
|
|
50
|
+
});
|
|
51
|
+
}
|
|
52
|
+
get size() {
|
|
53
|
+
// the size of all KV caches are expected to be the same, just use the first one
|
|
54
|
+
return this.caches.entries().next().value?.[1].size ?? 0;
|
|
55
|
+
}
|
|
56
|
+
get maxSequenceLength() {
|
|
57
|
+
return this.max_sequence_length;
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
export class KvCache {
|
|
61
|
+
key_cache;
|
|
62
|
+
value_cache;
|
|
63
|
+
// the size of the KV cache, represents the number of tokens since the first chat token
|
|
64
|
+
current_position = 0;
|
|
65
|
+
batch_size;
|
|
66
|
+
max_sequence_length;
|
|
67
|
+
num_kv_heads;
|
|
68
|
+
head_dim;
|
|
69
|
+
constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }) {
|
|
70
|
+
const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim];
|
|
71
|
+
this.key_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
|
|
72
|
+
this.value_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
|
|
73
|
+
this.batch_size = batchSize;
|
|
74
|
+
this.max_sequence_length = maxSequenceLength;
|
|
75
|
+
this.num_kv_heads = numHeads;
|
|
76
|
+
this.head_dim = headDim;
|
|
77
|
+
}
|
|
78
|
+
/**
|
|
79
|
+
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
80
|
+
*/
|
|
81
|
+
update(key, value) {
|
|
82
|
+
const batch_size = key.shape[0];
|
|
83
|
+
const seq_len = key.shape[2];
|
|
84
|
+
if (batch_size > this.key_cache.shape[0]) {
|
|
85
|
+
throw Error(`The current KV cache has been set up with a batch size of` +
|
|
86
|
+
` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`);
|
|
87
|
+
}
|
|
88
|
+
if (this.current_position + seq_len > this.max_sequence_length) {
|
|
89
|
+
throw Error(`The KV cache has exceeded its maximum sequence length of ${this.max_sequence_length}. Use a larger value.`);
|
|
90
|
+
}
|
|
91
|
+
const new_key_cache = this.mergeIntoCache(key, this.key_cache);
|
|
92
|
+
const new_value_cache = this.mergeIntoCache(value, this.value_cache);
|
|
93
|
+
this.key_cache.assign(new_key_cache);
|
|
94
|
+
this.value_cache.assign(new_value_cache);
|
|
95
|
+
new_key_cache.dispose();
|
|
96
|
+
new_value_cache.dispose();
|
|
97
|
+
// advance the pointer to reflect the updated cache's current
|
|
98
|
+
this.current_position += seq_len;
|
|
99
|
+
return {
|
|
100
|
+
keyCache: this.key_cache,
|
|
101
|
+
valueCache: this.value_cache,
|
|
102
|
+
};
|
|
103
|
+
}
|
|
104
|
+
mergeIntoCache(new_value, current_cache) {
|
|
105
|
+
const seq_len = new_value.shape[2];
|
|
106
|
+
return tf.tidy(() => {
|
|
107
|
+
const historical = current_cache.slice([0, 0, 0, 0], [this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
|
|
108
|
+
const future = current_cache.slice([0, 0, this.current_position + seq_len, 0], [this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
|
|
109
|
+
// merge the new tensor into the current cache to create a new, larger, cache,
|
|
110
|
+
// this is different from Python immplementations because TFJS tensors are immutable,
|
|
111
|
+
// because we cannot update a slice, we must slice and concat
|
|
112
|
+
return tf.concat([historical, new_value, future], 2);
|
|
113
|
+
});
|
|
114
|
+
}
|
|
115
|
+
reset() {
|
|
116
|
+
this.current_position = 0;
|
|
117
|
+
tf.tidy(() => {
|
|
118
|
+
const key_cache_shape = this.key_cache.shape;
|
|
119
|
+
const value_cache_shape = this.value_cache.shape;
|
|
120
|
+
this.key_cache.assign(tf.zeros(key_cache_shape));
|
|
121
|
+
this.value_cache.assign(tf.zeros(value_cache_shape));
|
|
122
|
+
});
|
|
123
|
+
}
|
|
124
|
+
dispose() {
|
|
125
|
+
this.key_cache.dispose();
|
|
126
|
+
this.value_cache.dispose();
|
|
127
|
+
}
|
|
128
|
+
/**
|
|
129
|
+
* The size of the KV cache, also the number of tokens since the first one.
|
|
130
|
+
*/
|
|
131
|
+
get size() {
|
|
132
|
+
return this.current_position;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
//# sourceMappingURL=kv_cache.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { KvCacheContainer } from "@/kv_cache";
|
|
3
|
+
import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
|
|
4
|
+
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
5
|
+
/**
|
|
6
|
+
* MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
|
|
7
|
+
* should be used in a custom training loop because it requires the cache to be
|
|
8
|
+
* passed through the `kwargs.kvCache` argument during the `layer.apply()`
|
|
9
|
+
* forward propagation.
|
|
10
|
+
*
|
|
11
|
+
* If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
|
|
12
|
+
*/
|
|
13
|
+
export declare class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
|
|
14
|
+
static className: string;
|
|
15
|
+
protected rope: tf.layers.Layer;
|
|
16
|
+
constructor(args: MultiHeadAttentionArgs);
|
|
17
|
+
protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
|
|
18
|
+
protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D): {
|
|
19
|
+
keyCache: tf.Variable<tf.Rank.R4>;
|
|
20
|
+
valueCache: tf.Variable<tf.Rank.R4>;
|
|
21
|
+
};
|
|
22
|
+
/**
|
|
23
|
+
* Adds RoPE position encoding right after splitting heads.
|
|
24
|
+
*/
|
|
25
|
+
protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
|
|
26
|
+
query_split: tf.Tensor4D;
|
|
27
|
+
key_split: tf.Tensor4D;
|
|
28
|
+
value_split: tf.Tensor4D;
|
|
29
|
+
};
|
|
30
|
+
}
|
|
31
|
+
//# sourceMappingURL=cached_rope_multihead_attention.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"cached_rope_multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAE/F,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE;;;;;;;GAOG;AACH,qBAAa,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,SAAkC;IAElD,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAEpB,IAAI,EAAE,sBAAsB;cAMrB,OAAO,CACtB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAuC9B,SAAS,CAAC,WAAW,CAAC,YAAY,EAAE,gBAAgB,EAAE,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,WAAW,EAAE,EAAE,CAAC,QAAQ;;;;IAqBtG;;OAEG;cACgB,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAO5D,EAAE,CAAC,QAAQ;mBAEX,EAAE,CAAC,QAAQ;qBACwB,EAAE,CAAC,QAAQ;;CAIxF"}
|