@stellarapp/tfjs-stellar 1.0.0 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +14 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -1,24 +1,8 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
|
-
import * as tfc from "@/index";
|
|
3
2
|
import { sparseCategoricalCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
|
|
4
|
-
import {
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
// eslint-disable-next-line
|
|
10
|
-
export interface LlmModelArgs extends tf.SequentialArgs {
|
|
11
|
-
};
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
interface DatasetArgs extends tf.TensorContainerObject {
|
|
15
|
-
xs: tf.Tensor;
|
|
16
|
-
ys: tf.Tensor;
|
|
17
|
-
loss_mask?: tf.Tensor;
|
|
18
|
-
packing_mask?: tf.Tensor;
|
|
19
|
-
}
|
|
20
|
-
|
|
21
|
-
|
|
3
|
+
import { causal as generateCausalMask } from "../masks";
|
|
4
|
+
import * as losses from "../losses";
|
|
5
|
+
;
|
|
22
6
|
/**
|
|
23
7
|
* This class overrides the `fitDataset()` function of tf.Sequential to support loss
|
|
24
8
|
* and packing masking. Use the `generate()` function to autoregressively predict the
|
|
@@ -26,42 +10,33 @@ interface DatasetArgs extends tf.TensorContainerObject {
|
|
|
26
10
|
*/
|
|
27
11
|
export class LlmModel extends tf.Sequential {
|
|
28
12
|
static className = "LlmModel";
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
constructor(args: LlmModelArgs) {
|
|
13
|
+
stopPredicting_ = true;
|
|
14
|
+
constructor(args) {
|
|
33
15
|
args.name = args.name ?? "model";
|
|
34
16
|
super(args);
|
|
35
17
|
}
|
|
36
|
-
|
|
37
|
-
|
|
38
18
|
/**
|
|
39
19
|
* Returns the metric functions and names so that metrics can be reported
|
|
40
20
|
* as they are in the base version of model.fitDataset
|
|
41
|
-
*
|
|
21
|
+
*
|
|
42
22
|
* e.g. "categoricalAccuracy" should be reported as "acc"
|
|
43
23
|
*/
|
|
44
|
-
|
|
24
|
+
getMetricFunctions() {
|
|
45
25
|
const [loss, ...metric_fn_names] = this.metricsNames;
|
|
46
|
-
|
|
47
26
|
return this.metricsTensors.map((metric_tensor, index) => ({
|
|
48
27
|
metric_fn: metric_tensor[0],
|
|
49
28
|
metric_label: metric_fn_names[index]
|
|
50
|
-
}))
|
|
29
|
+
}));
|
|
51
30
|
}
|
|
52
|
-
|
|
53
|
-
|
|
54
31
|
/**
|
|
55
32
|
* Get exactly one loss function from the loss function provided in `model.compile()`.
|
|
56
33
|
* If a string identifier was used, convert it to the actual loss function.
|
|
57
34
|
*/
|
|
58
|
-
|
|
35
|
+
getLossFunction() {
|
|
59
36
|
let loss = this.loss;
|
|
60
|
-
|
|
61
37
|
if (Array.isArray(loss)) {
|
|
62
38
|
loss = loss[0];
|
|
63
39
|
}
|
|
64
|
-
|
|
65
40
|
if (typeof loss == "string") {
|
|
66
41
|
if (loss == "sparseCategoricalCrossentropy") {
|
|
67
42
|
return sparseCategoricalCrossentropy;
|
|
@@ -70,249 +45,174 @@ export class LlmModel extends tf.Sequential {
|
|
|
70
45
|
" Use categoricalCrossentropy instead. See" +
|
|
71
46
|
" https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146"); */
|
|
72
47
|
}
|
|
73
|
-
|
|
74
|
-
const
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
((tfc.losses as Record<string, any>)[loss_id] ??
|
|
78
|
-
(tf.losses as Record<string, any>)[loss_id] ??
|
|
79
|
-
(tf.metrics as Record<string, any>)[loss_id]) as LossOrMetricFn
|
|
80
|
-
|
|
48
|
+
const loss_id = loss;
|
|
49
|
+
const loss_fn = (losses[loss_id] ??
|
|
50
|
+
tf.losses[loss_id] ??
|
|
51
|
+
tf.metrics[loss_id]);
|
|
81
52
|
if (loss_fn) {
|
|
82
|
-
return loss_fn
|
|
83
|
-
}
|
|
53
|
+
return loss_fn;
|
|
54
|
+
}
|
|
55
|
+
else {
|
|
84
56
|
throw Error(`LlmModel.getLossFunction: ${loss_id} is not a valid loss function`);
|
|
85
57
|
}
|
|
86
|
-
}
|
|
58
|
+
}
|
|
59
|
+
else if (typeof loss == "function") {
|
|
87
60
|
return loss;
|
|
88
61
|
}
|
|
89
|
-
|
|
90
62
|
throw Error("LlmModel.getLossFunction: the loss function's type should be string or function");
|
|
91
63
|
}
|
|
92
|
-
|
|
93
|
-
|
|
94
64
|
/**
|
|
95
65
|
* Train on a `tf.data.generator` dataset. See https://js.tensorflow.org/api/latest/#data.generator.
|
|
96
|
-
*
|
|
66
|
+
*
|
|
97
67
|
* The generator should yield `xs`, `ys`, `loss_mask` (if fine-tuning), and
|
|
98
68
|
* `packing_mask` (if sequence packing was done)
|
|
99
|
-
*
|
|
69
|
+
*
|
|
100
70
|
* @param tfdataset an instance of a `tf.Dataset` generator
|
|
101
71
|
* @param args a ModelFitDatasetArgs
|
|
102
72
|
*/
|
|
103
|
-
|
|
73
|
+
async fitDataset(tfdataset, args) {
|
|
104
74
|
this.stopTraining = false;
|
|
105
|
-
|
|
106
|
-
const dataset = tfdataset as tf.data.Dataset<DatasetArgs>;
|
|
75
|
+
const dataset = tfdataset;
|
|
107
76
|
const { epochs, callbacks } = args;
|
|
108
|
-
|
|
109
77
|
const metric_functions = this.getMetricFunctions();
|
|
110
78
|
const loss_function = this.getLossFunction();
|
|
111
79
|
this.lossFunctions = [loss_function];
|
|
112
|
-
|
|
113
|
-
const {
|
|
114
|
-
onBatchBegin,
|
|
115
|
-
onBatchEnd,
|
|
116
|
-
onEpochBegin,
|
|
117
|
-
onEpochEnd,
|
|
118
|
-
onTrainBegin,
|
|
119
|
-
onTrainEnd,
|
|
120
|
-
} = callbacks as tf.CustomCallbackArgs ?? {};
|
|
121
|
-
|
|
80
|
+
const { onBatchBegin, onBatchEnd, onEpochBegin, onEpochEnd, onTrainBegin, onTrainEnd, } = callbacks ?? {};
|
|
122
81
|
await onTrainBegin?.();
|
|
123
|
-
|
|
124
|
-
let cached_causal_mask: tf.Tensor | undefined = undefined;
|
|
125
|
-
|
|
82
|
+
let cached_causal_mask = undefined;
|
|
126
83
|
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
127
84
|
await onEpochBegin?.(epoch);
|
|
128
|
-
|
|
129
85
|
let batch = 0;
|
|
130
86
|
let total_samples = 0;
|
|
131
|
-
const accumulated_epoch_metrics
|
|
132
|
-
|
|
87
|
+
const accumulated_epoch_metrics = {};
|
|
133
88
|
// loop through dataset using its iterator
|
|
134
89
|
const iterator = await dataset.iterator();
|
|
135
90
|
let sample = await iterator.next();
|
|
136
|
-
|
|
137
91
|
while (!sample.done) {
|
|
138
|
-
const batch_metrics
|
|
139
|
-
|
|
92
|
+
const batch_metrics = { batch };
|
|
140
93
|
const { xs, ys, loss_mask, packing_mask } = sample.value;
|
|
141
94
|
const batch_size = xs.shape[0];
|
|
142
95
|
total_samples += batch_size; // for epoch metrics averaging
|
|
143
|
-
|
|
144
96
|
if (xs.shape.length != 2) {
|
|
145
97
|
throw Error(`LlmModel.fitDataset: ${this.name} the generator dataset should be batched, run: dataset.batch(batch_size)`);
|
|
146
98
|
}
|
|
147
|
-
|
|
148
99
|
// pre-calculate the causal attention mask and reuse it for all attention layers,
|
|
149
100
|
const seq_length = xs.shape[xs.shape.length - 1];
|
|
150
|
-
|
|
151
101
|
if (!cached_causal_mask || cached_causal_mask.shape[0] != seq_length) {
|
|
152
|
-
cached_causal_mask =
|
|
102
|
+
cached_causal_mask = generateCausalMask(seq_length, seq_length);
|
|
153
103
|
}
|
|
154
|
-
|
|
155
104
|
await onBatchBegin?.(batch);
|
|
156
|
-
|
|
157
105
|
tf.tidy(() => {
|
|
158
106
|
const { y_pred, loss } = this.fitBatch(xs, ys, loss_mask, loss_function, {
|
|
159
107
|
packingMask: packing_mask,
|
|
160
108
|
causalMask: cached_causal_mask
|
|
161
|
-
})
|
|
162
|
-
|
|
109
|
+
});
|
|
163
110
|
const loss_value = (loss.dataSync())[0];
|
|
164
|
-
|
|
165
111
|
batch_metrics.loss = loss_value;
|
|
166
112
|
accumulated_epoch_metrics.loss = (accumulated_epoch_metrics.loss || 0) + loss_value * batch_size;
|
|
167
|
-
|
|
168
113
|
// calculate and store metrics
|
|
169
114
|
for (const { metric_fn, metric_label } of metric_functions) {
|
|
170
|
-
const metric_sum = metric_fn(ys, y_pred
|
|
171
|
-
|
|
115
|
+
const metric_sum = metric_fn(ys, y_pred).mean();
|
|
172
116
|
const metric_value = (metric_sum.dataSync())[0];
|
|
173
|
-
|
|
174
|
-
batch_metrics[metric_label] = metric_value// / batch_size;
|
|
117
|
+
batch_metrics[metric_label] = metric_value; // / batch_size;
|
|
175
118
|
accumulated_epoch_metrics[metric_label] = (accumulated_epoch_metrics[metric_label] || 0) + metric_value * batch_size;
|
|
176
119
|
}
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
})
|
|
180
|
-
|
|
120
|
+
tf.dispose(y_pred);
|
|
121
|
+
});
|
|
181
122
|
tf.dispose(xs);
|
|
182
123
|
tf.dispose(ys);
|
|
183
124
|
tf.dispose(loss_mask);
|
|
184
|
-
|
|
185
125
|
if (packing_mask) {
|
|
186
126
|
tf.dispose(packing_mask);
|
|
187
127
|
}
|
|
188
|
-
|
|
189
128
|
await onBatchEnd?.(batch, batch_metrics);
|
|
190
|
-
|
|
191
129
|
// so that stop training works
|
|
192
130
|
await tf.nextFrame();
|
|
193
|
-
|
|
194
131
|
if (this.stopTraining) {
|
|
195
132
|
break;
|
|
196
133
|
}
|
|
197
|
-
|
|
198
134
|
sample = await iterator.next();
|
|
199
135
|
batch++;
|
|
200
136
|
}
|
|
201
|
-
|
|
202
137
|
for (const metric in accumulated_epoch_metrics) {
|
|
203
138
|
accumulated_epoch_metrics[metric] = accumulated_epoch_metrics[metric] / total_samples;
|
|
204
139
|
}
|
|
205
|
-
|
|
206
140
|
await onEpochEnd?.(epoch, accumulated_epoch_metrics);
|
|
207
|
-
|
|
208
141
|
if (this.stopTraining) {
|
|
209
142
|
break;
|
|
210
143
|
}
|
|
211
144
|
}
|
|
212
|
-
|
|
213
145
|
tf.dispose(cached_causal_mask);
|
|
214
|
-
await onTrainEnd?.()
|
|
215
|
-
|
|
146
|
+
await onTrainEnd?.();
|
|
216
147
|
return {};
|
|
217
148
|
}
|
|
218
|
-
|
|
219
|
-
|
|
220
149
|
/**
|
|
221
150
|
* Run the core forward and backward propagation on one training batch. This
|
|
222
151
|
* should be called within a `tf.tidy()`.
|
|
223
|
-
*
|
|
152
|
+
*
|
|
224
153
|
* @param xs the sample/input tensor
|
|
225
154
|
* @param ys the label/target tensor
|
|
226
155
|
* @param loss_mask a loss mask to ignore the prediction's non-assistant tokens
|
|
227
156
|
* @param loss_function the model's loss function
|
|
228
157
|
* @param other_masks other masks used by the model's layers e.g. packing mask, causal mask
|
|
229
158
|
*/
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
ys: tf.Tensor,
|
|
233
|
-
loss_mask: tf.Tensor | undefined,
|
|
234
|
-
loss_function: LossOrMetricFn,
|
|
235
|
-
other_masks?: { [key: string]: tf.Tensor | undefined }
|
|
236
|
-
): {
|
|
237
|
-
y_pred: tf.Tensor<tf.Rank>;
|
|
238
|
-
loss: tf.Scalar;
|
|
239
|
-
} {
|
|
240
|
-
let y_pred: tf.Tensor;
|
|
241
|
-
|
|
159
|
+
fitBatch(xs, ys, loss_mask, loss_function, other_masks) {
|
|
160
|
+
let y_pred;
|
|
242
161
|
// forward pass, calculate loss
|
|
243
162
|
const { value: loss, grads } = tf.variableGrads(() => {
|
|
244
163
|
// prediction has shape [batch, sequence_length, vocab_size]
|
|
245
164
|
y_pred = this.apply(xs, {
|
|
246
165
|
training: true,
|
|
247
166
|
...other_masks
|
|
248
|
-
})
|
|
249
|
-
|
|
167
|
+
});
|
|
250
168
|
// manually dispose later instead of the built-in disposal from variableGrads
|
|
251
169
|
tf.keep(y_pred);
|
|
252
|
-
|
|
253
170
|
const loss = loss_mask
|
|
254
171
|
? loss_function(ys, y_pred).mul(loss_mask)
|
|
255
172
|
: loss_function(ys, y_pred);
|
|
256
|
-
|
|
257
|
-
return loss.mean() as tf.Scalar;
|
|
173
|
+
return loss.mean();
|
|
258
174
|
});
|
|
259
|
-
|
|
260
175
|
// backpropagation
|
|
261
176
|
this.optimizer.applyGradients(grads);
|
|
262
|
-
|
|
263
177
|
return {
|
|
264
|
-
y_pred: y_pred
|
|
178
|
+
y_pred: y_pred,
|
|
265
179
|
loss
|
|
266
180
|
};
|
|
267
181
|
}
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
override compile(args: tf.ModelCompileArgs): void {
|
|
182
|
+
compile(args) {
|
|
271
183
|
if (args.loss == "categoricalCrossentropy") {
|
|
272
|
-
throw Error(`LlmModel.compile: use sparseCategoricalCrossentropy loss (along with onehot encoded labels) instead of categoricalCrossEntropy`)
|
|
184
|
+
throw Error(`LlmModel.compile: use sparseCategoricalCrossentropy loss (along with onehot encoded labels) instead of categoricalCrossEntropy`);
|
|
273
185
|
}
|
|
274
|
-
|
|
275
186
|
super.compile(args);
|
|
276
187
|
}
|
|
277
|
-
|
|
278
|
-
|
|
279
188
|
/**
|
|
280
189
|
* Autoregressively generate the next token until `model.stopPredicting` is set
|
|
281
190
|
* to `true` or the KV cache reaches its maximum sequence length. For a single chat
|
|
282
191
|
* session, the input should only be the most recent prompt(s). The KV cache stores
|
|
283
192
|
* the prior chat history up until the most recent chat.
|
|
284
|
-
*
|
|
193
|
+
*
|
|
285
194
|
* @param input tokenized input of the newest chat
|
|
286
195
|
* @param kv_cache an instance of a KV cache container
|
|
287
196
|
* @param onPredict callback function to receive the most recent token predicted
|
|
288
197
|
*/
|
|
289
|
-
|
|
198
|
+
async generate(input, kv_cache, onPredict) {
|
|
290
199
|
if (kv_cache.size >= kv_cache.maxSequenceLength) {
|
|
291
200
|
throw Error(`LlmModel.generate: ${this.name} KV cache's size reached the maxSequenceLength (${kv_cache.maxSequenceLength})`);
|
|
292
201
|
}
|
|
293
|
-
|
|
294
202
|
this.stopPredicting = false;
|
|
295
|
-
|
|
296
|
-
let current_token: tf.Tensor2D = tf.tidy(() => input.expandDims(0)) as tf.Tensor2D; // it's 2D because of the required batch dimension
|
|
297
|
-
|
|
203
|
+
let current_token = tf.tidy(() => input.expandDims(0)); // it's 2D because of the required batch dimension
|
|
298
204
|
while (!this.stopPredicting && kv_cache.size < kv_cache.maxSequenceLength) {
|
|
299
205
|
// add a batch dimension because forward pass requires inputs batched
|
|
300
206
|
const next_token = tf.tidy(() => this.predictNextToken(current_token, kv_cache));
|
|
301
|
-
|
|
302
207
|
// pass back the predicted token, without the batch dim,
|
|
303
208
|
const unbatched_next_token = tf.tidy(() => next_token.squeeze([0]));
|
|
304
209
|
await onPredict(unbatched_next_token);
|
|
305
|
-
|
|
306
210
|
unbatched_next_token.dispose();
|
|
307
|
-
|
|
308
211
|
current_token.dispose();
|
|
309
212
|
current_token = next_token;
|
|
310
213
|
}
|
|
311
|
-
|
|
312
214
|
tf.dispose(current_token);
|
|
313
215
|
}
|
|
314
|
-
|
|
315
|
-
|
|
316
216
|
/**
|
|
317
217
|
* Given a tokenized sentence, predict the next token (word).
|
|
318
218
|
* A normal prediction is ran to get an output with the shape
|
|
@@ -321,35 +221,25 @@ export class LlmModel extends tf.Sequential {
|
|
|
321
221
|
* position of `sentence_length` is returned as the next predicted
|
|
322
222
|
* token.
|
|
323
223
|
*/
|
|
324
|
-
|
|
224
|
+
predictNextToken(input, kv_cache) {
|
|
325
225
|
if (input.shape[0] != 1) {
|
|
326
226
|
throw Error(`LlmModel.predictNextToken: ${this.name} expects an input with a batch size of 1`);
|
|
327
227
|
}
|
|
328
|
-
|
|
329
228
|
return tf.tidy(() => {
|
|
330
229
|
// comes back as [batch, sequence_length, vocab_size]
|
|
331
|
-
const prediction = this.apply(input, { kvCache: kv_cache })
|
|
332
|
-
|
|
230
|
+
const prediction = this.apply(input, { kvCache: kv_cache });
|
|
333
231
|
const [batch_size, sequence_length, vocab_size] = prediction.shape;
|
|
334
|
-
|
|
335
232
|
// get the last token
|
|
336
|
-
const next_token = prediction.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size]).argMax(2)
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
})
|
|
233
|
+
const next_token = prediction.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size]).argMax(2);
|
|
234
|
+
return next_token;
|
|
235
|
+
});
|
|
340
236
|
}
|
|
341
|
-
|
|
342
|
-
|
|
343
237
|
get stopPredicting() {
|
|
344
238
|
return this.stopPredicting_;
|
|
345
239
|
}
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
set stopPredicting(stop: boolean) {
|
|
240
|
+
set stopPredicting(stop) {
|
|
349
241
|
this.stopPredicting_ = stop;
|
|
350
242
|
}
|
|
351
|
-
|
|
352
243
|
}
|
|
353
|
-
|
|
354
|
-
|
|
355
244
|
tf.serialization.registerClass(LlmModel);
|
|
245
|
+
//# sourceMappingURL=llm_model.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"llm_model.js","sourceRoot":"","sources":["../../src/models/llm_model.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,6BAA6B,EAAE,MAAM,qCAAqC,CAAC;AAEpF,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,UAAU,CAAC;AAExD,OAAO,KAAK,MAAM,MAAM,WAAW,CAAC;AAKnC,CAAC;AAWF;;;;GAIG;AACH,MAAM,OAAO,QAAS,SAAQ,EAAE,CAAC,UAAU;IACvC,MAAM,CAAC,SAAS,GAAG,UAAU,CAAC;IAEtB,eAAe,GAAY,IAAI,CAAC;IAExC,YAAY,IAAkB;QAC1B,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,OAAO,CAAC;QACjC,KAAK,CAAC,IAAI,CAAC,CAAC;IAChB,CAAC;IAGD;;;;;OAKG;IACO,kBAAkB;QACxB,MAAM,CAAC,IAAI,EAAE,GAAG,eAAe,CAAC,GAAG,IAAI,CAAC,YAAY,CAAC;QAErD,OAAO,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,aAAa,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC;YACtD,SAAS,EAAE,aAAa,CAAC,CAAC,CAAC;YAC3B,YAAY,EAAE,eAAe,CAAC,KAAK,CAAC;SACvC,CAAC,CAAC,CAAA;IACP,CAAC;IAGD;;;OAGG;IACO,eAAe;QACrB,IAAI,IAAI,GAAG,IAAI,CAAC,IAAI,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC;YACtB,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;QACnB,CAAC;QAED,IAAI,OAAO,IAAI,IAAI,QAAQ,EAAE,CAAC;YAC1B,IAAI,IAAI,IAAI,+BAA+B,EAAE,CAAC;gBAC1C,OAAO,6BAA6B,CAAC;gBACrC;;;gJAGgI;YACpI,CAAC;YAED,MAAM,OAAO,GAAG,IAAc,CAAC;YAE/B,MAAM,OAAO,GACT,CAAE,MAA8B,CAAC,OAAO,CAAC;gBACpC,EAAE,CAAC,MAA8B,CAAC,OAAO,CAAC;gBAC1C,EAAE,CAAC,OAA+B,CAAC,OAAO,CAAC,CAAmB,CAAA;YAEvE,IAAI,OAAO,EAAE,CAAC;gBACV,OAAO,OAAO,CAAA;YAClB,CAAC;iBAAM,CAAC;gBACJ,MAAM,KAAK,CAAC,6BAA6B,OAAO,+BAA+B,CAAC,CAAC;YACrF,CAAC;QACL,CAAC;aAAM,IAAI,OAAO,IAAI,IAAI,UAAU,EAAE,CAAC;YACnC,OAAO,IAAI,CAAC;QAChB,CAAC;QAED,MAAM,KAAK,CAAC,iFAAiF,CAAC,CAAC;IACnG,CAAC;IAGD;;;;;;;;OAQG;IACM,KAAK,CAAC,UAAU,CAAkB,SAAqB,EAAE,IAA+B;QAC7F,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC;QAE1B,MAAM,OAAO,GAAG,SAAyC,CAAC;QAC1D,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,GAAG,IAAI,CAAC;QAEnC,MAAM,gBAAgB,GAAG,IAAI,CAAC,kBAAkB,EAAE,CAAC;QACnD,MAAM,aAAa,GAAG,IAAI,CAAC,eAAe,EAAE,CAAC;QAC7C,IAAI,CAAC,aAAa,GAAG,CAAC,aAAa,CAAC,CAAC;QAErC,MAAM,EACF,YAAY,EACZ,UAAU,EACV,YAAY,EACZ,UAAU,EACV,YAAY,EACZ,UAAU,GACb,GAAG,SAAkC,IAAI,EAAE,CAAC;QAE7C,MAAM,YAAY,EAAE,EAAE,CAAC;QAEvB,IAAI,kBAAkB,GAA0B,SAAS,CAAC;QAE1D,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC;YAC1C,MAAM,YAAY,EAAE,CAAC,KAAK,CAAC,CAAC;YAE5B,IAAI,KAAK,GAAG,CAAC,CAAC;YACd,IAAI,aAAa,GAAG,CAAC,CAAC;YACtB,MAAM,yBAAyB,GAAiC,EAAE,CAAC;YAEnE,0CAA0C;YAC1C,MAAM,QAAQ,GAAG,MAAM,OAAO,CAAC,QAAQ,EAAE,CAAC;YAC1C,IAAI,MAAM,GAAG,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC;YAEnC,OAAO,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC;gBAClB,MAAM,aAAa,GAAiC,EAAE,KAAK,EAAE,CAAC;gBAE9D,MAAM,EAAE,EAAE,EAAE,EAAE,EAAE,SAAS,EAAE,YAAY,EAAE,GAAG,MAAM,CAAC,KAAK,CAAC;gBACzD,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBAC/B,aAAa,IAAI,UAAU,CAAC,CAAC,8BAA8B;gBAE3D,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;oBACvB,MAAM,KAAK,CAAC,wBAAwB,IAAI,CAAC,IAAI,0EAA0E,CAAC,CAAC;gBAC7H,CAAC;gBAED,iFAAiF;gBACjF,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;gBAEjD,IAAI,CAAC,kBAAkB,IAAI,kBAAkB,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,UAAU,EAAE,CAAC;oBACnE,kBAAkB,GAAG,kBAAkB,CAAC,UAAU,EAAE,UAAU,CAAC,CAAC;gBACpE,CAAC;gBAED,MAAM,YAAY,EAAE,CAAC,KAAK,CAAC,CAAC;gBAE5B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;oBACT,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,GAAG,IAAI,CAAC,QAAQ,CAAC,EAAE,EAAE,EAAE,EAAE,SAAS,EAAE,aAAa,EAAE;wBACrE,WAAW,EAAE,YAAY;wBACzB,UAAU,EAAE,kBAAkB;qBACjC,CAAC,CAAA;oBAEF,MAAM,UAAU,GAAG,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;oBAExC,aAAa,CAAC,IAAI,GAAG,UAAU,CAAC;oBAChC,yBAAyB,CAAC,IAAI,GAAG,CAAC,yBAAyB,CAAC,IAAI,IAAI,CAAC,CAAC,GAAG,UAAU,GAAG,UAAU,CAAC;oBAEjG,8BAA8B;oBAC9B,KAAK,MAAM,EAAE,SAAS,EAAE,YAAY,EAAE,IAAI,gBAAgB,EAAE,CAAC;wBACzD,MAAM,UAAU,GAAG,SAAS,CAAC,EAAE,EAAE,MAAO,CAAC,CAAC,IAAI,EAAE,CAAC;wBAEjD,MAAM,YAAY,GAAG,CAAC,UAAU,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;wBAEhD,aAAa,CAAC,YAAY,CAAC,GAAG,YAAY,CAAA,CAAA,gBAAgB;wBAC1D,yBAAyB,CAAC,YAAY,CAAC,GAAG,CAAC,yBAAyB,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC,GAAG,YAAY,GAAG,UAAU,CAAC;oBACzH,CAAC;oBAED,EAAE,CAAC,OAAO,CAAC,MAAO,CAAC,CAAC;gBACxB,CAAC,CAAC,CAAA;gBAEF,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;gBACf,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;gBACf,EAAE,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;gBAEtB,IAAI,YAAY,EAAE,CAAC;oBACf,EAAE,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC;gBAC7B,CAAC;gBAED,MAAM,UAAU,EAAE,CAAC,KAAK,EAAE,aAAa,CAAC,CAAC;gBAEzC,8BAA8B;gBAC9B,MAAM,EAAE,CAAC,SAAS,EAAE,CAAC;gBAErB,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;oBACpB,MAAM;gBACV,CAAC;gBAED,MAAM,GAAG,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC;gBAC/B,KAAK,EAAE,CAAC;YACZ,CAAC;YAED,KAAK,MAAM,MAAM,IAAI,yBAAyB,EAAE,CAAC;gBAC7C,yBAAyB,CAAC,MAAM,CAAC,GAAG,yBAAyB,CAAC,MAAM,CAAC,GAAG,aAAa,CAAC;YAC1F,CAAC;YAED,MAAM,UAAU,EAAE,CAAC,KAAK,EAAE,yBAAyB,CAAC,CAAC;YAErD,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;gBACpB,MAAM;YACV,CAAC;QACL,CAAC;QAED,EAAE,CAAC,OAAO,CAAC,kBAAkB,CAAC,CAAC;QAC/B,MAAM,UAAU,EAAE,EAAE,CAAA;QAEpB,OAAO,EAAE,CAAC;IACd,CAAC;IAGD;;;;;;;;;OASG;IACO,QAAQ,CACd,EAAa,EACb,EAAa,EACb,SAAgC,EAChC,aAA6B,EAC7B,WAAsD;QAKtD,IAAI,MAAiB,CAAC;QAEtB,+BAA+B;QAC/B,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,CAAC,aAAa,CAAC,GAAG,EAAE;YACjD,4DAA4D;YAC5D,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,EAAE,EAAE;gBACpB,QAAQ,EAAE,IAAI;gBACd,GAAG,WAAW;aACjB,CAAc,CAAC;YAEhB,6EAA6E;YAC7E,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAEhB,MAAM,IAAI,GAAG,SAAS;gBAClB,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC;gBAC1C,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC;YAEhC,OAAO,IAAI,CAAC,IAAI,EAAe,CAAC;QACpC,CAAC,CAAC,CAAC;QAEH,kBAAkB;QAClB,IAAI,CAAC,SAAS,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC;QAErC,OAAO;YACH,MAAM,EAAE,MAAO;YACf,IAAI;SACP,CAAC;IACN,CAAC;IAGQ,OAAO,CAAC,IAAyB;QACtC,IAAI,IAAI,CAAC,IAAI,IAAI,yBAAyB,EAAE,CAAC;YACzC,MAAM,KAAK,CAAC,gIAAgI,CAAC,CAAA;QACjJ,CAAC;QAED,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC;IACxB,CAAC;IAGD;;;;;;;;;OASG;IACI,KAAK,CAAC,QAAQ,CAAC,KAAkB,EAAE,QAA0B,EAAE,SAA8C;QAChH,IAAI,QAAQ,CAAC,IAAI,IAAI,QAAQ,CAAC,iBAAiB,EAAE,CAAC;YAC9C,MAAM,KAAK,CAAC,sBAAsB,IAAI,CAAC,IAAI,mDAAmD,QAAQ,CAAC,iBAAiB,GAAG,CAAC,CAAC;QACjI,CAAC;QAED,IAAI,CAAC,cAAc,GAAG,KAAK,CAAC;QAE5B,IAAI,aAAa,GAAgB,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAgB,CAAC,CAAC,kDAAkD;QAEtI,OAAO,CAAC,IAAI,CAAC,cAAc,IAAI,QAAQ,CAAC,IAAI,GAAG,QAAQ,CAAC,iBAAiB,EAAE,CAAC;YACxE,qEAAqE;YACrE,MAAM,UAAU,GAAG,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC,gBAAgB,CAAC,aAAa,EAAE,QAAQ,CAAC,CAAC,CAAC;YAEjF,wDAAwD;YACxD,MAAM,oBAAoB,GAAG,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACpE,MAAM,SAAS,CAAC,oBAAoB,CAAC,CAAC;YAEtC,oBAAoB,CAAC,OAAO,EAAE,CAAC;YAE/B,aAAa,CAAC,OAAO,EAAE,CAAC;YACxB,aAAa,GAAG,UAAU,CAAC;QAC/B,CAAC;QAED,EAAE,CAAC,OAAO,CAAC,aAAa,CAAC,CAAC;IAC9B,CAAC;IAGD;;;;;;;OAOG;IACI,gBAAgB,CAAC,KAAkB,EAAE,QAA0B;QAClE,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;YACtB,MAAM,KAAK,CAAC,8BAA8B,IAAI,CAAC,IAAI,0CAA0C,CAAC,CAAC;QACnG,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,qDAAqD;YACrD,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC;YAEzE,MAAM,CAAC,UAAU,EAAE,eAAe,EAAE,UAAU,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC;YAEnE,qBAAqB;YACrB,MAAM,UAAU,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,eAAe,GAAG,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,UAAU,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAA;YAEvG,OAAO,UAAyB,CAAC;QACrC,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAI,cAAc;QACd,OAAO,IAAI,CAAC,eAAe,CAAC;IAChC,CAAC;IAGD,IAAI,cAAc,CAAC,IAAa;QAC5B,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;IAChC,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC"}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
3
|
+
export interface UNetArgs {
|
|
4
|
+
/**
|
|
5
|
+
* The starting number of filters.
|
|
6
|
+
*/
|
|
7
|
+
filters: number;
|
|
8
|
+
/**
|
|
9
|
+
* The number of categories. For binary segmentation, `units=1`.
|
|
10
|
+
*/
|
|
11
|
+
units: number;
|
|
12
|
+
/**
|
|
13
|
+
* The activation of the final output convolution layer. Defaults to `sigmoid` if `categories=1`, else `softmax`.
|
|
14
|
+
*/
|
|
15
|
+
activation?: ActivationIdentifier;
|
|
16
|
+
/**
|
|
17
|
+
* The depth of the U-Net or the number of contractions and the number of expansions.
|
|
18
|
+
*/
|
|
19
|
+
depth: number;
|
|
20
|
+
/**
|
|
21
|
+
* Adds residual connections to transform the model into a ResUNet. Defaults to `false`.
|
|
22
|
+
*/
|
|
23
|
+
residual?: boolean;
|
|
24
|
+
/**
|
|
25
|
+
* Adds batch normalization to convolutions. Best used for batched inputs. Defaults to `false`.
|
|
26
|
+
*/
|
|
27
|
+
batchNorm?: boolean;
|
|
28
|
+
/**
|
|
29
|
+
* Set the unbatched input shape of the U-Net in the format `[height, width, channels]`. Defaults to `[null, null, 3]`. If set, only channels is mandatory.
|
|
30
|
+
*/
|
|
31
|
+
inputShape?: [number | null, number | null, number];
|
|
32
|
+
}
|
|
33
|
+
export type UNetModelArgs = UNetArgs & Omit<tf.SequentialArgs, "layers">;
|
|
34
|
+
export declare class UNetModel extends tf.Sequential {
|
|
35
|
+
constructor(args: UNetModelArgs);
|
|
36
|
+
summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void;
|
|
37
|
+
}
|
|
38
|
+
export declare function createUNet({ filters, depth, units, activation, residual, batchNorm, inputShape }: UNetModelArgs): tf.LayersModel;
|
|
39
|
+
export declare function loadUNetModel(pathOrIOHandler: string | tf.io.IOHandler, options?: tf.io.LoadOptions): Promise<tf.LayersModel>;
|
|
40
|
+
//# sourceMappingURL=u_net.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"u_net.d.ts","sourceRoot":"","sources":["../../src/models/u_net.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAGxG,MAAM,WAAW,QAAQ;IACrB;;OAEG;IACH,OAAO,EAAE,MAAM,CAAC;IAChB;;OAEG;IACH,KAAK,EAAE,MAAM,CAAC;IACd;;OAEG;IACH,UAAU,CAAC,EAAE,oBAAoB,CAAC;IAClC;;OAEG;IACH,KAAK,EAAE,MAAM,CAAC;IACd;;OAEG;IACH,QAAQ,CAAC,EAAE,OAAO,CAAC;IACnB;;OAEG;IACH,SAAS,CAAC,EAAE,OAAO,CAAC;IACpB;;OAEG;IACH,UAAU,CAAC,EAAE,CAAC,MAAM,GAAG,IAAI,EAAE,MAAM,GAAG,IAAI,EAAE,MAAM,CAAC,CAAC;CACvD;AAGD,MAAM,MAAM,aAAa,GAAG,QAAQ,GAAG,IAAI,CAAC,EAAE,CAAC,cAAc,EAAE,QAAQ,CAAC,CAAC;AAGzE,qBAAa,SAAU,SAAQ,EAAE,CAAC,UAAU;gBAE5B,IAAI,EAAE,aAAa;IAsBtB,OAAO,CAAC,UAAU,CAAC,EAAE,MAAM,EAAE,SAAS,CAAC,EAAE,MAAM,EAAE,EAAE,OAAO,CAAC,EAAE,CAAC,OAAO,CAAC,EAAE,GAAG,EAAE,GAAG,cAAc,EAAE,GAAG,EAAE,KAAK,IAAI,GAAG,IAAI;CAIjI;AAGD,wBAAgB,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,QAAgB,EAAE,SAAiB,EAAE,UAA4B,EAAE,EAAE,aAAa,kBA4CjJ;AAGD,wBAAsB,aAAa,CAAC,eAAe,EAAE,MAAM,GAAG,EAAE,CAAC,EAAE,CAAC,SAAS,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,WAAW,2BAOzG"}
|