@stellarapp/tfjs-stellar 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +17 -0
  2. package/dist/index.d.ts +3 -1
  3. package/dist/index.d.ts.map +1 -1
  4. package/dist/index.js +3 -1
  5. package/dist/index.js.map +1 -1
  6. package/dist/kv_cache.d.ts +2 -0
  7. package/dist/kv_cache.d.ts.map +1 -1
  8. package/dist/kv_cache.js +6 -0
  9. package/dist/kv_cache.js.map +1 -1
  10. package/dist/models/index.d.ts +2 -1
  11. package/dist/models/index.d.ts.map +1 -1
  12. package/dist/models/index.js +2 -1
  13. package/dist/models/index.js.map +1 -1
  14. package/package.json +1 -1
  15. package/dist/jest.config.d.ts +0 -8
  16. package/dist/jest.config.d.ts.map +0 -1
  17. package/dist/jest.config.js +0 -147
  18. package/dist/jest.config.js.map +0 -1
  19. package/dist/src/index.d.ts +0 -6
  20. package/dist/src/index.d.ts.map +0 -1
  21. package/dist/src/index.js +0 -6
  22. package/dist/src/index.js.map +0 -1
  23. package/dist/src/kv_cache.d.ts +0 -53
  24. package/dist/src/kv_cache.d.ts.map +0 -1
  25. package/dist/src/kv_cache.js +0 -135
  26. package/dist/src/kv_cache.js.map +0 -1
  27. package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
  28. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
  29. package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
  30. package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
  31. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
  32. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
  33. package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
  34. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
  35. package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
  36. package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
  37. package/dist/src/layers/gpt_decoder_block.js +0 -51
  38. package/dist/src/layers/gpt_decoder_block.js.map +0 -1
  39. package/dist/src/layers/index.d.ts +0 -17
  40. package/dist/src/layers/index.d.ts.map +0 -1
  41. package/dist/src/layers/index.js +0 -33
  42. package/dist/src/layers/index.js.map +0 -1
  43. package/dist/src/layers/multihead_attention.d.ts +0 -106
  44. package/dist/src/layers/multihead_attention.d.ts.map +0 -1
  45. package/dist/src/layers/multihead_attention.js +0 -269
  46. package/dist/src/layers/multihead_attention.js.map +0 -1
  47. package/dist/src/layers/multihead_attention.test.d.ts +0 -2
  48. package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
  49. package/dist/src/layers/multihead_attention.test.js +0 -160
  50. package/dist/src/layers/multihead_attention.test.js.map +0 -1
  51. package/dist/src/layers/positional_encoding.d.ts +0 -37
  52. package/dist/src/layers/positional_encoding.d.ts.map +0 -1
  53. package/dist/src/layers/positional_encoding.js +0 -115
  54. package/dist/src/layers/positional_encoding.js.map +0 -1
  55. package/dist/src/layers/positional_encoding.test.d.ts +0 -2
  56. package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
  57. package/dist/src/layers/positional_encoding.test.js +0 -95
  58. package/dist/src/layers/positional_encoding.test.js.map +0 -1
  59. package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
  60. package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
  61. package/dist/src/layers/rotary_position_embedding.js +0 -99
  62. package/dist/src/layers/rotary_position_embedding.js.map +0 -1
  63. package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
  64. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
  65. package/dist/src/layers/rotary_position_embedding.test.js +0 -88
  66. package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
  67. package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
  68. package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
  69. package/dist/src/layers/token_and_positional_embedding.js +0 -109
  70. package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
  71. package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
  72. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
  73. package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
  74. package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
  75. package/dist/src/layers/transformer_decoder.d.ts +0 -69
  76. package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
  77. package/dist/src/layers/transformer_decoder.js +0 -182
  78. package/dist/src/layers/transformer_decoder.js.map +0 -1
  79. package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
  80. package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
  81. package/dist/src/layers/transformer_decoder.test.js +0 -72
  82. package/dist/src/layers/transformer_decoder.test.js.map +0 -1
  83. package/dist/src/layers/transformer_encoder.d.ts +0 -55
  84. package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
  85. package/dist/src/layers/transformer_encoder.js +0 -175
  86. package/dist/src/layers/transformer_encoder.js.map +0 -1
  87. package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
  88. package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
  89. package/dist/src/layers/transformer_encoder.test.js +0 -58
  90. package/dist/src/layers/transformer_encoder.test.js.map +0 -1
  91. package/dist/src/losses/dice.d.ts +0 -30
  92. package/dist/src/losses/dice.d.ts.map +0 -1
  93. package/dist/src/losses/dice.js +0 -93
  94. package/dist/src/losses/dice.js.map +0 -1
  95. package/dist/src/losses/index.d.ts +0 -2
  96. package/dist/src/losses/index.d.ts.map +0 -1
  97. package/dist/src/losses/index.js +0 -2
  98. package/dist/src/losses/index.js.map +0 -1
  99. package/dist/src/masks.d.ts +0 -20
  100. package/dist/src/masks.d.ts.map +0 -1
  101. package/dist/src/masks.js +0 -37
  102. package/dist/src/masks.js.map +0 -1
  103. package/dist/src/metrics.d.ts +0 -20
  104. package/dist/src/metrics.d.ts.map +0 -1
  105. package/dist/src/metrics.js +0 -28
  106. package/dist/src/metrics.js.map +0 -1
  107. package/dist/src/models/gpt_model.d.ts +0 -94
  108. package/dist/src/models/gpt_model.d.ts.map +0 -1
  109. package/dist/src/models/gpt_model.js +0 -154
  110. package/dist/src/models/gpt_model.js.map +0 -1
  111. package/dist/src/models/index.d.ts +0 -3
  112. package/dist/src/models/index.d.ts.map +0 -1
  113. package/dist/src/models/index.js +0 -3
  114. package/dist/src/models/index.js.map +0 -1
  115. package/dist/src/models/llm_model.d.ts +0 -87
  116. package/dist/src/models/llm_model.d.ts.map +0 -1
  117. package/dist/src/models/llm_model.js +0 -245
  118. package/dist/src/models/llm_model.js.map +0 -1
  119. package/dist/src/models/u_net.d.ts +0 -40
  120. package/dist/src/models/u_net.d.ts.map +0 -1
  121. package/dist/src/models/u_net.js +0 -151
  122. package/dist/src/models/u_net.js.map +0 -1
  123. package/dist/src/tfjs_types.d.ts +0 -10
  124. package/dist/src/tfjs_types.d.ts.map +0 -1
  125. package/dist/src/tfjs_types.js +0 -2
  126. package/dist/src/tfjs_types.js.map +0 -1
  127. package/dist/src/utils.d.ts +0 -28
  128. package/dist/src/utils.d.ts.map +0 -1
  129. package/dist/src/utils.js +0 -63
  130. package/dist/src/utils.js.map +0 -1
  131. package/dist/src/utils.test.d.ts +0 -2
  132. package/dist/src/utils.test.d.ts.map +0 -1
  133. package/dist/src/utils.test.js +0 -73
  134. package/dist/src/utils.test.js.map +0 -1
