@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,145 +1,73 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
- import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
3
-
4
-
5
- export interface UNetArgs {
6
- /**
7
- * The starting number of filters.
8
- */
9
- filters: number;
10
- /**
11
- * The number of categories. For binary segmentation, `units=1`.
12
- */
13
- units: number;
14
- /**
15
- * The activation of the final output convolution layer. Defaults to `sigmoid` if `categories=1`, else `softmax`.
16
- */
17
- activation?: ActivationIdentifier;
18
- /**
19
- * The depth of the U-Net or the number of contractions and the number of expansions.
20
- */
21
- depth: number;
22
- /**
23
- * Adds residual connections to transform the model into a ResUNet. Defaults to `false`.
24
- */
25
- residual?: boolean;
26
- /**
27
- * Adds batch normalization to convolutions. Best used for batched inputs. Defaults to `false`.
28
- */
29
- batchNorm?: boolean;
30
- /**
31
- * Set the unbatched input shape of the U-Net in the format `[height, width, channels]`. Defaults to `[null, null, 3]`. If set, only channels is mandatory.
32
- */
33
- inputShape?: [number | null, number | null, number];
34
- }
35
-
36
-
37
- export type UNetModelArgs = UNetArgs & Omit<tf.SequentialArgs, "layers">;
38
-
39
-
40
2
  export class UNetModel extends tf.Sequential {
41
-
42
- constructor(args: UNetModelArgs) {
43
- const {
44
- filters,
45
- units,
46
- activation = units == 1 ? "sigmoid" : "softmax",
47
- depth,
48
- residual = false,
49
- batchNorm = false,
50
- inputShape = [null, null, 3],
51
- ...sequentialArgs
52
- } = args;
53
-
3
+ constructor(args) {
4
+ const { filters, units, activation = units == 1 ? "sigmoid" : "softmax", depth, residual = false, batchNorm = false, inputShape = [null, null, 3], ...sequentialArgs } = args;
54
5
  sequentialArgs.name = sequentialArgs.name ?? "unet_model";
55
-
56
6
  super({
57
7
  ...sequentialArgs,
58
8
  // calling user should not modify the layers after instantiation
59
9
  layers: [createUNet({ filters, units, activation, depth, residual, batchNorm, inputShape })]
60
10
  });
61
11
  }
62
-
63
-
64
- override summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void {
12
+ summary(lineLength, positions, printFn) {
65
13
  super.summary(lineLength, positions, printFn);
66
- (this.layers[0] as tf.LayersModel).summary(lineLength, positions, printFn);
14
+ this.layers[0].summary(lineLength, positions, printFn);
67
15
  }
68
16
  }
69
-
70
-
71
- export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }: UNetModelArgs) {
17
+ export function createUNet({ filters, depth, units, activation, residual = false, batchNorm = false, inputShape = [null, null, 3] }) {
72
18
  if (units < 1) {
73
19
  throw Error(`createUNet: units should be >= 1, got ${units}`);
74
20
  }
75
-
76
21
  const [image_height, image_width] = inputShape;
77
22
  const divisble_by = 2 ** depth;
78
-
79
23
  if ((image_height != null && image_height % divisble_by != 0) ||
80
24
  image_width != null && image_width % divisble_by != 0) {
81
- throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`)
25
+ throw Error(`createUNet: the input height and width must be divisible by 2^depth (${divisble_by})`);
82
26
  }
83
-
84
27
  const input = tf.input({ shape: inputShape });
85
-
86
- const skip_connection: tf.SymbolicTensor[] = [];
87
-
28
+ const skip_connection = [];
88
29
  let x = input;
89
-
90
30
  // calculate the filter sizes for each level
91
31
  const filter_sizes = Array.from({ length: depth }, (_, i) => filters * (2 ** i));
92
-
93
32
  for (const filter_size of filter_sizes) {
94
33
  const contraction = contractionBlock(x, filter_size, residual, batchNorm, `contraction-f${filter_size}`);
95
-
96
- x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction) as tf.SymbolicTensor;
34
+ x = tf.layers.maxPooling2d({ poolSize: 2, strides: 2, name: `pool-f${filter_size}` }).apply(contraction);
97
35
  skip_connection.push(contraction);
98
36
  }
99
-
100
- x = contractionBlock(x, filter_sizes.at(-1)! * 2, residual, batchNorm, "bottleneck");
101
-
37
+ x = contractionBlock(x, filter_sizes.at(-1) * 2, residual, batchNorm, "bottleneck");
102
38
  for (let i = filter_sizes.length - 1; i >= 0; i--) {
103
39
  x = expansionBlock(x, skip_connection[i], filter_sizes[i], residual, batchNorm, `expansion-f${filter_sizes[i]}`);
104
40
  }
105
-
106
41
  const output = tf.layers.conv2d({
107
42
  filters: units,
108
43
  kernelSize: 1,
109
44
  padding: "same",
110
45
  activation: activation ?? (units == 1 ? "sigmoid" : "softmax"),
111
46
  name: "output-conv"
112
- }).apply(x) as tf.SymbolicTensor;
113
-
47
+ }).apply(x);
114
48
  return tf.model({ inputs: input, outputs: output, name: "u_net" });
115
49
  }
116
-
117
-
118
- export async function loadUNetModel(pathOrIOHandler: string | tf.io.IOHandler, options?: tf.io.LoadOptions) {
50
+ export async function loadUNetModel(pathOrIOHandler, options) {
119
51
  const model = await tf.loadLayersModel(pathOrIOHandler, options);
120
52
  const unet = createUNet({ depth: 1, filters: 4, units: 1 }); // these are dummy args that are overwritten
121
53
  const { name, ...rest } = model;
122
54
  Object.assign(unet, rest);
123
-
124
55
  return unet;
125
56
  }
126
-
127
-
128
57
  /**
129
58
  * The contraction block of a U-Net
130
- *
59
+ *
131
60
  * Conv > BN > ReLU > Conv > BN + residual > ReLU
132
- *
61
+ *
133
62
  * TODO: for residual, change order to (BN > ReLU > Conv)x2 + residual
134
- *
63
+ *
135
64
  * @param x a previous layer's symbolic output
136
65
  * @param filters the number of filters, usually half the previous expansion block's
137
66
  * @param residual includes a residual connection
138
67
  * @param batchNorm applies batch normalization before ReLU activation
139
68
  * @param name a unique name for the contraction block
140
69
  */
141
- function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boolean, batchNorm: boolean, name: string) {
142
-
70
+ function contractionBlock(x, filters, residual, batchNorm, name) {
143
71
  const conv1 = tf.layers.conv2d({
144
72
  filters,
145
73
  kernelSize: 3,
@@ -149,7 +77,6 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
149
77
  name: `${name}-1-conv2d`
150
78
  });
151
79
  const relu1 = tf.layers.reLU({ name: `${name}-1-relu` });
152
-
153
80
  const conv2 = tf.layers.conv2d({
154
81
  filters,
155
82
  kernelSize: 3,
@@ -159,24 +86,17 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
159
86
  name: `${name}-2-conv2d`
160
87
  });
161
88
  const relu2 = tf.layers.reLU({ name: `${name}-2-relu` });
162
-
163
89
  let forward = conv1.apply(x);
164
-
165
90
  if (batchNorm) {
166
91
  forward = tf.layers.batchNormalization({ name: `${name}-1-batchnorm` }).apply(forward);
167
92
  }
168
-
169
93
  forward = relu1.apply(forward);
170
-
171
94
  forward = conv2.apply(forward);
172
-
173
95
  if (batchNorm) {
174
96
  forward = tf.layers.batchNormalization({ name: `${name}-2-batchnorm` }).apply(forward);
175
97
  }
176
-
177
98
  if (residual) {
178
99
  let residual_skip = x;
179
-
180
100
  if (x.shape[x.shape.length - 1] != filters) {
181
101
  // a 1x1 convolution on the input to ensure the residual connection's
182
102
  // channels/filters dim matches the convolution output
@@ -187,32 +107,26 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
187
107
  useBias: !batchNorm,
188
108
  kernelInitializer: "heNormal",
189
109
  name: `${name}-residual`
190
- }).apply(x) as tf.SymbolicTensor;
110
+ }).apply(x);
191
111
  }
192
-
193
112
  if (batchNorm) {
194
113
  residual_skip = tf.layers.batchNormalization({
195
114
  name: `${name}-residual-batchnorm`
196
- }).apply(residual_skip) as tf.SymbolicTensor;
115
+ }).apply(residual_skip);
197
116
  }
198
-
199
117
  forward = tf.layers.add().apply([
200
- residual_skip as tf.SymbolicTensor,
201
- forward as tf.SymbolicTensor
202
- ])
118
+ residual_skip,
119
+ forward
120
+ ]);
203
121
  }
204
-
205
122
  forward = relu2.apply(forward);
206
-
207
- return forward as tf.SymbolicTensor;
123
+ return forward;
208
124
  }
209
-
210
-
211
125
  /**
212
126
  * The expansion block of a U-Net
213
- *
127
+ *
214
128
  * Upconv + skip > contraction block
215
- *
129
+ *
216
130
  * @param x a previous layer's symbolic output
217
131
  * @param skip the corresponding contraction block's output (before pool), shape matches `x`
218
132
  * @param filters the number of filters, usually half the previous expansion block's
@@ -220,8 +134,7 @@ function contractionBlock(x: tf.SymbolicTensor, filters: number, residual: boole
220
134
  * @param batchNorm apply batch normalization, should be `false` when batch size is `1`
221
135
  * @param name a unique name for the contraction block
222
136
  */
223
- function expansionBlock(x: tf.SymbolicTensor, skip: tf.SymbolicTensor, filters: number, residual: boolean, batchNorm: boolean, name: string) {
224
-
137
+ function expansionBlock(x, skip, filters, residual, batchNorm, name) {
225
138
  const upconv = tf.layers.conv2dTranspose({
226
139
  filters,
227
140
  padding: "same",
@@ -230,11 +143,9 @@ function expansionBlock(x: tf.SymbolicTensor, skip: tf.SymbolicTensor, filters:
230
143
  kernelInitializer: "heNormal",
231
144
  name: `${name}-upconv`
232
145
  });
233
-
234
146
  const concat = tf.layers.concatenate({ axis: -1, name: `${name}-concat-upconv-skip` });
235
-
236
- let forward = upconv.apply(x) as tf.SymbolicTensor;
237
- forward = concat.apply([forward, skip]) as tf.SymbolicTensor;
238
-
147
+ let forward = upconv.apply(x);
148
+ forward = concat.apply([forward, skip]);
239
149
  return contractionBlock(forward, filters, residual, batchNorm, name);
240
150
  }
151
+ //# sourceMappingURL=u_net.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"u_net.js","sourceRoot":"","sources":["../../src/models/u_net.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAuCvC,MAAM,OAAO,SAAU,SAAQ,EAAE,CAAC,UAAU;IAExC,YAAY,IAAmB;QAC3B,MAAM,EACF,OAAO,EACP,KAAK,EACL,UAAU,GAAG,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,EAC/C,KAAK,EACL,QAAQ,GAAG,KAAK,EAChB,SAAS,GAAG,KAAK,EACjB,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAC5B,GAAG,cAAc,EACpB,GAAG,IAAI,CAAC;QAET,cAAc,CAAC,IAAI,GAAG,cAAc,CAAC,IAAI,IAAI,YAAY,CAAC;QAE1D,KAAK,CAAC;YACF,GAAG,cAAc;YACjB,gEAAgE;YAChE,MAAM,EAAE,CAAC,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,KAAK,EAAE,QAAQ,EAAE,SAAS,EAAE,UAAU,EAAE,CAAC,CAAC;SAC/F,CAAC,CAAC;IACP,CAAC;IAGQ,OAAO,CAAC,UAAmB,EAAE,SAAoB,EAAE,OAA2D;QACnH,KAAK,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;QAC7C,IAAI,CAAC,MAAM,CAAC,CAAC,CAAoB,CAAC,OAAO,CAAC,UAAU,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;IAC/E,CAAC;CACJ;AAGD,MAAM,UAAU,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,QAAQ,GAAG,KAAK,EAAE,SAAS,GAAG,KAAK,EAAE,UAAU,GAAG,CAAC,IAAI,EAAE,IAAI,EAAE,CAAC,CAAC,EAAiB;IAC9I,IAAI,KAAK,GAAG,CAAC,EAAE,CAAC;QACZ,MAAM,KAAK,CAAC,yCAAyC,KAAK,EAAE,CAAC,CAAC;IAClE,CAAC;IAED,MAAM,CAAC,YAAY,EAAE,WAAW,CAAC,GAAG,UAAU,CAAC;IAC/C,MAAM,WAAW,GAAG,CAAC,IAAI,KAAK,CAAC;IAE/B,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,YAAY,GAAG,WAAW,IAAI,CAAC,CAAC;QACzD,WAAW,IAAI,IAAI,IAAI,WAAW,GAAG,WAAW,IAAI,CAAC,EAAE,CAAC;QACxD,MAAM,KAAK,CAAC,wEAAwE,WAAW,GAAG,CAAC,CAAA;IACvG,CAAC;IAED,MAAM,KAAK,GAAG,EAAE,CAAC,KAAK,CAAC,EAAE,KAAK,EAAE,UAAU,EAAE,CAAC,CAAC;IAE9C,MAAM,eAAe,GAAwB,EAAE,CAAC;IAEhD,IAAI,CAAC,GAAG,KAAK,CAAC;IAEd,4CAA4C;IAC5C,MAAM,YAAY,GAAG,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,OAAO,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEjF,KAAK,MAAM,WAAW,IAAI,YAAY,EAAE,CAAC;QACrC,MAAM,WAAW,GAAG,gBAAgB,CAAC,CAAC,EAAE,WAAW,EAAE,QAAQ,EAAE,SAAS,EAAE,gBAAgB,WAAW,EAAE,CAAC,CAAC;QAEzG,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,YAAY,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,IAAI,EAAE,SAAS,WAAW,EAAE,EAAE,CAAC,CAAC,KAAK,CAAC,WAAW,CAAsB,CAAC;QAC9H,eAAe,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;IACtC,CAAC;IAED,CAAC,GAAG,gBAAgB,CAAC,CAAC,EAAE,YAAY,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,GAAG,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,YAAY,CAAC,CAAC;IAErF,KAAK,IAAI,CAAC,GAAG,YAAY,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QAChD,CAAC,GAAG,cAAc,CAAC,CAAC,EAAE,eAAe,CAAC,CAAC,CAAC,EAAE,YAAY,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,SAAS,EAAE,cAAc,YAAY,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;IACrH,CAAC;IAED,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC5B,OAAO,EAAE,KAAK;QACd,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,UAAU,IAAI,CAAC,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,SAAS,CAAC;QAC9D,IAAI,EAAE,aAAa;KACtB,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IAEjC,OAAO,EAAE,CAAC,KAAK,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC;AACvE,CAAC;AAGD,MAAM,CAAC,KAAK,UAAU,aAAa,CAAC,eAAyC,EAAE,OAA2B;IACtG,MAAM,KAAK,GAAG,MAAM,EAAE,CAAC,eAAe,CAAC,eAAe,EAAE,OAAO,CAAC,CAAC;IACjE,MAAM,IAAI,GAAG,UAAU,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,OAAO,EAAE,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,4CAA4C;IACzG,MAAM,EAAE,IAAI,EAAE,GAAG,IAAI,EAAE,GAAG,KAAK,CAAC;IAChC,MAAM,CAAC,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;IAE1B,OAAO,IAAI,CAAC;AAChB,CAAC;AAGD;;;;;;;;;;;;GAYG;AACH,SAAS,gBAAgB,CAAC,CAAoB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEhH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;QAC3B,OAAO;QACP,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,MAAM;QACf,OAAO,EAAE,CAAC,SAAS;QACnB,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;KAC3B,CAAC,CAAC;IACH,MAAM,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,SAAS,EAAE,CAAC,CAAC;IAEzD,IAAI,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAE7B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,IAAI,SAAS,EAAE,CAAC;QACZ,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,cAAc,EAAE,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAC3F,CAAC;IAED,IAAI,QAAQ,EAAE,CAAC;QACX,IAAI,aAAa,GAAG,CAAC,CAAC;QAEtB,IAAI,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,IAAI,OAAO,EAAE,CAAC;YACzC,qEAAqE;YACrE,sDAAsD;YACtD,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;gBAC7B,OAAO;gBACP,UAAU,EAAE,CAAC;gBACb,OAAO,EAAE,MAAM;gBACf,OAAO,EAAE,CAAC,SAAS;gBACnB,iBAAiB,EAAE,UAAU;gBAC7B,IAAI,EAAE,GAAG,IAAI,WAAW;aAC3B,CAAC,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;QACrC,CAAC;QAED,IAAI,SAAS,EAAE,CAAC;YACZ,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC;gBACzC,IAAI,EAAE,GAAG,IAAI,qBAAqB;aACrC,CAAC,CAAC,KAAK,CAAC,aAAa,CAAsB,CAAC;QACjD,CAAC;QAED,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC;YAC5B,aAAkC;YAClC,OAA4B;SAC/B,CAAC,CAAA;IACN,CAAC;IAED,OAAO,GAAG,KAAK,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;IAE/B,OAAO,OAA4B,CAAC;AACxC,CAAC;AAGD;;;;;;;;;;;GAWG;AACH,SAAS,cAAc,CAAC,CAAoB,EAAE,IAAuB,EAAE,OAAe,EAAE,QAAiB,EAAE,SAAkB,EAAE,IAAY;IAEvI,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,eAAe,CAAC;QACrC,OAAO;QACP,OAAO,EAAE,MAAM;QACf,UAAU,EAAE,CAAC;QACb,OAAO,EAAE,CAAC;QACV,iBAAiB,EAAE,UAAU;QAC7B,IAAI,EAAE,GAAG,IAAI,SAAS;KACzB,CAAC,CAAC;IAEH,MAAM,MAAM,GAAG,EAAE,CAAC,MAAM,CAAC,WAAW,CAAC,EAAE,IAAI,EAAE,CAAC,CAAC,EAAE,IAAI,EAAE,GAAG,IAAI,qBAAqB,EAAE,CAAC,CAAC;IAEvF,IAAI,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,CAAsB,CAAC;IACnD,OAAO,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,EAAE,IAAI,CAAC,CAAsB,CAAC;IAE7D,OAAO,gBAAgB,CAAC,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,SAAS,EAAE,IAAI,CAAC,CAAC;AACzE,CAAC"}
@@ -0,0 +1,6 @@
1
+ export * as layers from "@/layers";
2
+ export * as models from "@/models";
3
+ export * as losses from "@/losses";
4
+ export * from "@/kv_cache";
5
+ export * from "@/metrics";
6
+ //# sourceMappingURL=index.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
@@ -0,0 +1,6 @@
1
+ export * as layers from "@/layers";
2
+ export * as models from "@/models";
3
+ export * as losses from "@/losses";
4
+ export * from "@/kv_cache";
5
+ export * from "@/metrics";
6
+ //# sourceMappingURL=index.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/index.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AACnC,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAEnC,cAAc,YAAY,CAAC;AAC3B,cAAc,WAAW,CAAC"}
@@ -0,0 +1,53 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ export interface KvCacheArgs {
3
+ batchSize: number;
4
+ maxSequenceLength: number;
5
+ numHeads: number;
6
+ headDim: number;
7
+ dtype?: tf.DataType;
8
+ }
9
+ /**
10
+ * A container for KV caches. A model should initialize one KV cache
11
+ */
12
+ export declare class KvCacheContainer {
13
+ protected caches: Map<string, KvCache>;
14
+ protected max_sequence_length: number;
15
+ constructor(maxSequenceLength: number);
16
+ create(id: string, args: Omit<KvCacheArgs, "maxSequenceLength">): void;
17
+ /**
18
+ * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
19
+ */
20
+ update(id: string, key: tf.Tensor4D, value: tf.Tensor4D): {
21
+ keyCache: tf.Variable<tf.Rank.R4>;
22
+ valueCache: tf.Variable<tf.Rank.R4>;
23
+ } | undefined;
24
+ reset(): void;
25
+ dispose(): void;
26
+ get size(): number;
27
+ get maxSequenceLength(): number;
28
+ }
29
+ export declare class KvCache {
30
+ protected key_cache: tf.Variable<tf.Rank.R4>;
31
+ protected value_cache: tf.Variable<tf.Rank.R4>;
32
+ protected current_position: number;
33
+ protected batch_size: number;
34
+ protected max_sequence_length: number;
35
+ protected num_kv_heads: number;
36
+ protected head_dim: number;
37
+ constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype }: KvCacheArgs);
38
+ /**
39
+ * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
40
+ */
41
+ update(key: tf.Tensor4D, value: tf.Tensor4D): {
42
+ keyCache: tf.Variable<tf.Rank.R4>;
43
+ valueCache: tf.Variable<tf.Rank.R4>;
44
+ };
45
+ protected mergeIntoCache(new_value: tf.Tensor4D, current_cache: tf.Tensor4D): tf.Tensor4D;
46
+ reset(): void;
47
+ dispose(): void;
48
+ /**
49
+ * The size of the KV cache, also the number of tokens since the first one.
50
+ */
51
+ get size(): number;
52
+ }
53
+ //# sourceMappingURL=kv_cache.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"kv_cache.d.ts","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC,MAAM,WAAW,WAAW;IACxB,SAAS,EAAE,MAAM,CAAC;IAClB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,QAAQ,EAAE,MAAM,CAAC;IACjB,OAAO,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAA;CACtB;AAGD;;GAEG;AACH,qBAAa,gBAAgB;IACzB,SAAS,CAAC,MAAM,uBAA8B;IAC9C,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;gBAG1B,iBAAiB,EAAE,MAAM;IAS9B,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,IAAI,EAAE,IAAI,CAAC,WAAW,EAAE,mBAAmB,CAAC;IAUtE;;OAEG;IACI,MAAM,CAAC,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IA4BvD,KAAK;IAOL,OAAO;IAOd,IAAW,IAAI,WAGd;IAGD,IAAW,iBAAiB,WAE3B;CACJ;AAGD,qBAAa,OAAO;IAEhB,SAAS,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAC7C,SAAS,CAAC,WAAW,EAAE,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;IAG9C,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAK;IAEvC,SAAS,CAAC,UAAU,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,YAAY,EAAE,MAAM,CAAC;IAC/B,SAAS,CAAC,QAAQ,EAAE,MAAM,CAAC;gBAEf,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAiB,EAAE,EAAE,WAAW;IAa/F;;OAEG;IACI,MAAM,CAAC,GAAG,EAAE,EAAE,CAAC,QAAQ,EAAE,KAAK,EAAE,EAAE,CAAC,QAAQ;;;;IAgClD,SAAS,CAAC,cAAc,CAAC,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,aAAa,EAAE,EAAE,CAAC,QAAQ;IAqBpE,KAAK,IAAI,IAAI;IAab,OAAO,IAAI,IAAI;IAMtB;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;CAEJ"}
@@ -0,0 +1,135 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ /**
3
+ * A container for KV caches. A model should initialize one KV cache
4
+ */
5
+ export class KvCacheContainer {
6
+ caches = new Map();
7
+ max_sequence_length;
8
+ constructor(maxSequenceLength) {
9
+ if (!maxSequenceLength) {
10
+ throw Error(`KvCacheContainer: expected KV cache maximum sequence length to be greater than 0, got: ${String(maxSequenceLength)}`);
11
+ }
12
+ this.max_sequence_length = maxSequenceLength;
13
+ }
14
+ create(id, args) {
15
+ const new_cache = new KvCache({
16
+ ...args,
17
+ maxSequenceLength: this.max_sequence_length
18
+ });
19
+ this.caches.set(id, new_cache);
20
+ }
21
+ /**
22
+ * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
23
+ */
24
+ update(id, key, value) {
25
+ const kv_cache = this.caches.get(id);
26
+ if (!kv_cache) {
27
+ return undefined;
28
+ }
29
+ const { keyCache, valueCache } = kv_cache.update(key, value);
30
+ // slicing to get only the past key and value projections, but normally
31
+ // in TensorFlow and PyTorch the full cache is returned and masked for
32
+ // graph purposes
33
+ return tf.tidy(() => {
34
+ const k_cache = keyCache.slice([0, 0, 0, 0], [keyCache.shape[0], keyCache.shape[1], kv_cache.size, keyCache.shape[3]]);
35
+ const v_cache = valueCache.slice([0, 0, 0, 0], [valueCache.shape[0], valueCache.shape[1], kv_cache.size, valueCache.shape[3]]);
36
+ return {
37
+ keyCache: k_cache,
38
+ valueCache: v_cache
39
+ };
40
+ });
41
+ }
42
+ reset() {
43
+ this.caches.forEach(cache => {
44
+ cache.reset();
45
+ });
46
+ }
47
+ dispose() {
48
+ this.caches.forEach(cache => {
49
+ cache.dispose();
50
+ });
51
+ }
52
+ get size() {
53
+ // the size of all KV caches are expected to be the same, just use the first one
54
+ return this.caches.entries().next().value?.[1].size ?? 0;
55
+ }
56
+ get maxSequenceLength() {
57
+ return this.max_sequence_length;
58
+ }
59
+ }
60
+ export class KvCache {
61
+ key_cache;
62
+ value_cache;
63
+ // the size of the KV cache, represents the number of tokens since the first chat token
64
+ current_position = 0;
65
+ batch_size;
66
+ max_sequence_length;
67
+ num_kv_heads;
68
+ head_dim;
69
+ constructor({ batchSize, maxSequenceLength, numHeads, headDim, dtype = "float32" }) {
70
+ const cache_shape = [batchSize, numHeads, maxSequenceLength, headDim];
71
+ this.key_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
72
+ this.value_cache = tf.variable(tf.zeros(cache_shape, dtype), false);
73
+ this.batch_size = batchSize;
74
+ this.max_sequence_length = maxSequenceLength;
75
+ this.num_kv_heads = numHeads;
76
+ this.head_dim = headDim;
77
+ }
78
+ /**
79
+ * The key and value tensors should have the shape (post head split, etc) `[batch, heads, seq, head_dim]`
80
+ */
81
+ update(key, value) {
82
+ const batch_size = key.shape[0];
83
+ const seq_len = key.shape[2];
84
+ if (batch_size > this.key_cache.shape[0]) {
85
+ throw Error(`The current KV cache has been set up with a batch size of` +
86
+ ` ${this.key_cache.shape[0]}, but found new key tensors with batch size ${batch_size}`);
87
+ }
88
+ if (this.current_position + seq_len > this.max_sequence_length) {
89
+ throw Error(`The KV cache has exceeded its maximum sequence length of ${this.max_sequence_length}. Use a larger value.`);
90
+ }
91
+ const new_key_cache = this.mergeIntoCache(key, this.key_cache);
92
+ const new_value_cache = this.mergeIntoCache(value, this.value_cache);
93
+ this.key_cache.assign(new_key_cache);
94
+ this.value_cache.assign(new_value_cache);
95
+ new_key_cache.dispose();
96
+ new_value_cache.dispose();
97
+ // advance the pointer to reflect the updated cache's current
98
+ this.current_position += seq_len;
99
+ return {
100
+ keyCache: this.key_cache,
101
+ valueCache: this.value_cache,
102
+ };
103
+ }
104
+ mergeIntoCache(new_value, current_cache) {
105
+ const seq_len = new_value.shape[2];
106
+ return tf.tidy(() => {
107
+ const historical = current_cache.slice([0, 0, 0, 0], [this.batch_size, this.num_kv_heads, this.current_position, this.head_dim]);
108
+ const future = current_cache.slice([0, 0, this.current_position + seq_len, 0], [this.batch_size, this.num_kv_heads, this.max_sequence_length - this.current_position - seq_len, this.head_dim]);
109
+ // merge the new tensor into the current cache to create a new, larger, cache,
110
+ // this is different from Python immplementations because TFJS tensors are immutable,
111
+ // because we cannot update a slice, we must slice and concat
112
+ return tf.concat([historical, new_value, future], 2);
113
+ });
114
+ }
115
+ reset() {
116
+ this.current_position = 0;
117
+ tf.tidy(() => {
118
+ const key_cache_shape = this.key_cache.shape;
119
+ const value_cache_shape = this.value_cache.shape;
120
+ this.key_cache.assign(tf.zeros(key_cache_shape));
121
+ this.value_cache.assign(tf.zeros(value_cache_shape));
122
+ });
123
+ }
124
+ dispose() {
125
+ this.key_cache.dispose();
126
+ this.value_cache.dispose();
127
+ }
128
+ /**
129
+ * The size of the KV cache, also the number of tokens since the first one.
130
+ */
131
+ get size() {
132
+ return this.current_position;
133
+ }
134
+ }
135
+ //# sourceMappingURL=kv_cache.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"kv_cache.js","sourceRoot":"","sources":["../../src/kv_cache.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAYvC;;GAEG;AACH,MAAM,OAAO,gBAAgB;IACf,MAAM,GAAG,IAAI,GAAG,EAAmB,CAAC;IACpC,mBAAmB,CAAS;IAGtC,YAAY,iBAAyB;QACjC,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrB,MAAM,KAAK,CAAC,0FAA0F,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,CAAC;QACvI,CAAC;QAED,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;IACjD,CAAC;IAGM,MAAM,CAAC,EAAU,EAAE,IAA4C;QAClE,MAAM,SAAS,GAAG,IAAI,OAAO,CAAC;YAC1B,GAAG,IAAI;YACP,iBAAiB,EAAE,IAAI,CAAC,mBAAmB;SAC9C,CAAC,CAAC;QAEH,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,EAAE,SAAS,CAAC,CAAC;IACnC,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,EAAU,EAAE,GAAgB,EAAE,KAAkB;QAC1D,MAAM,QAAQ,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;QAErC,IAAI,CAAC,QAAQ,EAAE,CAAC;YACZ,OAAO,SAAS,CAAC;QACrB,CAAC;QAED,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,QAAQ,CAAC,MAAM,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;QAE7D,uEAAuE;QACvE,sEAAsE;QACtE,iBAAiB;QACjB,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,OAAO,GAAG,QAAQ,CAAC,KAAK,CAC1B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAC9E,MAAM,OAAO,GAAG,UAAU,CAAC,KAAK,CAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEpF,OAAO;gBACH,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE,OAAO;aACtB,CAAA;QACL,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,KAAK,EAAE,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YACxB,KAAK,CAAC,OAAO,EAAE,CAAC;QACpB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAW,IAAI;QACX,gFAAgF;QAChF,OAAO,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,CAAC;IAC7D,CAAC;IAGD,IAAW,iBAAiB;QACxB,OAAO,IAAI,CAAC,mBAAmB,CAAC;IACpC,CAAC;CACJ;AAGD,MAAM,OAAO,OAAO;IAEN,SAAS,CAA0B;IACnC,WAAW,CAAyB;IAE9C,uFAAuF;IAC7E,gBAAgB,GAAW,CAAC,CAAC;IAE7B,UAAU,CAAS;IACnB,mBAAmB,CAAS;IAC5B,YAAY,CAAS;IACrB,QAAQ,CAAS;IAE3B,YAAY,EAAE,SAAS,EAAE,iBAAiB,EAAE,QAAQ,EAAE,OAAO,EAAE,KAAK,GAAG,SAAS,EAAe;QAC3F,MAAM,WAAW,GAAG,CAAC,SAAS,EAAE,QAAQ,EAAE,iBAAiB,EAAE,OAAO,CAAqC,CAAC;QAE1G,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAClE,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,KAAK,CAAC,WAAW,EAAE,KAAK,CAAC,EAAE,KAAK,CAAC,CAAC;QAEpE,IAAI,CAAC,UAAU,GAAG,SAAS,CAAC;QAC5B,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,QAAQ,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,OAAO,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACI,MAAM,CAAC,GAAgB,EAAE,KAAkB;QAC9C,MAAM,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAChC,MAAM,OAAO,GAAG,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAE7B,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;YACvC,MAAM,KAAK,CAAC,2DAA2D;gBACnE,IAAI,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,+CAA+C,UAAU,EAAE,CAAC,CAAA;QAC/F,CAAC;QAED,IAAI,IAAI,CAAC,gBAAgB,GAAG,OAAO,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,4DAA4D,IAAI,CAAC,mBAAmB,uBAAuB,CAAC,CAAC;QAC7H,CAAC;QAED,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/D,MAAM,eAAe,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;QAErE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QACrC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;QAEzC,aAAa,CAAC,OAAO,EAAE,CAAC;QACxB,eAAe,CAAC,OAAO,EAAE,CAAC;QAE1B,6DAA6D;QAC7D,IAAI,CAAC,gBAAgB,IAAI,OAAO,CAAC;QAEjC,OAAO;YACH,QAAQ,EAAE,IAAI,CAAC,SAAS;YACxB,UAAU,EAAE,IAAI,CAAC,WAAW;SAC/B,CAAA;IACL,CAAC;IAGS,cAAc,CAAC,SAAsB,EAAE,aAA0B;QACvE,MAAM,OAAO,GAAG,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAEnC,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAEhB,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CAClC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EACZ,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,gBAAgB,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAEhF,MAAM,MAAM,GAAG,aAAa,CAAC,KAAK,CAC9B,CAAC,CAAC,EAAE,CAAC,EAAE,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,CAAC,CAAC,EAC1C,CAAC,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,YAAY,EAAE,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,gBAAgB,GAAG,OAAO,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErH,8EAA8E;YAC9E,qFAAqF;YACrF,6DAA6D;YAC7D,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,UAAU,EAAE,SAAS,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAA;IACN,CAAC;IAGM,KAAK;QACR,IAAI,CAAC,gBAAgB,GAAG,CAAC,CAAC;QAE1B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC;YAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC;YAEjD,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC;YACjD,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC,CAAC;QACzD,CAAC,CAAC,CAAC;IACP,CAAC;IAGM,OAAO;QACV,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,CAAC;QACzB,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC;IAC/B,CAAC;IAGD;;OAEG;IACH,IAAI,IAAI;QACJ,OAAO,IAAI,CAAC,gBAAgB,CAAC;IACjC,CAAC;CAEJ"}
@@ -0,0 +1,31 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { KvCacheContainer } from "@/kv_cache";
3
+ import { MultiHeadAttention, type MultiHeadAttentionArgs } from '@/layers/multihead_attention';
4
+ import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
5
+ /**
6
+ * MultiHeadAttention with RoPE and KV caching. If using KV caching, this layer
7
+ * should be used in a custom training loop because it requires the cache to be
8
+ * passed through the `kwargs.kvCache` argument during the `layer.apply()`
9
+ * forward propagation.
10
+ *
11
+ * If a KV cache is not provided, then this layer operates as MultiHeadAttention with RoPE.
12
+ */
13
+ export declare class CachedRoPEMultiHeadAttention extends MultiHeadAttention {
14
+ static className: string;
15
+ protected rope: tf.layers.Layer;
16
+ constructor(args: MultiHeadAttentionArgs);
17
+ protected forward(query_input: tf.Tensor, key_input: tf.Tensor, value_input: tf.Tensor, packing_mask: tf.Tensor | null, causal_mask: tf.Tensor | null, kwargs: Kwargs): tf.Tensor;
18
+ protected getCachedKV(kv_container: KvCacheContainer, key_split: tf.Tensor4D, value_split: tf.Tensor4D): {
19
+ keyCache: tf.Variable<tf.Rank.R4>;
20
+ valueCache: tf.Variable<tf.Rank.R4>;
21
+ };
22
+ /**
23
+ * Adds RoPE position encoding right after splitting heads.
24
+ */
25
+ protected splitHeads(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, shuffle: number[]): {
26
+ query_split: tf.Tensor4D;
27
+ key_split: tf.Tensor4D;
28
+ value_split: tf.Tensor4D;
29
+ };
30
+ }
31
+ //# sourceMappingURL=cached_rope_multihead_attention.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"cached_rope_multihead_attention.d.ts","sourceRoot":"","sources":["../../../src/layers/cached_rope_multihead_attention.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,kBAAkB,EAAE,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAE/F,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAGjE;;;;;;;GAOG;AACH,qBAAa,4BAA6B,SAAQ,kBAAkB;IAChE,MAAM,CAAC,SAAS,SAAkC;IAElD,SAAS,CAAC,IAAI,EAAE,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;gBAEpB,IAAI,EAAE,sBAAsB;cAMrB,OAAO,CACtB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,SAAS,EAAE,EAAE,CAAC,MAAM,EACpB,WAAW,EAAE,EAAE,CAAC,MAAM,EACtB,YAAY,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC9B,WAAW,EAAE,EAAE,CAAC,MAAM,GAAG,IAAI,EAC7B,MAAM,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM;IAuC9B,SAAS,CAAC,WAAW,CAAC,YAAY,EAAE,gBAAgB,EAAE,SAAS,EAAE,EAAE,CAAC,QAAQ,EAAE,WAAW,EAAE,EAAE,CAAC,QAAQ;;;;IAqBtG;;OAEG;cACgB,UAAU,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,EAAE,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC,MAAM,EAAE,OAAO,EAAE,MAAM,EAAE;qBAO5D,EAAE,CAAC,QAAQ;mBAEX,EAAE,CAAC,QAAQ;qBACwB,EAAE,CAAC,QAAQ;;CAIxF"}