@stellarapp/tfjs-stellar 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +10 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
package/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 rkuang9
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
package/README.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# tfjs-stellar
|
|
2
|
+
An extension of TensorFlow.js for implementing large language models.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# Layers
|
|
6
|
+
- MultiHeadAttention
|
|
7
|
+
- CachedRopeMultiHeadAttention
|
|
8
|
+
- TransformerDecoder
|
|
9
|
+
- TransformerEncoder
|
|
10
|
+
- GPT2DecoderBlock
|
|
11
|
+
- RotaryPositionEmbedding
|
|
12
|
+
- PositionalEncoding
|
|
13
|
+
- TokenAndPositionalEmbedding
|
|
14
|
+
|
|
15
|
+
> **Warning**:
|
|
16
|
+
> These layers are not one-to-one replications of the TensorFlow Keras equivalents
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
## Models
|
|
20
|
+
- LlmModel
|
|
21
|
+
- GptModel
|
|
22
|
+
- KvCache
|
|
23
|
+
- UNetModel
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
## Masks
|
|
27
|
+
- Causal
|
|
28
|
+
- Packing
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
## Example
|
|
32
|
+
|
|
33
|
+
```ts
|
|
34
|
+
import * as tfs from "@stellarapp/tfjs-stellar";
|
|
35
|
+
import * as tf from "@tensorflow/tfjs";
|
|
36
|
+
|
|
37
|
+
const attention = tfs.layers.multiheadAttention({ numHeads: 1, embedDim: 64 });
|
|
38
|
+
const output = attention.apply(tf.randomUniform([1, 5, 64]));
|
|
39
|
+
|
|
40
|
+
const gpt_model = tfs.models.gptModel({ numLayers: 1, numHeads: 1, embedDim: 64, vocabSize: 128 });
|
|
41
|
+
gpt_model.compile({ loss: "sparseCategoricalCrossentropy", optimizer: "adam" });
|
|
42
|
+
gpt_model.summary();
|
|
43
|
+
|
|
44
|
+
// see https://js.tensorflow.org/api/latest/#data.generator
|
|
45
|
+
// on how to create a generator dataset
|
|
46
|
+
//gpt_model.fitDataset(your_generator_dataset, { epochs: 1 });
|
|
47
|
+
```
|
package/dist/index.d.ts
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
export * as layers from "./layers";
|
|
2
|
+
export * as models from "./models";
|
|
3
|
+
export * as losses from "./losses";
|
|
4
|
+
export * as masks from "./masks";
|
|
5
|
+
export { KvCache as kvCache, KvCacheContainer as kvCacheContainer } from "./kv_cache";
|
|
6
|
+
export * as metrics from "./metrics";
|
|
7
|
+
//# sourceMappingURL=index.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,OAAO,IAAI,OAAO,EAAE,gBAAgB,IAAI,gBAAgB,EAAE,MAAM,YAAY,CAAC;AACtF,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC"}
|
package/dist/index.js
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
export * as layers from "./layers";
|
|
2
|
+
export * as models from "./models";
|
|
3
|
+
export * as losses from "./losses";
|
|
4
|
+
export * as masks from "./masks";
|
|
5
|
+
export { KvCache as kvCache, KvCacheContainer as kvCacheContainer } from "./kv_cache";
|
|
6
|
+
export * as metrics from "./metrics";
|
|
7
|
+
//# sourceMappingURL=index.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,OAAO,IAAI,OAAO,EAAE,gBAAgB,IAAI,gBAAgB,EAAE,MAAM,YAAY,CAAC;AACtF,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"jest.config.d.ts","sourceRoot":"","sources":["../jest.config.ts"],"names":[],"mappings":"AAAA;;;GAGG;AAEH,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,MAAM,CAAC;AAEnC,QAAA,MAAM,MAAM,EAAE,MAiMb,CAAC;AAEF,eAAe,MAAM,CAAC"}
|
|
@@ -2,60 +2,43 @@
|
|
|
2
2
|
* For a detailed explanation regarding each configuration property, visit:
|
|
3
3
|
* https://jestjs.io/docs/configuration
|
|
4
4
|
*/
|
|
5
|
-
|
|
6
|
-
import type { Config } from 'jest';
|
|
7
|
-
|
|
8
|
-
const config: Config = {
|
|
5
|
+
const config = {
|
|
9
6
|
setupFiles: [],
|
|
10
|
-
|
|
11
7
|
extensionsToTreatAsEsm: [".ts"],
|
|
12
|
-
|
|
13
8
|
// A map from regular expressions to paths to transformers
|
|
14
9
|
transform: {
|
|
15
10
|
"^.+\.ts?$": ["ts-jest", {
|
|
16
|
-
|
|
17
|
-
|
|
11
|
+
useESM: true
|
|
12
|
+
}],
|
|
18
13
|
},
|
|
19
|
-
|
|
20
14
|
// An array of regexp pattern strings that are matched against all test paths, matched tests are skipped
|
|
21
|
-
testPathIgnorePatterns: [
|
|
22
|
-
|
|
23
|
-
],
|
|
24
|
-
|
|
15
|
+
// testPathIgnorePatterns: [
|
|
16
|
+
//
|
|
17
|
+
// ],
|
|
25
18
|
// A map from regular expressions to module names or to arrays of module names that allow to stub out resources with a single module
|
|
26
19
|
moduleNameMapper: {
|
|
27
20
|
"^@/(.*$)": "<rootDir>/src/$1"
|
|
28
21
|
},
|
|
29
|
-
|
|
30
22
|
// All imported modules in your tests should be mocked automatically
|
|
31
23
|
// automock: false,
|
|
32
|
-
|
|
33
24
|
// Stop running tests after `n` failures
|
|
34
25
|
// bail: 0,
|
|
35
|
-
|
|
36
26
|
// The directory where Jest should store its cached dependency information
|
|
37
27
|
// cacheDirectory: "/private/var/folders/8x/0jgq0fqx5qzgtm1zc8xtrdc80000gn/T/jest_dx",
|
|
38
|
-
|
|
39
28
|
// Automatically clear mock calls, instances, contexts and results before every test
|
|
40
29
|
clearMocks: true,
|
|
41
|
-
|
|
42
30
|
// Indicates whether the coverage information should be collected while executing the test
|
|
43
31
|
collectCoverage: false,
|
|
44
|
-
|
|
45
32
|
// An array of glob patterns indicating a set of files for which coverage information should be collected
|
|
46
33
|
// collectCoverageFrom: undefined,
|
|
47
|
-
|
|
48
34
|
// The directory where Jest should output its coverage files
|
|
49
35
|
coverageDirectory: "coverage",
|
|
50
|
-
|
|
51
36
|
// An array of regexp pattern strings used to skip coverage collection
|
|
52
37
|
// coveragePathIgnorePatterns: [
|
|
53
38
|
// "/node_modules/"
|
|
54
39
|
// ],
|
|
55
|
-
|
|
56
40
|
// Indicates which provider should be used to instrument code for coverage
|
|
57
41
|
// coverageProvider: "babel",
|
|
58
|
-
|
|
59
42
|
// A list of reporter names that Jest uses when writing coverage reports
|
|
60
43
|
// coverageReporters: [
|
|
61
44
|
// "json",
|
|
@@ -63,36 +46,26 @@ const config: Config = {
|
|
|
63
46
|
// "lcov",
|
|
64
47
|
// "clover"
|
|
65
48
|
// ],
|
|
66
|
-
|
|
67
49
|
// An object that configures minimum threshold enforcement for coverage results
|
|
68
50
|
// coverageThreshold: undefined,
|
|
69
|
-
|
|
70
51
|
// A path to a custom dependency extractor
|
|
71
52
|
// dependencyExtractor: undefined,
|
|
72
|
-
|
|
73
53
|
// Make calling deprecated APIs throw helpful error messages
|
|
74
54
|
// errorOnDeprecated: false,
|
|
75
|
-
|
|
76
55
|
// The default configuration for fake timers
|
|
77
56
|
// fakeTimers: {
|
|
78
57
|
// "enableGlobally": false
|
|
79
58
|
// },
|
|
80
|
-
|
|
81
59
|
// Force coverage collection from ignored files using an array of glob patterns
|
|
82
60
|
// forceCoverageMatch: [],
|
|
83
|
-
|
|
84
61
|
// A path to a module which exports an async function that is triggered once before all test suites
|
|
85
62
|
// globalSetup: undefined,
|
|
86
|
-
|
|
87
63
|
// A path to a module which exports an async function that is triggered once after all test suites
|
|
88
64
|
// globalTeardown: undefined,
|
|
89
|
-
|
|
90
65
|
// A set of global variables that need to be available in all test environments
|
|
91
66
|
// globals: {},
|
|
92
|
-
|
|
93
67
|
// The maximum amount of workers used to run your tests. Can be specified as % or a number. E.g. maxWorkers: 10% will use 10% of your CPU amount + 1 as the maximum worker number. maxWorkers: 2 will use a maximum of 2 workers.
|
|
94
68
|
// maxWorkers: "50%",
|
|
95
|
-
|
|
96
69
|
// An array of file extensions your modules use
|
|
97
70
|
// moduleFileExtensions: [
|
|
98
71
|
// "js",
|
|
@@ -104,100 +77,71 @@ const config: Config = {
|
|
|
104
77
|
// "json",
|
|
105
78
|
// "node"
|
|
106
79
|
// ],
|
|
107
|
-
|
|
108
80
|
// An array of regexp pattern strings, matched against all module paths before considered 'visible' to the module loader
|
|
109
|
-
|
|
110
|
-
|
|
81
|
+
modulePathIgnorePatterns: ["<rootDir>/dist/"],
|
|
111
82
|
// Activates notifications for test results
|
|
112
83
|
// notify: false,
|
|
113
|
-
|
|
114
84
|
// An enum that specifies notification mode. Requires { notify: true }
|
|
115
85
|
// notifyMode: "failure-change",
|
|
116
|
-
|
|
117
86
|
// A preset that is used as a base for Jest's configuration
|
|
118
87
|
// preset: undefined,
|
|
119
|
-
|
|
120
88
|
// Run tests from one or more projects
|
|
121
89
|
// projects: undefined,
|
|
122
|
-
|
|
123
90
|
// Use this configuration option to add custom reporters to Jest
|
|
124
91
|
// reporters: undefined,
|
|
125
|
-
|
|
126
92
|
// Automatically reset mock state before every test
|
|
127
93
|
// resetMocks: false,
|
|
128
|
-
|
|
129
94
|
// Reset the module registry before running each individual test
|
|
130
95
|
// resetModules: false,
|
|
131
|
-
|
|
132
96
|
// A path to a custom resolver
|
|
133
97
|
// resolver: undefined,
|
|
134
|
-
|
|
135
98
|
// Automatically restore mock state and implementation before every test
|
|
136
99
|
// restoreMocks: false,
|
|
137
|
-
|
|
138
100
|
// The root directory that Jest should scan for tests and modules within
|
|
139
101
|
// rootDir: undefined,
|
|
140
|
-
|
|
141
102
|
// A list of paths to directories that Jest should use to search for files in
|
|
142
103
|
// roots: [
|
|
143
104
|
// "<rootDir>"
|
|
144
105
|
// ],
|
|
145
|
-
|
|
146
106
|
// Allows you to use a custom runner instead of Jest's default test runner
|
|
147
107
|
// runner: "jest-runner",
|
|
148
|
-
|
|
149
108
|
// The paths to modules that run some code to configure or set up the testing environment before each test
|
|
150
|
-
|
|
151
109
|
// A list of paths to modules that run some code to configure or set up the testing framework before each test
|
|
152
110
|
// setupFilesAfterEnv: [],
|
|
153
|
-
|
|
154
111
|
// The number of seconds after which a test is considered as slow and reported as such in the results.
|
|
155
112
|
// slowTestThreshold: 5,
|
|
156
|
-
|
|
157
113
|
// A list of paths to snapshot serializer modules Jest should use for snapshot testing
|
|
158
114
|
// snapshotSerializers: [],
|
|
159
|
-
|
|
160
115
|
// The test environment that will be used for testing
|
|
161
116
|
testEnvironment: "node",
|
|
162
|
-
|
|
163
117
|
// Options that will be passed to the testEnvironment
|
|
164
118
|
// testEnvironmentOptions: {},
|
|
165
|
-
|
|
166
119
|
// Adds a location field to test results
|
|
167
120
|
// testLocationInResults: false,
|
|
168
|
-
|
|
169
121
|
// The glob patterns Jest uses to detect test files
|
|
170
122
|
// testMatch: [
|
|
171
123
|
// "**/__tests__/**/*.[jt]s?(x)",
|
|
172
124
|
// "**/?(*.)+(spec|test).[tj]s?(x)"
|
|
173
125
|
// ],
|
|
174
|
-
|
|
175
126
|
// The regexp pattern or array of patterns that Jest uses to detect test files
|
|
176
127
|
// testRegex: [],
|
|
177
|
-
|
|
178
128
|
// This option allows the use of a custom results processor
|
|
179
129
|
// testResultsProcessor: undefined,
|
|
180
|
-
|
|
181
130
|
// This option allows use of a custom test runner
|
|
182
131
|
// testRunner: "jest-circus/runner",
|
|
183
|
-
|
|
184
132
|
// An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation
|
|
185
133
|
// transformIgnorePatterns: [
|
|
186
134
|
// "/node_modules/",
|
|
187
135
|
// "\\.pnp\\.[^\\/]+$"
|
|
188
136
|
// ],
|
|
189
|
-
|
|
190
137
|
// An array of regexp pattern strings that are matched against all modules before the module loader will automatically return a mock for them
|
|
191
138
|
// unmockedModulePathPatterns: undefined,
|
|
192
|
-
|
|
193
139
|
// Indicates whether each individual test should be reported during the run
|
|
194
140
|
// verbose: undefined,
|
|
195
|
-
|
|
196
141
|
// An array of regexp patterns that are matched against all source file paths before re-running tests in watch mode
|
|
197
142
|
// watchPathIgnorePatterns: [],
|
|
198
|
-
|
|
199
143
|
// Whether to use watchman for file crawling
|
|
200
144
|
// watchman: true,
|
|
201
145
|
};
|
|
202
|
-
|
|
203
146
|
export default config;
|
|
147
|
+
//# sourceMappingURL=jest.config.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"jest.config.js","sourceRoot":"","sources":["../jest.config.ts"],"names":[],"mappings":"AAAA;;;GAGG;AAIH,MAAM,MAAM,GAAW;IACnB,UAAU,EAAE,EAAE;IAEd,sBAAsB,EAAE,CAAC,KAAK,CAAC;IAE/B,0DAA0D;IAC1D,SAAS,EAAE;QACP,WAAW,EAAE,CAAC,SAAS,EAAE;gBACrB,MAAM,EAAE,IAAI;aACf,CAAC;KACL;IAED,wGAAwG;IACxG,4BAA4B;IAC5B,OAAO;IACP,KAAK;IAEL,oIAAoI;IACpI,gBAAgB,EAAE;QACd,UAAU,EAAE,kBAAkB;KACjC;IAED,oEAAoE;IACpE,mBAAmB;IAEnB,wCAAwC;IACxC,WAAW;IAEX,0EAA0E;IAC1E,sFAAsF;IAEtF,oFAAoF;IACpF,UAAU,EAAE,IAAI;IAEhB,0FAA0F;IAC1F,eAAe,EAAE,KAAK;IAEtB,yGAAyG;IACzG,kCAAkC;IAElC,4DAA4D;IAC5D,iBAAiB,EAAE,UAAU;IAE7B,sEAAsE;IACtE,gCAAgC;IAChC,qBAAqB;IACrB,KAAK;IAEL,0EAA0E;IAC1E,6BAA6B;IAE7B,wEAAwE;IACxE,uBAAuB;IACvB,YAAY;IACZ,YAAY;IACZ,YAAY;IACZ,aAAa;IACb,KAAK;IAEL,+EAA+E;IAC/E,gCAAgC;IAEhC,0CAA0C;IAC1C,kCAAkC;IAElC,4DAA4D;IAC5D,4BAA4B;IAE5B,4CAA4C;IAC5C,gBAAgB;IAChB,4BAA4B;IAC5B,KAAK;IAEL,+EAA+E;IAC/E,0BAA0B;IAE1B,mGAAmG;IACnG,0BAA0B;IAE1B,kGAAkG;IAClG,6BAA6B;IAE7B,+EAA+E;IAC/E,eAAe;IAEf,iOAAiO;IACjO,qBAAqB;IAErB,+CAA+C;IAC/C,0BAA0B;IAC1B,UAAU;IACV,WAAW;IACX,WAAW;IACX,WAAW;IACX,UAAU;IACV,WAAW;IACX,YAAY;IACZ,WAAW;IACX,KAAK;IAEL,wHAAwH;IACxH,wBAAwB,EAAE,CAAC,iBAAiB,CAAC;IAE7C,2CAA2C;IAC3C,iBAAiB;IAEjB,sEAAsE;IACtE,gCAAgC;IAEhC,2DAA2D;IAC3D,qBAAqB;IAErB,sCAAsC;IACtC,uBAAuB;IAEvB,gEAAgE;IAChE,wBAAwB;IAExB,mDAAmD;IACnD,qBAAqB;IAErB,gEAAgE;IAChE,uBAAuB;IAEvB,8BAA8B;IAC9B,uBAAuB;IAEvB,wEAAwE;IACxE,uBAAuB;IAEvB,wEAAwE;IACxE,sBAAsB;IAEtB,6EAA6E;IAC7E,WAAW;IACX,gBAAgB;IAChB,KAAK;IAEL,0EAA0E;IAC1E,yBAAyB;IAEzB,0GAA0G;IAE1G,8GAA8G;IAC9G,0BAA0B;IAE1B,sGAAsG;IACtG,wBAAwB;IAExB,sFAAsF;IACtF,2BAA2B;IAE3B,qDAAqD;IACrD,eAAe,EAAE,MAAM;IAEvB,qDAAqD;IACrD,8BAA8B;IAE9B,wCAAwC;IACxC,gCAAgC;IAEhC,mDAAmD;IACnD,eAAe;IACf,mCAAmC;IACnC,qCAAqC;IACrC,KAAK;IAEL,8EAA8E;IAC9E,iBAAiB;IAEjB,2DAA2D;IAC3D,mCAAmC;IAEnC,iDAAiD;IACjD,oCAAoC;IAEpC,4HAA4H;IAC5H,6BAA6B;IAC7B,sBAAsB;IACtB,wBAAwB;IACxB,KAAK;IAEL,6IAA6I;IAC7I,yCAAyC;IAEzC,2EAA2E;IAC3E,sBAAsB;IAEtB,mHAAmH;IACnH,+BAA+B;IAE/B,4CAA4C;IAC5C,kBAAkB;CACrB,CAAC;AAEF,eAAe,MAAM,CAAC"}
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
export interface KvCacheArgs {
|
|
3
|
+
batchSize: number;
|
|
4
|
+
maxSequenceLength: number;
|
|
5
|
+
numHeads: number;
|
|
6
|
+
headDim: number;
|
|
7
|
+
dtype?: tf.DataType;
|
|
8
|
+
}
|
|
9
|
+
/**
|
|
10
|
+
* A container for KV caches. A model should initialize one KV cache
|
|
11
|
+
*/
|
|
12
|
+
export declare class KvCacheContainer {
|
|
13
|
+
protected caches: Map<string, KvCache>;
|
|
14
|
+
protected max_sequence_length: number;
|
|
15
|
+
constructor(maxSequenceLength: number);
|
|
16
|
+
create(id: string, args: Omit<KvCacheArgs, "maxSequenceLength">): void;
|
|
17
|
+
/**
|
|
18
|
+
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
19
|
+
*/
|
|
20
|
+
update(id: string, key: tf.Tensor4D, value: tf.Tensor4D): {
|
|
21
|
+
keyCache: tf.Variable<tf.Rank.R4>;
|
|
22
|
+
valueCache: tf.Variable<tf.Rank.R4>;
|
|
23
|
+
} | undefined;
|
|
24
|
+
reset(): void;
|
|
25
|
+
dispose(): void;
|
|
26
|
+
get size(): number;
|
|
27
|
+
get maxSequenceLength(): number;
|
|
28
|
+
}
|
|
29
|
+
export declare class KvCache {
|
|
30
|
+
protected key_cache: tf.Variable<tf.Rank.R4>;
|
|
31
|
+
protected value_cache: tf.Variable<tf.Rank.R4>;
|
|
32
|
+
protected current_position: number;
|
|
33
|
+
protected batch_size: number;
|
|
34
|
+
protected max_sequence_length: number;
|
|
35
|
+
protected num_kv_heads: number;
|
|
36
|
+
protected head_dim: number;
|
|
37
|
+
constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype }: KvCacheArgs);
|
|
38
|
+
/**
|
|
39
|
+
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
40
|
+
*/
|
|
41
|
+
update(key: tf.Tensor4D, value: tf.Tensor4D): {
|
|
42
|
+
keyCache: tf.Variable<tf.Rank.R4>;
|
|
43
|
+
valueCache: tf.Variable<tf.Rank.R4>;
|
|
44
|
+
};
|
|
45
|
+
protected mergeIntoCache(new_value: tf.Tensor4D, current_cache: tf.Tensor4D): tf.Tensor4D;
|
|
46
|
+
reset(): void;
|
|
47
|
+
dispose(): void;
|
|
48
|
+
/**
|
|
49
|
+
* The size of the KV cache, also the number of tokens since the first one.
|
|
50
|
+
*/
|
|
51
|
+
get size(): number;
|
|
52
|
+
}
|
|
53
|
+
//# sourceMappingURL=kv_cache.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
|
|
@@ -1,205 +1,135 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
export interface KvCacheArgs {
|
|
5
|
-
batchSize: number;
|
|
6
|
-
maxSequenceLength: number;
|
|
7
|
-
numHeads: number;
|
|
8
|
-
headDim: number;
|
|
9
|
-
dtype?: tf.DataType
|
|
10
|
-
}
|
|
11
|
-
|
|
12
|
-
|
|
13
2
|
/**
|
|
14
3
|
* A container for KV caches. A model should initialize one KV cache
|
|
15
4
|
*/
|
|
16
5
|
export class KvCacheContainer {
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
constructor(maxSequenceLength: number) {
|
|
6
|
+
caches = new Map();
|
|
7
|
+
max_sequence_length;
|
|
8
|
+
constructor(maxSequenceLength) {
|
|
22
9
|
if (!maxSequenceLength) {
|
|
23
10
|
throw Error(`KvCacheContainer: expected KV cache maximum sequence length to be greater than 0, got: ${String(maxSequenceLength)}`);
|
|
24
11
|
}
|
|
25
|
-
|
|
26
12
|
this.max_sequence_length = maxSequenceLength;
|
|
27
13
|
}
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
public create(id: string, args: Omit<KvCacheArgs, "maxSequenceLength">) {
|
|
14
|
+
create(id, args) {
|
|
31
15
|
const new_cache = new KvCache({
|
|
32
16
|
...args,
|
|
33
17
|
maxSequenceLength: this.max_sequence_length
|
|
34
18
|
});
|
|
35
|
-
|
|
36
19
|
this.caches.set(id, new_cache);
|
|
37
20
|
}
|
|
38
|
-
|
|
39
|
-
|
|
40
21
|
/**
|
|
41
22
|
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
42
23
|
*/
|
|
43
|
-
|
|
24
|
+
update(id, key, value) {
|
|
44
25
|
const kv_cache = this.caches.get(id);
|
|
45
|
-
|
|
46
26
|
if (!kv_cache) {
|
|
47
27
|
return undefined;
|
|
48
28
|
}
|
|
49
|
-
|
|
50
29
|
const { keyCache, valueCache } = kv_cache.update(key, value);
|
|
51
|
-
|
|
52
30
|
// slicing to get only the past key and value projections, but normally
|
|
53
31
|
// in TensorFlow and PyTorch the full cache is returned and masked for
|
|
54
32
|
// graph purposes
|
|
55
33
|
return tf.tidy(() => {
|
|
56
|
-
const k_cache = keyCache.slice(
|
|
57
|
-
|
|
58
|
-
[keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
|
|
59
|
-
const v_cache = valueCache.slice(
|
|
60
|
-
[0, 0, 0, 0],
|
|
61
|
-
[valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
|
|
62
|
-
|
|
34
|
+
const k_cache = keyCache.slice([0, 0, 0, 0], [keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
|
|
35
|
+
const v_cache = valueCache.slice([0, 0, 0, 0], [valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
|
|
63
36
|
return {
|
|
64
37
|
keyCache: k_cache,
|
|
65
38
|
valueCache: v_cache
|
|
66
|
-
}
|
|
67
|
-
})
|
|
39
|
+
};
|
|
40
|
+
});
|
|
68
41
|
}
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
public reset() {
|
|
42
|
+
reset() {
|
|
72
43
|
this.caches.forEach(cache => {
|
|
73
44
|
cache.reset();
|
|
74
|
-
})
|
|
45
|
+
});
|
|
75
46
|
}
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
public dispose() {
|
|
47
|
+
dispose() {
|
|
79
48
|
this.caches.forEach(cache => {
|
|
80
49
|
cache.dispose();
|
|
81
|
-
})
|
|
50
|
+
});
|
|
82
51
|
}
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
public get size() {
|
|
52
|
+
get size() {
|
|
86
53
|
// the size of all KV caches are expected to be the same, just use the first one
|
|
87
54
|
return this.caches.entries().next().value?.[1].size ?? 0;
|
|
88
55
|
}
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
public get maxSequenceLength() {
|
|
56
|
+
get maxSequenceLength() {
|
|
92
57
|
return this.max_sequence_length;
|
|
93
58
|
}
|
|
94
59
|
}
|
|
95
|
-
|
|
96
|
-
|
|
97
60
|
export class KvCache {
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
protected value_cache: tf.Variable<tf.Rank.R4>
|
|
101
|
-
|
|
61
|
+
key_cache;
|
|
62
|
+
value_cache;
|
|
102
63
|
// the size of the KV cache, represents the number of tokens since the first chat token
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }: KvCacheArgs) {
|
|
111
|
-
const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim] as [number, number, number, number];
|
|
112
|
-
|
|
64
|
+
current_position = 0;
|
|
65
|
+
batch_size;
|
|
66
|
+
max_sequence_length;
|
|
67
|
+
num_kv_heads;
|
|
68
|
+
head_dim;
|
|
69
|
+
constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }) {
|
|
70
|
+
const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim];
|
|
113
71
|
this.key_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
|
|
114
72
|
this.value_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
|
|
115
|
-
|
|
116
73
|
this.batch_size = batchSize;
|
|
117
74
|
this.max_sequence_length = maxSequenceLength;
|
|
118
75
|
this.num_kv_heads = numHeads;
|
|
119
76
|
this.head_dim = headDim;
|
|
120
77
|
}
|
|
121
|
-
|
|
122
|
-
|
|
123
78
|
/**
|
|
124
79
|
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
125
80
|
*/
|
|
126
|
-
|
|
81
|
+
update(key, value) {
|
|
127
82
|
const batch_size = key.shape[0];
|
|
128
83
|
const seq_len = key.shape[2];
|
|
129
|
-
|
|
130
84
|
if (batch_size > this.key_cache.shape[0]) {
|
|
131
85
|
throw Error(`The current KV cache has been set up with a batch size of` +
|
|
132
|
-
` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`)
|
|
86
|
+
` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`);
|
|
133
87
|
}
|
|
134
|
-
|
|
135
88
|
if (this.current_position + seq_len > this.max_sequence_length) {
|
|
136
89
|
throw Error(`The KV cache has exceeded its maximum sequence length of ${this.max_sequence_length}. Use a larger value.`);
|
|
137
90
|
}
|
|
138
|
-
|
|
139
91
|
const new_key_cache = this.mergeIntoCache(key, this.key_cache);
|
|
140
92
|
const new_value_cache = this.mergeIntoCache(value, this.value_cache);
|
|
141
|
-
|
|
142
93
|
this.key_cache.assign(new_key_cache);
|
|
143
94
|
this.value_cache.assign(new_value_cache);
|
|
144
|
-
|
|
145
95
|
new_key_cache.dispose();
|
|
146
96
|
new_value_cache.dispose();
|
|
147
|
-
|
|
148
97
|
// advance the pointer to reflect the updated cache's current
|
|
149
98
|
this.current_position += seq_len;
|
|
150
|
-
|
|
151
99
|
return {
|
|
152
100
|
keyCache: this.key_cache,
|
|
153
101
|
valueCache: this.value_cache,
|
|
154
|
-
}
|
|
102
|
+
};
|
|
155
103
|
}
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
protected mergeIntoCache(new_value: tf.Tensor4D, current_cache: tf.Tensor4D) {
|
|
104
|
+
mergeIntoCache(new_value, current_cache) {
|
|
159
105
|
const seq_len = new_value.shape[2];
|
|
160
|
-
|
|
161
106
|
return tf.tidy(() => {
|
|
162
|
-
|
|
163
|
-
const
|
|
164
|
-
[0, 0, 0, 0],
|
|
165
|
-
[this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
|
|
166
|
-
|
|
167
|
-
const future = current_cache.slice(
|
|
168
|
-
[0, 0, this.current_position + seq_len, 0],
|
|
169
|
-
[this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
|
|
170
|
-
|
|
107
|
+
const historical = current_cache.slice([0, 0, 0, 0], [this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
|
|
108
|
+
const future = current_cache.slice([0, 0, this.current_position + seq_len, 0], [this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
|
|
171
109
|
// merge the new tensor into the current cache to create a new, larger, cache,
|
|
172
110
|
// this is different from Python immplementations because TFJS tensors are immutable,
|
|
173
111
|
// because we cannot update a slice, we must slice and concat
|
|
174
112
|
return tf.concat([historical, new_value, future], 2);
|
|
175
|
-
})
|
|
113
|
+
});
|
|
176
114
|
}
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
public reset(): void {
|
|
115
|
+
reset() {
|
|
180
116
|
this.current_position = 0;
|
|
181
|
-
|
|
182
117
|
tf.tidy(() => {
|
|
183
118
|
const key_cache_shape = this.key_cache.shape;
|
|
184
119
|
const value_cache_shape = this.value_cache.shape;
|
|
185
|
-
|
|
186
120
|
this.key_cache.assign(tf.zeros(key_cache_shape));
|
|
187
121
|
this.value_cache.assign(tf.zeros(value_cache_shape));
|
|
188
122
|
});
|
|
189
123
|
}
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
public dispose(): void {
|
|
124
|
+
dispose() {
|
|
193
125
|
this.key_cache.dispose();
|
|
194
126
|
this.value_cache.dispose();
|
|
195
127
|
}
|
|
196
|
-
|
|
197
|
-
|
|
198
128
|
/**
|
|
199
129
|
* The size of the KV cache, also the number of tokens since the first one.
|
|
200
130
|
*/
|
|
201
|
-
get size()
|
|
131
|
+
get size() {
|
|
202
132
|
return this.current_position;
|
|
203
133
|
}
|
|
204
|
-
|
|
205
134
|
}
|
|
135
|
+
//# sourceMappingURL=kv_cache.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { KvCacheContainer } from "../kv_cache";
|
|
3
|
+
import { MultiHeadAttention, type MultiHeadAttentionArgs } from '../layers/multihead_attention';
|
|
4
|
+
import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
|
|
5
|
+
/**
|
|
6
|
+
* MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
|
|
7
|
+
* should be used in a custom training loop because it requires the cache to be
|
|
8
|
+
* passed through the `kwargs.kvCache` argument during the `layer.apply()`
|
|
9
|
+
* forward propagation.
|
|
10
|
+
*
|
|
11
|
+
* If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
|
|
12
|
+
*/
|
|
13
|
+
export declare class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
|
|
14
|
+
static className: string;
|
|
15
|
+
protected rope: tf.layers.Layer;
|
|
16
|
+
constructor(args: MultiHeadAttentionArgs);
|
|
17
|
+
protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
|
|
18
|
+
protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D): {
|
|
19
|
+
keyCache: tf.Variable<tf.Rank.R4>;
|
|
20
|
+
valueCache: tf.Variable<tf.Rank.R4>;
|
|
21
|
+
};
|
|
22
|
+
/**
|
|
23
|
+
* Adds RoPE position encoding right after splitting heads.
|
|
24
|
+
*/
|
|
25
|
+
protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
|
|
26
|
+
query_split: tf.Tensor4D;
|
|
27
|
+
key_split: tf.Tensor4D;
|
|
28
|
+
value_split: tf.Tensor4D;
|
|
29
|
+
};
|
|
30
|
+
}
|
|
31
|
+
//# sourceMappingURL=cached_rope_multihead_attention.d.ts.map
|