@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,156 +1,93 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
2
  import { categoricalCrossentropy, binaryCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
3
-
4
3
  const epsilon = 1e-7;
5
-
6
4
  const REDUCE_HW = [1, 2]; // reduce over width and height
7
5
  const REDUCE_BHW = [0, 1, 2]; // reduce over batch, width, height
8
6
  const REDUCE_BHWC = [0, 1, 2, 3]; // reduce all dimensions
9
-
10
-
11
7
  // Standard (Sorensen) Dice Loss
12
- export function diceBinaryStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
13
-
8
+ export function diceBinaryStandard(y_true, y_pred) {
14
9
  const y_true_flat = tf.reshape(y_true, [y_true.shape[0], -1]);
15
10
  const y_pred_flat = tf.reshape(y_pred, [y_pred.shape[0], -1]);
16
-
17
11
  const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat), 1);
18
12
  const union = tf.add(tf.sum(y_true_flat, 1), tf.sum(y_pred_flat, 1));
19
-
20
- const dice = tf.div(
21
- intersection.mul(2).add(epsilon),
22
- union.add(epsilon)
23
- );
24
-
13
+ const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
25
14
  return tf.scalar(1).sub(dice);
26
15
  }
27
-
28
-
29
16
  // prevents minification of function name which TFJS relies on
30
17
  Object.defineProperty(diceBinaryStandard, "name", { value: "diceBinaryStandard", configurable: false });
31
-
32
-
33
18
  // https://github.com/keras-team/keras/blob/v3.3.3/keras/src/losses/losses.py#L1983-L2010
34
- export function diceBinaryGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
19
+ export function diceBinaryGlobal(y_true, y_pred) {
35
20
  const y_true_flat = tf.reshape(y_true, [-1]);
36
21
  const y_pred_flat = tf.reshape(y_pred, [-1]);
37
-
38
22
  const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat));
39
23
  const union = tf.add(tf.sum(y_true_flat), tf.sum(y_pred_flat));
40
-
41
- const dice = tf.div(
42
- intersection.mul(2).add(epsilon),
43
- union.add(epsilon)
44
- );
45
-
24
+ const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
46
25
  return tf.scalar(1).sub(dice);
47
26
  }
48
-
49
-
50
27
  // prevents minification of function name which TFJS relies on
51
28
  Object.defineProperty(diceBinaryGlobal, "name", { value: "diceBinaryGlobal", configurable: false });
52
-
53
-
54
- export function diceCategoricalStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
29
+ export function diceCategoricalStandard(y_true, y_pred) {
55
30
  const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_HW);
56
31
  const union = tf.add(y_true, y_pred).sum(REDUCE_HW);
57
-
58
- const dice = tf.div(
59
- intersection.mul(2).add(epsilon),
60
- union.add(epsilon)
61
- );
62
-
32
+ const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
63
33
  return tf.scalar(1).sub(tf.mean(dice, -1));
64
34
  }
65
-
66
-
67
35
  // prevents minification of function name which TFJS relies on
68
36
  Object.defineProperty(diceCategoricalStandard, "name", { value: "diceCategoricalStandard", configurable: false });
69
-
70
-
71
- export function diceCategoricalGeneralized(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
72
-
37
+ export function diceCategoricalGeneralized(y_true, y_pred) {
73
38
  // this is done twice so we calculate it once
74
39
  const y_true_sum = y_true.sum(REDUCE_BHW);
75
-
76
40
  const weighting = tf.div(1, y_true_sum.square().add(epsilon));
77
-
78
41
  const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHW).mul(weighting).sum();
79
42
  const union = tf.add(y_true_sum, y_pred.sum(REDUCE_BHW)).mul(weighting).sum();
80
-
81
- const dice = tf.div(
82
- intersection.mul(2).add(epsilon),
83
- union.add(epsilon)
84
- );
85
-
43
+ const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
86
44
  return tf.scalar(1).sub(dice);
87
45
  }
88
-
89
-
90
46
  // prevents minification of function name which TFJS relies on
91
47
  Object.defineProperty(diceCategoricalGeneralized, "name", { value: "diceCategoricalGeneralized", configurable: false });
92
-
93
-
94
- export function diceCategoricalGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
95
-
48
+ export function diceCategoricalGlobal(y_true, y_pred) {
96
49
  const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHWC);
