@stellarapp/tfjs-stellar 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/jest.config.ts +203 -0
- package/package.json +24 -0
- package/src/index.ts +93 -0
- package/src/kv_cache.ts +205 -0
- package/src/layers/cached_rope_multihead_attention.test.ts +59 -0
- package/src/layers/cached_rope_multihead_attention.ts +113 -0
- package/src/layers/gpt_decoder_block.ts +77 -0
- package/src/layers/multihead_attention.test.ts +212 -0
- package/src/layers/multihead_attention.ts +371 -0
- package/src/layers/positional_encoding.test.ts +113 -0
- package/src/layers/positional_encoding.ts +158 -0
- package/src/layers/rotary_position_embedding.test.ts +107 -0
- package/src/layers/rotary_position_embedding.ts +163 -0
- package/src/layers/token_and_positional_embedding.test.ts +81 -0
- package/src/layers/token_and_positional_embedding.ts +149 -0
- package/src/layers/transformer_decoder.test.ts +100 -0
- package/src/layers/transformer_decoder.ts +236 -0
- package/src/layers/transformer_encoder.test.ts +85 -0
- package/src/layers/transformer_encoder.ts +224 -0
- package/src/losses/dice.ts +156 -0
- package/src/losses/index.ts +1 -0
- package/src/metrics.ts +32 -0
- package/src/models/gpt_model.ts +232 -0
- package/src/models/index.ts +2 -0
- package/src/models/llm_model.ts +355 -0
- package/src/models/u_net.ts +240 -0
- package/src/packing_mask.ts +28 -0
- package/src/testing.ts +1 -0
- package/src/tfjs_types.ts +15 -0
- package/src/utils.test.ts +101 -0
- package/src/utils.ts +86 -0
- package/tsconfig.json +49 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
+
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
4
|
+
|
|
5
|
+
import { type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
6
|
+
import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
export interface TransformerDecoderArgs extends Omit<MultiHeadAttentionArgs, "causal"> {
|
|
10
|
+
activation?: "relu" | "gelu";
|
|
11
|
+
dimsFeedForward?: number;
|
|
12
|
+
causal?: boolean; // use causal mask for attention on inputs
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* This class implements the transformer decoder architecture from
|
|
18
|
+
* the 2017 paper "Attention Is All You Need".
|
|
19
|
+
*
|
|
20
|
+
* This decoder-only transformer layer accepts one tensor input.
|
|
21
|
+
* The input tensor should have the shape
|
|
22
|
+
* `[ batch, sequences, embedding dims ]`.
|
|
23
|
+
*
|
|
24
|
+
* Causal masking is enabled by default for the initial attention sub-layer.
|
|
25
|
+
*
|
|
26
|
+
* @param numHeads number of attention heads to use
|
|
27
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
28
|
+
* @param causal use causal masking on inputs (masks future inputs to prevent looking ahead), default `true`
|
|
29
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
30
|
+
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
31
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
32
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
33
|
+
*/
|
|
34
|
+
export class TransformerDecoder extends tf.layers.Layer {
|
|
35
|
+
static className = "TransformerDecoder";
|
|
36
|
+
|
|
37
|
+
protected readonly causalSelfAttention: tf.layers.Layer;
|
|
38
|
+
protected readonly causalSelfAttentionDropout: tf.layers.Layer;
|
|
39
|
+
protected readonly causalSelfAttentionNorm: tf.layers.Layer;
|
|
40
|
+
|
|
41
|
+
protected readonly feedforward1: tf.layers.Layer;
|
|
42
|
+
protected readonly feedforward2: tf.layers.Layer;
|
|
43
|
+
protected readonly feedForwardDropout: tf.layers.Layer;
|
|
44
|
+
protected readonly feedFowardNorm: tf.layers.Layer;
|
|
45
|
+
|
|
46
|
+
protected readonly numHeads: number;
|
|
47
|
+
protected readonly embedDim: number;
|
|
48
|
+
protected readonly useBias: boolean;
|
|
49
|
+
protected readonly dropout: number;
|
|
50
|
+
protected readonly activation: ActivationIdentifier;
|
|
51
|
+
protected readonly dimsFeedForward: number;
|
|
52
|
+
|
|
53
|
+
constructor({ numHeads, embedDim, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerDecoderArgs) {
|
|
54
|
+
super(args);
|
|
55
|
+
|
|
56
|
+
this.numHeads = numHeads;
|
|
57
|
+
this.embedDim = embedDim;
|
|
58
|
+
this.useBias = useBias ?? true;
|
|
59
|
+
this.dropout = dropout ?? 0.1;
|
|
60
|
+
this.activation = activation ?? "relu";
|
|
61
|
+
|
|
62
|
+
if (this.dropout >= 1) {
|
|
63
|
+
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
// in the paper section 3.3, d_model=512 (embedDim) and first dense layer outputs d_ff=2048
|
|
67
|
+
this.dimsFeedForward = dimsFeedForward ?? embedDim * 4;
|
|
68
|
+
|
|
69
|
+
// self attention sub-block
|
|
70
|
+
this.causalSelfAttention = new CachedRoPEMultiHeadAttention({
|
|
71
|
+
numHeads: this.numHeads, embedDim: this.embedDim,
|
|
72
|
+
useBias: this.useBias, dropout: this.dropout,
|
|
73
|
+
causal: true
|
|
74
|
+
});
|
|
75
|
+
this.causalSelfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
|
|
76
|
+
this.causalSelfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
77
|
+
|
|
78
|
+
// feed forward sub-block
|
|
79
|
+
this.feedforward1 = tf.layers.dense({
|
|
80
|
+
units: this.dimsFeedForward,
|
|
81
|
+
activation: this.activation,
|
|
82
|
+
useBias: this.useBias,
|
|
83
|
+
});
|
|
84
|
+
this.feedforward2 = tf.layers.dense({
|
|
85
|
+
units: this.embedDim,
|
|
86
|
+
activation: "linear",
|
|
87
|
+
useBias: this.useBias
|
|
88
|
+
});
|
|
89
|
+
this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
|
|
90
|
+
this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
/**
|
|
95
|
+
* Forward propagation
|
|
96
|
+
*
|
|
97
|
+
* @param inputs input tensor
|
|
98
|
+
* @return the output tensor
|
|
99
|
+
*/
|
|
100
|
+
override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
|
|
101
|
+
// validate the input tensors
|
|
102
|
+
if (Array.isArray(inputs) && inputs.length != 1 && inputs.length != 2) {
|
|
103
|
+
throw Error(`${this.getClassName()}::call ${this.name} expects one input tensor, got ${inputs.length} inputs.`);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if (Array.isArray(inputs)) {
|
|
107
|
+
inputs = inputs[0] as tf.Tensor;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// perform forward propagation
|
|
111
|
+
return tf.tidy(() => {
|
|
112
|
+
let output = this.causalSelfAttentionBlock(inputs, kwargs);
|
|
113
|
+
output = this.feedForwardBlock(output, kwargs);
|
|
114
|
+
|
|
115
|
+
return output;
|
|
116
|
+
});
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
protected causalSelfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
121
|
+
return tf.tidy(() => {
|
|
122
|
+
const residual = x;
|
|
123
|
+
|
|
124
|
+
let attention = this.causalSelfAttention.apply(x, kwargs) as tf.Tensor;
|
|
125
|
+
attention = this.causalSelfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
|
|
126
|
+
attention = tf.add(attention, residual);
|
|
127
|
+
attention = this.causalSelfAttentionNorm.apply(attention, kwargs) as tf.Tensor;
|
|
128
|
+
|
|
129
|
+
return attention;
|
|
130
|
+
});
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
protected feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
135
|
+
return tf.tidy(() => {
|
|
136
|
+
const residual = x;
|
|
137
|
+
|
|
138
|
+
let feedForward = this.feedforward1.apply(x, kwargs);
|
|
139
|
+
feedForward = this.feedforward2.apply(feedForward, kwargs);
|
|
140
|
+
feedForward = this.feedForwardDropout.apply(feedForward, kwargs) as tf.Tensor;
|
|
141
|
+
feedForward = tf.add(feedForward, residual);
|
|
142
|
+
feedForward = this.feedFowardNorm.apply(feedForward, kwargs) as tf.Tensor;
|
|
143
|
+
|
|
144
|
+
return feedForward;
|
|
145
|
+
});
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
/**
|
|
150
|
+
* Initialize the sublayers' weights and track them to enable serialization
|
|
151
|
+
*/
|
|
152
|
+
override build(inputShape: tf.Shape | tf.Shape[]): void {
|
|
153
|
+
let input_shapes: tf.Shape[] = [];
|
|
154
|
+
|
|
155
|
+
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
156
|
+
// input is an array of shapes
|
|
157
|
+
input_shapes = inputShape as tf.Shape[];
|
|
158
|
+
} else if (inputShape.length != 0) {
|
|
159
|
+
// input is a single shape
|
|
160
|
+
input_shapes = [inputShape as tf.Shape];
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
if (input_shapes.length != 1 && input_shapes.length != 2) {
|
|
164
|
+
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
165
|
+
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
const [decoderInputShape] = input_shapes;
|
|
169
|
+
|
|
170
|
+
if (decoderInputShape?.length != 3) {
|
|
171
|
+
throw Error(`${this.getClassName()}::build ${this.name} expects an input shape` +
|
|
172
|
+
` of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// initialize causal self attention sub-block's weights
|
|
176
|
+
this.causalSelfAttention.build(decoderInputShape);
|
|
177
|
+
this.causalSelfAttentionNorm.build(this.causalSelfAttention.computeOutputShape(decoderInputShape));
|
|
178
|
+
|
|
179
|
+
// initialize feedforward sub-block's weights
|
|
180
|
+
const feedforward1OutputShape = this.feedforward1.computeOutputShape(decoderInputShape);
|
|
181
|
+
const feedforward2OutputShape = this.feedforward2.computeOutputShape(feedforward1OutputShape);
|
|
182
|
+
|
|
183
|
+
this.feedforward1.build(decoderInputShape);
|
|
184
|
+
this.feedforward2.build(feedforward1OutputShape);
|
|
185
|
+
this.feedFowardNorm.build(feedforward2OutputShape);
|
|
186
|
+
|
|
187
|
+
// track sublayers' weights
|
|
188
|
+
this.trainableWeights = [
|
|
189
|
+
...this.causalSelfAttention.trainableWeights,
|
|
190
|
+
...this.causalSelfAttentionDropout.trainableWeights,
|
|
191
|
+
...this.causalSelfAttentionNorm.trainableWeights,
|
|
192
|
+
...this.feedforward1.trainableWeights,
|
|
193
|
+
...this.feedforward2.trainableWeights,
|
|
194
|
+
...this.feedForwardDropout.trainableWeights,
|
|
195
|
+
...this.feedFowardNorm.trainableWeights
|
|
196
|
+
];
|
|
197
|
+
|
|
198
|
+
// rename the weights otherwise they'll take on the default naming and overlap
|
|
199
|
+
// each other which breaks model loading due to duplicate weight names
|
|
200
|
+
let indexing = 0;
|
|
201
|
+
|
|
202
|
+
for (const weight of this.trainableWeights) {
|
|
203
|
+
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
204
|
+
(weight as any).name += unique_name;
|
|
205
|
+
(weight as any).originalName += unique_name;
|
|
206
|
+
indexing++;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
super.build(inputShape);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
/**
|
|
214
|
+
* Save the layer's hyperparameters for serialization
|
|
215
|
+
*/
|
|
216
|
+
override getConfig() {
|
|
217
|
+
const base_config = super.getConfig();
|
|
218
|
+
|
|
219
|
+
const config = {
|
|
220
|
+
numHeads: this.numHeads,
|
|
221
|
+
embedDim: this.embedDim,
|
|
222
|
+
useBias: this.useBias,
|
|
223
|
+
dropout: this.dropout,
|
|
224
|
+
activation: this.activation,
|
|
225
|
+
dimsFeedForward: this.dimsFeedForward
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
Object.assign(config, base_config);
|
|
229
|
+
|
|
230
|
+
return config;
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
tf.serialization.registerClass(TransformerDecoder);
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
|
|
3
|
+
import { TransformerEncoder } from "@/layers/transformer_encoder";
|
|
4
|
+
|
|
5
|
+
// disables warning for using the faster node backend,
|
|
6
|
+
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
7
|
+
tf.env().set('IS_NODE', false);
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
describe("TransformerEncoder tests", () => {
|
|
11
|
+
it("should return an output with the same shape as the input", () => {
|
|
12
|
+
const input = tf.randomUniform([2, 3, 10]);
|
|
13
|
+
|
|
14
|
+
const decoder = new TransformerEncoder({
|
|
15
|
+
numHeads: 2, embedDim: input.shape.at(-1)!,
|
|
16
|
+
dropout: 0.5, activation: "gelu", dimsFeedForward: 512, useBias: true
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
const output = decoder.apply(input) as tf.Tensor;
|
|
20
|
+
|
|
21
|
+
expect(output.shape.length).toBe(input.shape.length);
|
|
22
|
+
})
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
test("correct forward calls", () => {
|
|
26
|
+
const input = tf.randomUniform([2, 3, 10]);
|
|
27
|
+
|
|
28
|
+
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)! });
|
|
29
|
+
expect(() => encoder.apply(input)).not.toThrow();
|
|
30
|
+
expect(() => encoder.apply([input])).not.toThrow();
|
|
31
|
+
|
|
32
|
+
const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1)!, causal: true });
|
|
33
|
+
expect(() => causal.apply(input)).not.toThrow();
|
|
34
|
+
expect(() => causal.apply([input])).not.toThrow();
|
|
35
|
+
})
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
|
|
39
|
+
const input = tf.randomUniform([2, 3, 10]);
|
|
40
|
+
|
|
41
|
+
expect(() => new TransformerEncoder({ numHeads: 3, embedDim: input.shape.at(-1)! })).toThrow();
|
|
42
|
+
expect(() => new TransformerEncoder({ numHeads: 5, embedDim: input.shape.at(-1)! })).not.toThrow();
|
|
43
|
+
})
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
it("should not accept non-rank 3 tensor inputs", () => {
|
|
47
|
+
const incorrect_input = tf.randomUniform([2, 3, 10, 10]);
|
|
48
|
+
const incorrect_input2 = tf.randomUniform([2, 3]);
|
|
49
|
+
const correct_input = tf.randomUniform([2, 3, 10]);
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1)! });
|
|
53
|
+
expect(() => encoder.apply([correct_input, correct_input])).toThrow();
|
|
54
|
+
|
|
55
|
+
expect(() => encoder.apply(incorrect_input)).toThrow();
|
|
56
|
+
expect(() => encoder.apply(incorrect_input2)).toThrow();
|
|
57
|
+
|
|
58
|
+
expect(() => encoder.apply([correct_input, incorrect_input])).toThrow();
|
|
59
|
+
expect(() => encoder.apply([incorrect_input, correct_input])).toThrow();
|
|
60
|
+
|
|
61
|
+
expect(() => encoder.apply([correct_input, incorrect_input2])).toThrow();
|
|
62
|
+
expect(() => encoder.apply([incorrect_input2, correct_input])).toThrow();
|
|
63
|
+
})
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
it("should accept exactly one input", () => {
|
|
67
|
+
const input = tf.randomUniform([2, 3, 10]);
|
|
68
|
+
|
|
69
|
+
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
|
|
70
|
+
expect(() => encoder.apply(input)).not.toThrow();
|
|
71
|
+
expect(() => encoder.apply([input])).not.toThrow();
|
|
72
|
+
|
|
73
|
+
expect(() => encoder.apply([])).toThrow();
|
|
74
|
+
expect(() => encoder.apply([input, input])).toThrow();
|
|
75
|
+
expect(() => encoder.apply([input, input, input])).toThrow()
|
|
76
|
+
})
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
it("should return a non-empty config dict", () => {
|
|
80
|
+
const input = tf.randomUniform([2, 3, 10]);
|
|
81
|
+
|
|
82
|
+
const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1)! });
|
|
83
|
+
expect(Object.keys(encoder.getConfig())).not.toBe(0);
|
|
84
|
+
})
|
|
85
|
+
})
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { type Kwargs } from "@tensorflow/tfjs-layers/dist/types";
|
|
3
|
+
import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
|
|
4
|
+
|
|
5
|
+
import { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
export interface TransformerEncoderArgs extends MultiHeadAttentionArgs {
|
|
9
|
+
activation?: "relu" | "gelu";
|
|
10
|
+
dimsFeedForward?: number;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* This class implements the transformer encoder architecture from the 2017 paper
|
|
16
|
+
* Attention Is All You Need.
|
|
17
|
+
*
|
|
18
|
+
* This layer accepts exactly one tensor input with the shape
|
|
19
|
+
* `[ batch, sequences, embedding dims ]`.
|
|
20
|
+
*
|
|
21
|
+
* @param numHeads number of attention heads to use
|
|
22
|
+
* @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
|
|
23
|
+
* @param causal use causal masking, default `false` for encoders
|
|
24
|
+
* @param dropout use dropout during the attention calculations, default `0.1`
|
|
25
|
+
* @param activation the activation of the intermediate feed forward layer, default `relu`
|
|
26
|
+
* @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
|
|
27
|
+
* @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
|
|
28
|
+
*/
|
|
29
|
+
export class TransformerEncoder extends tf.layers.Layer {
|
|
30
|
+
static className = "TransformerEncoder";
|
|
31
|
+
|
|
32
|
+
private readonly selfAttention: tf.layers.Layer;
|
|
33
|
+
private readonly selfAttentionDropout: tf.layers.Layer;
|
|
34
|
+
private readonly selfAttentionNorm: tf.layers.Layer;
|
|
35
|
+
|
|
36
|
+
private readonly reluLayer: tf.layers.Layer;
|
|
37
|
+
private readonly linearLayer: tf.layers.Layer;
|
|
38
|
+
private readonly feedForwardDropout: tf.layers.Layer;
|
|
39
|
+
private readonly feedFowardNorm: tf.layers.Layer;
|
|
40
|
+
|
|
41
|
+
private readonly numHeads: number;
|
|
42
|
+
private readonly embedDim: number;
|
|
43
|
+
private readonly causal: boolean;
|
|
44
|
+
private readonly useBias: boolean;
|
|
45
|
+
private readonly dropout: number;
|
|
46
|
+
private readonly activation: ActivationIdentifier;
|
|
47
|
+
private readonly dimsFeedForward: number;
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }: TransformerEncoderArgs) {
|
|
51
|
+
super(args);
|
|
52
|
+
|
|
53
|
+
this.numHeads = numHeads;
|
|
54
|
+
this.embedDim = embedDim;
|
|
55
|
+
this.causal = causal ?? false;
|
|
56
|
+
this.useBias = useBias ?? true;
|
|
57
|
+
this.dropout = dropout ?? 0.1;
|
|
58
|
+
this.activation = activation ?? "relu";
|
|
59
|
+
this.dimsFeedForward = dimsFeedForward ?? 2048;
|
|
60
|
+
|
|
61
|
+
if (this.dropout >= 1) {
|
|
62
|
+
throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
// self attention sub-block
|
|
66
|
+
this.selfAttention = new MultiHeadAttention({
|
|
67
|
+
numHeads: this.numHeads, embedDim: this.embedDim, useBias: this.useBias,
|
|
68
|
+
dropout: this.dropout, causal: this.causal
|
|
69
|
+
});
|
|
70
|
+
this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout })
|
|
71
|
+
this.selfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
72
|
+
|
|
73
|
+
// feed forward sub-block
|
|
74
|
+
this.reluLayer = tf.layers.dense({
|
|
75
|
+
units: this.dimsFeedForward, activation: this.activation,
|
|
76
|
+
useBias: this.useBias
|
|
77
|
+
});
|
|
78
|
+
this.linearLayer = tf.layers.dense({
|
|
79
|
+
units: this.embedDim, activation: "linear",
|
|
80
|
+
useBias: this.useBias
|
|
81
|
+
});
|
|
82
|
+
this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
|
|
83
|
+
this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
/**
|
|
88
|
+
* Forward propagation
|
|
89
|
+
*/
|
|
90
|
+
override call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor | tf.Tensor[] {
|
|
91
|
+
// validate the input tensors
|
|
92
|
+
let input: tf.Tensor;
|
|
93
|
+
|
|
94
|
+
if (Array.isArray(inputs)) {
|
|
95
|
+
if (inputs.length != 1) {
|
|
96
|
+
throw Error(`${this.getClassName}::call ${this.name} expects exactly 1 tensor` +
|
|
97
|
+
` input, got ${inputs.length} inputs instead.`);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
input = inputs[0];
|
|
101
|
+
} else {
|
|
102
|
+
input = inputs;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
// perform forward propagation
|
|
106
|
+
return tf.tidy(() => {
|
|
107
|
+
const attention = this.selfAttentionBlock(input, kwargs);
|
|
108
|
+
const feedforward = this.feedForwardBlock(attention, kwargs);
|
|
109
|
+
|
|
110
|
+
return feedforward;
|
|
111
|
+
});
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
private selfAttentionBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
116
|
+
return tf.tidy(() => {
|
|
117
|
+
const residual = x;
|
|
118
|
+
|
|
119
|
+
let attention = this.selfAttention.apply(x, kwargs) as tf.Tensor;
|
|
120
|
+
attention = this.selfAttentionDropout.apply(attention, kwargs) as tf.Tensor;
|
|
121
|
+
attention = tf.add(attention, residual);
|
|
122
|
+
attention = this.selfAttentionNorm.apply(attention) as tf.Tensor;
|
|
123
|
+
|
|
124
|
+
return attention;
|
|
125
|
+
});
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
private feedForwardBlock(x: tf.Tensor, kwargs: Kwargs): tf.Tensor {
|
|
130
|
+
return tf.tidy(() => {
|
|
131
|
+
const residual = x;
|
|
132
|
+
|
|
133
|
+
let feedForward = this.reluLayer.apply(x);
|
|
134
|
+
feedForward = this.linearLayer.apply(feedForward);
|
|
135
|
+
feedForward = this.feedForwardDropout.apply(feedForward, kwargs) as tf.Tensor;
|
|
136
|
+
feedForward = tf.add(feedForward, residual);
|
|
137
|
+
feedForward = this.feedFowardNorm.apply(feedForward) as tf.Tensor;
|
|
138
|
+
|
|
139
|
+
return feedForward;
|
|
140
|
+
});
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
/**
|
|
145
|
+
* Initialize the sublayers' weights and track them to enable backpropagation.
|
|
146
|
+
*/
|
|
147
|
+
override build(inputShape: tf.Shape | tf.Shape[]): void {
|
|
148
|
+
let input_shapes: tf.Shape[] = [];
|
|
149
|
+
|
|
150
|
+
if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
|
|
151
|
+
// input is an array of shapes
|
|
152
|
+
input_shapes = inputShape as tf.Shape[];
|
|
153
|
+
} else if (inputShape.length != 0) {
|
|
154
|
+
// input is a single shape
|
|
155
|
+
input_shapes = [inputShape as tf.Shape];
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
// expects only 1 rank 3 tensor input
|
|
159
|
+
if (input_shapes.length != 1 || input_shapes[0].length != 3) {
|
|
160
|
+
throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`)
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
// initialize self attention sub-block's weights
|
|
164
|
+
this.selfAttention.build(inputShape);
|
|
165
|
+
this.selfAttentionNorm.build(inputShape);
|
|
166
|
+
|
|
167
|
+
// inintialize feedforward sub-block's weights
|
|
168
|
+
const reluLayerOutputShape = this.reluLayer.computeOutputShape(inputShape);
|
|
169
|
+
const linearLayerOutputShape = this.linearLayer.computeOutputShape(reluLayerOutputShape);
|
|
170
|
+
|
|
171
|
+
this.reluLayer.build(inputShape);
|
|
172
|
+
this.linearLayer.build(reluLayerOutputShape);
|
|
173
|
+
this.feedFowardNorm.build(linearLayerOutputShape);
|
|
174
|
+
|
|
175
|
+
// track sublayers' weights
|
|
176
|
+
this.trainableWeights = [
|
|
177
|
+
...this.selfAttention.trainableWeights,
|
|
178
|
+
...this.selfAttentionDropout.trainableWeights,
|
|
179
|
+
...this.selfAttentionNorm.trainableWeights,
|
|
180
|
+
...this.reluLayer.trainableWeights,
|
|
181
|
+
...this.linearLayer.trainableWeights,
|
|
182
|
+
...this.feedForwardDropout.trainableWeights,
|
|
183
|
+
...this.feedFowardNorm.trainableWeights
|
|
184
|
+
];
|
|
185
|
+
|
|
186
|
+
// rename the weights otherwise they'll take on the default naming and overlap
|
|
187
|
+
// each other which breaks model loading due to duplicate weight names
|
|
188
|
+
let indexing = 0;
|
|
189
|
+
|
|
190
|
+
for (const weight of this.trainableWeights) {
|
|
191
|
+
const unique_name = `${this.getClassName()}_${indexing}`;
|
|
192
|
+
(weight as any).name += unique_name;
|
|
193
|
+
(weight as any).originalName += unique_name;
|
|
194
|
+
indexing++;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
super.build(inputShape);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
/**
|
|
202
|
+
* Save the layer's hyperparameters for serialization
|
|
203
|
+
*/
|
|
204
|
+
override getConfig(): tf.serialization.ConfigDict {
|
|
205
|
+
const base_config = super.getConfig();
|
|
206
|
+
|
|
207
|
+
const config = {
|
|
208
|
+
numHeads: this.numHeads,
|
|
209
|
+
embedDim: this.embedDim,
|
|
210
|
+
causal: this.causal,
|
|
211
|
+
useBias: this.useBias,
|
|
212
|
+
dropout: this.dropout,
|
|
213
|
+
activation: this.activation,
|
|
214
|
+
dimsFeedForward: this.dimsFeedForward
|
|
215
|
+
};
|
|
216
|
+
|
|
217
|
+
Object.assign(config, base_config);
|
|
218
|
+
|
|
219
|
+
return config;
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
tf.serialization.registerClass(TransformerEncoder);
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { categoricalCrossentropy, binaryCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
|
|
3
|
+
|
|
4
|
+
const epsilon = 1e-7;
|
|
5
|
+
|
|
6
|
+
const REDUCE_HW = [1, 2]; // reduce over width and height
|
|
7
|
+
const REDUCE_BHW = [0, 1, 2]; // reduce over batch, width, height
|
|
8
|
+
const REDUCE_BHWC = [0, 1, 2, 3]; // reduce all dimensions
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
// Standard (Sorensen) Dice Loss
|
|
12
|
+
export function diceBinaryStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
|
|
13
|
+
|
|
14
|
+
const y_true_flat = tf.reshape(y_true, [y_true.shape[0], -1]);
|
|
15
|
+
const y_pred_flat = tf.reshape(y_pred, [y_pred.shape[0], -1]);
|
|
16
|
+
|
|
17
|
+
const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat), 1);
|
|
18
|
+
const union = tf.add(tf.sum(y_true_flat, 1), tf.sum(y_pred_flat, 1));
|
|
19
|
+
|
|
20
|
+
const dice = tf.div(
|
|
21
|
+
intersection.mul(2).add(epsilon),
|
|
22
|
+
union.add(epsilon)
|
|
23
|
+
);
|
|
24
|
+
|
|
25
|
+
return tf.scalar(1).sub(dice);
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
// prevents minification of function name which TFJS relies on
|
|
30
|
+
Object.defineProperty(diceBinaryStandard, "name", { value: "diceBinaryStandard", configurable: false });
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
// https://github.com/keras-team/keras/blob/v3.3.3/keras/src/losses/losses.py#L1983-L2010
|
|
34
|
+
export function diceBinaryGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
|
|
35
|
+
const y_true_flat = tf.reshape(y_true, [-1]);
|
|
36
|
+
const y_pred_flat = tf.reshape(y_pred, [-1]);
|
|
37
|
+
|
|
38
|
+
const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat));
|
|
39
|
+
const union = tf.add(tf.sum(y_true_flat), tf.sum(y_pred_flat));
|
|
40
|
+
|
|
41
|
+
const dice = tf.div(
|
|
42
|
+
intersection.mul(2).add(epsilon),
|
|
43
|
+
union.add(epsilon)
|
|
44
|
+
);
|
|
45
|
+
|
|
46
|
+
return tf.scalar(1).sub(dice);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
// prevents minification of function name which TFJS relies on
|
|
51
|
+
Object.defineProperty(diceBinaryGlobal, "name", { value: "diceBinaryGlobal", configurable: false });
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
export function diceCategoricalStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
|
|
55
|
+
const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_HW);
|
|
56
|
+
const union = tf.add(y_true, y_pred).sum(REDUCE_HW);
|
|
57
|
+
|
|
58
|
+
const dice = tf.div(
|
|
59
|
+
intersection.mul(2).add(epsilon),
|
|
60
|
+
union.add(epsilon)
|
|
61
|
+
);
|
|
62
|
+
|
|
63
|
+
return tf.scalar(1).sub(tf.mean(dice, -1));
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
// prevents minification of function name which TFJS relies on
|
|
68
|
+
Object.defineProperty(diceCategoricalStandard, "name", { value: "diceCategoricalStandard", configurable: false });
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
export function diceCategoricalGeneralized(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
|
|
72
|
+
|
|
73
|
+
// this is done twice so we calculate it once
|
|
74
|
+
const y_true_sum = y_true.sum(REDUCE_BHW);
|
|
75
|
+
|
|
76
|
+
const weighting = tf.div(1, y_true_sum.square().add(epsilon));
|
|
77
|
+
|
|
78
|
+
const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHW).mul(weighting).sum();
|
|
79
|
+
const union = tf.add(y_true_sum, y_pred.sum(REDUCE_BHW)).mul(weighting).sum();
|
|
80
|
+
|
|
81
|
+
const dice = tf.div(
|
|
82
|
+
intersection.mul(2).add(epsilon),
|
|
83
|
+
union.add(epsilon)
|
|
84
|
+
);
|
|
85
|
+
|
|
86
|
+
return tf.scalar(1).sub(dice);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
// prevents minification of function name which TFJS relies on
|
|
91
|
+
Object.defineProperty(diceCategoricalGeneralized, "name", { value: "diceCategoricalGeneralized", configurable: false });
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
export function diceCategoricalGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
|
|
95
|
+
|
|
96
|
+
const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHWC);
|
|
97
|
+
const union = tf.add(tf.sum(y_true, REDUCE_BHWC), tf.sum(y_pred, REDUCE_BHWC));
|
|
98
|
+
|
|
99
|
+
const dice = tf.div(
|
|
100
|
+
intersection.mul(2).add(epsilon),
|
|
101
|
+
union.add(epsilon)
|
|
102
|
+
);
|
|
103
|
+
|
|
104
|
+
return tf.scalar(1).sub(dice);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
// prevents minification of function name which TFJS relies on
|
|
109
|
+
Object.defineProperty(diceCategoricalGlobal, "name", { value: "diceCategoricalGlobal", configurable: false });
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
|
|
114
|
+
* Both have equal weight.
|
|
115
|
+
*
|
|
116
|
+
* @param y_true the label tensor
|
|
117
|
+
* @param y_pred the prediction tensor (not sparse)
|
|
118
|
+
* @returns a tensor of shape `[ batch ]`
|
|
119
|
+
*/
|
|
120
|
+
export function diceBinaryCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
|
|
121
|
+
// reduce cross entropy shape from [B, H, W] to [B] to match dice
|
|
122
|
+
const bce = binaryCrossentropy(y_true, y_pred).mean(REDUCE_HW);
|
|
123
|
+
const dice = diceBinaryStandard(y_true, y_pred);
|
|
124
|
+
|
|
125
|
+
return tf.add(bce.mul(0.5), dice.mul(0.5));
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
// prevents minification of function name which TFJS relies on
|
|
130
|
+
Object.defineProperty(diceBinaryCrossentropy, "name", { value: "diceBinaryCrossentropy", configurable: false });
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
/**
|
|
134
|
+
* Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
|
|
135
|
+
* Both have equal weight. Expects dense (non-sparse) label tensors.
|
|
136
|
+
*
|
|
137
|
+
* This does not support sparse tensors because TFJS's
|
|
138
|
+
* sparseCategoricalCrossentropy loss onehots the label
|
|
139
|
+
* and calls categoricalCrossentropy. See
|
|
140
|
+
* https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
|
|
141
|
+
*
|
|
142
|
+
* @param y_true the label
|
|
143
|
+
* @param y_pred the prediction tensor (not sparse)
|
|
144
|
+
* @returns a tensor of shape `[ batch ]`
|
|
145
|
+
*/
|
|
146
|
+
export function diceCategoricalCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor {
|
|
147
|
+
// reduce cross entropy shape from [B, H, W] to [B] to match dice
|
|
148
|
+
const cce = categoricalCrossentropy(y_true, y_pred).mean(REDUCE_HW);
|
|
149
|
+
const dice = diceCategoricalStandard(y_true, y_pred);
|
|
150
|
+
|
|
151
|
+
return tf.add(cce.mul(0.5), dice.mul(0.5));
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
// prevents minification of function name which TFJS relies on
|
|
156
|
+
Object.defineProperty(diceCategoricalCrossentropy, "name", { value: "diceCategoricalCrossentropy", configurable: false });
|