@stellarapp/tfjs-stellar 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/jest.config.ts +203 -0
- package/package.json +24 -0
- package/src/index.ts +93 -0
- package/src/kv_cache.ts +205 -0
- package/src/layers/cached_rope_multihead_attention.test.ts +59 -0
- package/src/layers/cached_rope_multihead_attention.ts +113 -0
- package/src/layers/gpt_decoder_block.ts +77 -0
- package/src/layers/multihead_attention.test.ts +212 -0
- package/src/layers/multihead_attention.ts +371 -0
- package/src/layers/positional_encoding.test.ts +113 -0
- package/src/layers/positional_encoding.ts +158 -0
- package/src/layers/rotary_position_embedding.test.ts +107 -0
- package/src/layers/rotary_position_embedding.ts +163 -0
- package/src/layers/token_and_positional_embedding.test.ts +81 -0
- package/src/layers/token_and_positional_embedding.ts +149 -0
- package/src/layers/transformer_decoder.test.ts +100 -0
- package/src/layers/transformer_decoder.ts +236 -0
- package/src/layers/transformer_encoder.test.ts +85 -0
- package/src/layers/transformer_encoder.ts +224 -0
- package/src/losses/dice.ts +156 -0
- package/src/losses/index.ts +1 -0
- package/src/metrics.ts +32 -0
- package/src/models/gpt_model.ts +232 -0
- package/src/models/index.ts +2 -0
- package/src/models/llm_model.ts +355 -0
- package/src/models/u_net.ts +240 -0
- package/src/packing_mask.ts +28 -0
- package/src/testing.ts +1 -0
- package/src/tfjs_types.ts +15 -0
- package/src/utils.test.ts +101 -0
- package/src/utils.ts +86 -0
- package/tsconfig.json +49 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export * from "./dice";
|
package/src/metrics.ts
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import { metrics, Tensor } from "@tensorflow/tfjs";
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Applies the recall metric with the prediction rounded based on a threshold
|
|
6
|
+
*
|
|
7
|
+
* @param y_true the label tensor
|
|
8
|
+
* @param y_pred the prediction tensor
|
|
9
|
+
* @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
|
|
10
|
+
* @returns
|
|
11
|
+
*/
|
|
12
|
+
export function recall(y_true: Tensor, y_pred: Tensor, threshold: number = 0.5) {
|
|
13
|
+
return metrics.recall(y_true, y_pred.greaterEqual(threshold));
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
// prevents minification of function name which TFJS relies on
|
|
17
|
+
Object.defineProperty(recall, "name", { value: "recall", configurable: false });
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* Applies the precision metric with the prediction rounded based on a threshold
|
|
21
|
+
*
|
|
22
|
+
* @param y_true the label tensor
|
|
23
|
+
* @param y_pred the prediction tensor
|
|
24
|
+
* @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
|
|
25
|
+
* @returns
|
|
26
|
+
*/
|
|
27
|
+
export function precision(y_true: Tensor, y_pred: Tensor, threshold: number = 0.5) {
|
|
28
|
+
return metrics.precision(y_true, y_pred.greaterEqual(threshold));
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
// prevents minification of function name which TFJS relies on
|
|
32
|
+
Object.defineProperty(precision, "name", { value: "precision", configurable: false });
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import * as tfc from "@/.";
|
|
3
|
+
import { type LossOrMetricFn } from "@/tfjs_types";
|
|
4
|
+
import { LlmModel, type LlmModelArgs } from "@/models/llm_model";
|
|
5
|
+
import { KvCacheContainer } from "@/kv_cache";
|
|
6
|
+
import { type DisposeResult } from "@tensorflow/tfjs-layers/dist/engine/topology";
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
export interface GptModelArgs extends LlmModelArgs {
|
|
10
|
+
/**
|
|
11
|
+
* Number of heads per attention layer.
|
|
12
|
+
*/
|
|
13
|
+
numHeads: number;
|
|
14
|
+
/**
|
|
15
|
+
* Number of GPT decoder blocks.
|
|
16
|
+
*/
|
|
17
|
+
numLayers: number;
|
|
18
|
+
/**
|
|
19
|
+
* The embedding size of each token.
|
|
20
|
+
*/
|
|
21
|
+
embedDim: number;
|
|
22
|
+
/**
|
|
23
|
+
* The vocabulary size of the embedding layer and number of units of the output
|
|
24
|
+
* layer. This is also the tokenizer vocabulary size.
|
|
25
|
+
*/
|
|
26
|
+
vocabSize: number;
|
|
27
|
+
/**
|
|
28
|
+
* Pad the embeddings' vocab size and output layer's units to the next nearest
|
|
29
|
+
* multiple of 64 to optimize hardware efficiency. Defaults to `true`.
|
|
30
|
+
*
|
|
31
|
+
* For example: if a tokenizer has 50,257 tokens, the model uses 50,304 for the
|
|
32
|
+
* vocab size and output units count.
|
|
33
|
+
*/
|
|
34
|
+
padToMultipleOf64?: boolean;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
/**
|
|
39
|
+
* This is a subclass of tf.Sequential that creating a GPT-like model and
|
|
40
|
+
* automatically handles padding (and masking) the vocab size for hardware
|
|
41
|
+
* efficiency.
|
|
42
|
+
*
|
|
43
|
+
* Example:
|
|
44
|
+
*
|
|
45
|
+
* ```javascript
|
|
46
|
+
*
|
|
47
|
+
* const model = new GptModel({ numLayers: 1, numHeads: 1, embedDim: 16, vocabSize: 64 });
|
|
48
|
+
* model.compile({ loss: "sparseCategoricalCrossentropy", optimizer: "adam" });
|
|
49
|
+
*
|
|
50
|
+
* // use fitDataset() instead of fit for masking support
|
|
51
|
+
* model.fitDataset(your_batched_generator_dataset, { epochs: 1 });
|
|
52
|
+
*
|
|
53
|
+
* const kv_cache = new KvCacheContainer(your_preferred_max_sequence_length);
|
|
54
|
+
*
|
|
55
|
+
* // use generate() and predictNextToken() instead of predict() for masking and auto memory cleanup
|
|
56
|
+
* model.generate(tokenized_tensor1d_input, kv_cache, onPredict_callback)
|
|
57
|
+
*
|
|
58
|
+
*
|
|
59
|
+
* ```
|
|
60
|
+
*/
|
|
61
|
+
export class GptModel extends LlmModel {
|
|
62
|
+
static className = "GptModel";
|
|
63
|
+
|
|
64
|
+
protected readonly numHeads: number;
|
|
65
|
+
protected readonly numLayers: number;
|
|
66
|
+
protected readonly embedDim: number;
|
|
67
|
+
protected readonly vocabSize: number;
|
|
68
|
+
protected readonly padToMultipleOf64: boolean;
|
|
69
|
+
|
|
70
|
+
// this is kept for reproducibility and model history but is not important since
|
|
71
|
+
// it can be calculated mathematically
|
|
72
|
+
protected readonly vocabSizePadded: number;
|
|
73
|
+
|
|
74
|
+
// the amount to pad the embedding vocab size and dense output units count
|
|
75
|
+
protected vocab_padding_mask?: tf.Tensor1D;
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
/**
|
|
79
|
+
* DO NOT add layers in the constructor or it will break tf.loadLayersModel().
|
|
80
|
+
* It should be done in build() instead.
|
|
81
|
+
*/
|
|
82
|
+
constructor(args: GptModelArgs) {
|
|
83
|
+
const { numHeads, numLayers, embedDim, vocabSize, padToMultipleOf64 = true, ...rest } = args;
|
|
84
|
+
|
|
85
|
+
super({ name: "model", ...rest });
|
|
86
|
+
|
|
87
|
+
this.numHeads = numHeads;
|
|
88
|
+
this.numLayers = numLayers;
|
|
89
|
+
this.embedDim = embedDim;
|
|
90
|
+
this.vocabSize = vocabSize;
|
|
91
|
+
this.padToMultipleOf64 = padToMultipleOf64;
|
|
92
|
+
this.vocabSizePadded = this.padToMultipleOf64
|
|
93
|
+
? Math.ceil(this.vocabSize / 64) * 64
|
|
94
|
+
: this.vocabSize;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
protected override fitBatch(
|
|
99
|
+
xs: tf.Tensor,
|
|
100
|
+
ys: tf.Tensor,
|
|
101
|
+
loss_mask: tf.Tensor | undefined,
|
|
102
|
+
loss_function: LossOrMetricFn,
|
|
103
|
+
other_masks?: { [key: string]: tf.Tensor | undefined }
|
|
104
|
+
) {
|
|
105
|
+
let y_pred: tf.Tensor;
|
|
106
|
+
|
|
107
|
+
// forward pass, calculate loss
|
|
108
|
+
const { value: loss, grads } = tf.variableGrads(() => {
|
|
109
|
+
y_pred = this.apply(xs, {
|
|
110
|
+
training: true,
|
|
111
|
+
...other_masks
|
|
112
|
+
}) as tf.Tensor;
|
|
113
|
+
|
|
114
|
+
// apply vocab pad masking
|
|
115
|
+
if (this.vocab_padding_mask) {
|
|
116
|
+
y_pred = y_pred.add(this.vocab_padding_mask);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
y_pred = tf.softmax(y_pred);
|
|
120
|
+
|
|
121
|
+
// manually dispose later instead of the built-in disposal from variableGrads
|
|
122
|
+
tf.keep(y_pred);
|
|
123
|
+
|
|
124
|
+
const loss = loss_mask
|
|
125
|
+
? loss_function(ys, y_pred).mul(loss_mask)
|
|
126
|
+
: loss_function(ys, y_pred);
|
|
127
|
+
|
|
128
|
+
return loss.mean() as tf.Scalar;
|
|
129
|
+
});
|
|
130
|
+
|
|
131
|
+
// backpropagation
|
|
132
|
+
this.optimizer.applyGradients(grads);
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
y_pred: y_pred!,
|
|
136
|
+
loss
|
|
137
|
+
};
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
/**
|
|
142
|
+
* Overrides LlmModel.predictNextToken to add softmax before argMax because the final
|
|
143
|
+
* dense layer doesn't have an activation.
|
|
144
|
+
*
|
|
145
|
+
* TODO: implement temperature and multinomial sampling so that the model has varied outputs
|
|
146
|
+
*/
|
|
147
|
+
override predictNextToken(input: tf.Tensor2D, kv_cache: KvCacheContainer): tf.Tensor2D {
|
|
148
|
+
if (input.shape[0] != 1) {
|
|
149
|
+
throw Error(`GptModel.predictNextToken: ${this.name} expects an input with a batch size of 1`);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
return tf.tidy(() => {
|
|
153
|
+
// comes back as [batch, sequence_length, vocab_size]
|
|
154
|
+
const prediction = this.apply(input, { kvCache: kv_cache }) as tf.Tensor;
|
|
155
|
+
|
|
156
|
+
const [batch_size, sequence_length, vocab_size] = prediction.shape;
|
|
157
|
+
|
|
158
|
+
// get the last token
|
|
159
|
+
const next_token = this.vocab_padding_mask != undefined
|
|
160
|
+
? prediction
|
|
161
|
+
.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size])
|
|
162
|
+
.add(this.vocab_padding_mask)
|
|
163
|
+
.softmax()
|
|
164
|
+
.argMax(2)
|
|
165
|
+
: prediction
|
|
166
|
+
.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size])
|
|
167
|
+
.softmax()
|
|
168
|
+
.argMax(2);
|
|
169
|
+
|
|
170
|
+
return next_token as tf.Tensor2D;
|
|
171
|
+
})
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
override build(inputShape?: tf.Shape | tf.Shape[]): void {
|
|
176
|
+
const actual_vocab_size = this.vocabSizePadded
|
|
177
|
+
? this.vocabSizePadded
|
|
178
|
+
: this.padToMultipleOf64
|
|
179
|
+
? Math.ceil(this.vocabSize / 64) * 64
|
|
180
|
+
: this.vocabSize
|
|
181
|
+
|
|
182
|
+
if (this.layers.length == 0) {
|
|
183
|
+
[
|
|
184
|
+
tf.layers.embedding({ inputDim: actual_vocab_size, outputDim: this.embedDim, batchInputShape: [null, null] }),
|
|
185
|
+
...Array(this.numLayers).fill(0).map(_ => tfc.gpt2DecoderBlock({ numHeads: this.numHeads, embedDim: this.embedDim })),
|
|
186
|
+
tf.layers.dense({ units: actual_vocab_size })
|
|
187
|
+
].forEach(layer => this.add(layer))
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
if (this.vocab_padding_mask) {
|
|
191
|
+
this.vocab_padding_mask.dispose();
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
if (this.padToMultipleOf64 && actual_vocab_size > this.vocabSize) {
|
|
195
|
+
this.vocab_padding_mask = tf.tidy(() => tf.where<tf.Tensor1D>(
|
|
196
|
+
// Create a mask of padded vocab length, values after the index "vocabSize"
|
|
197
|
+
// are set to -1e7 to mask out those positions so that softmax will ignore
|
|
198
|
+
// them. This mask is added to the final dense layer's output
|
|
199
|
+
tf.range(0, actual_vocab_size).greaterEqual(this.vocabSize), -1e7, 0).toFloat())
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
super.build(inputShape);
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
override dispose(): DisposeResult {
|
|
207
|
+
this.vocab_padding_mask?.dispose();
|
|
208
|
+
return super.dispose();
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
override getConfig() {
|
|
213
|
+
const base_config = super.getConfig();
|
|
214
|
+
|
|
215
|
+
const config = {
|
|
216
|
+
numHeads: this.numHeads,
|
|
217
|
+
numLayers: this.numLayers,
|
|
218
|
+
embedDim: this.embedDim,
|
|
219
|
+
vocabSize: this.vocabSize,
|
|
220
|
+
vocabSizePadded: this.vocabSizePadded,
|
|
221
|
+
padToMultipleOf64: this.padToMultipleOf64
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
Object.assign(config, base_config);
|
|
225
|
+
|
|
226
|
+
return config;
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
tf.serialization.registerClass(GptModel);
|
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import * as tfc from "@/index";
|
|
3
|
+
import { sparseCategoricalCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
|
|
4
|
+
import { Dataset, type LossOrMetricFn } from "@/tfjs_types";
|
|
5
|
+
import { generateCausalAttentionMask } from "@/utils";
|
|
6
|
+
import { KvCacheContainer } from "@/kv_cache";
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
// eslint-disable-next-line
|
|
10
|
+
export interface LlmModelArgs extends tf.SequentialArgs {
|
|
11
|
+
};
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
interface DatasetArgs extends tf.TensorContainerObject {
|
|
15
|
+
xs: tf.Tensor;
|
|
16
|
+
ys: tf.Tensor;
|
|
17
|
+
loss_mask?: tf.Tensor;
|
|
18
|
+
packing_mask?: tf.Tensor;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
* This class overrides the `fitDataset()` function of tf.Sequential to support loss
|
|
24
|
+
* and packing masking. Use the `generate()` function to autoregressively predict the
|
|
25
|
+
* next, set `stopPredicting=true` to stop.
|
|
26
|
+
*/
|
|
27
|
+
export class LlmModel extends tf.Sequential {
|
|
28
|
+
static className = "LlmModel";
|
|
29
|
+
|
|
30
|
+
private stopPredicting_: boolean = true;
|
|
31
|
+
|
|
32
|
+
constructor(args: LlmModelArgs) {
|
|
33
|
+
args.name = args.name ?? "model";
|
|
34
|
+
super(args);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
/**
|
|
39
|
+
* Returns the metric functions and names so that metrics can be reported
|
|
40
|
+
* as they are in the base version of model.fitDataset
|
|
41
|
+
*
|
|
42
|
+
* e.g. "categoricalAccuracy" should be reported as "acc"
|
|
43
|
+
*/
|
|
44
|
+
protected getMetricFunctions() {
|
|
45
|
+
const [loss, ...metric_fn_names] = this.metricsNames;
|
|
46
|
+
|
|
47
|
+
return this.metricsTensors.map((metric_tensor, index) => ({
|
|
48
|
+
metric_fn: metric_tensor[0],
|
|
49
|
+
metric_label: metric_fn_names[index]
|
|
50
|
+
}))
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Get exactly one loss function from the loss function provided in `model.compile()`.
|
|
56
|
+
* If a string identifier was used, convert it to the actual loss function.
|
|
57
|
+
*/
|
|
58
|
+
protected getLossFunction(): LossOrMetricFn {
|
|
59
|
+
let loss = this.loss;
|
|
60
|
+
|
|
61
|
+
if (Array.isArray(loss)) {
|
|
62
|
+
loss = loss[0];
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if (typeof loss == "string") {
|
|
66
|
+
if (loss == "sparseCategoricalCrossentropy") {
|
|
67
|
+
return sparseCategoricalCrossentropy;
|
|
68
|
+
/* throw Error("LlmModel.getLossFunction: TFJS's sparseCategoricalCrossentropy" +
|
|
69
|
+
" is not truly sparse, it simply converts it to onehot." +
|
|
70
|
+
" Use categoricalCrossentropy instead. See" +
|
|
71
|
+
" https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146"); */
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
const loss_id = loss as string;
|
|
75
|
+
|
|
76
|
+
const loss_fn =
|
|
77
|
+
((tfc.losses as Record<string, any>)[loss_id] ??
|
|
78
|
+
(tf.losses as Record<string, any>)[loss_id] ??
|
|
79
|
+
(tf.metrics as Record<string, any>)[loss_id]) as LossOrMetricFn
|
|
80
|
+
|
|
81
|
+
if (loss_fn) {
|
|
82
|
+
return loss_fn
|
|
83
|
+
} else {
|
|
84
|
+
throw Error(`LlmModel.getLossFunction: ${loss_id} is not a valid loss function`);
|
|
85
|
+
}
|
|
86
|
+
} else if (typeof loss == "function") {
|
|
87
|
+
return loss;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
throw Error("LlmModel.getLossFunction: the loss function's type should be string or function");
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
/**
|
|
95
|
+
* Train on a `tf.data.generator` dataset. See https://js.tensorflow.org/api/latest/#data.generator.
|
|
96
|
+
*
|
|
97
|
+
* The generator should yield `xs`, `ys`, `loss_mask` (if fine-tuning), and
|
|
98
|
+
* `packing_mask` (if sequence packing was done)
|
|
99
|
+
*
|
|
100
|
+
* @param tfdataset an instance of a `tf.Dataset` generator
|
|
101
|
+
* @param args a ModelFitDatasetArgs
|
|
102
|
+
*/
|
|
103
|
+
override async fitDataset<T = DatasetArgs>(tfdataset: Dataset<T>, args: tf.ModelFitDatasetArgs<T>): Promise<any> {
|
|
104
|
+
this.stopTraining = false;
|
|
105
|
+
|
|
106
|
+
const dataset = tfdataset as tf.data.Dataset<DatasetArgs>;
|
|
107
|
+
const { epochs, callbacks } = args;
|
|
108
|
+
|
|
109
|
+
const metric_functions = this.getMetricFunctions();
|
|
110
|
+
const loss_function = this.getLossFunction();
|
|
111
|
+
this.lossFunctions = [loss_function];
|
|
112
|
+
|
|
113
|
+
const {
|
|
114
|
+
onBatchBegin,
|
|
115
|
+
onBatchEnd,
|
|
116
|
+
onEpochBegin,
|
|
117
|
+
onEpochEnd,
|
|
118
|
+
onTrainBegin,
|
|
119
|
+
onTrainEnd,
|
|
120
|
+
} = callbacks as tf.CustomCallbackArgs ?? {};
|
|
121
|
+
|
|
122
|
+
await onTrainBegin?.();
|
|
123
|
+
|
|
124
|
+
let cached_causal_mask: tf.Tensor | undefined = undefined;
|
|
125
|
+
|
|
126
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
127
|
+
await onEpochBegin?.(epoch);
|
|
128
|
+
|
|
129
|
+
let batch = 0;
|
|
130
|
+
let total_samples = 0;
|
|
131
|
+
const accumulated_epoch_metrics: { [metric: string]: number } = {};
|
|
132
|
+
|
|
133
|
+
// loop through dataset using its iterator
|
|
134
|
+
const iterator = await dataset.iterator();
|
|
135
|
+
let sample = await iterator.next();
|
|
136
|
+
|
|
137
|
+
while (!sample.done) {
|
|
138
|
+
const batch_metrics: { [metric: string]: number } = { batch };
|
|
139
|
+
|
|
140
|
+
const { xs, ys, loss_mask, packing_mask } = sample.value;
|
|
141
|
+
const batch_size = xs.shape[0];
|
|
142
|
+
total_samples += batch_size; // for epoch metrics averaging
|
|
143
|
+
|
|
144
|
+
if (xs.shape.length != 2) {
|
|
145
|
+
throw Error(`LlmModel.fitDataset: ${this.name} the generator dataset should be batched, run: dataset.batch(batch_size)`);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// pre-calculate the causal attention mask and reuse it for all attention layers,
|
|
149
|
+
const seq_length = xs.shape[xs.shape.length - 1];
|
|
150
|
+
|
|
151
|
+
if (!cached_causal_mask || cached_causal_mask.shape[0] != seq_length) {
|
|
152
|
+
cached_causal_mask = generateCausalAttentionMask(seq_length, seq_length);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
await onBatchBegin?.(batch);
|
|
156
|
+
|
|
157
|
+
tf.tidy(() => {
|
|
158
|
+
const { y_pred, loss } = this.fitBatch(xs, ys, loss_mask, loss_function, {
|
|
159
|
+
packingMask: packing_mask,
|
|
160
|
+
causalMask: cached_causal_mask
|
|
161
|
+
})
|
|
162
|
+
|
|
163
|
+
const loss_value = (loss.dataSync())[0];
|
|
164
|
+
|
|
165
|
+
batch_metrics.loss = loss_value;
|
|
166
|
+
accumulated_epoch_metrics.loss = (accumulated_epoch_metrics.loss || 0) + loss_value * batch_size;
|
|
167
|
+
|
|
168
|
+
// calculate and store metrics
|
|
169
|
+
for (const { metric_fn, metric_label } of metric_functions) {
|
|
170
|
+
const metric_sum = metric_fn(ys, y_pred!).mean();
|
|
171
|
+
|
|
172
|
+
const metric_value = (metric_sum.dataSync())[0];
|
|
173
|
+
|
|
174
|
+
batch_metrics[metric_label] = metric_value// / batch_size;
|
|
175
|
+
accumulated_epoch_metrics[metric_label] = (accumulated_epoch_metrics[metric_label] || 0) + metric_value * batch_size;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
tf.dispose(y_pred!);
|
|
179
|
+
})
|
|
180
|
+
|
|
181
|
+
tf.dispose(xs);
|
|
182
|
+
tf.dispose(ys);
|
|
183
|
+
tf.dispose(loss_mask);
|
|
184
|
+
|
|
185
|
+
if (packing_mask) {
|
|
186
|
+
tf.dispose(packing_mask);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
await onBatchEnd?.(batch, batch_metrics);
|
|
190
|
+
|
|
191
|
+
// so that stop training works
|
|
192
|
+
await tf.nextFrame();
|
|
193
|
+
|
|
194
|
+
if (this.stopTraining) {
|
|
195
|
+
break;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
sample = await iterator.next();
|
|
199
|
+
batch++;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
for (const metric in accumulated_epoch_metrics) {
|
|
203
|
+
accumulated_epoch_metrics[metric] = accumulated_epoch_metrics[metric] / total_samples;
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
await onEpochEnd?.(epoch, accumulated_epoch_metrics);
|
|
207
|
+
|
|
208
|
+
if (this.stopTraining) {
|
|
209
|
+
break;
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
tf.dispose(cached_causal_mask);
|
|
214
|
+
await onTrainEnd?.()
|
|
215
|
+
|
|
216
|
+
return {};
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
/**
|
|
221
|
+
* Run the core forward and backward propagation on one training batch. This
|
|
222
|
+
* should be called within a `tf.tidy()`.
|
|
223
|
+
*
|
|
224
|
+
* @param xs the sample/input tensor
|
|
225
|
+
* @param ys the label/target tensor
|
|
226
|
+
* @param loss_mask a loss mask to ignore the prediction's non-assistant tokens
|
|
227
|
+
* @param loss_function the model's loss function
|
|
228
|
+
* @param other_masks other masks used by the model's layers e.g. packing mask, causal mask
|
|
229
|
+
*/
|
|
230
|
+
protected fitBatch(
|
|
231
|
+
xs: tf.Tensor,
|
|
232
|
+
ys: tf.Tensor,
|
|
233
|
+
loss_mask: tf.Tensor | undefined,
|
|
234
|
+
loss_function: LossOrMetricFn,
|
|
235
|
+
other_masks?: { [key: string]: tf.Tensor | undefined }
|
|
236
|
+
): {
|
|
237
|
+
y_pred: tf.Tensor<tf.Rank>;
|
|
238
|
+
loss: tf.Scalar;
|
|
239
|
+
} {
|
|
240
|
+
let y_pred: tf.Tensor;
|
|
241
|
+
|
|
242
|
+
// forward pass, calculate loss
|
|
243
|
+
const { value: loss, grads } = tf.variableGrads(() => {
|
|
244
|
+
// prediction has shape [batch, sequence_length, vocab_size]
|
|
245
|
+
y_pred = this.apply(xs, {
|
|
246
|
+
training: true,
|
|
247
|
+
...other_masks
|
|
248
|
+
}) as tf.Tensor;
|
|
249
|
+
|
|
250
|
+
// manually dispose later instead of the built-in disposal from variableGrads
|
|
251
|
+
tf.keep(y_pred);
|
|
252
|
+
|
|
253
|
+
const loss = loss_mask
|
|
254
|
+
? loss_function(ys, y_pred).mul(loss_mask)
|
|
255
|
+
: loss_function(ys, y_pred);
|
|
256
|
+
|
|
257
|
+
return loss.mean() as tf.Scalar;
|
|
258
|
+
});
|
|
259
|
+
|
|
260
|
+
// backpropagation
|
|
261
|
+
this.optimizer.applyGradients(grads);
|
|
262
|
+
|
|
263
|
+
return {
|
|
264
|
+
y_pred: y_pred!,
|
|
265
|
+
loss
|
|
266
|
+
};
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
override compile(args: tf.ModelCompileArgs): void {
|
|
271
|
+
if (args.loss == "categoricalCrossentropy") {
|
|
272
|
+
throw Error(`LlmModel.compile: use sparseCategoricalCrossentropy loss (along with onehot encoded labels) instead of categoricalCrossEntropy`)
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
super.compile(args);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
/**
|
|
280
|
+
* Autoregressively generate the next token until `model.stopPredicting` is set
|
|
281
|
+
* to `true` or the KV cache reaches its maximum sequence length. For a single chat
|
|
282
|
+
* session, the input should only be the most recent prompt(s). The KV cache stores
|
|
283
|
+
* the prior chat history up until the most recent chat.
|
|
284
|
+
*
|
|
285
|
+
* @param input tokenized input of the newest chat
|
|
286
|
+
* @param kv_cache an instance of a KV cache container
|
|
287
|
+
* @param onPredict callback function to receive the most recent token predicted
|
|
288
|
+
*/
|
|
289
|
+
public async generate(input: tf.Tensor1D, kv_cache: KvCacheContainer, onPredict: (token: tf.Tensor) => Promise<void>) {
|
|
290
|
+
if (kv_cache.size >= kv_cache.maxSequenceLength) {
|
|
291
|
+
throw Error(`LlmModel.generate: ${this.name} KV cache's size reached the maxSequenceLength (${kv_cache.maxSequenceLength})`);
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
this.stopPredicting = false;
|
|
295
|
+
|
|
296
|
+
let current_token: tf.Tensor2D = tf.tidy(() => input.expandDims(0)) as tf.Tensor2D; // it's 2D because of the required batch dimension
|
|
297
|
+
|
|
298
|
+
while (!this.stopPredicting && kv_cache.size < kv_cache.maxSequenceLength) {
|
|
299
|
+
// add a batch dimension because forward pass requires inputs batched
|
|
300
|
+
const next_token = tf.tidy(() => this.predictNextToken(current_token, kv_cache));
|
|
301
|
+
|
|
302
|
+
// pass back the predicted token, without the batch dim,
|
|
303
|
+
const unbatched_next_token = tf.tidy(() => next_token.squeeze([0]));
|
|
304
|
+
await onPredict(unbatched_next_token);
|
|
305
|
+
|
|
306
|
+
unbatched_next_token.dispose();
|
|
307
|
+
|
|
308
|
+
current_token.dispose();
|
|
309
|
+
current_token = next_token;
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
tf.dispose(current_token);
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
/**
|
|
317
|
+
* Given a tokenized sentence, predict the next token (word).
|
|
318
|
+
* A normal prediction is ran to get an output with the shape
|
|
319
|
+
* `[ batch_size, sentence_length, vocab_size ]` and the `vocab_size`
|
|
320
|
+
* position with the highest scored probability in the last
|
|
321
|
+
* position of `sentence_length` is returned as the next predicted
|
|
322
|
+
* token.
|
|
323
|
+
*/
|
|
324
|
+
public predictNextToken(input: tf.Tensor2D, kv_cache: KvCacheContainer) {
|
|
325
|
+
if (input.shape[0] != 1) {
|
|
326
|
+
throw Error(`LlmModel.predictNextToken: ${this.name} expects an input with a batch size of 1`);
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
return tf.tidy(() => {
|
|
330
|
+
// comes back as [batch, sequence_length, vocab_size]
|
|
331
|
+
const prediction = this.apply(input, { kvCache: kv_cache }) as tf.Tensor;
|
|
332
|
+
|
|
333
|
+
const [batch_size, sequence_length, vocab_size] = prediction.shape;
|
|
334
|
+
|
|
335
|
+
// get the last token
|
|
336
|
+
const next_token = prediction.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size]).argMax(2)
|
|
337
|
+
|
|
338
|
+
return next_token as tf.Tensor2D;
|
|
339
|
+
})
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
get stopPredicting() {
|
|
344
|
+
return this.stopPredicting_;
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
set stopPredicting(stop: boolean) {
|
|
349
|
+
this.stopPredicting_ = stop;
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
tf.serialization.registerClass(LlmModel);
|