97
50
  const union = tf.add(tf.sum(y_true, REDUCE_BHWC), tf.sum(y_pred, REDUCE_BHWC));
98
-
99
- const dice = tf.div(
100
- intersection.mul(2).add(epsilon),
101
- union.add(epsilon)
102
- );
103
-
51
+ const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
104
52
  return tf.scalar(1).sub(dice);
105
53
  }
106
-
107
-
108
54
  // prevents minification of function name which TFJS relies on
109
55
  Object.defineProperty(diceCategoricalGlobal, "name", { value: "diceCategoricalGlobal", configurable: false });
110
-
111
-
112
56
  /**
113
57
  * Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
114
58
  * Both have equal weight.
115
- *
59
+ *
116
60
  * @param y_true the label tensor
117
61
  * @param y_pred the prediction tensor (not sparse)
118
62
  * @returns a tensor of shape `[ batch ]`
119
63
  */
120
- export function diceBinaryCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
64
+ export function diceBinaryCrossentropy(y_true, y_pred) {
121
65
  // reduce cross entropy shape from [B, H, W] to [B] to match dice
122
66
  const bce = binaryCrossentropy(y_true, y_pred).mean(REDUCE_HW);
123
67
  const dice = diceBinaryStandard(y_true, y_pred);
124
-
125
68
  return tf.add(bce.mul(0.5), dice.mul(0.5));
126
69
  }
127
-
128
-
129
70
  // prevents minification of function name which TFJS relies on
130
71
  Object.defineProperty(diceBinaryCrossentropy, "name", { value: "diceBinaryCrossentropy", configurable: false });
131
-
132
-
133
72
  /**
134
73
  * Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
135
74
  * Both have equal weight. Expects dense (non-sparse) label tensors.
136
- *
75
+ *
137
76
  * This does not support sparse tensors because TFJS's
138
77
  * sparseCategoricalCrossentropy loss onehots the label
139
78
  * and calls categoricalCrossentropy. See
140
79
  * https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
141
- *
142
- * @param y_true the label
80
+ *
81
+ * @param y_true the label
143
82
  * @param y_pred the prediction tensor (not sparse)
144
83
  * @returns a tensor of shape `[ batch ]`
145
84
  */
146
- export function diceCategoricalCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
85
+ export function diceCategoricalCrossentropy(y_true, y_pred) {
147
86
  // reduce cross entropy shape from [B, H, W] to [B] to match dice
148
87
  const cce = categoricalCrossentropy(y_true, y_pred).mean(REDUCE_HW);
149
88
  const dice = diceCategoricalStandard(y_true, y_pred);
150
-
151
89
  return tf.add(cce.mul(0.5), dice.mul(0.5));
152
90
  }
153
-
154
-
155
91
  // prevents minification of function name which TFJS relies on
156
92
  Object.defineProperty(diceCategoricalCrossentropy, "name", { value: "diceCategoricalCrossentropy", configurable: false });
