@stellarapp/tfjs-stellar 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +14 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -1,24 +1,8 @@
1
1
  import * as tf from "@tensorflow/tfjs";
2
- import * as tfc from "@/index";
3
2
  import { sparseCategoricalCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
4
- import { Dataset, type LossOrMetricFn } from "@/tfjs_types";
5
- import { generateCausalAttentionMask } from "@/utils";
6
- import { KvCacheContainer } from "@/kv_cache";
7
-
8
-
9
- // eslint-disable-next-line
10
- export interface LlmModelArgs extends tf.SequentialArgs {
11
- };
12
-
13
-
14
- interface DatasetArgs extends tf.TensorContainerObject {
15
- xs: tf.Tensor;
16
- ys: tf.Tensor;
17
- loss_mask?: tf.Tensor;
18
- packing_mask?: tf.Tensor;
19
- }
20
-
21
-
3
+ import { causal as generateCausalMask } from "../masks";
4
+ import * as losses from "../losses";
5
+ ;
22
6
  /**
23
7
  * This class overrides the `fitDataset()` function of tf.Sequential to support loss
24
8
  * and packing masking. Use the `generate()` function to autoregressively predict the
@@ -26,42 +10,33 @@ interface DatasetArgs extends tf.TensorContainerObject {
26
10
  */
27
11
  export class LlmModel extends tf.Sequential {
28
12
  static className = "LlmModel";
29
-
30
- private stopPredicting_: boolean = true;
31
-
32
- constructor(args: LlmModelArgs) {
13
+ stopPredicting_ = true;
14
+ constructor(args) {
33
15
  args.name = args.name ?? "model";
34
16
  super(args);
35
17
  }
36
-
37
-
38
18
  /**
39
19
  * Returns the metric functions and names so that metrics can be reported
40
20
  * as they are in the base version of model.fitDataset
41
- *
21
+ *
42
22
  * e.g. "categoricalAccuracy" should be reported as "acc"
43
23
  */
44
- protected getMetricFunctions() {
24
+ getMetricFunctions() {
45
25
  const [loss, ...metric_fn_names] = this.metricsNames;
46
-
47
26
  return this.metricsTensors.map((metric_tensor, index) => ({
48
27
  metric_fn: metric_tensor[0],
49
28
  metric_label: metric_fn_names[index]
50
- }))
29
+ }));
51
30
  }
52
-
53
-
54
31
  /**
55
32
  * Get exactly one loss function from the loss function provided in `model.compile()`.
56
33
  * If a string identifier was used, convert it to the actual loss function.
57
34
  */
58
- protected getLossFunction(): LossOrMetricFn {
35
+ getLossFunction() {
59
36
  let loss = this.loss;
60
-
61
37
  if (Array.isArray(loss)) {
62
38
  loss = loss[0];
63
39
  }
64
-
65
40
  if (typeof loss == "string") {
66
41
  if (loss == "sparseCategoricalCrossentropy") {
67
42
  return sparseCategoricalCrossentropy;
@@ -70,249 +45,174 @@ export class LlmModel extends tf.Sequential {
70
45
  " Use categoricalCrossentropy instead. See" +
71
46
  " https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146"); */
72
47
  }
73
-
74
- const loss_id = loss as string;
75
-
76
- const loss_fn =
77
- ((tfc.losses as Record<string, any>)[loss_id] ??
78
- (tf.losses as Record<string, any>)[loss_id] ??
79
- (tf.metrics as Record<string, any>)[loss_id]) as LossOrMetricFn
80
-
48
+ const loss_id = loss;
49
+ const loss_fn = (losses[loss_id] ??
50
+ tf.losses[loss_id] ??
51
+ tf.metrics[loss_id]);
81
52
  if (loss_fn) {
82
- return loss_fn
83
- } else {
53
+ return loss_fn;
54
+ }
55
+ else {
84
56
  throw Error(`LlmModel.getLossFunction: ${loss_id} is not a valid loss function`);
85
57
  }
86
- } else if (typeof loss == "function") {
58
+ }
59
+ else if (typeof loss == "function") {
87
60
  return loss;
88
61
  }
89
-
90
62
  throw Error("LlmModel.getLossFunction: the loss function's type should be string or function");
91
63
  }
92
-
93
-
94
64
  /**
95
65
  * Train on a `tf.data.generator` dataset. See https://js.tensorflow.org/api/latest/#data.generator.
96
- *
66
+ *
97
67
  * The generator should yield `xs`, `ys`, `loss_mask` (if fine-tuning), and
98
68
  * `packing_mask` (if sequence packing was done)
99
- *
69
+ *
100
70
  * @param tfdataset an instance of a `tf.Dataset` generator
101
71
  * @param args a ModelFitDatasetArgs
102
72
  */
103
- override async fitDataset<T = DatasetArgs>(tfdataset: Dataset<T>, args: tf.ModelFitDatasetArgs<T>): Promise<any> {
73
+ async fitDataset(tfdataset, args) {
104
74
  this.stopTraining = false;
105
-
106
- const dataset = tfdataset as tf.data.Dataset<DatasetArgs>;
75
+ const dataset = tfdataset;
107
76
  const { epochs, callbacks } = args;
108
-
109
77
  const metric_functions = this.getMetricFunctions();
110
78
  const loss_function = this.getLossFunction();
111
79
  this.lossFunctions = [loss_function];
112
-
113
- const {
114
- onBatchBegin,
115
- onBatchEnd,
116
- onEpochBegin,
117
- onEpochEnd,
118
- onTrainBegin,
119
- onTrainEnd,
120
- } = callbacks as tf.CustomCallbackArgs ?? {};
121
-
80
+ const { onBatchBegin, onBatchEnd, onEpochBegin, onEpochEnd, onTrainBegin, onTrainEnd, } = callbacks ?? {};
122
81
  await onTrainBegin?.();
123
-
124
- let cached_causal_mask: tf.Tensor | undefined = undefined;
125
-
82
+ let cached_causal_mask = undefined;
126
83
  for (let epoch = 0; epoch < epochs; epoch++) {
127
84
  await onEpochBegin?.(epoch);
128
-
129
85
  let batch = 0;
130
86
  let total_samples = 0;
131
- const accumulated_epoch_metrics: { [metric: string]: number } = {};
132
-
87
+ const accumulated_epoch_metrics = {};
133
88
  // loop through dataset using its iterator
134
89
  const iterator = await dataset.iterator();
135
90
  let sample = await iterator.next();
136
-
137
91
  while (!sample.done) {
138
- const batch_metrics: { [metric: string]: number } = { batch };
139
-
92
+ const batch_metrics = { batch };
140
93
  const { xs, ys, loss_mask, packing_mask } = sample.value;
141
94
  const batch_size = xs.shape[0];
142
95
  total_samples += batch_size; // for epoch metrics averaging
143
-
144
96
  if (xs.shape.length != 2) {
145
97
  throw Error(`LlmModel.fitDataset: ${this.name} the generator dataset should be batched, run: dataset.batch(batch_size)`);
146
98
  }
147
-
148
99
  // pre-calculate the causal attention mask and reuse it for all attention layers,
149
100
  const seq_length = xs.shape[xs.shape.length - 1];
150
-
151
101
  if (!cached_causal_mask || cached_causal_mask.shape[0] != seq_length) {
152
- cached_causal_mask = generateCausalAttentionMask(seq_length, seq_length);
102
+ cached_causal_mask = generateCausalMask(seq_length, seq_length);
153
103
  }
154
-
155
104
  await onBatchBegin?.(batch);
156
-
157
105
  tf.tidy(() => {
158
106
  const { y_pred, loss } = this.fitBatch(xs, ys, loss_mask, loss_function, {
159
107
  packingMask: packing_mask,
160
108
  causalMask: cached_causal_mask
161
- })
162
-
109
+ });
163
110
  const loss_value = (loss.dataSync())[0];
164
-
165
111
  batch_metrics.loss = loss_value;
166
112
  accumulated_epoch_metrics.loss = (accumulated_epoch_metrics.loss || 0) + loss_value * batch_size;
167
-
168
113
  // calculate and store metrics
169
114
  for (const { metric_fn, metric_label } of metric_functions) {
170
- const metric_sum = metric_fn(ys, y_pred!).mean();
171
-
115
+ const metric_sum = metric_fn(ys, y_pred).mean();
172
116
  const metric_value = (metric_sum.dataSync())[0];
173
-
174
- batch_metrics[metric_label] = metric_value// / batch_size;
117
+ batch_metrics[metric_label] = metric_value; // / batch_size;
175
118
  accumulated_epoch_metrics[metric_label] = (accumulated_epoch_metrics[metric_label] || 0) + metric_value * batch_size;
176
119
  }
177
-
178
- tf.dispose(y_pred!);
179
- })
180
-
120
+ tf.dispose(y_pred);
121
+ });
181
122
  tf.dispose(xs);
182
123
  tf.dispose(ys);
183
124
  tf.dispose(loss_mask);
184
-
185
125
  if (packing_mask) {
186
126
  tf.dispose(packing_mask);
187
127
  }
188
-
189
128
  await onBatchEnd?.(batch, batch_metrics);
190
-
191
129
  // so that stop training works
192
130
  await tf.nextFrame();
193
-
194
131
  if (this.stopTraining) {
195
132
  break;
196
133
  }
197
-
198
134
  sample = await iterator.next();
199
135
  batch++;
200
136
  }
201
-
202
137
  for (const metric in accumulated_epoch_metrics) {
203
138
  accumulated_epoch_metrics[metric] = accumulated_epoch_metrics[metric] / total_samples;
204
139
  }
205
-
206
140
  await onEpochEnd?.(epoch, accumulated_epoch_metrics);
207
-
208
141
  if (this.stopTraining) {
209
142
  break;
210
143
  }
211
144
  }