package/README.md CHANGED
@@ -44,4 +44,21 @@ gpt_model.summary();
44
44
  // see https://js.tensorflow.org/api/latest/#data.generator
45
45
  // on how to create a generator dataset
46
46
  //gpt_model.fitDataset(your_generator_dataset, { epochs: 1 });
47
+ ```
48
+
49
+ ## Jest Unit Testing
50
+
51
+ If you plan to use this library in your Jest unit tests, you may need to add the following configurations to your `jest.config.ts` file's `config`
52
+
53
+ ```ts
54
+ // A map from regular expressions to paths to transformers
55
+ transform: {
56
+ '^.+\\.[jt]s?$': ["ts-jest", {
57
+ useESM: true,
58
+ }]
59
+ },
60
+
61
+ transformIgnorePatterns: [
62
+ "/node_modules/(?!(@stellarapp/tfjs-stellar|@tensorflow/tfjs))"
63
+ ],
47
64
  ```
package/dist/index.d.ts CHANGED
@@ -2,6 +2,8 @@ export * as layers from "./layers";
2
2
  export * as models from "./models";
3
3
  export * as losses from "./losses";
4
4
  export * as masks from "./masks";
5
- export { KvCache as kvCache, KvCacheContainer as kvCacheContainer } from "./kv_cache";
5
+ export * from "./kv_cache";
6
6
  export * as metrics from "./metrics";
7
+ export * as utils from "./utils";
8
+ export { loadUNetModel } from "./models/u_net";
7
9
  //# sourceMappingURL=index.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,OAAO,IAAI,OAAO,EAAE,gBAAgB,IAAI,gBAAgB,EAAE,MAAM,YAAY,CAAC;AACtF,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC"}
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,cAAc,YAAY,CAAC;AAC3B,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC;AACrC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,aAAa,EAAE,MAAM,gBAAgB,CAAC"}
package/dist/index.js CHANGED
@@ -2,6 +2,8 @@ export * as layers from "./layers";
2
2
  export * as models from "./models";
3
3
  export * as losses from "./losses";
4
4
  export * as masks from "./masks";
5
- export { KvCache as kvCache, KvCacheContainer as kvCacheContainer } from "./kv_cache";
5
+ export * from "./kv_cache";
6
6
  export * as metrics from "./metrics";
7
+ export * as utils from "./utils";
8
+ export { loadUNetModel } from "./models/u_net";
7
9
  //# sourceMappingURL=index.js.map
package/dist/index.js.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,OAAO,IAAI,OAAO,EAAE,gBAAgB,IAAI,gBAAgB,EAAE,MAAM,YAAY,CAAC;AACtF,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC"}
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,cAAc,YAAY,CAAC;AAC3B,OAAO,KAAK,OAAO,MAAM,WAAW,CAAC;AACrC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AACjC,OAAO,EAAE,aAAa,EAAE,MAAM,gBAAgB,CAAC"}
@@ -6,6 +6,8 @@ export interface KvCacheArgs {
6
6
  headDim: number;
7
7
  dtype?: tf.DataType;
8
8
  }
9
+ export declare function kvCacheContainer(maxSequenceLength: number): KvCacheContainer;
10
+ export declare function kvCache(args: KvCacheArgs): KvCache;
9
11
  /**
10
12
  * A container for KV caches. A model should initialize one KV cache
11
13
  */
@@ -1 +1 @@
1
- {"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
1
+ {"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD,wBAAgB,gBAAgB,CAAC,iBAAiB,EAAE,MAAM,oBAEzD;AAGD,wBAAgB,OAAO,CAAC,IAAI,EAAE,WAAW,WAExC;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
package/dist/kv_cache.js CHANGED
@@ -1,4 +1,10 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
+ export function kvCacheContainer(maxSequenceLength) {
3
+ return new KvCacheContainer(maxSequenceLength);
4
+ }
5
+ export function kvCache(args) {
6
+ return new KvCache(args);
7
+ }
2
8
  /**
3
9
  * A container for KV caches. A model should initialize one KV cache
4
10
  */
@@ -1 +1 @@
1
- {"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
1
+ {"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC,MAAM,UAAU,gBAAgB,CAAC,iBAAyB;IACtD,OAAO,IAAI,gBAAgB,CAAC,iBAAiB,CAAC,CAAC;AACnD,CAAC;AAGD,MAAM,UAAU,OAAO,CAAC,IAAiB;IACrC,OAAO,IAAI,OAAO,CAAC,IAAI,CAAC,CAAC;AAC7B,CAAC;AAGD;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
@@ -1,6 +1,7 @@
1
1
  import { LlmModel, type LlmModelArgs } from "./llm_model";
2
2
  import { GptModel, type GptModelArgs } from "../models/gpt_model";
3
- import { type UNetArgs } from "../models/u_net";
3
+ import { UNetModel, type UNetArgs } from "../models/u_net";
4
+ export { LlmModel, LlmModelArgs, GptModel, GptModelArgs, UNetModel, UNetArgs };
4
5
  export declare function llmModel(args: LlmModelArgs): LlmModel;
5
6
  export declare function gptModel(args: GptModelArgs): GptModel;
6
7
  export declare function unetModel(args: UNetArgs): import("@tensorflow/tfjs-layers").LayersModel;
@@ -1 +1 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAc,KAAK,QAAQ,EAAE,MAAM,iBAAiB,CAAC;AAG5D,wBAAgB,QAAQ,CAAC,IAAI,EAAE,YAAY,YAE1C;AAGD,wBAAgB,QAAQ,CAAC,IAAI,EAAE,YAAY,YAE1C;AAGD,wBAAgB,SAAS,CAAC,IAAI,EAAE,QAAQ,iDAEvC"}
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAc,SAAS,EAAE,KAAK,QAAQ,EAAE,MAAM,iBAAiB,CAAC;AACvE,OAAO,EACH,QAAQ,EAAE,YAAY,EACtB,QAAQ,EAAE,YAAY,EACtB,SAAS,EAAE,QAAQ,EACtB,CAAA;AAED,wBAAgB,QAAQ,CAAC,IAAI,EAAE,YAAY,YAE1C;AAGD,wBAAgB,QAAQ,CAAC,IAAI,EAAE,YAAY,YAE1C;AAGD,wBAAgB,SAAS,CAAC,IAAI,EAAE,QAAQ,iDAEvC"}
@@ -1,6 +1,7 @@
1
1
  import { LlmModel } from "./llm_model";
2
2
  import { GptModel } from "../models/gpt_model";
3
- import { createUNet } from "../models/u_net";
3
+ import { createUNet, UNetModel } from "../models/u_net";
4
+ export { LlmModel, GptModel, UNetModel };
4
5
  export function llmModel(args) {
5
6
  return new LlmModel(args);
6
7
  }
@@ -1 +1 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAqB,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAqB,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAE,UAAU,EAAiB,MAAM,iBAAiB,CAAC;AAG5D,MAAM,UAAU,QAAQ,CAAC,IAAkB;IACvC,OAAO,IAAI,QAAQ,CAAC,IAAI,CAAC,CAAC;AAC9B,CAAC;AAGD,MAAM,UAAU,QAAQ,CAAC,IAAkB;IACvC,OAAO,IAAI,QAAQ,CAAC,IAAI,CAAC,CAAC;AAC9B,CAAC;AAGD,MAAM,UAAU,SAAS,CAAC,IAAc;IACpC,OAAO,UAAU,CAAC,IAAI,CAAC,CAAC;AAC5B,CAAC"}
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/models/index.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,QAAQ,EAAqB,MAAM,aAAa,CAAC;AAC1D,OAAO,EAAE,QAAQ,EAAqB,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAE,UAAU,EAAE,SAAS,EAAiB,MAAM,iBAAiB,CAAC;AACvE,OAAO,EACH,QAAQ,EACR,QAAQ,EACR,SAAS,EACZ,CAAA;AAED,MAAM,UAAU,QAAQ,CAAC,IAAkB;IACvC,OAAO,IAAI,QAAQ,CAAC,IAAI,CAAC,CAAC;AAC9B,CAAC;AAGD,MAAM,UAAU,QAAQ,CAAC,IAAkB;IACvC,OAAO,IAAI,QAAQ,CAAC,IAAI,CAAC,CAAC;AAC9B,CAAC;AAGD,MAAM,UAAU,SAAS,CAAC,IAAc;IACpC,OAAO,UAAU,CAAC,IAAI,CAAC,CAAC;AAC5B,CAAC"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@stellarapp/tfjs-stellar",
3
- "version": "1.0.3",
3
+ "version": "1.0.5",
4
4
  "description": "An extension of TensorFlow.js for implementing large language models.",
5
5
  "license": "ISC",
6
6
  "author": "",
@@ -1,8 +0,0 @@
1
- /**
2
- * For a detailed explanation regarding each configuration property, visit:
3
- * https://jestjs.io/docs/configuration
4
- */
5
- import type { Config } from 'jest';
6
- declare const config: Config;
7
- export default config;
8
- //# sourceMappingURL=jest.config.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"jest.config.d.ts","sourceRoot":"","sources":["../jest.config.ts"],"names":[],"mappings":"AAAA;;;GAGG;AAEH,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,MAAM,CAAC;AAEnC,QAAA,MAAM,MAAM,EAAE,MAiMb,CAAC;AAEF,eAAe,MAAM,CAAC"}
@@ -1,147 +0,0 @@
1
- /**
2
- * For a detailed explanation regarding each configuration property, visit:
3
- * https://jestjs.io/docs/configuration
4
- */
5
- const config = {
6
- setupFiles: [],
7
- extensionsToTreatAsEsm: [".ts"],
8
- // A map from regular expressions to paths to transformers
9
- transform: {
10
- "^.+\.ts?$": ["ts-jest", {
11
- useESM: true
12
- }],
13
- },
14
- // An array of regexp pattern strings that are matched against all test paths, matched tests are skipped
15
- // testPathIgnorePatterns: [
16
- //
17
- // ],
18
- // A map from regular expressions to module names or to arrays of module names that allow to stub out resources with a single module
19
- moduleNameMapper: {
20
- "^@/(.*$)": "<rootDir>/src/$1"
21
- },
22
- // All imported modules in your tests should be mocked automatically
23
- // automock: false,
24
- // Stop running tests after `n` failures
25
- // bail: 0,
26
- // The directory where Jest should store its cached dependency information
27
- // cacheDirectory: "/private/var/folders/8x/0jgq0fqx5qzgtm1zc8xtrdc80000gn/T/jest_dx",
28
- // Automatically clear mock calls, instances, contexts and results before every test
29
- clearMocks: true,
30
- // Indicates whether the coverage information should be collected while executing the test
31
- collectCoverage: false,
32
- // An array of glob patterns indicating a set of files for which coverage information should be collected
33
- // collectCoverageFrom: undefined,
34
- // The directory where Jest should output its coverage files
35
- coverageDirectory: "coverage",
36
- // An array of regexp pattern strings used to skip coverage collection
37
- // coveragePathIgnorePatterns: [
38
- // "/node_modules/"
39
- // ],
40
- // Indicates which provider should be used to instrument code for coverage
41
- // coverageProvider: "babel",
42
- // A list of reporter names that Jest uses when writing coverage reports
43
- // coverageReporters: [
44
- // "json",
45
- // "text",
46
- // "lcov",
47
- // "clover"
48
- // ],
49
- // An object that configures minimum threshold enforcement for coverage results
50
- // coverageThreshold: undefined,
51
- // A path to a custom dependency extractor
52
- // dependencyExtractor: undefined,
53
- // Make calling deprecated APIs throw helpful error messages
54
- // errorOnDeprecated: false,
55
- // The default configuration for fake timers
56
- // fakeTimers: {
57
- // "enableGlobally": false
58
- // },
59
- // Force coverage collection from ignored files using an array of glob patterns
60
- // forceCoverageMatch: [],
61
- // A path to a module which exports an async function that is triggered once before all test suites
62
- // globalSetup: undefined,
63
- // A path to a module which exports an async function that is triggered once after all test suites
64
- // globalTeardown: undefined,
65
- // A set of global variables that need to be available in all test environments
66
- // globals: {},
67
- // The maximum amount of workers used to run your tests. Can be specified as % or a number. E.g. maxWorkers: 10% will use 10% of your CPU amount + 1 as the maximum worker number. maxWorkers: 2 will use a maximum of 2 workers.
68
- // maxWorkers: "50%",
69
- // An array of file extensions your modules use
70
- // moduleFileExtensions: [
71
- // "js",
72
- // "mjs",
73
- // "cjs",
74
- // "jsx",
75
- // "ts",
76
- // "tsx",
77
- // "json",
78
- // "node"
79
- // ],
80
- // An array of regexp pattern strings, matched against all module paths before considered 'visible' to the module loader
81
- modulePathIgnorePatterns: ["<rootDir>/dist/"],
82
- // Activates notifications for test results
83
- // notify: false,
84
- // An enum that specifies notification mode. Requires { notify: true }
85
- // notifyMode: "failure-change",
86
- // A preset that is used as a base for Jest's configuration
87
- // preset: undefined,
88
- // Run tests from one or more projects
89
- // projects: undefined,
90
- // Use this configuration option to add custom reporters to Jest
91
- // reporters: undefined,
92
- // Automatically reset mock state before every test
93
- // resetMocks: false,
94
- // Reset the module registry before running each individual test
95
- // resetModules: false,
96
- // A path to a custom resolver
97
- // resolver: undefined,
98
- // Automatically restore mock state and implementation before every test
99
- // restoreMocks: false,
100
- // The root directory that Jest should scan for tests and modules within
101
- // rootDir: undefined,
102
- // A list of paths to directories that Jest should use to search for files in
103
- // roots: [
104
- // "<rootDir>"
105
- // ],
106
- // Allows you to use a custom runner instead of Jest's default test runner
107
- // runner: "jest-runner",
108
- // The paths to modules that run some code to configure or set up the testing environment before each test
109
- // A list of paths to modules that run some code to configure or set up the testing framework before each test
110
- // setupFilesAfterEnv: [],
111
- // The number of seconds after which a test is considered as slow and reported as such in the results.
112
- // slowTestThreshold: 5,
113
- // A list of paths to snapshot serializer modules Jest should use for snapshot testing
114
- // snapshotSerializers: [],
115
- // The test environment that will be used for testing
116
- testEnvironment: "node",
117
- // Options that will be passed to the testEnvironment
118
- // testEnvironmentOptions: {},
119
- // Adds a location field to test results
120
- // testLocationInResults: false,
121
- // The glob patterns Jest uses to detect test files
122
- // testMatch: [
123
- // "**/__tests__/**/*.[jt]s?(x)",
124
- // "**/?(*.)+(spec|test).[tj]s?(x)"
125
- // ],
126
- // The regexp pattern or array of patterns that Jest uses to detect test files
127
- // testRegex: [],
128
- // This option allows the use of a custom results processor
129
- // testResultsProcessor: undefined,
130
- // This option allows use of a custom test runner
131
- // testRunner: "jest-circus/runner",
132
- // An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation
133
- // transformIgnorePatterns: [
134
- // "/node_modules/",
135
- // "\\.pnp\\.[^\\/]+$"
136
- // ],
137
- // An array of regexp pattern strings that are matched against all modules before the module loader will automatically return a mock for them
138
- // unmockedModulePathPatterns: undefined,
139
- // Indicates whether each individual test should be reported during the run
140
- // verbose: undefined,
141
- // An array of regexp patterns that are matched against all source file paths before re-running tests in watch mode
142
- // watchPathIgnorePatterns: [],
143
- // Whether to use watchman for file crawling
144
- // watchman: true,
145
- };
146
- export default config;
147
- //# sourceMappingURL=jest.config.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"jest.config.js","sourceRoot":"","sources":["../jest.config.ts"],"names":[],"mappings":"AAAA;;;GAGG;AAIH,MAAM,MAAM,GAAW;IACnB,UAAU,EAAE,EAAE;IAEd,sBAAsB,EAAE,CAAC,KAAK,CAAC;IAE/B,0DAA0D;IAC1D,SAAS,EAAE;QACP,WAAW,EAAE,CAAC,SAAS,EAAE;gBACrB,MAAM,EAAE,IAAI;aACf,CAAC;KACL;IAED,wGAAwG;IACxG,4BAA4B;IAC5B,OAAO;IACP,KAAK;IAEL,oIAAoI;IACpI,gBAAgB,EAAE;QACd,UAAU,EAAE,kBAAkB;KACjC;IAED,oEAAoE;IACpE,mBAAmB;IAEnB,wCAAwC;IACxC,WAAW;IAEX,0EAA0E;IAC1E,sFAAsF;IAEtF,oFAAoF;IACpF,UAAU,EAAE,IAAI;IAEhB,0FAA0F;IAC1F,eAAe,EAAE,KAAK;IAEtB,yGAAyG;IACzG,kCAAkC;IAElC,4DAA4D;IAC5D,iBAAiB,EAAE,UAAU;IAE7B,sEAAsE;IACtE,gCAAgC;IAChC,qBAAqB;IACrB,KAAK;IAEL,0EAA0E;IAC1E,6BAA6B;IAE7B,wEAAwE;IACxE,uBAAuB;IACvB,YAAY;IACZ,YAAY;IACZ,YAAY;IACZ,aAAa;IACb,KAAK;IAEL,+EAA+E;IAC/E,gCAAgC;IAEhC,0CAA0C;IAC1C,kCAAkC;IAElC,4DAA4D;IAC5D,4BAA4B;IAE5B,4CAA4C;IAC5C,gBAAgB;IAChB,4BAA4B;IAC5B,KAAK;IAEL,+EAA+E;IAC/E,0BAA0B;IAE1B,mGAAmG;IACnG,0BAA0B;IAE1B,kGAAkG;IAClG,6BAA6B;IAE7B,+EAA+E;IAC/E,eAAe;IAEf,iOAAiO;IACjO,qBAAqB;IAErB,+CAA+C;IAC/C,0BAA0B;IAC1B,UAAU;IACV,WAAW;IACX,WAAW;IACX,WAAW;IACX,UAAU;IACV,WAAW;IACX,YAAY;IACZ,WAAW;IACX,KAAK;IAEL,wHAAwH;IACxH,wBAAwB,EAAE,CAAC,iBAAiB,CAAC;IAE7C,2CAA2C;IAC3C,iBAAiB;IAEjB,sEAAsE;IACtE,gCAAgC;IAEhC,2DAA2D;IAC3D,qBAAqB;IAErB,sCAAsC;IACtC,uBAAuB;IAEvB,gEAAgE;IAChE,wBAAwB;IAExB,mDAAmD;IACnD,qBAAqB;IAErB,gEAAgE;IAChE,uBAAuB;IAEvB,8BAA8B;IAC9B,uBAAuB;IAEvB,wEAAwE;IACxE,uBAAuB;IAEvB,wEAAwE;IACxE,sBAAsB;IAEtB,6EAA6E;IAC7E,WAAW;IACX,gBAAgB;IAChB,KAAK;IAEL,0EAA0E;IAC1E,yBAAyB;IAEzB,0GAA0G;IAE1G,8GAA8G;IAC9G,0BAA0B;IAE1B,sGAAsG;IACtG,wBAAwB;IAExB,sFAAsF;IACtF,2BAA2B;IAE3B,qDAAqD;IACrD,eAAe,EAAE,MAAM;IAEvB,qDAAqD;IACrD,8BAA8B;IAE9B,wCAAwC;IACxC,gCAAgC;IAEhC,mDAAmD;IACnD,eAAe;IACf,mCAAmC;IACnC,qCAAqC;IACrC,KAAK;IAEL,8EAA8E;IAC9E,iBAAiB;IAEjB,2DAA2D;IAC3D,mCAAmC;IAEnC,iDAAiD;IACjD,oCAAoC;IAEpC,4HAA4H;IAC5H,6BAA6B;IAC7B,sBAAsB;IACtB,wBAAwB;IACxB,KAAK;IAEL,6IAA6I;IAC7I,yCAAyC;IAEzC,2EAA2E;IAC3E,sBAAsB;IAEtB,mHAAmH;IACnH,+BAA+B;IAE/B,4CAA4C;IAC5C,kBAAkB;CACrB,CAAC;AAEF,eAAe,MAAM,CAAC"}
@@ -1,6 +0,0 @@
1
- export * as layers from "@/layers";
2
- export * as models from "@/models";
3
- export * as losses from "@/losses";
4
- export * from "@/kv_cache";
5
- export * from "@/metrics";
6
- //# sourceMappingURL=index.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
package/dist/src/index.js DELETED
@@ -1,6 +0,0 @@
1
- export * as layers from "@/layers";
2
- export * as models from "@/models";
3
- export * as losses from "@/losses";
4
- export * from "@/kv_cache";
5
- export * from "@/metrics";
6
- //# sourceMappingURL=index.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
@@ -1,53 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- export interface KvCacheArgs {
3
- batchSize: number;
4
- maxSequenceLength: number;
5
- numHeads: number;
6
- headDim: number;
7
- dtype?: tf.DataType;
8
- }
9
- /**
10
- * A container for KV caches. A model should initialize one KV cache
11
- */
12
- export declare class KvCacheContainer {
13
- protected caches: Map<string, KvCache>;
14
- protected max_sequence_length: number;
15
- constructor(maxSequenceLength: number);
16
- create(id: string, args: Omit<KvCacheArgs, "maxSequenceLength">): void;
17
- /**
18
- * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
19
- */
20
- update(id: string, key: tf.Tensor4D, value: tf.Tensor4D): {
21
- keyCache: tf.Variable<tf.Rank.R4>;
22
- valueCache: tf.Variable<tf.Rank.R4>;
23
- } | undefined;
24
- reset(): void;
25
- dispose(): void;
26
- get size(): number;
27
- get maxSequenceLength(): number;
28
- }
29
- export declare class KvCache {
30
- protected key_cache: tf.Variable<tf.Rank.R4>;
31
- protected value_cache: tf.Variable<tf.Rank.R4>;
32
- protected current_position: number;
33
- protected batch_size: number;
34
- protected max_sequence_length: number;
35
- protected num_kv_heads: number;
36
- protected head_dim: number;
37
- constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype }: KvCacheArgs);
38
- /**
39
- * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
40
- */
41
- update(key: tf.Tensor4D, value: tf.Tensor4D): {
42
- keyCache: tf.Variable<tf.Rank.R4>;
43
- valueCache: tf.Variable<tf.Rank.R4>;
44
- };
45
- protected mergeIntoCache(new_value: tf.Tensor4D, current_cache: tf.Tensor4D): tf.Tensor4D;
46
- reset(): void;
47
- dispose(): void;
48
- /**
49
- * The size of the KV cache, also the number of tokens since the first one.
50
- */
51
- get size(): number;
52
- }
53
- //# sourceMappingURL=kv_cache.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
@@ -1,135 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- /**
3
- * A container for KV caches. A model should initialize one KV cache
4
- */
5
- export class KvCacheContainer {
6
- caches = new Map();
7
- max_sequence_length;
8
- constructor(maxSequenceLength) {
9
- if (!maxSequenceLength) {
10
- throw Error(`KvCacheContainer: expected KV cache maximum sequence length to be greater than 0, got: ${String(maxSequenceLength)}`);
11
- }
12
- this.max_sequence_length = maxSequenceLength;
13
- }
14
- create(id, args) {
15
- const new_cache = new KvCache({
16
- ...args,
17
- maxSequenceLength: this.max_sequence_length
18
- });
19
- this.caches.set(id, new_cache);
20
- }
21
- /**
22
- * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
23
- */
24
- update(id, key, value) {
25
- const kv_cache = this.caches.get(id);
26
- if (!kv_cache) {
27
- return undefined;
28
- }
29
- const { keyCache, valueCache } = kv_cache.update(key, value);
30
- // slicing to get only the past key and value projections, but normally
31
- // in TensorFlow and PyTorch the full cache is returned and masked for
32
- // graph purposes
33
- return tf.tidy(() => {
34
- const k_cache = keyCache.slice([0, 0, 0, 0], [keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
35
- const v_cache = valueCache.slice([0, 0, 0, 0], [valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
36
- return {
37
- keyCache: k_cache,
38
- valueCache: v_cache
39
- };
40
- });
41
- }
42
- reset() {
43
- this.caches.forEach(cache => {
44
- cache.reset();
45
- });
46
- }
47
- dispose() {
48
- this.caches.forEach(cache => {
49
- cache.dispose();
50
- });
51
- }
52
- get size() {
53
- // the size of all KV caches are expected to be the same, just use the first one
54
- return this.caches.entries().next().value?.[1].size ?? 0;
55
- }
56
- get maxSequenceLength() {
57
- return this.max_sequence_length;
58
- }
59
- }
60
- export class KvCache {
61
- key_cache;
62
- value_cache;
63
- // the size of the KV cache, represents the number of tokens since the first chat token
64
- current_position = 0;
65
- batch_size;
66
- max_sequence_length;
67
- num_kv_heads;
68
- head_dim;
69
- constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }) {
70
- const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim];
71
- this.key_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
72
- this.value_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
73
- this.batch_size = batchSize;
74
- this.max_sequence_length = maxSequenceLength;
75
- this.num_kv_heads = numHeads;
76
- this.head_dim = headDim;
77
- }
78
- /**
79
- * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
80
- */
81
- update(key, value) {
82
- const batch_size = key.shape[0];
83
- const seq_len = key.shape[2];
84
- if (batch_size > this.key_cache.shape[0]) {
85
- throw Error(`The current KV cache has been set up with a batch size of` +
86
- ` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`);
87
- }
88
- if (this.current_position + seq_len > this.max_sequence_length) {
89
- throw Error(`The KV cache has exceeded its maximum sequence length of ${this.max_sequence_length}. Use a larger value.`);
90
- }
91
- const new_key_cache = this.mergeIntoCache(key, this.key_cache);
92
- const new_value_cache = this.mergeIntoCache(value, this.value_cache);
93
- this.key_cache.assign(new_key_cache);
94
- this.value_cache.assign(new_value_cache);
95
- new_key_cache.dispose();
96
- new_value_cache.dispose();
97
- // advance the pointer to reflect the updated cache's current
98
- this.current_position += seq_len;
99
- return {
100
- keyCache: this.key_cache,
101
- valueCache: this.value_cache,
102
- };
103
- }
104
- mergeIntoCache(new_value, current_cache) {
105
- const seq_len = new_value.shape[2];
106
- return tf.tidy(() => {
107
- const historical = current_cache.slice([0, 0, 0, 0], [this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
108
- const future = current_cache.slice([0, 0, this.current_position + seq_len, 0], [this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
109
- // merge the new tensor into the current cache to create a new, larger, cache,
110
- // this is different from Python immplementations because TFJS tensors are immutable,
111
- // because we cannot update a slice, we must slice and concat
112
- return tf.concat([historical, new_value, future], 2);
113
- });
114
- }
115
- reset() {
116
- this.current_position = 0;
117
- tf.tidy(() => {
118
- const key_cache_shape = this.key_cache.shape;
119
- const value_cache_shape = this.value_cache.shape;
120
- this.key_cache.assign(tf.zeros(key_cache_shape));
121
- this.value_cache.assign(tf.zeros(value_cache_shape));
122
- });
123
- }
124
- dispose() {
125
- this.key_cache.dispose();
126
- this.value_cache.dispose();
127
- }
128
- /**
129
- * The size of the KV cache, also the number of tokens since the first one.
130
- */
131
- get size() {
132
- return this.current_position;
133
- }
134
- }
135
- //# sourceMappingURL=kv_cache.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
@@ -1,31 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { KvCacheContainer } from "@/kv_cache";
3
- import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
4
- import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
5
- /**
6
- * MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
7
- * should be used in a custom training loop because it requires the cache to be
8
- * passed through the `kwargs.kvCache` argument during the `layer.apply()`
9
- * forward propagation.
10
- *
11
- * If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
12
- */
13
- export declare class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
14
- static className: string;
15
- protected rope: tf.layers.Layer;
16
- constructor(args: MultiHeadAttentionArgs);
17
- protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
18
- protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D): {
19
- keyCache: tf.Variable<tf.Rank.R4>;
20
- valueCache: tf.Variable<tf.Rank.R4>;
21
- };
22
- /**
23
- * Adds RoPE position encoding right after splitting heads.
24
- */
25
- protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
26
- query_split: tf.Tensor4D;
27
- key_split: tf.Tensor4D;
28
- value_split: tf.Tensor4D;
29
- };
30
- }
31
- //# sourceMappingURL=cached_rope_multihead_attention.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"cached_rope_multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAE/F,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE;;;;;;;GAOG;AACH,qBAAa,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,SAAkC;IAElD,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAEpB,IAAI,EAAE,sBAAsB;cAMrB,OAAO,CACtB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAuC9B,SAAS,CAAC,WAAW,CAAC,YAAY,EAAE,gBAAgB,EAAE,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,WAAW,EAAE,EAAE,CAAC,QAAQ;;;;IAqBtG;;OAEG;cACgB,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAO5D,EAAE,CAAC,QAAQ;mBAEX,EAAE,CAAC,QAAQ;qBACwB,EAAE,CAAC,QAAQ;;CAIxF"}