93
+ //# sourceMappingURL=dice.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"dice.js","sourceRoot":"","sources":["../../src/losses/dice.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,uBAAuB,EAAE,kBAAkB,EAAE,MAAM,qCAAqC,CAAC;AAElG,MAAM,OAAO,GAAG,IAAI,CAAC;AAErB,MAAM,SAAS,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,+BAA+B;AACzD,MAAM,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,mCAAmC;AACjE,MAAM,WAAW,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,wBAAwB;AAG1D,gCAAgC;AAChC,MAAM,UAAU,kBAAkB,CAAC,MAAiB,EAAE,MAAiB;IAEnE,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAC9D,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAE9D,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,EAAE,CAAC,CAAC,CAAC;IACjE,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC;IAErE,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,kBAAkB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,oBAAoB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGxG,yFAAyF;AACzF,MAAM,UAAU,gBAAgB,CAAC,MAAiB,EAAE,MAAiB;IACjE,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC7C,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE7C,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC;IAC9D,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC,CAAC;IAE/D,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,gBAAgB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,kBAAkB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGpG,MAAM,UAAU,uBAAuB,CAAC,MAAiB,EAAE,MAAiB;IACxE,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,SAAS,CAAC,CAAC;IAC/D,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC;IAEpD,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,uBAAuB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,yBAAyB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGlH,MAAM,UAAU,0BAA0B,CAAC,MAAiB,EAAE,MAAiB;IAE3E,6CAA6C;IAC7C,MAAM,UAAU,GAAG,MAAM,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;IAE1C,MAAM,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,EAAE,UAAU,CAAC,MAAM,EAAE,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC;IAE9D,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,UAAU,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,GAAG,EAAE,CAAC;IACrF,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,EAAE,MAAM,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,GAAG,EAAE,CAAC;IAE9E,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,0BAA0B,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,4BAA4B,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGxH,MAAM,UAAU,qBAAqB,CAAC,MAAiB,EAAE,MAAiB;IAEtE,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,WAAW,CAAC,CAAC;IACjE,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,WAAW,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC,CAAC;IAE/E,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,qBAAqB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,uBAAuB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAG9G;;;;;;;GAOG;AACH,MAAM,UAAU,sBAAsB,CAAC,MAAiB,EAAE,MAAiB;IACvE,iEAAiE;IACjE,MAAM,GAAG,GAAG,kBAAkB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC/D,MAAM,IAAI,GAAG,kBAAkB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;IAEhD,OAAO,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,sBAAsB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,wBAAwB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGhH;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,2BAA2B,CAAC,MAAiB,EAAE,MAAiB;IAC5E,iEAAiE;IACjE,MAAM,GAAG,GAAG,uBAAuB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IACpE,MAAM,IAAI,GAAG,uBAAuB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;IAErD,OAAO,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,2BAA2B,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,6BAA6B,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export * from "./dice";
2
+ //# sourceMappingURL=index.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../src/losses/index.ts"],"names":[],"mappings":"AAAA,cAAc,QAAQ,CAAC"}
@@ -0,0 +1,2 @@
1
+ export * from "./dice";
2
+ //# sourceMappingURL=index.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../../src/losses/index.ts"],"names":[],"mappings":"AAAA,cAAc,QAAQ,CAAC"}
@@ -0,0 +1,20 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ /**
3
+ * Generate a causal mask used in self-attention to prevent tokens from looking
4
+ * ahead. The values in the upper right portion of the mask matrix are set to
5
+ * -1e7 so that they have no impact during scaled dot product attention.
6
+ */
7
+ export declare function causal(query_seq_length: number, key_seq_length: number): tf.Tensor<tf.Rank>;
8
+ /**
9
+ * Generate a self-attention mask that prevents packed sequences from cross document
10
+ * boundaries and attending to each other. The result is a tensor of diagonally
11
+ * positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
12
+ * The latter is scored zero during the scaled dot product attention's softmax operation.
13
+ *
14
+ * @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
15
+ *
16
+ * Example boundary of 3 samples that are packed into one:
17
+ * `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
18
+ */
19
+ export declare function packing(boundaries: Int32Array): tf.Tensor<tf.Rank>;
20
+ //# sourceMappingURL=masks.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"masks.d.ts","sourceRoot":"","sources":["../src/masks.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;GAIG;AACH,wBAAgB,MAAM,CAAC,gBAAgB,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,sBAItE;AAGD;;;;;;;;;;GAUG;AACH,wBAAgB,OAAO,CAAC,UAAU,EAAE,UAAU,sBAc7C"}
@@ -1,28 +1,37 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
-
2
+ /**
3
+ * Generate a causal mask used in self-attention to prevent tokens from looking
4
+ * ahead. The values in the upper right portion of the mask matrix are set to
5
+ * -1e7 so that they have no impact during scaled dot product attention.
6
+ */
7
+ export function causal(query_seq_length, key_seq_length) {
8
+ return tf.linalg.bandPart(tf.ones([query_seq_length, key_seq_length]), -1, 0)
9
+ .sub(1)
10
+ .mul(1e7);
11
+ }
3
12
  /**
4
13
  * Generate a self-attention mask that prevents packed sequences from cross document
5
14
  * boundaries and attending to each other. The result is a tensor of diagonally
6
15
  * positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
7
16
  * The latter is scored zero during the scaled dot product attention's softmax operation.
8
- *
17
+ *
9
18
  * @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
10
- *
11
- * Example boundary of 3 samples are packed into one:
19
+ *
20
+ * Example boundary of 3 samples that are packed into one:
12
21
  * `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
13
22
  */