212
-
213
145
  tf.dispose(cached_causal_mask);
214
- await onTrainEnd?.()
215
-
146
+ await onTrainEnd?.();
216
147
  return {};
217
148
  }
218
-
219
-
220
149
  /**
221
150
  * Run the core forward and backward propagation on one training batch. This
222
151
  * should be called within a `tf.tidy()`.
223
- *
152
+ *
224
153
  * @param xs the sample/input tensor
225
154
  * @param ys the label/target tensor
226
155
  * @param loss_mask a loss mask to ignore the prediction's non-assistant tokens
227
156
  * @param loss_function the model's loss function
228
157
  * @param other_masks other masks used by the model's layers e.g. packing mask, causal mask
229
158
  */
230
- protected fitBatch(
231
- xs: tf.Tensor,
232
- ys: tf.Tensor,
233
- loss_mask: tf.Tensor | undefined,
234
- loss_function: LossOrMetricFn,
235
- other_masks?: { [key: string]: tf.Tensor | undefined }
236
- ): {
237
- y_pred: tf.Tensor<tf.Rank>;
238
- loss: tf.Scalar;
239
- } {
240
- let y_pred: tf.Tensor;
241
-
159
+ fitBatch(xs, ys, loss_mask, loss_function, other_masks) {
160
+ let y_pred;
242
161
  // forward pass, calculate loss
243
162
  const { value: loss, grads } = tf.variableGrads(() => {
244
163
  // prediction has shape [batch, sequence_length, vocab_size]
245
164
  y_pred = this.apply(xs, {
246
165
  training: true,
247
166
  ...other_masks
248
- }) as tf.Tensor;
249
-
167
+ });
250
168
  // manually dispose later instead of the built-in disposal from variableGrads
251
169
  tf.keep(y_pred);
252
-
253
170
  const loss = loss_mask
254
171
  ? loss_function(ys, y_pred).mul(loss_mask)
255
172
  : loss_function(ys, y_pred);
256
-
257
- return loss.mean() as tf.Scalar;
173
+ return loss.mean();
258
174
  });
