@stellarapp/tfjs-stellar 1.0.4 → 1.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -0
- package/dist/index.d.ts +2 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -1
- package/dist/index.js.map +1 -1
- package/dist/kv_cache.d.ts +2 -0
- package/dist/kv_cache.d.ts.map +1 -1
- package/dist/kv_cache.js +6 -0
- package/dist/kv_cache.js.map +1 -1
- package/dist/masks.test.d.ts +2 -0
- package/dist/masks.test.d.ts.map +1 -0
- package/dist/masks.test.js +55 -0
- package/dist/masks.test.js.map +1 -0
- package/dist/models/index.d.ts +2 -1
- package/dist/models/index.d.ts.map +1 -1
- package/dist/models/index.js +2 -1
- package/dist/models/index.js.map +1 -1
- package/dist/utils.test.js +0 -15
- package/dist/utils.test.js.map +1 -1
- package/package.json +1 -1
- package/dist/jest.config.d.ts +0 -8
- package/dist/jest.config.d.ts.map +0 -1
- package/dist/jest.config.js +0 -147
- package/dist/jest.config.js.map +0 -1
- package/dist/src/index.d.ts +0 -6
- package/dist/src/index.d.ts.map +0 -1
- package/dist/src/index.js +0 -6
- package/dist/src/index.js.map +0 -1
- package/dist/src/kv_cache.d.ts +0 -53
- package/dist/src/kv_cache.d.ts.map +0 -1
- package/dist/src/kv_cache.js +0 -135
- package/dist/src/kv_cache.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
- package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
- package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
- package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
- package/dist/src/layers/gpt_decoder_block.js +0 -51
- package/dist/src/layers/gpt_decoder_block.js.map +0 -1
- package/dist/src/layers/index.d.ts +0 -17
- package/dist/src/layers/index.d.ts.map +0 -1
- package/dist/src/layers/index.js +0 -33
- package/dist/src/layers/index.js.map +0 -1
- package/dist/src/layers/multihead_attention.d.ts +0 -106
- package/dist/src/layers/multihead_attention.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.js +0 -269
- package/dist/src/layers/multihead_attention.js.map +0 -1
- package/dist/src/layers/multihead_attention.test.d.ts +0 -2
- package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
- package/dist/src/layers/multihead_attention.test.js +0 -160
- package/dist/src/layers/multihead_attention.test.js.map +0 -1
- package/dist/src/layers/positional_encoding.d.ts +0 -37
- package/dist/src/layers/positional_encoding.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.js +0 -115
- package/dist/src/layers/positional_encoding.js.map +0 -1
- package/dist/src/layers/positional_encoding.test.d.ts +0 -2
- package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
- package/dist/src/layers/positional_encoding.test.js +0 -95
- package/dist/src/layers/positional_encoding.test.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
- package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.js +0 -99
- package/dist/src/layers/rotary_position_embedding.js.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/rotary_position_embedding.test.js +0 -88
- package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.js +0 -109
- package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
- package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
- package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
- package/dist/src/layers/transformer_decoder.d.ts +0 -69
- package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.js +0 -182
- package/dist/src/layers/transformer_decoder.js.map +0 -1
- package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_decoder.test.js +0 -72
- package/dist/src/layers/transformer_decoder.test.js.map +0 -1
- package/dist/src/layers/transformer_encoder.d.ts +0 -55
- package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.js +0 -175
- package/dist/src/layers/transformer_encoder.js.map +0 -1
- package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
- package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
- package/dist/src/layers/transformer_encoder.test.js +0 -58
- package/dist/src/layers/transformer_encoder.test.js.map +0 -1
- package/dist/src/losses/dice.d.ts +0 -30
- package/dist/src/losses/dice.d.ts.map +0 -1
- package/dist/src/losses/dice.js +0 -93
- package/dist/src/losses/dice.js.map +0 -1
- package/dist/src/losses/index.d.ts +0 -2
- package/dist/src/losses/index.d.ts.map +0 -1
- package/dist/src/losses/index.js +0 -2
- package/dist/src/losses/index.js.map +0 -1
- package/dist/src/masks.d.ts +0 -20
- package/dist/src/masks.d.ts.map +0 -1
- package/dist/src/masks.js +0 -37
- package/dist/src/masks.js.map +0 -1
- package/dist/src/metrics.d.ts +0 -20
- package/dist/src/metrics.d.ts.map +0 -1
- package/dist/src/metrics.js +0 -28
- package/dist/src/metrics.js.map +0 -1
- package/dist/src/models/gpt_model.d.ts +0 -94
- package/dist/src/models/gpt_model.d.ts.map +0 -1
- package/dist/src/models/gpt_model.js +0 -154
- package/dist/src/models/gpt_model.js.map +0 -1
- package/dist/src/models/index.d.ts +0 -3
- package/dist/src/models/index.d.ts.map +0 -1
- package/dist/src/models/index.js +0 -3
- package/dist/src/models/index.js.map +0 -1
- package/dist/src/models/llm_model.d.ts +0 -87
- package/dist/src/models/llm_model.d.ts.map +0 -1
- package/dist/src/models/llm_model.js +0 -245
- package/dist/src/models/llm_model.js.map +0 -1
- package/dist/src/models/u_net.d.ts +0 -40
- package/dist/src/models/u_net.d.ts.map +0 -1
- package/dist/src/models/u_net.js +0 -151
- package/dist/src/models/u_net.js.map +0 -1
- package/dist/src/tfjs_types.d.ts +0 -10
- package/dist/src/tfjs_types.d.ts.map +0 -1
- package/dist/src/tfjs_types.js +0 -2
- package/dist/src/tfjs_types.js.map +0 -1
- package/dist/src/utils.d.ts +0 -28
- package/dist/src/utils.d.ts.map +0 -1
- package/dist/src/utils.js +0 -63
- package/dist/src/utils.js.map +0 -1
- package/dist/src/utils.test.d.ts +0 -2
- package/dist/src/utils.test.d.ts.map +0 -1
- package/dist/src/utils.test.js +0 -73
- package/dist/src/utils.test.js.map +0 -1
package/README.md
CHANGED
|
@@ -44,4 +44,21 @@ gpt_model.summary();
|
|
|
44
44
|
// see https://js.tensorflow.org/api/latest/#data.generator
|
|
45
45
|
// on how to create a generator dataset
|
|
46
46
|
//gpt_model.fitDataset(your_generator_dataset, { epochs: 1 });
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Jest Unit Testing
|
|
50
|
+
|
|
51
|
+
If you plan to use this library in your Jest unit tests, you may need to add the following configurations to your `jest.config.ts` file's `config`
|
|
52
|
+
|
|
53
|
+
```ts
|
|
54
|
+
// A map from regular expressions to paths to transformers
|
|
55
|
+
transform: {
|
|
56
|
+
'^.+\\.[jt]s?$': ["ts-jest", {
|
|
57
|
+
useESM: true,
|
|
58
|
+
}]
|
|
59
|
+
},
|
|
60
|
+
|
|
61
|
+
transformIgnorePatterns: [
|
|
62
|
+
"/node_modules/(?!(@stellarapp/tfjs-stellar|@tensorflow/tfjs))"
|
|
63
|
+
],
|
|
47
64
|
```
|
package/dist/index.d.ts
CHANGED
|
@@ -2,7 +2,8 @@ export * as layers from "./layers";
|
|
|
2
2
|
export * as models from "./models";
|
|
3
3
|
export * as losses from "./losses";
|
|
4
4
|
export * as masks from "./masks";
|
|
5
|
-
export
|
|
5
|
+
export * from "./kv_cache";
|
|
6
6
|
export * as metrics from "./metrics";
|
|
7
7
|
export * as utils from "./utils";
|
|
8
|
+
export { loadUNetModel } from "./models/u_net";
|
|
8
9
|
//# sourceMappingURL=index.d.ts.map
|
package/dist/index.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,cAAc,YAAY,CAAC;AAC3B,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC;AACrC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,aAAa,EAAE,MAAM,gBAAgB,CAAC"}
|
package/dist/index.js
CHANGED
|
@@ -2,7 +2,8 @@ export * as layers from "./layers";
|
|
|
2
2
|
export * as models from "./models";
|
|
3
3
|
export * as losses from "./losses";
|
|
4
4
|
export * as masks from "./masks";
|
|
5
|
-
export
|
|
5
|
+
export * from "./kv_cache";
|
|
6
6
|
export * as metrics from "./metrics";
|
|
7
7
|
export * as utils from "./utils";
|
|
8
|
+
export { loadUNetModel } from "./models/u_net";
|
|
8
9
|
//# sourceMappingURL=index.js.map
|
package/dist/index.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,cAAc,YAAY,CAAC;AAC3B,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC;AACrC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,aAAa,EAAE,MAAM,gBAAgB,CAAC"}
|
package/dist/kv_cache.d.ts
CHANGED
|
@@ -6,6 +6,8 @@ export interface KvCacheArgs {
|
|
|
6
6
|
headDim: number;
|
|
7
7
|
dtype?: tf.DataType;
|
|
8
8
|
}
|
|
9
|
+
export declare function kvCacheContainer(maxSequenceLength: number): KvCacheContainer;
|
|
10
|
+
export declare function kvCache(args: KvCacheArgs): KvCache;
|
|
9
11
|
/**
|
|
10
12
|
* A container for KV caches. A model should initialize one KV cache
|
|
11
13
|
*/
|
package/dist/kv_cache.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
|
|
1
|
+
{"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD,wBAAgB,gBAAgB,CAAC,iBAAiB,EAAE,MAAM,oBAEzD;AAGD,wBAAgB,OAAO,CAAC,IAAI,EAAE,WAAW,WAExC;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
|
package/dist/kv_cache.js
CHANGED
|
@@ -1,4 +1,10 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
export function kvCacheContainer(maxSequenceLength) {
|
|
3
|
+
return new KvCacheContainer(maxSequenceLength);
|
|
4
|
+
}
|
|
5
|
+
export function kvCache(args) {
|
|
6
|
+
return new KvCache(args);
|
|
7
|
+
}
|
|
2
8
|
/**
|
|
3
9
|
* A container for KV caches. A model should initialize one KV cache
|
|
4
10
|
*/
|
package/dist/kv_cache.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
|
|
1
|
+
{"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC,MAAM,UAAU,gBAAgB,CAAC,iBAAyB;IACtD,OAAO,IAAI,gBAAgB,CAAC,iBAAiB,CAAC,CAAC;AACnD,CAAC;AAGD,MAAM,UAAU,OAAO,CAAC,IAAiB;IACrC,OAAO,IAAI,OAAO,CAAC,IAAI,CAAC,CAAC;AAC7B,CAAC;AAGD;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"masks.test.d.ts","sourceRoot":"","sources":["../src/masks.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import * as masks from "./masks";
|
|
3
|
+
tf.env().set('IS_NODE', false);
|
|
4
|
+
describe(" mask tests", () => {
|
|
5
|
+
test("packing mask for self-attention", () => {
|
|
6
|
+
const boundaries = new Int32Array([1, 0, 0, 1, 0, 0, 1, 0, 0]);
|
|
7
|
+
const packing_mask = masks.packing(boundaries);
|
|
8
|
+
const expected_mask = tf.tensor([
|
|
9
|
+
[0, 0, 0, -10000000, -10000000, -10000001, -10000000, -10000000, -10000002],
|
|
10
|
+
[0, 0, 0, -10000001, -10000001, -10000002, -10000000, -10000000, -10000001],
|
|
11
|
+
[0, 0, 0, -10000003, -10000000, -10000000, -10000002, -10000002, -10000000],
|
|
12
|
+
[-10000000, -10000002, -10000000, 0, 0, 0, -10000001, -10000001, -10000000],
|
|
13
|
+
[-10000000, -10000000, -10000001, 0, 0, 0, -10000003, -10000000, -10000001],
|
|
14
|
+
[-10000000, -10000002, -10000000, 0, 0, 0, -10000000, -10000001, -10000000],
|
|
15
|
+
[-10000000, -10000000, -10000000, -10000000, -10000000, -10000000, 0, 0, 0],
|
|
16
|
+
[-10000000, -10000002, -10000000, -10000002, -10000000, -10000000, 0, 0, 0],
|
|
17
|
+
[-10000000, -10000001, -10000000, -10000000, -10000002, -10000003, 0, 0, 0]
|
|
18
|
+
]);
|
|
19
|
+
// The mask uses -1e7 on masked positions which introduces extra integers on
|
|
20
|
+
// some values in the float32 tensor. Ideally it should check that the sum is equal to 0,
|
|
21
|
+
// but since there are 54 masked positions, we'll just check that it's less than 108
|
|
22
|
+
expect(packing_mask.sub(expected_mask).sum().arraySync()).toBeLessThan(108);
|
|
23
|
+
});
|
|
24
|
+
test("packing mask for non-packed sequence", () => {
|
|
25
|
+
const boundaries = new Int32Array([1, 0, 0, 0, 0]);
|
|
26
|
+
const packing_mask = masks.packing(boundaries);
|
|
27
|
+
expect(packing_mask.sum().arraySync()).toEqual(0);
|
|
28
|
+
});
|
|
29
|
+
test("causal mask size 4", async () => {
|
|
30
|
+
const seq_len = 4;
|
|
31
|
+
const causal_mask = masks.causal(seq_len, seq_len);
|
|
32
|
+
const _ = -1e7;
|
|
33
|
+
const expected_mask = tf.tensor([
|
|
34
|
+
[0, _, _, _],
|
|
35
|
+
[0, 0, _, _],
|
|
36
|
+
[0, 0, 0, _],
|
|
37
|
+
[0, 0, 0, 0]
|
|
38
|
+
]);
|
|
39
|
+
// this might fail due to precision issues on the masked positions,
|
|
40
|
+
// in which case use less <= to 6 or 12 (number of masked positions x2)
|
|
41
|
+
expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
|
|
42
|
+
});
|
|
43
|
+
test("causal mask size 5", () => {
|
|
44
|
+
const expected_mask = tf.tensor([
|
|
45
|
+
[0, -10000000, -10000000, -10000000, -10000000],
|
|
46
|
+
[0, 0, -10000000, -10000000, -10000000],
|
|
47
|
+
[0, 0, 0, -10000000, -10000000],
|
|
48
|
+
[0, 0, 0, 0, -10000000],
|
|
49
|
+
[0, 0, 0, 0, 0]
|
|
50
|
+
]);
|
|
51
|
+
const causal_mask = masks.causal(5, 5);
|
|
52
|
+
expect(causal_mask.equal(expected_mask).sum().arraySync()).toBe(25);
|
|
53
|
+
});
|
|
54
|
+
});
|
|
55
|
+
//# sourceMappingURL=masks.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"masks.test.js","sourceRoot":"","sources":["../src/masks.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AAEjC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,aAAa,EAAE,GAAG,EAAE;IACzB,IAAI,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACzC,MAAM,UAAU,GAAG,IAAI,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC/D,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC;QAE/C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SAC9E,CAAC,CAAC;QAEH,4EAA4E;QAC5E,yFAAyF;QACzF,oFAAoF;QACpF,MAAM,CAAE,YAAY,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,SAAS,EAAa,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,CAAC;IAC5F,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,sCAAsC,EAAE,GAAG,EAAE;QAC9C,MAAM,UAAU,GAAG,IAAI,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnD,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC;QAE/C,MAAM,CAAE,YAAY,CAAC,GAAG,EAAE,CAAC,SAAS,EAAa,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,OAAO,GAAG,CAAC,CAAC;QAClB,MAAM,WAAW,GAAG,KAAK,CAAC,MAAM,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;QAEnD,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC;QACf,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACf,CAAC,CAAC;QAEH,mEAAmE;QACnE,uEAAuE;QACvE,MAAM,CAAC,CAAC,MAAM,WAAW,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oBAAoB,EAAE,GAAG,EAAE;QAC5B,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC/C,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YACvC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC/B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC;YACvB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SAAC,CAAC,CAAC;QAEtB,MAAM,WAAW,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAEvC,MAAM,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,SAAS,EAAE,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IACxE,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
package/dist/models/index.d.ts
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import { LlmModel, type LlmModelArgs } from "./llm_model";
|
|
2
2
|
import { GptModel, type GptModelArgs } from "../models/gpt_model";
|
|
3
|
-
import { type UNetArgs } from "../models/u_net";
|
|
3
|
+
import { UNetModel, type UNetArgs } from "../models/u_net";
|
|
4
|
+
export { LlmModel, LlmModelArgs, GptModel, GptModelArgs, UNetModel, UNetArgs };
|
|
4
5
|
export declare function llmModel(args: LlmModelArgs): LlmModel;
|
|
5
6
|
export declare function gptModel(args: GptModelArgs): GptModel;
|
|
6
7
|
export declare function unetModel(args: UNetArgs): import("@tensorflow/tfjs-layers").LayersModel;
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAc,KAAK,QAAQ,EAAE,MAAM,iBAAiB,CAAC;
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAc,SAAS,EAAE,KAAK,QAAQ,EAAE,MAAM,iBAAiB,CAAC;AACvE,OAAO,EACH,QAAQ,EAAE,YAAY,EACtB,QAAQ,EAAE,YAAY,EACtB,SAAS,EAAE,QAAQ,EACtB,CAAA;AAED,wBAAgB,QAAQ,CAAC,IAAI,EAAE,YAAY,YAE1C;AAGD,wBAAgB,QAAQ,CAAC,IAAI,EAAE,YAAY,YAE1C;AAGD,wBAAgB,SAAS,CAAC,IAAI,EAAE,QAAQ,iDAEvC"}
|
package/dist/models/index.js
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import { LlmModel } from "./llm_model";
|
|
2
2
|
import { GptModel } from "../models/gpt_model";
|
|
3
|
-
import { createUNet } from "../models/u_net";
|
|
3
|
+
import { createUNet, UNetModel } from "../models/u_net";
|
|
4
|
+
export { LlmModel, GptModel, UNetModel };
|
|
4
5
|
export function llmModel(args) {
|
|
5
6
|
return new LlmModel(args);
|
|
6
7
|
}
|
package/dist/models/index.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAqB,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAqB,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAE,UAAU,EAAiB,MAAM,iBAAiB,CAAC;
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAqB,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAqB,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAE,UAAU,EAAE,SAAS,EAAiB,MAAM,iBAAiB,CAAC;AACvE,OAAO,EACH,QAAQ,EACR,QAAQ,EACR,SAAS,EACZ,CAAA;AAED,MAAM,UAAU,QAAQ,CAAC,IAAkB;IACvC,OAAO,IAAI,QAAQ,CAAC,IAAI,CAAC,CAAC;AAC9B,CAAC;AAGD,MAAM,UAAU,QAAQ,CAAC,IAAkB;IACvC,OAAO,IAAI,QAAQ,CAAC,IAAI,CAAC,CAAC;AAC9B,CAAC;AAGD,MAAM,UAAU,SAAS,CAAC,IAAc;IACpC,OAAO,UAAU,CAAC,IAAI,CAAC,CAAC;AAC5B,CAAC"}
|
package/dist/utils.test.js
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
2
|
import { getScaleShape, getRandomCropStart } from "./utils";
|
|
3
|
-
import { causal } from "./masks";
|
|
4
3
|
// avoid TFJS node message during Jest testing
|
|
5
4
|
tf.env().set('IS_NODE', false);
|
|
6
5
|
describe("test custom TFJS utility functions", () => {
|
|
@@ -55,19 +54,5 @@ describe("test custom TFJS utility functions", () => {
|
|
|
55
54
|
const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555]);
|
|
56
55
|
expect(scale8_h).toBeLessThan(scale8_w);
|
|
57
56
|
});
|
|
58
|
-
test("causal attention map", async () => {
|
|
59
|
-
const seq_len = 4;
|
|
60
|
-
const causal_mask = causal(seq_len, seq_len);
|
|
61
|
-
const _ = -1e7;
|
|
62
|
-
const expected_mask = tf.tensor([
|
|
63
|
-
[0, _, _, _],
|
|
64
|
-
[0, 0, _, _],
|
|
65
|
-
[0, 0, 0, _],
|
|
66
|
-
[0, 0, 0, 0]
|
|
67
|
-
]);
|
|
68
|
-
// this might fail due to precision issues on the masked positions,
|
|
69
|
-
// in which case use less <= to 6 or 12 (number of masked positions x2)
|
|
70
|
-
expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
|
|
71
|
-
});
|
|
72
57
|
});
|
|
73
58
|
//# sourceMappingURL=utils.test.js.map
|
package/dist/utils.test.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;
|
|
1
|
+
{"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAG5D,8CAA8C;AAC9C,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,2DAA2D,EAAE,KAAK,IAAI,EAAE;QACzE,sCAAsC;QACtC,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAC/C,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAElD,MAAM,CAAC,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;IAGH,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,IAAI,EAAE,GAAG,CAAqB,CAAC;YACjD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAChD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;IACL,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACnD,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,oCAAoC;QACpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,0CAA0C;QAC1C,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;QAEpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC;QAEpC,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,eAAe,CAAC,QAAQ,CAAC,CAAC;QAE3C,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,YAAY,CAAC,QAAQ,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;AAEP,CAAC,CAAC,CAAC"}
|
package/package.json
CHANGED
package/dist/jest.config.d.ts
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"jest.config.d.ts","sourceRoot":"","sources":["../jest.config.ts"],"names":[],"mappings":"AAAA;;;GAGG;AAEH,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,MAAM,CAAC;AAEnC,QAAA,MAAM,MAAM,EAAE,MAiMb,CAAC;AAEF,eAAe,MAAM,CAAC"}
|
package/dist/jest.config.js
DELETED
|
@@ -1,147 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* For a detailed explanation regarding each configuration property, visit:
|
|
3
|
-
* https://jestjs.io/docs/configuration
|
|
4
|
-
*/
|
|
5
|
-
const config = {
|
|
6
|
-
setupFiles: [],
|
|
7
|
-
extensionsToTreatAsEsm: [".ts"],
|
|
8
|
-
// A map from regular expressions to paths to transformers
|
|
9
|
-
transform: {
|
|
10
|
-
"^.+\.ts?$": ["ts-jest", {
|
|
11
|
-
useESM: true
|
|
12
|
-
}],
|
|
13
|
-
},
|
|
14
|
-
// An array of regexp pattern strings that are matched against all test paths, matched tests are skipped
|
|
15
|
-
// testPathIgnorePatterns: [
|
|
16
|
-
//
|
|
17
|
-
// ],
|
|
18
|
-
// A map from regular expressions to module names or to arrays of module names that allow to stub out resources with a single module
|
|
19
|
-
moduleNameMapper: {
|
|
20
|
-
"^@/(.*$)": "<rootDir>/src/$1"
|
|
21
|
-
},
|
|
22
|
-
// All imported modules in your tests should be mocked automatically
|
|
23
|
-
// automock: false,
|
|
24
|
-
// Stop running tests after `n` failures
|
|
25
|
-
// bail: 0,
|
|
26
|
-
// The directory where Jest should store its cached dependency information
|
|
27
|
-
// cacheDirectory: "/private/var/folders/8x/0jgq0fqx5qzgtm1zc8xtrdc80000gn/T/jest_dx",
|
|
28
|
-
// Automatically clear mock calls, instances, contexts and results before every test
|
|
29
|
-
clearMocks: true,
|
|
30
|
-
// Indicates whether the coverage information should be collected while executing the test
|
|
31
|
-
collectCoverage: false,
|
|
32
|
-
// An array of glob patterns indicating a set of files for which coverage information should be collected
|
|
33
|
-
// collectCoverageFrom: undefined,
|
|
34
|
-
// The directory where Jest should output its coverage files
|
|
35
|
-
coverageDirectory: "coverage",
|
|
36
|
-
// An array of regexp pattern strings used to skip coverage collection
|
|
37
|
-
// coveragePathIgnorePatterns: [
|
|
38
|
-
// "/node_modules/"
|
|
39
|
-
// ],
|
|
40
|
-
// Indicates which provider should be used to instrument code for coverage
|
|
41
|
-
// coverageProvider: "babel",
|
|
42
|
-
// A list of reporter names that Jest uses when writing coverage reports
|
|
43
|
-
// coverageReporters: [
|
|
44
|
-
// "json",
|
|
45
|
-
// "text",
|
|
46
|
-
// "lcov",
|
|
47
|
-
// "clover"
|
|
48
|
-
// ],
|
|
49
|
-
// An object that configures minimum threshold enforcement for coverage results
|
|
50
|
-
// coverageThreshold: undefined,
|
|
51
|
-
// A path to a custom dependency extractor
|
|
52
|
-
// dependencyExtractor: undefined,
|
|
53
|
-
// Make calling deprecated APIs throw helpful error messages
|
|
54
|
-
// errorOnDeprecated: false,
|
|
55
|
-
// The default configuration for fake timers
|
|
56
|
-
// fakeTimers: {
|
|
57
|
-
// "enableGlobally": false
|
|
58
|
-
// },
|
|
59
|
-
// Force coverage collection from ignored files using an array of glob patterns
|
|
60
|
-
// forceCoverageMatch: [],
|
|
61
|
-
// A path to a module which exports an async function that is triggered once before all test suites
|
|
62
|
-
// globalSetup: undefined,
|
|
63
|
-
// A path to a module which exports an async function that is triggered once after all test suites
|
|
64
|
-
// globalTeardown: undefined,
|
|
65
|
-
// A set of global variables that need to be available in all test environments
|
|
66
|
-
// globals: {},
|
|
67
|
-
// The maximum amount of workers used to run your tests. Can be specified as % or a number. E.g. maxWorkers: 10% will use 10% of your CPU amount + 1 as the maximum worker number. maxWorkers: 2 will use a maximum of 2 workers.
|
|
68
|
-
// maxWorkers: "50%",
|
|
69
|
-
// An array of file extensions your modules use
|
|
70
|
-
// moduleFileExtensions: [
|
|
71
|
-
// "js",
|
|
72
|
-
// "mjs",
|
|
73
|
-
// "cjs",
|
|
74
|
-
// "jsx",
|
|
75
|
-
// "ts",
|
|
76
|
-
// "tsx",
|
|
77
|
-
// "json",
|
|
78
|
-
// "node"
|
|
79
|
-
// ],
|
|
80
|
-
// An array of regexp pattern strings, matched against all module paths before considered 'visible' to the module loader
|
|
81
|
-
modulePathIgnorePatterns: ["<rootDir>/dist/"],
|
|
82
|
-
// Activates notifications for test results
|
|
83
|
-
// notify: false,
|
|
84
|
-
// An enum that specifies notification mode. Requires { notify: true }
|
|
85
|
-
// notifyMode: "failure-change",
|
|
86
|
-
// A preset that is used as a base for Jest's configuration
|
|
87
|
-
// preset: undefined,
|
|
88
|
-
// Run tests from one or more projects
|
|
89
|
-
// projects: undefined,
|
|
90
|
-
// Use this configuration option to add custom reporters to Jest
|
|
91
|
-
// reporters: undefined,
|
|
92
|
-
// Automatically reset mock state before every test
|
|
93
|
-
// resetMocks: false,
|
|
94
|
-
// Reset the module registry before running each individual test
|
|
95
|
-
// resetModules: false,
|
|
96
|
-
// A path to a custom resolver
|
|
97
|
-
// resolver: undefined,
|
|
98
|
-
// Automatically restore mock state and implementation before every test
|
|
99
|
-
// restoreMocks: false,
|
|
100
|
-
// The root directory that Jest should scan for tests and modules within
|
|
101
|
-
// rootDir: undefined,
|
|
102
|
-
// A list of paths to directories that Jest should use to search for files in
|
|
103
|
-
// roots: [
|
|
104
|
-
// "<rootDir>"
|
|
105
|
-
// ],
|
|
106
|
-
// Allows you to use a custom runner instead of Jest's default test runner
|
|
107
|
-
// runner: "jest-runner",
|
|
108
|
-
// The paths to modules that run some code to configure or set up the testing environment before each test
|
|
109
|
-
// A list of paths to modules that run some code to configure or set up the testing framework before each test
|
|
110
|
-
// setupFilesAfterEnv: [],
|
|
111
|
-
// The number of seconds after which a test is considered as slow and reported as such in the results.
|
|
112
|
-
// slowTestThreshold: 5,
|
|
113
|
-
// A list of paths to snapshot serializer modules Jest should use for snapshot testing
|
|
114
|
-
// snapshotSerializers: [],
|
|
115
|
-
// The test environment that will be used for testing
|
|
116
|
-
testEnvironment: "node",
|
|
117
|
-
// Options that will be passed to the testEnvironment
|
|
118
|
-
// testEnvironmentOptions: {},
|
|
119
|
-
// Adds a location field to test results
|
|
120
|
-
// testLocationInResults: false,
|
|
121
|
-
// The glob patterns Jest uses to detect test files
|
|
122
|
-
// testMatch: [
|
|
123
|
-
// "**/__tests__/**/*.[jt]s?(x)",
|
|
124
|
-
// "**/?(*.)+(spec|test).[tj]s?(x)"
|
|
125
|
-
// ],
|
|
126
|
-
// The regexp pattern or array of patterns that Jest uses to detect test files
|
|
127
|
-
// testRegex: [],
|
|
128
|
-
// This option allows the use of a custom results processor
|
|
129
|
-
// testResultsProcessor: undefined,
|
|
130
|
-
// This option allows use of a custom test runner
|
|
131
|
-
// testRunner: "jest-circus/runner",
|
|
132
|
-
// An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation
|
|
133
|
-
// transformIgnorePatterns: [
|
|
134
|
-
// "/node_modules/",
|
|
135
|
-
// "\\.pnp\\.[^\\/]+$"
|
|
136
|
-
// ],
|
|
137
|
-
// An array of regexp pattern strings that are matched against all modules before the module loader will automatically return a mock for them
|
|
138
|
-
// unmockedModulePathPatterns: undefined,
|
|
139
|
-
// Indicates whether each individual test should be reported during the run
|
|
140
|
-
// verbose: undefined,
|
|
141
|
-
// An array of regexp patterns that are matched against all source file paths before re-running tests in watch mode
|
|
142
|
-
// watchPathIgnorePatterns: [],
|
|
143
|
-
// Whether to use watchman for file crawling
|
|
144
|
-
// watchman: true,
|
|
145
|
-
};
|
|
146
|
-
export default config;
|
|
147
|
-
//# sourceMappingURL=jest.config.js.map
|
package/dist/jest.config.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"jest.config.js","sourceRoot":"","sources":["../jest.config.ts"],"names":[],"mappings":"AAAA;;;GAGG;AAIH,MAAM,MAAM,GAAW;IACnB,UAAU,EAAE,EAAE;IAEd,sBAAsB,EAAE,CAAC,KAAK,CAAC;IAE/B,0DAA0D;IAC1D,SAAS,EAAE;QACP,WAAW,EAAE,CAAC,SAAS,EAAE;gBACrB,MAAM,EAAE,IAAI;aACf,CAAC;KACL;IAED,wGAAwG;IACxG,4BAA4B;IAC5B,OAAO;IACP,KAAK;IAEL,oIAAoI;IACpI,gBAAgB,EAAE;QACd,UAAU,EAAE,kBAAkB;KACjC;IAED,oEAAoE;IACpE,mBAAmB;IAEnB,wCAAwC;IACxC,WAAW;IAEX,0EAA0E;IAC1E,sFAAsF;IAEtF,oFAAoF;IACpF,UAAU,EAAE,IAAI;IAEhB,0FAA0F;IAC1F,eAAe,EAAE,KAAK;IAEtB,yGAAyG;IACzG,kCAAkC;IAElC,4DAA4D;IAC5D,iBAAiB,EAAE,UAAU;IAE7B,sEAAsE;IACtE,gCAAgC;IAChC,qBAAqB;IACrB,KAAK;IAEL,0EAA0E;IAC1E,6BAA6B;IAE7B,wEAAwE;IACxE,uBAAuB;IACvB,YAAY;IACZ,YAAY;IACZ,YAAY;IACZ,aAAa;IACb,KAAK;IAEL,+EAA+E;IAC/E,gCAAgC;IAEhC,0CAA0C;IAC1C,kCAAkC;IAElC,4DAA4D;IAC5D,4BAA4B;IAE5B,4CAA4C;IAC5C,gBAAgB;IAChB,4BAA4B;IAC5B,KAAK;IAEL,+EAA+E;IAC/E,0BAA0B;IAE1B,mGAAmG;IACnG,0BAA0B;IAE1B,kGAAkG;IAClG,6BAA6B;IAE7B,+EAA+E;IAC/E,eAAe;IAEf,iOAAiO;IACjO,qBAAqB;IAErB,+CAA+C;IAC/C,0BAA0B;IAC1B,UAAU;IACV,WAAW;IACX,WAAW;IACX,WAAW;IACX,UAAU;IACV,WAAW;IACX,YAAY;IACZ,WAAW;IACX,KAAK;IAEL,wHAAwH;IACxH,wBAAwB,EAAE,CAAC,iBAAiB,CAAC;IAE7C,2CAA2C;IAC3C,iBAAiB;IAEjB,sEAAsE;IACtE,gCAAgC;IAEhC,2DAA2D;IAC3D,qBAAqB;IAErB,sCAAsC;IACtC,uBAAuB;IAEvB,gEAAgE;IAChE,wBAAwB;IAExB,mDAAmD;IACnD,qBAAqB;IAErB,gEAAgE;IAChE,uBAAuB;IAEvB,8BAA8B;IAC9B,uBAAuB;IAEvB,wEAAwE;IACxE,uBAAuB;IAEvB,wEAAwE;IACxE,sBAAsB;IAEtB,6EAA6E;IAC7E,WAAW;IACX,gBAAgB;IAChB,KAAK;IAEL,0EAA0E;IAC1E,yBAAyB;IAEzB,0GAA0G;IAE1G,8GAA8G;IAC9G,0BAA0B;IAE1B,sGAAsG;IACtG,wBAAwB;IAExB,sFAAsF;IACtF,2BAA2B;IAE3B,qDAAqD;IACrD,eAAe,EAAE,MAAM;IAEvB,qDAAqD;IACrD,8BAA8B;IAE9B,wCAAwC;IACxC,gCAAgC;IAEhC,mDAAmD;IACnD,eAAe;IACf,mCAAmC;IACnC,qCAAqC;IACrC,KAAK;IAEL,8EAA8E;IAC9E,iBAAiB;IAEjB,2DAA2D;IAC3D,mCAAmC;IAEnC,iDAAiD;IACjD,oCAAoC;IAEpC,4HAA4H;IAC5H,6BAA6B;IAC7B,sBAAsB;IACtB,wBAAwB;IACxB,KAAK;IAEL,6IAA6I;IAC7I,yCAAyC;IAEzC,2EAA2E;IAC3E,sBAAsB;IAEtB,mHAAmH;IACnH,+BAA+B;IAE/B,4CAA4C;IAC5C,kBAAkB;CACrB,CAAC;AAEF,eAAe,MAAM,CAAC"}
|
package/dist/src/index.d.ts
DELETED
package/dist/src/index.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
|
package/dist/src/index.js
DELETED
package/dist/src/index.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
|
package/dist/src/kv_cache.d.ts
DELETED
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
export interface KvCacheArgs {
|
|
3
|
-
batchSize: number;
|
|
4
|
-
maxSequenceLength: number;
|
|
5
|
-
numHeads: number;
|
|
6
|
-
headDim: number;
|
|
7
|
-
dtype?: tf.DataType;
|
|
8
|
-
}
|
|
9
|
-
/**
|
|
10
|
-
* A container for KV caches. A model should initialize one KV cache
|
|
11
|
-
*/
|
|
12
|
-
export declare class KvCacheContainer {
|
|
13
|
-
protected caches: Map<string, KvCache>;
|
|
14
|
-
protected max_sequence_length: number;
|
|
15
|
-
constructor(maxSequenceLength: number);
|
|
16
|
-
create(id: string, args: Omit<KvCacheArgs, "maxSequenceLength">): void;
|
|
17
|
-
/**
|
|
18
|
-
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
19
|
-
*/
|
|
20
|
-
update(id: string, key: tf.Tensor4D, value: tf.Tensor4D): {
|
|
21
|
-
keyCache: tf.Variable<tf.Rank.R4>;
|
|
22
|
-
valueCache: tf.Variable<tf.Rank.R4>;
|
|
23
|
-
} | undefined;
|
|
24
|
-
reset(): void;
|
|
25
|
-
dispose(): void;
|
|
26
|
-
get size(): number;
|
|
27
|
-
get maxSequenceLength(): number;
|
|
28
|
-
}
|
|
29
|
-
export declare class KvCache {
|
|
30
|
-
protected key_cache: tf.Variable<tf.Rank.R4>;
|
|
31
|
-
protected value_cache: tf.Variable<tf.Rank.R4>;
|
|
32
|
-
protected current_position: number;
|
|
33
|
-
protected batch_size: number;
|
|
34
|
-
protected max_sequence_length: number;
|
|
35
|
-
protected num_kv_heads: number;
|
|
36
|
-
protected head_dim: number;
|
|
37
|
-
constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype }: KvCacheArgs);
|
|
38
|
-
/**
|
|
39
|
-
* The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
|
|
40
|
-
*/
|
|
41
|
-
update(key: tf.Tensor4D, value: tf.Tensor4D): {
|
|
42
|
-
keyCache: tf.Variable<tf.Rank.R4>;
|
|
43
|
-
valueCache: tf.Variable<tf.Rank.R4>;
|
|
44
|
-
};
|
|
45
|
-
protected mergeIntoCache(new_value: tf.Tensor4D, current_cache: tf.Tensor4D): tf.Tensor4D;
|
|
46
|
-
reset(): void;
|
|
47
|
-
dispose(): void;
|
|
48
|
-
/**
|
|
49
|
-
* The size of the KV cache, also the number of tokens since the first one.
|
|
50
|
-
*/
|
|
51
|
-
get size(): number;
|
|
52
|
-
}
|
|
53
|
-
//# sourceMappingURL=kv_cache.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
|