14
- export function generatePackingSelfAttentionMask(boundaries: Int32Array) {
23
+ export function packing(boundaries) {
15
24
  // see images at
16
25
  // https://reddit.com/r/LocalLLaMA/comments/197efaz/training_llama_mistral_and_mixtralmoe_faster_with/
17
26
  return tf.tidy(() => {
18
27
  // cumsum transforms the tensor such that each sequence in the pack gets its own id,
19
28
  const partitions = tf.tensor1d(boundaries).cumsum();
20
-
21
29
  return partitions.expandDims(1)
22
30
  .equal(partitions.expandDims(0))
23
31
  .sub(1)
24
32
  .mul(1e7)
25
33
  // introduce a head dimension so it can be broadcasted
26
34
  .expandDims(0);
27
- })
35
+ });
28
36
  }
37
+ //# sourceMappingURL=masks.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"masks.js","sourceRoot":"","sources":["../src/masks.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;GAIG;AACH,MAAM,UAAU,MAAM,CAAC,gBAAwB,EAAE,cAAsB;IACnE,OAAO,EAAE,CAAC,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,gBAAgB,EAAE,cAAc,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC;SACxE,GAAG,CAAC,CAAC,CAAC;SACN,GAAG,CAAC,GAAG,CAAC,CAAC;AAClB,CAAC;AAGD;;;;;;;;;;GAUG;AACH,MAAM,UAAU,OAAO,CAAC,UAAsB;IAC1C,gBAAgB;IAChB,sGAAsG;IACtG,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,oFAAoF;QACpF,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,UAAU,CAAC,CAAC,MAAM,EAAE,CAAC;QAEpD,OAAO,UAAU,CAAC,UAAU,CAAC,CAAC,CAAC;aAC1B,KAAK,CAAC,UAAU,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;aAC/B,GAAG,CAAC,CAAC,CAAC;aACN,GAAG,CAAC,GAAG,CAAC;YACT,sDAAsD;aACrD,UAAU,CAAC,CAAC,CAAC,CAAC;IACvB,CAAC,CAAC,CAAA;AACN,CAAC"}
@@ -0,0 +1,20 @@
1
+ import { Tensor } from "@tensorflow/tfjs";
2
+ /**
3
+ * Applies the recall metric with the prediction rounded based on a threshold
4
+ *
5
+ * @param y_true the label tensor
6
+ * @param y_pred the prediction tensor
7
+ * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
8
+ * @returns
9
+ */
10
+ export declare function recall(y_true: Tensor, y_pred: Tensor, threshold?: number): Tensor<import("@tensorflow/tfjs-core").Rank>;
11
+ /**
12
+ * Applies the precision metric with the prediction rounded based on a threshold
13
+ *
14
+ * @param y_true the label tensor
15
+ * @param y_pred the prediction tensor
16
+ * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
17
+ * @returns
18
+ */
19
+ export declare function precision(y_true: Tensor, y_pred: Tensor, threshold?: number): Tensor<import("@tensorflow/tfjs-core").Rank>;
20
+ //# sourceMappingURL=metrics.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"metrics.d.ts","sourceRoot":"","sources":["../src/metrics.ts"],"names":[],"mappings":"AAAA,OAAO,EAAW,MAAM,EAAE,MAAM,kBAAkB,CAAC;AAGnD;;;;;;;GAOG;AACH,wBAAgB,MAAM,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,SAAS,GAAE,MAAY,gDAE7E;AAKD;;;;;;;GAOG;AACH,wBAAgB,SAAS,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,SAAS,GAAE,MAAY,gDAEhF"}
@@ -1,32 +1,28 @@
1
- import { metrics, Tensor } from "@tensorflow/tfjs";
2
-
3
-
1
+ import { metrics } from "@tensorflow/tfjs";
4
2
  /**
5
3
  * Applies the recall metric with the prediction rounded based on a threshold
6
- *
4
+ *
7
5
  * @param y_true the label tensor
8
6
  * @param y_pred the prediction tensor
9
7
  * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
10
- * @returns
8
+ * @returns
11
9
  */