259
-
260
175
  // backpropagation
261
176
  this.optimizer.applyGradients(grads);
262
-
263
177
  return {
264
- y_pred: y_pred!,
178
+ y_pred: y_pred,
265
179
  loss
266
180
  };
267
181
  }
268
-
269
-
270
- override compile(args: tf.ModelCompileArgs): void {
182
+ compile(args) {
271
183
  if (args.loss == "categoricalCrossentropy") {
272
- throw Error(`LlmModel.compile: use sparseCategoricalCrossentropy loss (along with onehot encoded labels) instead of categoricalCrossEntropy`)
184
+ throw Error(`LlmModel.compile: use sparseCategoricalCrossentropy loss (along with onehot encoded labels) instead of categoricalCrossEntropy`);
273
185
  }
274
-
275
186
  super.compile(args);
276
187
  }
277
-
278
-
279
188
  /**
280
189
  * Autoregressively generate the next token until `model.stopPredicting` is set
281
190
  * to `true` or the KV cache reaches its maximum sequence length. For a single chat
282
191
  * session, the input should only be the most recent prompt(s). The KV cache stores
283
192
  * the prior chat history up until the most recent chat.
284
- *
193
+ *
285
194
  * @param input tokenized input of the newest chat
286
195
  * @param kv_cache an instance of a KV cache container
287
196
  * @param onPredict callback function to receive the most recent token predicted
288
197
  */
289
- public async generate(input: tf.Tensor1D, kv_cache: KvCacheContainer, onPredict: (token: tf.Tensor) => Promise<void>) {
198
+ async generate(input, kv_cache, onPredict) {
290
199
  if (kv_cache.size >= kv_cache.maxSequenceLength) {
291
200
  throw Error(`LlmModel.generate: ${this.name} KV cache's size reached the maxSequenceLength (${kv_cache.maxSequenceLength})`);
292
201
  }
293
-
294
202
  this.stopPredicting = false;
295
-
296
- let current_token: tf.Tensor2D = tf.tidy(() => input.expandDims(0)) as tf.Tensor2D; // it's 2D because of the required batch dimension
297
-
203
+ let current_token = tf.tidy(() => input.expandDims(0)); // it's 2D because of the required batch dimension
298
204
  while (!this.stopPredicting && kv_cache.size < kv_cache.maxSequenceLength) {
299
205
  // add a batch dimension because forward pass requires inputs batched
300
206
  const next_token = tf.tidy(() => this.predictNextToken(current_token, kv_cache));
301
-
302
207
  // pass back the predicted token, without the batch dim,
303
208
  const unbatched_next_token = tf.tidy(() => next_token.squeeze([0]));
304
209
  await onPredict(unbatched_next_token);
305
-
306
210
  unbatched_next_token.dispose();
307
-
308
211
  current_token.dispose();
309
212
  current_token = next_token;
310
213
  }
311
-
312
214
  tf.dispose(current_token);
313
215
  }
314
-
315
-
316
216
  /**
317
217
  * Given a tokenized sentence, predict the next token (word).
318
218
  * A normal prediction is ran to get an output with the shape
@@ -321,35 +221,25 @@ export class LlmModel extends tf.Sequential {
321
221
  * position of `sentence_length` is returned as the next predicted
322
222
  * token.
323
223
  */
324
- public predictNextToken(input: tf.Tensor2D, kv_cache: KvCacheContainer) {
224
+ predictNextToken(input, kv_cache) {
325
225
  if (input.shape[0] != 1) {
326
226
  throw Error(`LlmModel.predictNextToken: ${this.name} expects an input with a batch size of 1`);
327
227
  }
328
-
329
228
  return tf.tidy(() => {
330
229
  // comes back as [batch, sequence_length, vocab_size]
331
- const prediction = this.apply(input, { kvCache: kv_cache }) as tf.Tensor;
332
-
230
+ const prediction = this.apply(input, { kvCache: kv_cache });
333
231
  const [batch_size, sequence_length, vocab_size] = prediction.shape;
334
-
335
232
  // get the last token
336
- const next_token = prediction.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size]).argMax(2)
337
-
338
- return next_token as tf.Tensor2D;
339
- })
233
+ const next_token = prediction.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size]).argMax(2);
234
+ return next_token;
235
+ });
340
236
  }
341
-
342
-
343
237
  get stopPredicting() {
344
238
  return this.stopPredicting_;
345
239
  }
346
-
347
-
348
- set stopPredicting(stop: boolean) {
240
+ set stopPredicting(stop) {
349
241
  this.stopPredicting_ = stop;
350
242
  }
351
-
352
243
  }
353
-
354
-
355
244
  tf.serialization.registerClass(LlmModel);
245
+ //# sourceMappingURL=llm_model.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"llm_model.js","sourceRoot":"","sources":["../../src/models/llm_model.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,6BAA6B,EAAE,MAAM,qCAAqC,CAAC;AAEpF,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,UAAU,CAAC;AAExD,OAAO,KAAK,MAAM,MAAM,WAAW,CAAC;AAKnC,CAAC;AAWF;;;;GAIG;AACH,MAAM,OAAO,QAAS,SAAQ,EAAE,CAAC,UAAU;IACvC,MAAM,CAAC,SAAS,GAAG,UAAU,CAAC;IAEtB,eAAe,GAAY,IAAI,CAAC;IAExC,YAAY,IAAkB;QAC1B,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,OAAO,CAAC;QACjC,KAAK,CAAC,IAAI,CAAC,CAAC;IAChB,CAAC;IAGD;;;;;OAKG;IACO,kBAAkB;QACxB,MAAM,CAAC,IAAI,EAAE,GAAG,eAAe,CAAC,GAAG,IAAI,CAAC,YAAY,CAAC;QAErD,OAAO,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,aAAa,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC;YACtD,SAAS,EAAE,aAAa,CAAC,CAAC,CAAC;YAC3B,YAAY,EAAE,eAAe,CAAC,KAAK,CAAC;SACvC,CAAC,CAAC,CAAA;IACP,CAAC;IAGD;;;OAGG;IACO,eAAe;QACrB,IAAI,IAAI,GAAG,IAAI,CAAC,IAAI,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC;YACtB,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;QACnB,CAAC;QAED,IAAI,OAAO,IAAI,IAAI,QAAQ,EAAE,CAAC;YAC1B,IAAI,IAAI,IAAI,+BAA+B,EAAE,CAAC;gBAC1C,OAAO,6BAA6B,CAAC;gBACrC;;;gJAGgI;YACpI,CAAC;YAED,MAAM,OAAO,GAAG,IAAc,CAAC;YAE/B,MAAM,OAAO,GACT,CAAE,MAA8B,CAAC,OAAO,CAAC;gBACpC,EAAE,CAAC,MAA8B,CAAC,OAAO,CAAC;gBAC1C,EAAE,CAAC,OAA+B,CAAC,OAAO,CAAC,CAAmB,CAAA;YAEvE,IAAI,OAAO,EAAE,CAAC;gBACV,OAAO,OAAO,CAAA;YAClB,CAAC;iBAAM,CAAC;gBACJ,MAAM,KAAK,CAAC,6BAA6B,OAAO,+BAA+B,CAAC,CAAC;YACrF,CAAC;QACL,CAAC;aAAM,IAAI,OAAO,IAAI,IAAI,UAAU,EAAE,CAAC;YACnC,OAAO,IAAI,CAAC;QAChB,CAAC;QAED,MAAM,KAAK,CAAC,iFAAiF,CAAC,CAAC;IACnG,CAAC;IAGD;;;;;;;;OAQG;IACM,KAAK,CAAC,UAAU,CAAkB,SAAqB,EAAE,IAA+B;QAC7F,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC;QAE1B,MAAM,OAAO,GAAG,SAAyC,CAAC;QAC1D,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,GAAG,IAAI,CAAC;QAEnC,MAAM,gBAAgB,GAAG,IAAI,CAAC,kBAAkB,EAAE,CAAC;QACnD,MAAM,aAAa,GAAG,IAAI,CAAC,eAAe,EAAE,CAAC;QAC7C,IAAI,CAAC,aAAa,GAAG,CAAC,aAAa,CAAC,CAAC;QAErC,MAAM,EACF,YAAY,EACZ,UAAU,EACV,YAAY,EACZ,UAAU,EACV,YAAY,EACZ,UAAU,GACb,GAAG,SAAkC,IAAI,EAAE,CAAC;QAE7C,MAAM,YAAY,EAAE,EAAE,CAAC;QAEvB,IAAI,kBAAkB,GAA0B,SAAS,CAAC;QAE1D,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC;YAC1C,MAAM,YAAY,EAAE,CAAC,KAAK,CAAC,CAAC;YAE5B,IAAI,KAAK,GAAG,CAAC,CAAC;YACd,IAAI,aAAa,GAAG,CAAC,CAAC;YACtB,MAAM,yBAAyB,GAAiC,EAAE,CAAC;YAEnE,0CAA0C;YAC1C,MAAM,QAAQ,GAAG,MAAM,OAAO,CAAC,QAAQ,EAAE,CAAC;YAC1C,IAAI,MAAM,GAAG,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC;YAEnC,OAAO,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC;gBAClB,MAAM,aAAa,GAAiC,EAAE,KAAK,EAAE,CAAC;gBAE9D,MAAM,EAAE,EAAE,EAAE,EAAE,EAAE,SAAS,EAAE,YAAY,EAAE,GAAG,MAAM,CAAC,KAAK,CAAC;gBACzD,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBAC/B,aAAa,IAAI,UAAU,CAAC,CAAC,8BAA8B;gBAE3D,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;oBACvB,MAAM,KAAK,CAAC,wBAAwB,IAAI,CAAC,IAAI,0EAA0E,CAAC,CAAC;gBAC7H,CAAC;gBAED,iFAAiF;gBACjF,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;gBAEjD,IAAI,CAAC,kBAAkB,IAAI,kBAAkB,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,UAAU,EAAE,CAAC;oBACnE,kBAAkB,GAAG,kBAAkB,CAAC,UAAU,EAAE,UAAU,CAAC,CAAC;gBACpE,CAAC;gBAED,MAAM,YAAY,EAAE,CAAC,KAAK,CAAC,CAAC;gBAE5B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;oBACT,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,GAAG,IAAI,CAAC,QAAQ,CAAC,EAAE,EAAE,EAAE,EAAE,SAAS,EAAE,aAAa,EAAE;wBACrE,WAAW,EAAE,YAAY;wBACzB,UAAU,EAAE,kBAAkB;qBACjC,CAAC,CAAA;oBAEF,MAAM,UAAU,GAAG,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;oBAExC,aAAa,CAAC,IAAI,GAAG,UAAU,CAAC;oBAChC,yBAAyB,CAAC,IAAI,GAAG,CAAC,yBAAyB,CAAC,IAAI,IAAI,CAAC,CAAC,GAAG,UAAU,GAAG,UAAU,CAAC;oBAEjG,8BAA8B;oBAC9B,KAAK,MAAM,EAAE,SAAS,EAAE,YAAY,EAAE,IAAI,gBAAgB,EAAE,CAAC;wBACzD,MAAM,UAAU,GAAG,SAAS,CAAC,EAAE,EAAE,MAAO,CAAC,CAAC,IAAI,EAAE,CAAC;wBAEjD,MAAM,YAAY,GAAG,CAAC,UAAU,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;wBAEhD,aAAa,CAAC,YAAY,CAAC,GAAG,YAAY,CAAA,CAAA,gBAAgB;wBAC1D,yBAAyB,CAAC,YAAY,CAAC,GAAG,CAAC,yBAAyB,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC,GAAG,YAAY,GAAG,UAAU,CAAC;oBACzH,CAAC;oBAED,EAAE,CAAC,OAAO,CAAC,MAAO,CAAC,CAAC;gBACxB,CAAC,CAAC,CAAA;gBAEF,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;gBACf,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;gBACf,EAAE,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;gBAEtB,IAAI,YAAY,EAAE,CAAC;oBACf,EAAE,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC;gBAC7B,CAAC;gBAED,MAAM,UAAU,EAAE,CAAC,KAAK,EAAE,aAAa,CAAC,CAAC;gBAEzC,8BAA8B;gBAC9B,MAAM,EAAE,CAAC,SAAS,EAAE,CAAC;gBAErB,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;oBACpB,MAAM;gBACV,CAAC;gBAED,MAAM,GAAG,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC;gBAC/B,KAAK,EAAE,CAAC;YACZ,CAAC;YAED,KAAK,MAAM,MAAM,IAAI,yBAAyB,EAAE,CAAC;gBAC7C,yBAAyB,CAAC,MAAM,CAAC,GAAG,yBAAyB,CAAC,MAAM,CAAC,GAAG,aAAa,CAAC;YAC1F,CAAC;YAED,MAAM,UAAU,EAAE,CAAC,KAAK,EAAE,yBAAyB,CAAC,CAAC;YAErD,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;gBACpB,MAAM;YACV,CAAC;QACL,CAAC;QAED,EAAE,CAAC,OAAO,CAAC,kBAAkB,CAAC,CAAC;QAC/B,MAAM,UAAU,EAAE,EAAE,CAAA;QAEpB,OAAO,EAAE,CAAC;IACd,CAAC;IAGD;;;;;;;;;OASG;IACO,QAAQ,CACd,EAAa,EACb,EAAa,EACb,SAAgC,EAChC,aAA6B,EAC7B,WAAsD;QAKtD,IAAI,MAAiB,CAAC;QAEtB,+BAA+B;QAC/B,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,CAAC,aAAa,CAAC,GAAG,EAAE;YACjD,4DAA4D;YAC5D,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,EAAE,EAAE;gBACpB,QAAQ,EAAE,IAAI;gBACd,GAAG,WAAW;aACjB,CAAc,CAAC;YAEhB,6EAA6E;YAC7E,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAEhB,MAAM,IAAI,GAAG,SAAS;gBAClB,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC;gBAC1C,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC;YAEhC,OAAO,IAAI,CAAC,IAAI,EAAe,CAAC;QACpC,CAAC,CAAC,CAAC;QAEH,kBAAkB;QAClB,IAAI,CAAC,SAAS,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC;QAErC,OAAO;YACH,MAAM,EAAE,MAAO;YACf,IAAI;SACP,CAAC;IACN,CAAC;IAGQ,OAAO,CAAC,IAAyB;QACtC,IAAI,IAAI,CAAC,IAAI,IAAI,yBAAyB,EAAE,CAAC;YACzC,MAAM,KAAK,CAAC,gIAAgI,CAAC,CAAA;QACjJ,CAAC;QAED,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC;IACxB,CAAC;IAGD;;;;;;;;;OASG;IACI,KAAK,CAAC,QAAQ,CAAC,KAAkB,EAAE,QAA0B,EAAE,SAA8C;QAChH,IAAI,QAAQ,CAAC,IAAI,IAAI,QAAQ,CAAC,iBAAiB,EAAE,CAAC;YAC9C,MAAM,KAAK,CAAC,sBAAsB,IAAI,CAAC,IAAI,mDAAmD,QAAQ,CAAC,iBAAiB,GAAG,CAAC,CAAC;QACjI,CAAC;QAED,IAAI,CAAC,cAAc,GAAG,KAAK,CAAC;QAE5B,IAAI,aAAa,GAAgB,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAgB,CAAC,CAAC,kDAAkD;QAEtI,OAAO,CAAC,IAAI,CAAC,cAAc,IAAI,QAAQ,CAAC,IAAI,GAAG,QAAQ,CAAC,iBAAiB,EAAE,CAAC;YACxE,qEAAqE;YACrE,MAAM,UAAU,GAAG,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC,gBAAgB,CAAC,aAAa,EAAE,QAAQ,CAAC,CAAC,CAAC;YAEjF,wDAAwD;YACxD,MAAM,oBAAoB,GAAG,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACpE,MAAM,SAAS,CAAC,oBAAoB,CAAC,CAAC;YAEtC,oBAAoB,CAAC,OAAO,EAAE,CAAC;YAE/B,aAAa,CAAC,OAAO,EAAE,CAAC;YACxB,aAAa,GAAG,UAAU,CAAC;QAC/B,CAAC;QAED,EAAE,CAAC,OAAO,CAAC,aAAa,CAAC,CAAC;IAC9B,CAAC;IAGD;;;;;;;OAOG;IACI,gBAAgB,CAAC,KAAkB,EAAE,QAA0B;QAClE,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;YACtB,MAAM,KAAK,CAAC,8BAA8B,IAAI,CAAC,IAAI,0CAA0C,CAAC,CAAC;QACnG,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,qDAAqD;YACrD,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC;YAEzE,MAAM,CAAC,UAAU,EAAE,eAAe,EAAE,UAAU,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC;YAEnE,qBAAqB;YACrB,MAAM,UAAU,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,eAAe,GAAG,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,UAAU,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAA;YAEvG,OAAO,UAAyB,CAAC;QACrC,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAI,cAAc;QACd,OAAO,IAAI,CAAC,eAAe,CAAC;IAChC,CAAC;IAGD,IAAI,cAAc,CAAC,IAAa;QAC5B,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;IAChC,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC"}
@@ -0,0 +1,40 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
3
+ export interface UNetArgs {
4
+ /**
5
+ * The starting number of filters.
6
+ */
7
+ filters: number;
8
+ /**
9
+ * The number of categories. For binary segmentation, `units=1`.
10
+ */
11
+ units: number;
12
+ /**
13
+ * The activation of the final output convolution layer. Defaults to `sigmoid` if `categories=1`, else `softmax`.
14
+ */
15
+ activation?: ActivationIdentifier;
16
+ /**
17
+ * The depth of the U-Net or the number of contractions and the number of expansions.
18
+ */
19
+ depth: number;
20
+ /**
21
+ * Adds residual connections to transform the model into a ResUNet. Defaults to `false`.
22
+ */
23
+ residual?: boolean;
24
+ /**
25
+ * Adds batch normalization to convolutions. Best used for batched inputs. Defaults to `false`.
26
+ */
27
+ batchNorm?: boolean;
28
+ /**
29
+ * Set the unbatched input shape of the U-Net in the format `[height, width, channels]`. Defaults to `[null, null, 3]`. If set, only channels is mandatory.
30
+ */
31
+ inputShape?: [number | null, number | null, number];
32
+ }
33
+ export type UNetModelArgs = UNetArgs & Omit<tf.SequentialArgs, "layers">;
34
+ export declare class UNetModel extends tf.Sequential {
35
+ constructor(args: UNetModelArgs);
36
+ summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void;
37
+ }
38
+ export declare function createUNet({ filters, depth, units, activation, residual, batchNorm, inputShape }: UNetModelArgs): tf.LayersModel;
39
+ export declare function loadUNetModel(pathOrIOHandler: string | tf.io.IOHandler, options?: tf.io.LoadOptions): Promise<tf.LayersModel>;
40
+ //# sourceMappingURL=u_net.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"u_net.d.ts","sourceRoot":"","sources":["../../src/models/u_net.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAGxG,MAAM,WAAW,QAAQ;IACrB;;OAEG;IACH,OAAO,EAAE,MAAM,CAAC;IAChB;;OAEG;IACH,KAAK,EAAE,MAAM,CAAC;IACd;;OAEG;IACH,UAAU,CAAC,EAAE,oBAAoB,CAAC;IAClC;;OAEG;IACH,KAAK,EAAE,MAAM,CAAC;IACd;;OAEG;IACH,QAAQ,CAAC,EAAE,OAAO,CAAC;IACnB;;OAEG;IACH,SAAS,CAAC,EAAE,OAAO,CAAC;IACpB;;OAEG;IACH,UAAU,CAAC,EAAE,CAAC,MAAM,GAAG,IAAI,EAAE,MAAM,GAAG,IAAI,EAAE,MAAM,CAAC,CAAC;CACvD;AAGD,MAAM,MAAM,aAAa,GAAG,QAAQ,GAAG,IAAI,CAAC,EAAE,CAAC,cAAc,EAAE,QAAQ,CAAC,CAAC;AAGzE,qBAAa,SAAU,SAAQ,EAAE,CAAC,UAAU;gBAE5B,IAAI,EAAE,aAAa;IAsBtB,OAAO,CAAC,UAAU,CAAC,EAAE,MAAM,EAAE,SAAS,CAAC,EAAE,MAAM,EAAE,EAAE,OAAO,CAAC,EAAE,CAAC,OAAO,CAAC,EAAE,GAAG,EAAE,GAAG,cAAc,EAAE,GAAG,EAAE,KAAK,IAAI,GAAG,IAAI;CAIjI;AAGD,wBAAgB,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,QAAgB,EAAE,SAAiB,EAAE,UAA4B,EAAE,EAAE,aAAa,kBA4CjJ;AAGD,wBAAsB,aAAa,CAAC,eAAe,EAAE,MAAM,GAAG,EAAE,CAAC,EAAE,CAAC,SAAS,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,WAAW,2BAOzG"}