12
- export function recall(y_true: Tensor, y_pred: Tensor, threshold: number = 0.5) {
10
+ export function recall(y_true, y_pred, threshold = 0.5) {
13
11
  return metrics.recall(y_true, y_pred.greaterEqual(threshold));
14
12
  }
15
-
16
13
  // prevents minification of function name which TFJS relies on
17
14
  Object.defineProperty(recall, "name", { value: "recall", configurable: false });
18
-
19
15
  /**
20
16
  * Applies the precision metric with the prediction rounded based on a threshold
21
- *
17
+ *
22
18
  * @param y_true the label tensor
23
19
  * @param y_pred the prediction tensor
24
20
  * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
25
- * @returns
21
+ * @returns
26
22
  */
27
- export function precision(y_true: Tensor, y_pred: Tensor, threshold: number = 0.5) {
23
+ export function precision(y_true, y_pred, threshold = 0.5) {
28
24
  return metrics.precision(y_true, y_pred.greaterEqual(threshold));
29
25
  }
30
-
31
26
  // prevents minification of function name which TFJS relies on
32
27
  Object.defineProperty(precision, "name", { value: "precision", configurable: false });
28
+ //# sourceMappingURL=metrics.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"metrics.js","sourceRoot":"","sources":["../src/metrics.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,OAAO,EAAU,MAAM,kBAAkB,CAAC;AAGnD;;;;;;;GAOG;AACH,MAAM,UAAU,MAAM,CAAC,MAAc,EAAE,MAAc,EAAE,YAAoB,GAAG;IAC1E,OAAO,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,YAAY,CAAC,SAAS,CAAC,CAAC,CAAC;AAClE,CAAC;AAED,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,QAAQ,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAEhF;;;;;;;GAOG;AACH,MAAM,UAAU,SAAS,CAAC,MAAc,EAAE,MAAc,EAAE,YAAoB,GAAG;IAC7E,OAAO,OAAO,CAAC,SAAS,CAAC,MAAM,EAAE,MAAM,CAAC,YAAY,CAAC,SAAS,CAAC,CAAC,CAAC;AACrE,CAAC;AAED,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,SAAS,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,WAAW,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC"}
@@ -0,0 +1,94 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type LossOrMetricFn } from "../tfjs_types";
3
+ import { LlmModel, type LlmModelArgs } from "../models/llm_model";
4
+ import { KvCacheContainer } from "../kv_cache";
5
+ import { type DisposeResult } from "@tensorflow/tfjs-layers/dist/engine/topology";
6
+ export interface GptModelArgs extends LlmModelArgs {
7
+ /**
8
+ * Number of heads per attention layer.
9
+ */
10
+ numHeads: number;
11
+ /**
12
+ * Number of GPT decoder blocks.
13
+ */
14
+ numLayers: number;
15
+ /**
16
+ * The embedding size of each token.
17
+ */
18
+ embedDim: number;
19
+ /**
20
+ * The vocabulary size of the embedding layer and number of units of the output
21
+ * layer. This is also the tokenizer vocabulary size.
22
+ */
23
+ vocabSize: number;
24
+ /**
25
+ * Pad the embeddings' vocab size and output layer's units to the next nearest
26
+ * multiple of 64 to optimize hardware efficiency. Defaults to `true`.
27
+ *
28
+ * For example: if a tokenizer has 50,257 tokens, the model uses 50,304 for the
29
+ * vocab size and output units count.
30
+ */
31
+ padToMultipleOf64?: boolean;
32
+ }
33
+ /**
34
+ * This is a subclass of tf.Sequential that creating a GPT-like model and
35
+ * automatically handles padding (and masking) the vocab size for hardware
36
+ * efficiency.
37
+ *
38
+ * Example:
39
+ *
40
+ * ```javascript
41
+ *
42
+ * const model = new GptModel({ numLayers: 1, numHeads: 1, embedDim: 16, vocabSize: 64 });
43
+ * model.compile({ loss: "sparseCategoricalCrossentropy", optimizer: "adam" });
44
+ *
45
+ * // use fitDataset() instead of fit for masking support
46
+ * model.fitDataset(your_batched_generator_dataset, { epochs: 1 });
47
+ *
48
+ * const kv_cache = new KvCacheContainer(your_preferred_max_sequence_length);
49
+ *
50
+ * // use generate() and predictNextToken() instead of predict() for masking and auto memory cleanup
51
+ * model.generate(tokenized_tensor1d_input, kv_cache, onPredict_callback)
52
+ *
53
+ *
54
+ * ```
55
+ */
56
+ export declare class GptModel extends LlmModel {
57
+ static className: string;
58
+ protected readonly numHeads: number;
59
+ protected readonly numLayers: number;
60
+ protected readonly embedDim: number;
61
+ protected readonly vocabSize: number;
62
+ protected readonly padToMultipleOf64: boolean;
63
+ protected readonly vocabSizePadded: number;
64
+ protected vocab_padding_mask?: tf.Tensor1D;
65
+ /**
66
+ * DO NOT add layers in the constructor or it will break tf.loadLayersModel().
67
+ * It should be done in build() instead.
68
+ */
69
+ constructor(args: GptModelArgs);
70
+ protected fitBatch(xs: tf.Tensor, ys: tf.Tensor, loss_mask: tf.Tensor | undefined, loss_function: LossOrMetricFn, other_masks?: {
71
+ [key: string]: tf.Tensor | undefined;
72
+ }): {
73
+ y_pred: tf.Tensor<tf.Rank>;
74
+ loss: tf.Scalar;
75
+ };
76
+ /**
77
+ * Overrides LlmModel.predictNextToken to add softmax before argMax because the final
78
+ * dense layer doesn't have an activation.
79
+ *
80
+ * TODO: implement temperature and multinomial sampling so that the model has varied outputs
81
+ */
82
+ predictNextToken(input: tf.Tensor2D, kv_cache: KvCacheContainer): tf.Tensor2D;
83
+ build(inputShape?: tf.Shape | tf.Shape[]): void;
84
+ dispose(): DisposeResult;
85
+ getConfig(): {
86
+ numHeads: number;
87
+ numLayers: number;
88
+ embedDim: number;
89
+ vocabSize: number;
90
+ vocabSizePadded: number;
91
+ padToMultipleOf64: boolean;
92
+ };
93
+ }
94
+ //# sourceMappingURL=gpt_model.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"gpt_model.d.ts","sourceRoot":"","sources":["../../src/models/gpt_model.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,cAAc,EAAE,MAAM,eAAe,CAAC;AACpD,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,qBAAqB,CAAC;AAClE,OAAO,EAAE,gBAAgB,EAAE,MAAM,aAAa,CAAC;AAC/C,OAAO,EAAE,KAAK,aAAa,EAAE,MAAM,8CAA8C,CAAC;AAIlF,MAAM,WAAW,YAAa,SAAQ,YAAY;IAC9C;;OAEG;IACH,QAAQ,EAAE,MAAM,CAAC;IACjB;;OAEG;IACH,SAAS,EAAE,MAAM,CAAC;IAClB;;OAEG;IACH,QAAQ,EAAE,MAAM,CAAC;IACjB;;;OAGG;IACH,SAAS,EAAE,MAAM,CAAC;IAClB;;;;;;OAMG;IACH,iBAAiB,CAAC,EAAE,OAAO,CAAC;CAC/B;AAGD;;;;;;;;;;;;;;;;;;;;;;GAsBG;AACH,qBAAa,QAAS,SAAQ,QAAQ;IAClC,MAAM,CAAC,SAAS,SAAc;IAE9B,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,SAAS,EAAE,MAAM,CAAC;IACrC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,SAAS,EAAE,MAAM,CAAC;IACrC,SAAS,CAAC,QAAQ,CAAC,iBAAiB,EAAE,OAAO,CAAC;IAI9C,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,MAAM,CAAC;IAG3C,SAAS,CAAC,kBAAkB,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC;IAG3C;;;OAGG;gBACS,IAAI,EAAE,YAAY;cAgBX,QAAQ,CACvB,EAAE,EAAE,EAAE,CAAC,MAAM,EACb,EAAE,EAAE,EAAE,CAAC,MAAM,EACb,SAAS,EAAE,EAAE,CAAC,MAAM,GAAG,SAAS,EAChC,aAAa,EAAE,cAAc,EAC7B,WAAW,CAAC,EAAE;QAAE,CAAC,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,SAAS,CAAA;KAAE;;;;IAsC1D;;;;;OAKG;IACM,gBAAgB,CAAC,KAAK,EAAE,EAAE,CAAC,QAAQ,EAAE,QAAQ,EAAE,gBAAgB,GAAG,EAAE,CAAC,QAAQ;IA4B7E,KAAK,CAAC,UAAU,CAAC,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA+B/C,OAAO,IAAI,aAAa;IAMxB,SAAS;;;;;;;;CAiBrB"}